-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
S4 Module incompatible with pytorch 2.0's torch.compile
#91
Comments
torch.compile
Thanks a lot for the report! Because of potential issues like these, I'm waiting until Pytorch 2.x is more stable before upgrading the entire repo. If the underlying issue is complex numbers with torch.compile (can you repro the error with a minimal example?) there's not much I can do and it would be great to file an issue directly with Pytorch. |
Thanks for the reply.
It should be rather easy to reproduce, just by running a simple forward pass on the standalone sashimi model after it has been compiled. But I'll try to provide one within the next 24 hours if I get to it. |
Right, although it would be helpful to see if it fails with a more minimal model than the S4 layer if the line you pointed out is indeed the problem. If you don't get to it, I'll keep this in mind when I get around to trying to upgrade the library versions. |
Ah I see. I'll see what I can do |
@albertfgu Here you go: https://gist.github.com/RoiEXLab/5cc1630aca71b603528a574b2a2e3326 It turns out the
So it does indeed seem the issue is with complex numbers. Looking at the pytorch repo there are a lot of issues open regarding complex numbers, but I'm not quite how well they apply to this exact issue. Also I tried using different backends for the compilation (see |
Same issue here, I import the |
Unfortunately this is a missing functionality on PyTorch's end (in turn coming from lack of support in Triton): pytorch/pytorch#98161. The PyTorch team is aware of this and may look to support it eventually, but it's unclear how long that would take. I don't think that the core state space kernels ( |
Hi,
I'm using the sashimi model on my own dataset with reasonable success for a while now and I wanted to see if I could use the recently released
torch.compile
function on the sashimi model to speed up training for my experiments.Unfortunately it doesn't work. The following line seems to fail (for reasons I don't understand): https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/s4/s4.py#L703
On the pytorch site there's some information on how to deal with those issues, so I hope the code can be extended in the future to run faster by a noticeable amount.
Thanks in advance.
The text was updated successfully, but these errors were encountered: