Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement any and all metadata filters for weaviate vector store #13365

Merged
merged 7 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 12 additions & 3 deletions llama-index-core/llama_index/core/vector_stores/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _build_metadata_filter_fn(
metadata_filters: Optional[MetadataFilters] = None,
) -> Callable[[str], bool]:
"""Build metadata filter function."""
filter_list = metadata_filters.legacy_filters() if metadata_filters else []
filter_list = metadata_filters.filters if metadata_filters else []
if not filter_list:
return lambda _: True

Expand All @@ -62,8 +62,17 @@ def filter_fn(node_id: str) -> bool:
if metadata_value is None:
filter_matches = False
elif isinstance(metadata_value, list):
if filter_.value not in metadata_value:
filter_matches = False
match filter_.operator:
brenkehoe marked this conversation as resolved.
Show resolved Hide resolved
case "any":
if not any(value in metadata_value for value in filter_.value):
filter_matches = False
case "all":
for value in filter_.value:
if value not in metadata_value:
filter_matches = False
case _:
if filter_.value not in metadata_value:
filter_matches = False
elif isinstance(metadata_value, (int, float, str, bool)):
if metadata_value != filter_.value:
filter_matches = False
Expand Down
15 changes: 5 additions & 10 deletions llama-index-core/llama_index/core/vector_stores/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Vector store index types."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -69,10 +68,11 @@ class FilterOperator(str, Enum):
NE = "!=" # not equal to (string, int, float)
GTE = ">=" # greater than or equal to (int, float)
LTE = "<=" # less than or equal to (int, float)
IN = "in" # metadata in value array (string or number)
NIN = "nin" # metadata not in value array (string or number)
IN = "in" # In array (string or number)
NIN = "nin" # Not in array (string or number)
ANY = "any" # Contains any (array of strings)
ALL = "all" # Contains all (array of strings)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brenkehoe @logan-markewich Personally, I think the names "any" and "all" are too generic. Maybe "contains_any" and "contains_all" are more specific and easier to understand.

For example, Milvus also supports "array_contains", "array_contains_any" and "array_contains_all".

TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase within the text field)
CONTAINS = "contains" # metadata array contains value (string or number)
brenkehoe marked this conversation as resolved.
Show resolved Hide resolved


class FilterCondition(str, Enum):
Expand All @@ -93,12 +93,7 @@ class MetadataFilter(BaseModel):
"""

key: str
value: Union[
StrictInt,
StrictFloat,
StrictStr,
List[Union[StrictInt, StrictFloat, StrictStr]],
]
value: Union[StrictInt, StrictFloat, StrictStr, List[StrictStr]]
brenkehoe marked this conversation as resolved.
Show resolved Hide resolved
operator: FilterOperator = FilterOperator.EQ

@classmethod
Expand Down
44 changes: 41 additions & 3 deletions llama-index-core/tests/vector_stores/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
MetadataFilters,
VectorStoreQuery,
FilterCondition,
MetadataFilter,
FilterOperator,
)

_NODE_ID_WEIGHT_1_RANK_A = "AF3BE6C4-5F43-4D74-B075-6B0E07900DE8"
Expand All @@ -22,21 +24,21 @@ def _node_embeddings_for_test() -> List[TextNode]:
id_=_NODE_ID_WEIGHT_1_RANK_A,
embedding=[1.0, 0.0],
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-0")},
metadata={"weight": 1.0, "rank": "a"},
metadata={"weight": 1.0, "rank": "a", "quality": ["medium", "high"]},
),
TextNode(
text="lorem ipsum",
id_=_NODE_ID_WEIGHT_2_RANK_C,
embedding=[0.0, 1.0],
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-1")},
metadata={"weight": 2.0, "rank": "c"},
metadata={"weight": 2.0, "rank": "c", "quality": ["medium"]},
),
TextNode(
text="lorem ipsum",
id_=_NODE_ID_WEIGHT_3_RANK_C,
embedding=[1.0, 1.0],
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-2")},
metadata={"weight": 3.0, "rank": "c"},
metadata={"weight": 3.0, "rank": "c", "quality": ["low", "medium", "high"]},
),
]

Expand Down Expand Up @@ -181,3 +183,39 @@ def test_query_with_filters_with_filter_condition(self) -> None:
)
result = simple_vector_store.query(query)
self.assertEqual(len(result.ids), 0)

def test_query_with_any_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="quality", operator=FilterOperator.ANY, value=["high", "low"]
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_all_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="quality", operator=FilterOperator.ALL, value=["medium", "high"]
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def _transform_weaviate_filter_operator(operator: str) -> str:
return "GreaterThanEqual"
elif operator == "<=":
return "LessThanEqual"
elif operator == "all":
return "ContainsAll"
elif operator == "any":
return "ContainsAny"
else:
raise ValueError(f"Filter operator {operator} not supported")

Expand All @@ -77,6 +81,8 @@ def _to_weaviate_filter(standard_filters: MetadataFilters) -> Dict[str, Any]:
elif isinstance(filter.value, str) and filter.value.isnumeric():
filter.value = float(filter.value)
value_type = "valueNumber"
if filter.operator in ["any", "all"]:
value_type = "valueTextArray"
filters_list.append(
{
"path": filter.key,
Expand Down