Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Steps to convert mmaction's video-swin-transformer to ONNX successfully #89

Open
gigasurgeon opened this issue Dec 9, 2023 · 3 comments

Comments

@gigasurgeon
Copy link

Hello all. I have been trying to export mmaction's video-swin transformer model to ONNX. However, the script tools/deployment/pytorch2onnx.py provided in this repo was giving me following errors:

error 1) Floating point exception (core dumped)
error 2) RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/onnx/shape_type_inference.cpp":513, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.
error 3) other issues

I tried other another repo's reimplementation i.e https://github.com/haofanwang/video-swin-transformer-pytorch , but found same set of issues.
No luck. The default model was able to infer properly, but during onnx export it was failing. Even torch.jit.script() was failing.

So, after days of effort, I was able to come up with a way to successfully export my trained video-swin model to ONNX. Here's the code. I hope this will help.

from torchvision.models.video.swin_transformer import SwinTransformer3d
import torch
from collections import OrderedDict


torchvision_model = SwinTransformer3d(
        patch_size=[2, 4, 4],
        embed_dim= 128,
        depths= [2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        window_size=[16, 7, 7],
        mlp_ratio=4.0,
        dropout=0.0,
        attention_dropout= 0.0,
        stochastic_depth_prob=0.1,
        num_classes=5)


mmaction_weights = torch.load('../dl_model_ckpt_swin/frames/swin_last.pth')

assert len(torchvision_model.state_dict())==len(mmaction_weights['state_dict']), "mamction video-swin weight's length doesn't match with torchvision video-swin model's architecture"

# print(torchvision_model)

############################
######## printing pytorch torchvision's swin state_dict without loading checkpoint
# for k, i in enumerate(torchvision_model.state_dict()):
#     print(i, torchvision_model.state_dict()[i].shape)

# print('*'*50)
# print()

######## printing mmaction's swin checkpoints state_dict
# for k, i in enumerate(mmaction_weights['state_dict']):
#     print(i, mmaction_weights['state_dict'][i].shape)

############################


########## asserting shape of state dicts

torchvision_model_keys = [i for i in torchvision_model.state_dict()]
mmaction_weight_keys = [i for i in mmaction_weights['state_dict']]


for i in range(len(torchvision_model_keys)):
    shape_1 = torchvision_model.state_dict()[torchvision_model_keys[i]].shape
    shape_2 = mmaction_weights['state_dict'][mmaction_weight_keys[i]].shape

    if shape_1!=shape_2:
        print('shapes not matching')
        break
print('done')


############################ changing actual weight values in the torchvision swin

new_torchvision_state_dict = OrderedDict()

for i in range(len(torchvision_model_keys)):
    new_torchvision_state_dict[torchvision_model_keys[i]] = mmaction_weights['state_dict'][mmaction_weight_keys[i]]

torchvision_model.load_state_dict(new_torchvision_state_dict)


# for i in range(len(torchvision_model_keys)):
# #     print(torchvision_model.state_dict()[torchvision_model_keys[i]][-1])
# #     print('a')
# #     torchvision_model.state_dict()[torchvision_model_keys[i]] = mmaction_weights['state_dict'][mmaction_weight_keys[i]]
#     print(mmaction_weights['state_dict'][mmaction_weight_keys[i]][[-1]])
#     print('a')
#     print(torchvision_model.state_dict()[torchvision_model_keys[i]][-1])
#     print('b')
#     print(new_torchvision_state_dict[torchvision_model_keys[i]][-1])
#     exit()



print('done')

input_shape = [1, 3, 8, 224, 224]
input_tensor = torch.randn(input_shape)
a = torchvision_model(input_tensor)
# torch.jit.script(torchvision_model, (input_tensor))
# torchvision_model = torch.compile(torchvision_model)


torch.onnx.export(
        torchvision_model,
        input_tensor,
        'video_swin.onnx',
        export_params=True,
        keep_initializers_as_inputs=True,
        verbose=True,
        opset_version=15)
@adeljalalyousif
Copy link

adeljalalyousif commented Mar 25, 2024

I have tried the following method but it did not work with video swin transformer, it only works with CNN model from ,torchvision, can you help me with this method:


from mmcv import Config  
from vst.mmaction.models import build_model # vst is folder contains the cloned GitHub of Video Swin Transformer
from mmcv.runner import  load_checkpoint

import torch
import torch.onnx

config = './vst/configs/recognition/swin/swin_tiny_patch244_window877_kinetics400_1k.py'
checkpoint = './vst/checkpoints/swin_tiny_patch244_window877_kinetics400_1k.pth'
 
cfg = Config.fromfile(config)
model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
 
model.eval() 
model.cuda()
 # Load the checkpoint onto the GPU
checkpoint = load_checkpoint(model, checkpoint, map_location='cuda')
  
BATCH_SIZE = 1
T = 16

dummy_input=torch.randn(BATCH_SIZE, 3, T, 224, 224)

# export the model to ONNX
torch.onnx.export(model, dummy_input, "siwn_T.onnx", verbose=False)

I got this error:

Traceback (most recent call last):

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\spyder_kernels\py3compat.py", line 356, in compat_exec
exec(code, globals, locals)

File "c:\users\msi\untitled2.py", line 30, in
torch.onnx.export(model, dummy_input, "siwn_T.onnx", verbose=False)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx_init_.py", line 350, in export
return utils.export(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 163, in export
_export(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 1074, in _export
graph, params_dict, torch_out = _model_to_graph(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 727, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 602, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 517, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\jit_trace.py", line 1175, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\jit_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\jit_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\nn\modules\module.py", line 1118, in _slow_forward
result = self.forward(*input, **kwargs)

File "C:\Users\MSI\vst\mmaction\models\recognizers\base.py", line 253, in forward
raise ValueError('Label should not be None.')

ValueError: Label should not be None.

the original block of code in vst that cause the error is:

[ def forward(self, imgs, label=None, return_loss=True, **kwargs):
      """Define the computation performed at every call."""
      if kwargs.get('gradcam', False):
          del kwargs['gradcam']
          return self.forward_gradcam(imgs, **kwargs)
      if return_loss:
          if label is None:
              raise ValueError('Label should not be None.')
          if self.blending is not None:
              imgs, label = self.blending(imgs, label)
          return self.forward_train(imgs, label, **kwargs)

@innat-asj
Copy link

innat-asj commented Mar 29, 2024

@adeljalalyousif
Copy link

adeljalalyousif commented May 22, 2024

Thanks a lot, but when I try to convert onnx models provided in Kaggle link to TR engine by using :
trtexec --onnx=VideoSwinB_K400_IN1K_P244_W877_32x224.onnx --saveEngine=RT_engine_pytorch.trt --explicitBatch
But it fails

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants