Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shouldn't Prototypical loss use torch.cdist instead of F.pairwise_distance? #169

Open
theolepage opened this issue Jan 26, 2023 · 2 comments
Assignees
Labels
question Further information is requested

Comments

@theolepage
Copy link

Hello,

I have a question regarding the following part used for the Prototypical loss computation.

output = -1 * (F.pairwise_distance(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))**2)

From my understanding, output should be similar to the cosine similarity matrix used for the Angular Prototypical loss but based on Euclidean distances instead.

Thus, the output tensor should have a shape of $(N, N)$ (with $N$ the number of samples in the mini-batch) and values at $i, j$ should be the squared Euclidean distance between sample $i$ of out_positive and sample $j$ of out_anchor.

However, F.pairwise_distance computes the pairwise distance between out_positive and out_anchor and not the distance between each pair between two sets of row vectors like torch.cdist.

Visualization of the difference between F.pairwise_distance and F.cosine_similarity

diff_pairwise_distance_and_cosine_similarity

As a result the output shape will be $(N, D)$ (with $D$ the output dimension of the model) and the following loss computation is not coherent.

Thanks.

@lbjcom lbjcom added the question Further information is requested label Jan 31, 2023
@AlexGranger-scn
Copy link

Same confusion..... Have you solved this problem now? Thanks for your reply!

@deGennesMarc
Copy link

deGennesMarc commented Mar 15, 2024

Same issue. In the spirit of @theolepage 's suggestion, I replaced the line with :
output = -torch.cdist(out_positive, out_anchor, p=2).pow(2)
but as of today it does not work for me.

Also it seems to me the definition of the prototypical loss from the "In defence of metric learning" paper is wrong as there should be a minus sign in front of the distances S_{j,k} in the softmax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants