Skip to content

Commit

Permalink
Updated libxsmm kernels (#10112)
Browse files Browse the repository at this point in the history
* Fixed AVX-512 intrinsic layer (sparse_matmul_op.h). Incorporated LIBXSMM_DNN_CONV_OPTION_OVERWRITE. (#26)

* Fixed AVX-512 intrinsic implementation.

* OR'ed LIBXSMM_DNN_CONV_OPTION_OVERWRITE into convolution options, which folds zeroing the input buffer on first use. This removes the call to libxsmm_dnn_zero_buffer in case of LIBXSMM_DNN_COMPUTE_KIND_FWD.

* Made xsmm_conv2d.cc up-to-date with TF/master, avoid double-free in case of LIBXSMM_DNN_WARN_FALLBACK, use libxsmm_hash instead of std::hash, code cleanup (#27)

* Fixed AVX-512 intrinsic implementation.

* OR'ed LIBXSMM_DNN_CONV_OPTION_OVERWRITE into convolution options, which folds zeroing the input buffer on first use. This removes the call to libxsmm_dnn_zero_buffer in case of LIBXSMM_DNN_COMPUTE_KIND_FWD.

* Rely on libxsmm_hash rather than std::hash. Brought xsmm_conv2d.cc up-to-date with TF/master.

* Code cleanup: use LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE rather than assembling the option from separate flags.

* Avoid to destroy the handle in case of LIBXSMM_DNN_WARN_FALLBACK since the next iteration may double-delete the same handle. One would need to update the handle-cache to allow destruction at this place. However, all handles are destructed when TF terminates (cache cleanup).

* Configure LIBXSMM with default arguments (#28)

* Fixed AVX-512 intrinsic implementation.

* OR'ed LIBXSMM_DNN_CONV_OPTION_OVERWRITE into convolution options, which folds zeroing the input buffer on first use. This removes the call to libxsmm_dnn_zero_buffer in case of LIBXSMM_DNN_COMPUTE_KIND_FWD.

* Rely on libxsmm_hash rather than std::hash. Brought xsmm_conv2d.cc up-to-date with TF/master.

* Code cleanup: use LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE rather than assembling the option from separate flags.

* Avoid to destroy the handle in case of LIBXSMM_DNN_WARN_FALLBACK since the next iteration may double-delete the same handle. One would need to update the handle-cache to allow destruction at this place. However, all handles are destructed when TF terminates (cache cleanup).

* Rely on default configuration arguments, and thereby lower the dependence from LIBXSMM internals.
  • Loading branch information
benoitsteiner committed May 22, 2017
1 parent 88a8bb8 commit 23caaa5
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 70 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/conv_grad_input_ops.cc
Expand Up @@ -176,7 +176,7 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
desc.filter_format =
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE;
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;

auto input_ptr = input_backward.data();
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/conv_ops.cc
Expand Up @@ -228,7 +228,7 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE;
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;

if (!CanUseXsmmConv2D(desc, data_format)) {
Expand Down
56 changes: 31 additions & 25 deletions tensorflow/core/kernels/sparse_matmul_op.h
Expand Up @@ -31,11 +31,11 @@ namespace internal {
// in the lower 16-bits of input
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
tensorflow::uint32 tmp;
tensorflow::uint32 tmp;
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
tmp = (reinterpret_cast<const tensorflow::uint32&>(from) ) & 0xffff0000;
#else
tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
#else
tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
#endif
return reinterpret_cast<const float&>(tmp);
}
Expand All @@ -44,12 +44,12 @@ EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
// in the upper 16-bits of input
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
tensorflow::uint32 tmp;
tensorflow::uint32 tmp;
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16 ) & 0xffff0000;
tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
#else
tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
#endif
tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
#endif
return reinterpret_cast<const float&>(tmp);
}

Expand All @@ -61,25 +61,25 @@ EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
float r[4];
tensorflow::uint32 p[4];
pstoreu(r, from);
tensorflow::uint32 * ir = reinterpret_cast<tensorflow::uint32 *>(r);
tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
p[0] = (ir[0] << 16) & 0xffff0000;
p[1] = ir[0]& 0xffff0000;
p[1] = ir[0] & 0xffff0000;
p[2] = (ir[1] << 16) & 0xffff0000;
p[3] = ir[1] & 0xffff0000;
return ploadu<Packet4f>(reinterpret_cast<float *>(p));
return ploadu<Packet4f>(reinterpret_cast<float*>(p));
}

template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
float r[4];
tensorflow::uint32 p[4];
pstoreu(r, from);
tensorflow::uint32 * ir = reinterpret_cast<tensorflow::uint32 *>(r);
tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
p[0] = (ir[2] << 16) & 0xffff0000;
p[1] = ir[2] & 0xffff0000;
p[2] = (ir[3] << 16) & 0xffff0000;
p[3] = ir[3] & 0xffff0000;
return ploadu<Packet4f>(reinterpret_cast<float *>(p));
return ploadu<Packet4f>(reinterpret_cast<float*>(p));
}
#endif

Expand Down Expand Up @@ -131,23 +131,25 @@ EIGEN_DEVICE_FUNC inline Packet pload2bf16(
template <>
EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
tensorflow::uint32 p[4];
const tensorflow::uint32* ir = reinterpret_cast<const tensorflow::uint32 *>(from);
const tensorflow::uint32* ir =
reinterpret_cast<const tensorflow::uint32*>(from);
p[0] = (ir[0] << 16) & 0xffff0000;
p[1] = ir[0]& 0xffff0000;
p[1] = ir[0] & 0xffff0000;
p[2] = (ir[1] << 16) & 0xffff0000;
p[3] = ir[1] & 0xffff0000;
return ploadu<Packet4f>(reinterpret_cast<float *>(p));
return ploadu<Packet4f>(reinterpret_cast<float*>(p));
}

template <>
EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
tensorflow::uint32 p[4];
const tensorflow::uint32* ir = reinterpret_cast<const tensorflow::uint32 *>(from);
const tensorflow::uint32* ir =
reinterpret_cast<const tensorflow::uint32*>(from);
p[0] = (ir[0] << 16) & 0xffff0000;
p[1] = ir[0]& 0xffff0000;
p[1] = ir[0] & 0xffff0000;
p[2] = (ir[0] << 16) & 0xffff0000;
p[3] = ir[0] & 0xffff0000;
return ploadu<Packet4f>(reinterpret_cast<float *>(p));
return ploadu<Packet4f>(reinterpret_cast<float*>(p));
}
#endif

