You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Yes, the loss function implemented only for float values. I have to generalize the implementation for other types. Currently, you can convert logits to float explicitly.
I am try to use this implementation with apex half precision training, but it can't.
showing that it need float rather that half:
File "/data/asr_v3/src/model/transformer_transducer/lightning_model.py", line 41, in training_step
joint_out, rnnt_loss = self.forward(feature, feature_length, target, target_length, cal_rnnt_loss=True)
File "/opt/conda/lib/python3.7/site-packages/apex/amp/_initialize.py", line 197, in new_fwd
applier(kwargs, input_caster))
File "/data/asr_v3/src/model/transformer_transducer/lightning_model.py", line 36, in forward
joint_out, rnnt_loss = self.transducer.forward(feature, feature_length, target, target_length, cal_rnnt_loss)
File "/data/asr_v3/src/model/transformer_transducer/transformer_transducer.py", line 79, in forward
rnn_t_loss = self.cal_transducer_loss(joint, ori_token, feature_length, ori_token_length)
File "/data/asr_v3/src/model/transformer_transducer/transformer_transducer.py", line 108, in cal_transducer_loss
log_probs=log_prob, labels=target.int(), frames_lengths=frame_length.int(), labels_lengths=target_length.int(), reduction='mean')
File "/opt/conda/lib/python3.7/site-packages/warp_rnnt/init.py", line 80, in rnnt_loss
costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank)
File "/opt/conda/lib/python3.7/site-packages/warp_rnnt/init.py", line 16, in forward
blank=blank,
RuntimeError: xs must be a Float tensor (rnnt_loss at binding.cpp:42)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7fa72c18c687 in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: rnnt_loss(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, int) + 0xf79 (0x7fa707c87389 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #2: + 0x22ea7 (0x7fa707c9aea7 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #3: + 0x232ee (0x7fa707c9b2ee in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #4: + 0x1fd11 (0x7fa707c97d11 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #10: THPFunction_apply(_object, _object) + 0x8d6 (0x7fa7601b9e96 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #63: __libc_start_main + 0xf0 (0x7fa76fc35830 in /lib/x86_64-linux-gnu/libc.so.6)
The text was updated successfully, but these errors were encountered: