Skip to content

Commit

Permalink
feat: allow the prediction endpoint to be overridden
Browse files Browse the repository at this point in the history
  • Loading branch information
geraint0923 committed Jun 5, 2021
1 parent e7bf0d8 commit b48c0c3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
1 change: 1 addition & 0 deletions google/cloud/aiplatform/constants.py
Expand Up @@ -33,6 +33,7 @@
}

API_BASE_PATH = "aiplatform.googleapis.com"
PREDICTION_API_BASE_PATH = API_BASE_PATH

# Batch Prediction
BATCH_PREDICTION_INPUT_STORAGE_FORMATS = (
Expand Down
13 changes: 10 additions & 3 deletions google/cloud/aiplatform/initializer.py
Expand Up @@ -194,7 +194,7 @@ def encryption_spec_key_name(self) -> Optional[str]:
return self._encryption_spec_key_name

def get_client_options(
self, location_override: Optional[str] = None
self, location_override: Optional[str] = None, prediction_client: bool = False
) -> client_options.ClientOptions:
"""Creates GAPIC client_options using location and type.
Expand All @@ -220,8 +220,14 @@ def get_client_options(

utils.validate_region(region)

service_base_path = (
constants.PREDICTION_API_BASE_PATH
if prediction_client
else constants.API_BASE_PATH
)

return client_options.ClientOptions(
api_endpoint=f"{region}-{constants.API_BASE_PATH}"
api_endpoint=f"{region}-{service_base_path}"
)

def common_location_path(
Expand Down Expand Up @@ -278,7 +284,8 @@ def create_client(
kwargs = {
"credentials": credentials or self.credentials,
"client_options": self.get_client_options(
location_override=location_override
location_override=location_override,
prediction_client=prediction_client,
),
"client_info": client_info,
}
Expand Down

0 comments on commit b48c0c3

Please sign in to comment.