Skip to content

Commit

Permalink
Fixed mocks
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed May 4, 2021
1 parent 83b3268 commit 8b5a3f4
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 31 deletions.
48 changes: 27 additions & 21 deletions samples/model-builder/conftest.py
Expand Up @@ -143,6 +143,18 @@ def mock_custom_training_job():
yield mock


@pytest.fixture
def mock_custom_container_training_job():
mock = MagicMock(aiplatform.training_jobs.CustomContainerTrainingJob)
yield mock


@pytest.fixture
def mock_custom_package_training_job():
mock = MagicMock(aiplatform.training_jobs.CustomPythonPackageTrainingJob)
yield mock


@pytest.fixture
def mock_image_training_job():
mock = MagicMock(aiplatform.training_jobs.AutoMLImageTrainingJob)
Expand Down Expand Up @@ -194,47 +206,41 @@ def mock_run_automl_image_training_job(mock_image_training_job):


@pytest.fixture
def mock_init_custom_training_job():
with patch.object(aiplatform.CustomTrainingJob, "__init__") as mock:
mock.return_value = None
def mock_get_custom_training_job(mock_custom_training_job):
with patch.object(aiplatform, "CustomTrainingJob") as mock:
mock.return_value = mock_custom_training_job
yield mock


@pytest.fixture
def mock_run_custom_training_job():
with patch.object(aiplatform.CustomTrainingJob, "run") as mock:
def mock_get_custom_container_training_job(mock_custom_container_training_job):
with patch.object(aiplatform, "CustomContainerTrainingJob") as mock:
mock.return_value = mock_custom_container_training_job
yield mock


@pytest.fixture
def mock_init_custom_container_training_job():
with patch.object(
aiplatform.training_jobs.CustomContainerTrainingJob, "__init__"
) as mock:
mock.return_value = None
def mock_get_custom_package_training_job(mock_custom_package_training_job):
with patch.object(aiplatform, "CustomPythonPackageTrainingJob") as mock:
mock.return_value = mock_custom_package_training_job
yield mock


@pytest.fixture
def mock_run_custom_container_training_job():
with patch.object(aiplatform.CustomContainerTrainingJob, "run") as mock:
def mock_run_custom_training_job(mock_custom_training_job):
with patch.object(mock_custom_training_job, "run") as mock:
yield mock


@pytest.fixture
def mock_init_custom_package_training_job():
with patch.object(
aiplatform.training_jobs.CustomPythonPackageTrainingJob, "__init__"
) as mock:
mock.return_value = None
def mock_run_custom_container_training_job(mock_custom_container_training_job):
with patch.object(mock_custom_container_training_job, "run") as mock:
yield mock


@pytest.fixture
def mock_run_custom_package_training_job():
with patch.object(
aiplatform.training_jobs.CustomPythonPackageTrainingJob, "run"
) as mock:
def mock_run_custom_package_training_job(mock_custom_package_training_job):
with patch.object(mock_custom_package_training_job, "run") as mock:
yield mock


Expand Down
Expand Up @@ -21,7 +21,7 @@ def test_create_training_pipeline_custom_container_job_sample(
mock_sdk_init,
mock_image_dataset,
mock_get_image_dataset,
mock_init_custom_container_training_job,
mock_get_custom_container_training_job,
mock_run_custom_container_training_job,
):

Expand Down Expand Up @@ -50,7 +50,7 @@ def test_create_training_pipeline_custom_container_job_sample(
staging_bucket=constants.STAGING_BUCKET,
)

mock_init_custom_container_training_job.assert_called_once_with(
mock_get_custom_container_training_job.assert_called_once_with(
display_name=constants.DISPLAY_NAME,
container_uri=constants.CONTAINER_URI,
model_serving_container_image_uri=constants.CONTAINER_URI,
Expand Down
Expand Up @@ -21,6 +21,7 @@
def create_training_pipeline_custom_job_sample(
project: str,
location: str,
staging_bucket: str,
display_name: str,
script_path: str,
container_uri: str,
Expand All @@ -37,7 +38,7 @@ def create_training_pipeline_custom_job_sample(
test_fraction_split: float = 0.1,
sync: bool = True,
):
aiplatform.init(project=project, location=location)
aiplatform.init(project=project, location=location, staging_bucket=staging_bucket)

job = aiplatform.CustomTrainingJob(
display_name=display_name,
Expand Down
Expand Up @@ -21,13 +21,14 @@ def test_create_training_pipeline_custom_job_sample(
mock_sdk_init,
mock_image_dataset,
mock_get_image_dataset,
mock_init_custom_training_job,
mock_get_custom_training_job,
mock_run_custom_training_job,
):

create_training_pipeline_custom_job_sample.create_training_pipeline_custom_job_sample(
project=constants.PROJECT,
location=constants.LOCATION,
staging_bucket=constants.STAGING_BUCKET,
display_name=constants.DISPLAY_NAME,
args=constants.ARGS,
script_path=constants.SCRIPT_PATH,
Expand All @@ -45,9 +46,11 @@ def test_create_training_pipeline_custom_job_sample(
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
project=constants.PROJECT,
location=constants.LOCATION,
staging_bucket=constants.STAGING_BUCKET,
)
mock_init_custom_training_job.assert_called_once_with(
mock_get_custom_training_job.assert_called_once_with(
display_name=constants.DISPLAY_NAME,
script_path=constants.SCRIPT_PATH,
container_uri=constants.CONTAINER_URI,
Expand Down
Expand Up @@ -21,7 +21,7 @@ def test_create_training_pipeline_custom_package_job_sample(
mock_sdk_init,
mock_image_dataset,
mock_get_image_dataset,
mock_init_custom_package_training_job,
mock_get_custom_package_training_job,
mock_run_custom_package_training_job,
):

Expand Down Expand Up @@ -52,7 +52,7 @@ def test_create_training_pipeline_custom_package_job_sample(
staging_bucket=constants.STAGING_BUCKET,
)

mock_init_custom_package_training_job.assert_called_once_with(
mock_get_custom_package_training_job.assert_called_once_with(
display_name=constants.DISPLAY_NAME,
python_package_gcs_uri=constants.PYTHON_PACKAGE_GCS_URI,
python_module_name=constants.PYTHON_MODULE_NAME,
Expand Down
2 changes: 0 additions & 2 deletions samples/model-builder/test_constants.py
Expand Up @@ -159,8 +159,6 @@
"--file_system_poll_wait_seconds=31540000",
],
)
MODEL_SERVING_CONTAINER_PREDICT_ROUTE = (f"/v1/models/{MODEL_NAME}:predict",)
MODEL_SERVING_CONTAINER_HEALTH_ROUTE = f"/v1/models/{MODEL_NAME}"
PYTHON_PACKAGE_GCS_URI = (
"gs://bucket3/custom-training-python-package/my_app/trainer-0.1.tar.gz"
)
Expand Down

0 comments on commit 8b5a3f4

Please sign in to comment.