Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange behavior using PyTorch DDP #32

Open
snakers4 opened this issue Jan 13, 2022 · 7 comments
Open

Strange behavior using PyTorch DDP #32

snakers4 opened this issue Jan 13, 2022 · 7 comments

Comments

@snakers4
Copy link

@1ytic
Hi,

So far I have been able to use the loss with DDP on a single GPU , it behaves more or less as expected.

But when I use more than 1 device, the following happens:

  • On GPU-0 loss is calculated properly
  • On GPU-1 loss is close to zero for each batch

I checked the input tensors, devices, tensor values, etc - so far everything seems to be identical for GPU-0 and other GPUs.

@snakers4
Copy link
Author

@burchim
By the way, since you used this loss, did you encounter anything of this sort in your work?

@burchim
Copy link

burchim commented Jan 13, 2022

Hi @snakers4!
Yes I had a similar problem with 4 GPU devices where the rnnt loss was properly computed on the first devices but 0 on the others. I don't really remember what was the exact cause but it had something to with tensor devices. Maybe the frames / label lengths.

I also recently experimented replacing it with the official torchaudio.transforms.RNNTLoss loss from torchaudio 0.10.0.
Was working very well but I didn't try to do a full training with it.

@snakers4
Copy link
Author

Thanks for the heads up about the torchaudio loss!
I remember seeing it sometime ago, but I totally forgot about it.

@snakers4
Copy link
Author

@burchim
By the way, did you have RuntimeError: input length mismatch when migrating from warp-rnnt towards torchaudio?

@burchim
Copy link

burchim commented Jan 13, 2022

Yes, this means that logits / target lengths tensors do not match the logits / target tensors.
If you have logits lengths longer than your logits tensor for instance.

@burchim
Copy link

burchim commented Jan 13, 2022

Because I used the targets lengths instead of logits lengths, stupid error

@csukuangfj
Copy link

Thanks for the heads up about the torchaudio loss!

@snakers4
You may find https://github.com/danpovey/fast_rnnt useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants