Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update BatchPredictionJob.iter_outputs() and BQ docstrings #631

Merged
merged 2 commits into from Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 16 additions & 9 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -406,7 +406,7 @@ def create(
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
bigquery_source (Optional[str]):
BigQuery URI to a table, up to 2000 characters long. For example:
`projectId.bqDatasetId.bqTableId`
`bq://projectId.bqDatasetId.bqTableId`
gcs_destination_prefix (Optional[str]):
The Google Cloud Storage location of the directory where the
output is to be written to. In the given directory a new
Expand Down Expand Up @@ -808,23 +808,30 @@ def iter_outputs(
# BigQuery Destination, return RowIterator
elif output_info.bigquery_output_dataset:

# Build a BigQuery Client using the same credentials as JobServiceClient
bq_client = bigquery.Client(
project=self.project,
credentials=self.api_client._transport._credentials,
)

# Format from service is `bq://projectId.bqDatasetId`
# Format of `bigquery_output_dataset` from service is `bq://projectId.bqDatasetId`
bq_dataset = output_info.bigquery_output_dataset
bq_table = output_info.bigquery_output_table

if not bq_table:
raise RuntimeError(
"A BigQuery table with predictions was not found, this "
f"might be due to errors. Visit {self._dashboard_uri()} for details."
)

if bq_dataset.startswith("bq://"):
bq_dataset = bq_dataset[5:]

# # Split project ID and BQ dataset ID
_, bq_dataset_id = bq_dataset.split(".", 1)

# Build a BigQuery Client using the same credentials as JobServiceClient
bq_client = bigquery.Client(
project=self.project,
credentials=self.api_client._transport._credentials,
)

row_iterator = bq_client.list_rows(
table=f"{bq_dataset_id}.predictions", max_results=bq_max_results
table=f"{bq_dataset_id}.{bq_table}", max_results=bq_max_results
)

return row_iterator
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/models.py
Expand Up @@ -2038,7 +2038,7 @@ def batch_predict(
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
bigquery_source: Optional[str] = None
BigQuery URI to a table, up to 2000 characters long. For example:
`projectId.bqDatasetId.bqTableId`
`bq://projectId.bqDatasetId.bqTableId`
instances_format: str = "jsonl"
Required. The format in which instances are given, must be one
of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip",
Expand Down
38 changes: 37 additions & 1 deletion tests/unit/aiplatform/test_jobs.py
Expand Up @@ -56,6 +56,7 @@
_TEST_ALT_ID = "8834795523125638878"
_TEST_DISPLAY_NAME = "my_job_1234"
_TEST_BQ_DATASET_ID = "bqDatasetId"
_TEST_BQ_TABLE_NAME = "someBqTable"
_TEST_BQ_JOB_ID = "123459876"
_TEST_BQ_MAX_RESULTS = 100
_TEST_GCS_BUCKET_NAME = "my-bucket"
Expand Down Expand Up @@ -108,6 +109,9 @@
gcs_output_directory=_TEST_GCS_BUCKET_NAME
)
_TEST_BQ_OUTPUT_INFO = gca_batch_prediction_job.BatchPredictionJob.OutputInfo(
bigquery_output_dataset=_TEST_BQ_PATH, bigquery_output_table=_TEST_BQ_TABLE_NAME
)
_TEST_BQ_OUTPUT_INFO_INCOMPLETE = gca_batch_prediction_job.BatchPredictionJob.OutputInfo(
bigquery_output_dataset=_TEST_BQ_PATH
)

Expand Down Expand Up @@ -296,6 +300,23 @@ def get_batch_prediction_job_bq_output_mock():
yield get_batch_prediction_job_mock


@pytest.fixture
def get_batch_prediction_job_incomplete_bq_output_mock():
with patch.object(
job_service_client.JobServiceClient, "get_batch_prediction_job"
) as get_batch_prediction_job_mock:
get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
model=_TEST_MODEL_NAME,
input_config=_TEST_GCS_INPUT_CONFIG,
output_config=_TEST_BQ_OUTPUT_CONFIG,
output_info=_TEST_BQ_OUTPUT_INFO_INCOMPLETE,
state=_TEST_JOB_STATE_SUCCESS,
)
yield get_batch_prediction_job_mock


@pytest.fixture
def get_batch_prediction_job_empty_output_mock():
with patch.object(
Expand Down Expand Up @@ -397,7 +418,22 @@ def test_batch_prediction_iter_dirs_bq(self, bq_list_rows_mock):
bp.iter_outputs()

bq_list_rows_mock.assert_called_once_with(
table=f"{_TEST_BQ_DATASET_ID}.predictions", max_results=_TEST_BQ_MAX_RESULTS
table=f"{_TEST_BQ_DATASET_ID}.{_TEST_BQ_TABLE_NAME}",
max_results=_TEST_BQ_MAX_RESULTS,
)

@pytest.mark.usefixtures("get_batch_prediction_job_incomplete_bq_output_mock")
def test_batch_prediction_iter_dirs_bq_raises_on_empty(self, bq_list_rows_mock):
bp = jobs.BatchPredictionJob(
batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME
)
with pytest.raises(RuntimeError) as e:
bp.iter_outputs()
assert e.match(
regexp=(
"A BigQuery table with predictions was not found,"
" this might be due to errors. Visit http"
)
)

@pytest.mark.usefixtures("get_batch_prediction_job_running_bq_output_mock")
Expand Down