diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 3b7fd586c9..91e8c8d29c 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -26,6 +26,7 @@ from google.cloud._helpers import _date_from_iso8601_date from google.cloud._helpers import _datetime_to_rfc3339 from google.cloud.spanner_v1.proto import type_pb2 +from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest def _try_to_coerce_bytes(bytestring): @@ -47,6 +48,44 @@ def _try_to_coerce_bytes(bytestring): ) +def _merge_query_options(base, merge): + """Merge higher precedence QueryOptions with current QueryOptions. + + :type base: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryOptions` + or :class:`dict` or None + :param base: The current QueryOptions that is intended for use. + + :type merge: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryOptions` + or :class:`dict` or None + :param merge: + The QueryOptions that have a higher priority than base. These options + should overwrite the fields in base. + + :rtype: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryOptions` + or None + :returns: + QueryOptions object formed by merging the two given QueryOptions. + If the resultant object only has empty fields, returns None. + """ + combined = base or ExecuteSqlRequest.QueryOptions() + if type(combined) == dict: + combined = ExecuteSqlRequest.QueryOptions( + optimizer_version=combined.get("optimizer_version", "") + ) + merge = merge or ExecuteSqlRequest.QueryOptions() + if type(merge) == dict: + merge = ExecuteSqlRequest.QueryOptions( + optimizer_version=merge.get("optimizer_version", "") + ) + combined.MergeFrom(merge) + if not combined.optimizer_version: + return None + return combined + + # pylint: disable=too-many-return-statements,too-many-branches def _make_value_pb(value): """Helper for :func:`_make_list_value_pbs`. diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index c7b331adc0..01b3ddfabf 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -50,9 +50,10 @@ from google.cloud.client import ClientWithProject from google.cloud.spanner_v1 import __version__ -from google.cloud.spanner_v1._helpers import _metadata_with_prefix +from google.cloud.spanner_v1._helpers import _merge_query_options, _metadata_with_prefix from google.cloud.spanner_v1.instance import DEFAULT_NODE_COUNT from google.cloud.spanner_v1.instance import Instance +from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest _CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST" @@ -62,6 +63,7 @@ "without a scheme: ex %s=localhost:8080." ) % ((EMULATOR_ENV_VAR,) * 3) SPANNER_ADMIN_SCOPE = "https://www.googleapis.com/auth/spanner.admin" +OPTIMIZER_VERSION_ENV_VAR = "SPANNER_OPTIMIZER_VERSION" _USER_AGENT_DEPRECATED = ( "The 'user_agent' argument to 'Client' is deprecated / unused. " "Please pass an appropriate 'client_info' instead." @@ -72,6 +74,10 @@ def _get_spanner_emulator_host(): return os.getenv(EMULATOR_ENV_VAR) +def _get_spanner_optimizer_version(): + return os.getenv(OPTIMIZER_VERSION_ENV_VAR, "") + + class InstanceConfig(object): """Named configurations for Spanner instances. @@ -132,11 +138,20 @@ class Client(ClientWithProject): :param user_agent: (Deprecated) The user agent to be used with API request. Not used. + :type client_options: :class:`~google.api_core.client_options.ClientOptions` or :class:`dict` :param client_options: (Optional) Client options used to set user options on the client. API Endpoint should be set through client_options. + :type query_options: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryOptions` + or :class:`dict` + :param query_options: + (Optional) Query optimizer configuration to use for the given query. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.QueryOptions` + :raises: :class:`ValueError ` if both ``read_only`` and ``admin`` are :data:`True` """ @@ -157,6 +172,7 @@ def __init__( client_info=_CLIENT_INFO, user_agent=None, client_options=None, + query_options=None, ): # NOTE: This API has no use for the _http argument, but sending it # will have no impact since the _http() @property only lazily @@ -172,6 +188,13 @@ def __init__( else: self._client_options = client_options + env_query_options = ExecuteSqlRequest.QueryOptions( + optimizer_version=_get_spanner_optimizer_version() + ) + + # Environment flag config has higher precedence than application config. + self._query_options = _merge_query_options(query_options, env_query_options) + if user_agent is not None: warnings.warn(_USER_AGENT_DEPRECATED, DeprecationWarning, stacklevel=2) self.user_agent = user_agent diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index f5ea3e46dd..9ee046e094 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -30,8 +30,11 @@ import six # pylint: disable=ungrouped-imports -from google.cloud.spanner_v1._helpers import _make_value_pb -from google.cloud.spanner_v1._helpers import _metadata_with_prefix +from google.cloud.spanner_v1._helpers import ( + _make_value_pb, + _merge_query_options, + _metadata_with_prefix, +) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient from google.cloud.spanner_v1.gapic.transports import spanner_grpc_transport @@ -350,7 +353,9 @@ def drop(self): metadata = _metadata_with_prefix(self.name) api.drop_database(self.name, metadata=metadata) - def execute_partitioned_dml(self, dml, params=None, param_types=None): + def execute_partitioned_dml( + self, dml, params=None, param_types=None, query_options=None + ): """Execute a partitionable DML statement. :type dml: str @@ -365,9 +370,20 @@ def execute_partitioned_dml(self, dml, params=None, param_types=None): (Optional) maps explicit types for one or more param values; required if parameters are passed. + :type query_options: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryOptions` + or :class:`dict` + :param query_options: + (Optional) Query optimizer configuration to use for the given query. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.QueryOptions` + :rtype: int :returns: Count of rows affected by the DML statement. """ + query_options = _merge_query_options( + self._instance._client._query_options, query_options + ) if params is not None: if param_types is None: raise ValueError("Specify 'param_types' when passing 'params'.") @@ -398,6 +414,7 @@ def execute_partitioned_dml(self, dml, params=None, param_types=None): transaction=txn_selector, params=params_pb, param_types=param_types, + query_options=query_options, metadata=metadata, ) @@ -748,6 +765,7 @@ def generate_query_batches( param_types=None, partition_size_bytes=None, max_partitions=None, + query_options=None, ): """Start a partitioned query operation. @@ -783,6 +801,14 @@ def generate_query_batches( service uses this as a hint, the actual number of partitions may differ. + :type query_options: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryOptions` + or :class:`dict` + :param query_options: + (Optional) Query optimizer configuration to use for the given query. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.QueryOptions` + :rtype: iterable of dict :returns: mappings of information used peform actual partitioned reads via @@ -801,6 +827,13 @@ def generate_query_batches( query_info["params"] = params query_info["param_types"] = param_types + # Query-level options have higher precedence than client-level and + # environment-level options + default_query_options = self._database._instance._client._query_options + query_info["query_options"] = _merge_query_options( + default_query_options, query_options + ) + for partition in partitions: yield {"partition": partition, "query": query_info} diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 863053d4ef..fc6bb028b7 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -202,6 +202,7 @@ def execute_sql( params=None, param_types=None, query_mode=None, + query_options=None, retry=google.api_core.gapic_v1.method.DEFAULT, timeout=google.api_core.gapic_v1.method.DEFAULT, ): @@ -225,11 +226,22 @@ def execute_sql( :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + :type query_options: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryOptions` + or :class:`dict` + :param query_options: (Optional) Options that are provided for query plan stability. + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ return self.snapshot().execute_sql( - sql, params, param_types, query_mode, retry=retry, timeout=timeout + sql, + params, + param_types, + query_mode, + query_options=query_options, + retry=retry, + timeout=timeout, ) def batch(self): diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index ec7008fb75..56b3b6a813 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -23,6 +23,7 @@ from google.api_core.exceptions import ServiceUnavailable import google.api_core.gapic_v1.method from google.cloud._helpers import _datetime_to_pb_timestamp +from google.cloud.spanner_v1._helpers import _merge_query_options from google.cloud._helpers import _timedelta_to_duration_pb from google.cloud.spanner_v1._helpers import _make_value_pb from google.cloud.spanner_v1._helpers import _metadata_with_prefix @@ -157,6 +158,7 @@ def execute_sql( params=None, param_types=None, query_mode=None, + query_options=None, partition=None, retry=google.api_core.gapic_v1.method.DEFAULT, timeout=google.api_core.gapic_v1.method.DEFAULT, @@ -180,6 +182,14 @@ def execute_sql( :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + :type query_options: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryOptions` + or :class:`dict` + :param query_options: + (Optional) Query optimizer configuration to use for the given query. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.QueryOptions` + :type partition: bytes :param partition: (Optional) one of the partition tokens returned from :meth:`partition_query`. @@ -211,6 +221,11 @@ def execute_sql( transaction = self._make_txn_selector() api = database.spanner_api + # Query-level options have higher precedence than client-level and + # environment-level options + default_query_options = database._instance._client._query_options + query_options = _merge_query_options(default_query_options, query_options) + restart = functools.partial( api.execute_streaming_sql, self._session.name, @@ -221,6 +236,7 @@ def execute_sql( query_mode=query_mode, partition_token=partition, seqno=self._execute_sql_count, + query_options=query_options, metadata=metadata, retry=retry, timeout=timeout, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 55e2837df4..5a161fd8a6 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -17,8 +17,11 @@ from google.protobuf.struct_pb2 import Struct from google.cloud._helpers import _pb_timestamp_to_datetime -from google.cloud.spanner_v1._helpers import _make_value_pb -from google.cloud.spanner_v1._helpers import _metadata_with_prefix +from google.cloud.spanner_v1._helpers import ( + _make_value_pb, + _merge_query_options, + _metadata_with_prefix, +) from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionSelector from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionOptions from google.cloud.spanner_v1.snapshot import _SnapshotBase @@ -162,7 +165,9 @@ def _make_params_pb(params, param_types): return None - def execute_update(self, dml, params=None, param_types=None, query_mode=None): + def execute_update( + self, dml, params=None, param_types=None, query_mode=None, query_options=None + ): """Perform an ``ExecuteSql`` API request with DML. :type dml: str @@ -182,6 +187,11 @@ def execute_update(self, dml, params=None, param_types=None, query_mode=None): :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + :type query_options: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryOptions` + or :class:`dict` + :param query_options: (Optional) Options that are provided for query plan stability. + :rtype: int :returns: Count of rows affected by the DML statement. """ @@ -191,6 +201,11 @@ def execute_update(self, dml, params=None, param_types=None, query_mode=None): transaction = self._make_txn_selector() api = database.spanner_api + # Query-level options have higher precedence than client-level and + # environment-level options + default_query_options = database._instance._client._query_options + query_options = _merge_query_options(default_query_options, query_options) + response = api.execute_sql( self._session.name, dml, @@ -198,6 +213,7 @@ def execute_update(self, dml, params=None, param_types=None, query_mode=None): params=params_pb, param_types=param_types, query_mode=query_mode, + query_options=query_options, seqno=self._execute_sql_count, metadata=metadata, ) diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 86ce78727b..b2f2c7d5e7 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -16,6 +16,61 @@ import unittest +class Test_merge_query_options(unittest.TestCase): + def _callFUT(self, *args, **kw): + from google.cloud.spanner_v1._helpers import _merge_query_options + + return _merge_query_options(*args, **kw) + + def test_base_none_and_merge_none(self): + base = merge = None + result = self._callFUT(base, merge) + self.assertIsNone(result) + + def test_base_dict_and_merge_none(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + base = {"optimizer_version": "2"} + merge = None + expected = ExecuteSqlRequest.QueryOptions(optimizer_version="2") + result = self._callFUT(base, merge) + self.assertEqual(result, expected) + + def test_base_empty_and_merge_empty(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + base = ExecuteSqlRequest.QueryOptions() + merge = ExecuteSqlRequest.QueryOptions() + result = self._callFUT(base, merge) + self.assertIsNone(result) + + def test_base_none_merge_object(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + base = None + merge = ExecuteSqlRequest.QueryOptions(optimizer_version="3") + result = self._callFUT(base, merge) + self.assertEqual(result, merge) + + def test_base_none_merge_dict(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + base = None + merge = {"optimizer_version": "3"} + expected = ExecuteSqlRequest.QueryOptions(optimizer_version="3") + result = self._callFUT(base, merge) + self.assertEqual(result, expected) + + def test_base_object_merge_dict(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + base = ExecuteSqlRequest.QueryOptions(optimizer_version="1") + merge = {"optimizer_version": "3"} + expected = ExecuteSqlRequest.QueryOptions(optimizer_version="3") + result = self._callFUT(base, merge) + self.assertEqual(result, expected) + + class Test_make_value_pb(unittest.TestCase): def _callFUT(self, *args, **kw): from google.cloud.spanner_v1._helpers import _make_value_pb diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2e04537e02..8308ed6e92 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -56,6 +56,8 @@ def _constructor_test_helper( client_info=None, user_agent=None, client_options=None, + query_options=None, + expected_query_options=None, ): import google.api_core.client_options from google.cloud.spanner_v1 import client as MUT @@ -76,7 +78,11 @@ def _constructor_test_helper( expected_client_options = client_options client = self._make_one( - project=self.PROJECT, credentials=creds, user_agent=user_agent, **kwargs + project=self.PROJECT, + credentials=creds, + user_agent=user_agent, + query_options=query_options, + **kwargs ) expected_creds = expected_creds or creds.with_scopes.return_value @@ -97,15 +103,17 @@ def _constructor_test_helper( client._client_options.api_endpoint, expected_client_options.api_endpoint, ) + if expected_query_options is not None: + self.assertEqual(client._query_options, expected_query_options) - @mock.patch("google.cloud.spanner_v1.client.os.getenv") + @mock.patch("google.cloud.spanner_v1.client._get_spanner_emulator_host") @mock.patch("warnings.warn") - def test_constructor_emulator_host_warning(self, mock_warn, mock_os): + def test_constructor_emulator_host_warning(self, mock_warn, mock_em): from google.cloud.spanner_v1 import client as MUT expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) creds = _make_credentials() - mock_os.return_value = "http://emulator.host.com" + mock_em.return_value = "http://emulator.host.com" self._constructor_test_helper(expected_scopes, creds) mock_warn.assert_called_once_with(MUT._EMULATOR_HOST_HTTP_SCHEME) @@ -175,8 +183,40 @@ def test_constructor_custom_client_options_dict(self): expected_scopes, creds, client_options={"api_endpoint": "endpoint"} ) - @mock.patch("google.cloud.spanner_v1.client.os.getenv") - def test_instance_admin_api(self, mock_getenv): + def test_constructor_custom_query_options_client_config(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + from google.cloud.spanner_v1 import client as MUT + + expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) + creds = _make_credentials() + self._constructor_test_helper( + expected_scopes, + creds, + query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="1"), + expected_query_options=ExecuteSqlRequest.QueryOptions( + optimizer_version="1" + ), + ) + + @mock.patch("google.cloud.spanner_v1.client._get_spanner_optimizer_version") + def test_constructor_custom_query_options_env_config(self, mock_ver): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + from google.cloud.spanner_v1 import client as MUT + + expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) + creds = _make_credentials() + mock_ver.return_value = "2" + self._constructor_test_helper( + expected_scopes, + creds, + query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="1"), + expected_query_options=ExecuteSqlRequest.QueryOptions( + optimizer_version="2" + ), + ) + + @mock.patch("google.cloud.spanner_v1.client._get_spanner_emulator_host") + def test_instance_admin_api(self, mock_em): from google.cloud.spanner_v1.client import SPANNER_ADMIN_SCOPE credentials = _make_credentials() @@ -190,7 +230,7 @@ def test_instance_admin_api(self, mock_getenv): ) expected_scopes = (SPANNER_ADMIN_SCOPE,) - mock_getenv.return_value = None + mock_em.return_value = None inst_module = "google.cloud.spanner_v1.client.InstanceAdminClient" with mock.patch(inst_module) as instance_admin_client: api = client.instance_admin_api @@ -209,8 +249,8 @@ def test_instance_admin_api(self, mock_getenv): credentials.with_scopes.assert_called_once_with(expected_scopes) - @mock.patch("google.cloud.spanner_v1.client.os.getenv") - def test_instance_admin_api_emulator(self, mock_getenv): + @mock.patch("google.cloud.spanner_v1.client._get_spanner_emulator_host") + def test_instance_admin_api_emulator(self, mock_em): credentials = _make_credentials() client_info = mock.Mock() client_options = mock.Mock() @@ -221,7 +261,7 @@ def test_instance_admin_api_emulator(self, mock_getenv): client_options=client_options, ) - mock_getenv.return_value = "true" + mock_em.return_value = "true" inst_module = "google.cloud.spanner_v1.client.InstanceAdminClient" with mock.patch(inst_module) as instance_admin_client: api = client.instance_admin_api @@ -240,8 +280,8 @@ def test_instance_admin_api_emulator(self, mock_getenv): self.assertIn("transport", called_kw) self.assertNotIn("credentials", called_kw) - @mock.patch("google.cloud.spanner_v1.client.os.getenv") - def test_database_admin_api(self, mock_getenv): + @mock.patch("google.cloud.spanner_v1.client._get_spanner_emulator_host") + def test_database_admin_api(self, mock_em): from google.cloud.spanner_v1.client import SPANNER_ADMIN_SCOPE credentials = _make_credentials() @@ -255,7 +295,7 @@ def test_database_admin_api(self, mock_getenv): ) expected_scopes = (SPANNER_ADMIN_SCOPE,) - mock_getenv.return_value = None + mock_em.return_value = None db_module = "google.cloud.spanner_v1.client.DatabaseAdminClient" with mock.patch(db_module) as database_admin_client: api = client.database_admin_api @@ -274,8 +314,8 @@ def test_database_admin_api(self, mock_getenv): credentials.with_scopes.assert_called_once_with(expected_scopes) - @mock.patch("google.cloud.spanner_v1.client.os.getenv") - def test_database_admin_api_emulator(self, mock_getenv): + @mock.patch("google.cloud.spanner_v1.client._get_spanner_emulator_host") + def test_database_admin_api_emulator(self, mock_em): credentials = _make_credentials() client_info = mock.Mock() client_options = mock.Mock() @@ -286,7 +326,7 @@ def test_database_admin_api_emulator(self, mock_getenv): client_options=client_options, ) - mock_getenv.return_value = "true" + mock_em.return_value = "host:port" db_module = "google.cloud.spanner_v1.client.DatabaseAdminClient" with mock.patch(db_module) as database_admin_client: api = client.database_admin_api diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 7bf14de751..2d7e2e1888 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -924,7 +924,9 @@ def test_drop_success(self): metadata=[("google-cloud-resource-prefix", database.name)], ) - def _execute_partitioned_dml_helper(self, dml, params=None, param_types=None): + def _execute_partitioned_dml_helper( + self, dml, params=None, param_types=None, query_options=None + ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1.proto.result_set_pb2 import ( PartialResultSet, @@ -935,7 +937,10 @@ def _execute_partitioned_dml_helper(self, dml, params=None, param_types=None): TransactionSelector, TransactionOptions, ) - from google.cloud.spanner_v1._helpers import _make_value_pb + from google.cloud.spanner_v1._helpers import ( + _make_value_pb, + _merge_query_options, + ) transaction_pb = TransactionPB(id=self.TRANSACTION_ID) @@ -953,7 +958,9 @@ def _execute_partitioned_dml_helper(self, dml, params=None, param_types=None): api.begin_transaction.return_value = transaction_pb api.execute_streaming_sql.return_value = iterator - row_count = database.execute_partitioned_dml(dml, params, param_types) + row_count = database.execute_partitioned_dml( + dml, params, param_types, query_options + ) self.assertEqual(row_count, 2) @@ -975,6 +982,11 @@ def _execute_partitioned_dml_helper(self, dml, params=None, param_types=None): expected_params = None expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + expected_query_options = client._query_options + if query_options: + expected_query_options = _merge_query_options( + expected_query_options, query_options + ) api.execute_streaming_sql.assert_called_once_with( self.SESSION_NAME, @@ -982,6 +994,7 @@ def _execute_partitioned_dml_helper(self, dml, params=None, param_types=None): transaction=expected_transaction, params=expected_params, param_types=param_types, + query_options=expected_query_options, metadata=[("google-cloud-resource-prefix", database.name)], ) @@ -997,6 +1010,14 @@ def test_execute_partitioned_dml_w_params_and_param_types(self): dml=DML_W_PARAM, params=PARAMS, param_types=PARAM_TYPES ) + def test_execute_partitioned_dml_w_query_options(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, + query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3"), + ) + def test_session_factory_defaults(self): from google.cloud.spanner_v1.session import Session @@ -1615,7 +1636,9 @@ def test_process_read_batch(self): def test_generate_query_batches_w_max_partitions(self): sql = "SELECT COUNT(*) FROM table_name" max_partitions = len(self.TOKENS) - database = self._make_database() + client = _Client(self.PROJECT_ID) + instance = _Instance(self.INSTANCE_NAME, client=client) + database = _Database(self.DATABASE_NAME, instance=instance) batch_txn = self._make_one(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -1624,7 +1647,7 @@ def test_generate_query_batches_w_max_partitions(self): batch_txn.generate_query_batches(sql, max_partitions=max_partitions) ) - expected_query = {"sql": sql} + expected_query = {"sql": sql, "query_options": client._query_options} self.assertEqual(len(batches), len(self.TOKENS)) for batch, token in zip(batches, self.TOKENS): self.assertEqual(batch["partition"], token) @@ -1645,7 +1668,9 @@ def test_generate_query_batches_w_params_w_partition_size_bytes(self): params = {"max_age": 30} param_types = {"max_age": "INT64"} size = 1 << 20 - database = self._make_database() + client = _Client(self.PROJECT_ID) + instance = _Instance(self.INSTANCE_NAME, client=client) + database = _Database(self.DATABASE_NAME, instance=instance) batch_txn = self._make_one(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -1656,7 +1681,12 @@ def test_generate_query_batches_w_params_w_partition_size_bytes(self): ) ) - expected_query = {"sql": sql, "params": params, "param_types": param_types} + expected_query = { + "sql": sql, + "params": params, + "param_types": param_types, + "query_options": client._query_options, + } self.assertEqual(len(batches), len(self.TOKENS)) for batch, token in zip(batches, self.TOKENS): self.assertEqual(batch["partition"], token) @@ -1782,12 +1812,15 @@ def _make_instance_api(): class _Client(object): def __init__(self, project=TestDatabase.PROJECT_ID): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + self.project = project self.project_name = "projects/" + self.project self._endpoint_cache = {} self.instance_admin_api = _make_instance_api() self._client_info = mock.Mock() self._client_options = mock.Mock() + self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") class _Instance(object): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1eff634af0..e2bf18c723 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -362,6 +362,7 @@ def test_execute_sql_defaults(self): None, None, None, + query_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, ) @@ -386,7 +387,13 @@ def test_execute_sql_non_default_retry(self): self.assertIs(found, snapshot().execute_sql.return_value) snapshot().execute_sql.assert_called_once_with( - SQL, params, param_types, "PLAN", timeout=None, retry=None + SQL, + params, + param_types, + "PLAN", + query_options=None, + timeout=None, + retry=None, ) def test_execute_sql_explicit(self): @@ -411,6 +418,7 @@ def test_execute_sql_explicit(self): params, param_types, "PLAN", + query_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 883ab73258..e29b19d5f1 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -326,6 +326,7 @@ def _execute_sql_helper( count=0, partition=None, sql_count=0, + query_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, ): @@ -341,7 +342,10 @@ def _execute_sql_helper( ) from google.cloud.spanner_v1.proto.type_pb2 import Type, StructType from google.cloud.spanner_v1.proto.type_pb2 import STRING, INT64 - from google.cloud.spanner_v1._helpers import _make_value_pb + from google.cloud.spanner_v1._helpers import ( + _make_value_pb, + _merge_query_options, + ) VALUES = [[u"bharney", u"rhubbyl", 31], [u"phred", u"phlyntstone", 32]] VALUE_PBS = [[_make_value_pb(item) for item in row] for row in VALUES] @@ -378,6 +382,7 @@ def _execute_sql_helper( PARAMS, PARAM_TYPES, query_mode=MODE, + query_options=query_options, partition=partition, retry=retry, timeout=timeout, @@ -410,6 +415,12 @@ def _execute_sql_helper( fields={key: _make_value_pb(value) for (key, value) in PARAMS.items()} ) + expected_query_options = database._instance._client._query_options + if query_options: + expected_query_options = _merge_query_options( + expected_query_options, query_options + ) + api.execute_streaming_sql.assert_called_once_with( self.SESSION_NAME, SQL_QUERY_WITH_PARAM, @@ -417,6 +428,7 @@ def _execute_sql_helper( params=expected_params, param_types=PARAM_TYPES, query_mode=MODE, + query_options=expected_query_options, partition_token=partition, seqno=sql_count, metadata=[("google-cloud-resource-prefix", database.name)], @@ -452,6 +464,14 @@ def test_execute_sql_w_retry(self): def test_execute_sql_w_timeout(self): self._execute_sql_helper(multi_use=False, timeout=None) + def test_execute_sql_w_query_options(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + self._execute_sql_helper( + multi_use=False, + query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3"), + ) + def _partition_read_helper( self, multi_use, w_txn, size=None, max_partitions=None, index=None ): @@ -971,16 +991,30 @@ def test_begin_ok_exact_strong(self): ) +class _Client(object): + def __init__(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") + + +class _Instance(object): + def __init__(self): + self._client = _Client() + + +class _Database(object): + def __init__(self): + self.name = "testing" + self._instance = _Instance() + + class _Session(object): def __init__(self, database=None, name=TestSnapshot.SESSION_NAME): self._database = database self.name = name -class _Database(object): - name = "testing" - - class _MockIterator(object): def __init__(self, *values, **kw): self._iter_values = iter(values) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 9ef13c2ab6..dcb6cb95d3 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -350,14 +350,17 @@ def test_execute_update_w_params_wo_param_types(self): with self.assertRaises(ValueError): transaction.execute_update(DML_QUERY_WITH_PARAM, PARAMS) - def _execute_update_helper(self, count=0): + def _execute_update_helper(self, count=0, query_options=None): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1.proto.result_set_pb2 import ( ResultSet, ResultSetStats, ) from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionSelector - from google.cloud.spanner_v1._helpers import _make_value_pb + from google.cloud.spanner_v1._helpers import ( + _make_value_pb, + _merge_query_options, + ) MODE = 2 # PROFILE stats_pb = ResultSetStats(row_count_exact=1) @@ -370,7 +373,11 @@ def _execute_update_helper(self, count=0): transaction._execute_sql_count = count row_count = transaction.execute_update( - DML_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, query_mode=MODE + DML_QUERY_WITH_PARAM, + PARAMS, + PARAM_TYPES, + query_mode=MODE, + query_options=query_options, ) self.assertEqual(row_count, 1) @@ -380,6 +387,12 @@ def _execute_update_helper(self, count=0): fields={key: _make_value_pb(value) for (key, value) in PARAMS.items()} ) + expected_query_options = database._instance._client._query_options + if query_options: + expected_query_options = _merge_query_options( + expected_query_options, query_options + ) + api.execute_sql.assert_called_once_with( self.SESSION_NAME, DML_QUERY_WITH_PARAM, @@ -387,6 +400,7 @@ def _execute_update_helper(self, count=0): params=expected_params, param_types=PARAM_TYPES, query_mode=MODE, + query_options=expected_query_options, seqno=count, metadata=[("google-cloud-resource-prefix", database.name)], ) @@ -399,6 +413,13 @@ def test_execute_update_new_transaction(self): def test_execute_update_w_count(self): self._execute_update_helper(count=1) + def test_execute_update_w_query_options(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + self._execute_update_helper( + query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3") + ) + def test_batch_update_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -557,8 +578,22 @@ def test_context_mgr_failure(self): self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)]) +class _Client(object): + def __init__(self): + from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest + + self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") + + +class _Instance(object): + def __init__(self): + self._client = _Client() + + class _Database(object): - name = "testing" + def __init__(self): + self.name = "testing" + self._instance = _Instance() class _Session(object):