Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nan in SRU output #185

Open
ksopyla opened this issue May 7, 2021 · 10 comments
Open

nan in SRU output #185

ksopyla opened this issue May 7, 2021 · 10 comments

Comments

@ksopyla
Copy link

ksopyla commented May 7, 2021

I try to train a RNN network (seq2seq) with GRU and SRU cells. When training is done with GRU everything is ok, loss is decreasing an accuracy steadily rise. But when switch to GRU after few hours i got NAN in loss and norm of network params (hidden states, weight matrices) is nan.

I use https://github.com/asappresearch/sru/tree/3.0.0-dev branch and perform training with 2 geforce 3090 with CUDA 11.1, pytorch 1.8 and pytorch-lightning.

Could you point me the direction how to diagnose those bugs? It is bug in sru it self, GPU related or CUDA?

SRU is defined

        self.rnn = SRU(
            emb_dim,
            enc_hid_dim,
            num_layers=4,
            dropout=0.0,  # dropout applied between RNN layers
            bidirectional=True,  # bidirectional RNN
            layer_norm=True,  # apply layer normalization on the output of each layer
            normalize_after = True, # if True use post layer norm, else use pre layer norm
            highway_bias=0,  # initial bias of highway gate (<= 0)
            rescale=True,  # whether to use scaling correction
        )

I have also noticed lack of speedup relative to GRU, GRU is even faster ( 2.6it/s vs 1.4it/s)

@taoleicn
Copy link
Contributor

hi @ksopyla

Thank you for trying SRU.
I suspect the training encountered gradient explosion. Here are some changes that can possibly make the training more stable:

  • switch to pre-layer norm, normalize_after=False
  • disable rescaling, rescale=False
  • adding small l2 regularization / weight decay to the training
  • for deeper models, we usually use highway_bias=-2, but you use 4 layers so i'm not sure if this would help.

Re: speed up. It is hard to diagnose the speed without knowing what the task is and how the training is implemented. Usually SRU would run significantly faster if each forward() call takes multiple tokens and multiple sequences at once (instead of giving 1 token per sequence each forward() call).

@nicolaspanel
Copy link

thanks @taoleicn for your answer and for the great work !
I'm facing the same issue as @ksopyla: training works fine for many hours before nan occur
NOTE: The problem seems related to the forward pass since at first only a single sample has nan loss.
I tried all 4 options but none of them worked
I'm using torch==1.8.1+cu111 and git+https://github.com/asappresearch/sru.git@9ddc8da12f067125c2cfdd4f3b28a87c02889681

@nicolaspanel
Copy link

FYI, I switched to sru==2.5.1 and it seems to work fine now

@taoleicn
Copy link
Contributor

taoleicn commented Jun 2, 2021

hi @nicolaspanel ,
I wanna to troubleshoot the issue you've experiencing. Could you provide more details?
e.g. are you using SRU or SRUpp? are you using mixed precision / fp16 training? any observation or hints suggesting the problem occurs in the forward pass instead of backward?

Thank you!

@nicolaspanel
Copy link

nicolaspanel commented Jun 2, 2021

FYI, I switched to sru==2.5.1 and it seems to work fine now

After more investigations, nan also happen with sru==2.5.1

hi @nicolaspanel ,
I wanna to troubleshoot the issue you've experiencing. Could you provide more details?

off course 👍

e.g. are you using SRU or SRUpp?

I'm using SRU with

  • dropout=0.1
  • layer_norm=True
  • rescale=True
  • amp_recurrence_fp16=True,

are you using mixed precision / fp16 training?

Yes, using build-in amp.autocast()

any observation or hints suggesting the problem occurs in the forward pass instead of backward?

Thank you!

Since yesterday I have found a way to reproduce and investigate by generating a model checkpoint as soon loss contains a nan value, and before calling loss.backward()

Point of Interest 1: only one sample has loss=nan, other batch samples have a finite value.

Point of interest 2: checkpoint parameters are fine (no infinite/nan values)

When I execute the following code:

with amp.autocast():
  y = net(x)

y contains nan values

Point of interest 3: y has nan values only for some timesteps (time dimension is in the vertical axis in the picture below)
image

Point of interest 4: all outputs are finite if amp is disabled which indicate a numerical instability issue

Point of interest 5: no nan values when normalization is disabled, using net.eval() (may not be relevant though)

Point of interest 6: hidden features values have some very large values (+100, +200, -95, etc…), especially in the last layer
=> probably the cause of the numerical instabilities
=> weird though since layer_norm and weight_decay=1e-3 are used during training

Possible workaround: I will try to clip values to stay in range [-10; 10]

Hope this help
Best regards

@taoleicn
Copy link
Contributor

taoleicn commented Jun 2, 2021

Thank you @nicolaspanel !

Would disabling rescale and/or amp_recurrence_fp16 help?

I've lately found rescale to cause instability issues so we've made it False by default now. amp_recurrence_fp16 indicates if the recurrence kernel should run in half precision or full precision. We found using half precision to work well in the tasks we tried but it may not in your case.

Re: point of interest 5. Do you mean training works fine if layer norm is not used? Do you have zero-vector inputs in x? Not sure if it is zero vectors causing an issue for layer normalization..

@nicolaspanel
Copy link

nicolaspanel commented Jun 3, 2021

Thank you @nicolaspanel !

👍

Would disabling rescale and/or amp_recurrence_fp16 help?

I ran experiments with rescale=False => same issue

I've lately found rescale to cause instability issues so we've made it False by default now. amp_recurrence_fp16 indicates if the recurrence kernel should run in half precision or full precision. We found using half precision to work well in the tasks we tried but it may not in your case.

I didn't try with amp_recurrence_fp16=Falsebut, even if it may solve the issue during SRU's internal computations, having bigger and bigger numbers (absolute values) as we go through network layers will lead to numerical overflow anyway.

BTW I did some experiments clipping values to make sure they stay in range [-10, 10]. It works just fine (so far at least)

Re: point of interest 5. Do you mean training works fine if layer norm is not used?

NO, I tried disabling layer_norm during training => same issue

The fact that I do not have nan values when model is in inference mode (net.eval()) is probably a side effect of having no dropout, hiding the issue for the «reproduction sample» I used.

Do you have zero-vector inputs in x? Not sure if it is zero vectors causing an issue for layer normalization.

Yes x has some zero-vector inputs

@taoleicn
Copy link
Contributor

taoleicn commented Jun 4, 2021

hi @nicolaspanel
Thank you.

weird though since layer_norm and weight_decay=1e-3 are used during training

Are you using AdamW or RAdam optimizer in which a weight_decay issue has been fixed compared to Adam? In recent language model experiments I used weight decay = 0.1 or 0.01. I haven't used value clipping before. (gradient clipping has been always used)

BTW I did some experiments clipping values to make sure they stay in range [-10, 10]. It works just fine (so far at least)

There is also an use_tanh option which will apply tanh activation to the hidden states. This may serve as bounding the output values. It may hurt the performance though. i didn't experiment with this option a lot.

Point of interest 3: y has nan values only for some timesteps (time dimension is in the vertical axis in the picture below)

Is the horizontal axis the index of the layer? (in the picture you sent earlier) Does it show that nan values are generated even after the first layer?
Is this pytorch issue related? pytorch/pytorch#41527

@nicolaspanel
Copy link

hi @nicolaspanel
Thank you.

weird though since layer_norm and weight_decay=1e-3 are used during training

Are you using AdamW or RAdam optimizer in which a weight_decay issue has been fixed compared to Adam? In recent language model experiments I used weight decay = 0.1 or 0.01.

I use https://github.com/LiyuanLucasLiu/RAdam

I haven't used value clipping before. (gradient clipping has been always used)

I haven't used value clipping before neither but it this case it helped (no more nan since).

BTW I did some experiments clipping values to make sure they stay in range [-10, 10]. It works just fine (so far at least)

There is also an use_tanh option which will apply tanh activation to the hidden states. This may serve as bounding the output values. It may hurt the performance though. i didn't experiment with this option a lot.

I will stick to value clipping because the only goal is to prevent overflows in mixed precision setup, not adding an extra activation.

Point of interest 3: y has nan values only for some timesteps (time dimension is in the vertical axis in the picture below)

Is the horizontal axis the index of the layer? (in the picture you sent earlier) Does it show that nan values are generated even after the first layer?

No. The horizontal axis is the «feature» axis. Picture displays network's logits (output of the last - fully connected - layer) for a single sample (ie batch size = 1)

Is this pytorch issue related? pytorch/pytorch#41527

No I don't think so

@ksopyla is it possible for you to check your network's intermediate representations to see if they contain very high absolute values (greater than 100 or lower than 100 for example)?

Best regards

@v-nhandt21
Copy link

v-nhandt21 commented Dec 3, 2021

sru==2.5.1

@nicolaspanel How can you run SRU with Cuda 11.1!

I also have RTX3090 but I reach error (--generate-dependencies-with-compile), It seems sru_cuda_kernel.cu just run in CUDA10.2, Do you have the same problem, Can you share how you fix it !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants