Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create vae_mnist_new_architecture.jl #487

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Create vae_mnist_new_architecture.jl #487

wants to merge 3 commits into from

Conversation

NikonPic
Copy link

Proposal for using the vae mnist exmaple with the newer API from Knet. Type definitions allow increased performance (~20% faster due to lower gc time) and better readability of the network architecture.

Proposal for using the vae mnist exmaple with the newer API from Knet. Type definitions allow increased performance (~20% faster due to lower gc time) and better readability of the network architecture.
function train(ae, dtrn, iters)
img = convert(Atype, reshape(dtrn.x[:,1], (28, 28, 1, 1)))
for epoch = 1:iters
@time adam!(ae, dtrn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm not wrong, this is not the correct way to iterate over epochs, since here each time a new Adam struct is created and information (e.g. accumulated moments) from previous epochs are lost

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes completely correct! Thanks for the advice, i have adapted my example proposal accordingly.

Added the proper version for designing the training, more detailed callback and improved type definitions.

BCE = F(0)

for s = 1:samples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably not efficient, you can run all samples at once by additional "sample" batching. First, you need to reshape μ to (nz,B,1), then you need to sample from randn with size (nz,B,Nsample) and broadcast μ on it. Then, you can change binary_cross_entropy to deal with (nz,B,Nsample) input.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct the form was not efficient for sampling multiple times within one batch. The Suggestion to broadcast is more effecient and much faster. However, I was not able to broadcast this efficiently trough the decoder network. As the sampling doesn't increase performance as far as i can tell and the majority of implementations i found do not use it, i have also abandoned this for the example here.

Multiple sampling has been removed as it is also not used in the original VAE approach.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants