Skip to content

Commit

Permalink
Merge pull request tensorflow#116 from ROCmSoftwarePlatform/rocblas-h…
Browse files Browse the repository at this point in the history
…gemm

Add rocBLAS fp16 GEMM take2
  • Loading branch information
whchung committed Aug 10, 2018
2 parents 2a25993 + 26f945b commit e5defdf
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions tensorflow/stream_executor/rocm/rocm_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ limitations under the License.

#include "tensorflow/stream_executor/rocm/rocm_blas.h"

#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

#include <assert.h>
#include <complex>

Expand Down Expand Up @@ -189,6 +192,7 @@ namespace wrap {
__macro(rocblas_zhpr2) */ \
__macro(rocblas_sgemm) \
__macro(rocblas_dgemm) \
__macro(rocblas_hgemm) \
/* __macro(rocblas_cgemm) \
__macro(rocblas_zgemm) \
__macro(rocblas_ssyrk) \
Expand Down Expand Up @@ -1438,8 +1442,44 @@ bool ROCMBlas::DoBlasGemm(
float alpha, const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb, float beta,
DeviceMemory<Eigen::half> *c, int ldc) {
LOG(ERROR) << "fp16 sgemm is not implemented in this rocBLAS version";
return false;
VLOG(1) << port::Printf(
"doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
"c=%p ldc=%d",
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
if (transa == blas::Transpose::kNoTranspose) {
if (lda < static_cast<int64>(m)) {
LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
"precondition violation";
}
} else {
if (lda < static_cast<int64>(k)) {
LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
<< ") (transpose case); precondition violation";
}
}
if (transb == blas::Transpose::kNoTranspose) {
if (ldb < static_cast<int64>(k)) {
LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
<< ") (no transpose case); precondition violation";
}
} else {
if (ldb < static_cast<int64>(n)) {
LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
"precondition violation";
}
}
const Eigen::half alpha_half(alpha);
const Eigen::half beta_half(beta);
return DoBlasInternal(
wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
reinterpret_cast<const rocblas_half*>(&alpha_half),
reinterpret_cast<const rocblas_half*>(ROCMMemory(a)), lda,
reinterpret_cast<const rocblas_half*>(ROCMMemory(b)), ldb,
reinterpret_cast<const rocblas_half*>(&beta_half),
reinterpret_cast<rocblas_half*>(ROCMMemoryMutable(c)), ldc);
}

bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
Expand Down

0 comments on commit e5defdf

Please sign in to comment.