Skip to content

Commit

Permalink
Set VectorStore.search embedding argument type (#2538)
Browse files Browse the repository at this point in the history
Set VectorStore.search embedding argument type
  • Loading branch information
adolkhan committed Aug 15, 2023
1 parent 4a01722 commit 453d86a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
9 changes: 2 additions & 7 deletions deeplake/core/vectorstore/deeplake_vectorstore.py
Expand Up @@ -301,7 +301,7 @@ def add(

def search(
self,
embedding_data=None,
embedding_data: Union[str, List[str]] = None,
embedding_function: Optional[Callable] = None,
embedding: Optional[Union[List[float], np.ndarray]] = None,
k: int = 4,
Expand Down Expand Up @@ -344,7 +344,7 @@ def search(
Args:
embedding (Union[np.ndarray, List[float]], optional): Embedding representation for performing the search. Defaults to None. The ``embedding_data`` and ``embedding`` cannot both be specified.
embedding_data: Data against which the search will be performed by embedding it using the `embedding_function`. Defaults to None. The `embedding_data` and `embedding` cannot both be specified.
embedding_data (List[str]): Data against which the search will be performed by embedding it using the `embedding_function`. Defaults to None. The `embedding_data` and `embedding` cannot both be specified.
embedding_function (Optional[Callable], optional): function for converting `embedding_data` into embedding. Only valid if `embedding_data` is specified
k (int): Number of elements to return after running query. Defaults to 4.
distance_metric (str): Type of distance metric to use for sorting the data. Avaliable options are: ``"L1", "L2", "COS", "MAX"``. Defaults to ``"COS"``.
Expand Down Expand Up @@ -424,11 +424,6 @@ def search(
embedding_data,
embedding_function=embedding_function or self.embedding_function,
)
if isinstance(query_emb, np.ndarray):
assert (
query_emb.ndim == 1 or query_emb.shape[0] == 1
), "Query embedding must be 1-dimensional. Please consider using another embedding function for converting query string to embedding."

return vector_search.search(
query=query,
logger=logger,
Expand Down
20 changes: 19 additions & 1 deletion deeplake/core/vectorstore/test_deeplake_vectorstore.py
Expand Up @@ -433,6 +433,24 @@ def filter_fn(x):
assert isinstance(data.text[0].data()["value"], str)
assert data.embedding[0].numpy().size > 0

data = vector_store.search(
embedding_function=embedding_fn3,
embedding_data="dummy",
return_view=True,
k=2,
)
assert len(data) == 2
assert isinstance(data.text[0].data()["value"], str)
assert data.embedding[0].numpy().size > 0

with pytest.raises(NotImplementedError):
data = vector_store.search(
embedding_function=embedding_fn3,
embedding_data=["dummy", "dummy2"],
return_view=True,
k=2,
)

data = vector_store.search(
filter={"metadata": {"abcdefh": 1}},
embedding=None,
Expand Down Expand Up @@ -1738,7 +1756,7 @@ def test_query_dim(local_path):
)

vector_store.add(text=texts, embedding=embeddings)
with pytest.raises(AssertionError):
with pytest.raises(NotImplementedError):
vector_store.search([texts[0], texts[0]], embedding_fn3, k=1)

vector_store.search([texts[0]], embedding_fn4, k=1)
Expand Down
6 changes: 6 additions & 0 deletions deeplake/core/vectorstore/vector_search/dataset/dataset.py
Expand Up @@ -241,11 +241,17 @@ def fetch_embeddings(view, embedding_tensor: str = "embedding"):


def get_embedding(embedding, embedding_data, embedding_function=None):
if isinstance(embedding_data, str):
embedding_data = [embedding_data]

if (
embedding is None
and embedding_function is not None
and embedding_data is not None
):
if len(embedding_data) > 1:
raise NotImplementedError("Searching batched queries is not supported yet.")

embedding = embedding_function(embedding_data) # type: ignore

if embedding is not None and (
Expand Down

0 comments on commit 453d86a

Please sign in to comment.