Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix disable shuffling. #5287

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 51 additions & 39 deletions tensorflow_datasets/core/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

"""To shuffle records (stable)."""

from collections.abc import Iterator, Sequence
import math
import os
import struct
from typing import Iterator, List, Optional
from typing import List, Optional
import uuid

from absl import logging
import six
from etils import epath
from tensorflow_datasets.core import hashing
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import type_utils
Expand Down Expand Up @@ -56,14 +57,14 @@ def __init__(self, item1, item2):
self.item2 = item2


def _hkey_to_bytes(hkey):
def _hkey_to_bytes(hkey: int) -> bytes:
"""Converts 128 bits integer hkey to binary representation."""
max_int64 = 0xFFFFFFFFFFFFFFFF
return struct.pack('=QQ', (hkey >> 64) & max_int64, hkey & max_int64)


def _read_hkey(buff):
"""Reads from fobj and returns hkey (128 bites integer)."""
def _read_hkey(buff: bytes) -> int:
"""Reads from fobj and returns hkey (128 bits integer)."""
a, b = struct.unpack('=QQ', buff)
return (a << 64) | b

Expand Down Expand Up @@ -98,7 +99,7 @@ def _increase_open_files_limit():


def get_bucket_number(
hkey,
hkey: int,
num_buckets: int,
max_hkey: Optional[int] = None,
) -> int:
Expand All @@ -110,7 +111,7 @@ def get_bucket_number(
return min(math.trunc((hkey * num_buckets) / max_hkey), num_buckets - 1)


class _Bucket(object):
class _Bucket:
"""Holds (key, binary value) tuples to disk, fast.

Bucket instances are designed to be used either:
Expand All @@ -129,43 +130,43 @@ class _Bucket(object):
...
"""

def __init__(self, path):
def __init__(self, path: epath.Path):
"""Initialize a _Bucket instance.

Args:
path (str): path to bucket file, where to write to or read from.
path: Path to bucket file, where to write to or read from.
"""
self._path = path
self._fobj = None
self._length = 0
self._size = 0

@property
def size(self):
def size(self) -> int:
return self._size

def __len__(self):
def __len__(self) -> int:
return self._length

def add(self, key, data):
"""Adds (key, data) to bucket.
def add(self, hkey: type_utils.Key, data: bytes):
"""Adds (hkey, data) to bucket.

Args:
key (int): the key.
data (binary): the data.
hkey: The hashed key.
data: The data.
"""
if not self._fobj:
file_utils.makedirs_cached(os.path.dirname(self._path))
self._fobj = tf.io.gfile.GFile(self._path, mode='wb')
data_size = len(data)

try:
self._fobj.write(_hkey_to_bytes(key))
self._fobj.write(_hkey_to_bytes(hkey))
except tf.errors.ResourceExhaustedError as error:
# catch "Too many open files"
if error.message.endswith('Too many open files'):
_increase_open_files_limit()
self._fobj.write(_hkey_to_bytes(key))
self._fobj.write(_hkey_to_bytes(hkey))
else:
raise error
# http://docs.python.org/3/library/struct.html#byte-order-size-and-alignment
Expand All @@ -185,7 +186,7 @@ def flush(self):
self._fobj.flush()
self._fobj.close()

def read_values(self):
def read_values(self) -> Iterator[type_utils.KeySerializedExample]:
"""Yields (hkey, data) tuples stored in bucket."""
self.flush()
path = self._path
Expand All @@ -210,61 +211,73 @@ def del_file(self):
tf.io.gfile.remove(self._path)


class Shuffler(object):
class Shuffler:
"""Stores data in temp buckets, restitute it shuffled."""

def __init__(self, dirpath, hash_salt, disable_shuffling: bool = False):
def __init__(
self,
dirpath: epath.PathLike,
hash_salt: str | bytes,
disable_shuffling: bool = False,
):
"""Initialize Shuffler.

Args:
dirpath (string): directory in which to store temporary files.
hash_salt (string or bytes): salt to hash keys.
disable_shuffling (bool): specify whether to shuffle by hashing the key.
dirpath: Directory in which to store temporary files.
hash_salt: Salt to hash keys.
disable_shuffling: Specify whether to shuffle by hashing the key.
"""
grp_name = uuid.uuid4()
self._hasher = hashing.Hasher(hash_salt)
self._disable_shuffling = disable_shuffling
self._buckets: List[_Bucket] = []
for i in range(BUCKETS_NUMBER):
bucket_name = 'bucket_%s_%03d.tmp' % (grp_name, i)
path = os.path.join(dirpath, bucket_name)
path = epath.Path(dirpath) / bucket_name
self._buckets.append(_Bucket(path))
self._read_only = False
self._total_bytes = 0
# To keep data in memory until enough data has been gathered.
self._in_memory = True
self._mem_buffer = []
self._mem_buffer: List[type_utils.KeySerializedExample] = []

@property
def size(self):
def size(self) -> int:
"""Return total size in bytes of records (not keys)."""
return self._total_bytes

@property
def bucket_lengths(self):
def bucket_lengths(self) -> Sequence[int]:
if self._in_memory:
return [len(self._mem_buffer)]
return [len(b) for b in self._buckets]

def _add_to_bucket(self, hkey, data):
def _add_to_bucket(self, hkey: type_utils.Key, data: bytes):
# TODO(tfds): Support arbitrary keys.
# https://github.com/tensorflow/datasets/issues/5002
if not isinstance(hkey, int):
raise AssertionError(
f'Only int (not {type(hkey)}) can be used as key in Shuffler when'
' adding to bucket!'
)
bucket_number = get_bucket_number(hkey=hkey, num_buckets=BUCKETS_NUMBER)
self._buckets[bucket_number].add(hkey, data)

def _add_to_mem_buffer(self, hkey, data):
def _add_to_mem_buffer(self, hkey: type_utils.Key, data: bytes):
self._mem_buffer.append((hkey, data))
if self._total_bytes > MAX_MEM_BUFFER_SIZE:
for hkey, data in self._mem_buffer:
self._add_to_bucket(hkey, data)
self._mem_buffer = None
self._in_memory = False

def add(self, key, data):
def add(self, key: type_utils.Key, data: bytes):
"""Add (key, data) to shuffler."""
if self._read_only:
raise AssertionError('add() cannot be called after __iter__.')
if not isinstance(data, six.binary_type):
if not isinstance(data, bytes):
raise AssertionError(
'Only bytes (not %s) can be stored in Shuffler!' % (type(data))
f'Only bytes (not {type(data)}) can be stored in Shuffler!'
)
if self._disable_shuffling:
hkey = key
Expand All @@ -281,20 +294,19 @@ def __iter__(self) -> Iterator[type_utils.KeySerializedExample]:
previous_hkey = None
previous_data = None
iterator = self._iter_mem() if self._in_memory else self._iter_buckets()
if not self._disable_shuffling:
iterator = sorted(iterator)
for hkey, data in iterator:
if hkey == previous_hkey:
raise DuplicatedKeysError(data, previous_data)
previous_hkey = hkey
yield hkey, data
previous_data = data

def _iter_mem(self):
for hkey, data in sorted(self._mem_buffer):
yield hkey, data
def _iter_mem(self) -> Iterator[type_utils.KeySerializedExample]:
yield from self._mem_buffer

def _iter_buckets(self):
def _iter_buckets(self) -> Iterator[type_utils.KeySerializedExample]:
for bucket in self._buckets:
bucket_data = sorted(bucket.read_values())
yield from bucket.read_values()
bucket.del_file()
for hkey, data in bucket_data:
yield hkey, data
21 changes: 21 additions & 0 deletions tensorflow_datasets/core/shuffle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,27 @@ def test_duplicate_key(self):
with self.assertRaises(shuffle.DuplicatedKeysError):
next(iterator)

def test_disable_shuffling(self):
self._test_items(
'split1',
_ITEMS,
[value for _, value in _ITEMS],
disable_shuffling=True,
)
with mock.patch.object(
shuffle, 'MAX_MEM_BUFFER_SIZE', 0
), self.assertRaisesWithLiteralMatch(
AssertionError,
"Only int (not <class 'str'>) can be used as key in Shuffler when"
' adding to bucket!',
):
self._test_items(
'split1',
_ITEMS,
[value for _, value in _ITEMS],
disable_shuffling=True,
)


if __name__ == '__main__':
testing.test_main()
11 changes: 6 additions & 5 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@

"""To write records into sharded records files."""

from collections.abc import Iterable, Sequence
import dataclasses
import functools
import itertools
import json
import os
from typing import Any, Iterable, List, Optional, Tuple
from typing import Any, List, Optional, Tuple

from absl import logging
from etils import epath
Expand Down Expand Up @@ -97,7 +98,7 @@ def _get_index_path(path: str) -> epath.PathLike:
def _get_shard_specs(
num_examples: int,
total_size: int,
bucket_lengths: List[int],
bucket_lengths: Sequence[int],
filename_template: naming.ShardedFileTemplate,
shard_config: shard_utils.ShardConfig,
) -> List[_ShardSpec]:
Expand Down Expand Up @@ -225,16 +226,16 @@ def __init__(
self._file_format = file_format
self._shard_config = shard_config or shard_utils.ShardConfig()

def write(self, key, example):
def write(self, key: type_utils.Key, example):
"""Writes given Example.

The given example is not directly written to the tfrecord file, but to a
temporary file (or memory). The finalize() method does write the tfrecord
files.

Args:
key (int|bytes): the key associated with the example. Used for shuffling.
example: the Example to write to the tfrecord file.
key: The key associated with the example. Used for shuffling.
example: The Example to write to the tfrecord file.
"""
serialized_example = self._serializer.serialize_example(example=example)
self._shuffler.add(key, serialized_example)
Expand Down