Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Custom Container Prediction support, move to single API endpoint #277

Merged
merged 4 commits into from Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 6 additions & 11 deletions google/cloud/aiplatform/initializer.py
Expand Up @@ -163,7 +163,7 @@ def encryption_spec_key_name(self) -> Optional[str]:
return self._encryption_spec_key_name

def get_client_options(
self, location_override: Optional[str] = None, prediction_client: bool = False,
self, location_override: Optional[str] = None
) -> client_options.ClientOptions:
"""Creates GAPIC client_options using location and type.

Expand All @@ -173,15 +173,11 @@ def get_client_options(
location set by initializer. Must be a GCP region supported by AI
Platform (Unified).

prediction_client (bool):
True if service client is a PredictionServiceClient, otherwise defaults
to False. This is used to provide a prediction-specific API endpoint.

Returns:
clients_options (dict):
A dictionary containing client_options with one key, for example
clients_options (google.api_core.client_options.ClientOptions):
A ClientOptions object set with regionalized API endpoint, i.e.
{ "api_endpoint": "us-central1-aiplatform.googleapis.com" } or
{ "api_endpoint": "asia-east1-prediction-aiplatform.googleapis.com" }
{ "api_endpoint": "asia-east1-aiplatform.googleapis.com" }
"""
if not (self.location or location_override):
raise ValueError(
Expand All @@ -190,12 +186,11 @@ def get_client_options(

region = location_override or self.location
region = region.lower()
prediction = "prediction-" if prediction_client else ""

utils.validate_region(region)

return client_options.ClientOptions(
api_endpoint=f"{region}-{prediction}{constants.API_BASE_PATH}"
api_endpoint=f"{region}-{constants.API_BASE_PATH}"
)

def common_location_path(
Expand Down Expand Up @@ -250,7 +245,7 @@ def create_client(
kwargs = {
"credentials": credentials or self.credentials,
"client_options": self.get_client_options(
location_override=location_override, prediction_client=prediction_client
location_override=location_override
),
"client_info": client_info,
}
Expand Down
13 changes: 8 additions & 5 deletions google/cloud/aiplatform/models.py
Expand Up @@ -1130,9 +1130,9 @@ def __init__(
def upload(
cls,
display_name: str,
artifact_uri: str,
serving_container_image_uri: str,
*,
artifact_uri: Optional[str] = None,
serving_container_predict_route: Optional[str] = None,
serving_container_health_route: Optional[str] = None,
description: Optional[str] = None,
Expand Down Expand Up @@ -1167,11 +1167,12 @@ def upload(
display_name (str):
Required. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
artifact_uri (str):
Required. The path to the directory containing the Model artifact and
any of its supporting files. Not present for AutoML Models.
serving_container_image_uri (str):
Required. The URI of the Model serving container.
artifact_uri (str):
Optional. The path to the directory containing the Model artifact and
any of its supporting files. Leave blank for custom container prediction.
Not present for AutoML Models.
serving_container_predict_route (str):
Optional. An HTTP path to send prediction requests to the container, and
which must be supported by it. If not specified a default HTTP path will
Expand Down Expand Up @@ -1335,12 +1336,14 @@ def upload(
managed_model = gca_model.Model(
display_name=display_name,
description=description,
artifact_uri=artifact_uri,
container_spec=container_spec,
predict_schemata=model_predict_schemata,
encryption_spec=encryption_spec,
)

if artifact_uri:
managed_model.artifact_uri = artifact_uri

# Override explanation_spec if both required fields are provided
if explanation_metadata and explanation_parameters:
explanation_spec = gca_endpoint.explanation.ExplanationSpec()
Expand Down
29 changes: 7 additions & 22 deletions tests/unit/aiplatform/test_initializer.py
Expand Up @@ -117,7 +117,7 @@ def test_create_client_overrides(self):
assert isinstance(client, model_service_client.ModelServiceClient)
assert (
client._transport._host
== f"{_TEST_LOCATION_2}-prediction-{constants.API_BASE_PATH}:443"
== f"{_TEST_LOCATION_2}-{constants.API_BASE_PATH}:443"
)
assert client._transport._credentials == creds

Expand All @@ -134,36 +134,21 @@ def test_create_client_user_agent(self):
assert user_agent.startswith("model-builder/")

@pytest.mark.parametrize(
"init_location, location_override, prediction, expected_endpoint",
"init_location, location_override, expected_endpoint",
[
("us-central1", None, False, "us-central1-aiplatform.googleapis.com"),
(
"us-central1",
"europe-west4",
False,
"europe-west4-aiplatform.googleapis.com",
),
("asia-east1", None, False, "asia-east1-aiplatform.googleapis.com"),
(
"asia-east1",
None,
True,
"asia-east1-prediction-aiplatform.googleapis.com",
),
("us-central1", None, "us-central1-aiplatform.googleapis.com"),
("us-central1", "europe-west4", "europe-west4-aiplatform.googleapis.com",),
("asia-east1", None, "asia-east1-aiplatform.googleapis.com"),
],
)
def test_get_client_options(
self,
init_location: str,
location_override: str,
prediction: bool,
expected_endpoint: str,
self, init_location: str, location_override: str, expected_endpoint: str,
):
initializer.global_config.init(location=init_location)

assert (
initializer.global_config.get_client_options(
location_override=location_override, prediction_client=prediction
location_override=location_override
).api_endpoint
== expected_endpoint
)
Expand Down
6 changes: 2 additions & 4 deletions tests/unit/aiplatform/test_models.py
Expand Up @@ -315,9 +315,9 @@ def test_upload_uploads_and_gets_model(self, sync):
api_client_mock.upload_model.return_value = mock_lro
create_client_mock.return_value = api_client_mock

# Custom Container workflow, does not pass `artifact_uri`
my_model = models.Model.upload(
display_name=_TEST_MODEL_NAME,
artifact_uri=_TEST_ARTIFACT_URI,
serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
Expand All @@ -334,9 +334,7 @@ def test_upload_uploads_and_gets_model(self, sync):
)

managed_model = gca_model.Model(
display_name=_TEST_MODEL_NAME,
artifact_uri=_TEST_ARTIFACT_URI,
container_spec=container_spec,
display_name=_TEST_MODEL_NAME, container_spec=container_spec,
)

api_client_mock.upload_model.assert_called_once_with(
Expand Down