Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for Jax-like custom forward pass definition? #585

Open
tylerflex opened this issue Oct 24, 2022 · 1 comment
Open

support for Jax-like custom forward pass definition? #585

tylerflex opened this issue Oct 24, 2022 · 1 comment

Comments

@tylerflex
Copy link

Is there a way to define a custom forward pass, like in jax, where one can output a residual that may be used by the backward pass?

For example, is the following example (from the Jax docs) implementable in autograd?

from jax import custom_vjp

@custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res # Gets residuals computed in f_fwd
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)
@pat749
Copy link

pat749 commented Mar 26, 2023

In PyTorch, you can define a custom forward pass by subclassing torch.autograd.Function. This allows you to specify the forward pass, backward pass, and gradient computation of your custom function.

For example, you could implement the Jax f function as follows in PyTorch:

import torch

class f(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        return torch.sin(x) * y

    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        cos_x = torch.cos(x)
        sin_x = torch.sin(x)
        grad_x = grad_output * cos_x * y
        grad_y = grad_output * sin_x
        return grad_x, grad_y

x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)

output = f.apply(x, y)
output.backward()

print(x.grad) # tensor(-1.0806)
print(y.grad) # tensor(0.8415)

Here, ctx.save_for_backward is used to save the values of x and y for use in the backward pass. The backward method then computes the gradients with respect to x and y using the saved values and the chain rule. Finally, the apply method is used to apply the custom function to the inputs x and y.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
@pat749 @tylerflex and others