Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. T #1539

Open
2 tasks done
apachemycat opened this issue May 3, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@apachemycat
Copy link

Prerequisite

Environment

Pytorch 2.3 - If use_reentrant is not explicitly passed, an exception will now be raised

Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 360, in
main()
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 356, in main
runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 287, in run
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 311, in run_iter
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 121, in train_step
losses = self._run_forward(data, mode='loss')
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 161, in _run_forward
results = self(**data, mode=mode)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 228, in forward
return self.compute_loss(data, data_samples)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 272, in compute_loss
return self._compute_sequence_parallel_loss(data)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 262, in _compute_sequence_parallel_loss
outputs = self.llm(**data)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1395, in forward
return self.base_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 179, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1205, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 989, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 460, in checkpoint
raise ValueError(
ValueError: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

Reproduces the problem - code sample

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=SupervisedFinetune,
use_varlen_attn=use_varlen_attn,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

training_args = dict(
type=TrainingArguments,
gradient_checkpointing=True, # Leads to reduction in memory at slighly decrease in speed
gradient_checkpointing_kwargs={"use_reentrant": False}
)
#######################################################################

PART 3 Dataset & Dataloader

#######################################################################
alpaca_en = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path='json',data_files=data_files),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=alpaca_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length,
use_varlen_attn=use_varlen_attn)

sampler = SequenceParallelSampler
if sequence_parallel_size > 1 else DefaultSampler
train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=alpaca_en,
sampler=dict(type=sampler, shuffle=True),
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################

PART 4 Scheduler & Optimizer

#######################################################################

optimizer

optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

learning policy

More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501

auto_scale_lr=dict(base_batch_size=4, enable=True)
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
convert_to_iter_based=True)
]

train, val, test setting

train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################

PART 5 Runtime

#######################################################################

Log the dialogue periodically during the training process, optional

#custom_hooks = [

dict(type=DatasetInfoHook, tokenizer=tokenizer),

dict(

type=EvaluateChatHook,

tokenizer=tokenizer,

every_n_iters=evaluation_freq,

evaluation_inputs=evaluation_inputs,

system=SYSTEM,

prompt_template=prompt_template)

#]

if use_varlen_attn:
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

configure default hooks

default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per save_steps.
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

configure environment

env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

set visualizer

visualizer = dict(type='Visualizer', vis_backends=[dict(type='TensorboardVisBackend')])

set log level

log_level = 'DEBUG'

load from which checkpoint

load_from = None

whether to resume training from the loaded checkpoint

resume = True

Defaults to use random seed and disable deterministic

randomness = dict(seed=None, deterministic=False)

set log processor

log_processor = dict(by_epoch=False)

Reproduces the problem - command or script

NCCL_DEBUG=INFO LOGLEVEL=DEBUG NPROC_PER_NODE=1 torchrun /usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py --launcher pytorch --work-dir ${work_dir} config.py

Reproduces the problem - error message

Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 360, in
main()
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 356, in main
runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 287, in run
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 311, in run_iter
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 121, in train_step
losses = self._run_forward(data, mode='loss')
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 161, in _run_forward
results = self(**data, mode=mode)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 228, in forward
return self.compute_loss(data, data_samples)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 272, in compute_loss
return self._compute_sequence_parallel_loss(data)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 262, in _compute_sequence_parallel_loss
outputs = self.llm(**data)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1395, in forward
return self.base_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 179, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1205, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 989, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 460, in checkpoint
raise ValueError(
ValueError: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

Additional information

no error

@apachemycat apachemycat added the bug Something isn't working label May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant