From ca7f6d64ea75349a841b53fe6ef6547942439e35 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Wed, 31 Mar 2021 09:36:03 -0600 Subject: [PATCH] feat: Add Custom Container Prediction support, move to single API endpoint (#277) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary of Changes - Make `artifact_uri` optional in `Model.upload()` to allow for custom containers that contain model artifacts & server - Remove reference to specific prediction endpoint (`*-prediction-aiplatform.googleapis.com`) - [link to commit](https://github.com/googleapis/python-aiplatform/commit/bbd92185237673d28504e227e2b254d002b041a9) - Update relevant tests Fixes [b/184049758](http://b/184049758), [b/180036930](http://b/180036930) 🦕 --- google/cloud/aiplatform/initializer.py | 17 +++++-------- google/cloud/aiplatform/models.py | 13 ++++++---- tests/unit/aiplatform/test_initializer.py | 29 ++++++----------------- tests/unit/aiplatform/test_models.py | 6 ++--- 4 files changed, 23 insertions(+), 42 deletions(-) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 9baf9b6106..f544df2a7a 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -163,7 +163,7 @@ def encryption_spec_key_name(self) -> Optional[str]: return self._encryption_spec_key_name def get_client_options( - self, location_override: Optional[str] = None, prediction_client: bool = False, + self, location_override: Optional[str] = None ) -> client_options.ClientOptions: """Creates GAPIC client_options using location and type. @@ -173,15 +173,11 @@ def get_client_options( location set by initializer. Must be a GCP region supported by AI Platform (Unified). - prediction_client (bool): - True if service client is a PredictionServiceClient, otherwise defaults - to False. This is used to provide a prediction-specific API endpoint. - Returns: - clients_options (dict): - A dictionary containing client_options with one key, for example + clients_options (google.api_core.client_options.ClientOptions): + A ClientOptions object set with regionalized API endpoint, i.e. { "api_endpoint": "us-central1-aiplatform.googleapis.com" } or - { "api_endpoint": "asia-east1-prediction-aiplatform.googleapis.com" } + { "api_endpoint": "asia-east1-aiplatform.googleapis.com" } """ if not (self.location or location_override): raise ValueError( @@ -190,12 +186,11 @@ def get_client_options( region = location_override or self.location region = region.lower() - prediction = "prediction-" if prediction_client else "" utils.validate_region(region) return client_options.ClientOptions( - api_endpoint=f"{region}-{prediction}{constants.API_BASE_PATH}" + api_endpoint=f"{region}-{constants.API_BASE_PATH}" ) def common_location_path( @@ -250,7 +245,7 @@ def create_client( kwargs = { "credentials": credentials or self.credentials, "client_options": self.get_client_options( - location_override=location_override, prediction_client=prediction_client + location_override=location_override ), "client_info": client_info, } diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 1c0de15778..1ce671ba9f 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -1130,9 +1130,9 @@ def __init__( def upload( cls, display_name: str, - artifact_uri: str, serving_container_image_uri: str, *, + artifact_uri: Optional[str] = None, serving_container_predict_route: Optional[str] = None, serving_container_health_route: Optional[str] = None, description: Optional[str] = None, @@ -1167,11 +1167,12 @@ def upload( display_name (str): Required. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. - artifact_uri (str): - Required. The path to the directory containing the Model artifact and - any of its supporting files. Not present for AutoML Models. serving_container_image_uri (str): Required. The URI of the Model serving container. + artifact_uri (str): + Optional. The path to the directory containing the Model artifact and + any of its supporting files. Leave blank for custom container prediction. + Not present for AutoML Models. serving_container_predict_route (str): Optional. An HTTP path to send prediction requests to the container, and which must be supported by it. If not specified a default HTTP path will @@ -1335,12 +1336,14 @@ def upload( managed_model = gca_model.Model( display_name=display_name, description=description, - artifact_uri=artifact_uri, container_spec=container_spec, predict_schemata=model_predict_schemata, encryption_spec=encryption_spec, ) + if artifact_uri: + managed_model.artifact_uri = artifact_uri + # Override explanation_spec if both required fields are provided if explanation_metadata and explanation_parameters: explanation_spec = gca_endpoint.explanation.ExplanationSpec() diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 905b2d7823..041a498e38 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -117,7 +117,7 @@ def test_create_client_overrides(self): assert isinstance(client, model_service_client.ModelServiceClient) assert ( client._transport._host - == f"{_TEST_LOCATION_2}-prediction-{constants.API_BASE_PATH}:443" + == f"{_TEST_LOCATION_2}-{constants.API_BASE_PATH}:443" ) assert client._transport._credentials == creds @@ -134,36 +134,21 @@ def test_create_client_user_agent(self): assert user_agent.startswith("model-builder/") @pytest.mark.parametrize( - "init_location, location_override, prediction, expected_endpoint", + "init_location, location_override, expected_endpoint", [ - ("us-central1", None, False, "us-central1-aiplatform.googleapis.com"), - ( - "us-central1", - "europe-west4", - False, - "europe-west4-aiplatform.googleapis.com", - ), - ("asia-east1", None, False, "asia-east1-aiplatform.googleapis.com"), - ( - "asia-east1", - None, - True, - "asia-east1-prediction-aiplatform.googleapis.com", - ), + ("us-central1", None, "us-central1-aiplatform.googleapis.com"), + ("us-central1", "europe-west4", "europe-west4-aiplatform.googleapis.com",), + ("asia-east1", None, "asia-east1-aiplatform.googleapis.com"), ], ) def test_get_client_options( - self, - init_location: str, - location_override: str, - prediction: bool, - expected_endpoint: str, + self, init_location: str, location_override: str, expected_endpoint: str, ): initializer.global_config.init(location=init_location) assert ( initializer.global_config.get_client_options( - location_override=location_override, prediction_client=prediction + location_override=location_override ).api_endpoint == expected_endpoint ) diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 5f2f54b9d6..8d32bbe2c6 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -315,9 +315,9 @@ def test_upload_uploads_and_gets_model(self, sync): api_client_mock.upload_model.return_value = mock_lro create_client_mock.return_value = api_client_mock + # Custom Container workflow, does not pass `artifact_uri` my_model = models.Model.upload( display_name=_TEST_MODEL_NAME, - artifact_uri=_TEST_ARTIFACT_URI, serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, @@ -334,9 +334,7 @@ def test_upload_uploads_and_gets_model(self, sync): ) managed_model = gca_model.Model( - display_name=_TEST_MODEL_NAME, - artifact_uri=_TEST_ARTIFACT_URI, - container_spec=container_spec, + display_name=_TEST_MODEL_NAME, container_spec=container_spec, ) api_client_mock.upload_model.assert_called_once_with(