Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply max line length to 80 #178

Merged
merged 3 commits into from Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1,4 +1,4 @@
name: Pre-commit
name: Formatter

on:
pull_request:
Expand All @@ -16,5 +16,5 @@ jobs:
with:
# Note: this should match Cloud Composer
# https://cloud.google.com/composer/docs/concepts/versioning/composer-versions
python-version: '3.8'
python-version: '3.10'
- uses: pre-commit/action@v3.0.0
4 changes: 2 additions & 2 deletions .github/workflows/pylint-check.yml
@@ -1,4 +1,4 @@
name: Pylint
name: Linter

on:
pull_request:
Expand All @@ -9,7 +9,7 @@ on:
branches: [master]

jobs:
build:
linting_check:
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Expand Up @@ -3,6 +3,6 @@ repos:
rev: 23.10.0
hooks:
- id: pyink
args: [--pyink-indentation=2, --pyink-use-majority-quotes]
language_version: python3.8
args: [--pyink-indentation=2, --pyink-use-majority-quotes, --line-length=80]
RissyRan marked this conversation as resolved.
Show resolved Hide resolved
language_version: python3.10
verbose: true
32 changes: 17 additions & 15 deletions dags/examples/maxtext_sweep_gce_example_dag.py
Expand Up @@ -44,21 +44,23 @@
]

# Get list of MaxText GCE QueuedResource jobs
maxtext_sweep_gce_test = maxtext_sweep_gce_config.get_maxtext_sweep_gce_config(
test_owner=test_owner.RAYMOND_Z,
project_name=Project.TPU_PROD_ENV_MULTIPOD.value,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
is_tpu_reserved=False,
tpu_version=TpuVersion.V4,
tpu_cores=8,
runtime_version=RuntimeVersion.TPU_UBUNTU2204_BASE.value,
base_output_directory=base_output_directory,
num_slices=[1],
run_name_prefix="maxtext-1b",
base_set_up_cmds=base_set_up_cmds,
base_run_model_cmds=base_run_model_cmds,
sweep_params={"M_PER_DEVICE_BATCH_SIZE": [1, 2, 4]},
maxtext_sweep_gce_test = (
maxtext_sweep_gce_config.get_maxtext_sweep_gce_config(
test_owner=test_owner.RAYMOND_Z,
project_name=Project.TPU_PROD_ENV_MULTIPOD.value,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
is_tpu_reserved=False,
tpu_version=TpuVersion.V4,
tpu_cores=8,
runtime_version=RuntimeVersion.TPU_UBUNTU2204_BASE.value,
base_output_directory=base_output_directory,
num_slices=[1],
run_name_prefix="maxtext-1b",
base_set_up_cmds=base_set_up_cmds,
base_run_model_cmds=base_run_model_cmds,
sweep_params={"M_PER_DEVICE_BATCH_SIZE": [1, 2, 4]},
)
)

# Run jobs
Expand Down
30 changes: 16 additions & 14 deletions dags/examples/maxtext_sweep_gke_example_dag.py
Expand Up @@ -41,20 +41,22 @@
]

