Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable dict inputs for torch.autograd.grad and torch.autograd.backward (usability for torch.func.functional_call) #126650

Open
XuehaiPan opened this issue May 19, 2024 · 1 comment
Assignees
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@XuehaiPan
Copy link
Collaborator

XuehaiPan commented May 19, 2024

馃殌 The feature, motivation and pitch

functorch is deprecated in favor of torch.func in PyTorch 2.0. The API functorch.make_functional is replaced by torch.func.functional_call.

functorch.make_functional

torch.func.functional_call() is the replacement for functorch.make_functional and functorch.make_functional_with_buffers. However, it is not a drop-in replacement.

The torch.func.functional_call() API takes dict[str, Tensor] inputs for parameters and buffers, while torch.autograd.{grad,backward} only supports tensor or tuple of tensors (Tensor | tuple[Tensor, ...]) as input. Users need to do manual conversion between tuple and dict. That is very inconvenient.

This issue requests to support dict[str, Tensor] as inputs in torch.autograd.grad and torch.autograd.backward.


Code snippet for example case.

With fmodel, params = functorch.make_functional(model), params is tuple[nn.Parameter, ...]:

import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F

fmodel: nn.Module
params: tuple[torch.Tensor, ...]
fmodel, params = functorch.make_functional(model)

def calculate_loss(fmodel, params, batch)
    inputs, labels = batch
    outputs = fmodel(params, inputs)
    return F.mse_loss(outputs, labels)

loss: torch.Tensor = calculate_loss(fmodel, params, batch)
grads: tuple[torch.Tensor, ...] = torch.autograd.grad(loss, params)

With torch.func.functional_call(model, params), params is dict[str, nn.Parameter]:

import torch
import torch.nn as nn
import torch.nn.functional as F

params: dict[str, torch.Tensor] = dict(model.named_parameters())

def calculate_loss(model, params, batch)
    inputs, labels = batch
    outputs = torch.func.functional_call(model, params, inputs)
    return F.mse_loss(outputs, labels)

loss: torch.Tensor = calculate_loss(model, params, batch)
grads: tuple[torch.Tensor, ...] = torch.autograd.grad(
    loss,
    tuple(params.values()),  # need to convert `dict` to `tuple` manually
)
grads: dict[str, torch.Tensor] = dict(zip(params.keys(), grads))  # need to convert `tuple` to `dict` manually

Alternatives

No response

Additional context

No response

cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @zou3519 @Chillee @samdow @kshitij12345 @janeyx99

@XuehaiPan XuehaiPan added the module: functorch Pertaining to torch.func or pytorch/functorch label May 19, 2024
@XuehaiPan XuehaiPan self-assigned this May 19, 2024
@drisspg drisspg added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 20, 2024
@zou3519
Copy link
Contributor

zou3519 commented May 20, 2024

Leaving this up to @soulitzer (as maintainer of autograd), who will be back next week. It depends on if we (pytorch) are willing to support this (dicts and even pytree support for torch.autograd.grad) going forward.

Let's check in on this in one week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants