From b315593bd3e473d96cc3033f5bbf0da7487e38eb Mon Sep 17 00:00:00 2001 From: larkee <31196561+larkee@users.noreply.github.com> Date: Wed, 19 Feb 2020 14:02:26 +1300 Subject: [PATCH] feat(spanner): add emulator support (#14) * emulator support implementation * facilitate running system test against an emulator * add tests * formatting * remove brittle error string checks * add skips for tests when emulator support is used * fix lint errors --- google/cloud/spanner_v1/client.py | 75 ++++++++++++++++++++++---- google/cloud/spanner_v1/database.py | 16 +++++- google/cloud/spanner_v1/instance.py | 2 + noxfile.py | 10 ++-- tests/system/test_system.py | 25 ++++----- tests/unit/test_client.py | 81 ++++++++++++++++++++++++++++- tests/unit/test_database.py | 26 ++++++++- 7 files changed, 202 insertions(+), 33 deletions(-) diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index 264731178e..c7b331adc0 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -23,11 +23,21 @@ * a :class:`~google.cloud.spanner_v1.instance.Instance` owns a :class:`~google.cloud.spanner_v1.database.Database` """ +import grpc +import os import warnings from google.api_core.gapic_v1 import client_info import google.api_core.client_options +from google.cloud.spanner_admin_instance_v1.gapic.transports import ( + instance_admin_grpc_transport, +) + +from google.cloud.spanner_admin_database_v1.gapic.transports import ( + database_admin_grpc_transport, +) + # pylint: disable=line-too-long from google.cloud.spanner_admin_database_v1.gapic.database_admin_client import ( # noqa DatabaseAdminClient, @@ -45,6 +55,12 @@ from google.cloud.spanner_v1.instance import Instance _CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) +EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST" +_EMULATOR_HOST_HTTP_SCHEME = ( + "%s contains a http scheme. When used with a scheme it may cause gRPC's " + "DNS resolver to endlessly attempt to resolve. %s is intended to be used " + "without a scheme: ex %s=localhost:8080." +) % ((EMULATOR_ENV_VAR,) * 3) SPANNER_ADMIN_SCOPE = "https://www.googleapis.com/auth/spanner.admin" _USER_AGENT_DEPRECATED = ( "The 'user_agent' argument to 'Client' is deprecated / unused. " @@ -52,6 +68,10 @@ ) +def _get_spanner_emulator_host(): + return os.getenv(EMULATOR_ENV_VAR) + + class InstanceConfig(object): """Named configurations for Spanner instances. @@ -156,6 +176,12 @@ def __init__( warnings.warn(_USER_AGENT_DEPRECATED, DeprecationWarning, stacklevel=2) self.user_agent = user_agent + if _get_spanner_emulator_host() is not None and ( + "http://" in _get_spanner_emulator_host() + or "https://" in _get_spanner_emulator_host() + ): + warnings.warn(_EMULATOR_HOST_HTTP_SCHEME) + @property def credentials(self): """Getter for client's credentials. @@ -189,22 +215,42 @@ def project_name(self): def instance_admin_api(self): """Helper for session-related API calls.""" if self._instance_admin_api is None: - self._instance_admin_api = InstanceAdminClient( - credentials=self.credentials, - client_info=self._client_info, - client_options=self._client_options, - ) + if _get_spanner_emulator_host() is not None: + transport = instance_admin_grpc_transport.InstanceAdminGrpcTransport( + channel=grpc.insecure_channel(_get_spanner_emulator_host()) + ) + self._instance_admin_api = InstanceAdminClient( + client_info=self._client_info, + client_options=self._client_options, + transport=transport, + ) + else: + self._instance_admin_api = InstanceAdminClient( + credentials=self.credentials, + client_info=self._client_info, + client_options=self._client_options, + ) return self._instance_admin_api @property def database_admin_api(self): """Helper for session-related API calls.""" if self._database_admin_api is None: - self._database_admin_api = DatabaseAdminClient( - credentials=self.credentials, - client_info=self._client_info, - client_options=self._client_options, - ) + if _get_spanner_emulator_host() is not None: + transport = database_admin_grpc_transport.DatabaseAdminGrpcTransport( + channel=grpc.insecure_channel(_get_spanner_emulator_host()) + ) + self._database_admin_api = DatabaseAdminClient( + client_info=self._client_info, + client_options=self._client_options, + transport=transport, + ) + else: + self._database_admin_api = DatabaseAdminClient( + credentials=self.credentials, + client_info=self._client_info, + client_options=self._client_options, + ) return self._database_admin_api def copy(self): @@ -288,7 +334,14 @@ def instance( :rtype: :class:`~google.cloud.spanner_v1.instance.Instance` :returns: an instance owned by this client. """ - return Instance(instance_id, self, configuration_name, node_count, display_name) + return Instance( + instance_id, + self, + configuration_name, + node_count, + display_name, + _get_spanner_emulator_host(), + ) def list_instances(self, filter_="", page_size=None, page_token=None): """List instances for the client's project. diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 49abe919d5..f5ea3e46dd 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -16,6 +16,7 @@ import copy import functools +import grpc import os import re import threading @@ -33,6 +34,7 @@ from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient +from google.cloud.spanner_v1.gapic.transports import spanner_grpc_transport from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.pool import SessionCheckout @@ -190,11 +192,21 @@ def ddl_statements(self): def spanner_api(self): """Helper for session-related API calls.""" if self._spanner_api is None: + client_info = self._instance._client._client_info + client_options = self._instance._client._client_options + if self._instance.emulator_host is not None: + transport = spanner_grpc_transport.SpannerGrpcTransport( + channel=grpc.insecure_channel(self._instance.emulator_host) + ) + self._spanner_api = SpannerClient( + client_info=client_info, + client_options=client_options, + transport=transport, + ) + return self._spanner_api credentials = self._instance._client.credentials if isinstance(credentials, google.auth.credentials.Scoped): credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,)) - client_info = self._instance._client._client_info - client_options = self._instance._client._client_options if ( os.getenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING") == "true" diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 83a600bd10..05e596622c 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -76,12 +76,14 @@ def __init__( configuration_name=None, node_count=DEFAULT_NODE_COUNT, display_name=None, + emulator_host=None, ): self.instance_id = instance_id self._client = client self.configuration_name = configuration_name self.node_count = node_count self.display_name = display_name or instance_id + self.emulator_host = emulator_host def _update_from_pb(self, instance_pb): """Refresh self from the server-provided protobuf. diff --git a/noxfile.py b/noxfile.py index 200b68e04c..22f328c4af 100644 --- a/noxfile.py +++ b/noxfile.py @@ -94,9 +94,13 @@ def system(session): """Run the system test suite.""" system_test_path = os.path.join("tests", "system.py") system_test_folder_path = os.path.join("tests", "system") - # Sanity check: Only run tests if the environment variable is set. - if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): - session.skip("Credentials must be set via environment variable") + # Sanity check: Only run tests if either credentials or emulator host is set. + if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "") and not os.environ.get( + "SPANNER_EMULATOR_HOST", "" + ): + session.skip( + "Credentials or emulator host must be set via environment variable" + ) system_test_exists = os.path.exists(system_test_path) system_test_folder_exists = os.path.exists(system_test_folder_path) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index ae688029b4..a8d349e677 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -56,6 +56,7 @@ CREATE_INSTANCE = os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None +USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None USE_RESOURCE_ROUTING = ( os.getenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING") == "true" ) @@ -105,10 +106,10 @@ def setUpModule(): EXISTING_INSTANCES[:] = instances if CREATE_INSTANCE: - - # Defend against back-end returning configs for regions we aren't - # actually allowed to use. - configs = [config for config in configs if "-us-" in config.name] + if not USE_EMULATOR: + # Defend against back-end returning configs for regions we aren't + # actually allowed to use. + configs = [config for config in configs if "-us-" in config.name] if not configs: raise ValueError("List instance configs failed in module set up.") @@ -185,6 +186,7 @@ def test_create_instance(self): self.assertEqual(instance, instance_alt) self.assertEqual(instance.display_name, instance_alt.display_name) + @unittest.skipIf(USE_EMULATOR, "Skipping updating instance") def test_update_instance(self): OLD_DISPLAY_NAME = Config.INSTANCE.display_name NEW_DISPLAY_NAME = "Foo Bar Baz" @@ -382,12 +384,9 @@ def test_table_not_found(self): temp_db_id, ddl_statements=[create_table, index] ) self.to_delete.append(temp_db) - with self.assertRaises(exceptions.NotFound) as exc_info: + with self.assertRaises(exceptions.NotFound): temp_db.create() - expected = "Table not found: {0}".format(incorrect_table) - self.assertEqual(exc_info.exception.args, (expected,)) - @pytest.mark.skip( reason=( "update_dataset_ddl() has a flaky timeout" @@ -993,6 +992,7 @@ def test_transaction_batch_update_wo_statements(self): with self.assertRaises(InvalidArgument): transaction.batch_update([]) + @unittest.skipIf(USE_EMULATOR, "Skipping partitioned DML") def test_execute_partitioned_dml(self): # [START spanner_test_dml_partioned_dml_update] retry = RetryInstanceState(_has_all_ddl) @@ -1625,6 +1625,7 @@ def test_read_with_range_keys_and_index_open_open(self): expected = [data[keyrow]] + data[start + 1 : end] self.assertEqual(rows, expected) + @unittest.skipIf(USE_EMULATOR, "Skipping partitioned reads") def test_partition_read_w_index(self): row_count = 10 columns = self.COLUMNS[1], self.COLUMNS[2] @@ -1724,16 +1725,11 @@ def test_invalid_type(self): batch.insert(table, columns, valid_input) invalid_input = ((0, ""),) - with self.assertRaises(exceptions.FailedPrecondition) as exc_info: + with self.assertRaises(exceptions.FailedPrecondition): with self._db.batch() as batch: batch.delete(table, self.ALL) batch.insert(table, columns, invalid_input) - error_msg = ( - "Invalid value for column value in table " "counters: Expected INT64." - ) - self.assertIn(error_msg, str(exc_info.exception)) - def test_execute_sql_select_1(self): self._db.snapshot(multi_use=True) @@ -2111,6 +2107,7 @@ def test_execute_sql_returning_transfinite_floats(self): # NaNs cannot be searched for by equality. self.assertTrue(math.isnan(float_array[2])) + @unittest.skipIf(USE_EMULATOR, "Skipping partitioned queries") def test_partition_query(self): row_count = 40 sql = "SELECT * FROM {}".format(self.TABLE) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 35e63bfd68..2e04537e02 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -98,6 +98,17 @@ def _constructor_test_helper( expected_client_options.api_endpoint, ) + @mock.patch("google.cloud.spanner_v1.client.os.getenv") + @mock.patch("warnings.warn") + def test_constructor_emulator_host_warning(self, mock_warn, mock_os): + from google.cloud.spanner_v1 import client as MUT + + expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) + creds = _make_credentials() + mock_os.return_value = "http://emulator.host.com" + self._constructor_test_helper(expected_scopes, creds) + mock_warn.assert_called_once_with(MUT._EMULATOR_HOST_HTTP_SCHEME) + def test_constructor_default_scopes(self): from google.cloud.spanner_v1 import client as MUT @@ -164,7 +175,8 @@ def test_constructor_custom_client_options_dict(self): expected_scopes, creds, client_options={"api_endpoint": "endpoint"} ) - def test_instance_admin_api(self): + @mock.patch("google.cloud.spanner_v1.client.os.getenv") + def test_instance_admin_api(self, mock_getenv): from google.cloud.spanner_v1.client import SPANNER_ADMIN_SCOPE credentials = _make_credentials() @@ -178,6 +190,7 @@ def test_instance_admin_api(self): ) expected_scopes = (SPANNER_ADMIN_SCOPE,) + mock_getenv.return_value = None inst_module = "google.cloud.spanner_v1.client.InstanceAdminClient" with mock.patch(inst_module) as instance_admin_client: api = client.instance_admin_api @@ -196,7 +209,39 @@ def test_instance_admin_api(self): credentials.with_scopes.assert_called_once_with(expected_scopes) - def test_database_admin_api(self): + @mock.patch("google.cloud.spanner_v1.client.os.getenv") + def test_instance_admin_api_emulator(self, mock_getenv): + credentials = _make_credentials() + client_info = mock.Mock() + client_options = mock.Mock() + client = self._make_one( + project=self.PROJECT, + credentials=credentials, + client_info=client_info, + client_options=client_options, + ) + + mock_getenv.return_value = "true" + inst_module = "google.cloud.spanner_v1.client.InstanceAdminClient" + with mock.patch(inst_module) as instance_admin_client: + api = client.instance_admin_api + + self.assertIs(api, instance_admin_client.return_value) + + # API instance is cached + again = client.instance_admin_api + self.assertIs(again, api) + + self.assertEqual(len(instance_admin_client.call_args_list), 1) + called_args, called_kw = instance_admin_client.call_args + self.assertEqual(called_args, ()) + self.assertEqual(called_kw["client_info"], client_info) + self.assertEqual(called_kw["client_options"], client_options) + self.assertIn("transport", called_kw) + self.assertNotIn("credentials", called_kw) + + @mock.patch("google.cloud.spanner_v1.client.os.getenv") + def test_database_admin_api(self, mock_getenv): from google.cloud.spanner_v1.client import SPANNER_ADMIN_SCOPE credentials = _make_credentials() @@ -210,6 +255,7 @@ def test_database_admin_api(self): ) expected_scopes = (SPANNER_ADMIN_SCOPE,) + mock_getenv.return_value = None db_module = "google.cloud.spanner_v1.client.DatabaseAdminClient" with mock.patch(db_module) as database_admin_client: api = client.database_admin_api @@ -228,6 +274,37 @@ def test_database_admin_api(self): credentials.with_scopes.assert_called_once_with(expected_scopes) + @mock.patch("google.cloud.spanner_v1.client.os.getenv") + def test_database_admin_api_emulator(self, mock_getenv): + credentials = _make_credentials() + client_info = mock.Mock() + client_options = mock.Mock() + client = self._make_one( + project=self.PROJECT, + credentials=credentials, + client_info=client_info, + client_options=client_options, + ) + + mock_getenv.return_value = "true" + db_module = "google.cloud.spanner_v1.client.DatabaseAdminClient" + with mock.patch(db_module) as database_admin_client: + api = client.database_admin_api + + self.assertIs(api, database_admin_client.return_value) + + # API instance is cached + again = client.database_admin_api + self.assertIs(again, api) + + self.assertEqual(len(database_admin_client.call_args_list), 1) + called_args, called_kw = database_admin_client.call_args + self.assertEqual(called_args, ()) + self.assertEqual(called_kw["client_info"], client_info) + self.assertEqual(called_kw["client_options"], client_options) + self.assertIn("transport", called_kw) + self.assertNotIn("credentials", called_kw) + def test_copy(self): credentials = _make_credentials() # Make sure it "already" is scoped. diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 0f4071d868..7bf14de751 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -535,6 +535,27 @@ def test_spanner_api_resource_routing_error(self): client.instance_admin_api.get_instance.assert_called_once() + def test_spanner_api_w_emulator_host(self): + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client, emulator_host="host") + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") + with patch as spanner_client: + api = database.spanner_api + + self.assertIs(api, spanner_client.return_value) + + # API instance is cached + again = database.spanner_api + self.assertIs(again, api) + + self.assertEqual(len(spanner_client.call_args_list), 1) + called_args, called_kw = spanner_client.call_args + self.assertEqual(called_args, ()) + self.assertIsNotNone(called_kw["transport"]) + def test___eq__(self): instance = _Instance(self.INSTANCE_NAME) pool1, pool2 = _Pool(), _Pool() @@ -1765,13 +1786,16 @@ def __init__(self, project=TestDatabase.PROJECT_ID): self.project_name = "projects/" + self.project self._endpoint_cache = {} self.instance_admin_api = _make_instance_api() + self._client_info = mock.Mock() + self._client_options = mock.Mock() class _Instance(object): - def __init__(self, name, client=None): + def __init__(self, name, client=None, emulator_host=None): self.name = name self.instance_id = name.rsplit("/", 1)[1] self._client = client + self.emulator_host = emulator_host class _Database(object):