Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Custom sensitivities for strided matmul never hit and I think are wrong #192

Open
oxinabox opened this issue Oct 2, 2020 · 1 comment
Open

Comments

@oxinabox
Copy link
Member

oxinabox commented Oct 2, 2020

This file:
https://github.com/invenia/Nabla.jl/blob/4cadc87677fb1187354999dcf93bd528f45f85d0/src/sensitivities/linalg/strided.jl

it says:

const RS = StridedMatrix{<:∇Scalar}
const RST = Transpose{<:∇Scalar, RS}
const RSA = Adjoint{<:∇Scalar, RS}

But should say

const RS = StridedMatrix{<:∇Scalar}
const RST = Transpose{<:∇Scalar, <:RS}
const RSA = Adjoint{<:∇Scalar, <:RS}

Because otherwise RST are targetting Transpose{<:∇Scalar, Union(DenseArray, ...} (Similar for RSA).
Which will never occur in real code without manually costructing the Transpose

So ithink the only strided rules that are hit is:
(RS, RS, 'N', 'C', :Ȳ, :B, 'C', 'N', :A, :Ȳ)

And i think the others are wrong also, because i get errors that say GEMM is being used wrong when i change them to be that.

@oxinabox
Copy link
Member Author

oxinabox commented Jul 2, 2021

will be closed by #189

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant