Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Baichuan2-7B-Base中训练后显存翻倍问题 #387

Open
Mr-KenLee opened this issue Mar 7, 2024 · 1 comment
Open

Baichuan2-7B-Base中训练后显存翻倍问题 #387

Mr-KenLee opened this issue Mar 7, 2024 · 1 comment

Comments

@Mr-KenLee
Copy link

Mr-KenLee commented Mar 7, 2024

我用Seq2SeqTrainer对Baichuan2-7B-Base进行LoRA微调,但是很奇怪,我发现在第一次预测后,会出现OOM问题,但是Baichuan2-7B-Chat并不会。
同时,我发现Baichuan2-7B-Base的OOM问题来源于,从预测回归训练后,模型好像会二次加载,使得显存占用翻倍从而OOM。
我对比了Base和Chat的modeling.py文件,发现主要是Base中下面代码的问题:

class NormHead(nn.Module):
    def __init__(self, hidden_size, vocab_size, bias=False):
        super().__init__()
        self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.first_flag = True

    def forward(self, hidden_states):
        if self.training:
            norm_weight = nn.functional.normalize(self.weight)
        elif self.first_flag:
            self.first_flag = False
            self.weight = nn.Parameter(nn.functional.normalize(self.weight))
            norm_weight = self.weight
        else:
            norm_weight = self.weight
        return nn.functional.linear(hidden_states, norm_weight)

而在Chat中则是:

class NormHead(nn.Module):
    def __init__(self, hidden_size, vocab_size, bias=False):
        super().__init__()
        self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.first_flag = True

    def forward(self, hidden_states):
        if self.training:
            norm_weight = nn.functional.normalize(self.weight)
            self.first_flag = True
        elif self.first_flag:
            self.first_flag = False
            self.weight.data = nn.functional.normalize(self.weight)
            norm_weight = self.weight
        else:
            norm_weight = self.weight
        return nn.functional.linear(hidden_states, norm_weight)

将Base中的替换为Chat中的NormHead后问题解决,想请问下这个原因是为什么呢?两个modeling文件是否可以互用?

@Mr-KenLee
Copy link
Author

Mr-KenLee commented Mar 7, 2024

应该主要是没有self.first_flag = True造成的吧?Base没有这个就会造成从预测转训练的时候,进不到目标分支?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant