Skip to content

Commit

Permalink
Merge pull request #366 from j-towns/close-matmul
Browse files Browse the repository at this point in the history
Express matmul vjps in terms of matmul
  • Loading branch information
mattjj committed Feb 24, 2018
2 parents 10c7cc6 + e511c18 commit 13c8266
Showing 1 changed file with 42 additions and 13 deletions.
55 changes: 42 additions & 13 deletions autograd/numpy/numpy_vjps.py
Expand Up @@ -317,19 +317,48 @@ def grad_inner(argnum, ans, A, B):
return lambda G: tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim)
defvjp(anp.inner, partial(grad_inner, 0), partial(grad_inner, 1))

def grad_matmul(argnum, ans, A, B):
A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
if A_ndim == 0 or B_ndim == 0:
raise ValueError("Scalar operands are not allowed, use '*' instead")
elif A_ndim == 1 or B_ndim == 1 or (A_ndim == 2 and B_ndim == 2):
axes = ([A_ndim - 1], [max(0, B_ndim - 2)])
if argnum == 0:
return lambda G: match_complex(A, tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim))
elif argnum == 1:
return lambda G: match_complex(B, tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim))
else:
return grad_einsum(argnum + 1, ans, ("...ij,...jk->...ik", A, B), None)
defvjp(anp.matmul, partial(grad_matmul, 0), partial(grad_matmul, 1))
def matmul_adjoint_0(B, G, A_meta, B_ndim):
if anp.ndim(G) == 0: # A_ndim == B_ndim == 1
return unbroadcast(G * B, A_meta)
_, A_ndim, _, _ = A_meta
if A_ndim == 1:
G = anp.expand_dims(G, anp.ndim(G) - 1)
if B_ndim == 1: # The result we need is an outer product
B = anp.expand_dims(B, 0)
G = anp.expand_dims(G, anp.ndim(G))
else: # We need to swap the last two axes of B
B = anp.swapaxes(B, B_ndim - 2, B_ndim - 1)
result = anp.matmul(G, B)
return unbroadcast(result, A_meta)

def matmul_adjoint_1(A, G, A_ndim, B_meta):
if anp.ndim(G) == 0: # A_ndim == B_ndim == 1
return unbroadcast(G * A, B_meta)
_, B_ndim, _, _ = B_meta
B_is_vec = (B_ndim == 1)
if B_is_vec:
G = anp.expand_dims(G, anp.ndim(G))
if A_ndim == 1: # The result we need is an outer product
A = anp.expand_dims(A, 1)
G = anp.expand_dims(G, anp.ndim(G) - 1)
else: # We need to swap the last two axes of A
A = anp.swapaxes(A, A_ndim - 2, A_ndim - 1)
result = anp.matmul(A, G)
if B_is_vec:
result = anp.squeeze(result, anp.ndim(G) - 1)
return unbroadcast(result, B_meta)

def matmul_vjp_0(ans, A, B):
A_meta = anp.metadata(A)
B_ndim = anp.ndim(B)
return lambda g: matmul_adjoint_0(B, g, A_meta, B_ndim)

def matmul_vjp_1(ans, A, B):
A_ndim = anp.ndim(A)
B_meta = anp.metadata(B)
return lambda g: matmul_adjoint_1(A, g, A_ndim, B_meta)

defvjp(anp.matmul, matmul_vjp_0, matmul_vjp_1)

@primitive
def dot_adjoint_0(B, G, A_ndim, B_ndim):
Expand Down

0 comments on commit 13c8266

Please sign in to comment.