Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The loss value of softDTW is inf when choosing bandwidth values #18

Open
educationunion opened this issue Dec 14, 2021 · 1 comment
Open

Comments

@educationunion
Copy link

educationunion commented Dec 14, 2021

Hello @Maghoumi ,

I have tried the example codes you provided as follows, and I noticed that when choosing bandwidth = 1 or 2, the loss values are all inf. Could you please help me solve this issue?

Sample codes:

from soft_dtw_cuda import SoftDTW
batch_size, len_x, len_y, dims = 8, 15, 12, 5
x = torch.rand((batch_size, len_x, dims), requires_grad=True)
y = torch.rand((batch_size, len_y, dims))

sdtw = SoftDTW(use_cuda=False, gamma=0.1, bandwidth=2)
loss = sdtw(x, y)
loss
--- OUTPUT:
tensor([inf, inf, inf, inf, inf, inf, inf, inf], grad_fn=<_SoftDTWBackward>)

sdtw = SoftDTW(use_cuda=False, gamma=0.1, bandwidth=1)
loss = sdtw(x, y)
loss
--- OUTPUT:
tensor([inf, inf, inf, inf, inf, inf, inf, inf], grad_fn=<_SoftDTWBackward>)

Thank you very much.

Regards,

@Maghoumi
Copy link
Owner

Thanks for posting and example @educationunion! This issue was previously reported in #8 but I didn't have a good minimal example to try. I had a quick look, and I think this is caused by the condition check around here that fails to work as expected when the length of the two sequences being compared is different. This is perhaps the result of adding support for such sequences in (this commit).

I think the condition has take the length of each sequence into account before skipping the calculation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants