From e87a255705a5d04ade79f12c706dc842c0228118 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 28 Sep 2021 14:58:05 -0700 Subject: [PATCH] fix: use the project id from BQ dataset instead of the default project id (#717) It is possible that the user specify a different project than the current project as the BQ output destination for batch prediction job. --- google/cloud/aiplatform/jobs.py | 5 +---- tests/unit/aiplatform/test_jobs.py | 5 +++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index f696c4e975..aaadb6e4d7 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -795,9 +795,6 @@ def iter_outputs( if bq_dataset.startswith("bq://"): bq_dataset = bq_dataset[5:] - # # Split project ID and BQ dataset ID - _, bq_dataset_id = bq_dataset.split(".", 1) - # Build a BigQuery Client using the same credentials as JobServiceClient bq_client = bigquery.Client( project=self.project, @@ -805,7 +802,7 @@ def iter_outputs( ) row_iterator = bq_client.list_rows( - table=f"{bq_dataset_id}.{bq_table}", max_results=bq_max_results + table=f"{bq_dataset}.{bq_table}", max_results=bq_max_results ) return row_iterator diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index df55fc3864..462295c6be 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -48,13 +48,14 @@ _TEST_ID = "1028944691210842416" _TEST_ALT_ID = "8834795523125638878" _TEST_DISPLAY_NAME = "my_job_1234" +_TEST_BQ_PROJECT_ID = "projectId" _TEST_BQ_DATASET_ID = "bqDatasetId" _TEST_BQ_TABLE_NAME = "someBqTable" _TEST_BQ_JOB_ID = "123459876" _TEST_BQ_MAX_RESULTS = 100 _TEST_GCS_BUCKET_NAME = "my-bucket" -_TEST_BQ_PATH = f"bq://projectId.{_TEST_BQ_DATASET_ID}" +_TEST_BQ_PATH = f"bq://{_TEST_BQ_PROJECT_ID}.{_TEST_BQ_DATASET_ID}" _TEST_GCS_BUCKET_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}" _TEST_GCS_JSONL_SOURCE_URI = f"{_TEST_GCS_BUCKET_PATH}/bp_input_config.jsonl" _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" @@ -420,7 +421,7 @@ def test_batch_prediction_iter_dirs_bq(self, bq_list_rows_mock): bp.iter_outputs() bq_list_rows_mock.assert_called_once_with( - table=f"{_TEST_BQ_DATASET_ID}.{_TEST_BQ_TABLE_NAME}", + table=f"{_TEST_BQ_PROJECT_ID}.{_TEST_BQ_DATASET_ID}.{_TEST_BQ_TABLE_NAME}", max_results=_TEST_BQ_MAX_RESULTS, )