Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xarray.dot() dask problems #2074

Closed
crusaderky opened this issue Apr 22, 2018 · 10 comments
Closed

xarray.dot() dask problems #2074

crusaderky opened this issue Apr 22, 2018 · 10 comments

Comments

@crusaderky
Copy link
Contributor

crusaderky commented Apr 22, 2018

xarray.dot() has comparable performance with numpy.einsum.
However, when it uses a dask backend, it's much slower than the new dask.array.einsum function (dask/dask#3412).
The performance gap widens when the dimension upon which you are reducing is chunked.

Also, for some reason dot(a<s, t>, b<t>, dims=[t]) and dot(a<s,t>, a<s,t>, dims=[s,t]) do work (very slowly) when s and t are chunked, while dot(a<s, t>, a<s, t>, dims=[t]) crashes complaining it can't operate on a chunked core dim (related discussion: #1995).

The proposed solution is to simply wait for dask/dask#3412 to reach the next release and then reimplement xarray.dot to use dask.array.einsum. This means that dask users will lose the ability to use xarray.dot if they upgrade xarray version but not dask version, but I believe it shouldn't be a big problem for most?

import numpy
import dask.array
import xarray

def bench(tchunk, a_by_a, dims, iis):
    print(f"\nbench({tchunk}, {a_by_a}, {dims}, {iis})")

    a = xarray.DataArray(
        dask.array.random.random((500000, 100), chunks=(50000, tchunk)),
        dims=['s', 't'])
    if a_by_a:
        b = a
    else:
        b = xarray.DataArray(
            dask.array.random.random((100, ), chunks=tchunk),
            dims=['t'])

    print("xarray.dot(numpy backend):")
    %timeit xarray.dot(a.compute(), b.compute(), dims=dims)
    print("numpy.einsum:")
    %timeit numpy.einsum(iis, a, b)
    print("xarray.dot(dask backend):")
    try:
        %timeit xarray.dot(a, b, dims=dims).compute()
    except ValueError as e:
        print(e)
    print("dask.array.einsum:")
    %timeit dask.array.einsum(iis, a, b).compute()


bench(100, False, ['t'], '...i,...i')
bench( 20, False, ['t'], '...i,...i')
bench(100, True, ['t'], '...i,...i')
bench( 20, True, ['t'], '...i,...i')
bench(100, True, ['s', 't'], '...ij,...ij')
bench( 20, True, ['s', 't'], '...ij,...ij')

Output:

bench(100, False, ['t'], ...i,...i)
xarray.dot(numpy backend):
195 ms ± 3.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
numpy.einsum:
205 ms ± 2.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
xarray.dot(dask backend):
356 ms ± 44.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
dask.array.einsum:
244 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

bench(20, False, ['t'], ...i,...i)
xarray.dot(numpy backend):
297 ms ± 16.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy.einsum:
254 ms ± 15.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xarray.dot(dask backend):
732 ms ± 74.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
dask.array.einsum:
274 ms ± 12.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

bench(100, True, ['t'], ...i,...i)
xarray.dot(numpy backend):
438 ms ± 43.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy.einsum:
415 ms ± 17.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xarray.dot(dask backend):
633 ms ± 31.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
dask.array.einsum:
431 ms ± 17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

bench(20, True, ['t'], ...i,...i)
xarray.dot(numpy backend):
457 ms ± 17.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy.einsum:
463 ms ± 24.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xarray.dot(dask backend):
dimension 't' on 0th function argument to apply_ufunc with dask='parallelized' consists of multiple chunks, but is also a core dimension. To fix, rechunk into a single dask array chunk along this dimension, i.e., ``.rechunk({'t': -1})``, but beware that this may significantly increase memory usage.
dask.array.einsum:
485 ms ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

bench(100, True, ['s', 't'], ...ij,...ij)
xarray.dot(numpy backend):
418 ms ± 14.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy.einsum:
444 ms ± 43.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xarray.dot(dask backend):
384 ms ± 57.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
dask.array.einsum:
415 ms ± 19.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

bench(20, True, ['s', 't'], ...ij,...ij)
xarray.dot(numpy backend):
489 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy.einsum:
443 ms ± 3.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xarray.dot(dask backend):
585 ms ± 64.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
dask.array.einsum:
455 ms ± 13.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
@crusaderky crusaderky changed the title xarray.dot() problems with chunked core dims xarray.dot() dask problems Apr 22, 2018
@fujiisoup
Copy link
Member

fujiisoup commented Apr 22, 2018

xr.dot was implemented before dask/dask#3412 was merged, and thus it is not very efficient for dask now.

The proposed solution is to simply wait for dask/dask#3412 to reach the next release and then reimplement xarray.dot to use dask.array.einsum.

Agreed.
I think the reimplementation would be easy,

result = apply_ufunc(func, *arrays,
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
dask='parallelized', output_dtypes=[out_dtype])
return result.transpose(*[d for d in all_dims if d in result.dims])

dask='parallelrized' -> dask='allow'

@jakirkham
Copy link

Might be worth revisiting how da.dot is implemented as well. That would be the least amount of rewriting for you and would generally be nice for Dask users. If you have not already, @crusaderky, it would be nice to raise an issue over at Dask with a straight Dask benchmark comparing Dask Array's dot and einsum.

cc @mrocklin

@mrocklin
Copy link
Contributor

See also dask/dask#2225

@crusaderky
Copy link
Contributor Author

@jakirkham from what I understand da.dot implements... a limited special case of da.einsum?

Ok this is funny. I ran a few more benchmarks, and apparently xarray.dot on a dask backend is situationally faster than all other implementations when you are not reducing on any dimensions - which I understand is really the same as (a * b), except that faster than (a * b)?!?

def bench(...):
   ...
    if not dims:
        print("a * b (numpy backend):")
        %timeit a.compute() * b.compute()
        print("a * b (dask backend):")
        %timeit (a * b).compute()

bench(100, False, [], '...i,...i->...i')
bench( 20, False, [], '...i,...i->...i')
bench(100, True,  [], '...i,...i->...i')
bench( 20, True,  [], '...i,...i->...i')

Output:


bench(100, False, [], ...i,...i->...i)
xarray.dot(numpy backend):
291 ms ± 5.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy.einsum:
296 ms ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xarray.dot(dask backend):
dimension 's' on 0th function argument to apply_ufunc with dask='parallelized' consists of multiple chunks, but is also a core dimension. To fix, rechunk into a single dask array chunk along this dimension, i.e., ``.rechunk({'s': -1})``, but beware that this may significantly increase memory usage.
dask.array.einsum:
296 ms ± 21.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (numpy backend)
279 ms ± 9.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (dask backend)
241 ms ± 8.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

bench(20, False, [], ...i,...i->...i)
xarray.dot(numpy backend):
345 ms ± 6.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy.einsum:
342 ms ± 4.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xarray.dot(dask backend):
dimension 's' on 0th function argument to apply_ufunc with dask='parallelized' consists of multiple chunks, but is also a core dimension. To fix, rechunk into a single dask array chunk along this dimension, i.e., ``.rechunk({'s': -1})``, but beware that this may significantly increase memory usage.
dask.array.einsum:
347 ms ± 6.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (numpy backend)
319 ms ± 2.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (dask backend)
247 ms ± 5.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

bench(100, True, [], ...i,...i->...i)
xarray.dot(numpy backend):
477 ms ± 8.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy.einsum:
514 ms ± 35.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xarray.dot(dask backend):
241 ms ± 8.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
dask.array.einsum:
497 ms ± 21.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (numpy backend)
439 ms ± 27.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (dask backend)
517 ms ± 41.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

bench(20, True, [], ...i,...i->...i)
xarray.dot(numpy backend):
572 ms ± 13.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy.einsum:
563 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xarray.dot(dask backend):
268 ms ± 14.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
dask.array.einsum:
563 ms ± 5.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (numpy backend)
501 ms ± 5.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (dask backend)
922 ms ± 93.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This particular bit is shocking and I can't wrap my head around it?!?

bench(100, True, [], ...i,...i->...i)
xarray.dot(dask backend):
241 ms ± 8.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (dask backend)
517 ms ± 41.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

bench(20, True, [], ...i,...i->...i)
xarray.dot(dask backend):
268 ms ± 14.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
a * b (dask backend)
922 ms ± 93.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@jakirkham
Copy link

from what I understand da.dot implements... a limited special case of da.einsum?

Basically dot is an inner product. Certainly inner products can be formulated using Einstein notation (i.e. calling with einsum).

The question is whether the performance keeps up with that formulation. Currently it sounds like chunking causes some problems right now IIUC. However things like dot and tensordot dispatch through optimized BLAS routines. In theory einsum should do the same ( numpy/numpy#9425 ), but the experimental data still shows a few warts. For example, matmul is implemented with einsum, but is slower than dot. ( numpy/numpy#7569 ) ( numpy/numpy#8957 ) Pure einsum implementations seem to perform similarly.

I ran a few more benchmarks...

What are the arrays used as input for this case?

...apparently xarray.dot on a dask backend is situationally faster than all other implementations when you are not reducing on any dimensions...

Having a little trouble following this. dot reduces one dimension from each input. Excepting if one of the inputs is 0-D (i.e. a scalar), then it is just multiplying a single scalar through an array. Is that what you are referring?

@crusaderky
Copy link
Contributor Author

crusaderky commented Apr 23, 2018

What are the arrays used as input for this case?

See blob in the opening post

dot reduces one dimension from each input

xarray.dot(a, b, dims=[]) is functionally identical to a * b to my understanding, but faster in some edge cases - which I can't make any sense of.

@shoyer
Copy link
Member

shoyer commented Apr 23, 2018

+1 for using dask.array.einsum in xarray.dot.

@fujiisoup
Copy link
Member

@crusaderky , Thanks for the detailed benchmarking.
Further note:

  • xr.dot uses tensordot if possible, as when I implemented dask did not have einsum.
    In the other cases, we use dask.atop with np.einsum.

In your example, bench(100, False, ['t'], '...i,...i') uses dask.tensordot,
bench(100, True, ['t'], '...i,...i') uses np.einsum.

bench(100, True, [], ...i,...i->...i) also uses np.einsum.
But I have no idea yet why dot(a, b, dims=[]) is faster than a * b.

@mrocklin
Copy link
Contributor

mrocklin commented Apr 24, 2018 via email

@crusaderky
Copy link
Contributor Author

Done the work - but we'll need to wait for dask 0.17.3 to integrate it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants