From c2cf61288326cad28ab474064b887687bc649d76 Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 10 Jun 2021 11:00:11 -0700 Subject: [PATCH] feat: allow the prediction endpoint to be overridden (#461) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes # 🦕 --- google/cloud/aiplatform/constants.py | 1 + google/cloud/aiplatform/initializer.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/google/cloud/aiplatform/constants.py b/google/cloud/aiplatform/constants.py index a7d81084cd..67d730b7b8 100644 --- a/google/cloud/aiplatform/constants.py +++ b/google/cloud/aiplatform/constants.py @@ -33,6 +33,7 @@ } API_BASE_PATH = "aiplatform.googleapis.com" +PREDICTION_API_BASE_PATH = API_BASE_PATH # Batch Prediction BATCH_PREDICTION_INPUT_STORAGE_FORMATS = ( diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 9f0ad719f9..4f57115fe7 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -194,7 +194,7 @@ def encryption_spec_key_name(self) -> Optional[str]: return self._encryption_spec_key_name def get_client_options( - self, location_override: Optional[str] = None + self, location_override: Optional[str] = None, prediction_client: bool = False ) -> client_options.ClientOptions: """Creates GAPIC client_options using location and type. @@ -203,6 +203,8 @@ def get_client_options( Set this parameter to get client options for a location different from location set by initializer. Must be a GCP region supported by AI Platform (Unified). + prediction_client (str): Optional flag to use a prediction endpoint. + Returns: clients_options (google.api_core.client_options.ClientOptions): @@ -220,8 +222,14 @@ def get_client_options( utils.validate_region(region) + service_base_path = ( + constants.PREDICTION_API_BASE_PATH + if prediction_client + else constants.API_BASE_PATH + ) + return client_options.ClientOptions( - api_endpoint=f"{region}-{constants.API_BASE_PATH}" + api_endpoint=f"{region}-{service_base_path}" ) def common_location_path( @@ -278,7 +286,8 @@ def create_client( kwargs = { "credentials": credentials or self.credentials, "client_options": self.get_client_options( - location_override=location_override + location_override=location_override, + prediction_client=prediction_client, ), "client_info": client_info, }