Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SincNet vs. STFT #74

Closed
snakers4 opened this issue Nov 15, 2019 · 26 comments
Closed

SincNet vs. STFT #74

snakers4 opened this issue Nov 15, 2019 · 26 comments

Comments

@snakers4
Copy link

Hi,

We are doing STT / TTS for the Russian language.

We mostly used STFT due to our ignorance in DSP and our understanding that MFCC filters used by everyone may be a bit over-engineered (I have seen no papers actually comparing them properly vs. STFT and similar features for morphologically rich languages).

So, my question is as follows. I understand, that your filters contain a magnitude less parameters than ordinary CNN layers, and in essence are just DSP inspired frequency filters.

In our experience we tried a lot of separable convolutions, and we mostly agree with this paper (i.e. the success of mobile networks shows that the convolutions are overrated, but the mix layer + shift layer can do the job).

Here STFT is imlemented as a convolution (they inherit their kernels from numpy, but I guess they are similar in essence to triangual filters from MFCC). So, my questions are:

  • Have you tried using STFT, but actually training the filters?
  • Have you tried using separable convolutions?

Best,
Alex

@mravanelli
Copy link
Owner

Hi Alex,
thank you very much for your questions. As you have pointed out in SincNet we are not using STFT, but we are doing something related to it. In fact, you can see SincNet as a way to learn the STFT itself, where instead of using "fixed fft points" we use learnable channels. If you reorder the channels according to the central frequency you obtain a representation similar to the STFT (not that in SincNet we are not taking the magnitude of the STFT, but we have the full representation that includes the phase contribution). You can for sure learn some filters on the top of the STFT (as done in some papers cited in my work), but in SincNet we decided to avoid the step of pre-computing the STFT and starts from the lowest possible representation (i.e, the simple waveform). We didn't tried separable convolutions, but could be worth trying.

@snakers4
Copy link
Author

snakers4 commented Nov 16, 2019

I guess during some time in future I will report 3 experiments:

  • baseline stft
  • stft, but filters can be tuned
  • sincnet

Would be cool to compare

@antimora
Copy link

@snakers4 I would be very interested to see your results. Particularly recently I was thinking how it would be possible to adapt SincNet to apply sinc operation on spectrogram instead of directly on waveform. This would be a similar to melspectrogram function of pytorch (https://pytorch.org/audio/transforms.html#melspectrogram) but the filterbanks will change depending on tunable params.

My current main issue is scaling SincNet on larger inputs (15 seconds of waveform). It takes way too much memory even to get to the max pool layer. I asked for suggestions in PyTorch forum recently (still waiting for a reply): https://discuss.pytorch.org/t/attempting-to-reduce-memory-consumption-by-fusing-conv1d-with-maxpool1d/61448

@mravanelli
Copy link
Owner

mravanelli commented Nov 19, 2019 via email

@mpariente
Copy link

Hi,
If I may, I can suggest you to read this paper [disclaimer : I wrote it] which compares learnable, parametrized (such as the base layer of SincNet) and fixed filterbanks (such as STFT) for speech separation.
The code for the three filterbanks is available in asteroid.filterbanks[disclaimer : I wrote it] . The SincNet filterbank is also extended to be analytic.

Have you tried using STFT, but actually training the filters?

This is exactly what the FreeFB and AnalyticFreeFB are doing in the code I linked to.

My current main issue is scaling SincNet on larger inputs (15 seconds of waveform). It takes way too much memory even to get to the max pool layer.

This is because the stride is very small, I can suggest using analytic filterbanks (from which shift invariant representation can be computed) with larger stride to solve this problem.

@antimora
Copy link

@mpariente, thank you for you comment and links. Do you have any examples how one can convert waveform speech signal to features? I was trying to check out your latest project code but could not figure out how to use AnalyticFreeFB. Is the code complete?

@mpariente
Copy link

@antimora Yes the code for the filterbanks is complete but I should write how to use it somewhere, sorry.

import torch
from asteroid.filterbanks import AnalyticFreeFB, Encoder
encoder = Encoder(AnalyticFreeFB(n_filters=512,
                                 kernel_size=256,
                                 stride=128))
waveform = torch.randn(8, 1, 32000)
features = encoder(waveform)  # (batch, 512, time)

I'll soon work on the docs but don't hesitate to open an issue there.

@snakers4
Copy link
Author

We are finishing the current set of experiments, and dedicating a couple of weeks to run the comparison of the "first convolutions" seems alluring because I did not do it yet at all =)

Now we are just using this STFT implementation (it works on CPU faster than librosa for some reason, also to make it easier for our models to work with TTS voices, because you could just train Tacotron to output STFT and omit the vocoder for now =) ) with these params:

@attr.s(kw_only=True)
class AudioConfig(object):
    window_size = attr.ib(default=0.02)
    window_stride = attr.ib(default=0.01)
    sample_rate = attr.ib(default=16000)
    ....

self.n_fft = int(audio_config.sample_rate * (audio_config.window_size + 1e-8))
self.hop_length = int(audio_config.sample_rate * (audio_config.window_stride + 1e-8))

self.stft = STFT(self.n_fft,  # 320
                         self.hop_length,  # 160
                         self.n_fft)  # 320

This produces 161-sized feature maps.

@mpariente

Many thanks for your code and paper.
I will try to understand the deeper logic behind it.

Meanwhile a couple of basic dumb questions to do some quick and dirty experiments.

but I should write how to use it somewhere, sorry.

Or I guess I can also use this constructor you provided here, right?
Also, I just pass params like n_filters / kernel_size / stride as kwargs, right?

your filters contain a magnitude less parameters
This is because the stride is very small, I can suggest using analytic filterbanks (from which shift invariant representation can be computed) with larger stride to solve this problem.

Could you please give a very rough and dirty memory and compute comparison of the filter-banks available in your repo?

I.e. how to make this table correct?

  params compute memory
FreeFB / Conv1D 1 1 1
STFTFB 1 1 1
AnalyticFreeFB ? ? ?
ParamSincFB 0.1 ? N

Also maybe could you provide some guidance on the default params I should start with for each filter-bank given our params above?

This is because the stride is very small, I can suggest using analytic filterbanks (from which shift invariant representation can be computed) with larger stride to solve this problem.

Are you referring to the stride used in SincNet by default?

Have you tried using STFT, but actually training the filters?
This is exactly what the FreeFB and AnalyticFreeFB are doing in the code I linked to.

Judging from the code, FreeFB is just Conv1D and STFTFB is just Conv1D initialized with STFT.
But what is AnalyticFreeFB?

@snakers4
Copy link
Author

@mpariente

Also forgot to ask.
If I did not miss anything in your code, even if I use STFTFB, it will be trainable?
So to do a proper ablation test, I need to freeze it separately, right?

@mpariente
Copy link

Now we are just using this STFT implementation (it works on CPU faster than librosa for some reason, also to make it easier for our models to work with TTS voices, because you could just train Tacotron to output STFT and omit the vocoder for now =) ) with these params:

@attr.s(kw_only=True)
class AudioConfig(object):
    window_size = attr.ib(default=0.02)
    window_stride = attr.ib(default=0.01)
    sample_rate = attr.ib(default=16000)
    ....

self.n_fft = int(audio_config.sample_rate * (audio_config.window_size + 1e-8))
self.hop_length = int(audio_config.sample_rate * (audio_config.window_stride + 1e-8))

self.stft = STFT(self.n_fft,  # 320
                         self.hop_length,  # 160
                         self.n_fft)  # 320

This produces 161-sized feature maps.

The STFT implementation in asteroid is roughly equivalent. Few points where it's different :

  • The forward and backward basis are the same (this saves a little bit of space).
  • The output of the encoding is the concatenation of the real and imaginary parts instead of the magnitude and the phase. That's just a choice and these are equivalent.
  • The main difference is the following : I never divide by the window's envelope. I should do some proper experiments on that and write it somewhere because dividing by the window's envelope is done everywhere. Even if it enables unit-test level perfect reconstruction, in my experience, it hurts performance for speech enhancement and source separation applications by introducing amplitude modulation artifacts to the resynthesis.
  • np.sqrt(hanning) is used by default, it has better reconstruction that hanning (if we don't divide by the window's envelope).
  • I truncate the STFT filters instead of padding them so that I don't convolve with a lot of zeros in case of over-complete STFT (it should be a bit faster).

Overall, it is very similar.

If I did not miss anything in your code, even if I use STFTFB, it will be trainable?
So to do a proper ablation test, I need to freeze it separately, right?

It is not trainable, register_buffer is used, not register_parameter. A buffer is just an object which will be saved in the state_dict and be transferred to CPU or GPU along with the model but it doesn't appear in model.parameters() and is not updated by the optimzer.

@mpariente

Many thanks for your code and paper.
I will try to understand the deeper logic behind it.

Meanwhile a couple of basic dumb questions to do some quick and dirty experiments.

but I should write how to use it somewhere, sorry.

Or I guess I can also use this constructor you provided here, right?
Also, I just pass params like n_filters / kernel_size / stride as kwargs, right?

Yes exactly, make_enc_dec will provide an encoder and a decoder. To be equivalent to your piece of code, you can use:

self.stft = make_enc_dec('stft', n_filters=self.n_fft,
                         kernel_size=self.n_fft,
                         stride=self.hop_length)[0]

This will produce 322 features (the output shape is in self.stft.n_feats_out) and you can use this to take the magnitude or self.stft.post_process_inputs if you specified inp_mode='mag' in the encoder. And it will produce the same 161-dim features.

Also maybe could you provide some guidance on the default params I should start with for each filter-bank given our params above?

You can use exactly the same parameters for all filterbanks and it makes sense, that's the nice thing about it. So you can have a fair comparison between these filterbanks.

This is because the stride is very small, I can suggest using analytic filterbanks (from which shift invariant representation can be computed) with larger stride to solve this problem.

Are you referring to the stride used in SincNet by default?

Yes.

Have you tried using STFT, but actually training the filters?
This is exactly what the FreeFB and AnalyticFreeFB are doing in the code I linked to.

Judging from the code, FreeFB is just Conv1D and STFTFB is just Conv1D initialized with STFT.
But what is AnalyticFreeFB?

FreeFB will be the regular Conv1D without bias, that's right.
The STFTFB is also a Conv1D but the filters are the one from the STFT (they are fixed, not learnable)
AnalyticFreeFB is a learned analytic filterbank. I also ends up being an Conv1D but with analycity constraints.
ParamSincFB is the analytic extension of the original SincNet filters, it also is a Conv1D.

Could you please give a very rough and dirty memory and compute comparison of the filter-banks available in your repo?

Regarding the table, if the same values are used for n_filters, kernel_size and stride, the compute and memory are the same for all, only the number of parameter per filter changes.

params/filters
FreeFB / Conv1D kernel_size
STFTFB 0
AnalyticFreeFB kernel_size / 2
ParamSincFB 1

@snakers4
Copy link
Author

snakers4 commented Nov 23, 2019

It is not trainable, register_buffer is used, not register_parameter. A buffer is just an object which will be saved in the state_dict and be transferred to CPU or GPU along with the model but it doesn't appear in model.parameters() and is not updated by the optimzer.

Following this logic, if I take the nvidia implementation or your implementation and substitute register_buffer with register_parameter, i.e.:

        if is_trainable:
            self._filters = nn.Parameter(filters)
        else:
            self.register_buffer('_filters', filters)

STFTFB | 0

This is zero because of register_buffer?

This will produce 322 features (the output shape is in self.stft.n_feats_out) and you can use this to take the magnitude or self.stft.post_process_inputs if you specified inp_mode='mag' in the encoder. And it will produce the same 161-dim features.

The best practical models that we have have a constant width of 512, so it hardly would make a difference in terms of complexity.

But for some reason, in all implementations of STT that I saw the phases were always discarded. I wonder why?

@mpariente
Copy link

Following this logic, if I take the nvidia implementation or your implementation and substitute register_buffer with register_parameter, i.e.:

        if is_trainable:
            self._filters = nn.Parameter(filters)
        else:
            self.register_buffer('_filters', filters)

Correct.

STFTFB | 0

This is zero because of register_buffer?

Correct. Note that the nvidia implementation will have weird behavior in this case for over-complete STFT (i.e. n_filters > kernel_size) because the filters are padded with a lot of zeros, and these zeros will become parameters so your effective kernel_size will be much bigger in this case. If you always use kernel_size = n_filters, it will be fine.

This will produce 322 features (the output shape is in self.stft.n_feats_out) and you can use this to take the magnitude or self.stft.post_process_inputs if you specified inp_mode='mag' in the encoder. And it will produce the same 161-dim features.

The best practical models that we have have a constant width of 512, so it hardly would make a difference in terms of complexity.

But for some reason, in all implementations of STT that I saw the phases were always discarded. I wonder why?

Because the phase has some non-local properties, it is invariant by global and 2pi rotation which are difficult to model with DNN. The paper I linked to shows that modern DNN can benefit from using the complex representation (depending on the size of the window) because the modelling capabilities are higher now.

@snakers4
Copy link
Author

Many thanks for your explanations.

After we complete the current batch of experiments, I guess will add the following experiments to the queue:

  • (2 more long experiments now in the queue)
  • (current baseline) nvidia stft (CPU)
  • (new) STFTFB fixed (CPU or GPU)
  • (new) STFTFB learnable (GPU)
  • (new) FreeFB learnable (GPU)
  • (new) AnalyticFreeFB learnable (GPU),
  • (new) ParamSincFB learnable (GPU)
  • Also for the best option I will try to fiddle with the number of channels - what if, for example, with SincNet we could afford 5x filters (161 * 4-5 also looks kind of similar to the networks' hidden size)?

Each experiment is best when you run it at least for ~24 hours, so it will take between 1 and 2 weeks to test all this. Also it will be very interesting to see how transferring the first convolution from CPU to GPU would affect speed and io - now my networks are a bit slower than my io, but I guess there will be some trade-offs here.

@mravanelli
Copy link
Owner

@snakers4, if you want I can also share with you the tunable filter-banks that I'm designing for the SpeechBrain project. They are extremely compact and these days I'm doing several tests to check their performance.

@snakers4
Copy link
Author

@mravanelli

Yeah, why not, since I will be doing a grid-search anyway
A friend of mine probably will also contribute something wavelet related @pollytur

@mravanelli
Copy link
Owner

mravanelli commented Nov 25, 2019 via email

@snakers4
Copy link
Author

snakers4 commented Nov 26, 2019

@mpariente

Just to confirm that I am using you fbanks correctly

from filterbanks import make_enc_dec
from filterbanks.pytorch_stft import STFT  ## nvidia stft

# in the init of the pytorch model
        if self._fbank_type != 'none':
            n_fft = int(sample_rate * (window_size + 1e-8))
            hop_length = int(sample_rate * (window_stride + 1e-8))
            fb_kwargs = {'inp_mode': 'mag',
                         'n_filters': n_fft,
                         'kernel_size': n_fft,
                         'stride': hop_length}
            if self._fbank_type == 'nvidia_stft_fixed':
                self.fbank = STFT(n_fft,
                                  hop_length,
                                  n_fft,
                                  is_trainable=False)
            elif self._fbank_type == 'nvidia_stft_trainable':
                self.fbank = STFT(n_fft,
                                  hop_length,
                                  n_fft,
                                  is_trainable=True)                                  
            elif self._fbank_type == 'stftfb_fixed':
                self.fbank, _ = make_enc_dec('stft',
                                             **{fb_kwargs,
                                                {'is_trainable': False}})
            elif self._fbank_type == 'stftfb_trainable':
                self.fbank, _ = make_enc_dec('stft',
                                             **{fb_kwargs,
                                                {'is_trainable': True}})
            elif self._fbank_type in ['free', 'analytic_free', 'param_sinc']:
                self.fbank, _ = make_enc_dec('free',
                                             **fb_kwargs)
            else:
                raise NotImplementedError()

# in forward
# x is (batch, num_samples)
        if self._fbank_type != 'none':
            if self._fbank_type in ['nvidia_stft_fixed', 'nvidia_stft_trainable']:
                x, phases = self.stft.transform(x)
            elif self._fbank_type in ['stftfb_fixed', 'stftfb_trainable',
                                      'free', 'analytic_free', 'param_sinc']:
                batch_size = x.size(0)
                num_samples = x.size(1)                                      
                x = x.view(batch_size, 1, num_samples)
                x = self.fbank.post_process_inputs(self.fbank(x))
            else:
                raise NotImplementedError()

Also I have 2 types of augmentations for spectrograms:
(0) An augmentation mimicking phone calls, where I just discard 50% of my frequencies;
(1) Something similar to SpecAugment (I just cut out random bands of frequencies / strips of time);

Do I understand correctly, that I can use (1) with any type of your fbanks, but I can use (0) only with STFT?

@mpariente
Copy link

mpariente commented Nov 26, 2019

Most of it looks fine, the last filterbank definition should be

    elif self._fbank_type in ['free', 'analytic_free', 'param_sinc']:
        self.fbank, _ = make_enc_dec(self._fbank_type,
                                     **fb_kwargs)

to take into account the filterbank type right?
For the data augmentation, (1) does make sense and (0) doesn't, in the sense that discarding half the filters just amounts to dividing the number of filters by two but all the frequencies will still be covered (also holds for the learnable STFT).

@snakers4
Copy link
Author

to take into account the filterbank type right?

correct, just a typo

discarding half the filters just amounts to dividing the number of filters by two but all the frequencies will still be covered (also holds for the STFT).

Hm, correct me here, but if I always discard the "upper" half, i.e. stft[81:] = 0, isn't it equivalent to actually changing the effective sampling rate from 16k (most STT applications) to 8k (phone calls)?
This was the idea behind this

@mpariente
Copy link

Yes, you're right, I meant the learnable STFT, not the fixed one sorry. I edit the answer above to correct for this.

@snakers4
Copy link
Author

snakers4 commented Dec 10, 2019

I have completed my tests and started the learnable frontend experiments. It is very early to draw any formiddable conclusions yet, but:

  • Nvidia STFT fixed + GPU. Looks like that running STFT on GPUs provides some speed boost in my case, but I need to verify by running longer (at least ~24 hours per experiment). Looks like that I am bordering on being IO bound now, whereas in the past I was bound by my network (which is good)

  • Nvidia STFT trainable + GPU. Looks like when I start training my network from scratch with trainable pre-initialized STFT from nvidia (I just made these changes) with all of the other params being equal, the network explodes after a couple of batches

    • I have tried some things that have helped me in the past, so far without success:
    • lowering LR for the fbank
    • not applying the gradient clipping
    • decreasing the gradient clipping norm
        self.is_trainable = is_trainable
        if self.is_trainable:
            self.forward_basis = torch.nn.Parameter(forward_basis.float())
            self.inverse_basis = torch.nn.Parameter(inverse_basis.float())
        else:
            self.register_buffer('forward_basis', forward_basis.float())
            self.register_buffer('inverse_basis', inverse_basis.float())

...
 
        if self.is_trainable:
            forward_transform = F.conv1d(
                input_data,
                self.forward_basis,
                stride=self.hop_length,
                padding=0)
        else:   # looks like this was written for PyTorch 0.4
            forward_transform = F.conv1d(
                input_data,
                Variable(self.forward_basis, requires_grad=False),
                stride=self.hop_length,
                padding=0)
fbank cuda:1 55.69 0.0 0.23 1.14
after fbank cuda:1 13.99 -5.26 0.00 5.74
fbank cuda: 0 55. 0.0 0.23 1.16
after fbank cuda:0 13.90 -5.30 -0.00 5.71

after one batch update
GPU-0 Epoch 1 [1/18082] Time 5.22 (5.22)        Data 2.75 (2.75)        Loss 37.69 (37.69)

fbank cuda:0 inf inf inf nan
after fbank cuda:0 nan nan nan nan
fbank cuda:1 inf inf inf nan
after fbank cuda:1 nan nan nan nan
WARNING: Working around NaNs in data
  • STFT FB fixed+ GPU - seems to work as well

  • STFT FB trainable + GPU. Looks like when I use @mpariente 's fbanks, but I set STFT to be trainable, it does not explode / misbehave outright, I need more time to confirm. Also if I do not setup the fbank LR to be 10-100x lower than the main network LR, it also explodes.

@mpariente
I think I will start by just comparing my current pipeline vs nvidia stft on gpu vs your stft on gpu

Also there is a small issue (?) with your implementation
I am running a multi-GPU setup with DataParallel / DistributedDataParallel
And unless I apply this hack to your code, it fails to use the buffers properly (weights are not on GPU, but tensors are)
I guess it may a problem for sinc net

class _EncDec(SubModule):
    """ Base private class for Encoder and Decoder.
    def get_filters(self):
        """ Returns filters or pinv filters depending on `is_pinv` attribute """
        if self.is_pinv:
            # return self.filters
            return self.compute_filter_pinv(self.filters)  # I changed this line 
        else:
            return self.filterbank._filters

@mpariente
Copy link

@snakers4 Thanks for the update !
Did you use any pseudo-inverse in your code?

I think it is related to this issue which was fixed. The code is a bit different. Could you try with the current version and tell me if there is still a problem please?

@snakers4
Copy link
Author

snakers4 commented Dec 11, 2019

Did you use any pseudo-inverse in your code?

@mpariente

No, we are just doing STT, we just need spectrograms

I think it is related to this issue which was fixed.

Many thanks for the pointer, I will use this fix when running SincNet

First results

As for the first results, looks like running plain STFT on CPU / GPU does not make a real difference, while tuning the STFT poses the challenges noted above

I will tune for a bit more to confirm, then I guess I will try stft_fb / free filterbank / sincnet (unless they also explode) in this order

Val set 1

(actually 3 different val sets mashed together)
image

Val set 2

image

@snakers4
Copy link
Author

I could run the above experiment for ~5-10 additional hours, but it did not change much
Since I am using very random sampling for curriculums, the charts above are hardly different
(when the models are different is obvious even after 20-30 hours of training)

Also tried a few other things

  • free_fb
  • param_sinc

With / without this normalization (which works well with STFT)

class AdaptiveAudioNormalization(nn.Module):
    def __init__(self, sigma=20, truncate=4.0):
        super(AdaptiveAudioNormalization, self).__init__()

        filter_ = self.get_gaus_filter1d(sigma, truncate)
        self.register_buffer('filter_', filter_)
        self.reflect = torch.nn.ReflectionPad1d(sigma * int(truncate))

    def forward(self, spect):
        spect = torch.log1p(spect * 1048576)
        mean = spect.mean(dim=1, keepdim=True)
        mean = self.reflect(mean)
        mean = F.conv1d(mean, self.filter_)
        mean_mean = mean.mean(dim=-1, keepdim=True)
        spect = spect.add(-mean_mean)
        return spect

    @staticmethod
    def get_gaus_filter1d(sigma, truncate=4.0):
        sd = float(sigma)
        lw = int(truncate * sd + 0.5)
        sigma2 = sigma * sigma
        x = np.arange(-lw, lw+1)
        phi_x = np.exp(-0.5 / sigma2 * x ** 2)
        phi_x = phi_x / phi_x.sum()
        return torch.FloatTensor(phi_x.reshape(1, 1, -1))

With varying LR / optimizer

  • SGD with LR 1e-3 / 1e-4 / 1e-5
  • ADAM with LR 1e-4 / 1e-5

For any combinations, the network exploded either outright or after several hundred batches

I mostly did not change the other pipeline, so maybe the current pipeline is over-optimized for STFT and I need to start some simpler experiments

I have no other direction other than to try all of these experiments on some very eazy dataset (e.g. some small validation subset)

@mpariente
Copy link

Thanks a lot for the detailed experiments and results.
I wonder, did you check what caused the models to explode? Do you get NaN losses or simply garbage outputs. If NaN losses, does it start with a NaN loss or a NaN gradient?
For the analytic free filterbank, when I worked at 16kHz with noisy inputs, I had to catch NaN gradients for it to work, and it ended up converging fine.
Also, for the free filterbank, it doesn't really make sense to take the magnitude, as it is not complex.. But I'm not sure it's going to change anything.

@snakers4
Copy link
Author

Do you get NaN losses or simply garbage outputs. If NaN losses, does it start with a NaN loss or a NaN gradient?

The symptom obviously is NaN losses, but I had no time to check what exactly is causing it.
For some setups (tunable stft) I checked - the first convolution itself exploded

fbank cuda:0 inf inf inf nan (max min mean std)
after fbank cuda:0 nan nan nan nan  (max min mean std)
WARNING: Working around NaNs in data (working around NaNs in losses)

For the analytic free filterbank, when I worked at 16kHz with noisy inputs, I had to catch NaN gradients for it to work, and it ended up converging fine.

There is gradient clipping in my pipeline, I set it to some high values (100-200), otherwise the network converges slowly (I arrived at this value after grid-searching). I work around NaNs in losses, but not explicitly in gradients.

Could you maybe share a snippet on how you do it for the gradients? Many thanks!

Also, for the free filterbank, it doesn't really make sense to take the magnitude, as it is not complex.. But I'm not sure it's going to change anything.

Yeah, basically we are taking 1/2 of the convolution, but I belive the root cause the cause is elsewhere

with noisy inputs

Our data obviously is "in the wild" and noisy.
Maybe when we will refocus on improving our models, I will run some experiments training on small and clean datasets,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants