Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: rocblas_gemm_ex with m==1 fp16 inputs/outputs f32 compute slower than a quite naive gemv kernel on MI100 #1425

Open
Epliz opened this issue May 5, 2024 · 9 comments
Assignees

Comments

@Epliz
Copy link

Epliz commented May 5, 2024

Describe the bug

As described in the title, rocblas_gemm_ex seems quite suboptimal when m==1 inputs/outputs are fp16 and compute is fp32 on MI100.
A quite naive kernel I implemented beats it.

Causes ROCm/pytorch#1408 in pytorch.
It make LLM inference on Mistral 7b fp16 slower compared to what it could easily be.

To Reproduce

Here is a C++ reproducer:

#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include <iostream>
#include <chrono>
#include <functional>


#define ROWS_PER_BLOCK 4
#define THREADS_PER_BLOCK 256

#define DIV_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

#define FULL_MASK32 0xffffffff
#define FULL_MASK64 0xffffffffffffffff

#ifdef  __CUDA_ARCH__
#define __xx_shfl_down(mask, val, offset) __shfl_down_sync(mask, val, offset)
#elif defined(__HIP_PLATFORM_AMD__) // AMD
#define __xx_shfl_down(mask, val, offset) __shfl_down(val, offset)
#else
#error "Unsupported compiler"
#endif

__device__ float warpReduce(float val) {
  if (warpSize == 32) {
    for (int offset = 16; offset > 0; offset /= 2)
      val += __xx_shfl_down(FULL_MASK32, val, offset);
  }
  if (warpSize == 64) {
    for (int offset = 32; offset > 0; offset /= 2)
      val += __xx_shfl_down(FULL_MASK64, val, offset);

  }
  return val;
}

static inline void __device__ dot2(float& acc, const float2& a, const float2& b) {
  acc += a.x * b.x;
  acc += a.y * b.y;
}

template <typename T>
const T* __device__ addr(const T* p, unsigned index) {
  // helps the AMDGPU compiler understand it can use the sgrp pair + single vgpr addressing mode
  unsigned byte_offset = sizeof(T) * index;
  const uint8_t* p8 = (const uint8_t*)p;
  return (const T*) (p8 + byte_offset);
}

__global__ void muillm_gemv_kernel(
    const half* __restrict__ W, // weight matrix - size N x K
    const half* __restrict__ B, // optional bias - size N
    const half* __restrict__ X, // input = size K
    half* __restrict__ Y, // output - size N
    unsigned N,
    unsigned K
) {
  int warpCounts = THREADS_PER_BLOCK / warpSize;
  int warpId = threadIdx.x / warpSize;
  int laneId = threadIdx.x % warpSize;

#if ROWS_PER_BLOCK == 4
  int current_row = blockIdx.x * ROWS_PER_BLOCK + 0;
  if (current_row + 3 < N) {
    // can process ROWS_PER_BLOCK rows
    // shared state to do the reductions
    __shared__ float shared_accs[ROWS_PER_BLOCK];
    __shared__ int shared_reduction_counter;

    if (laneId == 0) {
      shared_accs[warpId] = 0.f;
      shared_reduction_counter = 0;
    }
    __syncthreads();

    // compute the t-th element of Y. by doing the dot product with the
    // t-th row of W
    const half* W0 = &W[(current_row + 0) * K];
    const half* W1 = &W[(current_row + 1) * K];
    const half* W2 = &W[(current_row + 2) * K];
    const half* W3 = &W[(current_row + 3) * K];

    float acc0 = 0.f;
    float acc1 = 0.f;
    float acc2 = 0.f;
    float acc3 = 0.f;

    // do the dot product
    {
      unsigned k; // should be 2 * tidx ?
      //*
      for (k = threadIdx.x * 2; k + 1 < K; k += (THREADS_PER_BLOCK * 2)) {
        // vectorized
        float2 x = __half22float2(*((const half2*)addr(X, k)));
        float2 w0 = __half22float2(*((const half2*)addr(W0, k)));
        float2 w1 = __half22float2(*((const half2*)addr(W1, k)));
        float2 w2 = __half22float2(*((const half2*)addr(W2, k)));
        float2 w3 = __half22float2(*((const half2*)addr(W3, k)));

        dot2(acc0, w0, x);
        dot2(acc1, w1, x);
        dot2(acc2, w2, x);
        dot2(acc3, w3, x);
      }
      //*/
      for (; k < K; k += THREADS_PER_BLOCK) {
        // remainder
        float x = __half2float(*addr(X,k));
        float w0 = __half2float(*addr(W0,k));
        float w1 = __half2float(*addr(W1,k));
        float w2 = __half2float(*addr(W2,k));
        float w3 = __half2float(*addr(W3,k));
        acc0 += w0 * x;
        acc1 += w1 * x;
        acc2 += w2 * x;
        acc3 += w3 * x;
      }
    }

    // warp reduce
    acc0 = warpReduce(acc0);
    acc1 = warpReduce(acc1);
    acc2 = warpReduce(acc2);
    acc3 = warpReduce(acc3);

    // reduce accross warps
    if (laneId == 0) {
      atomicAdd(&shared_accs[0], acc0);
      atomicAdd(&shared_accs[1], acc1);
      atomicAdd(&shared_accs[2], acc2);
      atomicAdd(&shared_accs[3], acc3);
      int old_count = atomicAdd(&shared_reduction_counter, 1);

      if (old_count == (warpCounts - 1)) {
        // we are the last warp to contribute
        // do the final write to memory

        acc0 = shared_accs[0]; // read the fully reduced value
        acc1 = shared_accs[1]; // read the fully reduced value
        acc2 = shared_accs[2]; // read the fully reduced value
        acc3 = shared_accs[3]; // read the fully reduced value
        if (B != nullptr) { // add the bias first if there is one
          acc0 += __half2float(B[current_row + 0]);
          acc1 += __half2float(B[current_row + 1]);
          acc2 += __half2float(B[current_row + 2]);
          acc3 += __half2float(B[current_row + 3]);
        }

        // write the output value
        Y[current_row + 0] = __float2half(acc0);
        Y[current_row + 1] = __float2half(acc1);
        Y[current_row + 2] = __float2half(acc2);
        Y[current_row + 3] = __float2half(acc3);
      }
    }
  } else
#endif
  { // case when either ROWS_PER_BLOCK != 4 or group doesn't have ROWS_PER_BLOCK rows to process
    // shared state to do the reductions
    __shared__ float shared_accs[ROWS_PER_BLOCK];
    __shared__ int shared_reduction_counters[ROWS_PER_BLOCK];

    if (laneId == 0) {
      shared_accs[warpId] = 0.f;
      shared_reduction_counters[warpId] = 0;
    }
    __syncthreads();

    for (int i = 0; i < ROWS_PER_BLOCK; i++) {
      // compute the t-th element of Y. by doing the dot product with the
      // t-th row of W
      int current_row = blockIdx.x * ROWS_PER_BLOCK + i;
      const half* W_ = &W[current_row * K];
    
      // do the dot product
      float acc = 0.f;
      for (int k = threadIdx.x; k < K; k += THREADS_PER_BLOCK) {
        float w = __half2float(W_[k]);
        acc += w * __half2float(X[k]);
      }

      // warp reduce
      acc = warpReduce(acc);

      // reduce accross warps
      if (laneId == 0) {
        atomicAdd(&shared_accs[i], acc);
        int old_count = atomicAdd(&shared_reduction_counters[i], 1);

        if (old_count == (warpCounts - 1)) {
          // we are the last warp to contribute
          // do the final write to memory

          acc = shared_accs[i]; // read the fully reduced value
          if (B != nullptr) { // add the bias first if there is one
            acc += __half2float(B[current_row]);
          }

          // write the output value
          Y[current_row] = __float2half(acc);
        }
      }
    }
  }
}

void muillm_linear_forward_cuda(
    const half* __restrict__ W, // size N x K
    const half* __restrict__ B, // size N
    const half* __restrict__ X, // size K
    half* __restrict__ Y, // size N
    unsigned N,
    unsigned K) {

  const int threads_per_blocks = THREADS_PER_BLOCK;
  const int num_blocks = DIV_ROUND_UP(N, ROWS_PER_BLOCK);

  muillm_gemv_kernel<<<num_blocks, threads_per_blocks, 0, 0>>>(
    W,
    B,
    X,
    Y,
    N,
    K
  );
}

static inline void rocblas_sgemv(rocblas_handle handle,
    const half* __restrict__ W, // size N x K
    const half* __restrict__ X, // size K
    half* __restrict__ Y, // size N
    unsigned N,
    unsigned K) {
  float alpha = 1.0f;
  float beta = 0.f;

  // adapted for row major from https://stackoverflow.com/questions/56043539/cublassgemm-row-major-multiplication
  rocblas_gemm_ex(handle,
                  rocblas_operation_none /*transA*/,
                  rocblas_operation_none /*transB*/,
                  1 /*m*/,
                  N /*n*/,
                  K /*k*/,
                  &alpha,
                  X /*a*/,
                  rocblas_datatype_f16_r /*a_type*/,
                  1 /*lda*/,
                  W /*b*/,
                  rocblas_datatype_f16_r /*b_type*/,
                  K /*ldb*/,
                  &beta,
                  nullptr /*c*/,
                  rocblas_datatype_f16_r /*c_type*/,
                  1 /*ldc*/,
                  Y /*d*/,
                  rocblas_datatype_f16_r /*d_type*/,
                  1 /*ldd*/,
                  rocblas_datatype_f32_r /*compute_type*/,
                  rocblas_gemm_algo_standard /*algo*/,
                  0 /*solution_index*/,
                  0 /*flags*/);
}

size_t timeus_func(size_t count, std::function<void(int)> f) {
  std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
  f(count);
  hipDeviceSynchronize();

  std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();

  return std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / count;
}

int main(int argc, char** argv) {
  int in_features=4096, out_features=14336;
  int tot_features = in_features * out_features;

  // allocate matrices and vectors
  half* x_small = nullptr;
  half* x_big = nullptr;
  half* w_up = nullptr;
  half* w_down = nullptr;

  std::cout<<"Allocating memory..."<<std::endl;
  if (hipMalloc(&x_small, sizeof(half) * in_features) != hipSuccess) {
    return -1;
  }

  if (hipMalloc(&x_big, sizeof(half) * out_features) != hipSuccess) {
    return -1;
  }

  if (hipMalloc(&w_up, sizeof(half) * tot_features) != hipSuccess) {
    return -1;
  }

  if (hipMalloc(&w_down, sizeof(half) * tot_features) != hipSuccess) {
    return -1;
  }

  // set memory
  std::cout<<"Setting memory..."<<std::endl;
  if (hipMemsetD16(x_small, 0, in_features) != hipSuccess) {
    return -1;
  }
  if (hipMemsetD16(x_big, 0, out_features) != hipSuccess) {
    return -1;
  }
  if (hipMemsetD16(w_up, 0, tot_features) != hipSuccess) {
    return -1;
  }
  if (hipMemsetD16(w_down, 0, tot_features) != hipSuccess) {
    return -1;
  }

  //
  std::cout<<"Running..."<<std::endl;

  int count = 10000;

  {

    auto mui_prof = [=] (int count) {
      for (int i = 0; i < count; i++) {
        muillm_linear_forward_cuda(w_up, nullptr, x_small, x_big, out_features, in_features);
        muillm_linear_forward_cuda(w_down, nullptr, x_big, x_small, in_features, out_features);
      }
    };

    // warmup
    size_t discarded = timeus_func(
      10,
      mui_prof
    );

    // measurement
    size_t mui_time = timeus_func(
      count,
      mui_prof
    );

    std::cout<<"mui: "<<mui_time<<"us/loop"<<std::endl;
  }

  {// rocblas
    rocblas_initialize();
    rocblas_handle handle;
    if(rocblas_create_handle(&handle) != rocblas_status_success) return -3;

    auto rocblas_prof = [=] (int count) {
      for (int i = 0; i < count; i++) {
        rocblas_sgemv(handle, w_up, x_small, x_big, out_features, in_features);
        rocblas_sgemv(handle, w_down, x_big, x_small, in_features, out_features);
      }
    };

    // warmup
    size_t discarded = timeus_func(
      10,
      rocblas_prof
    );

    // measurement
    size_t rocblas_time = timeus_func(
      count,
      rocblas_prof
    );

    std::cout<<"rocblas: "<<rocblas_time<<"us/loop"<<std::endl;
  }

  std::cout<<"DONE"<<std::endl;
  return 0;
}

Expected behavior

It should be at least as fast as my naive kernel.
But running the above, I get:

Allocating memory...
Setting memory...
Running...
mui: 230us/loop
rocblas: 386us/loop
hipblas: 386us/loop
DONE

Environment

Hardware description
CPU AMD Ryzen 7 5800X3D 8-Core Processor
GPU AMD Instinct MI100
Software version
rocm-core v6.0.2.60002-115~22.04
rocblas v4.0.0.60002-115~22.04

environment.txt

Additional context

Add any other context about the problem here.

EDIT: put a better kernel than originally included one

@IMbackK
Copy link

IMbackK commented May 7, 2024

yeah this has been an issue for a while: #1238

@Epliz
Copy link
Author

Epliz commented May 7, 2024

I updated the kernel from my reproducer, it saturates memory bandwidth (contrary to rocBLAS).

@Epliz
Copy link
Author

Epliz commented May 7, 2024

I see that @daineAMD replied to the other issue, so mentioning here as well, in case that helps in any way.
To contextualize again if needed, improving rocblas_gemm_ex for cases where it corresponds to gemv ops is a very common pattern for LLM inference at batch size = 1 which gets benchmarked quite often.
Given that a ~100 lines kernel beats rocblas by 2x, I would recommend to put some efforts into this. At least for the matrix shapes of popular LLMs, you could make sure it gets decent performance.

@IMbackK
Copy link

IMbackK commented May 7, 2024

its also pretty silly since just using the gemv kernels in these cases should be trivial

the suboptimiality of this is ofc also easly shown with rocblas's own tool:

rocblas-bench -f gemm_ex -m 1 -n 16192 -k 16192 --transposeA N --transposeB N -r s --compute_type s -i 50

transA,transB,M,N,K,alpha,lda,beta,ldb,ldc,ldd,batch_count,rocblas-Gflops,us
N,N,1,16192,16192,1,128,0,16192,128,128,1, 203.735, 2573.74

rocblas-bench -f gemv -r s -m 16192 -n 16192 --lda 16192 -i 50

transA,M,N,alpha,lda,incx,beta,incy,rocblas-Gflops,rocblas-GB/s,us
N,16192,16192,1,16192,1,0,1, 480.082, 960.224, 1092.3

@IMbackK
Copy link

IMbackK commented May 7, 2024

also rocblas_hgemv would also be great since there is opportunity here to use dual-issue

@daineAMD daineAMD self-assigned this May 7, 2024
@daineAMD
Copy link
Contributor

daineAMD commented May 7, 2024

Hi @Epliz, thanks for brining this up. Yes, the disparity between gemm with m == 1/n == 1 and gemv has been brought up in the past as noted by @IMbackK. Back when it was originally brought up, it wasn't straightforward on if the best approach would be to re-direct the gemm call to gemv (which has source kernels in rocblas) or to continue to gemm (which is handled within the Tensile library) since performance was somewhat of a mixed-bag; and handling this on a case-by-case basis seemed infeasible.

Regardless, it's good that this has been brought up again, and I'll discuss with the team on what the best approach is. If we can get gemv to outperform gemm in every case, then the changes to redirect to gemv would be straightforward, but most of the work would lie in ensuring that gemv is faster. I'll keep you updated with any progress here.

The request for rocblas_hgemv() has also been noted and I can discuss with the team about whether or not we plan on supporting this.

Thanks,
Daine

@IMbackK
Copy link

IMbackK commented May 8, 2024

Hi @daineAMD

Thank you for the detailed comment on this matter and for:
The request for rocblas_hgemv() has also been noted and I can discuss with the team about whether or not we plan on supporting this.

Out of curiosity:
On initial experimentation with rocblas-bench i have been unable to find a configuration where gemm_ex beats gemv on gfx906, gfx908 or gfx1030, if you have some notes on which these could be this would be interesting to me from a performance optimization perspective in my code.

@Epliz
Copy link
Author

Epliz commented May 14, 2024

Hi @daineAMD ,

Following up after a week.
Do you have any example of a configuration where gemv is slower than gemm ?

If not, can you please proceed with making gemm call gemv for those cases?

If the rocBlas team cannot tackle this task, would a pull request from my side be potentially merged? I can sign whatever contribution agreement you might need.

@daineAMD
Copy link
Contributor

Hi @Epliz and @IMbackK, sorry for the delay.

Looking at my past notes, it looks like the areas of most concern were where the incx parameter is large (with various exceptions), specifically gemm cases where (transA == transB == T && ldb >> 1) and (transA == transB == N && lda >> 1).
For example, the following gemm and gemv calls are essentially the same operation:
./rocblas-bench -f gemm -r f32_r --transposeA N --transposeB N -m 1 -n 2048 -k 2048 --alpha 1 --lda 2048 --beta 0 --ldb 2048 --ldc 1
and
./rocblas-bench -f gemv -r f32_r --transposeA T -m 2048 -n 2048 --lda 2048 --incx 2048. Note the large incx here which corresponds to the lda in the gemm call. You can try this out yourself, but I'm getting better performance with gemm here than gemv on MI100.

Other cases where I'm seeing gemm perform better than gemv is for small sizes, e.g.:
./rocblas-bench -f gemm -r f32_r --transposeA N --transposeB N -m 1 -n 1024 -k 1024 --alpha 1 --lda 1 --beta 0 --ldb 1024 --ldc 1
and
./rocblas-bench -f gemv -r f32_r --transposeA T -m 1024 -n 1024 --lda 1024 --incx 1

I have a ticket to investigate further to see if we can call gemv from cases where it outperforms gemm and/or see what optimizations can be done for the current gemv to make this easier; I'll be looking at this in the coming weeks.

You are free to take a look yourself and open a PR, you can take a look at the contributing guide if you're interested, but merging the PR will still take some time as most of the work still lies in ensuring no performance regressions.

Thanks again,
Daine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants