From 0c2f950022befac99e2b9f52eb63d43d887a9682 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 7 Oct 2020 13:08:24 -0400 Subject: [PATCH] feat: add 'update_transforms' to '_helpers.pbs_for_set_with_merge' Toward #217. --- google/cloud/firestore_v1/_helpers.py | 15 +++----- tests/unit/v1/test__helpers.py | 53 +++++++++++++++------------ 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index d8e6df772..d68ec7ba0 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -794,19 +794,14 @@ def pbs_for_set_with_merge( extractor.apply_merge(merge) merge_empty = not document_data + allow_empty_mask = merge_empty or extractor.transform_paths - write_pbs = [] - - if extractor.has_updates or merge_empty: - write_pbs.append( - extractor.get_update_pb(document_path, allow_empty_mask=merge_empty) - ) - + set_pb = extractor.get_update_pb(document_path, allow_empty_mask=allow_empty_mask) if extractor.transform_paths: - transform_pb = extractor.get_transform_pb(document_path) - write_pbs.append(transform_pb) + field_transform_pbs = extractor.get_field_transform_pbs(document_path) + set_pb.update_transforms.extend(field_transform_pbs) - return write_pbs + return [set_pb] class DocumentExtractorForUpdate(DocumentExtractor): diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index f6a4d8284..56dbc3287 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -1924,23 +1924,17 @@ def _make_write_w_document(document_path, **data): ) @staticmethod - def _make_write_w_transform(document_path, fields): + def _add_field_transforms(update_pb, fields): from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import DocumentTransform server_val = DocumentTransform.FieldTransform.ServerValue - transforms = [ - write.DocumentTransform.FieldTransform( - field_path=field, set_to_server_value=server_val.REQUEST_TIME - ) - for field in fields - ] - - return write.Write( - transform=write.DocumentTransform( - document=document_path, field_transforms=transforms + for field in fields: + update_pb.update_transforms.append( + DocumentTransform.FieldTransform( + field_path=field, set_to_server_value=server_val.REQUEST_TIME + ) ) - ) @staticmethod def _update_document_mask(update_pb, field_paths): @@ -1974,6 +1968,20 @@ def test_with_merge_field_wo_transform(self): expected_pbs = [update_pb] self.assertEqual(write_pbs, expected_pbs) + def test_with_merge_true_w_only_transform(self): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + document_data = {"butter": SERVER_TIMESTAMP} + + write_pbs = self._call_fut(document_path, document_data, merge=True) + + update_pb = self._make_write_w_document(document_path) + self._update_document_mask(update_pb, field_paths=()) + self._add_field_transforms(update_pb, fields=["butter"]) + expected_pbs = [update_pb] + self.assertEqual(write_pbs, expected_pbs) + def test_with_merge_true_w_transform(self): from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP @@ -1986,8 +1994,8 @@ def test_with_merge_true_w_transform(self): update_pb = self._make_write_w_document(document_path, **update_data) self._update_document_mask(update_pb, field_paths=sorted(update_data)) - transform_pb = self._make_write_w_transform(document_path, fields=["butter"]) - expected_pbs = [update_pb, transform_pb] + self._add_field_transforms(update_pb, fields=["butter"]) + expected_pbs = [update_pb] self.assertEqual(write_pbs, expected_pbs) def test_with_merge_field_w_transform(self): @@ -2006,8 +2014,8 @@ def test_with_merge_field_w_transform(self): document_path, cheese=document_data["cheese"] ) self._update_document_mask(update_pb, ["cheese"]) - transform_pb = self._make_write_w_transform(document_path, fields=["butter"]) - expected_pbs = [update_pb, transform_pb] + self._add_field_transforms(update_pb, fields=["butter"]) + expected_pbs = [update_pb] self.assertEqual(write_pbs, expected_pbs) def test_with_merge_field_w_transform_masking_simple(self): @@ -2021,10 +2029,9 @@ def test_with_merge_field_w_transform_masking_simple(self): write_pbs = self._call_fut(document_path, document_data, merge=["butter.pecan"]) update_pb = self._make_write_w_document(document_path) - transform_pb = self._make_write_w_transform( - document_path, fields=["butter.pecan"] - ) - expected_pbs = [update_pb, transform_pb] + self._update_document_mask(update_pb, field_paths=()) + self._add_field_transforms(update_pb, fields=["butter.pecan"]) + expected_pbs = [update_pb] self.assertEqual(write_pbs, expected_pbs) def test_with_merge_field_w_transform_parent(self): @@ -2043,10 +2050,8 @@ def test_with_merge_field_w_transform_parent(self): document_path, cheese=update_data["cheese"], butter={"popcorn": "yum"} ) self._update_document_mask(update_pb, ["cheese", "butter"]) - transform_pb = self._make_write_w_transform( - document_path, fields=["butter.pecan"] - ) - expected_pbs = [update_pb, transform_pb] + self._add_field_transforms(update_pb, fields=["butter.pecan"]) + expected_pbs = [update_pb] self.assertEqual(write_pbs, expected_pbs)