Skip to content

Commit

Permalink
fix: Update BatchPredictionJob.iter_outputs() and BQ docstrings (#631)
Browse files Browse the repository at this point in the history
* fix: Have iter_outputs use BQ output table field

* fix: Update arg docstring to reflect bq:// prefix
  • Loading branch information
vinnysenthil committed Aug 17, 2021
1 parent 9dcf6fb commit 28f32fd
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
25 changes: 16 additions & 9 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -406,7 +406,7 @@ def create(
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
bigquery_source (Optional[str]):
BigQuery URI to a table, up to 2000 characters long. For example:
`projectId.bqDatasetId.bqTableId`
`bq://projectId.bqDatasetId.bqTableId`
gcs_destination_prefix (Optional[str]):
The Google Cloud Storage location of the directory where the
output is to be written to. In the given directory a new
Expand Down Expand Up @@ -808,23 +808,30 @@ def iter_outputs(
# BigQuery Destination, return RowIterator
elif output_info.bigquery_output_dataset:

# Build a BigQuery Client using the same credentials as JobServiceClient
bq_client = bigquery.Client(
project=self.project,
credentials=self.api_client._transport._credentials,
)

# Format from service is `bq://projectId.bqDatasetId`
# Format of `bigquery_output_dataset` from service is `bq://projectId.bqDatasetId`
bq_dataset = output_info.bigquery_output_dataset
bq_table = output_info.bigquery_output_table

if not bq_table:
raise RuntimeError(
"A BigQuery table with predictions was not found, this "
f"might be due to errors. Visit {self._dashboard_uri()} for details."
)

if bq_dataset.startswith("bq://"):
bq_dataset = bq_dataset[5:]

# # Split project ID and BQ dataset ID
_, bq_dataset_id = bq_dataset.split(".", 1)

# Build a BigQuery Client using the same credentials as JobServiceClient
bq_client = bigquery.Client(
project=self.project,
credentials=self.api_client._transport._credentials,
)

row_iterator = bq_client.list_rows(
table=f"{bq_dataset_id}.predictions", max_results=bq_max_results
table=f"{bq_dataset_id}.{bq_table}", max_results=bq_max_results
)

return row_iterator
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/models.py
Expand Up @@ -2038,7 +2038,7 @@ def batch_predict(
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
bigquery_source: Optional[str] = None
BigQuery URI to a table, up to 2000 characters long. For example:
`projectId.bqDatasetId.bqTableId`
`bq://projectId.bqDatasetId.bqTableId`
instances_format: str = "jsonl"
Required. The format in which instances are given, must be one
of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip",
Expand Down
38 changes: 37 additions & 1 deletion tests/unit/aiplatform/test_jobs.py
Expand Up @@ -56,6 +56,7 @@
_TEST_ALT_ID = "8834795523125638878"
_TEST_DISPLAY_NAME = "my_job_1234"
_TEST_BQ_DATASET_ID = "bqDatasetId"
_TEST_BQ_TABLE_NAME = "someBqTable"
_TEST_BQ_JOB_ID = "123459876"
_TEST_BQ_MAX_RESULTS = 100
_TEST_GCS_BUCKET_NAME = "my-bucket"
Expand Down Expand Up @@ -108,6 +109,9 @@
gcs_output_directory=_TEST_GCS_BUCKET_NAME
)
_TEST_BQ_OUTPUT_INFO = gca_batch_prediction_job.BatchPredictionJob.OutputInfo(
bigquery_output_dataset=_TEST_BQ_PATH, bigquery_output_table=_TEST_BQ_TABLE_NAME
)
_TEST_BQ_OUTPUT_INFO_INCOMPLETE = gca_batch_prediction_job.BatchPredictionJob.OutputInfo(
bigquery_output_dataset=_TEST_BQ_PATH
)

Expand Down Expand Up @@ -296,6 +300,23 @@ def get_batch_prediction_job_bq_output_mock():
yield get_batch_prediction_job_mock


@pytest.fixture
def get_batch_prediction_job_incomplete_bq_output_mock():
with patch.object(
job_service_client.JobServiceClient, "get_batch_prediction_job"
) as get_batch_prediction_job_mock:
get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
model=_TEST_MODEL_NAME,
input_config=_TEST_GCS_INPUT_CONFIG,
output_config=_TEST_BQ_OUTPUT_CONFIG,
output_info=_TEST_BQ_OUTPUT_INFO_INCOMPLETE,
state=_TEST_JOB_STATE_SUCCESS,
)
yield get_batch_prediction_job_mock


@pytest.fixture
def get_batch_prediction_job_empty_output_mock():
with patch.object(
Expand Down Expand Up @@ -397,7 +418,22 @@ def test_batch_prediction_iter_dirs_bq(self, bq_list_rows_mock):
bp.iter_outputs()

bq_list_rows_mock.assert_called_once_with(
table=f"{_TEST_BQ_DATASET_ID}.predictions", max_results=_TEST_BQ_MAX_RESULTS
table=f"{_TEST_BQ_DATASET_ID}.{_TEST_BQ_TABLE_NAME}",
max_results=_TEST_BQ_MAX_RESULTS,
)

@pytest.mark.usefixtures("get_batch_prediction_job_incomplete_bq_output_mock")
def test_batch_prediction_iter_dirs_bq_raises_on_empty(self, bq_list_rows_mock):
bp = jobs.BatchPredictionJob(
batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME
)
with pytest.raises(RuntimeError) as e:
bp.iter_outputs()
assert e.match(
regexp=(
"A BigQuery table with predictions was not found,"
" this might be due to errors. Visit http"
)
)

@pytest.mark.usefixtures("get_batch_prediction_job_running_bq_output_mock")
Expand Down

0 comments on commit 28f32fd

Please sign in to comment.