Skip to content

Commit

Permalink
feat: add 'update_transforms' to '_helpers.pbs_for_create'
Browse files Browse the repository at this point in the history
Toward #217.
  • Loading branch information
tseaver committed Oct 7, 2020
1 parent c3acd4a commit 51295fd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 29 deletions.
25 changes: 11 additions & 14 deletions google/cloud/firestore_v1/_helpers.py
Expand Up @@ -495,7 +495,9 @@ def get_update_pb(

return update_pb

def get_transform_pb(self, document_path, exists=None) -> types.write.Write:
def get_field_transform_pbs(
self, document_path
) -> List[types.write.DocumentTransform.FieldTransform]:
def make_array_value(values):
value_list = [encode_value(element) for element in values]
return document.ArrayValue(values=value_list)
Expand Down Expand Up @@ -559,9 +561,10 @@ def make_array_value(values):
for path, value in self.minimums.items()
]
)
field_transforms = [
transform for path, transform in sorted(path_field_transforms)
]
return [transform for path, transform in sorted(path_field_transforms)]

def get_transform_pb(self, document_path, exists=None) -> types.write.Write:
field_transforms = self.get_field_transform_pbs(document_path)
transform_pb = write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=field_transforms
Expand Down Expand Up @@ -592,19 +595,13 @@ def pbs_for_create(document_path, document_data) -> List[types.write.Write]:
if extractor.deleted_fields:
raise ValueError("Cannot apply DELETE_FIELD in a create request.")

write_pbs = []

# Conformance tests require skipping the 'update_pb' if the document
# contains only transforms.
if extractor.empty_document or extractor.set_fields:
write_pbs.append(extractor.get_update_pb(document_path, exists=False))
create_pb = extractor.get_update_pb(document_path, exists=False)

if extractor.has_transforms:
exists = None if write_pbs else False
transform_pb = extractor.get_transform_pb(document_path, exists)
write_pbs.append(transform_pb)
field_transform_pbs = extractor.get_field_transform_pbs(document_path)
create_pb.update_transforms.extend(field_transform_pbs)

return write_pbs
return [create_pb]


def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]:
Expand Down
58 changes: 43 additions & 15 deletions tests/unit/v1/test__helpers.py
Expand Up @@ -1270,6 +1270,42 @@ def test_get_update_pb_wo_exists_precondition(self):
self.assertEqual(update_pb.update.fields, encode_dict(document_data))
self.assertFalse(update_pb._pb.HasField("current_document"))

def test_get_field_transform_pbs_miss(self):
from google.cloud.firestore_v1.types import write
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM

document_data = {"a": 1}
inst = self._make_one(document_data)
document_path = (
"projects/project-id/databases/(default)/" "documents/document-id"
)

field_transform_pbs = inst.get_field_transform_pbs(document_path)

self.assertEqual(field_transform_pbs, [])

def test_get_field_transform_pbs_w_server_timestamp(self):
from google.cloud.firestore_v1.types import write
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM

document_data = {"a": SERVER_TIMESTAMP}
inst = self._make_one(document_data)
document_path = (
"projects/project-id/databases/(default)/" "documents/document-id"
)

field_transform_pbs = inst.get_field_transform_pbs(document_path)

self.assertEqual(len(field_transform_pbs), 1)
field_transform_pb = field_transform_pbs[0]
self.assertIsInstance(
field_transform_pb, write.DocumentTransform.FieldTransform
)
self.assertEqual(field_transform_pb.field_path, "a")
self.assertEqual(field_transform_pb.set_to_server_value, REQUEST_TIME_ENUM)

def test_get_transform_pb_w_server_timestamp_w_exists_precondition(self):
from google.cloud.firestore_v1.types import write
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
Expand Down Expand Up @@ -1526,23 +1562,17 @@ def _make_write_w_document(document_path, **data):
)

@staticmethod
def _make_write_w_transform(document_path, fields):
def _add_field_transforms(update_pb, fields):
from google.cloud.firestore_v1.types import write
from google.cloud.firestore_v1 import DocumentTransform

server_val = DocumentTransform.FieldTransform.ServerValue
transforms = [
write.DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
for field in fields
]

return write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=transforms
for field in fields:
update_pb.update_transforms.append(
DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
)
)

def _helper(self, do_transform=False, empty_val=False):
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
Expand All @@ -1569,9 +1599,7 @@ def _helper(self, do_transform=False, empty_val=False):
expected_pbs = [update_pb]

if do_transform:
expected_pbs.append(
self._make_write_w_transform(document_path, fields=["butter"])
)
self._add_field_transforms(update_pb, fields=["butter"])

self.assertEqual(write_pbs, expected_pbs)

Expand Down

0 comments on commit 51295fd

Please sign in to comment.