diff --git a/snakemake/deployment/conda.py b/snakemake/deployment/conda.py index a9c23e990..6f1a1b7a8 100644 --- a/snakemake/deployment/conda.py +++ b/snakemake/deployment/conda.py @@ -6,7 +6,12 @@ import os from pathlib import Path import re -from snakemake.sourcecache import LocalGitFile, LocalSourceFile, infer_source_file +from snakemake.sourcecache import ( + LocalGitFile, + LocalSourceFile, + SourceFile, + infer_source_file, +) import subprocess import tempfile from urllib.request import urlopen @@ -498,9 +503,7 @@ def create(self, dryrun=False): logger.debug(out) logger.info( - "Environment for {} created (location: {})".format( - os.path.relpath(env_file), os.path.relpath(env_path) - ) + f"Environment for {self.file.get_path_or_uri()} created (location: {os.path.relpath(env_path)})" ) except subprocess.CalledProcessError as e: # remove potential partially installed environment @@ -706,8 +709,10 @@ def __eq__(self, other): class CondaEnvFileSpec(CondaEnvSpec): - def __init__(self, filepath: str, rule=None): - if isinstance(filepath, _IOFile): + def __init__(self, filepath, rule=None): + if isinstance(filepath, SourceFile): + self.file = IOFile(str(filepath.get_path_or_uri()), rule=rule) + elif isinstance(filepath, _IOFile): self.file = filepath else: self.file = IOFile(filepath, rule=rule) @@ -777,5 +782,8 @@ def __eq__(self, other): return self.name == other.name -def is_conda_env_file(spec: str): +def is_conda_env_file(spec): + if isinstance(spec, SourceFile): + spec = spec.get_filename() + return spec.endswith(".yaml") or spec.endswith(".yml") diff --git a/snakemake/logging.py b/snakemake/logging.py index a32301568..45847f7fd 100644 --- a/snakemake/logging.py +++ b/snakemake/logging.py @@ -349,7 +349,7 @@ def set_level(self, level): def logfile_hint(self): if self.mode == Mode.default: logfile = self.get_logfile() - self.info("Complete log: {}".format(logfile)) + self.info("Complete log: {}".format(os.path.relpath(logfile))) def location(self, msg): callerframerecord = inspect.stack()[1] diff --git a/snakemake/shell.py b/snakemake/shell.py index d5d981216..8c1744394 100644 --- a/snakemake/shell.py +++ b/snakemake/shell.py @@ -205,7 +205,9 @@ def __new__( ) logger.info("Activating singularity image {}".format(container_img)) if conda_env: - logger.info("Activating conda environment: {}".format(conda_env)) + logger.info( + "Activating conda environment: {}".format(os.path.relpath(conda_env)) + ) tmpdir_resource = resources.get("tmpdir", None) # environment variable lists for linear algebra libraries taken from: diff --git a/snakemake/sourcecache.py b/snakemake/sourcecache.py index dc768726b..c61805fea 100644 --- a/snakemake/sourcecache.py +++ b/snakemake/sourcecache.py @@ -5,6 +5,7 @@ import hashlib from pathlib import Path +import posixpath import re import os import shutil @@ -14,7 +15,6 @@ from abc import ABC, abstractmethod from datetime import datetime - from snakemake.common import ( ON_WINDOWS, is_local_file, @@ -66,6 +66,11 @@ def mtime(self): """If possible, return mtime of the file. Otherwise, return None.""" return None + @property + @abstractmethod + def is_local(self): + ... + def __hash__(self): return self.get_path_or_uri().__hash__() @@ -94,6 +99,10 @@ def get_filename(self): def is_persistently_cacheable(self): return False + @property + def is_local(self): + return False + class LocalSourceFile(SourceFile): def __init__(self, path): @@ -123,6 +132,10 @@ def mtime(self): def __fspath__(self): return self.path + @property + def is_local(self): + return True + class LocalGitFile(SourceFile): def __init__( @@ -136,7 +149,7 @@ def __init__( self.path = path def get_path_or_uri(self): - return "git+{}/{}@{}".format(self.repo_path, self.path, self.ref) + return "git+file://{}/{}@{}".format(self.repo_path, self.path, self.ref) def join(self, path): return LocalGitFile( @@ -147,16 +160,29 @@ def join(self, path): commit=self.commit, ) + def get_basedir(self): + return self.__class__( + repo_path=self.repo_path, + path=os.path.dirname(self.path), + tag=self.tag, + commit=self.commit, + ref=self.ref, + ) + def is_persistently_cacheable(self): return False def get_filename(self): - return os.path.basename(self.path) + return posixpath.basename(self.path) @property def ref(self): return self.tag or self.commit or self._ref + @property + def is_local(self): + return True + class HostingProviderFile(SourceFile): """Marker for denoting github source files from releases.""" @@ -229,6 +255,10 @@ def join(self, path): branch=self.branch, ) + @property + def is_local(self): + return False + class GithubFile(HostingProviderFile): def get_path_or_uri(self): @@ -276,7 +306,12 @@ def infer_source_file(path_or_uri, basedir: SourceFile = None): return basedir.join(path_or_uri) return LocalSourceFile(path_or_uri) if path_or_uri.startswith("git+file:"): - root_path, file_path, ref = split_git_path(path_or_uri) + try: + root_path, file_path, ref = split_git_path(path_or_uri) + except Exception as e: + raise WorkflowError( + f"Failed to read source {path_or_uri} from git repo.", e + ) return LocalGitFile(root_path, file_path, ref=ref) # something else return GenericSourceFile(path_or_uri) @@ -311,7 +346,7 @@ def runtime_cache_path(self): def open(self, source_file, mode="r"): cache_entry = self._cache(source_file) - return self._open(cache_entry, mode) + return self._open(LocalSourceFile(cache_entry), mode) def exists(self, source_file): try: @@ -343,7 +378,7 @@ def _cache(self, source_file): def _do_cache(self, source_file, cache_entry): # open from origin - with self._open(source_file.get_path_or_uri(), "rb") as source: + with self._open(source_file, "rb") as source: tmp_source = tempfile.NamedTemporaryFile( prefix=str(cache_entry), delete=False, # no need to delete since we move it below @@ -362,20 +397,34 @@ def _do_cache(self, source_file, cache_entry): # as mtime. os.utime(cache_entry, times=(mtime, mtime)) - def _open(self, path_or_uri, mode): + def _open_local_or_remote(self, source_file, mode): + from retry import retry_call + + if source_file.is_local: + return self._open(source_file, mode) + else: + return retry_call( + self._open, + [source_file, mode], + tries=3, + delay=3, + backoff=2, + logger=logger, + ) + + def _open(self, source_file, mode): from smart_open import open - if isinstance(path_or_uri, LocalGitFile): + if isinstance(source_file, LocalGitFile): import git return io.BytesIO( - git.Repo(path_or_uri.repo_path) - .git.show("{}:{}".format(path_or_uri.ref, path_or_uri.path)) + git.Repo(source_file.repo_path) + .git.show("{}:{}".format(source_file.ref, source_file.path)) .encode() ) - if isinstance(path_or_uri, SourceFile): - path_or_uri = path_or_uri.get_path_or_uri() + path_or_uri = source_file.get_path_or_uri() try: return open(path_or_uri, mode) diff --git a/snakemake/wrapper.py b/snakemake/wrapper.py index f3c5ef9dd..c872d1de1 100644 --- a/snakemake/wrapper.py +++ b/snakemake/wrapper.py @@ -12,17 +12,20 @@ from snakemake.exceptions import WorkflowError from snakemake.script import script -from snakemake.sourcecache import SourceCache, infer_source_file +from snakemake.sourcecache import LocalGitFile, SourceCache, infer_source_file PREFIX = "https://github.com/snakemake/snakemake-wrappers/raw/" +EXTENSIONS = [".py", ".R", ".Rmd", ".jl"] -def is_script(path): + +def is_script(source_file): + filename = source_file.get_filename() return ( - path.endswith("wrapper.py") - or path.endswith("wrapper.R") - or path.endswith("wrapper.jl") + filename.endswith("wrapper.py") + or filename.endswith("wrapper.R") + or filename.endswith("wrapper.jl") ) @@ -34,7 +37,7 @@ def get_path(path, prefix=None): parts = path.split("/") path = "/" + "/".join(parts[1:]) + "@" + parts[0] path = prefix + path - return path + return infer_source_file(path) def is_url(path): @@ -45,24 +48,13 @@ def is_url(path): ) -def is_local(path): - return path.startswith("file:") - - -def is_git_path(path): - return path.startswith("git+file:") - - -def find_extension( - path, sourcecache: SourceCache, extensions=[".py", ".R", ".Rmd", ".jl"] -): - for ext in extensions: - if path.endswith("wrapper{}".format(ext)): - return path +def find_extension(source_file, sourcecache: SourceCache): + for ext in EXTENSIONS: + if source_file.get_filename().endswith("wrapper{}".format(ext)): + return source_file - path = infer_source_file(path) - for ext in extensions: - script = path.join("wrapper{}".format(ext)) + for ext in EXTENSIONS: + script = source_file.join("wrapper{}".format(ext)) if sourcecache.exists(script): return script @@ -77,11 +69,8 @@ def get_conda_env(path, prefix=None): path = get_path(path, prefix=prefix) if is_script(path): # URLs and posixpaths share the same separator. Hence use posixpath here. - path = posixpath.dirname(path) - if is_git_path(path): - path, version = path.split("@") - return os.path.join(path, "environment.yaml") + "@" + version - return path + "/environment.yaml" + path = path.get_basedir() + return path.join("environment.yaml") def wrapper( @@ -112,11 +101,17 @@ def wrapper( Load a wrapper from https://github.com/snakemake/snakemake-wrappers under the given path + wrapper.(py|R|Rmd) and execute it. """ - path = get_script( + assert path is not None + script_source = get_script( path, SourceCache(runtime_cache_path=runtime_sourcecache_path), prefix=prefix ) + if script_source is None: + raise WorkflowError( + f"Unable to locate wrapper script for wrapper {path}. " + "This can be a network issue or a mistake in the wrapper URL." + ) script( - path, + script_source.get_path_or_uri(), "", input, output, diff --git a/tests/tests.py b/tests/tests.py index 329162eb2..da4a7d2ef 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -526,11 +526,29 @@ def test_conda_cmd_exe(): run(dpath("test_conda_cmd_exe"), use_conda=True) -@skip_on_windows # Conda support is partly broken on Win +@skip_on_windows # wrappers are for linux and macos only def test_wrapper(): run(dpath("test_wrapper"), use_conda=True) +@skip_on_windows # wrappers are for linux and macos only +def test_wrapper_local_git_prefix(): + import git + + with tempfile.TemporaryDirectory() as tmpdir: + print("Cloning wrapper repo...") + repo = git.Repo.clone_from( + "https://github.com/snakemake/snakemake-wrappers", tmpdir + ) + print("Cloning complete.") + + run( + dpath("test_wrapper"), + use_conda=True, + wrapper_prefix=f"git+file://{tmpdir}", + ) + + def test_get_log_none(): run(dpath("test_get_log_none"))