Skip to content

Commit

Permalink
Merge pull request #2537 from activeloopai/add_seed_to_random_split
Browse files Browse the repository at this point in the history
Allow user-control of random seeds
  • Loading branch information
levongh committed Aug 24, 2023
2 parents 8f008fc + 1122a87 commit dadfe2c
Show file tree
Hide file tree
Showing 13 changed files with 133 additions and 12 deletions.
3 changes: 3 additions & 0 deletions deeplake/__init__.py
Expand Up @@ -24,6 +24,7 @@
from .core.dataset import Dataset
from .core.transform import compute, compose
from .core.tensor import Tensor
from .core.seed import DeeplakeRandom
from .util.bugout_reporter import deeplake_reporter
from .compression import SUPPORTED_COMPRESSIONS
from .htype import HTYPE_CONFIGURATIONS
Expand All @@ -50,6 +51,7 @@
ingest_huggingface = huggingface.ingest_huggingface
dataset = api_dataset.init # type: ignore
tensor = Tensor
random = DeeplakeRandom()

__all__ = [
"tensor",
Expand All @@ -76,6 +78,7 @@
"delete",
"copy",
"rename",
"random",
]


Expand Down
1 change: 1 addition & 0 deletions deeplake/core/dataset/deeplake_query_dataset.py
Expand Up @@ -331,5 +331,6 @@ def __del__(self):
def random_split(self, lengths: Sequence[Union[int, float]]):
if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
lengths = calculate_absolute_lengths(lengths, len(self))

vs = self.indra_ds.random_split(lengths)
return [DeepLakeQueryDataset(self.deeplake_ds, v) for v in vs]
25 changes: 25 additions & 0 deletions deeplake/core/seed.py
@@ -0,0 +1,25 @@
import numpy as np
from typing import Optional

