New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cudnn conv2d #435
base: main
Are you sure you want to change the base?
Add cudnn conv2d #435
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yudi0201 !
Overall looks good to me. After merging this PR, we can add a primitive function to call the conv2d_cudnn in our runtime library and have an operator like hidet.ops.conv2d_cudnn
.
src/hidet/runtime/cuda/cudnn.cpp
Outdated
|
||
void *dev_ptrs[3] = {ptr_x, ptr_w, ptr_y}; // device pointers | ||
int64_t uids[3] = {'x', 'w', 'y'}; | ||
void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to use the workspace shared by all hidet operators (i.e., https://github.com/hidet-org/hidet/blob/main/include/hidet/runtime/cuda/context.h#L46).
When we run the operator in the second time, there will not be any memory allocation. Thus, it can also be used in cudaGraph.
CHECK_CUDNN(cudnnBackendDestroyDescriptor(xDesc)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(wDesc)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(yDesc)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(cDesc)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(fprop)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(op_graph)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(engine)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(engcfg)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(plan)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be good to benchmark the performance of our implementation vs. PyTorch's conv2d performance. I am not sure whether the overhead of creating/destroying descriptors is large enough to influence the performance.
@yaoyaoding |
It's doable, similar to cublas gemm: 072a606 |
What about adding cudnn*, cublas* etc to search space? |
That's exactly what the commint I mentioned before does. |
No description provided.