Skip to content

Commit

Permalink
fix: improve robustness when retrieving remote source files, fixed us…
Browse files Browse the repository at this point in the history
…age of local git repos as wrapper prefixes (in collaboration with @cokelaer and @Smeds) (#1495)

* fix: fixed usage of local git repos as wrapper prefixes

* implement retry mechanism for non-local source file access

* logging fixes

* avoid retry import
  • Loading branch information
johanneskoester committed Mar 18, 2022
1 parent 5cf275a commit e16531d
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 52 deletions.
22 changes: 15 additions & 7 deletions snakemake/deployment/conda.py
Expand Up @@ -6,7 +6,12 @@
import os
from pathlib import Path
import re
from snakemake.sourcecache import LocalGitFile, LocalSourceFile, infer_source_file
from snakemake.sourcecache import (
LocalGitFile,
LocalSourceFile,
SourceFile,
infer_source_file,
)
import subprocess
import tempfile
from urllib.request import urlopen
Expand Down Expand Up @@ -498,9 +503,7 @@ def create(self, dryrun=False):

logger.debug(out)
logger.info(
"Environment for {} created (location: {})".format(
os.path.relpath(env_file), os.path.relpath(env_path)
)
f"Environment for {self.file.get_path_or_uri()} created (location: {os.path.relpath(env_path)})"
)
except subprocess.CalledProcessError as e:
# remove potential partially installed environment
Expand Down Expand Up @@ -706,8 +709,10 @@ def __eq__(self, other):


class CondaEnvFileSpec(CondaEnvSpec):
def __init__(self, filepath: str, rule=None):
if isinstance(filepath, _IOFile):
def __init__(self, filepath, rule=None):
if isinstance(filepath, SourceFile):
self.file = IOFile(str(filepath.get_path_or_uri()), rule=rule)
elif isinstance(filepath, _IOFile):
self.file = filepath
else:
self.file = IOFile(filepath, rule=rule)
Expand Down Expand Up @@ -777,5 +782,8 @@ def __eq__(self, other):
return self.name == other.name


def is_conda_env_file(spec: str):
def is_conda_env_file(spec):
if isinstance(spec, SourceFile):
spec = spec.get_filename()

return spec.endswith(".yaml") or spec.endswith(".yml")
2 changes: 1 addition & 1 deletion snakemake/logging.py
Expand Up @@ -349,7 +349,7 @@ def set_level(self, level):
def logfile_hint(self):
if self.mode == Mode.default:
logfile = self.get_logfile()
self.info("Complete log: {}".format(logfile))
self.info("Complete log: {}".format(os.path.relpath(logfile)))

def location(self, msg):
callerframerecord = inspect.stack()[1]
Expand Down
4 changes: 3 additions & 1 deletion snakemake/shell.py
Expand Up @@ -205,7 +205,9 @@ def __new__(
)
logger.info("Activating singularity image {}".format(container_img))
if conda_env:
logger.info("Activating conda environment: {}".format(conda_env))
logger.info(
"Activating conda environment: {}".format(os.path.relpath(conda_env))
)

tmpdir_resource = resources.get("tmpdir", None)
# environment variable lists for linear algebra libraries taken from:
Expand Down
73 changes: 61 additions & 12 deletions snakemake/sourcecache.py
Expand Up @@ -5,6 +5,7 @@

import hashlib
from pathlib import Path
import posixpath
import re
import os
import shutil
Expand All @@ -14,7 +15,6 @@
from abc import ABC, abstractmethod
from datetime import datetime


from snakemake.common import (
ON_WINDOWS,
is_local_file,
Expand Down Expand Up @@ -66,6 +66,11 @@ def mtime(self):
"""If possible, return mtime of the file. Otherwise, return None."""
return None

@property
@abstractmethod
def is_local(self):
...

def __hash__(self):
return self.get_path_or_uri().__hash__()

Expand Down Expand Up @@ -94,6 +99,10 @@ def get_filename(self):
def is_persistently_cacheable(self):
return False

@property
def is_local(self):
return False


class LocalSourceFile(SourceFile):
def __init__(self, path):
Expand Down Expand Up @@ -123,6 +132,10 @@ def mtime(self):
def __fspath__(self):
return self.path

@property
def is_local(self):
return True


class LocalGitFile(SourceFile):
def __init__(
Expand All @@ -136,7 +149,7 @@ def __init__(
self.path = path

def get_path_or_uri(self):
return "git+{}/{}@{}".format(self.repo_path, self.path, self.ref)
return "git+file://{}/{}@{}".format(self.repo_path, self.path, self.ref)

def join(self, path):
return LocalGitFile(
Expand All @@ -147,16 +160,29 @@ def join(self, path):
commit=self.commit,
)

def get_basedir(self):
return self.__class__(
repo_path=self.repo_path,
path=os.path.dirname(self.path),
tag=self.tag,
commit=self.commit,
ref=self.ref,
)

def is_persistently_cacheable(self):
return False

def get_filename(self):
return os.path.basename(self.path)
return posixpath.basename(self.path)

@property
def ref(self):
return self.tag or self.commit or self._ref

@property
def is_local(self):
return True


class HostingProviderFile(SourceFile):
"""Marker for denoting github source files from releases."""
Expand Down Expand Up @@ -229,6 +255,10 @@ def join(self, path):
branch=self.branch,
)

@property
def is_local(self):
return False


class GithubFile(HostingProviderFile):
def get_path_or_uri(self):
Expand Down Expand Up @@ -276,7 +306,12 @@ def infer_source_file(path_or_uri, basedir: SourceFile = None):
return basedir.join(path_or_uri)
return LocalSourceFile(path_or_uri)
if path_or_uri.startswith("git+file:"):
root_path, file_path, ref = split_git_path(path_or_uri)
try:
root_path, file_path, ref = split_git_path(path_or_uri)
except Exception as e:
raise WorkflowError(
f"Failed to read source {path_or_uri} from git repo.", e
)
return LocalGitFile(root_path, file_path, ref=ref)
# something else
return GenericSourceFile(path_or_uri)
Expand Down Expand Up @@ -311,7 +346,7 @@ def runtime_cache_path(self):

def open(self, source_file, mode="r"):
cache_entry = self._cache(source_file)
return self._open(cache_entry, mode)
return self._open(LocalSourceFile(cache_entry), mode)

def exists(self, source_file):
try:
Expand Down Expand Up @@ -343,7 +378,7 @@ def _cache(self, source_file):

def _do_cache(self, source_file, cache_entry):
# open from origin
with self._open(source_file.get_path_or_uri(), "rb") as source:
with self._open(source_file, "rb") as source:
tmp_source = tempfile.NamedTemporaryFile(
prefix=str(cache_entry),
delete=False, # no need to delete since we move it below
Expand All @@ -362,20 +397,34 @@ def _do_cache(self, source_file, cache_entry):
# as mtime.
os.utime(cache_entry, times=(mtime, mtime))

def _open(self, path_or_uri, mode):
def _open_local_or_remote(self, source_file, mode):
from retry import retry_call

if source_file.is_local:
return self._open(source_file, mode)
else:
return retry_call(
self._open,
[source_file, mode],
tries=3,
delay=3,
backoff=2,
logger=logger,
)

def _open(self, source_file, mode):
from smart_open import open

if isinstance(path_or_uri, LocalGitFile):
if isinstance(source_file, LocalGitFile):
import git

return io.BytesIO(
git.Repo(path_or_uri.repo_path)
.git.show("{}:{}".format(path_or_uri.ref, path_or_uri.path))
git.Repo(source_file.repo_path)
.git.show("{}:{}".format(source_file.ref, source_file.path))
.encode()
)

if isinstance(path_or_uri, SourceFile):
path_or_uri = path_or_uri.get_path_or_uri()
path_or_uri = source_file.get_path_or_uri()

try:
return open(path_or_uri, mode)
Expand Down
55 changes: 25 additions & 30 deletions snakemake/wrapper.py
Expand Up @@ -12,17 +12,20 @@

from snakemake.exceptions import WorkflowError
from snakemake.script import script
from snakemake.sourcecache import SourceCache, infer_source_file
from snakemake.sourcecache import LocalGitFile, SourceCache, infer_source_file


PREFIX = "https://github.com/snakemake/snakemake-wrappers/raw/"

EXTENSIONS = [".py", ".R", ".Rmd", ".jl"]

def is_script(path):

def is_script(source_file):
filename = source_file.get_filename()
return (
path.endswith("wrapper.py")
or path.endswith("wrapper.R")
or path.endswith("wrapper.jl")
filename.endswith("wrapper.py")
or filename.endswith("wrapper.R")
or filename.endswith("wrapper.jl")
)


Expand All @@ -34,7 +37,7 @@ def get_path(path, prefix=None):
parts = path.split("/")
path = "/" + "/".join(parts[1:]) + "@" + parts[0]
path = prefix + path
return path
return infer_source_file(path)


def is_url(path):
Expand All @@ -45,24 +48,13 @@ def is_url(path):
)


def is_local(path):
return path.startswith("file:")


def is_git_path(path):
return path.startswith("git+file:")


def find_extension(
path, sourcecache: SourceCache, extensions=[".py", ".R", ".Rmd", ".jl"]
):
for ext in extensions:
if path.endswith("wrapper{}".format(ext)):
return path
def find_extension(source_file, sourcecache: SourceCache):
for ext in EXTENSIONS:
if source_file.get_filename().endswith("wrapper{}".format(ext)):
return source_file

path = infer_source_file(path)
for ext in extensions:
script = path.join("wrapper{}".format(ext))
for ext in EXTENSIONS:
script = source_file.join("wrapper{}".format(ext))

if sourcecache.exists(script):
return script
Expand All @@ -77,11 +69,8 @@ def get_conda_env(path, prefix=None):
path = get_path(path, prefix=prefix)
if is_script(path):
# URLs and posixpaths share the same separator. Hence use posixpath here.
path = posixpath.dirname(path)
if is_git_path(path):
path, version = path.split("@")
return os.path.join(path, "environment.yaml") + "@" + version
return path + "/environment.yaml"
path = path.get_basedir()
return path.join("environment.yaml")


def wrapper(
Expand Down Expand Up @@ -112,11 +101,17 @@ def wrapper(
Load a wrapper from https://github.com/snakemake/snakemake-wrappers under
the given path + wrapper.(py|R|Rmd) and execute it.
"""
path = get_script(
assert path is not None
script_source = get_script(
path, SourceCache(runtime_cache_path=runtime_sourcecache_path), prefix=prefix
)
if script_source is None:
raise WorkflowError(
f"Unable to locate wrapper script for wrapper {path}. "
"This can be a network issue or a mistake in the wrapper URL."
)
script(
path,
script_source.get_path_or_uri(),
"",
input,
output,
Expand Down
20 changes: 19 additions & 1 deletion tests/tests.py
Expand Up @@ -526,11 +526,29 @@ def test_conda_cmd_exe():
run(dpath("test_conda_cmd_exe"), use_conda=True)


@skip_on_windows # Conda support is partly broken on Win
@skip_on_windows # wrappers are for linux and macos only
def test_wrapper():
run(dpath("test_wrapper"), use_conda=True)


@skip_on_windows # wrappers are for linux and macos only
def test_wrapper_local_git_prefix():
import git

with tempfile.TemporaryDirectory() as tmpdir:
print("Cloning wrapper repo...")
repo = git.Repo.clone_from(
"https://github.com/snakemake/snakemake-wrappers", tmpdir
)
print("Cloning complete.")

run(
dpath("test_wrapper"),
use_conda=True,
wrapper_prefix=f"git+file://{tmpdir}",
)


def test_get_log_none():
run(dpath("test_get_log_none"))

Expand Down

0 comments on commit e16531d

Please sign in to comment.