Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry_from_failure parameter to DbtCloudRunJobOperator #38868

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
27 changes: 22 additions & 5 deletions airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def trigger_job_run(
account_id: int | None = None,
steps_override: list[str] | None = None,
schema_override: str | None = None,
retry_from_failure: bool = False,
additional_run_config: dict[str, Any] | None = None,
) -> Response:
"""
Expand All @@ -410,6 +411,8 @@ def trigger_job_run(
instead of those configured in dbt Cloud.
:param schema_override: Optional. Override the destination schema in the configured target for this
job.
:param retry_from_failure: Optional. If set to True, the job will be triggered using the "rerun"
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
endpoint.
:param additional_run_config: Optional. Any additional parameters that should be included in the API
request when triggering the job.
:return: The request response.
Expand All @@ -424,11 +427,14 @@ def trigger_job_run(
}
payload.update(additional_run_config)

return self._run_and_get_response(
method="POST",
endpoint=f"{account_id}/jobs/{job_id}/run/",
payload=json.dumps(payload),
)
if retry_from_failure:
return self.retry_failed_job_run(job_id, account_id)
else:
return self._run_and_get_response(
method="POST",
endpoint=f"{account_id}/jobs/{job_id}/run/",
payload=json.dumps(payload),
)
boraberke marked this conversation as resolved.
Show resolved Hide resolved

@fallback_to_default_account
def list_job_runs(
Expand Down Expand Up @@ -647,6 +653,17 @@ async def get_job_run_artifacts_concurrently(
results = await asyncio.gather(*tasks.values())
return {filename: result.json() for filename, result in zip(tasks.keys(), results)}

@fallback_to_default_account
def retry_failed_job_run(self, job_id: int, account_id: int | None = None) -> Response:
"""
Retry a failed run for a job from the point of failure, if the run failed. Otherwise, trigger a new run.

:param job_id: The ID of a dbt Cloud job.
:param account_id: Optional. The ID of a dbt Cloud account.
:return: The request response.
"""
return self._run_and_get_response(method="POST", endpoint=f"{account_id}/jobs/{job_id}/rerun/")

def test_connection(self) -> tuple[bool, str]:
"""Test dbt Cloud connection."""
try:
Expand Down
3 changes: 3 additions & 0 deletions airflow/providers/dbt/cloud/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
check_interval: int = 60,
additional_run_config: dict[str, Any] | None = None,
reuse_existing_run: bool = False,
retry_from_failure: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
Expand All @@ -121,6 +122,7 @@ def __init__(
self.additional_run_config = additional_run_config or {}
self.run_id: int | None = None
self.reuse_existing_run = reuse_existing_run
self.retry_from_failure = retry_from_failure
self.deferrable = deferrable

def execute(self, context: Context):
Expand Down Expand Up @@ -150,6 +152,7 @@ def execute(self, context: Context):
cause=self.trigger_reason,
steps_override=self.steps_override,
schema_override=self.schema_override,
retry_from_failure=self.retry_from_failure,
additional_run_config=self.additional_run_config,
)
self.run_id = trigger_job_response.json()["data"]["id"]
Expand Down
21 changes: 21 additions & 0 deletions tests/providers/dbt/cloud/hooks/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,27 @@ def test_trigger_job_run_with_additional_run_configs(
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
ids=["default_account", "explicit_account"],
)
@patch.object(DbtCloudHook, "run")
@patch.object(DbtCloudHook, "_paginate")
def test_trigger_job_run_with_retry_from_failure(self, mock_http_run, mock_paginate, conn_id, account_id):
hook = DbtCloudHook(conn_id)
cause = ""
retry_from_failure = True
hook.trigger_job_run(
job_id=JOB_ID, cause=cause, account_id=account_id, retry_from_failure=retry_from_failure
)

assert hook.method == "POST"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/jobs/{JOB_ID}/rerun/", data=None)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
Expand Down