Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NestedTensor] Graph breaks with SDPA + NT constructor #126472

Open
davidberard98 opened this issue May 16, 2024 · 1 comment
Open

[NestedTensor] Graph breaks with SDPA + NT constructor #126472

davidberard98 opened this issue May 16, 2024 · 1 comment
Labels
module: nestedtensor NestedTensor tag see issue #25032 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@davidberard98
Copy link
Contributor

davidberard98 commented May 16, 2024

馃悰 Describe the bug

When we use SDPA, we need max_seqlen and min_seqlen. Getting max/min_seqlen normally requires a .item call (which usually graph breaks, I think?).

So this focuses on removing graph breaks where:

  • We construct the NT in the graph
  • we use SDPA and
  • we pass in the max/min_seqlen manually.

General repro - the approach is to call nested_view_from_values_offsets_lengths with max_seqlen and min_seqlen passed in:

import torch
from torch.nested._internal.nested_tensor import ViewNestedFromBuffer, nested_view_from_values_offsets_lengths
import torch._dynamo

# note: for testing with ViewNestedFromBuffer, which I wasn't able to get working
torch._dynamo.allow_in_graph(ViewNestedFromBuffer)


def convert_jagged_to_nested_tensor(
    values: torch.Tensor, offsets: torch.Tensor, max_length: int
) -> torch.Tensor:
    # metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1}
    # nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache)
    nt = nested_view_from_values_offsets_lengths(values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length)
    return nt


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.q_linear = torch.nn.Linear(8, 32*4)
        self.k_linear = torch.nn.Linear(8, 32*4)
        self.v_linear = torch.nn.Linear(8, 32*4)

    def forward(self, values, offsets):
        nt = convert_jagged_to_nested_tensor(values, offsets, 5)
        q, k, v = [mod(nt) for mod in (self.q_linear, self.k_linear, self.v_linear)]
        q, k, v = [
            x.view(4, -1, 4, 32).transpose(1, 2)
            for x in (q, k, v)
        ]
        sdpa_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False)
        return sdpa_out.values()


values = torch.randn(10, 8, device='cuda')
offsets = torch.tensor([0, 1, 3, 6, 10], device='cuda')

# optionally, dynamic=False; but I wasn't able to get that to work either
torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)

Failure 1: With #122836 (rebased onto 7f1d5ab)

Traceback (most recent call last):
  File "/home/dberard/local/scripts/nt_2.py", line 38, in <module>
    torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/home/dberard/local/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2423, in RETURN_VALUE
    self._return(inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _return
    self.output.compile_subgraph(
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1083, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1300, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1391, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1372, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/dberard/local/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/dberard/local/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/dberard/local/pytorch/torch/__init__.py", line 1747, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_inductor/compile_fx.py", line 1478, in compile_fx
    return aot_autograd(
  File "/home/dberard/local/pytorch/torch/_dynamo/backends/common.py", line 65, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/dberard/local/pytorch/torch/_functorch/aot_autograd.py", line 962, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_functorch/aot_autograd.py", line 554, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/dberard/local/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 566, in inner
    dynamic_dims = {
  File "/home/dberard/local/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 567, in <setcomp>
    i for i, s in enumerate(o.shape) if not is_concrete_int(s)
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 221, in is_concrete_int
    if isinstance(a.node.expr, sympy.core.numbers.Integer):
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'torch._C._SymNode' object has no attribute 'expr'

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Failure 2: Based on the failure, I tried with @soulitzer's PR #124624 patched on top:

Traceback (most recent call last):
  File "/home/dberard/local/scripts/nt_2.py", line 38, in <module>
    torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/home/dberard/local/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 737, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 743, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2447, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2563, in inline_call_
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1306, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/torch.py", line 754, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 1585, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 1708, in wrap_fx_proxy_cls
    set_example_value(proxy.node, example_value)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 1166, in set_example_value
    if symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 604, in compute_unbacked_bindings
    symbol_to_path = free_unbacked_symbols_with_path(example_value, ())
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 541, in free_unbacked_symbols_with_path
    real=a.real_tensor.size() if a.real_tensor is not None else None
torch._dynamo.exc.InternalTorchDynamoError: 'NestedTensor' object has no attribute 'real_tensor'

from user code:
   File "/home/dberard/local/scripts/nt_2.py", line 25, in forward
    nt = convert_jagged_to_nested_tensor(values, offsets, 5)
  File "/home/dberard/local/scripts/nt_2.py", line 13, in convert_jagged_to_nested_tensor
    nt = nested_view_from_values_offsets_lengths(values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Failure 3: Based on this, I tried a quick patch: this change

Traceback (most recent call last):
  File "/home/dberard/local/scripts/nt_2.py", line 38, in <module>
    torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/home/dberard/local/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 743, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2447, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2563, in inline_call_
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1306, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/torch.py", line 754, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 1585, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 1708, in wrap_fx_proxy_cls
    set_example_value(proxy.node, example_value)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 1166, in set_example_value
    if symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 605, in compute_unbacked_bindings
    assert not pending, (
AssertionError: pending {u0} not in NestedTensor(size=(4, u1, 8), offsets=FakeTensor(..., device='cuda:0', size=(5,), dtype=torch.int64), contiguous=True) ((8*u1, 8, 1), 0)

from user code:
   File "/home/dberard/local/scripts/nt_2.py", line 25, in forward
    nt = convert_jagged_to_nested_tensor(values, offsets, 5)
  File "/home/dberard/local/scripts/nt_2.py", line 13, in convert_jagged_to_nested_tensor
    nt = nested_view_from_values_offsets_lengths(values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I haven't gotten around to investigating this yet. Maybe #126198 is related (just based on unbacked symint <-> NT issues).

Failure 4: One other attempt - I figured I'd try #124803 to see if it would fix the issue without unbacked symint issues, but it runs into other issues where we get multiple NestedInts for the same dimension. (So we should probably just go with #124624 and figure out what the unbacked symint issue is about)

Traceback (most recent call last):
  File "/home/dberard/local/scripts/nt_2.py", line 38, in <module>
    torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/home/dberard/local/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2423, in RETURN_VALUE
    self._return(inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _return
    self.output.compile_subgraph(
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1083, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1300, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1391, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1372, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/dberard/local/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/dberard/local/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/dberard/local/pytorch/torch/__init__.py", line 1747, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_inductor/compile_fx.py", line 1478, in compile_fx
    return aot_autograd(
  File "/home/dberard/local/pytorch/torch/_dynamo/backends/common.py", line 65, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/dberard/local/pytorch/torch/_functorch/aot_autograd.py", line 962, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_functorch/aot_autograd.py", line 554, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/dberard/local/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 692, in inner
    fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs)
  File "/home/dberard/local/pytorch/torch/utils/_pytree.py", line 943, in tree_map
    return treespec.unflatten(map(func, *flat_args))
  File "/home/dberard/local/pytorch/torch/utils/_pytree.py", line 782, in unflatten
    leaves = list(leaves)
  File "/home/dberard/local/pytorch/torch/_functorch/_aot_autograd/functional_utils.py", line 59, in from_fun
    out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t))
  File "/home/dberard/local/pytorch/torch/utils/_python_dispatch.py", line 322, in transform_subclass
    assert sub.shape == outer_size, (
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: Expected return value from <class 'torch.nested._internal.nested_tensor.NestedTensor'>__tensor_unflatten__() to have shape equal to torch.Size([4, j2, 8]), but got: torch.Size([4, j3, 8])

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

Described above - but these were all built on 7f1d5ab for H100

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer

@davidberard98 davidberard98 added the module: nestedtensor NestedTensor tag see issue #25032 label May 17, 2024
@jbschlosser
Copy link
Contributor

jbschlosser commented May 17, 2024

I haven't gotten around to investigating this yet. Maybe #126198 is related (just based on unbacked symint <-> NT issues).

I ran into a similar error, which prompted the fix in the linked PR.

FWIW I was able to get your repro working without graph breaks using a combination of:

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nestedtensor NestedTensor tag see issue #25032 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants