diff --git a/google/cloud/firestore_v1/watch.py b/google/cloud/firestore_v1/watch.py index 2216acd45..103732223 100644 --- a/google/cloud/firestore_v1/watch.py +++ b/google/cloud/firestore_v1/watch.py @@ -213,9 +213,9 @@ def __init__( self._closing = threading.Lock() self._closed = False - initial_request = firestore_pb2.ListenRequest( - database=self._firestore._database_string, add_target=self._targets - ) + self.resume_token = None + + rpc_request = self._get_rpc_request if ResumableBidiRpc is None: ResumableBidiRpc = self.ResumableBidiRpc # FBO unit tests @@ -224,7 +224,7 @@ def __init__( self._api.transport.listen, should_recover=_should_recover, should_terminate=_should_terminate, - initial_request=initial_request, + initial_request=rpc_request, metadata=self._firestore._rpc_metadata, ) @@ -252,13 +252,19 @@ def __init__( self.has_pushed = False # The server assigns and updates the resume token. - self.resume_token = None if BackgroundConsumer is None: # FBO unit tests BackgroundConsumer = self.BackgroundConsumer self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot) self._consumer.start() + def _get_rpc_request(self): + if self.resume_token is not None: + self._targets["resume_token"] = self.resume_token + return firestore_pb2.ListenRequest( + database=self._firestore._database_string, add_target=self._targets + ) + @property def is_active(self): """bool: True if this manager is actively streaming. diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index afd88b813..0778717bc 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -776,6 +776,12 @@ def test__reset_docs(self): self.assertEqual(inst.resume_token, None) self.assertFalse(inst.current) + def test_resume_token_sent_on_recovery(self): + inst = self._makeOne() + inst.resume_token = b"ABCD0123" + request = inst._get_rpc_request() + self.assertEqual(request.add_target.resume_token, b"ABCD0123") + class DummyFirestoreStub(object): def Listen(self): # pragma: NO COVER @@ -922,7 +928,7 @@ def __init__( self.start_rpc = start_rpc self.should_recover = should_recover self.should_terminate = should_terminate - self.initial_request = initial_request + self.initial_request = initial_request() self.metadata = metadata self.closed = False self.callbacks = []