Skip to content
This repository has been archived by the owner on Dec 31, 2023. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
fix: Pass the 'params' parameter to the underlying 'BatchPredictReque…
…st' object in 'batch_predict()' method (#110)
  • Loading branch information
elbernante committed Dec 16, 2020
1 parent df22fd5 commit b89fb00
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
5 changes: 4 additions & 1 deletion google/cloud/automl_v1beta1/services/tables/tables_client.py
Expand Up @@ -2999,7 +2999,10 @@ def batch_predict(
)

req = google.cloud.automl_v1beta1.BatchPredictRequest(
name=model_name, input_config=input_request, output_config=output_request,
name=model_name,
input_config=input_request,
output_config=output_request,
params=params,
)

method_kwargs = self.__process_request_kwargs(req, **kwargs)
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_tables_client_v1beta1.py
Expand Up @@ -1599,6 +1599,24 @@ def test_batch_predict_bigquery(self):
)
)

def test_batch_predict_bigquery_with_params(self):
client = self.tables_client({}, {})
client.batch_predict(
model_name="my_model",
bigquery_input_uri="bq://input",
bigquery_output_uri="bq://output",
params={"feature_importance": "true"},
)

client.prediction_client.batch_predict.assert_called_with(
request=automl_v1beta1.BatchPredictRequest(
name="my_model",
input_config={"bigquery_source": {"input_uri": "bq://input"}},
output_config={"bigquery_destination": {"output_uri": "bq://output"}},
params={"feature_importance": "true"},
)
)

def test_batch_predict_mixed(self):
client = self.tables_client({}, {})
client.batch_predict(
Expand Down

0 comments on commit b89fb00

Please sign in to comment.