From 4816a58653e466ca94b1482a1d947a856f5381b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20K=C3=B6ster?= Date: Tue, 2 Nov 2021 18:15:53 +0100 Subject: [PATCH] fix: only consider context of shell command for technical switches if called from snakemake rules. (#1213) --- snakemake/common/__init__.py | 3 +++ snakemake/parser.py | 6 ++++-- snakemake/shell.py | 24 +++++++++++++++++------- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/snakemake/common/__init__.py b/snakemake/common/__init__.py index 1ade8f179..abb31b5ab 100644 --- a/snakemake/common/__init__.py +++ b/snakemake/common/__init__.py @@ -69,6 +69,9 @@ def __new__(cls, _=None): APPDIRS = None +RULEFUNC_CONTEXT_MARKER = "__is_snakemake_rule_func" + + def get_appdirs(): global APPDIRS if APPDIRS is None: diff --git a/snakemake/parser.py b/snakemake/parser.py index a5440e7f0..d0546769f 100644 --- a/snakemake/parser.py +++ b/snakemake/parser.py @@ -11,6 +11,7 @@ from io import TextIOWrapper from snakemake.exceptions import WorkflowError +from snakemake import common dd = textwrap.dedent @@ -507,10 +508,11 @@ def start(self): "resources, log, version, rule, conda_env, container_img, " "singularity_args, use_singularity, env_modules, bench_record, jobid, " "is_shell, bench_iteration, cleanup_scripts, shadow_dir, edit_notebook, " - "conda_base_path, basedir, runtime_sourcecache_path):".format( + "conda_base_path, basedir, runtime_sourcecache_path, {rule_func_marker}=True):".format( rulename=self.rulename if self.rulename is not None - else self.snakefile.rulecount + else self.snakefile.rulecount, + rule_func_marker=common.RULEFUNC_CONTEXT_MARKER, ) ) diff --git a/snakemake/shell.py b/snakemake/shell.py index 2c9f543dc..ea75cd411 100644 --- a/snakemake/shell.py +++ b/snakemake/shell.py @@ -14,7 +14,7 @@ import threading from snakemake.utils import format, argvquote, cmd_exe_quote, find_bash_on_windows -from snakemake.common import ON_WINDOWS +from snakemake.common import ON_WINDOWS, RULEFUNC_CONTEXT_MARKER from snakemake.logging import logger from snakemake.deployment import singularity from snakemake.deployment.conda import Conda @@ -130,14 +130,23 @@ def __new__( kwargs["quote_func"] = cmd_exe_quote cmd = format(cmd, *args, stepout=2, **kwargs) - context = inspect.currentframe().f_back.f_locals - # add kwargs to context (overwriting the locals of the caller) - context.update(kwargs) stdout = sp.PIPE if iterable or read else STDOUT close_fds = sys.platform != "win32" + func_context = inspect.currentframe().f_back.f_locals + + if func_context.get(RULEFUNC_CONTEXT_MARKER): + # If this comes from a rule, we expect certain information to be passed + # implicitly via the rule func context, which is added here. + context = func_context + else: + # Otherwise, context is just filled via kwargs. + context = dict() + # add kwargs to context (overwriting the locals of the caller) + context.update(kwargs) + jobid = context.get("jobid") if not context.get("is_shell"): logger.shellcmd(cmd) @@ -148,6 +157,8 @@ def __new__( env_modules = context.get("env_modules", None) shadow_dir = context.get("shadow_dir", None) resources = context.get("resources", {}) + singularity_args = context.get("singularity_args", "") + threads = context.get("threads", 1) cmd = " ".join((cls._process_prefix, cmd, cls._process_suffix)).strip() @@ -176,11 +187,10 @@ def __new__( cmd = '"{}" "{}"'.format(cls.get_executable() or "/bin/sh", script) if container_img: - args = context.get("singularity_args", "") cmd = singularity.shellcmd( container_img, cmd, - args, + singularity_args, envvars=None, shell_executable=cls._process_args["executable"], container_workdir=shadow_dir, @@ -190,12 +200,12 @@ def __new__( if conda_env: logger.info("Activating conda environment: {}".format(conda_env)) - threads = str(context.get("threads", 1)) tmpdir_resource = resources.get("tmpdir", None) # environment variable lists for linear algebra libraries taken from: # https://stackoverflow.com/a/53224849/2352071 # https://github.com/xianyi/OpenBLAS/tree/59243d49ab8e958bb3872f16a7c0ef8c04067c0a#setting-the-number-of-threads-using-environment-variables envvars = dict(os.environ) + threads = str(threads) envvars["OMP_NUM_THREADS"] = threads envvars["GOTO_NUM_THREADS"] = threads envvars["OPENBLAS_NUM_THREADS"] = threads