Skip to content

Commit

Permalink
Fix pre-mature evaluation of tasks in mapped task group (#34337)
Browse files Browse the repository at this point in the history
* Fix pre-mature evaluation of tasks in mapped task group

Getting the relevant upstream indexes of a task instance in a mapped task group
should only be done when the task has expanded. If the task has not expanded yet,
we should return None so that the task can wait for the upstreams before trying
to run.
This issue is more noticeable when the trigger rule is ONE_FAILED because then,
the task instance is marked as SKIPPED.
This commit fixes this issue.
closes: #34023

* fixup! Fix pre-mature evaluation of tasks in mapped task group

* fixup! fixup! Fix pre-mature evaluation of tasks in mapped task group

* fixup! fixup! fixup! Fix pre-mature evaluation of tasks in mapped task group

* Fix tests

(cherry picked from commit 69938fd)
  • Loading branch information
ephraimbuddy committed Nov 1, 2023
1 parent 836ef42 commit e6662d0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 8 deletions.
18 changes: 18 additions & 0 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep, TIDepStatus
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule as TR

if TYPE_CHECKING:
Expand Down Expand Up @@ -131,6 +132,20 @@ def _get_expanded_ti_count() -> int:
"""
return ti.task.get_mapped_ti_count(ti.run_id, session=session)

def _iter_expansion_dependencies() -> Iterator[str]:
from airflow.models.mappedoperator import MappedOperator

if isinstance(ti.task, MappedOperator):
for op in ti.task.iter_mapped_dependencies():
yield op.task_id
task_group = ti.task.task_group
if task_group and task_group.iter_mapped_task_groups():
yield from (
op.task_id
for tg in task_group.iter_mapped_task_groups()
for op in tg.iter_mapped_dependencies()
)

@functools.lru_cache
def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
"""Get the given task's map indexes relevant to the current ti.
Expand All @@ -141,6 +156,9 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
"""
if TYPE_CHECKING:
assert isinstance(ti.task.dag, DAG)
if isinstance(ti.task.task_group, MappedTaskGroup):
if upstream_id not in set(_iter_expansion_dependencies()):
return None
try:
expanded_ti_count = _get_expanded_ti_count()
except (NotFullyPopulated, NotMapped):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,8 +1303,8 @@ def file_transforms(filename):
states = self.get_states(dr)
expected = {
"file_transforms.my_setup": {0: "success", 1: "failed", 2: "skipped"},
"file_transforms.my_work": {0: "success", 1: "upstream_failed", 2: "skipped"},
"file_transforms.my_teardown": {0: "success", 1: "upstream_failed", 2: "skipped"},
"file_transforms.my_work": {2: "upstream_failed", 1: "upstream_failed", 0: "upstream_failed"},
"file_transforms.my_teardown": {2: "success", 1: "success", 0: "success"},
}

assert states == expected
Expand Down
47 changes: 41 additions & 6 deletions tests/ti_deps/deps/test_trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,19 +1164,23 @@ def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)]

# After running the first t1, the first t2 becomes immediately available.
# After running the first t1, the remaining t1 must be run before t2 is available.
tis["tg.t1", 0].run()
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2), ("tg.t2", 0)]
assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2)]

# Similarly for the subsequent t2 instances.
# After running all t1, t2 is available.
tis["tg.t1", 1].run()
tis["tg.t1", 2].run()
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t1", 1), ("tg.t2", 0), ("tg.t2", 2)]
assert sorted(tis) == [("tg.t2", 0), ("tg.t2", 1), ("tg.t2", 2)]

# But running t2 partially does not make t3 available.
tis["tg.t1", 1].run()
# Similarly for t2 instances. They both have to complete before t3 is available
tis["tg.t2", 0].run()
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t2", 1), ("tg.t2", 2)]

# But running t2 partially does not make t3 available.
tis["tg.t2", 2].run()
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t2", 1)]
Expand Down Expand Up @@ -1406,3 +1410,34 @@ def w2():
(status,) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session)
assert status.reason.startswith("All setup tasks must complete successfully")
assert self.get_ti(dr, "w2").state == expected


def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker, session):
"""Test that one failed trigger rule works well in mapped task group"""
with dag_maker() as dag:

@dag.task
def t1():
return [1, 2, 3]

@task_group("tg1")
def tg1(a):
@dag.task()
def t2(a):
return a

@dag.task(trigger_rule=TriggerRule.ONE_FAILED)
def t3(a):
return a

t2(a) >> t3(a)

t = t1()
tg1.expand(a=t)

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id="t1")
ti.run()
dr.task_instance_scheduling_decisions()
ti3 = dr.get_task_instance(task_id="tg1.t3")
assert not ti3.state

0 comments on commit e6662d0

Please sign in to comment.