From 29aeb673dc566a757caa2eaad5e5d13f714769de Mon Sep 17 00:00:00 2001 From: larkee Date: Thu, 22 Apr 2021 19:46:39 +1000 Subject: [PATCH 1/4] fix: correctly set resume token for restarting streams --- google/cloud/spanner_v1/database.py | 6 +-- google/cloud/spanner_v1/snapshot.py | 24 ++++++--- tests/unit/test_snapshot.py | 84 +++++++++++++++++------------ 3 files changed, 69 insertions(+), 45 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 1e76bf218f..5eb688d9c6 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -518,11 +518,11 @@ def execute_pdml(): param_types=param_types, query_options=query_options, ) - restart = functools.partial( - api.execute_streaming_sql, request=request, metadata=metadata, + method = functools.partial( + api.execute_streaming_sql, metadata=metadata, ) - iterator = _restart_on_unavailable(restart) + iterator = _restart_on_unavailable(method, request) result_set = StreamedResultSet(iterator) list(result_set) # consume all partials diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 1b3ae8097d..60311f5ff1 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -41,16 +41,19 @@ ) -def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=None): +def _restart_on_unavailable(method, request, trace_name=None, session=None, attributes=None): """Restart iteration after :exc:`.ServiceUnavailable`. - :type restart: callable - :param restart: curried function returning iterator + :type method: callable + :param method: curried function returning iterator + + :type request: callable + :param request: curried function returning iterator """ resume_token = b"" item_buffer = [] with trace_call(trace_name, session, attributes): - iterator = restart() + iterator = method(request=request) while True: try: for item in iterator: @@ -61,7 +64,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N except ServiceUnavailable: del item_buffer[:] with trace_call(trace_name, session, attributes): - iterator = restart(resume_token=resume_token) + request.resume_token = resume_token + iterator = method(request=request) continue except InternalServerError as exc: resumable_error = any( @@ -72,7 +76,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N raise del item_buffer[:] with trace_call(trace_name, session, attributes): - iterator = restart(resume_token=resume_token) + request.resume_token = resume_token + iterator = method(request=request) continue if len(item_buffer) == 0: @@ -189,7 +194,11 @@ def read( trace_attributes = {"table_id": table, "columns": columns} iterator = _restart_on_unavailable( - restart, "CloudSpanner.ReadOnlyTransaction", self._session, trace_attributes + restart, + request, + "CloudSpanner.ReadOnlyTransaction", + self._session, + trace_attributes ) self._read_request_count += 1 @@ -302,6 +311,7 @@ def execute_sql( trace_attributes = {"db.statement": sql} iterator = _restart_on_unavailable( restart, + request, "CloudSpanner.ReadWriteTransaction", self._session, trace_attributes, diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index cc9a67cb4d..da129e69c0 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -47,10 +47,10 @@ class Test_restart_on_unavailable(OpenTelemetryBase): - def _call_fut(self, restart, span_name=None, session=None, attributes=None): + def _call_fut(self, restart, request, span_name=None, session=None, attributes=None): from google.cloud.spanner_v1.snapshot import _restart_on_unavailable - return _restart_on_unavailable(restart, span_name, session, attributes) + return _restart_on_unavailable(restart, request, span_name, session, attributes) def _make_item(self, value, resume_token=b""): return mock.Mock( @@ -59,18 +59,21 @@ def _make_item(self, value, resume_token=b""): def test_iteration_w_empty_raw(self): raw = _MockIterator() + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], return_value=raw) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), []) + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_non_empty_raw(self): ITEMS = (self._make_item(0), self._make_item(1)) raw = _MockIterator(*ITEMS) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], return_value=raw) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) - restart.assert_called_once_with() + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_raw_w_resume_tken(self): @@ -81,10 +84,11 @@ def test_iteration_w_raw_w_resume_tken(self): self._make_item(3), ) raw = _MockIterator(*ITEMS) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], return_value=raw) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) - restart.assert_called_once_with() + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_raw_raising_unavailable_no_token(self): @@ -97,10 +101,12 @@ def test_iteration_w_raw_raising_unavailable_no_token(self): ) before = _MockIterator(fail_after=True, error=ServiceUnavailable("testing")) after = _MockIterator(*ITEMS) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) - self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")]) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, b'') self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): @@ -118,10 +124,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): ), ) after = _MockIterator(*ITEMS) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) - self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")]) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, b'') self.assertNoSpans() def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): @@ -134,11 +142,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): ) before = _MockIterator(fail_after=True, error=InternalServerError("testing")) after = _MockIterator(*ITEMS) + request = mock.Mock(spec=['resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) with self.assertRaises(InternalServerError): list(resumable) - self.assertEqual(restart.mock_calls, [mock.call()]) + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_raw_raising_unavailable(self): @@ -151,12 +160,12 @@ def test_iteration_w_raw_raising_unavailable(self): *(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing") ) after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + LAST)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error(self): @@ -173,12 +182,12 @@ def test_iteration_w_raw_raising_retryable_internal_error(self): ) ) after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + LAST)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) self.assertNoSpans() def test_iteration_w_raw_raising_non_retryable_internal_error(self): @@ -191,11 +200,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self): *(FIRST + SECOND), fail_after=True, error=InternalServerError("testing") ) after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) with self.assertRaises(InternalServerError): list(resumable) - self.assertEqual(restart.mock_calls, [mock.call()]) + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_raw_raising_unavailable_after_token(self): @@ -207,12 +217,12 @@ def test_iteration_w_raw_raising_unavailable_after_token(self): *FIRST, fail_after=True, error=ServiceUnavailable("testing") ) after = _MockIterator(*SECOND) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + SECOND)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): @@ -228,12 +238,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): ) ) after = _MockIterator(*SECOND) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + SECOND)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) self.assertNoSpans() def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): @@ -245,19 +255,22 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): *FIRST, fail_after=True, error=InternalServerError("testing") ) after = _MockIterator(*SECOND) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) - resumable = self._call_fut(restart) + resumable = self._call_fut(restart, request) with self.assertRaises(InternalServerError): list(resumable) - self.assertEqual(restart.mock_calls, [mock.call()]) + restart.assert_called_once_with(request=request) self.assertNoSpans() def test_iteration_w_span_creation(self): name = "TestSpan" extra_atts = {"test_att": 1} raw = _MockIterator() + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], return_value=raw) - resumable = self._call_fut(restart, name, _Session(_Database()), extra_atts) + resumable = self._call_fut( + restart, request, name, _Session(_Database()), extra_atts) self.assertEqual(list(resumable), []) self.assertSpanAttributes(name, attributes=dict(BASE_ATTRIBUTES, test_att=1)) @@ -272,9 +285,10 @@ def test_iteration_w_multiple_span_creation(self): *(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing") ) after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=['test', 'resume_token']) restart = mock.Mock(spec=[], side_effect=[before, after]) name = "TestSpan" - resumable = self._call_fut(restart, name, _Session(_Database())) + resumable = self._call_fut(restart, request, name, _Session(_Database())) self.assertEqual(list(resumable), list(FIRST + LAST)) self.assertEqual( restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] From 713bcf5d228cb191cd0fe32703d8e84c48091558 Mon Sep 17 00:00:00 2001 From: larkee Date: Thu, 22 Apr 2021 19:47:23 +1000 Subject: [PATCH 2/4] style: fix lint --- google/cloud/spanner_v1/snapshot.py | 6 +++-- tests/unit/test_snapshot.py | 39 ++++++++++++++++------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 60311f5ff1..feb2dbcc45 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -41,7 +41,9 @@ ) -def _restart_on_unavailable(method, request, trace_name=None, session=None, attributes=None): +def _restart_on_unavailable( + method, request, trace_name=None, session=None, attributes=None +): """Restart iteration after :exc:`.ServiceUnavailable`. :type method: callable @@ -198,7 +200,7 @@ def read( request, "CloudSpanner.ReadOnlyTransaction", self._session, - trace_attributes + trace_attributes, ) self._read_request_count += 1 diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index da129e69c0..666caa94be 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -47,7 +47,9 @@ class Test_restart_on_unavailable(OpenTelemetryBase): - def _call_fut(self, restart, request, span_name=None, session=None, attributes=None): + def _call_fut( + self, restart, request, span_name=None, session=None, attributes=None + ): from google.cloud.spanner_v1.snapshot import _restart_on_unavailable return _restart_on_unavailable(restart, request, span_name, session, attributes) @@ -59,7 +61,7 @@ def _make_item(self, value, resume_token=b""): def test_iteration_w_empty_raw(self): raw = _MockIterator() - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), []) @@ -69,7 +71,7 @@ def test_iteration_w_empty_raw(self): def test_iteration_w_non_empty_raw(self): ITEMS = (self._make_item(0), self._make_item(1)) raw = _MockIterator(*ITEMS) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) @@ -84,7 +86,7 @@ def test_iteration_w_raw_w_resume_tken(self): self._make_item(3), ) raw = _MockIterator(*ITEMS) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) @@ -101,12 +103,12 @@ def test_iteration_w_raw_raising_unavailable_no_token(self): ) before = _MockIterator(fail_after=True, error=ServiceUnavailable("testing")) after = _MockIterator(*ITEMS) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) self.assertEqual(len(restart.mock_calls), 2) - self.assertEqual(request.resume_token, b'') + self.assertEqual(request.resume_token, b"") self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): @@ -124,12 +126,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): ), ) after = _MockIterator(*ITEMS) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(ITEMS)) self.assertEqual(len(restart.mock_calls), 2) - self.assertEqual(request.resume_token, b'') + self.assertEqual(request.resume_token, b"") self.assertNoSpans() def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): @@ -142,7 +144,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): ) before = _MockIterator(fail_after=True, error=InternalServerError("testing")) after = _MockIterator(*ITEMS) - request = mock.Mock(spec=['resume_token']) + request = mock.Mock(spec=["resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) resumable = self._call_fut(restart, request) with self.assertRaises(InternalServerError): @@ -160,7 +162,7 @@ def test_iteration_w_raw_raising_unavailable(self): *(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing") ) after = _MockIterator(*LAST) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + LAST)) @@ -182,7 +184,7 @@ def test_iteration_w_raw_raising_retryable_internal_error(self): ) ) after = _MockIterator(*LAST) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + LAST)) @@ -200,7 +202,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self): *(FIRST + SECOND), fail_after=True, error=InternalServerError("testing") ) after = _MockIterator(*LAST) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) resumable = self._call_fut(restart, request) with self.assertRaises(InternalServerError): @@ -217,7 +219,7 @@ def test_iteration_w_raw_raising_unavailable_after_token(self): *FIRST, fail_after=True, error=ServiceUnavailable("testing") ) after = _MockIterator(*SECOND) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + SECOND)) @@ -238,7 +240,7 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): ) ) after = _MockIterator(*SECOND) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) resumable = self._call_fut(restart, request) self.assertEqual(list(resumable), list(FIRST + SECOND)) @@ -255,7 +257,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): *FIRST, fail_after=True, error=InternalServerError("testing") ) after = _MockIterator(*SECOND) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) resumable = self._call_fut(restart, request) with self.assertRaises(InternalServerError): @@ -267,10 +269,11 @@ def test_iteration_w_span_creation(self): name = "TestSpan" extra_atts = {"test_att": 1} raw = _MockIterator() - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) resumable = self._call_fut( - restart, request, name, _Session(_Database()), extra_atts) + restart, request, name, _Session(_Database()), extra_atts + ) self.assertEqual(list(resumable), []) self.assertSpanAttributes(name, attributes=dict(BASE_ATTRIBUTES, test_att=1)) @@ -285,7 +288,7 @@ def test_iteration_w_multiple_span_creation(self): *(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing") ) after = _MockIterator(*LAST) - request = mock.Mock(test="test", spec=['test', 'resume_token']) + request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) name = "TestSpan" resumable = self._call_fut(restart, request, name, _Session(_Database())) From 6b2a3f79b816c96d65b567080803e557cda6b39b Mon Sep 17 00:00:00 2001 From: larkee Date: Thu, 22 Apr 2021 20:10:14 +1000 Subject: [PATCH 3/4] docs: update docstring --- google/cloud/spanner_v1/snapshot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index feb2dbcc45..f926d7836d 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -47,10 +47,10 @@ def _restart_on_unavailable( """Restart iteration after :exc:`.ServiceUnavailable`. :type method: callable - :param method: curried function returning iterator + :param method: function returning iterator - :type request: callable - :param request: curried function returning iterator + :type request: proto + :param request: request proto to call the method with """ resume_token = b"" item_buffer = [] From d167c9698f2cf480c1abdf0aaafa18d7bd834218 Mon Sep 17 00:00:00 2001 From: larkee Date: Fri, 23 Apr 2021 11:47:26 +1000 Subject: [PATCH 4/4] test: fix assertion --- tests/unit/test_snapshot.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 666caa94be..24f87a30fc 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -293,9 +293,8 @@ def test_iteration_w_multiple_span_creation(self): name = "TestSpan" resumable = self._call_fut(restart, request, name, _Session(_Database())) self.assertEqual(list(resumable), list(FIRST + LAST)) - self.assertEqual( - restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)] - ) + self.assertEqual(len(restart.mock_calls), 2) + self.assertEqual(request.resume_token, RESUME_TOKEN) span_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(span_list), 2)