Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cudnn conv2d #435

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Add cudnn conv2d #435

wants to merge 2 commits into from

Conversation

yudi0201
Copy link
Collaborator

@yudi0201 yudi0201 commented Mar 7, 2024

No description provided.

Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yudi0201 !

Overall looks good to me. After merging this PR, we can add a primitive function to call the conv2d_cudnn in our runtime library and have an operator like hidet.ops.conv2d_cudnn.


void *dev_ptrs[3] = {ptr_x, ptr_w, ptr_y}; // device pointers
int64_t uids[3] = {'x', 'w', 'y'};
void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to use the workspace shared by all hidet operators (i.e., https://github.com/hidet-org/hidet/blob/main/include/hidet/runtime/cuda/context.h#L46).

When we run the operator in the second time, there will not be any memory allocation. Thus, it can also be used in cudaGraph.

Comment on lines +605 to +613
CHECK_CUDNN(cudnnBackendDestroyDescriptor(xDesc));
CHECK_CUDNN(cudnnBackendDestroyDescriptor(wDesc));
CHECK_CUDNN(cudnnBackendDestroyDescriptor(yDesc));
CHECK_CUDNN(cudnnBackendDestroyDescriptor(cDesc));
CHECK_CUDNN(cudnnBackendDestroyDescriptor(fprop));
CHECK_CUDNN(cudnnBackendDestroyDescriptor(op_graph));
CHECK_CUDNN(cudnnBackendDestroyDescriptor(engine));
CHECK_CUDNN(cudnnBackendDestroyDescriptor(engcfg));
CHECK_CUDNN(cudnnBackendDestroyDescriptor(plan));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to benchmark the performance of our implementation vs. PyTorch's conv2d performance. I am not sure whether the overhead of creating/destroying descriptors is large enough to influence the performance.

@vadiklyutiy
Copy link
Collaborator

@yaoyaoding
A little bit different but connected question.
Are we planing to include conv2d_cudnn to search space? I mean we can search via our current space + cudnn implementation. Is it doable at all (wo huge redesign)?

@yaoyaoding
Copy link
Member

It's doable, similar to cublas gemm: 072a606

@vadiklyutiy
Copy link
Collaborator

vadiklyutiy commented Mar 7, 2024

What about adding cudnn*, cublas* etc to search space?

@yaoyaoding
Copy link
Member

What about adding cudnn*, cublas* etc to search space?

That's exactly what the commint I mentioned before does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants