New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Add GPU kernels. #372
base: main
Are you sure you want to change the base?
Conversation
if (nbytes == 0) { return nullptr; } | ||
|
||
void *data; | ||
|
||
#ifdef GPU | ||
if (device != nullptr) { | ||
data = sycl::malloc_device<char>(nbytes, *static_cast<sycl::queue *>(device)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The allocation may fail, need to deal with the fail case?
src/layers/attention.h
Outdated
#ifdef GPU | ||
sycl::queue *q = static_cast<sycl::queue *>(ctx->device); | ||
int64_t size = ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float); | ||
q->memcpy(qkvMatMul.Data(), query.Data(), size).wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need a copy here?
src/models/model_factory.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not need to change this file.
src/layers/mlp_llama.h
Outdated
@@ -275,8 +275,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> { | |||
} | |||
} | |||
|
|||
template <typename T1, typename T2> | |||
void catGateUpProj(DecoderContext *ctx, hpj::Matrix<T1> &input, hpj::Matrix<T2> &output, hpj::Matrix<T2> &siluBuf) { | |||
void catGateUpProj(DecoderContext *ctx, hpj::Matrix<InT> &input, hpj::Matrix<ImT> &output, hpj::Matrix<ImT> &siluBuf) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed because of compiler error/warning? suggest using OutT for output.
@@ -349,8 +349,11 @@ class MMHelper { | |||
// W8A8 | |||
else if constexpr (std::is_same_v<WeiT, w8a8_t>) { | |||
using dt = dnnl::memory::data_type; | |||
dnnl::engine eng(dnnl::engine::kind::cpu, 0); | |||
dnnl::stream stm(eng); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why now need to create an engine and stream every time calling into the function? it may impact the performance.
No description provided.