Skip to content

Commit

Permalink
fix: Deterministic foreach task id's for Argo Workflows (#1704)
Browse files Browse the repository at this point in the history
* change to more deterministic task ids for argo workflows

* cleanup comments

* wip: rework task id generation for nested foreaches

* wip: stash

* wip: rework

* wip: possibly finally working nested foreach joins

* cleanup

* cleanup and fix non-nested foreach input params

* cleanup

* one more fix for foreach join cases.

* add more thorough comment on foreach step task id generation

* rename max-split to split-cardinality

* more comments on task id generation

* cleanup generate_input_paths

* comment updates

* changes

---------

Co-authored-by: savin <savingoyal@gmail.com>
  • Loading branch information
saikonen and savingoyal committed May 6, 2024
1 parent 9989bc6 commit cb200c5
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 39 deletions.
182 changes: 162 additions & 20 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,9 @@ def _compile_workflow_template(self):

# Visit every node and yield the uber DAGTemplate(s).
def _dag_templates(self):
def _visit(node, exit_node=None, templates=None, dag_tasks=None):
def _visit(
node, exit_node=None, templates=None, dag_tasks=None, parent_foreach=None
):
# Every for-each node results in a separate subDAG and an equivalent
# DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
# has a unique name - the top-level DAGTemplate is named as the name of
Expand Down Expand Up @@ -883,6 +885,37 @@ def _visit(node, exit_node=None, templates=None, dag_tasks=None):
)
)
]
# NOTE: Due to limitations with Argo Workflows Parameter size we
# can not pass arbitrarily large lists of task id's to join tasks.
# Instead we ensure that task id's for foreach tasks can be
# deduced deterministically and pass the relevant information to
# the join task.
#
# We need to add the split-index and root-input-path for the last
# step in any foreach scope and use these to generate the task id,
# as the join step uses the root and the cardinality of the
# foreach scope to generate the required id's.
if (
node.is_inside_foreach
and self.graph[node.out_funcs[0]].type == "join"
):
if any(
self.graph[parent].matching_join
== self.graph[node.out_funcs[0]].name
and self.graph[parent].type == "foreach"
for parent in self.graph[node.out_funcs[0]].split_parents
):
parameters.extend(
[
Parameter("split-index").value(
"{{inputs.parameters.split-index}}"
),
Parameter("root-input-path").value(
"{{inputs.parameters.input-paths}}"
),
]
)

dag_task = (
DAGTask(self._sanitize(node.name))
.dependencies(
Expand All @@ -903,9 +936,19 @@ def _visit(node, exit_node=None, templates=None, dag_tasks=None):
# For split nodes traverse all the children
if node.type == "split":
for n in node.out_funcs:
_visit(self.graph[n], node.matching_join, templates, dag_tasks)
_visit(
self.graph[n],
node.matching_join,
templates,
dag_tasks,
parent_foreach,
)
return _visit(
self.graph[node.matching_join], exit_node, templates, dag_tasks
self.graph[node.matching_join],
exit_node,
templates,
dag_tasks,
parent_foreach,
)
# For foreach nodes generate a new sub DAGTemplate
elif node.type == "foreach":
Expand All @@ -929,6 +972,16 @@ def _visit(node, exit_node=None, templates=None, dag_tasks=None):
),
Parameter("split-index").value("{{item}}"),
]
+ (
[
Parameter("root-input-path").value(
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
% (node.name, self._sanitize(node.name))
),
]
if parent_foreach
else []
)
)
)
.with_param(
Expand All @@ -938,13 +991,18 @@ def _visit(node, exit_node=None, templates=None, dag_tasks=None):
)
dag_tasks.append(foreach_task)
templates, dag_tasks_1 = _visit(
self.graph[node.out_funcs[0]], node.matching_join, templates, []
self.graph[node.out_funcs[0]],
node.matching_join,
templates,
[],
node.name,
)
templates.append(
Template(foreach_template_name)
.inputs(
Inputs().parameters(
[Parameter("input-paths"), Parameter("split-index")]
+ ([Parameter("root-input-path")] if parent_foreach else [])
)
)
.outputs(
Expand All @@ -971,13 +1029,26 @@ def _visit(node, exit_node=None, templates=None, dag_tasks=None):
Arguments().parameters(
[
Parameter("input-paths").value(
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters}}"
% (
self.graph[node.matching_join].in_funcs[-1],
foreach_template_name,
)
)
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
% (node.name, self._sanitize(node.name))
),
Parameter("split-cardinality").value(
"{{tasks.%s.outputs.parameters.split-cardinality}}"
% self._sanitize(node.name)
),
]
+ (
[
Parameter("split-index").value(
"{{inputs.parameters.split-index}}"
),
Parameter("root-input-path").value(
"{{inputs.parameters.input-paths}}"
),
]
if parent_foreach
else []
)
)
)
)
Expand All @@ -987,11 +1058,16 @@ def _visit(node, exit_node=None, templates=None, dag_tasks=None):
exit_node,
templates,
dag_tasks,
parent_foreach,
)
# For linear nodes continue traversing to the next node
if node.type in ("linear", "join", "start"):
return _visit(
self.graph[node.out_funcs[0]], exit_node, templates, dag_tasks
self.graph[node.out_funcs[0]],
exit_node,
templates,
dag_tasks,
parent_foreach,
)
else:
raise ArgoWorkflowsException(
Expand Down Expand Up @@ -1034,11 +1110,43 @@ def _container_templates(self):
# Ideally, we would like these task ids to be the same as node name
# (modulo retry suffix) on Argo Workflows but that doesn't seem feasible
# right now.
task_str = node.name + "-{{workflow.creationTimestamp}}"

task_idx = ""
input_paths = ""
root_input = None
# export input_paths as it is used multiple times in the container script
# and we do not want to repeat the values.
input_paths_expr = "export INPUT_PATHS=''"
if node.name != "start":
task_str += "-{{inputs.parameters.input-paths}}"
input_paths_expr = (
"export INPUT_PATHS={{inputs.parameters.input-paths}}"
)
input_paths = "$(echo $INPUT_PATHS)"
if any(self.graph[n].type == "foreach" for n in node.in_funcs):
task_str += "-{{inputs.parameters.split-index}}"
task_idx = "{{inputs.parameters.split-index}}"
if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
if any(
self.graph[parent].matching_join
== self.graph[node.out_funcs[0]].name
for parent in self.graph[node.out_funcs[0]].split_parents
if self.graph[parent].type == "foreach"
) and any(not self.graph[f].type == "foreach" for f in node.in_funcs):
# we need to propagate the split-index and root-input-path info for
# the last step inside a foreach for correctly joining nested
# foreaches
task_idx = "{{inputs.parameters.split-index}}"
root_input = "{{inputs.parameters.root-input-path}}"

# Task string to be hashed into an ID
task_str = "-".join(
[
node.name,
"{{workflow.creationTimestamp}}",
root_input or input_paths,
task_idx,
]
)

# Generated task_ids need to be non-numeric - see register_task_id in
# service.py. We do so by prefixing `t-`
task_id_expr = (
Expand Down Expand Up @@ -1087,6 +1195,7 @@ def _container_templates(self):
# env var.
'${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"}',
"mkdir -p $PWD/.logs",
input_paths_expr,
task_id_expr,
mflog_expr,
]
Expand All @@ -1098,8 +1207,6 @@ def _container_templates(self):
node.name, self.flow_datastore.TYPE
)

input_paths = "{{inputs.parameters.input-paths}}"

