Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Recurrent Neural Networks (RNNs) in KFACOptimizer Class #3683

Open
neuronphysics opened this issue Nov 21, 2023 · 8 comments

Comments

@neuronphysics
Copy link

馃殌 Feature

I would like to request the addition of support for Recurrent Neural Networks (RNNs) in the KFACOptimizer class. Currently, the KFACOptimizer class works for linear and 2D convolution layers, but it does not support RNN layers. RNNs are widely used in various applications, and adding support for them in the KFACOptimizer class would be highly beneficial.

Proposal

I propose adding an RNN module to the KFACOptimizer class to handle RNN layers. This would involve modifying the KFACOptimizer class to compute the necessary statistics and whitened tensors for RNN layers, similar to how it currently handles linear and convolution layers.

Motivation

RNNs are a fundamental component of many deep learning models, especially in tasks involving sequential data, natural language processing, and time-series analysis. Efficient optimization techniques, such as KFAC, can significantly improve the training of RNN-based models by providing better approximations of the Fisher information matrix.

Adding support for RNNs in the KFACOptimizer class would enable researchers and practitioners to apply the KFAC optimization technique to a wider range of models, resulting in improved training efficiency and convergence. This enhancement would benefit the deep learning community by making state-of-the-art optimization techniques more accessible and effective for RNN-based applications.

Additional Context

I have provided an initial implementation attempt in the feature request description, which includes modifications to the KFACOptimizer class to handle RNN layers. I kindly request the repository maintainers to evaluate this implementation, provide feedback, and consider including it in the official codebase if deemed appropriate.

Additionally, I would like to refer the maintainers to the following paper for more details on the topic:

def compute_rnn_whiten_tensor_optimized(cov_f, xcov_f, cov_b, xcov_b, damping_coeff):
    """
    Optimized function to compute whitened tensors for RNNs.

    Parameters:
    cov_f (torch.Tensor): Covariance matrix for forward pass.
    xcov_f (torch.Tensor): Cross-covariance matrix for forward pass.
    cov_b (torch.Tensor): Covariance matrix for backward pass.
    xcov_b (torch.Tensor): Cross-covariance matrix for backward pass.
    damping_coeff (float): Damping coefficient for stability.

    Returns:
    Tuple[torch.Tensor, torch.Tensor]: Whitened tensors for forward and backward passes.
    """

    def compute_whiten_tensor(cov, xcov, damping_correction):
        """ Internal function to compute whitened tensor for one pass. """
        I = torch.eye(cov.shape[0], device=cov.device)
        cov_damped = cov + I * damping_correction * torch.sqrt(damping_coeff)
        inv_cov = torch.linalg.inv(cov_damped)
        B = torch.matmul(xcov, inv_cov)
        Btilde_inv = torch.linalg.inv(I - B)
        cov_tilde = cov_damped
        whitened_tensor = torch.matmul(torch.matmul(Btilde_inv, cov_tilde), Btilde_inv.T)
        return whitened_tensor

    # Compute damping corrections
    damping_correction_f = torch.sqrt(torch.trace(cov_f) / torch.trace(cov_b))
    damping_correction_b = 1.0 / damping_correction_f

    # Compute whitened tensors for forward and backward passes
    whitened_f = compute_whiten_tensor(cov_f, xcov_f, damping_correction_f)
    whitened_b = compute_whiten_tensor(cov_b, xcov_b, damping_correction_b)

    return whitened_f, whitened_b

def compute_rnn_cross_covariance(h):
    """
    Computes the cross-covariance of RNN hidden states.

    Parameters:
    h (torch.Tensor): Hidden states of the RNN layer, shape (seq_len, batch, num_directions * hidden_size).

    Returns:
    torch.Tensor: The cross-covariance matrix.
    """
    seq_len, batch_size, _ = h.shape

    # Flatten the tensor except for the batch dimension
    h_flat = h.transpose(0, 1).reshape(batch_size, -1)

    # Compute cross-covariance
    cross_cov = torch.zeros(h_flat.shape[1], h_flat.shape[1], device=h.device)
    for i in range(seq_len - 1):
        cross_cov += torch.ger(h_flat[:, i], h_flat[:, i+1])
    cross_cov /= (seq_len - 1) * batch_size

    return cross_cov

class KFACOptimizer(optim.Optimizer):
    """"
    This class implement the second order optimizer - KFAC, which uses Kronecker factor products of inputs and the gradients to
    get the approximate inverse fisher matrix, which is used to update the model parameters. Presently this optimizer works only
    on liner and 2D convolution layers. If you want to know more details about KFAC, please check the paper [1]_ and [2]_.

    References:
    -----------
    [1] Martens, James, and Roger Grosse. Optimizing Neural Networks with Kronecker-Factored Approximate Curvature.
    arXiv:1503.05671, arXiv, 7 June 2020. arXiv.org, http://arxiv.org/abs/1503.05671.
    [2] Grosse, Roger, and James Martens. A Kronecker-Factored Approximate Fisher Matrix for Convolution Layers.
    arXiv:1602.01407, arXiv, 23 May 2016. arXiv.org, http://arxiv.org/abs/1602.01407.
    """

    def __init__(self,
                 model: torch.nn.Module,
                 lr: float = 0.001,
                 momentum: float = 0.9,
                 stat_decay: float = 0.95,
                 damping: float = 0.001,
                 kl_clip: float = 0.001,
                 weight_decay: float = 0,
                 TCov: int = 10,
                 TInv: int = 100,
                 batch_averaged: bool = True,
                 mean: bool = False):
        """
        Parameters:
        -----------
        model: torch.nn.Module
            The model to be optimized.
        lr: float (default: 0.001)
            Learning rate for the optimizer.
        momentum: float (default: 0.9)
            Momentum for the optimizer.
        stat_decay: float (default: 0.95)
            Decay rate for the update of covariance matrix with mean.
        damping: float (default: 0.001)
            damping factor for the update of covariance matrix.
        kl_clip: float (default: 0.001)
            Clipping value for the update of covariance matrix.
        weight_decay: float (default: 0)
            weight decay for the optimizer.
        Tcov: int (default: 10)
            The number of steps to update the covariance matrix.
        Tinv: int (default: 100)
            The number of steps to calculate the inverse of covariance matrix.
        batch_averaged: bool (default: True)
            States whether to use batch averaged covariance matrix.
        mean: bool (default: False)
            States whether to use mean centered covariance matrix.
        """

        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError(
                "Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr,
                        momentum=momentum,
                        damping=damping,
                        weight_decay=weight_decay)
        super(KFACOptimizer, self).__init__(model.parameters(), defaults)
        self.batch_averaged = batch_averaged


        self.known_modules = {'Linear', 'Conv2d', 'LSTM', 'GRU'}
        self.modules: List[torch.nn.Module] = []

        self.model = model
        self._prepare_model()

        self.steps = 0

        self.m_aa: Dict[torch.nn.Module, torch.Tensor] = {}
        self.m_gg: Dict[torch.nn.Module, torch.Tensor] = {}
        self.Q_a: Dict[torch.nn.Module, torch.Tensor] = {}
        self.Q_g: Dict[torch.nn.Module, torch.Tensor] = {}
        self.d_a: Dict[torch.nn.Module, torch.Tensor] = {}
        self.d_g: Dict[torch.nn.Module, torch.Tensor] = {}
        self.stat_decay = stat_decay

        self.kl_clip = kl_clip
        self.TCov = TCov
        self.TInv = TInv

        self.mean = mean

    def try_contiguous(self, x: torch.Tensor) -> torch.Tensor:
        """
        Checks the memory layout of the input tensor and changes it to contiguous type.

        Parameters:
        -----------
        x: torch.Tensor
            The input tensor to be made contiguous in memory, if it is not so.

        Return:
        -------
        torch.Tensor
            Tensor with contiguous memory
        """
        if not x.is_contiguous():
            x = x.contiguous()

        return x

    def _extract_patches(
            self, x: torch.Tensor, kernel_size: Tuple[int,
                                                      ...], stride: Tuple[int,
                                                                          ...],
            padding: Union[int, str, Tuple[int, ...]]) -> torch.Tensor:
        """
        Extract patches of a given size from the input tensor given. Used in calculating
        the matrices for the kronecker product in the case of 2d Convolutions.

        Parameters:
        -----------
        x: torch.Tensor
            The input feature maps. with the size of (batch_size, in_c, h, w)
        kernel_size: Tuple[int, ...]
            the kernel size of the conv filter.
        stride: Tuple[int, ...]
            the stride of conv operation.
        padding: Union[int, str, Tuple[int, ...]]
            number of paddings. be a tuple of two elements

        Return:
        -------
        torch.Tensor:
            Extracted patches with shape (batch_size, out_h, out_w, in_c*kh*kw)
        """
        if isinstance(padding, tuple):
            if padding[0] + padding[1] > 0:
                x = torch.nn.functional.pad(
                    x, (padding[1], padding[1], padding[0],
                        padding[0])).data  # Actually check dims
        elif isinstance(padding, int):
            if padding > 0:
                x = torch.nn.functional.pad(
                    x, (padding, padding, padding, padding)).data
        elif isinstance(padding, str):
            if padding == 'VALID':
                pad = int((kernel_size[0] - 1) / 2)
                x = torch.nn.functional.pad(x, (pad, pad, pad, pad)).data

        x = x.unfold(2, kernel_size[0], stride[0])
        x = x.unfold(3, kernel_size[1], stride[1])
        x = x.transpose_(1, 2).transpose_(2, 3).contiguous()
        x = x.view(x.size(0), x.size(1), x.size(2),
                   x.size(3) * x.size(4) * x.size(5))
        return x

    def compute_cov_a(self, a: torch.Tensor,
                      layer: torch.nn.Module) -> torch.Tensor:
        """
        Compute the covariance matrix of the A matrix (the output of each layer).
        The covariance of the activations `aaT`

        Parameters:
        -----------
        a: torch.Tensor
            It is the output of the layer for which the covariance matrix should be calculated.
        layer: torch.nn.Module
            It specifies the type of layer from which the output of the layer is taken.

        Returns:
        --------
        torch.Tensor
            The covariance matrix of the A matrix.
        """
        if isinstance(layer, torch.nn.Linear):
            batch_size = a.size(0)
            if layer.bias is not None:
                a = torch.cat((a, a.new(a.size(0), 1).fill_(1)), 1)

        elif isinstance(layer, torch.nn.Conv2d):
            batch_size = a.size(0)
            a = self._extract_patches(a, layer.kernel_size, layer.stride,
                                      layer.padding)
            spatial_size = a.size(1) * a.size(2)
            a = a.view(-1, a.size(-1))
            if layer.bias is not None:
                a = torch.cat((a, a.new(a.size(0), 1).fill_(1)), 1)
            a = a / spatial_size

        elif isinstance(layer, (torch.nn.LSTM, torch.nn.GRU)):
            # New code for RNN layers
            # Assuming 'a' is the output tensor from the RNN layer
            # You might need additional tensors, depending on your RNN configuration
            if layer.num_layers > 1:
                # Assuming the last dimension is num_layers * num_directions * hidden_size
                a = a.view(a.shape[0], a.shape[1], -1)

            if bidirectional:
                # Split the tensor into two parts, one for each direction
                forward, backward = a.split(a.shape[-1] // 2, dim=-1)
                a = torch.cat([forward, backward], dim=0)
            # Flatten the tensor except for the batch dimension
            a = a.transpose(0, 1).reshape(a.shape[1], -1)
            if layer.bias is not None:
                # Augmenting the activations with a column of ones for biases
                a = torch.cat((a, a.new(a.size(0), 1).fill_(1)), 1)

            batch_size = a.shape[0]

        return a.t() @ (a / batch_size)

    def compute_cov_g(self, g: torch.Tensor,
                      layer: torch.nn.Module) -> torch.Tensor:
        """
        Compute the covariance matrix of the G matrix (the gradient of the layer).
        calculating the covariance of gradients `dLdLT`
        Parameters:
        -----------
        g: torch.Tensor
            It is the gradient of the layer for which the covariance matrix should be calculated.
        layer: torch.nn.Module
            It specifies the type of layer from which the output of the layer is taken.

        Returns:
        --------
        torch.Tensor
            The covariance matrix of the G matrix.
        """
        if isinstance(layer, torch.nn.Linear):
            batch_size = g.size(0)
            if self.batch_averaged:
                cov_g = g.t() @ (g * batch_size)
            else:
                cov_g = g.t() @ (g / batch_size)

        elif isinstance(layer, torch.nn.Conv2d):
            spatial_size = g.size(2) * g.size(3)
            batch_size = g.shape[0]
            g = g.transpose(1, 2).transpose(2, 3)
            g = self.try_contiguous(g)
            g = g.view(-1, g.size(-1))
            if self.batch_averaged:
                g = g * batch_size
            g = g * spatial_size
            cov_g = g.t() @ (g / g.size(0))

        elif isinstance(layer, (torch.nn.LSTM, torch.nn.GRU)):
            # For RNN layers
            # Assuming 'g' is the gradient of the output from the RNN layer
            # g should have a shape similar to the output tensor of the RNN layer

            if layer.num_layers > 1:
              # Reshape gradient tensor if the RNN has multiple layers
              g = g.view(g.shape[0], g.shape[1], -1)

            if layer.bidirectional:
               # Split and process separately for bidirectional RNNs
               forward, backward = g.split(g.shape[-1] // 2, dim=-1)
               g = torch.cat([forward, backward], dim=0)

            # Flatten the tensor except for the batch dimension
            g = g.transpose(0, 1).reshape(g.shape[1], -1)
            batch_size = g.shape[0]
            if self.batch_averaged:
               cov_g = g.t() @ (g * batch_size)
            else:
               cov_g = g.t() @ (g / batch_size)
        return cov_g

    def _save_input(self, module: torch.nn.Module, input: torch.Tensor):
        """
        Updates the input of the layer using exponentially weighted averages of the layer input.

        Parameters:
        -----------
        module: torch.nn.Module
            specifies the layer for which the input should be taken
        input: torch.Tensor
            the input matrix which should get updated
        """
        if isinstance(module, (torch.nn.LSTM, torch.nn.GRU)):
            self.rnn_hidden_states[module] = input[0].data
        if self.steps % self.TCov == 0:
            aa = self.compute_cov_a(input[0].data, module)
            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(1))
            self.m_aa[module] *= self.stat_decay + aa * (1 - self.stat_decay)

    def _save_grad_output(self, module: torch.nn.Module,
                          grad_input: torch.Tensor, grad_output: torch.Tensor):
        """
        Updates the backward gradient of the layer using exponentially weighted averages of the layer input.

        Parameters:
        -----------
        module: torch.nn.Module
            specifies the layer for which the gradient should be taken
        input: torch.Tensor
            the gradient matrix which should get updated
        """
        if isinstance(module, (torch.nn.LSTM, torch.nn.GRU)):
            self.rnn_gradients[module] = grad_output[0].data  # grad_output contains gradients wrt output

        # Accumulate statistics for Fisher matrices
        if self.steps % self.TCov == 0:
            gg = self.compute_cov_g(grad_output[0].data, module)
            # Initialize buffers
            if self.steps == 0:
                self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(1))
            self.m_gg[module] *= self.stat_decay + gg * (1 - self.stat_decay)

    def _prepare_model(self):
        """"
        Attaches hooks(saving the ouptut and grad according to the update function) to the model for
        to calculate gradients at every step.
        """
        count = 0
        for module in self.model.modules():
            classname = module.__class__.__name__
            if classname in self.known_modules:
                self.modules.append(module)
                module.register_forward_pre_hook(self._save_input)
                module.register_backward_hook(self._save_grad_output)
                count += 1

    def _update_inv(self, m: torch.nn.Module):
        """
        Does eigen decomposition of the input(A) and gradient(G) matrix for computing inverse of the ~ fisher.

        Parameter:
        ----------
        m: torch.nn.Module
            This is the layer for which the eigen decomposition should be done on.
        """
        eps = 1e-10  # for numerical stability
        if isinstance(m, (torch.nn.LSTM, torch.nn.GRU)):
            # Use the optimized function for RNN layers
            forward_states = self.rnn_hidden_states.get(m, None)
            backward_grads = sself.rnn_gradients.get(m, None)
            # Assuming you have already computed cov_f, xcov_f, cov_b, xcov_b
            if activations is not None and grad_activations is not None:
               xcov_f = compute_rnn_cross_covariance(forward_states) #  compute or retrieve cross-covariance for forward pass
               xcov_b = compute_rnn_cross_covariance(backward_grads)  #  compute or retrieve cross-covariance for backward pass

            cov_f = self.m_aa[m]  # Example, replace with actual computation
            cov_b = self.m_gg[m]  # Example, replace with actual computation
            damping_coeff = self.damping  # Or however you compute the damping coefficient

            whitened_f, whitened_b = compute_rnn_whiten_tensor_optimized(
                cov_f, xcov_f, cov_b, xcov_b, damping_coeff
            )

            # Store the results in the appropriate attributes
            self.Q_a[m] = whitened_f  # Example, adjust as needed
            self.Q_g[m] = whitened_b  # Example, adjust as needed
                        # Compute d_a and d_g from the eigenvalues of the whitened tensors
            d_a = torch.symeig(whitened_f, eigenvectors=False).eigenvalues
            d_g = torch.symeig(whitened_b, eigenvectors=False).eigenvalues
            # Compute whitened tensors using ideas from compute_rnn_whiten_tensor_option2
            damping_correction_f = torch.sqrt(torch.trace(cov_f) / torch.trace(cov_b))
            damping_correction_b = 1.0 / damping_correction_f

            if not cov_f.is_diagonal():  # Replace with actual check for non-diagonal
               I = torch.eye(cov_f.shape[0], device=cov_f.device)
               cov_f += I * damping_correction_f * torch.sqrt(damping_coeff)
               cov_b += I * damping_correction_b * torch.sqrt(damping_coeff)

               cov_f_damped = cov_f + I
               cov_b_damped = cov_b + I

               inv_cov_f = torch.inverse(cov_f_damped)
               inv_cov_b = torch.inverse(cov_b_damped)

               B_f = torch.matmul(xcov_f, inv_cov_f)
               B_b = torch.matmul(xcov_b, inv_cov_b)

               Btilde_inv_f = torch.inverse(I - B_f)
               Btilde_inv_b = torch.inverse(I - B_b)

               cov_tilde_f = cov_f_damped
               cov_tilde_b = cov_b_damped

               a_f = torch.matmul(torch.matmul(Btilde_inv_f, cov_tilde_f), Btilde_inv_f.T)
               a_b = torch.matmul(torch.matmul(Btilde_inv_b, cov_tilde_b), Btilde_inv_b.T)

               c_f = torch.matmul(torch.matmul(B_f, cov_tilde_f), B_f.T)
               c_b = torch.matmul(torch.matmul(B_b, cov_tilde_b), B_b.T)

               c_f = torch.matmul(torch.matmul(Btilde_inv_f, c_f), Btilde_inv_f.T)
               c_b = torch.matmul(torch.matmul(Btilde_inv_b, c_b), Btilde_inv_b.T)
            else:
               # Handle diagonal stats
               I = torch.ones(cov_f.shape[0], device=cov_f.device)
               cov_f += I * damping_correction_f * torch.sqrt(damping_coeff)
               cov_b += I * damping_correction_b * torch.sqrt(damping_coeff)

               cov_f_damped = cov_f + I
               cov_b_damped = cov_b + I

               inv_cov_f = 1.0 / cov_f_damped
               inv_cov_b = 1.0 / cov_b_damped

               B_f = xcov_f * inv_cov_f
               B_b = xcov_b * inv_cov_b

               Btilde_inv_f = 1.0 / (I - B_f)
               Btilde_inv_b = 1.0 / (I - B_b)

               cov_tilde_f = cov_f_damped
               cov_tilde_b = cov_b_damped

               a_f = Btilde_inv_f * cov_tilde_f * Btilde_inv_f
               a_b = Btilde_inv_b * cov_tilde_b * Btilde_inv_b

               c_f = a_f * B_f * B_f
               c_b = a_b * B_b * B_b

            # Store the results in the appropriate attributes
            self.d_a[m] = torch.eig(a_f, eigenvectors=False)[0][:, 0]  # Extract eigenvalues
            self.d_g[m] = torch.eig(a_b, eigenvectors=False)[0][:, 0]  # Extract eigenvalues
            # Specify a tolerance for equivalence
            tolerance = 1e-6  # Adjust this value as needed based on your problem

            # Check if the two values are equivalent within the specified tolerance
            equivalent = torch.allclose(self.d_a[m], d_a, atol=tolerance)

            # Print the result
            if equivalent:
               print("Values of eigen decompositions are equivalent within the specified tolerance.")
            else:
               print("Values of eigen decompositions are not equivalent within the specified tolerance.")
        else:
            if self.mean:
                self.d_a[m], self.Q_a[m] = torch.symeig(self.m_aa[m] -
                                                        torch.mean(self.m_aa[m]),
                                                        eigenvectors=True)
                self.d_g[m], self.Q_g[m] = torch.symeig(self.m_gg[m] -
                                                        torch.mean(self.m_gg[m]),
                                                        eigenvectors=True)
            else:
                self.d_a[m], self.Q_a[m] = torch.symeig(self.m_aa[m],
                                                        eigenvectors=True)
                self.d_g[m], self.Q_g[m] = torch.symeig(self.m_gg[m],
                                                        eigenvectors=True)

        self.d_a[m].mul_((self.d_a[m] > eps).float())
        self.d_g[m].mul_((self.d_g[m] > eps).float())

    @staticmethod
    def _get_matrix_form_grad(m: torch.nn.Module):
        """
        Returns the gradient of the layer in a matrix form

        Parameter:
        ----------
        m: torch.nn.Module
            the layer for which the gradient must be calculated

        Return:
        -------
        torch.tensor
            a matrix form of the gradient. it should be a [output_dim, input_dim] matrix.
        """
        if isinstance(m, torch.nn.Conv2d):
            assert isinstance(m.weight.grad, torch.Tensor)
            p_grad_mat = m.weight.grad.data.view(
                m.weight.grad.data.size(0), -1)  # n_filters * (in_c * kw * kh)
        elif isinstance(m, torch.nn.Linear):
            assert isinstance(m.weight.grad, torch.Tensor)
            p_grad_mat = m.weight.grad.data
        else:
            raise NotImplementedError(
                "KFAC optimizer currently support only Linear and Conv2d layers"
            )

        if m.bias is not None:
            if isinstance(m.bias.grad.data, torch.Tensor):
                p_grad_mat = torch.cat(
                    [p_grad_mat, m.bias.grad.data.view(-1, 1)], 1)
            else:
                raise TypeError("bias.grad.data should be a Tensor")
        return p_grad_mat

    def _get_natural_grad(self, m: torch.nn.Module, p_grad_mat: torch.Tensor,
                          damping: float) -> List[torch.Tensor]:
        """
        This function returns the product of inverse of the fisher matrix and the weights gradient.

        Parameters:
        -----------
        m: torch.nn.Module
            Specifies the layer for which the calculation must be done on.
        p_grad_mat: torch.Tensor
            the gradients in matrix form isinstance(m.weight.grad.data, torch.Tensor) and i
        damping: float
            the damping factor for the calculation

        Return:
        -------
        torch.Tensor
            the product of inverse of the fisher matrix and the weights gradient.
        """
        # p_grad_mat is of output_dim * input_dim
        # inv((ss')) p_grad_mat inv(aa') = [ Q_g (1/R_g) Q_g^T ] @ p_grad_mat @ [Q_a (1/R_a) Q_a^T]
        v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
        v2 = v1 / (self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) +
                   damping)
        a = self.Q_g[m] @ v2 @ self.Q_a[m].t()
        if m.bias is not None:
            # we always put gradient w.r.t weight in [0]
            # and w.r.t bias in [1]
            if isinstance(m.weight.grad.data, torch.Tensor) and isinstance(
                    m.bias.grad.data, torch.Tensor):
                v = [a[:, :-1], a[:, -1:]]
                v[0] = v[0].view(m.weight.grad.data.size())
                v[1] = v[1].view(m.bias.grad.data.size())
            else:
                raise TypeError(
                    "weight.grad.data and bias.grad.data should be a Tensor")
        else:
            v = [a.view(m.weight.grad.data.size())]

        return v

    def _kl_clip_and_update_grad(self, updates: Dict[torch.nn.Module,
                                                     List[torch.Tensor]],
                                 lr: float):
        """
        Performs clipping on the updates matrix, if the value is large. Then final value is updated in the backwards gradient data

        Parameters:
        -----------
        updates: Dict[torch.nn.Module,List[torch.Tensor]]
            A dicitonary containing the product of gradient and fisher inverse of each layer.
        lr: float
            learning rate of the optimizer
        """
        # do kl clip
        vg_sum = 0.0
        for m in self.modules:
            v = updates[m]
            vg_sum += (v[0] * m.weight.grad.data * lr**2).sum().item()
            if m.bias is not None:
                vg_sum += (v[1] * m.bias.grad.data * lr**2).sum().item()
        nu = min(1.0, math.sqrt(self.kl_clip / vg_sum))

        for m in self.modules:
            v = updates[m]
            if isinstance(m.weight.grad.data, torch.Tensor):
                m.weight.grad.data.copy_(v[0])
                m.weight.grad.data.mul_(nu)
            else:
                raise TypeError("weight.grad.data should be a Tensor")
            if m.bias is not None:
                if isinstance(m.bias.grad.data, torch.Tensor):
                    m.bias.grad.data.copy_(v[1])
                    m.bias.grad.data.mul_(nu)
                else:
                    raise TypeError("bias.grad.data should be a Tensor")

    def _step(self, closure: Optional[Callable] = None):
        """
        Called in every step of the optimizer, updating the model parameters from the gradient by the KFAC equation.
        Also, performs weight decay and adds momentum if any.

        Parameters:
        -----------
        closure: Callable, optional(default: None)
            an optional customizable function to be passed which can be used to clear the gradients and other compute loss for every step.
        """
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            lr = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0 and self.steps >= 20 * self.TCov:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.zeros_like(
                            p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p)
                    d_p = buf

                torch.add(p.data, -lr, d_p, out=p.data)

    def step(self, closure: Optional[Callable] = None):
        """
        This is the function that gets called in each step of the optimizer to update the weights and biases of the model.

        Parameters:
        -----------
        closure: Callable, optional(default: None)
            an optional customizable function to be passed which can be used to clear the gradients and other compute loss for every step.
        """
        group = self.param_groups[0]
        lr = group['lr']
        damping = group['damping']
        updates = {}
        for m in self.modules:
            if self.steps % self.TInv == 0:
                self._update_inv(m)
            p_grad_mat = self._get_matrix_form_grad(m)
            v = self._get_natural_grad(m, p_grad_mat, damping)
            updates[m] = v
        self._kl_clip_and_update_grad(updates, lr)

        self._step(closure)
        self.steps += 1

I believe that adding RNN support to the KFACOptimizer class would be a valuable addition to the repository and benefit the deep learning community as a whole. Thank you for considering this feature request.

@rbharath
Copy link
Member

This sounds like an interesting addition that we would tentatively consider adding. Could you please join our office hours to discuss with us (MWF at 9am PST)? https://forum.deepchem.io/t/announcing-the-deepchem-office-hours/293

@neuronphysics
Copy link
Author

neuronphysics commented Nov 23, 2023

Hi, Thanks for the reply. I'm interested in joining the office hours to discuss this feature request. Could you please specify which day of the week the office hours occur?

@shreyasvinaya
Copy link
Member

Hi @neuronphysics the office hours happen at 9am PST on Mondays, Wednesdays and Fridays

@neuronphysics
Copy link
Author

Hi, I joined the Google Meet session on Friday for the office hour, but I wasn't admitted.

@shreyasvinaya
Copy link
Member

Hi @neuronphysics there was no office hours on Friday due to thanksgiving, you can join on Monday (27th Nov 2023)

@neuronphysics
Copy link
Author

Hi I am trying to attend your office hour. Should I use the old google meet?

@shreyasvinaya
Copy link
Member

Hi @neuronphysics , the office hours is delayed by 15 min today, requesting you to wait, you will be admitted as soon as the office hours start

@prachi237
Copy link

prachi237 commented Feb 22, 2024

@rbharath sir , @shreyasvinaya sir can I work on this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants