diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 9d4068c75..04485a84c 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -213,7 +213,7 @@ def get_all(self, references): .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ - return self._client.get_all(references, transaction=self._id) + return self._client.get_all(references, transaction=self) def get(self, ref_or_query): """ @@ -225,9 +225,9 @@ def get(self, ref_or_query): query, or :data:`None` if the document does not exist. """ if isinstance(ref_or_query, DocumentReference): - return self._client.get_all([ref_or_query], transaction=self._id) + return self._client.get_all([ref_or_query], transaction=self) elif isinstance(ref_or_query, Query): - return ref_or_query.stream(transaction=self._id) + return ref_or_query.stream(transaction=self) else: raise ValueError( 'Value for argument "ref_or_query" must be a DocumentReference or a Query.' diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index 8cae24a23..da3c2d0b0 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -333,7 +333,7 @@ def test_get_all(self): transaction = self._make_one(client) ref1, ref2 = mock.Mock(), mock.Mock() result = transaction.get_all([ref1, ref2]) - client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction.id) + client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction) self.assertIs(result, client.get_all.return_value) def test_get_document_ref(self): @@ -343,7 +343,7 @@ def test_get_document_ref(self): transaction = self._make_one(client) ref = DocumentReference("documents", "doc-id") result = transaction.get(ref) - client.get_all.assert_called_once_with([ref], transaction=transaction.id) + client.get_all.assert_called_once_with([ref], transaction=transaction) self.assertIs(result, client.get_all.return_value) def test_get_w_query(self): @@ -354,7 +354,7 @@ def test_get_w_query(self): query = Query(parent=mock.Mock(spec=[])) query.stream = mock.MagicMock() result = transaction.get(query) - query.stream.assert_called_once_with(transaction=transaction.id) + query.stream.assert_called_once_with(transaction=transaction) self.assertIs(result, query.stream.return_value) def test_get_failure(self):