Skip to content

Commit

Permalink
Accept Sequence[float] as query_vector in FindNearest
Browse files Browse the repository at this point in the history
  • Loading branch information
Sichen Liu committed Apr 8, 2024
1 parent 6886f2b commit dad74e9
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
3 changes: 2 additions & 1 deletion google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Iterator,
Iterable,
NoReturn,
Sequence,
Tuple,
Union,
TYPE_CHECKING,
Expand Down Expand Up @@ -549,7 +550,7 @@ def avg(self, field_ref: str | FieldPath, alias=None):
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
) -> VectorQuery:
Expand Down
6 changes: 4 additions & 2 deletions google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABC
from enum import Enum
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable, Optional, Sequence, Tuple, Union
from google.api_core import gapic_v1
from google.api_core import retry as retries
from google.cloud.firestore_v1.base_document import DocumentSnapshot
Expand Down Expand Up @@ -107,11 +107,13 @@ def get(
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
):
"""Finds the closest vector embeddings to the given query vector."""
if not isinstance (query_vector, Vector):
self._query_vector = Vector(query_vector)
self._vector_field = vector_field
self._query_vector = query_vector
self._limit = limit
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/v1/test_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,42 @@ def test_vector_query_collection_group(distance_measure, expected_distance):
**kwargs,
)

def test_vector_query_list_as_query_vector():
# Create a minimal fake GAPIC with a dummy response.
firestore_api = mock.Mock(spec=["run_query"])
response_pb = _make_query_response(name="xxx/test_doc", data=data)
run_query_response = iter([response_pb])
firestore_api.run_query.return_value = run_query_response

# Attach the fake GAPIC to a real client.
client = make_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dah", "dah", "dum")
vector_query = parent.where("snooze", "==", 10).find_nearest(
vector_field="embedding",
query_vector=[1.0, 2.0, 3.0],
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=5,
)

get_response = vector_query.stream()
assert isinstance(get_response, types.GeneratorType)
assert list(get_response) == []

# Verify the mock call.
parent_path, _ = parent._parent_info()
firestore_api.run_query.assert_called_once_with(
request={
"parent": parent_path,
"structured_query": vector_query._to_protobuf(),
"transaction": None,
},
metadata=client._rpc_metadata,
)



def test_query_stream_multiple_empty_response_in_stream():
# Create a minimal fake GAPIC with a dummy response.
Expand Down

0 comments on commit dad74e9

Please sign in to comment.