Skip to content

Commit

Permalink
fix: in rules from imported modules, exclude modified paths from modu…
Browse files Browse the repository at this point in the history
…le prefixing (#1494)
  • Loading branch information
dlaehnemann committed Mar 23, 2022
1 parent 9e92d63 commit 1e73db0
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 27 deletions.
50 changes: 35 additions & 15 deletions snakemake/rules.py
Expand Up @@ -118,7 +118,10 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0):
self.is_checkpoint = False
self.restart_times = 0
self.basedir = None
self.path_modifer = None
self.input_modifier = None
self.output_modifier = None
self.log_modifier = None
self.benchmark_modifier = None
self.ruleinfo = None
elif len(args) == 1:
other = args[0]
Expand Down Expand Up @@ -168,7 +171,10 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0):
self.is_checkpoint = other.is_checkpoint
self.restart_times = other.restart_times
self.basedir = other.basedir
self.path_modifier = other.path_modifier
self.input_modifier = other.input_modifier
self.output_modifier = other.output_modifier
self.log_modifier = other.log_modifier
self.benchmark_modifier = other.benchmark_modifier
self.ruleinfo = other.ruleinfo

def dynamic_branch(self, wildcards, input=True):
Expand Down Expand Up @@ -348,7 +354,9 @@ def benchmark(self, benchmark):
if isinstance(benchmark, Path):
benchmark = str(benchmark)
if not callable(benchmark):
benchmark = self.apply_path_modifier(benchmark, property="benchmark")
benchmark = self.apply_path_modifier(
benchmark, self.benchmark_modifier, property="benchmark"
)
benchmark = self._update_item_wildcard_constraints(benchmark)

self._benchmark = IOFile(benchmark, rule=self)
Expand Down Expand Up @@ -471,9 +479,9 @@ def check_output_duplicates(self):
)
seen[value] = name or idx

def apply_path_modifier(self, item, property=None):
assert self.path_modifier is not None
apply = partial(self.path_modifier.modify, property=property)
def apply_path_modifier(self, item, path_modifier, property=None):
assert path_modifier is not None
apply = partial(path_modifier.modify, property=property)

assert not callable(item)
if isinstance(item, dict):
Expand Down Expand Up @@ -528,9 +536,14 @@ def _set_inoutput_item(self, item, output=False, name=None):
if isinstance(item, _IOFile) and item.rule and item in item.rule.output:
rule_dependency = item.rule

item = self.apply_path_modifier(
item, property="output" if output else "input"
)
if output:
path_modifier = self.output_modifier
property = "output"
else:
path_modifier = self.input_modifier
property = "input"

item = self.apply_path_modifier(item, path_modifier, property=property)

# Check to see that all flags are valid
# Note that "remote", "dynamic", and "expand" are valid for both inputs and outputs.
Expand Down Expand Up @@ -685,7 +698,7 @@ def _set_log_item(self, item, name=None):
item = str(item)
if isinstance(item, str) or callable(item):
if not callable(item):
item = self.apply_path_modifier(item, property="log")
item = self.apply_path_modifier(item, self.log_modifier, property="log")
item = self._update_item_wildcard_constraints(item)

