Skip to content

Commit

Permalink
Added create_training_pipeline_custom_training_managed_dataset_sample…
Browse files Browse the repository at this point in the history
… and fixed unmanaged sample
  • Loading branch information
ivanmkc committed Apr 21, 2021
1 parent 94fd07b commit 4ea56bf
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.

from google.cloud import aiplatform
from typing import List


# [START aiplatform_sdk_create_training_pipeline_custom_job_sample]
def create_training_pipeline_custom_job_sample(
project: str,
display_name: str,
args: List[str],
script_path: str,
container_uri: str,
dataset_id: int,
location: str = "us-central1",
model_display_name: str = None,
training_fraction_split: float = 0.8,
Expand All @@ -35,11 +36,9 @@ def create_training_pipeline_custom_job_sample(
script_path=script_path,
container_uri=container_uri)

my_image_ds = aiplatform.ImageDataset(dataset_id)

model = job.run(
dataset=my_image_ds,
model_display_name=model_display_name,
args=args,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
Expand Down
Expand Up @@ -19,26 +19,22 @@

def test_create_training_pipeline_custom_job_sample(
mock_sdk_init,
mock_image_dataset,
mock_init_custom_training_job,
mock_run_custom_training_job,
mock_get_image_dataset,
):

create_training_pipeline_custom_job_sample.create_training_pipeline_custom_job_sample(
project=constants.PROJECT,
display_name=constants.DISPLAY_NAME,
args=constants.ARGS,
script_path=constants.SCRIPT_PATH,
container_uri=constants.CONTAINER_URI,
dataset_id=constants.RESOURCE_ID,
model_display_name=constants.DISPLAY_NAME_2,
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
test_fraction_split=constants.TEST_FRACTION_SPLIT,
)

mock_get_image_dataset.assert_called_once_with(constants.RESOURCE_ID)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)
Expand All @@ -48,8 +44,8 @@ def test_create_training_pipeline_custom_job_sample(
container_uri=constants.CONTAINER_URI,
)
mock_run_custom_training_job.assert_called_once_with(
dataset=mock_image_dataset,
model_display_name=constants.DISPLAY_NAME_2,
args=constants.ARGS,
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
test_fraction_split=constants.TEST_FRACTION_SPLIT,
Expand Down
3 changes: 2 additions & 1 deletion samples/model-builder/test_constants.py
Expand Up @@ -54,4 +54,5 @@
PREDICTION_TEXT_INSTANCE = "This is some text for testing NLP prediction output"

SCRIPT_PATH = "task.py"
CONTAINER_URI = "gcr.io/my_project/my_image:latest"
CONTAINER_URI = "gcr.io/my_project/my_image:latest"
ARGS = ["--tfds", "tf_flowers:3.*.*"]

0 comments on commit 4ea56bf

Please sign in to comment.