Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError while finetuning RWKVv5 #216

Open
Ethan-Chen-plus opened this issue Jan 13, 2024 · 8 comments
Open

AssertionError while finetuning RWKVv5 #216

Ethan-Chen-plus opened this issue Jan 13, 2024 · 8 comments

Comments

@Ethan-Chen-plus
Copy link

While finetuning RWKV, I use this script(using demo dataset by make_data.py and put demo.bin and demo.idx in ./data):

#!/bin/bash

BASE_NAME="model/demo"
N_LAYER="12"
N_EMBD="768"
M_BSZ="16" # takes 16G VRAM (reduce this to save VRAM)
LR_INIT="6e-4"
LR_FINAL="6e-5"
GRAD_CP=0 # set to 1 to save VRAM (will be slower)
EPOCH_SAVE=10

# magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)
# use https://www.dcode.fr/prime-numbers-search

python train.py --load_model "../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth" --wandb "RWKV-5-Test" --proj_dir $BASE_NAME \
 --ctx_len 512 --my_pile_stage 3 --epoch_count 999999 --epoch_begin 0 \
 --data_file "data/demo" --my_exit_tokens 1498226207 --magic_prime 2926181 \
 --num_nodes 1 --micro_bsz $M_BSZ --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 \
 --lr_init $LR_INIT --lr_final $LR_FINAL --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 --data_type "binidx" --vocab_size 65536 \
 --weight_decay 0.001 --epoch_save $EPOCH_SAVE --head_size_a 64 \
 --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp $GRAD_CP --ds_bucket_mb 200

I caught this error:

(rwkv) ubuntu@ip-172-31-67-197:~/MedicalGPT/rwkv/RWKV-LM/RWKV-v5$ CUDA_VISIBLE_DEVICES=2 bash demo-training-run-demo.sh
INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpw45qi2d_
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpw45qi2d_/_remote_module_non_scriptable.py
INFO:pytorch_lightning.utilities.rank_zero:########## work in progress ##########
[2024-01-13 12:22:27,924] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO:pytorch_lightning.utilities.rank_zero:
############################################################################
#
# RWKV-5 BF16 on 1x1 GPU, bsz 1x1x16=16, deepspeed_stage_2 
#
# Data = data/demo (binidx), ProjDir = model/demo
#
# Epoch = 0 to 71 (will continue afterwards), save every 10 epoch
#
# Each "epoch" = 2520 steps, 40320 samples, 20643840 tokens
#
# Model = 12 n_layer, 768 n_embd, 512 ctx_len
#
# Adam = lr 0.0006 to 6e-05, warmup 10 steps, beta (0.9, 0.99), eps 1e-08
#
# Found torch 1.13.1+cu117, recommend 1.13.1+cu117 or newer
# Found deepspeed 0.12.6, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning 1.9.5, recommend 1.9.5
#
############################################################################

INFO:pytorch_lightning.utilities.rank_zero:{'load_model': '../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth', 'wandb': 'RWKV-5-Test', 'proj_dir': 'model/demo', 'random_seed': -1, 'data_file': 'data/demo', 'data_type': 'binidx', 'vocab_size': 65536, 'ctx_len': 512, 'epoch_steps': 2520, 'epoch_count': 72, 'epoch_begin': 0, 'epoch_save': 10, 'micro_bsz': 16, 'n_layer': 12, 'n_embd': 768, 'dim_att': 768, 'dim_ffn': 2688, 'pre_ffn': 0, 'head_qk': 0, 'tiny_att_dim': 0, 'tiny_att_layer': -999, 'lr_init': 0.0006, 'lr_final': 6e-05, 'warmup_steps': 10, 'beta1': 0.9, 'beta2': 0.99, 'adam_eps': 1e-08, 'grad_cp': 0, 'dropout': 0, 'weight_decay': 0.001, 'weight_decay_final': -1, 'my_pile_version': 1, 'my_pile_stage': 3, 'my_pile_shift': 0, 'my_pile_edecay': 0, 'layerwise_lr': 1, 'ds_bucket_mb': 200, 'my_sample_len': 0, 'my_ffn_shift': 1, 'my_att_shift': 1, 'head_size_a': 64, 'head_size_divisor': 8, 'my_pos_emb': 0, 'load_partial': 0, 'magic_prime': 2926181, 'my_qa_mask': 0, 'my_random_steps': 0, 'my_testing': '', 'my_exit': 99999999, 'my_exit_tokens': 1498226207, 'logger': False, 'enable_checkpointing': False, 'default_root_dir': None, 'gradient_clip_val': 1.0, 'gradient_clip_algorithm': None, 'num_nodes': 1, 'num_processes': None, 'devices': '1', 'gpus': None, 'auto_select_gpus': None, 'tpu_cores': None, 'ipus': None, 'enable_progress_bar': True, 'overfit_batches': 0.0, 'track_grad_norm': -1, 'check_val_every_n_epoch': 100000000000000000000, 'fast_dev_run': False, 'accumulate_grad_batches': None, 'max_epochs': -1, 'min_epochs': None, 'max_steps': -1, 'min_steps': None, 'max_time': None, 'limit_train_batches': None, 'limit_val_batches': None, 'limit_test_batches': None, 'limit_predict_batches': None, 'val_check_interval': None, 'log_every_n_steps': 100000000000000000000, 'accelerator': 'gpu', 'strategy': 'deepspeed_stage_2', 'sync_batchnorm': False, 'precision': 'bf16', 'enable_model_summary': True, 'num_sanity_val_steps': 0, 'resume_from_checkpoint': None, 'profiler': None, 'benchmark': None, 'reload_dataloaders_every_n_epochs': 0, 'auto_lr_find': False, 'replace_sampler_ddp': False, 'detect_anomaly': False, 'auto_scale_batch_size': False, 'plugins': None, 'amp_backend': None, 'amp_level': None, 'move_metrics_to_cpu': False, 'multiple_trainloader_mode': 'max_size_cycle', 'inference_mode': True, 'my_timestamp': '2024-01-13-12-22-29', 'betas': (0.9, 0.99), 'real_bsz': 16, 'run_name': '65536 ctx512 L12 D768'}

INFO:pytorch_lightning.utilities.rank_zero:Current vocab size = 65536 (make sure it's correct)
INFO:pytorch_lightning.utilities.rank_zero:Data has 200499 tokens.
INFO:pytorch_lightning.utilities.rank_zero:########## Pile 20b-tokenized stage 3 ##########
Traceback (most recent call last):
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/train.py", line 248, in <module>
    train_data = MyDataset(args)
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/src/dataset.py", line 56, in __init__
    assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
AssertionError
@BlinkDL
Copy link
Owner

BlinkDL commented Jan 14, 2024

Data has 200499 tokens

therefore set my_exit_tokens to 200499, and note:
magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 200499 /512-1 = 390.599609375 in this case)
use https://www.dcode.fr/prime-numbers-search

therefore set magic_prime = 389

@Ethan-Chen-plus
Copy link
Author

Thanks for answering.But still some errors occur:

CUDA_VISIBLE_DEVICES=2 bash demo-training-run-demo.sh

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpyy254i6t
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpyy254i6t/_remote_module_non_scriptable.py
INFO:pytorch_lightning.utilities.rank_zero:########## work in progress ##########
[2024-01-16 12:35:16,735] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO:pytorch_lightning.utilities.rank_zero:
############################################################################
#
# RWKV-5 BF16 on 1x1 GPU, bsz 1x1x16=16, deepspeed_stage_2 
#
# Data = data/demo (binidx), ProjDir = model/demo
#
# Epoch = 0 to -1 (will continue afterwards), save every 10 epoch
#
# Each "epoch" = 2520 steps, 40320 samples, 20643840 tokens
#
# Model = 12 n_layer, 768 n_embd, 512 ctx_len
#
# Adam = lr 0.0006 to 6e-05, warmup 10 steps, beta (0.9, 0.99), eps 1e-08
#
# Found torch 1.13.1+cu117, recommend 1.13.1+cu117 or newer
# Found deepspeed 0.12.6, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning 1.9.5, recommend 1.9.5
#
############################################################################

INFO:pytorch_lightning.utilities.rank_zero:{'load_model': '../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth', 'wandb': 'RWKV-5-Test', 'proj_dir': 'model/demo', 'random_seed': -1, 'data_file': 'data/demo', 'data_type': 'binidx', 'vocab_size': 65536, 'ctx_len': 512, 'epoch_steps': 2520, 'epoch_count': 0, 'epoch_begin': 0, 'epoch_save': 10, 'micro_bsz': 16, 'n_layer': 12, 'n_embd': 768, 'dim_att': 768, 'dim_ffn': 2688, 'pre_ffn': 0, 'head_qk': 0, 'tiny_att_dim': 0, 'tiny_att_layer': -999, 'lr_init': 0.0006, 'lr_final': 6e-05, 'warmup_steps': 10, 'beta1': 0.9, 'beta2': 0.99, 'adam_eps': 1e-08, 'grad_cp': 0, 'dropout': 0, 'weight_decay': 0.001, 'weight_decay_final': -1, 'my_pile_version': 1, 'my_pile_stage': 3, 'my_pile_shift': 0, 'my_pile_edecay': 0, 'layerwise_lr': 1, 'ds_bucket_mb': 200, 'my_sample_len': 0, 'my_ffn_shift': 1, 'my_att_shift': 1, 'head_size_a': 64, 'head_size_divisor': 8, 'my_pos_emb': 0, 'load_partial': 0, 'magic_prime': 389, 'my_qa_mask': 0, 'my_random_steps': 0, 'my_testing': '', 'my_exit': 99999999, 'my_exit_tokens': 1498226207, 'logger': False, 'enable_checkpointing': False, 'default_root_dir': None, 'gradient_clip_val': 1.0, 'gradient_clip_algorithm': None, 'num_nodes': 1, 'num_processes': None, 'devices': '1', 'gpus': None, 'auto_select_gpus': None, 'tpu_cores': None, 'ipus': None, 'enable_progress_bar': True, 'overfit_batches': 0.0, 'track_grad_norm': -1, 'check_val_every_n_epoch': 100000000000000000000, 'fast_dev_run': False, 'accumulate_grad_batches': None, 'max_epochs': -1, 'min_epochs': None, 'max_steps': -1, 'min_steps': None, 'max_time': None, 'limit_train_batches': None, 'limit_val_batches': None, 'limit_test_batches': None, 'limit_predict_batches': None, 'val_check_interval': None, 'log_every_n_steps': 100000000000000000000, 'accelerator': 'gpu', 'strategy': 'deepspeed_stage_2', 'sync_batchnorm': False, 'precision': 'bf16', 'enable_model_summary': True, 'num_sanity_val_steps': 0, 'resume_from_checkpoint': None, 'profiler': None, 'benchmark': None, 'reload_dataloaders_every_n_epochs': 0, 'auto_lr_find': False, 'replace_sampler_ddp': False, 'detect_anomaly': False, 'auto_scale_batch_size': False, 'plugins': None, 'amp_backend': None, 'amp_level': None, 'move_metrics_to_cpu': False, 'multiple_trainloader_mode': 'max_size_cycle', 'inference_mode': True, 'my_timestamp': '2024-01-16-12-35-18', 'betas': (0.9, 0.99), 'real_bsz': 16, 'run_name': '65536 ctx512 L12 D768'}

INFO:pytorch_lightning.utilities.rank_zero:Current vocab size = 65536 (make sure it's correct)
INFO:pytorch_lightning.utilities.rank_zero:Data has 200499 tokens.
INFO:pytorch_lightning.utilities.rank_zero:########## Pile 20b-tokenized stage 3 ##########
RWKV_MY_TESTING 
Using /home/ubuntu/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py310_cu117/wkv5/build.ninja...
Building extension module wkv5...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/2] /usr/bin/g++-10 -MMD -MF wkv5_op.o.d -DTORCH_EXTENSION_NAME=wkv5 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include/TH -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include/THC -isystem /home/ubuntu/micromamba/envs/rwkv/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/cuda/wkv5_op.cpp -o wkv5_op.o 
[2/2] /usr/bin/g++-10 wkv5_op.o wkv5_cuda.cuda.o -shared -L/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda_cu -ltorch_cuda_cpp -ltorch -ltorch_python -L/usr/lib64 -lcudart -o wkv5.so
Loading extension module wkv5...
INFO:pytorch_lightning.utilities.rank_zero:########## Loading ../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth... ##########
Traceback (most recent call last):
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/train.py", line 284, in <module>
    model.load_state_dict(load_dict)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RWKV:
        Unexpected key(s) in state_dict: "blocks.12.ln1.weight", "blocks.12.ln1.bias", "blocks.12.ln2.weight", "blocks.12.ln2.bias", "blocks.12.att.time_mix_k", "blocks.12.att.time_mix_v", "blocks.12.att.time_mix_r", "blocks.12.att.time_mix_g", "blocks.12.att.time_decay", "blocks.12.att.time_faaaa", "blocks.12.att.receptance.weight", "blocks.12.att.key.weight", "blocks.12.att.value.weight", "blocks.12.att.output.weight", "blocks.12.att.gate.weight", "blocks.12.att.ln_x.weight", "blocks.12.att.ln_x.bias", "blocks.12.ffn.time_mix_k", "blocks.12.ffn.time_mix_r", "blocks.12.ffn.key.weight", "blocks.12.ffn.receptance.weight", "blocks.12.ffn.value.weight", "blocks.13.ln1.weight", "blocks.13.ln1.bias", "blocks.13.ln2.weight", "blocks.13.ln2.bias", "blocks.13.att.time_mix_k", "blocks.13.att.time_mix_v", "blocks.13.att.time_mix_r", "blocks.13.att.time_mix_g", "blocks.13.att.time_decay", "blocks.13.att.time_faaaa", "blocks.13.att.receptance.weight", "blocks.13.att.key.weight", "blocks.13.att.value.weight", "blocks.13.att.output.weight", "blocks.13.att.gate.weight", "blocks.13.att.ln_x.weight", "blocks.13.att.ln_x.bias", "blocks.13.ffn.time_mix_k", "blocks.13.ffn.time_mix_r", "blocks.13.ffn.key.weight", "blocks.13.ffn.receptance.weight", "blocks.13.ffn.value.weight", "blocks.14.ln1.weight", "blocks.14.ln1.bias", "blocks.14.ln2.weight", "blocks.14.ln2.bias", "blocks.14.att.time_mix_k", "blocks.14.att.time_mix_v", "blocks.14.att.time_mix_r", "blocks.14.att.time_mix_g", "blocks.14.att.time_decay", "blocks.14.att.time_faaaa", "blocks.14.att.receptance.weight", "blocks.14.att.key.weight", "blocks.14.att.value.weight", "blocks.14.att.output.weight", "blocks.14.att.gate.weight", "blocks.14.att.ln_x.weight", "blocks.14.att.ln_x.bias", "blocks.14.ffn.time_mix_k", "blocks.14.ffn.time_mix_r", "blocks.14.ffn.key.weight", "blocks.14.ffn.receptance.weight", "blocks.14.ffn.value.weight", "blocks.15.ln1.weight", "blocks.15.ln1.bias", "blocks.15.ln2.weight", "blocks.15.ln2.bias", "blocks.15.att.time_mix_k", "blocks.15.att.time_mix_v", "blocks.15.att.time_mix_r", "blocks.15.att.time_mix_g", "blocks.15.att.time_decay", "blocks.15.att.time_faaaa", "blocks.15.att.receptance.weight", "blocks.15.att.key.weight", "blocks.15.att.value.weight", "blocks.15.att.output.weight", "blocks.15.att.gate.weight", "blocks.15.att.ln_x.weight", "blocks.15.att.ln_x.bias", "blocks.15.ffn.time_mix_k", "blocks.15.ffn.time_mix_r", "blocks.15.ffn.key.weight", "blocks.15.ffn.receptance.weight", "blocks.15.ffn.value.weight", "blocks.16.ln1.weight", "blocks.16.ln1.bias", "blocks.16.ln2.weight", "blocks.16.ln2.bias", "blocks.16.att.time_mix_k", "blocks.16.att.time_mix_v", "blocks.16.att.time_mix_r", "blocks.16.att.time_mix_g", "blocks.16.att.time_decay", "blocks.16.att.time_faaaa", "blocks.16.att.receptance.weight", "blocks.16.att.key.weight", "blocks.16.att.value.weight", "blocks.16.att.output.weight", "blocks.16.att.gate.weight", "blocks.16.att.ln_x.weight", "blocks.16.att.ln_x.bias", "blocks.16.ffn.time_mix_k", "blocks.16.ffn.time_mix_r", "blocks.16.ffn.key.weight", "blocks.16.ffn.receptance.weight", "blocks.16.ffn.value.weight", "blocks.17.ln1.weight", "blocks.17.ln1.bias", "blocks.17.ln2.weight", "blocks.17.ln2.bias", "blocks.17.att.time_mix_k", "blocks.17.att.time_mix_v", "blocks.17.att.time_mix_r", "blocks.17.att.time_mix_g", "blocks.17.att.time_decay", "blocks.17.att.time_faaaa", "blocks.17.att.receptance.weight", "blocks.17.att.key.weight", "blocks.17.att.value.weight", "blocks.17.att.output.weight", "blocks.17.att.gate.weight", "blocks.17.att.ln_x.weight", "blocks.17.att.ln_x.bias", "blocks.17.ffn.time_mix_k", "blocks.17.ffn.time_mix_r", "blocks.17.ffn.key.weight", "blocks.17.ffn.receptance.weight", "blocks.17.ffn.value.weight", "blocks.18.ln1.weight", "blocks.18.ln1.bias", "blocks.18.ln2.weight", "blocks.18.ln2.bias", "blocks.18.att.time_mix_k", "blocks.18.att.time_mix_v", "blocks.18.att.time_mix_r", "blocks.18.att.time_mix_g", "blocks.18.att.time_decay", "blocks.18.att.time_faaaa", "blocks.18.att.receptance.weight", "blocks.18.att.key.weight", "blocks.18.att.value.weight", "blocks.18.att.output.weight", "blocks.18.att.gate.weight", "blocks.18.att.ln_x.weight", "blocks.18.att.ln_x.bias", "blocks.18.ffn.time_mix_k", "blocks.18.ffn.time_mix_r", "blocks.18.ffn.key.weight", "blocks.18.ffn.receptance.weight", "blocks.18.ffn.value.weight", "blocks.19.ln1.weight", "blocks.19.ln1.bias", "blocks.19.ln2.weight", "blocks.19.ln2.bias", "blocks.19.att.time_mix_k", "blocks.19.att.time_mix_v", "blocks.19.att.time_mix_r", "blocks.19.att.time_mix_g", "blocks.19.att.time_decay", "blocks.19.att.time_faaaa", "blocks.19.att.receptance.weight", "blocks.19.att.key.weight", "blocks.19.att.value.weight", "blocks.19.att.output.weight", "blocks.19.att.gate.weight", "blocks.19.att.ln_x.weight", "blocks.19.att.ln_x.bias", "blocks.19.ffn.time_mix_k", "blocks.19.ffn.time_mix_r", "blocks.19.ffn.key.weight", "blocks.19.ffn.receptance.weight", "blocks.19.ffn.value.weight", "blocks.20.ln1.weight", "blocks.20.ln1.bias", "blocks.20.ln2.weight", "blocks.20.ln2.bias", "blocks.20.att.time_mix_k", "blocks.20.att.time_mix_v", "blocks.20.att.time_mix_r", "blocks.20.att.time_mix_g", "blocks.20.att.time_decay", "blocks.20.att.time_faaaa", "blocks.20.att.receptance.weight", "blocks.20.att.key.weight", "blocks.20.att.value.weight", "blocks.20.att.output.weight", "blocks.20.att.gate.weight", "blocks.20.att.ln_x.weight", "blocks.20.att.ln_x.bias", "blocks.20.ffn.time_mix_k", "blocks.20.ffn.time_mix_r", "blocks.20.ffn.key.weight", "blocks.20.ffn.receptance.weight", "blocks.20.ffn.value.weight", "blocks.21.ln1.weight", "blocks.21.ln1.bias", "blocks.21.ln2.weight", "blocks.21.ln2.bias", "blocks.21.att.time_mix_k", "blocks.21.att.time_mix_v", "blocks.21.att.time_mix_r", "blocks.21.att.time_mix_g", "blocks.21.att.time_decay", "blocks.21.att.time_faaaa", "blocks.21.att.receptance.weight", "blocks.21.att.key.weight", "blocks.21.att.value.weight", "blocks.21.att.output.weight", "blocks.21.att.gate.weight", "blocks.21.att.ln_x.weight", "blocks.21.att.ln_x.bias", "blocks.21.ffn.time_mix_k", "blocks.21.ffn.time_mix_r", "blocks.21.ffn.key.weight", "blocks.21.ffn.receptance.weight", "blocks.21.ffn.value.weight", "blocks.22.ln1.weight", "blocks.22.ln1.bias", "blocks.22.ln2.weight", "blocks.22.ln2.bias", "blocks.22.att.time_mix_k", "blocks.22.att.time_mix_v", "blocks.22.att.time_mix_r", "blocks.22.att.time_mix_g", "blocks.22.att.time_decay", "blocks.22.att.time_faaaa", "blocks.22.att.receptance.weight", "blocks.22.att.key.weight", "blocks.22.att.value.weight", "blocks.22.att.output.weight", "blocks.22.att.gate.weight", "blocks.22.att.ln_x.weight", "blocks.22.att.ln_x.bias", "blocks.22.ffn.time_mix_k", "blocks.22.ffn.time_mix_r", "blocks.22.ffn.key.weight", "blocks.22.ffn.receptance.weight", "blocks.22.ffn.value.weight", "blocks.23.ln1.weight", "blocks.23.ln1.bias", "blocks.23.ln2.weight", "blocks.23.ln2.bias", "blocks.23.att.time_mix_k", "blocks.23.att.time_mix_v", "blocks.23.att.time_mix_r", "blocks.23.att.time_mix_g", "blocks.23.att.time_decay", "blocks.23.att.time_faaaa", "blocks.23.att.receptance.weight", "blocks.23.att.key.weight", "blocks.23.att.value.weight", "blocks.23.att.output.weight", "blocks.23.att.gate.weight", "blocks.23.att.ln_x.weight", "blocks.23.att.ln_x.bias", "blocks.23.ffn.time_mix_k", "blocks.23.ffn.time_mix_r", "blocks.23.ffn.key.weight", "blocks.23.ffn.receptance.weight", "blocks.23.ffn.value.weight". 
        size mismatch for emb.weight: copying a param with shape torch.Size([65536, 1024]) from checkpoint, the shape in current model is torch.Size([65536, 768]).
        size mismatch for blocks.0.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln0.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln0.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.0.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.0.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.0.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.1.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.1.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.1.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.1.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.2.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.2.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.2.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.2.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.3.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.3.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.3.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.3.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.4.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.4.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.4.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.4.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.5.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.5.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.5.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.5.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.6.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.6.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.6.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.6.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.7.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.7.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.7.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.7.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.8.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.8.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.8.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.8.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.9.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.9.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.9.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.9.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.10.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.10.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.10.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.10.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.11.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.11.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.11.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.11.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for ln_out.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for ln_out.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for head.weight: copying a param with shape torch.Size([65536, 1024]) from checkpoint, the shape in current model is torch.Size([65536, 768]).

@BlinkDL
Copy link
Owner

BlinkDL commented Jan 16, 2024

for 0.4B finetuning, set:
N_LAYER="24"
N_EMBD="1024"
LR_INIT="2e-5"
LR_FINAL="2e-5"
GRAD_CP="1"

@Ethan-Chen-plus
Copy link
Author

Thanks for helping! But I wonder why set LR_INIT==LR_FINAL?
Another Question is that if I set GRAD_CP=0, the cost of mem will be more and I will receive OOM.

INFO:pytorch_lightning.strategies.deepspeed:initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
INFO:pytorch_lightning.utilities.rank_zero:Enabling DeepSpeed BF16.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]
Using /home/ubuntu/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py310_cu117/fused_adam/build.ninja...
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused_adam...
Time to load fused_adam op: 0.06461381912231445 seconds
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:2 to store for rank: 0
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name   | Type       | Params
--------------------------------------
0 | emb    | Embedding  | 67.1 M
1 | blocks | ModuleList | 327 M 
2 | ln_out | LayerNorm  | 2.0 K 
3 | head   | Linear     | 67.1 M
--------------------------------------
461 M     Trainable params
0         Non-trainable params
461 M     Total params
1,846.886 Total estimated model params size (MB)
Epoch 0:   0%|                                           | 0/2520 [00:00<?, ?it/s]
{'zero_allow_untested_optimizer': True, 'zero_optimization': {'stage': 2, 'contiguous_gradients': True, 'overlap_comm': True, 'allgather_partitions': True, 'reduce_scatter': True, 'allgather_bucket_size': 200000000, 'reduce_bucket_size': 200000000, 'sub_group_size': 1000000000000}, 'activation_checkpointing': {'partition_activations': False, 'cpu_checkpointing': False, 'contiguous_memory_optimization': False, 'synchronize_checkpoint_boundary': False}, 'aio': {'block_size': 1048576, 'queue_depth': 8, 'single_submit': False, 'overlap_events': True, 'thread_count': 1}, 'gradient_accumulation_steps': 1, 'train_micro_batch_size_per_gpu': 16, 'gradient_clipping': 1.0, 'bf16': {'enabled': True}}

Login to wandb...
wandb: Currently logged in as: keweichen (aicolab). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in /home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/wandb/run-20240116_132216-ck3u9wok
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run 65536 ctx512 L24 D1024 2024-01-16-13-22-02
wandb: ⭐️ View project at https://wandb.ai/aicolab/RWKV-5-Test
wandb: 🚀 View run at https://wandb.ai/aicolab/RWKV-5-Test/runs/ck3u9wok
Traceback (most recent call last):
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/train.py", line 312, in <module>
    trainer.fit(model, data_loader)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 88, in launch
    return function(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run
    results = self._run_stage()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage
    self._run_train()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train
    self.fit_loop.run()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 213, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 202, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 249, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 370, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1356, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1754, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 169, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 280, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 234, in optimizer_step
    return self.precision_plugin.optimizer_step(
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/deepspeed.py", line 132, in optimizer_step
    closure_result = closure()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 149, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 144, in closure
    self._backward_fn(step_output.closure_loss)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 305, in backward_fn
    self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1494, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 207, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, optimizer_idx, *args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/deepspeed.py", line 118, in backward
    deepspeed_engine.backward(tensor, *args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1955, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2019, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 21.99 GiB total capacity; 20.70 GiB already allocated; 287.00 MiB free; 20.88 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
wandb: 🚀 View run 65536 ctx512 L24 D1024 2024-01-16-13-22-02 at: https://wandb.ai/aicolab/RWKV-5-Test/runs/ck3u9wok
wandb: ️⚡ View job at https://wandb.ai/aicolab/RWKV-5-Test/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEyOTkxNjc0MQ==/version_details/v1
wandb: Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240116_132216-ck3u9wok/logs

image
I currently have 4 cards of A10-22G, how can I maximize the utilization of computing power and memory?

@BlinkDL
Copy link
Owner

BlinkDL commented Jan 16, 2024

set --devices 4 to use 4 GPU

CUDA_VISIBLE_DEVICES=0,1,2,3

@Ethan-Chen-plus
Copy link
Author

Ethan-Chen-plus commented Jan 19, 2024

Thanks again @BlinkDL
image

I have another question I'd like to ask: Currently, I'm using a context length (ctx_len) of 1024 for full fine-tuning a model with only 0.4B parameters, specifically rwkv5, but it's almost maxing out the memory on all four of my A10 GPUs. However, llama2-7b can run full-scale on four A10 cards with a context length of 4096. Is there a way I can enable my v5 model to run full-scale training with a context length of 4096 using model parallelism across four GPUs?

@PicoCreator
Copy link
Contributor

Check your "gradient checkpoint" flag, disabling gives a speed boost, for much more VRAM usage (llama typically have that set to true)

@BlinkDL
Copy link
Owner

BlinkDL commented Jan 26, 2024

@Ethan-Chen-plus set GRAD_CP=1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants