Skip to content

Commit

Permalink
feat: Add export model (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinnysenthil committed May 3, 2021
1 parent 857f63d commit 12c5be4
Show file tree
Hide file tree
Showing 2 changed files with 430 additions and 15 deletions.
162 changes: 162 additions & 0 deletions google/cloud/aiplatform/models.py
Expand Up @@ -17,6 +17,7 @@
import proto
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

from google.api_core import operation
from google.auth import credentials as auth_credentials

from google.cloud.aiplatform import base
Expand All @@ -35,9 +36,11 @@
endpoint_v1 as gca_endpoint_v1,
endpoint_v1beta1 as gca_endpoint_v1beta1,
explanation_v1beta1 as gca_explanation_v1beta1,
io as gca_io_compat,
machine_resources as gca_machine_resources_compat,
machine_resources_v1beta1 as gca_machine_resources_v1beta1,
model as gca_model_compat,
model_service as gca_model_service_compat,
model_v1beta1 as gca_model_v1beta1,
env_var as gca_env_var_compat,
env_var_v1beta1 as gca_env_var_v1beta1,
Expand Down Expand Up @@ -1217,6 +1220,26 @@ def description(self):
"""Description of the model."""
return self._gca_resource.description

@property
def supported_export_formats(
self,
) -> Dict[str, List[gca_model_compat.Model.ExportFormat.ExportableContent]]:
"""The formats and content types in which this Model may be exported.
If empty, this Model is not available for export.
For example, if this model can be exported as a Tensorflow SavedModel and
have the artifacts written to Cloud Storage, the expected value would be:
{'tf-saved-model': [<ExportableContent.ARTIFACT: 1>]}
"""
return {
export_format.id: [
gca_model_compat.Model.ExportFormat.ExportableContent(content)
for content in export_format.exportable_contents
]
for export_format in self._gca_resource.supported_export_formats
}

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -2030,3 +2053,142 @@ def list(
location=location,
credentials=credentials,
)

@base.optional_sync()
def _wait_on_export(self, operation_future: operation.Operation, sync=True) -> None:
operation_future.result()

def export_model(
self,
export_format_id: str,
artifact_destination: Optional[str] = None,
image_destination: Optional[str] = None,
sync: bool = True,
) -> Dict[str, str]:
"""Exports a trained, exportable Model to a location specified by the user.
A Model is considered to be exportable if it has at least one `supported_export_formats`.
Either `artifact_destination` or `image_destination` must be provided.
Usage:
my_model.export(
export_format_id='tf-saved-model'
artifact_destination='gs://my-bucket/models/'
)
or
my_model.export(
export_format_id='custom-model'
image_destination='us-central1-docker.pkg.dev/projectId/repo/image'
)
Args:
export_format_id (str):
Required. The ID of the format in which the Model must be exported.
The list of export formats that this Model supports can be found
by calling `Model.supported_export_formats`.
artifact_destination (str):
The Cloud Storage location where the Model artifact is to be
written to. Under the directory given as the destination a
new one with name
"``model-export-<model-display-name>-<timestamp-of-export-call>``",
where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601
format, will be created. Inside, the Model and any of its
supporting files will be written.
This field should only be set when, in [Model.supported_export_formats],
the value for the key given in `export_format_id` contains ``ARTIFACT``.
image_destination (str):
The Google Container Registry or Artifact Registry URI where
the Model container image will be copied to. Accepted forms:
- Google Container Registry path. For example:
``gcr.io/projectId/imageName:tag``.
- Artifact Registry path. For example:
``us-central1-docker.pkg.dev/projectId/repoName/imageName:tag``.
This field should only be set when, in [Model.supported_export_formats],
the value for the key given in `export_format_id` contains ``IMAGE``.
sync (bool):
Whether to execute this export synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
Returns:
output_info (Dict[str, str]):
Details of the completed export with output destination paths to
the artifacts or container image.
Raises:
ValueError if model does not support exporting.
ValueError if invalid arguments or export formats are provided.
"""

# Model does not support exporting
if not self.supported_export_formats:
raise ValueError(f"The model `{self.resource_name}` is not exportable.")

# No destination provided
if not any((artifact_destination, image_destination)):
raise ValueError(
"Please provide an `artifact_destination` or `image_destination`."
)

export_format_id = export_format_id.lower()

# Unsupported export type
if export_format_id not in self.supported_export_formats:
raise ValueError(
f"'{export_format_id}' is not a supported export format for this model. "
f"Choose one of the following: {self.supported_export_formats}"
)

content_types = gca_model_compat.Model.ExportFormat.ExportableContent
supported_content_types = self.supported_export_formats[export_format_id]

if (
artifact_destination
and content_types.ARTIFACT not in supported_content_types
):
raise ValueError(
"This model can not be exported as an artifact in '{export_format_id}' format. "
"Try exporting as a container image by passing the `image_destination` argument."
)

if image_destination and content_types.IMAGE not in supported_content_types:
raise ValueError(
"This model can not be exported as a container image in '{export_format_id}' format. "
"Try exporting the model artifacts by passing a `artifact_destination` argument."
)

# Construct request payload
output_config = gca_model_service_compat.ExportModelRequest.OutputConfig(
export_format_id=export_format_id
)

if artifact_destination:
output_config.artifact_destination = gca_io_compat.GcsDestination(
output_uri_prefix=artifact_destination
)

if image_destination:
output_config.image_destination = gca_io_compat.ContainerRegistryDestination(
output_uri=image_destination
)

_LOGGER.log_action_start_against_resource("Exporting", "model", self)

operation_future = self.api_client.export_model(
name=self.resource_name, output_config=output_config
)

_LOGGER.log_action_started_against_resource_with_lro(
"Export", "model", self.__class__, operation_future
)

# Block before returning
self._wait_on_export(operation_future=operation_future, sync=sync)

_LOGGER.log_action_completed_against_resource("model", "exported", self)

return json_format.MessageToDict(operation_future.metadata.output_info._pb)

0 comments on commit 12c5be4

Please sign in to comment.