diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index d5ccf39546..4d5fc1b69a 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -49,7 +49,6 @@ from google.cloud.client import ClientWithProject from google.cloud.spanner_v1 import __version__ from google.cloud.spanner_v1._helpers import _merge_query_options, _metadata_with_prefix -from google.cloud.spanner_v1.instance import DEFAULT_NODE_COUNT from google.cloud.spanner_v1.instance import Instance from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_admin_instance_v1 import ListInstanceConfigsRequest @@ -294,8 +293,9 @@ def instance( instance_id, configuration_name=None, display_name=None, - node_count=DEFAULT_NODE_COUNT, + node_count=None, labels=None, + processing_units=None, ): """Factory to create a instance associated with this client. @@ -320,6 +320,10 @@ def instance( :param node_count: (Optional) The number of nodes in the instance's cluster; used to set up the instance's cluster. + :type processing_units: int + :param processing_units: (Optional) The number of processing units + allocated to this instance. + :type labels: dict (str -> str) or None :param labels: (Optional) User-assigned labels for this instance. @@ -334,6 +338,7 @@ def instance( display_name, self._emulator_host, labels, + processing_units, ) def list_instances(self, filter_="", page_size=None): diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 5a9cf95f5a..7f5539acf8 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -15,6 +15,7 @@ """User friendly container for Cloud Spanner Instance.""" import google.api_core.operation +from google.api_core.exceptions import InvalidArgument import re from google.cloud.spanner_admin_instance_v1 import Instance as InstancePB @@ -41,6 +42,7 @@ ) DEFAULT_NODE_COUNT = 1 +PROCESSING_UNITS_PER_NODE = 1000 _OPERATION_METADATA_MESSAGES = ( backup.Backup, @@ -95,6 +97,10 @@ class Instance(object): :type node_count: int :param node_count: (Optional) Number of nodes allocated to the instance. + :type processing_units: int + :param processing_units: (Optional) The number of processing units + allocated to this instance. + :type display_name: str :param display_name: (Optional) The display name for the instance in the Cloud Console UI. (Must be between 4 and 30 @@ -110,15 +116,29 @@ def __init__( instance_id, client, configuration_name=None, - node_count=DEFAULT_NODE_COUNT, + node_count=None, display_name=None, emulator_host=None, labels=None, + processing_units=None, ): self.instance_id = instance_id self._client = client self.configuration_name = configuration_name - self.node_count = node_count + if node_count is not None and processing_units is not None: + if processing_units != node_count * PROCESSING_UNITS_PER_NODE: + raise InvalidArgument( + "Only one of node count and processing units can be set." + ) + if node_count is None and processing_units is None: + self._node_count = DEFAULT_NODE_COUNT + self._processing_units = DEFAULT_NODE_COUNT * PROCESSING_UNITS_PER_NODE + elif node_count is not None: + self._node_count = node_count + self._processing_units = node_count * PROCESSING_UNITS_PER_NODE + else: + self._processing_units = processing_units + self._node_count = processing_units // PROCESSING_UNITS_PER_NODE self.display_name = display_name or instance_id self.emulator_host = emulator_host if labels is None: @@ -134,7 +154,8 @@ def _update_from_pb(self, instance_pb): raise ValueError("Instance protobuf does not contain display_name") self.display_name = instance_pb.display_name self.configuration_name = instance_pb.config - self.node_count = instance_pb.node_count + self._node_count = instance_pb.node_count + self._processing_units = instance_pb.processing_units self.labels = instance_pb.labels @classmethod @@ -190,6 +211,44 @@ def name(self): """ return self._client.project_name + "/instances/" + self.instance_id + @property + def processing_units(self): + """Processing units used in requests. + + :rtype: int + :returns: The number of processing units allocated to this instance. + """ + return self._processing_units + + @processing_units.setter + def processing_units(self, value): + """Sets the processing units for requests. Affects node_count. + + :param value: The number of processing units allocated to this instance. + """ + self._processing_units = value + self._node_count = value // PROCESSING_UNITS_PER_NODE + + @property + def node_count(self): + """Node count used in requests. + + :rtype: int + :returns: + The number of nodes in the instance's cluster; + used to set up the instance's cluster. + """ + return self._node_count + + @node_count.setter + def node_count(self, value): + """Sets the node count for requests. Affects processing_units. + + :param value: The number of nodes in the instance's cluster. + """ + self._node_count = value + self._processing_units = value * PROCESSING_UNITS_PER_NODE + def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented @@ -218,7 +277,8 @@ def copy(self): self.instance_id, new_client, self.configuration_name, - node_count=self.node_count, + node_count=self._node_count, + processing_units=self._processing_units, display_name=self.display_name, ) @@ -250,7 +310,7 @@ def create(self): name=self.name, config=self.configuration_name, display_name=self.display_name, - node_count=self.node_count, + processing_units=self._processing_units, labels=self.labels, ) metadata = _metadata_with_prefix(self.name) @@ -306,8 +366,8 @@ def update(self): .. note:: - Updates the ``display_name``, ``node_count`` and ``labels``. To change those - values before updating, set them via + Updates the ``display_name``, ``node_count``, ``processing_units`` + and ``labels``. To change those values before updating, set them via .. code:: python @@ -325,10 +385,15 @@ def update(self): name=self.name, config=self.configuration_name, display_name=self.display_name, - node_count=self.node_count, + node_count=self._node_count, + processing_units=self._processing_units, labels=self.labels, ) - field_mask = FieldMask(paths=["config", "display_name", "node_count", "labels"]) + + # Always update only processing_units, not nodes + field_mask = FieldMask( + paths=["config", "display_name", "processing_units", "labels"] + ) metadata = _metadata_with_prefix(self.name) future = api.update_instance( diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 8471cfc4c2..ad2b8a9178 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -229,6 +229,35 @@ def test_create_instance(self): self.assertEqual(instance, instance_alt) self.assertEqual(instance.display_name, instance_alt.display_name) + @unittest.skipIf(USE_EMULATOR, "Skipping LCI tests") + @unittest.skipUnless(CREATE_INSTANCE, "Skipping instance creation") + def test_create_instance_with_processing_nodes(self): + ALT_INSTANCE_ID = "new" + unique_resource_id("-") + PROCESSING_UNITS = 5000 + instance = Config.CLIENT.instance( + instance_id=ALT_INSTANCE_ID, + configuration_name=Config.INSTANCE_CONFIG.name, + processing_units=PROCESSING_UNITS, + ) + operation = instance.create() + # Make sure this instance gets deleted after the test case. + self.instances_to_delete.append(instance) + + # We want to make sure the operation completes. + operation.result( + SPANNER_OPERATION_TIMEOUT_IN_SECONDS + ) # raises on failure / timeout. + + # Create a new instance instance and make sure it is the same. + instance_alt = Config.CLIENT.instance( + ALT_INSTANCE_ID, Config.INSTANCE_CONFIG.name + ) + instance_alt.reload() + + self.assertEqual(instance, instance_alt) + self.assertEqual(instance.display_name, instance_alt.display_name) + self.assertEqual(instance.processing_units, instance_alt.processing_units) + @unittest.skipIf(USE_EMULATOR, "Skipping updating instance") def test_update_instance(self): OLD_DISPLAY_NAME = Config.INSTANCE.display_name diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d33d9cc08a..2777fbc9a0 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -37,6 +37,7 @@ class TestClient(unittest.TestCase): INSTANCE_NAME = "%s/instances/%s" % (PATH, INSTANCE_ID) DISPLAY_NAME = "display-name" NODE_COUNT = 5 + PROCESSING_UNITS = 5000 LABELS = {"test": "true"} TIMEOUT_SECONDS = 80 @@ -580,6 +581,7 @@ def test_list_instances(self): config=self.CONFIGURATION_NAME, display_name=self.DISPLAY_NAME, node_count=self.NODE_COUNT, + processing_units=self.PROCESSING_UNITS, ) ] ) @@ -597,6 +599,7 @@ def test_list_instances(self): self.assertEqual(instance.config, self.CONFIGURATION_NAME) self.assertEqual(instance.display_name, self.DISPLAY_NAME) self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(instance.processing_units, self.PROCESSING_UNITS) expected_metadata = ( ("google-cloud-resource-prefix", client.project_name), diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 2ed777b25b..c715fb2ee1 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -27,6 +27,7 @@ class TestInstance(unittest.TestCase): LOCATION = "projects/" + PROJECT + "/locations/" + CONFIG_NAME DISPLAY_NAME = "display_name" NODE_COUNT = 5 + PROCESSING_UNITS = 5000 OP_ID = 8915 OP_NAME = "operations/projects/%s/instances/%soperations/%d" % ( PROJECT, @@ -39,6 +40,7 @@ class TestInstance(unittest.TestCase): DATABASE_ID = "database_id" DATABASE_NAME = "%s/databases/%s" % (INSTANCE_NAME, DATABASE_ID) LABELS = {"test": "true"} + FIELD_MASK = ["config", "display_name", "processing_units", "labels"] def _getTargetClass(self): from google.cloud.spanner_v1.instance import Instance @@ -230,7 +232,7 @@ def test_create_already_exists(self): self.assertEqual(instance.name, self.INSTANCE_NAME) self.assertEqual(instance.config, self.CONFIG_NAME) self.assertEqual(instance.display_name, self.INSTANCE_ID) - self.assertEqual(instance.node_count, 1) + self.assertEqual(instance.processing_units, 1000) self.assertEqual(metadata, [("google-cloud-resource-prefix", instance.name)]) def test_create_success(self): @@ -258,7 +260,36 @@ def test_create_success(self): self.assertEqual(instance.name, self.INSTANCE_NAME) self.assertEqual(instance.config, self.CONFIG_NAME) self.assertEqual(instance.display_name, self.DISPLAY_NAME) - self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(instance.processing_units, self.PROCESSING_UNITS) + self.assertEqual(instance.labels, self.LABELS) + self.assertEqual(metadata, [("google-cloud-resource-prefix", instance.name)]) + + def test_create_with_processing_units(self): + op_future = _FauxOperationFuture() + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _create_instance_response=op_future + ) + instance = self._make_one( + self.INSTANCE_ID, + client, + configuration_name=self.CONFIG_NAME, + display_name=self.DISPLAY_NAME, + processing_units=self.PROCESSING_UNITS, + labels=self.LABELS, + ) + + future = instance.create() + + self.assertIs(future, op_future) + + (parent, instance_id, instance, metadata) = api._created_instance + self.assertEqual(parent, self.PARENT) + self.assertEqual(instance_id, self.INSTANCE_ID) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.config, self.CONFIG_NAME) + self.assertEqual(instance.display_name, self.DISPLAY_NAME) + self.assertEqual(instance.processing_units, self.PROCESSING_UNITS) self.assertEqual(instance.labels, self.LABELS) self.assertEqual(metadata, [("google-cloud-resource-prefix", instance.name)]) @@ -389,9 +420,7 @@ def test_update_not_found(self): instance.update() instance, field_mask, metadata = api._updated_instance - self.assertEqual( - field_mask.paths, ["config", "display_name", "node_count", "labels"] - ) + self.assertEqual(field_mask.paths, self.FIELD_MASK) self.assertEqual(instance.name, self.INSTANCE_NAME) self.assertEqual(instance.config, self.CONFIG_NAME) self.assertEqual(instance.display_name, self.INSTANCE_ID) @@ -417,14 +446,42 @@ def test_update_success(self): self.assertIs(future, op_future) + instance, field_mask, metadata = api._updated_instance + self.assertEqual(field_mask.paths, self.FIELD_MASK) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.config, self.CONFIG_NAME) + self.assertEqual(instance.display_name, self.DISPLAY_NAME) + self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(instance.labels, self.LABELS) + self.assertEqual(metadata, [("google-cloud-resource-prefix", instance.name)]) + + def test_update_success_with_processing_units(self): + op_future = _FauxOperationFuture() + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _update_instance_response=op_future + ) + instance = self._make_one( + self.INSTANCE_ID, + client, + configuration_name=self.CONFIG_NAME, + processing_units=self.PROCESSING_UNITS, + display_name=self.DISPLAY_NAME, + labels=self.LABELS, + ) + + future = instance.update() + + self.assertIs(future, op_future) + instance, field_mask, metadata = api._updated_instance self.assertEqual( - field_mask.paths, ["config", "display_name", "node_count", "labels"] + field_mask.paths, ["config", "display_name", "processing_units", "labels"] ) self.assertEqual(instance.name, self.INSTANCE_NAME) self.assertEqual(instance.config, self.CONFIG_NAME) self.assertEqual(instance.display_name, self.DISPLAY_NAME) - self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(instance.processing_units, self.PROCESSING_UNITS) self.assertEqual(instance.labels, self.LABELS) self.assertEqual(metadata, [("google-cloud-resource-prefix", instance.name)])