Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Lazy load Endpoint class #655

Merged
merged 16 commits into from Oct 7, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 43 additions & 4 deletions google/cloud/aiplatform/models.py
Expand Up @@ -115,12 +115,44 @@ def __init__(
credentials=credentials,
resource_name=endpoint_name,
)
self._gca_resource = self._get_gca_resource(resource_name=endpoint_name)

endpoint_name = utils.full_resource_name(
resource_name=endpoint_name,
resource_noun="endpoints",
project=project,
location=location,
)

# Lazy load the Endpoint gca_resource until needed
self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name)
self._skipped_getter_call = True

self._prediction_client = self._instantiate_prediction_client(
location=self.location, credentials=credentials,
)

def _sync_gca_resource_if_skipped(self) -> None:
"""Sync GAPIC service representation of Endpoint class resource only if
get_endpoint() was never called."""
if self._skipped_getter_call:
self._gca_resource = self._get_gca_resource(
resource_name=self._gca_resource.name
)
self._skipped_getter_call = False

def _sync_gca_resource(self) -> None:
"""Sync GAPIC service representation of Endpoint resource once."""
if self._skipped_getter_call:
self._sync_gca_resource_if_skipped()
else:
super()._sync_gca_resource()

def _assert_gca_resource_is_available(self) -> None:
"""Ensures Endpoint getter was called at least once before
asserting on gca_resource's availability."""
self._sync_gca_resource_if_skipped()
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
super()._assert_gca_resource_is_available()
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved

@property
def traffic_split(self) -> Dict[str, int]:
"""A map from a DeployedModel's ID to the percentage of this Endpoint's
Expand Down Expand Up @@ -320,8 +352,8 @@ def _create(

_LOGGER.log_create_complete(cls, created_endpoint, "endpoint")

return cls(
endpoint_name=created_endpoint.name,
return cls._construct_sdk_resource_from_gapic(
gapic_resource=created_endpoint,
project=project,
location=location,
credentials=credentials,
Expand Down Expand Up @@ -361,6 +393,9 @@ def _construct_sdk_resource_from_gapic(

endpoint._gca_resource = gapic_resource

# Building Endpoint from GAPIC object is equivalent to calling getter
endpoint._skipped_getter_call = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is creating a new instance attribute but _skipped_getter_call is defined as a class attribute above. Perhaps, _skipped_getter_call should be an instance attribute?


endpoint._prediction_client = cls._instantiate_prediction_client(
location=endpoint.location, credentials=credentials,
)
Expand Down Expand Up @@ -627,6 +662,7 @@ def deploy(
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
"""
self._sync_gca_resource_if_skipped()

