Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Adds support for Airflow 2 Cloud Composer environment and operators #134

Merged
merged 8 commits into from Aug 10, 2021
12 changes: 12 additions & 0 deletions samples/pipeline.yaml
Expand Up @@ -12,6 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.


# ===== NOTE =====
# This YAML config template is used to write DAGs that use Airflow 1.10 operators.
# You can keep using this template when deploying to an environment that uses Airflow 2,
# they will keep working due to backport compatibility.
#
# For tracking progress on the YAML config templates that use Airflow 2 operators, see
# https://github.com/GoogleCloudPlatform/public-datasets-pipelines/issues/137.

---
resources:
# A list of GCP resources that are unique and specific to your pipeline.
Expand Down Expand Up @@ -58,6 +67,9 @@ resources:
deletion_protection: true

dag:
# Specify the Airflow version of the operators used by the DAG. Defaults to Airflow 1 when unspecified.
airflow_version: 1
adlersantos marked this conversation as resolved.
Show resolved Hide resolved

# The DAG acronym stands for directed acyclic graph. This block represents
# your data pipeline along with every property and configuration it needs to
# onboard your data.
Expand Down
28 changes: 27 additions & 1 deletion scripts/dag_imports.json
@@ -1,5 +1,5 @@
{
"1.10.15": {
"1": {
"BashOperator": {
"import": "from airflow.operators import bash_operator",
"class": "bash_operator.BashOperator"
Expand Down Expand Up @@ -28,5 +28,31 @@
"import": "from airflow.contrib.operators import gcs_delete_operator",
"class": "gcs_delete_operator.GoogleCloudStorageDeleteOperator"
}
},
"2": {
"BashOperator": {
"import": "from airflow.operators import bash",
"class": "bash.BashOperator"
},
"GoogleCloudStorageToBigQueryOperator": {
"import": "from airflow.providers.google.cloud.transfers import gcs_to_bigquery",
"class": "gcs_to_bigquery.GCSToBigQueryOperator"
},
"GoogleCloudStorageToGoogleCloudStorageOperator": {
"import": "from airflow.providers.google.cloud.transfers import gcs_to_gcs",
"class": "gcs_to_gcs.GCSToGCSOperator"
},
"GoogleCloudStorageDeleteOperator": {
"import": "from airflow.providers.google.cloud.operators import gcs",
"class": "gcs.GCSDeleteObjectsOperator"
},
"BigQueryInsertJobOperator": {
"import": "from airflow.providers.google.cloud.operators import bigquery",
"class": "bigquery.BigQueryInsertJobOperator"
},
"KubernetesPodOperator": {
"import": "from airflow.providers.cncf.kubernetes.operators import kubernetes_pod",
"class": "kubernetes_pod.KubernetesPodOperator"
}
}
}
76 changes: 74 additions & 2 deletions scripts/deploy_dag.py
Expand Up @@ -14,14 +14,24 @@


import argparse
import json
import pathlib
import subprocess
import typing
import warnings

from ruamel import yaml

yaml = yaml.YAML(typ="safe")

CURRENT_PATH = pathlib.Path(__file__).resolve().parent
PROJECT_ROOT = CURRENT_PATH.parent
DATASETS_PATH = PROJECT_ROOT / "datasets"
DEFAULT_AIRFLOW_VERSION = 1
leahecole marked this conversation as resolved.
Show resolved Hide resolved


class IncompatibilityError(Exception):
pass


def main(
Expand All @@ -44,11 +54,18 @@ def main(

print("========== AIRFLOW DAGS ==========")
if pipeline:
pipelines = [env_path / "datasets" / pipeline]
pipelines = [env_path / "datasets" / dataset_id / pipeline]
else:
pipelines = list_subdirs(env_path / "datasets" / dataset_id)

# if local:
# runtime_airflow_version = local_airflow_version()
# else:
# runtime_airflow_version = composer_airflow_version(composer_env, composer_region)

for pipeline_path in pipelines:
# check_airflow_version_compatibility(pipeline_path, runtime_airflow_version)
adlersantos marked this conversation as resolved.
Show resolved Hide resolved

copy_custom_callables_to_airflow_dags_folder(
local,
env_path,
Expand Down Expand Up @@ -127,6 +144,7 @@ def run_cloud_composer_vars_import(
subprocess.check_call(
[
"gcloud",
"beta",
"composer",
"environments",
"run",
Expand All @@ -135,7 +153,7 @@ def run_cloud_composer_vars_import(
str(composer_region),
"variables",
"--",
"--import",
"import",
str(airflow_path),
],
cwd=cwd,
Expand Down Expand Up @@ -266,6 +284,60 @@ def list_subdirs(path: pathlib.Path) -> typing.List[pathlib.Path]:
return subdirs


def local_airflow_version() -> str:
airflow_version = subprocess.run(
["airflow", "version"], stdout=subprocess.PIPE
).stdout.decode("utf-8")
return 2 if airflow_version.startswith("2") else 1


def composer_airflow_version(composer_env: str, composer_region: str) -> str:
composer_env = json.loads(
subprocess.run(
[
"gcloud",
"composer",
"environments",
"describe",
composer_env,
"--location",
composer_region,
"--format",
"json",
],
stdout=subprocess.PIPE,
).stdout.decode("utf-8")
)

# Example image version: composer-1.17.0-preview.8-airflow-2.1.1
image_version = composer_env["config"]["softwareConfig"]["imageVersion"]

airflow_version = image_version.split("-airflow-")[-1]
return 2 if airflow_version.startswith("2") else 1


def get_dag_airflow_version(config: dict) -> int:
return config["dag"].get("airflow_version", DEFAULT_AIRFLOW_VERSION)


def check_airflow_version_compatibility(
pipeline_path: pathlib.Path, runtime_airflow_version: int
) -> None:
"""If a DAG uses Airflow 2 operators but the runtime version uses Airflow 1,
raise a compatibility error. On the other hand, DAGs using Airflow 1.x operators
can still run in an Airflow 2 runtime environment via backport providers.
"""
dag_airflow_version = get_dag_airflow_version(
yaml.load((pipeline_path / "pipeline.yaml").read_text())
)

if dag_airflow_version > runtime_airflow_version:
raise IncompatibilityError(
f"The DAG {pipeline_path.name} uses Airflow 2, but"
" you are deploying to an Airflow 1.x environment."
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Deploy DAGs and variables to an Airflow environment"
Expand Down
67 changes: 23 additions & 44 deletions scripts/generate_dag.py
Expand Up @@ -50,7 +50,7 @@
"default_args": AIRFLOW_TEMPLATES_PATH / "default_args.py.jinja2",
}

AIRFLOW_VERSION = "1.10.15"
DEFAULT_AIRFLOW_VERSION = 1
adlersantos marked this conversation as resolved.
Show resolved Hide resolved
AIRFLOW_IMPORTS = json.load(open(CURRENT_PATH / "dag_imports.json"))


Expand Down Expand Up @@ -80,24 +80,18 @@ def generate_pipeline_dag(dataset_id: str, pipeline_id: str, env: str):
validate_dag_id_existence_and_format(config)
dag_contents = generate_dag(config, dataset_id)

target_path = pipeline_dir / f"{pipeline_id}_dag.py"
create_file_in_dot_and_project_dirs(
dataset_id,
pipeline_id,
dag_contents,
target_path.name,
PROJECT_ROOT / f".{env}",
)
write_to_file(dag_contents, target_path)
dag_path = pipeline_dir / f"{pipeline_id}_dag.py"
dag_path.touch()
write_to_file(dag_contents, dag_path)
format_python_code(dag_path)

copy_custom_callables_to_dot_dir(
copy_files_to_dot_dir(
dataset_id,
pipeline_id,
PROJECT_ROOT / f".{env}",
)

print_airflow_variables(dataset_id, dag_contents, env)
format_python_code(target_path)


def generate_dag(config: dict, dataset_id: str) -> str:
Expand All @@ -111,16 +105,18 @@ def generate_dag(config: dict, dataset_id: str) -> str:


def generate_package_imports(config: dict) -> str:
_airflow_version = airflow_version(config)
contents = {"from airflow import DAG"}
for task in config["dag"]["tasks"]:
contents.add(AIRFLOW_IMPORTS[AIRFLOW_VERSION][task["operator"]]["import"])
contents.add(AIRFLOW_IMPORTS[_airflow_version][task["operator"]]["import"])
return "\n".join(contents)


def generate_tasks(config: dict) -> list:
_airflow_version = airflow_version(config)
contents = []
for task in config["dag"]["tasks"]:
contents.append(generate_task_contents(task))
contents.append(generate_task_contents(task, _airflow_version))
return contents


Expand All @@ -138,11 +134,11 @@ def generate_dag_context(config: dict, dataset_id: str) -> str:
)


def generate_task_contents(task: dict) -> str:
def generate_task_contents(task: dict, airflow_version: str) -> str:
validate_task(task)
return jinja2.Template(TEMPLATE_PATHS["task"].read_text()).render(
**task,
namespaced_operator=AIRFLOW_IMPORTS[AIRFLOW_VERSION][task["operator"]]["class"],
namespaced_operator=AIRFLOW_IMPORTS[airflow_version][task["operator"]]["class"],
)


Expand All @@ -156,6 +152,10 @@ def dag_init(config: dict) -> dict:
return config["dag"].get("initialize") or config["dag"].get("init")


def airflow_version(config: dict) -> str:
return str(config["dag"].get("airflow_version", DEFAULT_AIRFLOW_VERSION))


def namespaced_dag_id(dag_id: str, dataset_id: str) -> str:
return f"{dataset_id}.{dag_id}"

Expand Down Expand Up @@ -219,34 +219,13 @@ def print_airflow_variables(dataset_id: str, dag_contents: str, env: str):
print()


def create_file_in_dot_and_project_dirs(
dataset_id: str,
pipeline_id: str,
contents: str,
filename: str,
env_dir: pathlib.Path,
):
print("\nCreated\n")
for prefix in (
env_dir / "datasets" / dataset_id / pipeline_id,
DATASETS_PATH / dataset_id / pipeline_id,
):
prefix.mkdir(parents=True, exist_ok=True)
target_path = prefix / filename
write_to_file(contents + "\n", target_path)
print(f" - {target_path.relative_to(PROJECT_ROOT)}")


def copy_custom_callables_to_dot_dir(
dataset_id: str, pipeline_id: str, env_dir: pathlib.Path
):
callables_dir = DATASETS_PATH / dataset_id / pipeline_id / "custom"
if callables_dir.exists():
target_dir = env_dir / "datasets" / dataset_id / pipeline_id
target_dir.mkdir(parents=True, exist_ok=True)
subprocess.check_call(
["cp", "-rf", str(callables_dir), str(target_dir)], cwd=PROJECT_ROOT
)
def copy_files_to_dot_dir(dataset_id: str, pipeline_id: str, env_dir: pathlib.Path):
source_dir = PROJECT_ROOT / "datasets" / dataset_id / pipeline_id
target_dir = env_dir / "datasets" / dataset_id
target_dir.mkdir(parents=True, exist_ok=True)
subprocess.check_call(
["cp", "-rf", str(source_dir), str(target_dir)], cwd=PROJECT_ROOT
)


def build_images(dataset_id: str, env: str):
Expand Down
39 changes: 39 additions & 0 deletions tests/scripts/test_deploy_dag.py
Expand Up @@ -90,6 +90,13 @@ def copy_config_files_and_set_tmp_folder_names_as_ids(
)
)
generate_dag.write_to_file(pipeline_yaml_str, pipeline_path / "pipeline.yaml")
(ENV_DATASETS_PATH / dataset_path.name / pipeline_path.name).mkdir(
parents=True, exist_ok=True
)
shutil.copyfile(
pipeline_path / "pipeline.yaml",
ENV_DATASETS_PATH / dataset_path.name / pipeline_path.name / "pipeline.yaml",
)


def create_airflow_folders(airflow_home: pathlib.Path):
Expand Down Expand Up @@ -204,6 +211,7 @@ def test_script_can_deploy_without_variables_files(

mocker.patch("scripts.deploy_dag.run_gsutil_cmd")
mocker.patch("scripts.deploy_dag.run_cloud_composer_vars_import")
mocker.patch("scripts.deploy_dag.composer_airflow_version", return_value=1)

deploy_dag.main(
local=False,
Expand All @@ -217,6 +225,37 @@ def test_script_can_deploy_without_variables_files(
)


def test_script_errors_out_when_deploying_airflow2_dag_to_airflow1_env(
dataset_path: pathlib.Path,
pipeline_path: pathlib.Path,
airflow_home: pathlib.Path,
env: str,
mocker,
):
setup_dag_and_variables(
dataset_path,
pipeline_path,
airflow_home,
env,
f"{dataset_path.name}_variables.json",
)

mocker.patch("scripts.deploy_dag.get_dag_airflow_version", return_value=2)
mocker.patch("scripts.deploy_dag.composer_airflow_version", return_value=1)

with pytest.raises(Exception):
deploy_dag.main(
local=False,
env_path=ENV_PATH,
dataset_id=dataset_path.name,
pipeline=pipeline_path.name,
airflow_home=airflow_home,
composer_env="test-env",
composer_bucket="test-bucket",
composer_region="test-region",
)


def test_script_with_pipeline_arg_deploys_only_that_pipeline(
dataset_path: pathlib.Path,
pipeline_path: pathlib.Path,
Expand Down
14 changes: 14 additions & 0 deletions tests/scripts/test_generate_dag.py
Expand Up @@ -110,6 +110,20 @@ def test_main_generates_dag_files(
assert (path_prefix / f"{pipeline_path.name}_dag.py").exists()


def test_main_copies_pipeline_yaml_file(
dataset_path: pathlib.Path, pipeline_path: pathlib.Path, env: str
):
copy_config_files_and_set_tmp_folder_names_as_ids(dataset_path, pipeline_path)

generate_dag.main(dataset_path.name, pipeline_path.name, env)

for path_prefix in (
pipeline_path,
ENV_DATASETS_PATH / dataset_path.name / pipeline_path.name,
):
assert (path_prefix / "pipeline.yaml").exists()


def test_main_copies_custom_dir_if_it_exists(
dataset_path: pathlib.Path, pipeline_path: pathlib.Path, env: str
):
Expand Down