Skip to content

Commit

Permalink
Merge pull request #1016 from activeloopai/fix/2.0/pytorch_old
Browse files Browse the repository at this point in the history
Fixes issues that prevented pytorch_old to run with workers>0
  • Loading branch information
AbhinavTuli committed Jul 8, 2021
2 parents fa8e734 + 0b79dd4 commit 75f9aff
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 85 deletions.
14 changes: 4 additions & 10 deletions hub/api/dataset.py
Expand Up @@ -4,17 +4,13 @@
import numpy as np

from hub.api.tensor import Tensor
from hub.constants import (
DEFAULT_MEMORY_CACHE_SIZE,
DEFAULT_LOCAL_CACHE_SIZE,
MB,
)
from hub.constants import DEFAULT_MEMORY_CACHE_SIZE, DEFAULT_LOCAL_CACHE_SIZE, MB

from hub.core.meta.dataset_meta import DatasetMeta

from hub.core.typing import StorageProvider
from hub.core.index import Index
from hub.integrations import dataset_to_pytorch, dataset_to_tensorflow
from hub.integrations import dataset_to_tensorflow
from hub.util.keys import dataset_exists, get_dataset_meta_key, tensor_exists
from hub.util.bugout_reporter import hub_reporter
from hub.util.cache_chain import generate_chain
Expand Down Expand Up @@ -263,7 +259,6 @@ def pytorch(
self,
transform: Optional[Callable] = None,
num_workers: int = 1,
tensors: Optional[List[str]] = None,
batch_size: Optional[int] = 1,
drop_last: Optional[bool] = False,
collate_fn: Optional[Callable] = None,
Expand All @@ -278,8 +273,6 @@ def pytorch(
Args:
transform (Callable, optional) : Transformation function to be applied to each sample.
num_workers (int): The number of workers to use for fetching data in parallel.
tensors (List, optional): Optionally provide a list of tensor names in the ordering that your training script expects.
For example, if the dataset that has "image" and "label" tensors and `tensors=["image", "label"]`, your training script should expect each batch will be provided as a tuple of (image, label).
batch_size (int, optional): Number of samples per batch to load. Default value is 1.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. Default value is False.
Expand All @@ -292,11 +285,12 @@ def pytorch(
Returns:
A torch.utils.data.DataLoader object.
"""
from hub.integrations import dataset_to_pytorch

return dataset_to_pytorch(
self,
transform,
num_workers=num_workers,
tensors=tensors,
batch_size=batch_size,
drop_last=drop_last,
collate_fn=collate_fn,
Expand Down
29 changes: 8 additions & 21 deletions hub/integrations/pytorch.py
Expand Up @@ -6,7 +6,6 @@
from hub.core.storage import StorageProvider, S3Provider, MemoryProvider
from hub.core.meta.tensor_meta import TensorMeta
from hub.util.remove_cache import get_base_storage
from hub.util.subscript_namedtuple import subscript_namedtuple as namedtuple
from itertools import repeat
from collections import defaultdict
from typing import Any, Callable, List, Optional, Set, Dict, Union, Tuple, Sequence
Expand Down Expand Up @@ -75,15 +74,16 @@ def dataset_to_pytorch(
dataset,
transform: Optional[Callable] = None,
num_workers: int = 1,
tensors: Optional[Sequence[str]] = None,
batch_size: Optional[int] = 1,
drop_last: Optional[bool] = False,
collate_fn: Optional[Callable] = None,
pin_memory: Optional[bool] = False,
):
dataset.flush()
_import_torch()
pytorch_ds = TorchDataset(dataset, transform, num_workers, tensors)
# TODO new pytorch approach doesn't support 0 workers currently
num_workers = max(num_workers, 1)
pytorch_ds = TorchDataset(dataset, transform, num_workers)
return torch.utils.data.DataLoader( # type: ignore
pytorch_ds,
batch_size=batch_size,
Expand All @@ -99,25 +99,12 @@ def __init__(
dataset,
transform: Optional[Callable] = None,
num_workers: int = 1,
tensors: Optional[Sequence[str]] = None,
):
self.transform = transform
self.num_workers: int = num_workers
self.map = ProcessPool(nodes=num_workers).map
self.length = len(dataset)
self.keys = list(dataset.tensors)

self.tensor_keys: List[str]
if tensors is not None:
for t in tensors:
if t not in dataset.tensors:
raise TensorDoesNotExistError(t)
self.tensor_keys = list(tensors)
else:
self.tensor_keys = list(dataset.tensors)

self._return_type = namedtuple("Tensors", self.tensor_keys)

self.tensor_keys = list(dataset.tensors)
self.storage = get_base_storage(dataset.storage)
if isinstance(self.storage, MemoryProvider):
raise DatasetUnsupportedPytorch(
Expand Down Expand Up @@ -199,7 +186,7 @@ def _load_all_chunk_engines(self):
# creating a cache around base storage to pass to ChunkEngine
return {
key: ChunkEngine(key, LRUCache(MemoryProvider(), self.storage, 16 * MB))
for key in self.keys
for key in self.tensor_keys
}

def _load_all_meta(self):
Expand Down Expand Up @@ -313,9 +300,9 @@ def _process_samples(self):
last_index = min(self.last_index_meta[key] for key in self.tensor_keys)
samples = []
for i in range(first_index, last_index + 1):
sample = self._return_type(
**{key: self.all_index_value_maps[key][i] for key in self.tensor_keys}
)
sample = {
key: self.all_index_value_maps[key][i] for key in self.tensor_keys
}
samples.append(sample)
self.processed_samples = samples
self.processed_range = slice(first_index, last_index)
Expand Down
41 changes: 18 additions & 23 deletions hub/integrations/pytorch_old.py
Expand Up @@ -7,14 +7,13 @@
ModuleNotInstalledException,
TensorDoesNotExistError,
)
from hub.util.subscript_namedtuple import subscript_namedtuple as namedtuple
import hub


def dataset_to_pytorch(
dataset,
transform: Optional[Callable] = None,
num_workers: int = 1,
tensors: Optional[Sequence[str]] = None,
batch_size: Optional[int] = 1,
drop_last: Optional[bool] = False,
collate_fn: Optional[Callable] = None,
Expand All @@ -33,14 +32,9 @@ def dataset_to_pytorch(
pytorch_ds = TorchDataset(
dataset,
transform,
tensors,
python_version_warning=python_version_warning,
)
# TODO add pytorch for num_workers > 1
if num_workers > 0:
raise NotImplementedError(
"Multiproccessed pytorch training is not support for Python version < 3.8. Please set num_workers equal to 0 or upgrade to python 3.8."
)

return torch.utils.data.DataLoader( # type: ignore
pytorch_ds,
num_workers=num_workers,
Expand All @@ -56,7 +50,6 @@ def __init__(
self,
dataset,
transform: Optional[Callable] = None,
tensors: Optional[Sequence[str]] = None,
python_version_warning: bool = True,
):

Expand All @@ -65,36 +58,38 @@ def __init__(
"Python version<3.8 detected. Pytorch iteration speeds will be slow. Use newer Python versions for faster data streaming to Pytorch."
)

self.dataset = dataset
self.dataset = None

base_storage = get_base_storage(dataset.storage)
if isinstance(base_storage, MemoryProvider):
self.storage = get_base_storage(dataset.storage)
self.index = dataset.index
if isinstance(self.storage, MemoryProvider):
raise DatasetUnsupportedPytorch(
"Datasets whose underlying storage is MemoryProvider are not supported for Pytorch iteration."
)

self.transform = transform
self.tensor_keys: List[str]
if tensors is not None:
for t in tensors:
if t not in dataset.tensors:
raise TensorDoesNotExistError(t)
self.tensor_keys = list(tensors)
else:
self.tensor_keys = list(dataset.tensors)
self._return_type = namedtuple("Tensors", self.tensor_keys)
self.tensor_keys = list(dataset.tensors)

def _apply_transform(self, sample: Union[Dict, Tuple]):
return self.transform(sample) if self.transform else sample

def _init_ds(self):
"""
For each process, dataset should be independently loaded
"""
if self.dataset is None:
self.dataset = hub.Dataset(storage=self.storage, index=self.index)

def __len__(self):
self._init_ds()
return len(self.dataset)

def __getitem__(self, index: int):
sample = self._return_type()
self._init_ds()
sample = {}
# pytorch doesn't support certain dtypes, which are type casted to another dtype below
for key in self.tensor_keys:
item = self.dataset[key][index].numpy()
item = self.dataset[key][index].numpy() # type: ignore
if item.dtype == "uint16":
item = item.astype("int32")
elif item.dtype in ["uint32", "uint64"]:
Expand Down
39 changes: 8 additions & 31 deletions hub/integrations/tests/test_pytorch.py
Expand Up @@ -12,6 +12,10 @@
from hub.core.tests.common import parametrize_all_dataset_storages


def to_tuple(sample):
return sample["image"], sample["image2"]


@requires_torch
@parametrize_all_dataset_storages
def test_pytorch_small(ds):
Expand All @@ -26,11 +30,6 @@ def test_pytorch_small(ds):
dl = ds.pytorch(num_workers=2)
return

if sys.version_info < (3, 8):
with pytest.raises(NotImplementedError):
dl = ds.pytorch(num_workers=2)
return

dl = ds.pytorch(num_workers=2, batch_size=1)

for i, batch in enumerate(dl):
Expand Down Expand Up @@ -87,19 +86,11 @@ def test_pytorch_transform(ds):
ds.create_tensor("image2")
ds.image2.extend(np.array([i * np.ones((100, 100)) for i in range(256)]))

def to_tuple(sample):
return sample["image"], sample["image2"]

if isinstance(get_base_storage(ds.storage), MemoryProvider):
with pytest.raises(DatasetUnsupportedPytorch):
dl = ds.pytorch(num_workers=2)
return

if sys.version_info < (3, 8):
with pytest.raises(NotImplementedError):
dl = ds.pytorch(num_workers=2)
return

dl = ds.pytorch(num_workers=2, transform=to_tuple, batch_size=1)

for i, batch in enumerate(dl):
Expand Down Expand Up @@ -127,11 +118,6 @@ def test_pytorch_with_compression(ds: Dataset):
dl = ds.pytorch(num_workers=2)
return

if sys.version_info < (3, 8):
with pytest.raises(NotImplementedError):
dl = ds.pytorch(num_workers=2)
return

dl = ds.pytorch(num_workers=2, batch_size=1)

for batch in dl:
Expand All @@ -153,13 +139,13 @@ def test_pytorch_small_old(ds):
if isinstance(get_base_storage(ds.storage), MemoryProvider):
with pytest.raises(DatasetUnsupportedPytorch):
dl = dataset_to_pytorch(
ds, num_workers=0, batch_size=1, python_version_warning=False
ds, num_workers=2, batch_size=1, python_version_warning=False
)
return

# .pytorch will automatically switch depending on version, this syntax is being used to ensure testing of old code on Python 3.8
dl = dataset_to_pytorch(
ds, num_workers=0, batch_size=1, python_version_warning=False
ds, num_workers=2, batch_size=1, python_version_warning=False
)

for i, batch in enumerate(dl):
Expand All @@ -173,11 +159,7 @@ def test_pytorch_small_old(ds):

@requires_torch
@parametrize_all_dataset_storages
@pytest.mark.xfail(
sys.version_info < (3, 8),
raises=NotImplementedError,
reason="requires python3.8 or higher",
)
@pytest.mark.skip(reason="future")
def test_custom_tensor_order(ds):
with ds:
tensors = ["a", "b", "c", "d"]
Expand All @@ -187,17 +169,12 @@ def test_custom_tensor_order(ds):

if isinstance(get_base_storage(ds.storage), MemoryProvider):
with pytest.raises(DatasetUnsupportedPytorch):
ptds = ds.pytorch(num_workers=2)
return

if sys.version_info < (3, 8):
with pytest.raises(NotImplementedError):
dl = ds.pytorch(num_workers=2)
return

dl_new = ds.pytorch(num_workers=2, tensors=["c", "d", "a"])
dl_old = dataset_to_pytorch(
ds, num_workers=0, tensors=["c", "d", "a"], python_version_warning=False
ds, num_workers=2, tensors=["c", "d", "a"], python_version_warning=False
)
for dl in [dl_new, dl_old]:
for i, batch in enumerate(dl):
Expand Down

0 comments on commit 75f9aff

Please sign in to comment.