Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument types to be able to use torch JIT #54

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

AdrianEddy
Copy link

This PR adds types to function signatures to be able to use torch.jit or torch.onnx.export(). I also had to convert some functions to modules

The code should be equivalent to the previous one, I verified that with inference (I didn't test training though)

It's easiest to review this without whitespace diff

Related to #29

@ylab604
Copy link

ylab604 commented May 8, 2024

@AdrianEddy Thank you for great works!
I try to inference with your code. but ,

Traceback (most recent call last):
File "main_stereo.py", line 612, in
main(args)
File "main_stereo.py", line 331, in main
inference_stereo(model_without_ddp,
File "anaconda3/envs/torch/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "unimatch/evaluate_stereo.py", line 799, in inference_stereo
pred_disp = model(left, right,
File "anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "unimatch/unimatch/unimatch.py", line 190, in forward
feature0, feature1 = self.transformer(feature0, feature1,
File "anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "unimatch/unimatch/transformer.py", line 272, in forward
shifted_window_attn_mask_1d = self.generate_shift_window_attn_mask_1d(
File "anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "unimatch/unimatch/utils.py", line 207, in forward
mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1
File "unimatch/unimatch/utils.py", line 193, in window_partition_1d
B, W, C = x.shape
ValueError: too many values to unpack (expected 3)

@AdrianEddy
Copy link
Author

@ylab604 Please check now

@ylab604
Copy link

ylab604 commented May 10, 2024

I see i did change the mask function when i check your code(yesterday).
But, important thing is that onnx graph(netron)is not normal campare with pinto0309

@ylab604
Copy link

ylab604 commented May 10, 2024

Anyway thank you for your kindness. And i will also update the result of excution

@AdrianEddy
Copy link
Author

What do you mean it's not normal? What's weird about it?

@ylab604
Copy link

ylab604 commented May 10, 2024

What do you mean it's not normal? What's weird about it?

This means that if converted to onnx or jit, the inference output will be different from the original torch model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants