-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.nn.Dropout
recomputation support during the backward pass to Thunder
#114
Labels
Comments
kevinstephano
added
enhancement
New feature or request
help wanted
Extra attention is needed
labels
Apr 2, 2024
fyi @ptrblck |
Here's a simple test case: import torch
import thunder
def func(a):
t1 = torch.nn.functional.dropout(a, p=0.5)
return t1 @ t1
a = torch.randn(2, 2, device="cuda", requires_grad=True)
jfunc = thunder.jit(func)
out = jfunc(a) Forward trace shows that the dropout mask is saved for backward ( print(thunder.last_traces(jfunc)[-1])
def augmented_forward_fn(a):
# a: "cuda:0 f32[2, 2]"
[t1, t4] = nvFusion0(a)
# t0 = prims.uniform((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t0: "cuda:0 f32[2, 2]"
# t1 = prims.lt(t0, 0.5) # t1: "cuda:0 b8[2, 2]"
# t2 = prims.convert_element_type(t1, dtypes.float32) # t2: "cuda:0 f32[2, 2]"
# t3 = prims.mul(a, t2) # t3: "cuda:0 f32[2, 2]"
# t4 = prims.mul(t3, 2.0) # t4: "cuda:0 f32[2, 2]"
t5 = torch.matmul(t4, t4) # t5: "cuda:0 f32[2, 2]"
# t5 = ltorch.matmul(t4, t4) # t5: "cuda:0 f32[2, 2]"
# t5 = prims.matmul(t4, t4) # t5: "cuda:0 f32[2, 2]"
return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((t1, t4), (2.0,)) Backward trace: print(thunder.last_backward_traces(jfunc)[-1])
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t6, = cotangents
clear_collection(cotangents)
del cotangents
t1, t4, = C0
clear_collection(C0)
del C0
f7, = C1
clear_collection(C1)
del C1
t35 = torch.permute(t4, (1, 0)) # t35: "cuda:0 f32[2, 2]"
# t35 = ltorch.permute(t4, (1, 0)) # t35: "cuda:0 f32[2, 2]"
# t35 = prims.transpose(t4, (1, 0)) # t35: "cuda:0 f32[2, 2]"
del t4
t36 = torch.matmul(t35, t6) # t36: "cuda:0 f32[2, 2]"
# t36 = ltorch.matmul(t35, t6) # t36: "cuda:0 f32[2, 2]"
# t36 = prims.matmul(t35, t6) # t36: "cuda:0 f32[2, 2]"
t34 = torch.matmul(t6, t35) # t34: "cuda:0 f32[2, 2]"
# t34 = ltorch.matmul(t6, t33) # t34: "cuda:0 f32[2, 2]"
# t34 = prims.matmul(t6, t33) # t34: "cuda:0 f32[2, 2]"
del t6, t35
[t39] = nvFusion0(f7, t1, t34, t36)
# t2 = prims.convert_element_type(t1, dtypes.float32) # t2: "cuda:0 f32[2, 2]"
# t37 = prims.add(t34, t36) # t37: "cuda:0 f32[2, 2]"
# t38 = prims.mul(f7, t37) # t38: "cuda:0 f32[2, 2]"
# t39 = prims.mul(t2, t38) # t39: "cuda:0 f32[2, 2]"
del f7, t1, t34, t36
return (t39,) We should implement a trace transformation that replaces |
Triage review:
|
kiya00
added a commit
that referenced
this issue
Apr 16, 2024
…_philox) and RNG state query/updating (#114)
kiya00
added a commit
that referenced
this issue
Apr 18, 2024
…_philox) and RNG state query/updating (#114)
kiya00
added a commit
that referenced
this issue
Apr 19, 2024
…_philox) and RNG state query/updating (#114)
4 tasks
kiya00
added a commit
that referenced
this issue
Apr 24, 2024
kiya00
added a commit
that referenced
this issue
Apr 24, 2024
4 tasks
kiya00
added a commit
that referenced
this issue
Apr 24, 2024
kiya00
added a commit
that referenced
this issue
Apr 30, 2024
kiya00
added a commit
that referenced
this issue
May 7, 2024
kiya00
added a commit
that referenced
this issue
May 8, 2024
kiya00
added a commit
that referenced
this issue
May 15, 2024
4 tasks
kiya00
added a commit
that referenced
this issue
May 27, 2024
kiya00
added a commit
that referenced
this issue
May 28, 2024
kiya00
added a commit
that referenced
this issue
May 29, 2024
8 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
🚀 Feature
I would like to have Thunder save the seed and offset from random number generation to allow for the recomputation of Dropout in the backward pass.
There are two pieces needed to make it work:
thunder.prims.uniform_philox
.uniform
call, replacinguniform
withuniform_philox
, and incrementing PRNG state properly. This is not implemented.Motivation
Multihead Attention modules in LLMs often use dropout where the memory used is the square of the sequence length.
cc @apaz-cli
The text was updated successfully, but these errors were encountered: