Skip to content

Commit

Permalink
Add missing declarations for explicit instantiations in concat_lib an…
Browse files Browse the repository at this point in the history
…d split_lib, and add new headers concat_lib_gpu.h and split_lib_gpu.h to contain them (and the declaration of the primary templates).

The current behaviour (using externally defined instantiations without having seen a declaration of those external instantiations) is undesirable and effectively deprecated, and is warned about by -Wundefined-func-template.

PiperOrigin-RevId: 237158146
  • Loading branch information
tensorflower-gardener committed Mar 7, 2019
1 parent 58e052b commit 87cd62e
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 78 deletions.
6 changes: 3 additions & 3 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ tf_kernel_library(
gpu_srcs = [
"concat_lib_gpu_impl.cu.cc",
"concat_lib.h",
"concat_lib_gpu.h",
"cuda_device_array.h",
"cuda_device_array_gpu.h",
],
Expand Down Expand Up @@ -607,6 +608,7 @@ tf_kernel_library(
gpu_srcs = [
"split_lib_gpu.cu.cc",
"split_lib.h",
"split_lib_gpu.h",
],
deps = [
":cuda_device_array",
Expand All @@ -618,9 +620,7 @@ tf_kernel_library(

cc_library(
name = "split_lib_hdrs",
hdrs = [
"split_lib.h",
],
hdrs = ["split_lib.h"],
deps = [
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/core/kernels/concat_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ void ConcatGPU(
inputs_flat,
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);

// Explicit instantiations in concat_lib_gpu.cc.
#define REGISTER(T) \
extern template void ConcatGPU<T>( \
OpKernelContext * c, \
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
inputs_flat, \
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);

TF_CALL_GPU_NUMBER_TYPES(REGISTER);
TF_CALL_complex64(REGISTER);
TF_CALL_complex128(REGISTER);
TF_CALL_int32(REGISTER); // Needed for TensorLists.
TF_CALL_int64(REGISTER);
TF_CALL_int16(REGISTER);
TF_CALL_bfloat16(REGISTER);
TF_CALL_bool(REGISTER);
TF_CALL_uint8(REGISTER);
#undef REGISTER
#endif // GOOGLE_CUDA

#ifdef TENSORFLOW_USE_SYCL
Expand Down
16 changes: 1 addition & 15 deletions tensorflow/core/kernels/concat_lib_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,10 @@ limitations under the License.

#if GOOGLE_CUDA

#include "tensorflow/core/kernels/concat_lib_gpu.h"
#include "tensorflow/core/kernels/cuda_device_array.h"

namespace tensorflow {

template <typename T, typename IntType>
void ConcatGPUSlice(
const Eigen::GpuDevice& gpu_device,
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
inputs_flat,
typename TTypes<T, 2>::Matrix* output);

template <typename T, typename IntType>
void ConcatGPUImpl(const Eigen::GpuDevice& d,
const CudaDeviceArrayStruct<const T*>& input_ptrs,
const CudaDeviceArrayStruct<IntType>& ptr_offsets,
bool same_size, int slice_size,
typename TTypes<T, 2>::Matrix* output);

namespace {

template <typename T, typename IntType>
Expand Down
82 changes: 82 additions & 0 deletions tensorflow/core/kernels/concat_lib_gpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_
#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_

#define EIGEN_USE_THREADS
#define EIGEN_USE_GPU

#include <memory>
#include <vector>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"

namespace tensorflow {

template <typename T, typename IntType>
void ConcatGPUSlice(
const Eigen::GpuDevice& gpu_device,
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
inputs_flat,
typename TTypes<T, 2>::Matrix* output);

template <typename T, typename IntType>
void ConcatGPUImpl(const Eigen::GpuDevice& d,
const CudaDeviceArrayStruct<const T*>& input_ptrs,
const CudaDeviceArrayStruct<IntType>& ptr_offsets,
bool same_size, int slice_size,
typename TTypes<T, 2>::Matrix* output);

// Explicit instantiations in concat_lib_gpu_impl.cu.cc.
#define REGISTER(T) \
extern template void ConcatGPUSlice<T, int32>( \
const Eigen::GpuDevice& gpu_device, \
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
inputs_flat, \
typename TTypes<T, 2>::Matrix* output); \
extern template void ConcatGPUSlice<T, int64>( \
const Eigen::GpuDevice& gpu_device, \
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
inputs_flat, \
typename TTypes<T, 2>::Matrix* output); \
extern template void ConcatGPUImpl<T, int32>( \
const Eigen::GpuDevice& d, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output); \
extern template void ConcatGPUImpl<T, int64>( \
const Eigen::GpuDevice& d, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output);

TF_CALL_GPU_NUMBER_TYPES(REGISTER);
TF_CALL_complex64(REGISTER);
TF_CALL_complex128(REGISTER);
TF_CALL_int32(REGISTER); // Needed for TensorLists.
TF_CALL_int64(REGISTER);
TF_CALL_int16(REGISTER);
TF_CALL_bfloat16(REGISTER);
TF_CALL_bool(REGISTER);
TF_CALL_uint8(REGISTER);
#undef REGISTER

} // namespace tensorflow

#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_
1 change: 1 addition & 0 deletions tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/concat_lib_gpu.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"

Expand Down
89 changes: 44 additions & 45 deletions tensorflow/core/kernels/split_lib_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
#include "tensorflow/core/kernels/split_lib_gpu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"

namespace tensorflow {
Expand Down Expand Up @@ -192,54 +193,52 @@ __global__ void SplitVOpKernel_fixed(
}

template <typename T>
struct SplitOpGPULaunch {
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
CudaLaunchConfig config = GetCudaLaunchConfig(
prefix_dim_size * split_dim_size * suffix_dim_size, d);

TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), input,
prefix_dim_size, split_dim_size,
suffix_dim_size, output_ptr_data));
}
};
void SplitOpGPULaunch<T>::Run(
const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
CudaLaunchConfig config = GetCudaLaunchConfig(
prefix_dim_size * split_dim_size * suffix_dim_size, d);

TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), input,
prefix_dim_size, split_dim_size, suffix_dim_size,
output_ptr_data));
}

template <typename T, typename IntType>
struct SplitVOpGPULaunch {
void Run(const Eigen::GpuDevice& gpu_device, bool fixed_size,
const T* input_ptr, int total_rows, int total_cols,
const CudaDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
if (fixed_size) {
CudaLaunchConfig config =
GetCudaLaunchConfig(total_rows * total_cols, gpu_device);

SplitVOpKernel_fixed<T><<<config.block_count, config.thread_per_block, 0,
gpu_device.stream()>>>(
input_ptr, total_rows, total_cols, output_ptr_data);
} else {
auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device);
IntType smem_max = gpu_device.sharedMemPerBlock();
IntType smem_usage = output_scan.size * sizeof(IntType);
// performance crossover is less than using maximum available shared
// memory on most processors possibly due to decreasing occupancy
// 4096 inputs is a lot, most code will take the smem path
const int32 kMaxSmemBytesPerformance = 16384;
if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
split_v_kernel<T, IntType, true>
<<<config.block_count, config.thread_per_block, smem_usage,
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
total_cols, output_ptr_data);
else
split_v_kernel<T, IntType, false>
<<<config.block_count, config.thread_per_block, 0,
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
total_cols, output_ptr_data);
}
void SplitVOpGPULaunch<T, IntType>::Run(
const Eigen::GpuDevice& gpu_device, bool fixed_size, const T* input_ptr,
int total_rows, int total_cols,
const CudaDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
if (fixed_size) {
CudaLaunchConfig config =
GetCudaLaunchConfig(total_rows * total_cols, gpu_device);

SplitVOpKernel_fixed<T><<<config.block_count, config.thread_per_block, 0,
gpu_device.stream()>>>(
input_ptr, total_rows, total_cols, output_ptr_data);
} else {
auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device);
IntType smem_max = gpu_device.sharedMemPerBlock();
IntType smem_usage = output_scan.size * sizeof(IntType);
// performance crossover is less than using maximum available shared
// memory on most processors possibly due to decreasing occupancy
// 4096 inputs is a lot, most code will take the smem path
const int32 kMaxSmemBytesPerformance = 16384;
if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
split_v_kernel<T, IntType, true>
<<<config.block_count, config.thread_per_block, smem_usage,
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
total_cols, output_ptr_data);
else
split_v_kernel<T, IntType, false>
<<<config.block_count, config.thread_per_block, 0,
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
total_cols, output_ptr_data);
}
};
}

#define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>;

Expand Down
61 changes: 61 additions & 0 deletions tensorflow/core/kernels/split_lib_gpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_
#define TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_

#define EIGEN_USE_THREADS
#define EIGEN_USE_GPU

#include <memory>
#include <vector>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
#include "tensorflow/core/kernels/split_lib.h"

namespace tensorflow {

template <typename T>
struct SplitOpGPULaunch {
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data);
};

template <typename T, typename IntType>
struct SplitVOpGPULaunch {
void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
int total_cols, int total_rows,
const CudaDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data);
};

// Explicit instantiations in split_lib_gpu.cu.cc.
#define REGISTER_GPU_KERNEL(T) \
extern template struct SplitOpGPULaunch<T>; \
extern template struct SplitVOpGPULaunch<T, int32>; \
extern template struct SplitVOpGPULaunch<T, int64>;

TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
TF_CALL_complex64(REGISTER_GPU_KERNEL);
TF_CALL_complex128(REGISTER_GPU_KERNEL);
TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL

} // namespace tensorflow

#endif // TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_
8 changes: 1 addition & 7 deletions tensorflow/core/kernels/split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_device_array.h"
#include "tensorflow/core/kernels/split_lib_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA

Expand Down Expand Up @@ -267,13 +268,6 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {

#if GOOGLE_CUDA

template <typename T>
struct SplitOpGPULaunch {
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data);
};

// Partial specialization for GPU
template <typename T>
class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
Expand Down
9 changes: 1 addition & 8 deletions tensorflow/core/kernels/split_v_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_device_array.h"
#include "tensorflow/core/kernels/split_lib_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA

Expand Down Expand Up @@ -329,14 +330,6 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {

#if GOOGLE_CUDA

template <typename T, typename IntType>
struct SplitVOpGPULaunch {
void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
int total_cols, int total_rows,
const CudaDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data);
};

// Partial specialization for GPU
template <typename T, typename Tlen>
class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
Expand Down

0 comments on commit 87cd62e

Please sign in to comment.