Skip to content

Commit

Permalink
[Fixbug] Fix dynamic memcpy bug (#427)
Browse files Browse the repository at this point in the history
Minimal failure case:

```
resize_inputs: Tensor = symbol([1, 3, "h", "w"], dtype="int32", device="cpu")
resize_outputs = self.resize(resize_inputs.to(self.dtype, self.device))  # (float32, cuda)
resize_graph: FlowGraph = trace_from(resize_outputs, resize_inputs)

resize_graph.build()
```
compiles this launch where symbols `h` and `w` are undefined.

```
DLL void hidet_launch_0(float * __restrict__ x, float * __restrict__ y) {
  cudaMemcpyAsync(y, x, (4 * ((3 * h) * w)), cudaMemcpyHostToDevice, (cudaStream_t)get_cuda_stream());
}
```

Fix is to add exprs to BlackBoxStmt so that symbols defined in exprs can
be visited during codegen.
  • Loading branch information
KTong821 committed Feb 21, 2024
1 parent 5f76caf commit f3ccb87
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion python/hidet/ir/primitives/cuda/memcpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ def memcpy_async(dst: Expr, src: Expr, count: Expr, kind: str):
raise RuntimeError(f'Unsupported transfer from {src} to {dst}, candidate kinds are {list(kind_map.keys())}')

return BlackBoxStmt(
'cudaMemcpyAsync({}, {}, {}, {}, (cudaStream_t){});'.format(dst, src, count, kind_map[kind], get_cuda_stream())
f'cudaMemcpyAsync({{}}, {{}}, {{}}, {kind_map[kind]}, (cudaStream_t){{}});', dst, src, count, get_cuda_stream()
)

0 comments on commit f3ccb87

Please sign in to comment.