Skip to content

Commit

Permalink
feat: Add wait_for_resource_creation to BatchPredictionJob and unbloc…
Browse files Browse the repository at this point in the history
…k async creation when model is pending creation. (#660)
  • Loading branch information
sasha-gitg committed Aug 29, 2021
1 parent 4ad67dc commit db580ad
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 86 deletions.
33 changes: 33 additions & 0 deletions README.rst
Expand Up @@ -274,6 +274,39 @@ Please visit `Importing models to Vertex AI`_ for a detailed overview:
.. _Importing models to Vertex AI: https://cloud.google.com/vertex-ai/docs/general/import-model


Batch Prediction
----------------

To create a batch prediction job:

.. code-block:: Python
model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
batch_prediction_job = model.batch_predict(
job_display_name='my-batch-prediction-job',
instances_format='csv'
machine_type='n1-standard-4',
gcs_source=['gs://path/to/my/file.csv']
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
)
You can also create a batch prediction job asynchronously by including the `sync=False` argument:

.. code-block:: Python
batch_prediction_job = model.batch_predict(..., sync=False)
# wait for resource to be created
batch_prediction_job.wait_for_resource_creation()
# get the state
batch_prediction_job.state
# block until job is complete
batch_prediction_job.wait()
Endpoints
---------

Expand Down
16 changes: 10 additions & 6 deletions google/cloud/aiplatform/base.py
Expand Up @@ -680,17 +680,21 @@ def wrapper(*args, **kwargs):
inspect.getfullargspec(method).annotations["return"]
)

# object produced by the method
returned_object = bound_args.arguments.get(return_input_arg)

# is a classmethod that creates the object and returns it
if args and inspect.isclass(args[0]):
# assumes classmethod is our resource noun
returned_object = args[0]._empty_constructor()

# assumes class in classmethod is the resource noun
returned_object = (
args[0]._empty_constructor()
if not returned_object
else returned_object
)
self = returned_object

else: # instance method

# object produced by the method
returned_object = bound_args.arguments.get(return_input_arg)

# if we're returning an input object
if returned_object and returned_object is not self:

Expand Down
121 changes: 60 additions & 61 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -32,15 +32,6 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import constants
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import hyperparameter_tuning
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import console_utils
from google.cloud.aiplatform.utils import source_utils
from google.cloud.aiplatform.utils import worker_spec_utils

from google.cloud.aiplatform.compat.services import job_service_client
from google.cloud.aiplatform.compat.types import (
batch_prediction_job as gca_bp_job_compat,
batch_prediction_job_v1 as gca_bp_job_v1,
Expand All @@ -58,6 +49,13 @@
machine_resources_v1beta1 as gca_machine_resources_v1beta1,
study as gca_study_compat,
)
from google.cloud.aiplatform import constants
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import hyperparameter_tuning
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import console_utils
from google.cloud.aiplatform.utils import source_utils
from google.cloud.aiplatform.utils import worker_spec_utils


_LOGGER = base.Logger(__name__)
Expand Down Expand Up @@ -352,7 +350,7 @@ def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
def create(
cls,
job_display_name: str,
model_name: str,
model_name: Union[str, "aiplatform.Model"],
instances_format: str = "jsonl",
predictions_format: str = "jsonl",
gcs_source: Optional[Union[str, Sequence[str]]] = None,
Expand Down Expand Up @@ -384,10 +382,12 @@ def create(
Required. The user-defined name of the BatchPredictionJob.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
model_name (str):
model_name (Union[str, aiplatform.Model]):
Required. A fully-qualified model resource name or model ID.
Example: "projects/123/locations/us-central1/models/456" or
"456" when project and location are initialized or passed.
Or an instance of aiplatform.Model.
instances_format (str):
Required. The format in which instances are given, must be one
of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip",
Expand Down Expand Up @@ -533,15 +533,17 @@ def create(
"""

utils.validate_display_name(job_display_name)

if labels:
utils.validate_labels(labels)

model_name = utils.full_resource_name(
resource_name=model_name,
resource_noun="models",
project=project,
location=location,
)
if isinstance(model_name, str):
model_name = utils.full_resource_name(
resource_name=model_name,
resource_noun="models",
project=project,
location=location,
)

# Raise error if both or neither source URIs are provided
if bool(gcs_source) == bool(bigquery_source):
Expand Down Expand Up @@ -570,6 +572,7 @@ def create(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)

gca_bp_job = gca_bp_job_compat
gca_io = gca_io_compat
gca_machine_resources = gca_machine_resources_compat
Expand All @@ -584,7 +587,6 @@ def create(

# Required Fields
gapic_batch_prediction_job.display_name = job_display_name
gapic_batch_prediction_job.model = model_name

input_config = gca_bp_job.BatchPredictionJob.InputConfig()
output_config = gca_bp_job.BatchPredictionJob.OutputConfig()
Expand Down Expand Up @@ -657,63 +659,43 @@ def create(
metadata=explanation_metadata, parameters=explanation_parameters
)

# TODO (b/174502913): Support private feature once released

api_client = cls._instantiate_client(location=location, credentials=credentials)
empty_batch_prediction_job = cls._empty_constructor(
project=project, location=location, credentials=credentials,
)

return cls._create(
api_client=api_client,
parent=initializer.global_config.common_location_path(
project=project, location=location
),
batch_prediction_job=gapic_batch_prediction_job,
empty_batch_prediction_job=empty_batch_prediction_job,
model_or_model_name=model_name,
gca_batch_prediction_job=gapic_batch_prediction_job,
generate_explanation=generate_explanation,
project=project or initializer.global_config.project,
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
sync=sync,
)

@classmethod
@base.optional_sync()
@base.optional_sync(return_input_arg="empty_batch_prediction_job")
def _create(
cls,
api_client: job_service_client.JobServiceClient,
parent: str,
batch_prediction_job: Union[
empty_batch_prediction_job: "BatchPredictionJob",
model_or_model_name: Union[str, "aiplatform.Model"],
gca_batch_prediction_job: Union[
gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob
],
generate_explanation: bool,
project: str,
location: str,
credentials: Optional[auth_credentials.Credentials],
sync: bool = True,
) -> "BatchPredictionJob":
"""Create a batch prediction job.
Args:
api_client (dataset_service_client.DatasetServiceClient):
Required. An instance of DatasetServiceClient with the correct api_endpoint
already set based on user's preferences.
batch_prediction_job (gca_bp_job.BatchPredictionJob):
empty_batch_prediction_job (BatchPredictionJob):
Required. BatchPredictionJob without _gca_resource populated.
model_or_model_name (Union[str, aiplatform.Model]):
Required. Required. A fully-qualified model resource name or
an instance of aiplatform.Model.
gca_batch_prediction_job (gca_bp_job.BatchPredictionJob):
Required. a batch prediction job proto for creating a batch prediction job on Vertex AI.
generate_explanation (bool):
Required. Generate explanation along with the batch prediction
results.
parent (str):
Required. Also known as common location path, that usually contains the
project and location that the user provided to the upstream method.
Example: "projects/my-prj/locations/us-central1"
project (str):
Required. Project to upload this model to. Overrides project set in
aiplatform.init.
location (str):
Required. Location to upload this model to. Overrides location set in
aiplatform.init.
credentials (Optional[auth_credentials.Credentials]):
Custom credentials to use to upload this model. Overrides
credentials set in aiplatform.init.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Expand All @@ -725,21 +707,34 @@ def _create(
by Vertex AI.
"""
# select v1beta1 if explain else use default v1

parent = initializer.global_config.common_location_path(
project=empty_batch_prediction_job.project,
location=empty_batch_prediction_job.location,
)

model_resource_name = (
model_or_model_name
if isinstance(model_or_model_name, str)
else model_or_model_name.resource_name
)

gca_batch_prediction_job.model = model_resource_name

api_client = empty_batch_prediction_job.api_client

if generate_explanation:
api_client = api_client.select_version(compat.V1BETA1)

_LOGGER.log_create_with_lro(cls)

gca_batch_prediction_job = api_client.create_batch_prediction_job(
parent=parent, batch_prediction_job=batch_prediction_job
parent=parent, batch_prediction_job=gca_batch_prediction_job
)

batch_prediction_job = cls(
batch_prediction_job_name=gca_batch_prediction_job.name,
project=project,
location=location,
credentials=credentials,
)
empty_batch_prediction_job._gca_resource = gca_batch_prediction_job

batch_prediction_job = empty_batch_prediction_job

_LOGGER.log_create_complete(cls, batch_prediction_job._gca_resource, "bpj")

Expand Down Expand Up @@ -843,6 +838,10 @@ def iter_outputs(
f"on your prediction output:\n{output_info}"
)

def wait_for_resource_creation(self) -> None:
"""Waits until resource has been created."""
self._wait_for_resource_creation()


class _RunnableJob(_Job):
"""ABC to interface job as a runnable training class."""
Expand Down
4 changes: 1 addition & 3 deletions google/cloud/aiplatform/models.py
Expand Up @@ -981,7 +981,6 @@ def undeploy(
if deployed_model_id in traffic_split and traffic_split[deployed_model_id]:
raise ValueError("Model being undeployed should have 0 traffic.")
if sum(traffic_split.values()) != 100:
# TODO(b/172678233) verify every referenced deployed model exists
raise ValueError(
"Sum of all traffic within traffic split needs to be 100."
)
Expand Down Expand Up @@ -2167,11 +2166,10 @@ def batch_predict(
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
"""
self.wait()

return jobs.BatchPredictionJob.create(
job_display_name=job_display_name,
model_name=self.resource_name,
model_name=self,
instances_format=instances_format,
predictions_format=predictions_format,
gcs_source=gcs_source,
Expand Down
11 changes: 11 additions & 0 deletions tests/system/aiplatform/e2e_base.py
Expand Up @@ -43,6 +43,17 @@ def _temp_prefix(cls) -> str:
"""
pass

@classmethod
def _make_display_name(cls, key: str) -> str:
"""Helper method to make unique display_names.
Args:
key (str): Required. Identifier for the display name.
Returns:
Unique display name.
"""
return f"{cls._temp_prefix}-{key}-{uuid.uuid4()}"

def setup_method(self):
importlib.reload(initializer)
importlib.reload(aiplatform)
Expand Down

0 comments on commit db580ad

Please sign in to comment.