Skip to content

Commit

Permalink
feat(spanner): add emulator support (#14)
Browse files Browse the repository at this point in the history
* emulator support implementation

* facilitate running system test against an emulator

* add tests

* formatting

* remove brittle error string checks

* add skips for tests when emulator support is used

* fix lint errors
  • Loading branch information
larkee committed Feb 19, 2020
1 parent 250f19e commit b315593
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 33 deletions.
75 changes: 64 additions & 11 deletions google/cloud/spanner_v1/client.py
Expand Up @@ -23,11 +23,21 @@
* a :class:`~google.cloud.spanner_v1.instance.Instance` owns a
:class:`~google.cloud.spanner_v1.database.Database`
"""
import grpc
import os
import warnings

from google.api_core.gapic_v1 import client_info
import google.api_core.client_options

from google.cloud.spanner_admin_instance_v1.gapic.transports import (
instance_admin_grpc_transport,
)

from google.cloud.spanner_admin_database_v1.gapic.transports import (
database_admin_grpc_transport,
)

# pylint: disable=line-too-long
from google.cloud.spanner_admin_database_v1.gapic.database_admin_client import ( # noqa
DatabaseAdminClient,
Expand All @@ -45,13 +55,23 @@
from google.cloud.spanner_v1.instance import Instance

_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST"
_EMULATOR_HOST_HTTP_SCHEME = (
"%s contains a http scheme. When used with a scheme it may cause gRPC's "
"DNS resolver to endlessly attempt to resolve. %s is intended to be used "
"without a scheme: ex %s=localhost:8080."
) % ((EMULATOR_ENV_VAR,) * 3)
SPANNER_ADMIN_SCOPE = "https://www.googleapis.com/auth/spanner.admin"
_USER_AGENT_DEPRECATED = (
"The 'user_agent' argument to 'Client' is deprecated / unused. "
"Please pass an appropriate 'client_info' instead."
)


def _get_spanner_emulator_host():
return os.getenv(EMULATOR_ENV_VAR)


class InstanceConfig(object):
"""Named configurations for Spanner instances.
Expand Down Expand Up @@ -156,6 +176,12 @@ def __init__(
warnings.warn(_USER_AGENT_DEPRECATED, DeprecationWarning, stacklevel=2)
self.user_agent = user_agent

if _get_spanner_emulator_host() is not None and (
"http://" in _get_spanner_emulator_host()
or "https://" in _get_spanner_emulator_host()
):
warnings.warn(_EMULATOR_HOST_HTTP_SCHEME)

@property
def credentials(self):
"""Getter for client's credentials.
Expand Down Expand Up @@ -189,22 +215,42 @@ def project_name(self):
def instance_admin_api(self):
"""Helper for session-related API calls."""
if self._instance_admin_api is None:
self._instance_admin_api = InstanceAdminClient(
credentials=self.credentials,
client_info=self._client_info,
client_options=self._client_options,
)
if _get_spanner_emulator_host() is not None:
transport = instance_admin_grpc_transport.InstanceAdminGrpcTransport(
channel=grpc.insecure_channel(_get_spanner_emulator_host())
)
self._instance_admin_api = InstanceAdminClient(
client_info=self._client_info,
client_options=self._client_options,
transport=transport,
)
else:
self._instance_admin_api = InstanceAdminClient(
credentials=self.credentials,
client_info=self._client_info,
client_options=self._client_options,
)
return self._instance_admin_api

@property
def database_admin_api(self):
"""Helper for session-related API calls."""
if self._database_admin_api is None:
self._database_admin_api = DatabaseAdminClient(
credentials=self.credentials,
client_info=self._client_info,
client_options=self._client_options,
)
if _get_spanner_emulator_host() is not None:
transport = database_admin_grpc_transport.DatabaseAdminGrpcTransport(
channel=grpc.insecure_channel(_get_spanner_emulator_host())
)
self._database_admin_api = DatabaseAdminClient(
client_info=self._client_info,
client_options=self._client_options,
transport=transport,
)
else:
self._database_admin_api = DatabaseAdminClient(
credentials=self.credentials,
client_info=self._client_info,
client_options=self._client_options,
)
return self._database_admin_api

def copy(self):
Expand Down Expand Up @@ -288,7 +334,14 @@ def instance(
:rtype: :class:`~google.cloud.spanner_v1.instance.Instance`
:returns: an instance owned by this client.
"""
return Instance(instance_id, self, configuration_name, node_count, display_name)
return Instance(
instance_id,
self,
configuration_name,
node_count,
display_name,
_get_spanner_emulator_host(),
)

def list_instances(self, filter_="", page_size=None, page_token=None):
"""List instances for the client's project.
Expand Down
16 changes: 14 additions & 2 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -16,6 +16,7 @@

import copy
import functools
import grpc
import os
import re
import threading
Expand All @@ -33,6 +34,7 @@
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient
from google.cloud.spanner_v1.gapic.transports import spanner_grpc_transport
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.pool import SessionCheckout
Expand Down Expand Up @@ -190,11 +192,21 @@ def ddl_statements(self):
def spanner_api(self):
"""Helper for session-related API calls."""
if self._spanner_api is None:
client_info = self._instance._client._client_info
client_options = self._instance._client._client_options
if self._instance.emulator_host is not None:
transport = spanner_grpc_transport.SpannerGrpcTransport(
channel=grpc.insecure_channel(self._instance.emulator_host)
)
self._spanner_api = SpannerClient(
client_info=client_info,
client_options=client_options,
transport=transport,
)
return self._spanner_api
credentials = self._instance._client.credentials
if isinstance(credentials, google.auth.credentials.Scoped):
credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,))
client_info = self._instance._client._client_info
client_options = self._instance._client._client_options
if (
os.getenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING")
== "true"
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/instance.py
Expand Up @@ -76,12 +76,14 @@ def __init__(
configuration_name=None,
node_count=DEFAULT_NODE_COUNT,
display_name=None,
emulator_host=None,
):
self.instance_id = instance_id
self._client = client
self.configuration_name = configuration_name
self.node_count = node_count
self.display_name = display_name or instance_id
self.emulator_host = emulator_host

def _update_from_pb(self, instance_pb):
"""Refresh self from the server-provided protobuf.
Expand Down
10 changes: 7 additions & 3 deletions noxfile.py
Expand Up @@ -94,9 +94,13 @@ def system(session):
"""Run the system test suite."""
system_test_path = os.path.join("tests", "system.py")
system_test_folder_path = os.path.join("tests", "system")
# Sanity check: Only run tests if the environment variable is set.
if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""):
session.skip("Credentials must be set via environment variable")
# Sanity check: Only run tests if either credentials or emulator host is set.
if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "") and not os.environ.get(
"SPANNER_EMULATOR_HOST", ""
):
session.skip(
"Credentials or emulator host must be set via environment variable"
)

system_test_exists = os.path.exists(system_test_path)
system_test_folder_exists = os.path.exists(system_test_folder_path)
Expand Down
25 changes: 11 additions & 14 deletions tests/system/test_system.py
Expand Up @@ -56,6 +56,7 @@


CREATE_INSTANCE = os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None
USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None
USE_RESOURCE_ROUTING = (
os.getenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING") == "true"
)
Expand Down Expand Up @@ -105,10 +106,10 @@ def setUpModule():
EXISTING_INSTANCES[:] = instances

if CREATE_INSTANCE:

# Defend against back-end returning configs for regions we aren't
# actually allowed to use.
configs = [config for config in configs if "-us-" in config.name]
if not USE_EMULATOR:
# Defend against back-end returning configs for regions we aren't
# actually allowed to use.
configs = [config for config in configs if "-us-" in config.name]

if not configs:
raise ValueError("List instance configs failed in module set up.")
Expand Down Expand Up @@ -185,6 +186,7 @@ def test_create_instance(self):
self.assertEqual(instance, instance_alt)
self.assertEqual(instance.display_name, instance_alt.display_name)

@unittest.skipIf(USE_EMULATOR, "Skipping updating instance")
def test_update_instance(self):
OLD_DISPLAY_NAME = Config.INSTANCE.display_name
NEW_DISPLAY_NAME = "Foo Bar Baz"
Expand Down Expand Up @@ -382,12 +384,9 @@ def test_table_not_found(self):
temp_db_id, ddl_statements=[create_table, index]
)
self.to_delete.append(temp_db)
with self.assertRaises(exceptions.NotFound) as exc_info:
with self.assertRaises(exceptions.NotFound):
temp_db.create()

expected = "Table not found: {0}".format(incorrect_table)
self.assertEqual(exc_info.exception.args, (expected,))

@pytest.mark.skip(
reason=(
"update_dataset_ddl() has a flaky timeout"
Expand Down Expand Up @@ -993,6 +992,7 @@ def test_transaction_batch_update_wo_statements(self):
with self.assertRaises(InvalidArgument):
transaction.batch_update([])

@unittest.skipIf(USE_EMULATOR, "Skipping partitioned DML")
def test_execute_partitioned_dml(self):
# [START spanner_test_dml_partioned_dml_update]
retry = RetryInstanceState(_has_all_ddl)
Expand Down Expand Up @@ -1625,6 +1625,7 @@ def test_read_with_range_keys_and_index_open_open(self):
expected = [data[keyrow]] + data[start + 1 : end]
self.assertEqual(rows, expected)

@unittest.skipIf(USE_EMULATOR, "Skipping partitioned reads")
def test_partition_read_w_index(self):
row_count = 10
columns = self.COLUMNS[1], self.COLUMNS[2]
Expand Down Expand Up @@ -1724,16 +1725,11 @@ def test_invalid_type(self):
batch.insert(table, columns, valid_input)

invalid_input = ((0, ""),)
with self.assertRaises(exceptions.FailedPrecondition) as exc_info:
with self.assertRaises(exceptions.FailedPrecondition):
with self._db.batch() as batch:
batch.delete(table, self.ALL)
batch.insert(table, columns, invalid_input)

error_msg = (
"Invalid value for column value in table " "counters: Expected INT64."
)
self.assertIn(error_msg, str(exc_info.exception))

def test_execute_sql_select_1(self):

self._db.snapshot(multi_use=True)
Expand Down Expand Up @@ -2111,6 +2107,7 @@ def test_execute_sql_returning_transfinite_floats(self):
# NaNs cannot be searched for by equality.
self.assertTrue(math.isnan(float_array[2]))

@unittest.skipIf(USE_EMULATOR, "Skipping partitioned queries")
def test_partition_query(self):
row_count = 40
sql = "SELECT * FROM {}".format(self.TABLE)
Expand Down

0 comments on commit b315593

Please sign in to comment.