Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add export model #353

Merged
merged 9 commits into from May 3, 2021
158 changes: 158 additions & 0 deletions google/cloud/aiplatform/models.py
Expand Up @@ -35,9 +35,11 @@
endpoint_v1 as gca_endpoint_v1,
endpoint_v1beta1 as gca_endpoint_v1beta1,
explanation_v1beta1 as gca_explanation_v1beta1,
io_v1 as gca_io_v1,
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
machine_resources as gca_machine_resources_compat,
machine_resources_v1beta1 as gca_machine_resources_v1beta1,
model as gca_model_compat,
model_service as gca_model_service_v1,
model_v1beta1 as gca_model_v1beta1,
env_var as gca_env_var_compat,
env_var_v1beta1 as gca_env_var_v1beta1,
Expand Down Expand Up @@ -1221,6 +1223,26 @@ def description(self):
"""Description of the model."""
return self._gca_resource.description

@property
def supported_export_formats(
self,
) -> Dict[str, List[gca_model_compat.Model.ExportFormat.ExportableContent]]:
"""The formats and content types in which this Model may be exported.
If empty, this Model is not available for export.

For example, if this model can be exported as a Tensorflow SavedModel and
have the artifacts written to Cloud Storage, the expected value would be:

{'tf-saved-model': [<ExportableContent.ARTIFACT: 1>]}
"""
return {
export_format.id: [
gca_model_compat.Model.ExportFormat.ExportableContent(content)
for content in export_format.exportable_contents
]
for export_format in self._gca_resource.supported_export_formats
}

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -2036,3 +2058,139 @@ def list(
location=location,
credentials=credentials,
)

@base.optional_sync()
def _wait_on_export(self, operation_future, sync=True) -> None:
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
operation_future.result()

def export_model(
self,
export_format_id: str,
artifact_destination: Optional[str] = None,
image_destination: Optional[str] = None,
sync: bool = True,
) -> None:
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
"""Exports a trained, exportable Model to a location specified by the user.
A Model is considered to be exportable if it has at least one `supported_export_formats`.
Either `artifact_destination` or `image_destination` must be provided.

Usage:
my_model.export(
export_format_id='tf-saved-model'
artifact_destination='gs://my-bucket/models/'
)

or

my_model.export(
export_format_id='custom-model'
image_destination='us-central1-docker.pkg.dev/projectId/repo/image'
)

Args:
export_format_id (str):
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
Required. The ID of the format in which the Model must be exported.
Each Model lists the export formats it supports, which can be
found using `Model.supported_export_formats`.
artifact_destination (str):
The Cloud Storage location where the Model artifact is to be
written to. Under the directory given as the destination a
new one with name
"``model-export-<model-display-name>-<timestamp-of-export-call>``",
where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601
format, will be created. Inside, the Model and any of its
supporting files will be written.

This field should only be set when, in [Model.supported_export_formats],
the value for the key given in `export_format_id` contains ``ARTIFACT``.
image_destination (str):
The Google Container Registry or Artifact Registry URI where
the Model container image will be copied to. Accepted forms:

- Google Container Registry path. For example:
``gcr.io/projectId/imageName:tag``.

- Artifact Registry path. For example:
``us-central1-docker.pkg.dev/projectId/repoName/imageName:tag``.

This field should only be set when, in [Model.supported_export_formats],
the value for the key given in `export_format_id` contains ``IMAGE``.
sync (bool):
Whether to execute this export synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.

vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
Raises:
ValueError if model does not support exporting.

ValueError if invalid arguments or export formats are provided.
"""

# Model does not support exporting
if not self.supported_export_formats:
raise ValueError(f"The model `{self.resource_name}` is not exportable.")

# No destination provided
if not any((artifact_destination, image_destination)):
raise ValueError(
"Please provide an `artifact_destination` or `image_destination`."
)

export_format_id = export_format_id.lower()

# Unsupported export type
if export_format_id not in self.supported_export_formats:
raise ValueError(
f"'{export_format_id}' is not a supported export format for this model. "
f"Choose one of the following: {self.supported_export_formats}"
)

content_types = gca_model_compat.Model.ExportFormat.ExportableContent
supported_content_types = self.supported_export_formats[export_format_id]

if (
artifact_destination
and content_types.ARTIFACT not in supported_content_types
):
raise ValueError(
"This model can not be exported as an artifact in '{export_format_id}' format. "
"Try exporting as a container image by passing the `image_destination` argument."
)

if image_destination and content_types.IMAGE not in supported_content_types:
raise ValueError(
"This model can not be exported as a container image in '{export_format_id}' format. "
"Try exporting the model artifacts by passing a `artifact_destination` argument."
)

# Construct request payload
output_config = gca_model_service_v1.ExportModelRequest.OutputConfig(
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
export_format_id=export_format_id
)

if artifact_destination:
output_config.artifact_destination = gca_io_v1.GcsDestination(
output_uri_prefix=artifact_destination
)

if image_destination:
output_config.image_destination = gca_io_v1.ContainerRegistryDestination(
output_uri=image_destination
)

_LOGGER.log_action_start_against_resource("Exporting", "model", self)

operation_future = self.api_client.export_model(
name=self.resource_name, output_config=output_config
)

_LOGGER.log_action_started_against_resource_with_lro(
"Export", "model", self.__class__, operation_future
)

# Block before returning
self._wait_on_export(operation_future=operation_future, sync=sync)

_LOGGER.log_action_completed_against_resource("model", "exported", self)
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved

return json_format.MessageToDict(operation_future.metadata.output_info._pb)
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved