From e072d5dd04d58fff7f62ce19ce42e906dfd11012 Mon Sep 17 00:00:00 2001 From: larkee <31196561+larkee@users.noreply.github.com> Date: Fri, 31 Jan 2020 13:15:45 +1100 Subject: [PATCH] feat(spanner): add resource based routing implementation (#10183) * feat(spanner): implement resource routing * corrected warning message as per the PR comment * Update spanner/google/cloud/spanner_v1/database.py Add comma to warning message Co-Authored-By: skuruppu Co-authored-by: skuruppu --- google/cloud/spanner_v1/client.py | 9 +- google/cloud/spanner_v1/database.py | 47 +++++ tests/system/test_system.py | 58 +++++++ tests/unit/test_client.py | 39 ++++- tests/unit/test_database.py | 255 +++++++++++++++++++++++++++- 5 files changed, 401 insertions(+), 7 deletions(-) diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index b35bf19f07..264731178e 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -26,6 +26,7 @@ import warnings from google.api_core.gapic_v1 import client_info +import google.api_core.client_options # pylint: disable=line-too-long from google.cloud.spanner_admin_database_v1.gapic.database_admin_client import ( # noqa @@ -122,6 +123,7 @@ class Client(ClientWithProject): _instance_admin_api = None _database_admin_api = None + _endpoint_cache = {} user_agent = None _SET_PROJECT = True # Used by from_service_account_json() @@ -143,7 +145,12 @@ def __init__( project=project, credentials=credentials, _http=None ) self._client_info = client_info - self._client_options = client_options + if client_options and type(client_options) == dict: + self._client_options = google.api_core.client_options.from_dict( + client_options + ) + else: + self._client_options = client_options if user_agent is not None: warnings.warn(_USER_AGENT_DEPRECATED, DeprecationWarning, stacklevel=2) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index f561ecd4fa..49abe919d5 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -16,12 +16,16 @@ import copy import functools +import os import re import threading +import warnings +from google.api_core.client_options import ClientOptions import google.auth.credentials from google.protobuf.struct_pb2 import Struct from google.cloud.exceptions import NotFound +from google.api_core.exceptions import PermissionDenied import six # pylint: disable=ungrouped-imports @@ -54,6 +58,19 @@ ) +_RESOURCE_ROUTING_PERMISSIONS_WARNING = ( + "The client library attempted to connect to an endpoint closer to your Cloud Spanner data " + "but was unable to do so. The client library will fall back and route requests to the endpoint " + "given in the client options, which may result in increased latency. " + "We recommend including the scope https://www.googleapis.com/auth/spanner.admin so that the " + "client library can get an instance-specific endpoint and efficiently route requests." +) + + +class ResourceRoutingPermissionsWarning(Warning): + pass + + class Database(object): """Representation of a Cloud Spanner Database. @@ -178,6 +195,36 @@ def spanner_api(self): credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,)) client_info = self._instance._client._client_info client_options = self._instance._client._client_options + if ( + os.getenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING") + == "true" + ): + endpoint_cache = self._instance._client._endpoint_cache + if self._instance.name in endpoint_cache: + client_options = ClientOptions( + api_endpoint=endpoint_cache[self._instance.name] + ) + else: + try: + api = self._instance._client.instance_admin_api + resp = api.get_instance( + self._instance.name, + field_mask={"paths": ["endpoint_uris"]}, + metadata=_metadata_with_prefix(self.name), + ) + endpoints = resp.endpoint_uris + if endpoints: + endpoint_cache[self._instance.name] = list(endpoints)[0] + client_options = ClientOptions( + api_endpoint=endpoint_cache[self._instance.name] + ) + # If there are no endpoints, use default endpoint. + except PermissionDenied: + warnings.warn( + _RESOURCE_ROUTING_PERMISSIONS_WARNING, + ResourceRoutingPermissionsWarning, + stacklevel=2, + ) self._spanner_api = SpannerClient( credentials=credentials, client_info=client_info, diff --git a/tests/system/test_system.py b/tests/system/test_system.py index abfd1297d7..ae688029b4 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -56,6 +56,9 @@ CREATE_INSTANCE = os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None +USE_RESOURCE_ROUTING = ( + os.getenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING") == "true" +) if CREATE_INSTANCE: INSTANCE_ID = "google-cloud" + unique_resource_id("-") @@ -282,6 +285,61 @@ def tearDown(self): for doomed in self.to_delete: doomed.drop() + @unittest.skipUnless(USE_RESOURCE_ROUTING, "requires enabling resource routing") + def test_spanner_api_use_user_specified_endpoint(self): + # Clear cache. + Client._endpoint_cache = {} + api = Config.CLIENT.instance_admin_api + resp = api.get_instance( + Config.INSTANCE.name, field_mask={"paths": ["endpoint_uris"]} + ) + if not resp or not resp.endpoint_uris: + return # no resolved endpoint. + resolved_endpoint = resp.endpoint_uris[0] + + client = Client(client_options={"api_endpoint": resolved_endpoint}) + + instance = client.instance(Config.INSTANCE.instance_id) + temp_db_id = "temp_db" + unique_resource_id("_") + temp_db = instance.database(temp_db_id) + temp_db.spanner_api + + # No endpoint cache - Default endpoint used. + self.assertEqual(client._endpoint_cache, {}) + + @unittest.skipUnless(USE_RESOURCE_ROUTING, "requires enabling resource routing") + def test_spanner_api_use_resolved_endpoint(self): + # Clear cache. + Client._endpoint_cache = {} + api = Config.CLIENT.instance_admin_api + resp = api.get_instance( + Config.INSTANCE.name, field_mask={"paths": ["endpoint_uris"]} + ) + if not resp or not resp.endpoint_uris: + return # no resolved endpoint. + resolved_endpoint = resp.endpoint_uris[0] + + client = Client( + client_options=Config.CLIENT._client_options + ) # Use same endpoint as main client. + + instance = client.instance(Config.INSTANCE.instance_id) + temp_db_id = "temp_db" + unique_resource_id("_") + temp_db = instance.database(temp_db_id) + temp_db.spanner_api + + # Endpoint is cached - resolved endpoint used. + self.assertIn(Config.INSTANCE.name, client._endpoint_cache) + self.assertEqual( + client._endpoint_cache[Config.INSTANCE.name], resolved_endpoint + ) + + # Endpoint is cached at a class level. + self.assertIn(Config.INSTANCE.name, Config.CLIENT._endpoint_cache) + self.assertEqual( + Config.CLIENT._endpoint_cache[Config.INSTANCE.name], resolved_endpoint + ) + def test_list_databases(self): # Since `Config.INSTANCE` is newly created in `setUpModule`, the # database created in `setUpClass` here will be the only one. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e42031cea4..35e63bfd68 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -57,6 +57,7 @@ def _constructor_test_helper( user_agent=None, client_options=None, ): + import google.api_core.client_options from google.cloud.spanner_v1 import client as MUT kwargs = {} @@ -66,6 +67,14 @@ def _constructor_test_helper( else: expected_client_info = MUT._CLIENT_INFO + kwargs["client_options"] = client_options + if type(client_options) == dict: + expected_client_options = google.api_core.client_options.from_dict( + client_options + ) + else: + expected_client_options = client_options + client = self._make_one( project=self.PROJECT, credentials=creds, user_agent=user_agent, **kwargs ) @@ -80,7 +89,14 @@ def _constructor_test_helper( self.assertEqual(client.project, self.PROJECT) self.assertIs(client._client_info, expected_client_info) self.assertEqual(client.user_agent, user_agent) - self.assertEqual(client._client_options, client_options) + if expected_client_options is not None: + self.assertIsInstance( + client._client_options, google.api_core.client_options.ClientOptions + ) + self.assertEqual( + client._client_options.api_endpoint, + expected_client_options.api_endpoint, + ) def test_constructor_default_scopes(self): from google.cloud.spanner_v1 import client as MUT @@ -127,6 +143,27 @@ def test_constructor_credentials_wo_create_scoped(self): expected_scopes = None self._constructor_test_helper(expected_scopes, creds) + def test_constructor_custom_client_options_obj(self): + from google.api_core.client_options import ClientOptions + from google.cloud.spanner_v1 import client as MUT + + expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) + creds = _make_credentials() + self._constructor_test_helper( + expected_scopes, + creds, + client_options=ClientOptions(api_endpoint="endpoint"), + ) + + def test_constructor_custom_client_options_dict(self): + from google.cloud.spanner_v1 import client as MUT + + expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) + creds = _make_credentials() + self._constructor_test_helper( + expected_scopes, creds, client_options={"api_endpoint": "endpoint"} + ) + def test_instance_admin_api(self): from google.cloud.spanner_v1.client import SPANNER_ADMIN_SCOPE diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 163036f030..0f4071d868 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -231,7 +231,14 @@ def test_name_property(self): self.assertEqual(database.name, expected_name) def test_spanner_api_property_w_scopeless_creds(self): + from google.cloud.spanner_admin_instance_v1.proto import ( + spanner_instance_admin_pb2 as admin_v1_pb2, + ) + client = _Client() + client.instance_admin_api.get_instance.return_value = admin_v1_pb2.Instance( + endpoint_uris=[] + ) client_info = client._client_info = mock.Mock() client_options = client._client_options = mock.Mock() credentials = client.credentials = object() @@ -241,8 +248,10 @@ def test_spanner_api_property_w_scopeless_creds(self): patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") - with patch as spanner_client: - api = database.spanner_api + with mock.patch("os.getenv") as getenv: + getenv.return_value = "true" + with patch as spanner_client: + api = database.spanner_api self.assertIs(api, spanner_client.return_value) @@ -250,6 +259,7 @@ def test_spanner_api_property_w_scopeless_creds(self): again = database.spanner_api self.assertIs(again, api) + client.instance_admin_api.get_instance.assert_called_once() spanner_client.assert_called_once_with( credentials=credentials, client_info=client_info, @@ -258,6 +268,9 @@ def test_spanner_api_property_w_scopeless_creds(self): def test_spanner_api_w_scoped_creds(self): import google.auth.credentials + from google.cloud.spanner_admin_instance_v1.proto import ( + spanner_instance_admin_pb2 as admin_v1_pb2, + ) from google.cloud.spanner_v1.database import SPANNER_DATA_SCOPE class _CredentialsWithScopes(google.auth.credentials.Scoped): @@ -281,16 +294,22 @@ def with_scopes(self, scopes): database = self._make_one(self.DATABASE_ID, instance, pool=pool) patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") + client.instance_admin_api.get_instance.return_value = admin_v1_pb2.Instance( + endpoint_uris=[] + ) - with patch as spanner_client: - api = database.spanner_api + with mock.patch("os.getenv") as getenv: + getenv.return_value = "true" + with patch as spanner_client: + api = database.spanner_api - self.assertIs(api, spanner_client.return_value) + self.assertNotIn(instance.name, client._endpoint_cache) # API instance is cached again = database.spanner_api self.assertIs(again, api) + client.instance_admin_api.get_instance.assert_called_once() self.assertEqual(len(spanner_client.call_args_list), 1) called_args, called_kw = spanner_client.call_args self.assertEqual(called_args, ()) @@ -300,6 +319,222 @@ def with_scopes(self, scopes): self.assertEqual(scoped._scopes, expected_scopes) self.assertIs(scoped._source, credentials) + def test_spanner_api_property_w_scopeless_creds_and_new_endpoint(self): + from google.cloud.spanner_admin_instance_v1.proto import ( + spanner_instance_admin_pb2 as admin_v1_pb2, + ) + + client = _Client() + client.instance_admin_api.get_instance.return_value = admin_v1_pb2.Instance( + endpoint_uris=["test1", "test2"] + ) + client_info = client._client_info = mock.Mock() + client._client_options = mock.Mock() + credentials = client.credentials = object() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + client_patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") + options_patch = mock.patch("google.cloud.spanner_v1.database.ClientOptions") + + with mock.patch("os.getenv") as getenv: + getenv.return_value = "true" + with options_patch as options: + with client_patch as spanner_client: + api = database.spanner_api + + self.assertIs(api, spanner_client.return_value) + self.assertIn(instance.name, client._endpoint_cache) + + # API instance is cached + again = database.spanner_api + self.assertIs(again, api) + + self.assertEqual(len(spanner_client.call_args_list), 1) + called_args, called_kw = spanner_client.call_args + self.assertEqual(called_args, ()) + self.assertEqual(called_kw["client_info"], client_info) + self.assertEqual(called_kw["credentials"], credentials) + options.assert_called_with(api_endpoint="test1") + + def test_spanner_api_w_scoped_creds_and_new_endpoint(self): + import google.auth.credentials + from google.cloud.spanner_admin_instance_v1.proto import ( + spanner_instance_admin_pb2 as admin_v1_pb2, + ) + from google.cloud.spanner_v1.database import SPANNER_DATA_SCOPE + + class _CredentialsWithScopes(google.auth.credentials.Scoped): + def __init__(self, scopes=(), source=None): + self._scopes = scopes + self._source = source + + def requires_scopes(self): # pragma: NO COVER + return True + + def with_scopes(self, scopes): + return self.__class__(scopes, self) + + expected_scopes = (SPANNER_DATA_SCOPE,) + client = _Client() + client_info = client._client_info = mock.Mock() + client._client_options = mock.Mock() + credentials = client.credentials = _CredentialsWithScopes() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + client_patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") + options_patch = mock.patch("google.cloud.spanner_v1.database.ClientOptions") + client.instance_admin_api.get_instance.return_value = admin_v1_pb2.Instance( + endpoint_uris=["test1", "test2"] + ) + + with mock.patch("os.getenv") as getenv: + getenv.return_value = "true" + with options_patch as options: + with client_patch as spanner_client: + api = database.spanner_api + + self.assertIs(api, spanner_client.return_value) + self.assertIn(instance.name, client._endpoint_cache) + + # API instance is cached + again = database.spanner_api + self.assertIs(again, api) + + self.assertEqual(len(spanner_client.call_args_list), 1) + called_args, called_kw = spanner_client.call_args + self.assertEqual(called_args, ()) + self.assertEqual(called_kw["client_info"], client_info) + scoped = called_kw["credentials"] + self.assertEqual(scoped._scopes, expected_scopes) + self.assertIs(scoped._source, credentials) + options.assert_called_with(api_endpoint="test1") + + def test_spanner_api_resource_routing_permissions_error(self): + from google.api_core.exceptions import PermissionDenied + + client = _Client() + client_info = client._client_info = mock.Mock() + client_options = client._client_options = mock.Mock() + client._endpoint_cache = {} + credentials = client.credentials = mock.Mock() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") + client.instance_admin_api.get_instance.side_effect = PermissionDenied("test") + + with mock.patch("os.getenv") as getenv: + getenv.return_value = "true" + with patch as spanner_client: + api = database.spanner_api + + self.assertIs(api, spanner_client.return_value) + + # API instance is cached + again = database.spanner_api + self.assertIs(again, api) + + client.instance_admin_api.get_instance.assert_called_once() + spanner_client.assert_called_once_with( + credentials=credentials, + client_info=client_info, + client_options=client_options, + ) + + def test_spanner_api_disable_resource_routing(self): + client = _Client() + client_info = client._client_info = mock.Mock() + client_options = client._client_options = mock.Mock() + client._endpoint_cache = {} + credentials = client.credentials = mock.Mock() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") + + with mock.patch("os.getenv") as getenv: + getenv.return_value = "false" + with patch as spanner_client: + api = database.spanner_api + + self.assertIs(api, spanner_client.return_value) + + # API instance is cached + again = database.spanner_api + self.assertIs(again, api) + + client.instance_admin_api.get_instance.assert_not_called() + spanner_client.assert_called_once_with( + credentials=credentials, + client_info=client_info, + client_options=client_options, + ) + + def test_spanner_api_cached_endpoint(self): + from google.cloud.spanner_admin_instance_v1.proto import ( + spanner_instance_admin_pb2 as admin_v1_pb2, + ) + + client = _Client() + client_info = client._client_info = mock.Mock() + client._client_options = mock.Mock() + client._endpoint_cache = {self.INSTANCE_NAME: "cached"} + credentials = client.credentials = mock.Mock() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + client_patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") + options_patch = mock.patch("google.cloud.spanner_v1.database.ClientOptions") + client.instance_admin_api.get_instance.return_value = admin_v1_pb2.Instance( + endpoint_uris=["test1", "test2"] + ) + + with mock.patch("os.getenv") as getenv: + getenv.return_value = "true" + with options_patch as options: + with client_patch as spanner_client: + api = database.spanner_api + + self.assertIs(api, spanner_client.return_value) + + # API instance is cached + again = database.spanner_api + self.assertIs(again, api) + + self.assertEqual(len(spanner_client.call_args_list), 1) + called_args, called_kw = spanner_client.call_args + self.assertEqual(called_args, ()) + self.assertEqual(called_kw["client_info"], client_info) + self.assertEqual(called_kw["credentials"], credentials) + options.assert_called_with(api_endpoint="cached") + + def test_spanner_api_resource_routing_error(self): + from google.api_core.exceptions import GoogleAPIError + + client = _Client() + client._client_info = mock.Mock() + client._client_options = mock.Mock() + client.credentials = mock.Mock() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + client.instance_admin_api.get_instance.side_effect = GoogleAPIError("test") + + with mock.patch("os.getenv") as getenv: + getenv.return_value = "true" + with self.assertRaises(GoogleAPIError): + database.spanner_api + + client.instance_admin_api.get_instance.assert_called_once() + def test___eq__(self): instance = _Instance(self.INSTANCE_NAME) pool1, pool2 = _Pool(), _Pool() @@ -1516,10 +1751,20 @@ def test_process_w_query_batch(self): ) +def _make_instance_api(): + from google.cloud.spanner_admin_instance_v1.gapic.instance_admin_client import ( + InstanceAdminClient, + ) + + return mock.create_autospec(InstanceAdminClient) + + class _Client(object): def __init__(self, project=TestDatabase.PROJECT_ID): self.project = project self.project_name = "projects/" + self.project + self._endpoint_cache = {} + self.instance_admin_api = _make_instance_api() class _Instance(object):