/
client.py
636 lines (564 loc) · 22.7 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
import os
import deeplake
import requests # type: ignore
import textwrap
from typing import Any, Optional, Dict, List, Union
from deeplake.util.exceptions import (
AgreementNotAcceptedError,
AuthorizationException,
LoginException,
InvalidPasswordException,
ManagedCredentialsNotFoundError,
NotLoggedInAgreementError,
ResourceNotFoundException,
InvalidTokenException,
UserNotLoggedInException,
TokenPermissionError,
)
from deeplake.client.utils import (
check_response_status,
JobResponseStatusSchema,
)
from deeplake.client.config import (
ACCEPT_AGREEMENTS_SUFFIX,
REJECT_AGREEMENTS_SUFFIX,
GET_MANAGED_CREDS_SUFFIX,
HUB_REST_ENDPOINT,
HUB_REST_ENDPOINT_LOCAL,
HUB_REST_ENDPOINT_DEV,
GET_TOKEN_SUFFIX,
HUB_REST_ENDPOINT_STAGING,
REGISTER_USER_SUFFIX,
DEFAULT_REQUEST_TIMEOUT,
GET_DATASET_CREDENTIALS_SUFFIX,
CREATE_DATASET_SUFFIX,
DATASET_SUFFIX,
GET_USER_PROFILE,
SEND_EVENT_SUFFIX,
UPDATE_SUFFIX,
GET_PRESIGNED_URL_SUFFIX,
CONNECT_DATASET_SUFFIX,
REMOTE_QUERY_SUFFIX,
ORG_PERMISSION_SUFFIX,
DEEPLAKE_AUTH_TOKEN,
)
from deeplake.client.log import logger
import jwt # should add it to requirements.txt
# for these codes, we will retry requests upto 3 times
retry_status_codes = {502}
class DeepLakeBackendClient:
"""Communicates with Activeloop Backend"""
def __init__(self, token: Optional[str] = None):
from deeplake.util.bugout_reporter import (
save_reporting_config,
get_reporting_config,
set_username,
)
self.version = deeplake.__version__
self.auth_header = None
self.token = (
token
or os.environ.get(DEEPLAKE_AUTH_TOKEN)
or "PUBLIC_TOKEN_" + ("_" * 150)
)
self.auth_header = f"Bearer {self.token}"
# remove public token, otherwise env var will be ignored
# we can remove this after a while
orgs = self.get_user_organizations()
if orgs == ["public"]:
self.token = token or self.get_token()
self.auth_header = f"Bearer {self.token}"
else:
username = self.get_user_profile()["name"]
if get_reporting_config().get("username") != username:
save_reporting_config(True, username=username)
set_username(username)
def get_token(self):
return self.token
def request(
self,
method: str,
relative_url: str,
endpoint: Optional[str] = None,
params: Optional[dict] = None,
data: Optional[dict] = None,
files: Optional[dict] = None,
json: Optional[dict] = None,
headers: Optional[dict] = None,
timeout: Optional[int] = DEFAULT_REQUEST_TIMEOUT,
):
"""Sends a request to the backend.
Args:
method (str): The method for sending the request. Should be one of 'GET', 'OPTIONS', 'HEAD', 'POST', 'PUT',
'PATCH', or 'DELETE'.
relative_url (str): The suffix to be appended to the end of the endpoint url.
endpoint(str, optional): The endpoint to send the request to.
params (dict, optional): Dictionary to send in the query string for the request.
data (dict, optional): Dictionary to send in the body of the request.
files (dict, optional): Dictionary of 'name': file-like-objects (or {'name': file-tuple}) for multipart
encoding upload.
file-tuple can be a 2-tuple (filename, fileobj), 3-tuple (filename, fileobj, content_type)
or a 4-tuple (filename, fileobj, content_type, custom_headers), where 'content-type' is a string
defining the content type of the given file and 'custom_headers' a dict-like object containing
additional headers to add for the file.
json (dict, optional): A JSON serializable Python object to send in the body of the request.
headers (dict, optional): Dictionary of HTTP Headers to send with the request.
timeout (float,optional): How many seconds to wait for the server to send data before giving up.
Raises:
InvalidPasswordException: `password` cannot be `None` inside `json`.
Returns:
requests.Response: The response received from the server.
"""
params = params or {}
data = data or None
files = files or None
json = json or None
endpoint = endpoint or self.endpoint()
endpoint = endpoint.strip("/")
relative_url = relative_url.strip("/")
request_url = f"{endpoint}/{relative_url}"
headers = headers or {}
headers["hub-cli-version"] = self.version
headers["Authorization"] = self.auth_header
# clearer error than `ServerUnderMaintenence`
if json is not None and "password" in json and json["password"] is None:
# do NOT pass in the password here. `None` is explicitly typed.
raise InvalidPasswordException("Password cannot be `None`.")
status_code = None
tries = 0
while status_code is None or (status_code in retry_status_codes and tries < 3):
response = requests.request(
method,
request_url,
params=params,
data=data,
json=json,
headers=headers,
files=files,
timeout=timeout,
)
status_code = response.status_code
tries += 1
check_response_status(response)
return response
def endpoint(self):
if deeplake.client.config.USE_LOCAL_HOST:
return HUB_REST_ENDPOINT_LOCAL
if deeplake.client.config.USE_DEV_ENVIRONMENT:
return HUB_REST_ENDPOINT_DEV
if deeplake.client.config.USE_STAGING_ENVIRONMENT:
return HUB_REST_ENDPOINT_STAGING
return HUB_REST_ENDPOINT
def request_auth_token(self, username: str, password: str):
"""Sends a request to backend to retrieve auth token.
Args:
username (str): The Activeloop username to request token for.
password (str): The password of the account.
Returns:
string: The auth token corresponding to the accound.
Raises:
UserNotLoggedInException: if user is not authorised
LoginException: If there is an issue retrieving the auth token.
"""
json = {"username": username, "password": password}
response = self.request("POST", GET_TOKEN_SUFFIX, json=json)
try:
token_dict = response.json()
token = token_dict["token"]
except Exception:
raise LoginException()
return token
def send_register_request(self, username: str, email: str, password: str):
"""Sends a request to backend to register a new user.
Args:
username (str): The Activeloop username to create account for.
email (str): The email id to link with the Activeloop account.
password (str): The new password of the account. Should be atleast 6 characters long.
"""
json = {"username": username, "email": email, "password": password}
self.request("POST", REGISTER_USER_SUFFIX, json=json)
def get_dataset_credentials(
self,
org_id: str,
ds_name: str,
mode: Optional[str] = None,
db_engine: Optional[dict] = None,
no_cache: bool = False,
):
"""Retrieves temporary 12 hour credentials for the required dataset from the backend.
Args:
org_id (str): The name of the user/organization to which the dataset belongs.
ds_name (str): The name of the dataset being accessed.
mode (str, optional): The mode in which the user has requested to open the dataset.
If not provided, the backend will set mode to 'a' if user has write permission, else 'r'.
db_engine (dict, optional): The database engine args to use for the dataset.
no_cache (bool): If True, cached creds are ignored and new creds are returned. Default False.
Returns:
tuple: containing full url to dataset, credentials, mode and expiration time respectively.
Raises:
UserNotLoggedInException: When user is not authenticated
InvalidTokenException: If the specified token is invalid
TokenPermissionError: when there are permission or other errors related to token
AgreementNotAcceptedError: when user has not accepted the agreement
NotLoggedInAgreementError: when user is not authenticated and dataset has agreement which needs to be signed
"""
import json
db_engine = db_engine or {}
relative_url = GET_DATASET_CREDENTIALS_SUFFIX.format(org_id, ds_name)
try:
response = self.request(
"GET",
relative_url,
endpoint=self.endpoint(),
params={
"mode": mode,
"no_cache": no_cache,
"db_engine": json.dumps(db_engine),
},
).json()
except Exception as e:
if isinstance(e, AuthorizationException):
authorization_exception_prompt = "You don't have permission"
response_data = e.response.json()
code = response_data.get("code")
if code == 1:
agreements = response_data["agreements"]
agreements = [agreement["text"] for agreement in agreements]
raise AgreementNotAcceptedError(agreements) from e
elif code == 2:
raise NotLoggedInAgreementError from e
else:
try:
decoded_token = jwt.decode(
self.token, options={"verify_signature": False}
)
except Exception:
raise InvalidTokenException
if (
authorization_exception_prompt.lower()
in response_data["description"].lower()
and decoded_token["id"] == "public"
):
raise UserNotLoggedInException()
raise TokenPermissionError()
raise
full_url = response.get("path")
repository = response.get("repository")
creds = response["creds"]
mode = response["mode"]
expiration = creds["expiration"] if creds else None
return full_url, creds, mode, expiration, repository
def send_event(self, event_json: dict):
"""Sends an event to the backend.
Args:
event_json (dict): The event to be sent.
"""
self.request("POST", SEND_EVENT_SUFFIX, json=event_json)
def create_dataset_entry(
self, username, dataset_name, meta, public=True, repository=None
):
tag = f"{username}/{dataset_name}"
if repository is None:
repository = f"protected/{username}"
response = self.request(
"POST",
CREATE_DATASET_SUFFIX,
json={
"tag": tag,
"public": public,
"rewrite": True,
"meta": meta,
"repository": repository,
},
endpoint=self.endpoint(),
)
if response.status_code == 200:
logger.info("Your Deep Lake dataset has been successfully created!")
def get_managed_creds(self, org_id, creds_key):
"""Retrieves the managed credentials for the given org_id and creds_key.
Args:
org_id (str): The name of the user/organization to which the dataset belongs.
creds_key (str): The key corresponding to the managed credentials.
Returns:
dict: The managed credentials.
Raises:
ManagedCredentialsNotFoundError: If the managed credentials do not exist for the given organization.
"""
relative_url = GET_MANAGED_CREDS_SUFFIX.format(org_id)
try:
resp = self.request(
"GET",
relative_url,
endpoint=self.endpoint(),
params={"query": creds_key},
).json()
except ResourceNotFoundException:
raise ManagedCredentialsNotFoundError(org_id, creds_key) from None
creds = resp["creds"]
key_mapping = {
"access_key": "aws_access_key_id",
"secret_key": "aws_secret_access_key",
"session_token": "aws_session_token",
"token": "aws_session_token",
"region": "aws_region",
}
final_creds = {}
for key, value in creds.items():
if key == "access_token":
key = "Authorization"
value = f"Bearer {value}"
elif key in key_mapping:
key = key_mapping[key]
final_creds[key] = value
return final_creds
def delete_dataset_entry(self, username, dataset_name):
tag = f"{username}/{dataset_name}"
suffix = f"{DATASET_SUFFIX}/{tag}"
self.request(
"DELETE",
suffix,
endpoint=self.endpoint(),
).json()
def accept_agreements(self, org_id, ds_name):
"""Accepts the agreements for the given org_id and ds_name.
Args:
org_id (str): The name of the user/organization to which the dataset belongs.
ds_name (str): The name of the dataset being accessed.
"""
relative_url = ACCEPT_AGREEMENTS_SUFFIX.format(org_id, ds_name)
self.request(
"POST",
relative_url,
endpoint=self.endpoint(),
).json()
def reject_agreements(self, org_id, ds_name):
"""Rejects the agreements for the given org_id and ds_name.
Args:
org_id (str): The name of the user/organization to which the dataset belongs.
ds_name (str): The name of the dataset being accessed.
"""
relative_url = REJECT_AGREEMENTS_SUFFIX.format(org_id, ds_name)
self.request(
"POST",
relative_url,
endpoint=self.endpoint(),
).json()
def rename_dataset_entry(self, username, old_name, new_name):
suffix = UPDATE_SUFFIX.format(username, old_name)
self.request(
"PUT", suffix, endpoint=self.endpoint(), json={"basename": new_name}
)
def get_user_organizations(self):
"""Get list of user organizations from the backend. If user is not authenticated, returns ['public'].
Returns:
list: user/organization names
"""
response = self.request(
"GET", GET_USER_PROFILE, endpoint=self.endpoint()
).json()
return response["organizations"]
def get_workspace_datasets(
self, workspace: str, suffix_public: str, suffix_user: str
):
organizations = self.get_user_organizations()
if workspace in organizations:
response = self.request(
"GET",
suffix_user,
endpoint=self.endpoint(),
params={"organization": workspace},
).json()
else:
print(
f'You are not a member of organization "{workspace}". List of accessible datasets from "{workspace}": ',
)
response = self.request(
"GET",
suffix_public,
endpoint=self.endpoint(),
params={"organization": workspace},
).json()
return response
def update_privacy(self, username: str, dataset_name: str, public: bool):
suffix = UPDATE_SUFFIX.format(username, dataset_name)
self.request("PUT", suffix, endpoint=self.endpoint(), json={"public": public})
def get_presigned_url(self, org_id, ds_id, chunk_path, expiration=3600):
relative_url = GET_PRESIGNED_URL_SUFFIX.format(org_id, ds_id)
response = self.request(
"GET",
relative_url,
endpoint=self.endpoint(),
params={"chunk_path": chunk_path, "expiration": expiration},
).json()
presigned_url = response["data"]
return presigned_url
def get_user_profile(self):
response = self.request(
"GET",
"/api/user/profile",
endpoint=self.endpoint(),
)
return response.json()
def connect_dataset_entry(
self,
src_path: str,
org_id: str,
ds_name: Optional[str] = None,
creds_key: Optional[str] = None,
) -> str:
"""Creates a new dataset entry that can be accessed with a hub path, but points to the original ``src_path``.
Args:
src_path (str): The path at which the source dataset resides.
org_id (str): The organization into which the dataset entry is put and where the credentials are searched.
ds_name (Optional[str]): Name of the dataset entry. Can be infered from the source path.
creds_key (Optional[str]): Name of the managed credentials that will be used to access the source path.
Returns:
str: The id of the dataset entry that was created.
"""
response = self.request(
"POST",
CONNECT_DATASET_SUFFIX,
json={
"src_path": src_path,
"org_id": org_id,
"ds_name": ds_name,
"creds_key": creds_key,
},
endpoint=self.endpoint(),
).json()
return response["generated_id"]
def remote_query(
self, org_id: str, ds_name: str, query_string: str
) -> Dict[str, Any]:
"""Queries a remote dataset.
Args:
org_id (str): The organization to which the dataset belongs.
ds_name (str): The name of the dataset.
query_string (str): The query string.
Returns:
Dict[str, Any]: The json response containing matching indicies and data from virtual tensors.
"""
response = self.request(
"POST",
REMOTE_QUERY_SUFFIX.format(org_id, ds_name),
json={"query": query_string},
endpoint=self.endpoint(),
).json()
return response
def has_indra_org_permission(self, org_id: str) -> Dict[str, Any]:
"""Queries a remote dataset.
Args:
org_id (str): The organization to which the dataset belongs.
Returns:
Dict[str, Any]: The json response containing org permissions.
"""
response = self.request(
"GET",
ORG_PERMISSION_SUFFIX.format(org_id),
endpoint=self.endpoint(),
).json()
return response
class DeepMemoryBackendClient(DeepLakeBackendClient):
def __init__(self, token: Optional[str] = None):
super().__init__(token=token)
def deepmemory_is_available(self, org_id: str):
"""Checks if DeepMemory is available for the user.
Args:
org_id (str): The name of the user/organization to which the dataset belongs.
Returns:
bool: True if DeepMemory is available, False otherwise.
"""
try:
response = self.request(
"GET",
f"/api/organizations/{org_id}/features/deepmemory",
endpoint=self.endpoint(),
)
return response.json()["available"]
except Exception:
return False
def start_taining(
self,
corpus_path: str,
queries_path: str,
) -> Dict[str, Any]:
"""Starts training of DeepMemory model.
Args:
corpus_path (str): The path to the corpus dataset.
queries_path (str): The path to the queries dataset.
Returns:
Dict[str, Any]: The json response containing job_id.
"""
response = self.request(
method="POST",
relative_url="/api/deepmemory/v1/train",
json={"corpus_dataset": corpus_path, "query_dataset": queries_path},
)
check_response_status(response)
return response.json()
def cancel_job(self, job_id: str):
"""Cancels a job with job_id.
Args:
job_id (str): The job_id of the job to be cancelled.
Returns:
bool: True if job was cancelled successfully, False otherwise.
"""
try:
response = self.request(
method="POST",
relative_url=f"/api/deepmemory/v1/jobs/{job_id}/cancel",
)
check_response_status(response)
except Exception as e:
print(f"Job with job_id='{job_id}' was not cancelled!\n Error: {e}")
return False
print("Job cancelled successfully")
return True
def check_status(self, job_id: str, recall: str, improvement: str):
"""Checks status of a job with job_id.
Args:
job_id (str): The job_id of the job to be checked.
recall (str): Current best top 10 recall
importvement (str): Current best improvement over baseline
Returns:
Dict[str, Any]: The json response containing job status.
"""
response = self.request(
method="GET",
relative_url=f"/api/deepmemory/v1/jobs/{job_id}/status",
)
check_response_status(response)
response_status_schema = JobResponseStatusSchema(response=response.json())
response_status_schema.print_status(job_id, recall, improvement)
return response.json()
def list_jobs(self, dataset_path: str):
"""Lists all jobs for a dataset.
Args:
dataset_path (str): The path to the dataset.
Returns:
Dict[str, Any]: The json response containing list of jobs.
"""
dataset_id = dataset_path[6:]
response = self.request(
method="GET",
relative_url=f"/api/deepmemory/v1/{dataset_id}/jobs",
)
check_response_status(response)
return response.json()
def delete_job(self, job_id: str):
"""Deletes a job with job_id.
Args:
job_id (str): The job_id of the job to be deleted.
Returns:
bool: True if job was deleted successfully, False otherwise.
"""
try:
response = self.request(
method="DELETE",
relative_url=f"/api/deepmemory/v1/jobs/{job_id}",
)
check_response_status(response)
return True
except Exception as e:
print(f"Job with job_id='{job_id}' was not deleted!\n Error: {e}")
return False