From 6d597a1e5be05c993c9f86beca4c1486342caf94 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Thu, 12 Nov 2020 15:58:51 -0500 Subject: [PATCH] feat: add 'timeout' arg to 'Table.mutate_rows' (#157) Also, call data client's 'mutate_rows' directly -- do *not* scribble on its internal API wrappers. See: https://github.com/googleapis/python-bigtable/issues/7#issuecomment-715538708 Closes #7 --- google/cloud/bigtable/table.py | 67 ++++++----- tests/unit/test_table.py | 196 +++++++++++++++++++++++---------- 2 files changed, 172 insertions(+), 91 deletions(-) diff --git a/google/cloud/bigtable/table.py b/google/cloud/bigtable/table.py index 950a8c3fe..35ca43d24 100644 --- a/google/cloud/bigtable/table.py +++ b/google/cloud/bigtable/table.py @@ -20,9 +20,9 @@ from google.api_core.exceptions import NotFound from google.api_core.exceptions import RetryError from google.api_core.exceptions import ServiceUnavailable +from google.api_core.gapic_v1.method import DEFAULT from google.api_core.retry import if_exception_type from google.api_core.retry import Retry -from google.api_core.gapic_v1.method import wrap_method from google.cloud._helpers import _to_bytes from google.cloud.bigtable.backup import Backup from google.cloud.bigtable.column_family import _gc_rule_from_pb @@ -625,7 +625,7 @@ def yield_rows(self, **kwargs): ) return self.read_rows(**kwargs) - def mutate_rows(self, rows, retry=DEFAULT_RETRY): + def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT): """Mutates multiple rows in bulk. For example: @@ -656,17 +656,23 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY): the :meth:`~google.api_core.retry.Retry.with_delay` method or the :meth:`~google.api_core.retry.Retry.with_deadline` method. + :type timeout: float + :param timeout: number of seconds bounding retries for the call + :rtype: list :returns: A list of response statuses (`google.rpc.status_pb2.Status`) corresponding to success or failure of each row mutation sent. These will be in the same order as the `rows`. """ + if timeout is DEFAULT: + timeout = self.mutation_timeout + retryable_mutate_rows = _RetryableMutateRowsWorker( self._instance._client, self.name, rows, app_profile_id=self._app_profile_id, - timeout=self.mutation_timeout, + timeout=timeout, ) return retryable_mutate_rows(retry=retry) @@ -1058,27 +1064,20 @@ def _do_mutate_retryable_rows(self): # All mutations are either successful or non-retryable now. return self.responses_statuses - mutate_rows_request = _mutate_rows_request( - self.table_name, retryable_rows, app_profile_id=self.app_profile_id - ) + entries = _compile_mutation_entries(self.table_name, retryable_rows) data_client = self.client.table_data_client - inner_api_calls = data_client._inner_api_calls - if "mutate_rows" not in inner_api_calls: - default_retry = (data_client._method_configs["MutateRows"].retry,) - if self.timeout is None: - default_timeout = data_client._method_configs["MutateRows"].timeout - else: - default_timeout = timeout.ExponentialTimeout(deadline=self.timeout) - data_client._inner_api_calls["mutate_rows"] = wrap_method( - data_client.transport.mutate_rows, - default_retry=default_retry, - default_timeout=default_timeout, - client_info=data_client._client_info, - ) + + kwargs = {} + if self.timeout is not None: + kwargs["timeout"] = timeout.ExponentialTimeout(deadline=self.timeout) try: - responses = data_client._inner_api_calls["mutate_rows"]( - mutate_rows_request, retry=None + responses = data_client.mutate_rows( + self.table_name, + entries, + app_profile_id=self.app_profile_id, + retry=None, + **kwargs ) except (ServiceUnavailable, DeadlineExceeded, Aborted): # If an exception, considered retryable by `RETRY_CODES`, is @@ -1260,8 +1259,8 @@ def _create_row_request( return message -def _mutate_rows_request(table_name, rows, app_profile_id=None): - """Creates a request to mutate rows in a table. +def _compile_mutation_entries(table_name, rows): + """Create list of mutation entries :type table_name: str :param table_name: The name of the table to write to. @@ -1269,29 +1268,29 @@ def _mutate_rows_request(table_name, rows, app_profile_id=None): :type rows: list :param rows: List or other iterable of :class:`.DirectRow` instances. - :type: app_profile_id: str - :param app_profile_id: (Optional) The unique name of the AppProfile. - - :rtype: :class:`data_messages_v2_pb2.MutateRowsRequest` - :returns: The ``MutateRowsRequest`` protobuf corresponding to the inputs. + :rtype: List[:class:`data_messages_v2_pb2.MutateRowsRequest.Entry`] + :returns: entries corresponding to the inputs. :raises: :exc:`~.table.TooManyMutationsError` if the number of mutations is - greater than 100,000 - """ - request_pb = data_messages_v2_pb2.MutateRowsRequest( - table_name=table_name, app_profile_id=app_profile_id + greater than the max ({}) + """.format( + _MAX_BULK_MUTATIONS ) + entries = [] mutations_count = 0 + entry_klass = data_messages_v2_pb2.MutateRowsRequest.Entry + for row in rows: _check_row_table_name(table_name, row) _check_row_type(row) mutations = row._get_mutations() - request_pb.entries.add(row_key=row.row_key, mutations=mutations) + entries.append(entry_klass(row_key=row.row_key, mutations=mutations)) mutations_count += len(mutations) + if mutations_count > _MAX_BULK_MUTATIONS: raise TooManyMutationsError( "Maximum number of mutations is %s" % (_MAX_BULK_MUTATIONS,) ) - return request_pb + return entries def _check_row_table_name(table_name, row): diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index c99cd6591..4469846b1 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -20,14 +20,14 @@ from google.api_core.exceptions import DeadlineExceeded -class Test___mutate_rows_request(unittest.TestCase): +class Test__compile_mutation_entries(unittest.TestCase): def _call_fut(self, table_name, rows): - from google.cloud.bigtable.table import _mutate_rows_request + from google.cloud.bigtable.table import _compile_mutation_entries - return _mutate_rows_request(table_name, rows) + return _compile_mutation_entries(table_name, rows) @mock.patch("google.cloud.bigtable.table._MAX_BULK_MUTATIONS", new=3) - def test__mutate_rows_too_many_mutations(self): + def test_w_too_many_mutations(self): from google.cloud.bigtable.row import DirectRow from google.cloud.bigtable.table import TooManyMutationsError @@ -41,13 +41,15 @@ def test__mutate_rows_too_many_mutations(self): rows[0].set_cell("cf1", b"c1", 2) rows[1].set_cell("cf1", b"c1", 3) rows[1].set_cell("cf1", b"c1", 4) + with self.assertRaises(TooManyMutationsError): self._call_fut("table", rows) - def test__mutate_rows_request(self): + def test_normal(self): from google.cloud.bigtable.row import DirectRow + from google.cloud.bigtable_v2.proto import bigtable_pb2 - table = mock.Mock(name="table", spec=["name"]) + table = mock.Mock(spec=["name"]) table.name = "table" rows = [ DirectRow(row_key=b"row_key", table=table), @@ -55,25 +57,26 @@ def test__mutate_rows_request(self): ] rows[0].set_cell("cf1", b"c1", b"1") rows[1].set_cell("cf1", b"c1", b"2") + result = self._call_fut("table", rows) - expected_result = _mutate_rows_request_pb(table_name="table") - entry1 = expected_result.entries.add() - entry1.row_key = b"row_key" - mutations1 = entry1.mutations.add() - mutations1.set_cell.family_name = "cf1" - mutations1.set_cell.column_qualifier = b"c1" - mutations1.set_cell.timestamp_micros = -1 - mutations1.set_cell.value = b"1" - entry2 = expected_result.entries.add() - entry2.row_key = b"row_key_2" - mutations2 = entry2.mutations.add() - mutations2.set_cell.family_name = "cf1" - mutations2.set_cell.column_qualifier = b"c1" - mutations2.set_cell.timestamp_micros = -1 - mutations2.set_cell.value = b"2" + Entry = bigtable_pb2.MutateRowsRequest.Entry - self.assertEqual(result, expected_result) + entry_1 = Entry(row_key=b"row_key") + mutations_1 = entry_1.mutations.add() + mutations_1.set_cell.family_name = "cf1" + mutations_1.set_cell.column_qualifier = b"c1" + mutations_1.set_cell.timestamp_micros = -1 + mutations_1.set_cell.value = b"1" + + entry_2 = Entry(row_key=b"row_key_2") + mutations_2 = entry_2.mutations.add() + mutations_2.set_cell.family_name = "cf1" + mutations_2.set_cell.column_qualifier = b"c1" + mutations_2.set_cell.timestamp_micros = -1 + mutations_2.set_cell.value = b"2" + + self.assertEqual(result, [entry_1, entry_2]) class Test__check_row_table_name(unittest.TestCase): @@ -162,27 +165,49 @@ def _get_target_client_class(): def _make_client(self, *args, **kwargs): return self._get_target_client_class()(*args, **kwargs) - def test_constructor_w_admin(self): - credentials = _make_credentials() - client = self._make_client( - project=self.PROJECT_ID, credentials=credentials, admin=True - ) - instance = client.instance(instance_id=self.INSTANCE_ID) + def test_constructor_defaults(self): + instance = mock.Mock(spec=[]) + table = self._make_one(self.TABLE_ID, instance) + + self.assertEqual(table.table_id, self.TABLE_ID) + self.assertIs(table._instance, instance) + self.assertIsNone(table.mutation_timeout) + self.assertIsNone(table._app_profile_id) + + def test_constructor_explicit(self): + instance = mock.Mock(spec=[]) + mutation_timeout = 123 + app_profile_id = "profile-123" + + table = self._make_one( + self.TABLE_ID, + instance, + mutation_timeout=mutation_timeout, + app_profile_id=app_profile_id, + ) + self.assertEqual(table.table_id, self.TABLE_ID) - self.assertIs(table._instance._client, client) - self.assertEqual(table.name, self.TABLE_NAME) + self.assertIs(table._instance, instance) + self.assertEqual(table.mutation_timeout, mutation_timeout) + self.assertEqual(table._app_profile_id, app_profile_id) - def test_constructor_wo_admin(self): - credentials = _make_credentials() - client = self._make_client( - project=self.PROJECT_ID, credentials=credentials, admin=False + def test_name(self): + table_data_client = mock.Mock(spec=["table_path"]) + client = mock.Mock( + project=self.PROJECT_ID, + table_data_client=table_data_client, + spec=["project", "table_data_client"], ) - instance = client.instance(instance_id=self.INSTANCE_ID) + instance = mock.Mock( + _client=client, + instance_id=self.INSTANCE_ID, + spec=["_client", "instance_id"], + ) + table = self._make_one(self.TABLE_ID, instance) - self.assertEqual(table.table_id, self.TABLE_ID) - self.assertIs(table._instance._client, client) - self.assertEqual(table.name, self.TABLE_NAME) + + self.assertEqual(table.name, table_data_client.table_path.return_value) def _row_methods_helper(self): client = self._make_client( @@ -620,8 +645,11 @@ def test_read_row_still_partial(self): with self.assertRaises(ValueError): self._read_row_helper(chunks, None) - def test_mutate_rows(self): + def _mutate_rows_helper( + self, mutation_timeout=None, app_profile_id=None, retry=None, timeout=None + ): from google.rpc.status_pb2 import Status + from google.cloud.bigtable.table import DEFAULT_RETRY from google.cloud.bigtable_admin_v2.gapic import bigtable_table_admin_client table_api = mock.create_autospec( @@ -633,21 +661,78 @@ def test_mutate_rows(self): ) instance = client.instance(instance_id=self.INSTANCE_ID) client._table_admin_client = table_api - table = self._make_one(self.TABLE_ID, instance) + ctor_kwargs = {} - response = [Status(code=0), Status(code=1)] + if mutation_timeout is not None: + ctor_kwargs["mutation_timeout"] = mutation_timeout + + if app_profile_id is not None: + ctor_kwargs["app_profile_id"] = app_profile_id - mock_worker = mock.Mock(return_value=response) - with mock.patch( + table = self._make_one(self.TABLE_ID, instance, **ctor_kwargs) + + rows = [mock.MagicMock(), mock.MagicMock()] + response = [Status(code=0), Status(code=1)] + instance_mock = mock.Mock(return_value=response) + klass_mock = mock.patch( "google.cloud.bigtable.table._RetryableMutateRowsWorker", - new=mock.MagicMock(return_value=mock_worker), - ): - statuses = table.mutate_rows([mock.MagicMock(), mock.MagicMock()]) + new=mock.MagicMock(return_value=instance_mock), + ) + + call_kwargs = {} + + if retry is not None: + call_kwargs["retry"] = retry + + if timeout is not None: + expected_timeout = call_kwargs["timeout"] = timeout + else: + expected_timeout = mutation_timeout + + with klass_mock: + statuses = table.mutate_rows(rows, **call_kwargs) + result = [status.code for status in statuses] expected_result = [0, 1] - self.assertEqual(result, expected_result) + klass_mock.new.assert_called_once_with( + client, + self.TABLE_NAME, + rows, + app_profile_id=app_profile_id, + timeout=expected_timeout, + ) + + if retry is not None: + instance_mock.assert_called_once_with(retry=retry) + else: + instance_mock.assert_called_once_with(retry=DEFAULT_RETRY) + + def test_mutate_rows_w_default_mutation_timeout_app_profile_id(self): + self._mutate_rows_helper() + + def test_mutate_rows_w_mutation_timeout(self): + mutation_timeout = 123 + self._mutate_rows_helper(mutation_timeout=mutation_timeout) + + def test_mutate_rows_w_app_profile_id(self): + app_profile_id = "profile-123" + self._mutate_rows_helper(app_profile_id=app_profile_id) + + def test_mutate_rows_w_retry(self): + retry = mock.Mock() + self._mutate_rows_helper(retry=retry) + + def test_mutate_rows_w_timeout_arg(self): + timeout = 123 + self._mutate_rows_helper(timeout=timeout) + + def test_mutate_rows_w_mutation_timeout_and_timeout_arg(self): + mutation_timeout = 123 + timeout = 456 + self._mutate_rows_helper(mutation_timeout=mutation_timeout, timeout=timeout) + def test_read_rows(self): from google.cloud._testing import _Monkey from google.cloud.bigtable.row_data import PartialRowsData @@ -1424,21 +1509,18 @@ def test_callable_no_retry_strategy(self): row_3 = DirectRow(row_key=b"row_key_3", table=table) row_3.set_cell("cf", b"col", b"value3") - response = self._make_responses( - [self.SUCCESS, self.RETRYABLE_1, self.NON_RETRYABLE] - ) + worker = self._make_worker(client, table.name, [row_1, row_2, row_3]) - with mock.patch("google.cloud.bigtable.table.wrap_method") as patched: - patched.return_value = mock.Mock(return_value=[response]) + response_codes = [self.SUCCESS, self.RETRYABLE_1, self.NON_RETRYABLE] + response = self._make_responses(response_codes) + data_api.mutate_rows = mock.MagicMock(return_value=[response]) - worker = self._make_worker(client, table.name, [row_1, row_2, row_3]) - statuses = worker(retry=None) + statuses = worker(retry=None) result = [status.code for status in statuses] - expected_result = [self.SUCCESS, self.RETRYABLE_1, self.NON_RETRYABLE] + self.assertEqual(result, response_codes) - client._table_data_client._inner_api_calls["mutate_rows"].assert_called_once() - self.assertEqual(result, expected_result) + data_api.mutate_rows.assert_called_once() def test_callable_retry(self): from google.cloud.bigtable.row import DirectRow