Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Coding trees with TensorDict #146

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft

[Draft] Coding trees with TensorDict #146

wants to merge 9 commits into from

Conversation

tcbegley
Copy link
Contributor

@tcbegley tcbegley commented Jan 5, 2023

Description

This PR contains a draft implementation of support for trees using TensorDict.

We allow the user to create a tree structure where each node in the tree is a tensordict with the same keys. The entire contents of the tree are backed by a single source tensordict with pre-allocated memory. As nodes are added to the tree the data in the node is stored in the source. The primary advantage of this setup is that collecting the data from multiple nodes requires us to simply index the source which can be done very efficiently compared to a naive alternative which would require us to iterate over the nodes we want and stack the results.

The interface is very preliminary and up for debate. The tests should give an idea of usage, here's a very basic example.

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.prototype import make_tree

>>> # create the root node from a tensordict
>>> # preallocate memory for 100 nodes in the source
>>> root = make_tree(TensorDict({"data": torch.rand(2, 3, 4)}, [2, 3]), n_nodes=100)

>>> # adding a new tensordict as a child creates a new node in the tree
>>> root["left"] = TensorDict({"data": torch.rand(2, 3, 4)}, [2, 3])

>>> # we can now gather the data for the two nodes in the tree
>>> root.get_multiple_items("data", ("left", "data"))
tensor(...)  # shape torch.Size([2, 2, 3, 4])

>>> # alternatively we can create a tree from a nested tensordict
>>> root = make_tree(
...     TensorDict(
...         {"data": torch.ones(2, 3, 4), "left": TensorDict({"data": torch.zeros(2, 3, 4)}, [2, 3])},
...         [2, 3]
...     ),
...     n_nodes=100,
... )
>>> root.get_multiple_items("data", ("left", "data"))

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 5, 2023
@tcbegley tcbegley marked this pull request as draft January 5, 2023 14:32
tensordict/prototype/tree.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool!
Can we add a file in the benchmark directory to test the tree against a plain tensordict implementation (if that makes sense)?

@vmoens vmoens added the enhancement New feature or request label Jan 5, 2023
@tcbegley
Copy link
Contributor Author

tcbegley commented Jan 5, 2023

Added some simple benchmarks. lmk what other things you think we should benchmark

get_multiple_items
tensordict: 2.145204999993439
tree: 0.9129869579919614

get_multiple_items_deep
tensordict 2.4592998750013066
tree 0.8956774580001365

@vmoens
Copy link
Contributor

vmoens commented Jan 5, 2023

impressive!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants