Skip to content

Commit

Permalink
feat: Expose additional attributes into Vertex SDK to close gap with …
Browse files Browse the repository at this point in the history
…GAPIC (#477)

* Add most missing fields

* Add tests for get trainingjob subclass

* Drop Dataset len, add more attrs, update docstrings

* flake8 lint

* Address reviewer comments

* Switch 'an' to 'a' when referencing Vertex AI

* Address comments, move base attrs to subclasses

* Drop unused import

* Add test to ensure supported training schemas are always unique

* Address reviewer comments
  • Loading branch information
vinnysenthil committed Jun 17, 2021
1 parent e3cbdd8 commit 572a27c
Show file tree
Hide file tree
Showing 7 changed files with 428 additions and 40 deletions.
21 changes: 19 additions & 2 deletions google/cloud/aiplatform/base.py
Expand Up @@ -42,7 +42,7 @@
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils

from google.cloud.aiplatform.compat.types import encryption_spec as gca_encryption_spec

logging.basicConfig(level=logging.INFO, stream=sys.stdout)

Expand Down Expand Up @@ -563,6 +563,23 @@ def update_time(self) -> datetime.datetime:
self._sync_gca_resource()
return self._gca_resource.update_time

@property
def encryption_spec(self) -> Optional[gca_encryption_spec.EncryptionSpec]:
"""Customer-managed encryption key options for this Vertex AI resource.
If this is set, then all resources created by this Vertex AI resource will
be encrypted with the provided encryption key.
"""
return getattr(self._gca_resource, "encryption_spec")

@property
def labels(self) -> Dict[str, str]:
"""User-defined labels containing metadata about this resource.
Read more about labels at https://goo.gl/xmQnxf
"""
return self._gca_resource.labels

@property
def gca_resource(self) -> proto.Message:
"""The underlying resource proto represenation."""
Expand Down Expand Up @@ -813,7 +830,7 @@ def _construct_sdk_resource_from_gapic(
Args:
gapic_resource (proto.Message):
A GAPIC representation of an Vertex AI resource, usually
A GAPIC representation of a Vertex AI resource, usually
retrieved by a get_* or in a list_* API call.
project (str):
Optional. Project to construct SDK object from. If not set,
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/initializer.py
Expand Up @@ -267,7 +267,7 @@ def create_client(
Args:
client_class (utils.VertexAiServiceClientWithOverride):
(Required) An Vertex AI Service Client with optional overrides.
(Required) A Vertex AI Service Client with optional overrides.
credentials (auth_credentials.Credentials):
Custom auth credentials. If not provided will use the current config.
location_override (str): Optional location override.
Expand Down
83 changes: 78 additions & 5 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -19,6 +19,7 @@

import abc
import copy
import datetime
import sys
import time
import logging
Expand All @@ -28,6 +29,7 @@

from google.auth import credentials as auth_credentials
from google.protobuf import duration_pb2 # type: ignore
from google.rpc import status_pb2

from google.cloud import aiplatform
from google.cloud.aiplatform import base
Expand All @@ -45,6 +47,7 @@
batch_prediction_job as gca_bp_job_compat,
batch_prediction_job_v1 as gca_bp_job_v1,
batch_prediction_job_v1beta1 as gca_bp_job_v1beta1,
completion_stats as gca_completion_stats,
custom_job as gca_custom_job_compat,
custom_job_v1beta1 as gca_custom_job_v1beta1,
explanation_v1beta1 as gca_explanation_v1beta1,
Expand Down Expand Up @@ -139,6 +142,27 @@ def state(self) -> gca_job_state.JobState:

return self._gca_resource.state

@property
def start_time(self) -> Optional[datetime.datetime]:
"""Time when the Job resource entered the `JOB_STATE_RUNNING` for the
first time."""
self._sync_gca_resource()
return getattr(self._gca_resource, "start_time")

@property
def end_time(self) -> Optional[datetime.datetime]:
"""Time when the Job resource entered the `JOB_STATE_SUCCEEDED`,
`JOB_STATE_FAILED`, or `JOB_STATE_CANCELLED` state."""
self._sync_gca_resource()
return getattr(self._gca_resource, "end_time")

@property
def error(self) -> Optional[status_pb2.Status]:
"""Detailed error info for this Job resource. Only populated when the
Job's state is `JOB_STATE_FAILED` or `JOB_STATE_CANCELLED`."""
self._sync_gca_resource()
return getattr(self._gca_resource, "error")

@property
@abc.abstractmethod
def _job_type(cls) -> str:
Expand Down Expand Up @@ -302,6 +326,27 @@ def __init__(
credentials=credentials,
)

@property
def output_info(self,) -> Optional[aiplatform.gapic.BatchPredictionJob.OutputInfo]:
"""Information describing the output of this job, including output location
into which prediction output is written.
This is only available for batch predicition jobs that have run successfully.
"""
return self._gca_resource.output_info

@property
def partial_failures(self) -> Optional[Sequence[status_pb2.Status]]:
"""Partial failures encountered. For example, single files that can't be read.
This field never exceeds 20 entries. Status details fields contain standard
GCP error details."""
return getattr(self._gca_resource, "partial_failures")

@property
def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
"""Statistics on completed and failed prediction instances."""
return getattr(self._gca_resource, "completion_stats")

@classmethod
def create(
cls,
Expand Down Expand Up @@ -842,7 +887,7 @@ def get(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "_RunnableJob":
"""Get an Vertex AI Job for the given resource_name.
"""Get a Vertex AI Job for the given resource_name.
Args:
resource_name (str):
Expand All @@ -858,7 +903,7 @@ def get(
credentials set in aiplatform.init.
Returns:
An Vertex AI Job.
A Vertex AI Job.
"""
self = cls._empty_constructor(
project=project,
Expand Down Expand Up @@ -887,7 +932,7 @@ class CustomJob(_RunnableJob):

_resource_noun = "customJobs"
_getter_method = "get_custom_job"
_list_method = "list_custom_job"
_list_method = "list_custom_jobs"
_cancel_method = "cancel_custom_job"
_delete_method = "delete_custom_job"
_job_type = "training"
Expand Down Expand Up @@ -987,6 +1032,20 @@ def __init__(
),
)

@property
def network(self) -> Optional[str]:
"""The full name of the Google Compute Engine
[network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
CustomJob should be peered.
Takes the format `projects/{project}/global/networks/{network}`. Where
{project} is a project number, as in `12345`, and {network} is a network name.
Private services access must already be configured for the network. If left
unspecified, the CustomJob is not peered with any network.
"""
return getattr(self._gca_resource, "network")

@classmethod
def from_local_script(
cls,
Expand Down Expand Up @@ -1157,7 +1216,7 @@ def run(
distributed training jobs that are not resilient
to workers leaving and joining a job.
tensorboard (str):
Optional. The name of an Vertex AI
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
resource to which this CustomJob will upload Tensorboard
logs. Format:
Expand Down Expand Up @@ -1444,6 +1503,20 @@ def __init__(
),
)

@property
def network(self) -> Optional[str]:
"""The full name of the Google Compute Engine
[network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
HyperparameterTuningJob should be peered.
Takes the format `projects/{project}/global/networks/{network}`. Where
{project} is a project number, as in `12345`, and {network} is a network name.
Private services access must already be configured for the network. If left
unspecified, the HyperparameterTuningJob is not peered with any network.
"""
return getattr(self._gca_resource.trial_job_spec, "network")

@base.optional_sync()
def run(
self,
Expand Down Expand Up @@ -1473,7 +1546,7 @@ def run(
distributed training jobs that are not resilient
to workers leaving and joining a job.
tensorboard (str):
Optional. The name of an Vertex AI
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
resource to which this CustomJob will upload Tensorboard
logs. Format:
Expand Down
130 changes: 126 additions & 4 deletions google/cloud/aiplatform/models.py
Expand Up @@ -18,8 +18,10 @@
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

from google.api_core import operation
from google.api_core import exceptions as api_exceptions
from google.auth import credentials as auth_credentials

from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import explain
Expand Down Expand Up @@ -119,6 +121,33 @@ def __init__(
credentials=credentials,
)

@property
def traffic_split(self) -> Dict[str, int]:
"""A map from a DeployedModel's ID to the percentage of this Endpoint's
traffic that should be forwarded to that DeployedModel.
If a DeployedModel's ID is not listed in this map, then it receives no traffic.
The traffic percentage values must add up to 100, or map must be empty if
the Endpoint is to not accept any traffic at a moment.
"""
self._sync_gca_resource()
return dict(self._gca_resource.traffic_split)

@property
def network(self) -> Optional[str]:
"""The full name of the Google Compute Engine
[network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
Endpoint should be peered.
Takes the format `projects/{project}/global/networks/{network}`. Where
{project} is a project number, as in `12345`, and {network} is a network name.
Private services access must already be configured for the network. If left
unspecified, the Endpoint is not peered with any network.
"""
return getattr(self._gca_resource, "network")

@classmethod
def create(
cls,
Expand Down Expand Up @@ -1211,12 +1240,13 @@ class Model(base.VertexAiResourceNounWithFutureManager):
_delete_method = "delete_model"

@property
def uri(self):
"""Uri of the model."""
return self._gca_resource.artifact_uri
def uri(self) -> Optional[str]:
"""Path to the directory containing the Model artifact and any of its
supporting files. Not present for AutoML Models."""
return self._gca_resource.artifact_uri or None

@property
def description(self):
def description(self) -> str:
"""Description of the model."""
return self._gca_resource.description

Expand All @@ -1240,6 +1270,98 @@ def supported_export_formats(
for export_format in self._gca_resource.supported_export_formats
}

@property
def supported_deployment_resources_types(
self,
) -> List[aiplatform.gapic.Model.DeploymentResourcesType]:
"""List of deployment resource types accepted for this Model.
When this Model is deployed, its prediction resources are described by
the `prediction_resources` field of the objects returned by
`Endpoint.list_models()`. Because not all Models support all resource
configuration types, the configuration types this Model supports are
listed here.
If no configuration types are listed, the Model cannot be
deployed to an `Endpoint` and does not support online predictions
(`Endpoint.predict()` or `Endpoint.explain()`). Such a Model can serve
predictions by using a `BatchPredictionJob`, if it has at least one entry
each in `Model.supported_input_storage_formats` and
`Model.supported_output_storage_formats`."""
return list(self._gca_resource.supported_deployment_resources_types)

@property
def supported_input_storage_formats(self) -> List[str]:
"""The formats this Model supports in the `input_config` field of a
`BatchPredictionJob`. If `Model.predict_schemata.instance_schema_uri`
exists, the instances should be given as per that schema.
[Read the docs for more on batch prediction formats](https://cloud.google.com/vertex-ai/docs/predictions/batch-predictions#batch_request_input)
If this Model doesn't support any of these formats it means it cannot be
used with a `BatchPredictionJob`. However, if it has
`supported_deployment_resources_types`, it could serve online predictions
by using `Endpoint.predict()` or `Endpoint.explain()`.
"""
return list(self._gca_resource.supported_input_storage_formats)

@property
def supported_output_storage_formats(self) -> List[str]:
"""The formats this Model supports in the `output_config` field of a
`BatchPredictionJob`.
If both `Model.predict_schemata.instance_schema_uri` and
`Model.predict_schemata.prediction_schema_uri` exist, the predictions
are returned together with their instances. In other words, the
prediction has the original instance data first, followed by the actual
prediction content (as per the schema).
[Read the docs for more on batch prediction formats](https://cloud.google.com/vertex-ai/docs/predictions/batch-predictions)
If this Model doesn't support any of these formats it means it cannot be
used with a `BatchPredictionJob`. However, if it has
`supported_deployment_resources_types`, it could serve online predictions
by using `Endpoint.predict()` or `Endpoint.explain()`.
"""
return list(self._gca_resource.supported_output_storage_formats)

@property
def predict_schemata(self) -> Optional[aiplatform.gapic.PredictSchemata]:
"""The schemata that describe formats of the Model's predictions and
explanations, if available."""
return getattr(self._gca_resource, "predict_schemata")

@property
def training_job(self) -> Optional["aiplatform.training_jobs._TrainingJob"]:
"""The TrainingJob that uploaded this Model, if any.
Raises:
api_core.exceptions.NotFound: If the Model's training job resource
cannot be found on the Vertex service.
"""
job_name = getattr(self._gca_resource, "training_pipeline")

if not job_name:
return None

try:
return aiplatform.training_jobs._TrainingJob._get_and_return_subclass(
resource_name=job_name,
project=self.project,
location=self.location,
credentials=self.credentials,
)
except api_exceptions.NotFound:
raise api_exceptions.NotFound(
f"The training job used to create this model could not be found: {job_name}"
)

@property
def container_spec(self) -> Optional[aiplatform.gapic.ModelContainerSpec]:
"""The specification of the container that is to be used when deploying
this Model. Not present for AutoML Models."""
return getattr(self._gca_resource, "container_spec")

def __init__(
self,
model_name: str,
Expand Down

0 comments on commit 572a27c

Please sign in to comment.