Skip to content

Commit

Permalink
feat: add support for restore a database with CMEK
Browse files Browse the repository at this point in the history
  • Loading branch information
larkee committed Jan 5, 2021
1 parent 023b2e5 commit 514e675
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 35 deletions.
27 changes: 17 additions & 10 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -47,6 +47,8 @@
)
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import (
Expand Down Expand Up @@ -102,8 +104,9 @@ class Database(object):
or :class:`dict`
:param encryption_config:
(Optional) Encryption information about the database.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
If a dict is provided, it must be of the same form as either of the protobuf
messages :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig`
"""

_spanner_api = None
Expand All @@ -123,11 +126,7 @@ def __init__(
self._state = None
self._create_time = None
self._restore_info = None

if type(encryption_config) == dict:
self._encryption_config = EncryptionConfig(**encryption_config)
else:
self._encryption_config = encryption_config
self._encryption_config = encryption_config

if pool is None:
pool = BurstyPool()
Expand Down Expand Up @@ -297,6 +296,8 @@ def create(self):
db_name = self.database_id
if "-" in db_name:
db_name = "`%s`" % (db_name,)
if type(self._encryption_config) == dict:
self._encryption_config = EncryptionConfig(**self._encryption_config)

request = CreateDatabaseRequest(
parent=self._instance.name,
Expand Down Expand Up @@ -560,8 +561,8 @@ def run_in_transaction(self, func, *args, **kw):
def restore(self, source):
"""Restore from a backup to this database.
:type backup: :class:`~google.cloud.spanner_v1.backup.Backup`
:param backup: the path of the backup being restored from.
:type source: :class:`~google.cloud.spanner_v1.backup.Backup`
:param source: the path of the source being restored from.
:rtype: :class:`~google.api_core.operation.Operation`
:returns: a future used to poll the status of the create request
Expand All @@ -575,10 +576,16 @@ def restore(self, source):
raise ValueError("Restore source not specified")
api = self._instance._client.database_admin_api
metadata = _metadata_with_prefix(self.name)
future = api.restore_database(
if type(self._encryption_config) == dict:
self._encryption_config = RestoreDatabaseEncryptionConfig(**self._encryption_config)
request = RestoreDatabaseRequest(
parent=self._instance.name,
database_id=self.database_id,
backup=source.name,
encryption_config=self._encryption_config
)
future = api.restore_database(
request=request,
metadata=metadata,
)
return future
Expand Down
16 changes: 12 additions & 4 deletions google/cloud/spanner_v1/instance.py
Expand Up @@ -373,12 +373,14 @@ def database(
:param pool: (Optional) session pool to be used by database.
:type encryption_config:
:class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
:class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig` or
:class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig`
or :class:`dict`
:param encryption_config:
(Optional) Encryption information about the database.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig
If a dict is provided, it must be of the same form as either of the protobuf
messages :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig`
:rtype: :class:`~google.cloud.spanner_v1.database.Database`
:returns: a database owned by this instance.
Expand Down Expand Up @@ -444,7 +446,13 @@ def backup(self, backup_id, database="", expire_time=None, encryption_config=Non
backup_id, self, database=database.name, expire_time=expire_time, encryption_config=encryption_config
)
except AttributeError:
return Backup(backup_id, self, database=database, expire_time=expire_time)
return Backup(
backup_id,
self,
database=database,
expire_time=expire_time,
encryption_config=encryption_config
)

def list_backups(self, filter_="", page_size=None):
"""List backups for the instance.
Expand Down
118 changes: 99 additions & 19 deletions tests/unit/test_database.py
Expand Up @@ -157,20 +157,6 @@ def test_ctor_w_encryption_config(self):
self.assertIs(database._instance, instance)
self.assertEqual(database._encryption_config, encryption_config)

def test_ctor_w_encryption_config_dict(self):
from google.cloud.spanner_admin_database_v1 import EncryptionConfig

instance = _Instance(self.INSTANCE_NAME)
encryption_config_dict = {"kms_key_name": "kms_key"}
encryption_config = EncryptionConfig(kms_key_name="kms_key")
database = self._make_one(
self.DATABASE_ID, instance, encryption_config=encryption_config_dict
)
self.assertEqual(database.database_id, self.DATABASE_ID)
self.assertIs(database._instance, instance)
self.assertEqual(database._encryption_config, encryption_config)


def test_from_pb_bad_database_name(self):
from google.cloud.spanner_admin_database_v1 import Database

Expand Down Expand Up @@ -487,15 +473,17 @@ def test_create_instance_not_found(self):
def test_create_success(self):
from tests._fixtures import DDL_STATEMENTS
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
from google.cloud.spanner_admin_database_v1 import EncryptionConfig

op_future = object()
client = _Client()
api = client.database_admin_api = self._make_database_admin_api()
api.create_database.return_value = op_future
instance = _Instance(self.INSTANCE_NAME, client=client)
pool = _Pool()
encryption_config = EncryptionConfig(kms_key_name="kms_key_name")
database = self._make_one(
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool, encryption_config=encryption_config
)

future = database.create()
Expand All @@ -506,7 +494,40 @@ def test_create_success(self):
parent=self.INSTANCE_NAME,
create_statement="CREATE DATABASE {}".format(self.DATABASE_ID),
extra_statements=DDL_STATEMENTS,
encryption_config=None,
encryption_config=encryption_config,
)

api.create_database.assert_called_once_with(
request=expected_request,
metadata=[("google-cloud-resource-prefix", database.name)],
)

def test_create_success_w_encryption_config_dict(self):
from tests._fixtures import DDL_STATEMENTS
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
from google.cloud.spanner_admin_database_v1 import EncryptionConfig

op_future = object()
client = _Client()
api = client.database_admin_api = self._make_database_admin_api()
api.create_database.return_value = op_future
instance = _Instance(self.INSTANCE_NAME, client=client)
pool = _Pool()
encryption_config = {"kms_key_name": "kms_key_name"}
database = self._make_one(
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool, encryption_config=encryption_config
)

future = database.create()

self.assertIs(future, op_future)

expected_encryption_config = EncryptionConfig(**encryption_config)
expected_request = CreateDatabaseRequest(
parent=self.INSTANCE_NAME,
create_statement="CREATE DATABASE {}".format(self.DATABASE_ID),
extra_statements=DDL_STATEMENTS,
encryption_config=expected_encryption_config,
)

api.create_database.assert_called_once_with(
Expand Down Expand Up @@ -1123,6 +1144,7 @@ def test_restore_backup_unspecified(self):

def test_restore_grpc_error(self):
from google.api_core.exceptions import Unknown
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest

client = _Client()
api = client.database_admin_api = self._make_database_admin_api()
Expand All @@ -1135,15 +1157,20 @@ def test_restore_grpc_error(self):
with self.assertRaises(Unknown):
database.restore(backup)

api.restore_database.assert_called_once_with(
expected_request = RestoreDatabaseRequest(
parent=self.INSTANCE_NAME,
database_id=self.DATABASE_ID,
backup=self.BACKUP_NAME,
)

api.restore_database.assert_called_once_with(
request=expected_request,
metadata=[("google-cloud-resource-prefix", database.name)],
)

def test_restore_not_found(self):
from google.api_core.exceptions import NotFound
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest

client = _Client()
api = client.database_admin_api = self._make_database_admin_api()
Expand All @@ -1156,31 +1183,84 @@ def test_restore_not_found(self):
with self.assertRaises(NotFound):
database.restore(backup)

api.restore_database.assert_called_once_with(
expected_request = RestoreDatabaseRequest(
parent=self.INSTANCE_NAME,
database_id=self.DATABASE_ID,
backup=self.BACKUP_NAME,
)

api.restore_database.assert_called_once_with(
request=expected_request,
metadata=[("google-cloud-resource-prefix", database.name)],
)

def test_restore_success(self):
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest

op_future = object()
client = _Client()
api = client.database_admin_api = self._make_database_admin_api()
api.restore_database.return_value = op_future
instance = _Instance(self.INSTANCE_NAME, client=client)
pool = _Pool()
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
encryption_config = RestoreDatabaseEncryptionConfig(
encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION,
kms_key_name="kms_key_name"
)
database = self._make_one(self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config)
backup = _Backup(self.BACKUP_NAME)

future = database.restore(backup)

self.assertIs(future, op_future)

expected_request = RestoreDatabaseRequest(
parent=self.INSTANCE_NAME,
database_id=self.DATABASE_ID,
backup=self.BACKUP_NAME,
encryption_config=encryption_config
)

api.restore_database.assert_called_once_with(
request=expected_request,
metadata=[("google-cloud-resource-prefix", database.name)],
)

def test_restore_success_w_encryption_config_dict(self):
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest

op_future = object()
client = _Client()
api = client.database_admin_api = self._make_database_admin_api()
api.restore_database.return_value = op_future
instance = _Instance(self.INSTANCE_NAME, client=client)
pool = _Pool()
encryption_config = {
'encryption_type': RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION,
'kms_key_name': 'kms_key_name'
}
database = self._make_one(self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config)
backup = _Backup(self.BACKUP_NAME)

future = database.restore(backup)

self.assertIs(future, op_future)

expected_encryption_config = RestoreDatabaseEncryptionConfig(
encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION,
kms_key_name="kms_key_name"
)
expected_request = RestoreDatabaseRequest(
parent=self.INSTANCE_NAME,
database_id=self.DATABASE_ID,
backup=self.BACKUP_NAME,
encryption_config=expected_encryption_config
)

api.restore_database.assert_called_once_with(
request=expected_request,
metadata=[("google-cloud-resource-prefix", database.name)],
)

Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_instance.py
Expand Up @@ -488,15 +488,14 @@ def test_database_factory_defaults(self):
self.assertIs(pool._database, database)

def test_database_factory_explicit(self):
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
from google.cloud.spanner_v1.database import Database
from tests._fixtures import DDL_STATEMENTS

client = _Client(self.PROJECT)
instance = self._make_one(self.INSTANCE_ID, client, self.CONFIG_NAME)
DATABASE_ID = "database-id"
pool = _Pool()
encryption_config = EncryptionConfig(kms_key_name="kms_key")
encryption_config = {"kms_key_name": "kms_key_name"}

database = instance.database(
DATABASE_ID,
Expand Down

0 comments on commit 514e675

Please sign in to comment.