self._validate_deploy_args(
min_replica_count,
Expand Down Expand Up @@ -977,6 +1013,8 @@ def undeploy(
Optional. Strings which should be sent along with the request as
metadata.
"""
self._sync_gca_resource_if_skipped()

if traffic_split is not None:
if deployed_model_id in traffic_split and traffic_split[deployed_model_id]:
raise ValueError("Model being undeployed should have 0 traffic.")
Expand Down Expand Up @@ -1021,6 +1059,7 @@ def _undeploy(
Optional. Strings which should be sent along with the request as
metadata.
"""
self._sync_gca_resource_if_skipped()
current_traffic_split = traffic_split or dict(self._gca_resource.traffic_split)

if deployed_model_id in current_traffic_split:
Expand Down Expand Up @@ -1105,7 +1144,7 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
self.wait()

prediction_response = self._prediction_client.predict(
endpoint=self.resource_name, instances=instances, parameters=parameters
endpoint=self._gca_resource.name, instances=instances, parameters=parameters
)

return Prediction(
Expand Down
43 changes: 17 additions & 26 deletions tests/system/aiplatform/test_e2e_tabular.py
Expand Up @@ -34,6 +34,16 @@
_LOCAL_TRAINING_SCRIPT_PATH = os.path.join(
_DIR_NAME, "test_resources/california_housing_training_script.py"
)
_INSTANCE = {
"longitude": -124.35,
"latitude": 40.54,
"housing_median_age": 52.0,
"total_rooms": 1820.0,
"total_bedrooms": 300.0,
"population": 806,
"households": 270.0,
"median_income": 3.014700,
}


@pytest.mark.usefixtures("prepare_staging_bucket", "delete_staging_bucket", "teardown")
Expand Down Expand Up @@ -136,39 +146,20 @@ def test_end_to_end_tabular(self, shared_state):
# Send online prediction with same instance to both deployed models
# This sample is taken from an observation where median_house_value = 94600
custom_endpoint.wait()
custom_prediction = custom_endpoint.predict(
[
{
"longitude": -124.35,
"latitude": 40.54,
"housing_median_age": 52.0,
"total_rooms": 1820.0,
"total_bedrooms": 300.0,
"population": 806,
"households": 270.0,
"median_income": 3.014700,
},
]
)
custom_prediction = custom_endpoint.predict([_INSTANCE])

custom_batch_prediction_job.wait()

automl_endpoint.wait()
automl_prediction = automl_endpoint.predict(
[
{
"longitude": "-124.35",
"latitude": "40.54",
"housing_median_age": "52.0",
"total_rooms": "1820.0",
"total_bedrooms": "300.0",
"population": "806",
"households": "270.0",
"median_income": "3.014700",
},
]
[{k: str(v) for k, v in _INSTANCE.items()}] # Cast int values to strings
)

# Test lazy loading of Endpoint, check getter was never called after predict()
custom_endpoint = aiplatform.Endpoint(custom_endpoint.resource_name)
custom_endpoint.predict([_INSTANCE])
assert custom_endpoint._skipped_getter_call

assert (
custom_job.state
== gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
Expand Down
66 changes: 51 additions & 15 deletions tests/unit/aiplatform/test_endpoints.py
Expand Up @@ -183,6 +183,15 @@ def get_endpoint_mock():
yield get_endpoint_mock


@pytest.fixture
def get_empty_endpoint_mock():
with mock.patch.object(
endpoint_service_client.EndpointServiceClient, "get_endpoint"
) as get_endpoint_mock:
get_endpoint_mock.return_value = gca_endpoint.Endpoint(name=_TEST_ENDPOINT_NAME)
yield get_endpoint_mock


@pytest.fixture
def get_endpoint_alt_location_mock():
with mock.patch.object(
Expand Down Expand Up @@ -227,7 +236,9 @@ def create_endpoint_mock():
) as create_endpoint_mock:
create_endpoint_lro_mock = mock.Mock(ga_operation.Operation)
create_endpoint_lro_mock.result.return_value = gca_endpoint.Endpoint(
name=_TEST_ENDPOINT_NAME, display_name=_TEST_DISPLAY_NAME
name=_TEST_ENDPOINT_NAME,
display_name=_TEST_DISPLAY_NAME,
encryption_spec=_TEST_ENCRYPTION_SPEC,
)
create_endpoint_mock.return_value = create_endpoint_lro_mock
yield create_endpoint_mock
Expand Down Expand Up @@ -392,19 +403,35 @@ def test_constructor(self, create_endpoint_client_mock):
]
)

def test_constructor_with_endpoint_id(self, get_endpoint_mock):
models.Endpoint(_TEST_ID)
get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME)
def test_lazy_constructor_with_endpoint_id(self, get_endpoint_mock):
ep = models.Endpoint(_TEST_ID)
assert ep._gca_resource.name == _TEST_ENDPOINT_NAME
assert ep._skipped_getter_call
assert not get_endpoint_mock.called

def test_constructor_with_endpoint_name(self, get_endpoint_mock):
models.Endpoint(_TEST_ENDPOINT_NAME)
def test_lazy_constructor_with_endpoint_name(self, get_endpoint_mock):
ep = models.Endpoint(_TEST_ENDPOINT_NAME)
assert ep._gca_resource.name == _TEST_ENDPOINT_NAME
assert ep._skipped_getter_call
assert not get_endpoint_mock.called

def test_lazy_constructor_calls_get_on_property_access(self, get_endpoint_mock):
ep = models.Endpoint(_TEST_ENDPOINT_NAME)
assert ep._gca_resource.name == _TEST_ENDPOINT_NAME
assert ep._skipped_getter_call
assert not get_endpoint_mock.called

ep.display_name # Retrieve a property that requires a call to Endpoint getter
get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME)

def test_constructor_with_custom_project(self, get_endpoint_mock):
models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2)
def test_lazy_constructor_with_custom_project(self, get_endpoint_mock):
ep = models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2)
test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path(
_TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID
)
assert not get_endpoint_mock.called

ep.name # Retrieve a property that requires a call to Endpoint getter
get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name)

@pytest.mark.usefixtures("get_endpoint_mock")
Expand All @@ -420,11 +447,19 @@ def test_constructor_with_conflicting_location(self):
regexp=r"is provided, but different from the resource location"
)

def test_constructor_with_custom_location(self, get_endpoint_alt_location_mock):
models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2)
def test_lazy_constructor_with_custom_location(
self, get_endpoint_alt_location_mock
):
ep = models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2)
test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path(
_TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID
)

# Get Endpoint not called due to lazy loading
assert not get_endpoint_alt_location_mock.called

ep.network # Accessing a property that requires calling getter

get_endpoint_alt_location_mock.assert_called_with(
name=test_endpoint_resource_name
)
Expand Down Expand Up @@ -495,15 +530,16 @@ def test_create(self, create_endpoint_mock, sync):
)

expected_endpoint.name = _TEST_ENDPOINT_NAME
assert my_endpoint.gca_resource == expected_endpoint
assert my_endpoint.network is None
assert my_endpoint._gca_resource == expected_endpoint

@pytest.mark.usefixtures("get_endpoint_mock")
@pytest.mark.usefixtures("get_empty_endpoint_mock")
def test_accessing_properties_with_no_resource_raises(self,):
"""Ensure a descriptive RuntimeError is raised when the
GAPIC object has not been populated"""

my_endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)

my_endpoint._gca_resource = None
my_endpoint._skipped_getter_call = False

with pytest.raises(RuntimeError) as e:
my_endpoint.gca_resource
Expand Down Expand Up @@ -923,7 +959,7 @@ def test_undeploy(self, undeploy_model_mock, sync):
traffic_split={"model1": 100},
)
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
assert dict(test_endpoint._gca_resource.traffic_split) == {"model1": 100}
assert dict(test_endpoint.traffic_split) == {"model1": 100}
test_endpoint.undeploy("model1", sync=sync)
if not sync:
test_endpoint.wait()
Expand Down