diff --git a/google/cloud/bigtable/table.py b/google/cloud/bigtable/table.py index 950a8c3fe..5d82312cd 100644 --- a/google/cloud/bigtable/table.py +++ b/google/cloud/bigtable/table.py @@ -22,7 +22,6 @@ from google.api_core.exceptions import ServiceUnavailable from google.api_core.retry import if_exception_type from google.api_core.retry import Retry -from google.api_core.gapic_v1.method import wrap_method from google.cloud._helpers import _to_bytes from google.cloud.bigtable.backup import Backup from google.cloud.bigtable.column_family import _gc_rule_from_pb @@ -1058,27 +1057,20 @@ def _do_mutate_retryable_rows(self): # All mutations are either successful or non-retryable now. return self.responses_statuses - mutate_rows_request = _mutate_rows_request( - self.table_name, retryable_rows, app_profile_id=self.app_profile_id - ) + entries = _compile_mutation_entries(self.table_name, retryable_rows) data_client = self.client.table_data_client - inner_api_calls = data_client._inner_api_calls - if "mutate_rows" not in inner_api_calls: - default_retry = (data_client._method_configs["MutateRows"].retry,) - if self.timeout is None: - default_timeout = data_client._method_configs["MutateRows"].timeout - else: - default_timeout = timeout.ExponentialTimeout(deadline=self.timeout) - data_client._inner_api_calls["mutate_rows"] = wrap_method( - data_client.transport.mutate_rows, - default_retry=default_retry, - default_timeout=default_timeout, - client_info=data_client._client_info, - ) + + kwargs = {} + if self.timeout is not None: + kwargs["timeout"] = timeout.ExponentialTimeout(deadline=self.timeout) try: - responses = data_client._inner_api_calls["mutate_rows"]( - mutate_rows_request, retry=None + responses = data_client.mutate_rows( + self.table_name, + entries, + app_profile_id=self.app_profile_id, + retry=None, + **kwargs ) except (ServiceUnavailable, DeadlineExceeded, Aborted): # If an exception, considered retryable by `RETRY_CODES`, is @@ -1260,8 +1252,8 @@ def _create_row_request( return message -def _mutate_rows_request(table_name, rows, app_profile_id=None): - """Creates a request to mutate rows in a table. +def _compile_mutation_entries(table_name, rows): + """Create list of mutation entries :type table_name: str :param table_name: The name of the table to write to. @@ -1269,29 +1261,29 @@ def _mutate_rows_request(table_name, rows, app_profile_id=None): :type rows: list :param rows: List or other iterable of :class:`.DirectRow` instances. - :type: app_profile_id: str - :param app_profile_id: (Optional) The unique name of the AppProfile. - - :rtype: :class:`data_messages_v2_pb2.MutateRowsRequest` - :returns: The ``MutateRowsRequest`` protobuf corresponding to the inputs. + :rtype: List[:class:`data_messages_v2_pb2.MutateRowsRequest.Entry`] + :returns: entries corresponding to the inputs. :raises: :exc:`~.table.TooManyMutationsError` if the number of mutations is - greater than 100,000 - """ - request_pb = data_messages_v2_pb2.MutateRowsRequest( - table_name=table_name, app_profile_id=app_profile_id + greater than the max ({}) + """.format( + _MAX_BULK_MUTATIONS ) + entries = [] mutations_count = 0 + entry_klass = data_messages_v2_pb2.MutateRowsRequest.Entry + for row in rows: _check_row_table_name(table_name, row) _check_row_type(row) mutations = row._get_mutations() - request_pb.entries.add(row_key=row.row_key, mutations=mutations) + entries.append(entry_klass(row_key=row.row_key, mutations=mutations)) mutations_count += len(mutations) + if mutations_count > _MAX_BULK_MUTATIONS: raise TooManyMutationsError( "Maximum number of mutations is %s" % (_MAX_BULK_MUTATIONS,) ) - return request_pb + return entries def _check_row_table_name(table_name, row): diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 9e018c0a9..b0c0c4305 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -20,14 +20,14 @@ from google.api_core.exceptions import DeadlineExceeded -class Test___mutate_rows_request(unittest.TestCase): +class Test__compile_mutation_entries(unittest.TestCase): def _call_fut(self, table_name, rows): - from google.cloud.bigtable.table import _mutate_rows_request + from google.cloud.bigtable.table import _compile_mutation_entries - return _mutate_rows_request(table_name, rows) + return _compile_mutation_entries(table_name, rows) @mock.patch("google.cloud.bigtable.table._MAX_BULK_MUTATIONS", new=3) - def test__mutate_rows_too_many_mutations(self): + def test_w_too_many_mutations(self): from google.cloud.bigtable.row import DirectRow from google.cloud.bigtable.table import TooManyMutationsError @@ -41,13 +41,15 @@ def test__mutate_rows_too_many_mutations(self): rows[0].set_cell("cf1", b"c1", 2) rows[1].set_cell("cf1", b"c1", 3) rows[1].set_cell("cf1", b"c1", 4) + with self.assertRaises(TooManyMutationsError): self._call_fut("table", rows) - def test__mutate_rows_request(self): + def test_normal(self): from google.cloud.bigtable.row import DirectRow + from google.cloud.bigtable_v2.proto import bigtable_pb2 - table = mock.Mock(name="table", spec=["name"]) + table = mock.Mock(spec=["name"]) table.name = "table" rows = [ DirectRow(row_key=b"row_key", table=table), @@ -55,25 +57,26 @@ def test__mutate_rows_request(self): ] rows[0].set_cell("cf1", b"c1", b"1") rows[1].set_cell("cf1", b"c1", b"2") + result = self._call_fut("table", rows) - expected_result = _mutate_rows_request_pb(table_name="table") - entry1 = expected_result.entries.add() - entry1.row_key = b"row_key" - mutations1 = entry1.mutations.add() - mutations1.set_cell.family_name = "cf1" - mutations1.set_cell.column_qualifier = b"c1" - mutations1.set_cell.timestamp_micros = -1 - mutations1.set_cell.value = b"1" - entry2 = expected_result.entries.add() - entry2.row_key = b"row_key_2" - mutations2 = entry2.mutations.add() - mutations2.set_cell.family_name = "cf1" - mutations2.set_cell.column_qualifier = b"c1" - mutations2.set_cell.timestamp_micros = -1 - mutations2.set_cell.value = b"2" + Entry = bigtable_pb2.MutateRowsRequest.Entry - self.assertEqual(result, expected_result) + entry_1 = Entry(row_key=b"row_key") + mutations_1 = entry_1.mutations.add() + mutations_1.set_cell.family_name = "cf1" + mutations_1.set_cell.column_qualifier = b"c1" + mutations_1.set_cell.timestamp_micros = -1 + mutations_1.set_cell.value = b"1" + + entry_2 = Entry(row_key=b"row_key_2") + mutations_2 = entry_2.mutations.add() + mutations_2.set_cell.family_name = "cf1" + mutations_2.set_cell.column_qualifier = b"c1" + mutations_2.set_cell.timestamp_micros = -1 + mutations_2.set_cell.value = b"2" + + self.assertEqual(result, [entry_1, entry_2]) class Test__check_row_table_name(unittest.TestCase): @@ -175,7 +178,7 @@ def test_constructor_defaults(self): def test_constructor_explicit(self): instance = mock.Mock(spec=[]) mutation_timeout = 123 - app_profile_id = 'profile-123' + app_profile_id = "profile-123" table = self._make_one( self.TABLE_ID, @@ -194,7 +197,7 @@ def test_name(self): client = mock.Mock( project=self.PROJECT_ID, table_data_client=table_data_client, - spec=["project", "table_data_client"] + spec=["project", "table_data_client"], ) instance = mock.Mock( _client=client, @@ -642,7 +645,9 @@ def test_read_row_still_partial(self): with self.assertRaises(ValueError): self._read_row_helper(chunks, None) - def _mutate_rows_helper(self, mutation_timeout=None, app_profile_id=None, retry=None): + def _mutate_rows_helper( + self, mutation_timeout=None, app_profile_id=None, retry=None + ): from google.rpc.status_pb2 import Status from google.cloud.bigtable.table import DEFAULT_RETRY from google.cloud.bigtable_admin_v2.gapic import bigtable_table_admin_client @@ -705,7 +710,7 @@ def test_mutate_rows_w_mutation_timeout(self): self._mutate_rows_helper(mutation_timeout=mutation_timeout) def test_mutate_rows_w_app_profile_id(self): - app_profile_id = 'profile-123' + app_profile_id = "profile-123" self._mutate_rows_helper(app_profile_id=app_profile_id) def test_mutate_rows_w_retry(self): @@ -1488,21 +1493,18 @@ def test_callable_no_retry_strategy(self): row_3 = DirectRow(row_key=b"row_key_3", table=table) row_3.set_cell("cf", b"col", b"value3") - response = self._make_responses( - [self.SUCCESS, self.RETRYABLE_1, self.NON_RETRYABLE] - ) + worker = self._make_worker(client, table.name, [row_1, row_2, row_3]) - with mock.patch("google.cloud.bigtable.table.wrap_method") as patched: - patched.return_value = mock.Mock(return_value=[response]) + response_codes = [self.SUCCESS, self.RETRYABLE_1, self.NON_RETRYABLE] + response = self._make_responses(response_codes) + data_api.mutate_rows = mock.MagicMock(return_value=[response]) - worker = self._make_worker(client, table.name, [row_1, row_2, row_3]) - statuses = worker(retry=None) + statuses = worker(retry=None) result = [status.code for status in statuses] - expected_result = [self.SUCCESS, self.RETRYABLE_1, self.NON_RETRYABLE] + self.assertEqual(result, response_codes) - client._table_data_client._inner_api_calls["mutate_rows"].assert_called_once() - self.assertEqual(result, expected_result) + data_api.mutate_rows.assert_called_once() def test_callable_retry(self): from google.cloud.bigtable.row import DirectRow