top_opts_dict = {
"with": [
decorator.make_decorator_spec()
Expand Down Expand Up @@ -1168,10 +1275,16 @@ def _container_templates(self):
node.type == "join"
and self.graph[node.split_parents[-1]].type == "foreach"
):
# Set aggregated input-paths for a foreach-join
# Set aggregated input-paths for a for-each join
foreach_step = next(
n for n in node.in_funcs if self.graph[n].is_inside_foreach
)
input_paths = (
"$(python -m metaflow.plugins.argo.process_input_paths %s)"
% input_paths
"$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})"
% (
foreach_step,
input_paths,
)
)
step = [
"step",
Expand Down Expand Up @@ -1337,13 +1450,37 @@ def _container_templates(self):
# input. Analogously, if the node under consideration is a foreach
# node, then we emit split cardinality as an extra output. I would like
# to thank the designers of Argo Workflows for making this so
# straightforward!
# straightforward! Things become a bit more complicated to support very
# wide foreaches where we have to resort to passing a root-input-path
# so that we can compute the task ids for each parent task of a for-each
# join task deterministically inside the join task without resorting to
# passing a rather long list of (albiet compressed)
inputs = []
if node.name != "start":
inputs.append(Parameter("input-paths"))
if any(self.graph[n].type == "foreach" for n in node.in_funcs):
# Fetch split-index from parent
inputs.append(Parameter("split-index"))
if (
node.type == "join"
and self.graph[node.split_parents[-1]].type == "foreach"
):
# append this only for joins of foreaches, not static splits
inputs.append(Parameter("split-cardinality"))
if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
if any(
self.graph[parent].matching_join
== self.graph[node.out_funcs[0]].name
for parent in self.graph[node.out_funcs[0]].split_parents
if self.graph[parent].type == "foreach"
) and any(not self.graph[f].type == "foreach" for f in node.in_funcs):
# we need to propagate the split-index and root-input-path info for
# the last step inside a foreach for correctly joining nested
# foreaches
if not any(self.graph[n].type == "foreach" for n in node.in_funcs):
# Don't add duplicate split index parameters.
inputs.append(Parameter("split-index"))
inputs.append(Parameter("root-input-path"))

outputs = []
if node.name != "end":
Expand All @@ -1353,6 +1490,11 @@ def _container_templates(self):
outputs.append(
Parameter("num-splits").valueFrom({"path": "/mnt/out/splits"})
)
outputs.append(
Parameter("split-cardinality").valueFrom(
{"path": "/mnt/out/split_cardinality"}
)
)

# It makes no sense to set env vars to None (shows up as "None" string)
# Also we skip some env vars (e.g. in case we want to pull them from KUBERNETES_SECRETS)
Expand Down
3 changes: 3 additions & 0 deletions metaflow/plugins/argo/argo_workflows_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def task_finished(
if graph[step_name].type == "foreach":
with open("/mnt/out/splits", "w") as file:
json.dump(list(range(flow._foreach_num_splits)), file)
with open("/mnt/out/split_cardinality", "w") as file:
json.dump(flow._foreach_num_splits, file)

# Unfortunately, we can't always use pod names as task-ids since the pod names
# are not static across retries. We write the task-id to a file that is read
# by the next task here.
Expand Down
23 changes: 23 additions & 0 deletions metaflow/plugins/argo/generate_input_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import sys
from hashlib import md5


def generate_input_paths(step_name, timestamp, input_paths, split_cardinality):
# => run_id/step/:foo,bar
run_id = input_paths.split("/")[0]
foreach_base_id = "{}-{}-{}".format(step_name, timestamp, input_paths)

ids = [_generate_task_id(foreach_base_id, i) for i in range(int(split_cardinality))]
return "{}/{}/:{}".format(run_id, step_name, ",".join(ids))


def _generate_task_id(base, idx):
# For foreach splits generate the expected input-paths based on split_cardinality and base_id.
# newline required at the end due to 'echo' appending one in the shell side task_id creation.
task_str = "%s-%s\n" % (base, idx)
hash = md5(task_str.encode("utf-8")).hexdigest()[-8:]
return "t-" + hash


if __name__ == "__main__":
print(generate_input_paths(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]))
19 changes: 0 additions & 19 deletions metaflow/plugins/argo/process_input_paths.py

This file was deleted.

0 comments on commit cb200c5

Please sign in to comment.