Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2788 from activeloopai/auth-context
Add basic Auth context
- Loading branch information
Showing
12 changed files
with
209 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
|
||
from deeplake.client.auth.auth_context import AuthContext, AuthProviderType | ||
from deeplake.client.auth.activeloop import ActiveLoopAuthContext | ||
from deeplake.client.auth.azure import AzureAuthContext | ||
from deeplake.client.config import DEEPLAKE_AUTH_PROVIDER | ||
|
||
|
||
def initialize_auth_context(*args, **kwargs) -> AuthContext: | ||
if ( | ||
os.environ.get(DEEPLAKE_AUTH_PROVIDER, "").lower() | ||
== AuthProviderType.AZURE.value.lower() | ||
): | ||
return AzureAuthContext() | ||
|
||
return ActiveLoopAuthContext(*args, **kwargs) | ||
|
||
|
||
__all__ = ["initialize_auth_context"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
from typing import Optional | ||
|
||
from deeplake.client.auth.auth_context import AuthContext, AuthProviderType | ||
from deeplake.client.config import DEEPLAKE_AUTH_TOKEN | ||
|
||
|
||
class ActiveLoopAuthContext(AuthContext): | ||
def __init__(self, token: Optional[str] = None): | ||
self.token = token | ||
|
||
def get_token(self) -> Optional[str]: | ||
if self.token is None: | ||
self.authenticate() | ||
|
||
return self.token | ||
|
||
def authenticate(self) -> None: | ||
self.token = ( | ||
self.token | ||
or os.environ.get(DEEPLAKE_AUTH_TOKEN) | ||
or "PUBLIC_TOKEN_" + ("_" * 150) | ||
) | ||
|
||
def get_provider_type(self) -> AuthProviderType: | ||
return AuthProviderType.ACTIVELOOP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from enum import Enum | ||
from typing import Optional | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class AuthProviderType(Enum): | ||
ACTIVELOOP = "activeloop" | ||
AZURE = "azure" | ||
|
||
|
||
class AuthContext(ABC): | ||
def get_auth_headers(self) -> dict: | ||
return { | ||
"Authorization": f"Bearer {self.get_token()}", | ||
"X-Activeloop-Provider-Type": self.get_provider_type().value, | ||
} | ||
|
||
@abstractmethod | ||
def get_token(self) -> Optional[str]: | ||
pass | ||
|
||
@abstractmethod | ||
def authenticate(self) -> None: | ||
""" | ||
Try to authenticate using the necessary configuration. | ||
If the authentication fails, an `InvalidAuthContext` exception should be raised. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_provider_type(self) -> AuthProviderType: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
from datetime import datetime, timedelta | ||
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential | ||
from azure.core.exceptions import ClientAuthenticationError | ||
|
||
from deeplake.client.auth.auth_context import AuthContext, AuthProviderType | ||
from deeplake.util.exceptions import InvalidAuthContextError | ||
|
||
ID_TOKEN_CACHE_MINUTES = 5 | ||
|
||
|
||
class AzureAuthContext(AuthContext): | ||
def __init__(self): | ||
self.credential = self._get_azure_credential() | ||
self.token = None | ||
self._last_auth_time = None | ||
|
||
def _get_azure_credential(self): | ||
azure_keys = [i for i in os.environ if i.startswith("AZURE_")] | ||
if "AZURE_CLIENT_ID" in azure_keys and len(azure_keys) == 1: | ||
# Explicitly set client_id, to avoid any warnings coming from DefaultAzureCredential | ||
return ManagedIdentityCredential( | ||
client_id=os.environ.get("AZURE_CLIENT_ID") | ||
) | ||
|
||
return DefaultAzureCredential() | ||
|
||
def get_token(self) -> str: | ||
self.authenticate() | ||
|
||
return self.token | ||
|
||
def authenticate(self) -> None: | ||
if ( | ||
self._last_auth_time is not None | ||
and datetime.now() - self._last_auth_time | ||
< timedelta(minutes=ID_TOKEN_CACHE_MINUTES) | ||
): | ||
return | ||
|
||
try: | ||
response = self.credential.get_token( | ||
"https://management.azure.com/.default" | ||
) | ||
self.token = response.token | ||
self._last_auth_time = datetime.now() | ||
except ClientAuthenticationError as e: | ||
raise InvalidAuthContextError( | ||
f"Failed to authenticate with Azure. Please check your credentials. \n {e.message}", | ||
) from e | ||
except Exception as e: | ||
raise InvalidAuthContextError( | ||
"Failed to authenticate with Azure. An unexpected error occured." | ||
) from e | ||
|
||
def get_provider_type(self) -> AuthProviderType: | ||
return AuthProviderType.AZURE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import pytest | ||
|
||
from deeplake.client.auth import AuthProviderType, initialize_auth_context | ||
from deeplake.client.auth.azure import AzureAuthContext | ||
from deeplake.client.config import DEEPLAKE_AUTH_PROVIDER | ||
|
||
|
||
def test_initialize_auth_context(): | ||
context = initialize_auth_context(token="dummy") | ||
assert context.get_provider_type() == AuthProviderType.ACTIVELOOP | ||
|
||
context = initialize_auth_context() | ||
assert context.get_provider_type() == AuthProviderType.ACTIVELOOP | ||
|
||
os.environ[DEEPLAKE_AUTH_PROVIDER] = "dummy" | ||
context = initialize_auth_context() | ||
assert context.get_provider_type() == AuthProviderType.ACTIVELOOP | ||
|
||
os.environ[DEEPLAKE_AUTH_PROVIDER] = "azure" | ||
context = initialize_auth_context() | ||
assert isinstance(context, AzureAuthContext) | ||
assert context.get_provider_type() == AuthProviderType.AZURE | ||
|
||
os.environ[DEEPLAKE_AUTH_PROVIDER] = "AzUrE" | ||
context = initialize_auth_context() | ||
assert isinstance(context, AzureAuthContext) | ||
assert context.get_provider_type() == AuthProviderType.AZURE | ||
|
||
del os.environ[DEEPLAKE_AUTH_PROVIDER] | ||
|
||
|
||
@pytest.mark.skip( | ||
reason="This test requires not having Azure credentials and fails in the CI environment." | ||
) | ||
def test_azure_auth_context_exceptions(): | ||
azure_envs = [i for i in os.environ if i.startswith("AZURE_")] | ||
values = {i: os.environ[i] for i in azure_envs} | ||
context = AzureAuthContext() | ||
|
||
for key in azure_envs: | ||
del os.environ[key] | ||
|
||
with pytest.raises(Exception): | ||
context.authenticate() | ||
|
||
for key, value in values.items(): | ||
os.environ[key] = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.