Skip to content

Commit

Permalink
fix: avoid incomplete remote files in case of errors and automaticall…
Browse files Browse the repository at this point in the history
…y retry download and upload (#1432)
  • Loading branch information
johanneskoester committed Feb 26, 2022
1 parent 04f39a9 commit 8fc23ed
Show file tree
Hide file tree
Showing 19 changed files with 157 additions and 108 deletions.
3 changes: 2 additions & 1 deletion setup.py
Expand Up @@ -79,6 +79,7 @@
"tabulate",
"yte >=1.0,<2.0",
"jinja2 >=3.0,<4.0",
"retry",
],
extras_require={
"reports": ["jinja2", "networkx", "pygments", "pygraphviz"],
Expand All @@ -92,7 +93,7 @@
"pep": [
"peppy",
"eido",
]
],
},
classifiers=[
"Development Status :: 5 - Production/Stable",
Expand Down
25 changes: 13 additions & 12 deletions snakemake/dag.py
Expand Up @@ -2247,18 +2247,19 @@ def get_outputs_with_changes(self, change_type):
changed.extend(list(job.outputs_older_than_script_or_notebook()))
return changed

def warn_about_changes(self):
for change_type in ["code", "input", "params"]:
changed = self.get_outputs_with_changes(change_type)
if changed:
rerun_trigger = ""
if not ON_WINDOWS:
rerun_trigger = f"\n To trigger a re-run, use 'snakemake -R $(snakemake --list-{change_type}-changes)'."
logger.warning(
f"The {change_type} used to generate one or several output files has changed:\n"
f" To inspect which output files have changes, run 'snakemake --list-{change_type}-changes'."
f"{rerun_trigger}"
)
def warn_about_changes(self, quiet=False):
if not quiet:
for change_type in ["code", "input", "params"]:
changed = self.get_outputs_with_changes(change_type)
if changed:
rerun_trigger = ""
if not ON_WINDOWS:
rerun_trigger = f"\n To trigger a re-run, use 'snakemake -R $(snakemake --list-{change_type}-changes)'."
logger.warning(
f"The {change_type} used to generate one or several output files has changed:\n"
f" To inspect which output files have changes, run 'snakemake --list-{change_type}-changes'."
f"{rerun_trigger}"
)

def __str__(self):
return self.dot()
Expand Down
12 changes: 8 additions & 4 deletions snakemake/remote/AzBlob.py
Expand Up @@ -15,7 +15,11 @@

# module specific
from snakemake.exceptions import WorkflowError, AzureFileException
from snakemake.remote import AbstractRemoteObject, AbstractRemoteProvider
from snakemake.remote import (
AbstractRemoteObject,
AbstractRemoteProvider,
AbstractRemoteRetryObject,
)

# service provider support
try:
Expand Down Expand Up @@ -60,7 +64,7 @@ def available_protocols(self):
return ["ab://"]


class RemoteObject(AbstractRemoteObject):
class RemoteObject(AbstractRemoteRetryObject):
def __init__(self, *args, keep_local=False, provider=None, **kwargs):
super(RemoteObject, self).__init__(
*args, keep_local=keep_local, provider=provider, **kwargs
Expand Down Expand Up @@ -95,7 +99,7 @@ def size(self):
return self._as.blob_size(self.container_name, self.blob_name)
return self._iofile.size_local

def download(self):
def _download(self):
if self.exists():
os.makedirs(os.path.dirname(self.local_file()), exist_ok=True)
self._as.download_from_azure_storage(
Expand All @@ -105,7 +109,7 @@ def download(self):
return self.local_file()
return None

def upload(self):
def _upload(self):
self._as.upload_to_azure_storage(
container_name=self.container_name,
blob_name=self.blob_name,
Expand Down
15 changes: 10 additions & 5 deletions snakemake/remote/EGA.py
Expand Up @@ -14,9 +14,14 @@
from requests.auth import HTTPBasicAuth


from snakemake.remote import AbstractRemoteObject, AbstractRemoteProvider
from snakemake.remote import (
AbstractRemoteObject,
AbstractRemoteProvider,
check_deprecated_retry,
)
from snakemake.exceptions import WorkflowError
from snakemake.common import lazy_property
from snakemake.logging import logger


EGAFileInfo = namedtuple("EGAFileInfo", ["size", "status", "id", "checksum"])
Expand All @@ -30,7 +35,7 @@ def __init__(
keep_local=False,
stay_on_remote=False,
is_default=False,
retry=5,
retry=None,
**kwargs
):
super().__init__(
Expand All @@ -40,7 +45,7 @@ def __init__(
is_default=is_default,
**kwargs
)
self.retry = retry
check_deprecated_retry(retry)
self._token = None
self._expires = None
self._file_cache = dict()
Expand Down Expand Up @@ -193,7 +198,7 @@ def _credentials(cls, name):
)


class RemoteObject(AbstractRemoteObject):
class RemoteObject(AbstractRemoteRetryObject):
# === Implementations of abstract class members ===
def _stats(self):
return self.provider.get_files(self.parts.dataset)[self.parts.path]
Expand All @@ -209,7 +214,7 @@ def mtime(self):
# Hence, the files are always considered to be "ancient".
return 0

def download(self):
def _download(self):
stats = self._stats()

r = self.provider.api_request(
Expand Down
4 changes: 2 additions & 2 deletions snakemake/remote/FTP.py
Expand Up @@ -179,7 +179,7 @@ def size(self):
else:
return self._iofile.size_local

def download(self, make_dest_dirs=True):
def _download(self, make_dest_dirs=True):
with self.connection_pool.item() as ftpc:
if self.exists():
# if the destination path does not exist
Expand All @@ -197,7 +197,7 @@ def download(self, make_dest_dirs=True):
"The file does not seem to exist remotely: %s" % self.local_file()
)

def upload(self):
def _upload(self):
with self.connection_pool.item() as ftpc:
ftpc.synchronize_times()
ftpc.upload(source=self.local_path, target=self.remote_path)
Expand Down
4 changes: 2 additions & 2 deletions snakemake/remote/GS.py
Expand Up @@ -220,7 +220,7 @@ def size(self):
return self._iofile.size_local

@retry.Retry(predicate=google_cloud_retry_predicate, deadline=600)
def download(self):
def _download(self):
"""Download with maximum retry duration of 600 seconds (10 minutes)"""
if not self.exists():
return None
Expand Down Expand Up @@ -251,7 +251,7 @@ def _download_directory(self):
return self.local_file()

@retry.Retry(predicate=google_cloud_retry_predicate)
def upload(self):
def _upload(self):
try:
if not self.bucket.exists():
self.bucket.create()
Expand Down
4 changes: 2 additions & 2 deletions snakemake/remote/HTTP.py
Expand Up @@ -209,7 +209,7 @@ def size(self):
else:
return self._iofile.size_local

def download(self, make_dest_dirs=True):
def _download(self, make_dest_dirs=True):
with self.httpr(stream=True) as httpr:
if self.exists():
# Find out if the source file is gzip compressed in order to keep
Expand Down Expand Up @@ -237,7 +237,7 @@ def download(self, make_dest_dirs=True):
"The file does not seem to exist remotely: %s" % self.remote_file()
)

def upload(self):
def _upload(self):
raise HTTPFileException(
"Upload is not permitted for the HTTP remote provider. Is an output set to HTTP.remote()?"
)
Expand Down
4 changes: 2 additions & 2 deletions snakemake/remote/NCBI.py
Expand Up @@ -139,7 +139,7 @@ def size(self):
else:
return self._iofile.size_local

def download(self):
def _download(self):
if self.exists():
self._ncbi.fetch_from_ncbi(
[self.accession],
Expand All @@ -155,7 +155,7 @@ def download(self):
"The record does not seem to exist remotely: %s" % self.accession
)

def upload(self):
def _upload(self):
raise NCBIFileException(
"Upload is not permitted for the NCBI remote provider. Is an output set to NCBI.RemoteProvider.remote()?"
)
Expand Down
12 changes: 8 additions & 4 deletions snakemake/remote/S3.py
Expand Up @@ -11,7 +11,11 @@
import concurrent.futures

# module-specific
from snakemake.remote import AbstractRemoteObject, AbstractRemoteProvider
from snakemake.remote import (
AbstractRemoteObject,
AbstractRemoteProvider,
AbstractRemoteRetryObject,
)
from snakemake.exceptions import WorkflowError, S3FileException
from snakemake.utils import os_sync

Expand Down Expand Up @@ -59,7 +63,7 @@ def available_protocols(self):
return ["s3://"]


class RemoteObject(AbstractRemoteObject):
class RemoteObject(AbstractRemoteRetryObject):
"""This is a class to interact with the AWS S3 object store."""

def __init__(self, *args, keep_local=False, provider=None, **kwargs):
Expand Down Expand Up @@ -97,11 +101,11 @@ def size(self):
else:
return self._iofile.size_local

def download(self):
def _download(self):
self._s3c.download_from_s3(self.s3_bucket, self.s3_key, self.local_file())
os_sync() # ensure flush to disk

def upload(self):
def _upload(self):
self._s3c.upload_to_s3(
self.s3_bucket,
self.local_file(),
Expand Down
4 changes: 2 additions & 2 deletions snakemake/remote/SFTP.py
Expand Up @@ -120,7 +120,7 @@ def size(self):
else:
return self._iofile.size_local

def download(self, make_dest_dirs=True):
def _download(self, make_dest_dirs=True):
with self.connection_pool.item() as sftpc:
if self.exists():
# if the destination path does not exist
Expand Down Expand Up @@ -156,7 +156,7 @@ def mkdir_remote_path(self):
sftpc.mkdir(part)
sftpc.chdir(part)

def upload(self):
def _upload(self):
if self.provider.mkdir_remote:
self.mkdir_remote_path()

Expand Down
12 changes: 8 additions & 4 deletions snakemake/remote/XRootD.py
Expand Up @@ -8,7 +8,11 @@
import re

from stat import S_ISREG
from snakemake.remote import AbstractRemoteObject, AbstractRemoteProvider
from snakemake.remote import (
AbstractRemoteObject,
AbstractRemoteProvider,
AbstractRemoteRetryObject,
)
from snakemake.exceptions import WorkflowError, XRootDFileException

try:
Expand Down Expand Up @@ -51,7 +55,7 @@ def available_protocols(self):
return ["root://", "roots://", "rootk://"]


class RemoteObject(AbstractRemoteObject):
class RemoteObject(AbstractRemoteRetryObject):
"""This is a class to interact with XRootD servers."""

def __init__(
Expand Down Expand Up @@ -89,11 +93,11 @@ def size(self):
else:
return self._iofile.size_local

def download(self):
def _download(self):
assert not self.stay_on_remote
self._xrd.copy(self.remote_file(), self.file())

def upload(self):
def _upload(self):
assert not self.stay_on_remote
self._xrd.copy(self.file(), self.remote_file())

Expand Down
46 changes: 42 additions & 4 deletions snakemake/remote/__init__.py
Expand Up @@ -9,8 +9,11 @@
import re
from functools import partial
from abc import ABCMeta, abstractmethod
from wrapt import ObjectProxy
from contextlib import contextmanager
import shutil

from wrapt import ObjectProxy
from retry import retry

try:
from connection_pool import ConnectionPool
Expand All @@ -22,6 +25,7 @@
import collections

# module-specific
from snakemake.exceptions import WorkflowError
import snakemake.io
from snakemake.logging import logger
from snakemake.common import parse_uri
Expand Down Expand Up @@ -225,12 +229,29 @@ def mtime(self):
def size(self):
pass

def download(self):
try:
return self._download()
except Exception as e:
local_path = self.local_file()
if os.path.exists(local_path):
if os.path.isdir(local_path):
shutil.rmtree(local_path)
os.remove(local_path)
raise WorkflowError(e)

def upload(self):
try:
self._upload()
except Exception as e:
raise WorkflowError(e)

@abstractmethod
def download(self, *args, **kwargs):
def _download(self, *args, **kwargs):
pass

@abstractmethod
def upload(self, *args, **kwargs):
def _upload(self, *args, **kwargs):
pass

@abstractmethod
Expand All @@ -253,7 +274,17 @@ def local_touch_or_create(self):
self._iofile.touch_or_create()


class DomainObject(AbstractRemoteObject):
class AbstractRemoteRetryObject(AbstractRemoteObject):
@retry(tries=3, delay=3, backoff=2, logger=logger)
def download(self):
return super().download()

@retry(tries=3, delay=3, backoff=2, logger=logger)
def upload(self):
super().upload()


class DomainObject(AbstractRemoteRetryObject):
"""This is a mixin related to parsing components
out of a location path specified as
(host|IP):port/remote/location
Expand Down Expand Up @@ -474,3 +505,10 @@ def remote(self, value, *args, provider_kws=None, **kwargs):


AUTO = AutoRemoteProvider()


def check_deprecated_retry(retry):
if retry:
logger.warning(
"Using deprecated retry argument. Snakemake now always uses 3 retries on download and upload."
)

0 comments on commit 8fc23ed

Please sign in to comment.