Skip to content

Commit

Permalink
Added more args to samples
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Apr 21, 2021
1 parent 75505d1 commit fec3e6c
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 11 deletions.
Expand Up @@ -19,13 +19,17 @@
# [START aiplatform_sdk_create_training_pipeline_custom_job_sample]
def create_training_pipeline_custom_job_sample(
project: str,
location: str,
display_name: str,
script_path: str,
container_uri: str,
model_serving_container_image_uri: str,
args: Optional[List[Union[str, float, int]]] = None,
location: str = "us-central1",
model_serving_container_image_uri: str,
model_display_name: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
replica_count: int = 0,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
test_fraction_split: float = 0.1,
Expand All @@ -41,6 +45,10 @@ def create_training_pipeline_custom_job_sample(
model = job.run(
model_display_name=model_display_name,
args=args,
replica_count=replica_count,
machine_type=machine_type,
accelerator_type=accelerator_type,
accelerator_count=accelerator_count,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
Expand Down
Expand Up @@ -25,12 +25,17 @@ def test_create_training_pipeline_custom_job_sample(

create_training_pipeline_custom_job_sample.create_training_pipeline_custom_job_sample(
project=constants.PROJECT,
location=constants.LOCATION,
display_name=constants.DISPLAY_NAME,
args=constants.ARGS,
script_path=constants.SCRIPT_PATH,
container_uri=constants.CONTAINER_URI,
model_serving_container_image_uri=constants.CONTAINER_URI,
model_display_name=constants.DISPLAY_NAME_2,
replica_count=constants.REPLICA_COUNT,
machine_type=constants.MACHINE_TYPE,
accelerator_type=constants.ACCELERATOR_TYPE,
accelerator_count=constants.ACCELERATOR_COUNT,
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
test_fraction_split=constants.TEST_FRACTION_SPLIT,
Expand All @@ -47,6 +52,10 @@ def test_create_training_pipeline_custom_job_sample(
)
mock_run_custom_training_job.assert_called_once_with(
model_display_name=constants.DISPLAY_NAME_2,
replica_count=constants.REPLICA_COUNT,
machine_type=constants.MACHINE_TYPE,
accelerator_type=constants.ACCELERATOR_TYPE,
accelerator_count=constants.ACCELERATOR_COUNT,
args=constants.ARGS,
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
Expand Down
Expand Up @@ -18,14 +18,18 @@
# [START aiplatform_sdk_create_training_pipeline_custom_job_sample]
def create_training_pipeline_custom_training_managed_dataset_sample(
project: str,
location: str,
display_name: str,
script_path: str,
container_uri: str,
model_serving_container_image_uri: str,
dataset_id: int,
args: Optional[List[Union[str, float, int]]] = None,
location: str = "us-central1",
dataset_id: int,
model_display_name: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
replica_count: int = 0,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
test_fraction_split: float = 0.1,
Expand All @@ -43,10 +47,14 @@ def create_training_pipeline_custom_training_managed_dataset_sample(
model = job.run(
dataset=my_image_ds,
model_display_name=model_display_name,
args=args,
replica_count=replica_count,
machine_type=machine_type,
accelerator_type=accelerator_type,
accelerator_count=accelerator_count,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
args=args,
test_fraction_split=test_fraction_split,
sync=sync,
)

Expand Down
Expand Up @@ -27,13 +27,18 @@ def test_create_training_pipeline_custom_job_sample(

create_training_pipeline_custom_training_managed_dataset_sample.create_training_pipeline_custom_training_managed_dataset_sample(
project=constants.PROJECT,
location=constants.LOCATION,
display_name=constants.DISPLAY_NAME,
args=constants.ARGS,
script_path=constants.SCRIPT_PATH,
container_uri=constants.CONTAINER_URI,
model_serving_container_image_uri=constants.CONTAINER_URI,
dataset_id=constants.RESOURCE_ID,
model_display_name=constants.DISPLAY_NAME_2,
replica_count=constants.REPLICA_COUNT,
machine_type=constants.MACHINE_TYPE,
accelerator_type=constants.ACCELERATOR_TYPE,
accelerator_count=constants.ACCELERATOR_COUNT,
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
test_fraction_split=constants.TEST_FRACTION_SPLIT,
Expand All @@ -54,6 +59,10 @@ def test_create_training_pipeline_custom_job_sample(
dataset=mock_image_dataset,
model_display_name=constants.DISPLAY_NAME_2,
args=constants.ARGS,
replica_count=constants.REPLICA_COUNT,
machine_type=constants.MACHINE_TYPE,
accelerator_type=constants.ACCELERATOR_TYPE,
accelerator_count=constants.ACCELERATOR_COUNT,
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
test_fraction_split=constants.TEST_FRACTION_SPLIT,
Expand Down
Expand Up @@ -18,9 +18,9 @@
# [START aiplatform_sdk_create_training_pipeline_image_classification_sample]
def create_training_pipeline_image_classification_sample(
project: str,
location: str,
display_name: str,
dataset_id: int,
location: str = "us-central1",
dataset_id: int,
model_display_name: Optional[str] = None,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
Expand Down
Expand Up @@ -27,6 +27,7 @@ def test_create_training_pipeline_image_classification_sample(

create_training_pipeline_image_classification_sample.create_training_pipeline_image_classification_sample(
project=constants.PROJECT,
location=constants.LOCATION,
display_name=constants.DISPLAY_NAME,
dataset_id=constants.RESOURCE_ID,
model_display_name=constants.DISPLAY_NAME_2,
Expand Down
6 changes: 5 additions & 1 deletion samples/model-builder/test_constants.py
Expand Up @@ -55,4 +55,8 @@

SCRIPT_PATH = "task.py"
CONTAINER_URI = "gcr.io/my_project/my_image:latest"
ARGS = ["--tfds", "tf_flowers:3.*.*"]
ARGS = ["--tfds", "tf_flowers:3.*.*"]
REPLICA_COUNT = 0
MACHINE_TYPE = "n1-standard-4"
ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
ACCELERATOR_COUNT = 0

0 comments on commit fec3e6c

Please sign in to comment.