Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused llama kernel #10266

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open

Fused llama kernel #10266

wants to merge 30 commits into from

Conversation

ofhwei
Copy link
Contributor

@ofhwei ofhwei commented May 15, 2023

llama模型并行推理优化,将每一层LlamaDecoderLayer 所有的cuda kernel放在一个大op里, 尽可能减少python层面指令发送的延迟。

Comment on lines 846 to 847
std::map<int, std::shared_ptr<OpExpr>> ops_;
std::map<int, std::shared_ptr<OpExpr>> ops_with_past_key_value_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用hash map会快一点

const TensorTuple& past_values, const int64_t head_size) const {
int64_t num_layers = input_norm_weights.size();
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("head_size", "num_layers", "parallel_conf");
auto conf = PbMessage2TxtString(JUST(hidden_states->parallel_desc())->parallel_conf());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以缓存一下,proto对象每次计算序列化比较耗时

Comment on lines 756 to 770
.Output("rms_norm_out")
.Output("inv_rms")
.Output("query")
.Output("key")
.Output("value")
.Output("rotary_query")
.Output("rotary_key")
.Output("concat_keys", num_layers)
.Output("concat_values", num_layers)
.Output("attn_out")
.Output("out")
.Output("post_norm_out")
.Output("gate_out")
.Output("glu_out")
.Output("decoder_out")
Copy link
Contributor

@clackhan clackhan May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是只需要输出"concat_keys" , "concat_values" 和"decoder_out"就可以了,如果是中间变量的话可以用temp buffer

@yuanms2
Copy link
Contributor

yuanms2 commented May 16, 2023

fast transformer 是这样做的吗?

llama 的 python 实现需要手工改动吗? 还是自动通过模式匹配实现的?

@clackhan
Copy link
Contributor

clackhan commented May 16, 2023

fast transformer 是这样做的吗?

fast transformer是纯c++实现,可以认为是一个专用实现,代码中实现了一个Llama类,编译生成一个可行性的二进制文件,运行时创建一个Llama实例,在创建这个对象时会统一申请全部计算所需内存,析构时统一释放内存,因为是纯c++计算且整个过程没有内存申请操作,所以整个算子launch过程非常快。目前Llama还处于第三方pr状态,没有python实现。

fast transformer主仓库中比较成熟的实现如GPT,也是基本上是这个套路,其pytorch和tensorflow实现就是将c++端的class GptOp包装一下导出到python端。

llama 的 python 实现需要手工改动吗? 还是自动通过模式匹配实现的?

使用融合算子时需要手工改动代码。

@strint
Copy link
Contributor

strint commented May 16, 2023

在创建这个对象时会统一申请全部计算所需内存,析构时统一释放内存,因为是纯c++计算且整个过程没有内存申请操作

之前提到推理时有个动态 shape 的问题,它是取 max 去申请了内存么

@clackhan
Copy link
Contributor

在创建这个对象时会统一申请全部计算所需内存,析构时统一释放内存,因为是纯c++计算且整个过程没有内存申请操作

之前提到推理时有个动态 shape 的问题,它是取 max 去申请了内存么

是的,申请了最大所需内存

@ofhwei
Copy link
Contributor Author

ofhwei commented May 16, 2023

在创建这个对象时会统一申请全部计算所需内存,析构时统一释放内存,因为是纯c++计算且整个过程没有内存申请操作

ft 的kv_cache 按最大所需长度max_cache_seq_len分配显存 见 https://github.com/void-main/FasterTransformer/blob/main/src/fastertransformer/models/llama/Llama.cc#L102

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.


namespace oneflow {
namespace cuda {
namespace rms_norm_output_norm_arg {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以加个注释,说明有两个输出,分别指什么。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants