Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.roll converts nn.Parameter into regular torch.Tensor #126524

Closed
thomassajot opened this issue May 17, 2024 · 3 comments
Closed

torch.roll converts nn.Parameter into regular torch.Tensor #126524

thomassajot opened this issue May 17, 2024 · 3 comments
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@thomassajot
Copy link

thomassajot commented May 17, 2024

馃悰 Describe the bug

A nn.Parameter cannot be assigned the output of torch.roll.

import torch

class MyModule(torch.nn.Module):
    def __init__(self, cache: torch.Tensor):
        super().__init__()
        assert cache.ndim == 3
        self.cache = torch.nn.Parameter(cache, requires_grad=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        n_tokens = x.size(1)
        self.cache = torch.roll(self.cache, -n_tokens, dims=1)
        self.cache[:, -n_tokens:, :] = x
        return self.cache

MyModule(torch.zeros(2, 3, 4))(torch.zeros(2, 2, 4))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 15
     12         self.cache[:, -n_tokens:, :] = x
     13         return self.cache
---> 15 MyModule(torch.zeros(2, 3, 4))(torch.zeros(2, 2, 4))

File ~/si_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/si_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Cell In[1], line 11
      9 def forward(self, x: torch.Tensor) -> torch.Tensor:
     10     n_tokens = x.size(1)
---> 11     self.cache = torch.roll(self.cache, -n_tokens, dims=1)
     12     self.cache[:, -n_tokens:, :] = x
     13     return self.cache

File ~/si_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1715, in Module.__setattr__(self, name, value)
   1713 elif params is not None and name in params:
   1714     if value is not None:
-> 1715         raise TypeError(f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
   1716                         "(torch.nn.Parameter or None expected)"
   1717                         )
   1718     self.register_parameter(name, value)
   1719 else:

TypeError: cannot assign 'torch.FloatTensor' as parameter 'cache' (torch.nn.Parameter or None expected)

Versions

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@flishwang
Copy link

Paramters generally should not be directly changed in the forward pass. For best practise, you may use self.register_buffer instead.

Anyway, if you really want to change the data in a parameter, use the following code:

def forward(self, x: torch.Tensor) -> torch.Tensor:
        n_tokens = x.size(1)
        with torch.no_grad():
            self.cache.data = torch.roll(self.cache, -n_tokens, dims=1)
        self.cache[:, -n_tokens:, :] = x
        return self.cache

By the way, I'm not sure what you want to do with MyModule(torch.zeros(2, 3, 4))(torch.zeros(2, 2, 4)). Generally the returned tensor/parameter cannot be called.

@mikaylagawarecki
Copy link
Contributor

This seems to be the expected behavior, you would expect an out of place op on a Parameter (e.g. torch.roll) to return a tensor, and you attempt to reassign this to self.cache, which fails with the expected error message

@mikaylagawarecki mikaylagawarecki added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 20, 2024
@thomassajot
Copy link
Author

Thank you for your reply. As you mentioned, register_buffer might be the correct way to go.
The goal of the cache is to append historical elements alongside the current tensor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants