Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用BrainCog模拟RMSNorm,最终输出为nan #179

Open
LumenScope opened this issue Feb 27, 2024 · 0 comments
Open

使用BrainCog模拟RMSNorm,最终输出为nan #179

LumenScope opened this issue Feb 27, 2024 · 0 comments

Comments

@LumenScope
Copy link

class SNN_RMSNorm(nn.Module):
    def __init__(self, max_length = 128, hidden_size=4096,node=LIAFNode, threshold=0.5, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.rms_neuron = node(act_fun='LeakyReLU', threshold=threshold)
        self.weight_neuron = node(act_fun='ReLU', threshold=threshold)
        self.weight = nn.Parameter(torch.ones(hidden_size,hidden_size))
        self.rms_connection = CustomLinear(torch.ones(1,hidden_size))
        self.weight_connection = CustomLinear(self.weight)

    def forward(self, x):
        x_sqr = x ** 2
        x_rms = x_sqr.mean(-1, keepdim=True)
        s_rms = self.rms_neuron(self.rms_connection(x_rms))
        rms_out = torch.rsqrt(s_rms + self.eps)
        s_scale = self.weight_neuron(self.weight_connection(rms_out))
        return s_scale

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        Apply the RMSNorm normalization to the input tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The normalized tensor.

        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass through the RMSNorm layer.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.

        """
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

以上为我定义的SNN化RMSNorm和原始RMSNorm函数,以下为SNN_RMSNorm前向传播的输出,虽然维度shape经过我的处理达到了一致,但是输出如下:

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       grad_fn=<StackBackward0>)
torch.Size([2, 128, 4096])

以下为全部代码:

from torchvision import transforms
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
from spikingjelly.clock_driven.neuron import MultiStepLIFNode, MultiStepParametricLIFNode
from transformers import CLIPProcessor, CLIPModel
from accelerate import Accelerator
from dataclasses import dataclass
from typing import Optional, Tuple

import fairscale.nn.model_parallel.initialize as fs_init
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear

import numpy as np
import os
import sys
from torch.nn import Parameter
import abc
from abc import ABC
from einops import rearrange, repeat

accelerator = Accelerator()


@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 2048


class CustomLinear(nn.Module):
    """
    用户自定义连接 通常stdp的计算
    """

    def __init__(self, weight, mask=None):
        super().__init__()

        self.weight = nn.Parameter(weight, requires_grad=True)
        self.mask = mask

    def forward(self, x: torch.Tensor):
        """
        :param x:输入 x.shape = [N ]
        """
        #
        # ret.shape = [C]

        return x.matmul(self.weight)

    def update(self, dw):
        """
        :param dw:权重更新量
        """
        with torch.no_grad():
            if self.mask is not None:
                dw *= self.mask
            self.weight.data += dw


class STDP(nn.Module):
    """
    STDP learning rule
    """

    def __init__(self, node, connection, decay=0.99):
        """
        :param node:node神经元类型实例如IFNode LIFNode
        :param connection:连接 类的实例 里面只能有一个操作
        """
        super().__init__()

        self.node = node
        self.connection = connection
        self.trace = None
        self.decay = decay

    def forward(self, x):
        """
        计算前向传播过程
        :return:s是脉冲 dw更新量
        """
        x = x.clone().detach()
        i = self.connection(x)
        with torch.no_grad():
            s = self.node(i)

            i.data += s - i.data
            trace = self.cal_trace(x)
            x.data += trace - x.data

        dw = torch.autograd.grad(
            outputs=i, inputs=self.connection.weight, grad_outputs=i)

        return s, dw

    def cal_trace(self, x):
        """
        计算trace
        """
        if self.trace is None:
            self.trace = Parameter(x.clone().detach(), requires_grad=False)
        else:
            self.trace *= self.decay
            self.trace += x
        return self.trace.detach()

    def reset(self):
        """
        重置
        """
        self.trace = None


def heaviside(x):
    return (x >= 0.).to(x.dtype)


class quadratic_gate(torch.autograd.Function):
    """
    使用 quadratic_gate 作为代理梯度函数
    对应的原函数为:

    .. math::
        g(x) =
        \\begin{cases}
        0, & x < -\\frac{1}{\\alpha} \\\\
        -\\frac{1}{2}\\alpha^2|x|x + \\alpha x + \\frac{1}{2}, & |x| \\leq \\frac{1}{\\alpha}  \\\\
        1, & x > \\frac{1}{\\alpha} \\\\
        \\end{cases}

    反向传播的函数为:

    .. math::
        g'(x) =
        \\begin{cases}
        0, & |x| > \\frac{1}{\\alpha} \\\\
        -\\alpha^2|x|+\\alpha, & |x| \\leq \\frac{1}{\\alpha}
        \\end{cases}

    """

    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            mask_zero = (x.abs() > 1 / alpha)
            grad_x = -alpha * alpha * x.abs() + alpha
            grad_x.masked_fill_(mask_zero, 0)
            ctx.save_for_backward(grad_x)
        return x.gt(0.).float()

    @staticmethod
    def backward(ctx, grad_output):
        grad_x = None
        if ctx.needs_input_grad[0]:
            grad_x = grad_output * ctx.saved_tensors[0]
        return grad_x, None


class SurrogateFunctionBase(nn.Module):
    """
    Surrogate Function 的基类
    :param alpha: 为一些能够调控函数形状的代理函数提供参数.
    :param requires_grad: 参数 ``alpha`` 是否需要计算梯度, 默认为 ``False``
    """

    def __init__(self, alpha, requires_grad=True):
        super().__init__()
        self.alpha = nn.Parameter(
            torch.tensor(alpha, dtype=torch.float),
            requires_grad=requires_grad)

    @staticmethod
    def act_fun(x, alpha):
        """
        :param x: 膜电位的输入
        :param alpha: 控制代理梯度形状的变量, 可以为 ``NoneType``
        :return: 激发之后的spike, 取值为 ``[0, 1]``
        """
        raise NotImplementedError

    def forward(self, x):
        """
        :param x: 膜电位输入
        :return: 激发之后的spike
        """
        return self.act_fun(x, self.alpha)


'''
    sigmoid surrogate function.
'''


class QGateGrad(SurrogateFunctionBase):
    def __init__(self, alpha=2., requires_grad=False):
        super().__init__(alpha, requires_grad)

    @staticmethod
    def act_fun(x, alpha):
        return quadratic_gate.apply(x, alpha)


class relu_like(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x, alpha)
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_x, grad_alpha = None, None
        x, alpha = ctx.saved_tensors
        if ctx.needs_input_grad[0]:
            grad_x = grad_output * x.gt(0.).float() * alpha
        if ctx.needs_input_grad[1]:
            grad_alpha = (grad_output * F.relu(x)).sum()
        return grad_x, grad_alpha


class RoundGrad(nn.Module):
    def __init__(self, **kwargs):
        super(RoundGrad, self).__init__()
        self.act = nn.Hardtanh(-.5, 4.5)

    def forward(self, x):
        x = self.act(x)
        return x.ceil() + x - x.detach()


class backeigate(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(0.).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input) < 0.5
        return grad_input * temp.float()


class BackEIGateGrad(SurrogateFunctionBase):
    def __init__(self, alpha=2., requires_grad=False):
        super().__init__(alpha, requires_grad)

    @staticmethod
    def act_fun(x, alpha):
        return backeigate.apply(x)


class ei(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.sign(input).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input) < 0.5
        return grad_input * temp.float()


class BaseNode(nn.Module, abc.ABC):
    """
    神经元模型的基类
    :param threshold: 神经元发放脉冲需要达到的阈值
    :param v_reset: 静息电位
    :param dt: 时间步长
    :param step: 仿真步
    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
    :param mem_detach: 是否将上一时刻的膜电位在计算图中截断
    :param args: 其他的参数
    :param kwargs: 其他的参数
    """

    def __init__(self,
                 threshold=.5,
                 v_reset=0.,
                 dt=1.,
                 step=8,
                 requires_thres_grad=False,
                 sigmoid_thres=False,
                 requires_fp=False,
                 layer_by_layer=False,
                 n_groups=1,
                 *args,
                 **kwargs):

        super(BaseNode, self).__init__()
        self.threshold = Parameter(torch.tensor(
            threshold), requires_grad=requires_thres_grad)
        self.sigmoid_thres = sigmoid_thres
        self.mem = 0.
        self.spike = 0.
        self.dt = dt
        self.feature_map = []
        self.mem_collect = []
        self.requires_fp = requires_fp
        self.v_reset = v_reset
        self.step = step
        self.layer_by_layer = layer_by_layer
        self.groups = n_groups
        self.mem_detach = kwargs['mem_detach'] if 'mem_detach' in kwargs else False
        self.requires_mem = kwargs['requires_mem'] if 'requires_mem' in kwargs else False

    @abc.abstractmethod
    def calc_spike(self):
        """
        通过当前的mem计算是否发放脉冲,并reset
        :return: None
        """

        pass

    def integral(self, inputs):
        """
        计算由当前inputs对于膜电势的累积
        :param inputs: 当前突触输入电流
        :type inputs: torch.tensor
        :return: None
        """

        pass

    def get_thres(self):
        return self.threshold if not self.sigmoid_thres else self.threshold.sigmoid()

    def rearrange2node(self, inputs):
        if self.groups != 1:
            if len(inputs.shape) == 4:
                outputs = rearrange(
                    inputs, 'b (c t) w h -> t b c w h', t=self.step)
            elif len(inputs.shape) == 2:
                outputs = rearrange(inputs, 'b (c t) -> t b c', t=self.step)
            else:
                raise NotImplementedError

        elif self.layer_by_layer:
            if len(inputs.shape) == 4:
                outputs = rearrange(
                    inputs, '(t b) c w h -> t b c w h', t=self.step)
            elif len(inputs.shape) == 3:
                outputs = rearrange(
                    inputs, '(t b) n c -> t b n c', t=self.step)
            elif len(inputs.shape) == 2:
                outputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)
            else:
                raise NotImplementedError

        else:
            outputs = inputs

        return outputs

    def rearrange2op(self, inputs):
        if self.groups != 1:
            if len(inputs.shape) == 5:
                outputs = rearrange(inputs, 't b c w h -> b (c t) w h')
            elif len(inputs.shape) == 3:
                outputs = rearrange(inputs, ' t b c -> b (c t)')
            else:
                raise NotImplementedError
        elif self.layer_by_layer:
            if len(inputs.shape) == 5:
                outputs = rearrange(inputs, 't b c w h -> (t b) c w h')
            elif len(inputs.shape) == 4:
                outputs = rearrange(inputs, ' t b n c -> (t b) n c')
            elif len(inputs.shape) == 3:
                outputs = rearrange(inputs, ' t b c -> (t b) c')
            else:
                raise NotImplementedError

        else:
            outputs = inputs

        return outputs

    def forward(self, inputs):
        """
        torch.nn.Module 默认调用的函数,用于计算膜电位的输入和脉冲的输出
        在```self.requires_fp is True``` 的情况下,可以使得```self.feature_map```用于记录trace
        :param inputs: 当前输入的膜电位
        :return: 输出的脉冲
        """

        if self.layer_by_layer or self.groups != 1:
            inputs = self.rearrange2node(inputs)

            outputs = []
            for i in range(self.step):

                if self.mem_detach and hasattr(self.mem, 'detach'):
                    self.mem = self.mem.detach()
                    self.spike = self.spike.detach()
                self.integral(inputs[i])

                self.calc_spike()

                if self.requires_fp is True:
                    self.feature_map.append(self.spike)
                if self.requires_mem is True:
                    self.mem_collect.append(self.mem)
                outputs.append(self.spike)
            outputs = torch.stack(outputs)

            outputs = self.rearrange2op(outputs)
            return outputs
        else:
            if self.mem_detach and hasattr(self.mem, 'detach'):
                self.mem = self.mem.detach()
                self.spike = self.spike.detach()
            self.integral(inputs)
            self.calc_spike()
            if self.requires_fp is True:
                self.feature_map.append(self.spike)
            if self.requires_mem is True:
                self.mem_collect.append(self.mem)
            return self.spike

    def n_reset(self):
        """
        神经元重置,用于模型接受两个不相关输入之间,重置神经元所有的状态
        :return: None
        """
        self.mem = self.v_reset
        self.spike = 0.
        self.feature_map = []
        self.mem_collect = []

    def get_n_attr(self, attr):

        if hasattr(self, attr):
            return getattr(self, attr)
        else:
            return None

    def set_n_warm_up(self, flag):
        """
        一些训练策略会在初始的一些epoch,将神经元视作ANN的激活函数训练,此为设置是否使用该方法训练
        :param flag: True:神经元变为激活函数, False:不变
        :return: None
        """
        self.warm_up = flag

    def set_n_threshold(self, thresh):
        """
        动态设置神经元的阈值
        :param thresh: 阈值
        :return:
        """
        self.threshold = Parameter(torch.tensor(
            thresh, dtype=torch.float), requires_grad=False)

    def set_n_tau(self, tau):
        """
        动态设置神经元的衰减系数,用于带Leaky的神经元
        :param tau: 衰减系数
        :return:
        """
        if hasattr(self, 'tau'):
            self.tau = Parameter(torch.tensor(
                tau, dtype=torch.float), requires_grad=False)
        else:
            raise NotImplementedError


class LIFNode(BaseNode):
    """
    Leaky Integrate and Fire
    :param threshold: 神经元发放脉冲需要达到的阈值
    :param v_reset: 静息电位
    :param dt: 时间步长
    :param step: 仿真步
    :param tau: 膜电位时间常数, 用于控制膜电位衰减
    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
    :param args: 其他的参数
    :param kwargs: 其他的参数
    """

    def __init__(self, threshold=0.5, tau=2., act_fun=QGateGrad, *args, **kwargs):
        super().__init__(threshold, *args, **kwargs)
        self.tau = tau
        if isinstance(act_fun, str):
            act_fun = eval(act_fun)
        self.act_fun = act_fun(alpha=2., requires_grad=False)
        # self.threshold = threshold
        # print(threshold)
        # print(tau)

    def integral(self, inputs):
        self.mem = self.mem + (inputs - self.mem) / self.tau

    def calc_spike(self):
        self.spike = self.act_fun(self.mem - self.threshold)
        self.mem = self.mem * (1 - self.spike.detach())


class LIAFNode(BaseNode):
    """
    Leaky Integrate and Analog Fire (LIAF), Reference: https://ieeexplore.ieee.org/abstract/document/9429228
    与LIF相同, 但前传的是膜电势, 更新沿用阈值和膜电势
    :param act_fun: 前传使用的激活函数 [ReLU, SeLU, LeakyReLU]
    :param threshold_related: 阈值依赖模式,若为"True"则 self.spike = act_fun(mem-threshold)
    :note that BaseNode return self.spike, and here self.spike is analog value.
    """

    def __init__(self, spike_act=BackEIGateGrad(), act_fun="SELU", threshold=0.5, tau=2., threshold_related=True, *args, **kwargs):
        super().__init__(threshold, *args, **kwargs)
        if isinstance(act_fun, str):
            act_fun = eval("nn." + act_fun + "()")
        self.tau = tau
        self.act_fun = act_fun
        self.spike_act = spike_act
        self.threshold_related = threshold_related

    def integral(self, inputs):
        self.mem = self.mem + (inputs - self.mem) / self.tau

    def calc_spike(self):
        if self.threshold_related:
            spike_tmp = self.act_fun(self.mem - self.threshold)
        else:
            spike_tmp = self.act_fun(self.mem)
        self.spike = self.spike_act(self.mem - self.threshold)
        self.mem = self.mem * (1 - self.spike)
        self.spike = spike_tmp


class SNN_RMSNorm(nn.Module):
    def __init__(self, max_length = 128, hidden_size=4096,node=LIAFNode, threshold=0.5, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.rms_neuron = node(act_fun='LeakyReLU', threshold=threshold)
        self.weight_neuron = node(act_fun='ReLU', threshold=threshold)
        self.weight = nn.Parameter(torch.ones(hidden_size,hidden_size))
        self.rms_connection = CustomLinear(torch.ones(1,hidden_size))
        self.weight_connection = CustomLinear(self.weight)

    def forward(self, x):
        x_sqr = x ** 2
        x_rms = x_sqr.mean(-1, keepdim=True)
        s_rms = self.rms_neuron(self.rms_connection(x_rms))
        rms_out = torch.rsqrt(s_rms + self.eps)
        s_scale = self.weight_neuron(self.weight_connection(rms_out))
        return s_scale
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant