diff --git a/scripts/generate_dag.py b/scripts/generate_dag.py index 7c313dde8..f6194052f 100644 --- a/scripts/generate_dag.py +++ b/scripts/generate_dag.py @@ -138,8 +138,9 @@ def generate_shared_variables_file(env: str) -> None: shared_variables_file = pathlib.Path( PROJECT_ROOT / f".{env}" / "datasets" / "shared_variables.json" ) - shared_variables_file.touch() - shared_variables_file.write_text("{}", encoding="utf-8") + if not shared_variables_file.exists(): + shared_variables_file.touch() + shared_variables_file.write_text("{}", encoding="utf-8") def dag_init(config: dict) -> dict: diff --git a/tests/scripts/test_generate_dag.py b/tests/scripts/test_generate_dag.py index 9612f68bf..fe3bbcb68 100644 --- a/tests/scripts/test_generate_dag.py +++ b/tests/scripts/test_generate_dag.py @@ -55,6 +55,16 @@ def pipeline_path(dataset_path, suffix="_pipeline") -> typing.Iterator[pathlib.P yield pathlib.Path(dir_path) +@pytest.fixture(autouse=True) +def cleanup_shared_variables(): + shared_variables_file = ENV_DATASETS_PATH / "shared_variables.json" + if shared_variables_file.exists(): + shared_variables_file.unlink() + yield + if shared_variables_file.exists(): + shared_variables_file.unlink() + + def generate_image_files(dataset_path: pathlib.Path, num_containers: int = 1): for i in range(num_containers): target_dir = dataset_path / "_images" / f"test_image_{i+1}" @@ -145,12 +155,11 @@ def test_main_copies_custom_dir_if_it_exists( assert (path_prefix / "custom").is_dir() -def test_main_creates_shared_variables_file( +def test_main_creates_shared_variables_file_if_it_doesnt_exist( dataset_path: pathlib.Path, pipeline_path: pathlib.Path, env: str ): copy_config_files_and_set_tmp_folder_names_as_ids(dataset_path, pipeline_path) - custom_path = dataset_path / pipeline_path.name / "custom" - custom_path.mkdir(parents=True, exist_ok=True) + assert not (ENV_DATASETS_PATH / "shared_variables.json").exists() generate_dag.main(dataset_path.name, pipeline_path.name, env) @@ -158,6 +167,27 @@ def test_main_creates_shared_variables_file( assert not (ENV_DATASETS_PATH / "shared_variables.json").is_dir() +def test_main_does_not_modify_existing_shared_variables_file( + dataset_path: pathlib.Path, pipeline_path: pathlib.Path, env: str +): + copy_config_files_and_set_tmp_folder_names_as_ids(dataset_path, pipeline_path) + + # Create .test/datasets dir that'll contain the existing shared_variables.json file + ENV_DATASETS_PATH.mkdir(parents=True, exist_ok=True) + shared_variables_file = ENV_DATASETS_PATH / "shared_variables.json" + assert not shared_variables_file.exists() + + # Create a non-empty shared variables file + airflow_vars = {"key": "value"} + shared_variables_file.touch() + shared_variables_file.write_text(json.dumps(airflow_vars), encoding="utf-8") + + generate_dag.main(dataset_path.name, pipeline_path.name, env) + + assert shared_variables_file.exists() + assert json.loads(shared_variables_file.read_text()) == airflow_vars + + def test_main_raises_an_error_when_airflow_version_is_not_specified( dataset_path: pathlib.Path, pipeline_path: pathlib.Path, env: str ):