Skip to content

Commit

Permalink
fix: Set prediction client when listing Endpoints (#512)
Browse files Browse the repository at this point in the history
* fix: Set prediction client when listing Endpoints

* Address reviewer comments

* Remove redundant init() in TestEndpoints

* Update location passed to _instantiate_prediction_client()
  • Loading branch information
vinnysenthil committed Jun 30, 2021
1 parent b95e040 commit 95639ee
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 49 deletions.
11 changes: 6 additions & 5 deletions google/cloud/aiplatform/base.py
Expand Up @@ -819,8 +819,9 @@ def _sync_object_with_future_result(
if value:
setattr(self, attribute, value)

@classmethod
def _construct_sdk_resource_from_gapic(
self,
cls,
gapic_resource: proto.Message,
project: Optional[str] = None,
location: Optional[str] = None,
Expand All @@ -846,7 +847,7 @@ def _construct_sdk_resource_from_gapic(
VertexAiResourceNoun:
An initialized SDK object that represents GAPIC type.
"""
sdk_resource = self._empty_constructor(
sdk_resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
sdk_resource._gca_resource = gapic_resource
Expand Down Expand Up @@ -894,14 +895,14 @@ def _list(
Returns:
List[VertexAiResourceNoun] - A list of SDK resource objects
"""
self = cls._empty_constructor(
resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
)

# Fetch credentials once and re-use for all `_empty_constructor()` calls
creds = initializer.global_config.credentials

resource_list_method = getattr(self.api_client, self._list_method)
resource_list_method = getattr(resource.api_client, resource._list_method)

list_request = {
"parent": initializer.global_config.common_location_path(
Expand All @@ -916,7 +917,7 @@ def _list(
resource_list = resource_list_method(request=list_request) or []

return [
self._construct_sdk_resource_from_gapic(
cls._construct_sdk_resource_from_gapic(
gapic_resource, project=project, location=location, credentials=creds
)
for gapic_resource in resource_list
Expand Down
44 changes: 42 additions & 2 deletions google/cloud/aiplatform/models.py
Expand Up @@ -116,9 +116,9 @@ def __init__(
resource_name=endpoint_name,
)
self._gca_resource = self._get_gca_resource(resource_name=endpoint_name)

self._prediction_client = self._instantiate_prediction_client(
location=location or initializer.global_config.location,
credentials=credentials,
location=self.location, credentials=credentials,
)

@property
Expand Down Expand Up @@ -324,6 +324,46 @@ def _create(
credentials=credentials,
)

@classmethod
def _construct_sdk_resource_from_gapic(
cls,
gapic_resource: proto.Message,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Endpoint":
"""Given a GAPIC Endpoint object, return the SDK representation.
Args:
gapic_resource (proto.Message):
A GAPIC representation of a Endpoint resource, usually
retrieved by a get_* or in a list_* API call.
project (str):
Optional. Project to construct Endpoint object from. If not set,
project set in aiplatform.init will be used.
location (str):
Optional. Location to construct Endpoint object from. If not set,
location set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to construct Endpoint.
Overrides credentials set in aiplatform.init.
Returns:
Endpoint:
An initialized Endpoint resource.
"""
endpoint = cls._empty_constructor(
project=project, location=location, credentials=credentials
)

endpoint._gca_resource = gapic_resource

endpoint._prediction_client = cls._instantiate_prediction_client(
location=endpoint.location, credentials=credentials,
)

return endpoint

@staticmethod
def _allocate_traffic(
traffic_split: Dict[str, int], traffic_percentage: int,
Expand Down

0 comments on commit 95639ee

Please sign in to comment.