Skip to content

Commit

Permalink
fix: Add retries when polling during monitoring runs (#786)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha-gitg committed Oct 27, 2021
1 parent 78879e2 commit 45401c0
Show file tree
Hide file tree
Showing 20 changed files with 288 additions and 111 deletions.
8 changes: 7 additions & 1 deletion google/cloud/aiplatform/base.py
Expand Up @@ -39,6 +39,7 @@

import proto

from google.api_core import retry
from google.api_core import operation
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import initializer
Expand All @@ -48,6 +49,9 @@

logging.basicConfig(level=logging.INFO, stream=sys.stdout)

# This is the default retry callback to be used with get methods.
_DEFAULT_RETRY = retry.Retry()


class Logger:
"""Logging wrapper class with high level helper methods."""
Expand Down Expand Up @@ -532,7 +536,9 @@ def _get_gca_resource(self, resource_name: str) -> proto.Message:
location=self.location,
)

return getattr(self.api_client, self._getter_method)(name=resource_name)
return getattr(self.api_client, self._getter_method)(
name=resource_name, retry=_DEFAULT_RETRY
)

def _sync_gca_resource(self):
"""Sync GAPIC service representation of client class resource."""
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/metadata/resource.py
Expand Up @@ -94,7 +94,7 @@ def __init__(
)

self._gca_resource = getattr(self.api_client, self._getter_method)(
name=full_resource_name
name=full_resource_name, retry=base._DEFAULT_RETRY
)

@property
Expand Down
Expand Up @@ -3,6 +3,7 @@
from unittest import mock

from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
Expand Down Expand Up @@ -301,7 +302,9 @@ def test_run_call_pipeline_service_create(

assert job._gca_resource is mock_pipeline_service_get.return_value

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
mock_model_service_get.assert_called_once_with(
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
)

assert model_from_job._gca_resource is mock_model_service_get.return_value

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/aiplatform/test_automl_image_training_jobs.py
Expand Up @@ -6,7 +6,7 @@
from google.protobuf import struct_pb2

from google.cloud import aiplatform

from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import models
Expand Down Expand Up @@ -309,7 +309,9 @@ def test_run_call_pipeline_service_create(
training_pipeline=true_training_pipeline,
)

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
mock_model_service_get.assert_called_once_with(
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
)
assert job._gca_resource is mock_pipeline_service_get.return_value
assert model_from_job._gca_resource is mock_model_service_get.return_value
assert job.get_model()._gca_resource is mock_model_service_get.return_value
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/aiplatform/test_automl_tabular_training_jobs.py
Expand Up @@ -3,7 +3,7 @@
from unittest import mock

from google.cloud import aiplatform

from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
Expand Down Expand Up @@ -367,7 +367,9 @@ def test_run_call_pipeline_service_create(

assert job._gca_resource is mock_pipeline_service_get.return_value

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
mock_model_service_get.assert_called_once_with(
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
)

assert model_from_job._gca_resource is mock_model_service_get.return_value

Expand Down Expand Up @@ -446,7 +448,9 @@ def test_run_call_pipeline_service_create_with_export_eval_data_items(

assert job._gca_resource is mock_pipeline_service_get.return_value

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
mock_model_service_get.assert_called_once_with(
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
)

assert model_from_job._gca_resource is mock_model_service_get.return_value

Expand Down
14 changes: 10 additions & 4 deletions tests/unit/aiplatform/test_automl_text_training_jobs.py
Expand Up @@ -3,7 +3,7 @@
from unittest import mock

from google.cloud import aiplatform

from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import models
Expand Down Expand Up @@ -370,7 +370,9 @@ def test_run_call_pipeline_service_create_classification(
training_pipeline=true_training_pipeline,
)

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
mock_model_service_get.assert_called_once_with(
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
)
assert job._gca_resource is mock_pipeline_service_get.return_value
assert model_from_job._gca_resource is mock_model_service_get.return_value
assert job.get_model()._gca_resource is mock_model_service_get.return_value
Expand Down Expand Up @@ -437,7 +439,9 @@ def test_run_call_pipeline_service_create_extraction(
training_pipeline=true_training_pipeline,
)

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
mock_model_service_get.assert_called_once_with(
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
)
assert job._gca_resource is mock_pipeline_service_get.return_value
assert model_from_job._gca_resource is mock_model_service_get.return_value
assert job.get_model()._gca_resource is mock_model_service_get.return_value
Expand Down Expand Up @@ -505,7 +509,9 @@ def test_run_call_pipeline_service_create_sentiment(
training_pipeline=true_training_pipeline,
)

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
mock_model_service_get.assert_called_once_with(
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
)
assert job._gca_resource is mock_pipeline_service_get.return_value
assert model_from_job._gca_resource is mock_model_service_get.return_value
assert job.get_model()._gca_resource is mock_model_service_get.return_value
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/aiplatform/test_automl_video_training_jobs.py
Expand Up @@ -6,7 +6,7 @@
from google.protobuf import struct_pb2

from google.cloud import aiplatform

from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import models
Expand Down Expand Up @@ -271,7 +271,9 @@ def test_init_aiplatform_with_encryption_key_name_and_create_training_job(
training_pipeline=true_training_pipeline,
)

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
mock_model_service_get.assert_called_once_with(
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
)
assert job._gca_resource is mock_pipeline_service_get.return_value
assert model_from_job._gca_resource is mock_model_service_get.return_value
assert job.get_model()._gca_resource is mock_model_service_get.return_value
Expand Down Expand Up @@ -538,7 +540,9 @@ def test_run_call_pipeline_service_create(
training_pipeline=true_training_pipeline,
)

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
mock_model_service_get.assert_called_once_with(
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
)
assert job._gca_resource is mock_pipeline_service_get.return_value
assert model_from_job._gca_resource is mock_model_service_get.return_value
assert job.get_model()._gca_resource is mock_model_service_get.return_value
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/aiplatform/test_custom_job.py
Expand Up @@ -29,6 +29,7 @@
from test_training_jobs import mock_python_package_to_gcs # noqa: F401

from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform.compat.types import custom_job as gca_custom_job_compat
from google.cloud.aiplatform.compat.types import (
custom_job_v1beta1 as gca_custom_job_v1beta1,
Expand Down Expand Up @@ -447,7 +448,9 @@ def test_get_custom_job(self, get_custom_job_mock):

job = aiplatform.CustomJob.get(_TEST_CUSTOM_JOB_NAME)

get_custom_job_mock.assert_called_once_with(name=_TEST_CUSTOM_JOB_NAME)
get_custom_job_mock.assert_called_once_with(
name=_TEST_CUSTOM_JOB_NAME, retry=base._DEFAULT_RETRY
)
assert (
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
)
Expand Down
66 changes: 48 additions & 18 deletions tests/unit/aiplatform/test_datasets.py
Expand Up @@ -28,13 +28,13 @@
from google.auth import credentials as auth_credentials

from google.cloud import aiplatform
from google.cloud import bigquery
from google.cloud import storage

from google.cloud.aiplatform import base
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud import bigquery
from google.cloud import storage

from google.cloud.aiplatform_v1.services.dataset_service import (
client as dataset_service_client,
Expand Down Expand Up @@ -474,7 +474,9 @@ def teardown_method(self):
def test_init_dataset(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT)
datasets._Dataset(dataset_name=_TEST_NAME)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

def test_init_dataset_with_id_only_with_project_and_location(
self, get_dataset_mock
Expand All @@ -483,21 +485,27 @@ def test_init_dataset_with_id_only_with_project_and_location(
datasets._Dataset(
dataset_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION
)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

def test_init_dataset_with_project_and_location(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT)
datasets._Dataset(
dataset_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION
)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

def test_init_dataset_with_alt_project_and_location(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT)
datasets._Dataset(
dataset_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION
)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

def test_init_dataset_with_alt_location(self, get_dataset_tabular_gcs_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_ALT_LOCATION)
Expand All @@ -511,7 +519,9 @@ def test_init_dataset_with_alt_location(self, get_dataset_tabular_gcs_mock):

assert _TEST_ALT_LOCATION != _TEST_LOCATION

get_dataset_tabular_gcs_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_tabular_gcs_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

def test_init_dataset_with_project_and_alt_location(self):
aiplatform.init(project=_TEST_PROJECT)
Expand All @@ -525,7 +535,9 @@ def test_init_dataset_with_project_and_alt_location(self):
def test_init_dataset_with_id_only(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
datasets._Dataset(dataset_name=_TEST_ID)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

@pytest.mark.usefixtures("get_dataset_without_name_mock")
@patch.dict(
Expand All @@ -541,7 +553,9 @@ def test_init_dataset_with_id_only_without_project_or_location(self):
def test_init_dataset_with_location_override(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
datasets._Dataset(dataset_name=_TEST_ID, location=_TEST_ALT_LOCATION)
get_dataset_mock.assert_called_once_with(name=_TEST_ALT_NAME)
get_dataset_mock.assert_called_once_with(
name=_TEST_ALT_NAME, retry=base._DEFAULT_RETRY
)

@pytest.mark.usefixtures("get_dataset_mock")
def test_init_dataset_with_invalid_name(self):
Expand Down Expand Up @@ -764,7 +778,9 @@ def test_create_then_import(
metadata=_TEST_REQUEST_METADATA,
)

get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

import_data_mock.assert_called_once_with(
name=_TEST_NAME, import_configs=[expected_import_config]
Expand Down Expand Up @@ -798,7 +814,9 @@ def teardown_method(self):
def test_init_dataset_image(self, get_dataset_image_mock):
aiplatform.init(project=_TEST_PROJECT)
datasets.ImageDataset(dataset_name=_TEST_NAME)
get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_image_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
def test_init_dataset_non_image(self):
Expand Down Expand Up @@ -934,7 +952,9 @@ def test_create_then_import(
metadata=_TEST_REQUEST_METADATA,
)

get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_image_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

expected_import_config = gca_dataset.ImportDataConfig(
gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),
Expand Down Expand Up @@ -989,7 +1009,9 @@ def teardown_method(self):
def test_init_dataset_tabular(self, get_dataset_tabular_bq_mock):

datasets.TabularDataset(dataset_name=_TEST_NAME)
get_dataset_tabular_bq_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_tabular_bq_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

@pytest.mark.usefixtures("get_dataset_image_mock")
def test_init_dataset_non_tabular(self):
Expand Down Expand Up @@ -1236,7 +1258,9 @@ def teardown_method(self):
def test_init_dataset_text(self, get_dataset_text_mock):
aiplatform.init(project=_TEST_PROJECT)
datasets.TextDataset(dataset_name=_TEST_NAME)
get_dataset_text_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_text_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

@pytest.mark.usefixtures("get_dataset_image_mock")
def test_init_dataset_non_text(self):
Expand Down Expand Up @@ -1409,7 +1433,9 @@ def test_create_then_import(
metadata=_TEST_REQUEST_METADATA,
)

get_dataset_text_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_text_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

expected_import_config = gca_dataset.ImportDataConfig(
gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),
Expand Down Expand Up @@ -1463,7 +1489,9 @@ def teardown_method(self):
def test_init_dataset_video(self, get_dataset_video_mock):
aiplatform.init(project=_TEST_PROJECT)
datasets.VideoDataset(dataset_name=_TEST_NAME)
get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_video_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
def test_init_dataset_non_video(self):
Expand Down Expand Up @@ -1599,7 +1627,9 @@ def test_create_then_import(
metadata=_TEST_REQUEST_METADATA,
)

get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_video_mock.assert_called_once_with(
name=_TEST_NAME, retry=base._DEFAULT_RETRY
)

expected_import_config = gca_dataset.ImportDataConfig(
gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),
Expand Down

0 comments on commit 45401c0

Please sign in to comment.