Skip to content

Commit

Permalink
Fix tuple iterator issue (#99443)
Browse files Browse the repository at this point in the history
* Fix tuple iterator issue

* Lintrunner
  • Loading branch information
mlazos committed Apr 24, 2023
1 parent e4bdb86 commit 9e8bd61
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 11 deletions.
20 changes: 17 additions & 3 deletions test/dynamo/test_misc.py
Expand Up @@ -2294,7 +2294,6 @@ def foo(x):
self.assertIs(x_ref(), None)

def test_release_module_memory(self):

mod = torch.nn.Linear(10, 10)
x = torch.rand([10, 10])
mod_weight_ref = weakref.ref(mod.weight)
Expand Down Expand Up @@ -2640,7 +2639,6 @@ def __init__(self):
self.names = []

def forward(self, idx, targets=None):

b, t = idx.size()
assert (
t <= self.block_size
Expand Down Expand Up @@ -3763,7 +3761,6 @@ def fn(x, y):
self.assertTrue(same(ref, res))

def test_disable_flag(self):

cnt = torch._dynamo.testing.CompileCounter()

with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}):
Expand Down Expand Up @@ -4046,6 +4043,23 @@ def fn(x, y):
res = opt_fn(x, y)
self.assertTrue(same(ref, res))

def test_tuple_from_tuple_iter(self):
def inner_fn(*args):
acc = torch.ones(10, 10)
for arg in args:
acc.add_(arg)

return acc

@torch._dynamo.optimize("eager")
def fn(inputs, params):
y = tuple(inputs) + tuple(params)
return inner_fn(*y)

inputs = [torch.randn(10, 10) for _ in range(3)]

fn(inputs, iter(tuple(inputs)))

def test_torch_package_working_with_trace(self):
# from torch._dynamo.test_case import run_tests

Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/builder.py
Expand Up @@ -65,12 +65,12 @@
)
from .functions import UserFunctionVariable
from .lists import (
ListIteratorVariable,
ListVariable,
NamedTupleVariable,
RangeVariable,
SizeVariable,
SliceVariable,
TupleIteratorVariable,
TupleVariable,
)
from .misc import (
Expand Down Expand Up @@ -265,7 +265,7 @@ def _wrap(self, value):
)(tuple_iterator_getitem(value, i)).add_guards(guards)
for i in range(tuple_iterator_len(value))
]
return ListIteratorVariable(
return TupleIteratorVariable(
output, mutable_local=MutableLocal(), guards=guards
)
elif istype(value, (slice, range)):
Expand Down
13 changes: 7 additions & 6 deletions torch/_dynamo/variables/builtin.py
Expand Up @@ -26,7 +26,7 @@
from .base import MutableLocal, typestr, VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import BaseListVariable, ListVariable, TupleVariable
from .lists import BaseListVariable, ListVariable, TupleIteratorVariable, TupleVariable
from .tensor import FakeItemVariable, SymNodeVariable, UnspecializedPythonVariable
from .user_defined import UserDefinedVariable

Expand Down Expand Up @@ -195,7 +195,7 @@ def _binop_handlers():

# Override table contains: op_fn -> [list of handlers]
op_handlers = {}
for (op, magic_method_names) in itertools.chain(
for op, magic_method_names in itertools.chain(
BuiltinVariable._inplace_binops().items(),
BuiltinVariable._reversible_binops().items(),
):
Expand Down Expand Up @@ -355,7 +355,7 @@ def _find_binop_handler(op, a, b):
return None

# Return first handler that matches the type checks
for ((type1, type2), handler) in handlers[op]:
for (type1, type2), handler in handlers[op]:
if isinstance(a, type1) and isinstance(b, type2):
return handler

Expand Down Expand Up @@ -646,7 +646,6 @@ def _call_min_max_binary(self, tx, a, b):
)
for i in [a, b]
):

if any([isinstance(val, FakeItemVariable) for val in [a, b]]):
return variables.FakeItemVariable.from_tensor_variable(result)

Expand Down Expand Up @@ -683,7 +682,6 @@ def _call_min_max_binary(self, tx, a, b):
)
return SymNodeVariable.create(tx, proxy, None)
else:

unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}")

call_min = _call_min_max
Expand Down Expand Up @@ -739,7 +737,10 @@ def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
elif obj.has_unpack_var_sequence(tx):
guards = set()
if obj.source and not is_constant_source(obj.source):
guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
if isinstance(obj, TupleIteratorVariable):
guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN))
else:
guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
return cls(
list(obj.unpack_var_sequence(tx)),
mutable_local=MutableLocal(),
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/variables/lists.py
Expand Up @@ -534,3 +534,7 @@ def reconstruct(self, codegen):
create_instruction("BUILD_TUPLE", len(remaining_items)),
create_instruction("GET_ITER"),
]


class TupleIteratorVariable(ListIteratorVariable):
pass

0 comments on commit 9e8bd61

Please sign in to comment.