Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent init methods of pythia-6.9b model #135

Open
mqyqlx opened this issue Nov 17, 2023 · 2 comments
Open

Inconsistent init methods of pythia-6.9b model #135

mqyqlx opened this issue Nov 17, 2023 · 2 comments

Comments

@mqyqlx
Copy link

mqyqlx commented Nov 17, 2023

Hi, I found that the init method of parameters in pythia-6.9B model is inconsistent with the standard deviation of the step0 checkpoint. Table 6 in the paper shows that init-method is small-init and output-layer-init-method is wang-init. But I got different std values from step0 models.

Inconsistent std values:

input_layer_std: 0.009882117688026186(small_init), 0.02(std calculated from step0 model paramters)
output_layer_std: 0.0009765625(wang_init), 0.0025(std calculated from step0 model paramters)

Could you provide the real init method? Thanks!

Config Table 6:
image

Here are the reproducible script and results.

import math
from transformers import GPTNeoXForCausalLM, AutoTokenizer

model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-6.9b",
  revision="step0",
)

model_dim = 4096 # Pythia-6.9b

# compute right std values of the two init methods 
# reference https://github.com/EleutherAI/gpt-neox/blob/v1.0/megatron/model/init_functions.py#L101-L118 

small_init_std = (2/(5* model_dim)) ** 0.5
wang_init_std = 2 / (32 * math.sqrt(model_dim))
print('small_init_std:', small_init_std)
print('wang_init_std:', wang_init_std)

for n, p in model.named_parameters():
    print(n, p.shape, p.std().item())

Results:

small_init_std: 0.009882117688026186
wang_init_std: 0.0009765625

gpt_neox.embed_in.weight torch.Size([50432, 4096]) 0.019999271258711815
gpt_neox.layers.0.input_layernorm.weight torch.Size([4096]) 0.0
gpt_neox.layers.0.input_layernorm.bias torch.Size([4096]) 0.0
gpt_neox.layers.0.post_attention_layernorm.weight torch.Size([4096]) 0.0
gpt_neox.layers.0.post_attention_layernorm.bias torch.Size([4096]) 0.0
gpt_neox.layers.0.attention.query_key_value.weight torch.Size([12288, 4096]) 0.019999688491225243
gpt_neox.layers.0.attention.query_key_value.bias torch.Size([12288]) 0.0
gpt_neox.layers.0.attention.dense.weight torch.Size([4096, 4096]) 0.002499272581189871
gpt_neox.layers.0.attention.dense.bias torch.Size([4096]) 0.0
gpt_neox.layers.0.mlp.dense_h_to_4h.weight torch.Size([16384, 4096]) 0.019998779520392418
gpt_neox.layers.0.mlp.dense_h_to_4h.bias torch.Size([16384]) 0.0
gpt_neox.layers.0.mlp.dense_4h_to_h.weight torch.Size([4096, 16384]) 0.0024998513981699944
gpt_neox.layers.0.mlp.dense_4h_to_h.bias torch.Size([4096]) 0.0
gpt_neox.layers.1.input_layernorm.weight torch.Size([4096]) 0.0
gpt_neox.layers.1.input_layernorm.bias torch.Size([4096]) 0.0
gpt_neox.layers.1.post_attention_layernorm.weight torch.Size([4096]) 0.0
gpt_neox.layers.1.post_attention_layernorm.bias torch.Size([4096]) 0.0
gpt_neox.layers.1.attention.query_key_value.weight torch.Size([12288, 4096]) 0.01999974064528942
gpt_neox.layers.1.attention.query_key_value.bias torch.Size([12288]) 0.0
gpt_neox.layers.1.attention.dense.weight torch.Size([4096, 4096]) 0.0025000576861202717
gpt_neox.layers.1.attention.dense.bias torch.Size([4096]) 0.0
gpt_neox.layers.1.mlp.dense_h_to_4h.weight torch.Size([16384, 4096]) 0.02000279724597931
gpt_neox.layers.1.mlp.dense_h_to_4h.bias torch.Size([16384]) 0.0
gpt_neox.layers.1.mlp.dense_4h_to_h.weight torch.Size([4096, 16384]) 0.002499587135389447
gpt_neox.layers.1.mlp.dense_4h_to_h.bias torch.Size([4096]) 0.0
...
@StellaAthena
Copy link
Member

This is very weird. Have you been able to form any tentative hypotheses about it?

@mqyqlx
Copy link
Author

mqyqlx commented Mar 19, 2024

This is very weird. Have you been able to form any tentative hypotheses about it?

Not yet. I guess these two standard deviations used in Pythia-6.9B are set empirically and seem not to be calculated by a formula.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants