diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 20d8141a22..6a5eb8ffee 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -406,7 +406,7 @@ def create( https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. bigquery_source (Optional[str]): BigQuery URI to a table, up to 2000 characters long. For example: - `projectId.bqDatasetId.bqTableId` + `bq://projectId.bqDatasetId.bqTableId` gcs_destination_prefix (Optional[str]): The Google Cloud Storage location of the directory where the output is to be written to. In the given directory a new @@ -808,14 +808,15 @@ def iter_outputs( # BigQuery Destination, return RowIterator elif output_info.bigquery_output_dataset: - # Build a BigQuery Client using the same credentials as JobServiceClient - bq_client = bigquery.Client( - project=self.project, - credentials=self.api_client._transport._credentials, - ) - - # Format from service is `bq://projectId.bqDatasetId` + # Format of `bigquery_output_dataset` from service is `bq://projectId.bqDatasetId` bq_dataset = output_info.bigquery_output_dataset + bq_table = output_info.bigquery_output_table + + if not bq_table: + raise RuntimeError( + "A BigQuery table with predictions was not found, this " + f"might be due to errors. Visit {self._dashboard_uri()} for details." + ) if bq_dataset.startswith("bq://"): bq_dataset = bq_dataset[5:] @@ -823,8 +824,14 @@ def iter_outputs( # # Split project ID and BQ dataset ID _, bq_dataset_id = bq_dataset.split(".", 1) + # Build a BigQuery Client using the same credentials as JobServiceClient + bq_client = bigquery.Client( + project=self.project, + credentials=self.api_client._transport._credentials, + ) + row_iterator = bq_client.list_rows( - table=f"{bq_dataset_id}.predictions", max_results=bq_max_results + table=f"{bq_dataset_id}.{bq_table}", max_results=bq_max_results ) return row_iterator diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index c1518ce89d..4af337b3e8 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -2038,7 +2038,7 @@ def batch_predict( https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. bigquery_source: Optional[str] = None BigQuery URI to a table, up to 2000 characters long. For example: - `projectId.bqDatasetId.bqTableId` + `bq://projectId.bqDatasetId.bqTableId` instances_format: str = "jsonl" Required. The format in which instances are given, must be one of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip", diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 76584cd0c4..d10eb0335d 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -56,6 +56,7 @@ _TEST_ALT_ID = "8834795523125638878" _TEST_DISPLAY_NAME = "my_job_1234" _TEST_BQ_DATASET_ID = "bqDatasetId" +_TEST_BQ_TABLE_NAME = "someBqTable" _TEST_BQ_JOB_ID = "123459876" _TEST_BQ_MAX_RESULTS = 100 _TEST_GCS_BUCKET_NAME = "my-bucket" @@ -108,6 +109,9 @@ gcs_output_directory=_TEST_GCS_BUCKET_NAME ) _TEST_BQ_OUTPUT_INFO = gca_batch_prediction_job.BatchPredictionJob.OutputInfo( + bigquery_output_dataset=_TEST_BQ_PATH, bigquery_output_table=_TEST_BQ_TABLE_NAME +) +_TEST_BQ_OUTPUT_INFO_INCOMPLETE = gca_batch_prediction_job.BatchPredictionJob.OutputInfo( bigquery_output_dataset=_TEST_BQ_PATH ) @@ -296,6 +300,23 @@ def get_batch_prediction_job_bq_output_mock(): yield get_batch_prediction_job_mock +@pytest.fixture +def get_batch_prediction_job_incomplete_bq_output_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_BQ_OUTPUT_INFO_INCOMPLETE, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + @pytest.fixture def get_batch_prediction_job_empty_output_mock(): with patch.object( @@ -397,7 +418,22 @@ def test_batch_prediction_iter_dirs_bq(self, bq_list_rows_mock): bp.iter_outputs() bq_list_rows_mock.assert_called_once_with( - table=f"{_TEST_BQ_DATASET_ID}.predictions", max_results=_TEST_BQ_MAX_RESULTS + table=f"{_TEST_BQ_DATASET_ID}.{_TEST_BQ_TABLE_NAME}", + max_results=_TEST_BQ_MAX_RESULTS, + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_incomplete_bq_output_mock") + def test_batch_prediction_iter_dirs_bq_raises_on_empty(self, bq_list_rows_mock): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + with pytest.raises(RuntimeError) as e: + bp.iter_outputs() + assert e.match( + regexp=( + "A BigQuery table with predictions was not found," + " this might be due to errors. Visit http" + ) ) @pytest.mark.usefixtures("get_batch_prediction_job_running_bq_output_mock")