diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample.py b/samples/model-builder/create_training_pipeline_image_classification_sample.py index 3786894a05..615e468485 100644 --- a/samples/model-builder/create_training_pipeline_image_classification_sample.py +++ b/samples/model-builder/create_training_pipeline_image_classification_sample.py @@ -24,6 +24,7 @@ def create_training_pipeline_image_classification_sample( display_name: str, dataset_id: int, model_display_name: Optional[str] = None, + multi_label: bool = False, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -33,7 +34,11 @@ def create_training_pipeline_image_classification_sample( ): aiplatform.init(project=project, location=location) - job = aiplatform.AutoMLImageTrainingJob(display_name=display_name) + job = aiplatform.AutoMLImageTrainingJob( + display_name=display_name, + prediction_type='classification', + multi_label=multi_label + ) my_image_ds = aiplatform.ImageDataset(dataset_id) diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py index 1c7080e7a1..c5d7e14beb 100644 --- a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py +++ b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py @@ -44,7 +44,9 @@ def test_create_training_pipeline_image_classification_sample( project=constants.PROJECT, location=constants.LOCATION ) mock_get_automl_image_training_job.assert_called_once_with( - display_name=constants.DISPLAY_NAME + display_name=constants.DISPLAY_NAME, + multi_label=False, + prediction_type='classification' ) mock_run_automl_image_training_job.assert_called_once_with( dataset=mock_image_dataset,