From 0fcfc2301246d3f20b6fbffc1deae06f16721ec7 Mon Sep 17 00:00:00 2001 From: larkee <31196561+larkee@users.noreply.github.com> Date: Mon, 26 Apr 2021 17:43:21 +1000 Subject: [PATCH] fix: correctly set resume token when restarting streams (#314) * fix: correctly set resume token for restarting streams * style: fix lint * docs: update docstring * test: fix assertion Co-authored-by: larkee --- google/cloud/spanner_v1/database.py | 6 +- google/cloud/spanner_v1/snapshot.py | 26 +++++--- tests/unit/test_snapshot.py | 92 +++++++++++++++++------------ 3 files changed, 76 insertions(+), 48 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 1e76bf218f..5eb688d9c6 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -518,11 +518,11 @@ def execute_pdml(): param_types=param_types, query_options=query_options, ) - restart = functools.partial( - api.execute_streaming_sql, request=request, metadata=metadata, + method = functools.partial( + api.execute_streaming_sql, metadata=metadata, ) - iterator = _restart_on_unavailable(restart) + iterator = _restart_on_unavailable(method, request) result_set = StreamedResultSet(iterator) list(result_set) # consume all partials diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 1b3ae8097d..f926d7836d 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -41,16 +41,21 @@ ) -def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=None): +def _restart_on_unavailable( + method, request, trace_name=None, session=None, attributes=None +): """Restart iteration after :exc:`.ServiceUnavailable`. - :type restart: callable - :param restart: curried function returning iterator + :type method: callable + :param method: function returning iterator + + :type request: proto + :param request: request proto to call the method with """ resume_token = b"" item_buffer = [] with trace_call(trace_name, session, attributes): - iterator = restart() + iterator = method(request=request) while True: try: for item in iterator: @@ -61,7 +66,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N except ServiceUnavailable: del item_buffer[:] with trace_call(trace_name, session, attributes): - iterator = restart(resume_token=resume_token) + request.resume_token = resume_token + iterator = method(request=request) continue except InternalServerError as exc: resumable_error = any( @@ -72,7 +78,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N raise del item_buffer[:] with trace_call(trace_name, session, attributes): - iterator = restart(resume_token=resume_token) + request.resume_token = resume_token + iterator = method(request=request) continue if len(item_buffer) == 0: @@ -189,7 +196,11 @@ def read( trace_attributes = {"table_id": table, "columns": columns} iterator = _restart_on_unavailable( - restart, "CloudSpanner.ReadOnlyTransaction", self._session, trace_attributes + restart, + request, + "CloudSpanner.ReadOnlyTransaction", + self._session, + trace_attributes, ) self._read_request_count += 1 @@ -302,6 +313,7 @@ def execute_sql( trace_attributes = {"db.statement": sql} iterator = _restart_on_unavailable( restart, + request, "CloudSpanner.ReadWriteTransaction", self._session, trace_attributes, diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index cc9a67cb4d..24f87a30fc 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -47,10 +47,12 @@ class Test_restart_on_unavailable(OpenTelemetryBase): - def _call_fut(self, restart, span_name=None, session=None, attributes=None): + def _call_fut( + self, restart, request, span_name=None, session=None, attributes=None + ): from google.cloud.spanner_v1.snapshot import _restart_on_unavailable - return _restart_on_unavailable(restart, span_name, session, attributes) + return _restart_on_unavailable(restart, request, span_name, session, attributes) def _make_item(self, value, resume_token=b""): return mock.Mock( @@ -59,18 +61,21 @@ def _make_item(self, value, resume_token=b""): def test_iteration_w_empty_raw(self): raw = _MockIterator() + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), []) + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_non_empty_raw(self): ITEMS = (self._make_item(0), self._make_item(1)) raw = _MockIterator(*ITEMS) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) - restart.assert_called_once_with() + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_raw_w_resume_tken(self): @@ -81,10 +86,11 @@ def test_iteration_w_raw_w_resume_tken(self): self._make_item(3), ) raw = _MockIterator(*ITEMS) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) - restart.assert_called_once_with() + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_raw_raising_unavailable_no_token(self): @@ -97,10 +103,12 @@ def test_iteration_w_raw_raising_unavailable_no_token(self): ) before = _MockIterator(fail_after=True, error=ServiceUnavailable("testing")) after = _MockIterator(*ITEMS) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) - self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")]) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, b"") self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): @@ -118,10 +126,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): ), ) after = _MockIterator(*ITEMS) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) - self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")]) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, b"") self.assertNoSpans() def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): @@ -134,11 +144,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): ) before = _MockIterator(fail_after=True, error=InternalServerError("testing")) after = _MockIterator(*ITEMS) + request = mock.Mock(spec=["resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) with self.assertRaises(InternalServerError): list(resumable) - self.assertEqual(restart.mock_calls, [mock.call()]) + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_raw_raising_unavailable(self): @@ -151,12 +162,12 @@ def test_iteration_w_raw_raising_unavailable(self): *(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing") ) after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + LAST)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error(self): @@ -173,12 +184,12 @@ def test_iteration_w_raw_raising_retryable_internal_error(self): ) ) after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + LAST)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) self.assertNoSpans() def test_iteration_w_raw_raising_non_retryable_internal_error(self): @@ -191,11 +202,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self): *(FIRST + SECOND), fail_after=True, error=InternalServerError("testing") ) after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) with self.assertRaises(InternalServerError): list(resumable) - self.assertEqual(restart.mock_calls, [mock.call()]) + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_raw_raising_unavailable_after_token(self): @@ -207,12 +219,12 @@ def test_iteration_w_raw_raising_unavailable_after_token(self): *FIRST, fail_after=True, error=ServiceUnavailable("testing") ) after = _MockIterator(*SECOND) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + SECOND)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): @@ -228,12 +240,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): ) ) after = _MockIterator(*SECOND) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + SECOND)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) self.assertNoSpans() def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): @@ -245,19 +257,23 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): *FIRST, fail_after=True, error=InternalServerError("testing") ) after = _MockIterator(*SECOND) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) with self.assertRaises(InternalServerError): list(resumable) - self.assertEqual(restart.mock_calls, [mock.call()]) + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_span_creation(self): name = "TestSpan" extra_atts = {"test_att": 1} raw = _MockIterator() + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) - resumable = self._call_fut(restart, name, _Session(_Database()), extra_atts) + resumable = self._call_fut( + restart, request, name, _Session(_Database()), extra_atts + ) self.assertEqual(list(resumable), []) self.assertSpanAttributes(name, attributes=dict(BASE_ATTRIBUTES, test_att=1)) @@ -272,13 +288,13 @@ def test_iteration_w_multiple_span_creation(self): *(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing") ) after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) name = "TestSpan" - resumable = self._call_fut(restart, name, _Session(_Database())) + resumable = self._call_fut(restart, request, name, _Session(_Database())) self.assertEqual(list(resumable), list(FIRST + LAST)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) span_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(span_list), 2)