From 37ee0a1dc6e0105e19aca18f44995a352bfc40cb Mon Sep 17 00:00:00 2001 From: Andrew Ferlitsch Date: Tue, 7 Dec 2021 15:59:07 -0800 Subject: [PATCH] fix: add clarity to parameters per user feedback (#886) * fix: add clarity per user feedback * fix: review * fix: review * fix: TW review * lint: fix wsp * fix: param name * fix: code review Co-authored-by: Karl Weinmeister <11586922+kweinmeister@users.noreply.github.com> --- .../predict_tabular_classification_sample.py | 15 +++++++++++++-- .../predict_tabular_classification_sample_test.py | 2 +- .../predict_tabular_regression_sample.py | 7 +++++-- .../predict_tabular_regression_sample_test.py | 2 +- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/samples/model-builder/predict_tabular_classification_sample.py b/samples/model-builder/predict_tabular_classification_sample.py index e5b1a0283d..0a0a1da8bd 100644 --- a/samples/model-builder/predict_tabular_classification_sample.py +++ b/samples/model-builder/predict_tabular_classification_sample.py @@ -20,11 +20,22 @@ # [START aiplatform_sdk_predict_tabular_classification_sample] def predict_tabular_classification_sample( - project: str, location: str, endpoint: str, instances: List[Dict], + project: str, + location: str, + endpoint_name: str, + instances: List[Dict], ): + ''' + Args + project: Your project ID or project number. + location: Region where Endpoint is located. For example, 'us-central1'. + endpoint_name: A fully qualified endpoint name or endpoint ID. Example: "projects/123/locations/us-central1/endpoints/456" or + "456" when project and location are initialized or passed. + instances: A list of one or more instances (examples) to return a prediction for. + ''' aiplatform.init(project=project, location=location) - endpoint = aiplatform.Endpoint(endpoint) + endpoint = aiplatform.Endpoint(endpoint_name) response = endpoint.predict(instances=instances) diff --git a/samples/model-builder/predict_tabular_classification_sample_test.py b/samples/model-builder/predict_tabular_classification_sample_test.py index 49a701115b..66f2976803 100644 --- a/samples/model-builder/predict_tabular_classification_sample_test.py +++ b/samples/model-builder/predict_tabular_classification_sample_test.py @@ -22,7 +22,7 @@ def test_predict_tabular_classification_sample(mock_sdk_init, mock_get_endpoint) predict_tabular_classification_sample.predict_tabular_classification_sample( project=constants.PROJECT, location=constants.LOCATION, - endpoint=constants.ENDPOINT_NAME, + endpoint_name=constants.ENDPOINT_NAME, instances=constants.PREDICTION_TABULAR_CLASSIFICATION_INSTANCE, ) diff --git a/samples/model-builder/predict_tabular_regression_sample.py b/samples/model-builder/predict_tabular_regression_sample.py index fee4d34e38..b7bf575d44 100644 --- a/samples/model-builder/predict_tabular_regression_sample.py +++ b/samples/model-builder/predict_tabular_regression_sample.py @@ -19,11 +19,14 @@ # [START aiplatform_sdk_predict_tabular_regression_sample] def predict_tabular_regression_sample( - project: str, location: str, endpoint: str, instances: List[Dict], + project: str, + location: str, + endpoint_name: str, + instances: List[Dict], ): aiplatform.init(project=project, location=location) - endpoint = aiplatform.Endpoint(endpoint) + endpoint = aiplatform.Endpoint(endpoint_name) response = endpoint.predict(instances=instances) diff --git a/samples/model-builder/predict_tabular_regression_sample_test.py b/samples/model-builder/predict_tabular_regression_sample_test.py index 7491d7c1d5..abda65b3c4 100644 --- a/samples/model-builder/predict_tabular_regression_sample_test.py +++ b/samples/model-builder/predict_tabular_regression_sample_test.py @@ -22,7 +22,7 @@ def test_predict_tabular_regression_sample(mock_sdk_init, mock_get_endpoint): predict_tabular_regression_sample.predict_tabular_regression_sample( project=constants.PROJECT, location=constants.LOCATION, - endpoint=constants.ENDPOINT_NAME, + endpoint_name=constants.ENDPOINT_NAME, instances=constants.PREDICTION_TABULAR_REGRESSOIN_INSTANCE, )