Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attacks in Jax only support cross entropy loss #1184

Open
ZuowenWang0000 opened this issue Dec 18, 2020 · 2 comments
Open

Attacks in Jax only support cross entropy loss #1184

ZuowenWang0000 opened this issue Dec 18, 2020 · 2 comments

Comments

@ZuowenWang0000
Copy link

The fast_gradient_method in the Jax implementation is now by default using cross-entropy loss for crafting adversarial examples:

loss = - np.sum(logsoftmax(pred) * label)

It is apparently now always correct to assume people are using cross-entropy loss.

Describe the solution you'd like
The most straight forward solution would be to pass the loss function being used as an extra parameter to both fgsm and pgd functions. This would be also coherent with attacks implemented in other frameworks such as in tf:

An alternative would be instead of passing the predict function, we pass a model object which has the predict function and loss function registered.

@jonasguan
Copy link
Collaborator

Thanks for the suggestion @ZuowenWang0000! If you can submit a PR with your proposed changes, we'd be glad to review and merge it.

@jonasguan jonasguan reopened this Jan 19, 2021
@jonasguan
Copy link
Collaborator

Oops, did not mean to close this yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants