Skip to content

Commit

Permalink
Merge pull request #2788 from activeloopai/auth-context
Browse files Browse the repository at this point in the history
Add basic Auth context
  • Loading branch information
dgaloop committed Mar 19, 2024
2 parents 712dfb0 + e02ec88 commit 7536e9a
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 80 deletions.
1 change: 0 additions & 1 deletion deeplake/api/dataset.py
Expand Up @@ -11,7 +11,6 @@
from deeplake.auto.unstructured.image_classification import ImageClassification
from deeplake.auto.unstructured.coco.coco import CocoDataset
from deeplake.auto.unstructured.yolo.yolo import YoloDataset
from deeplake.client.client import DeepLakeBackendClient
from deeplake.client.log import logger
from deeplake.core.dataset import Dataset, dataset_factory
from deeplake.core.tensor import Tensor
Expand Down
19 changes: 19 additions & 0 deletions deeplake/client/auth/__init__.py
@@ -0,0 +1,19 @@
import os

from deeplake.client.auth.auth_context import AuthContext, AuthProviderType
from deeplake.client.auth.activeloop import ActiveLoopAuthContext
from deeplake.client.auth.azure import AzureAuthContext
from deeplake.client.config import DEEPLAKE_AUTH_PROVIDER


def initialize_auth_context(*args, **kwargs) -> AuthContext:
if (
os.environ.get(DEEPLAKE_AUTH_PROVIDER, "").lower()
== AuthProviderType.AZURE.value.lower()
):
return AzureAuthContext()

return ActiveLoopAuthContext(*args, **kwargs)


__all__ = ["initialize_auth_context"]
26 changes: 26 additions & 0 deletions deeplake/client/auth/activeloop.py
@@ -0,0 +1,26 @@
import os
from typing import Optional

from deeplake.client.auth.auth_context import AuthContext, AuthProviderType
from deeplake.client.config import DEEPLAKE_AUTH_TOKEN


class ActiveLoopAuthContext(AuthContext):
def __init__(self, token: Optional[str] = None):
self.token = token

def get_token(self) -> Optional[str]:
if self.token is None:
self.authenticate()

return self.token

def authenticate(self) -> None:
self.token = (
self.token
or os.environ.get(DEEPLAKE_AUTH_TOKEN)
or "PUBLIC_TOKEN_" + ("_" * 150)
)

def get_provider_type(self) -> AuthProviderType:
return AuthProviderType.ACTIVELOOP
32 changes: 32 additions & 0 deletions deeplake/client/auth/auth_context.py
@@ -0,0 +1,32 @@
from enum import Enum
from typing import Optional
from abc import ABC, abstractmethod


class AuthProviderType(Enum):
ACTIVELOOP = "activeloop"
AZURE = "azure"


class AuthContext(ABC):
def get_auth_headers(self) -> dict:
return {
"Authorization": f"Bearer {self.get_token()}",
"X-Activeloop-Provider-Type": self.get_provider_type().value,
}

@abstractmethod
def get_token(self) -> Optional[str]:
pass

@abstractmethod
def authenticate(self) -> None:
"""
Try to authenticate using the necessary configuration.
If the authentication fails, an `InvalidAuthContext` exception should be raised.
"""
pass

@abstractmethod
def get_provider_type(self) -> AuthProviderType:
pass
57 changes: 57 additions & 0 deletions deeplake/client/auth/azure.py
@@ -0,0 +1,57 @@
import os
from datetime import datetime, timedelta
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
from azure.core.exceptions import ClientAuthenticationError

from deeplake.client.auth.auth_context import AuthContext, AuthProviderType
from deeplake.util.exceptions import InvalidAuthContextError

ID_TOKEN_CACHE_MINUTES = 5


class AzureAuthContext(AuthContext):
def __init__(self):
self.credential = self._get_azure_credential()
self.token = None
self._last_auth_time = None

def _get_azure_credential(self):
azure_keys = [i for i in os.environ if i.startswith("AZURE_")]
if "AZURE_CLIENT_ID" in azure_keys and len(azure_keys) == 1:
# Explicitly set client_id, to avoid any warnings coming from DefaultAzureCredential
return ManagedIdentityCredential(
client_id=os.environ.get("AZURE_CLIENT_ID")
)

return DefaultAzureCredential()

def get_token(self) -> str:
self.authenticate()

return self.token

def authenticate(self) -> None:
if (
self._last_auth_time is not None
and datetime.now() - self._last_auth_time
< timedelta(minutes=ID_TOKEN_CACHE_MINUTES)
):
return

try:
response = self.credential.get_token(
"https://management.azure.com/.default"
)
self.token = response.token
self._last_auth_time = datetime.now()
except ClientAuthenticationError as e:
raise InvalidAuthContextError(
f"Failed to authenticate with Azure. Please check your credentials. \n {e.message}",
) from e
except Exception as e:
raise InvalidAuthContextError(
"Failed to authenticate with Azure. An unexpected error occured."
) from e

def get_provider_type(self) -> AuthProviderType:
return AuthProviderType.AZURE
48 changes: 48 additions & 0 deletions deeplake/client/auth/test_auth_context.py
@@ -0,0 +1,48 @@
import os
import pytest

