Skip to content

Commit

Permalink
fix: add support for API base path overriding (#908)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha-gitg committed Dec 14, 2021
1 parent 48c2bf1 commit 45c4086
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
29 changes: 17 additions & 12 deletions google/cloud/aiplatform/initializer.py
Expand Up @@ -194,18 +194,20 @@ def encryption_spec_key_name(self) -> Optional[str]:
return self._encryption_spec_key_name

def get_client_options(
self, location_override: Optional[str] = None, prediction_client: bool = False
self,
location_override: Optional[str] = None,
prediction_client: bool = False,
api_base_path_override: Optional[str] = None,
) -> client_options.ClientOptions:
"""Creates GAPIC client_options using location and type.
Args:
location_override (str):
Set this parameter to get client options for a location different from
location set by initializer. Must be a GCP region supported by AI
Platform (Unified).
prediction_client (str): Optional flag to use a prediction endpoint.
Optional. Set this parameter to get client options for a location different
from location set by initializer. Must be a GCP region supported by
Vertex AI.
prediction_client (str): Optional. flag to use a prediction endpoint.
api_base_path_override (str): Optional. Override default API base path.
Returns:
clients_options (google.api_core.client_options.ClientOptions):
A ClientOptions object set with regionalized API endpoint, i.e.
Expand All @@ -222,7 +224,7 @@ def get_client_options(

utils.validate_region(region)

service_base_path = (
service_base_path = api_base_path_override or (
constants.PREDICTION_API_BASE_PATH
if prediction_client
else constants.API_BASE_PATH
Expand Down Expand Up @@ -261,17 +263,19 @@ def create_client(
credentials: Optional[auth_credentials.Credentials] = None,
location_override: Optional[str] = None,
prediction_client: bool = False,
api_base_path_override: Optional[str] = None,
) -> utils.VertexAiServiceClientWithOverride:
"""Instantiates a given VertexAiServiceClient with optional
overrides.
Args:
client_class (utils.VertexAiServiceClientWithOverride):
(Required) A Vertex AI Service Client with optional overrides.
Required. A Vertex AI Service Client with optional overrides.
credentials (auth_credentials.Credentials):
Custom auth credentials. If not provided will use the current config.
location_override (str): Optional location override.
prediction_client (str): Optional flag to use a prediction endpoint.
Optional. Custom auth credentials. If not provided will use the current config.
location_override (str): Optional. location override.
prediction_client (str): Optional. flag to use a prediction endpoint.
api_base_path_override (str): Optional. Override default api base path.
Returns:
client: Instantiated Vertex AI Service client with optional overrides
"""
Expand All @@ -288,6 +292,7 @@ def create_client(
"client_options": self.get_client_options(
location_override=location_override,
prediction_client=prediction_client,
api_base_path_override=api_base_path_override,
),
"client_info": client_info,
}
Expand Down
Expand Up @@ -29,7 +29,6 @@
from google.api_core import exceptions
from google.cloud import aiplatform
from google.cloud import storage
from google.cloud.aiplatform.constants import base as constants
from google.cloud.aiplatform.utils import TensorboardClientWithOverride
from google.cloud.aiplatform.tensorboard import uploader_utils
from google.cloud.aiplatform.compat.types import tensorboard_experiment
Expand All @@ -41,16 +40,16 @@

def _get_api_client() -> TensorboardClientWithOverride:
"""Creates an Tensorboard API client."""
constants.API_BASE_PATH = training_utils.environment_variables.tensorboard_api_uri

m = re.match(
"projects/.*/locations/(.*)/tensorboards/.*",
training_utils.environment_variables.tensorboard_resource_name,
)
region = m[1]

api_client = aiplatform.initializer.global_config.create_client(
client_class=TensorboardClientWithOverride, location_override=region,
client_class=TensorboardClientWithOverride,
location_override=region,
api_base_path_override=training_utils.environment_variables.tensorboard_api_uri,
)

return api_client
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/aiplatform/test_initializer.py
Expand Up @@ -181,6 +181,15 @@ def test_get_client_options(
== expected_endpoint
)

def test_get_client_options_with_api_override(self):
initializer.global_config.init(location="asia-east1")

client_options = initializer.global_config.get_client_options(
api_base_path_override="override.googleapis.com"
)

assert client_options.api_endpoint == "asia-east1-override.googleapis.com"


class TestThreadPool:
def teardown_method(self):
Expand Down

0 comments on commit 45c4086

Please sign in to comment.