From de91d2ccf53bd844b4dbf4f64dd087f4ee935be5 Mon Sep 17 00:00:00 2001 From: Arya Massarat <23412689+aryarm@users.noreply.github.com> Date: Fri, 24 Sep 2021 06:48:11 -0700 Subject: [PATCH] fix: merging of pipe groups when multiple rules are chained together via pipes (#1173) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * when handling pipes, process jobs in BFS order handle_pipes() depends on the jobs being in BFS order otherwise, it won't merge groups together properly see snakemake/snakemake#975 * create test for #975 - multiple piped rules * create output for multiple pipes test * fix formatting and filter finished jobs * register test_pipes_multiple with nosetests in tests.py * perf: use mergeable CandidateGroup objects instead of topological sorting of the jobs (faster). Co-authored-by: Johannes Köster --- snakemake/dag.py | 41 +++++++++++++++---- snakemake/jobs.py | 1 + tests/test_pipes_multiple/Snakefile | 27 ++++++++++++ .../expected-results/test.out | 2 + tests/tests.py | 6 +++ 5 files changed, 69 insertions(+), 8 deletions(-) create mode 100644 tests/test_pipes_multiple/Snakefile create mode 100644 tests/test_pipes_multiple/expected-results/test.out diff --git a/snakemake/dag.py b/snakemake/dag.py index dbd6081fe..f6ef9b69d 100755 --- a/snakemake/dag.py +++ b/snakemake/dag.py @@ -1224,6 +1224,8 @@ def postprocess(self, update_needrun=True): def handle_pipes(self): """Use pipes to determine job groups. Check if every pipe has exactly one consumer""" + + visited = set() for job in self.needrun_jobs: candidate_groups = set() if job.group is not None: @@ -1283,22 +1285,31 @@ def handle_pipes(self): continue if len(candidate_groups) > 1: - raise WorkflowError( - "An output file is marked as " - "pipe, but consuming jobs " - "are part of conflicting " - "groups.", - rule=job.rule, - ) + if all(isinstance(group, CandidateGroup) for group in candidate_groups): + for g in candidate_groups: + g.merge(group) + else: + raise WorkflowError( + "An output file is marked as " + "pipe, but consuming jobs " + "are part of conflicting " + "groups.", + rule=job.rule, + ) elif candidate_groups: # extend the candidate group to all involved jobs group = candidate_groups.pop() else: # generate a random unique group name - group = str(uuid.uuid4()) + group = CandidateGroup() # str(uuid.uuid4()) job.group = group + visited.add(job) for j in all_depending: j.group = group + visited.add(j) + + for job in visited: + job.group = group.id if isinstance(group, CandidateGroup) else group def _ready(self, job): """Return whether the given job is ready to execute.""" @@ -2181,3 +2192,17 @@ def __str__(self): def __len__(self): return self._len + + +class CandidateGroup: + def __init__(self): + self.id = str(uuid.uuid4()) + + def __eq__(self, other): + return self.id == other.id + + def __hash__(self): + return hash(self.id) + + def merge(self, other): + self.id = other.id diff --git a/snakemake/jobs.py b/snakemake/jobs.py index eadd19787..a1546ebfe 100644 --- a/snakemake/jobs.py +++ b/snakemake/jobs.py @@ -289,6 +289,7 @@ def group(self): @group.setter def group(self, group): + print(group, type(group)) self._group = group @property diff --git a/tests/test_pipes_multiple/Snakefile b/tests/test_pipes_multiple/Snakefile new file mode 100644 index 000000000..ae6e8bbd4 --- /dev/null +++ b/tests/test_pipes_multiple/Snakefile @@ -0,0 +1,27 @@ +shell.executable("bash") + +rule all: + input: + "test.out" + +rule a: + output: + pipe("testa.{i}.txt") + shell: + "echo {wildcards.i} > {output}" + +rule b: + input: + rules.a.output + output: + pipe("testb.{i}.txt") + shell: + "cat {input} > {output}" + +rule c: + input: + expand(rules.b.output, i=range(2)) + output: + "test.out" + shell: + "cat {input} > {output}" diff --git a/tests/test_pipes_multiple/expected-results/test.out b/tests/test_pipes_multiple/expected-results/test.out new file mode 100644 index 000000000..0d66ea1ae --- /dev/null +++ b/tests/test_pipes_multiple/expected-results/test.out @@ -0,0 +1,2 @@ +0 +1 diff --git a/tests/tests.py b/tests/tests.py index 0b887f4aa..9511a63d2 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -832,6 +832,12 @@ def test_pipes(): run(dpath("test_pipes")) +@skip_on_windows +def test_pipes_multiple(): + # see github issue #975 + run(dpath("test_pipes_multiple")) + + def test_pipes_fail(): run(dpath("test_pipes_fail"), shouldfail=True)