-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
chunk.py
472 lines (369 loc) · 13.3 KB
/
chunk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
""" A set of NumPy functions to apply per chunk """
from __future__ import annotations
import contextlib
from collections.abc import Container, Iterable, Sequence
from functools import wraps
from numbers import Integral
import numpy as np
from tlz import concat
from dask.core import flatten
def keepdims_wrapper(a_callable):
"""
A wrapper for functions that don't provide keepdims to ensure that they do.
"""
@wraps(a_callable)
def keepdims_wrapped_callable(x, axis=None, keepdims=None, *args, **kwargs):
r = a_callable(x, *args, axis=axis, **kwargs)
if not keepdims:
return r
axes = axis
if axes is None:
axes = range(x.ndim)
if not isinstance(axes, (Container, Iterable, Sequence)):
axes = [axes]
r_slice = tuple()
for each_axis in range(x.ndim):
if each_axis in axes:
r_slice += (None,)
else:
r_slice += (slice(None),)
r = r[r_slice]
return r
return keepdims_wrapped_callable
# Wrap NumPy functions to ensure they provide keepdims.
sum = np.sum
prod = np.prod
min = np.min
max = np.max
argmin = keepdims_wrapper(np.argmin)
nanargmin = keepdims_wrapper(np.nanargmin)
argmax = keepdims_wrapper(np.argmax)
nanargmax = keepdims_wrapper(np.nanargmax)
any = np.any
all = np.all
nansum = np.nansum
nanprod = np.nanprod
nancumprod = np.nancumprod
nancumsum = np.nancumsum
nanmin = np.nanmin
nanmax = np.nanmax
mean = np.mean
with contextlib.suppress(AttributeError):
nanmean = np.nanmean
var = np.var
with contextlib.suppress(AttributeError):
nanvar = np.nanvar
std = np.std
with contextlib.suppress(AttributeError):
nanstd = np.nanstd
def coarsen(reduction, x, axes, trim_excess=False, **kwargs):
"""Coarsen array by applying reduction to fixed size neighborhoods
Parameters
----------
reduction: function
Function like np.sum, np.mean, etc...
x: np.ndarray
Array to be coarsened
axes: dict
Mapping of axis to coarsening factor
Examples
--------
>>> x = np.array([1, 2, 3, 4, 5, 6])
>>> coarsen(np.sum, x, {0: 2})
array([ 3, 7, 11])
>>> coarsen(np.max, x, {0: 3})
array([3, 6])
Provide dictionary of scale per dimension
>>> x = np.arange(24).reshape((4, 6))
>>> x
array([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
>>> coarsen(np.min, x, {0: 2, 1: 3})
array([[ 0, 3],
[12, 15]])
You must avoid excess elements explicitly
>>> x = np.array([1, 2, 3, 4, 5, 6, 7, 8])
>>> coarsen(np.min, x, {0: 3}, trim_excess=True)
array([1, 4])
"""
# Insert singleton dimensions if they don't exist already
for i in range(x.ndim):
if i not in axes:
axes[i] = 1
if trim_excess:
ind = tuple(
slice(0, -(d % axes[i])) if d % axes[i] else slice(None, None)
for i, d in enumerate(x.shape)
)
x = x[ind]
# (10, 10) -> (5, 2, 5, 2)
newshape = tuple(concat([(x.shape[i] // axes[i], axes[i]) for i in range(x.ndim)]))
return reduction(x.reshape(newshape), axis=tuple(range(1, x.ndim * 2, 2)), **kwargs)
def trim(x, axes=None):
"""Trim boundaries off of array
>>> x = np.arange(24).reshape((4, 6))
>>> trim(x, axes={0: 0, 1: 1})
array([[ 1, 2, 3, 4],
[ 7, 8, 9, 10],
[13, 14, 15, 16],
[19, 20, 21, 22]])
>>> trim(x, axes={0: 1, 1: 1})
array([[ 7, 8, 9, 10],
[13, 14, 15, 16]])
"""
if isinstance(axes, Integral):
axes = [axes] * x.ndim
if isinstance(axes, dict):
axes = [axes.get(i, 0) for i in range(x.ndim)]
return x[tuple(slice(ax, -ax if ax else None) for ax in axes)]
def topk(a, k, axis, keepdims):
"""Chunk and combine function of topk
Extract the k largest elements from a on the given axis.
If k is negative, extract the -k smallest elements instead.
Note that, unlike in the parent function, the returned elements
are not sorted internally.
"""
assert keepdims is True
axis = axis[0]
if abs(k) >= a.shape[axis]:
return a
a = np.partition(a, -k, axis=axis)
k_slice = slice(-k, None) if k > 0 else slice(-k)
return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
def topk_aggregate(a, k, axis, keepdims):
"""Final aggregation function of topk
Invoke topk one final time and then sort the results internally.
"""
assert keepdims is True
a = topk(a, k, axis, keepdims)
axis = axis[0]
a = np.sort(a, axis=axis)
if k < 0:
return a
return a[
tuple(
slice(None, None, -1) if i == axis else slice(None) for i in range(a.ndim)
)
]
def argtopk_preprocess(a, idx):
"""Preparatory step for argtopk
Put data together with its original indices in a tuple.
"""
return a, idx
def argtopk(a_plus_idx, k, axis, keepdims):
"""Chunk and combine function of argtopk
Extract the indices of the k largest elements from a on the given axis.
If k is negative, extract the indices of the -k smallest elements instead.
Note that, unlike in the parent function, the returned elements
are not sorted internally.
"""
assert keepdims is True
axis = axis[0]
if isinstance(a_plus_idx, list):
a_plus_idx = list(flatten(a_plus_idx))
a = np.concatenate([ai for ai, _ in a_plus_idx], axis)
idx = np.concatenate(
[np.broadcast_to(idxi, ai.shape) for ai, idxi in a_plus_idx], axis
)
else:
a, idx = a_plus_idx
if abs(k) >= a.shape[axis]:
return a_plus_idx
idx2 = np.argpartition(a, -k, axis=axis)
k_slice = slice(-k, None) if k > 0 else slice(-k)
idx2 = idx2[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
return np.take_along_axis(a, idx2, axis), np.take_along_axis(idx, idx2, axis)
def argtopk_aggregate(a_plus_idx, k, axis, keepdims):
"""Final aggregation function of argtopk
Invoke argtopk one final time, sort the results internally, drop the data
and return the index only.
"""
assert keepdims is True
a_plus_idx = a_plus_idx if len(a_plus_idx) > 1 else a_plus_idx[0]
a, idx = argtopk(a_plus_idx, k, axis, keepdims)
axis = axis[0]
idx2 = np.argsort(a, axis=axis)
idx = np.take_along_axis(idx, idx2, axis)
if k < 0:
return idx
return idx[
tuple(
slice(None, None, -1) if i == axis else slice(None) for i in range(idx.ndim)
)
]
def arange(start, stop, step, length, dtype, like=None):
from dask.array.utils import arange_safe
res = arange_safe(start, stop, step, dtype, like=like)
return res[:-1] if len(res) > length else res
def linspace(start, stop, num, endpoint=True, dtype=None):
from dask.array.core import Array
if isinstance(start, Array):
start = start.compute()
if isinstance(stop, Array):
stop = stop.compute()
return np.linspace(start, stop, num, endpoint=endpoint, dtype=dtype)
def astype(x, astype_dtype=None, **kwargs):
return x.astype(astype_dtype, **kwargs)
def view(x, dtype, order="C"):
if order == "C":
try:
x = np.ascontiguousarray(x, like=x)
except TypeError:
x = np.ascontiguousarray(x)
return x.view(dtype)
else:
try:
x = np.asfortranarray(x, like=x)
except TypeError:
x = np.asfortranarray(x)
return x.T.view(dtype).T
def slice_with_int_dask_array(x, idx, offset, x_size, axis):
"""Chunk function of `slice_with_int_dask_array_on_axis`.
Slice one chunk of x by one chunk of idx.
Parameters
----------
x: ndarray, any dtype, any shape
i-th chunk of x
idx: ndarray, ndim=1, dtype=any integer
j-th chunk of idx (cartesian product with the chunks of x)
offset: ndarray, shape=(1, ), dtype=int64
Index of the first element along axis of the current chunk of x
x_size: int
Total size of the x da.Array along axis
axis: int
normalized axis to take elements from (0 <= axis < x.ndim)
Returns
-------
x sliced along axis, using only the elements of idx that fall inside the
current chunk.
"""
from dask.array.utils import asarray_safe, meta_from_array
idx = asarray_safe(idx, like=meta_from_array(x))
# Needed when idx is unsigned
idx = idx.astype(np.int64)
# Normalize negative indices
idx = np.where(idx < 0, idx + x_size, idx)
# A chunk of the offset dask Array is a numpy array with shape (1, ).
# It indicates the index of the first element along axis of the current
# chunk of x.
idx = idx - offset
# Drop elements of idx that do not fall inside the current chunk of x
idx_filter = (idx >= 0) & (idx < x.shape[axis])
idx = idx[idx_filter]
# np.take does not support slice indices
# return np.take(x, idx, axis)
return x[tuple(idx if i == axis else slice(None) for i in range(x.ndim))]
def slice_with_int_dask_array_aggregate(idx, chunk_outputs, x_chunks, axis):
"""Final aggregation function of `slice_with_int_dask_array_on_axis`.
Aggregate all chunks of x by one chunk of idx, reordering the output of
`slice_with_int_dask_array`.
Note that there is no combine function, as a recursive aggregation (e.g.
with split_every) would not give any benefit.
Parameters
----------
idx: ndarray, ndim=1, dtype=any integer
j-th chunk of idx
chunk_outputs: ndarray
concatenation along axis of the outputs of `slice_with_int_dask_array`
for all chunks of x and the j-th chunk of idx
x_chunks: tuple
dask chunks of the x da.Array along axis, e.g. ``(3, 3, 2)``
axis: int
normalized axis to take elements from (0 <= axis < x.ndim)
Returns
-------
Selection from all chunks of x for the j-th chunk of idx, in the correct
order
"""
# Needed when idx is unsigned
idx = idx.astype(np.int64)
# Normalize negative indices
idx = np.where(idx < 0, idx + sum(x_chunks), idx)
x_chunk_offset = 0
chunk_output_offset = 0
# Assemble the final index that picks from the output of the previous
# kernel by adding together one layer per chunk of x
# FIXME: this could probably be reimplemented with a faster search-based
# algorithm
idx_final = np.zeros_like(idx)
for x_chunk in x_chunks:
idx_filter = (idx >= x_chunk_offset) & (idx < x_chunk_offset + x_chunk)
idx_cum = np.cumsum(idx_filter)
idx_final += np.where(idx_filter, idx_cum - 1 + chunk_output_offset, 0)
x_chunk_offset += x_chunk
if idx_cum.size > 0:
chunk_output_offset += idx_cum[-1]
# np.take does not support slice indices
# return np.take(chunk_outputs, idx_final, axis)
return chunk_outputs[
tuple(
idx_final if i == axis else slice(None) for i in range(chunk_outputs.ndim)
)
]
def getitem(obj, index):
"""Getitem function
This function creates a copy of the desired selection for array-like
inputs when the selection is smaller than half of the original array. This
avoids excess memory usage when extracting a small portion from a large array.
For more information, see
https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing.
Parameters
----------
obj: ndarray, string, tuple, list
Object to get item from.
index: int, list[int], slice()
Desired selection to extract from obj.
Returns
-------
Selection obj[index]
"""
try:
result = obj[index]
except IndexError as e:
raise ValueError(
"Array chunk size or shape is unknown. "
"Possible solution with x.compute_chunk_sizes()"
) from e
try:
if not result.flags.owndata and obj.size >= 2 * result.size:
result = result.copy()
except AttributeError:
pass
return result
def take_along_axis_chunk(
arr: np.ndarray, indices: np.ndarray, offset: np.ndarray, arr_size: int, axis: int
):
"""Slice an ndarray according to ndarray indices along an axis.
Parameters
----------
arr: np.ndarray, dtype=Any
The data array.
indices: np.ndarray, dtype=int64
The indices of interest.
offset: np.ndarray, shape=(1, ), dtype=int64
Index of the first element along axis of the current chunk of arr
arr_size: int
Total size of the arr da.Array along axis
axis: int
The axis along which the indices are from.
Returns
-------
out: np.ndarray
The indexed arr.
"""
# Needed when indices is unsigned
indices = indices.astype(np.int64)
# Normalize negative indices
indices = np.where(indices < 0, indices + arr_size, indices)
# A chunk of the offset dask Array is a numpy array with shape (1, ).
# It indicates the index of the first element along axis of the current
# chunk of arr.
indices = indices - offset
# Drop elements of idx that do not fall inside the current chunk of arr.
idx_filter = (indices >= 0) & (indices < arr.shape[axis])
indices[~idx_filter] = 0
res = np.take_along_axis(arr, indices, axis=axis)
res[~idx_filter] = 0
return np.expand_dims(res, axis)