from deeplake.client.auth import AuthProviderType, initialize_auth_context
from deeplake.client.auth.azure import AzureAuthContext
from deeplake.client.config import DEEPLAKE_AUTH_PROVIDER


def test_initialize_auth_context():
context = initialize_auth_context(token="dummy")
assert context.get_provider_type() == AuthProviderType.ACTIVELOOP

context = initialize_auth_context()
assert context.get_provider_type() == AuthProviderType.ACTIVELOOP

os.environ[DEEPLAKE_AUTH_PROVIDER] = "dummy"
context = initialize_auth_context()
assert context.get_provider_type() == AuthProviderType.ACTIVELOOP

os.environ[DEEPLAKE_AUTH_PROVIDER] = "azure"
context = initialize_auth_context()
assert isinstance(context, AzureAuthContext)
assert context.get_provider_type() == AuthProviderType.AZURE

os.environ[DEEPLAKE_AUTH_PROVIDER] = "AzUrE"
context = initialize_auth_context()
assert isinstance(context, AzureAuthContext)
assert context.get_provider_type() == AuthProviderType.AZURE

del os.environ[DEEPLAKE_AUTH_PROVIDER]


@pytest.mark.skip(
reason="This test requires not having Azure credentials and fails in the CI environment."
)
def test_azure_auth_context_exceptions():
azure_envs = [i for i in os.environ if i.startswith("AZURE_")]
values = {i: os.environ[i] for i in azure_envs}
context = AzureAuthContext()

for key in azure_envs:
del os.environ[key]

with pytest.raises(Exception):
context.authenticate()

for key, value in values.items():
os.environ[key] = value
68 changes: 8 additions & 60 deletions deeplake/client/client.py
Expand Up @@ -27,9 +27,8 @@
HUB_REST_ENDPOINT,
HUB_REST_ENDPOINT_LOCAL,
HUB_REST_ENDPOINT_DEV,
GET_TOKEN_SUFFIX,
HUB_REST_ENDPOINT_TESTING,
HUB_REST_ENDPOINT_STAGING,
REGISTER_USER_SUFFIX,
DEFAULT_REQUEST_TIMEOUT,
GET_DATASET_CREDENTIALS_SUFFIX,
CREATE_DATASET_SUFFIX,
Expand All @@ -41,9 +40,9 @@
CONNECT_DATASET_SUFFIX,
REMOTE_QUERY_SUFFIX,
ORG_PERMISSION_SUFFIX,
DEEPLAKE_AUTH_TOKEN,
)
from deeplake.client.log import logger
from deeplake.client.auth import initialize_auth_context
import jwt # should add it to requirements.txt

# for these codes, we will retry requests upto 3 times
Expand All @@ -61,28 +60,21 @@ def __init__(self, token: Optional[str] = None):
)

self.version = deeplake.__version__
self.auth_header = None
self.token = (
token
or os.environ.get(DEEPLAKE_AUTH_TOKEN)
or "PUBLIC_TOKEN_" + ("_" * 150)
)
self.auth_header = f"Bearer {self.token}"
self.auth_context = initialize_auth_context(token=token)

# remove public token, otherwise env var will be ignored
# we can remove this after a while
orgs = self.get_user_organizations()
if orgs == ["public"]:
self.token = token or self.get_token()
self.auth_header = f"Bearer {self.token}"
else:
username = self.get_user_profile()["name"]
if get_reporting_config().get("username") != username:
save_reporting_config(True, username=username)
set_username(username)

def get_token(self):
return self.token
return self.auth_context.get_token()

