Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correctly set resume token when restarting streams #314

Merged
merged 4 commits into from Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -518,11 +518,11 @@ def execute_pdml():
param_types=param_types,
query_options=query_options,
)
restart = functools.partial(
api.execute_streaming_sql, request=request, metadata=metadata,
method = functools.partial(
api.execute_streaming_sql, metadata=metadata,
)

iterator = _restart_on_unavailable(restart)
iterator = _restart_on_unavailable(method, request)

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
Expand Down
26 changes: 19 additions & 7 deletions google/cloud/spanner_v1/snapshot.py
Expand Up @@ -41,16 +41,21 @@
)


def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=None):
def _restart_on_unavailable(
method, request, trace_name=None, session=None, attributes=None
):
"""Restart iteration after :exc:`.ServiceUnavailable`.

:type restart: callable
:param restart: curried function returning iterator
:type method: callable
:param method: function returning iterator

:type request: proto
:param request: request proto to call the method with
"""
resume_token = b""
item_buffer = []
with trace_call(trace_name, session, attributes):
iterator = restart()
iterator = method(request=request)
while True:
try:
for item in iterator:
Expand All @@ -61,7 +66,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
except ServiceUnavailable:
del item_buffer[:]
with trace_call(trace_name, session, attributes):
iterator = restart(resume_token=resume_token)
request.resume_token = resume_token
iterator = method(request=request)
continue
except InternalServerError as exc:
resumable_error = any(
Expand All @@ -72,7 +78,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
raise
del item_buffer[:]
with trace_call(trace_name, session, attributes):
iterator = restart(resume_token=resume_token)
request.resume_token = resume_token
iterator = method(request=request)
continue

if len(item_buffer) == 0:
Expand Down Expand Up @@ -189,7 +196,11 @@ def read(

trace_attributes = {"table_id": table, "columns": columns}
iterator = _restart_on_unavailable(
restart, "CloudSpanner.ReadOnlyTransaction", self._session, trace_attributes
restart,
request,
"CloudSpanner.ReadOnlyTransaction",
self._session,
trace_attributes,
)

self._read_request_count += 1
Expand Down Expand Up @@ -302,6 +313,7 @@ def execute_sql(
trace_attributes = {"db.statement": sql}
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadWriteTransaction",
self._session,
trace_attributes,
Expand Down
92 changes: 54 additions & 38 deletions tests/unit/test_snapshot.py
Expand Up @@ -47,10 +47,12 @@


class Test_restart_on_unavailable(OpenTelemetryBase):
def _call_fut(self, restart, span_name=None, session=None, attributes=None):
def _call_fut(
self, restart, request, span_name=None, session=None, attributes=None
):
from google.cloud.spanner_v1.snapshot import _restart_on_unavailable

return _restart_on_unavailable(restart, span_name, session, attributes)
return _restart_on_unavailable(restart, request, span_name, session, attributes)

def _make_item(self, value, resume_token=b""):
return mock.Mock(
Expand All @@ -59,18 +61,21 @@ def _make_item(self, value, resume_token=b""):

def test_iteration_w_empty_raw(self):
raw = _MockIterator()
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), [])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_non_empty_raw(self):
ITEMS = (self._make_item(0), self._make_item(1))
raw = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
restart.assert_called_once_with()
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_w_resume_tken(self):
Expand All @@ -81,10 +86,11 @@ def test_iteration_w_raw_w_resume_tken(self):
self._make_item(3),
)
raw = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
restart.assert_called_once_with()
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_raising_unavailable_no_token(self):
Expand All @@ -97,10 +103,12 @@ def test_iteration_w_raw_raising_unavailable_no_token(self):
)
before = _MockIterator(fail_after=True, error=ServiceUnavailable("testing"))
after = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")])
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, b"")
self.assertNoSpans()

def test_iteration_w_raw_raising_retryable_internal_error_no_token(self):
Expand All @@ -118,10 +126,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self):
),
)
after = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")])
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, b"")
self.assertNoSpans()

def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self):
Expand All @@ -134,11 +144,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self):
)
before = _MockIterator(fail_after=True, error=InternalServerError("testing"))
after = _MockIterator(*ITEMS)
request = mock.Mock(spec=["resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
with self.assertRaises(InternalServerError):
list(resumable)
self.assertEqual(restart.mock_calls, [mock.call()])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_raising_unavailable(self):
Expand All @@ -151,12 +162,12 @@ def test_iteration_w_raw_raising_unavailable(self):
*(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing")
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + LAST))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_retryable_internal_error(self):
Expand All @@ -173,12 +184,12 @@ def test_iteration_w_raw_raising_retryable_internal_error(self):
)
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + LAST))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_non_retryable_internal_error(self):
Expand All @@ -191,11 +202,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self):
*(FIRST + SECOND), fail_after=True, error=InternalServerError("testing")
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
with self.assertRaises(InternalServerError):
list(resumable)
self.assertEqual(restart.mock_calls, [mock.call()])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_raising_unavailable_after_token(self):
Expand All @@ -207,12 +219,12 @@ def test_iteration_w_raw_raising_unavailable_after_token(self):
*FIRST, fail_after=True, error=ServiceUnavailable("testing")
)
after = _MockIterator(*SECOND)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + SECOND))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_retryable_internal_error_after_token(self):
Expand All @@ -228,12 +240,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self):
)
)
after = _MockIterator(*SECOND)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + SECOND))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self):
Expand All @@ -245,19 +257,23 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self):
*FIRST, fail_after=True, error=InternalServerError("testing")
)
after = _MockIterator(*SECOND)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
with self.assertRaises(InternalServerError):
list(resumable)
self.assertEqual(restart.mock_calls, [mock.call()])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_span_creation(self):
name = "TestSpan"
extra_atts = {"test_att": 1}
raw = _MockIterator()
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart, name, _Session(_Database()), extra_atts)
resumable = self._call_fut(
restart, request, name, _Session(_Database()), extra_atts
)
self.assertEqual(list(resumable), [])
self.assertSpanAttributes(name, attributes=dict(BASE_ATTRIBUTES, test_att=1))

Expand All @@ -272,13 +288,13 @@ def test_iteration_w_multiple_span_creation(self):
*(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing")
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
name = "TestSpan"
resumable = self._call_fut(restart, name, _Session(_Database()))
resumable = self._call_fut(restart, request, name, _Session(_Database()))
self.assertEqual(list(resumable), list(FIRST + LAST))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)

span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 2)
Expand Down