Skip to content

Commit

Permalink
Merge pull request #2728 from activeloopai/fy_faster_import
Browse files Browse the repository at this point in the history
Faster import
  • Loading branch information
FayazRahman committed Jan 4, 2024
2 parents 9deea3d + f78e263 commit 8582788
Show file tree
Hide file tree
Showing 17 changed files with 141 additions and 132 deletions.
9 changes: 0 additions & 9 deletions deeplake/api/tests/test_api.py
Expand Up @@ -1968,15 +1968,6 @@ def test_hub_exists(ds_generator, path, hub_token, convert_to_pathlib):
assert deeplake.exists(f"{path}_does_not_exist", token=hub_token) == False


def test_pyav_not_installed(local_ds, video_paths):
pyav_installed = deeplake.core.compression._PYAV_INSTALLED
deeplake.core.compression._PYAV_INSTALLED = False
local_ds.create_tensor("videos", htype="video", sample_compression="mp4")
with pytest.raises(SampleAppendError):
local_ds.videos.append(deeplake.read(video_paths["mp4"][0]))
deeplake.core.compression._PYAV_INSTALLED = pyav_installed


@pytest.mark.slow
def test_partial_read_then_write(s3_ds_generator):
ds = s3_ds_generator()
Expand Down
3 changes: 1 addition & 2 deletions deeplake/auto/structured/dataframe.py
Expand Up @@ -12,8 +12,6 @@
from deeplake.core.linked_sample import LinkedSample
import pathlib

import pandas as pd # type: ignore


from deeplake.client.log import logger

Expand Down Expand Up @@ -127,6 +125,7 @@ def _parse_tensor_params(self, key: str, inspect_limit: int = 1000):

def _get_extend_values(self, tensor_params: dict, key: str): # type: ignore
"""Method creates a list of values to be extended to the tensor, based on the tensor parameters and the data in the dataframe column"""
import pandas as pd # type: ignore

column_data = self.source[key]
column_data = column_data.where(pd.notnull(column_data), None).values.tolist()
Expand Down
41 changes: 22 additions & 19 deletions deeplake/core/compression.py
Expand Up @@ -36,22 +36,6 @@
from pathlib import Path
from gzip import GzipFile

try:
import av # type: ignore

_PYAV_INSTALLED = True
except ImportError:
_PYAV_INSTALLED = False


try:
import nibabel as nib # type: ignore
from nibabel import FileHolder, Nifti1Image, Nifti2Image # type: ignore

_NIBABEL_INSTALLED = True
except ImportError:
_NIBABEL_INSTALLED = False

if sys.byteorder == "little":
_NATIVE_INT32 = "<i4"
_NATIVE_FLOAT32 = "<f4"
Expand Down Expand Up @@ -864,10 +848,13 @@ def _frame_to_stamp(nframe, stream):


def _open_video(file: Union[str, bytes, memoryview]):
if not _PYAV_INSTALLED:
try:
import av # type: ignore
except ImportError:
raise ModuleNotFoundError(
"PyAV is not installed. Run `pip install deeplake[video]`."
)

if isinstance(file, str):
container = av.open(
file, options={"protocol_whitelist": "file,http,https,tcp,tls,subfile"}
Expand All @@ -886,6 +873,8 @@ def _open_video(file: Union[str, bytes, memoryview]):


def _read_metadata_from_vstream(container, vstream):
import av

duration = vstream.duration
if duration is None:
duration = container.duration
Expand Down Expand Up @@ -920,6 +909,8 @@ def _decompress_video(
step: int,
reverse: bool,
):
import av

container, vstream = _open_video(file)
nframes, height, width, _ = _read_metadata_from_vstream(container, vstream)[0]

Expand Down Expand Up @@ -983,6 +974,8 @@ def _read_timestamps(
step: int,
reverse: bool,
) -> np.ndarray:
import av

container, vstream = _open_video(file)

nframes = math.ceil((stop - start) / step)
Expand Down Expand Up @@ -1039,10 +1032,13 @@ def _read_timestamps(


def _open_audio(file: Union[str, bytes, memoryview]):
if not _PYAV_INSTALLED:
try:
import av
except ImportError:
raise ModuleNotFoundError(
"PyAV is not installed. Please run `pip install deeplake[audio]`"
)

if isinstance(file, str):
container = av.open(
file, options={"protocol_whitelist": "file,http,https,tcp,tls,subfile"}
Expand All @@ -1061,6 +1057,8 @@ def _open_audio(file: Union[str, bytes, memoryview]):


def _read_shape_from_astream(container, astream):
import av

nchannels = astream.channels
duration = astream.duration
if duration is None:
Expand Down Expand Up @@ -1090,6 +1088,8 @@ def _read_audio_shape(
def _read_audio_meta(
file: Union[bytes, memoryview, str],
) -> dict:
import av

container, astream = _open_audio(file)
meta = {}
if astream.duration:
Expand Down Expand Up @@ -1163,7 +1163,10 @@ def _read_3d_data_meta(file: Union[bytes, memoryview, str]):


def _open_nifti(file: Union[bytes, memoryview, str], gz: bool = False):
if not _NIBABEL_INSTALLED:
try:
import nibabel as nib # type: ignore
from nibabel import FileHolder, Nifti1Image, Nifti2Image # type: ignore
except ImportError:
raise ModuleNotFoundError(
"nibabel is not installed. Please run `pip install deeplake[medical]`"
)
Expand Down
40 changes: 21 additions & 19 deletions deeplake/core/storage/azure.py
Expand Up @@ -11,28 +11,17 @@
from deeplake.util.path import relpath
from concurrent import futures

try:
from azure.identity import DefaultAzureCredential # type: ignore
from azure.storage.blob import ( # type: ignore
BlobServiceClient,
BlobSasPermissions,
ContainerSasPermissions,
generate_blob_sas,
generate_container_sas,
)
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential

logger = logging.getLogger("azure.identity")
logger.setLevel(logging.ERROR)

_AZURE_PACKAGES_INSTALLED = True
except ImportError:
_AZURE_PACKAGES_INSTALLED = False


class AzureProvider(StorageProvider):
def __init__(self, root: str, creds: Dict = {}, token: Optional[str] = None):
if not _AZURE_PACKAGES_INSTALLED:
try:
import azure.identity
import azure.storage.blob
import azure.core

logger = logging.getLogger("azure.identity")
logger.setLevel(logging.ERROR)
except ImportError:
raise ImportError(
"Azure packages not installed. Run `pip install deeplake[azure]`."
)
Expand Down Expand Up @@ -69,6 +58,9 @@ def _get_attrs(self, path: str) -> Tuple[str, str, str]:
return account_name, container_name, root_folder

def _set_credential(self, creds: Dict[str, str]):
from azure.identity import DefaultAzureCredential # type: ignore
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential

self.account_name = (
creds.get("account_name") or self.account_name
) # account name in creds can override account name in path
Expand All @@ -86,6 +78,8 @@ def _set_credential(self, creds: Dict[str, str]):
self.credential = DefaultAzureCredential()

def _set_clients(self):
from azure.storage.blob import BlobServiceClient # type: ignore

self.blob_service_client = BlobServiceClient(
self.account_url, credential=self.credential
)
Expand Down Expand Up @@ -162,6 +156,8 @@ def clear(self, prefix=""):
self.container_client.delete_blobs(*batch)

def get_sas_token(self):
from azure.storage.blob import generate_container_sas, ContainerSasPermissions # type: ignore

self._check_update_creds()
if self.sas_token:
return self.sas_token
Expand Down Expand Up @@ -288,6 +284,8 @@ def get_object_size(self, path: str) -> int:
return blob_client.get_blob_properties().size

def get_clients_from_full_path(self, url: str):
from azure.storage.blob import BlobServiceClient # type: ignore

self._check_update_creds()
account_name, container_name, blob_path = self._get_attrs(url)
account_url = f"https://{account_name}.blob.core.windows.net"
Expand All @@ -298,6 +296,8 @@ def get_clients_from_full_path(self, url: str):
return blob_client, blob_service_client

def get_presigned_url(self, path: str, full: bool = False) -> str:
from azure.storage.blob import BlobSasPermissions, generate_blob_sas # type: ignore

self._check_update_creds()
if full:
blob_client, blob_service_client = self.get_clients_from_full_path(path)
Expand Down Expand Up @@ -356,6 +356,8 @@ def _check_update_creds(self, force=False):
"""If the client has an expiration time, check if creds are expired and fetch new ones.
This would only happen for datasets stored on Deep Lake storage for which temporary 12 hour credentials are generated.
"""
from azure.core.credentials import AzureSasCredential # type: ignore

if self.expiration and (
force or float(self.expiration) < datetime.now(timezone.utc).timestamp()
):
Expand Down
39 changes: 23 additions & 16 deletions deeplake/core/storage/gcs.py
Expand Up @@ -9,21 +9,6 @@

from deeplake.util.path import relpath

try:
from google.cloud import storage # type: ignore
from google.api_core import retry # type: ignore
from google.oauth2 import service_account # type: ignore
import google.auth as gauth # type: ignore
import google.auth.compute_engine # type: ignore
import google.auth.credentials # type: ignore
import google.auth.exceptions # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from google.api_core.exceptions import NotFound # type: ignore

_GOOGLE_PACKAGES_INSTALLED = True
except ImportError:
_GOOGLE_PACKAGES_INSTALLED = False


from deeplake.core.storage.provider import StorageProvider
from deeplake.util.exceptions import (
Expand Down Expand Up @@ -70,6 +55,8 @@ def _connect_google_default(self):
ValueError: If the name of the default project doesn't match the GCSProvider project name.
DefaultCredentialsError: If no credentials are found.
"""
import google.auth as gauth # type: ignore

credentials, project = gauth.default(scopes=[self.scope])
if self.project and self.project != project:
raise ValueError(
Expand Down Expand Up @@ -113,6 +100,8 @@ def _connect_token(self, token: Optional[Union[str, Dict]] = None):
FileNotFoundError: If token file doesn't exist.
ValueError: If token format isn't supported by gauth.
"""
import google.auth.credentials # type: ignore

if isinstance(token, str):
if not os.path.exists(token):
raise FileNotFoundError(token)
Expand All @@ -132,6 +121,8 @@ def _connect_token(self, token: Optional[Union[str, Dict]] = None):
self.credentials = credentials

def _connect_service(self, fn):
from google.oauth2 import service_account # type: ignore

credentials = service_account.Credentials.from_service_account_file(
fn, scopes=[self.scope]
)
Expand All @@ -145,6 +136,8 @@ def _connect_browser(self):
Raises:
GCSDefaultCredsNotFoundError: if application deafault credentials can't be found.
"""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore

try:
if os.name == "nt":
path = os.path.join(
Expand Down Expand Up @@ -255,10 +248,19 @@ def __init__(
Raises:
ModuleNotFoundError: If google cloud packages aren't installed.
"""
if not _GOOGLE_PACKAGES_INSTALLED:

try:
import google.cloud.storage # type: ignore
import google.api_core # type: ignore
import google.oauth2 # type: ignore
import google.auth # type: ignore
import google_auth_oauthlib # type: ignore
from google.api_core.exceptions import NotFound # type: ignore
except ImportError:
raise ModuleNotFoundError(
"Google cloud packages are not installed. Run `pip install deeplake[gcp]`."
)

self.root = root
self.token: Union[str, Dict, None] = token
self.tag: Optional[str] = None
Expand Down Expand Up @@ -290,6 +292,9 @@ def subdir(self, path: str, read_only: bool = False):
return sd

def _initialize_provider(self):
from google.cloud import storage # type: ignore
from google.api_core import retry # type: ignore

self._set_bucket_and_path()
if not self.token:
self.token = None
Expand Down Expand Up @@ -442,6 +447,8 @@ def __delitem__(self, key):

def __contains__(self, key):
"""Checks if key exists in mapping."""
from google.cloud import storage # type: ignore

stats = storage.Blob(
bucket=self.client_bucket, name=self._get_path_from_key(key)
).exists(self.client_bucket.client)
Expand Down

0 comments on commit 8582788

Please sign in to comment.