Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manage temp tensor files in memory rather than sending them to storage #2819

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 5 additions & 8 deletions deeplake/api/tests/test_api.py
Expand Up @@ -2247,7 +2247,7 @@ def test_ignore_temp_tensors(local_path):
create_shape_tensor=False,
create_id_tensor=False,
)
ds.__temptensor.append(123)
# ds.__temptensor.append(123)

with deeplake.load(local_path) as ds:
assert list(ds.tensors) == []
Expand All @@ -2270,9 +2270,8 @@ def test_ignore_temp_tensors(local_path):

with deeplake.load(local_path, read_only=True) as ds:
assert list(ds.tensors) == []
assert list(ds._tensors()) == ["__temptensor"]
assert ds.meta.hidden_tensors == ["__temptensor"]
assert ds.__temptensor[0].numpy() == 123
assert list(ds._tensors()) == []
assert ds.meta.hidden_tensors == []


@pytest.mark.slow
Expand Down Expand Up @@ -2550,7 +2549,7 @@ def test_invalid_ds_name():
verify_dataset_name("hub://test/data-set_123")


def test_pickle_bug(local_ds):
def test_pickle_loses_temp_tensors(local_ds):
import pickle

file = BytesIO()
Expand All @@ -2564,9 +2563,7 @@ def test_pickle_bug(local_ds):
file.seek(0)
ds = pickle.load(file)

np.testing.assert_array_equal(
ds["__temp_123"].numpy(), np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
)
assert "__temp_123" not in ds


def test_max_view(memory_ds):
Expand Down
2 changes: 2 additions & 0 deletions deeplake/core/dataset/dataset.py
Expand Up @@ -1061,6 +1061,8 @@ def _delete_tensor(self, name: str, large_ok: bool = False):
raise TensorDoesNotExistError(name)

if not tensor_exists(key, self.storage, self.version_state["commit_id"]):
if key.startswith("__temp"):
return
raise TensorDoesNotExistError(name)

if not self._is_root():
Expand Down
39 changes: 36 additions & 3 deletions deeplake/core/meta/dataset_meta.py
Expand Up @@ -44,18 +44,51 @@ def allow_delete(self, value):
self.is_dirty = True

def __getstate__(self) -> Dict[str, Any]:
# d = super().__getstate__()
# d["tensors"] = self.tensors.copy()
# d["groups"] = self.groups.copy()
# d["tensor_names"] = self.tensor_names.copy()
# d["hidden_tensors"] = self.hidden_tensors.copy()
# d["default_index"] = self.default_index.copy()
# d["allow_delete"] = self._allow_delete
# return d

d = super().__getstate__()
d["tensors"] = self.tensors.copy()
d["tensors"] = list(
filter(lambda x: (not x.startswith("__temp")), self.tensors)
)
d["groups"] = self.groups.copy()
d["tensor_names"] = self.tensor_names.copy()
d["hidden_tensors"] = self.hidden_tensors.copy()

d["tensor_names"] = {
k: v for k, v in self.tensor_names.items() if not k.startswith("__temp")
}

d["hidden_tensors"] = list(
filter(lambda x: (not x.startswith("__temp")), self.hidden_tensors)
)
d["default_index"] = self.default_index.copy()
d["allow_delete"] = self._allow_delete
return d

def __setstate__(self, d):
if "allow_delete" in d:
d["_allow_delete"] = d.pop("allow_delete")
#
# if "hidden_tensors" in d:
# d["hidden_tensors"] = list(
# filter(lambda x: (not x.startswith("__temp")), d["hidden_tensors"])
# )
#
# if "tensors" in d:
# d["tensors"] = list(
# filter(lambda x: (not x.startswith("__temp")), d["tensors"])
# )
#
# if "tensor_names" in d:
# d["tensor_names"] = {
# k: v for k, v in d["tensor_names"].items() if not k.startswith("__temp")
# }

self.__dict__.update(d)

def add_tensor(self, name, key, hidden=False):
Expand Down
27 changes: 13 additions & 14 deletions deeplake/core/storage/azure.py
Expand Up @@ -14,6 +14,8 @@

