Skip to content

Commit

Permalink
Azure Managed Credentials Bug Fix (#2718)
Browse files Browse the repository at this point in the history
* changing vectorstore's exec_option autocasting logic to use root instead of path

* fixing azure + managed credentials

* adding tests

* removing uneccessary changes

* fixing azure_creds_key

* increasing code coverage
  • Loading branch information
AdkSarsen committed Dec 12, 2023
1 parent fdc0b0b commit 2d36d92
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 20 deletions.
6 changes: 5 additions & 1 deletion deeplake/constants.py
Expand Up @@ -90,7 +90,11 @@

ENV_AZURE_CLIENT_ID = "AZURE_CLIENT_ID"
ENV_AZURE_TENANT_ID = "AZURE_TENANT_ID"
ENV_AZURE_SUBSCRIPTION_ID = "AZURE_SUBSCRIPTION_ID"
ENV_AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"

ENV_AWS_ACCESS_KEY = "AWS_ACCESS_KEY"
ENV_AWS_SECRETS_ACCESS_KEY = "AWS_SECRETS_ACCESS_KEY"
ENV_AWS_ENDPOINT_URL = "ENDPOINT_URL"

ENV_GDRIVE_CLIENT_ID = "GDRIVE_CLIENT_ID"
ENV_GDRIVE_CLIENT_SECRET = "GDRIVE_CLIENT_SECRET"
Expand Down
27 changes: 9 additions & 18 deletions deeplake/core/vectorstore/test_deeplake_vectorstore.py
Expand Up @@ -20,6 +20,7 @@
from deeplake.constants import (
DEFAULT_VECTORSTORE_TENSORS,
DEFAULT_VECTORSTORE_DISTANCE_METRIC,
HUB_CLOUD_DEV_USERNAME,
)
from deeplake.constants import MB
from deeplake.util.exceptions import (
Expand Down Expand Up @@ -2795,36 +2796,26 @@ def test_exec_option_cli(

@requires_libdeeplake
@pytest.mark.parametrize(
"path",
"path, creds",
[
"s3_path",
"gcs_path",
"azure_path",
("s3_path", "s3_creds"),
("gcs_path", "gcs_creds"),
("azure_path", "azure_creds"),
],
indirect=True,
)
def test_exec_option_with_connected_datasets(
hub_cloud_dev_token,
hub_cloud_path,
hub_cloud_dev_managed_creds_key,
path,
creds,
):
runner = CliRunner()

db = VectorStore(path, overwrite=True)
assert db.exec_option == "python"

runner.invoke(login, f"-t {hub_cloud_dev_token}")
assert db.exec_option == "python"
db = VectorStore(path, overwrite=True, creds=creds)

db.dataset_handler.dataset.connect(
creds_key=hub_cloud_dev_managed_creds_key,
creds_key=creds,
dest_path=hub_cloud_path,
token=hub_cloud_dev_token,
)
db.dataset_handler.dataset.add_creds_key(
hub_cloud_dev_managed_creds_key, managed=True
)
db.dataset_handler.dataset.add_creds_key(creds, managed=True)
assert db.exec_option == "compute_engine"


Expand Down
2 changes: 1 addition & 1 deletion deeplake/core/vectorstore/vector_search/utils.py
Expand Up @@ -55,7 +55,7 @@ def get_exec_option(self):
# option 1: dataset is created in vector_db:
if (
isinstance(self.dataset, DeepLakeCloudDataset)
and "vectordb/" in self.dataset.base_storage.path
and "vectordb/" in self.dataset.base_storage.root
and self.token is not None
):
return "tensor_db"
Expand Down
20 changes: 20 additions & 0 deletions deeplake/tests/client_fixtures.py
@@ -1,6 +1,10 @@
from deeplake.constants import (
ENV_HUB_DEV_MANAGED_CREDS_KEY,
AZURE_OPT,
HUB_CLOUD_OPT,
ENV_AZURE_CLIENT_ID,
ENV_AZURE_CLIENT_SECRET,
ENV_AZURE_TENANT_ID,
ENV_HUB_DEV_USERNAME,
ENV_HUB_DEV_PASSWORD,
ENV_KAGGLE_USERNAME,
Expand Down Expand Up @@ -74,3 +78,19 @@ def hub_cloud_dev_managed_creds_key(request):

creds_key = os.getenv(ENV_HUB_DEV_MANAGED_CREDS_KEY, "aws_creds")
return creds_key


@pytest.fixture(scope="session")
def azure_creds_key(request):
if not is_opt_true(
request,
AZURE_OPT,
):
pytest.skip(f"{AZURE_OPT} flag not set")

creds_key = {
"azure_client_id": os.getenv(ENV_AZURE_CLIENT_ID),
"azure_tenant_id": os.getenv(ENV_AZURE_TENANT_ID),
"azure_client_secret": os.getenv(ENV_AZURE_CLIENT_SECRET),
}
return creds_key
52 changes: 52 additions & 0 deletions deeplake/tests/path_fixtures.py
Expand Up @@ -27,6 +27,12 @@
ENV_GDRIVE_CLIENT_SECRET,
ENV_GDRIVE_REFRESH_TOKEN,
HUB_CLOUD_DEV_USERNAME,
ENV_AZURE_CLIENT_ID,
ENV_AZURE_TENANT_ID,
ENV_AZURE_CLIENT_SECRET,
ENV_AWS_ACCESS_KEY,
ENV_AWS_SECRETS_ACCESS_KEY,
ENV_AWS_ENDPOINT_URL,
)
from deeplake import VectorStore
from deeplake.client.client import DeepMemoryBackendClient
Expand Down Expand Up @@ -306,6 +312,23 @@ def s3_path(request):
S3Provider(path).clear()


@pytest.fixture
def s3_creds(request):
if not is_opt_true(request, S3_OPT):
pytest.skip(f"{S3_OPT} flag not set")
return

aws_access_key = os.environ.get(ENV_AWS_ACCESS_KEY)
aws_secrets_key = os.environ.get(ENV_AWS_SECRETS_ACCESS_KEY)
endpoint_url = os.environ.get(ENV_AWS_ENDPOINT_URL)
creds = {
"aws_access_key": aws_access_key,
"aws_secrets_key": aws_secrets_key,
"endpoint_url": endpoint_url,
}
return creds


@pytest.fixture
def s3_vstream_path(request):
if not is_opt_true(request, S3_OPT):
Expand Down Expand Up @@ -384,6 +407,23 @@ def gdrive_creds():
return creds


@pytest.fixture
def azure_creds(request):
if not is_opt_true(request, AZURE_OPT):
pytest.skip(f"{AZURE_OPT} flag not set")
return

azure_client_id = os.environ.get(ENV_AZURE_CLIENT_ID)
azure_tenant_id = os.environ.get(ENV_AZURE_TENANT_ID)
azure_client_secret = os.environ.get(ENV_AZURE_CLIENT_SECRET)
creds = {
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_client_secret": azure_client_secret,
}
return creds


@pytest.fixture
def gdrive_path(request, gdrive_creds):
if not is_opt_true(request, GDRIVE_OPT):
Expand Down Expand Up @@ -691,6 +731,18 @@ def path(request):
return request.getfixturevalue(request.param)


@pytest.fixture
def dest_path(request):
"""Used with parametrize to get all dataset paths."""
return request.getfixturevalue(request.param)


@pytest.fixture
def creds(request):
"""Used with parametrize to get all dataset creds."""
return request.getfixturevalue(request.param)


@pytest.fixture
def hub_token(request):
"""Used with parametrize to get hub_cloud_dev_token if hub-cloud option is True else None"""
Expand Down

0 comments on commit 2d36d92

Please sign in to comment.