Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
feat(firestore): surface new 'IN' and 'ARRAY_CONTAINS_ANY' operators …
…(#9541)
  • Loading branch information
tseaver committed Oct 30, 2019
1 parent d69ec57 commit 5e9fe4f
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 53 deletions.
2 changes: 2 additions & 0 deletions google/cloud/firestore_v1/query.py
Expand Up @@ -43,6 +43,8 @@
">=": _operator_enum.GREATER_THAN_OR_EQUAL,
">": _operator_enum.GREATER_THAN,
"array_contains": _operator_enum.ARRAY_CONTAINS,
"in": _operator_enum.IN,
"array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY,
}
_BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}."
_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values'
Expand Down
155 changes: 105 additions & 50 deletions tests/system/test_system.py
Expand Up @@ -492,11 +492,13 @@ def test_collection_add(client, cleanup):
assert set(collection3.list_documents()) == {document_ref5}


def test_query_stream(client, cleanup):
@pytest.fixture
def query_docs(client):
collection_id = "qs" + UNIQUE_RESOURCE_ID
sub_collection = "child" + UNIQUE_RESOURCE_ID
collection = client.collection(collection_id, "doc", sub_collection)

cleanup = []
stored = {}
num_vals = 5
allowed_vals = six.moves.xrange(num_vals)
Expand All @@ -505,38 +507,82 @@ def test_query_stream(client, cleanup):
document_data = {
"a": a_val,
"b": b_val,
"c": [a_val, num_vals * 100],
"stats": {"sum": a_val + b_val, "product": a_val * b_val},
}
_, doc_ref = collection.add(document_data)
# Add to clean-up.
cleanup(doc_ref.delete)
cleanup.append(doc_ref.delete)
stored[doc_ref.id] = document_data

# 0. Limit to snapshots where ``a==1``.
query0 = collection.where("a", "==", 1)
values0 = {snapshot.id: snapshot.to_dict() for snapshot in query0.stream()}
assert len(values0) == num_vals
for key, value in six.iteritems(values0):
yield collection, stored, allowed_vals

for operation in cleanup:
operation()


def test_query_stream_w_simple_field_eq_op(query_docs):
collection, stored, allowed_vals = query_docs
query = collection.where("a", "==", 1)
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
assert len(values) == len(allowed_vals)
for key, value in six.iteritems(values):
assert stored[key] == value
assert value["a"] == 1


def test_query_stream_w_simple_field_array_contains_op(query_docs):
collection, stored, allowed_vals = query_docs
query = collection.where("c", "array_contains", 1)
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
assert len(values) == len(allowed_vals)
for key, value in six.iteritems(values):
assert stored[key] == value
assert value["a"] == 1


def test_query_stream_w_simple_field_in_op(query_docs):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
query = collection.where("a", "in", [1, num_vals + 100])
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
assert len(values) == len(allowed_vals)
for key, value in six.iteritems(values):
assert stored[key] == value
assert value["a"] == 1

# 1. Order by ``b``.
query1 = collection.order_by("b", direction=query0.DESCENDING)
values1 = [(snapshot.id, snapshot.to_dict()) for snapshot in query1.stream()]
assert len(values1) == len(stored)
b_vals1 = []
for key, value in values1:

def test_query_stream_w_simple_field_array_contains_any_op(query_docs):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
query = collection.where("c", "array_contains_any", [1, num_vals * 200])
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
assert len(values) == len(allowed_vals)
for key, value in six.iteritems(values):
assert stored[key] == value
b_vals1.append(value["b"])
assert value["a"] == 1


def test_query_stream_w_order_by(query_docs):
collection, stored, allowed_vals = query_docs
query = collection.order_by("b", direction=firestore.Query.DESCENDING)
values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()]
assert len(values) == len(stored)
b_vals = []
for key, value in values:
assert stored[key] == value
b_vals.append(value["b"])
# Make sure the ``b``-values are in DESCENDING order.
assert sorted(b_vals1, reverse=True) == b_vals1
assert sorted(b_vals, reverse=True) == b_vals


# 2. Limit to snapshots where ``stats.sum > 1`` (a field path).
query2 = collection.where("stats.sum", ">", 4)
values2 = {snapshot.id: snapshot.to_dict() for snapshot in query2.stream()}
assert len(values2) == 10
def test_query_stream_w_field_path(query_docs):
collection, stored, allowed_vals = query_docs
query = collection.where("stats.sum", ">", 4)
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
assert len(values) == 10
ab_pairs2 = set()
for key, value in six.iteritems(values2):
for key, value in six.iteritems(values):
assert stored[key] == value
ab_pairs2.add((value["a"], value["b"]))

Expand All @@ -550,63 +596,72 @@ def test_query_stream(client, cleanup):
)
assert expected_ab_pairs == ab_pairs2

# 3. Use a start and end cursor.
query3 = (

def test_query_stream_w_start_end_cursor(query_docs):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
query = (
collection.order_by("a")
.start_at({"a": num_vals - 2})
.end_before({"a": num_vals - 1})
)
values3 = [(snapshot.id, snapshot.to_dict()) for snapshot in query3.stream()]
assert len(values3) == num_vals
for key, value in values3:
values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()]
assert len(values) == num_vals
for key, value in values:
assert stored[key] == value
assert value["a"] == num_vals - 2
b_vals1.append(value["b"])

# 4. Send a query with no results.
query4 = collection.where("b", "==", num_vals + 100)
values4 = list(query4.stream())
assert len(values4) == 0

# 5. Select a subset of fields.
query5 = collection.where("b", "<=", 1)
query5 = query5.select(["a", "stats.product"])
values5 = {snapshot.id: snapshot.to_dict() for snapshot in query5.stream()}
assert len(values5) == num_vals * 2 # a ANY, b in (0, 1)
for key, value in six.iteritems(values5):


def test_query_stream_wo_results(query_docs):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
query = collection.where("b", "==", num_vals + 100)
values = list(query.stream())
assert len(values) == 0


def test_query_stream_w_projection(query_docs):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
query = collection.where("b", "<=", 1).select(["a", "stats.product"])
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
assert len(values) == num_vals * 2 # a ANY, b in (0, 1)
for key, value in six.iteritems(values):
expected = {
"a": stored[key]["a"],
"stats": {"product": stored[key]["stats"]["product"]},
}
assert expected == value

# 6. Add multiple filters via ``where()``.
query6 = collection.where("stats.product", ">", 5)
query6 = query6.where("stats.product", "<", 10)
values6 = {snapshot.id: snapshot.to_dict() for snapshot in query6.stream()}

def test_query_stream_w_multiple_filters(query_docs):
collection, stored, allowed_vals = query_docs
query = collection.where("stats.product", ">", 5).where("stats.product", "<", 10)
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
matching_pairs = [
(a_val, b_val)
for a_val in allowed_vals
for b_val in allowed_vals
if 5 < a_val * b_val < 10
]
assert len(values6) == len(matching_pairs)
for key, value in six.iteritems(values6):
assert len(values) == len(matching_pairs)
for key, value in six.iteritems(values):
assert stored[key] == value
pair = (value["a"], value["b"])
assert pair in matching_pairs

# 7. Skip the first three results, when ``b==2``
query7 = collection.where("b", "==", 2)

def test_query_stream_w_offset(query_docs):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
offset = 3
query7 = query7.offset(offset)
values7 = {snapshot.id: snapshot.to_dict() for snapshot in query7.stream()}
query = collection.where("b", "==", 2).offset(offset)
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
# NOTE: We don't check the ``a``-values, since that would require
# an ``order_by('a')``, which combined with the ``b == 2``
# filter would necessitate an index.
assert len(values7) == num_vals - offset
for key, value in six.iteritems(values7):
assert len(values) == num_vals - offset
for key, value in six.iteritems(values):
assert stored[key] == value
assert value["b"] == 2

Expand Down
35 changes: 32 additions & 3 deletions tests/unit/v1/test_query.py
Expand Up @@ -1464,18 +1464,47 @@ def _call_fut(op_string):

return _enum_from_op_string(op_string)

def test_success(self):
@staticmethod
def _get_op_class():
from google.cloud.firestore_v1.gapic import enums

op_class = enums.StructuredQuery.FieldFilter.Operator
return enums.StructuredQuery.FieldFilter.Operator

def test_lt(self):
op_class = self._get_op_class()
self.assertEqual(self._call_fut("<"), op_class.LESS_THAN)

def test_le(self):
op_class = self._get_op_class()
self.assertEqual(self._call_fut("<="), op_class.LESS_THAN_OR_EQUAL)

def test_eq(self):
op_class = self._get_op_class()
self.assertEqual(self._call_fut("=="), op_class.EQUAL)

def test_ge(self):
op_class = self._get_op_class()
self.assertEqual(self._call_fut(">="), op_class.GREATER_THAN_OR_EQUAL)

def test_gt(self):
op_class = self._get_op_class()
self.assertEqual(self._call_fut(">"), op_class.GREATER_THAN)

def test_array_contains(self):
op_class = self._get_op_class()
self.assertEqual(self._call_fut("array_contains"), op_class.ARRAY_CONTAINS)

def test_failure(self):
def test_in(self):
op_class = self._get_op_class()
self.assertEqual(self._call_fut("in"), op_class.IN)

def test_array_contains_any(self):
op_class = self._get_op_class()
self.assertEqual(
self._call_fut("array_contains_any"), op_class.ARRAY_CONTAINS_ANY
)

def test_invalid(self):
with self.assertRaises(ValueError):
self._call_fut("?")

Expand Down

0 comments on commit 5e9fe4f

Please sign in to comment.