Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

paddle进行开启分布式训练时如何进行性能分析? #63858

Open
Sakurafwsv opened this issue Apr 25, 2024 · 4 comments
Open

paddle进行开启分布式训练时如何进行性能分析? #63858

Sakurafwsv opened this issue Apr 25, 2024 · 4 comments
Assignees
Labels

Comments

@Sakurafwsv
Copy link

请提出你的问题 Please ask your question

paddle开启分布式训练时,如tp、pp均为2时,我想进行性能分析,查看板卡占用信息确实可以看到4张卡都被占用了,但是打印device时永远都是0卡,这怎样查看所有卡的信息并进行分析呢?

@lijialin03
Copy link
Contributor

您好,请问是否可以提供相关命令/截图/报错信息等,以便进行相关开发同学进行复现,谢谢~

@Sakurafwsv
Copy link
Author

您好,请问是否可以提供相关命令/截图/报错信息等,以便进行相关开发同学进行复现,谢谢~

编译并安装paddlecustomdevice中的mlu后端,并在runtime.cc文件中的SetDevice接口加入std::cout << "-----------device id:" << device->id << std::endl;以打印device信息,执行paddlenlp的llama pretrain脚本,开启tp、pp训练,终端输出的device id一直为0

@Sakurafwsv
Copy link
Author

您好,请问是否可以提供相关命令/截图/报错信息等,以便进行相关开发同学进行复现,谢谢~

运行llama脚本如下
set -x
unset CUDA_VISIBLE_DEVICES
export MLU_VISIBLE_DEVICES="0,1,2,3"
task_name="llama_hybrid"
rm -rf output/$task_name/
rm -rf "output/$task_name""_log"
export PADDLE_DISTRI_BACKEND="xccl"
export PADDLE_XCCL_BACKEND=mlu

PYTHONPATH=../:$PYTHONPATH
python -u -m paddle.distributed.launch
--devices "0,1,2,3"
--log_dir "output/$task_name""_log"
run_pretrain.py
--model_name_or_path "facebook/llama-7b"
--tokenizer_name_or_path "facebook/llama-7b"
--input_dir "/workspace/llm-pretrain-data"
--output_dir "output/$task_name"
--split 949,50,1
--tensor_parallel_degree 2
--pipeline_parallel_degree 2
--max_seq_length 1024
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--use_flash_attention 0
--use_fused_rms_norm 0
--fp16 false
--fp16_opt_level "O0"
--scale_loss 1024
--learning_rate 0.0001
--min_learning_rate 0.00001
--max_steps 200
--save_steps 5000
--weight_decay 0.01
--warmup_ratio 0.01
--max_grad_norm 1.0
--logging_steps 5
--dataloader_num_workers 1
--eval_steps 1000
--report_to "visualdl"
--disable_tqdm true
--continue_training 0
--recompute 1
--do_train
--do_eval
--device "mlu"
--data_impl "mmap"
--gradient_accumulation_steps 16
--sequence_parallel 1
--pipeline_parallel_config disable_partial_send_recv \

@paddle-bot paddle-bot bot added status/following-up 跟进中 and removed status/new-issue 新建 labels Apr 26, 2024
@qili93
Copy link
Contributor

qili93 commented Apr 26, 2024

您好,您可以在作业运行的同时,输入cnmon的命令查看0-3号卡是否正常运行吗?

如果正常被占用的话,性能分析可以使用飞桨的原生Profiler,MLU在适配Paddle时候有通过 https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/mlu/runtime/runtime.cc#L982 这里定义的profiler接口支持飞桨的原生Profiler。

飞桨原生Profiler的使用可以查看文档 https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/profiler/Profiler_cn.html#profiler 里面的使用说明,代码修改可以这个PR PaddlePaddle/PaddleCustomDevice#785 的样例代码,将其中的代码参考如下修改即可。

profiler = profiler.Profiler(targets=[profiler.ProfilerTarget.CUSTOM_DEVICE], custom_device_types=['mlu'])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants