Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NOMERG, WIP, POC] Auto-nested TensorDict #201

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

tcbegley
Copy link
Contributor

@tcbegley tcbegley commented Feb 6, 2023

Description

This PR adds support for auto-nesting inside TensorDict. This is a proof-of-concept with missing features. Supporting auto-nested values is challenging because of the large number of methods in the TensorDict class and its children which employ recursion. Checking for cycles during iteration also inevitably introduces some overhead. These trade-offs still need to be varefully benchmarked and evaluated.

Here's a summary of the state of this branch and outstanding issues. We have implemented the following:

  • New functionality in _TensorDictKeysView that can detect a cycle ad raise an error or continue (internal usage only)
  • A new function _apply_safe which can safely map any function onto all entries of the TensorDict, preserving auto-nesting if detected.

The updated keys view is useful for iterating over all values in the TensorDict and applying some in-place operation, or aggregating some computed quantities. For example, zeroing all values in the TensorDict

for key in _TensorDictKeysView(
    self, include_nested=True, leaves_only=True, error_on_loop=False
):
    value = self.get(key)
    value.zero_()

or alternatively, in the implementation of any

any(
    self.get(key).any()
    for key in _TensorDictKeysView(
        self, include_nested=True, leaves_only=True, error_on_loop=False
    )
)

On the other hand, _apply_safe can be used to reimplement any function which returns a TensorDict of the same structure as the input. For example, implementing to_tensordict is as simple as

_apply_safe(lambda _, value: value.clone(), self)

Fixed so far

  • apply_: implemented with _apply_safe
  • expand: implemented with _apply_safe
  • __eq__: implemented with _apply_safe
  • __ne__: implemented with _apply_safe
  • to_tensordict: implemented with _apply_safe
  • zero_: implemented with _TensorDictKeysView
  • clone: implemented with _apply_safe
  • __repr__: fixed manually, without either paradigm
  • all: implemented with _TensorDictKeysView when dim is not specified, and _apply_safe when it is
  • any: implemented with _TensorDictKeysView when dim is not specified, and _apply_safe when it is
  • lock: implemented with _TensorDictKeysView
  • unlock: implemented with _TensorDictKeysView
  • _index_tensordict: implemented with _apply_safe
  • masked_fill_: implemented with _TensorDictKeysView

Outstanding bugs

  • split: doesn't neatly fit paradigm of _apply_safe since we return list of TensorDicts. Could possibly use _apply_safe inside a list comprehension, but we would not be able to use torch.split, we'd have to manually compute indexes and slice which risks both being slow and also deviating from torch.split behaviour if not carefully tested.
  • select: KeyError uses set(self.keys(include_nested=True)) in the error message which fails on auto-nested
  • to_dict: could be implemented with _TensorDictKeysView, but we need a convenience function for setting nested entries of a Python dict.
  • zero_: fails for TensorDict variants with lazily computed values, as the in-place update applies to a lazily computed value and doesn't persist. Suggest replacing value.zero_ we have currently with self.set_(key, 0, no_check=True)
  • apply: needs to be updated for auto-nested case. Unclear if we can use _apply_safe.
  • _TensorDictKeysView._items: need to handle the SubTensorDict case.
  • _TensorDictKeysView: fails when instantiated with lazy variants of TensorDict
  • assert_allclose_td: fails in the auto-nested case. Shouldn't be hard to fix the recursion error, but ideally if there is auto-nesting we would check that the auto-nesting structure exists in both tensordicts.
  • masked_select: needs to be updated, could probably be done with _apply_safe
  • _index_tensordict: preserve _is_memmap etc. in nested values (really an _apply_safe bug)
  • unbind: may work automatically once apply is fixed
  • pad: recursive implementation. Could potentially use _apply_safe.
  • is_contiguous: recursive implementation needs to be updated
  • __setitem__: gets stuck in a recursive loop in the auto-nested case.
  • stack: fails in auto-nested case.
  • cat: fails in auto-nested case.
  • flatten_keys: doesn't make sense in auto-nested case, raise informative error and test for it
  • memmap_: fails in recursive case. Probably needs some care...
  • Instantiating a TensorDict from a dict with auto-nested values causes a recursion error.

Tests to refactor

  • test_lock_write: replace calls to items with include_nested=True with _TensorDictKeysView.
  • test_apply: replace calls to keys with _TensorDictKeysView.
  • test_apply_other: replace calls to keys with _TensorDictKeysView.
  • test_masking_set: helper function zeros_like is recursive and fails on auto-nested tensordicts. Potential to use _apply_safe.
  • test_entry_type: replace call to keys with _TensorDictKeysView
  • test_update: key comparison is currently calling set(td.keys(True)). Need to find alternative way to check that key structure is preserved.
  • test_getitem_range: failing on all test cases, seems that _index_tensordict is doing something incorrect when we pass range. Original issue is fixed, assert_allclose_td is causing failure now
  • test_to_dict_nested: has a recursive checker algorithm that fails on auto-nested case
  • test_unflatten_keys: doesn't really make sense in the auto-nested case.
  • test_batchsize_reset: some issue with _index_tensordict it seems
  • test_shared_inheritance: unbind doesn't preserve is_shared.

Open questions

  • select: what should happen when we select keys from an auto-nested tensordict.
  • detect_loop: we don't actually use this in any of our implementations, should we have such a public method?

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 6, 2023
@ruleva1983
Copy link

ruleva1983 commented Feb 10, 2023

test_chunk: fails for auto-nested due to usage of cat function. Tentatively disable it for that specific case
test_lock_write: corrected as suggested
test_apply: using keys view and also calling the apply_ function instead which in turn calls apply_safe

ruleva1983 and others added 7 commits February 13, 2023 11:10
Co-authored-by: Tom Begley <tomcbegley@gmail.com>
Co-authored-by: Ruggero Vasile <ruleva1983@meta.com>
…219)

Co-authored-by: Tom Begley <tomcbegley@gmail.com>
Co-authored-by: Ruggero Vasile <ruleva1983@meta.com>
@tcbegley tcbegley changed the title [WIP] Auto-nested TensorDict [WIP, POC] Auto-nested TensorDict Feb 14, 2023
@vmoens
Copy link
Contributor

vmoens commented Feb 15, 2023

We're putting this PR on hold for now.
The changes are extensive and substantially reduce code readability.
The plan is either to finish this PR at some point in the future, or adopt another strategy (e.g. a specialized class for auto-nesting).

@vmoens vmoens changed the title [WIP, POC] Auto-nested TensorDict [NOMERG, WIP, POC] Auto-nested TensorDict Feb 15, 2023
@tcbegley
Copy link
Contributor Author

For the benefit of anyone who picks this up in future, copying my comment from #220 about the reasons for some outstanding test failures:

The following tests fail because assert_allclose_td does not support auto-nested values

  • test_from_empty
  • test_masking
  • test_getitem_ellipsis
  • test_getitem_range

The following tests fail because we can't instantiate a TensorDict from a Python dict with auto-nested values

  • test_broadcast
  • test_equal_dict
  • test_nested_dict_init

Finally test_nestedtensor_stack is failing because LazyStackedTensorDict.contiguous is broken. I think a few other methods for LazyStackedTensorDict could be broken but not caught by the tests. The issues here largely stem from the fact that values are computed lazily, and hence use of id to check for repeated values is brittle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants