From 51295fda5f16e7211a6379a71d602af8d411d94f Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 7 Oct 2020 12:29:59 -0400 Subject: [PATCH] feat: add 'update_transforms' to '_helpers.pbs_for_create' Toward #217. --- google/cloud/firestore_v1/_helpers.py | 25 +++++------- tests/unit/v1/test__helpers.py | 58 ++++++++++++++++++++------- 2 files changed, 54 insertions(+), 29 deletions(-) diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index f9f01e7b9..a327a5fa7 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -495,7 +495,9 @@ def get_update_pb( return update_pb - def get_transform_pb(self, document_path, exists=None) -> types.write.Write: + def get_field_transform_pbs( + self, document_path + ) -> List[types.write.DocumentTransform.FieldTransform]: def make_array_value(values): value_list = [encode_value(element) for element in values] return document.ArrayValue(values=value_list) @@ -559,9 +561,10 @@ def make_array_value(values): for path, value in self.minimums.items() ] ) - field_transforms = [ - transform for path, transform in sorted(path_field_transforms) - ] + return [transform for path, transform in sorted(path_field_transforms)] + + def get_transform_pb(self, document_path, exists=None) -> types.write.Write: + field_transforms = self.get_field_transform_pbs(document_path) transform_pb = write.Write( transform=write.DocumentTransform( document=document_path, field_transforms=field_transforms @@ -592,19 +595,13 @@ def pbs_for_create(document_path, document_data) -> List[types.write.Write]: if extractor.deleted_fields: raise ValueError("Cannot apply DELETE_FIELD in a create request.") - write_pbs = [] - - # Conformance tests require skipping the 'update_pb' if the document - # contains only transforms. - if extractor.empty_document or extractor.set_fields: - write_pbs.append(extractor.get_update_pb(document_path, exists=False)) + create_pb = extractor.get_update_pb(document_path, exists=False) if extractor.has_transforms: - exists = None if write_pbs else False - transform_pb = extractor.get_transform_pb(document_path, exists) - write_pbs.append(transform_pb) + field_transform_pbs = extractor.get_field_transform_pbs(document_path) + create_pb.update_transforms.extend(field_transform_pbs) - return write_pbs + return [create_pb] def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]: diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index 55b74f89d..dc8462e7f 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -1270,6 +1270,42 @@ def test_get_update_pb_wo_exists_precondition(self): self.assertEqual(update_pb.update.fields, encode_dict(document_data)) self.assertFalse(update_pb._pb.HasField("current_document")) + def test_get_field_transform_pbs_miss(self): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM + + document_data = {"a": 1} + inst = self._make_one(document_data) + document_path = ( + "projects/project-id/databases/(default)/" "documents/document-id" + ) + + field_transform_pbs = inst.get_field_transform_pbs(document_path) + + self.assertEqual(field_transform_pbs, []) + + def test_get_field_transform_pbs_w_server_timestamp(self): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM + + document_data = {"a": SERVER_TIMESTAMP} + inst = self._make_one(document_data) + document_path = ( + "projects/project-id/databases/(default)/" "documents/document-id" + ) + + field_transform_pbs = inst.get_field_transform_pbs(document_path) + + self.assertEqual(len(field_transform_pbs), 1) + field_transform_pb = field_transform_pbs[0] + self.assertIsInstance( + field_transform_pb, write.DocumentTransform.FieldTransform + ) + self.assertEqual(field_transform_pb.field_path, "a") + self.assertEqual(field_transform_pb.set_to_server_value, REQUEST_TIME_ENUM) + def test_get_transform_pb_w_server_timestamp_w_exists_precondition(self): from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP @@ -1526,23 +1562,17 @@ def _make_write_w_document(document_path, **data): ) @staticmethod - def _make_write_w_transform(document_path, fields): + def _add_field_transforms(update_pb, fields): from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import DocumentTransform server_val = DocumentTransform.FieldTransform.ServerValue - transforms = [ - write.DocumentTransform.FieldTransform( - field_path=field, set_to_server_value=server_val.REQUEST_TIME - ) - for field in fields - ] - - return write.Write( - transform=write.DocumentTransform( - document=document_path, field_transforms=transforms + for field in fields: + update_pb.update_transforms.append( + DocumentTransform.FieldTransform( + field_path=field, set_to_server_value=server_val.REQUEST_TIME + ) ) - ) def _helper(self, do_transform=False, empty_val=False): from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP @@ -1569,9 +1599,7 @@ def _helper(self, do_transform=False, empty_val=False): expected_pbs = [update_pb] if do_transform: - expected_pbs.append( - self._make_write_w_transform(document_path, fields=["butter"]) - ) + self._add_field_transforms(update_pb, fields=["butter"]) self.assertEqual(write_pbs, expected_pbs)