-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Auto-nested tensordict bugs #106
Comments
Working to fix this issue |
I think there are essentially 3 things to consider: Here is an example of how to do it with def to_tensordict(
tensordict, current_key: Tuple = None, being_computed: Dict = None
):
"""A version of to_tensordict that supports auto-nesting."""
out_dict = {}
if current_key is None:
current_key = ()
if being_computed is None:
being_computed = {}
being_computed[current_key] = id(tensordict)
for key, value in tensordict.items():
if isinstance(value, TensorDictBase):
nested_key = current_key + (key,)
if id(value) in being_computed.values():
being_computed[nested_key] = id(value)
continue
new_value = to_tensordict(
value, current_key=nested_key, being_computed=being_computed
)
else:
new_value = value.clone()
out_dict[key] = new_value
out = TensorDict(
out_dict,
device=self.device,
batch_size=self.batch_size,
_run_checks=False,
)
for other_nested_key, other_value in being_computed.items():
if other_nested_key != current_key:
if other_value == id(tensordict):
out[other_nested_key] = out
return out
return to_tensordict(self) Again, we keep track of what is being processed. If something is being processed, we just ignore that for now and we delay the writing of that thing until completion of the operation on the nested tensordict. (3) some methods do not return a tensordict of the same structure but other stuff: eg: To resolve this issue, we should approach each problem independently: first the keys, second the tensor-to-tensor methods and lastly the others. |
|
Describe the bug
Auto-nesting may be a desirable feature (e.g. to build graphs), but currently it is broken for multiple functions, e.g.
Consideration
This is something that should be included in the tests. We could design a special test case in
TestTensorDictsBase
with a nested self.## Solution
IMO there is not a single solution to this problem. For repr, could find a way of representing a nested tensordict, something like
For keys, we could avoid returning a key if it a key pointing to the same value has already been returned (same for values and items).
For flatten_keys, it should be prohibited for TensorDict. The options are (1) leave it as it is since the maximum recursion already takes care of it or (2) build a wrapper around
flatten_keys()
to detect if the same method (i.e. the same call to the same method from the same class) is occurring twice, something likeThere are probably other properties that I'm missing, but i'd expect them to be covered by the tests if we can design the dedicated test pipeline mentioned earlier.
The text was updated successfully, but these errors were encountered: