diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 00f6b19b40..2aa98b1600 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -194,18 +194,20 @@ def encryption_spec_key_name(self) -> Optional[str]: return self._encryption_spec_key_name def get_client_options( - self, location_override: Optional[str] = None, prediction_client: bool = False + self, + location_override: Optional[str] = None, + prediction_client: bool = False, + api_base_path_override: Optional[str] = None, ) -> client_options.ClientOptions: """Creates GAPIC client_options using location and type. Args: location_override (str): - Set this parameter to get client options for a location different from - location set by initializer. Must be a GCP region supported by AI - Platform (Unified). - prediction_client (str): Optional flag to use a prediction endpoint. - - + Optional. Set this parameter to get client options for a location different + from location set by initializer. Must be a GCP region supported by + Vertex AI. + prediction_client (str): Optional. flag to use a prediction endpoint. + api_base_path_override (str): Optional. Override default API base path. Returns: clients_options (google.api_core.client_options.ClientOptions): A ClientOptions object set with regionalized API endpoint, i.e. @@ -222,7 +224,7 @@ def get_client_options( utils.validate_region(region) - service_base_path = ( + service_base_path = api_base_path_override or ( constants.PREDICTION_API_BASE_PATH if prediction_client else constants.API_BASE_PATH @@ -261,17 +263,19 @@ def create_client( credentials: Optional[auth_credentials.Credentials] = None, location_override: Optional[str] = None, prediction_client: bool = False, + api_base_path_override: Optional[str] = None, ) -> utils.VertexAiServiceClientWithOverride: """Instantiates a given VertexAiServiceClient with optional overrides. Args: client_class (utils.VertexAiServiceClientWithOverride): - (Required) A Vertex AI Service Client with optional overrides. + Required. A Vertex AI Service Client with optional overrides. credentials (auth_credentials.Credentials): - Custom auth credentials. If not provided will use the current config. - location_override (str): Optional location override. - prediction_client (str): Optional flag to use a prediction endpoint. + Optional. Custom auth credentials. If not provided will use the current config. + location_override (str): Optional. location override. + prediction_client (str): Optional. flag to use a prediction endpoint. + api_base_path_override (str): Optional. Override default api base path. Returns: client: Instantiated Vertex AI Service client with optional overrides """ @@ -288,6 +292,7 @@ def create_client( "client_options": self.get_client_options( location_override=location_override, prediction_client=prediction_client, + api_base_path_override=api_base_path_override, ), "client_info": client_info, } diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py index fcb881a75c..79b3285a1a 100644 --- a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py @@ -29,7 +29,6 @@ from google.api_core import exceptions from google.cloud import aiplatform from google.cloud import storage -from google.cloud.aiplatform.constants import base as constants from google.cloud.aiplatform.utils import TensorboardClientWithOverride from google.cloud.aiplatform.tensorboard import uploader_utils from google.cloud.aiplatform.compat.types import tensorboard_experiment @@ -41,8 +40,6 @@ def _get_api_client() -> TensorboardClientWithOverride: """Creates an Tensorboard API client.""" - constants.API_BASE_PATH = training_utils.environment_variables.tensorboard_api_uri - m = re.match( "projects/.*/locations/(.*)/tensorboards/.*", training_utils.environment_variables.tensorboard_resource_name, @@ -50,7 +47,9 @@ def _get_api_client() -> TensorboardClientWithOverride: region = m[1] api_client = aiplatform.initializer.global_config.create_client( - client_class=TensorboardClientWithOverride, location_override=region, + client_class=TensorboardClientWithOverride, + location_override=region, + api_base_path_override=training_utils.environment_variables.tensorboard_api_uri, ) return api_client diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index f4043a5eba..e52dfef3aa 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -181,6 +181,15 @@ def test_get_client_options( == expected_endpoint ) + def test_get_client_options_with_api_override(self): + initializer.global_config.init(location="asia-east1") + + client_options = initializer.global_config.get_client_options( + api_base_path_override="override.googleapis.com" + ) + + assert client_options.api_endpoint == "asia-east1-override.googleapis.com" + class TestThreadPool: def teardown_method(self):