Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: in rules from imported modules, exclude modified paths from module prefixing #1494

Merged
merged 11 commits into from Mar 23, 2022
50 changes: 35 additions & 15 deletions snakemake/rules.py
Expand Up @@ -118,7 +118,10 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0):
self.is_checkpoint = False
self.restart_times = 0
self.basedir = None
self.path_modifer = None
self.input_modifier = None
self.output_modifier = None
self.log_modifier = None
self.benchmark_modifier = None
self.ruleinfo = None
elif len(args) == 1:
other = args[0]
Expand Down Expand Up @@ -168,7 +171,10 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0):
self.is_checkpoint = other.is_checkpoint
self.restart_times = other.restart_times
self.basedir = other.basedir
self.path_modifier = other.path_modifier
self.input_modifier = other.input_modifier
self.output_modifier = other.output_modifier
self.log_modifier = other.log_modifier
self.benchmark_modifier = other.benchmark_modifier
self.ruleinfo = other.ruleinfo

def dynamic_branch(self, wildcards, input=True):
Expand Down Expand Up @@ -348,7 +354,9 @@ def benchmark(self, benchmark):
if isinstance(benchmark, Path):
benchmark = str(benchmark)
if not callable(benchmark):
benchmark = self.apply_path_modifier(benchmark, property="benchmark")
benchmark = self.apply_path_modifier(
benchmark, self.benchmark_modifier, property="benchmark"
)
benchmark = self._update_item_wildcard_constraints(benchmark)

self._benchmark = IOFile(benchmark, rule=self)
Expand Down Expand Up @@ -471,9 +479,9 @@ def check_output_duplicates(self):
)
seen[value] = name or idx

def apply_path_modifier(self, item, property=None):
assert self.path_modifier is not None
apply = partial(self.path_modifier.modify, property=property)
def apply_path_modifier(self, item, path_modifier, property=None):
assert path_modifier is not None
apply = partial(path_modifier.modify, property=property)

assert not callable(item)
if isinstance(item, dict):
Expand Down Expand Up @@ -528,9 +536,14 @@ def _set_inoutput_item(self, item, output=False, name=None):
if isinstance(item, _IOFile) and item.rule and item in item.rule.output:
rule_dependency = item.rule

item = self.apply_path_modifier(
item, property="output" if output else "input"
)
if output:
path_modifier = self.output_modifier
property = "output"
else:
path_modifier = self.input_modifier
property = "input"

item = self.apply_path_modifier(item, path_modifier, property=property)

# Check to see that all flags are valid
# Note that "remote", "dynamic", and "expand" are valid for both inputs and outputs.
Expand Down Expand Up @@ -685,7 +698,7 @@ def _set_log_item(self, item, name=None):
item = str(item)
if isinstance(item, str) or callable(item):
if not callable(item):
item = self.apply_path_modifier(item, property="log")
item = self.apply_path_modifier(item, self.log_modifier, property="log")
item = self._update_item_wildcard_constraints(item)

self.log.append(IOFile(item, rule=self) if isinstance(item, str) else item)
Expand Down Expand Up @@ -769,7 +782,7 @@ def _apply_wildcards(
mapping=None,
no_flattening=False,
aux_params=None,
apply_path_modifier=True,
path_modifier=None,
property=None,
incomplete_checkpoint_func=lambda e: None,
allow_unpack=True,
Expand Down Expand Up @@ -835,8 +848,10 @@ def _apply_wildcards(
"Function did not return str or list of str.", rule=self
)

if from_callable and apply_path_modifier and not incomplete:
item_ = self.apply_path_modifier(item_, property=property)
if from_callable and path_modifier is not None and not incomplete:
item_ = self.apply_path_modifier(
item_, path_modifier, property=property
)

concrete = concretize(item_, wildcards, _is_callable)
newitems.append(concrete)
Expand Down Expand Up @@ -882,6 +897,7 @@ def handle_incomplete_checkpoint(exception):
concretize=concretize_iofile,
mapping=mapping,
incomplete_checkpoint_func=handle_incomplete_checkpoint,
path_modifier=self.input_modifier,
property="input",
groupid=groupid,
)
Expand Down Expand Up @@ -944,7 +960,6 @@ def handle_incomplete_checkpoint(exception):
omit_callable=omit_callable,
allow_unpack=False,
no_flattening=True,
apply_path_modifier=False,
property="params",
aux_params={
"input": input._plainstrings(),
Expand Down Expand Up @@ -993,7 +1008,12 @@ def concretize_logfile(f, wildcards, is_from_callable):

try:
self._apply_wildcards(
log, self.log, wildcards, concretize=concretize_logfile, property="log"
log,
self.log,
wildcards,
concretize=concretize_logfile,
path_modifier=self.log_modifier,
property="log",
)
except WildcardError as e:
raise WildcardError(
Expand Down
35 changes: 21 additions & 14 deletions snakemake/workflow.py
Expand Up @@ -1249,7 +1249,7 @@ def func(*args, **wildcards):
return expand(
*args,
scatteritem=map("{{}}-of-{}".format(n).format, range(1, n + 1)),
**wildcards
**wildcards,
)

for key in content:
Expand Down Expand Up @@ -1369,18 +1369,21 @@ def decorate(ruleinfo):
if ruleinfo.wildcard_constraints:
rule.set_wildcard_constraints(
*ruleinfo.wildcard_constraints[0],
**ruleinfo.wildcard_constraints[1]
**ruleinfo.wildcard_constraints[1],
)
if ruleinfo.name:
rule.name = ruleinfo.name
del self._rules[name]
self._rules[ruleinfo.name] = rule
name = rule.name
rule.path_modifier = ruleinfo.path_modifier
if ruleinfo.input:
rule.set_input(*ruleinfo.input[0], **ruleinfo.input[1])
pos_files, keyword_files, modifier = ruleinfo.input
rule.input_modifier = modifier
rule.set_input(*pos_files, **keyword_files)
if ruleinfo.output:
rule.set_output(*ruleinfo.output[0], **ruleinfo.output[1])
pos_files, keyword_files, modifier = ruleinfo.output
rule.output_modifier = modifier
rule.set_output(*pos_files, **keyword_files)
if ruleinfo.params:
rule.set_params(*ruleinfo.params[0], **ruleinfo.params[1])
# handle default resources
Expand Down Expand Up @@ -1418,9 +1421,9 @@ def decorate(ruleinfo):
if ruleinfo.shadow_depth is True:
rule.shadow_depth = "full"
logger.warning(
"Shadow is set to True in rule {} (equivalent to 'full'). It's encouraged to use the more explicit options 'minimal|copy-minimal|shallow|full' instead.".format(
rule
)
f"Shadow is set to True in rule {rule} (equivalent to 'full'). "
"It's encouraged to use the more explicit options "
"'minimal|copy-minimal|shallow|full' instead."
)
else:
rule.shadow_depth = ruleinfo.shadow_depth
Expand Down Expand Up @@ -1455,11 +1458,15 @@ def decorate(ruleinfo):
if ruleinfo.version:
rule.version = ruleinfo.version
if ruleinfo.log:
rule.set_log(*ruleinfo.log[0], **ruleinfo.log[1])
pos_files, keyword_files, modifier = ruleinfo.log
rule.log_modifier = modifier
rule.set_log(*pos_files, **keyword_files)
if ruleinfo.message:
rule.message = ruleinfo.message
if ruleinfo.benchmark:
rule.benchmark = ruleinfo.benchmark
benchmark, modifier = ruleinfo.benchmark
rule.benchmark_modifier = modifier
rule.benchmark = benchmark
if not self.run_local:
group = self.overwrite_groups.get(name) or ruleinfo.group
if group is not None:
Expand Down Expand Up @@ -1627,14 +1634,14 @@ def decorate(ruleinfo):

def input(self, *paths, **kwpaths):
def decorate(ruleinfo):
ruleinfo.input = (paths, kwpaths)
ruleinfo.input = (paths, kwpaths, self.modifier.path_modifier)
return ruleinfo

return decorate

def output(self, *paths, **kwpaths):
def decorate(ruleinfo):
ruleinfo.output = (paths, kwpaths)
ruleinfo.output = (paths, kwpaths, self.modifier.path_modifier)
return ruleinfo

return decorate
Expand Down Expand Up @@ -1679,7 +1686,7 @@ def decorate(ruleinfo):

def benchmark(self, benchmark):
def decorate(ruleinfo):
ruleinfo.benchmark = benchmark
ruleinfo.benchmark = (benchmark, self.modifier.path_modifier)
return ruleinfo

return decorate
Expand Down Expand Up @@ -1770,7 +1777,7 @@ def decorate(ruleinfo):

def log(self, *logs, **kwlogs):
def decorate(ruleinfo):
ruleinfo.log = (logs, kwlogs)
ruleinfo.log = (logs, kwlogs, self.modifier.path_modifier)
return ruleinfo

return decorate
Expand Down
27 changes: 27 additions & 0 deletions tests/test_module_no_prefixing_modified_paths/Snakefile
@@ -0,0 +1,27 @@
module module1:
snakefile: "module1/Snakefile"
config: config


use rule * from module1 as module1_*


# provide a prefix for all paths in module2
module module2:
snakefile: "module2/Snakefile"
config: config
prefix: "module2"


use rule * from module2 as module2_*

# overwrite the input to remove the module2 prefix specified above
use rule c from module2 as module2_c with:
input:
"test.txt"


rule joint_all:
input:
"module2/test_final.txt",
default_target: True
@@ -0,0 +1,2 @@
test_a
test_c
@@ -0,0 +1,5 @@
rule a:
output:
"test.txt"
shell:
"echo test_a > {output}"
17 changes: 17 additions & 0 deletions tests/test_module_no_prefixing_modified_paths/module2/Snakefile
@@ -0,0 +1,17 @@
rule b:
output:
"test.txt"
shell:
"echo test_b > {output}"


rule c:
input:
"test.txt"
output:
"test_final.txt"
shell:
"""
cp {input} {output};
echo test_c >> {output}
"""
8 changes: 8 additions & 0 deletions tests/tests.py
Expand Up @@ -1390,6 +1390,14 @@ def test_module_complex2():
run(dpath("test_module_complex2"), dryrun=True)


@skip_on_windows
def test_module_no_prefixing_modified_paths():
run(
dpath("test_module_no_prefixing_modified_paths"),
targets=["module2/test_final.txt"],
)


def test_module_with_script():
run(dpath("test_module_with_script"))

Expand Down