Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Express matmul vjps in terms of matmul #366

Merged
merged 1 commit into from Feb 24, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 42 additions & 13 deletions autograd/numpy/numpy_vjps.py
Expand Up @@ -317,19 +317,48 @@ def grad_inner(argnum, ans, A, B):
return lambda G: tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim)
defvjp(anp.inner, partial(grad_inner, 0), partial(grad_inner, 1))

def grad_matmul(argnum, ans, A, B):
A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
if A_ndim == 0 or B_ndim == 0:
raise ValueError("Scalar operands are not allowed, use '*' instead")
elif A_ndim == 1 or B_ndim == 1 or (A_ndim == 2 and B_ndim == 2):
axes = ([A_ndim - 1], [max(0, B_ndim - 2)])
if argnum == 0:
return lambda G: match_complex(A, tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim))
elif argnum == 1:
return lambda G: match_complex(B, tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim))
else:
return grad_einsum(argnum + 1, ans, ("...ij,...jk->...ik", A, B), None)
defvjp(anp.matmul, partial(grad_matmul, 0), partial(grad_matmul, 1))
def matmul_adjoint_0(B, G, A_meta, B_ndim):
if anp.ndim(G) == 0: # A_ndim == B_ndim == 1
return unbroadcast(G * B, A_meta)
_, A_ndim, _, _ = A_meta
if A_ndim == 1:
G = anp.expand_dims(G, anp.ndim(G) - 1)
if B_ndim == 1: # The result we need is an outer product
B = anp.expand_dims(B, 0)
G = anp.expand_dims(G, anp.ndim(G))
else: # We need to swap the last two axes of B
B = anp.swapaxes(B, B_ndim - 2, B_ndim - 1)
result = anp.matmul(G, B)
return unbroadcast(result, A_meta)

def matmul_adjoint_1(A, G, A_ndim, B_meta):
if anp.ndim(G) == 0: # A_ndim == B_ndim == 1
return unbroadcast(G * A, B_meta)
_, B_ndim, _, _ = B_meta
B_is_vec = (B_ndim == 1)
if B_is_vec:
G = anp.expand_dims(G, anp.ndim(G))
if A_ndim == 1: # The result we need is an outer product
A = anp.expand_dims(A, 1)
G = anp.expand_dims(G, anp.ndim(G) - 1)
else: # We need to swap the last two axes of A
A = anp.swapaxes(A, A_ndim - 2, A_ndim - 1)
result = anp.matmul(A, G)
if B_is_vec:
result = anp.squeeze(result, anp.ndim(G) - 1)
return unbroadcast(result, B_meta)

def matmul_vjp_0(ans, A, B):
A_meta = anp.metadata(A)
B_ndim = anp.ndim(B)
return lambda g: matmul_adjoint_0(B, g, A_meta, B_ndim)

def matmul_vjp_1(ans, A, B):
A_ndim = anp.ndim(A)
B_meta = anp.metadata(B)
return lambda g: matmul_adjoint_1(A, g, A_ndim, B_meta)

defvjp(anp.matmul, matmul_vjp_0, matmul_vjp_1)

@primitive
def dot_adjoint_0(B, G, A_ndim, B_ndim):
Expand Down