Skip to content

Commit

Permalink
Deep Memory recall fix (#2698)
Browse files Browse the repository at this point in the history
fixing deep memory recall

---------

Co-authored-by: adolkhan <adilkhan.sarsen@alumni.nu.edu.kz>
  • Loading branch information
adolkhan and adolkhan committed Nov 21, 2023
1 parent 8eb512c commit 6f92e9d
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
50 changes: 46 additions & 4 deletions deeplake/client/test_client.py
Expand Up @@ -128,6 +128,19 @@ class Status:
"--------------------------------------------------------------\n\n\n"
)

completed_no_improvement = (
"--------------------------------------------------------------\n"
"| 1338464cd80cab681bfcfw23 |\n"
"--------------------------------------------------------------\n"
"| status | completed |\n"
"--------------------------------------------------------------\n"
"| progress | eta: 100.3 seconds |\n"
"| | recall@10: 100.0% (+0.0%) |\n"
"--------------------------------------------------------------\n"
"| results | recall@10: 100.0% (+0.0%) |\n"
"--------------------------------------------------------------\n\n\n"
)

failed = (
"--------------------------------------------------------------\n"
"| 1338464cd80cab681bfcfff3 |\n"
Expand Down Expand Up @@ -168,15 +181,15 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
progress=None,
)
response_schema = JobResponseStatusSchema(response=pending_response)
response_schema.print_status(job_id, recall=None, importvement=None)
response_schema.print_status(job_id, recall=None, improvement=None)
captured = capsys.readouterr()
assert captured.out == Status.pending

# for training that is in progress
job_id = "3218464cd80cab681bfcfff3"
training_response = create_response(job_id=job_id)
response_schema = JobResponseStatusSchema(response=training_response)
response_schema.print_status(job_id, recall="85.5", importvement="2.6")
response_schema.print_status(job_id, recall="85.5", improvement="2.6")
captured = capsys.readouterr()
assert captured.out == Status.training

Expand All @@ -187,10 +200,36 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
status="completed",
)
response_schema = JobResponseStatusSchema(response=completed_response)
response_schema.print_status(job_id, recall="85.5", importvement="2.6")
response_schema.print_status(job_id, recall="85.5", improvement="2.6")
captured = capsys.readouterr()
assert captured.out == Status.completed

job_id = "1338464cd80cab681bfcfw23"
completed_no_improvement_response = create_response(
job_id=job_id,
status="completed",
progress={
"eta": 100.34,
"last_update_at": "2021-08-31T15:00:00.000000",
"error": None,
"train_recall@10": "87.8%",
"best_recall@10": "100.0% (+0.0)%",
"epoch": 0,
"base_val_recall@10": 0.8292181491851807,
"val_recall@10": "85.5%",
"dataset": "query",
"split": 0,
"loss": -0.05437087118625641,
"delta": 2.572011947631836,
},
)
response_schema = JobResponseStatusSchema(
response=completed_no_improvement_response
)
response_schema.print_status(job_id, recall="0.0", improvement="0.0")
captured = capsys.readouterr()
assert captured.out == Status.completed_no_improvement

# for jobs that failed
job_id = "1338464cd80cab681bfcfff3"
failed_response = create_response(
Expand All @@ -204,7 +243,7 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
},
)
response_schema = JobResponseStatusSchema(response=failed_response)
response_schema.print_status(job_id, recall=None, importvement=None)
response_schema.print_status(job_id, recall=None, improvement=None)
captured = capsys.readouterr()
assert captured.out == Status.failed

Expand All @@ -213,18 +252,21 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
training_response,
completed_response,
failed_response,
completed_no_improvement_response,
]
recalls = {
"1238464cd80cab681bfcfff3": None,
"3218464cd80cab681bfcfff3": "85.5",
"2138464cd80cab681bfcfff3": "85.5",
"1338464cd80cab681bfcfff3": None,
"1338464cd80cab681bfcfw23": "0.0",
}
improvements = {
"1238464cd80cab681bfcfff3": None,
"3218464cd80cab681bfcfff3": "2.6",
"2138464cd80cab681bfcfff3": "2.6",
"1338464cd80cab681bfcfff3": None,
"1338464cd80cab681bfcfw23": "0.0",
}
response_schema = JobResponseStatusSchema(response=responses)
output_str = response_schema.print_jobs(
Expand Down
8 changes: 5 additions & 3 deletions deeplake/client/utils.py
Expand Up @@ -145,7 +145,7 @@ def print_status(
self,
job_id: Union[str, List[str]],
recall: str,
importvement: str,
improvement: str,
):
if not isinstance(job_id, List):
job_id = [job_id]
Expand All @@ -161,7 +161,7 @@ def print_status(
indent=" " * 30,
add_vertical_bars=True,
recall=recall,
improvement=importvement,
improvement=improvement,
)

print(line)
Expand All @@ -174,7 +174,7 @@ def print_status(
" " * 30,
add_vertical_bars=True,
recall=recall,
improvement=importvement,
improvement=improvement,
)
progress_string = "| {:<27}| {:<30}"
if progress == "None":
Expand Down Expand Up @@ -298,6 +298,8 @@ def get_best_recall_improvement(recall, improvement, best_recall):
elif float(improvement) < float(bimprovement):
return brecall, bimprovement
else:
if brecall > recall:
return brecall, bimprovement
return recall, improvement


Expand Down
Expand Up @@ -5,4 +5,6 @@ ID STATUS RESULTS PROGRESS
2138464cd80cab681bfcfff3 completed recall@10: 85.5% (+2.6%) eta: 100.3 seconds
recall@10: 85.5% (+2.6%)
1338464cd80cab681bfcfff3 failed not available yet eta: None seconds
error: list indices must beintegers or slices,not str
error: list indices must beintegers or slices,not str
1338464cd80cab681bfcfw23 completed recall@10: 100.0% (+0.0%) eta: 100.3 seconds
recall@10: 100.0% (+0.0%)

0 comments on commit 6f92e9d

Please sign in to comment.