def request(
self,
Expand Down Expand Up @@ -131,7 +123,7 @@ def request(
request_url = f"{endpoint}/{relative_url}"
headers = headers or {}
headers["hub-cli-version"] = self.version
headers["Authorization"] = self.auth_header
headers = {**headers, **self.auth_context.get_auth_headers()}

# clearer error than `ServerUnderMaintenence`
if json is not None and "password" in json and json["password"] is None:
Expand Down Expand Up @@ -161,48 +153,13 @@ def endpoint(self):
return HUB_REST_ENDPOINT_LOCAL
if deeplake.client.config.USE_DEV_ENVIRONMENT:
return HUB_REST_ENDPOINT_DEV
if deeplake.client.config.USE_TESTING_ENVIRONMENT:
return HUB_REST_ENDPOINT_TESTING
if deeplake.client.config.USE_STAGING_ENVIRONMENT:
return HUB_REST_ENDPOINT_STAGING

return HUB_REST_ENDPOINT

def request_auth_token(self, username: str, password: str):
"""Sends a request to backend to retrieve auth token.
Args:
username (str): The Activeloop username to request token for.
password (str): The password of the account.
Returns:
string: The auth token corresponding to the accound.
Raises:
UserNotLoggedInException: if user is not authorised
LoginException: If there is an issue retrieving the auth token.
"""
json = {"username": username, "password": password}
response = self.request("POST", GET_TOKEN_SUFFIX, json=json)

try:
token_dict = response.json()
token = token_dict["token"]
except Exception:
raise LoginException()
return token

def send_register_request(self, username: str, email: str, password: str):
"""Sends a request to backend to register a new user.
Args:
username (str): The Activeloop username to create account for.
email (str): The email id to link with the Activeloop account.
password (str): The new password of the account. Should be atleast 6 characters long.
"""

json = {"username": username, "email": email, "password": password}
self.request("POST", REGISTER_USER_SUFFIX, json=json)

def get_dataset_credentials(
self,
org_id: str,
Expand Down Expand Up @@ -248,7 +205,6 @@ def get_dataset_credentials(
).json()
except Exception as e:
if isinstance(e, AuthorizationException):
authorization_exception_prompt = "You don't have permission"
response_data = e.response.json()
code = response_data.get("code")
if code == 1:
Expand All @@ -259,18 +215,10 @@ def get_dataset_credentials(
raise NotLoggedInAgreementError from e
else:
try:
decoded_token = jwt.decode(
self.token, options={"verify_signature": False}
)
jwt.decode(self.token, options={"verify_signature": False})
except Exception:
raise InvalidTokenException

if (
authorization_exception_prompt.lower()
in response_data["description"].lower()
and decoded_token["id"] == "public"
):
raise UserNotLoggedInException()
raise TokenPermissionError()
raise
full_url = response.get("path")
Expand Down
14 changes: 3 additions & 11 deletions deeplake/client/config.py
Expand Up @@ -7,13 +7,13 @@
HUB_REST_ENDPOINT = "https://app.activeloop.ai"
HUB_REST_ENDPOINT_STAGING = "https://app-staging.activeloop.dev"
HUB_REST_ENDPOINT_DEV = "https://app-dev.activeloop.dev"
HUB_REST_ENDPOINT_TESTING = "https://testing.activeloop.dev"
HUB_REST_ENDPOINT_LOCAL = "http://localhost:7777"
USE_LOCAL_HOST = False
USE_DEV_ENVIRONMENT = False
USE_TESTING_ENVIRONMENT = False
USE_STAGING_ENVIRONMENT = False

GET_TOKEN_SUFFIX = "/api/user/token"
REGISTER_USER_SUFFIX = "/api/user/register"
GET_DATASET_CREDENTIALS_SUFFIX = "/api/org/{}/ds/{}/creds"
GET_PRESIGNED_URL_SUFFIX = "/api/org/{}/ds/{}/chunks/url/presigned"
CREATE_DATASET_SUFFIX = "/api/dataset/create"
Expand All @@ -30,13 +30,5 @@
DEFAULT_REQUEST_TIMEOUT = 170

DEEPLAKE_AUTH_TOKEN = "ACTIVELOOP_TOKEN"
DEEPLAKE_AUTH_PROVIDER = "ACTIVELOOP_AUTH_PROVIDER"
ORG_PERMISSION_SUFFIX = "/api/organizations/{}/features/dataset_query"

# ManagedService Endpoints
INIT_VECTORSTORE_SUFFIX = "/api/dlserver/vectorstore/init"
GET_VECTORSTORE_SUMMARY_SUFFIX = "/api/dlserver/vectorstore/{}/{}/summary"
DELETE_VECTORSTORE_SUFFIX = "/api/dlserver/vectorstore"

VECTORSTORE_SEARCH_SUFFIX = "/api/dlserver/vectorstore/search"
VECTORSTORE_ADD_SUFFIX = "/api/dlserver/vectorstore/add"
VECTORSTORE_REMOVE_ROWS_SUFFIX = "/api/dlserver/vectorstore/remove"
2 changes: 0 additions & 2 deletions deeplake/client/utils.py
Expand Up @@ -8,8 +8,6 @@

from deeplake.client.config import (
REPORTING_CONFIG_FILE_PATH,
TOKEN_FILE_PATH,
DEEPLAKE_AUTH_TOKEN,
)
from deeplake.util.exceptions import (
AuthenticationException,
Expand Down
7 changes: 3 additions & 4 deletions deeplake/integrations/mmdet/mmdet_.py
Expand Up @@ -203,7 +203,6 @@
from mmcv.parallel import collate # type: ignore
from functools import partial
from deeplake.integrations.pytorch.dataset import TorchDataset
from deeplake.client.client import DeepLakeBackendClient
from deeplake.core.ipc import _get_free_port
from mmdet.core import BitmapMasks # type: ignore
import deeplake as dp
Expand Down Expand Up @@ -794,9 +793,9 @@ def load_ds_from_cfg(cfg: mmcv.utils.config.ConfigDict):
if token is None:
uname = creds.get("username")
if uname is not None:
pword = creds["password"]
client = DeepLakeBackendClient()
token = client.request_auth_token(username=uname, password=pword)
raise NotImplementedError(
"Username/Password based authentication from deeplake has been deprecated. Please specify a token in the config."
)
ds_path = cfg.deeplake_path
ds = dp.load(ds_path, token=token, read_only=True)
deeplake_commit = cfg.get("deeplake_commit")
Expand Down

0 comments on commit 7536e9a

Please sign in to comment.