Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

如何再8*A100上预训练128k长度的llama3? #683

Open
1518630367 opened this issue May 13, 2024 · 2 comments
Open

如何再8*A100上预训练128k长度的llama3? #683

1518630367 opened this issue May 13, 2024 · 2 comments

Comments

@1518630367
Copy link

看README的图表是可以训练的,但是我一直OOM

@1518630367
Copy link
Author

1518630367 commented May 13, 2024

import torch
from datasets import load_dataset
from mmengine.config import read_base
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.engine.hooks import DatasetInfoHook
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune

with read_base():
from .map_fn import pretrain_map_fn as dataset_map_fn

pretrained_model_name_or_path = '/opt/218/models/Meta-Llama-3-8B-Instruct-continue_pre'

data_path = './train_128k_1000.jsonl'
max_length = 128000
pack_to_max_length = True

batch_size = 1 # per_device
accumulative_counts = 1
dataloader_num_workers = 10
max_epochs = 3
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 1
max_norm = 1 # grad clip

save_steps = 100
save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited)

Evaluate the generation performance during the training

evaluation_freq = 1000
SYSTEM = ''
evaluation_inputs = [
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
]

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

train_dataset = dict(
type=process_hf_dataset,
dataset=dict(
type=load_dataset, path='json', data_files=dict(train=data_path)),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=None,
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')
param_scheduler = dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
end=max_epochs,
convert_to_iter_based=True)

train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

custom_hooks = [dict(type=DatasetInfoHook, tokenizer=tokenizer)]

default_hooks = dict(
timer=dict(type=IterTimerHook),
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
param_scheduler=dict(type=ParamSchedulerHook),
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
sampler_seed=dict(type=DistSamplerSeedHook),
)

env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)

visualizer = None

log_level = 'INFO'

load_from = None

resume = False

randomness = dict(seed=None, deterministic=False)

log_processor = dict(by_epoch=False)

@pppppM
Copy link
Collaborator

pppppM commented May 14, 2024

训练超长序列,需要使用序列并行

https://xtuner.readthedocs.io/zh-cn/docs/acceleration/train_extreme_long_sequence.html#id7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants