Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Express matmul vjps in terms of matmul #366

Merged
merged 1 commit into from Feb 24, 2018
Merged

Conversation

j-towns
Copy link
Collaborator

@j-towns j-towns commented Feb 23, 2018

No description provided.

@duvenaud
Copy link
Contributor

Nice, is the motivation mainly to get rid of the call to einsum?

@j-towns
Copy link
Collaborator Author

j-towns commented Feb 24, 2018

Nice, is the motivation mainly to get rid of the call to einsum?

Yes. There should be a performance advantage to this version, particularly for small arrays, by avoiding all the parsing that einsum and its vjp have to do, I haven't profiled it yet though.

@mattjj mattjj merged commit 13c8266 into HIPS:master Feb 24, 2018
@mattjj
Copy link
Contributor

mattjj commented Feb 24, 2018

Thanks, Jamie!

I thought matmul might also be faster than einsum because it could more readily call BLAS routines, but after poking around it looks like matmul's implementation kinda sucks and just calls into einsum anyway for the stacked case. I think that's an upstream numpy issue, though, and I would guess this implementation is at least as fast as the previous one.

@j-towns
Copy link
Collaborator Author

j-towns commented Feb 24, 2018

Oh that's a shame :(. Even dot sucks in some cases because BLAS doesn't (yet) support arbitrarily strided arrays. There was a good description of how numpy deals with different strides posted here:

how do you perform a GEMM operation on two arrays with arbitrary strides? Currently, NumPy attempts to detect a number of special cases: if the strides in both arrays imply a column-major layout, then call BLAS directly; if one of them has strides corresponding to a row-major layout, then set the corresponding transA/transB argument, etc. – and if all else fails, either copy the data into a contiguous buffer, or else fall back on a naive triple-nested-loop GEMM implementation. (There's also a check where if we can determine through examining the arrays' data pointers and strides that they're actually transposes of each other, then we instead dispatch to SYRK.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants