/
GS.py
326 lines (265 loc) · 11 KB
/
GS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
__author__ = "Johannes Köster"
__copyright__ = "Copyright 2022, Johannes Köster"
__email__ = "johannes.koester@tu-dortmund.de"
__license__ = "MIT"
import base64
import os
import re
import struct
import time
from snakemake.remote import AbstractRemoteObject, AbstractRemoteProvider
from snakemake.exceptions import WorkflowError, CheckSumMismatchException
from snakemake.common import lazy_property
import snakemake.io
from snakemake.utils import os_sync
try:
import google.cloud
from google.cloud import storage
from google.api_core import retry
from google_crc32c import Checksum
except ImportError as e:
raise WorkflowError(
"The Python 3 packages 'google-cloud-sdk' and `google-crc32c` "
"need to be installed to use GS remote() file functionality. %s" % e.msg
)
def google_cloud_retry_predicate(ex):
"""Given an exception from Google Cloud, determine if it's one in the
listing of transient errors (determined by function
google.api_core.retry.if_transient_error(exception)) or determine if
triggered by a hash mismatch due to a bad download. This function will
return a boolean to indicate if retry should be done, and is typically
used with the google.api_core.retry.Retry as a decorator (predicate).
Arguments:
ex (Exception) : the exception passed from the decorated function
Returns: boolean to indicate doing retry (True) or not (False)
"""
from requests.exceptions import ReadTimeout
# Most likely case is Google API transient error.
if retry.if_transient_error(ex):
return True
# Timeouts should be considered for retry as well.
if isinstance(ex, ReadTimeout):
return True
# Could also be checksum mismatch of download.
if isinstance(ex, CheckSumMismatchException):
return True
return False
@retry.Retry(predicate=google_cloud_retry_predicate)
def download_blob(blob, filename):
"""A helper function to download a storage Blob to a blob_file (the filename)
and validate it using the Crc32cCalculator.
Arguments:
blob (storage.Blob) : the Google storage blob object
blob_file (str) : the file path to download to
Returns: boolean to indicate doing retry (True) or not (False)
"""
# create parent directories if necessary
os.makedirs(os.path.dirname(filename), exist_ok=True)
# ideally we could calculate hash while streaming to file with provided function
# https://github.com/googleapis/python-storage/issues/29
with open(filename, "wb") as blob_file:
parser = Crc32cCalculator(blob_file)
blob.download_to_file(parser)
os.sync()
# **Important** hash can be incorrect or missing if not refreshed
blob.reload()
# Compute local hash and verify correct
if parser.hexdigest() != blob.crc32c:
os.remove(filename)
raise CheckSumMismatchException("The checksum of %s does not match." % filename)
return filename
class Crc32cCalculator:
"""The Google Python client doesn't provide a way to stream a file being
written, so we can wrap the file object in an additional class to
do custom handling. This is so we don't need to download the file
and then stream-read it again to calculate the hash.
"""
def __init__(self, fileobj):
self._fileobj = fileobj
self.checksum = Checksum()
def write(self, chunk):
self._fileobj.write(chunk)
self._update(chunk)
def _update(self, chunk):
"""Given a chunk from the read in file, update the hexdigest"""
self.checksum.update(chunk)
def hexdigest(self):
"""Return the hexdigest of the hasher.
The Base64 encoded CRC32c is in big-endian byte order.
See https://cloud.google.com/storage/docs/hashes-etags
"""
return base64.b64encode(self.checksum.digest()).decode("utf-8")
class RemoteProvider(AbstractRemoteProvider):
supports_default = True
def __init__(
self, *args, keep_local=False, stay_on_remote=False, is_default=False, **kwargs
):
super(RemoteProvider, self).__init__(
*args,
keep_local=keep_local,
stay_on_remote=stay_on_remote,
is_default=is_default,
**kwargs
)
self.client = storage.Client(*args, **kwargs)
def remote_interface(self):
return self.client
@property
def default_protocol(self):
"""The protocol that is prepended to the path when no protocol is specified."""
return "gs://"
@property
def available_protocols(self):
"""List of valid protocols for this remote provider."""
return ["gs://"]
class RemoteObject(AbstractRemoteObject):
def __init__(
self, *args, keep_local=False, provider=None, user_project=None, **kwargs
):
super(RemoteObject, self).__init__(
*args, keep_local=keep_local, provider=provider, **kwargs
)
if provider:
self.client = provider.remote_interface()
else:
self.client = storage.Client(*args, **kwargs)
# keep user_project available for when bucket is initialized
self._user_project = user_project
self._key = None
self._bucket_name = None
self._bucket = None
self._blob = None
async def inventory(self, cache: snakemake.io.IOCache):
"""Using client.list_blobs(), we want to iterate over the objects in
the "folder" of a bucket and store information about the IOFiles in the
provided cache (snakemake.io.IOCache) indexed by bucket/blob name.
This will be called by the first mention of a remote object, and
iterate over the entire bucket once (and then not need to again).
This includes:
- cache.exist_remote
- cache.mtime
- cache.size
"""
if cache.remaining_wait_time <= 0:
# No more time to create inventory.
return
start_time = time.time()
subfolder = os.path.dirname(self.blob.name)
for blob in self.client.list_blobs(self.bucket_name, prefix=subfolder):
# By way of being listed, it exists. mtime is a datetime object
name = "{}/{}".format(blob.bucket.name, blob.name)
cache.exists_remote[name] = True
cache.mtime[name] = snakemake.io.Mtime(remote=blob.updated.timestamp())
cache.size[name] = blob.size
cache.remaining_wait_time -= time.time() - start_time
# Mark bucket and prefix as having an inventory, such that this method is
# only called once for the subfolder in the bucket.
cache.exists_remote.has_inventory.add("%s/%s" % (self.bucket_name, subfolder))
# === Implementations of abstract class members ===
def get_inventory_parent(self):
return self.bucket_name
@retry.Retry(predicate=google_cloud_retry_predicate)
def exists(self):
return self.blob.exists()
def mtime(self):
if self.exists():
self.update_blob()
t = self.blob.updated
return t.timestamp()
else:
raise WorkflowError(
"The file does not seem to exist remotely: %s" % self.local_file()
)
def size(self):
if self.exists():
self.update_blob()
return self.blob.size // 1024
else:
return self._iofile.size_local
@retry.Retry(predicate=google_cloud_retry_predicate, deadline=600)
def _download(self):
"""Download with maximum retry duration of 600 seconds (10 minutes)"""
if not self.exists():
return None
# Create just a directory, or a file itself
if snakemake.io.is_flagged(self.local_file(), "directory"):
return self._download_directory()
return download_blob(self.blob, self.local_file())
@retry.Retry(predicate=google_cloud_retry_predicate)
def _download_directory(self):
"""A 'private' function to handle download of a storage folder, which
includes the content found inside.
"""
# Create the directory locally
os.makedirs(self.local_file(), exist_ok=True)
for blob in self.client.list_blobs(self.bucket_name, prefix=self.key):
local_name = "{}/{}".format(blob.bucket.name, blob.name)
# Don't try to create "directory blob"
if os.path.exists(local_name) and os.path.isdir(local_name):
continue
download_blob(blob, local_name)
# Return the root directory
return self.local_file()
@retry.Retry(predicate=google_cloud_retry_predicate)
def _upload(self):
try:
if not self.bucket.exists():
self.bucket.create()
self.update_blob()
# Distinguish between single file, and folder
f = self.local_file()
if os.path.isdir(f):
# Ensure the "directory" exists
self.blob.upload_from_string(
"", content_type="application/x-www-form-urlencoded;charset=UTF-8"
)
for root, _, files in os.walk(f):
for filename in files:
filename = os.path.join(root, filename)
bucket_path = filename.lstrip(self.bucket.name).lstrip("/")
blob = self.bucket.blob(bucket_path)
blob.upload_from_filename(filename)
else:
self.blob.upload_from_filename(f)
except google.cloud.exceptions.Forbidden as e:
raise WorkflowError(
e,
"When running locally, make sure that you are authenticated "
"via gcloud (see Snakemake documentation). When running in a "
"kubernetes cluster, make sure that storage-rw is added to "
"--scopes (see Snakemake documentation).",
)
@property
def name(self):
return self.key
@property
def list(self):
return [k.name for k in self.bucket.list_blobs()]
# ========= Helpers ===============
@retry.Retry(predicate=google_cloud_retry_predicate)
def update_blob(self):
self._blob = self.bucket.get_blob(self.key)
@lazy_property
def bucket(self):
return self.client.bucket(self.bucket_name, user_project=self._user_project)
@lazy_property
def blob(self):
return self.bucket.blob(self.key)
@lazy_property
def bucket_name(self):
return self.parse().group("bucket")
@property
def key(self):
key = self.parse().group("key")
f = self.local_file()
if snakemake.io.is_flagged(f, "directory"):
key = key if f.endswith("/") else key + "/"
return key
def parse(self):
m = re.search("(?P<bucket>[^/]*)/(?P<key>.*)", self.local_file())
if len(m.groups()) != 2:
raise WorkflowError(
"GS remote file {} does not have the form "
"<bucket>/<key>.".format(self.local_file())
)
return m