Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The merge semantics for Arc<Mutex<OwnedTape<_, _>>> seem a bit unintuitive #841

Open
emchristiansen opened this issue Aug 2, 2023 · 2 comments

Comments

@emchristiansen
Copy link

The merge semantics for Arc<Mutex<OwnedTape<_, _>>> seem a bit unintuitive.
In particular, one would hope that when these tapes are merged, they would essentially be replaced with one tape which is the union of the two input tapes.
But what actually happens is the left tape becomes the union and the right tape becomes the empty tape.
This can lead to a similar problem as with the plain OwnedTape<_, _>, where the gradient tape is partitioned across several objects, and it's up to the programmer to figure out which object has which part of the gradients.

For reference, here's the merge code for Arc<Mutex<OwnedTape<_, _>>>:

impl<E, D: Storage<E>> Merge<Self> for Arc<Mutex<OwnedTape<E, D>>> {
    fn merge(self, other: Self) -> Self {
        if !Arc::ptr_eq(&self, &other) {
            let mut lhs = self.lock().unwrap();
            let mut rhs = other.lock().unwrap();
            lhs.gradients
                .gradient_by_id
                .append(&mut rhs.gradients.gradient_by_id);
            if let Some(leafs) = &mut rhs.gradients.leaf_ids {
                lhs.gradients
                    .leaf_ids
                    .get_or_insert_with(Default::default)
                    .append(leafs);
            }
            lhs.operations.append(&mut rhs.operations);
        }
        self
    }
}

If we invoke this code by writing let z = x + y.clone() and then write let zz = y * 2.0, we'll end up with gradients partitioned between z and zz, each of which have Arc<Mutex<_>> pointers to distinct OwnedTape<_, _> objects.

I'd personally find it more natural if the merge operation mutated the tapes of both arguments to make them both point to the same underlying data, so that we only have one OwnedTape<_, _> at the end of the day.

Thoughts?

@emchristiansen
Copy link
Author

FYI I hacked something together using Arc<Mutex<Arc<Mutex<OwnedTape<_, _>>>>> 🤯

The basic idea is to provide another layer of indirection so we can mutate the pointer to the actual OwnedTape<_, _>.
Here's the basic code:

impl<E, D: Storage<E>> Merge<NoneTape> for Arc<Mutex<Arc<Mutex<OwnedTape<E, D>>>>> {
    fn merge(self, _: NoneTape) -> Self {
        self
    }
}

impl<E, D: Storage<E>> Merge<Self> for Arc<Mutex<Arc<Mutex<OwnedTape<E, D>>>>> {
    fn merge(self, other: Self) -> Self {
        if !Arc::ptr_eq(&self, &other) {
            let pointer_lhs = self.lock().unwrap();
            let mut pointer_rhs = other.lock().unwrap();
            if !Arc::ptr_eq(&pointer_lhs, &pointer_rhs) {
                let mut lhs = pointer_lhs.lock().unwrap();
                let mut rhs = pointer_rhs.lock().unwrap();
                lhs.gradients
                    .gradient_by_id
                    .append(&mut rhs.gradients.gradient_by_id);
                if let Some(leafs) = &mut rhs.gradients.leaf_ids {
                    lhs.gradients
                        .leaf_ids
                        .get_or_insert_with(Default::default)
                        .append(leafs);
                }
                lhs.operations.append(&mut rhs.operations);
            }
            // Update the RHS so it points to the same underlying OwnedTape.
            *pointer_rhs = pointer_lhs.clone();
        }
        self
    }
}

impl<E, D: Storage<E>> Tape<E, D> for Arc<Mutex<Arc<Mutex<OwnedTape<E, D>>>>> {
    const OWNS_TAPE: bool = true;
    fn add_backward_op<F>(&mut self, operation: F)
    where
        F: 'static + FnOnce(&mut Gradients<E, D>) -> Result<(), D::Err>,
    {
        let mut tape = self.lock().unwrap();
        tape.add_backward_op(operation);
    }
}

TBH anything with Arc<Mutex<Arc<Mutex<_>>>> seems pretty sketchy to me, but so far it seems to work (currently testing it).
Let me know if you have a more elegant approach.

@coreylowman
Copy link
Owner

So you would want both tapes to be merged together and have the same gradients?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants