Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test and fix Backward/Deriv #3478

Closed
wants to merge 1 commit into from
Closed

Conversation

mrdaybird
Copy link
Contributor

@mrdaybird mrdaybird commented Apr 28, 2023

Background

The Backward/Deriv method has several times being misused or incorrectly implemented.(See #3471, #3469) So, to make it clear:

The Backward method of a layer(except loss functions) expects the first argument to be the output of the Forward method. The same is true for activation functions, in which Deriv method expects the first argument to be the output of Fn.

To take an example of a correctly implemented function, let's look at LogisticFunction(logistic_function.hpp). Let's first look at the Fn method:

  57   template<typename InputVecType, typename OutputVecType>
  58   static void Fn(const InputVecType& x, OutputVecType& y)
  59   {
  60     y = (1.0 / (1 + arma::exp(-x)));
  61   }

Now look at the Deriv method:

  80   template<typename InputVecType, typename OutputVecType>
  81   static void Deriv(const InputVecType& y, OutputVecType& x)
  82   {
  83     x = y % (1.0 - y);
  84   }

Observe that derivative in the Deriv method is written in terms of output, not in terms of input.

To put it mathematically,

If $y=f(x)$ is the given function, then the derivative should be function of $y$ i.e.
$$\frac{dy}{dx} = g(y)$$

In the example, our function $y=f(x)$ is defined as follows,
$$f(x) = \frac{1}{1 + e^{-x} }$$

Now, if we simply differentiate it w.r.t. $x$,
$$f'(x) = \frac{e^{-x}}{(1+e^{-x})^2}$$

If we arrange the terms, we can write the derivative as,
$$f'(x) = f(x)(1-f(x))$$
substitute f(x) with y,
$$\frac{dy}{dx} = y(1-y)$$
which is our required form of derivative.

Add test and fix issues

Since JacobianTest is fixed(#3471), all layers need to be tested against JacobianTest(and/or its siblings). After adding JacobianTest for activation functions, lots of the test failed. Also potentially, other layers may be buggy.

The aim of this PR, is to:

  1. Add tests to the activation function, loss functions and the layers.
  2. Document known issues.
  3. Fix issues.

The following layers/functions fail the JacobianTest and need to fix their Backward/Deriv method:

(The list is WIP and may be updated)

  1. Softplus function
  2. Mish function
  3. LiSHT function
  4. GELU function
  5. Elliot function
  6. Elish function
  7. Inverse Quadratic function
  8. Quadratic function
  9. Multiquad function
  10. Poisson1 function
  11. Gaussion function
  12. Hard Swish function
  13. Tanh Exp
  14. SILU function

@mlpack-bot
Copy link

mlpack-bot bot commented May 29, 2023

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label May 29, 2023
@mlpack-bot mlpack-bot bot closed this Jun 5, 2023
@rcurtin
Copy link
Member

rcurtin commented Jun 9, 2023

mlpack-bot closed this one for inactivity, but if this does eventually get finished (be it in this PR or another one) I think it would be a valuable addition.

@mrdaybird
Copy link
Contributor Author

Sorry, I have been pretty busy past month. Anyway, I have time now to complete this.

Here's mine initial findings.
Most(almost all) of the functions mentioned cannot be written in terms of their output. Taking one of the easier functions as an example, Quadratic function. (y = x^2). To write the function in terms of its output, it has to be invertible, but this function is not bijective(since it is not injective), hence not invertible. This was one of easier examples, if you look that something like GELU function(below). It's even more complicated to find the inverse, if it exists.

y = 0.5 * x % (1 + arma::tanh(std::sqrt(2 / M_PI) *
(x + 0.044715 * arma::pow(x, 3))));

I think the best way to move forward with fixing derivative of these functions, would be to:

  1. Change the BaseLayer(base_layer.hpp) class and calculate the derivatives inside the Forward method using the input and store the result. (look at celu_impl.hpp )
  2. Rewrite the derivative of functions, other than these, to take the input as the argument.

To me this solution feel like going against the flow, but results in a more coherent interface in terms of how the activation functions are implemented.

There is another solution, which would involve not changing the BaseLayer but implementing another class like BaseLayer, that calculates the derivative in the forward and stores it in the class.

So, the functions which are not invertible can be implemented using this class and not BaseLayer. This would result in activation layer being implemented using two types of layers, 1. layer whose derivative is calculated using output 2. layer whose derivative is calculated using input.
Advantage of this solution is that we can utilize the symmetry that some of the functions provide, like sigmoid layer can be very easily calculated,

Regardless of whichever solution is chosen, there will be clear difference in how the Backward and Deriv should in implemented in the future.

The solutions are non-trivial (another one of the reason this PR got pushed so long), so I thought I should get your opinion on this? So, what do you think @rcurtin?

@rcurtin
Copy link
Member

rcurtin commented Nov 1, 2023

@mrdaybird I'm sorry this sat for so long; things have been busy for me over the summer and fall. This also came up in #3551 and so I want to say thank you so much for the clear explanation here, and the effort of digging into each layer; it helped me get a complete handle on the issue quickly. I'm not sure yet what the best solution is, but you're still interested and have time let's discuss in #3551 and figure out how we want to solve it. 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants