Skip to content

Commit

Permalink
[CuBLAS] Add CuBLAS benchmarks
Browse files Browse the repository at this point in the history
Some CuBLAS benchmarking results on RTX2080 TI (all measurements are median latencies):

SECTION 1
FP32 Matrix Multiply: C (bs x m x n) = A (bs x m x k) @ B(bs x k x n)

Group 1 results with m = 512, n = 512, k = 512
bs = 1:
cublas_batched_gemm            69.0us
cublas_strided_gemm            41.0us
hidet.ops.matmul optimized     37.0us
PyTorch                        44.6us

bs = 2:
cublas_batched_gemm            111.7us
cublas_strided_gemm            75.8us
hidet.ops.matmul optimized     69.2us
PyTorch                        71.7us

bs = 4:
cublas_batched_gemm            124.9us
cublas_strided_gemm            97.2us
hidet.ops.matmul optimized     100.8us
PyTorch                        96.3us

bs = 8:
cublas_batched_gemm            190.5us
cublas_strided_gemm            191.1us
hidet.ops.matmul optimized     204.7us
PyTorch                        187.6us

Group 2 results with m = 1024, n = 1024, k = 2048
bs = 1:
cublas_batched_gemm            405.1us
cublas_strided_gemm            419.2us
hidet.ops.matmul optimized     370.7us
PyTorch                        405.1us

bs = 2:
cublas_batched_gemm            725.3us
cublas_strided_gemm            859.9us
hidet.ops.matmul optimized     800.8us
PyTorch                        719.2us

bs = 4:
cublas_batched_gemm            1442us
cublas_strided_gemm            1592us
hidet.ops.matmul optimized     1606us
PyTorch                        1466us

bs = 8:
cublas_batched_gemm            2658us
cublas_strided_gemm            2830us
hidet.ops.matmul optimized     3475us
PyTorch                        2753us

SECTION 2
FP16 Matrix Multiply: C (bs x m x n) = A (bs x m x k) @ B(bs x k x n)

Group 1 results with m = 512, n = 512, k = 512
bs = 1:
cublas_batched_gemm            63.5us
cublas_strided_gemm            34.0us
hidet.ops.matmul optimized     34.9us
PyTorch                        41.0us

bs = 2:
cublas_batched_gemm            66.0us
cublas_strided_gemm            30.2us
hidet.ops.matmul optimized     64.8us
PyTorch                        45.1us

bs = 4:
cublas_batched_gemm            72.7us
cublas_strided_gemm            32.4us
hidet.ops.matmul optimized     24.4us
PyTorch                        46.3us

bs = 8:
cublas_batched_gemm            81.2us
cublas_strided_gemm            36.2us
hidet.ops.matmul optimized     38.5us
PyTorch                        47.8us

Group 2 results with m = 1024, n = 1024, k = 2048
bs = 1:
cublas_batched_gemm            71.0us
cublas_strided_gemm            60.1us
hidet.ops.matmul optimized     65.5us
PyTorch                        90.6us

bs = 2:
cublas_batched_gemm            114.8us
cublas_strided_gemm            112.3us
hidet.ops.matmul optimized     123.1us
PyTorch                        160.5us

bs = 4:
cublas_batched_gemm            225.1us
cublas_strided_gemm            223.4us
hidet.ops.matmul optimized     245.6us
PyTorch                        319.8us

bs = 8:
cublas_batched_gemm            442.8us
cublas_strided_gemm            439.1us
hidet.ops.matmul optimized     733.2us
PyTorch                        634.8us
  • Loading branch information
Yudi Sun committed Apr 5, 2024
1 parent 531b8d3 commit 1035fcb
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions python/hidet/cuda/cublas/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import math
import torch
import numpy as np

import hidet
from hidet.cuda.cublas import cublasComputeType
from hidet.utils.benchmark import do_bench
from hidet import ops


def benchmark_cublas_batched_gemm(bs, m, n, k, dtype, compute_type):
a, b, c = [], [], []
for i in range(bs):
a.append(hidet.randn((m, k), device='cuda', dtype=dtype) / math.sqrt(k))
b.append(hidet.randn((k, n), device='cuda', dtype=dtype) / math.sqrt(k))
c.append(hidet.empty((m, n), device='cuda', dtype=dtype))

latencies = do_bench(
lambda: hidet.cuda.cublas.batched_gemm(
bs, m, n, k, a[0].dtype, b[0].dtype, c[0].dtype, a, b, c, False, False, compute_type
),
warmup=10,
rep=100,
)

print(f"cublas_batched_gemm Results for Configuration: dtype = {dtype}, input shape = {[bs, m, n, k]}, ")
print("Median Latency Is: " + str(latencies[1]) + " milliseconds")
print("-------------------------------------------------")


def benchmark_cublas_strided_gemm(bs, m, n, k, dtype, compute_type):
a = hidet.randn((bs, m, k), device='cuda', dtype=dtype) / math.sqrt(k)
b = hidet.randn((bs, k, n), device='cuda', dtype=dtype) / math.sqrt(k)
c = hidet.empty((bs, m, n), device='cuda', dtype=dtype)

latencies = do_bench(
lambda: hidet.cuda.cublas.strided_gemm(
bs, m, n, k, a.dtype, b.dtype, c.dtype, a, b, c, m * k, k * n, m * n, False, False, compute_type
),
warmup=10,
rep=100,
)

print(f"cublas_strided_gemm Results for Configuration: dtype = {dtype}, input shape = {[bs, m, n, k]}, ")
print("Median Latency Is: " + str(latencies[1]) + " milliseconds")
print("-------------------------------------------------")


def benchmark_torch_batched_matmul(bs, m, n, k, dtype, compute_type):
a = torch.from_numpy(np.array(np.random.randn(bs, m, k)).astype(dtype)).cuda()
b = torch.from_numpy(np.array(np.random.randn(bs, k, n)).astype(dtype)).cuda()

latencies = do_bench(lambda: a @ b, warmup=10, rep=100)

print(f"torch_batched_matmul Results for Configuration: dtype = {dtype}, input shape = {[bs, m, n, k]}, ")
print("Median Latency Is: " + str(latencies[1]) + " milliseconds")
print("-------------------------------------------------")


def benchmark_hidet_batched_matmul(bs, m, n, k, dtype, compute_type):
a = hidet.symbol((bs, m, k), device='cuda', dtype=dtype)
b = hidet.symbol((bs, k, n), device='cuda', dtype=dtype)
c = ops.matmul(a, b)
hidet.option.search_space(2)
graph = hidet.trace_from(c, inputs=[a, b])
graph = hidet.graph.optimize(graph)
graph = graph.cuda_graph()

latencies = do_bench(lambda: graph.run_async(), warmup=10, rep=100)

print(f"hidet_batched_matmul Results for Configuration: dtype = {dtype}, input shape = {[bs, m, n, k]}, ")
print("Median Latency Is: " + str(latencies[1]) + " milliseconds")
print("-------------------------------------------------")


if __name__ == '__main__':
sizes = [
# # Group 1
[1, 512, 512, 512],
[2, 512, 512, 512],
[4, 512, 512, 512],
[8, 512, 512, 512],
# Group 2
[1, 1024, 1024, 2048],
[2, 1024, 1024, 2048],
[4, 1024, 1024, 2048],
[8, 1024, 1024, 2048],
]
dtypes = [['float32', cublasComputeType.CUBLAS_COMPUTE_32F], ['float16', cublasComputeType.CUBLAS_COMPUTE_16F]]

for data_type in dtypes:
for size in sizes:
# benchmark_cublas_batched_gemm(*(size + data_type))
benchmark_cublas_strided_gemm(*(size + data_type))
# benchmark_torch_batched_matmul(*(size + data_type))
# benchmark_hidet_batched_matmul(*(size + data_type))

0 comments on commit 1035fcb

Please sign in to comment.