Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] error when training vit model with torch.compile #1852

Open
HaoLiuHust opened this issue Dec 21, 2023 · 0 comments
Open

[Bug] error when training vit model with torch.compile #1852

HaoLiuHust opened this issue Dec 21, 2023 · 0 comments

Comments

@HaoLiuHust
Copy link

HaoLiuHust commented Dec 21, 2023

Branch

main branch (mmpretrain version)

Describe the bug

I am training a vit model(mvitv2 from timm), when training it with torch compile, there will be an error:

run_epoch()
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/loops.py", line 112, in run_epoch
    self.run_iter(idx, data_batch)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/loops.py", line 128, in run_iter
    outputs = self.runner.model.train_step(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/model/wrappers/distributed.py", line 119, in train_step
    with optim_wrapper.optim_context(self):
  File "/opt/conda/lib/python3.10/site-packages/mmengine/model/wrappers/distributed.py", line 120, in <resume in train_step>
    data = self.module.data_preprocessor(data, training=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1110, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 557, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 593, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2174, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2281, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1110, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 557, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py", line 729, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1191, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1278, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1376, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1337, in get_fake_value
    return wrap_fake_exception(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 916, in wrap_fake_exception
    return fn()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1338, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1410, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1397, in run_node
    return node.target(*args, **kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function one_hot>(*(FakeTensor(..., device='cuda:1', size=(s0 + s1 + 118,), dtype=torch.int64), 2), **{}):
Cannot call sizes() on tensor with symbolic sizes/strides

from user code:
   File "/mnt/pai-storage-8/jieshen/code/mmlab/mmpretrain/mmpretrain/models/utils/data_preprocessor.py", line 177, in forward
    batch_score = batch_label_to_onehot(
  File "/mnt/pai-storage-8/jieshen/code/mmlab/mmpretrain/mmpretrain/structures/utils.py", line 123, in batch_label_to_onehot
    sparse_onehot_list = F.one_hot(batch_label, num_classes)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

it I turn off the torch compile flag, it will be ok

Environment

{'sys.platform': 'linux',
 'Python': '3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0]',
 'CUDA available': True,
 'numpy_random_seed': 2147483648,
 'GPU 0': 'NVIDIA GeForce RTX 4090',
 'CUDA_HOME': '/usr/local/cuda',
 'NVCC': 'Cuda compilation tools, release 11.8',
 'GCC': 'gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0',
 'PyTorch': '2.1.2',
 'TorchVision': '0.15.2',
 'OpenCV': '4.8.0',
 'MMEngine': '0.10.1',
 'MMCV': '2.0.1',
 'MMPreTrain': '1.1.1+'}

Other information

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant