Skip to content

Commit

Permalink
Deep memory fixes (#2662)
Browse files Browse the repository at this point in the history
* fix

* eval fix

* mypy fixes

* changes

* test fixes

* sonar cloud fix

* sonar fixes

* fix

---------

Co-authored-by: adolkhan <adilkhan.sarsen@alumni.nu.edu.kz>
  • Loading branch information
adolkhan and adolkhan committed Oct 24, 2023
1 parent 16f051c commit 008dbe8
Show file tree
Hide file tree
Showing 12 changed files with 528 additions and 303 deletions.
4 changes: 2 additions & 2 deletions deeplake/client/test_client.py
Expand Up @@ -168,7 +168,7 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
progress=None,
)
response_schema = JobResponseStatusSchema(response=pending_response)
response_schema.print_status(job_id)
response_schema.print_status(job_id, recall=None, importvement=None)
captured = capsys.readouterr()
assert captured.out == Status.pending

Expand Down Expand Up @@ -204,7 +204,7 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
},
)
response_schema = JobResponseStatusSchema(response=failed_response)
response_schema.print_status(job_id)
response_schema.print_status(job_id, recall=None, importvement=None)
captured = capsys.readouterr()
assert captured.out == Status.failed

Expand Down
21 changes: 8 additions & 13 deletions deeplake/client/utils.py
Expand Up @@ -144,8 +144,8 @@ def validate_status_response(self):
def print_status(
self,
job_id: Union[str, List[str]],
recall: Optional[str] = None,
importvement: Optional[str] = None,
recall: str,
importvement: str,
):
if not isinstance(job_id, List):
job_id = [job_id]
Expand All @@ -157,8 +157,8 @@ def print_status(

if response["status"] == "completed":
response["results"] = get_results(
response,
" " * 30,
response=response,
indent=" " * 30,
add_vertical_bars=True,
recall=recall,
improvement=importvement,
Expand Down Expand Up @@ -217,8 +217,8 @@ def print_jobs(
)
if response_status == "completed":
response_results = get_results(
response,
"",
response=response,
indent="",
add_vertical_bars=False,
width=15,
recall=recalls[response_id],
Expand Down Expand Up @@ -271,20 +271,15 @@ def print_jobs(

def get_results(
response: Dict[str, Any],
improvement: str,
recall: str,
indent: str,
add_vertical_bars: bool,
width: int = 21,
improvement: Optional[str] = None,
recall: Optional[str] = None,
):
progress = response["progress"]
for progress_key, progress_value in progress.items():
if progress_key == BEST_RECALL:
curr_recall, curr_improvement = progress_value.split("%")[:2]

recall = recall or curr_recall
improvement = improvement or curr_improvement

if "(" not in improvement:
improvement = f"(+{improvement}%)"

Expand Down
1 change: 1 addition & 0 deletions deeplake/constants.py
Expand Up @@ -328,3 +328,4 @@
"M": 32,
},
}
VECTORSTORE_EXTEND_BATCH_SIZE = 500
144 changes: 87 additions & 57 deletions deeplake/core/vectorstore/deep_memory.py
@@ -1,20 +1,30 @@
import uuid
from collections import defaultdict
from typing import Any, Dict, Optional, List, Union, Callable, Tuple
from time import time

import numpy as np

import deeplake
from deeplake.enterprise.dataloader import indra_available
from deeplake.constants import DEFAULT_QUERIES_VECTORSTORE_TENSORS
from deeplake.util.remove_cache import get_base_storage
from deeplake.constants import (
DEFAULT_QUERIES_VECTORSTORE_TENSORS,
DEFAULT_MEMORY_CACHE_SIZE,
DEFAULT_LOCAL_CACHE_SIZE,
)
from deeplake.util.storage import get_storage_and_cache_chain
from deeplake.core.dataset import Dataset
from deeplake.core.dataset.deeplake_cloud_dataset import DeepLakeCloudDataset
from deeplake.core.vectorstore.deeplake_vectorstore import VectorStore
from deeplake.client.client import DeepMemoryBackendClient
from deeplake.client.utils import JobResponseStatusSchema
from deeplake.util.bugout_reporter import (
feature_report_path,
)
from deeplake.util.dataset import try_flushing
from deeplake.util.path import get_path_type
from deeplake.util.version_control import load_meta


class DeepMemory:
Expand Down Expand Up @@ -114,7 +124,6 @@ def train(
path=queries_path,
overwrite=True,
runtime=runtime,
embedding_function=embedding_function,
token=token or self.token,
creds=self.creds,
)
Expand All @@ -125,6 +134,7 @@ def train(
{"relevance": relevance_per_doc} for relevance_per_doc in relevance
],
embedding_data=[query for query in queries],
embedding_function=embedding_function,
)

# do some rest_api calls to train the model
Expand Down Expand Up @@ -206,9 +216,22 @@ def status(self, job_id: str):
},
token=self.token,
)

_, storage = get_storage_and_cache_chain(
path=self.dataset.path,
db_engine={"tensor_db": True},
read_only=False,
creds=self.creds,
token=self.dataset.token,
memory_cache_size=DEFAULT_MEMORY_CACHE_SIZE,
local_cache_size=DEFAULT_LOCAL_CACHE_SIZE,
)

loaded_dataset = DeepLakeCloudDataset(storage=storage)

try:
recall, improvement = _get_best_model(
self.dataset.embedding, job_id, latest_job=True
loaded_dataset.embedding, job_id, latest_job=True
)

recall = "{:.2f}".format(100 * recall)
Expand All @@ -228,6 +251,17 @@ def list_jobs(self, debug=False):
},
token=self.token,
)
_, storage = get_storage_and_cache_chain(
path=self.dataset.path,
db_engine={"tensor_db": True},
read_only=False,
creds=self.creds,
token=self.dataset.token,
memory_cache_size=DEFAULT_MEMORY_CACHE_SIZE,
local_cache_size=DEFAULT_LOCAL_CACHE_SIZE,
)
loaded_dataset = DeepLakeCloudDataset(storage=storage)

response = self.client.list_jobs(
dataset_path=self.dataset.path,
)
Expand All @@ -243,7 +277,7 @@ def list_jobs(self, debug=False):
for job in jobs:
try:
recall, delta = _get_best_model(
self.dataset.embedding,
loaded_dataset.embedding,
job,
latest_job=job == latest_job,
)
Expand Down Expand Up @@ -352,6 +386,7 @@ def evaluate(
},
token=self.token,
)
try_flushing(self.dataset)
try:
from indra import api # type: ignore

Expand All @@ -373,9 +408,10 @@ def evaluate(

start = time()
query_embs: Union[List[np.ndarray], List[List[float]]]

if embedding is not None:
query_embs = embedding
elif embedding is None:
else:
if self.embedding_function is not None:
embedding_function = (
embedding_function or self.embedding_function.embed_documents
Expand Down Expand Up @@ -404,26 +440,20 @@ def evaluate(
]:
eval_type = "with" if use_model else "without"
print(f"---- Evaluating {eval_type} model ---- ")
callect_data = False
for k in top_k:
callect_data = k == 10

recall, queries_dict = recall_at_k(
self.dataset,
indra_dataset,
relevance,
top_k=k,
query_embs=query_embs,
metric=metric,
collect_data=callect_data,
use_model=use_model,
)
avg_recalls, queries_dict = recall_at_k(
indra_dataset,
relevance,
top_k=top_k,
query_embs=query_embs,
metric=metric,
use_model=use_model,
)

if callect_data:
queries_data.update(queries_dict)
queries_data.update(queries_dict)

print(f"Recall@{k}:\t {100*recall: .1f}%")
recalls[f"{eval_type} model"][f"recall@{k}"] = recall
for recall, recall_value in avg_recalls.items():
print(f"Recall@{recall}:\t {100*recall_value: .1f}%")
recalls[f"{eval_type} model"][f"recall@{recall}"] = recall_value

log_queries = parsed_qvs_params.get("log_queries")
branch = parsed_qvs_params.get("branch")
Expand Down Expand Up @@ -454,16 +484,14 @@ def evaluate(


def recall_at_k(
dataset: Dataset,
indra_dataset: Any,
relevance: List[List[Tuple[str, int]]],
query_embs: Union[List[np.ndarray], List[List[float]]],
metric: str,
top_k: int = 10,
collect_data: bool = False,
top_k: List[int] = [1, 3, 5, 10, 50, 100],
use_model: bool = False,
):
recalls = []
recalls = defaultdict(list)
top_k_list = []

for query_idx, _ in enumerate(query_embs):
Expand All @@ -473,54 +501,56 @@ def recall_at_k(
correct_labels = [rel[0] for rel in query_relevance]

# Compute the cosine similarity between the query and all data points
view_top_k = get_view_top_k(
view = get_view(
metric=metric,
query_emb=query_emb,
top_k=top_k,
indra_dataset=indra_dataset,
)

top_k_retrieved = [
sample.id.numpy() for sample in view_top_k
] # TODO: optimize this

# Compute the recall: the fraction of relevant items found in the top k
num_relevant_in_top_k = len(
set(correct_labels).intersection(set(top_k_retrieved))
)
if len(correct_labels) == 0:
continue
recall = num_relevant_in_top_k / len(correct_labels)

if collect_data:
top_k_list.append(top_k_retrieved)
recalls.append(recall)
for k in top_k:
collect_data = k == 10
view_top_k = view[:k]

# Average the recalls for each query
avg_recall = np.mean(np.array(recalls))
queries_data = {}
if collect_data:
model_type = "deep_memory" if use_model else "vector_search"
top_k_retrieved = [
sample.id.numpy() for sample in view_top_k
] # TODO: optimize this

queries_data = {
f"{model_type}_top_10": top_k_list,
f"{model_type}_recall": recalls,
}
return avg_recall, queries_data
# Compute the recall: the fraction of relevant items found in the top k
num_relevant_in_top_k = len(
set(correct_labels).intersection(set(top_k_retrieved))
)
if len(correct_labels) == 0:
continue
recall = num_relevant_in_top_k / len(correct_labels)

if collect_data:
top_k_list.append(top_k_retrieved)
recalls[k].append(recall)

def get_view_top_k(
# Average the recalls for each query
avg_recalls = {
f"{recall}": np.mean(np.array(recall_list))
for recall, recall_list in recalls.items()
}
model_type = "deep_memory" if use_model else "vector_search"
queries_data = {
f"{model_type}_top_10": top_k_list,
f"{model_type}_recall": recalls[10],
}
return avg_recalls, queries_data


def get_view(
metric: str,
query_emb: Union[List[float], np.ndarray],
top_k: int,
indra_dataset: Any,
return_tensors: List[str] = ["text", "metadata", "id"],
tql_filter: str = "",
):
tql_filter_str = tql_filter if tql_filter == "" else " where " + tql_filter
query_emb_str = ",".join([f"{q}" for q in query_emb])
return_tensors_str = ", ".join(return_tensors)
tql = f"SELECT * FROM (SELECT {return_tensors_str}, ROW_NUMBER() as indices, {metric}(embedding, ARRAY[{query_emb_str}]) as score {tql_filter_str} order by {metric}(embedding, ARRAY[{query_emb_str}]) desc limit {top_k})"
tql = f"SELECT * FROM (SELECT {return_tensors_str}, ROW_NUMBER() as indices, {metric}(embedding, ARRAY[{query_emb_str}]) as score {tql_filter_str} order by {metric}(embedding, ARRAY[{query_emb_str}]) desc limit 100)"
indra_view = indra_dataset.query(tql)
return indra_view

Expand Down

0 comments on commit 008dbe8

Please sign in to comment.