/
search_algorithm.py
214 lines (187 loc) · 7.19 KB
/
search_algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import numpy as np
from abc import ABC, abstractmethod
from typing import Union, Dict, List, Optional
from deeplake.enterprise.util import INDRA_INSTALLED
from deeplake.core.vectorstore.vector_search.indra import query
from deeplake.core.vectorstore.vector_search import utils
from deeplake.core.dataset import Dataset as DeepLakeDataset
from deeplake.core.dataset.indra_dataset_view import IndraDatasetView
class SearchBasic(ABC):
def __init__(
self,
deeplake_dataset: DeepLakeDataset,
org_id: Optional[str] = None,
token: Optional[str] = None,
runtime: Optional[Dict] = None,
deep_memory: bool = False,
):
"""Base class for all search algorithms.
Args:
deeplake_dataset (DeepLakeDataset): DeepLake dataset object.
org_id (Optional[str], optional): Organization ID, is needed only for local datasets. Defaults to None.
token (Optional[str], optional): Token used for authentication. Defaults to None.
runtime (Optional[Dict], optional): Whether to run query on managed_db or indra. Defaults to None.
deep_memory (bool): Use DeepMemory for the search. Defaults to False.
"""
self.deeplake_dataset = deeplake_dataset
self.org_id = org_id
self.token = token
self.runtime = runtime
self.deep_memory = deep_memory
def run(
self,
tql_string: str,
return_view: bool,
return_tql: bool,
distance_metric: str,
k: int,
query_embedding: np.ndarray,
embedding_tensor: str,
tql_filter: str,
return_tensors: List[str],
):
tql_query = self._create_tql_string(
tql_string,
distance_metric,
k,
query_embedding,
embedding_tensor,
tql_filter,
return_tensors,
)
view = self._get_view(
tql_query,
runtime=self.runtime,
)
if return_view:
return view
return_data = self._collect_return_data(view)
if return_tql:
return {"data": return_data, "tql": tql_query}
return return_data
@abstractmethod
def _collect_return_data(
self,
view: DeepLakeDataset,
):
pass
@staticmethod
def _create_tql_string(
tql_string: str,
distance_metric: str,
k: int,
query_embedding: np.ndarray,
embedding_tensor: str,
tql_filter: str,
return_tensors: List[str],
):
"""Creates TQL query string for the vector search."""
if tql_string:
return tql_string
else:
return query.parse_query(
distance_metric,
k,
query_embedding,
embedding_tensor,
tql_filter,
return_tensors,
)
@abstractmethod
def _get_view(self, tql_query: str, runtime: Optional[Dict] = None):
pass
class SearchIndra(SearchBasic):
def _get_view(self, tql_query, runtime: Optional[Dict] = None):
indra_dataset = self._get_indra_dataset()
indra_view = indra_dataset.query(tql_query)
view = IndraDatasetView(indra_ds=indra_view)
view._tql_query = tql_query
return view
def _get_indra_dataset(self):
if not INDRA_INSTALLED:
from deeplake.enterprise.util import raise_indra_installation_error
raise raise_indra_installation_error(indra_import_error=None)
if self.deeplake_dataset.libdeeplake_dataset is not None:
indra_dataset = self.deeplake_dataset.libdeeplake_dataset
else:
from deeplake.enterprise.convert_to_libdeeplake import (
dataset_to_libdeeplake,
)
if self.org_id is not None:
self.deeplake_dataset.org_id = self.org_id
if self.token is not None:
self.deeplake_dataset.set_token(self.token)
indra_dataset = dataset_to_libdeeplake(self.deeplake_dataset)
return indra_dataset
def _collect_return_data(
self,
view: DeepLakeDataset,
):
return_data = {}
for tensor in view.tensors:
return_data[tensor] = utils.parse_tensor_return(view[tensor])
return return_data
class SearchManaged(SearchBasic):
def _get_view(self, tql_query, runtime: Optional[Dict] = None):
view, data = self.deeplake_dataset.query(
tql_query, runtime=runtime, return_data=True
)
self.data = data
return view
def _collect_return_data(
self,
view: DeepLakeDataset,
):
return self.data
def search(
query_embedding: np.ndarray,
distance_metric: str,
deeplake_dataset: DeepLakeDataset,
k: int,
tql_string: str,
tql_filter: str,
embedding_tensor: str,
runtime: dict,
return_tensors: List[str],
return_view: bool = False,
token: Optional[str] = None,
org_id: Optional[str] = None,
return_tql: bool = False,
) -> Union[Dict, DeepLakeDataset]:
"""Generalized search algorithm that uses indra. It combines vector search and other TQL queries.
Args:
query_embedding (Optional[Union[List[float], np.ndarray): embedding representation of the query.
distance_metric (str): Distance metric to compute similarity between query embedding and dataset embeddings
deeplake_dataset (DeepLakeDataset): DeepLake dataset object.
k (int): number of samples to return after the search.
tql_string (str): Standalone TQL query for execution without other filters.
tql_filter (str): Additional filter using TQL syntax
embedding_tensor (str): name of the tensor in the dataset with `htype = "embedding"`.
runtime (dict): Runtime parameters for the query.
return_tensors (List[str]): List of tensors to return data for.
return_view (bool): Return a Deep Lake dataset view that satisfied the search parameters, instead of a dictinary with data. Defaults to False.
token (Optional[str], optional): Token used for authentication. Defaults to None.
org_id (Optional[str], optional): Organization ID, is needed only for local datasets. Defaults to None.
return_tql (bool): Return TQL query used for the search. Defaults to False.
Raises:
ValueError: If both tql_string and tql_filter are specified.
raise_indra_installation_error: If the indra is not installed
Returns:
Union[Dict, DeepLakeDataset]: Dictionary where keys are tensor names and values are the results of the search, or a Deep Lake dataset view.
"""
searcher: SearchBasic
if runtime and runtime.get("db_engine", False):
searcher = SearchManaged(deeplake_dataset, org_id, token, runtime=runtime)
else:
searcher = SearchIndra(deeplake_dataset, org_id, token)
return searcher.run(
tql_string=tql_string,
return_view=return_view,
return_tql=return_tql,
distance_metric=distance_metric,
k=k,
query_embedding=query_embedding,
embedding_tensor=embedding_tensor,
tql_filter=tql_filter,
return_tensors=return_tensors,
)