Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical instability parameterization tricks #73

Open
kieferk opened this issue Sep 5, 2020 · 0 comments
Open

Numerical instability parameterization tricks #73

kieferk opened this issue Sep 5, 2020 · 0 comments

Comments

@kieferk
Copy link

kieferk commented Sep 5, 2020

First of all I just want to say that your WTTE is really cool. Great blog post and paper. I've been using an adapted version of it for a time-to-event task and wanted to share a trick I've found useful for numerical instability issues in case you or anyone else is interested.

A couple of things to note in my case:

  • I'm not using an RNN, since I have sufficient engineered features for the history at a point in time.
  • I rewrote it in pytorch, so my code here is in pytorch.
  • My case uses the discrete likelihood. I haven't tested anything for the continuous case but I don't see why it wouldn't work there too.

While testing it I had a lot of issues with nan loss and numeric instability during the fit of alpha and beta. I know you've worked a lot on this from reading the other github issues.

I've found that this parameterization for alpha and beta helps a lot:

class WTTE(nn.Module):
    
    def __init__(self, nnet_output_dim):
        super(WTTE, self).__init__()
        
        # this is the neural net whose outputs then are used to find alpha and beta
        self.nnet = InnerNNET()

        self.softplus = nn.Softplus()
        self.tanh = nn.Tanh()

        self.alpha_scaling = nn.Linear(nnet_output_dim, 1)
        self.beta_scaling = nn.Linear(nnet_output_dim, 1)

        # offset and scale parameters
        alpha_offset_init, beta_offset_init = 1.0, 1.0
        alpha_scale_init, beta_scale_init = 1.0, 1.0

        self.alpha_offset = nn.Parameter(tt.from_numpy(np.array([alpha_offset_init])).float(), requires_grad=True)
        self.beta_offset = nn.Parameter(tt.from_numpy(np.array([beta_offset_init])).float(), requires_grad=True)
        
        self.alpha_scale = nn.Parameter(tt.from_numpy(np.array([alpha_scale_init])).float(), requires_grad=True)
        self.beta_scale = nn.Parameter(tt.from_numpy(np.array([beta_scale_init])).float(), requires_grad=True)
        
    
    def forward(self, x):
        
        x = self.nnet(x)
        
        # derive alpha and beta individual scaling factors
        a_scaler = self.alpha_scaling(x)
        b_scaler = self.beta_scaling(x)

        # enforce the scaling factors to be between -1 and 1
        a_scaler = self.tanh(a_scaler)
        b_scaler = self.tanh(b_scaler)
        
        # combine the global offsets and scale factors with individual ones
        alpha = self.alpha_offset + (self.alpha_scale * a_scaler)
        beta = self.beta_offset + (self.beta_scale * b_scaler)

       # put alpha on positive range with exp, beta with softplus
        alpha = tt.exp(alpha)
        beta = self.softplus(beta)

        return alpha, beta

Essentially why this helps is that the tanh activation function enforces the individual/observation scaling factors to always be between -1 and 1, so you don't have to worry about too small or large outputs from your network. The alpha_scale and beta_scale are responsible for setting the range to multiply the -1 to 1 outputs by. The offsets are nice as an intercept or centering mechanism.

If you set the initialization for the offsets and scaling factors to be low numbers (I start them at 1.0, for example), they will slowly creep up to their optimal values during fit. Here is some output from a recent fit of mine to show what I mean:

A off: 1.10000	A scale: 1.10000	B off: 0.90000	B scale: 1.10000	
A off: 1.23279	A scale: 1.03885	B off: 0.90022	B scale: 0.89804	
A off: 1.25786	A scale: 1.06547	B off: 0.90056	B scale: 0.89798	
A off: 1.28466	A scale: 1.09343	B off: 0.90163	B scale: 0.89878	
A off: 1.34988	A scale: 1.16266	B off: 0.90678	B scale: 0.90290	
A off: 1.44370	A scale: 1.25528	B off: 0.93324	B scale: 0.93015	
A off: 1.53040	A scale: 1.32979	B off: 0.98308	B scale: 0.98226		
...[many epochs later]...
A off: 1.92782	A scale: 1.57879	B off: 2.97086	B scale: 2.55308	
A off: 1.93340	A scale: 1.59249	B off: 3.01380	B scale: 2.59065	
A off: 1.94988	A scale: 1.59956	B off: 3.01739	B scale: 2.54407	
A off: 1.94464	A scale: 1.59733	B off: 3.03923	B scale: 2.55807	
A off: 1.95629	A scale: 1.60365	B off: 3.06733	B scale: 2.58807	
A off: 1.95865	A scale: 1.59092	B off: 3.09355	B scale: 2.60830

You could also enforce maximums on alpha and beta easily if you wanted to by adding torch.clamp calls around the outputs, but I have not found this to be necessary.

I have only tested this on my own data and so I can't make any claims that this will solve numerical instability issues for other people, but I figured it may help someone!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant