Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement any and all metadata filters for weaviate vector store #13365

Merged
merged 7 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 40 additions & 10 deletions llama-index-core/llama_index/core/vector_stores/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BasePydanticVectorStore,
MetadataFilters,
FilterCondition,
FilterOperator,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
Expand All @@ -47,27 +48,56 @@ def _build_metadata_filter_fn(
metadata_filters: Optional[MetadataFilters] = None,
) -> Callable[[str], bool]:
"""Build metadata filter function."""
filter_list = metadata_filters.legacy_filters() if metadata_filters else []
filter_list = metadata_filters.filters if metadata_filters else []
if not filter_list:
return lambda _: True

filter_condition = cast(MetadataFilters, metadata_filters.condition)

def filter_fn(node_id: str) -> bool:
def _process_filter_match(
operator: FilterOperator, value: Any, metadata_value: Any
) -> bool:
if metadata_value is None:
return False
if operator == FilterOperator.EQ:
return metadata_value == value
if operator == FilterOperator.NE:
return metadata_value != value
if operator == FilterOperator.GT:
return metadata_value > value
if operator == FilterOperator.GTE:
return metadata_value >= value
if operator == FilterOperator.LT:
return metadata_value < value
if operator == FilterOperator.LTE:
return metadata_value <= value
if operator == FilterOperator.IN:
return value in metadata_value
if operator == FilterOperator.NIN:
return value not in metadata_value
if operator == FilterOperator.CONTAINS:
return value in metadata_value
if operator == FilterOperator.TEXT_MATCH:
return value.lower() in metadata_value.lower()
if operator == FilterOperator.ALL:
return all(val in metadata_value for val in value)
if operator == FilterOperator.ANY:
return any(val in metadata_value for val in value)
raise ValueError(f"Invalid operator: {operator}")

metadata = metadata_lookup_fn(node_id)

filter_matches_list = []
for filter_ in filter_list:
filter_matches = True
metadata_value = metadata.get(filter_.key, None)
if metadata_value is None:
filter_matches = False
elif isinstance(metadata_value, list):
if filter_.value not in metadata_value:
filter_matches = False
elif isinstance(metadata_value, (int, float, str, bool)):
if metadata_value != filter_.value:
filter_matches = False

filter_matches = _process_filter_match(
operator=filter_.operator,
value=filter_.value,
metadata_value=metadata.get(filter_.key, None),
)

filter_matches_list.append(filter_matches)

if filter_condition == FilterCondition.AND:
Expand Down
14 changes: 5 additions & 9 deletions llama-index-core/llama_index/core/vector_stores/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Vector store index types."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -69,8 +68,10 @@ class FilterOperator(str, Enum):
NE = "!=" # not equal to (string, int, float)
GTE = ">=" # greater than or equal to (int, float)
LTE = "<=" # less than or equal to (int, float)
IN = "in" # metadata in value array (string or number)
NIN = "nin" # metadata not in value array (string or number)
IN = "in" # In array (string or number)
NIN = "nin" # Not in array (string or number)
ANY = "any" # Contains any (array of strings)
ALL = "all" # Contains all (array of strings)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brenkehoe @logan-markewich Personally, I think the names "any" and "all" are too generic. Maybe "contains_any" and "contains_all" are more specific and easier to understand.

For example, Milvus also supports "array_contains", "array_contains_any" and "array_contains_all".

TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase within the text field)
CONTAINS = "contains" # metadata array contains value (string or number)
brenkehoe marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -93,12 +94,7 @@ class MetadataFilter(BaseModel):
"""

key: str
value: Union[
StrictInt,
StrictFloat,
StrictStr,
List[Union[StrictInt, StrictFloat, StrictStr]],
]
value: Union[StrictInt, StrictFloat, StrictStr, List[StrictStr]]
brenkehoe marked this conversation as resolved.
Show resolved Hide resolved
operator: FilterOperator = FilterOperator.EQ

@classmethod
Expand Down
224 changes: 221 additions & 3 deletions llama-index-core/tests/vector_stores/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
MetadataFilters,
VectorStoreQuery,
FilterCondition,
MetadataFilter,
FilterOperator,
)

_NODE_ID_WEIGHT_1_RANK_A = "AF3BE6C4-5F43-4D74-B075-6B0E07900DE8"
Expand All @@ -22,21 +24,36 @@ def _node_embeddings_for_test() -> List[TextNode]:
id_=_NODE_ID_WEIGHT_1_RANK_A,
embedding=[1.0, 0.0],
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-0")},
metadata={"weight": 1.0, "rank": "a"},
metadata={
"weight": 1.0,
"rank": "a",
"quality": ["medium", "high"],
"identifier": "6FTR78Yun",
},
),
TextNode(
text="lorem ipsum",
id_=_NODE_ID_WEIGHT_2_RANK_C,
embedding=[0.0, 1.0],
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-1")},
metadata={"weight": 2.0, "rank": "c"},
metadata={
"weight": 2.0,
"rank": "c",
"quality": ["medium"],
"identifier": "6FTR78Ygl",
},
),
TextNode(
text="lorem ipsum",
id_=_NODE_ID_WEIGHT_3_RANK_C,
embedding=[1.0, 1.0],
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-2")},
metadata={"weight": 3.0, "rank": "c"},
metadata={
"weight": 3.0,
"rank": "c",
"quality": ["low", "medium", "high"],
"identifier": "6FTR78Ztl",
},
),
]

Expand Down Expand Up @@ -182,6 +199,207 @@ def test_query_with_filters_with_filter_condition(self) -> None:
result = simple_vector_store.query(query)
self.assertEqual(len(result.ids), 0)

def test_query_with_equal_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.EQ, value=1.0)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 1)

def test_query_with_notequal_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.NE, value=1.0)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_greaterthan_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.GT, value=1.5)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_greaterthanequal_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.GTE, value=1.0)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 3)

def test_query_with_lessthan_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.LT, value=1.1)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None

def test_query_with_lessthanequal_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.LTE, value=1.0)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 1)

def test_query_with_in_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="quality", operator=FilterOperator.IN, value="high")
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_notin_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="quality", operator=FilterOperator.NIN, value="high")
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 1)

def test_query_with_contains_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="quality", operator=FilterOperator.CONTAINS, value="high"
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_textmatch_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="identifier",
operator=FilterOperator.TEXT_MATCH,
value="6FTR78Y",
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_any_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="quality", operator=FilterOperator.ANY, value=["high", "low"]
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_all_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="quality", operator=FilterOperator.ALL, value=["medium", "high"]
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_clear(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())
Expand Down