Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

operating with apex? #3

Open
tongjinle123 opened this issue Feb 11, 2020 · 1 comment
Open

operating with apex? #3

tongjinle123 opened this issue Feb 11, 2020 · 1 comment
Labels
enhancement New feature or request

Comments

@tongjinle123
Copy link

I am try to use this implementation with apex half precision training, but it can't.
showing that it need float rather that half:


File "/data/asr_v3/src/model/transformer_transducer/lightning_model.py", line 41, in training_step
joint_out, rnnt_loss = self.forward(feature, feature_length, target, target_length, cal_rnnt_loss=True)
File "/opt/conda/lib/python3.7/site-packages/apex/amp/_initialize.py", line 197, in new_fwd
applier(kwargs, input_caster))
File "/data/asr_v3/src/model/transformer_transducer/lightning_model.py", line 36, in forward
joint_out, rnnt_loss = self.transducer.forward(feature, feature_length, target, target_length, cal_rnnt_loss)
File "/data/asr_v3/src/model/transformer_transducer/transformer_transducer.py", line 79, in forward
rnn_t_loss = self.cal_transducer_loss(joint, ori_token, feature_length, ori_token_length)
File "/data/asr_v3/src/model/transformer_transducer/transformer_transducer.py", line 108, in cal_transducer_loss
log_probs=log_prob, labels=target.int(), frames_lengths=frame_length.int(), labels_lengths=target_length.int(), reduction='mean')
File "/opt/conda/lib/python3.7/site-packages/warp_rnnt/init.py", line 80, in rnnt_loss
costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank)
File "/opt/conda/lib/python3.7/site-packages/warp_rnnt/init.py", line 16, in forward
blank=blank,
RuntimeError: xs must be a Float tensor (rnnt_loss at binding.cpp:42)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7fa72c18c687 in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: rnnt_loss(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, int) + 0xf79 (0x7fa707c87389 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #2: + 0x22ea7 (0x7fa707c9aea7 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #3: + 0x232ee (0x7fa707c9b2ee in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #4: + 0x1fd11 (0x7fa707c97d11 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)

frame #10: THPFunction_apply(_object
, _object
) + 0x8d6 (0x7fa7601b9e96 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #63: __libc_start_main + 0xf0 (0x7fa76fc35830 in /lib/x86_64-linux-gnu/libc.so.6)

@1ytic
Copy link
Owner

1ytic commented Feb 12, 2020

Yes, the loss function implemented only for float values. I have to generalize the implementation for other types. Currently, you can convert logits to float explicitly.

@1ytic 1ytic added the enhancement New feature or request label Feb 12, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants