Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation - Find and Fill Method for Dropout Layer #3684

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/mlpack/methods/ann/layer/dropout_impl.hpp
Expand Up @@ -87,7 +87,9 @@ void DropoutType<MatType>::Forward(const MatType& input, MatType& output)
// Scale with input / (1 - ratio) and set values to zero with probability
// 'ratio'.
mask.randu(input.n_rows, input.n_cols);
mask.transform([&](double val) { return (val > ratio); });
arma::uvec indices = arma::find(mask > ratio);
mask.zeros();
mask.elem(indices).fill(1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask.elem(indices).fill(1);
mask.elem(find(mask > ratio)).fill(1);

Would this work? It avoids the use of arma:: here, and would make this function fully general. But I think then we would need to find a way to get rid of the .zeros() call... maybe can we just do mask = (mask > ratio)? Or something along those lines? (It would be interesting to see the speed of that approach.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rcurtin You can take a look at my implementation in the latest commit. It should be a bit faster and is fully functional. I avoided using arma:: to keep it more general.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, it looks good to me! Thanks for doing that.

output = input % mask * scale;
}
}
Expand Down