Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed permission issues when using zenodo remote provider to access restricted depositions #1634

Merged
merged 3 commits into from May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion snakemake/logging.py
Expand Up @@ -362,7 +362,9 @@ def location(self, msg):
def info(self, msg, indent=False):
self.handler(dict(level="info", msg=msg, indent=indent))

def warning(self, msg):
def warning(self, msg, *fmt_items):
if fmt_items:
msg = msg % fmt_items
self.handler(dict(level="warning", msg=msg))

def debug(self, msg):
Expand Down
88 changes: 67 additions & 21 deletions snakemake/remote/zenodo.py
Expand Up @@ -17,7 +17,9 @@
from snakemake.common import lazy_property


ZenFileInfo = namedtuple("ZenFileInfo", ["checksum", "filesize", "id", "download"])
ZenFileInfo = namedtuple(
"ZenFileInfo", ["filename", "checksum", "filesize", "download"]
)


class RemoteProvider(AbstractRemoteProvider):
Expand Down Expand Up @@ -89,7 +91,7 @@ def _download(self):

if local_md5 != stats.checksum:
raise ZenodoFileException(
"File checksums do not match for remote file id: {}".format(stats.id)
"File checksums do not match for remote file: {}".format(stats.filename)
)

def _upload(self):
Expand Down Expand Up @@ -137,15 +139,24 @@ def __init__(self, *args, **kwargs):
else:
self._baseurl = "https://zenodo.org"

if "deposition" in kwargs:
self.deposition = kwargs.pop("deposition")
self.bucket = self.get_bucket()
else:
self.is_new_deposition = "deposition" not in kwargs
if self.is_new_deposition:
# Creating a new deposition, as deposition id was not supplied.
self.deposition, self.bucket = self.create_deposition().values()
self.create_deposition()
else:
self.deposition = kwargs.pop("deposition")
self._bucket = None
self.is_new_deposition = False

def _api_request(
self, url, method="GET", data=None, headers={}, files=None, json=False
self,
url,
method="GET",
data=None,
headers={},
files=None,
json=False,
restricted_access=True,
):

# Create a session with a hook to raise error on bad request.
Expand All @@ -154,7 +165,7 @@ def _api_request(
session.headers["Authorization"] = "Bearer {}".format(self._access_token)
session.headers.update(headers)

cookies = self.restricted_access_cookies
cookies = self.restricted_access_cookies if restricted_access else None

# Run query.
try:
Expand All @@ -177,25 +188,60 @@ def create_deposition(self):
data="{}",
json=True,
)
return {"id": resp["id"], "bucket": resp["links"]["bucket"]}
self.deposition = resp["id"]
self._bucket = resp["links"]["bucket"]

def get_bucket(self):
resp = self._api_request(
self._baseurl + "/api/deposit/depositions/{}".format(self.deposition),
headers={"Content-Type": "application/json"},
json=True,
)
return resp["links"]["bucket"]
@property
def bucket(self):
if self._bucket is None:
resp = self._api_request(
self._baseurl + "/api/deposit/depositions/{}".format(self.deposition),
headers={"Content-Type": "application/json"},
json=True,
)
self._bucket = resp["links"]["bucket"]
return self._bucket

def get_files(self):
if self.is_new_deposition:
return self.get_files_own_deposition()
else:
return self.get_files_record()

def get_files_own_deposition(self):
files = self._api_request(
self._baseurl + "/api/deposit/depositions/{}/files".format(self.deposition),
headers={"Content-Type": "application/json"},
json=True,
)
return {
os.path.basename(f["filename"]): ZenFileInfo(
f["checksum"], int(f["filesize"]), f["id"], f["links"]["download"]
f["filename"], f["checksum"], int(f["filesize"]), f["links"]["download"]
)
for f in files
}

def get_files_record(self):
resp = self._api_request(
self._baseurl + "/api/records/{}".format(self.deposition),
headers={"Content-Type": "application/json"},
json=True,
)
files = resp["files"]

def get_checksum(f):
checksum = f["checksum"]
if checksum.startswith("md5:"):
return checksum[4:]
else:
raise ZenodoFileException(
"Unsupported checksum (currently only md5 support is "
f"implemented for Zenodo): {checksum}"
)

return {
os.path.basename(f["key"]): ZenFileInfo(
f["key"], get_checksum(f), int(f["size"]), f["links"]["self"]
)
for f in files
}
Expand All @@ -211,9 +257,9 @@ def restricted_access_cookies(self):
self._baseurl
+ f"/record/{self.deposition}?token={self.restricted_access_token}"
)
resp = self._api_request(url)
if "session" in resp["cookies"]:
self._restricted_access_cookies = resp["cookies"]
resp = self._api_request(url, restricted_access=False)
if "session" in resp.cookies:
self._restricted_access_cookies = resp.cookies
else:
raise WorkflowError(
"Failure to retrieve session cookie with given restricted access token. "
Expand Down