From bdb75f828a3ae27ba97ea6cd5e71a34ac7b27eea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20K=C3=B6ster?= Date: Fri, 24 Sep 2021 12:53:27 +0200 Subject: [PATCH] perf: more extensive caching of source files, including wrappers. (#1182) * initial draft of more caching of source files * fixes * fixes * add retry behavior * fixes * fixes * current basedir joining * fixes * fix config schema handling * fix * fix path * fix * fix path simplification * fmt * fix lints * fix lints * handle different path types * fixes * fix arg name * handle local file URLs * fix * fixes * fixes * remove retry code * fix archive * handle is_local * exception handling fix * dbg * fixes * gitlab support, docs * minor --- docs/snakefiles/deployment.rst | 18 +- docs/snakefiles/modularization.rst | 39 ++- snakemake/cwl.py | 1 + snakemake/deployment/conda.py | 24 +- snakemake/exceptions.py | 5 + snakemake/executors/__init__.py | 22 +- snakemake/io.py | 4 +- snakemake/notebook.py | 11 +- snakemake/parser.py | 12 +- snakemake/rules.py | 2 +- snakemake/script.py | 142 +++++----- snakemake/sourcecache.py | 327 +++++++++++++++++++++--- snakemake/utils.py | 21 +- snakemake/workflow.py | 68 ++--- snakemake/wrapper.py | 40 ++- test-environment.yml | 2 +- tests/test_module_with_script/Snakefile | 4 +- 17 files changed, 538 insertions(+), 204 deletions(-) diff --git a/docs/snakefiles/deployment.rst b/docs/snakefiles/deployment.rst index 9fe0e3246..4e67cefed 100644 --- a/docs/snakefiles/deployment.rst +++ b/docs/snakefiles/deployment.rst @@ -69,14 +69,20 @@ Consider the following example: configfile: "config/config.yaml" module dna_seq: - snakefile: "https://github.com/snakemake-workflows/dna-seq-gatk-variant-calling/raw/v2.0.1/Snakefile" - config: config + snakefile: + # here, it is also possible to provide a plain raw URL like "https://github.com/snakemake-workflows/dna-seq-gatk-variant-calling/raw/v2.0.1/workflow/Snakefile" + github("snakemake-workflows/dna-seq-gatk-variant-calling", path="workflow/Snakefile" tag="v2.0.1") + config: + config use rule * from dna_seq First, we load a local configuration file. -Next, we define the module ``dna_seq`` to be loaded from the URL ``https://github.com/snakemake-workflows/dna-seq-gatk-variant-calling/blob/v2.0.1/Snakefile``, while using the contents of the local configuration file. +Next, we define the module ``dna_seq`` to be loaded from the URL ``https://github.com/snakemake-workflows/dna-seq-gatk-variant-calling/raw/v2.0.1/workflow/Snakefile``, while using the contents of the local configuration file. +Note that it is possible to either specify the full URL pointing to the raw Snakefile as a string or to use the github marker as done here. +With the latter, Snakemake can however cache the used source files persistently (if a tag is given), such that they don't have to be downloaded on each invocation. Finally we declare all rules of the dna_seq module to be used. + This kind of deployment is equivalent to just cloning the original repository and modifying the configuration in it. However, the advantage here is that we are (a) able to easily extend of modify the workflow, while making the changes transparent, and (b) we can store this workflow in a separate (e.g. private) git repository, along with for example configuration and meta data, without the need to duplicate the workflow code. Finally, we are always able to later combine another module into the current workflow, e.g. when further kinds of analyses are needed. @@ -92,7 +98,9 @@ For example, we can easily add another rule to extend the given workflow: configfile: "config/config.yaml" module dna_seq: - snakefile: "https://github.com/snakemake-workflows/dna-seq-gatk-variant-calling/raw/v2.0.1/Snakefile" + snakefile: + # here, it is also possible to provide a plain raw URL like "https://github.com/snakemake-workflows/dna-seq-gatk-variant-calling/raw/v2.0.1/workflow/Snakefile" + github("snakemake-workflows/dna-seq-gatk-variant-calling", path="workflow/Snakefile" tag="v2.0.1") config: config use rule * from dna_seq @@ -106,6 +114,8 @@ For example, we can easily add another rule to extend the given workflow: notebook: "notebooks/plot-vafs.py.ipynb" +Moreover, it is possible to further extend the workflow with other modules, thereby generating an integrative analysis. + ---------------------------------- Uploading workflows to WorkflowHub ---------------------------------- diff --git a/docs/snakefiles/modularization.rst b/docs/snakefiles/modularization.rst index e5a6a4c6a..128ee9e23 100644 --- a/docs/snakefiles/modularization.rst +++ b/docs/snakefiles/modularization.rst @@ -126,12 +126,14 @@ With Snakemake 6.0 and later, it is possible to define external workflows as mod min_version("6.0") module other_workflow: - snakefile: "other_workflow/Snakefile" + snakefile: + # here, plain paths, URLs and the special markers for code hosting providers (see below) are possible. + "other_workflow/Snakefile" use rule * from other_workflow as other_* The first statement registers the external workflow as a module, by defining the path to the main snakefile. -The snakefile property of the module can either take a local path or a HTTP/HTTPS url. +Here, plain paths, HTTP/HTTPS URLs and special markers for code hosting providers like Github or Gitlab are possible (see :ref:`snakefile-code-hosting-providers`). The second statement declares all rules of that module to be used in the current one. Thereby, the ``as other_*`` at the end renames all those rule with a common prefix. This can be handy to avoid rule name conflicts (note that rules from modules can otherwise overwrite rules from your current workflow or other modules). @@ -149,6 +151,7 @@ It is possible to overwrite the global config dictionary for the module, which i configfile: "config/config.yaml" module other_workflow: + # here, plain paths, URLs and the special markers for code hosting providers (see below) are possible. snakefile: "other_workflow/Snakefile" config: config["other-workflow"] @@ -167,6 +170,7 @@ This modification can be performed after a general import, and will overwrite an min_version("6.0") module other_workflow: + # here, plain paths, URLs and the special markers for code hosting providers (see below) are possible. snakefile: "other_workflow/Snakefile" config: config["other-workflow"] @@ -261,3 +265,34 @@ This function automatically determines the absolute path to the file (here ``../ When executing, snakemake first tries to create (or update, if necessary) ``test.txt`` (and all other possibly mentioned dependencies) by executing the subworkflow. Then the current workflow is executed. This can also happen recursively, since the subworkflow may have its own subworkflows as well. + + +.. _snakefile-code-hosting-providers: + +---------------------- +Code hosting providers +---------------------- + +To obtain the correct URL to an external source code resource (e.g. a snakefile, see :ref:`snakefiles-modules`), Snakemake provides markers for code hosting providers. +Currently, Github + +.. code-block:: python + + github("owner/repo", path="workflow/Snakefile", tag="v1.0.0") + + +and Gitlab are supported: + +.. code-block:: python + + gitlab("owner/repo", path="workflow/Snakefile", tag="v1.0.0") + +For the latter, it is also possible to specify an alternative host, e.g. + +.. code-block:: python + + gitlab("owner/repo", path="workflow/Snakefile", tag="v1.0.0", host="somecustomgitlab.org") + + +While specifying a tag is highly encouraged, it is alternatively possible to specify a `commit` or a `branch` via respective keyword arguments. +Note that only when specifying a tag or a commit, Snakemake is able to persistently cache the source, thereby avoiding to repeatedly query it in case of multiple executions. diff --git a/snakemake/cwl.py b/snakemake/cwl.py index 245b24d17..8bb5764e9 100644 --- a/snakemake/cwl.py +++ b/snakemake/cwl.py @@ -35,6 +35,7 @@ def cwl( use_singularity, bench_record, jobid, + runtime_sourcecache_path, ): """ Load cwl from the given basedir + path and execute it. diff --git a/snakemake/deployment/conda.py b/snakemake/deployment/conda.py index 399647e35..7c839d6a6 100644 --- a/snakemake/deployment/conda.py +++ b/snakemake/deployment/conda.py @@ -5,6 +5,7 @@ import os import re +from snakemake.sourcecache import LocalGitFile, LocalSourceFile, infer_source_file import subprocess import tempfile from urllib.request import urlopen @@ -45,7 +46,7 @@ class Env: def __init__( self, env_file, workflow, env_dir=None, container_img=None, cleanup=None ): - self.file = env_file + self.file = infer_source_file(env_file) self.frontend = workflow.conda_frontend self.workflow = workflow @@ -161,7 +162,9 @@ def create_archive(self): try: # Download logger.info( - "Downloading packages for conda environment {}...".format(self.file) + "Downloading packages for conda environment {}...".format( + self.file.get_path_or_uri() + ) ) os.makedirs(env_archive, exist_ok=True) try: @@ -216,11 +219,16 @@ def create(self, dryrun=False): env_file = self.file tmp_file = None - if not is_local_file(env_file) or env_file.startswith("git+file:/"): + if not isinstance(env_file, LocalSourceFile) or isinstance( + env_file, LocalGitFile + ): with tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") as tmp: + # write to temp file such that conda can open it tmp.write(self.content) env_file = tmp.name tmp_file = tmp.name + else: + env_file = env_file.get_path_or_uri() env_hash = self.hash env_path = self.path @@ -258,13 +266,13 @@ def create(self, dryrun=False): if dryrun: logger.info( "Incomplete Conda environment {} will be recreated.".format( - utils.simplify_path(self.file) + self.file.simplify_path() ) ) else: logger.info( "Removing incomplete Conda environment {}...".format( - utils.simplify_path(self.file) + self.file.simplify_path() ) ) shutil.rmtree(env_path, ignore_errors=True) @@ -274,15 +282,13 @@ def create(self, dryrun=False): if dryrun: logger.info( "Conda environment {} will be created.".format( - utils.simplify_path(self.file) + self.file.simplify_path() ) ) return env_path conda = Conda(self._container_img) logger.info( - "Creating conda environment {}...".format( - utils.simplify_path(self.file) - ) + "Creating conda environment {}...".format(self.file.simplify_path()) ) # Check if env archive exists. Use that if present. env_archive = self.archive_file diff --git a/snakemake/exceptions.py b/snakemake/exceptions.py index ec76a7867..3e7080c08 100644 --- a/snakemake/exceptions.py +++ b/snakemake/exceptions.py @@ -164,6 +164,11 @@ def __init__(self, *args, lineno=None, snakefile=None, rule=None): self.rule = rule +class SourceFileError(WorkflowError): + def __init__(self, msg): + super().__init__("Error in source file definition: {}".format(msg)) + + class WildcardError(WorkflowError): pass diff --git a/snakemake/executors/__init__.py b/snakemake/executors/__init__.py index 745f24315..b8717e52c 100644 --- a/snakemake/executors/__init__.py +++ b/snakemake/executors/__init__.py @@ -515,6 +515,7 @@ def job_args_and_prepare(self, job): self.workflow.edit_notebook, self.workflow.conda_base_path, job.rule.basedir, + self.workflow.sourcecache.runtime_cache_path, ) def run_single_job(self, job): @@ -2296,6 +2297,7 @@ def run_wrapper( edit_notebook, conda_base_path, basedir, + runtime_sourcecache_path, ): """ Wrapper around the run method that handles exceptions and benchmarking. @@ -2376,6 +2378,7 @@ def run_wrapper( edit_notebook, conda_base_path, basedir, + runtime_sourcecache_path, ) else: # The benchmarking is started here as we have a run section @@ -2406,6 +2409,7 @@ def run_wrapper( edit_notebook, conda_base_path, basedir, + runtime_sourcecache_path, ) # Store benchmark record for this iteration bench_records.append(bench_record) @@ -2434,20 +2438,26 @@ def run_wrapper( edit_notebook, conda_base_path, basedir, + runtime_sourcecache_path, ) except (KeyboardInterrupt, SystemExit) as e: # Re-raise the keyboard interrupt in order to record an error in the # scheduler but ignore it raise e except (Exception, BaseException) as ex: - log_verbose_traceback(ex) # this ensures that exception can be re-raised in the parent thread - lineno, file = get_exception_origin(ex, linemaps) - raise RuleException( - format_error( - ex, lineno, linemaps=linemaps, snakefile=file, show_traceback=True + origin = get_exception_origin(ex, linemaps) + if origin is not None: + log_verbose_traceback(ex) + lineno, file = origin + raise RuleException( + format_error( + ex, lineno, linemaps=linemaps, snakefile=file, show_traceback=True + ) ) - ) + else: + # some internal bug, just reraise + raise ex if benchmark is not None: try: diff --git a/snakemake/io.py b/snakemake/io.py index ab4578611..b4d192aa6 100755 --- a/snakemake/io.py +++ b/snakemake/io.py @@ -1345,11 +1345,11 @@ def git_content(git_file): """ This function will extract a file from a git repository, one located on the filesystem. - Expected format is git+file:///path/to/your/repo/path_to_file@@version + Expected format is git+file:///path/to/your/repo/path_to_file@version Args: env_file (str): consist of path to repo, @, version and file information - Ex: git+file:////home/smeds/snakemake-wrappers/bio/fastqc/wrapper.py@0.19.3 + Ex: git+file:///home/smeds/snakemake-wrappers/bio/fastqc/wrapper.py@0.19.3 Returns: file content or None if the expected format isn't meet """ diff --git a/snakemake/notebook.py b/snakemake/notebook.py index f97729e51..5cf4a8143 100644 --- a/snakemake/notebook.py +++ b/snakemake/notebook.py @@ -10,6 +10,7 @@ from snakemake.logging import logger from snakemake.common import is_local_file from snakemake.common import ON_WINDOWS +from snakemake.sourcecache import SourceCache KERNEL_STARTED_RE = re.compile(r"Kernel started: (?P\S+)") KERNEL_SHUTDOWN_RE = re.compile(r"Kernel shutdown: (?P\S+)") @@ -152,6 +153,7 @@ def get_preamble(self): self.bench_iteration, self.cleanup_scripts, self.shadow_dir, + self.is_local, preamble_addendum=preamble_addendum, ) @@ -218,7 +220,8 @@ def notebook( bench_iteration, cleanup_scripts, shadow_dir, - edit=None, + edit, + runtime_sourcecache_path, ): """ Load a script from the given basedir + path and execute it. @@ -251,9 +254,12 @@ def notebook( ) if not draft: - path, source, language = get_source(path, basedir, wildcards, params) + path, source, language, is_local = get_source( + path, SourceCache(runtime_sourcecache_path), basedir, wildcards, params + ) else: source = None + is_local = True exec_class = get_exec_class(language) @@ -280,6 +286,7 @@ def notebook( bench_iteration, cleanup_scripts, shadow_dir, + is_local, ) if draft: diff --git a/snakemake/parser.py b/snakemake/parser.py index 4faa8de67..a5440e7f0 100644 --- a/snakemake/parser.py +++ b/snakemake/parser.py @@ -507,7 +507,7 @@ def start(self): "resources, log, version, rule, conda_env, container_img, " "singularity_args, use_singularity, env_modules, bench_record, jobid, " "is_shell, bench_iteration, cleanup_scripts, shadow_dir, edit_notebook, " - "conda_base_path, basedir):".format( + "conda_base_path, basedir, runtime_sourcecache_path):".format( rulename=self.rulename if self.rulename is not None else self.snakefile.rulecount @@ -608,7 +608,7 @@ def args(self): yield ( ", basedir, input, output, params, wildcards, threads, resources, log, " "config, rule, conda_env, conda_base_path, container_img, singularity_args, env_modules, " - "bench_record, jobid, bench_iteration, cleanup_scripts, shadow_dir" + "bench_record, jobid, bench_iteration, cleanup_scripts, shadow_dir, runtime_sourcecache_path" ) @@ -621,7 +621,7 @@ def args(self): ", basedir, input, output, params, wildcards, threads, resources, log, " "config, rule, conda_env, conda_base_path, container_img, singularity_args, env_modules, " "bench_record, jobid, bench_iteration, cleanup_scripts, shadow_dir, " - "edit_notebook" + "edit_notebook, runtime_sourcecache_path" ) @@ -634,7 +634,7 @@ def args(self): ", input, output, params, wildcards, threads, resources, log, " "config, rule, conda_env, conda_base_path, container_img, singularity_args, env_modules, " "bench_record, workflow.wrapper_prefix, jobid, bench_iteration, " - "cleanup_scripts, shadow_dir" + "cleanup_scripts, shadow_dir, runtime_sourcecache_path" ) @@ -645,7 +645,7 @@ class CWL(Script): def args(self): yield ( ", basedir, input, output, params, wildcards, threads, resources, log, " - "config, rule, use_singularity, bench_record, jobid" + "config, rule, use_singularity, bench_record, jobid, runtime_sourcecache_path" ) @@ -1157,7 +1157,7 @@ def python(self, token): class Snakefile: def __init__(self, path, workflow, rulecount=0): - self.path = path + self.path = path.get_path_or_uri() self.file = workflow.sourcecache.open(path) self.tokens = tokenize.generate_tokens(self.file.readline) self.rulecount = rulecount diff --git a/snakemake/rules.py b/snakemake/rules.py index d922ef012..064ec5e38 100644 --- a/snakemake/rules.py +++ b/snakemake/rules.py @@ -570,7 +570,7 @@ def _set_inoutput_item(self, item, output=False, name=None): report_obj = item.flags["report"] if report_obj.caption is not None: r = ReportObject( - os.path.join(self.workflow.current_basedir, report_obj.caption), + self.workflow.current_basedir.join(report_obj.caption), report_obj.category, report_obj.subcategory, report_obj.patterns, diff --git a/snakemake/script.py b/snakemake/script.py index 99b318796..b63f6c5c2 100644 --- a/snakemake/script.py +++ b/snakemake/script.py @@ -6,6 +6,13 @@ import inspect import itertools import os +from snakemake import sourcecache +from snakemake.sourcecache import ( + LocalSourceFile, + SourceCache, + SourceFile, + infer_source_file, +) import tempfile import textwrap import sys @@ -322,6 +329,7 @@ def __init__( bench_iteration, cleanup_scripts, shadow_dir, + is_local, ): self.path = path self.source = source @@ -346,6 +354,7 @@ def __init__( self.bench_iteration = bench_iteration self.cleanup_scripts = cleanup_scripts self.shadow_dir = shadow_dir + self.is_local = is_local def evaluate(self, edit=False): assert not edit or self.editable @@ -360,7 +369,7 @@ def evaluate(self, edit=False): os.makedirs(dir_, exist_ok=True) with tempfile.NamedTemporaryFile( - suffix="." + os.path.basename(self.path), dir=dir_, delete=False + suffix="." + self.path.get_filename(), dir=dir_, delete=False ) as fd: self.write_script(preamble, fd) @@ -437,9 +446,9 @@ def generate_preamble( bench_iteration, cleanup_scripts, shadow_dir, + is_local, preamble_addendum="", ): - wrapper_path = path[7:] if path.startswith("file://") else path snakemake = Snakemake( input_, output, @@ -451,7 +460,7 @@ def generate_preamble( config, rulename, bench_iteration, - os.path.dirname(wrapper_path), + path.get_basedir().get_path_or_uri(), ) snakemake = pickle.dumps(snakemake) # Obtain search path for current snakemake module. @@ -462,8 +471,8 @@ def generate_preamble( searchpath = singularity.SNAKEMAKE_MOUNTPOINT searchpath = repr(searchpath) # For local scripts, add their location to the path in case they use path-based imports - if path.startswith("file://"): - searchpath += ", " + repr(os.path.dirname(path[7:])) + if is_local: + searchpath += ", " + repr(path.get_basedir().get_path_or_uri()) return textwrap.dedent( """ @@ -479,10 +488,14 @@ def generate_preamble( ) def get_preamble(self): - wrapper_path = self.path[7:] if self.path.startswith("file://") else self.path + + if isinstance(self.path, LocalSourceFile): + file_override = os.path.realpath(self.path.get_path_or_uri()) + else: + file_override = self.path.get_path_or_uri() preamble_addendum = ( "__real_file__ = __file__; __file__ = {file_override};".format( - file_override=repr(os.path.realpath(wrapper_path)) + file_override=repr(file_override) ) ) @@ -508,12 +521,13 @@ def get_preamble(self): self.bench_iteration, self.cleanup_scripts, self.shadow_dir, + self.is_local, preamble_addendum=preamble_addendum, ) def write_script(self, preamble, fd): fd.write(preamble.encode()) - fd.write(self.source) + fd.write(self.source.encode()) def _is_python_env(self): if self.conda_env is not None: @@ -651,11 +665,7 @@ def generate_preamble( REncoder.encode_dict(config), REncoder.encode_value(rulename), REncoder.encode_numeric(bench_iteration), - REncoder.encode_value( - os.path.dirname(path[7:]) - if path.startswith("file://") - else os.path.dirname(path) - ), + REncoder.encode_value(path.get_basedir().get_path_or_uri()), preamble_addendum=preamble_addendum, ) @@ -686,7 +696,7 @@ def get_preamble(self): def write_script(self, preamble, fd): fd.write(preamble.encode()) - fd.write(self.source) + fd.write(self.source.encode()) def execute_script(self, fname, edit=False): if self.conda_env is not None and "R_LIBS" in os.environ: @@ -764,16 +774,12 @@ def get_preamble(self): REncoder.encode_dict(self.config), REncoder.encode_value(self.rulename), REncoder.encode_numeric(self.bench_iteration), - REncoder.encode_value( - os.path.dirname(self.path[7:]) - if self.path.startswith("file://") - else os.path.dirname(self.path) - ), + REncoder.encode_value(self.path.get_basedir().get_path_or_uri()), ) def write_script(self, preamble, fd): # Insert Snakemake object after the RMarkdown header - code = self.source.decode() + code = self.source pos = next(itertools.islice(re.finditer(r"---\n", code), 1, 2)).start() + 3 fd.write(str.encode(code[:pos])) preamble = textwrap.dedent( @@ -785,7 +791,7 @@ def write_script(self, preamble, fd): % preamble ) fd.write(preamble.encode()) - fd.write(str.encode(code[pos:])) + fd.write(code[pos:].encode()) def execute_script(self, fname, edit=False): if len(self.output) != 1: @@ -852,11 +858,7 @@ def get_preamble(self): JuliaEncoder.encode_dict(self.config), JuliaEncoder.encode_value(self.rulename), JuliaEncoder.encode_value(self.bench_iteration), - JuliaEncoder.encode_value( - os.path.dirname(self.path[7:]) - if self.path.startswith("file://") - else os.path.dirname(self.path) - ), + JuliaEncoder.encode_value(self.path.get_basedir().get_path_or_uri()), ).replace( "'", '"' ) @@ -864,7 +866,7 @@ def get_preamble(self): def write_script(self, preamble, fd): fd.write(preamble.encode()) - fd.write(self.source) + fd.write(self.source.encode()) def execute_script(self, fname, edit=False): self._execute_cmd("julia {fname:q}", fname=fname) @@ -894,10 +896,9 @@ def generate_preamble( bench_iteration, cleanup_scripts, shadow_dir, + is_local, preamble_addendum="", ): - wrapper_path = path[7:] if path.startswith("file://") else path - # snakemake's namedlists will be encoded as a dict # which stores the not-named items at the key "positional" # and unpacks named items into the dict @@ -927,7 +928,7 @@ def encode_namedlist(values): config=encode_namedlist(config.items()), rulename=rulename, bench_iteration=bench_iteration, - scriptdir=os.path.dirname(wrapper_path), + scriptdir=path.get_basedir().get_path_or_uri(), ) import json @@ -941,8 +942,8 @@ def encode_namedlist(values): searchpath = singularity.SNAKEMAKE_MOUNTPOINT searchpath = repr(searchpath) # For local scripts, add their location to the path in case they use path-based imports - if path.startswith("file://"): - searchpath += ", " + repr(os.path.dirname(path[7:])) + if is_local: + searchpath += ", " + repr(path.get_basedir().get_path_or_uri()) return textwrap.dedent( """ @@ -1073,12 +1074,6 @@ def encode_namedlist(values): ) def get_preamble(self): - wrapper_path = self.path[7:] if self.path.startswith("file://") else self.path - # preamble_addendum = ( - # "__real_file__ = __file__; __file__ = {file_override};".format( - # file_override=repr(os.path.realpath(wrapper_path)) - # ) - # ) preamble_addendum = "" preamble = RustScript.generate_preamble( @@ -1103,6 +1098,7 @@ def get_preamble(self): self.bench_iteration, self.cleanup_scripts, self.shadow_dir, + self.is_local, preamble_addendum=preamble_addendum, ) return preamble @@ -1126,7 +1122,7 @@ def combine_preamble_and_source(self, preamble: str) -> str: Also, because rust-scipt relies on inner docs, there can't be an empty line between the manifest and preamble. """ - manifest, src = RustScript.extract_manifest(self.source.decode()) + manifest, src = RustScript.extract_manifest(self.source) return manifest + preamble.lstrip("\r\n") + src @staticmethod @@ -1235,53 +1231,50 @@ def strip_re(regex: Pattern, s: str) -> Tuple[str, str]: return head, tail -def get_source(path, basedir=".", wildcards=None, params=None): - source = None - if not path.startswith("http") and not path.startswith("git+file"): - if path.startswith("file://"): - path = path[7:] - elif path.startswith("file:"): - path = path[5:] - if not os.path.isabs(path): - path = smart_join(basedir, path, abspath=True) - if is_local_file(path): - path = "file://" + path +def get_source( + path, + sourcecache: sourcecache.SourceCache, + basedir=None, + wildcards=None, + params=None, +): if wildcards is not None and params is not None: + if isinstance(path, SourceFile): + path = path.get_path_or_uri() # Format path if wildcards are given. - path = format(path, wildcards=wildcards, params=params) - if path.startswith("file://"): - sourceurl = "file:" + pathname2url(path[7:]) - elif path.startswith("git+file"): - source = git_content(path).encode() - (root_path, file_path, version) = split_git_path(path) - path = path.rstrip("@" + version) - else: - sourceurl = path + path = infer_source_file(format(path, wildcards=wildcards, params=params)) + + if basedir is not None: + basedir = infer_source_file(basedir) - if source is None: - with urlopen(sourceurl) as source: - source = source.read() + source_file = infer_source_file(path, basedir) + with sourcecache.open(source_file) as f: + source = f.read() - language = get_language(path, source) + language = get_language(source_file, source) - return path, source, language + is_local = isinstance(source_file, LocalSourceFile) + return path, source, language, is_local -def get_language(path, source): + +def get_language(source_file, source): import nbformat + filename = source_file.get_filename() + language = None - if path.endswith(".py"): + if filename.endswith(".py"): language = "python" - elif path.endswith(".ipynb"): + elif filename.endswith(".ipynb"): language = "jupyter" - elif path.endswith(".R"): + elif filename.endswith(".R"): language = "r" - elif path.endswith(".Rmd"): + elif filename.endswith(".Rmd"): language = "rmarkdown" - elif path.endswith(".jl"): + elif filename.endswith(".jl"): language = "julia" - elif path.endswith(".rs"): + elif filename.endswith(".rs"): language = "rust" # detect kernel language for Jupyter Notebooks @@ -1322,11 +1315,15 @@ def script( bench_iteration, cleanup_scripts, shadow_dir, + runtime_sourcecache_path, ): """ Load a script from the given basedir + path and execute it. """ - path, source, language = get_source(path, basedir, wildcards, params) + + path, source, language, is_local = get_source( + path, SourceCache(runtime_sourcecache_path), basedir, wildcards, params + ) exec_class = { "python": PythonScript, @@ -1363,5 +1360,6 @@ def script( bench_iteration, cleanup_scripts, shadow_dir, + is_local, ) executor.evaluate() diff --git a/snakemake/sourcecache.py b/snakemake/sourcecache.py index bc2de4dab..f5fd4658e 100644 --- a/snakemake/sourcecache.py +++ b/snakemake/sourcecache.py @@ -7,15 +7,255 @@ from pathlib import Path import re import os +from snakemake import utils import tempfile import io +from abc import ABC, abstractmethod -from snakemake.common import is_local_file, get_appdirs -from snakemake.exceptions import WorkflowError -from snakemake.io import git_content +from snakemake.common import is_local_file, get_appdirs, parse_uri, smart_join +from snakemake.exceptions import WorkflowError, SourceFileError +from snakemake.io import git_content, split_git_path +from snakemake.logging import logger -# TODO also use sourcecache for script and wrapper code! + +def _check_git_args(tag: str = None, branch: str = None, commit: str = None): + n_refs = sum(1 for ref in (tag, branch, commit) if ref is not None) + if n_refs != 1: + raise SourceFileError( + "exactly one of tag, branch, or commit must be specified." + ) + + +class SourceFile(ABC): + @abstractmethod + def get_path_or_uri(self): + ... + + @abstractmethod + def is_persistently_cacheable(self): + ... + + def get_uri_hash(self): + urihash = hashlib.sha256() + urihash.update(self.get_path_or_uri().encode()) + return urihash.hexdigest() + + def get_basedir(self): + path = os.path.dirname(self.get_path_or_uri()) + return self.__class__(path) + + @abstractmethod + def get_filename(self): + ... + + def join(self, path): + if isinstance(path, SourceFile): + path = path.get_path_or_uri() + return self.__class__(smart_join(self.get_path_or_uri(), path)) + + def __hash__(self): + return self.get_path_or_uri().__hash__() + + def __eq__(self, other): + if isinstance(other, SourceFile): + return self.get_path_or_uri() == other.get_path_or_uri() + return False + + def __str__(self): + return self.get_path_or_uri() + + def simplify_path(self): + return self + + +class GenericSourceFile(SourceFile): + def __init__(self, path_or_uri): + self.path_or_uri = path_or_uri + + def get_path_or_uri(self): + return self.path_or_uri + + def get_filename(self): + return os.path.basename(self.path_or_uri) + + def is_persistently_cacheable(self): + return False + + +class LocalSourceFile(SourceFile): + def __init__(self, path): + self.path = path + + def get_path_or_uri(self): + return self.path + + def is_persistently_cacheable(self): + return False + + def get_filename(self): + return os.path.basename(self.path) + + def abspath(self): + return LocalSourceFile(os.path.abspath(self.path)) + + def isabs(self): + return os.path.isabs(self.path) + + def simplify_path(self): + return utils.simplify_path(self.path) + + +class LocalGitFile(SourceFile): + def __init__( + self, repo_path, path: str, tag: str = None, ref: str = None, commit: str = None + ): + _check_git_args(tag, ref, commit) + self.tag = tag + self.commit = commit + self._ref = ref + self.repo_path = repo_path + self.path = path + + def get_path_or_uri(self): + return "git+{}/{}@{}".format(self.repo_path, self.path, self.ref) + + def join(self, path): + return LocalGitFile( + self.repo_path, + "/".join((self.path, path)), + tag=self.tag, + ref=self.ref, + commit=self.commit, + ) + + def is_persistently_cacheable(self): + return False + + def get_filename(self): + return os.path.basename(self.path) + + @property + def ref(self): + return self.tag or self.commit or self._ref + + +class HostingProviderFile(SourceFile): + """Marker for denoting github source files from releases.""" + + valid_repo = re.compile("^.+/.+$") + + def __init__( + self, + repo: str, + path: str, + tag: str = None, + branch: str = None, + commit: str = None, + ): + if not self.__class__.valid_repo.match(repo): + raise SourceFileError( + "repo {} is not a valid repo specification (must be given as owner/name)." + ) + + _check_git_args(tag, branch, commit) + + if path is None: + raise SourceFileError("path must be given") + + if not all( + isinstance(item, str) + for item in (repo, path, tag, branch, commit) + if item is not None + ): + raise SourceFileError("arguments must be given as str.") + + self.repo = repo + self.tag = tag + self.commit = commit + self.branch = branch + self.path = path.strip("/") + + def is_persistently_cacheable(self): + return self.tag or self.commit + + def get_filename(self): + return os.path.basename(self.path) + + @property + def ref(self): + return self.tag or self.commit or self.branch + + def get_basedir(self): + return self.__class__( + repo=self.repo, + path=os.path.dirname(self.path), + tag=self.tag, + commit=self.commit, + branch=self.branch, + ) + + def join(self, path): + path = os.path.normpath("{}/{}".format(self.path, path)) + return self.__class__( + repo=self.repo, + path=path, + tag=self.tag, + commit=self.commit, + branch=self.branch, + ) + + +class GithubFile(HostingProviderFile): + def get_path_or_uri(self): + return "https://github.com/{}/raw/{}/{}".format(self.repo, self.ref, self.path) + + +class GitlabFile(HostingProviderFile): + def __init__( + self, + repo: str, + path: str, + tag: str = None, + branch: str = None, + commit: str = None, + host: str = None, + ): + super().__init__(repo, path, tag, branch, commit) + self.host = host + + def get_path_or_uri(self): + return "https://{}/{}/-/raw/{}/{}".format( + self.host or "gitlab.com", self.repo, self.ref, self.path + ) + + +def infer_source_file(path_or_uri, basedir: SourceFile = None): + if isinstance(path_or_uri, SourceFile): + if basedir is None: + return path_or_uri + else: + path_or_uri = path_or_uri.get_path_or_uri() + if isinstance(path_or_uri, Path): + path_or_uri = str(path_or_uri) + if not isinstance(path_or_uri, str): + raise SourceFileError( + "must be given as Python string or one of the predefined source file marker types (see docs)" + ) + if is_local_file(path_or_uri): + # either local file or relative to some remote basedir + for schema in ("file://", "file:"): + if path_or_uri.startswith(schema): + path_or_uri = path_or_uri[len(schema) :] + break + if not os.path.isabs(path_or_uri) and basedir is not None: + return basedir.join(path_or_uri) + return LocalSourceFile(path_or_uri) + if path_or_uri.startswith("git+file:"): + root_path, file_path, ref = split_git_path(path_or_uri) + return LocalGitFile(root_path, file_path, ref=ref) + # something else + return GenericSourceFile(path_or_uri) class SourceCache: @@ -23,63 +263,84 @@ class SourceCache: "https://raw.githubusercontent.com/snakemake/snakemake-wrappers/\d+\.\d+.\d+" ] # TODO add more prefixes for uris that are save to be cached - def __init__(self): + def __init__(self, runtime_cache_path=None): self.cache = Path( os.path.join(get_appdirs().user_cache_dir, "snakemake/source-cache") ) - self.runtime_cache = tempfile.TemporaryDirectory( - suffix="snakemake-runtime-source-cache" - ) + os.makedirs(self.cache, exist_ok=True) + if runtime_cache_path is None: + self.runtime_cache = tempfile.TemporaryDirectory( + suffix="snakemake-runtime-source-cache" + ) + self._runtime_cache_path = None + else: + self._runtime_cache_path = runtime_cache_path + self.runtime_cache = None self.cacheable_prefixes = re.compile("|".join(self.cache_whitelist)) + @property + def runtime_cache_path(self): + return self._runtime_cache_path or self.runtime_cache.name + def lock_cache(self, entry): from filelock import FileLock return FileLock(entry.with_suffix(".lock")) - def is_persistently_cacheable(self, path_or_uri): - # TODO remove special git url handling once included in smart_open - if path_or_uri.startswith("git+file:"): - return False - return is_local_file(path_or_uri) and self.cacheable_prefixes.match(path_or_uri) - - def open(self, path_or_uri, mode="r"): - cache_entry = self._cache(path_or_uri) + def open(self, source_file, mode="r"): + cache_entry = self._cache(source_file) return self._open(cache_entry, mode) - def get_path(self, path_or_uri, mode="r"): - cache_entry = self._cache(path_or_uri) + def exists(self, source_file): + try: + self._cache(source_file) + except Exception: + return False + return True + + def get_path(self, source_file, mode="r"): + cache_entry = self._cache(source_file) return cache_entry - def _cache_entry(self, path_or_uri): - urihash = hashlib.sha256() - urihash.update(path_or_uri.encode()) - urihash = urihash.hexdigest() + def _cache_entry(self, source_file): + urihash = source_file.get_uri_hash() # TODO add git support to smart_open! - if self.is_persistently_cacheable(path_or_uri): + if source_file.is_persistently_cacheable(): # check cache return self.cache / urihash else: # check runtime cache - return Path(self.runtime_cache.name) / urihash + return Path(self.runtime_cache_path) / urihash - def _cache(self, path_or_uri): - cache_entry = self._cache_entry(path_or_uri) + def _cache(self, source_file): + cache_entry = self._cache_entry(source_file) with self.lock_cache(cache_entry): if not cache_entry.exists(): - # open from origin - with self._open(path_or_uri, "rb") as source, open( - cache_entry, "wb" - ) as cache_source: - cache_source.write(source.read()) + self._do_cache(source_file, cache_entry) return cache_entry + def _do_cache(self, source_file, cache_entry): + # open from origin + with self._open(source_file.get_path_or_uri(), "rb") as source, open( + cache_entry, "wb" + ) as cache_source: + cache_source.write(source.read()) + def _open(self, path_or_uri, mode): from smart_open import open - if str(path_or_uri).startswith("git+file:"): - return io.BytesIO(git_content(path_or_uri).encode()) + if isinstance(path_or_uri, LocalGitFile): + import git + + return io.BytesIO( + git.Repo(path_or_uri.repo_path) + .git.show("{}:{}".format(path_or_uri.ref, path_or_uri.path)) + .encode() + ) + + if isinstance(path_or_uri, SourceFile): + path_or_uri = path_or_uri.get_path_or_uri() try: return open(path_or_uri, mode) diff --git a/snakemake/utils.py b/snakemake/utils.py index d2e3ed381..965dd1f53 100644 --- a/snakemake/utils.py +++ b/snakemake/utils.py @@ -7,6 +7,7 @@ import json import re import inspect +from snakemake.sourcecache import LocalSourceFile, infer_source_file import textwrap import platform from itertools import chain @@ -55,19 +56,21 @@ def validate(data, schema, set_default=True): "in order to use the validate directive." ) - schemafile = schema + schemafile = infer_source_file(schema) - if not os.path.isabs(schemafile): - frame = inspect.currentframe().f_back + if isinstance(schemafile, LocalSourceFile) and not schemafile.isabs() and workflow: # if workflow object is not available this has not been started from a workflow - if workflow: - schemafile = smart_join(workflow.current_basedir, schemafile) + schemafile = workflow.current_basedir.join(schemafile) - source = workflow.sourcecache.open(schemafile) if workflow else schemafile + source = ( + workflow.sourcecache.open(schemafile) + if workflow + else schemafile.get_path_or_uri() + ) schema = _load_configfile(source, filetype="Schema") - if is_local_file(schemafile): + if isinstance(schemafile, LocalSourceFile): resolver = RefResolver( - urljoin("file:", schemafile), + urljoin("file:", schemafile.get_path_or_uri()), schema, handlers={ "file": lambda uri: _load_configfile(re.sub("^file://", "", uri)) @@ -75,7 +78,7 @@ def validate(data, schema, set_default=True): ) else: resolver = RefResolver( - schemafile, + schemafile.get_path_or_uri(), schema, ) diff --git a/snakemake/workflow.py b/snakemake/workflow.py index afe4a5519..6c5abf565 100644 --- a/snakemake/workflow.py +++ b/snakemake/workflow.py @@ -81,8 +81,15 @@ from snakemake.caching.remote import OutputFileCache as RemoteOutputFileCache from snakemake.modules import ModuleInfo, WorkflowModifier, get_name_modifier_func from snakemake.ruleinfo import RuleInfo -from snakemake.sourcecache import SourceCache +from snakemake.sourcecache import ( + GenericSourceFile, + LocalSourceFile, + SourceCache, + SourceFile, + infer_source_file, +) from snakemake.deployment.conda import Conda +from snakemake import sourcecache class Workflow: @@ -226,6 +233,8 @@ def __init__( _globals["checkpoints"] = Checkpoints() _globals["scatter"] = Scatter() _globals["gather"] = Gather() + _globals["github"] = sourcecache.GithubFile + _globals["gitlab"] = sourcecache.GitlabFile self.vanilla_globals = dict(_globals) self.modifier_stack = [WorkflowModifier(self, globals=_globals)] @@ -310,10 +319,10 @@ def get_sources(self): files = set() def local_path(f): - if is_local_file(f): - return parse_uri(f).uri_path - else: - return None + if not isinstance(f, SourceFile) and is_local_file(f): + return f + if isinstance(f, LocalSourceFile): + return f.get_path_or_uri() def norm_rule_relpath(f, rule): if not os.path.isabs(f): @@ -1088,9 +1097,10 @@ def files(items): def current_basedir(self): """Basedir of currently parsed Snakefile.""" assert self.included_stack - basedir = os.path.dirname(self.included_stack[-1]) - if is_local_file(basedir): - return os.path.abspath(basedir) + snakefile = self.included_stack[-1] + basedir = snakefile.get_basedir() + if isinstance(basedir, LocalSourceFile): + return basedir.abspath() else: return basedir @@ -1103,7 +1113,7 @@ def source_path(self, rel_path): calling_file = frame.f_code.co_filename calling_dir = os.path.dirname(calling_file) path = smart_join(calling_dir, rel_path) - return self.sourcecache.get_path(path) + return self.sourcecache.get_path(infer_source_file(path)) @property def snakefile(self): @@ -1141,16 +1151,8 @@ def include( """ Include a snakefile. """ - if isinstance(snakefile, Path): - snakefile = str(snakefile) - - # check if snakefile is a path to the filesystem - if is_local_file(snakefile): - if not os.path.isabs(snakefile) and self.included_stack: - snakefile = smart_join(self.current_basedir, snakefile) - # Could still be a url if relative import was used - if is_local_file(snakefile): - snakefile = os.path.abspath(snakefile) + basedir = self.current_basedir if self.included_stack else None + snakefile = infer_source_file(snakefile, basedir) if not self.modifier.allow_rule_overwrite and snakefile in self.included: logger.info("Multiple includes of {} ignored".format(snakefile)) @@ -1170,13 +1172,14 @@ def include( if print_compilation: print(code) - # insert the current directory into sys.path - # this allows to import modules from the workflow directory - sys.path.insert(0, os.path.dirname(snakefile)) + if isinstance(snakefile, LocalSourceFile): + # insert the current directory into sys.path + # this allows to import modules from the workflow directory + sys.path.insert(0, snakefile.get_basedir().get_path_or_uri()) self.linemaps[snakefile] = linemap - exec(compile(code, snakefile, "exec"), self.globals) + exec(compile(code, snakefile.get_path_or_uri(), "exec"), self.globals) if not overwrite_first_rule: self.first_rule = first_rule @@ -1258,14 +1261,14 @@ def pepschema(self, schema): if is_local_file(schema) and not os.path.isabs(schema): # schema is relative to current Snakefile - schema = os.path.join(self.current_basedir, schema) + schema = self.current_basedir.join(schema).get_path_or_uri() if self.pepfile is None: raise WorkflowError("Please specify a PEP with the pepfile directive.") eido.validate_project(project=pep, schema=schema, exclude_case=True) def report(self, path): """Define a global report description in .rst format.""" - self.report_text = os.path.join(self.current_basedir, path) + self.report_text = self.current_basedir.join(path) @property def config(self): @@ -1458,12 +1461,15 @@ def decorate(ruleinfo): "(not with run).", rule=rule, ) - if is_local_file(ruleinfo.conda_env) and not os.path.isabs( - ruleinfo.conda_env + + if ( + ruleinfo.conda_env is not None + and is_local_file(ruleinfo.conda_env) + and not os.path.isabs(ruleinfo.conda_env) ): - ruleinfo.conda_env = os.path.join( - self.current_basedir, ruleinfo.conda_env - ) + ruleinfo.conda_env = self.current_basedir.join( + ruleinfo.conda_env + ).get_path_or_uri() rule.conda_env = ruleinfo.conda_env invalid_rule = not ( @@ -1869,4 +1875,4 @@ def srcdir(path): """Return the absolute path, relative to the source directory of the current Snakefile.""" if not workflow.included_stack: return None - return os.path.join(os.path.dirname(workflow.included_stack[-1]), path) + return workflow.current_basedir.join(path).get_path_or_uri() diff --git a/snakemake/wrapper.py b/snakemake/wrapper.py index f530df6a9..91b1a04fd 100644 --- a/snakemake/wrapper.py +++ b/snakemake/wrapper.py @@ -12,6 +12,7 @@ from snakemake.exceptions import WorkflowError from snakemake.script import script +from snakemake.sourcecache import SourceCache, infer_source_file PREFIX = "https://github.com/snakemake/snakemake-wrappers/raw/" @@ -52,35 +53,24 @@ def is_git_path(path): return path.startswith("git+file:") -def find_extension(path, extensions=[".py", ".R", ".Rmd", ".jl"]): +def find_extension( + path, sourcecache: SourceCache, extensions=[".py", ".R", ".Rmd", ".jl"] +): for ext in extensions: if path.endswith("wrapper{}".format(ext)): return path + + path = infer_source_file(path) for ext in extensions: - script = "/wrapper{}".format(ext) - if is_local(path): - if path.startswith("file://"): - p = path[7:] - elif path.startswith("file:"): - p = path[5:] - if os.path.exists(p + script): - return path + script - else: - try: - urlopen(path + script) - return path + script - except URLError: - continue - if is_git_path(path): - path, version = path.split("@") - return os.path.join(path, "wrapper.py") + "@" + version - else: - return path + "/wrapper.py" # default case + script = path.join("wrapper{}".format(ext)) + + if sourcecache.exists(script): + return script -def get_script(path, prefix=None): +def get_script(path, sourcecache: SourceCache, prefix=None): path = get_path(path, prefix=prefix) - return find_extension(path) + return find_extension(path, sourcecache) def get_conda_env(path, prefix=None): @@ -116,12 +106,15 @@ def wrapper( bench_iteration, cleanup_scripts, shadow_dir, + runtime_sourcecache_path, ): """ Load a wrapper from https://github.com/snakemake/snakemake-wrappers under the given path + wrapper.(py|R|Rmd) and execute it. """ - path = get_script(path, prefix=prefix) + path = get_script( + path, SourceCache(runtime_cache_path=runtime_sourcecache_path), prefix=prefix + ) script( path, "", @@ -144,4 +137,5 @@ def wrapper( bench_iteration, cleanup_scripts, shadow_dir, + runtime_sourcecache_path, ) diff --git a/test-environment.yml b/test-environment.yml index 7dac49ea4..0474ee941 100644 --- a/test-environment.yml +++ b/test-environment.yml @@ -61,4 +61,4 @@ dependencies: - glpk - smart_open - filelock - - tabulate \ No newline at end of file + - tabulate diff --git a/tests/test_module_with_script/Snakefile b/tests/test_module_with_script/Snakefile index 988e47eaf..9b9f7feeb 100644 --- a/tests/test_module_with_script/Snakefile +++ b/tests/test_module_with_script/Snakefile @@ -4,11 +4,9 @@ min_version("6.1.1") configfile: "config.yaml" -remote_snakefile = "https://raw.githubusercontent.com/snakemake/snakemake/v6.1.1/tests/test_script_py/Snakefile" - module with_script: snakefile: - remote_snakefile + github("snakemake/snakemake", tag="v6.1.1", path="tests/test_script_py/Snakefile") config: config