# Get list of MaxText GKE XPK jobs
maxtext_sweep_gke_test = maxtext_sweep_gke_config.get_maxtext_sweep_gke_config(
test_owner=test_owner.RAYMOND_Z,
project_name=Project.TPU_PROD_ENV_MULTIPOD.value,
cluster_name=ClusterName.V4_128_MULTISLICE_CLUSTER.value,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
base_output_directory=base_output_directory,
tpu_version=TpuVersion.V4,
tpu_cores=128,
num_slices=[1],
docker_image=DockerImage.XPK_MAXTEXT_TEST.value,
run_name_prefix="maxtext-16b",
base_run_model_cmds=base_run_model_cmds,
sweep_params={"M_PER_DEVICE_BATCH_SIZE": [2, 4, 8]},
maxtext_sweep_gke_test = (
maxtext_sweep_gke_config.get_maxtext_sweep_gke_config(
test_owner=test_owner.RAYMOND_Z,
project_name=Project.TPU_PROD_ENV_MULTIPOD.value,
cluster_name=ClusterName.V4_128_MULTISLICE_CLUSTER.value,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
base_output_directory=base_output_directory,
tpu_version=TpuVersion.V4,
tpu_cores=128,
num_slices=[1],
docker_image=DockerImage.XPK_MAXTEXT_TEST.value,
run_name_prefix="maxtext-16b",
base_run_model_cmds=base_run_model_cmds,
sweep_params={"M_PER_DEVICE_BATCH_SIZE": [2, 4, 8]},
)
)

# Run jobs
Expand Down
4 changes: 1 addition & 3 deletions dags/multipod/configs/maxtext_gce_config.py
Expand Up @@ -50,9 +50,7 @@ def get_maxtext_nightly_config(
current_datetime = current_time.strftime("%Y-%m-%d-%H-%M-%S")

trigger = "automated" if automated_test else "manual"
base_output_directory = (
f"{gcs_bucket.XLML_OUTPUT_DIR}/maxtext/{test_mode.value}/{trigger}/{current_date}"
)
base_output_directory = f"{gcs_bucket.XLML_OUTPUT_DIR}/maxtext/{test_mode.value}/{trigger}/{current_date}"

run_name = f"{num_slices}slice-V{tpu_version.value}_{tpu_cores}-maxtext-{test_mode.value}-{current_datetime}"

Expand Down
4 changes: 3 additions & 1 deletion dags/multipod/configs/maxtext_sweep_gce_config.py
Expand Up @@ -66,7 +66,9 @@ def get_maxtext_sweep_gce_config(
del config_dict["NUM_SLICES"]

# Export sweep params as env variables for MaxText to read
run_model_cmds = [f"export {key}={value}" for (key, value) in config_dict.items()]
run_model_cmds = [
f"export {key}={value}" for (key, value) in config_dict.items()
]
for cmd in base_run_model_cmds:
run_model_cmds.append(cmd)

Expand Down
4 changes: 3 additions & 1 deletion dags/multipod/configs/maxtext_sweep_gke_config.py
Expand Up @@ -64,7 +64,9 @@ def get_maxtext_sweep_gke_config(
del config_dict["NUM_SLICES"]

# Export sweep params as env variables for MaxText to read
run_model_cmds = [f"export {key}={value}" for (key, value) in config_dict.items()]
run_model_cmds = [
f"export {key}={value}" for (key, value) in config_dict.items()
]
for cmd in base_run_model_cmds:
run_model_cmds.append(cmd)

Expand Down
12 changes: 9 additions & 3 deletions dags/pytorch_xla/configs/pytorchxla_torchbench_config.py
Expand Up @@ -115,7 +115,9 @@ def get_torchbench_tpu_config(
task_owner=test_owner.PEI_Z,
)

job_metric_config = metric_config.MetricConfig(use_runtime_generated_filename=True)
job_metric_config = metric_config.MetricConfig(
use_runtime_generated_filename=True
)

return task.TpuQueuedResourceTask(
task_test_config=job_test_config,
Expand All @@ -125,7 +127,9 @@ def get_torchbench_tpu_config(


# Below is the setup for torchbench GPU run.
def set_up_torchbench_gpu(model_name: str, nvidia_driver_version: str) -> Tuple[str]:
def set_up_torchbench_gpu(
model_name: str, nvidia_driver_version: str
) -> Tuple[str]:
"""Common set up for TorchBench."""

def model_install_cmds(output_file=None) -> str:
Expand Down Expand Up @@ -259,7 +263,9 @@ def get_torchbench_gpu_config(
task_owner=test_owner.PEI_Z,
)

job_metric_config = metric_config.MetricConfig(use_runtime_generated_filename=True)
job_metric_config = metric_config.MetricConfig(
use_runtime_generated_filename=True
)

return task.GpuCreateResourceTask(
image_project.value,
Expand Down
12 changes: 9 additions & 3 deletions dags/pytorch_xla/nightly.py
Expand Up @@ -50,7 +50,9 @@
@task_group(prefix_group_id=False)
def torchvision():
mnist_v2_8 = task.TpuQueuedResourceTask(
test_config.JSonnetTpuVmTest.from_pytorch("pt-nightly-mnist-pjrt-func-v2-8-1vm"),
test_config.JSonnetTpuVmTest.from_pytorch(
"pt-nightly-mnist-pjrt-func-v2-8-1vm"
),
US_CENTRAL1_C,
).run()
resnet_v2_8 = task.TpuQueuedResourceTask(
Expand Down Expand Up @@ -102,7 +104,9 @@ def torchvision():
resnet_v2_8 >> resnet_v3_8_tests

resnet_v100_2x2 = task.GpuGkeTask(
test_config.JSonnetGpuTest.from_pytorch("pt-nightly-resnet50-mp-fake-v100-x2x2"),
test_config.JSonnetGpuTest.from_pytorch(
"pt-nightly-resnet50-mp-fake-v100-x2x2"
),
US_CENTRAL1,
"gpu-uc1",
).run()
Expand Down Expand Up @@ -145,7 +149,9 @@ def huggingface():
US_CENTRAL1_C,
).run()
accelerate_v4_8 = task.TpuQueuedResourceTask(
test_config.JSonnetTpuVmTest.from_pytorch("pt-nightly-accelerate-smoke-v4-8-1vm"),
test_config.JSonnetTpuVmTest.from_pytorch(
"pt-nightly-accelerate-smoke-v4-8-1vm"
),
US_CENTRAL2_B,
).run()
diffusers_v4_8 = task.TpuQueuedResourceTask(
Expand Down
Expand Up @@ -513,7 +513,9 @@ def get_flax_bert_config(
)

set_up_cmds = get_flax_bert_setup_cmds()
run_model_cmds = get_flax_bert_run_model_cmds(task_name, num_train_epochs, extraFlags)
run_model_cmds = get_flax_bert_run_model_cmds(
task_name, num_train_epochs, extraFlags
)

job_test_config = test_config.TpuVmTest(
test_config.Tpu(
Expand Down
Expand Up @@ -37,12 +37,16 @@ def get_setup_cmds(
) -> Tuple[str]:
if pax_version is PaxVersion.STABLE:
logging.info("Running the latest stable Pax version.")
ckp_cmds = f"gsutil -m cp -r {ckp_path} {job_log_dir}" if ckp_path else "echo"
ckp_cmds = (
f"gsutil -m cp -r {ckp_path} {job_log_dir}" if ckp_path else "echo"
)
return common.set_up_google_pax() + (ckp_cmds,)
elif pax_version is PaxVersion.NIGHTLY:
logging.info("Running nightly Pax version.")
build_date = datetime.today().strftime("%Y%m%d")
ckp_cmds = f"gsutil -m cp -r {ckp_path} {job_log_dir}" if ckp_path else "echo"
ckp_cmds = (
f"gsutil -m cp -r {ckp_path} {job_log_dir}" if ckp_path else "echo"
)
return (
ckp_cmds,
(
Expand Down
Expand Up @@ -110,7 +110,9 @@ def get_tf_resnet_config(
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

set_up_cmds = common.install_tf_2_16() + common.set_up_google_tensorflow_2_16_models()
set_up_cmds = (
common.install_tf_2_16() + common.set_up_google_tensorflow_2_16_models()
)
if not is_pjrt and is_pod:
set_up_cmds += common.set_up_se_nightly()

Expand Down Expand Up @@ -199,7 +201,9 @@ def get_tf_dlrm_config(
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

set_up_cmds = common.install_tf_2_16() + common.set_up_google_tensorflow_2_16_models()
set_up_cmds = (
common.install_tf_2_16() + common.set_up_google_tensorflow_2_16_models()
)
if not is_pjrt and is_pod:
set_up_cmds += common.set_up_se_nightly()

Expand Down Expand Up @@ -311,7 +315,10 @@ def get_tf_dlrm_config(

def export_env_variable(is_pod: bool, is_pjrt: bool) -> str:
"""Export environment variables for training if any."""
stmts = ["export WRAPT_DISABLE_EXTENSIONS=true", "export TF_USE_LEGACY_KERAS=1"]
stmts = [
"export WRAPT_DISABLE_EXTENSIONS=true",
"export TF_USE_LEGACY_KERAS=1",
]
if is_pod:
stmts.append("export TPU_LOAD_LIBRARY=0")
elif is_pjrt:
Expand Down
Expand Up @@ -110,7 +110,9 @@ def get_tf_resnet_config(
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

set_up_cmds = common.install_tf_nightly() + common.set_up_google_tensorflow_models()
set_up_cmds = (
common.install_tf_nightly() + common.set_up_google_tensorflow_models()
)
if not is_pjrt and is_pod:
set_up_cmds += common.set_up_se_nightly()

Expand Down Expand Up @@ -199,7 +201,9 @@ def get_tf_dlrm_config(
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

set_up_cmds = common.install_tf_nightly() + common.set_up_google_tensorflow_models()
set_up_cmds = (
common.install_tf_nightly() + common.set_up_google_tensorflow_models()
)
if not is_pjrt and is_pod:
set_up_cmds += common.set_up_se_nightly()

Expand Down
16 changes: 12 additions & 4 deletions dags/solutions_team/solutionsteam_flax_latest_supported.py
Expand Up @@ -271,9 +271,13 @@
"--max_seq_length 512",
"--eval_steps 1000",
]
jax_bert_v4_mnli_extra_flags = jax_bert_mnli_extra_flags + jax_bert_v4_batch_size
jax_bert_v4_mnli_extra_flags = (
jax_bert_mnli_extra_flags + jax_bert_v4_batch_size
)
jax_bert_v4_mnli_conv_extra_flags = (
jax_bert_mnli_extra_flags + jax_bert_v4_batch_size + jax_bert_conv_extra_flags
jax_bert_mnli_extra_flags
+ jax_bert_v4_batch_size
+ jax_bert_conv_extra_flags
)

jax_bert_mnli_v4_8 = flax_config.get_flax_bert_config(
Expand All @@ -300,9 +304,13 @@
"--max_seq_length 128",
"--eval_steps 100",
]
jax_bert_v4_mrpc_extra_flags = jax_bert_mrpc_extra_flags + jax_bert_v4_batch_size
jax_bert_v4_mrpc_extra_flags = (
jax_bert_mrpc_extra_flags + jax_bert_v4_batch_size
)
jax_bert_v4_mrpc_conv_extra_flags = (
jax_bert_mrpc_extra_flags + jax_bert_v4_batch_size + jax_bert_conv_extra_flags
jax_bert_mrpc_extra_flags
+ jax_bert_v4_batch_size
+ jax_bert_conv_extra_flags
)

jax_bert_mrpc_v4_8 = flax_config.get_flax_bert_config(
Expand Down
4 changes: 3 additions & 1 deletion dags/solutions_team/solutionsteam_jax_integration.py
Expand Up @@ -41,7 +41,9 @@
catchup=False,
):
compilation_cache = task.TpuQueuedResourceTask(
test_config.JSonnetTpuVmTest.from_jax("jax-compilation-cache-test-func-v2-8-1vm"),
test_config.JSonnetTpuVmTest.from_jax(
"jax-compilation-cache-test-func-v2-8-1vm"
),
US_CENTRAL1_C,
).run()

Expand Down
Expand Up @@ -35,7 +35,9 @@
log_dir_prefix = f"{gcs_bucket.XLML_OUTPUT_DIR}/pax/nightly"

# GPT-3 config with 1B params on c4 dataset with SPMD and Adam
c4spmd1b_pretraining_exp_path = "tasks.lm.params.c4.C4Spmd1BAdam4ReplicasLimitSteps"
c4spmd1b_pretraining_exp_path = (
"tasks.lm.params.c4.C4Spmd1BAdam4ReplicasLimitSteps"
)
pax_nightly_c4spmd1b_pretraining_v4_8 = pax_config.get_pax_lm_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
Expand Down
2 changes: 1 addition & 1 deletion scripts/.pylintrc
Expand Up @@ -238,7 +238,7 @@ generated-members=
[FORMAT]

# Maximum number of characters on a single line.
max-line-length=125
max-line-length=80

# TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
Expand Down
2 changes: 1 addition & 1 deletion scripts/code-style.sh
Expand Up @@ -21,7 +21,7 @@ FOLDERS_TO_FORMAT=("dags" "xlml")

for folder in "${FOLDERS_TO_FORMAT[@]}"
do
pyink "$folder" --pyink-indentation=2 --pyink-use-majority-quotes
pyink "$folder" --pyink-indentation=2 --pyink-use-majority-quotes --line-length=80
done

for folder in "${FOLDERS_TO_FORMAT[@]}"
Expand Down
2 changes: 1 addition & 1 deletion xlml/apis/gcp_config.py
Expand Up @@ -24,7 +24,7 @@ class GCPConfig:
"""This is a class to set up configs of GCP.

Attributes:
project_name: The name of a project to provision resource and run a test job.
project_name: Name of a project to provision resource and run a test job.
zone: The zone to provision resource and run a test job.
dataset_name: The option of dataset for metrics.
dataset_project: The name of a project that hosts the dataset.
Expand Down
3 changes: 2 additions & 1 deletion xlml/apis/metric_config.py
Expand Up @@ -90,7 +90,8 @@ class ProfileConfig:

@dataclasses.dataclass
class MetricConfig:
"""This is a class to set up config of Benchmark metric, dimension, and profile.
"""This is a class to set up config of Benchmark metric,
dimension, and profile.

Attributes:
json_lines: The config for JSON Lines input.
Expand Down