Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.nn.Dropout recomputation support during the backward pass to Thunder #114

Open
kevinstephano opened this issue Apr 2, 2024 · 3 comments · Fixed by #237
Open
Assignees
Labels
enhancement New feature or request high priority

Comments

@kevinstephano
Copy link
Collaborator

kevinstephano commented Apr 2, 2024

🚀 Feature

I would like to have Thunder save the seed and offset from random number generation to allow for the recomputation of Dropout in the backward pass.

There are two pieces needed to make it work:

  • Support stateless (deterministic) PRNG. This is done with thunder.prims.uniform_philox.
  • Trace transform to query PyTorch's PRNG state before each uniform call, replacing uniform with uniform_philox , and incrementing PRNG state properly. This is not implemented.

Motivation

Multihead Attention modules in LLMs often use dropout where the memory used is the square of the sequence length.

cc @apaz-cli

@kevinstephano kevinstephano added enhancement New feature or request help wanted Extra attention is needed labels Apr 2, 2024
@IvanYashchuk IvanYashchuk self-assigned this Apr 2, 2024
@mruberry
Copy link
Collaborator

mruberry commented Apr 3, 2024

fyi @ptrblck

@IvanYashchuk
Copy link
Collaborator

Here's a simple test case:

import torch
import thunder

def func(a):
    t1 = torch.nn.functional.dropout(a, p=0.5)
    return t1 @ t1

a = torch.randn(2, 2, device="cuda", requires_grad=True)

jfunc = thunder.jit(func)
out = jfunc(a)

Forward trace shows that the dropout mask is saved for backward (t1):

print(thunder.last_traces(jfunc)[-1])

def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  [t1, t4] = nvFusion0(a)
    # t0 = prims.uniform((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.lt(t0, 0.5)  # t1: "cuda:0 b8[2, 2]"
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.mul(a, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.mul(t3, 2.0)  # t4: "cuda:0 f32[2, 2]"
  t5 = torch.matmul(t4, t4)  # t5: "cuda:0 f32[2, 2]"
    # t5 = ltorch.matmul(t4, t4)  # t5: "cuda:0 f32[2, 2]"
      # t5 = prims.matmul(t4, t4)  # t5: "cuda:0 f32[2, 2]"
  return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((t1, t4), (2.0,))

Backward trace:

print(thunder.last_backward_traces(jfunc)[-1])

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t6, = cotangents
  clear_collection(cotangents)
  del cotangents
  t1, t4, = C0
  clear_collection(C0)
  del C0
  f7, = C1
  clear_collection(C1)
  del C1
  t35 = torch.permute(t4, (1, 0))  # t35: "cuda:0 f32[2, 2]"
    # t35 = ltorch.permute(t4, (1, 0))  # t35: "cuda:0 f32[2, 2]"
      # t35 = prims.transpose(t4, (1, 0))  # t35: "cuda:0 f32[2, 2]"
  del t4
  t36 = torch.matmul(t35, t6)  # t36: "cuda:0 f32[2, 2]"
    # t36 = ltorch.matmul(t35, t6)  # t36: "cuda:0 f32[2, 2]"
      # t36 = prims.matmul(t35, t6)  # t36: "cuda:0 f32[2, 2]"
  t34 = torch.matmul(t6, t35)  # t34: "cuda:0 f32[2, 2]"
    # t34 = ltorch.matmul(t6, t33)  # t34: "cuda:0 f32[2, 2]"
      # t34 = prims.matmul(t6, t33)  # t34: "cuda:0 f32[2, 2]"
  del t6, t35
  [t39] = nvFusion0(f7, t1, t34, t36)
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t37 = prims.add(t34, t36)  # t37: "cuda:0 f32[2, 2]"
    # t38 = prims.mul(f7, t37)  # t38: "cuda:0 f32[2, 2]"
    # t39 = prims.mul(t2, t38)  # t39: "cuda:0 f32[2, 2]"
  del f7, t1, t34, t36
  return (t39,)

We should implement a trace transformation that replaces prims.uniform with prims.uniform_philox and augments the trace with appropriate explicit calls to get_rng_state/set_rng_state to manipulate PyTorch's RNG state.

@mruberry mruberry removed the help wanted Extra attention is needed label Apr 8, 2024
@mruberry
Copy link
Collaborator

mruberry commented Apr 8, 2024

Triage review:

  • it would be nice to support different kinds of rng in the future
  • different executors may produce different tensors with the same random generator
  • we need to think about how to ensure the mask is generated consistently (require same executor for both)
  • can we provide a mechanism for practitioners who want to write everything explicitly to update the PyTorch random state properly (after calling a uniform philox prim)?
  • could an implementation be used to update the PyTorch random state properly?

kiya00 added a commit that referenced this issue Apr 16, 2024
kiya00 added a commit that referenced this issue Apr 18, 2024
kiya00 added a commit that referenced this issue Apr 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request high priority
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants