Skip to content

Commit

Permalink
Multi lingual (#20)
Browse files Browse the repository at this point in the history
* support multi lingual nmt

* fix typo

Co-authored-by: wangxiaohui <wangxiaohui.neo@bytedance.com>
  • Loading branch information
neopro12 and neopro12 committed Dec 24, 2020
1 parent 8070745 commit cd288df
Show file tree
Hide file tree
Showing 15 changed files with 840 additions and 153 deletions.
2 changes: 0 additions & 2 deletions docs/build.md
Expand Up @@ -5,8 +5,6 @@
- protobuf >= 3.13
- cmake >= 3.18

There are submodules in this repository which you should clone with `--recurse-submodules`.

To install cudatoolkit-dev, you could run `conda install -c conda-forge cudatoolkit-dev` or follow the [official guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#runfile), the runfile installation with `--toolkit` arg is recommended.

After installation, check the installation of `nvcc` and static libraries (*.a) in `${CUDA_PATH}/lib64`.
Expand Down
24 changes: 15 additions & 9 deletions example/transformer_example.cc.cu
Expand Up @@ -10,8 +10,13 @@ Example of how to run transformer inference using our implementation.
*/

// Appoint precision.
const lightseq::cuda::OperationType optype =
#ifdef FP16_MODE
const lightseq::cuda::OperationType OPTYPE =
lightseq::cuda::OperationType::FP16;
#else
const lightseq::cuda::OperationType OPTYPE =
lightseq::cuda::OperationType::FP32;
#endif

int main(int argc, char *argv[]) {
/* ---step1. init environment--- */
Expand All @@ -21,10 +26,10 @@ int main(int argc, char *argv[]) {
cudaStreamCreate(&stream_);
cublasCreate(&hd_);
cublasSetStream(hd_, stream_);
typedef lightseq::cuda::OperationTypeTraits<optype> optraits;
typedef lightseq::cuda::OperationTypeTraits<OPTYPE> optraits;

/* ---step2. load model weights into GPU memory--- */
lightseq::cuda::TransformerWeight<optype> tw_;
lightseq::cuda::TransformerWeight<OPTYPE> tw_;
// saved in custom proto file
std::string model_weights_path = argv[1];
std::string res = tw_.initializing(model_weights_path);
Expand All @@ -47,8 +52,8 @@ int main(int argc, char *argv[]) {
std::vector<int>(max_batch_size * tw_._max_step * tw_._hidden_size, 0);
thrust::device_vector<int> d_output_ =
std::vector<int>(max_batch_size * tw_._max_step, 0);
std::shared_ptr<lightseq::cuda::Encoder<optype>> encoder_ =
std::make_shared<lightseq::cuda::Encoder<optype>>(
std::shared_ptr<lightseq::cuda::Encoder<OPTYPE>> encoder_ =
std::make_shared<lightseq::cuda::Encoder<OPTYPE>>(
max_batch_size,
reinterpret_cast<int *>(thrust::raw_pointer_cast(d_input_.data())),
reinterpret_cast<int *>(
Expand All @@ -62,15 +67,16 @@ int main(int argc, char *argv[]) {
return 1;
}
// instantiate decoder
std::shared_ptr<lightseq::cuda::Decoder<optype>> decoder_ =
std::make_shared<lightseq::cuda::Decoder<optype>>(
std::shared_ptr<lightseq::cuda::Decoder<OPTYPE>> decoder_ =
std::make_shared<lightseq::cuda::Decoder<OPTYPE>>(
max_batch_size,
reinterpret_cast<int *>(
thrust::raw_pointer_cast(d_padding_mask_.data())),
reinterpret_cast<optraits::DataType *>(
thrust::raw_pointer_cast(d_encoder_output_.data())),
reinterpret_cast<int *>(thrust::raw_pointer_cast(d_output_.data())),
tw_, stream_, hd_);
tw_, stream_, hd_, false,
reinterpret_cast<int *>(thrust::raw_pointer_cast(d_input_.data())));
res = decoder_->check();
if (!res.empty()) {
std::cout << res << std::endl;
Expand Down Expand Up @@ -104,7 +110,7 @@ int main(int argc, char *argv[]) {
batch_seq_len, host_input);

/* ---step5. infer and log--- */
for (int i = 0; i < 10; i++) {
for (int i = 0; i < 1; i++) {
auto start = std::chrono::high_resolution_clock::now();
// copy inputs from cpu memory to gpu memory
cudaMemcpyAsync(
Expand Down
2 changes: 1 addition & 1 deletion kernels/CMakeLists.txt
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.18)

set(cuda_kernel_files gptKernels.cc.cu transformerKernels.cc.cu)
set(cuda_kernel_files gptKernels.cc.cu transformerKernels.cc.cu multilgKernels.cc.cu)

add_library(cuda_kernels STATIC ${cuda_kernel_files})
target_include_directories(cuda_kernels INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
Expand Down
1 change: 1 addition & 0 deletions kernels/gptKernels.h
Expand Up @@ -17,6 +17,7 @@ void ker_gpt_embedding_launcher(int batch_size, int batch_seq_len,
int pos_offset);



template <typename T>
void ker_correlation_softmax_gpt_launcher(int batch_size, int batch_seq_len,
int head_num, cudaStream_t stream,
Expand Down

0 comments on commit cd288df

Please sign in to comment.