From 1e87971b776f73ec32311fc72dca368869695a4e Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Wed, 25 Aug 2021 17:09:30 -0400 Subject: [PATCH 1/7] feat: Unblock async batch prediction jobs and add wait_for_resource_creation. --- google/cloud/aiplatform/base.py | 13 +++---- google/cloud/aiplatform/jobs.py | 60 +++++++++++++++++++------------ google/cloud/aiplatform/models.py | 4 +-- 3 files changed, 46 insertions(+), 31 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 20f9aa07ad..8bee523487 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -680,17 +680,18 @@ def wrapper(*args, **kwargs): inspect.getfullargspec(method).annotations["return"] ) + + # object produced by the method + returned_object = bound_args.arguments.get(return_input_arg) + # is a classmethod that creates the object and returns it if args and inspect.isclass(args[0]): - # assumes classmethod is our resource noun - returned_object = args[0]._empty_constructor() + + # assumes class in classmethod is the resource noun + returned_object = args[0]._empty_constructor() if not returned_object else returned_object self = returned_object else: # instance method - - # object produced by the method - returned_object = bound_args.arguments.get(return_input_arg) - # if we're returning an input object if returned_object and returned_object is not self: diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 6a5eb8ffee..fe317a9166 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -352,7 +352,7 @@ def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]: def create( cls, job_display_name: str, - model_name: str, + model_name: Union[str, 'aiplatform.Model'], instances_format: str = "jsonl", predictions_format: str = "jsonl", gcs_source: Optional[Union[str, Sequence[str]]] = None, @@ -388,6 +388,8 @@ def create( Required. A fully-qualified model resource name or model ID. Example: "projects/123/locations/us-central1/models/456" or "456" when project and location are initialized or passed. + + Or an instance ot aiplatform.Model. instances_format (str): Required. The format in which instances are given, must be one of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip", @@ -531,17 +533,19 @@ def create( (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. """ - + utils.validate_display_name(job_display_name) + if labels: utils.validate_labels(labels) - model_name = utils.full_resource_name( - resource_name=model_name, - resource_noun="models", - project=project, - location=location, - ) + if isinstance(model_name, str): + model_name = utils.full_resource_name( + resource_name=model_name, + resource_noun="models", + project=project, + location=location, + ) # Raise error if both or neither source URIs are provided if bool(gcs_source) == bool(bigquery_source): @@ -570,6 +574,7 @@ def create( f"{predictions_format} is not an accepted prediction format " f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}" ) + gca_bp_job = gca_bp_job_compat gca_io = gca_io_compat gca_machine_resources = gca_machine_resources_compat @@ -584,7 +589,6 @@ def create( # Required Fields gapic_batch_prediction_job.display_name = job_display_name - gapic_batch_prediction_job.model = model_name input_config = gca_bp_job.BatchPredictionJob.InputConfig() output_config = gca_bp_job.BatchPredictionJob.OutputConfig() @@ -657,12 +661,15 @@ def create( metadata=explanation_metadata, parameters=explanation_parameters ) - # TODO (b/174502913): Support private feature once released - - api_client = cls._instantiate_client(location=location, credentials=credentials) + empty_batch_prediction_job = cls._empty_constructor( + project=project, + location=location, + credentials=credentials, + ) return cls._create( - api_client=api_client, + empty_batch_prediction_job=empty_batch_prediction_job, + model_or_model_name=model_name, parent=initializer.global_config.common_location_path( project=project, location=location ), @@ -673,12 +680,13 @@ def create( credentials=credentials or initializer.global_config.credentials, sync=sync, ) - + @classmethod - @base.optional_sync() + @base.optional_sync(return_input_arg='empty_batch_prediction_job') def _create( cls, - api_client: job_service_client.JobServiceClient, + empty_batch_prediction_job: 'BatchPredictionJob', + model_or_model_name: Union[str, 'aiplatform.Model'], parent: str, batch_prediction_job: Union[ gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob @@ -725,6 +733,13 @@ def _create( by Vertex AI. """ # select v1beta1 if explain else use default v1 + + + model = model_or_model_name if isinstance(model_or_model_name, str) else model_or_model_name.resource_name + batch_prediction_job.model = model + + api_client = empty_batch_prediction_job.api_client + if generate_explanation: api_client = api_client.select_version(compat.V1BETA1) @@ -734,12 +749,9 @@ def _create( parent=parent, batch_prediction_job=batch_prediction_job ) - batch_prediction_job = cls( - batch_prediction_job_name=gca_batch_prediction_job.name, - project=project, - location=location, - credentials=credentials, - ) + empty_batch_prediction_job._gca_resource = gca_batch_prediction_job + + batch_prediction_job = empty_batch_prediction_job _LOGGER.log_create_complete(cls, batch_prediction_job._gca_resource, "bpj") @@ -843,6 +855,10 @@ def iter_outputs( f"on your prediction output:\n{output_info}" ) + def wait_for_resource_creation(self) -> None: + """Waits until resource has been created.""" + self._wait_for_resource_creation() + class _RunnableJob(_Job): """ABC to interface job as a runnable training class.""" diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 4af337b3e8..ca1aefd9f3 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -981,7 +981,6 @@ def undeploy( if deployed_model_id in traffic_split and traffic_split[deployed_model_id]: raise ValueError("Model being undeployed should have 0 traffic.") if sum(traffic_split.values()) != 100: - # TODO(b/172678233) verify every referenced deployed model exists raise ValueError( "Sum of all traffic within traffic split needs to be 100." ) @@ -2167,11 +2166,10 @@ def batch_predict( (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. """ - self.wait() return jobs.BatchPredictionJob.create( job_display_name=job_display_name, - model_name=self.resource_name, + model_name=self, instances_format=instances_format, predictions_format=predictions_format, gcs_source=gcs_source, From 3ae693eb09d73c97f9a71bdd11ddccba627c111a Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Thu, 26 Aug 2021 09:25:24 -0400 Subject: [PATCH 2/7] test: update tests --- tests/system/aiplatform/test_e2e_tabular.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/system/aiplatform/test_e2e_tabular.py b/tests/system/aiplatform/test_e2e_tabular.py index 9a330f34cf..17894f061e 100644 --- a/tests/system/aiplatform/test_e2e_tabular.py +++ b/tests/system/aiplatform/test_e2e_tabular.py @@ -64,9 +64,11 @@ def test_end_to_end_tabular(self, shared_state): # Create and import to single managed dataset for both training jobs + dataset_gcs_source = f'gs://{shared_state["staging_bucket_name"]}/{_BLOB_PATH}' + ds = aiplatform.TabularDataset.create( display_name=f"{self._temp_prefix}-dataset-{uuid.uuid4()}", - gcs_source=[f'gs://{shared_state["staging_bucket_name"]}/{_BLOB_PATH}'], + gcs_source=[dataset_gcs_source], sync=False, ) From 43181de0a2fb5c5877f59779da0ae802710b0561 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Thu, 26 Aug 2021 12:39:04 -0400 Subject: [PATCH 3/7] test: update batch prediction tests --- tests/unit/aiplatform/test_jobs.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index d10eb0335d..12bc79a8d7 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -212,6 +212,11 @@ def get_batch_prediction_job_mock(): job_service_client.JobServiceClient, "get_batch_prediction_job" ) as get_batch_prediction_job_mock: get_batch_prediction_job_mock.side_effect = [ + gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_PENDING, + ), gca_batch_prediction_job.BatchPredictionJob( name=_TEST_BATCH_PREDICTION_JOB_NAME, display_name=_TEST_DISPLAY_NAME, @@ -475,8 +480,9 @@ def test_batch_predict_gcs_source_and_dest( sync=sync, ) - if not sync: - batch_prediction_job.wait() + batch_prediction_job.wait_for_resource_creation() + + batch_prediction_job.wait() # Construct expected request expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( @@ -514,8 +520,9 @@ def test_batch_predict_gcs_source_bq_dest( sync=sync, ) - if not sync: - batch_prediction_job.wait() + batch_prediction_job.wait_for_resource_creation() + + batch_prediction_job.wait() assert ( batch_prediction_job.output_info @@ -571,8 +578,9 @@ def test_batch_predict_with_all_args( sync=sync, ) - if not sync: - batch_prediction_job.wait() + batch_prediction_job.wait_for_resource_creation() + + batch_prediction_job.wait() # Construct expected request expected_gapic_batch_prediction_job = gca_batch_prediction_job_v1beta1.BatchPredictionJob( From ad6018628f7289519f06e3cf3d7e6b747ccc8a0b Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Thu, 26 Aug 2021 16:20:44 -0400 Subject: [PATCH 4/7] test: Add batch prediction to integration tests --- google/cloud/aiplatform/base.py | 11 +++-- google/cloud/aiplatform/jobs.py | 45 ++++++++++---------- tests/system/aiplatform/e2e_base.py | 11 +++++ tests/system/aiplatform/test_e2e_tabular.py | 46 ++++++++++++++++++--- tests/unit/aiplatform/test_jobs.py | 2 +- 5 files changed, 81 insertions(+), 34 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 8bee523487..d7d0e6317b 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -680,15 +680,18 @@ def wrapper(*args, **kwargs): inspect.getfullargspec(method).annotations["return"] ) - # object produced by the method - returned_object = bound_args.arguments.get(return_input_arg) + returned_object = bound_args.arguments.get(return_input_arg) # is a classmethod that creates the object and returns it if args and inspect.isclass(args[0]): - + # assumes class in classmethod is the resource noun - returned_object = args[0]._empty_constructor() if not returned_object else returned_object + returned_object = ( + args[0]._empty_constructor() + if not returned_object + else returned_object + ) self = returned_object else: # instance method diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index fe317a9166..46c10b40e9 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -32,15 +32,6 @@ from google.cloud import aiplatform from google.cloud.aiplatform import base from google.cloud.aiplatform import compat -from google.cloud.aiplatform import constants -from google.cloud.aiplatform import initializer -from google.cloud.aiplatform import hyperparameter_tuning -from google.cloud.aiplatform import utils -from google.cloud.aiplatform.utils import console_utils -from google.cloud.aiplatform.utils import source_utils -from google.cloud.aiplatform.utils import worker_spec_utils - -from google.cloud.aiplatform.compat.services import job_service_client from google.cloud.aiplatform.compat.types import ( batch_prediction_job as gca_bp_job_compat, batch_prediction_job_v1 as gca_bp_job_v1, @@ -58,6 +49,13 @@ machine_resources_v1beta1 as gca_machine_resources_v1beta1, study as gca_study_compat, ) +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import hyperparameter_tuning +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.utils import console_utils +from google.cloud.aiplatform.utils import source_utils +from google.cloud.aiplatform.utils import worker_spec_utils _LOGGER = base.Logger(__name__) @@ -352,7 +350,7 @@ def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]: def create( cls, job_display_name: str, - model_name: Union[str, 'aiplatform.Model'], + model_name: Union[str, "aiplatform.Model"], instances_format: str = "jsonl", predictions_format: str = "jsonl", gcs_source: Optional[Union[str, Sequence[str]]] = None, @@ -533,9 +531,9 @@ def create( (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. """ - + utils.validate_display_name(job_display_name) - + if labels: utils.validate_labels(labels) @@ -574,7 +572,7 @@ def create( f"{predictions_format} is not an accepted prediction format " f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}" ) - + gca_bp_job = gca_bp_job_compat gca_io = gca_io_compat gca_machine_resources = gca_machine_resources_compat @@ -662,10 +660,8 @@ def create( ) empty_batch_prediction_job = cls._empty_constructor( - project=project, - location=location, - credentials=credentials, - ) + project=project, location=location, credentials=credentials, + ) return cls._create( empty_batch_prediction_job=empty_batch_prediction_job, @@ -680,13 +676,13 @@ def create( credentials=credentials or initializer.global_config.credentials, sync=sync, ) - + @classmethod - @base.optional_sync(return_input_arg='empty_batch_prediction_job') + @base.optional_sync(return_input_arg="empty_batch_prediction_job") def _create( cls, - empty_batch_prediction_job: 'BatchPredictionJob', - model_or_model_name: Union[str, 'aiplatform.Model'], + empty_batch_prediction_job: "BatchPredictionJob", + model_or_model_name: Union[str, "aiplatform.Model"], parent: str, batch_prediction_job: Union[ gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob @@ -734,8 +730,11 @@ def _create( """ # select v1beta1 if explain else use default v1 - - model = model_or_model_name if isinstance(model_or_model_name, str) else model_or_model_name.resource_name + model = ( + model_or_model_name + if isinstance(model_or_model_name, str) + else model_or_model_name.resource_name + ) batch_prediction_job.model = model api_client = empty_batch_prediction_job.api_client diff --git a/tests/system/aiplatform/e2e_base.py b/tests/system/aiplatform/e2e_base.py index c0843133dd..de91c1249a 100644 --- a/tests/system/aiplatform/e2e_base.py +++ b/tests/system/aiplatform/e2e_base.py @@ -43,6 +43,17 @@ def _temp_prefix(cls) -> str: """ pass + @classmethod + def _make_display_name(cls, key: str) -> str: + """Helper method to make unique display_names. + + Args: + key (str): Required. Identifier for the display name. + Returns: + Unique display name. + """ + return f"{cls._temp_prefix}-{key}-{uuid.uuid4()}" + def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) diff --git a/tests/system/aiplatform/test_e2e_tabular.py b/tests/system/aiplatform/test_e2e_tabular.py index 17894f061e..a55ea237e4 100644 --- a/tests/system/aiplatform/test_e2e_tabular.py +++ b/tests/system/aiplatform/test_e2e_tabular.py @@ -16,12 +16,15 @@ # import os -import uuid from urllib import request import pytest from google.cloud import aiplatform +from google.cloud.aiplatform.compat.types import ( + job_state as gca_job_state, + pipeline_state as gca_pipeline_state, +) from tests.system.aiplatform import e2e_base @@ -67,7 +70,7 @@ def test_end_to_end_tabular(self, shared_state): dataset_gcs_source = f'gs://{shared_state["staging_bucket_name"]}/{_BLOB_PATH}' ds = aiplatform.TabularDataset.create( - display_name=f"{self._temp_prefix}-dataset-{uuid.uuid4()}", + display_name=self._make_display_name("dataset"), gcs_source=[dataset_gcs_source], sync=False, ) @@ -77,7 +80,7 @@ def test_end_to_end_tabular(self, shared_state): # Define both training jobs custom_job = aiplatform.CustomTrainingJob( - display_name=f"{self._temp_prefix}-train-housing-custom-{uuid.uuid4()}", + display_name=self._make_display_name("train-housing-custom"), script_path=_LOCAL_TRAINING_SCRIPT_PATH, container_uri="gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest", requirements=["gcsfs==0.7.1"], @@ -85,7 +88,7 @@ def test_end_to_end_tabular(self, shared_state): ) automl_job = aiplatform.AutoMLTabularTrainingJob( - display_name=f"{self._temp_prefix}-train-housing-automl-{uuid.uuid4()}", + display_name=self._make_display_name("train-housing-automl"), optimization_prediction_type="regression", optimization_objective="minimize-rmse", ) @@ -95,14 +98,14 @@ def test_end_to_end_tabular(self, shared_state): custom_model = custom_job.run( ds, replica_count=1, - model_display_name=f"{self._temp_prefix}-custom-housing-model-{uuid.uuid4()}", + model_display_name=self._make_display_name("custom-housing-model"), sync=False, ) automl_model = automl_job.run( dataset=ds, target_column="median_house_value", - model_display_name=f"{self._temp_prefix}-automl-housing-model-{uuid.uuid4()}", + model_display_name=self._make_display_name("automl-housing-model"), sync=False, ) @@ -115,6 +118,21 @@ def test_end_to_end_tabular(self, shared_state): automl_endpoint = automl_model.deploy(machine_type="n1-standard-4", sync=False) shared_state["resources"].extend([automl_endpoint, custom_endpoint]) + custom_batch_prediction_job = custom_model.batch_predict( + job_display_name=self._make_display_name("automl-housing-model"), + instances_format="csv", + machine_type="n1-standard-4", + gcs_source=dataset_gcs_source, + gcs_destination_prefix=f'gs://{shared_state["staging_bucket_name"]}/bp_results/', + sync=False, + ) + + shared_state["resources"].append(custom_batch_prediction_job) + + custom_job.wait_for_resource_creation() + automl_job.wait_for_resource_creation() + custom_batch_prediction_job.wait_for_resource_creation() + # Send online prediction with same instance to both deployed models # This sample is taken from an observation where median_house_value = 94600 custom_endpoint.wait() @@ -132,6 +150,9 @@ def test_end_to_end_tabular(self, shared_state): }, ] ) + + custom_batch_prediction_job.wait() + automl_endpoint.wait() automl_prediction = automl_endpoint.predict( [ @@ -148,6 +169,19 @@ def test_end_to_end_tabular(self, shared_state): ] ) + assert ( + custom_job.state + == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + assert ( + automl_job.state + == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + assert ( + custom_batch_prediction_job.state + == gca_job_state.JobState.JOB_STATE_SUCCEEDED + ) + # Ensure a single prediction was returned assert len(custom_prediction.predictions) == 1 assert len(automl_prediction.predictions) == 1 diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 12bc79a8d7..f14eea99bc 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -395,7 +395,7 @@ def test_batch_prediction_job_status(self, get_batch_prediction_job_mock): bp_job_state = bp.state assert get_batch_prediction_job_mock.call_count == 2 - assert bp_job_state == _TEST_JOB_STATE_SUCCESS + assert bp_job_state == _TEST_JOB_STATE_RUNNING get_batch_prediction_job_mock.assert_called_with( name=_TEST_BATCH_PREDICTION_JOB_NAME From b7ed326060112c2f6b93c3e41efd3e44a57d991a Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Fri, 27 Aug 2021 10:11:50 -0400 Subject: [PATCH 5/7] chore: lint --- google/cloud/aiplatform/jobs.py | 54 ++++++++++++--------------------- 1 file changed, 19 insertions(+), 35 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 46c10b40e9..ed59996310 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -382,12 +382,12 @@ def create( Required. The user-defined name of the BatchPredictionJob. The name can be up to 128 characters long and can be consist of any UTF-8 characters. - model_name (str): + model_name (Union[str, aiplatform.Model]): Required. A fully-qualified model resource name or model ID. Example: "projects/123/locations/us-central1/models/456" or "456" when project and location are initialized or passed. - Or an instance ot aiplatform.Model. + Or an instance of aiplatform.Model. instances_format (str): Required. The format in which instances are given, must be one of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip", @@ -666,14 +666,8 @@ def create( return cls._create( empty_batch_prediction_job=empty_batch_prediction_job, model_or_model_name=model_name, - parent=initializer.global_config.common_location_path( - project=project, location=location - ), - batch_prediction_job=gapic_batch_prediction_job, + gca_batch_prediction_job=gapic_batch_prediction_job, generate_explanation=generate_explanation, - project=project or initializer.global_config.project, - location=location or initializer.global_config.location, - credentials=credentials or initializer.global_config.credentials, sync=sync, ) @@ -683,41 +677,25 @@ def _create( cls, empty_batch_prediction_job: "BatchPredictionJob", model_or_model_name: Union[str, "aiplatform.Model"], - parent: str, - batch_prediction_job: Union[ + gca_batch_prediction_job: Union[ gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob ], generate_explanation: bool, - project: str, - location: str, - credentials: Optional[auth_credentials.Credentials], sync: bool = True, ) -> "BatchPredictionJob": """Create a batch prediction job. Args: - api_client (dataset_service_client.DatasetServiceClient): - Required. An instance of DatasetServiceClient with the correct api_endpoint - already set based on user's preferences. - batch_prediction_job (gca_bp_job.BatchPredictionJob): + empty_batch_prediction_job (BatchPredictionJob): + Required. BatchPredictionJob without _gca_resource populated. + model_or_model_name (Union[str, aiplatform.Model]): + Required. Required. A fully-qualified model resource name or + an instance of aiplatform.Model. + gca_batch_prediction_job (gca_bp_job.BatchPredictionJob): Required. a batch prediction job proto for creating a batch prediction job on Vertex AI. generate_explanation (bool): Required. Generate explanation along with the batch prediction results. - parent (str): - Required. Also known as common location path, that usually contains the - project and location that the user provided to the upstream method. - Example: "projects/my-prj/locations/us-central1" - project (str): - Required. Project to upload this model to. Overrides project set in - aiplatform.init. - location (str): - Required. Location to upload this model to. Overrides location set in - aiplatform.init. - credentials (Optional[auth_credentials.Credentials]): - Custom credentials to use to upload this model. Overrides - credentials set in aiplatform.init. - Returns: (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. @@ -730,12 +708,18 @@ def _create( """ # select v1beta1 if explain else use default v1 - model = ( + parent = initializer.global_config.common_location_path( + project=empty_batch_prediction_job.project, + location=empty_batch_prediction_job.location, + ) + + model_resource_name = ( model_or_model_name if isinstance(model_or_model_name, str) else model_or_model_name.resource_name ) - batch_prediction_job.model = model + + gca_batch_prediction_job.model = model_resource_name api_client = empty_batch_prediction_job.api_client @@ -745,7 +729,7 @@ def _create( _LOGGER.log_create_with_lro(cls) gca_batch_prediction_job = api_client.create_batch_prediction_job( - parent=parent, batch_prediction_job=batch_prediction_job + parent=parent, batch_prediction_job=gca_batch_prediction_job ) empty_batch_prediction_job._gca_resource = gca_batch_prediction_job From d905ddee50a7276dfb3602c2cbafdae6f8e0f68d Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Fri, 27 Aug 2021 10:28:22 -0400 Subject: [PATCH 6/7] docs: update README --- README.rst | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/README.rst b/README.rst index e8fc200700..1f4462722c 100644 --- a/README.rst +++ b/README.rst @@ -274,6 +274,39 @@ Please visit `Importing models to Vertex AI`_ for a detailed overview: .. _Importing models to Vertex AI: https://cloud.google.com/vertex-ai/docs/general/import-model +Batch Prediction +---------------- + +To create a batch prediction job: + +.. code-block:: Python + + model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}') + + batch_prediction_job = model.batch_predict( + job_display_name='my-batch-prediction-job', + instances_format='csv' + machine_type='n1-standard-4', + gcs_source=['gs://path/to/my/file.csv'] + gcs_destination_prefix='gs://path/to/by/batch_prediction/results/' + ) + +You can also create a batch prediction job asynchronously by including the `sync=False` argument: + +.. code-block:: Python + + batch_prediction_job = model.batch_predict(..., sync=False) + + # wait for resource to be created + batch_prediction_job.wait_for_resource_creation() + + # get the state + batch_prediction_job.state + + # block until job is complete + batch_prediction_job.wait() + + Endpoints --------- From 7ffb1f1a20dc2de6c77564ec1959fb3ff45fa364 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Fri, 27 Aug 2021 11:59:38 -0400 Subject: [PATCH 7/7] test: use unique filenames to avoid conflicts during parallel tests --- tests/unit/aiplatform/test_training_jobs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 0fd781b380..1a919f1635 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -25,6 +25,7 @@ import sys import tarfile import tempfile +import uuid from unittest import mock from unittest.mock import patch @@ -614,11 +615,12 @@ class TestCustomTrainingJob: def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) - with open(_TEST_LOCAL_SCRIPT_FILE_NAME, "w") as fp: + self._local_script_file_name = f"{uuid.uuid4()}-{_TEST_LOCAL_SCRIPT_FILE_NAME}" + with open(self._local_script_file_name, "w") as fp: fp.write(_TEST_PYTHON_SOURCE) def teardown_method(self): - pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_NAME).unlink() + pathlib.Path(self._local_script_file_name).unlink() initializer.global_pool.shutdown(wait=True) @pytest.mark.parametrize("sync", [True, False])