Skip to content

Commit

Permalink
feat: Add Custom Container Prediction support, move to single API end…
Browse files Browse the repository at this point in the history
…point (#277)

## Summary of Changes
- Make `artifact_uri` optional in `Model.upload()` to allow for custom containers that contain model artifacts & server
- Remove reference to specific prediction endpoint (`*-prediction-aiplatform.googleapis.com`) - [link to commit](bbd9218)
- Update relevant tests

Fixes [b/184049758](http://b/184049758), [b/180036930](http://b/180036930) 🦕
  • Loading branch information
vinnysenthil committed Mar 31, 2021
1 parent 1230dc6 commit ca7f6d6
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 42 deletions.
17 changes: 6 additions & 11 deletions google/cloud/aiplatform/initializer.py
Expand Up @@ -163,7 +163,7 @@ def encryption_spec_key_name(self) -> Optional[str]:
return self._encryption_spec_key_name

def get_client_options(
self, location_override: Optional[str] = None, prediction_client: bool = False,
self, location_override: Optional[str] = None
) -> client_options.ClientOptions:
"""Creates GAPIC client_options using location and type.
Expand All @@ -173,15 +173,11 @@ def get_client_options(
location set by initializer. Must be a GCP region supported by AI
Platform (Unified).
prediction_client (bool):
True if service client is a PredictionServiceClient, otherwise defaults
to False. This is used to provide a prediction-specific API endpoint.
Returns:
clients_options (dict):
A dictionary containing client_options with one key, for example
clients_options (google.api_core.client_options.ClientOptions):
A ClientOptions object set with regionalized API endpoint, i.e.
{ "api_endpoint": "us-central1-aiplatform.googleapis.com" } or
{ "api_endpoint": "asia-east1-prediction-aiplatform.googleapis.com" }
{ "api_endpoint": "asia-east1-aiplatform.googleapis.com" }
"""
if not (self.location or location_override):
raise ValueError(
Expand All @@ -190,12 +186,11 @@ def get_client_options(

region = location_override or self.location
region = region.lower()
prediction = "prediction-" if prediction_client else ""

utils.validate_region(region)

return client_options.ClientOptions(
api_endpoint=f"{region}-{prediction}{constants.API_BASE_PATH}"
api_endpoint=f"{region}-{constants.API_BASE_PATH}"
)

def common_location_path(
Expand Down Expand Up @@ -250,7 +245,7 @@ def create_client(
kwargs = {
"credentials": credentials or self.credentials,
"client_options": self.get_client_options(
location_override=location_override, prediction_client=prediction_client
location_override=location_override
),
"client_info": client_info,
}
Expand Down
13 changes: 8 additions & 5 deletions google/cloud/aiplatform/models.py
Expand Up @@ -1130,9 +1130,9 @@ def __init__(
def upload(
cls,
display_name: str,
artifact_uri: str,
serving_container_image_uri: str,
*,
artifact_uri: Optional[str] = None,
serving_container_predict_route: Optional[str] = None,
serving_container_health_route: Optional[str] = None,
description: Optional[str] = None,
Expand Down Expand Up @@ -1167,11 +1167,12 @@ def upload(
display_name (str):
Required. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
artifact_uri (str):
Required. The path to the directory containing the Model artifact and
any of its supporting files. Not present for AutoML Models.
serving_container_image_uri (str):
Required. The URI of the Model serving container.
artifact_uri (str):
Optional. The path to the directory containing the Model artifact and
any of its supporting files. Leave blank for custom container prediction.
Not present for AutoML Models.
serving_container_predict_route (str):
Optional. An HTTP path to send prediction requests to the container, and
which must be supported by it. If not specified a default HTTP path will
Expand Down Expand Up @@ -1335,12 +1336,14 @@ def upload(
managed_model = gca_model.Model(
display_name=display_name,
description=description,
artifact_uri=artifact_uri,
container_spec=container_spec,
predict_schemata=model_predict_schemata,
encryption_spec=encryption_spec,
)

if artifact_uri:
managed_model.artifact_uri = artifact_uri

# Override explanation_spec if both required fields are provided
if explanation_metadata and explanation_parameters:
explanation_spec = gca_endpoint.explanation.ExplanationSpec()
Expand Down
29 changes: 7 additions & 22 deletions tests/unit/aiplatform/test_initializer.py
Expand Up @@ -117,7 +117,7 @@ def test_create_client_overrides(self):
assert isinstance(client, model_service_client.ModelServiceClient)
assert (
client._transport._host
== f"{_TEST_LOCATION_2}-prediction-{constants.API_BASE_PATH}:443"
== f"{_TEST_LOCATION_2}-{constants.API_BASE_PATH}:443"
)
assert client._transport._credentials == creds

Expand All @@ -134,36 +134,21 @@ def test_create_client_user_agent(self):
assert user_agent.startswith("model-builder/")

@pytest.mark.parametrize(
"init_location, location_override, prediction, expected_endpoint",
"init_location, location_override, expected_endpoint",
[
("us-central1", None, False, "us-central1-aiplatform.googleapis.com"),
(
"us-central1",
"europe-west4",
False,
"europe-west4-aiplatform.googleapis.com",
),
("asia-east1", None, False, "asia-east1-aiplatform.googleapis.com"),
(
"asia-east1",
None,
True,
"asia-east1-prediction-aiplatform.googleapis.com",
),
("us-central1", None, "us-central1-aiplatform.googleapis.com"),
("us-central1", "europe-west4", "europe-west4-aiplatform.googleapis.com",),
("asia-east1", None, "asia-east1-aiplatform.googleapis.com"),
],
)
def test_get_client_options(
self,
init_location: str,
location_override: str,
prediction: bool,
expected_endpoint: str,
self, init_location: str, location_override: str, expected_endpoint: str,
):
initializer.global_config.init(location=init_location)

assert (
initializer.global_config.get_client_options(
location_override=location_override, prediction_client=prediction
location_override=location_override
).api_endpoint
== expected_endpoint
)
Expand Down
6 changes: 2 additions & 4 deletions tests/unit/aiplatform/test_models.py
Expand Up @@ -315,9 +315,9 @@ def test_upload_uploads_and_gets_model(self, sync):
api_client_mock.upload_model.return_value = mock_lro
create_client_mock.return_value = api_client_mock

# Custom Container workflow, does not pass `artifact_uri`
my_model = models.Model.upload(
display_name=_TEST_MODEL_NAME,
artifact_uri=_TEST_ARTIFACT_URI,
serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
Expand All @@ -334,9 +334,7 @@ def test_upload_uploads_and_gets_model(self, sync):
)

managed_model = gca_model.Model(
display_name=_TEST_MODEL_NAME,
artifact_uri=_TEST_ARTIFACT_URI,
container_spec=container_spec,
display_name=_TEST_MODEL_NAME, container_spec=container_spec,
)

api_client_mock.upload_model.assert_called_once_with(
Expand Down

0 comments on commit ca7f6d6

Please sign in to comment.