Skip to content

Commit

Permalink
Fix sub-dataset length calculation (#2533)
Browse files Browse the repository at this point in the history
* fix sub-dataset dataloader length
* bump libdeeplake version to 0.0.68
  • Loading branch information
levongh committed Aug 9, 2023
1 parent 8c160c1 commit 631b463
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
9 changes: 6 additions & 3 deletions deeplake/enterprise/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,13 @@ def collate_fn(self):
return get_collate_fn(self._collate, self._mode)

def __len__(self):
round_fn = math.floor if self._drop_last else math.ceil
return round_fn(
len(self._orig_dataset) / ((self.batch_size) * self._world_size)
len_ds = (
len(self._orig_dataset[self._tensors])
if self._tensors is not None
else len(self._orig_dataset)
)
round_fn = math.floor if self._drop_last else math.ceil
return round_fn(len_ds / ((self.batch_size) * self._world_size))

def batch(self, batch_size: int, drop_last: bool = False):
"""Returns a batched :class:`DeepLakeDataLoader` object.
Expand Down
13 changes: 13 additions & 0 deletions deeplake/enterprise/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,19 @@ def test_pytorch_transform(hub_cloud_ds):
np.testing.assert_array_equal(actual_image2, expected_image2)


@requires_libdeeplake
def test_inequal_tensors_dataloader_length(local_auth_ds):
with local_auth_ds as ds:
ds.create_tensor("images")
ds.create_tensor("label")
ds.images.extend(([i * np.ones((i + 1, i + 1)) for i in range(16)]))

ld = local_auth_ds.dataloader().batch(1).pytorch()
assert len(ld) == 0
ld1 = local_auth_ds.dataloader().batch(2).pytorch(tensors=["images"])
assert len(ld1) == 8


@requires_torch
@requires_libdeeplake
def test_pytorch_transform_dict(hub_cloud_ds):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def libdeeplake_availabe():
extras_require["all"] = [req_map[r] for r in all_extras]

if libdeeplake_availabe():
libdeeplake = "libdeeplake==0.0.66"
libdeeplake = "libdeeplake==0.0.68"
extras_require["enterprise"] = [libdeeplake, "pyjwt"]
extras_require["all"].append(libdeeplake)

Expand Down

0 comments on commit 631b463

Please sign in to comment.