Expand Down Expand Up @@ -255,12 +257,13 @@ EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
}
template <>
EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
Packet2d a = _mm512_extractf32x4_ps(a_in, 1);
Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
return _mm512_broadcastsd_pd(a);
}
template <>
EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
Packet2d a = _mm_permute_pd(_mm512_extractf32x4_ps(a_in, 1), 3);
Packet2d a =
_mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
return _mm512_broadcastsd_pd(a);
}
template <>
Expand Down Expand Up @@ -417,14 +420,17 @@ EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {

template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
return _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(from)),
16);
return _mm512_castsi512_ps(_mm512_slli_epi32(
_mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
16));
}

template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
return _mm512_slli_epi32(
_mm512_cvtepu16_epi32(_mm512_extractf64x4_pd(from, 1)), 16);
Packet16i tmp = _mm512_castps_si512(from);
Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8);
return _mm512_castsi512_ps(_mm512_slli_epi32(
_mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16));
}

#endif
Expand Down
31 changes: 1 addition & 30 deletions tensorflow/core/kernels/xsmm_conv2d.cc
Expand Up @@ -131,32 +131,7 @@ class libxsmm_dnn_conv_desc_wrap {

struct HashFunction {
std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const {
// unsigned char ptr[sizeof(&w.d)];

// memcpy(ptr, (unsigned char *)&w.d, sizeof(&w.d))

//
/*
std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw;
N << w.d.N; C << w.d.C;
H << w.d.H; W << w.d.W;
K << w.d.K; R << w.d.R;
S << w.d.S; u << w.d.u;
v << w.d.v; padh << w.d.pad_h_in;
padw << w.d.pad_w_in;
std::string out_ = N.str() + C.str()\
+ H.str() + W.str()\
+ K.str() + R.str()\
+ S.str() + u.str()\
+ v.str() + padh.str()\
+ padw.str();
//
//
*/
return (std::hash<unsigned long long>()((unsigned long long)&(w.d)));
return libxsmm_hash(&w.d, sizeof(w.d), 25071975);
}
};

Expand Down Expand Up @@ -221,8 +196,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,

status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
"Destroy handle");
return false; // Use non-libxsmm code
}
chk_libxsmm_err(status, "Check codegen status");
Expand Down Expand Up @@ -324,8 +297,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
chk_libxsmm_err(status, "Link filter");
}
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output");

chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
LIBXSMM_DNN_REGULAR_INPUT),
"Bind input forward");
Expand Down
15 changes: 2 additions & 13 deletions third_party/libxsmm.BUILD
Expand Up @@ -11,19 +11,8 @@ exports_files(["LICENSE"])
libxsmm_interface_arguments = "0 1"

# Arguments to ./scripts/libxsmm_config.py, see that file for detailed description.
# ilp64: no
# big: no
# offload: no
# alignment [b]
# prefetch: 1 (auto)
# threshold: fallback to BLAS if n*m*k above this
# synchronize: yes
# jit: yes
# flags
# alpha = 1
# beta = 1
# gemm = 2
libxsmm_config_arguments = "0 0 0 64 1 0 1 1 0 1 1 2"
# rely on default arguments
libxsmm_config_arguments = ""

# Arguments to ./scripts/libxsmm_dispatch.py, see that file for detailed description.
# (dummy argument)
Expand Down

0 comments on commit 23caaa5

Please sign in to comment.