diff --git a/google/cloud/automl_v1beta1/services/tables/tables_client.py b/google/cloud/automl_v1beta1/services/tables/tables_client.py index 21028a36..f6e7889f 100644 --- a/google/cloud/automl_v1beta1/services/tables/tables_client.py +++ b/google/cloud/automl_v1beta1/services/tables/tables_client.py @@ -2999,7 +2999,10 @@ def batch_predict( ) req = google.cloud.automl_v1beta1.BatchPredictRequest( - name=model_name, input_config=input_request, output_config=output_request, + name=model_name, + input_config=input_request, + output_config=output_request, + params=params, ) method_kwargs = self.__process_request_kwargs(req, **kwargs) diff --git a/tests/unit/test_tables_client_v1beta1.py b/tests/unit/test_tables_client_v1beta1.py index 1d5b168c..4df06d48 100644 --- a/tests/unit/test_tables_client_v1beta1.py +++ b/tests/unit/test_tables_client_v1beta1.py @@ -1599,6 +1599,24 @@ def test_batch_predict_bigquery(self): ) ) + def test_batch_predict_bigquery_with_params(self): + client = self.tables_client({}, {}) + client.batch_predict( + model_name="my_model", + bigquery_input_uri="bq://input", + bigquery_output_uri="bq://output", + params={"feature_importance": "true"}, + ) + + client.prediction_client.batch_predict.assert_called_with( + request=automl_v1beta1.BatchPredictRequest( + name="my_model", + input_config={"bigquery_source": {"input_uri": "bq://input"}}, + output_config={"bigquery_destination": {"output_uri": "bq://output"}}, + params={"feature_importance": "true"}, + ) + ) + def test_batch_predict_mixed(self): client = self.tables_client({}, {}) client.batch_predict(