class DeeplakeRandom(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(DeeplakeRandom, cls).__new__(cls)
cls.instance.internal_seed = None
cls.instance.indra_api = None
return cls.instance


def seed(self, seed: Optional[int] = None):
if isinstance(seed, Optional[int]):
self.internal_seed = seed
if self.indra_api is None:
from deeplake.enterprise.convert_to_libdeeplake import import_indra_api_silent
self.indra_api = import_indra_api_silent()
if self.indra_api is not None:
self.indra_api.set_seed(self.internal_seed)
else:
raise TypeError(f"provided seed type `{type(seed)}` is increect seed must be an integer")

def get_seed(self) -> Optional[int]:
return self.internal_seed
36 changes: 36 additions & 0 deletions deeplake/core/tests/test_deeplake_indra_dataset.py
Expand Up @@ -293,6 +293,41 @@ def test_query_tensors_polygon_htype_consistency(local_auth_ds_generator):
assert np.all(i == j)


@requires_libdeeplake
def test_random_split_with_seed(local_auth_ds_generator):
deeplake_ds = local_auth_ds_generator()
from deeplake.core.seed import DeeplakeRandom

with deeplake_ds:
deeplake_ds.create_tensor("label", htype="generic", dtype=np.int32)
for i in range(1000):
deeplake_ds.label.append(int(i % 100))

deeplake_indra_ds = deeplake_ds.query("SELECT * GROUP BY label")

initial_state = np.random.get_state()
DeeplakeRandom().seed(100)
split1 = deeplake_indra_ds.random_split([0.2, 0.2, 0.6])
assert len(split1) == 3
assert len(split1[0]) == 20

DeeplakeRandom().seed(101)
split2 = deeplake_indra_ds.random_split([0.2, 0.2, 0.6])
assert len(split2) == 3
assert len(split2[0]) == 20

DeeplakeRandom().seed(100)
split3 = deeplake_indra_ds.random_split([0.2, 0.2, 0.6])
assert len(split3) == 3
assert len(split3[0]) == 20

for i in range(len(split1)):
assert np.all(split1[i].label.numpy() == split3[i].label.numpy())
assert not np.all(split1[i].label.numpy() == split2[i].label.numpy())

np.random.set_state(initial_state)


@requires_libdeeplake
def test_random_split(local_auth_ds_generator):
deeplake_ds = local_auth_ds_generator()
Expand All @@ -306,6 +341,7 @@ def test_random_split(local_auth_ds_generator):
split = deeplake_indra_ds.random_split([0.2, 0.2, 0.6])
assert len(split) == 3
assert len(split[0]) == 20

l = split[0].dataloader().pytorch()
for b in l:
pass
Expand Down
17 changes: 14 additions & 3 deletions deeplake/enterprise/convert_to_libdeeplake.py
Expand Up @@ -13,19 +13,30 @@
INDRA_API = None


def import_indra_api():
def import_indra_api_silent():
global INDRA_API
if INDRA_API:
return INDRA_API
if not importlib.util.find_spec("indra"):
raise_indra_installation_error() # type: ignore
return None
try:
from indra import api # type: ignore

INDRA_API = api
return api
except Exception as e:
raise_indra_installation_error(e)
return e


def import_indra_api():
api = import_indra_api_silent()

if api is None:
raise_indra_installation_error() # type: ignore
elif isinstance(api, Exception):
raise_indra_installation_error(api)
else:
return api


INDRA_INSTALLED = bool(importlib.util.find_spec("indra"))
Expand Down
24 changes: 22 additions & 2 deletions deeplake/enterprise/dataloader.py
Expand Up @@ -136,6 +136,9 @@ def __init__(
self.__initialized = True
self._IterableDataset_len_called = None
self._iterator = None
self._worker_init_fn = None

self._internal_iterator = None

@property
def batch_size(self):
Expand Down Expand Up @@ -167,7 +170,13 @@ def timeout(self):

@property
def worker_init_fn(self):
return None
return self._worker_init_fn

@worker_init_fn.setter
def worker_init_fn(self, fn):
self._worker_init_fn = fn
if self._dataloader is not None:
self._dataloader.worker_init_fn = fn

@property # type: ignore
def multiprocessing_context(self):
Expand Down Expand Up @@ -696,9 +705,20 @@ def __iter__(self):
htype_dict=htype_dict,
ndim_dict=ndim_dict,
tensor_info_dict=tensor_info_dict,
worker_init_fn=self.worker_init_fn,
)
dataset_read(self._orig_dataset)
return iter(self._dataloader)

if self._internal_iterator is not None:
self._internal_iterator = iter(self._internal_iterator)
return self

def __next__(self):
if self._dataloader is None:
self.__iter__()
if self._internal_iterator is None:
self._internal_iterator = iter(self._dataloader)
return next(self._internal_iterator)


def dataloader(dataset, ignore_errors: bool = False) -> DeepLakeDataLoader:
Expand Down
14 changes: 14 additions & 0 deletions deeplake/enterprise/test_pytorch.py
Expand Up @@ -2,6 +2,7 @@
import deeplake
import numpy as np
import pytest
from functools import partial
from deeplake.util.exceptions import EmptyTensorError, TensorDoesNotExistError

from deeplake.util.remove_cache import get_base_storage
Expand Down Expand Up @@ -62,6 +63,19 @@ def index_transform(sample):
return sample["index"], sample["xyz"]


def dummy_init_fn(arg):
return f"function called with arg {arg}"


@requires_libdeeplake
def test_setting_woker_init_function(local_auth_ds):
dl = local_auth_ds.dataloader().pytorch()

assert dl.worker_init_fn == None
dl.worker_init_fn = partial(dummy_init_fn, 1024)
assert dl.worker_init_fn() == f"function called with arg 1024"


@requires_torch
@requires_libdeeplake
@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion deeplake/integrations/pytorch/common.py
Expand Up @@ -142,7 +142,7 @@ def validate_decode_method(
for tensor_name, decode_method in decode_method.items():
if tensor_name not in all_tensor_keys:
raise ValueError(
"tensor {tensor_name} specified in decode_method not found in tensors."
f"tensor {tensor_name} specified in decode_method not found in tensors."
)
if tensor_name in jpeg_png_compressed_tensors_set:
if decode_method not in jpeg_png_supported_decode_methods:
Expand Down
9 changes: 5 additions & 4 deletions deeplake/integrations/pytorch/shuffle_buffer.py
@@ -1,5 +1,5 @@
from typing import List, Any, Sequence
from random import randrange
from random import Random
from functools import reduce
from operator import mul
import numpy as np
Expand Down Expand Up @@ -30,7 +30,8 @@ class ShuffleBuffer:
def __init__(self, size: int) -> None:
if size <= 0:
raise ValueError("Buffer size should be positive value more than zero")

from deeplake.core.seed import DeeplakeRandom
self.random = Random(DeeplakeRandom().get_seed())
self.size = size
self.buffer: List[Any] = list()
self.buffer_used = 0
Expand Down Expand Up @@ -83,7 +84,7 @@ def exchange(self, sample):
return sample

# exchange samples with shuffle buffer
selected = randrange(buffer_len)
selected = self.random.randrange(buffer_len)
val = self.buffer[selected]
self.buffer[selected] = sample

Expand All @@ -97,7 +98,7 @@ def exchange(self, sample):
self.close_buffer_pbar()
if buffer_len > 0:
# return random selection
selected = randrange(buffer_len)
selected = self.random.randrange(buffer_len)
val = self.buffer.pop(selected)
self.buffer_used -= self._sample_size(val)

Expand Down
4 changes: 4 additions & 0 deletions deeplake/util/scheduling.py
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from collections import defaultdict
from deeplake.core.meta.encode.chunk_id import ChunkIdEncoder
from deeplake.core.seed import DeeplakeRandom


def find_primary_tensor(dataset):
Expand Down Expand Up @@ -37,6 +38,8 @@ def create_fetching_schedule(dataset, primary_tensor_name, shuffle_within_chunks
enc_array = chunk_id_encoder.array
num_chunks = chunk_id_encoder.num_chunks
# pick chunks randomly, one by one
prev_state = np.random.get_state()
np.random.seed(DeeplakeRandom().get_seed())
chunk_order = np.random.choice(num_chunks, num_chunks, replace=False)
schedule = []
for chunk_idx in chunk_order:
Expand All @@ -52,6 +55,7 @@ def create_fetching_schedule(dataset, primary_tensor_name, shuffle_within_chunks
elif isinstance(index_struct, dict):
idxs = filter(lambda idx: idx in index_struct, schedule)
schedule = [int(idx) for idx in idxs for _ in range(index_struct[idx])]
np.random.set_state(prev_state)
return schedule


Expand Down
5 changes: 5 additions & 0 deletions deeplake/util/shuffle.py
@@ -1,8 +1,13 @@
import numpy as np
from deeplake.core.seed import DeeplakeRandom



def shuffle(ds):
"""Returns a shuffled wrapper of a given Dataset."""
prev_state = np.random.get_state()
np.random.seed(DeeplakeRandom().get_seed())
idxs = np.arange(len(ds))
np.random.shuffle(idxs)
np.random.set_state(prev_state)
return ds[idxs.tolist()]
3 changes: 2 additions & 1 deletion deeplake/util/tests/test_shuffle.py
@@ -1,4 +1,5 @@
import numpy as np
import deeplake
from deeplake.util import shuffle


Expand All @@ -7,7 +8,7 @@ def test_shuffle(memory_ds):
ds.create_tensor("ints", dtype="int64")
ds.ints.extend(np.arange(10, dtype="int64").reshape((10, 1)))

np.random.seed(0)
deeplake.random.seed(0)
ds = shuffle(ds)
expected = [[2], [8], [4], [9], [1], [6], [7], [3], [0], [5]]
assert ds.ints.numpy().tolist() == expected
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -70,7 +70,7 @@ def libdeeplake_availabe():
extras_require["all"] = [req_map[r] for r in all_extras]

if libdeeplake_availabe():
libdeeplake = "libdeeplake==0.0.68"
libdeeplake = "libdeeplake==0.0.70"
extras_require["enterprise"] = [libdeeplake, "pyjwt"]
extras_require["all"].append(libdeeplake)

Expand Down

0 comments on commit dadfe2c

Please sign in to comment.