Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning end2end with a neural network #17

Open
jonnor opened this issue Aug 20, 2021 · 3 comments
Open

Learning end2end with a neural network #17

jonnor opened this issue Aug 20, 2021 · 3 comments

Comments

@jonnor
Copy link

jonnor commented Aug 20, 2021

Hi, thank you for this nice project.

Could one connect a neural network to the NFM module, and learn them at the same time? Any example code or tips on how to do that?
I am interested in using a convolutional neural network frontend on spectrogram data, and capture a bit more complex activations than single stationary spectrogram frames.

@yoyololicon
Copy link
Owner

yoyololicon commented Aug 21, 2021

Hi @jonnor ,

Could one connect a neural network to the NFM module, and learn them at the same time? Any example code or tips on how to do that?
I am interested in using a convolutional neural network frontend on spectrogram data, and capture a bit more complex activations than single stationary spectrogram frames.

I do plan to add some examples as jupyter notebooks but I'm currently busy at other projects.
Your application sounds totally doable to me, but you have to make sure that all the gradients pass from the loss to the NMF parameters are always non-negative.

For example, you want to train a model that will predict the activations, and learn a shared non-negative template jointly, then you can do something like this:

import torch
from torch import nn
from torch import optim
from torchnmf.trainer import BetaMu
from torchnmf import NMF


#pick an activation function so the output is non-negative
H = nn.Sequential(AnotherModel(), nn.Softplus())       
W = NMF(W=(out_channels, in_channels))

optimizer = optim.Adm(H.parameters())
trainer = BetaMu(W.parameters())

for x, y in dataloader:
     # optimize NMF
    def closure():
        trainer.zero_grad()
        with torch.no_grad():
            h = H(x)
        return y, W(H=h)
    trainer.step(closure)

    # optimize nueral net
    h = H(x)
    predict = W(H=h)
    loss = ... # you can use other types of loss here
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

@jonnor
Copy link
Author

jonnor commented Aug 31, 2021

Hi @yoyololicon - thank you for the response and example code! In this framework, the loss would be something that compares the output of the NMF (decomposed and re-composed)? Like RMS as a simple case, or a perceptual metric for something more advanced?

@yoyololicon
Copy link
Owner

yoyololicon commented Aug 31, 2021

Hi @yoyololicon - thank you for the response and example code! In this framework, the loss would be something that compares the output of the NMF (decomposed and re-composed)? Like RMS as a simple case, or a perceptual metric for something more advanced?

@jonnor
Yes, in the above code you are free to use these kinds of loss function, not only beta divergence.
The NMF part is still trained with beta divergence though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants