Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TapeGlobal and thoughts on the tape variants #843

Open
emchristiansen opened this issue Aug 3, 2023 · 3 comments
Open

TapeGlobal and thoughts on the tape variants #843

emchristiansen opened this issue Aug 3, 2023 · 3 comments

Comments

@emchristiansen
Copy link

TapeGlobal

Here's another tape-tracking API to consider, as implemented in client code:

type FloatInner = f32;
type DfdxDevice = Cpu;

thread_local! {
static TAPE_GLOBAL: once_cell::sync::Lazy<
  Mutex<Option<OwnedTape<FloatInner, DfdxDevice>>>,
> = once_cell::sync::Lazy::new(|| Default::default());
}

#[derive(Debug, Clone, Default)]
pub struct TapeGlobal;

impl TapeGlobal
{
  pub fn init()
  {
    TAPE_GLOBAL.with(|tape| {
      *tape
        .lock()
        .unwrap() = Some(OwnedTape::default());
    });
  }

  pub fn reset()
  {
    TAPE_GLOBAL.with(|tape| {
      *tape
        .lock()
        .unwrap() = None;
    });
  }

  pub fn get() -> OwnedTape<FloatInner, DfdxDevice>
  {
    TAPE_GLOBAL.with(|tape| {
      let mut locked_tape = tape
        .lock()
        .unwrap();
      let out = locked_tape
        .take()
        .expect("Tape must be initialized before calling get");
      out
    })
  }

  pub fn set(value: OwnedTape<FloatInner, DfdxDevice>)
  {
    TAPE_GLOBAL.with(|tape| {
      let mut locked_tape = tape
        .lock()
        .unwrap();
      assert!(
        locked_tape.is_none(),
        "Tape must be None before calling set"
      );
      *locked_tape = Some(value);
    });
  }
}

impl Merge<NoneTape> for TapeGlobal
{
  fn merge(
    self,
    _: NoneTape,
  ) -> Self
  {
    self
  }
}

impl Merge<TapeGlobal> for TapeGlobal
{
  fn merge(
    self,
    _other: Self,
  ) -> Self
  {
    self
  }
}

impl Tape<FloatInner, DfdxDevice> for TapeGlobal
{
  const OWNS_TAPE: bool = true;

  fn add_backward_op<F>(
    &mut self,
    operation: F,
  ) where
    F: 'static
      + FnOnce(
        &mut Gradients<FloatInner, DfdxDevice>,
      ) -> Result<(), <DfdxDevice as HasErr>::Err>,
  {
    let mut tape = TapeGlobal::get();
    tape.add_backward_op(operation);
    TapeGlobal::set(tape);
  }
}

pub struct TapeGlobalTensor(
  pub Tensor<Rank0, FloatInner, DfdxDevice, TapeGlobal>,
);

impl HasErr for TapeGlobalTensor
{
  type Err = <Tensor<Rank0, FloatInner, DfdxDevice, TapeGlobal> as HasErr>::Err;
}

impl Backward<FloatInner, DfdxDevice> for TapeGlobalTensor
{
  fn try_backward(self)
    -> Result<Gradients<FloatInner, DfdxDevice>, Self::Err>
  {
    let (t, _) = self
      .0
      .split_tape();
    let tape = TapeGlobal::get();
    t.put_tape(tape)
      .try_backward()
  }
}

TapeGlobal does what it says: It maintains a global, thread-local, tape, thus avoiding the partitioned-gradients problem*.
Here's how you use it:

TapeGlobal::init();
let x = dev
  .tensor(1.0)
  .put_tape(TapeGlobal);
let y = x.clone() * x;
dbg!(TapeGlobalTensor(y).backward());

// If calling in a loop, you'd call TapeGlobal::reset() then
// TapeGlobal::init() each time to clear the gradients.

Let me know if you'd like a PR with something like this.

Thoughts on the tape variants

As you know, I have very non-standard model inputs and outputs, and I've been playing with different tape tracking APIs in the hope of finding one that's easy to use and not error-prone, specifically: OwnedTape<_, _>, Arc<Mutex<OwnedTape<_, _>>>, Arc<Mutex<Arc<Mutex<OwnedTape<_, _>>>>>, and GlobalTape.

Here are my thoughts:

  1. OwnedTape<_, _>: Simple and avoids Arc, but suffers from the partitioned-gradients problem, so the programmer must either 1) mentally track which tape has which gradients, or 2) do a final accumulate across all the model outputs to merge all the tapes into one.
  2. Arc<Mutex<OwnedTensor<_, _>>>: Uses Arc (sorta bad) and only partially solves the partitioned-gradients problem.
  3. Arc<Mutex<Arc<Mutex<OwnedTensor<_, _>>>>>: Uses nested Arc<Mutex<_>> (bad code smell) and still does not fully solve the partitioned-gradients problem. E.g., if you have dependency paths that originate from model parameters and never interact with the main tape (this can occur in a conditional computation setting). This behavior isn't terrible, because these gradients will be zero, but you will get an unwrap error if you expect them to be defined.
  4. TapeGlobal: This does fully solve the partitioned-gradients problem, at the expense of maintaining an ugly global variable and losing type parameterization.

I'm personally still not sure which API I prefer between 1, 3, and 4 (I don't love any of them).
But I think 2 is essentially useless and rather dangerous, as you can accidentally omit gradients from your tape, and the shared state makes the tape accumulation difficult to reason about.

*Partitioned-gradients problem: When the gradients for you network are partitioned across several tapes and you as a programmer have to worry about which tape has which gradients.

@coreylowman
Copy link
Owner

I don't see a way to do global/shared tapes without using some form of Arc/Mutex.

FWIW the optimizers.update returns an error that indicates whether some tensors that were supposed to have gradients did not, so that was intended to capture this error.

@emchristiansen
Copy link
Author

Arc<Mutex<_>>

FYI you can define the TapeGlobal state like this:

thread_local! {
static TAPE_GLOBAL:
  RefCell<Option<OwnedTape<FloatInner, DfdxDevice>>
> = RefCell::new(None);
}

IIUC, thread_local! ensures there are no threading concurrency issues, and the lack of async ensures there are no cooperative futures concurrency issues.
So, you don't need anything like an Arc<Mutex<_>>; a RefCell<_> suffices.

FWIW I don't love TapeGlobal but I prefer it to the other approaches.

Missing gradients

I think relying on that optimizer.update behavior becomes problematic when you have conditional computation (as I do, of course!).
In this case you expect only a fraction of the parameters to have gradients for any given pass.

@mileswatson
Copy link

mileswatson commented Jan 20, 2024

Are these variants the only way to reuse a taped tensor more than once during a forward pass? E.g. when x is used twice in

TapeGlobal::init();
let x = dev
  .tensor(1.0)
  .put_tape(TapeGlobal);
let y = x.clone() * x;
dbg!(TapeGlobalTensor(y).backward());

What's the difference between this and retaping x (instead of cloning)? I'm trying to continue work on #437, but am suspicious that some of the retaping is incorrect (particularly lines 105 and 115 in rl-ppo-continuous.rs).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants