Skip to content

Commit

Permalink
feat: allow the prediction endpoint to be overridden (#461)
Browse files Browse the repository at this point in the history
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)

Fixes #<issue_number_goes_here> 🦕
  • Loading branch information
geraint0923 committed Jun 10, 2021
1 parent 8cfd611 commit c2cf612
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions google/cloud/aiplatform/constants.py
Expand Up @@ -33,6 +33,7 @@
}

API_BASE_PATH = "aiplatform.googleapis.com"
PREDICTION_API_BASE_PATH = API_BASE_PATH

# Batch Prediction
BATCH_PREDICTION_INPUT_STORAGE_FORMATS = (
Expand Down
15 changes: 12 additions & 3 deletions google/cloud/aiplatform/initializer.py
Expand Up @@ -194,7 +194,7 @@ def encryption_spec_key_name(self) -> Optional[str]:
return self._encryption_spec_key_name

def get_client_options(
self, location_override: Optional[str] = None
self, location_override: Optional[str] = None, prediction_client: bool = False
) -> client_options.ClientOptions:
"""Creates GAPIC client_options using location and type.
Expand All @@ -203,6 +203,8 @@ def get_client_options(
Set this parameter to get client options for a location different from
location set by initializer. Must be a GCP region supported by AI
Platform (Unified).
prediction_client (str): Optional flag to use a prediction endpoint.
Returns:
clients_options (google.api_core.client_options.ClientOptions):
Expand All @@ -220,8 +222,14 @@ def get_client_options(

utils.validate_region(region)

service_base_path = (
constants.PREDICTION_API_BASE_PATH
if prediction_client
else constants.API_BASE_PATH
)

return client_options.ClientOptions(
api_endpoint=f"{region}-{constants.API_BASE_PATH}"
api_endpoint=f"{region}-{service_base_path}"
)

def common_location_path(
Expand Down Expand Up @@ -278,7 +286,8 @@ def create_client(
kwargs = {
"credentials": credentials or self.credentials,
"client_options": self.get_client_options(
location_override=location_override
location_override=location_override,
prediction_client=prediction_client,
),
"client_info": client_info,
}
Expand Down

0 comments on commit c2cf612

Please sign in to comment.