class AzureProvider(StorageProvider):
def __init__(self, root: str, creds: Dict = {}, token: Optional[str] = None):
super().__init__()

try:
import azure.identity
import azure.storage.blob
Expand Down Expand Up @@ -87,7 +89,7 @@ def _set_clients(self):
self.container_name
)

def __setitem__(self, path, content):
def _setitem_impl(self, path, content):
self.check_readonly()
self._check_update_creds()
if isinstance(content, memoryview):
Expand All @@ -99,10 +101,10 @@ def __setitem__(self, path, content):
)
blob_client.upload_blob(content, overwrite=True)

def __getitem__(self, path):
def _getitem_impl(self, path):
return self.get_bytes(path)

def __delitem__(self, path):
def _delitem_impl(self, path):
self.check_readonly()
blob_client = self.container_client.get_blob_client(
f"{self.root_folder}/{path}"
Expand All @@ -111,7 +113,7 @@ def __delitem__(self, path):
raise KeyError(path)
blob_client.delete_blob()

def get_bytes(
def _get_bytes_impl(
self,
path: str,
start_byte: Optional[int] = None,
Expand Down Expand Up @@ -144,11 +146,12 @@ def get_bytes(
byts = blob_client.download_blob(offset=offset, length=length).readall()
return byts

def clear(self, prefix=""):
def _clear_impl(self, prefix=""):
self.check_readonly()
self._check_update_creds()
blobs = [
posixpath.join(self.root_folder, key) for key in self._all_keys(prefix)
posixpath.join(self.root_folder, key)
for key in self._all_keys_impl(prefix=prefix)
]
# delete_blobs can only delete 256 blobs at a time
batches = [blobs[i : i + 256] for i in range(0, len(blobs), 256)]
Expand Down Expand Up @@ -176,7 +179,7 @@ def get_sas_token(self):
)
return sas_token

def _all_keys(self, prefix: str = ""):
def _all_keys_impl(self, refresh: bool = False, prefix: str = ""):
self._check_update_creds()
prefix = posixpath.join(self.root_folder, prefix)
return {
Expand All @@ -189,14 +192,9 @@ def _all_keys(self, prefix: str = ""):
) # https://github.com/Azure/azure-sdk-for-python/issues/24814
}

def __iter__(self):
yield from self._all_keys()

def __len__(self):
self._check_update_creds()
return len(self._all_keys())

def __getstate__(self):
super()._getstate_prepare()

return {
"root": self.root,
"creds": self.creds,
Expand All @@ -206,6 +204,7 @@ def __getstate__(self):
"db_engine": self.db_engine,
"repository": self.repository,
"expiration": self.expiration,
"_temp_data": self._temp_data,
}

def __setstate__(self, state):
Expand Down
32 changes: 15 additions & 17 deletions deeplake/core/storage/gcs.py
Expand Up @@ -248,6 +248,7 @@ def __init__(
Raises:
ModuleNotFoundError: If google cloud packages aren't installed.
"""
super().__init__()

try:
import google.cloud.storage # type: ignore
Expand Down Expand Up @@ -323,9 +324,10 @@ def _set_bucket_and_path(self):
def _get_path_from_key(self, key):
return posixpath.join(self.path, key)

def _all_keys(self):
def _all_keys_impl(self, refresh: bool = False):
self._blob_objects = self.client_bucket.list_blobs(prefix=self.path)
return {posixpath.relpath(obj.name, self.path) for obj in self._blob_objects}
all = {posixpath.relpath(obj.name, self.path) for obj in self._blob_objects}
return [f for f in all if not f.endswith("/")]

def _set_hub_creds_info(
self,
Expand All @@ -349,7 +351,7 @@ def _set_hub_creds_info(
self.db_engine = db_engine
self.repository = repository

def clear(self, prefix=""):
def _clear_impl(self, prefix=""):
"""Remove all keys with given prefix below root - empties out mapping.

Warning:
Expand Down Expand Up @@ -384,11 +386,11 @@ def rename(self, root):
if not self.path.endswith("/"):
self.path += "/"

def __getitem__(self, key):
def _getitem_impl(self, key):
"""Retrieve data."""
return self.get_bytes(key)
return self._get_bytes_impl(key)

def get_bytes(
def _get_bytes_impl(
self,
path: str,
start_byte: Optional[int] = None,
Expand Down Expand Up @@ -418,7 +420,7 @@ def get_bytes(
except self.missing_exceptions:
raise KeyError(path)

def __setitem__(self, key, value):
def _setitem_impl(self, key: str, value: bytes):
"""Store value in key."""
self.check_readonly()
blob = self.client_bucket.blob(self._get_path_from_key(key))
Expand All @@ -428,15 +430,7 @@ def __setitem__(self, key, value):
value = bytes(value)
blob.upload_from_string(value, retry=self.retry)

def __iter__(self):
"""Iterating over the structure."""
yield from [f for f in self._all_keys() if not f.endswith("/")]

def __len__(self):
"""Returns length of the structure."""
return len(self._all_keys())

def __delitem__(self, key):
def _delitem_impl(self, key):
"""Remove key."""
self.check_readonly()
blob = self.client_bucket.blob(self._get_path_from_key(key))
Expand All @@ -445,7 +439,7 @@ def __delitem__(self, key):
except self.missing_exceptions:
raise KeyError(key)

def __contains__(self, key):
def _contains_impl(self, key):
"""Checks if key exists in mapping."""
from google.cloud import storage # type: ignore

Expand All @@ -455,6 +449,8 @@ def __contains__(self, key):
return stats

def __getstate__(self):
super()._getstate_prepare()

return (
self.root,
self.token,
Expand All @@ -463,6 +459,7 @@ def __getstate__(self):
self.read_only,
self.db_engine,
self.repository,
self._temp_data,
)

def __setstate__(self, state):
Expand All @@ -473,6 +470,7 @@ def __setstate__(self, state):
self.read_only = state[4]
self.db_engine = state[5]
self.repository = state[6]
self._temp_data = state[7]
self._initialize_provider()

def get_presigned_url(self, key, full=False):
Expand Down
21 changes: 10 additions & 11 deletions deeplake/core/storage/google_drive.py
Expand Up @@ -106,6 +106,7 @@ def __init__(
- Due to limits on requests per 100 seconds on google drive api, continuous requests such as uploading many small files can be slow.
- Users can request to increse their quotas on their google cloud platform.
"""
super().__init__()
try:
import googleapiclient # type: ignore
from google.auth.transport.requests import Request # type: ignore
Expand Down Expand Up @@ -279,7 +280,7 @@ def get_object_by_id(self, id):
file.seek(0)
return file.read()

def __getitem__(self, path):
def _getitem_impl(self, path):
id = self._get_id(path)
if not id:
raise KeyError(path)
Expand Down Expand Up @@ -307,7 +308,7 @@ def _unlock_creation(self, path):
lock_hash = "." + hash_inputs(self.root_id, path)
os.remove(lock_hash)

def __setitem__(self, path, content):
def _setitem_impl(self, path, content):
self.check_readonly()
id = self._get_id(path)
if not id:
Expand All @@ -330,20 +331,23 @@ def __setitem__(self, path, content):
self._write_to_file(id, content)
return

def __delitem__(self, path):
def _delitem_impl(self, path):
self.check_readonly()
id = self._pop_id(path)
if not id:
raise KeyError(path)
self._delete_file(id)

def __getstate__(self):
super()._getstate_prepare()

return (
self.root,
self.root_id,
self.client_id,
self.client_secret,
self.refresh_token,
self._temp_data,
)

def __setstate__(self, state):
Expand All @@ -352,19 +356,14 @@ def __setstate__(self, state):
self.client_id = state[2]
self.client_secret = state[3]
self.refresh_token = state[4]
self._temp_data = state[5]
self._init_from_state()

def _all_keys(self):
def _all_keys_impl(self, refresh: bool = False):
keys = set(self.gid.path_id_map.keys())
return keys

def __iter__(self):
yield from self._all_keys()

def __len__(self):
return len(self._all_keys())

def clear(self, prefix=""):
def _clear_impl(self, prefix=""):
self.check_readonly()
for key in self._all_keys():
if key.startswith(prefix):
Expand Down