Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Lazy load Endpoint class #655

Merged
merged 16 commits into from Oct 7, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/base.py
Expand Up @@ -833,7 +833,7 @@ def _sync_object_with_future_result(
"_gca_resource",
"credentials",
]
optional_sync_attributes = ["_prediction_client"]
optional_sync_attributes = ["_prediction_client", "_endpoint_name"]

for attribute in sync_attributes:
setattr(self, attribute, getattr(result, attribute))
Expand Down
39 changes: 36 additions & 3 deletions google/cloud/aiplatform/models.py
Expand Up @@ -79,6 +79,7 @@ class Endpoint(base.VertexAiResourceNounWithFutureManager):

client_class = utils.EndpointClientWithOverride
_is_client_prediction_client = False
_skipped_getter_call = True # get_endpoint() has not been called
_resource_noun = "endpoints"
_getter_method = "get_endpoint"
_list_method = "list_endpoints"
Expand Down Expand Up @@ -115,12 +116,36 @@ def __init__(
credentials=credentials,
resource_name=endpoint_name,
)
self._gca_resource = self._get_gca_resource(resource_name=endpoint_name)

# Lazy load the Endpoint gca_resource until needed
self._endpoint_name = endpoint_name

self._prediction_client = self._instantiate_prediction_client(
location=self.location, credentials=credentials,
)

def _sync_gca_resource_if_skipped(self) -> None:
"""Sync GAPIC service representation of Endpoint class resource only if
get_endpoint() was never called."""
if self._skipped_getter_call:
self._gca_resource = self._get_gca_resource(
resource_name=self._endpoint_name
)
self._skipped_getter_call = False

def _sync_gca_resource(self) -> None:
"""Sync GAPIC service representation of Endpoint resource once."""
if self._skipped_getter_call:
self._sync_gca_resource_if_skipped()
else:
super()._sync_gca_resource()

def _assert_gca_resource_is_available(self) -> None:
"""Ensures Endpoint getter was called at least once before
asserting on gca_resource's availability."""
self._sync_gca_resource_if_skipped()
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
super()._assert_gca_resource_is_available()
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved

@property
def traffic_split(self) -> Dict[str, int]:
"""A map from a DeployedModel's ID to the percentage of this Endpoint's
Expand Down Expand Up @@ -320,8 +345,8 @@ def _create(

_LOGGER.log_create_complete(cls, created_endpoint, "endpoint")

return cls(
endpoint_name=created_endpoint.name,
return cls._construct_sdk_resource_from_gapic(
gapic_resource=created_endpoint,
project=project,
location=location,
credentials=credentials,
Expand Down Expand Up @@ -360,6 +385,10 @@ def _construct_sdk_resource_from_gapic(
)

endpoint._gca_resource = gapic_resource
endpoint._endpoint_name = endpoint._gca_resource.name

# Building Endpoint from GAPIC object is equivalent to calling getter
endpoint._skipped_getter_call = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is creating a new instance attribute but _skipped_getter_call is defined as a class attribute above. Perhaps, _skipped_getter_call should be an instance attribute?


endpoint._prediction_client = cls._instantiate_prediction_client(
location=endpoint.location, credentials=credentials,
Expand Down Expand Up @@ -627,6 +656,7 @@ def deploy(
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
"""
self._sync_gca_resource_if_skipped()

self._validate_deploy_args(
min_replica_count,
Expand Down Expand Up @@ -977,6 +1007,8 @@ def undeploy(
Optional. Strings which should be sent along with the request as
metadata.
"""
self._sync_gca_resource_if_skipped()

if traffic_split is not None:
if deployed_model_id in traffic_split and traffic_split[deployed_model_id]:
raise ValueError("Model being undeployed should have 0 traffic.")
Expand Down Expand Up @@ -1022,6 +1054,7 @@ def _undeploy(
Optional. Strings which should be sent along with the request as
metadata.
"""
self._sync_gca_resource_if_skipped()
current_traffic_split = traffic_split or dict(self._gca_resource.traffic_split)

if deployed_model_id in current_traffic_split:
Expand Down
65 changes: 50 additions & 15 deletions tests/unit/aiplatform/test_endpoints.py
Expand Up @@ -183,6 +183,15 @@ def get_endpoint_mock():
yield get_endpoint_mock


@pytest.fixture
def get_empty_endpoint_mock():
with mock.patch.object(
endpoint_service_client.EndpointServiceClient, "get_endpoint"
) as get_endpoint_mock:
get_endpoint_mock.return_value = gca_endpoint.Endpoint()
yield get_endpoint_mock


@pytest.fixture
def get_endpoint_alt_location_mock():
with mock.patch.object(
Expand Down Expand Up @@ -227,7 +236,9 @@ def create_endpoint_mock():
) as create_endpoint_mock:
create_endpoint_lro_mock = mock.Mock(ga_operation.Operation)
create_endpoint_lro_mock.result.return_value = gca_endpoint.Endpoint(
name=_TEST_ENDPOINT_NAME, display_name=_TEST_DISPLAY_NAME
name=_TEST_ENDPOINT_NAME,
display_name=_TEST_DISPLAY_NAME,
encryption_spec=_TEST_ENCRYPTION_SPEC,
)
create_endpoint_mock.return_value = create_endpoint_lro_mock
yield create_endpoint_mock
Expand Down Expand Up @@ -392,19 +403,35 @@ def test_constructor(self, create_endpoint_client_mock):
]
)

def test_constructor_with_endpoint_id(self, get_endpoint_mock):
models.Endpoint(_TEST_ID)
get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME)
def test_lazy_constructor_with_endpoint_id(self, get_endpoint_mock):
ep = models.Endpoint(_TEST_ID)
assert ep._endpoint_name == _TEST_ID
assert ep._skipped_getter_call
assert not get_endpoint_mock.called

def test_constructor_with_endpoint_name(self, get_endpoint_mock):
models.Endpoint(_TEST_ENDPOINT_NAME)
def test_lazy_constructor_with_endpoint_name(self, get_endpoint_mock):
ep = models.Endpoint(_TEST_ENDPOINT_NAME)
assert ep._endpoint_name == _TEST_ENDPOINT_NAME
assert ep._skipped_getter_call
assert not get_endpoint_mock.called

def test_lazy_constructor_calls_get_on_property_access(self, get_endpoint_mock):
ep = models.Endpoint(_TEST_ENDPOINT_NAME)
assert ep._endpoint_name == _TEST_ENDPOINT_NAME
assert ep._skipped_getter_call
assert not get_endpoint_mock.called

ep.display_name # Retrieve a property that requires a call to Endpoint getter
get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME)

def test_constructor_with_custom_project(self, get_endpoint_mock):
models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2)
def test_lazy_constructor_with_custom_project(self, get_endpoint_mock):
ep = models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2)
test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path(
_TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID
)
assert not get_endpoint_mock.called

ep.name # Retrieve a property that requires a call to Endpoint getter
get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name)

@pytest.mark.usefixtures("get_endpoint_mock")
Expand All @@ -420,11 +447,19 @@ def test_constructor_with_conflicting_location(self):
regexp=r"is provided, but different from the resource location"
)

def test_constructor_with_custom_location(self, get_endpoint_alt_location_mock):
models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2)
def test_lazy_constructor_with_custom_location(
self, get_endpoint_alt_location_mock
):
ep = models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2)
test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path(
_TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID
)

# Get Endpoint not called due to lazy loading
assert not get_endpoint_alt_location_mock.called

ep.network # Accessing a property that requires calling getter

get_endpoint_alt_location_mock.assert_called_with(
name=test_endpoint_resource_name
)
Expand Down Expand Up @@ -495,14 +530,14 @@ def test_create(self, create_endpoint_mock, sync):
)

expected_endpoint.name = _TEST_ENDPOINT_NAME
assert my_endpoint.gca_resource == expected_endpoint
assert my_endpoint.network is None
assert my_endpoint._gca_resource == expected_endpoint

@pytest.mark.usefixtures("get_endpoint_mock")
@pytest.mark.usefixtures("get_empty_endpoint_mock")
def test_accessing_properties_with_no_resource_raises(self,):
"""Ensure a descriptive RuntimeError is raised when the
GAPIC object has not been populated"""

my_endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)

my_endpoint._gca_resource = None

with pytest.raises(RuntimeError) as e:
Expand Down Expand Up @@ -923,7 +958,7 @@ def test_undeploy(self, undeploy_model_mock, sync):
traffic_split={"model1": 100},
)
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
assert dict(test_endpoint._gca_resource.traffic_split) == {"model1": 100}
assert dict(test_endpoint.traffic_split) == {"model1": 100}
test_endpoint.undeploy("model1", sync=sync)
if not sync:
test_endpoint.wait()
Expand Down