From cc1a7084f7715c94657d5a3b3374c0fc9a86a299 Mon Sep 17 00:00:00 2001 From: Andrew Ferlitsch Date: Wed, 12 May 2021 14:08:19 -0700 Subject: [PATCH] feat: Add AutoML vision, Custom training job, and generic prediction samples (#300) * debug mock issue * new mock * more samples * more samples * add next sample/test * add sample/test * run black * Add new Dataset import mocks, fix MBSDK sample tests * Add license headers, update Endpoint mocks/usage * type updates * sasha comment fixes * fix test errors after review update * fix: type for instances * Lint SDK samples * Fix flake8 import order nits Co-authored-by: Vinny Senthil Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> --- ...mage_classification_training_job_sample.py | 54 ++++++++++++++++++ ...classification_training_job_sample_test.py | 56 +++++++++++++++++++ samples/model-builder/conftest.py | 19 +++++++ .../custom_training_job_sample.py | 49 ++++++++++++++++ .../custom_training_job_sample_test.py | 50 +++++++++++++++++ .../model-builder/endpoint_predict_sample.py | 32 +++++++++++ .../endpoint_predict_sample_test.py | 37 ++++++++++++ ...ge_dataset_create_classification_sample.py | 38 +++++++++++++ ...taset_create_classification_sample_test.py | 40 +++++++++++++ ..._dataset_create_object_detection_sample.py | 38 +++++++++++++ ...set_create_object_detection_sample_test.py | 41 ++++++++++++++ .../image_dataset_create_sample.py | 31 ++++++++++ .../image_dataset_create_sample_test.py | 32 +++++++++++ .../image_dataset_import_data_sample.py | 37 ++++++++++++ .../image_dataset_import_data_sample_test.py | 40 +++++++++++++ samples/model-builder/test_constants.py | 5 ++ 16 files changed, 599 insertions(+) create mode 100644 samples/model-builder/automl_image_classification_training_job_sample.py create mode 100644 samples/model-builder/automl_image_classification_training_job_sample_test.py create mode 100644 samples/model-builder/custom_training_job_sample.py create mode 100644 samples/model-builder/custom_training_job_sample_test.py create mode 100644 samples/model-builder/endpoint_predict_sample.py create mode 100644 samples/model-builder/endpoint_predict_sample_test.py create mode 100644 samples/model-builder/image_dataset_create_classification_sample.py create mode 100644 samples/model-builder/image_dataset_create_classification_sample_test.py create mode 100644 samples/model-builder/image_dataset_create_object_detection_sample.py create mode 100644 samples/model-builder/image_dataset_create_object_detection_sample_test.py create mode 100644 samples/model-builder/image_dataset_create_sample.py create mode 100644 samples/model-builder/image_dataset_create_sample_test.py create mode 100644 samples/model-builder/image_dataset_import_data_sample.py create mode 100644 samples/model-builder/image_dataset_import_data_sample_test.py diff --git a/samples/model-builder/automl_image_classification_training_job_sample.py b/samples/model-builder/automl_image_classification_training_job_sample.py new file mode 100644 index 0000000000..502caf008d --- /dev/null +++ b/samples/model-builder/automl_image_classification_training_job_sample.py @@ -0,0 +1,54 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_automl_image_classification_training_job_sample] +def automl_image_classification_training_job_sample( + project: str, location: str, dataset_id: str, display_name: str, +): + aiplatform.init(project=project, location=location) + + dataset = aiplatform.ImageDataset(dataset_id) + + job = aiplatform.AutoMLImageTrainingJob( + display_name=display_name, + prediction_type="classification", + multi_label=False, + model_type="CLOUD", + base_model=None, + ) + + model = job.run( + dataset=dataset, + model_display_name=display_name, + training_fraction_split=0.6, + validation_fraction_split=0.2, + test_fraction_split=0.2, + budget_milli_node_hours=8000, + disable_early_stopping=False, + ) + + print(model.display_name) + print(model.name) + print(model.resource_name) + print(model.description) + print(model.uri) + + return model + + +# [END aiplatform_sdk_automl_image_classification_training_job_sample] diff --git a/samples/model-builder/automl_image_classification_training_job_sample_test.py b/samples/model-builder/automl_image_classification_training_job_sample_test.py new file mode 100644 index 0000000000..a402340f77 --- /dev/null +++ b/samples/model-builder/automl_image_classification_training_job_sample_test.py @@ -0,0 +1,56 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import automl_image_classification_training_job_sample +import test_constants as constants + + +def test_automl_image_classification_training_job_sample( + mock_sdk_init, + mock_image_dataset, + mock_get_image_dataset, + mock_get_automl_image_training_job, + mock_run_automl_image_training_job, +): + automl_image_classification_training_job_sample.automl_image_classification_training_job_sample( + project=constants.PROJECT, + location=constants.LOCATION, + dataset_id=constants.DATASET_NAME, + display_name=constants.DISPLAY_NAME, + ) + + mock_get_image_dataset.assert_called_once_with(constants.DATASET_NAME) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_automl_image_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + base_model=None, + model_type="CLOUD", + multi_label=False, + prediction_type="classification", + ) + + mock_run_automl_image_training_job.assert_called_once_with( + budget_milli_node_hours=8000, + disable_early_stopping=False, + test_fraction_split=0.2, + training_fraction_split=0.6, + validation_fraction_split=0.2, + model_display_name=constants.DISPLAY_NAME, + dataset=mock_image_dataset, + ) diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 01756f668b..d8c2ed239d 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -124,6 +124,18 @@ def mock_create_video_dataset(mock_video_dataset): """Mocks for SomeDataset.import_data() """ +@pytest.fixture +def mock_import_image_dataset(mock_image_dataset): + with patch.object(mock_image_dataset, "import_data") as mock: + yield mock + + +@pytest.fixture +def mock_import_tabular_dataset(mock_tabular_dataset): + with patch.object(mock_tabular_dataset, "import_data") as mock: + yield mock + + @pytest.fixture def mock_import_text_dataset(mock_text_dataset): with patch.object(mock_text_dataset, "import_data") as mock: @@ -327,6 +339,13 @@ def mock_get_endpoint(mock_endpoint): yield mock_get_endpoint +@pytest.fixture +def mock_endpoint_predict(mock_endpoint): + with patch.object(mock_endpoint, "predict") as mock: + mock.return_value = [] + yield mock + + @pytest.fixture def mock_endpoint_explain(mock_endpoint): with patch.object(mock_endpoint, "explain") as mock_endpoint_explain: diff --git a/samples/model-builder/custom_training_job_sample.py b/samples/model-builder/custom_training_job_sample.py new file mode 100644 index 0000000000..14c874e3a5 --- /dev/null +++ b/samples/model-builder/custom_training_job_sample.py @@ -0,0 +1,49 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_custom_training_job_sample] +def custom_training_job_sample( + project: str, + location: str, + bucket: str, + display_name: str, + script_path: str, + script_args: str, + container_uri: str, + model_serving_container_image_uri: str, + requirements: str, + replica_count: int, +): + aiplatform.init(project=project, location=location, staging_bucket=bucket) + + job = aiplatform.CustomTrainingJob( + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + requirements=requirements, + model_serving_container_image_uri=model_serving_container_image_uri, + ) + + model = job.run( + args=script_args, replica_count=replica_count, model_display_name=display_name + ) + + return model + + +# [END aiplatform_sdk_custom_training_job_sample] diff --git a/samples/model-builder/custom_training_job_sample_test.py b/samples/model-builder/custom_training_job_sample_test.py new file mode 100644 index 0000000000..40d12fb332 --- /dev/null +++ b/samples/model-builder/custom_training_job_sample_test.py @@ -0,0 +1,50 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import custom_training_job_sample +import test_constants as constants + + +def test_custom_training_job_sample( + mock_sdk_init, mock_get_custom_training_job, mock_run_custom_training_job +): + custom_training_job_sample.custom_training_job_sample( + project=constants.PROJECT, + location=constants.LOCATION, + bucket=constants.STAGING_BUCKET, + display_name=constants.DISPLAY_NAME, + script_path=constants.PYTHON_PACKAGE, + script_args=constants.PYTHON_PACKAGE_CMDARGS, + container_uri=constants.TRAIN_IMAGE, + model_serving_container_image_uri=constants.DEPLOY_IMAGE, + requirements=[], + replica_count=1, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, + location=constants.LOCATION, + staging_bucket=constants.STAGING_BUCKET, + ) + + mock_get_custom_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + container_uri=constants.TRAIN_IMAGE, + model_serving_container_image_uri=constants.DEPLOY_IMAGE, + requirements=[], + script_path=constants.PYTHON_PACKAGE, + ) + + mock_run_custom_training_job.assert_called_once() diff --git a/samples/model-builder/endpoint_predict_sample.py b/samples/model-builder/endpoint_predict_sample.py new file mode 100644 index 0000000000..98b7450c51 --- /dev/null +++ b/samples/model-builder/endpoint_predict_sample.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_endpoint_predict_sample] +def endpoint_predict_sample( + project: str, location: str, instances: list, endpoint: str +): + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint) + + prediction = endpoint.predict(instances=instances) + print(prediction) + return prediction + + +# [END aiplatform_sdk_endpoint_predict_sample] diff --git a/samples/model-builder/endpoint_predict_sample_test.py b/samples/model-builder/endpoint_predict_sample_test.py new file mode 100644 index 0000000000..8c2d4e8e10 --- /dev/null +++ b/samples/model-builder/endpoint_predict_sample_test.py @@ -0,0 +1,37 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import endpoint_predict_sample +import test_constants as constants + + +def test_endpoint_predict_sample( + mock_sdk_init, mock_endpoint_predict, mock_get_endpoint +): + + endpoint_predict_sample.endpoint_predict_sample( + project=constants.PROJECT, + location=constants.LOCATION, + instances=[], + endpoint=constants.ENDPOINT_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME) + + mock_endpoint_predict.assert_called_once_with(instances=[]) diff --git a/samples/model-builder/image_dataset_create_classification_sample.py b/samples/model-builder/image_dataset_create_classification_sample.py new file mode 100644 index 0000000000..ca53cfb7d2 --- /dev/null +++ b/samples/model-builder/image_dataset_create_classification_sample.py @@ -0,0 +1,38 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_image_dataset_create_classification_sample] +def image_dataset_create_classification_sample( + project: str, location: str, display_name: str, src_uris: list +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset.create( + display_name=display_name, + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.image.single_label_classification, + ) + + print(ds.display_name) + print(ds.name) + print(ds.resource_name) + print(ds.metadata_schema_uri) + return ds + + +# [END aiplatform_sdk_image_dataset_create_classification_sample] diff --git a/samples/model-builder/image_dataset_create_classification_sample_test.py b/samples/model-builder/image_dataset_create_classification_sample_test.py new file mode 100644 index 0000000000..0627d26339 --- /dev/null +++ b/samples/model-builder/image_dataset_create_classification_sample_test.py @@ -0,0 +1,40 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import image_dataset_create_classification_sample + +import test_constants as constants + + +def test_image_dataset_create_classification_sample( + mock_sdk_init, mock_create_image_dataset +): + image_dataset_create_classification_sample.image_dataset_create_classification_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_image_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.image.single_label_classification, + ) diff --git a/samples/model-builder/image_dataset_create_object_detection_sample.py b/samples/model-builder/image_dataset_create_object_detection_sample.py new file mode 100644 index 0000000000..cdcdca009e --- /dev/null +++ b/samples/model-builder/image_dataset_create_object_detection_sample.py @@ -0,0 +1,38 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_image_dataset_create_object_detection_sample] +def image_dataset_create_object_detection_sample( + project: str, location: str, display_name: str, src_uris: list +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset.create( + display_name=display_name, + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.image.bounding_box, + ) + + print(ds.display_name) + print(ds.name) + print(ds.resource_name) + print(ds.metadata_schema_uri) + return ds + + +# [END aiplatform_sdk_image_dataset_create_object_detection_sample] diff --git a/samples/model-builder/image_dataset_create_object_detection_sample_test.py b/samples/model-builder/image_dataset_create_object_detection_sample_test.py new file mode 100644 index 0000000000..722a0e2a20 --- /dev/null +++ b/samples/model-builder/image_dataset_create_object_detection_sample_test.py @@ -0,0 +1,41 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import image_dataset_create_object_detection_sample + +import test_constants as constants + + +def test_image_dataset_create_object_detection_sample( + mock_sdk_init, mock_create_image_dataset +): + image_dataset_create_object_detection_sample.image_dataset_create_object_detection_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_create_image_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.image.bounding_box, + ) diff --git a/samples/model-builder/image_dataset_create_sample.py b/samples/model-builder/image_dataset_create_sample.py new file mode 100644 index 0000000000..d5821ff7da --- /dev/null +++ b/samples/model-builder/image_dataset_create_sample.py @@ -0,0 +1,31 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_image_dataset_create_sample] +def image_dataset_create_sample(project: str, location: str, display_name: str): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset.create(display_name=display_name) + + print(ds.display_name) + print(ds.name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_image_dataset_create_sample] diff --git a/samples/model-builder/image_dataset_create_sample_test.py b/samples/model-builder/image_dataset_create_sample_test.py new file mode 100644 index 0000000000..9d04536184 --- /dev/null +++ b/samples/model-builder/image_dataset_create_sample_test.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import image_dataset_create_sample +import test_constants as constants + + +def test_image_dataset_create_sample(mock_sdk_init, mock_create_image_dataset): + image_dataset_create_sample.image_dataset_create_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_image_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME + ) diff --git a/samples/model-builder/image_dataset_import_data_sample.py b/samples/model-builder/image_dataset_import_data_sample.py new file mode 100644 index 0000000000..40ca2c75a8 --- /dev/null +++ b/samples/model-builder/image_dataset_import_data_sample.py @@ -0,0 +1,37 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_image_dataset_import_data_sample] +def image_dataset_import_data_sample( + project: str, location: str, src_uris: list, import_schema_uri: str, dataset_id: str +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset(dataset_id) + + ds = ds.import_data( + gcs_source=src_uris, import_schema_uri=import_schema_uri, sync=True + ) + + print(ds.display_name) + print(ds.name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_image_dataset_import_data_sample] diff --git a/samples/model-builder/image_dataset_import_data_sample_test.py b/samples/model-builder/image_dataset_import_data_sample_test.py new file mode 100644 index 0000000000..e237b115f3 --- /dev/null +++ b/samples/model-builder/image_dataset_import_data_sample_test.py @@ -0,0 +1,40 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import image_dataset_import_data_sample + +import test_constants as constants + + +def test_image_dataset_import_data_sample( + mock_sdk_init, mock_import_image_dataset, mock_get_image_dataset +): + + image_dataset_import_data_sample.image_dataset_import_data_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + import_schema_uri=None, + dataset_id=constants.DATASET_NAME, + ) + + mock_get_image_dataset.assert_called_once_with(constants.DATASET_NAME) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_import_image_dataset.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, import_schema_uri=None, sync=True + ) diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index 69da01dbd8..aa92434b95 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -54,6 +54,11 @@ ENCRYPTION_SPEC_KEY_NAME = f"{PARENT}/keyRings/{RESOURCE_ID}/cryptoKeys/{RESOURCE_ID_2}" +PYTHON_PACKAGE = "gs://my-packages/training.tar.gz" +PYTHON_PACKAGE_CMDARGS = f"--model-dir={GCS_DESTINATION}" +TRAIN_IMAGE = "gcr.io/train_image:latest" +DEPLOY_IMAGE = "gcr.io/deploy_image:latest" + PREDICTION_TEXT_INSTANCE = "This is some text for testing NLP prediction output" PREDICTION_TABULAR_CLASSIFICATION_INSTANCE = [