diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index dd9d1e090d..2ce48adc53 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -110,12 +110,43 @@ def __init__( credentials=credentials, resource_name=endpoint_name, ) - self._gca_resource = self._get_gca_resource(resource_name=endpoint_name) + + endpoint_name = utils.full_resource_name( + resource_name=endpoint_name, + resource_noun="endpoints", + project=project, + location=location, + ) + + # Lazy load the Endpoint gca_resource until needed + self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name) self._prediction_client = self._instantiate_prediction_client( location=self.location, credentials=credentials, ) + def _skipped_getter_call(self) -> bool: + """Check if GAPIC resource was populated by call to get/list API methods + + Returns False if `_gca_resource` is None or fully populated. Returns True + if `_gca_resource` is partially populated + """ + return self._gca_resource and not self._gca_resource.create_time + + def _sync_gca_resource_if_skipped(self) -> None: + """Sync GAPIC service representation of Endpoint class resource only if + get_endpoint() was never called.""" + if self._skipped_getter_call(): + self._gca_resource = self._get_gca_resource( + resource_name=self._gca_resource.name + ) + + def _assert_gca_resource_is_available(self) -> None: + """Ensures Endpoint getter was called at least once before + asserting on gca_resource's availability.""" + super()._assert_gca_resource_is_available() + self._sync_gca_resource_if_skipped() + @property def traffic_split(self) -> Dict[str, int]: """A map from a DeployedModel's ID to the percentage of this Endpoint's @@ -315,8 +346,8 @@ def _create( _LOGGER.log_create_complete(cls, created_endpoint, "endpoint") - return cls( - endpoint_name=created_endpoint.name, + return cls._construct_sdk_resource_from_gapic( + gapic_resource=created_endpoint, project=project, location=location, credentials=credentials, @@ -622,6 +653,7 @@ def deploy( will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. """ + self._sync_gca_resource_if_skipped() self._validate_deploy_args( min_replica_count, @@ -967,6 +999,8 @@ def undeploy( Optional. Strings which should be sent along with the request as metadata. """ + self._sync_gca_resource_if_skipped() + if traffic_split is not None: if deployed_model_id in traffic_split and traffic_split[deployed_model_id]: raise ValueError("Model being undeployed should have 0 traffic.") @@ -1011,6 +1045,7 @@ def _undeploy( Optional. Strings which should be sent along with the request as metadata. """ + self._sync_gca_resource_if_skipped() current_traffic_split = traffic_split or dict(self._gca_resource.traffic_split) if deployed_model_id in current_traffic_split: @@ -1095,7 +1130,7 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict self.wait() prediction_response = self._prediction_client.predict( - endpoint=self.resource_name, instances=instances, parameters=parameters + endpoint=self._gca_resource.name, instances=instances, parameters=parameters ) return Prediction( diff --git a/tests/system/aiplatform/test_e2e_tabular.py b/tests/system/aiplatform/test_e2e_tabular.py index a55ea237e4..651c737555 100644 --- a/tests/system/aiplatform/test_e2e_tabular.py +++ b/tests/system/aiplatform/test_e2e_tabular.py @@ -34,6 +34,16 @@ _LOCAL_TRAINING_SCRIPT_PATH = os.path.join( _DIR_NAME, "test_resources/california_housing_training_script.py" ) +_INSTANCE = { + "longitude": -124.35, + "latitude": 40.54, + "housing_median_age": 52.0, + "total_rooms": 1820.0, + "total_bedrooms": 300.0, + "population": 806, + "households": 270.0, + "median_income": 3.014700, +} @pytest.mark.usefixtures("prepare_staging_bucket", "delete_staging_bucket", "teardown") @@ -136,39 +146,20 @@ def test_end_to_end_tabular(self, shared_state): # Send online prediction with same instance to both deployed models # This sample is taken from an observation where median_house_value = 94600 custom_endpoint.wait() - custom_prediction = custom_endpoint.predict( - [ - { - "longitude": -124.35, - "latitude": 40.54, - "housing_median_age": 52.0, - "total_rooms": 1820.0, - "total_bedrooms": 300.0, - "population": 806, - "households": 270.0, - "median_income": 3.014700, - }, - ] - ) + custom_prediction = custom_endpoint.predict([_INSTANCE]) custom_batch_prediction_job.wait() automl_endpoint.wait() automl_prediction = automl_endpoint.predict( - [ - { - "longitude": "-124.35", - "latitude": "40.54", - "housing_median_age": "52.0", - "total_rooms": "1820.0", - "total_bedrooms": "300.0", - "population": "806", - "households": "270.0", - "median_income": "3.014700", - }, - ] + [{k: str(v) for k, v in _INSTANCE.items()}] # Cast int values to strings ) + # Test lazy loading of Endpoint, check getter was never called after predict() + custom_endpoint = aiplatform.Endpoint(custom_endpoint.resource_name) + custom_endpoint.predict([_INSTANCE]) + assert custom_endpoint._skipped_getter_call() + assert ( custom_job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED diff --git a/tests/unit/aiplatform/test_end_to_end.py b/tests/unit/aiplatform/test_end_to_end.py index d9e0788f39..10ba0c3b0c 100644 --- a/tests/unit/aiplatform/test_end_to_end.py +++ b/tests/unit/aiplatform/test_end_to_end.py @@ -151,6 +151,12 @@ def test_dataset_create_to_model_predict( assert endpoint_deploy_return is None if not sync: + # Accessing attribute in Endpoint that has not been created raises informatively + with pytest.raises( + RuntimeError, match=r"Endpoint resource has not been created." + ): + my_endpoint.network + my_endpoint.wait() created_endpoint.wait() diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index a6e8488af8..00fe5093cf 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -169,6 +169,15 @@ def get_endpoint_mock(): yield get_endpoint_mock +@pytest.fixture +def get_empty_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint(name=_TEST_ENDPOINT_NAME) + yield get_endpoint_mock + + @pytest.fixture def get_endpoint_alt_location_mock(): with mock.patch.object( @@ -213,7 +222,9 @@ def create_endpoint_mock(): ) as create_endpoint_mock: create_endpoint_lro_mock = mock.Mock(ga_operation.Operation) create_endpoint_lro_mock.result.return_value = gca_endpoint.Endpoint( - name=_TEST_ENDPOINT_NAME, display_name=_TEST_DISPLAY_NAME + name=_TEST_ENDPOINT_NAME, + display_name=_TEST_DISPLAY_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, ) create_endpoint_mock.return_value = create_endpoint_lro_mock yield create_endpoint_mock @@ -378,19 +389,35 @@ def test_constructor(self, create_endpoint_client_mock): ] ) - def test_constructor_with_endpoint_id(self, get_endpoint_mock): - models.Endpoint(_TEST_ID) - get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME) + def test_lazy_constructor_with_endpoint_id(self, get_endpoint_mock): + ep = models.Endpoint(_TEST_ID) + assert ep._gca_resource.name == _TEST_ENDPOINT_NAME + assert ep._skipped_getter_call() + assert not get_endpoint_mock.called - def test_constructor_with_endpoint_name(self, get_endpoint_mock): - models.Endpoint(_TEST_ENDPOINT_NAME) + def test_lazy_constructor_with_endpoint_name(self, get_endpoint_mock): + ep = models.Endpoint(_TEST_ENDPOINT_NAME) + assert ep._gca_resource.name == _TEST_ENDPOINT_NAME + assert ep._skipped_getter_call() + assert not get_endpoint_mock.called + + def test_lazy_constructor_calls_get_on_property_access(self, get_endpoint_mock): + ep = models.Endpoint(_TEST_ENDPOINT_NAME) + assert ep._gca_resource.name == _TEST_ENDPOINT_NAME + assert ep._skipped_getter_call() + assert not get_endpoint_mock.called + + ep.display_name # Retrieve a property that requires a call to Endpoint getter get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME) - def test_constructor_with_custom_project(self, get_endpoint_mock): - models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2) + def test_lazy_constructor_with_custom_project(self, get_endpoint_mock): + ep = models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2) test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID ) + assert not get_endpoint_mock.called + + ep.name # Retrieve a property that requires a call to Endpoint getter get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name) @pytest.mark.usefixtures("get_endpoint_mock") @@ -406,11 +433,19 @@ def test_constructor_with_conflicting_location(self): regexp=r"is provided, but different from the resource location" ) - def test_constructor_with_custom_location(self, get_endpoint_alt_location_mock): - models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) + def test_lazy_constructor_with_custom_location( + self, get_endpoint_alt_location_mock + ): + ep = models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID ) + + # Get Endpoint not called due to lazy loading + assert not get_endpoint_alt_location_mock.called + + ep.network # Accessing a property that requires calling getter + get_endpoint_alt_location_mock.assert_called_with( name=test_endpoint_resource_name ) @@ -481,15 +516,17 @@ def test_create(self, create_endpoint_mock, sync): ) expected_endpoint.name = _TEST_ENDPOINT_NAME - assert my_endpoint.gca_resource == expected_endpoint - assert my_endpoint.network is None + assert my_endpoint._gca_resource == expected_endpoint - @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.usefixtures("get_empty_endpoint_mock") def test_accessing_properties_with_no_resource_raises(self,): + """Ensure a descriptive RuntimeError is raised when the + GAPIC object has not been populated""" my_endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME) - my_endpoint._gca_resource = None + # Create a gca_resource without `name` being populated + my_endpoint._gca_resource = gca_endpoint.Endpoint(create_time=datetime.now()) with pytest.raises(RuntimeError) as e: my_endpoint.gca_resource @@ -909,7 +946,7 @@ def test_undeploy(self, undeploy_model_mock, sync): traffic_split={"model1": 100}, ) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) - assert dict(test_endpoint._gca_resource.traffic_split) == {"model1": 100} + assert dict(test_endpoint.traffic_split) == {"model1": 100} test_endpoint.undeploy("model1", sync=sync) if not sync: test_endpoint.wait()