diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 188c15b6a..38d08dd14 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -56,10 +56,12 @@ "<": _operator_enum.LESS_THAN, "<=": _operator_enum.LESS_THAN_OR_EQUAL, _EQ_OP: _operator_enum.EQUAL, + "!=": _operator_enum.NOT_EQUAL, ">=": _operator_enum.GREATER_THAN_OR_EQUAL, ">": _operator_enum.GREATER_THAN, "array_contains": _operator_enum.ARRAY_CONTAINS, "in": _operator_enum.IN, + "not-in": _operator_enum.NOT_IN, "array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY, } _BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." @@ -255,8 +257,8 @@ def where(self, field_path: str, op_string: str, value) -> "BaseQuery": field_path (str): A field path (``.``-delimited list of field names) for the field to filter on. op_string (str): A comparison operation in the form of a string. - Acceptable values are ``<``, ``<=``, ``==``, ``>=``, ``>``, - ``in``, ``array_contains`` and ``array_contains_any``. + Acceptable values are ``<``, ``<=``, ``==``, ``!=``, ``>=``, ``>``, + ``in``, ``not-in``, ``array_contains`` and ``array_contains_any``. value (Any): The value to compare the field against in the filter. If ``value`` is :data:`None` or a NaN, then ``==`` is the only allowed operation. @@ -864,7 +866,7 @@ def _enum_from_op_string(op_string: str) -> Any: Args: op_string (str): A comparison operation in the form of a string. - Acceptable values are ``<``, ``<=``, ``==``, ``>=`` + Acceptable values are ``<``, ``<=``, ``==``, ``!=``, ``>=`` and ``>``. Returns: diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 988fa082c..355c5aebb 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -582,6 +582,36 @@ def test_query_stream_w_simple_field_in_op(query_docs): assert value["a"] == 1 +def test_query_stream_w_not_eq_op(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("stats.sum", "!=", 4) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == 20 + ab_pairs2 = set() + for key, value in values.items(): + assert stored[key] == value + ab_pairs2.add((value["a"], value["b"])) + + expected_ab_pairs = set( + [ + (a_val, b_val) + for a_val in allowed_vals + for b_val in allowed_vals + if a_val + b_val != 4 + ] + ) + assert expected_ab_pairs == ab_pairs2 + + +def test_query_stream_w_simple_not_in_op(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("stats.sum", "not-in", [2, num_vals + 100]) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + + assert len(values) == 22 + + def test_query_stream_w_simple_field_array_contains_any_op(query_docs): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 59578af39..4b22f6cd8 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1186,6 +1186,14 @@ def test_array_contains_any(self): self._call_fut("array_contains_any"), op_class.ARRAY_CONTAINS_ANY ) + def test_not_in(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("not-in"), op_class.NOT_IN) + + def test_not_eq(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("!="), op_class.NOT_EQUAL) + def test_invalid(self): with self.assertRaises(ValueError): self._call_fut("?") diff --git a/tests/unit/v1/testdata/query-invalid-operator.json b/tests/unit/v1/testdata/query-invalid-operator.json index 064164dc0..c53e5c2bd 100644 --- a/tests/unit/v1/testdata/query-invalid-operator.json +++ b/tests/unit/v1/testdata/query-invalid-operator.json @@ -2,7 +2,7 @@ "tests": [ { "description": "query: invalid operator in Where clause", - "comment": "The != operator is not supported.", + "comment": "The |~| operator is not supported.", "query": { "collPath": "projects/projectID/databases/(default)/documents/C", "clauses": [ @@ -13,7 +13,7 @@ "a" ] }, - "op": "!=", + "op": "|~|", "jsonValue": "4" } }