Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry_from_failure parameter to DbtCloudRunJobOperator #38868

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
23 changes: 23 additions & 0 deletions airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def trigger_job_run(
account_id: int | None = None,
steps_override: list[str] | None = None,
schema_override: str | None = None,
retry_from_failure: bool = False,
additional_run_config: dict[str, Any] | None = None,
) -> Response:
"""
Expand All @@ -413,6 +414,9 @@ def trigger_job_run(
instead of those configured in dbt Cloud.
:param schema_override: Optional. Override the destination schema in the configured target for this
job.
:param retry_from_failure: Optional. If set to True, the job will be triggered using the "rerun"
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
endpoint. This parameter cannot be used alongside steps_override, schema_override, or
additional_run_config.
:param additional_run_config: Optional. Any additional parameters that should be included in the API
request when triggering the job.
:return: The request response.
Expand All @@ -427,6 +431,14 @@ def trigger_job_run(
}
payload.update(additional_run_config)

if retry_from_failure:
if steps_override is not None or schema_override is not None or additional_run_config != {}:
raise ValueError(
"steps_override, schema_override, or additional_run_config"
" cannot be used when retry_from_failure is True."
)
return self.retry_failed_job_run(job_id, account_id)

return self._run_and_get_response(
method="POST",
endpoint=f"{account_id}/jobs/{job_id}/run/",
Expand Down Expand Up @@ -650,6 +662,17 @@ async def get_job_run_artifacts_concurrently(
results = await asyncio.gather(*tasks.values())
return {filename: result.json() for filename, result in zip(tasks.keys(), results)}

@fallback_to_default_account
def retry_failed_job_run(self, job_id: int, account_id: int | None = None) -> Response:
"""
Retry a failed run for a job from the point of failure, if the run failed. Otherwise, trigger a new run.

:param job_id: The ID of a dbt Cloud job.
:param account_id: Optional. The ID of a dbt Cloud account.
:return: The request response.
"""
return self._run_and_get_response(method="POST", endpoint=f"{account_id}/jobs/{job_id}/rerun/")

def test_connection(self) -> tuple[bool, str]:
"""Test dbt Cloud connection."""
try:
Expand Down
7 changes: 7 additions & 0 deletions airflow/providers/dbt/cloud/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ class DbtCloudRunJobOperator(BaseOperator):
request when triggering the job.
:param reuse_existing_run: Flag to determine whether to reuse existing non terminal job run. If set to
true and non terminal job runs found, it use the latest run without triggering a new job run.
:param retry_from_failure: Flag to determine whether to retry the job run from failure. If set to true
and the task retry number is greater than 1, it will retry the job run from the failure point. For
more information on retry logic, see:
https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job
:param deferrable: Run operator in the deferrable mode
:return: The ID of the triggered dbt Cloud job run.
"""
Expand Down Expand Up @@ -105,6 +109,7 @@ def __init__(
check_interval: int = 60,
additional_run_config: dict[str, Any] | None = None,
reuse_existing_run: bool = False,
retry_from_failure: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
Expand All @@ -121,6 +126,7 @@ def __init__(
self.additional_run_config = additional_run_config or {}
self.run_id: int | None = None
self.reuse_existing_run = reuse_existing_run
self.retry_from_failure = retry_from_failure
self.deferrable = deferrable

def execute(self, context: Context):
Expand Down Expand Up @@ -150,6 +156,7 @@ def execute(self, context: Context):
cause=self.trigger_reason,
steps_override=self.steps_override,
schema_override=self.schema_override,
retry_from_failure=self.retry_from_failure and context["ti"].try_number > 1,
additional_run_config=self.additional_run_config,
)
self.run_id = trigger_job_response.json()["data"]["id"]
Expand Down
4 changes: 4 additions & 0 deletions docs/apache-airflow-providers-dbt-cloud/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ resource utilization while the job is running.
When ``wait_for_termination`` is False and ``deferrable`` is False, we just submit the job and can only
track the job status with the :class:`~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor`.

When ``retry_from_failure`` is True and Task Instance ``try_number`` is greater than 1, we retry
the failed run for a job from the point of failure, if the run failed. Otherwise we trigger a new run.
For more information on the retry logic, reference the
`API documentation <https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job>`__.

While ``schema_override`` and ``steps_override`` are explicit, optional parameters for the
``DbtCloudRunJobOperator``, custom run configurations can also be passed to the operator using the
Expand Down
77 changes: 77 additions & 0 deletions tests/providers/dbt/cloud/hooks/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,83 @@ def test_trigger_job_run_with_additional_run_configs(
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
ids=["default_account", "explicit_account"],
)
@patch.object(DbtCloudHook, "run")
@patch.object(DbtCloudHook, "_paginate")
def test_trigger_job_run_with_retry_from_failure(
self,
mock_http_run,
mock_paginate,
conn_id,
account_id,
):
hook = DbtCloudHook(conn_id)
cause = ""
retry_from_failure = True
hook.trigger_job_run(
job_id=JOB_ID, cause=cause, account_id=account_id, retry_from_failure=retry_from_failure
)

assert hook.method == "POST"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/rerun/", data=None
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
ids=["default_account", "explicit_account"],
)
@pytest.mark.parametrize(
argnames="steps_override, schema_override, additional_run_config",
argvalues=[
(["dbt test", "dbt run"], None, None),
(None, ["other_schema"], None),
(None, None, {"threads_override": 8, "generate_docs_override": False}),
],
)
@patch.object(DbtCloudHook, "run")
@patch.object(DbtCloudHook, "_paginate")
def test_failed_trigger_job_run_with_retry_from_failure(
self,
mock_http_run,
mock_paginate,
conn_id,
account_id,
steps_override,
schema_override,
additional_run_config,
):
hook = DbtCloudHook(conn_id)
cause = ""
retry_from_failure = True
error_match = (
"steps_override, schema_override, or additional_run_config"
" cannot be used when retry_from_failure is True"
)

with pytest.raises(ValueError, match=error_match):
hook.trigger_job_run(
job_id=JOB_ID,
cause=cause,
account_id=account_id,
steps_override=steps_override,
schema_override=schema_override,
additional_run_config=additional_run_config,
retry_from_failure=retry_from_failure,
)

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_not_called()
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
Expand Down
53 changes: 53 additions & 0 deletions tests/providers/dbt/cloud/operators/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def test_execute_wait_for_termination(
cause=f"Triggered via Apache Airflow by task {TASK_ID!r} in the {self.dag.dag_id} DAG.",
steps_override=self.config["steps_override"],
schema_override=self.config["schema_override"],
retry_from_failure=False,
additional_run_config=self.config["additional_run_config"],
)

Expand Down Expand Up @@ -299,6 +300,7 @@ def test_execute_no_wait_for_termination(self, mock_run_job, conn_id, account_id
cause=f"Triggered via Apache Airflow by task {TASK_ID!r} in the {self.dag.dag_id} DAG.",
steps_override=self.config["steps_override"],
schema_override=self.config["schema_override"],
retry_from_failure=False,
additional_run_config=self.config["additional_run_config"],
)

Expand Down Expand Up @@ -366,6 +368,56 @@ def test_execute_no_wait_for_termination_and_reuse_existing_run(
},
)

@patch.object(DbtCloudHook, "trigger_job_run")
@pytest.mark.parametrize(
"try_number, expected_retry_from_failure",
[
(1, False),
(2, True),
(3, True),
],
)
@pytest.mark.parametrize(
"conn_id, account_id",
[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
ids=["default_account", "explicit_account"],
)
def test_execute_retry_from_failure(
self, mock_run_job, try_number, expected_retry_from_failure, conn_id, account_id
):
operator = DbtCloudRunJobOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=conn_id,
account_id=account_id,
trigger_reason=None,
dag=self.dag,
retry_from_failure=True,
**self.config,
)

assert operator.dbt_cloud_conn_id == conn_id
assert operator.job_id == self.config["job_id"]
assert operator.account_id == account_id
assert operator.check_interval == self.config["check_interval"]
assert operator.timeout == self.config["timeout"]
assert operator.retry_from_failure
assert operator.steps_override == self.config["steps_override"]
assert operator.schema_override == self.config["schema_override"]
assert operator.additional_run_config == self.config["additional_run_config"]

self.mock_ti.try_number = try_number
operator.execute(context={"ti": self.mock_ti})

mock_run_job.assert_called_once_with(
account_id=account_id,
job_id=JOB_ID,
cause=f"Triggered via Apache Airflow by task {TASK_ID!r} in the {self.dag.dag_id} DAG.",
steps_override=self.config["steps_override"],
schema_override=self.config["schema_override"],
retry_from_failure=expected_retry_from_failure,
additional_run_config=self.config["additional_run_config"],
)

@patch.object(DbtCloudHook, "trigger_job_run")
@pytest.mark.parametrize(
"conn_id, account_id",
Expand Down Expand Up @@ -398,6 +450,7 @@ def test_custom_trigger_reason(self, mock_run_job, conn_id, account_id):
cause=custom_trigger_reason,
steps_override=self.config["steps_override"],
schema_override=self.config["schema_override"],
retry_from_failure=False,
additional_run_config=self.config["additional_run_config"],
)

Expand Down