self.log.append(IOFile(item, rule=self) if isinstance(item, str) else item)
Expand Down Expand Up @@ -769,7 +782,7 @@ def _apply_wildcards(
mapping=None,
no_flattening=False,
aux_params=None,
apply_path_modifier=True,
path_modifier=None,
property=None,
incomplete_checkpoint_func=lambda e: None,
allow_unpack=True,
Expand Down Expand Up @@ -835,8 +848,10 @@ def _apply_wildcards(
"Function did not return str or list of str.", rule=self
)

if from_callable and apply_path_modifier and not incomplete:
item_ = self.apply_path_modifier(item_, property=property)
if from_callable and path_modifier is not None and not incomplete:
item_ = self.apply_path_modifier(
item_, path_modifier, property=property
)

concrete = concretize(item_, wildcards, _is_callable)
newitems.append(concrete)
Expand Down Expand Up @@ -882,6 +897,7 @@ def handle_incomplete_checkpoint(exception):
concretize=concretize_iofile,
mapping=mapping,
incomplete_checkpoint_func=handle_incomplete_checkpoint,
path_modifier=self.input_modifier,
property="input",
groupid=groupid,
)
Expand Down Expand Up @@ -944,7 +960,6 @@ def handle_incomplete_checkpoint(exception):
omit_callable=omit_callable,
allow_unpack=False,
no_flattening=True,
apply_path_modifier=False,
property="params",
aux_params={
"input": input._plainstrings(),
Expand Down Expand Up @@ -993,7 +1008,12 @@ def concretize_logfile(f, wildcards, is_from_callable):

try:
self._apply_wildcards(
log, self.log, wildcards, concretize=concretize_logfile, property="log"
log,
self.log,
wildcards,
concretize=concretize_logfile,
path_modifier=self.log_modifier,
property="log",
)
except WildcardError as e:
raise WildcardError(
Expand Down
31 changes: 19 additions & 12 deletions snakemake/workflow.py
Expand Up @@ -1386,11 +1386,14 @@ def decorate(ruleinfo):
del self._rules[name]
self._rules[ruleinfo.name] = rule
name = rule.name
rule.path_modifier = ruleinfo.path_modifier
if ruleinfo.input:
rule.set_input(*ruleinfo.input[0], **ruleinfo.input[1])
pos_files, keyword_files, modifier = ruleinfo.input
rule.input_modifier = modifier
rule.set_input(*pos_files, **keyword_files)
if ruleinfo.output:
rule.set_output(*ruleinfo.output[0], **ruleinfo.output[1])
pos_files, keyword_files, modifier = ruleinfo.output
rule.output_modifier = modifier
rule.set_output(*pos_files, **keyword_files)
if ruleinfo.params:
rule.set_params(*ruleinfo.params[0], **ruleinfo.params[1])
# handle default resources
Expand Down Expand Up @@ -1428,9 +1431,9 @@ def decorate(ruleinfo):
if ruleinfo.shadow_depth is True:
rule.shadow_depth = "full"
logger.warning(
"Shadow is set to True in rule {} (equivalent to 'full'). It's encouraged to use the more explicit options 'minimal|copy-minimal|shallow|full' instead.".format(
rule
)
f"Shadow is set to True in rule {rule} (equivalent to 'full'). "
"It's encouraged to use the more explicit options "
"'minimal|copy-minimal|shallow|full' instead."
)
else:
rule.shadow_depth = ruleinfo.shadow_depth
Expand Down Expand Up @@ -1465,11 +1468,15 @@ def decorate(ruleinfo):
if ruleinfo.version:
rule.version = ruleinfo.version
if ruleinfo.log:
rule.set_log(*ruleinfo.log[0], **ruleinfo.log[1])
pos_files, keyword_files, modifier = ruleinfo.log
rule.log_modifier = modifier
rule.set_log(*pos_files, **keyword_files)
if ruleinfo.message:
rule.message = ruleinfo.message
if ruleinfo.benchmark:
rule.benchmark = ruleinfo.benchmark
benchmark, modifier = ruleinfo.benchmark
rule.benchmark_modifier = modifier
rule.benchmark = benchmark
if not self.run_local:
group = self.overwrite_groups.get(name) or ruleinfo.group
if group is not None:
Expand Down Expand Up @@ -1637,14 +1644,14 @@ def decorate(ruleinfo):

def input(self, *paths, **kwpaths):
def decorate(ruleinfo):
ruleinfo.input = (paths, kwpaths)
ruleinfo.input = (paths, kwpaths, self.modifier.path_modifier)
return ruleinfo

return decorate

def output(self, *paths, **kwpaths):
def decorate(ruleinfo):
ruleinfo.output = (paths, kwpaths)
ruleinfo.output = (paths, kwpaths, self.modifier.path_modifier)
return ruleinfo

return decorate
Expand Down Expand Up @@ -1689,7 +1696,7 @@ def decorate(ruleinfo):

def benchmark(self, benchmark):
def decorate(ruleinfo):
ruleinfo.benchmark = benchmark
ruleinfo.benchmark = (benchmark, self.modifier.path_modifier)
return ruleinfo

return decorate
Expand Down Expand Up @@ -1780,7 +1787,7 @@ def decorate(ruleinfo):

def log(self, *logs, **kwlogs):
def decorate(ruleinfo):
ruleinfo.log = (logs, kwlogs)
ruleinfo.log = (logs, kwlogs, self.modifier.path_modifier)
return ruleinfo

return decorate
Expand Down
27 changes: 27 additions & 0 deletions tests/test_module_no_prefixing_modified_paths/Snakefile
@@ -0,0 +1,27 @@
module module1:
snakefile: "module1/Snakefile"
config: config


use rule * from module1 as module1_*


# provide a prefix for all paths in module2
module module2:
snakefile: "module2/Snakefile"
config: config
prefix: "module2"


use rule * from module2 as module2_*

# overwrite the input to remove the module2 prefix specified above
use rule c from module2 as module2_c with:
input:
"test.txt"


rule joint_all:
input:
"module2/test_final.txt",
default_target: True
@@ -0,0 +1,2 @@
test_a
test_c
@@ -0,0 +1,5 @@
rule a:
output:
"test.txt"
shell:
"echo test_a > {output}"
17 changes: 17 additions & 0 deletions tests/test_module_no_prefixing_modified_paths/module2/Snakefile
@@ -0,0 +1,17 @@
rule b:
output:
"test.txt"
shell:
"echo test_b > {output}"


rule c:
input:
"test.txt"
output:
"test_final.txt"
shell:
"""
cp {input} {output};
echo test_c >> {output}
"""
8 changes: 8 additions & 0 deletions tests/tests.py
Expand Up @@ -1409,6 +1409,14 @@ def test_module_complex2():
run(dpath("test_module_complex2"), dryrun=True)


@skip_on_windows
def test_module_no_prefixing_modified_paths():
run(
dpath("test_module_no_prefixing_modified_paths"),
targets=["module2/test_final.txt"],
)


def test_module_with_script():
run(dpath("test_module_with_script"))

Expand Down

0 comments on commit 1e73db0

Please sign in to comment.