diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index 1dd8e6429..000000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,256 +0,0 @@ -stages: - - build - - test - -build_pyt_release: - image: nvcr.io/nvidia/pytorch:21.02-py3 - tags: - - fastertransformer - stage: build - only: - - main - - merge_requests - artifacts: - paths: - - ${CI_PROJECT_DIR}/build/ - expire_in: 1 week - script: - - cd ${CI_PROJECT_DIR} && mkdir build && cd build - - git submodule init && git submodule update - - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON -DBUILD_GPT=ON .. - - make -j12 - -build_pyt_release_sparse: - image: nvcr.io/nvidia/pytorch:21.02-py3 - tags: - - fastertransformer - stage: build - only: - - main - - merge_requests - artifacts: - paths: - - ${CI_PROJECT_DIR}/build/ - expire_in: 1 week - script: - - cd ${CI_PROJECT_DIR} && mkdir build && cd build - - git submodule init && git submodule update - - wget https://developer.download.nvidia.com/compute/libcusparse-lt/0.1.0/local_installers/libcusparse_lt-linux-x86_64-0.1.0.2.tar.gz - - tar -xzvf libcusparse_lt-linux-x86_64-0.1.0.2.tar.gz - - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON -DSPARSITY_SUPPORT=ON -DCUSPARSELT_PATH=${CI_PROJECT_DIR}/build/libcusparse_lt/ .. - - make -j12 - -build_tf_release: - image: nvcr.io/nvidia/tensorflow:21.02-tf1-py3 - tags: - - fastertransformer - stage: build - only: - - main - - merge_requests - artifacts: - paths: - - ${CI_PROJECT_DIR}/build/ - expire_in: 1 week - script: - - cd ${CI_PROJECT_DIR} && mkdir build && cd build - - git submodule init && git submodule update - - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python3.8/dist-packages/tensorflow_core/ -DBUILD_GPT=ON .. - - make -j12 - - apt-get update && apt-get install bc - -# 1. Get accuracy on LAMBADA dataset -# 2. Run pytorch gpt op as basline -# 3. Run pytorch piepline parallel and compare difference with baseline -# 4. Run pytorch tensor parallel and compare difference with baseline -pyt_gpt_test: - image: nvcr.io/nvidia/pytorch:21.02-py3 - tags: - - fastertransformer - stage: test - only: - - main - - merge_requests - needs: - - job: build_pyt_release - artifacts: true - script: - - cd ${CI_PROJECT_DIR}/build/ - - git submodule init && git submodule update - - export PYTHONPATH="${CI_PROJECT_DIR}/:$PYTHONPATH" - - export NVIDIA_TF32_OVERRIDE=0 # Disable the TF32 - - export CUDA_VISIBLE_DEVICES=0 - - wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -P ../models - - wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -P ../models - - wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip - - wget https://github.com/cybertronai/bflm/raw/master/lambada_test.jsonl -P ../models/megatron-models - - unzip megatron_lm_345m_v0.0.zip -d ../models/megatron-models/345m - - python ../examples/pytorch/gpt/utils/megatron_ckpt_convert.py -head_num 16 -i ../models/megatron-models/345m/release/ -o ../models/megatron-models/c-model/345m/ -t_g 1 -i_g 1 - - bash ../examples/pytorch/gpt/scripts/evaluate_zeroshot_gpt.sh - - python ../examples/pytorch/gpt/gpt_example.py --ckpt_path=../models/megatron-models/c-model/345m/1-gpu/ --top_p 0.5 --sample_output_file single-gpu-out.txt - - export CUDA_VISIBLE_DEVICES=0,1 - - mpirun -n 2 --allow-run-as-root python ../examples/pytorch/gpt/multi_gpu_gpt_example.py --tensor_para_size=1 --pipeline_para_size=2 --ckpt_path=../models/megatron-models/c-model/345m/1-gpu/ --top_p 0.5 --sample_output_file pipeline-parallel-2-gpu-out.txt - - diff single-gpu-out.txt pipeline-parallel-2-gpu-out.txt - - python ../examples/pytorch/gpt/utils/megatron_ckpt_convert.py -head_num 16 -i ../models/megatron-models/345m/release/ -o ../models/megatron-models/c-model/345m/ -t_g 1 -i_g 2 - - mpirun -n 2 --allow-run-as-root python ../examples/pytorch/gpt/multi_gpu_gpt_example.py --tensor_para_size=2 --pipeline_para_size=1 --ckpt_path=../models/megatron-models/c-model/345m/2-gpu/ --top_p 0.5 --sample_output_file tensor-parallel-2-gpu-out.txt - - diff single-gpu-out.txt tensor-parallel-2-gpu-out.txt - timeout: 4h 30m - -tf_test: - image: nvcr.io/nvidia/tensorflow:21.02-tf1-py3 - tags: - - fastertransformer - stage: test - only: - - main - - merge_requests - needs: - - job: build_tf_release - artifacts: true - script: - - cd ${CI_PROJECT_DIR}/build/ - - apt-get update && apt-get install bc - - export PYTHONPATH="${CI_PROJECT_DIR}/:$PYTHONPATH" - - export NVIDIA_TF32_OVERRIDE=0 # Disable the TF32 - - export CUDA_VISIBLE_DEVICES=0 - - bash ${CI_PROJECT_DIR}/examples/tensorflow/decoding/utils/translation/download_model_data.sh - - mkdir -p ${CI_PROJECT_DIR}/translation/ckpt_fp16 - - python ${CI_PROJECT_DIR}/tests/bert/tf_bert_unit_test.py - - python ${CI_PROJECT_DIR}/tests/bert/tf_encoder_unit_test.py - - python ${CI_PROJECT_DIR}/examples/tensorflow/ckpt_type_convert.py --init_checkpoint=${CI_PROJECT_DIR}/translation/ckpt/model.ckpt-500000 --fp16_checkpoint=${CI_PROJECT_DIR}/translation/ckpt_fp16/model.ckpt-500000 - - python ${CI_PROJECT_DIR}/tests/decoding/tf_decoding_unit_test.py - timeout: 4h 30m - -tf_xlnet_test: - image: nvcr.io/nvidia/tensorflow:21.02-tf1-py3 - tags: - - fastertransformer - stage: test - only: - - master - - v4.1 - - main - - merge_requests - needs: - - job: build_tf_release - artifacts: true - script: - - cd ${CI_PROJECT_DIR}/examples/tensorflow/xlnet - - bash downloadModel.sh - - bash verifyCorrectness.sh # For FP32 model - -pyt_sp_test: - image: nvcr.io/nvidia/pytorch:21.02-py3 - tags: - - fastertransformer - stage: test - only: - - main - - merge_requests - needs: - - job: build_pyt_release_sparse - artifacts: true - script: - - cd ${CI_PROJECT_DIR}/build/ - - export PYTHONPATH="${CI_PROJECT_DIR}/:$PYTHONPATH" - - export NVIDIA_TF32_OVERRIDE=0 # Disable the TF32 - - export CUDA_VISIBLE_DEVICES=0 - - pip install transformers==2.5.1 - # GOS has no Ampere GPU, so no sparse tests can be done. only test some dense cases - - ${CI_PROJECT_DIR}/build/bin/bert_gemm 32 64 12 64 1 0 - - python ${CI_PROJECT_DIR}/examples/pytorch/bert/bert_example.py 32 12 64 12 64 --fp16 - - ${CI_PROJECT_DIR}/build/bin/bert_gemm 32 64 12 64 1 1 - - python ${CI_PROJECT_DIR}/examples/pytorch/bert/bert_example.py 32 12 64 12 64 --fp16 --int8_mode 1 - - python ${CI_PROJECT_DIR}/examples/pytorch/bert/bert_example.py 32 12 64 12 64 --fp16 --int8_mode 2 - - python ${CI_PROJECT_DIR}/examples/pytorch/bert/bert_example.py 32 12 64 12 64 --fp16 --int8_mode 3 - -pyt_longformer_test: - image: nvcr.io/nvidia/pytorch:21.02-py3 - tags: - - fastertransformer - stage: test - only: - - main - - merge_requests - needs: - - job: build_pyt_release - artifacts: true - script: - - cd ${CI_PROJECT_DIR}/examples/pytorch/longformer - - apt-get update && apt-get install git-lfs - - git lfs install - - git config lfs.fetchinclude "pytorch_model.bin,config.json" - - git clone https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa - - cd ${CI_PROJECT_DIR} - - export PYTHONPATH="${CI_PROJECT_DIR}/:$PYTHONPATH" - - export NVIDIA_TF32_OVERRIDE=0 # Disable the TF32 - - export CUDA_VISIBLE_DEVICES=0 - - pip install transformers==4.8.2 - - python3 tests/longformer/py_longformer_unit_test.py - -pyt_decoding_test: - image: nvcr.io/nvidia/pytorch:21.02-py3 - tags: - - fastertransformer - stage: test - only: - - main - - merge_requests - needs: - - job: build_pyt_release - artifacts: true - script: - - cd ${CI_PROJECT_DIR}/build/ - - export PYTHONPATH="${CI_PROJECT_DIR}/:$PYTHONPATH" - - export NVIDIA_TF32_OVERRIDE=0 # Disable the TF32 - - export CUDA_VISIBLE_DEVICES=0 - - apt-get update && apt-get install bc - - pip install sacrebleu - - pip install opennmt-py==1.1.1 - - bash ../examples/pytorch/decoding/utils/download_model.sh - - mkdir pytorch/translation/data -p - - cp ../examples/tensorflow/decoding/utils/translation/test* pytorch/translation/data - - python ../examples/pytorch/decoding/utils/recover_bpe.py pytorch/translation/data/test.de debpe_ref.txt - - echo "Run decoding fp32" # decoding fp32 testing - - python ../examples/pytorch/decoding/translate_example.py --batch_size 128 --beam_size 4 --model_type decoding_ext --decoding_ths_path ./lib/libth_decoding.so --data_type fp32 --output_file output.txt - - python ../examples/pytorch/decoding/utils/recover_bpe.py output.txt debpe_output.txt - - cat debpe_output.txt | sacrebleu debpe_ref.txt - - echo "Run decoder fp32" # decoder fp32 testing - - python ../examples/pytorch/decoding/translate_example.py --batch_size 128 --beam_size 4 --model_type torch_decoding_with_decoder_ext --decoder_ths_path ./lib/libth_decoder.so --data_type fp32 --output_file output.txt - - python ../examples/pytorch/decoding/utils/recover_bpe.py output.txt debpe_output.txt - - cat debpe_output.txt | sacrebleu debpe_ref.txt - - echo "Run decoding fp16" # decoding fp16 testing - - python ../examples/pytorch/decoding/translate_example.py --batch_size 128 --beam_size 4 --model_type decoding_ext --decoding_ths_path ./lib/libth_decoding.so --data_type fp16 --output_file output.txt - - python ../examples/pytorch/decoding/utils/recover_bpe.py output.txt debpe_output.txt - - cat debpe_output.txt | sacrebleu debpe_ref.txt - - echo "Run decoder fp16" # decoder fp16 testing - - python ../examples/pytorch/decoding/translate_example.py --batch_size 128 --beam_size 4 --model_type torch_decoding_with_decoder_ext --decoder_ths_path ./lib/libth_decoder.so --data_type fp16 --output_file output.txt - - python ../examples/pytorch/decoding/utils/recover_bpe.py output.txt debpe_output.txt - - cat debpe_output.txt | sacrebleu debpe_ref.txt - timeout: 4h - -t5_test: - image: nvcr.io/nvidia/pytorch:21.02-py3 - tags: - - fastertransformer - stage: test - only: - - main - - merge_requests - needs: - - job: build_pyt_release - artifacts: true - script: - - cd ${CI_PROJECT_DIR}/build/ - - export PYTHONPATH="${CI_PROJECT_DIR}/:$PYTHONPATH" - - export NVIDIA_TF32_OVERRIDE=0 # Disable the TF32 - - export CUDA_VISIBLE_DEVICES=0 - - apt-get update && apt-get install bc - - pip install transformers huggingface_hub tokenizers sacrebleu SentencePiece - - python ../examples/pytorch/t5/translate_example.py -batch 32 -time 0123 - - python ../examples/pytorch/t5/translate_example.py -batch 32 -time 0123 -d fp16 - - python ../examples/pytorch/t5/translate_example.py -batch 4 -time 0123 -d fp16 --model t5-3b - - export CUDA_VISIBLE_DEVICES=0,2 - - mpirun -n 2 --allow-run-as-root python ../examples/pytorch/t5/translate_example.py -batch 4 -time 13 -d fp16 --model t5-3b --tensor_para_size 2 - - mpirun -n 2 --allow-run-as-root python ../examples/pytorch/t5/translate_example.py -batch 4 -time 13 -d fp16 --model t5-3b --pipeline_para_size 2 - timeout: 4h diff --git a/.gitlab/issue_templates/bug.md b/.gitlab/issue_templates/bug.md deleted file mode 100644 index 5a9897a27..000000000 --- a/.gitlab/issue_templates/bug.md +++ /dev/null @@ -1,6 +0,0 @@ -Bug title: - -Description: - -Assign: -/assign diff --git a/.gitlab/issue_templates/feature.md b/.gitlab/issue_templates/feature.md deleted file mode 100644 index d876df20e..000000000 --- a/.gitlab/issue_templates/feature.md +++ /dev/null @@ -1,6 +0,0 @@ -Feature title: - -Description: - -Assign: -/assign @bhsueh @juney diff --git a/.gitlab/merge_request_templates/merge.md b/.gitlab/merge_request_templates/merge.md deleted file mode 100644 index 28f9a0b34..000000000 --- a/.gitlab/merge_request_templates/merge.md +++ /dev/null @@ -1,7 +0,0 @@ -Related Issue: Closed # - -Assign: -/assign @bhsueh - -Reviewers: -/assign_reviewer @yudong @jkosek @pziecina @juney diff --git a/3rdparty/INIReader.h b/3rdparty/INIReader.h index 8edc6dfdd..7d40f0638 100644 --- a/3rdparty/INIReader.h +++ b/3rdparty/INIReader.h @@ -316,7 +316,7 @@ class INIReader // Construct INIReader and parse given file. See ini.h for more info // about the parsing. INIReader(FILE *file); - + ~INIReader(); // Return the result of ini_parse(), i.e., 0 on success, line number of // first error on parse error, or -1 on file open error. int ParseError() const; @@ -384,6 +384,8 @@ inline int INIReader::ParseError() const return _error; } +inline INIReader::~INIReader() { } + inline const std::set& INIReader::Sections() const { return _sections; diff --git a/3rdparty/trt_fused_multihead_attention/fused_multihead_attention.h b/3rdparty/trt_fused_multihead_attention/fused_multihead_attention.h index 4daf9de93..b058e1e2e 100644 --- a/3rdparty/trt_fused_multihead_attention/fused_multihead_attention.h +++ b/3rdparty/trt_fused_multihead_attention/fused_multihead_attention.h @@ -313,8 +313,15 @@ class TFusedMHAKernelFactory static TFusedMHAKernelFactory& Get() { - static TFusedMHAKernelFactory s_factory; - return s_factory; + int device_id; + cudaGetDevice(&device_id); + static std::unique_ptr> s_factory[32] = {nullptr}; + if (s_factory[device_id] == nullptr) { + assert(device_id <= 32); + s_factory[device_id] = std::make_unique>(TFusedMHAKernelFactory()); + } + + return *(s_factory[device_id]); } private: diff --git a/3rdparty/trt_fused_multihead_attention/qkvToContext.cu b/3rdparty/trt_fused_multihead_attention/qkvToContext.cu index 6b5c448f4..b5d86df09 100644 --- a/3rdparty/trt_fused_multihead_attention/qkvToContext.cu +++ b/3rdparty/trt_fused_multihead_attention/qkvToContext.cu @@ -24,16 +24,28 @@ namespace fastertransformer { +union __half2_uint32_t_union { + half2 fp162; + uint32_t u32; +}; +union __float_uint32_t_union { + float fp32; + uint32_t u32; +}; + static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype) { if (dtype == DATA_TYPE_FP16) { - half2 h2 = __float2half2_rn(norm); - alpha = reinterpret_cast(h2); + __half2_uint32_t_union temp; + temp.fp162 = __float2half2_rn(norm); + alpha = temp.u32; } else if (dtype == DATA_TYPE_FP32) { - alpha = reinterpret_cast(norm); + __float_uint32_t_union temp; + temp.fp32 = norm; + alpha = temp.u32; } else if (dtype == DATA_TYPE_INT32) { @@ -365,7 +377,7 @@ public: void setup(const int S, const int B, const int window_num) { - size_t warps_m, warps_n, warps_k = 1; + size_t warps_m = 1, warps_n = 1, warps_k = 1; if (S == 64) { warps_m = 2; warps_n = 2; @@ -413,9 +425,14 @@ public: float scaleBmm2 = scaleCtx; float scaleSoftmax = interface->mDqProbs; - params.scale_bmm1 = reinterpret_cast(scaleBmm1); - params.scale_bmm2 = reinterpret_cast(scaleBmm2); - params.scale_softmax = reinterpret_cast(scaleSoftmax); + __float_uint32_t_union temp; + + temp.fp32 = scaleBmm1; + params.scale_bmm1 = temp.u32; + temp.fp32 = scaleBmm2; + params.scale_bmm2 = temp.u32; + temp.fp32 = scaleSoftmax; + params.scale_softmax = temp.u32; params.enable_i2f_trick = -double(1 << 22) * double(scaleBmm2) <= -128.f && double(1 << 22) * double(scaleBmm2) >= 127.f; @@ -446,7 +463,7 @@ public: void setup(const int S, const int B) { - size_t warps_m, warps_n, warps_k = 1; + size_t warps_m = 1, warps_n = 1, warps_k = 1; if ((sm == 75 || sm == 80) && S == 64) { warps_m = 2; @@ -495,9 +512,14 @@ public: float scaleBmm2 = interface->mDqProbs * scaleQkv / scaleCtx; float scaleSoftmax = 1.f / interface->mDqProbs; - params.scale_bmm1 = reinterpret_cast(scaleBmm1); - params.scale_bmm2 = reinterpret_cast(scaleBmm2); - params.scale_softmax = reinterpret_cast(scaleSoftmax); + __float_uint32_t_union temp; + + temp.fp32 = scaleBmm1; + params.scale_bmm1 = temp.u32; + temp.fp32 = scaleBmm2; + params.scale_bmm2 = temp.u32; + temp.fp32 = scaleSoftmax; + params.scale_softmax = temp.u32; params.enable_i2f_trick = -double(1 << 22) * double(scaleBmm2) <= -128.f && double(1 << 22) * double(scaleBmm2) >= 127.f; diff --git a/CMakeLists.txt b/CMakeLists.txt index 43f64f4d4..76a5e0655 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,8 @@ endif() option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF) +option(BUILD_FAST_MATH "Build in fast math mode" ON) + if(BUILD_MULTI_GPU) message(STATUS "Add DBUILD_MULTI_GPU, requires MPI and NCCL") add_definitions("-DBUILD_MULTI_GPU") @@ -157,6 +159,10 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") # set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3") +if(BUILD_FAST_MATH) +set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math") +message("CMAKE_CUDA_FLAGS_RELEASE: ${CMAKE_CUDA_FLAGS_RELEASE}") +endif() set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) @@ -227,7 +233,10 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');" endif() endif() -list(APPEND COMMON_HEADER_DIRS ${MPI_INCLUDE_PATH}) +if (BUILD_MULTI_GPU) + list(APPEND COMMON_HEADER_DIRS ${MPI_INCLUDE_PATH}) + list(APPEND COMMON_LIB_DIRS /usr/local/mpi/lib) +endif() if(USE_TRITONSERVER_DATATYPE) list(APPEND COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR}/../repo-core-src/include) @@ -237,9 +246,6 @@ include_directories( ${COMMON_HEADER_DIRS} ) -# set path of mpi -list(APPEND COMMON_LIB_DIRS /usr/local/mpi/lib) - link_directories( ${COMMON_LIB_DIRS} ) @@ -249,9 +255,16 @@ add_subdirectory(src) add_subdirectory(examples) add_subdirectory(tests) +# # Mesaure the compile time +option(MEASURE_BUILD_TIME "Measure the build time of each module" OFF) +if (MEASURE_BUILD_TIME) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_COMMAND} -E time") + set_property(GLOBAL PROPERTY RULE_LAUNCH_CUSTOM "${CMAKE_COMMAND} -E time") + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "${CMAKE_COMMAND} -E time") +endif() + ######################################## -if(BUILD_MULTI_GPU) # Following feature requires cmake 3.15 # TODO Remove this part or modify such that we can run it under cmake 3.10 cmake_minimum_required(VERSION 3.15 FATAL_ERROR) @@ -259,6 +272,8 @@ add_library(transformer-static STATIC $ $ $ + $ + $ $ $ $ @@ -271,6 +286,12 @@ add_library(transformer-static STATIC $ $ $ + $ + $ + $ + $ + $ + $ $ $ $ @@ -285,55 +306,65 @@ add_library(transformer-static STATIC $ $ $ + $ $ $ $ + $ $ $ $ + $ $ $ $ $ $ $ - $ $ $ $ $ $ $ + $ + $ $ $ $ $ $ $ + $ $ $ $ $ + $ $ - $ $ $ $ $ $ $ + $ + $ $ $ $ - $) + $ +) set_property(TARGET transformer-static PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET transformer-static PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(transformer-static PUBLIC -lcudart -lnccl -lmpi -lcublas -lcublasLt -lcurand) +target_link_libraries(transformer-static PUBLIC -lcudart -lcublas -lcublasLt -lcurand) add_library(transformer-shared SHARED $ $ $ + $ + $ $ $ $ @@ -346,6 +377,12 @@ add_library(transformer-shared SHARED $ $ $ + $ + $ + $ + $ + $ + $ $ $ $ @@ -360,53 +397,59 @@ add_library(transformer-shared SHARED $ $ $ + $ $ $ $ + $ $ $ $ + $ $ $ $ $ $ $ - $ $ $ $ $ $ $ + $ + $ $ $ $ $ $ $ + $ $ $ $ $ + $ $ - $ - $ - $ $ $ $ $ $ $ + $ + $ $ $ $ - $) + $ +) set_target_properties(transformer-shared PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(transformer-shared PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_target_properties(transformer-shared PROPERTIES LINKER_LANGUAGE CXX) -target_link_libraries(transformer-shared PUBLIC -lcudart -lnccl -lmpi -lcublas -lcublasLt -lcurand) +target_link_libraries(transformer-shared PUBLIC -lcudart -lcublas -lcublasLt -lcurand) include(GNUInstallDirs) set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/FasterTransformer) @@ -465,5 +508,3 @@ export( ) export(PACKAGE FasterTransformer) - -endif() # BUILD_MULTI_GPU diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d511a88b5..9f9403bf0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -47,4 +47,4 @@ (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. - ``` + ``` \ No newline at end of file diff --git a/README.md b/README.md index 45b5374b7..fe94e8f8d 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,8 @@ FasterTransformer is built on top of CUDA, cuBLAS, cuBLASLt and C++. We provide | Models | Framework | FP16 | INT8 (after Turing) | Sparsity (after Ampere) | Tensor parallel | Pipeline parallel | | ---------------- | -------------- | ---- | ------------------- | ----------------------- | --------------- | ----------------- | | BERT | TensorFlow | Yes | Yes | - | - | - | -| BERT | PyTorch | Yes | Yes | Yes | - | - | +| BERT | PyTorch | Yes | Yes | Yes | Yes | Yes | +| BERT | Triton backend | Yes | - | - | Yes | Yes | | XLNet | C++ | Yes | - | - | - | - | | Encoder | TensorFlow | Yes | Yes | - | - | - | | Encoder | PyTorch | Yes | Yes | Yes | - | - | @@ -52,7 +53,8 @@ FasterTransformer is built on top of CUDA, cuBLAS, cuBLASLt and C++. We provide | Swin Transformer | PyTorch | Yes | Yes | - | - | - | | Swin Transformer | TensorRT | Yes | Yes | - | - | - | | ViT | PyTorch | Yes | Yes | - | - | - | -| ViT | TensorRT | Yes | - | - | - | - | +| ViT | TensorRT | Yes | Yes | - | - | - | +| GPT-NeoX | Triton backend | Yes | - | - | Yes | Yes | * Note that the FasterTransformer supports the models above on C++ because all source codes are built on C++. @@ -65,7 +67,7 @@ The following code lists the directory structure of FasterTransformer: ```bash /src/fastertransformer: source code of FasterTransformer |--/models: Implementation of different models, like BERT, GPT. - |--/layers: Implementation of layer modeuls, like attention layer, ffn layer. + |--/layers: Implementation of layer modules, like attention layer, ffn layer. |--/kernels: CUDA kernels for different models/layers and operations, like addBiasResiual. |--/tensorrt_plugin: encapluate FasterTransformer into TensorRT plugin. |--/tf_op: custom Tensorflow OP implementation @@ -181,7 +183,7 @@ In the experiments of decoding, we updated the following parameters: * top_p = 0.9 * tensor parallel size = 8 * input sequence length = 512 -* ouptut sequence length = 32 +* output sequence length = 32
@@ -189,17 +191,45 @@ In the experiments of decoding, we updated the following parameters: ### Changelog +Aug 2022 +- **Release the FasterTransformer 5.1** +- Support for interactive generation +- Support for attention time-limited memory +- Support mt5 and t5-v1.1 + +July 2022 +- Support UL2 huggingface ckpt. ([link](https://huggingface.co/google/ul2)) + - Fix bug of T5 under bfloat16. +- Add ViT INT8 TensorRT Plugin +- Support batch sampling +- Support shared context optimization in GPT model + +June 2022 +- Support streaming generation for triton backend. +- Support OPT. +- Support multi-node multi-GPU BERT under FP32, FP16 and BF16. + +May 2022 +- Support bfloat16 on most models. +- Support [prefix-prompt](https://arxiv.org/pdf/2101.00190.pdf) for GPT-J. +- Support GPT-NeoX. + - epsilon value used in layernorm is now a parameter + - rotary embedding GPT-NeoX style (only GPT-J was implemented) + - load per-GPU layernorm and bias parameters + - weight conversion from EleutherAI checkpoint + April 2022 -- Change the default accumulation type of all gemm to FP32. -- Support bfloat16 inference in GPT model. -- Support Nemo Megatron T5 and Megatron-LM T5 model. -- Support ViT. +- **Release the FasterTransformer 5.0** + - Change the default accumulation type of all gemm to FP32. + - Support bfloat16 inference in GPT model. + - Support Nemo Megatron T5 and Megatron-LM T5 model. + - Support ViT. March 2022 - Support `stop_ids` and `ban_bad_ids` in GPT-J. - Support dynamice `start_id` and `end_id` in GPT-J, GPT, T5 and Decoding. -Febuary 2022 +February 2022 - Support Swin Transformer. - Optimize the k/v cache update of beam search by in-direction buffer. - Support runtime input for GPT-J, T5 and GPT. @@ -317,7 +347,7 @@ March 2020 - Add a layer normalization layer after decoder. - Add a normalization for inputs of decoder -Febuary 2020 +February 2020 - **Release the FasterTransformer 2.0** - Provide a highly optimized OpenNMT-tf based decoder and decoding, including C++ API and TensorFlow op. - Refine the sample codes of encoder. diff --git a/benchmarks/bert/pyt_benchmark.sh b/benchmarks/bert/pyt_benchmark.sh index a8c866828..b8d56de57 100644 --- a/benchmarks/bert/pyt_benchmark.sh +++ b/benchmarks/bert/pyt_benchmark.sh @@ -54,7 +54,7 @@ if [ -f "gemm_config.in" ] ; then tmp_log_ths=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-ths-log.log if [ "$precision" = "fp16" ]; then - python ../examples/pytorch/bert/bert_example.py ${batch_size} 12 ${seq_len} 12 64 --fp16 --time 2>&1 | tee $tmp_log_ths + python ../examples/pytorch/bert/bert_example.py ${batch_size} 12 ${seq_len} 12 64 --data_type fp16 --time 2>&1 | tee $tmp_log_ths else python ../examples/pytorch/bert/bert_example.py ${batch_size} 12 ${seq_len} 12 64 --time 2>&1 | tee $tmp_log_ths fi diff --git a/benchmarks/bert/pyt_int8_benchmark.sh b/benchmarks/bert/pyt_int8_benchmark.sh index 97dede8b8..8aaf0be89 100644 --- a/benchmarks/bert/pyt_int8_benchmark.sh +++ b/benchmarks/bert/pyt_int8_benchmark.sh @@ -43,7 +43,7 @@ do ../build/bin/bert_gemm ${batch_size} ${seq_len} 12 64 1 ${int8_mode} tmp_log_ths=${logdir}/batchsize-${batch_size}-seq-${seq_len}-fp16-ths-log.log - python ../examples/pytorch/bert/bert_example.py ${batch_size} 12 ${seq_len} 12 64 --fp16 --time --int8_mode ${int8_mode} 2>&1 | tee $tmp_log_ths + python ../examples/pytorch/bert/bert_example.py ${batch_size} 12 ${seq_len} 12 64 --data_type fp16 --time --int8_mode ${int8_mode} 2>&1 | tee $tmp_log_ths ths_time=`tail -n 3 ${tmp_log_ths} | head -n 1 | awk '{print $5}'` ft_time=`tail -n 2 ${tmp_log_ths} | head -n 1 | awk '{print $5}'` diff --git a/benchmarks/bert/pyt_sp_fp16_benchmark.sh b/benchmarks/bert/pyt_sp_fp16_benchmark.sh index 8d477724f..2dd00349b 100644 --- a/benchmarks/bert/pyt_sp_fp16_benchmark.sh +++ b/benchmarks/bert/pyt_sp_fp16_benchmark.sh @@ -56,9 +56,9 @@ do tmp_log_pt=${logdir}/batchsize-${batch_size}-seq-${seq_len}-log.log tmp_log_pt_sp=${logdir}/batchsize-${batch_size}-seq-${seq_len}-sp-log.log - python ../examples/pytorch/bert/bert_example.py ${batch_size} ${layer_num} ${seq_len} ${head_num} ${head_size} --fp16 --time 2>&1 | tee $tmp_log_pt + python ../examples/pytorch/bert/bert_example.py ${batch_size} ${layer_num} ${seq_len} ${head_num} ${head_size} --data_type fp16 --time 2>&1 | tee $tmp_log_pt sleep 5s - python ../examples/pytorch/bert/bert_example.py ${batch_size} ${layer_num} ${seq_len} ${head_num} ${head_size} --fp16 --sparse --time 2>&1 | tee $tmp_log_pt_sp + python ../examples/pytorch/bert/bert_example.py ${batch_size} ${layer_num} ${seq_len} ${head_num} ${head_size} --data_type fp16 --sparse --time 2>&1 | tee $tmp_log_pt_sp sleep 5s ft_o_time=`tail -n 2 ${tmp_log_pt} | head -n 1 | awk '{print $5}'` diff --git a/benchmarks/bert/pyt_sp_int8_mode2_benchmark.sh b/benchmarks/bert/pyt_sp_int8_mode2_benchmark.sh index 4665180c7..c793c6a5e 100644 --- a/benchmarks/bert/pyt_sp_int8_mode2_benchmark.sh +++ b/benchmarks/bert/pyt_sp_int8_mode2_benchmark.sh @@ -56,9 +56,9 @@ do tmp_log_pt=${logdir}/batchsize-${batch_size}-seq-${seq_len}-log.log tmp_log_pt_sp=${logdir}/batchsize-${batch_size}-seq-${seq_len}-sp-log.log - python ../examples/pytorch/bert/bert_example.py ${batch_size} ${layer_num} ${seq_len} ${head_num} ${head_size} --fp16 --int8_mode 2 --time 2>&1 | tee $tmp_log_pt + python ../examples/pytorch/bert/bert_example.py ${batch_size} ${layer_num} ${seq_len} ${head_num} ${head_size} --data_type fp16 --int8_mode 2 --time 2>&1 | tee $tmp_log_pt sleep 5s - python ../examples/pytorch/bert/bert_example.py ${batch_size} ${layer_num} ${seq_len} ${head_num} ${head_size} --fp16 --int8_mode 2 --sparse --time 2>&1 | tee $tmp_log_pt_sp + python ../examples/pytorch/bert/bert_example.py ${batch_size} ${layer_num} ${seq_len} ${head_num} ${head_size} --data_type fp16 --int8_mode 2 --sparse --time 2>&1 | tee $tmp_log_pt_sp sleep 5s ft_o_time=`tail -n 2 ${tmp_log_pt} | head -n 1 | awk '{print $5}'` diff --git a/benchmarks/bert/pyt_tp_benchmark.sh b/benchmarks/bert/pyt_tp_benchmark.sh new file mode 100644 index 000000000..bc8df3162 --- /dev/null +++ b/benchmarks/bert/pyt_tp_benchmark.sh @@ -0,0 +1,126 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# apt-get update +# apt-get install bc +set -x +export NVIDIA_TF32_OVERRIDE=0 + +MODEL_LAYER=32 +HEAD_NUM=32 +SIZE_PER_HEAD=128 +HIDDEN_SIZE=$(echo "${HEAD_NUM} * ${SIZE_PER_HEAD}" | bc) +INTER_SIZE=$(echo "${HIDDEN_SIZE} * 4" | bc) +for precision in fp16; +do + +if [ "$precision" = "fp16" ]; then + echo "Using fp16." + precision_num=1 + precision_larger="FP16" +else + echo "Using fp32" + precision_num=0 + precision_larger="FP32" +fi + +logdir="bert-6B-log-${precision}-triton" +if [ ! -f ${logdir} ] ; then + mkdir ${logdir} -p +fi +all_log="${logdir}/all-log.log" +echo -e "| Batch_size | Seq_len | Precision | TP1, PP1
Latency (ms) | TP2, PP1
Latency (ms) | TP4, PP1
Latency (ms) | TP1, PP2
Latency (ms) | TP1, PP4
Latency (ms) | " > $all_log +echo -e "|:----------:|:-------:|:---------:|:---------------------------:|:---------------------------:|:---------------------------:|:---------------------------:|:---------------------------:| " >> $all_log + +cat /proc/cpuinfo > ${logdir}/cpuinfo.txt +nvidia-smi > ${logdir}/gpuinfo.txt + +echo "[bert] + model_name = bert + position_embedding_type = absolute + hidden_size = ${HIDDEN_SIZE} + num_layer = ${MODEL_LAYER} + head_num = ${HEAD_NUM} + size_per_head = ${SIZE_PER_HEAD} + activation_type = gelu + inter_size = ${INTER_SIZE} + max_position_embeddings = 512 + layer_norm_eps = 1e-12 + weight_data_type = fp32 + tensor_para_size = 1" > config.ini + +for batch_size in 1 4 32 128 ; +do +for seq_len in 32 128 384 1024 ; +do + # tp 1, pp 1 + python ../examples/pytorch/bert/utils/update_bert_config.py \ + --model-dir ./ \ + --config-ini-path config.ini \ + --pipeline-para-size 1 \ + --tensor-para-size 1 \ + --data-type fp16 \ + --request-batch-size ${batch_size} \ + --request-seq-len ${seq_len} + tmp_log=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-tp-1-pp-1.log + CUDA_VISIBLE_DEVICES=4 ./bin/bert_triton_example config.ini 2>&1 | tee ${tmp_log} + ft_tp1_pp1_time=`tail -n 1 ${tmp_log} | head -n 1 | awk '{print $7}'` + sleep 5 + + # tp 2, pp 1 + python ../examples/pytorch/bert/utils/update_bert_config.py \ + --model-dir ./ \ + --config-ini-path config.ini \ + --pipeline-para-size 1 \ + --tensor-para-size 2 \ + --data-type fp16 \ + --request-batch-size ${batch_size} \ + --request-seq-len ${seq_len} + tmp_log=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-tp-2-pp-1.log + CUDA_VISIBLE_DEVICES=4,5 ./bin/bert_triton_example config.ini 2>&1 | tee ${tmp_log} + ft_tp2_pp1_time=`tail -n 1 ${tmp_log} | head -n 1 | awk '{print $7}'` + sleep 5 + + # tp 4, pp 1 + python ../examples/pytorch/bert/utils/update_bert_config.py \ + --model-dir ./ \ + --config-ini-path config.ini \ + --pipeline-para-size 1 \ + --tensor-para-size 4 \ + --data-type fp16 \ + --request-batch-size ${batch_size} \ + --request-seq-len ${seq_len} + tmp_log=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-tp-4-pp-1.log + CUDA_VISIBLE_DEVICES=4,5,6,7 ./bin/bert_triton_example config.ini 2>&1 | tee ${tmp_log} + ft_tp4_pp1_time=`tail -n 1 ${tmp_log} | head -n 1 | awk '{print $7}'` + sleep 5 + + # tp 1, pp 2 + python ../examples/pytorch/bert/utils/update_bert_config.py \ + --model-dir ./ \ + --config-ini-path config.ini \ + --pipeline-para-size 2 \ + --tensor-para-size 1 \ + --data-type fp16 \ + --request-batch-size ${batch_size} \ + --request-seq-len ${seq_len} + tmp_log=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-tp-1-pp-2.log + CUDA_VISIBLE_DEVICES=4,5 ./bin/bert_triton_example config.ini 2>&1 | tee ${tmp_log} + ft_tp1_pp2_time=`tail -n 1 ${tmp_log} | head -n 1 | awk '{print $7}'` + sleep 5 + + echo "| ${batch_size} | ${seq_len} | fp16 | ${ft_tp1_pp1_time} | ${ft_tp2_pp1_time} | ${ft_tp4_pp1_time} | ${ft_tp1_pp2_time} | " >> ${all_log} +done +done +done \ No newline at end of file diff --git a/benchmarks/gpt/cpp_benchmark.sh b/benchmarks/gpt/cpp_benchmark.sh new file mode 100644 index 000000000..824beada6 --- /dev/null +++ b/benchmarks/gpt/cpp_benchmark.sh @@ -0,0 +1,120 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# $1: TP size +# $2: PP size + +export NVIDIA_TF32_OVERRIDE=0 +tensor_para_size=$1 +pipeline_para_size=$2 +total_gpu_count=$(echo "scale=2; ${tensor_para_size} * ${pipeline_para_size} " | bc) + +vocab_size=51200 + +logdir="gpt-TP${tensor_para_size}-PP${pipeline_para_size}-log" +if [ ! -f ${logdir} ]; then + mkdir ${logdir} -p +fi + +all_log="${logdir}/all-log.log" + +echo -e "| model size | Batch Size | Input length | Output length | Decode value | Precision | FT latency (ms) |" > $all_log +echo -e "|:----------:|:----------:|:------------:|:-------------:|:------------:|:---------:|:---------------:|" >> $all_log + +cat /proc/cpuinfo > ${logdir}/cpuinfo.txt +nvidia-smi > ${logdir}/gpuinfo.txt + +for model_size in "345m" "5b"; +do + if [ "$model_size" = "345m" ]; then + head_num=16 + size_per_head=64 + inter_size=$(echo "scale=2; $head_num * ${size_per_head} * 4 " | bc) + num_layer=24 + elif [ "$model_size" = "5b" ]; then + head_num=32 + size_per_head=128 + inter_size=$(echo "scale=2; $head_num * ${size_per_head} * 4 " | bc) + num_layer=24 + fi + +for decode_type in "beamsearch" "sampling"; +do + + if [ "$decode_type" = "beamsearch" ]; then + decode_values=(4) + elif [ "$decode_type" = "sampling" ]; then + decode_values=(4 0.5) + fi + +for request_batch_size in 1 4 16; +do +for input_length in 60; +do +for request_output_len in 80; +do +for decode_value in ${decode_values[@]}; +do + +if [ "$decode_type" = "beamsearch" ]; then + beam_width=$decode_value + topk=0 + topp=0.0 +elif [ "$decode_type" = "sampling" ]; then + beam_width=1 + if [[ $decode_value == +([[:digit:]]) ]]; then + topk=$decode_value + topp=0.0 + else + topk=0 + topp=$decode_value + fi +fi + +tmp_log=${logdir}/batchsize-${request_batch_size}-decode_value-${decode_value}-${input_length}-${request_output_len}-${decode_type}-${decode_value}.log + +python ../examples/pytorch/gpt/utils/generate_start_ids.py --max_batch_size ${request_batch_size} --max_input_length ${input_length} +./bin/gpt_gemm ${request_batch_size} ${beam_width} ${input_length} ${head_num} ${size_per_head} ${inter_size} ${vocab_size} 1 ${tensor_para_size} +python ../examples/pytorch/gpt/utils/generate_gpt_config.py \ + --max_seq_len 1024 \ + --beam_width ${beam_width} \ + --head_num ${head_num} \ + --size_per_head ${size_per_head} \ + --inter_size ${inter_size} \ + --num_layer ${num_layer} \ + -v 51200 \ + -d fp16 \ + -topk ${topk} \ + -topp ${topp} \ + --tensor_para_size ${tensor_para_size} \ + --pipeline_para_size ${pipeline_para_size} \ + -request_batch_size ${request_batch_size} \ + --request_output_len ${request_output_len} +mpirun -n ${total_gpu_count} --allow-run-as-root ./bin/multi_gpu_gpt_example .tmp.config.ini 2>&1 | tee ${tmp_log} +ft_latency=`tail -n 1 ${tmp_log} | head -n 1 | awk '{print $17}'` +echo "" | awk -v ft_latency=$ft_latency \ + -v batch_size=$request_batch_size \ + -v input_length=${input_length} -v request_output_len="$request_output_len" \ + -v model_size=${model_size} -v decode_value="$decode_value" -v decode_type="$decode_type" \ + '{printf "| %5s | %3d | %4d | %4d | %10s %5s | FP16 | %7.2f |\n", model_size, batch_size, input_length, request_output_len, + decode_type, decode_value, ft_latency}' >> $all_log + +rm .tmp.config.ini + +done # decode_values +done # request_output_len +done # input_length +done # batch_size +done # decode_type +done # model_size \ No newline at end of file diff --git a/benchmarks/t5/pyt_benchmark.sh b/benchmarks/t5/pyt_benchmark.sh index 49c2ad016..3157bb49f 100644 --- a/benchmarks/t5/pyt_benchmark.sh +++ b/benchmarks/t5/pyt_benchmark.sh @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +# $1: RUNNING the Huggingface in the benchmark or not +# $2: model_size like t5-base, t5-3b. Useless when $3 is given +# $3: model_path for t5 model. We can use the downloaded model directly, +# preventing wasting time to download model. + +if [ "$1" = 0 ]; then + IS_RUN_HF=0 +else + IS_RUN_HF=1 +fi + if [ $FT_REPO_PATH ];then echo "FT_REPO_PATH = $FT_REPO_PATH" else @@ -21,52 +32,13 @@ fi export NVIDIA_TF32_OVERRIDE=0 -for model_size in "t5-base"; -do -if [ "$model_size" = "t5-small" ]; then - encoder_head_num=8 - encoder_size_per_head=64 - encoder_d_model=512 - encoder_num_layer=6 - encoder_inter_size=2048 - - decoder_head_num=8 - decoder_size_per_head=64 - decoder_d_model=512 - decoder_num_layer=6 - decoder_inter_size=2048 - decoder_vocab_size=32128 -elif [ "$model_size" = "t5-base" ]; then - encoder_head_num=12 - encoder_size_per_head=64 - encoder_d_model=768 - encoder_num_layer=12 - encoder_inter_size=3072 - - decoder_head_num=12 - decoder_size_per_head=64 - decoder_d_model=768 - decoder_num_layer=12 - decoder_inter_size=3072 - decoder_vocab_size=32128 -elif [ "$model_size" = "t5-3b" ]; then - encoder_head_num=32 - encoder_size_per_head=128 - encoder_d_model=1024 - encoder_num_layer=24 - encoder_inter_size=16384 - - decoder_head_num=32 - decoder_size_per_head=128 - decoder_d_model=1024 - decoder_num_layer=24 - decoder_inter_size=16384 - decoder_vocab_size=32128 +if [ "$2" != "" ]; then + model_size=$2 else - echo "[ERROR] no model_size $model_size" + model_size="t5-base" fi -for precision in fp32; +for precision in fp16; do if [ "$precision" = "fp16" ]; then @@ -94,8 +66,14 @@ if [ ! -f ${logdir} ] ; then fi all_log="${logdir}/all-log.log" -echo -e "| Batch Size | ${decode_type} | Precision | Huggingface
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoding
Speedup |" > $all_log -echo -e "|:----------:|:--------------:|:---------:|:----------------------------------------:|:----------------------------------------:|:-------------------------:|" >> $all_log + +if [ "$IS_RUN_HF" == "1" ]; then + echo -e "| Batch Size | ${decode_type} | Precision | Huggingface
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoding
Speedup |" > $all_log + echo -e "|:----------:|:--------------:|:---------:|:----------------------------------------:|:----------------------------------------:|:-------------------------:|" >> $all_log +else + echo -e "| Batch Size | ${decode_type} | Precision | FT Decoding
Throughput (token/sec) |" > $all_log + echo -e "|:----------:|:--------------:|:---------:|:----------------------------------------:|" >> $all_log +fi cat /proc/cpuinfo > ${logdir}/cpuinfo.txt nvidia-smi > ${logdir}/gpuinfo.txt @@ -108,7 +86,11 @@ do beam_width=$decode_value topk=0 topp=0.0 - test_time="01" + if [ "$IS_RUN_HF" == "1" ]; then + test_time="01" + else + test_time="1" + fi elif [ "$decode_type" = "sampling" ]; then beam_width=1 if [[ $decode_value == +([[:digit:]]) ]]; then @@ -118,7 +100,11 @@ do topk=0 topp=$decode_value fi - test_time="23" + if [ "$IS_RUN_HF" == "1" ]; then + test_time="23" + else + test_time="3" + fi fi if [ -f "gemm_config.in" ] ; then @@ -126,35 +112,57 @@ do fi tmp_log_th=${logdir}/batchsize-${batch_size}-beamwidth-${beam_width}-seq-128-${precision}-${decode_type}-${decode_value}-th-log.log - ./bin/t5_gemm ${batch_size} ${beam_width} 128 \ - ${encoder_d_model} ${encoder_head_num} ${encoder_size_per_head} ${encoder_inter_size} \ - ${decoder_d_model} ${decoder_head_num} ${decoder_size_per_head} ${decoder_inter_size} \ - ${decoder_vocab_size} ${precision_num} > ${logdir}/batchsize-${batch_size}-beamwidth-${beam_width}-seq-128-${precision}-th-log.gemm.log - - python ${FT_REPO_PATH}/examples/pytorch/t5/translate_example.py \ + + if [ "$3" != "" ]; then + python ${FT_REPO_PATH}/examples/pytorch/t5/translate_example.py \ --batch_size ${batch_size} \ --beam_width ${beam_width} \ --max_seq_len 128 \ --data_type ${precision} \ --beam_search_diversity_rate 0.0 \ - --model ${model_size} \ + --model_path $3 \ + --model "t5-base" \ --sampling_topk ${topk} \ --sampling_topp ${topp} \ + --max_iteration 200 \ --test_time ${test_time} 2>&1 | tee ${tmp_log_th} - ft_decoding_throughput=`tail -n 1 ${tmp_log_th} | awk '{print $16}'` - th_throughput=`tail -n 2 ${tmp_log_th} | head -n 1 | awk '{print $16}'` - ft_decoding_speedup=$(echo "scale=2; $ft_decoding_throughput / $th_throughput " | bc) - - echo "" | awk -v th_throughput=$th_throughput \ - -v ft_decoding_throughput=$ft_decoding_throughput \ - -v ft_decoding_speedup=$ft_decoding_speedup -v batch_size=$batch_size -v decode_value="$decode_value" \ - -v precision_large=$precision_large \ - '{printf "| %3d | %4s | %s | %5d | %5d | %5.2f |\n", batch_size, decode_value, - precision_large, th_throughput, ft_decoding_throughput, - ft_decoding_speedup }' >> $all_log + else + python ${FT_REPO_PATH}/examples/pytorch/t5/translate_example.py \ + --batch_size ${batch_size} \ + --beam_width ${beam_width} \ + --max_seq_len 128 \ + --data_type ${precision} \ + --beam_search_diversity_rate 0.0 \ + --model ${model_size} \ + --sampling_topk ${topk} \ + --sampling_topp ${topp} \ + --max_iteration 200 \ + --test_time ${test_time} 2>&1 | tee ${tmp_log_th} + fi + + if [ "$IS_RUN_HF" == "1" ]; then + ft_decoding_throughput=`tail -n 1 ${tmp_log_th} | awk '{print $19}'` + th_throughput=`tail -n 2 ${tmp_log_th} | head -n 1 | awk '{print $19}'` + ft_decoding_speedup=$(echo "scale=2; $ft_decoding_throughput / $th_throughput " | bc) + + echo "" | awk -v th_throughput=$th_throughput \ + -v ft_decoding_throughput=$ft_decoding_throughput \ + -v ft_decoding_speedup=$ft_decoding_speedup -v batch_size=$batch_size -v decode_value="$decode_value" \ + -v precision_large=$precision_large \ + '{printf "| %3d | %4s | %s | %5d | %5d | %5.2f |\n", batch_size, decode_value, + precision_large, th_throughput, ft_decoding_throughput, + ft_decoding_speedup }' >> $all_log + else + ft_decoding_throughput=`tail -n 1 ${tmp_log_th} | awk '{print $19}'` + + echo "" | awk -v ft_decoding_throughput=$ft_decoding_throughput \ + -v batch_size=$batch_size -v decode_value="$decode_value" \ + -v precision_large=$precision_large \ + '{printf "| %3d | %4s | %s | %5d |\n", batch_size, decode_value, + precision_large, ft_decoding_throughput}' >> $all_log + fi done # decode_value done # batch_size done # decode_type -done # for precision -done # for model_size \ No newline at end of file +done # for precision \ No newline at end of file diff --git a/docker/Dockerfile.tf b/docker/Dockerfile.tf new file mode 100644 index 000000000..4e2ce0281 --- /dev/null +++ b/docker/Dockerfile.tf @@ -0,0 +1,41 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# -------------------------------------------------- # +# This is a Docker image dedicated to develop +# FasterTransformer. +# -------------------------------------------------- # + +ARG DOCKER_VERSION=22.07 +ARG BASE_IMAGE=nvcr.io/nvidia/tensorflow:${DOCKER_VERSION}-tf1-py3 +FROM ${BASE_IMAGE} + +RUN apt-get update && \ + apt-get install -y --no-install-recommends bc && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# backend build +WORKDIR /workspace/FasterTransformer +ADD . /workspace/FasterTransformer + +RUN git submodule update --init --recursive + +ARG SM=80 +ARG FORCE_BACKEND_REBUILD=0 +RUN mkdir /var/run/sshd -p && \ + mkdir build -p && cd build && \ + cmake -DSM=${SM} -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python3.8/dist-packages/tensorflow_core/ -DBUILD_MULTI_GPU=ON .. && \ + make -j"$(grep -c ^processor /proc/cpuinfo)" diff --git a/docker/Dockerfile.torch b/docker/Dockerfile.torch new file mode 100644 index 000000000..dea54b6b6 --- /dev/null +++ b/docker/Dockerfile.torch @@ -0,0 +1,59 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# -------------------------------------------------- # +# This is a Docker image dedicated to develop +# FasterTransformer. +# -------------------------------------------------- # + +ARG DOCKER_VERSION=22.07 +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:${DOCKER_VERSION}-py3 +FROM ${BASE_IMAGE} + +RUN apt-get update && \ + apt-get install -y --no-install-recommends bc && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# backend build +WORKDIR /workspace/FasterTransformer +ADD . /workspace/FasterTransformer + +RUN git submodule update --init --recursive && \ + git clone https://github.com/NVIDIA/NeMo /workspace/FasterTransformer/3rdparty/NeMo && \ + cd /workspace/FasterTransformer/3rdparty/NeMo && \ + git checkout 66c7677cd4a68d78965d4905dd1febbf5385dff3 && \ + cd - + +# Originally, we need to re-install the apex package for NeMo. +# However, we don't really need apex in tests about NeMo because we +# only use NeMo to do tokenization, dataset loading and model conversion. +# So, remove the re-installation because it is time-consuming. +# RUN pip uninstall -y apex && \ +# pip install git+https://github.com/NVIDIA/apex --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" && \ +# pip install 3rdparty/NeMo[nlp] +RUN pip install 3rdparty/NeMo[nlp] + +ARG SM=80 +ARG FORCE_BACKEND_REBUILD=0 +ARG SPARSITY_SUPPORT=OFF +ARG BUILD_MULTI_GPU=ON +RUN mkdir /var/run/sshd -p && \ + mkdir build -p && cd build && \ + wget https://developer.download.nvidia.com/compute/libcusparse-lt/0.1.0/local_installers/libcusparse_lt-linux-x86_64-0.1.0.2.tar.gz && \ + tar -xzvf libcusparse_lt-linux-x86_64-0.1.0.2.tar.gz && \ + cmake -DSM=${SM} -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON -DSPARSITY_SUPPORT=${SPARSITY_SUPPORT} -DMEASURE_BUILD_TIME=ON \ + -DCUSPARSELT_PATH=/workspace/FasterTransformer/build/libcusparse_lt/ -DBUILD_MULTI_GPU=${BUILD_MULTI_GPU} -DBUILD_TRT=ON .. && \ + make -j"$(grep -c ^processor /proc/cpuinfo)" diff --git a/docs/bert_guide.md b/docs/bert_guide.md index a618a9b4f..f4ae49258 100644 --- a/docs/bert_guide.md +++ b/docs/bert_guide.md @@ -14,15 +14,20 @@ The FasterTransformer BERT contains the optimized BERT model, Effective FasterTr - [Prepare](#prepare) - [Build the project](#build-the-project) - [How to use](#how-to-use) - - [BERT process](#bert-process) + - [Run FasterTransformer BERT on C++](#run-fastertransformer-bert-on-c) + - [Run FasterTransformer BERT on TensorFlow](#run-fastertransformer-bert-on-tensorflow) + - [Run FasterTransformer BERT on PyTorch](#run-fastertransformer-bert-on-pytorch) + - [Run the PyTorch BERT sample with multi-GPU:](#run-the-pytorch-bert-sample-with-multi-gpu) - [Performance](#performance) - - [BERT performance](#bert-performance) + - [Multi-GPU BERT-6B performance on A100 and triton example](#multi-gpu-bert-6b-performance-on-a100-and-triton-example) + - [Single GPU BERT performance](#single-gpu-bert-performance) - [BERT performance on A100 and TensorFlow](#bert-performance-on-a100-and-tensorflow) - [BERT performance on T4 and TensorFlow](#bert-performance-on-t4-and-tensorflow) - [BERT performance on V100 and TensorFlow](#bert-performance-on-v100-and-tensorflow) - [BERT performance comparison between T4, V100, A100 and A100 with MIG mode on TensorFlow](#bert-performance-comparison-between-t4-v100-a100-and-a100-with-mig-mode-on-tensorflow) - [BERT performance comparison between different features on T4 and TensorFlow](#bert-performance-comparison-between-different-features-on-t4-and-tensorflow) - [BERT performance on A100 and PyTorch](#bert-performance-on-a100-and-pytorch) + - [BERT performance on A10 and PyTorch](#bert-performance-on-a10-and-pytorch) - [BERT performance on T4 and PyTorch](#bert-performance-on-t4-and-pytorch) - [BERT performance on V100 and PyTorch](#bert-performance-on-v100-and-pytorch) - [Performance on BERT Applications: SQuAD MRPC](#performance-on-bert-applications-squad-mrpc) @@ -38,13 +43,14 @@ The following configurations are supported in the FasterTransformer encoder. - Sequence length (S): smaller or equal to 4096. For INT8 mode=1, S should be a multiple of 32 when S > 384. - Size per head (N): Even number and smaller than 128. - Head number (H): Any number satisfying that H * N <= 1024 under FP32, or H * N <= 2048 under FP16. -- Data type: FP32, FP16 and INT8 +- Data type: FP32, FP16, BF16 and INT8 - Any number layer (N1) if the memory is enough -In the FasterTransformer v1.0, we provide a highly optimized BERT-equivalent encoder model. Next, based on the idea of [Effective Transformer](https://github.com/bytedance/effective_transformer), we further optimize BERT inference by removing the useless padding in FasterTransformer v2.1 and provide the Effective FasterTransformer. In FasterTransformer v3.0, we provide INT8 quantization inference to get better performance. In FasterTransformer v3.1, we optimize the INT8 kernels to improve the performance of INT8 inference and integrate the multi-head attention of TensorRT plugin into FasterTransformer. In FasterTransformer v4.0, we add the multi-head attention kernel to support FP16 on V100 and INT8 on T4, A100. The following graph demonstrates the flow chart of these optimization, except INT8. In FasterTransformer v5.0, we refactor the codes, encapsulating the mask building and padding removing into the Bert forward function, and add the sparsity feature of Ampere GPU to accelerate the GEMM. +In the FasterTransformer v1.0, we provide a highly optimized BERT-equivalent encoder model. Next, based on the idea of [Effective Transformer](https://github.com/bytedance/effective_transformer), we further optimize BERT inference by removing the useless padding in FasterTransformer v2.1 and provide the Effective FasterTransformer. In FasterTransformer v3.0, we provide INT8 quantization inference to get better performance. In FasterTransformer v3.1, we optimize the INT8 kernels to improve the performance of INT8 inference and integrate the multi-head attention of TensorRT plugin into FasterTransformer. In FasterTransformer v4.0, we add the multi-head attention kernel to support FP16 on V100 and INT8 on T4, A100. The following graph demonstrates the flow chart of these optimization, except INT8. In FasterTransformer v5.0, we refactor the codes, encapsulating the mask building and padding removing into the Bert forward function, and add the sparsity feature of Ampere GPU to accelerate the GEMM. In FasterTransformer v5.1, we support multi-node multi-GPU inference on Bert FP16.
Fig. 1 Flowchart of encoder.
+

The BERT model is proposed by google in 2018. The encoder of FasterTransformer is equivalent to BERT model, but do lots of optimization. The leftmost flow of Fig. 1 shows the optimization in FasterTransformer. After optimization, FasterTransformer only uses 8 or 6 gemms (blue blocks) and 6 custom CUDA kernels (green blocks) to implement one transformer block. @@ -54,40 +60,57 @@ To further improve the performance of multi head attention, we integrate the mul
Fig. 2 Effective Transformer.
+

Besides, we find that the padding would affect the accuracy for some tasks although they should be useless. So, we recommend removing the padding in the final outputs of downstream tasks. The arguments, inputs, and outputs of encoder: -* Arguments: - 1. Maximum batch size - 2. Maximum sequence length - 3. Head number - 4. Size per head - 5. Intermediate size. The inter size of feed forward network. It is often set to 4 * head_num * size_per_head. - 6. Number of decoder layers - 7. SM version of GPU device. Some kernel chosen depend on the GPU device. - 8. Query scaling. It is used to scale the query before the batch multiplication of query and key. - 9. CUDA stream. - 10. Pointer of cuBLAS wrapper, which is declared in `src/fastertransformer/utils/cublasMMWrapper.h`. - 11. Pointer of memory allocator, which is declared in `src/fastertransformer/utils/allocator.h` - 12. “is_free_buffer_after_forward” flag. If setting to be true, FasterTransformer will allocate buffer before forward, and free buffer after forward. If the memory is controlled by memory pool and the cost of allocating/releasing memory is small, setting the flag to be true can save some memory. - 13. Attention type. There are four different types. Users can use `getAttentionType(size_per_head, sm, remove_padding, seq_len)` to determine the attention type automatically. - 14. The flag of sparsity. This feature requires Ampere GPU and a sparse model. If setting to true, then FT will use sparse gemm to accelerate the gemm computation. - 15. Activation type. There are two options, GeLU and ReLU now. - 16. LayerNorm type. There are two options, post layernorm and pre layernorm. -* Inputs: - 1. Bert input feature. This feature should be after the embedding lookup and position embedding. The shape is \[ request batch size, maximum sequence length, hidden dimension \]. - 2. Sequence length. The shape is \[ request batch size \]. -* Outputs: - 1. Bert output feature. The shape is \[ request batch size, maximum sequence length, hidden dimension \]. +* Constructor of BERT + +| Classification | Name | Data Type | Description | +| :------------: | :--------------------------: | :----------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [0] | max_batch_size | int | **Deprecated, move to input** | +| [1] | max_seq_len | int | **Deprecated, move to input** | +| [2] | head_num | int | Head number for model configuration | +| [3] | size_per_head | int | Size per head for model configuration | +| [4] | inter_size | int | The inter size of feed forward network. It is often set to 4 * head_num * size_per_head. | +| [5] | num_layer | int | Number of transformer layers for model configuration | +| [6] | sm | int | The compute capacity of GPU | +| [7] | q_scaling | float | It is used to scale the query before the batch multiplication of query and key | +| [8] | stream | cudaStream_t | CUDA stream | +| [9] | cublas_wrapper | cublasMMWrapper* | Pointer of cuBLAS wrapper, which is declared in `src/fastertransformer/utils/cublasMMWrapper.h` | +| [10] | allocator | IAllocator* | Pointer of memory allocator, which is declared in `src/fastertransformer/utils/allocator.h` | +| [11] | is_free_buffer_after_forward | bool | If setting to be `true`, FasterTransformer will allocate buffer before forward, and free buffer after forward. When the allocator is based on memory pool, setting to `true` may help reducing the memory usage during inference. | +| [12] | attention_type | AttentionType | Determine fusing the attention or not, remove padding or not, which is declared in `src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h` | +| [13] | sparse | bool | Is using sparsity. **Experimental feature** | +| [14] | activation_type | ActivationType | Determine the activation in FFN, which is declared in `src/fastertransformer/layers/attention_layers/FfnLayer.h` | +| [15] | layernorm_type | LayerNormType | Determine using pre-layernorm or post-layernorm, which is declared in `src/fastertransformer/kernels/layernorm_kernels.h` | +| [16] | tensor_para | NcclParam | Tensor Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [17] | pipeline_para | NcclParam | Pipeline Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [18] | custom_all_reduce_comm | AbstractCustomComm | Custom all reduction communication for custom all reduction in model parallelism. It is only supported in 8-way tensor parallelism | +| [19] | enable_custom_all_reduce | int | Flag of enabling custom all reduction or not | + +* Input of BERT + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :----------------: | :-----------------------------------------------------: | :------: | :------------: | :-------------------------------: | +| input_hidden_state | [batch_size, sequence_length, head_num * size_per_head] | GPU | fp32/fp16/bf16 | The input of transformer layer | +| input_lengths | [batch_size] | GPU | int | The lengths of input_hidden_state | + +* Output of BERT + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :-----------------: | :-----------------------------------------------------: | :------: | :------------: | :-----------------------------: | +| output_hidden_state | [batch_size, sequence_length, head_num * size_per_head] | GPU | fp32/fp16/bf16 | The output of transformer layer | Besides, notice that the multi-head attention kernel from TensorRT is powerful but have some limitation. First, this kernel requires Turing or new GPU and the size per head must be 64. When the conditions are not satisfied, we use original multi-head attention implementation of FasterTransformer. Second, it requires an additional sequence length offset like fig 2 shows. More details are in [link](https://github.com/NVIDIA/TensorRT/tree/release/7.2/plugin/embLayerNormPlugin). When the input has padding, the shape of the sequence length offset is \[2 x B1 + 1 \]. Assume there are three sentences with sequence length s1, s2 and s3, and the sequence length after padding is S. Then the sequence length offset is \[0, s1, S, s2 + S, 2 x S, 2 x S + s3, 3 x S\]. On the other hand, when we remove the padding, the shape of the sequence length offset is \[B1 + 1\], and the sequence length offset is \[0, s1, s1 + s2, s1 + s2 + s3 \]. Namely, the sequence length offset records the sequence length for each sentence. When we have padding, we view the padding as some independent sentences. -In FasterTransformer v4.0, we implement two pipelines of INT8 inference, as shown in Fig. 3.. For int8_mode == 1 (int8v1), we don't quantize residual connection, use int32 as the output of int8 gemms and use per-channel quantization for weights. For int8_mode == 2 (int8v2), we quantize residual connection, use int8 as the output of int8 gemms and use per-tensor quantization for weights. Generally speaking, int8_mode == 1 will have higher accuracy while int8_mode == 2 will have better performance. +In FasterTransformer v4.0, we implement two pipelines of INT8 inference, as shown in Fig. 3. For int8_mode == 1 (int8v1), we don't quantize residual connection, use int32 as the output of int8 gemms and use per-channel quantization for weights. For int8_mode == 2 (int8v2), we quantize residual connection, use int8 as the output of int8 gemms and use per-tensor quantization for weights. Generally speaking, int8_mode == 1 will have higher accuracy while int8_mode == 2 will have better performance.
Fig. 3 Workflow of int8 inference.
+

| feature | int8_mode == 1 | int8_mode == 2 | | :---------------------------------: | :------------: | :------------: | @@ -99,6 +122,8 @@ For INT8 inference, quantized model is needed. We provide TensorFlow quantizatio In FasterTransformer v5.0, we support the sparsity gemm to leverage the sparsity feature of Ampere GPU. We also provide an example on PyTorch. +In FasterTransformer v5.1, we support the multi-GPU multi-node inference for BERT model. + ## Setup The following section lists the requirements to use FasterTransformer BERT. @@ -107,7 +132,7 @@ The following section lists the requirements to use FasterTransformer BERT. - CMake >= 3.8 for Tensorflow, CMake >= 3.13 for PyTorch - CUDA 11.0 or newer version -- Python 3 is recommended because some features are not supported in python 2 +- Python: Only verify on python 3 - Tensorflow: Verify on 1.15, 1.13 and 1.14 should work. - PyTorch: Verify on 1.8.0, >= 1.5.0 should work. @@ -138,10 +163,20 @@ For those unable to use the NGC container, to set up the required environment or - `nvcr.io/nvidia/pytorch:20.07-py3` contains the PyTorch 1.6.0 and python 3.6 - `nvcr.io/nvidia/pytorch:20.12-py3` contains the PyTorch 1.8.0 and python 3.8 - To achieve best performance, we recommend to use the latest image. For example, running image `nvcr.io/nvidia/tensorflow:20.12-tf1-py3` by + To achieve best performance, we recommend to use the latest image. For example, running image `nvcr.io/nvidia/tensorflow:22.04-tf1-py3` by + + ```bash + nvidia-docker run -ti --rm nvcr.io/nvidia/tensorflow:22.04-tf1-py3 bash + git clone https://github.com/NVIDIA/FasterTransformer.git + mkdir -p FasterTransformer/build + cd FasterTransformer/build + git submodule init && git submodule update + ``` + + For pytorch, it is similar ```bash - nvidia-docker run -ti --rm nvcr.io/nvidia/tensorflow:20.12-tf1-py3 bash + nvidia-docker run -ti --rm nvcr.io/nvidia/pytorch:22.04-py3 bash git clone https://github.com/NVIDIA/FasterTransformer.git mkdir -p FasterTransformer/build cd FasterTransformer/build @@ -150,7 +185,19 @@ For those unable to use the NGC container, to set up the required environment or #### Build the project -* Note: the `xx` of `-DSM=xx` in following scripts means the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100). Default setting is including 70, 75, 80 and 86. +* Note: the `xx` of `-DSM=xx` in following scripts means the compute capability of your GPU. The following table shows the compute capability of common GPUs. + +| GPU | compute capacity | +| :---: | :--------------: | +| P40 | 60 | +| P4 | 61 | +| V100 | 70 | +| T4 | 75 | +| A100 | 80 | +| A30 | 80 | +| A10 | 86 | + +By default, `-DSM` is set by 70, 75, 80 and 86. When users set more kinds of `-DSM`, it requires longer time to compile. So, we suggest setting the `-DSM` for the device you use only. Here, we use `xx` as an example due to convenience. 1. build with C++ @@ -161,7 +208,7 @@ For those unable to use the NGC container, to set up the required environment or 2. build with TensorFlow - Uses need to set the path of TensorFlow. For example, if we use `nvcr.io/nvidia/tensorflow:20.12-tf1-py3`, then + Uses need to set the path of TensorFlow. For example, if we use `nvcr.io/nvidia/tensorflow:22.04-tf1-py3`, then ```bash cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python3.8/dist-packages/tensorflow_core/ .. @@ -171,7 +218,7 @@ For those unable to use the NGC container, to set up the required environment or 3. build with PyTorch ```bash - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON .. + cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON -DBUILD_MULTI_GPU=ON .. make ``` @@ -188,28 +235,22 @@ For those unable to use the NGC container, to set up the required environment or ## How to use -### BERT process - -1. Run FasterTransformer BERT on C++ +### Run FasterTransformer BERT on C++ - 1.1 Generate the `gemm_config.in` file for FP32/FP16 and the `igemm_config.in` file for INT8 + 1. Generate the `gemm_config.in` file for FP32/FP16/BF16 and the `igemm_config.in` file for INT8 There are two methods to generate the best GEMM configuration - 1.1.1 Using `./bin/bert_gemm` to generate the best GEMM configuration. + 1.1 Using `./bin/bert_gemm` to generate the best GEMM configuration. - ```bash - ./bin/bert_gemm - ``` - - 1.1.2 Generate the best GEMM configuration when running BERT. - - When == 1, it will first check if the corresponding GEMM configuration exists, if not, it will automatically generate the best GEMM configuration. + Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) ```bash - ./bin/bert_example + ./bin/bert_gemm ``` + 1.2 Generate the best GEMM configuration when running BERT. + The generation of best GEMM configuration is recommended no matter what platform we use when we use FasterTransformer. If we do not generate the configure file, the FasterTransformer will use the default configuration and the inference speed may be slower. Assume the settings of the BERT are as follows: @@ -230,13 +271,14 @@ For those unable to use the NGC container, to set up the required environment or In the following subsection, we use the same settings and 12 transformer layers unless specified. - 1.2 Run FasterTransformer BERT under FP32 on C++ + 2 Run FasterTransformer BERT under FP32 on C++ `./bin/bert_example` runs the BERT in the `C++`. The arguments of `bert_example` is: ```bash - ./bin/bert_example + ./bin/bert_example ``` + Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) Then the following scripts can run the BERT under the above settings. @@ -253,18 +295,20 @@ For those unable to use the NGC container, to set up the required environment or [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 16.51 ms (100 iterations) ``` - 1.3 Run FasterTransformer BERT under FP16 on C++ + 3 Run FasterTransformer BERT under FP16/BF16 on C++ - So far, we use the FP32 to run the FasterTransformer. If we use the volta or newer NVIDIA GPU, we can use tensor core when we use the FP16. + So far, we use the FP32 to run the FasterTransformer. If we use the volta or newer NVIDIA GPU, we can use tensor core when we use the FP16. BF16 is only supported after Ampere NVIDIA GPU (SM 80). - To use the FP16, we only need to set the `` flag to 1 like following: + To use the FP16, we only need to set the `` flag to 1 like following: ```bash ./bin/bert_gemm 32 32 12 64 1 0 ./bin/bert_example 32 12 32 12 64 1 0 0 ``` - Note that the configuration of FP32 and FP16 are different, so we need to generate the configuration again. + To use the BF16, we only need to set the `` flag to 2. + + Note that the configuration of FP32 and FP16/BF16 are different, so we need to generate the configuration again. The outputs should be like to the following: @@ -275,7 +319,7 @@ For those unable to use the NGC container, to set up the required environment or [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 4.00 ms ``` - 1.4 Run FasterTransformer BERT under INT8 on C++ + 4 Run FasterTransformer BERT under INT8 on C++ If we use the Turing or newer NVIDIA GPUs, we can use tensor core when we use the INT8. @@ -311,7 +355,7 @@ For those unable to use the NGC container, to set up the required environment or [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 4.79 ms ( 50 iterations) ``` - 1.5 Run Effective FasterTransformer under FP32 on C++ + 5 Run Effective FasterTransformer under FP32 on C++ To use the Effective FasterTransformer, we only need to set the `` flag to 1 like following: @@ -322,14 +366,14 @@ For those unable to use the NGC container, to set up the required environment or The outputs should be like to the following: - ```bash + ```bash Device Tesla V100-PCIE-32GB before allocate free 29.46 GB total 31.75 GB After allocate free 29.40 GB used 2.35 GB total 31.75 GB [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 9.77 ms ``` - 1.6 Run Effective FasterTransformer under INT8 on C++ + 6 Run Effective FasterTransformer under INT8 on C++ To use the Effective FasterTransformer under INT8, we need to set the `` flag to 1 and `` flag to 1 or 2 like following. Since the sequence length in INT8 should be a multiple of 32, Effective FasterTransformer could be a good choice for INT8. @@ -345,7 +389,7 @@ For those unable to use the NGC container, to set up the required environment or The outputs should be like to the following: - ```bash + ```bash #For int8_mode == 1 Device Tesla T4 Device Tesla T4 @@ -362,9 +406,9 @@ For those unable to use the NGC container, to set up the required environment or [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 2.69 ms ( 50 iterations) ``` -2. Run FasterTransformer on TensorFlow (on T4 GPU) +### Run FasterTransformer BERT on TensorFlow - 2.1 Run FasterTransformer encoder under FP32 on TensorFlow + 1 Run FasterTransformer encoder under FP32 on TensorFlow ```bash ./bin/bert_gemm 32 32 12 64 0 0 @@ -395,7 +439,7 @@ For those unable to use the NGC container, to set up the required environment or Note: We can also generate the best GEMM configuration when running encoder by setting `--allow_gemm_test True`. Note: If users use Ampere GPUs, then TensorFlow will uses TF32 by default, and hence TensorFlow will be faster than FasterTransformer, and have large value differences. - 2.2 Run FasterTransformer BERT under FP16 on TensorFlow + 2 Run FasterTransformer BERT under FP16 on TensorFlow To use the FP16 in TensorFlow, we only need to set the `--data_type fp16` like following: @@ -425,7 +469,7 @@ For those unable to use the NGC container, to set up the required environment or [INFO] batch_size 32 max_seq_len 32 precision FP16 12 layer EFF-OP-while-time 4.98 ms ( 50 iterations) ``` - 2.3 Run FasterTransformer and Effective FasterTransformer encoder under INT8 on TensorFlow + 3 Run FasterTransformer and Effective FasterTransformer encoder under INT8 on TensorFlow To use the INT8 in TensorFlow, we only need to set the `--int8_mode 1` or `--int8_mode 2` like following: @@ -483,11 +527,11 @@ For those unable to use the NGC container, to set up the required environment or Note: since we do not use the correct scales for quantization in this test, the Cross Check between TF and FT should fail. - 2.5 Run FasterTransformer for GLUE dataset + 4 Run FasterTransformer for GLUE dataset This subsection demonstrates how to integrate the FasterTransformer in TensorFlow and evaluate the accuracy of FasterTransformer on GLUE dataset. To evaluate on GLUE dataset, it requires the repo of [BERT](https://github.com/google-research/bert). - 2.5.1 Prepare the BERT codes, Download the BERT pretrained model. + 4.1 Prepare the BERT codes, Download the BERT pretrained model. ```bash git clone https://github.com/google-research/bert.git tensorflow/tensorflow_bert/bert @@ -495,7 +539,7 @@ For those unable to use the NGC container, to set up the required environment or unzip uncased_L-12_H-768_A-12.zip ``` - 2.5.2 Download the GLUE MRPC dataset. Note that the file `download_glue_data.py` can only executed under python3. + 4.2 Download the GLUE MRPC dataset. Note that the file `download_glue_data.py` can only executed under python3. ```bash wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/1502038877f6a88c225a34450793fbc3ea87eaba/download_glue_data.py @@ -504,7 +548,7 @@ For those unable to use the NGC container, to set up the required environment or Note: If the `download_glue_data.py` has some issues, try to use [this](https://gist.github.com/vlasenkoalexey/fef1601580f269eca73bf26a198595f3). - 2.5.3 Finetune the pretrained model on MRPC datasets. This takes some minutes. + 4.3 Finetune the pretrained model on MRPC datasets. This takes some minutes. ```bash export BERT_BASE_DIR=${PWD}/uncased_L-12_H-768_A-12 @@ -541,11 +585,11 @@ For those unable to use the NGC container, to set up the required environment or ``` - 2.5.4 Evaluate the accuracy of FasterTransformer under FP32 + 4.4 Evaluate the accuracy of FasterTransformer under FP32 To evaluate the accuracy of FasterTransformer, we can use `tensorflow/tensorflow_bert/run_classifier_wrap.py`. This file uses `run_classifier.py` of BERT repo, replacing the transformer model by FasterTransformer and add some additional arguments like `--floatx`. - ```bash + ```bash ./bin/bert_gemm 8 128 12 64 0 0 python ../examples/tensorflow/bert/tensorflow_bert/run_classifier_wrap.py \ --floatx=float32 \ @@ -575,7 +619,7 @@ For those unable to use the NGC container, to set up the required environment or I1204 01:25:59.203072 140314814412608 run_classifier.py:925] loss = 0.50261945 ``` - 2.5.5 Convert the finetuned checkpoint to FP16 and evaluate the accuracy of FasterTransformer under FP16. + 4.5 Convert the finetuned checkpoint to FP16 and evaluate the accuracy of FasterTransformer under FP16. To convert the checkpoint from FP32 to FP16, we can use `../examples/tensorflow/bert/tensorflow_bert/ckpt_type_convert.py` to convert the checkpoint. This file requires two arguments, the location of FP32 checkpoint, and the location putting the FP16 checkpoint. @@ -612,7 +656,7 @@ For those unable to use the NGC container, to set up the required environment or I1204 01:27:49.962604 139736813242176 run_classifier.py:925] loss = 0.5103358 ``` - 2.5.6 Compare the speed of BERT of TensorFlow and FasterTransformer under both FP32 and FP16. + 4.6 Compare the speed of BERT of TensorFlow and FasterTransformer under both FP32 and FP16. To compare the speed of TensorFlow and FasterTransformer on BERT model directly, we can use `../examples/tensorflow/bert/tensorflow_bert/profile_transformer_inferece.py`. @@ -649,11 +693,11 @@ For those unable to use the NGC container, to set up the required environment or average time (seconds) elapsed fast transformer: 0.009342837333679199 sec ``` - 2.7 Run FasterTransformer for SQuAD 1.1 dataset + 5 Run FasterTransformer for SQuAD 1.1 dataset This subsection demonstrates how to integrate the FasterTransformer in TensorFlow and evaluates the accuracy of FasterTransformer on SQuAD 1.1 dataset. To evaluate on SQuAD 1.1 dataset, it requires the repo of [BERT](https://github.com/google-research/bert). - 2.7.1 Prepare the BERT codes and download the fine-tuned model of SQuAD 1.1 from NGC + 5.1 Prepare the BERT codes and download the fine-tuned model of SQuAD 1.1 from NGC Because the training time of SQuAD is longer, and the NVIDIA NGC has provided the fine-tuned BERT model, we download the fine-tuned model directly. @@ -663,7 +707,7 @@ For those unable to use the NGC container, to set up the required environment or unzip bert_tf_ckpt_base_qa_squad11_amp_128_19.03.1.zip -d squad_model ``` - 2.7.2 Download the SQuAD dataset. + 5.2 Download the SQuAD dataset. ```bash mkdir squad_data @@ -671,7 +715,7 @@ For those unable to use the NGC container, to set up the required environment or wget -P squad_data https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json ``` - 2.7.3 Evaluate the accuracy of TensorFlow under FP32 + 5.3 Evaluate the accuracy of TensorFlow under FP32 ```bash python ../examples/tensorflow/bert/tensorflow_bert/bert/run_squad.py \ @@ -694,7 +738,7 @@ For those unable to use the NGC container, to set up the required environment or {"exact_match": 78.9120151371807, "f1": 86.22012390507868} ``` - 2.7.4 Evaluate the accuracy of FasterTransformer under FP32 + 5.4 Evaluate the accuracy of FasterTransformer under FP32 To evaluate the accuracy of FasterTransformer, we can use `../examples/tensorflow/bert/tensorflow_bert/run_squad_wrap.py`. This file uses `run_squad.py` of BERT repo, replacing the transformer model by FasterTransformer, and add some additional arguments like `--floatx`. @@ -722,7 +766,7 @@ For those unable to use the NGC container, to set up the required environment or {"exact_match": 78.9120151371807, "f1": 86.22012390507868} ``` - 2.7.5 Convert the checkpoint to FP16 and evaluate the accuracy of TensorFlow and FasterTransformer under FP16 + 5.5 Convert the checkpoint to FP16 and evaluate the accuracy of TensorFlow and FasterTransformer under FP16 To convert the checkpoint from FP32 to FP16, we can use `tensorflow/tensorflow_bert/ckpt_type_convert.py` to convert the checkpoint. This file requires two arguments, the location of FP32 checkpoint, and the location putting the FP16 checkpoint. @@ -752,216 +796,230 @@ For those unable to use the NGC container, to set up the required environment or {"exact_match": 79.03500473036897, "f1": 86.23027825772257} ``` - 2.7.6 Evaluate the accuracy of Effective FasterTransformer under FP16 - - Since the total sequence length is not fixed, we recommend using the default gemm configuration directly for Effective FasterTransformer. - - ```bash - python ../examples/tensorflow/bert/tensorflow_bert/run_squad_wrap.py \ - --floatx=float16 \ - --predict_batch_size=8 \ - --vocab_file=squad_model/vocab.txt \ - --bert_config_file=squad_model/bert_config.json \ - --init_checkpoint=squad_fp16_model/model.ckpt \ - --train_file=squad_data/train-v1.1.json \ - --do_predict=True \ - --predict_file=squad_data/dev-v1.1.json \ - --max_seq_length=384 \ - --output_dir=./squad_ft_output/fp_16/ \ - --allow_gemm_test=False \ - --remove_padding=True - - python ../examples/tensorflow/bert/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_ft_output/fp_16/predictions.json - ``` - - The results of TensorFlow would be like: - - ```bash - {"exact_match": 79.04446546830653, "f1": 86.23343183703513} - ``` - - 2.7.7 Evaluate the accuracy of FasterTransformer under INT8 - - Please refer to the directory `examples/tensorflow/bert/bert-quantization` first for how to get a quantized model. In `section 2.7.7` and `section 2.7.8`, to keep consistent with the procedures described in `examples/tensorflow/bert/bert-quantization`, we use `https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip` as initial checkpoint with finetuned accuracy of == <89.57, 82.44>. - - In `bert-tf-quantization`, we give detailed procedure of Post Training Quantization (PTQ), Quantization Aware Training (QAT) and QAT with Knowledge-distillation. Since they have the same inference procedure, we use QAT-KD checkpoint to show how to evaluate the accuracy of FasterTransformer under INT8. - - Suppose we already fine-tuned a FP32 checkpoint using QAT-KD with int8_mode == 2 as described in `bert-quantization`. The path to checkpoint is `squad_model/QAT_KD_mode_2/`. - - We first convert the checkpoint from FP32 to FP16 (this step is not necessary, but it will give us a better performance) and then quantize the FP16 checkpoint using `../examples/tensorflow/bert/tensorflow_bert/ckpt_quantization.py`. This file requires three arguments, the location of initial checkpoint, the location putting the quantized checkpoint and the int8_mode. - - ```bash - python ../examples/tensorflow/bert/tensorflow_bert/ckpt_type_convert.py --init_checkpoint=squad_model/QAT_KD_mode_2/model.ckpt-27374 --fp16_checkpoint=squad_model/QAT_KD_mode_2_fp16/model.ckpt - - python ../examples/tensorflow/bert/tensorflow_bert/ckpt_quantization.py --init_checkpoint=squad_model/QAT_KD_mode_2_fp16/model.ckpt --quantized_checkpoint=squad_model/QAT_KD_mode_2_fp16_quantized/model.ckpt --int8_mode=2 - - ./bin/bert_gemm 8 384 12 64 1 1 - python ../examples/tensorflow/bert/tensorflow_bert/run_squad_wrap.py \ - --floatx=float16 \ - --predict_batch_size=8 \ - --vocab_file=squad_model/vocab.txt \ - --bert_config_file=squad_model/bert_config.json \ - --init_checkpoint=squad_model/QAT_KD_mode_2_fp16_quantized/model.ckpt \ - --train_file=squad_data/train-v1.1.json \ - --do_predict=True \ - --predict_file=squad_data/dev-v1.1.json \ - --max_seq_length=384 \ - --output_dir=./squad_ft_output/int8_mode_2/ \ - --int8_mode=2 \ - --allow_gemm_test=False - - python ../examples/tensorflow/bert/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_ft_output/int8_mode_2/predictions.json - ``` - - The results of TensorFlow would be like: - - ```bash - {"exact_match": 83.85052034058657, "f1": 90.46351799300075} - ``` - - 2.7.8 Evaluate the accuracy of Effective FasterTransformer under INT8 - - To evaluate the accuracy of Effective FasterTransformer under INT8, we follow the steps described in above section to get the correct checkpoint, and then run `../examples/tensorflow/bert/tensorflow_bert/run_squad_wrap.py` with `--remove_padding True` flag. - - ```bash - ./bin/bert_gemm 8 384 12 64 1 1 - python ../examples/tensorflow/bert/tensorflow_bert/run_squad_wrap.py \ - --floatx=float16 \ - --predict_batch_size=8 \ - --vocab_file=squad_model/vocab.txt \ - --bert_config_file=squad_model/bert_config.json \ - --init_checkpoint=squad_model/QAT_KD_mode_2_fp16_quantized/model.ckpt \ - --train_file=squad_data/train-v1.1.json \ - --do_predict=True \ - --predict_file=squad_data/dev-v1.1.json \ - --max_seq_length=384 \ - --output_dir=./squad_ft_output/int8_mode_2_effectiveTransformer/ \ - --remove_padding=True \ - --int8_mode=2 \ - --allow_gemm_test=False - - python ../examples/tensorflow/bert/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_ft_output/int8_mode_2_effectiveTransformer/predictions.json - ``` - - The results of TensorFlow would be like: - - ```bash - {"exact_match": 83.85052034058657, "f1": 90.46351799300075} - ``` - -3. Run FasterTransformer on PyTorch - - Please install HuggingFace's `transformers` first before running the demos by - ```bash - pip install transformers==2.5.1 - ``` - - 3.1 Generate the `gemm_config.in` file: - - ```bash - ./bin/bert_gemm - ./bin/bert_gemm 1 32 12 64 1 0 - ``` - If you want to use the library in other directory, please generate this file according to your setting and copy it to your working directory. - - 3.2 Run the PyTorch BERT sample: - - ```bash - python ../examples/pytorch/bert/bert_example.py <--fp16> <--int8_mode 0/1/2/3> <--sparse> <--time> - python ../examples/pytorch/bert/bert_example.py 1 12 32 12 64 --fp16 --time - ``` - - Remove `--fp16` for fp32 mode. `--int8_mode 1` or `--int8_mode 2` or `--int8_mode 3` will use int8_mode 1 or 2 or 3 in FasterTransformer. `--sparse` will use Ampere sparsity feature if FasterTransformer is built with sparsity support. - - The outputs should be like to the following: - - ```bash - FT Mean diff: 0.0004119873046875 - FT Max diff: 0.00830078125 - FT Min diff: 0.0 - EFF-FT Mean diff: 0.0004119873046875 - EFF-FT Max diff: 0.00830078125 - EFF-FT Min diff: 0.0 - [INFO] HuggingFaceEnocder time costs: 5.77 ms - [INFO] FasterTransformer time costs: 1.12 ms - [INFO] EFF-FasterTransformer time costs: 1.33 ms - ``` - - 3.3 Run the BERT MRPC sample code: - - ```bash - bash ../examples/pytorch/bert/scripts/run_mrpc.sh - ``` - the `` can be: - - `ori`: original HuggingFace's BERT encoder - - `ths`: original HuggingFace's BERT encoder in TorchScript mode - - `thsext`: our TorchScript custom class - - the `` can be `fp32` or `fp16` - - For example, run HuggingFace's BERT under FP32 by following scripts: - - ```bash - bash ../examples/pytorch/bert/scripts/run_mrpc.sh ori fp32 - ``` - - The outputs should be like to the following: - - ```bash - 06/28/2020 07:29:59 - INFO - __main__ - Evaluation for mrpc done in total 4.646116 secs (0.011388 sec per example) - 06/28/2020 07:29:59 - INFO - __main__ - ***** Eval results ***** - 06/28/2020 07:29:59 - INFO - __main__ - acc = 0.8284313725490197 - 06/28/2020 07:29:59 - INFO - __main__ - acc_and_f1 = 0.8556872581808643 - 06/28/2020 07:29:59 - INFO - __main__ - f1 = 0.8829431438127091 - ``` - - For example, run our PyTorch custom op under FP16 by following scripts: - - ```bash - bash ../examples/pytorch/bert/scripts/run_mrpc.sh thsext fp16 - ``` - - The outputs should be like to the following: - - ```bash - 06/28/2020 07:30:19 - INFO - __main__ - Evaluation for mrpc done in total 1.725153 secs (0.004228 sec per example) - 06/28/2020 07:30:19 - INFO - __main__ - ***** Eval results ***** - 06/28/2020 07:30:19 - INFO - __main__ - acc = 0.8284313725490197 - 06/28/2020 07:30:19 - INFO - __main__ - acc_and_f1 = 0.8556872581808643 - 06/28/2020 07:30:19 - INFO - __main__ - f1 = 0.8829431438127091 - ``` - - 3.4 Run the BERT SQuAD sample code: - - ```bash - bash ../examples/pytorch/bert/scripts/run_squad.sh --mtype --dtype --path --head_num --head_size --bs --seqlen --sparse --remove_padding - ``` - - the `` can be: - - `ori`: original HuggingFace's BERT encoder - - `ths`: original HuggingFace's BERT encoder in TorchScript mode - - `thsext`: our TorchScript custom class - - the `` can be `fp32` or `fp16` or `int8_1` or `int8_2` or `int8_3` - - `` is the directory containing the checkpoint - - ``, ``, ``, and `` are the model and running parameters - - `` can be `true` or `false` - - `` can be `true` or `false` - - For example, run our PyTorch custom op under FP16 by following scripts (using HuggingFace's checkpoint): - - ```bash - bash ../examples/pytorch/bert/scripts/run_squad.sh --mtype thsext --dtype fp16 - ``` - - If we want to do INT8 or sparse tests, please refer to the directory `examples/pytorch/bert/bert-quantization-sparsity` first for how to get a trained model. After we obtained a model checkpoint, we can run it by specify the `` and the related parameters. For example, run a BERT base model with INT8 mode 2 and sparse mode (need to do sparse training and QAT simultaneously for the checkpoint): - ```bash - bash ../examples/pytorch/bert/scripts/run_squad.sh \ - --mtype thsext \ - --dtype int8_2 \ - --path \ - --head_num 12 \ - -head_size 64 \ - --sparse true - ``` + 5.6 Evaluate the accuracy of Effective FasterTransformer under FP16 + + Since the total sequence length is not fixed, we recommend using the default gemm configuration directly for Effective FasterTransformer. + + ```bash + python ../examples/tensorflow/bert/tensorflow_bert/run_squad_wrap.py \ + --floatx=float16 \ + --predict_batch_size=8 \ + --vocab_file=squad_model/vocab.txt \ + --bert_config_file=squad_model/bert_config.json \ + --init_checkpoint=squad_fp16_model/model.ckpt \ + --train_file=squad_data/train-v1.1.json \ + --do_predict=True \ + --predict_file=squad_data/dev-v1.1.json \ + --max_seq_length=384 \ + --output_dir=./squad_ft_output/fp_16/ \ + --allow_gemm_test=False \ + --remove_padding=True + + python ../examples/tensorflow/bert/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_ft_output/fp_16/predictions.json + ``` + + The results of TensorFlow would be like: + + ```bash + {"exact_match": 79.04446546830653, "f1": 86.23343183703513} + ``` + + 5.7 Evaluate the accuracy of FasterTransformer under INT8 + + Please refer to the directory `examples/tensorflow/bert/bert-quantization` first for how to get a quantized model. In `section 5.7` and `section 5.8`, to keep consistent with the procedures described in `examples/tensorflow/bert/bert-quantization`, we use `https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip` as initial checkpoint with finetuned accuracy of == <89.57, 82.44>. + + In `bert-tf-quantization`, we give detailed procedure of Post Training Quantization (PTQ), Quantization Aware Training (QAT) and QAT with Knowledge-distillation. Since they have the same inference procedure, we use QAT-KD checkpoint to show how to evaluate the accuracy of FasterTransformer under INT8. + + Suppose we already fine-tuned a FP32 checkpoint using QAT-KD with int8_mode == 2 as described in `bert-quantization`. The path to checkpoint is `squad_model/QAT_KD_mode_2/`. + + We first convert the checkpoint from FP32 to FP16 (this step is not necessary, but it will give us a better performance) and then quantize the FP16 checkpoint using `../examples/tensorflow/bert/tensorflow_bert/ckpt_quantization.py`. This file requires three arguments, the location of initial checkpoint, the location putting the quantized checkpoint and the int8_mode. + + ```bash + python ../examples/tensorflow/bert/tensorflow_bert/ckpt_type_convert.py --init_checkpoint=squad_model/QAT_KD_mode_2/model.ckpt-27374 --fp16_checkpoint=squad_model/QAT_KD_mode_2_fp16/model.ckpt + + python ../examples/tensorflow/bert/tensorflow_bert/ckpt_quantization.py --init_checkpoint=squad_model/QAT_KD_mode_2_fp16/model.ckpt --quantized_checkpoint=squad_model/QAT_KD_mode_2_fp16_quantized/model.ckpt --int8_mode=2 + + ./bin/bert_gemm 8 384 12 64 1 1 + python ../examples/tensorflow/bert/tensorflow_bert/run_squad_wrap.py \ + --floatx=float16 \ + --predict_batch_size=8 \ + --vocab_file=squad_model/vocab.txt \ + --bert_config_file=squad_model/bert_config.json \ + --init_checkpoint=squad_model/QAT_KD_mode_2_fp16_quantized/model.ckpt \ + --train_file=squad_data/train-v1.1.json \ + --do_predict=True \ + --predict_file=squad_data/dev-v1.1.json \ + --max_seq_length=384 \ + --output_dir=./squad_ft_output/int8_mode_2/ \ + --int8_mode=2 \ + --allow_gemm_test=False + + python ../examples/tensorflow/bert/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_ft_output/int8_mode_2/predictions.json + ``` + + The results of TensorFlow would be like: + + ```bash + {"exact_match": 83.85052034058657, "f1": 90.46351799300075} + ``` + + 5.8 Evaluate the accuracy of Effective FasterTransformer under INT8 + + To evaluate the accuracy of Effective FasterTransformer under INT8, we follow the steps described in above section to get the correct checkpoint, and then run `../examples/tensorflow/bert/tensorflow_bert/run_squad_wrap.py` with `--remove_padding True` flag. + + ```bash + ./bin/bert_gemm 8 384 12 64 1 1 + python ../examples/tensorflow/bert/tensorflow_bert/run_squad_wrap.py \ + --floatx=float16 \ + --predict_batch_size=8 \ + --vocab_file=squad_model/vocab.txt \ + --bert_config_file=squad_model/bert_config.json \ + --init_checkpoint=squad_model/QAT_KD_mode_2_fp16_quantized/model.ckpt \ + --train_file=squad_data/train-v1.1.json \ + --do_predict=True \ + --predict_file=squad_data/dev-v1.1.json \ + --max_seq_length=384 \ + --output_dir=./squad_ft_output/int8_mode_2_effectiveTransformer/ \ + --remove_padding=True \ + --int8_mode=2 \ + --allow_gemm_test=False + + python ../examples/tensorflow/bert/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_ft_output/int8_mode_2_effectiveTransformer/predictions.json + ``` + + The results of TensorFlow would be like: + + ```bash + {"exact_match": 83.85052034058657, "f1": 90.46351799300075} + ``` + +### Run FasterTransformer BERT on PyTorch + + Please install HuggingFace's `transformers` first before running the demos by + + ```bash + pip install transformers==2.5.1 + ``` + + 1 Generate the `gemm_config.in` file: + + ```bash + ./bin/bert_gemm + ./bin/bert_gemm 1 32 12 64 1 0 + ``` + Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) + + If you want to use the library in other directory, please generate this file according to your setting and copy it to your working directory. + + 2 Run the PyTorch BERT sample: + + ```bash + python ../examples/pytorch/bert/bert_example.py <--data_type fp32/fp16/bf16> <--int8_mode 0/1/2/3> <--sparse> <--time> + python ../examples/pytorch/bert/bert_example.py 1 12 32 12 64 --data_type fp16 --time + ``` + + Set `--data_type fp32` for fp32 mode, and set `--data_type bf16` for bf16 mode. `--int8_mode 1` or `--int8_mode 2` or `--int8_mode 3` will use int8_mode 1 or 2 or 3 in FasterTransformer. `--sparse` will use Ampere sparsity feature if FasterTransformer is built with sparsity support. + + The outputs should be like to the following: + + ```bash + FT Mean diff: 0.0004119873046875 + FT Max diff: 0.00830078125 + FT Min diff: 0.0 + EFF-FT Mean diff: 0.0004119873046875 + EFF-FT Max diff: 0.00830078125 + EFF-FT Min diff: 0.0 + [INFO] HuggingFaceEnocder time costs: 5.77 ms + [INFO] FasterTransformer time costs: 1.12 ms + [INFO] EFF-FasterTransformer time costs: 1.33 ms + ``` + + 3 Run the BERT MRPC sample code: + + ```bash + bash ../examples/pytorch/bert/scripts/run_mrpc.sh + ``` + the `` can be: + - `ori`: original HuggingFace's BERT encoder + - `ths`: original HuggingFace's BERT encoder in TorchScript mode + - `thsext`: our TorchScript custom class + + the `` can be `fp32` or `fp16` or `bf16` + + For example, run HuggingFace's BERT under FP32 by following scripts: + + ```bash + bash ../examples/pytorch/bert/scripts/run_mrpc.sh ori fp32 + ``` + + The outputs should be like to the following: + + ```bash + 06/28/2020 07:29:59 - INFO - __main__ - Evaluation for mrpc done in total 4.646116 secs (0.011388 sec per example) + 06/28/2020 07:29:59 - INFO - __main__ - ***** Eval results ***** + 06/28/2020 07:29:59 - INFO - __main__ - acc = 0.8284313725490197 + 06/28/2020 07:29:59 - INFO - __main__ - acc_and_f1 = 0.8556872581808643 + 06/28/2020 07:29:59 - INFO - __main__ - f1 = 0.8829431438127091 + ``` + + For example, run our PyTorch custom op under FP16 by following scripts: + + ```bash + bash ../examples/pytorch/bert/scripts/run_mrpc.sh thsext fp16 + ``` + + The outputs should be like to the following: + + ```bash + 06/28/2020 07:30:19 - INFO - __main__ - Evaluation for mrpc done in total 1.725153 secs (0.004228 sec per example) + 06/28/2020 07:30:19 - INFO - __main__ - ***** Eval results ***** + 06/28/2020 07:30:19 - INFO - __main__ - acc = 0.8284313725490197 + 06/28/2020 07:30:19 - INFO - __main__ - acc_and_f1 = 0.8556872581808643 + 06/28/2020 07:30:19 - INFO - __main__ - f1 = 0.8829431438127091 + ``` + + 4 Run the BERT SQuAD sample code: + + ```bash + bash ../examples/pytorch/bert/scripts/run_squad.sh --mtype --dtype --path --head_num --head_size --bs --seqlen --sparse --remove_padding + ``` + - the `` can be: + - `ori`: original HuggingFace's BERT encoder + - `ths`: original HuggingFace's BERT encoder in TorchScript mode + - `thsext`: our TorchScript custom class + - the `` can be `fp32` or `fp16` or `bf16` or `int8_1` or `int8_2` or `int8_3` + - `` is the directory containing the checkpoint + - ``, ``, ``, and `` are the model and running parameters + - `` can be `true` or `false` + - `` can be `true` or `false` + + For example, run our PyTorch custom op under FP16 by following scripts (using HuggingFace's checkpoint): + + ```bash + bash ../examples/pytorch/bert/scripts/run_squad.sh --mtype thsext --dtype fp16 + ``` + + If we want to do INT8 or sparse tests, please refer to the directory `examples/pytorch/bert/bert-quantization-sparsity` first for how to get a trained model. After we obtained a model checkpoint, we can run it by specify the `` and the related parameters. For example, run a BERT base model with INT8 mode 2 and sparse mode (need to do sparse training and QAT simultaneously for the checkpoint): + ```bash + bash ../examples/pytorch/bert/scripts/run_squad.sh \ + --mtype thsext \ + --dtype int8_2 \ + --path \ + --head_num 12 \ + -head_size 64 \ + --sparse true + ``` + +### Run the PyTorch BERT sample with multi-GPU: + + Since v5.1, FasterTransformer supports multi-node multi-GPU inference on BERT model under FP32, FP16 and BF16. Users can use `--tensor_para_size` and `pipeline_para_size` to control the tensor parallelism and pipeline parallelism. + + ```bash + mpirun -n 4 python3 ../examples/pytorch/bert/bert_example.py 32 12 32 12 64 --data_type fp16 --tensor_para_size 2 --pipeline_para_size 2 + ``` + + * Note that multi-node inference is also supported. The usage is same to multi-GPU under MPI. + * Note that the performances under model parallelism are affected by the hardware and network significantly. If your GPUs are connected by PCIe, using model parallelism may bring few speedup or even worse performance. ## Performance @@ -969,6 +1027,7 @@ Hardware settings: * A100 (with mclk 1593MHz, pclk 1410MHz) with AMD EPYC 7742 64-Core Processor * T4 (with mclk 5000MHz, pclk 1590MHz) with Intel(R) Xeon(R) CPU E5-2670 0 @ 2.60GHz * V100 (with mclk 877MHz, pclk 1380MHz) with Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz (dgx-1 server) +* A10 (with mclk 6251, pclk 1695) with AMD EPYC 7232P 8-Core Processor Note that the CPU may affect the performance when the batch size and sequence length are small. @@ -978,7 +1037,58 @@ To run the following benchmark, we need to install the unix computing tool "bc" apt-get install bc ``` -### BERT performance +### Multi-GPU BERT-6B performance on A100 and triton example + +These benchmarks are ran by `bert_triton_example.cc`. These benchmark compare the performance of different model parallelism size. The GPUs are connected by NVLink. Note that for model parallelism, there are large effects on hardware and network. + +The model configuration are: +* head_num = 32 +* size_per_head = 128 +* num_layers = 32 + +* Compare tensor parallelism (TP) and pipeline parallelism (PP) + +| Batch_size | Seq_len | Precision | TP1, PP1
Latency (ms) | TP2, PP1
Latency (ms) | TP1, PP2
Latency (ms) | TP2, PP 1
Speedup | TP1, PP 2
Speedup | +| :--------: | :-----: | :-------: | :-------------------------: | :-------------------------: | :-------------------------: | :---------------------: | :---------------------: | +| 1 | 32 | fp16 | 10.58 | 8.35 | 10.81 | 1.27 | 0.98 | +| 1 | 128 | fp16 | 10.87 | 9.39 | 11.14 | 1.16 | 0.98 | +| 1 | 384 | fp16 | 26.07 | 19.48 | 25.42 | 1.34 | 1.03 | +| 1 | 1024 | fp16 | 47.02 | 32.13 | 46.56 | 1.46 | 1.01 | +| 4 | 32 | fp16 | 12.44 | 10.91 | 16.4 | 1.14 | 0.76 | +| 4 | 128 | fp16 | 24.98 | 18.34 | 26.86 | 1.36 | 0.93 | +| 4 | 384 | fp16 | 61.99 | 38.19 | 50.46 | 1.62 | 1.23 | +| 32 | 32 | fp16 | 26.36 | 19.62 | 27.01 | 1.34 | 0.98 | +| 32 | 32 | fp16 | 26.5 | 19.67 | 27.51 | 1.35 | 0.96 | +| 32 | 128 | fp16 | 132 | 81.42 | 89.9 | 1.62 | 1.47 | +| 32 | 384 | fp16 | 409.41 | 233.69 | 282.46 | 1.75 | 1.45 | +| 32 | 1024 | fp16 | 1145.01 | 613.86 | 760.44 | 1.87 | 1.51 | +| 128 | 32 | fp16 | 112.37 | 70.65 | 87.77 | 1.59 | 1.28 | +| 128 | 128 | fp16 | 469.51 | 264.02 | 303.99 | 1.78 | 1.54 | +| 128 | 384 | fp16 | 1477.78 | 804.84 | 1020.95 | 1.84 | 1.45 | +| 128 | 1024 | fp16 | 4975.6 | 2629.93 | 3136.76 | 1.89 | 1.59 | + +* Compare scaling of different tensor parallelism (TP) + +| Batch_size | Seq_len | Precision | TP1, PP1
Latency (ms) | TP2, PP1
Latency (ms) | TP4, PP1
Latency (ms) | TP2, PP 1
Speedup | TP4, PP 1
Speedup | +| :--------: | :-----: | :-------: | :-------------------------: | :-------------------------: | :-------------------------: | :---------------------: | :---------------------: | +| 1 | 32 | fp16 | 10.58 | 8.35 | 6.21 | 1.27 | 1.70 | +| 1 | 128 | fp16 | 10.87 | 9.39 | 6.98 | 1.16 | 1.56 | +| 1 | 384 | fp16 | 26.07 | 19.48 | 14.65 | 1.34 | 1.78 | +| 1 | 1024 | fp16 | 47.02 | 32.13 | 21.29 | 1.46 | 2.21 | +| 4 | 32 | fp16 | 12.44 | 10.91 | 8.67 | 1.14 | 1.43 | +| 4 | 128 | fp16 | 24.98 | 18.34 | 14.12 | 1.36 | 1.77 | +| 4 | 384 | fp16 | 61.99 | 38.19 | 26.56 | 1.62 | 2.33 | +| 32 | 32 | fp16 | 26.36 | 19.62 | 14.55 | 1.34 | 1.81 | +| 32 | 32 | fp16 | 26.5 | 19.67 | 14.53 | 1.35 | 1.82 | +| 32 | 128 | fp16 | 132 | 81.42 | 48.93 | 1.62 | 2.70 | +| 32 | 384 | fp16 | 409.41 | 233.69 | 138.12 | 1.75 | 2.96 | +| 32 | 1024 | fp16 | 1145.01 | 613.86 | 364.4 | 1.87 | 3.14 | +| 128 | 32 | fp16 | 112.37 | 70.65 | 44.84 | 1.59 | 2.51 | +| 128 | 128 | fp16 | 469.51 | 264.02 | 161.36 | 1.78 | 2.91 | +| 128 | 384 | fp16 | 1477.78 | 804.84 | 477.71 | 1.84 | 3.09 | +| 128 | 1024 | fp16 | 4975.6 | 2629.93 | 1529.07 | 1.89 | 3.25 | + +### Single GPU BERT performance We demonstrate the inference time of FasterTransformer in C++, TensorFlow and PyTorch, and compared to the performance on A100, T4 and V100. @@ -999,6 +1109,7 @@ In the experiments of encoder, we updated the following parameters: * head_num = 12 * size_per_head = 64 * num_layers = 12 +* EFF-FT: Use remove padding. The avereage sequence length is set to half of `Seq_len`. For example, when `Seq_len` is 128, the real sequence length is 64, and other 64 tokens are paddings. #### BERT performance on A100 and TensorFlow @@ -1298,6 +1409,64 @@ User can use `export NVIDIA_TF32_OVERRIDE=1` to enforce the program run under TF | <32, 128> | 3.62 | 2.94 | 2.28 | 2.04 | 1.23 | 1.12 | | <32, 384> | 10.10 | 8.60 | 5.77 | 5.38 | 1.17 | 1.07 | +#### BERT performance on A10 and PyTorch + +* Performance on FP32 + +| Batch_size | Seq_len | Precision | TorchScript
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +| :--------: | :-----: | :-------: | :----------------------------: | :-------------------: | :-----------------------: | :--------------: | :------------------: | +| 1 | 32 | FP32 | 3.98 | 2.08 | 2.31 | 1.91 | 1.72 | +| 1 | 128 | FP32 | 3.98 | 3.58 | 2.70 | 1.11 | 1.47 | +| 1 | 384 | FP32 | 9.38 | 7.95 | 3.83 | 1.17 | 2.44 | +| 8 | 32 | FP32 | 6.91 | 5.83 | 3.31 | 1.18 | 2.08 | +| 8 | 128 | FP32 | 20.30 | 16.82 | 12.94 | 1.20 | 1.56 | +| 8 | 384 | FP32 | 65.68 | 54.49 | 28.36 | 1.20 | 2.31 | +| 32 | 32 | FP32 | 20.48 | 17.13 | 10.51 | 1.19 | 1.94 | +| 32 | 128 | FP32 | 80.18 | 68.22 | 41.48 | 1.17 | 1.93 | +| 32 | 384 | FP32 | 264.69 | 213.87 | 141.08 | 1.23 | 1.87 | + +* Performance on FP16 + +| Batch_size | Seq_len | Precision | TorchScript
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +| :--------: | :-----: | :-------: | :----------------------------: | :-------------------: | :-----------------------: | :--------------: | :------------------: | +| 1 | 32 | FP16 | 4.20 | 1.06 | 1.76 | 3.96 | 2.38 | +| 1 | 128 | FP16 | 4.38 | 1.02 | 1.83 | 4.29 | 2.39 | +| 1 | 384 | FP16 | 4.09 | 2.30 | 1.81 | 1.77 | 2.25 | +| 8 | 32 | FP16 | 3.76 | 1.54 | 1.92 | 2.44 | 1.95 | +| 8 | 128 | FP16 | 5.78 | 4.02 | 2.81 | 1.43 | 2.05 | +| 8 | 384 | FP16 | 18.43 | 10.91 | 4.52 | 1.68 | 4.07 | +| 32 | 32 | FP16 | 5.36 | 3.92 | 2.41 | 1.36 | 2.22 | +| 32 | 128 | FP16 | 21.59 | 15.30 | 8.13 | 1.41 | 2.65 | +| 32 | 384 | FP16 | 73.50 | 43.74 | 25.22 | 1.68 | 2.91 | + +* Performance on INT8-v1 + +| Batch_size | Seq_len | TorchScript-FP16
Latency (ms) | FT-INT8-v1
Latency (ms) | EFF-FT-INT8-v1
Latency (ms) | FT-INT8-v1
Speedup | EFF-FT-INT8-v1
Speedup | +| :--------: | :-----: | :---------------------------------: | :---------------------------: | :-------------------------------: | :----------------------: | :--------------------------: | +| 1 | 32 | 4.00 | 1.34 | 1.37 | 2.98 | 2.91 | +| 1 | 128 | 4.43 | 1.56 | 1.42 | 2.83 | 3.11 | +| 1 | 384 | 3.93 | 2.01 | 1.48 | 1.95 | 2.65 | +| 8 | 32 | 3.60 | 1.79 | 1.59 | 2.01 | 2.26 | +| 8 | 128 | 5.82 | 3.48 | 2.49 | 1.67 | 2.33 | +| 8 | 384 | 18.41 | 9.55 | 4.47 | 1.92 | 4.11 | +| 32 | 32 | 5.42 | 3.77 | 2.28 | 1.43 | 2.37 | +| 32 | 128 | 21.51 | 12.08 | 6.63 | 1.78 | 3.24 | +| 32 | 384 | 73.67 | 35.65 | 20.96 | 2.06 | 3.51 | + +* Performance on INT8-v2 + +| Batch_size | Seq_len | TorchScript-FP16
Latency (ms) | FT-INT8-v2
Latency (ms) | EFF-FT-INT8-v2
Latency (ms) | FT-INT8-v2
Speedup | EFF-FT-INT8-v2
Speedup | +| :--------: | :-----: | :---------------------------------: | :---------------------------: | :-------------------------------: | :----------------------: | :--------------------------: | +| 1 | 32 | 4.04 | 1.37 | 1.40 | 2.94 | 2.88 | +| 1 | 128 | 4.41 | 1.42 | 1.42 | 3.10 | 3.10 | +| 1 | 384 | 3.88 | 1.73 | 1.49 | 2.24 | 2.60 | +| 8 | 32 | 3.69 | 1.57 | 1.51 | 2.35 | 2.44 | +| 8 | 128 | 5.87 | 2.37 | 1.89 | 2.47 | 3.10 | +| 8 | 384 | 18.43 | 6.42 | 3.21 | 2.87 | 5.74 | +| 32 | 32 | 5.41 | 2.64 | 1.91 | 2.04 | 2.83 | +| 32 | 128 | 21.50 | 8.08 | 4.64 | 2.66 | 4.63 | +| 32 | 384 | 73.65 | 25.42 | 14.04 | 2.89 | 5.24 | + #### BERT performance on T4 and PyTorch * Performance on FP32 diff --git a/docs/decoder_guide.md b/docs/decoder_guide.md index 24b778913..ac3e7c2ce 100644 --- a/docs/decoder_guide.md +++ b/docs/decoder_guide.md @@ -245,8 +245,9 @@ For those unable to use the NGC container, to set up the required environment or `./bin/decoding_example` runs the decoding with beam search or sampling in the `C++`. The arguments of `decoding_example` is: ```bash - ./bin/decoding_example + ./bin/decoding_example ``` + Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) Then the following scripts can run the decoding with beam search under the above settings. @@ -277,17 +278,19 @@ For those unable to use the NGC container, to set up the required environment or [INFO] batch_size 32 beam_width 1 head_num 8 size_per_head 64 max_seq_len 32 num_layers 6 vocab_size 30000, top_k 0, top_p 0.500, FT-CPP-decoding-time 75.91 ms ``` - 1.3 Run decoding under FP16 on C++ + 1.3 Run decoding under FP16/BF16 on C++ - So far, we use the FP32 to run the FasterTransformer. If we use the volta or newer NVIDIA GPU, we can use tensor core to accelerate when we use the FP16. + So far, we use the FP32 to run the FasterTransformer. If we use the volta or newer NVIDIA GPU, we can use tensor core to accelerate when we use the FP16. BF16 is only supported after Ampere NVIDIA GPU (SM 80). - To use the FP16, we only need to set the `` flag to 1 like following: + To use the FP16, we only need to set the `` flag to 1 like following: ```bash ./bin/decoding_gemm 32 4 8 64 2048 30000 32 512 1 ./bin/decoding_example 32 4 8 64 2048 30000 6 32 32 512 0 0.0 1 ``` + To use the BF16, we only need to set the `` flag to 2. + Note that the configuration of FP32 and FP16 are different, so we need to generate the configuration again. The outputs should be like to the following: @@ -565,18 +568,20 @@ For those unable to use the NGC container, to set up the required environment or 3.1 Generate the `gemm_config.in` file: ```bash - ./bin/decoding_gemm + ./bin/decoding_gemm ./bin/decoding_gemm 8 4 8 64 2048 31538 32 512 1 ``` + + Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) + If you want to use the library in other directory, please generate this file according to your setting and copy it to your working directory. 3.2 Run the PyTorch decoder sample: ```bash - python ../examples/pytorch/decoder/decoder_example.py <--fp16> <--time> - python ../examples/pytorch/decoder/decoder_example.py 8 6 32 8 64 --fp16 --time + python ../examples/pytorch/decoder/decoder_example.py <--data_type fp32/fp16/bf16> <--time> + python ../examples/pytorch/decoder/decoder_example.py 8 6 32 8 64 --data_type fp16 --time ``` - Remove `--fp16` for fp32 mode. The outputs should be like to the following: @@ -592,10 +597,9 @@ For those unable to use the NGC container, to set up the required environment or 3.3 Run the PyTorch decoding sample: ```bash - python pytorch/decoding_sample.py <--fp16> <--time> - python ../examples/pytorch/decoding/decoding_example.py 8 6 32 8 64 4 31538 --fp16 --time + python pytorch/decoding_sample.py <--data_type fp32/fp16/bf16> <--time> + python ../examples/pytorch/decoding/decoding_example.py 8 6 32 8 64 4 31538 --data_type fp16 --time ``` - Remove `--fp16` for fp32 mode. The outputs should be like to the following: @@ -698,7 +702,7 @@ For those unable to use the NGC container, to set up the required environment or - `torch_decoding`: PyTorch version decoding with the method FasterTransformer decoding uses - `torch_decoding_with_decoder_ext`: PyTorch version decoding with the method FasterTransformer decoding uses but replace the decoder with the FasterTransformer decoder - the `` can be `fp32` or `fp16` + the `` can be `fp32` or `fp16` or `bf16` If you do not specify the output file, it only print to the standard output. diff --git a/docs/gpt_guide.md b/docs/gpt_guide.md index 71c4fab2d..7422a83c3 100644 --- a/docs/gpt_guide.md +++ b/docs/gpt_guide.md @@ -16,8 +16,17 @@ - [Build the project](#build-the-project) - [How to use](#how-to-use) - [Prepare](#prepare-1) + - [Download openai-gpt model and convert](#download-openai-gpt-model-and-convert) + - [Download megatron model and convert](#download-megatron-model-and-convert) + - [Download onnx model and convert](#download-onnx-model-and-convert) + - [Download huggingface gpt model and convert](#download-huggingface-gpt-model-and-convert) - [Run GPT](#run-gpt) + - [Run GPT with prompts](#run-gpt-with-prompts) + - [Run Meta OPT](#run-meta-opt) - [gpt with triton backend](#gpt-with-triton-backend) + - [Advanced features](#advanced-features) + - [generate different sentences and enable shared context](#generate-different-sentences-and-enable-shared-context) + - [Interactive generation](#interactive-generation) - [Performance](#performance) - [Large model inference with model parallel](#large-model-inference-with-model-parallel) - [Performance of Megatron-530B](#performance-of-megatron-530b) @@ -27,7 +36,6 @@ - [Performance of GPT-6.7B](#performance-of-gpt-67b) - [Performance of GPT-1.3B](#performance-of-gpt-13b) - [Performance of GPT-350M](#performance-of-gpt-350m) - - [TODO](#todo) ## Introduction @@ -45,6 +53,7 @@ GPT is a variant of Decoding model, which does not have the encoder module, cros * Data type * FP32 * FP16 + * BF16 * INT8 weight only PTQ for bs 1 and 2 * Feature * Multi-GPU multi-node inference @@ -64,6 +73,7 @@ GPT is a variant of Decoding model, which does not have the encoder module, cros
Fig 1. Workflow of GPT model.
+

Fig 1 demonstrates the workflow of FasterTransformer GPT. Different from BERT and encoder-decoder structure, GPT receive some input ids as context, and generates the respective output ids as response. In this workflow, the major bottleneck is the GptDecoderLayer (transformer block) because the time increase linearly when we increase the number of layers. In GPT-3, the GptDecoderLayer takes about 95% of total time. @@ -77,6 +87,7 @@ FasterTransformer splits the whole workflow into 2 parts. The first one is “co Fig 2. Comparison between different self attention.             Fig 3. Workflow of GPT with tensor parallelism. +

The following examples demonstrating how to run multi-GPU and multi-node GPT model. 1. `examples/cpp/multi_gpu_gpt_example.cc`: It uses MPI to organize all GPUs. @@ -94,62 +105,82 @@ In c++ example codes, we skip the step 4 and step 6, loading the request by `exa The source codes are put in `src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc`. The arguments, input tensors and output tensors of GPT: -* Arguments: - 1. Maximum batch size (Deprecated, move to input) - 2. Maximum sequence length (Deprecated, move to input) - 3. Maximum input sequence length (Deprecated, move to input) - 4. beam width for beam search. If setting b to be 1, then we don’t use beam search but use sampling. (Deprecated, move to input) - 5. Head number - 6. Size per head - 7. Intermediate size. The inter size of feed forward network. It is often set to 4 * head_num * size_per_head. - 8. Number of decoder layers - 9. Vocab size - 10. Start id of the vocabulary - 11. End id of the vocabulary - 12. Diversity rate of beam search. A hyper hyper-parameter for [simple diverse decoding](https://arxiv.org/pdf/1611.08562.pdf). (Deprecated, move to input) - 13. top_k value for top k sampling. (Deprecated, move to input) - 14. top_p value for top p sampling. (Deprecated, move to input) - 15. Random seed for sampling. (Deprecated, move to input) - 16. Temperature for logit. Setting to be 1.0 if you don’t want to apply the temperature. (Deprecated, move to input) - 17. Length penalty for logit. Setting to be 1.0 if you don’t want to apply the length penalty. (Deprecated, move to input) - 18. Repetition penalty for logit. Setting to be 1.0 if you don’t want to apply the repetition penalty. (Deprecated, move to input) - 19. Tensor Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h`. - 20. Pipeline Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h`. - 21. CUDA stream. - 22. Pointer of cuBLAS wrapper, which is declared in `src/fastertransformer/utils/cublasMMWrapper.h`. - 23. Pointer of memory allocator, which is declared in `src/fastertransformer/utils/allocator.h` - 24. “is_free_buffer_after_forward” flag. If setting to be true, FasterTransformer will allocate buffer before forward, and free buffer after forward. If the memory is controlled by memory pool and the cost of allocating/releasing memory is small, setting the flag to be true can save some memory. - 25. Pointer of CUDA device properties, which is used to get the properties of hardware like size of shared memory. - 26. Is using sparsity. - 27. Int8 mode. - 28. Custom all reduction communication for custom all reduction in model parallelism. It is only supported in 8-way tensor parallelism. - 29. Flag of enable custom all reduction or not. -* Input tensors: - 1. Input ids (context). The shape is \[ request batch size * beam width, request maximum input length \]. - 2. Input lengths. The shape is \[ request batch size * beam width \]. - 3. Maximum output sequence length. An integer to describe the largest number of tokens you hope for results. Note that it includes the input ids. - 4. Stop word list. When FT generates words in this list, it will stop the generation. An extension of stop id, optional. - 5. Start id in runtime. The shape is \[batch_size\] on cpu, optional. If FT receives this input, FT will replace default start id by it, optional. - 6. End id in runtime. The shape is \[batch_size\] on cpu, optional. If FT receives this input, FT will replace default end id by it, optional. - 7. top_k value for top k sampling. The shape is \[1\] or \[batch_size, 1\] on cpu, optional. - 8. top_p value for top p sampling. The shape is \[1\] or \[batch_size, 1\] on cpu, optional. - 9. Diversity rate of beam search (beam_search_diversity_rate). A hyper hyper-parameter for [simple diverse decoding](https://arxiv.org/pdf/1611.08562.pdf). [1] or \[batch_size, 1\] on cpu, optional. - 10. Temperature for logit (temperature). The sahpe \[1\] or \[batch_size, 1\] on cpu, optional. - 11. Length penalty for logit (len_penalty). The shape is \[1\] or \[batch_size, 1\] on cpu, optional - 12. Repetition penalty for logit (repetition_penalty). The shape is \[1\] or \[batch_size, 1\] on cpu, optional - 13. Random_seed \[1\] or \[batch_size, 1\] on cpu, optional - 14. Length of prefix soft prompt embedding. This describes how many tokens of soft prompt embedding in each sentence. The shape is \[batch_size\], optional. - 15. Prefix soft prompt embedding. FT will concat them with results of embedding lookup kernel. The shape is \[batch_size, max_prefix_soft_prompt_length, hidden_units\], optional. -* Output tensors: - 1. Output ids. The shape is \[batch size, beam width, maximum output sequence length \]. - 2. Parent ids. It is used to find the best path in beam search. It is deprecated now. - 3. Sequence lengths. The shape is \[batch size * beam width\]. It records the final sequence lengths of all sentences. - 4. Log probability for sampling. The shape is \[requested token number, batch size, beam \]. It records the log probability of logits at each step. Optional outputs in FP32. - 5. Cumulative log probability of generated senteces. The shape is \[batch size, beam\]. Optional outputs in FP32. - -The `beam_width` value is set by the output shape directly. When the `beam_width` is larger than 1, FT will use beam search to generate tokens; otherwise, FT will use topk or topp sampling. - -We also provide the module `Gpt` in `src/fastertransformer/models/gpt/Gpt.cc`, which is a GPT model without model parallelism. It does not need the arguments 19 and 20, while others are same. +* Constructor of GPT + +| Classification | Name | Data Type | Description | +| :------------: | :--------------------------: | :----------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [0] | max_batch_size | size_t | **Deprecated, move to input** | +| [1] | max_seq_len | size_t | **Deprecated, move to input** | +| [2] | max_input_len | size_t | **Deprecated, move to input** | +| [3] | beam_width | size_t | **Deprecated, move to input** | +| [4] | head_num | size_t | Head number for model configuration | +| [5] | size_per_head | size_t | Size per head for model configuration | +| [6] | inter_size | size_t | The inter size of feed forward network. It is often set to 4 * head_num * size_per_head. | +| [7] | num_layer | size_t | Number of transformer layers for model configuration | +| [8] | vocab_size | int | Vocabulary size for model configuration | +| [9] | start_id | int | Start id for vocabulary | +| [10] | end_id | int | End id for vocabulary | +| [11] | prompt_learning_start_id | int | The start id of virtual token in p/prompt-tuning | +| [12] | prompt_learning_type | PromptLearningType | The type of prompt learning when we load the prompt embedding in constructor. FT supports `no_prompt`, `soft_prompt`, `prefix_prompt`, `p_prompt_tuning` now | +| [13] | gpt_variant_params | gptVariantParams | This structure defines some hyper-parameters of gpt layers, including type of layernorm and activation | +| [14] | beam_search_diversity_rate | float | **Deprecated, move to input** | +| [15] | top_k | size_t | **Deprecated, move to input** | +| [16] | top_p | float | **Deprecated, move to input** | +| [17] | random_seed | unsigned long long | **Deprecated, move to input** | +| [18] | temperature | float | **Deprecated, move to input** | +| [19] | len_penalty | float | **Deprecated, move to input** | +| [20] | repetition_penalty | float | **Deprecated, move to input** | +| [21] | tensor_para | NcclParam | Tensor Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [22] | pipeline_para | NcclParam | Pipeline Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [23] | stream | cudaStream_t | CUDA stream | +| [24] | cublas_wrapper | cublasMMWrapper* | Pointer of cuBLAS wrapper, which is declared in `src/fastertransformer/utils/cublasMMWrapper.h` | +| [25] | allocator | IAllocator* | Pointer of memory allocator, which is declared in `src/fastertransformer/utils/allocator.h` | +| [26] | is_free_buffer_after_forward | bool | If setting to be `true`, FasterTransformer will allocate buffer before forward, and free buffer after forward. When the allocator is based on memory pool, setting to `true` may help reducing the memory usage during inference. | +| [27] | cuda_device_prop | cudaDeviceProp* | Pointer of CUDA device properties, which is used to get the properties of hardware like size of shared memory | +| [28] | sparse | bool | Is using sparsity. **Experimental feature** | +| [29] | int8_mode | int | Using int8 weight only quantization or not. **Experimental feature** | +| [30] | custom_all_reduce_comm | AbstractCustomComm | Custom all reduction communication for custom all reduction in model parallelism. It is only supported in 8-way tensor parallelism | +| [31] | enable_custom_all_reduce | int | Flag of enabling custom all reduction or not | +| [32] | remove_padding | bool | Remove the padding of input ids or not in context phase. | +| [33] | shared_contexts_ratio | float | Ratio that controls the use of the shared contexts optimization. If the compact size (that accounts only for unique prompts) is less than ratio * batch size, use the optimized implementation. Setting shared_contexts_ratio=0 deactivate the optimization. | + +* Input of GPT + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :-----------------------------: | :-------------------------------------------: | :------: | :--------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| input_ids | [batch_size, max_input_length] | GPU | int | The input ids (context) | +| input_lengths | [batch_size] | GPU | int | The lengths of input ids | +| prompt_learning_task_name_ids | [batch_size] | CPU | int | **Optional**. Task name ids for prompt learning. | +| output_seq_len | [batch_size] | CPU | uint32_t | The largest number of tokens you hope for results. Note that it contains the input length | +| stop_words_list | [batch_size, 2, stop_words_length] | GPU | int | **Optional**. When FT generates words in this list, it will stop the generation. An extension of stop id | +| bad_words_list | [batch_size, 2, bad_words_length] | GPU | int | **Optional**. The words in the list will be When FT generates words in this list, it will stop the generation. An extension of stop id | +| start_id | [batch_size] | CPU | int | **Optional**. If FT receives this input, FT will replace default start id by it | +| end_id | [batch_size] | CPU | int | **Optional**. If FT receives this input, FT will replace default end id by it | +| runtime_top_k | [1] or [batch_size] | CPU | uint | **Optional**. top_k value for top k sampling | +| runtime_top_p | [1] or [batch_size] | CPU | float | **Optional**. top_p value for top p sampling | +| beam_search_diversity_rate | [1] or [batch_size] | CPU | float | **Optional**. A hyper hyper-parameter for [simple diverse decoding](https://arxiv.org/pdf/1611.08562.pdf) | +| temperature | [1] or [batch_size] | CPU | float | **Optional**. Temperature applied to logits for both beam search and sampling | +| len_penalty | [1] or [batch_size] | CPU | float | **Optional**. Length penalty applied to logits for only beam search | +| repetition_penalty | [1] or [batch_size] | CPU | float | **Optional**. Repetition penalty applied to logits for both beam search and sampling | +| random_seed | [1] or [batch_size] | CPU | unsigned long long int | **Optional**. Random seed to initialize the random table in sampling. | +| request_prompt_lengths | [batch_size], | CPU | int | **Optional**. Length of prefix soft prompt embedding. This describes how many tokens of soft prompt embedding in each sentence. | +| request_prompt_embedding | [batch_size, max_prompt_length, hidden_units] | GPU | float/half/bfloat16 | **Optional**. FT will concat them with results of embedding lookup kernel. For prefix soft prompt embedding, the type must be float; for p/prompt tuning, the type is same to weight. | +| request_prompt_type | [batch_size] | CPU | int | **Optional**. Prompt type of request. This is necessary when user pass the prompt embedding by input | +| is_return_context_cum_log_probs | [1] | CPU | bool | **Optional**. Return the cumulative log probability of context or not | +| session_len | [1] | CPU | uint32 | **Optional**. The maximum time length allowed during the whole interactive generation. Only used for interactive generation feature | +| continue_gen | [1] | CPU | bool | **Optional**. A flag to tell FasterTransformer to not discard previous tokens and continue producing token based on previous generations. Only used for interactive generation feature | +| memory_len | [1] | CPU | uint32 | **Optional**. The maximum time memory used in attention modules. Reduces the memory footprint but quality of generation might degrades. | + +* Output of GPT + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :--------------: | :----------------------------------------------: | :------: | :-------: | :-------------------------------------------------------------------------------: | +| output_ids | [batch_size, beam_width, max_output_seq_len] | GPU | int | The output ids. It contains the input_ids and generated ids | +| sequence_length | [batch_size, beam_width] | GPU | int | The lengths of output ids | +| output_log_probs | [batch_size, beam_width, request_output_seq_len] | GPU | float | **Optional**. It records the log probability of logits at each step for sampling. | +| cum_log_probs | [batch_size, beam_width] | GPU | float | **Optional**. Cumulative log probability of generated sentences | + +The `beam_width` value is set by the output shape directly. When the `beam_width` of `output_ids` is larger than 1, FT will use beam search to generate tokens; otherwise, FT will use topk or topp sampling. When the inputs of beam search and sampling is invalid, like beam width 1, top k 0, top p 0.0, FT will run greedy search automatically. ### Optimization @@ -167,11 +198,11 @@ The following guide demonstrates how to run the examples of c++, PyTorch and Tri - CMake >= 3.8 for Tensorflow, CMake >= 3.13 for PyTorch - CUDA 11.0 or newer version - NCCL 2.10 or newer version -- Python 3 is recommended because some features are not supported in python 2 +- Python: Only verify on python 3 - Tensorflow: Verify on 1.15, 1.13 and 1.14 should work. - PyTorch: Verify on 1.8.0, >= 1.5.0 should work. -Recommend use nvcr image like `nvcr.io/nvidia/tensorflow:21.11-tf1-py3` or `nvcr.io/nvidia/pytorch:21.11-py3`. +Recommend use nvcr image like `nvcr.io/nvidia/tensorflow:22.07-tf1-py3` or `nvcr.io/nvidia/pytorch:22.07-py3`. These components are readily available within the NGC TensorFlow Docker image below. @@ -195,10 +226,10 @@ For those unable to use the NGC container, to set up the required environment or You can choose the tensorflow version and python version you want. Here, we list some possible images: - To achieve best performance, we recommand to use the latest image. For example, running image `nvcr.io/nvidia/tensorflow:21.11-tf1-py3` by + To achieve best performance, we recommend to use the latest image. For example, running image `nvcr.io/nvidia/tensorflow:22.07-tf1-py3` by ```bash - nvidia-docker run -ti --rm nvcr.io/nvidia/tensorflow:21.11-tf1-py3 bash + nvidia-docker run -ti --rm nvcr.io/nvidia/tensorflow:22.07-tf1-py3 bash git clone https://github.com/NVIDIA/FasterTransformer.git mkdir -p FasterTransformer/build cd FasterTransformer/build @@ -207,7 +238,19 @@ For those unable to use the NGC container, to set up the required environment or #### Build the project -* Note: the `xx` of `-DSM=xx` in following scripts means the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100). Default setting is including 70, 75, 80 and 86. +* Note: the `xx` of `-DSM=xx` in following scripts means the compute capability of your GPU. The following table shows the compute capability of common GPUs. + +| GPU | compute capacity | +| :---: | :--------------: | +| P40 | 60 | +| P4 | 61 | +| V100 | 70 | +| T4 | 75 | +| A100 | 80 | +| A30 | 80 | +| A10 | 86 | + +By default, `-DSM` is set by 70, 75, 80 and 86. When users set more kinds of `-DSM`, it requires longer time to compile. So, we suggest setting the `-DSM` for the device you use only. Here, we use `xx` as an example due to convenience. 1. build with C++ @@ -218,7 +261,7 @@ For those unable to use the NGC container, to set up the required environment or 2. build with TensorFlow - Uses need to set the path of TensorFlow. For example, if we use `nvcr.io/nvidia/tensorflow:21.11-tf1-py3`, then + Uses need to set the path of TensorFlow. For example, if we use `nvcr.io/nvidia/tensorflow:22.07-tf1-py3`, then ```bash cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python3.8/dist-packages/tensorflow_core/ -DBUILD_MULTI_GPU=ON .. @@ -241,7 +284,7 @@ For those unable to use the NGC container, to set up the required environment or * Install required tools ```bash -pip install -r ../examples/pytorch/requirement.txt +pip install -r ../examples/pytorch/gpt/requirement.txt ``` To run the GPT on c, users need to convert the checkpoint of TensorFlow or PyTorch to binary files, and then load by FasterTransformer c api. Unfortunately, there is no published large model. So, users are only able to verify the correctness by smaller model. Currently, FasterTransformer provides two kinds of samples. First one is using the checkpoint of [OpenAI GPT-2 model](https://github.com/openai/gpt-2) (which is trained by TensorFlow); Another choice is using the checkpoint of [Megatron](https://github.com/NVIDIA/Megatron-LM) (which is trained by pytorch). @@ -255,9 +298,9 @@ wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -P ../m wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -P ../models ``` -* Downlaod openai-gpt model and convert +#### Download openai-gpt model and convert -To convert the OpenAI GPT model to binary, FasterTransformer provides a tool `sample/tensorflow/utils/openai_gpt_ckpt_convert.py` to convert the checkpoint. The converter requires the following arguemtns: +To convert the OpenAI GPT model to binary, FasterTransformer provides a tool `sample/tensorflow/utils/openai_gpt_ckpt_convert.py` to convert the checkpoint. The converter requires the following arguments: 1. `-i`: The path of megatron model 2. `-o`: The output path of converted model @@ -276,7 +319,7 @@ python ../examples/tensorflow/gpt/utils/openai_gpt_ckpt_converter.py -o ../model In the repo of OpenAI, they provide many models, including `124M`, `355M`, `774M` and `1558M` -* Download megatron model and convert +#### Download megatron model and convert To convert the Megatron GPT model to binary, FasterTransformer provides a tool `examples/pytorch/utils/megatron_ckpt_convert.py` to convert the checkpoint. @@ -284,47 +327,54 @@ To convert the Megatron GPT model to binary, FasterTransformer provides a tool ` wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip mkdir -p ../models/megatron-models/345m unzip megatron_lm_345m_v0.0.zip -d ../models/megatron-models/345m -python ../examples/pytorch/gpt/utils/megatron_ckpt_convert.py -head_num 16 -i ../models/megatron-models/345m/release/ -o ../models/megatron-models/c-model/345m/ -t_g 1 -i_g 1 -python ../examples/pytorch/gpt/utils/megatron_ckpt_convert.py -head_num 16 -i ../models/megatron-models/345m/release/ -o ../models/megatron-models/c-model/345m/ -t_g 1 -i_g 8 +export PYTHONPATH=$PWD/..:${PYTHONPATH} +python ../examples/pytorch/gpt/utils/megatron_ckpt_convert.py \ + -head_num 16 \ + -i ../models/megatron-models/345m/release/ \ + -o ../models/megatron-models/c-model/345m/ \ + -t_g 1 \ + -i_g 1 \ + --vocab-path ../models/gpt2-vocab.json \ + --merges-path ../models/gpt2-merges.txt +python ../examples/pytorch/gpt/utils/megatron_ckpt_convert.py \ + -head_num 16 \ + -i ../models/megatron-models/345m/release/ \ + -o ../models/megatron-models/c-model/345m/ \ + -t_g 1 \ + -i_g 8 \ + --vocab-path ../models/gpt2-vocab.json \ + --merges-path ../models/gpt2-merges.txt ``` where `t_g` means the number GPUs of TP during training, and `i_g` means the number of GPUs for TP during inference. -Note that there are different checkpoint version of Megatron. The version of the checkpoint above is 0. If users have trained a model by themselves, the default version of latest Megatron is 3. To convert the checkpoint with version 3, please add `-checkpoint_version 3`. +Note that there are different checkpoint version of Megatron. The version of the checkpoint above is 0. -For model trained by pipeline parallelism, please use new checkpoint converter `megatron_ckpt_convert_2.py`. This converter is only able to convert the newer version of checkpoint. +For model trained by pipeline parallelism or the checkpoint version is 3, you don't need to specify head_num or checkpoint_version as it can retrieve from model_args. ```bash -python ../examples/pytorch/gpt/utils/megatron_ckpt_convert_2.py -i ../models/megatron-models/345m/release/ -o ../models/megatron-models/c-model/345m/ -i_g 1 +python ../examples/pytorch/gpt/utils/megatron_ckpt_convert.py -i ../models/megatron-models/345m/release/ -o ../models/megatron-models/c-model/345m/ -i_g 1 ``` -* How to use `checkpoint_saver_fastertransformer.py` to convert the megatron model. Note that this tool is only available for newer checkpoint. Need to get more details from ADLR team. - -```bash -git clone -b checkpoint_util https://gitlab-master.nvidia.com/ADLR/megatron-lm.git # This is still an internal tool. -cp ../examples/pytorch/gpt/utils/checkpoint_saver_fastertransformer.py megatron-lm/tools -cd megatron-lm -python tools/checkpoint_util.py --model-type GPT --loader megatron --saver fastertransformer --input ../megatron_new_ckpt/357m-pipeline-2-tensor-2/ --output ../tmp --target-tensor-parallel-size 1 -``` +#### Download onnx model and convert -* Download onnx model and convert +Note that the original `gpt2-10.onnx` model at `https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-10.onnx` is removed. And new link `https://github.com/onnx/models/blob/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx` cannot be loaded by onnx successfully. -To convert the Megatron GPT model to binary, FasterTransformer provides a tool `examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py` to convert the checkpoint. +To convert the ONNX GPT model to binary, FasterTransformer provides a tool `examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py` to convert the checkpoint. ```bash -wget https://github.com/onnx/models/raw/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx +wget https://github.com/onnx/models/blob/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx python ../examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py -i gpt2-10.onnx -o ../models/onnx-models/c-model/124m/ -i_g 1 python ../examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py -i gpt2-10.onnx -o ../models/onnx-models/c-model/124m/ -i_g 4 ``` -* Downlaod huggingface gpt model and convert +#### Download huggingface gpt model and convert ```bash git clone https://huggingface.co/gpt2-xl python ../examples/pytorch/gpt/utils/huggingface_gpt_convert.py -i gpt2-xl/ -o ../models/huggingface-models/c-model/gpt2-xl -i_g 1 ``` - ### Run GPT 1. Run GPT under on C++ with multiple gpu @@ -353,7 +403,7 @@ python ../examples/pytorch/gpt/utils/huggingface_gpt_convert.py -i gpt2-xl/ -o . python ../examples/pytorch/gpt/utils/gpt_token_converter.py --vocab_file=../models/gpt2-vocab.json --bpe_file=../models/gpt2-merges.txt ``` - By setting the `is_half` of `gpt_config.ini` to 1, users can run gpt model under fp16. + By setting the `data_type` of `gpt_config.ini` to `fp16` or `bf16`, users can run gpt model under fp16 or bf16. 1.3 Run with tensor parallelism (TP), pipeline parallelism (PP) @@ -370,9 +420,9 @@ python ../examples/pytorch/gpt/utils/huggingface_gpt_convert.py -i gpt2-xl/ -o . ```bash srun -N2 -n2 -t 600 --pty bash # Assume we get 2 nodes: prm-dgx-09 and prm-dgx-10 - srun -N2 -n2 docker pull nvcr.io/nvidia/tensorflow:20.07-tf1-py3 + srun -N2 -n2 docker pull nvcr.io/nvidia/tensorflow:22.07-tf1-py3 - srun -N2 -n2 nvidia-docker run -itd --rm --privileged --network=host --pid=host --cap-add=IPC_LOCK --device=/dev/infiniband -v $PWD:$PWD -w $PWD --name ft-test nvcr.io/nvidia/tensorflow:21.11-tf1-py3 /bin/bash + srun -N2 -n2 nvidia-docker run -itd --rm --privileged --network=host --pid=host --cap-add=IPC_LOCK --device=/dev/infiniband -v $PWD:$PWD -w $PWD --name ft-test nvcr.io/nvidia/tensorflow:22.07-tf1-py3 /bin/bash srun -N2 -n2 nvidia-docker exec -i --env SLURM_NTASKS --env SLURM_NODEID --env SLURM_PROCID --env SLURM_STEP_NODELIST --env SLURMD_NODENAME --privileged ft-test bash -c "mkdir /root/.ssh && cp $PWD/ssh/* /root/.ssh && chmod 700 /root/.ssh && chmod 640 /root/.ssh/authorized_keys2 && chmod 400 /root/.ssh/id_rsa && apt-get update && apt-get install ssh -y && mkdir /run/sshd/ && /usr/sbin/sshd -p 11068 && nvidia-smi -lgc 1530" @@ -386,7 +436,7 @@ python ../examples/pytorch/gpt/utils/huggingface_gpt_convert.py -i gpt2-xl/ -o . Basically, `gpt_example.py` includes the example how to declare a model, load a ckeckpoint, and forward context inputs and get generated outputs in Pytorch. - For generating outputs based on context inputs, create a text file including the context inputs (line by line) and set `--sample_file_input` to the text file path. (By default, the script will generate outputs without context inputs.) Set `--sample_file_output` to write the outputs to a file. Use `--fp_16` to run in FP16. + For generating outputs based on context inputs, create a text file including the context inputs (line by line) and set `--sample_file_input` to the text file path. (By default, the script will generate outputs without context inputs.) Set `--sample_file_output` to write the outputs to a file. Use `--data_type fp16/bf16` to run in FP16 or BF16. Run with `-h` to see more settings. ```bash @@ -412,7 +462,7 @@ python ../examples/pytorch/gpt/utils/huggingface_gpt_convert.py -i gpt2-xl/ -o . #### Set up in interactive mode ```bash - srun -A devtech -J devtech-gpt:gpt -p luna -N1 --mpi=pmix --ntasks-per-node=8 --container-image nvcr.io/nvidia/pytorch:20.12-py3 --container-mounts /lustre/fsw/devtech/hpc-devtech/dahn/FasterTransformer:/workspace/fastertransformer --container-workdir /workspace/fastertransformer --pty bash + srun -A devtech -J devtech-gpt:gpt -p luna -N1 --mpi=pmix --ntasks-per-node=8 --container-image nvcr.io/nvidia/pytorch:22.07-py3 --container-mounts /lustre/fsw/devtech/hpc-devtech/dahn/FasterTransformer:/workspace/fastertransformer --container-workdir /workspace/fastertransformer --pty bash mkdir build && cd build cmake -DSM=80 -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON .. && make -j12 @@ -422,19 +472,43 @@ python ../examples/pytorch/gpt/utils/huggingface_gpt_convert.py -i gpt2-xl/ -o . * tensor_para_size=8, pipeline_para_size=1 ```bash - srun -A devtech -p luna -N1 --mpi=pmix --ntasks-per-node=8 --container-image nvcr.io/nvidia/pytorch:20.12-py3 --container-mounts /lustre/fsw/devtech/hpc-devtech/dahn/FasterTransformer:/workspace/fastertransformer --container-workdir /workspace/fastertransformer/build python ../examples/pytorch/gpt/multi_gpu_gpt_example.py --tensor_para_size=8 --pipeline_para_size=1 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/8-gpu" + srun -A devtech -p luna -N1 --mpi=pmix --ntasks-per-node=8 --container-image nvcr.io/nvidia/pytorch:22.07-py3 --container-mounts /lustre/fsw/devtech/hpc-devtech/dahn/FasterTransformer:/workspace/fastertransformer --container-workdir /workspace/fastertransformer/build python ../examples/pytorch/gpt/multi_gpu_gpt_example.py --tensor_para_size=8 --pipeline_para_size=1 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/8-gpu" ``` #### Run on multi-node * tensor_para_size=8, pipeline_para_size=2 ```bash - srun -A devtech -p luna -N2 --mpi=pmix --ntasks-per-node=8 --container-image nvcr.io/nvidia/pytorch:20.12-py3 --container-mounts /lustre/fsw/devtech/hpc-devtech/dahn/FasterTransformer:/workspace/fastertransformer --container-workdir /workspace/fastertransformer/build python ../examples/pytorch/gpt/multi_gpu_gpt_example.py --tensor_para_size=8 --pipeline_para_size=2 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/8-gpu" + srun -A devtech -p luna -N2 --mpi=pmix --ntasks-per-node=8 --container-image nvcr.io/nvidia/pytorch:22.07-py3 --container-mounts /lustre/fsw/devtech/hpc-devtech/dahn/FasterTransformer:/workspace/fastertransformer --container-workdir /workspace/fastertransformer/build python ../examples/pytorch/gpt/multi_gpu_gpt_example.py --tensor_para_size=8 --pipeline_para_size=2 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/8-gpu" + ``` + + 2.2 Run LAMBADA test on PyTorch + + download data set: + + ```bash + wget https://github.com/cybertronai/bflm/raw/master/lambada_test.jsonl -P ../models/megatron-models + export PYTHONPATH=$PWD/../:$PYTHONPATH + python ../examples/pytorch/gpt/utils/update_gpt_config.py \ + --model-dir ../models/megatron-models/c-model/345m/1-gpu/ \ + --config-ini-path ../models/megatron-models/c-model/345m/1-gpu/config.ini \ + --pipeline-para-size 1 \ + --tensor-para-size 1 \ + --max-seq-len 512 \ + --beam-width 1 \ + --sampling-top-k 1 \ + --sampling-top-p 0 \ + --data-type fp16 + python ../examples/pytorch/gpt/lambada_task_example.py \ + --batch-size 64 \ + --checkpoint-path ../models/megatron-models/c-model/345m/1-gpu/ \ + --lib-path lib/libth_parallel_gpt.so \ + --lambada-path ../models/megatron-models/lambada_test.jsonl ``` 3. Run GPT on tensorflow - Note that the tensorflow op only supports single gpu. + Follow [Download openai-gpt model and convert](#download-openai-gpt-model-and-convert) to prepare the model. Assume the TF model is put in `../models/openai-gpt-models/`. ```bash ./bin/gpt_gemm 4 1 32 12 64 3072 50257 1 1 @@ -442,22 +516,328 @@ python ../examples/pytorch/gpt/utils/huggingface_gpt_convert.py -i gpt2-xl/ -o . --length=32 \ --top_k=4 \ --top_p=0.6 \ - --data_type=fp16 + --data_type=fp16 \ + --models_dir=../models/openai-gpt-models/ ``` -4. Run LAMBADA test + Note that the tensorflow op only supports single gpu. - download data set: +### Run GPT with prompts + +GPT now supports p/prompt-tuning. It works with [nemo checkpoint and prompt learning](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/prompt_learning.html). + +1. Convert the prompt weights + + Use the `examples/pytorch/gpt/utils/nemo_ckpt_convert.py` to convert the NeMo Megatron Prompt Weights. + It will automatically generate configuration needed for triton backend inference. + + Note that you need to specify `start_id`, `end_id` by yourself in order to make sure that it is consistent with the tokenizer. + +2. Run GPT with C++ example + + You need to specify the example gpt_config.ini like below to enable the p/prompt_tuning feature. + + ```ini + [gptj_6B] + head_num=16 + size_per_head=256 + vocab_size=50400 + decoder_layers=28 + rotary_embedding=64 + start_id=50256 + end_id=50256 + inter_size=16384 + num_tasks=2 + prompt_learning_type=2 + + ;prompt learning example (soft prompt doesn't need it) + [gptj_6B_task_0] + task_name=task_0 + prompt_length=5 + + [gptj_6B_task_1] + task_name=task_1 + prompt_length=10 + ``` + + `task_name` and `prompt_length` are specified for loading prompt weights. + `prompt_learning_start_id` is needed for checking whether ids are prompts or normal input ids. + + **prompt_learning_type**: + + - no prompt: 0 + - soft_prompt: 1 + - prefix_prompt: 2 + - p/prompt_tuning: 3 + +### Run Meta OPT + +Meta OPT and OpenAI GPT do not have big differences in terms of structures, so they are sharing the same model and triton backend classes. \ +You need to convert the Huggingface Meta Opt models to fastertransformer format by `examples/pytorch/gpt/utils/huggingface_opt_convert.py`. + +1. Run OPT under on C++ with multiple gpu + + Users can see the details of arguments in `examples/cpp/multi_gpu_gpt/gpt_config.ini`. It controls the model path, model size, tensor parallelism size, and some hyper-parameters.\ + In order to run with Meta Opt models, you need to add additional configuraitons: `model_variant`, which controls the `layernorm_eps, layernorm_type, activation_type, has_post_decoder_layernorm`. + + For example, the opt 125m model configuraitons would be like: + ```ini + [opt_125M] + head_num=12 + size_per_head=64 + vocab_size=50272 + decoder_layers=12 + start_id=2 + end_id=2 + inter_size=3072 + model_variant=opt-pre ;define variant structure + ``` + There are two model types: opt-pre = [pre_layernorm](https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L332), opt_post = [post_layernorm](https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L323)\ + **Note that:** [the model has post decoder layernorm when layernorm_type is pre_layernorm](https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L498). + +2. Run OPT on PyTorch + + We can run summarization task examples of meta opt models. See `examples/pytorch/gpt/opt_summarization.py`. + + Note that the summarization test are ran by topk = 2, so the rouge score of HF and FT are often different. + + * Run on opt-125m model ```bash - wget https://github.com/cybertronai/bflm/raw/master/lambada_test.jsonl -P ../models/megatron-models - bash ../examples/pytorch/gpt/scripts/evaluate_zeroshot_gpt.sh + git lfs clone https://huggingface.co/facebook/opt-125m + python ../examples/pytorch/gpt/utils/huggingface_opt_convert.py \ + -i opt-125m/ \ + -o opt-125m/c-model/ \ + -i_g 1 + python3 ../examples/pytorch/gpt/opt_summarization.py \ + --summarize \ + --test_hf \ + --max_ite 20 \ + --ft_model_location opt-125m/c-model \ + --hf_model_name opt-125m + ``` + + The results are similar to: + + ``` + Hugging Face (total latency: 9.258284 sec) + rouge1 : 20.36984889475218 + rouge2 : 4.854345624891912 + rougeL : 14.82866480289381 + rougeLsum : 18.23638863809613 + Faster Transformers (total latency: 3.9376330000000004 sec) + rouge1 : 26.676168312282357 + rouge2 : 10.004052949342602 + rougeL : 19.20934213532261 + rougeLsum : 24.243496576656323 + ``` + + * Run on opt-350m model + + ```bash + git lfs clone https://huggingface.co/facebook/opt-350m + python ../examples/pytorch/gpt/utils/huggingface_opt_convert.py \ + -i opt-350m/ \ + -o opt-350m/c-model/ \ + -i_g 1 + python3 ../examples/pytorch/gpt/opt_summarization.py \ + --summarize \ + --test_hf \ + --max_ite 20 \ + --ft_model_location opt-350m/c-model \ + --hf_model_name opt-350m \ + --data_type fp16 + ``` + + The results are similar to: + + ``` + Hugging Face (total latency: 21.961627 sec) + rouge1 : 28.939621379501467 + rouge2 : 9.858278077813752 + rougeL : 19.159853526952528 + rougeLsum : 26.120654334830885 + Faster Transformers (total latency: 6.293255999999998 sec) + rouge1 : 26.80687566772978 + rouge2 : 8.639787737378661 + rougeL : 18.90520115636779 + rougeLsum : 24.372302912676407 + ``` + +3. Run OPT with Triton Backends + + Model configurations have been automatically generated when converting the [meta opt models](https://huggingface.co/docs/transformers/model_doc/opt).\ + Then, you can use the converted weights and configuration file to serve the model by triton servers. + Example of the `config.ini` when converting the model: + ```ini + [gpt] + model_name = opt-350m/ + head_num = 16 + size_per_head = 64 + inter_size = 4096 + max_pos_seq_len = 2048 + num_layer = 24 + layernorm_eps = 1e-5 + layernorm_type = post_layernorm + activation_type = Relu + has_post_decoder_layernorm = 0 + vocab_size = 50272 + start_id = 2 + end_id = 2 + weight_data_type = fp32 ``` ### gpt with triton backend Details are in [transformer_backend](https://github.com/triton-inference-server/fastertransformer_backend) +### Advanced features + +#### generate different sentences and enable shared context + +The model downloading and conversion are described in [Download megatron model and convert](#download-megatron-model-and-convert). + +A common request is, we have single input request, and hope to reply multiple results with different random seed. To achieve this target, we can mulpitle the inputs by several times, and set different random seed for different sentences in a batch. You can enable it by adding `--enable_random_seed`. Otherwise, all random seed would be set to 0 by default. + +For example, we prepare a input with batch size 4, and the sentences are all same. + +```bash +for i in {1..4} ; do echo " Article : (CNN)James Best, best known for his portrayal of bumbling sheriff Rosco P. Coltrane on TV's \"The Dukes of Hazzard,\" died Monday after a brief illness. He was 88. Best died in hospice in Hickory, North Carolina, of complications from pneumonia, said Steve Latshaw, a longtime friend and Hollywood colleague. Although he'd been a busy actor for decades in theater and in Hollywood, Best didn't become famous until 1979, when \"The Dukes of Hazzard's\" cornpone charms began beaming into millions of American homes almost every Friday night. For seven seasons, Best's Rosco P. Coltrane chased the moonshine-running Duke boys back and forth across the back roads of fictitious Hazzard County, Georgia, although his \"hot pursuit\" usually ended with him crashing his patrol car. Although Rosco was slow-witted and corrupt, Best gave him a childlike enthusiasm that got laughs and made him endearing. His character became known for his distinctive \"kew-kew-kew\" chuckle and for goofy catchphrases such as \"cuff 'em and stuff 'em! \" upon making an arrest. Among the most popular shows on TV in the early '80s, \"The Dukes of Hazzard\" ran until 1985 and spawned TV movies, an animated series and video games. Several of Best's \"Hazzard\" co-stars paid tribute to the late actor on social media. \"I laughed and learned more from Jimmie in one hour than from anyone else in a whole year,\" co-star John Schneider, who played Bo Duke, said on Twitter. \"Give Uncle Jesse my love when you see him dear friend.\" \"Jimmy Best was the most constantly creative person I have ever known,\" said Ben Jones, who played mechanic Cooter on the show, in a Facebook post. \"Every minute of his long life was spent acting, writing, producing, painting, teaching, fishing, or involved in another of his life's many passions.\" Born Jewel Guy on July 26, 1926, in Powderly, Kentucky, Best was orphaned at 3 and adopted by Armen and Essa Best, who renamed him James and raised him in rural Indiana. Best served in the Army during World War II before launching his acting career. TL;DR: " >> sample_input.txt ; done +``` + +Then, we run the `multi_gpu_gpt_example.py` with `--enable_random_seed`: + +```bash +python3 ../examples/pytorch/gpt/multi_gpu_gpt_example.py \ + --ckpt_path ../models/megatron-models/c-model/345m/1-gpu/ \ + --vocab_file ../models/gpt2-vocab.json \ + --merges_file ../models/gpt2-merges.txt \ + --sample_input_file sample_input.txt \ + --max_batch_size 4 \ + --time \ + --top_p 0.9 \ + --top_k 0 \ + --shared_contexts_ratio 0.0 \ + --enable_random_seed \ + --output_len 8 +``` + +You can see the results are little different, and the program will show the time cost like: + +```bash +[INFO] GPT time costs: 64.25 ms +``` + +Although this method can achieve our target, but computing same duplicated inputs is waste. So, we can set `--shared_contexts_ratio` to compute the duplicated inputs once in context phase: + +```bash +python3 ../examples/pytorch/gpt/multi_gpu_gpt_example.py \ + --ckpt_path ../models/megatron-models/c-model/345m/1-gpu/ \ + --vocab_file ../models/gpt2-vocab.json \ + --merges_file ../models/gpt2-merges.txt \ + --sample_input_file sample_input.txt \ + --max_batch_size 4 \ + --time \ + --top_p 0.9 \ + --top_k 0 \ + --shared_contexts_ratio 1.0 \ + --enable_random_seed \ + --output_len 8 +``` + +You can see the inference is faster than original one like: + +```bash +[INFO] GPT time costs: 41.69 ms +``` + +Notes: +1. The results of enabling `shared_context` and disabling `shared_context` may be different because the shape of GEMM are changed. But it does not affect the qualities of generation. +2. We use short `output_len` in this example to demonstarte the benefit of `shared_context`. In real application, the more duplicated input, longer input length compared to output length, the more speedup `shared_context` brings. +3. Since the additional overhead of enabling `shared_context` is ignorable, we enable it by default. + +#### Interactive generation + +
+       +
+
+ Fig 4. GPT generate some outputs by some inputs +
+

+ +
+       +
+
+ Fig 5. New inputs with previous texts and some additional new input ids. +
+

+ +In some scenarios (like chatting), the new requests are related to previous requests. Currently, users can pass all previous inputs and outputs as a new inputs into FT to make FT generate new reply from these previous texts, like what we see in Fig 4 and Fig 5. However, this means that we need to re-compute the k/v cache of all previous inputs and outputs again, which is time wasting when the context is very long. + +
+       +
+
+ Fig 6. The workflow of generation with interactive generation +
+

+ +To achieve better performance and prevent useless computing, we add a new flag `continue_gen` into GPT. When this flag is on, FT keeps all results during generation and assume the users will provide some more texts. And FT would not compute the k/v cache of the results it already has, but only compute the k/v cache of new ids. The workflow would become what we demonstrate in Fig 6. To prevent allocate the memory buffer again, users also need to set the `session_len` to be the maximum sequence length of the final sentence, but not only for intermediate sentence. + +We will use `multi_gpu_gpt_interactive_example` to demonstarte how to use this feature. In this example, we load the `examples/cpp/multi_gpu_gpt/start_ids.csv` first (the input length are all 8): + +``` +818, 262, 938, 3155, 286, 1528, 11, 257 +198, 464, 968, 8221, 2732, 286, 15198, 318 +464, 968, 1971, 12056, 423, 257, 649, 1182 +464, 968, 1971, 3782, 468, 3199, 663, 5079 +818, 257, 1445, 326, 481, 1884, 787, 340 +464, 968, 1971, 12056, 6, 5859, 41683, 423 +198, 198, 464, 5398, 4332, 628, 628, 198 +464, 717, 640, 314, 2497, 262, 3807, 11 +``` + +then generates 32 tokens with setting `continue_gen=true` to get an intermediate results (the results are saved in `out.interm`): + +``` +818 262 938 3155 286 1528 11 257 1256 286 661 423 587 4737 502 546 262 649 1492 11 290 314 1053 587 2111 284 3280 617 286 262 2683 326 661 423 587 4737 502 13 198 198 +198 464 968 8221 2732 286 15198 318 1762 351 262 1181 338 9358 5011 284 5004 262 1266 835 284 1445 262 4979 13 198 1 1135 821 1016 284 307 2045 379 262 1266 835 284 1445 262 +464 968 1971 12056 423 257 649 1182 3985 11 290 339 338 257 3516 508 338 587 1088 262 4652 329 257 890 640 13 679 338 257 3516 508 338 587 1088 262 4652 329 257 890 640 +464 968 1971 3782 468 3199 663 5079 1351 286 262 995 338 749 14212 661 13 198 464 1351 11 543 373 14102 416 262 968 1971 3782 11 318 1912 319 257 5526 286 517 621 352 11 +818 257 1445 326 481 1884 787 340 4577 329 262 1664 284 3677 663 7303 11 262 1664 468 4987 284 3677 663 10171 287 262 1664 284 257 1448 286 7713 2957 416 262 2839 13598 4081 309 +464 968 1971 12056 6 5859 41683 423 587 257 1263 636 286 262 1074 338 1943 428 1622 13 198 464 12056 423 587 1498 284 1057 262 2613 6840 11 290 484 423 587 1498 284 1057 262 +198 198 464 5398 4332 628 628 198 198 464 5398 4332 628 628 198 198 464 5398 4332 628 628 198 198 464 5398 4332 628 628 198 198 464 5398 4332 628 628 198 198 464 5398 4332 +464 717 640 314 2497 262 3807 11 314 373 588 11 705 5812 616 1793 11 428 318 523 3608 2637 314 373 588 11 705 40 765 284 307 287 428 3807 2637 314 373 588 11 705 +``` + +Next, we load another inputs from `examples/cpp/multi_gpu+gpt/interactive_inputs_ids` (the input length are all 8 again): + +``` +5962, 11, 314, 561, 588, 284, 910, 326 +11125, 286, 2844, 291, 5028, 422, 262, 7627 +392, 257, 1913, 1998, 351, 1353, 12, 28282 +830, 34643, 11, 7602, 11, 4708, 6332, 1938 +5, 38328, 763, 13, 1119, 481, 2148, 257 +3245, 355, 257, 22080, 1074, 13, 4042, 286 +14150, 26443, 262, 1230, 338, 1410, 284, 3958 +5195, 4398, 470, 314, 7342, 340, 2961, 30 +``` + +and pass into FT again (note that we only need to pass new ids because FT already records all previous ids). Then FT will concatenate these new ids into output ids, compute k/v caches for only these new ids, and then generate another 32 tokens as a new response (the results are saved in `out`): + +``` +818 262 938 3155 286 1528 11 257 1256 286 661 423 587 4737 502 546 262 649 1492 11 290 314 1053 587 2111 284 3280 617 286 262 2683 326 661 423 587 4737 502 13 198 198 5962 11 314 561 588 284 910 326 314 1101 407 257 4336 286 262 1492 13 314 892 340 338 257 1310 1165 881 286 257 366 10919 611 1 1492 13 314 892 340 338 257 1310 1165 +198 464 968 8221 2732 286 15198 318 1762 351 262 1181 338 9358 5011 284 5004 262 1266 835 284 1445 262 4979 13 198 1 1135 821 1016 284 307 2045 379 262 1266 835 284 1445 262 11125 286 2844 291 5028 422 262 7627 7784 15296 284 262 7421 7784 15296 553 531 42743 6523 3899 1024 33246 271 13 198 464 42743 318 635 2045 379 262 5885 286 3867 262 4979 422 262 7421 +464 968 1971 12056 423 257 649 1182 3985 11 290 339 338 257 3516 508 338 587 1088 262 4652 329 257 890 640 13 679 338 257 3516 508 338 587 1088 262 4652 329 257 890 640 392 257 1913 1998 351 1353 12 28282 18370 13 679 338 257 3516 508 338 587 1088 262 4652 329 257 890 640 13 679 338 257 3516 508 338 587 1088 262 4652 329 257 890 640 13 +464 968 1971 3782 468 3199 663 5079 1351 286 262 995 338 749 14212 661 13 198 464 1351 11 543 373 14102 416 262 968 1971 3782 11 318 1912 319 257 5526 286 517 621 352 11 830 34643 11 7602 11 4708 6332 1938 290 584 14212 661 13 198 464 1351 318 14102 416 262 968 1971 3782 290 318 3199 319 262 3052 286 262 7533 13 198 464 1351 318 20633 416 262 +818 257 1445 326 481 1884 787 340 4577 329 262 1664 284 3677 663 7303 11 262 1664 468 4987 284 3677 663 10171 287 262 1664 284 257 1448 286 7713 2957 416 262 2839 13598 4081 309 5 38328 763 13 1119 481 2148 257 2472 286 720 16 13 20 2997 287 5003 290 4283 13 198 464 1730 318 2938 284 1969 287 262 1218 2063 286 428 614 13 198 464 1664 531 340 +464 968 1971 12056 6 5859 41683 423 587 257 1263 636 286 262 1074 338 1943 428 1622 13 198 464 12056 423 587 1498 284 1057 262 2613 6840 11 290 484 423 587 1498 284 1057 262 3245 355 257 22080 1074 13 4042 286 262 640 11 262 12056 423 587 1498 284 1057 262 2613 6840 11 290 484 423 587 1498 284 1057 262 3245 355 257 22080 1074 13 198 464 12056 423 +198 198 464 5398 4332 628 628 198 198 464 5398 4332 628 628 198 198 464 5398 4332 628 628 198 198 464 5398 4332 628 628 198 198 464 5398 4332 628 628 198 198 464 5398 4332 14150 26443 262 1230 338 1410 284 3958 262 779 286 262 1573 366 16991 1 287 262 1499 338 1743 3303 13 198 198 464 1230 338 1410 284 3958 262 779 286 262 1573 366 16991 1 287 +464 717 640 314 2497 262 3807 11 314 373 588 11 705 5812 616 1793 11 428 318 523 3608 2637 314 373 588 11 705 40 765 284 307 287 428 3807 2637 314 373 588 11 705 5195 4398 470 314 7342 340 2961 30 4162 4398 470 314 1775 340 878 8348 314 373 588 11 705 40 765 284 307 287 428 3807 2637 314 373 588 11 705 40 765 284 307 287 428 +``` + ## Performance Hardware settings (A100 SuperPod architecture): @@ -481,17 +861,21 @@ We demonstrate the inference time of Megatron and FasterTransformer on Triton, a TP means tensor parallelism, PP means pipeline parallelism.
-
Fig 4. Latency on input length 60, output length 20. TP means tensor parallelism and PP means pipeline parallelism.
+
Fig 7. Latency on input length 60, output length 20. TP means tensor parallelism and PP means pipeline parallelism.
+

-
Fig 5. Throughput per GPU on input length 60, output length 20. TP means tensor parallelism and PP means pipeline parallelism.
+
Fig 8. Throughput per GPU on input length 60, output length 20. TP means tensor parallelism and PP means pipeline parallelism.
+

-
Fig 6. Latency on fixing output length 20, 16 ways tensor parallelism, different input length and batch size.
+
Fig 9. Latency on fixing output length 20, 16 ways tensor parallelism, different input length and batch size.
+

-
Fig 7. Latency on fixing input length 128, 16 ways tensor parallelism, different output length and batch size.
- +
Fig 10. Latency on fixing input length 128, 16 ways tensor parallelism, different output length and batch size.
+

+3 | Batch Size | Input Length | Output Length | Latency of TP-16, PP-1 (ms) | Latency of TP-32, PP-1 (ms) | Latency of TP-8, PP-3 (ms) | | :--------: | :----------: | :-----------: | :-------------------------: | :-------------------------: | :------------------------: | | 1 | 20 | 8 | 565 | 431 | 842 | @@ -724,5 +1108,3 @@ TP means tensor parallelism | 16 | 512 | 32 | 189.91 | 4.80 | | 32 | 512 | 32 | 296.15 | 6.09 | | 64 | 512 | 32 | 529.18 | 8.67 | - -## TODO diff --git a/docs/gptj_guide.md b/docs/gptj_guide.md index c757db8c9..0fa67225d 100644 --- a/docs/gptj_guide.md +++ b/docs/gptj_guide.md @@ -9,27 +9,101 @@ - [Setup](#setup) - [Requirements](#requirements) - [Docker image](#docker-image) - - [Setup](#setup-1) - - [Build](#build) + - [Build project](#build-project) - [Download the model](#download-the-model) + - [Download tables](#download-tables) - [Run GPT-J](#run-gpt-j) + - [Run GPTJ with prompts](#run-gptj-with-prompts) - [Compare with reference implementation](#compare-with-reference-implementation) - [gpt-j with triton backend](#gpt-j-with-triton-backend) - ## Introduction This document describes the step to run the GPT-J model on FasterTransformer. GPT-J was developed by EleutherAI and trained on The Pile, a 825GB dataset from curated sources (e.g. Wikipedia, arXiv, GitHub, StackExchange, PubMed, ...). With 6 billion parameters, GPT-J is one of the largest GPT-like publicly released models as of 2021. +Optimization in GPT-j are similar to optimization in GPT, describing in the [gpt_guide.md](gpt_guide.md#optimization). + +* Constructor of GPT-j + +| Classification | Name | Data Type | Description | +| :------------: | :--------------------------: | :----------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [0] | max_batch_size | size_t | **Deprecated, move to input** | +| [1] | max_seq_len | size_t | **Deprecated, move to input** | +| [2] | max_input_len | size_t | **Deprecated, move to input** | +| [3] | beam_width | size_t | **Deprecated, move to input** | +| [4] | head_num | size_t | Head number for model configuration | +| [5] | size_per_head | size_t | Size per head for model configuration | +| [6] | inter_size | size_t | The inter size of feed forward network. It is often set to 4 * head_num * size_per_head. | +| [7] | num_layer | size_t | Number of transformer layers for model configuration | +| [8] | vocab_size | size_t | Vocabulary size for model configuration | +| [9] | rotary_embeeding_dim | size_t | Rotary embedding dimension of rotary position embedding for model configuration | +| [10] | start_id | int | Start id for vocabulary | +| [11] | end_id | int | End id for vocabulary | +| [12] | prompt_learning_start_id | int | The start id of virtual token in p/prompt-tuning | +| [13] | prompt_learning_type | PromptLearningType | The type of prompt learning when we load the prompt embedding in constructor. FT supports `no_prompt`, `soft_prompt`, `prefix_prompt`, `p_prompt_tuning` now | +| [14] | beam_search_diversity_rate | float | **Deprecated, move to input** | +| [15] | top_k | size_t | **Deprecated, move to input** | +| [16] | top_p | float | **Deprecated, move to input** | +| [17] | random_seed | unsigned long long | **Deprecated, move to input** | +| [18] | temperature | float | **Deprecated, move to input** | +| [19] | len_penalty | float | **Deprecated, move to input** | +| [20] | repetition_penalty | float | **Deprecated, move to input** | +| [21] | tensor_para | NcclParam | Tensor Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [22] | pipeline_para | NcclParam | Pipeline Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [23] | stream | cudaStream_t | CUDA stream | +| [24] | cublas_wrapper | cublasMMWrapper* | Pointer of cuBLAS wrapper, which is declared in `src/fastertransformer/utils/cublasMMWrapper.h` | +| [25] | allocator | IAllocator* | Pointer of memory allocator, which is declared in `src/fastertransformer/utils/allocator.h` | +| [26] | is_free_buffer_after_forward | bool | If setting to be `true`, FasterTransformer will allocate buffer before forward, and free buffer after forward. When the allocator is based on memory pool, setting to `true` may help reducing the memory usage during inference. | +| [27] | cuda_device_prop | cudaDeviceProp* | Pointer of CUDA device properties, which is used to get the properties of hardware like size of shared memory | +| [28] | custom_all_reduce_comm | AbstractCustomComm | Custom all reduction communication for custom all reduction in model parallelism. It is only supported in 8-way tensor parallelism | +| [29] | enable_custom_all_reduce | int | Flag of enabling custom all reduction or not | + +* Input of GPT-j + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :---------------------------: | :-------------------------------------------: | :------: | :--------------------: | :-------------------------------------------------------------------------------------------------------------------------------------: | +| input_ids | [batch_size, max_input_length] | GPU | int | The input ids (context) | +| input_lengths | [batch_size] | GPU | int | The lengths of input ids | +| prompt_learning_task_name_ids | [batch_size] | CPU | int | **Optional**. Task name ids for prompt learning. | +| output_seq_len | [batch_size] | CPU | uint32_t | The largest number of tokens you hope for results. Note that it contains the input length | +| start_id | [batch_size] | CPU | int | **Optional**. If FT receives this input, FT will replace default start id by it | +| end_id | [batch_size] | CPU | int | **Optional**. If FT receives this input, FT will replace default end id by it | +| stop_words_list | [batch_size, 2, stop_words_length] | GPU | int | **Optional**. FT would not generate the tokens in the list. | +| bad_words_list | [batch_size, 2, bad_words_length] | GPU | int | **Optional**. The words in the list will be When FT generates words in this list, it will stop the generation. An extension of stop id | +| runtime_top_k | [1] or [batch_size] | CPU | uint | **Optional**. top_k value for top k sampling | +| runtime_top_p | [1] or [batch_size] | CPU | float | **Optional**. top_p value for top p sampling | +| beam_search_diversity_rate | [1] or [batch_size] | CPU | float | **Optional**. A hyper hyper-parameter for [simple diverse decoding](https://arxiv.org/pdf/1611.08562.pdf) | +| temperature | [1] or [batch_size] | CPU | float | **Optional**. Temperature applied to logits for both beam search and sampling | +| len_penalty | [1] or [batch_size] | CPU | float | **Optional**. Length penalty applied to logits for only beam search | +| repetition_penalty | [1] or [batch_size] | CPU | float | **Optional**. Repetition penalty applied to logits for both beam search and sampling | +| random_seed | [1] or [batch_size] | CPU | unsigned long long int | **Optional**. Random seed to initialize the random table in sampling. | +| request_prompt_lengths | [batch_size], | CPU | int | **Optional**. Length of prefix soft prompt embedding. This describes how many tokens of soft prompt embedding in each sentence. | +| request_prompt_embedding | [batch_size, max_prompt_length, hidden_units] | GPU | float | **Optional**. Prefix soft prompt embedding. FT will concat them with results of embedding lookup kernel | +| request_prompt_type | [batch_size] | CPU | int | **Optional**. Prompt type of request. This is necessary when user pass the prompt embedding by input | +| memory_len | [1] | CPU | uint32 | **Optional**. The maximum time memory used in attention modules. Reduces the memory footprint but quality of generation might degrades. | + +* Output of GPT-j + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :--------------: | :----------------------------------------------: | :------: | :-------: | :-------------------------------------------------------------------------------: | +| output_ids | [batch_size, beam_width, max_output_seq_len] | GPU | int | The output ids. It contains the input_ids and generated ids | +| sequence_length | [batch_size, beam_width] | GPU | int | The lengths of output ids | +| output_log_probs | [batch_size, beam_width, request_output_seq_len] | GPU | float | **Optional**. It records the log probability of logits at each step for sampling. | +| cum_log_probs | [batch_size, beam_width] | GPU | float | **Optional**. Cumulative log probability of generated sentences | + +The `beam_width` value is set by the output shape directly. When the `beam_width` of `output_ids` is larger than 1, FT will use beam search to generate tokens; otherwise, FT will use topk or topp sampling. When the inputs of beam search and sampling is invalid, like beam width 1, top k 0, top p 0.0, FT will run greedy search automatically. + ### Supported features * Checkpoint converter * EleutherAI + * Huggingface * Data type * FP32 * FP16 + * BF16 * Feature * Multi-GPU multi-node inference * Dynamic random seed @@ -39,61 +113,6 @@ With 6 billion parameters, GPT-J is one of the largest GPT-like publicly release * Frameworks * Triton backend -Optimization in GPT-j are similar to optimization in GPT, describing in the [gpt_guide.md](gpt_guide.md#optimization). - -* Arguments: - 1. Maximum batch size (Deprecated, move to input) - 2. Maximum sequence length (Deprecated, move to input) - 3. Maximum input sequence length (Deprecated, move to input) - 4. beam width for beam search. If setting b to be 1, then we don’t use beam search but use sampling. (Deprecated, move to input) - 5. Head number - 6. Size per head - 7. Intermediate size. The inter size of feed forward network. It is often set to 4 * head_num * size_per_head. - 8. Number of decoder layers. - 9. Vocab size. - 10. Rotary embedding for attetnion. - 11. Start id of the vocabulary. - 12. End id of the vocabulary. - 13. Diversity rate of beam search. A hyper hyper-parameter for [simple diverse decoding](https://arxiv.org/pdf/1611.08562.pdf). (Deprecated, move to input) - 14. top_k value for top k sampling. (Deprecated, move to input) - 15. top_p value for top p sampling. (Deprecated, move to input) - 16. Random seed for sampling. (Deprecated, move to input) - 17. Temperature for logit. Setting to be 1.0 if you don’t want to apply the temperature. (Deprecated, move to input) - 18. Length penalty for logit. Setting to be 1.0 if you don’t want to apply the length penalty. (Deprecated, move to input) - 19. Repetition penalty for logit. Setting to be 1.0 if you don’t want to apply the repetition penalty. (Deprecated, move to input) - 20. Tensor Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h`. - 21. Pipeline Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h`. - 22. CUDA stream. - 23. Pointer of cuBLAS wrapper, which is declared in `src/fastertransformer/utils/cublasMMWrapper.h`. - 24. Pointer of memory allocator, which is declared in `src/fastertransformer/utils/allocator.h` - 25. “is_free_buffer_after_forward” flag. If setting to be true, FasterTransformer will allocate buffer before forward, and free buffer after forward. If the memory is controlled by memory pool and the cost of allocating/releasing memory is small, setting the flag to be true can save some memory. - 26. Pointer of CUDA device properties, which is used to get the properties of hardware like size of shared memory. - 27. Custom all reduction communication for custom all reduction in model parallelism. It is only supported in 8-way tensor parallelism. - 28. Flag of enable custom all reduction or not. -* Input tensors: - 1. Input ids (context). The shape is \[ request batch size * beam width, request maximum input length \]. - 2. Input lengths. The shape is \[ request batch size * beam width \]. - 3. Maximum output sequence length. An integer to describe the largest number of tokens you hope for results. Note that it includes the input ids. - 4. Start id in runtime. The shape is \[batch_size\] on cpu, optional. If FT receives this input, FT will replace default start id by it, optional. - 5. End id in runtime. The shape is \[batch_size\] on cpu, optional. If FT receives this input, FT will replace default end id by it, optional. - 6. Stop word list. When FT generates words in this list, it will stop the generation. An extension of stop id, optional. - 7. Bad word list. FT won't generates words in this list, optional. - 8. top_k value for top k sampling. The shape is \[1\] or \[batch_size, 1\] on cpu, optional. - 9. top_p value for top p sampling. The shape is \[1\] or \[batch_size, 1\] on cpu, optional. - 10. Diversity rate of beam search (beam_search_diversity_rate). A hyper hyper-parameter for [simple diverse decoding](https://arxiv.org/pdf/1611.08562.pdf). [1] or \[batch_size, 1\] on cpu, optional. - 11. Temperature for logit (temperature). The sahpe \[1\] or \[batch_size, 1\] on cpu, optional. - 12. Length penalty for logit (len_penalty). The shape is \[1\] or \[batch_size, 1\] on cpu, optional - 13. Repetition penalty for logit (repetition_penalty). The shape is \[1\] or \[batch_size, 1\] on cpu, optional - 14. Random_seed \[1\] or \[batch_size, 1\] on cpu, optional - 15. Length of prefix soft prompt embedding. This describes how many tokens of soft prompt embedding in each sentence. The shape is \[batch_size\], optional. - 16. Prefix soft prompt embedding. FT will concat them with results of embedding lookup kernel. The shape is \[batch_size, max_prefix_soft_prompt_length, hidden_units\], optional. -* Output tensors: - 1. Output ids. The shape is \[batch size, beam width, maximum output sequence length \]. - 2. Sequence lengths. The shape is \[batch size * beam width\]. It records the final sequence lengths of all sentences. - 3. Log probability for sampling. The shape is \[requested token number, batch size, beam \]. It records the log probability of logits at each step. Optional outputs in FP32. - -The beam_width value is set by the output shape directly. When the beam_width is larger than 1, FT will use beam search to generate tokens; otherwise, FT will use topk or topp sampling. - ## Setup ### Requirements @@ -101,10 +120,10 @@ The beam_width value is set by the output shape directly. When the beam_width is - CMake >= 3.13 for PyTorch - CUDA 11.0 or newer version - NCCL 2.10 or newer version -- Python 3 is recommended because some features are not supported in python 2 +- Python: Only verify on python 3 - PyTorch: Verify on 1.8.0, >= 1.5.0 should work. -Recommend use nvcr image like `nvcr.io/nvidia/pytorch:21.11-py3`. +Recommend use nvcr image like `nvcr.io/nvidia/pytorch:22.07-py3`. These components are readily available within the NGC Docker image below. @@ -122,13 +141,13 @@ For those unable to use the NGC container, to set up the required environment or ### Docker image -* The model was built and tested with the use nvcr image `nvcr.io/nvidia/pytorch:21.07-py3`. e.g. +* The model was built and tested with the use nvcr image `nvcr.io/nvidia/pytorch:22.07-py3`. e.g. ```bash - nvidia-docker run -ti --rm nvcr.io/nvidia/pytorch:21.07-py3 bash + nvidia-docker run -ti --rm nvcr.io/nvidia/pytorch:22.07-py3 bash ``` -### Setup +### Build project * Get the code and install all dependencies: @@ -140,10 +159,19 @@ For those unable to use the NGC container, to set up the required environment or pip3 install fire jax jaxlib ``` -### Build +* Note: the `xx` of `-DSM=xx` in following scripts means the compute capability of your GPU. The following table shows the compute capability of common GPUs. -* Note: the `xx` of `-DSM=xx` in following scripts means the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100). Default setting is including 70, 75, 80 and 86. +| GPU | compute capacity | +| :---: | :--------------: | +| P40 | 60 | +| P4 | 61 | +| V100 | 70 | +| T4 | 75 | +| A100 | 80 | +| A30 | 80 | +| A10 | 86 | +By default, `-DSM` is set by 70, 75, 80 and 86. When users set more kinds of `-DSM`, it requires longer time to compile. So, we suggest setting the `-DSM` for the device you use only. Here, we use `xx` as an example due to convenience. ```bash cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_MULTI_GPU=ON .. @@ -152,11 +180,12 @@ For those unable to use the NGC container, to set up the required environment or ### Download the model -* Download the public model and convert +* Download the mystic public model and convert ```bash wget https://mystic.the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd - tar -axf step_383500_slim.tar.gz + unzstd step_383500_slim.tar.zstd + tar -axf step_383500_slim.tar python3 ../examples/pytorch/gptj/utils/gptj_ckpt_convert.py --output-dir ../models/j6b_ckpt --ckpt-dir ./step_383500/ ``` @@ -165,6 +194,21 @@ The script accepts the following arguments: 2. `--ckpt-dir` is the path to the extracted checkpoint. If `--ckpt-dir` terminates with `.pt` then the script reads the Pytorch model file instead than the public checkpoint, which is faster. 3. `--n-inference-gpus` number of GPUs used for inference, defaults to 1. The binary model parameters are saved to `${output-dir}/${n-inference-gpus}-gpu/` +* Download the huggingface gptj model and convert + + ```bash + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/EleutherAI/gpt-j-6B + python3 ../examples/pytorch/gptj/utils/huggingface_gptj_ckpt_convert.py --ckpt-dir gpt-j-6B/ --output-dir gpt-j-6B/c-models/ --n-inference-gpus 1 + ``` + +The script accepts the following arguments: +1. `--output-dir` is the path of the base directory where the weight binary files will be saved. +2. `--ckpt-dir` is the path to the extracted checkpoint. +3. `--n-inference-gpus` number of GPUs used for inference, defaults to 1. The binary model parameters are saved to `${output-dir}/${n-inference-gpus}-gpu/` + +### Download tables * The vocabolary and merge tables are the same as for GPT @@ -179,7 +223,7 @@ The script accepts the following arguments: Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) ```bash ./bin/gpt_gemm - E.g., ./bin/gpt_gemm 8 1 32 16 128 16384 50400 1 1 + E.g., ./bin/gpt_gemm 8 1 32 16 256 16384 50400 1 1 ``` * Run GPT on C++ @@ -190,7 +234,7 @@ The script accepts the following arguments: mpirun -n 1 --allow-run-as-root ./bin/gptj_example ``` -E.g. by setting the `is_half` of `gpt_config.ini` to 1, users can run gpt model under fp16. +E.g. by setting the `data_type` of `gpt_config.ini` to `fp16` or `bf16`, users can run gpt model under fp16/bf16. * Convert the token ids to sentence. @@ -208,8 +252,59 @@ E.g. by setting the `is_half` of `gpt_config.ini` to 1, users can run gpt model export CUDA_VISIBLE_DEVICES=0 mpirun -n 1 --allow-run-as-root ./bin/gptj_triton_example ``` + To run with tensor and/or pipeline parallelism, make more GPUs visible, edit the `../examples/cpp/gptj/gptj_config.ini` and generate the parameter files with `gptj_ckpt_convert.py` accordingly. + +### Run GPTJ with prompts + +GPTJ now supports prefix_prompt. + +1. Convert the prompt weights + + You need to transpose the prefix prompt weights to the shape [num_layers, 2, num_heads, perfix_seq_len, size_per_head], and save it by numpy. The naming style is like ` model.prefix_prompt..weights..bin`. + + Note that you need to specify `start_id`, `end_id` by yourself in order to make sure that it is consistent with the tokenizer. + +2. Run GPT with C++ example + + You need to specify the example gpt_config.ini like below to enable the p/prompt_tuning feature. + + ```ini + [gpt_124M] + head_num=12 + size_per_head=64 + vocab_size=50257 + decoder_layers=12 + start_id=50256 + end_id=50256 + inter_size=3072 + num_tasks=3 + prompt_learning_type=3 + + [gpt_124M_task_0] + task_name = squad + prompt_length = 10 + + [gpt_124M_task_1] + task_name = sentiment + prompt_length = 10 + + [gpt_124M_task_2] + task_name = intent_and_slot + prompt_length = 10 + ``` + + `task_name` and `prompt_length` are specified for loading prompt weights. + + **prompt_learning_type**: + + - no prompt: 0 + - soft_prompt: 1 + - prefix_prompt: 2 + - p/prompt_tuning: 3 + + ### Compare with reference implementation * Install the reference implementation from finetuneanon: diff --git a/docs/gptneox_guide.md b/docs/gptneox_guide.md new file mode 100644 index 000000000..61f078770 --- /dev/null +++ b/docs/gptneox_guide.md @@ -0,0 +1,135 @@ +# GPT-NeoX + +## Table Of Contents + +- [GPT-NeoX](#gpt-neox) + - [Table Of Contents](#table-of-contents) + - [Introduction](#introduction) + - [Supported features](#supported-features) + - [Setup](#setup) + - [Requirements](#requirements) + - [Download the model](#download-the-model) + - [Tokenizer](#tokenizer) + - [Run GPT-NeoX](#run-gpt-neox) + +## Introduction + +This document describes the steps to run the GPT-NeoX model on FasterTransformer. +GPT-NeoX is a model developed by EleutherAI, available publicly on their GitHub [repository](https://github.com/EleutherAI/gpt-neox). +For the time being, only the 20B parameter version has been tested. + +More details are listed in [gptj_guide.md](gptj_guide.md#introduction). + +Optimization in gpt-neox are similar to optimization in GPT, describing in the [gpt_guide.md](gpt_guide.md#optimization). + +### Supported features + +* Checkpoint converter + * EleutherAI +* Data type + * FP32 + * FP16 +* Feature + * Multi-GPU multi-node inference + * Dynamic random seed + * Stop tokens + * Bad words list + * Beam search and sampling are both supported + +## Setup + +### Requirements + +See common requirements such as in [gptj_guide.md](gptj_guide.md#requirements). + +### Download the model + +First download a pytorch checkpoint, as provided by [EleutherAI](https://github.com/EleutherAI/gpt-neox#download-links): + +```bash +wget --cut-dirs=5 -nH -r --no-parent --reject "index.html*" https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/ -P 20B_checkpoints +``` + +Then use the script provided by FasterTransformer to convert the checkpoint to raw weights, understood by FT. + +```bash +python ../examples/pytorch/gptneox/utils/eleutherai_gpt_neox_convert.py 20B_checkpoints ../models/gptneox -t 2 +``` + +### Tokenizer + +You may download the tokenizer config [here](https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json). + +To tokenize/detokenize files, use the script found in `examples/pytorch/gptneox/utils/hftokenizer.py`. You may need to pass the path to the tokenizer config with the `--tokenizer` flag. + +### Run GPT-NeoX + +* Generate the `gemm_config.in` file.\ + Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) + ```bash + ./bin/gpt_gemm + E.g., ./bin/gpt_gemm 8 1 32 64 96 24576 50432 1 2 + ``` + +* Run GPT on C++ + + Users can see the details of arguments in `examples/cpp/gptneox/gptneox_config.ini`. It controls the model path, model size, tensor parallelism size, and some hyper-parameters. + + ```bash + mpirun -n 2 --allow-run-as-root ./bin/gptneox_example + ``` + +E.g. by setting the `data_type` of `gptneox_config.ini` to `fp16`, users can run gpt model under fp16. + +You can then decode the `out` file with the tokenizer: + + ```bash + wget https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json + ../examples/pytorch/gptneox/utils/hftokenizer.py out --tokenizer 20B_tokenizer.json + ``` + + diff --git a/docs/images/encoder_flowchart.png b/docs/images/encoder_flowchart.png index ce5086f53..0eb5fe906 100644 Binary files a/docs/images/encoder_flowchart.png and b/docs/images/encoder_flowchart.png differ diff --git a/docs/images/gpt/gpt_interactive_generation.0.png b/docs/images/gpt/gpt_interactive_generation.0.png new file mode 100644 index 000000000..0b4c678f6 Binary files /dev/null and b/docs/images/gpt/gpt_interactive_generation.0.png differ diff --git a/docs/images/gpt/gpt_interactive_generation.1.png b/docs/images/gpt/gpt_interactive_generation.1.png new file mode 100644 index 000000000..c626d35cf Binary files /dev/null and b/docs/images/gpt/gpt_interactive_generation.1.png differ diff --git a/docs/images/gpt/gpt_interactive_generation.2.png b/docs/images/gpt/gpt_interactive_generation.2.png new file mode 100644 index 000000000..ec6339503 Binary files /dev/null and b/docs/images/gpt/gpt_interactive_generation.2.png differ diff --git a/docs/longformer_guide.md b/docs/longformer_guide.md index 3150a9c50..fb236a4ae 100644 --- a/docs/longformer_guide.md +++ b/docs/longformer_guide.md @@ -47,7 +47,7 @@ In this demo, you can run the longformer as a Pytorch OP. - CMake >= 3.13 for PyTorch - CUDA 11.0 or newer version -- Python 3 is recommended because some features are not supported in python 2 +- Python: Only verify on python 3 - PyTorch: Verify on 1.8.0, >= 1.5.0 should work. These components are readily available within the NGC PyTorch Docker image below. @@ -112,7 +112,7 @@ The script will first compare the performance between HuggingFace Longformer enc #### Args -1. Use `--fp16` to run in FP16 mode, that's too say use FP16 input and yield FP16 output for FT longformer. Also HuggingFace's longformer will use FP16 too. +1. Use `--data_type fp16/bf16` to run in FP16 or BF16 mode, that's too say use FP16/BF16 input and yield FP16/BF16 output for FT longformer. Also HuggingFace's longformer will use FP16/BF16 too. 2. Use `--sequence-length` to select sequence length. `sequence_length` should >= 2 * `local_attention_window_size` and `sequence-length` % `local_attention_window_size` = 0 3. Use `-max-global-attention-num` to choose maximum of global token nums. That's too say your global token nums should not exceed this limit. And note that FT's longformer only support to place global tokens at the beginning of the sequence. In QA example, all the question tokens will be placed at the beginning of the sequence and marked as global tokens. 4. Use `--batch-size` to select batch size. Note in QA example, it will just duplicate the same question and passage sequence `batch_size` times and stack them together. @@ -124,7 +124,7 @@ The script will first compare the performance between HuggingFace Longformer enc #### Building the FT Longformer encoder ```python -# Pass the neccessary config and args +# Pass the necessary config and args weights_file = os.path.join(hf_model_dir, 'pytorch_model.bin') ft_encoder = FTLongformerEncoder(weights_file, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, diff --git a/docs/swin_guide.md b/docs/swin_guide.md index 05b781171..f5a3eb92d 100644 --- a/docs/swin_guide.md +++ b/docs/swin_guide.md @@ -66,7 +66,7 @@ In this demo, you can run Faster Swin-Transformer as a C++ program. - Python 3 is recommended because some features are not supported in python 2 - PyTorch: Verify on 1.10.0, >= 1.5.0 should work. -Recommand to use image `nvcr.io/nvidia/pytorch:21.07-py3`. +Recommend to use image `nvcr.io/nvidia/pytorch:21.07-py3`. > docker run -ti --gpus all --rm nvcr.io/nvidia/pytorch:21.07-py3 bash @@ -113,7 +113,7 @@ Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) # 4: swin-LARGE with window size 7x7 # 5: swin-LARGE with window size 12x12 ./bin/swin_gemm -./bin/swin_example +./bin/swin_example ./bin/swin_int8_example ``` Take swin-TINY with batch=32 as an example: @@ -126,6 +126,10 @@ Take swin-TINY with batch=32 as an example: ./bin/swin_gemm 32 224 7 3 32 1 0 ./bin/swin_example 1 0 32 +# Run Swin-Transformer(TINY) under BF16 on C++ +./bin/swin_gemm 32 224 7 3 32 2 0 +./bin/swin_example 2 0 32 + # Run Swin-Transformer(TINY) under INT8 on C++ ./bin/swin_gemm 32 224 7 3 32 0 1 ./bin/swin_int8_example 0 32 @@ -170,7 +174,7 @@ python -m torch.distributed.launch --nproc_per_node 1 \ (When `int8_mode=1`, all GEMMs are INT8-in-INT8-out. When `int8_mode=2`, GEMM of all `fc2` layers and `patchMerge` are relaxed to INT8-in-INT32-out, while other GEMMs keep INT8-I/O. ) -If you want to insists on using `--int8-mode 1` for LARGE model (because speed of mode=1 is much faster), we recommend using QAT to finetune paramters of LARGE checkpoint. +If you want to insists on using `--int8-mode 1` for LARGE model (because speed of mode=1 is much faster), we recommend using QAT to finetune parameters of LARGE checkpoint. 2. Run test ```bash @@ -232,7 +236,7 @@ We here compared the performance between Swin-Transformer and FT Swin-Transforme * num_of_blocks = {2,2,6,2} ### Swin Performance on T4 -Here, `torch.jit.trace` means using tracing to convert Torch model to TorchScript model and then profile its performace. +Here, `torch.jit.trace` means using tracing to convert Torch model to TorchScript model and then profile its performance. #### FP32 | Batch_size | torch.jit.trace | cpp | speedup | trt plugin | speedup | torch op | speedup | | :--------: | :-------------: | :----: | :-----: | :--------: | :-----: | :------: | :-----: | @@ -266,7 +270,7 @@ INT8 vs. FP16 speedup on Swin TINY/SMALL/BASE/LARGE: | 32 | 26.10 | 19.19 | 1.36 | 44.04 | 30.66 | 1.44 | 64.53 | 42.84 | 1.51 | 121.06 | 76.34 | 1.59 | ### Swin Performance on A100 -Here, `torch.jit.trace` means using tracing to convert Torch model to TorchScript model and then profile its performace. +Here, `torch.jit.trace` means using tracing to convert Torch model to TorchScript model and then profile its performance. #### TF32 On chips with Ampere architectures (like A30, A100), user can use `export NVIDIA_TF32_OVERRIDE=1` to enforce the program run under TF32, otherwise FP32 GEMM is used by default, which is much slower. | Batch_size | torch.jit.trace | cpp | speedup | trt plugin | speedup | torch op | speedup | diff --git a/docs/t5_guide.md b/docs/t5_guide.md index c5ac716ce..a0885635c 100644 --- a/docs/t5_guide.md +++ b/docs/t5_guide.md @@ -18,8 +18,12 @@ The FasterTransformer T5 implements the huggingface t5 model (https://huggingfac - [Build the project](#build-the-project) - [How to use](#how-to-use) - [Translation process](#translation-process) + - [Running UL2 on FasterTransformer Pytorch op](#running-ul2-on-fastertransformer-pytorch-op) + - [Running t5-v1.1](#running-t5-v11) + - [Running mt5](#running-mt5) - [Performance](#performance) - [End-to-end translation performance on PyTorch](#end-to-end-translation-performance-on-pytorch) + - [T5-3B on A100-80GB](#t5-3b-on-a100-80gb) - [T5-base on A100-40GB](#t5-base-on-a100-40gb) - [T5-base on V100-16GB](#t5-base-on-v100-16gb) - [T5-small on V100-16GB](#t5-small-on-v100-16gb) @@ -37,6 +41,7 @@ This document describes what FasterTransformer provides for the `T5` model, expl * Data type * FP32 * FP16 + * BF16 * Feature * Multi-GPU multi-node inference * Dynamic random seed @@ -50,18 +55,116 @@ This document describes what FasterTransformer provides for the `T5` model, expl ## Model architecture ### Workflow - The source codes are put in `src/fastertransformer/models/t5`. +* Constructor of T5 Encoder + +| Classification | Name | Data Type | Description | +| :------------: | :--------------------------: | :----------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [0] | max_batch_size | size_t | **Deprecated, move to input** | +| [1] | max_seq_len | size_t | **Deprecated, move to input** | +| [2] | head_num | size_t | Head number for model configuration | +| [3] | size_per_head | size_t | Size per head for model configuration | +| [4] | inter_size | size_t | The inter size of feed forward network. It is often set to 4 * head_num * size_per_head. | +| [5] | d_model | size_t | The dimension of embedding of transformer input. | +| [6] | num_layer | size_t | Number of transformer layers for model configuration | +| [7] | num_bucket_or_max_seq_len | size_t | Number of bucket in relative position embedding, or max sequence length for absolute position embedding | +| [8] | max_distance | size_t | Max distance for relative position embedding | +| [9] | sm | int | The compute capacity of GPU | +| [10] | q_scaling | float | It is used to scale the query before the batch multiplication of query and key | +| [11] | stream | cudaStream_t | CUDA stream | +| [12] | cublas_wrapper | cublasMMWrapper* | Pointer of cuBLAS wrapper, which is declared in `src/fastertransformer/utils/cublasMMWrapper.h` | +| [13] | allocator | IAllocator* | Pointer of memory allocator, which is declared in `src/fastertransformer/utils/allocator.h` | +| [14] | is_free_buffer_after_forward | bool | If setting to be `true`, FasterTransformer will allocate buffer before forward, and free buffer after forward. When the allocator is based on memory pool, setting to `true` may help reducing the memory usage during inference. | +| [15] | attention_type | AttentionType | Determine fusing the attention or not, remove padding or not, which is declared in `src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h` | +| [16] | sparse | bool | Is using sparsity. **Experimental feature** | +| [17] | activation_type | ActivationType | Determine the activation in FFN, which is declared in `src/fastertransformer/layers/attention_layers/FfnLayer.h` | +| [18] | layernorm_type | LayerNormType | Determine using pre-layernorm or post-layernorm, which is declared in `src/fastertransformer/kernels/layernorm_kernels.h` | +| [19] | tensor_para | NcclParam | Tensor Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [20] | pipeline_para | NcclParam | Pipeline Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [21] | custom_all_reduce_comm | AbstractCustomComm | Custom all reduction communication for custom all reduction in model parallelism. It is only supported in 8-way tensor parallelism | +| [22] | enable_custom_all_reduce | int | Flag of enabling custom all reduction or not | + +* Input of T5 Encoder + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :-------------: | :----------------------------: | :------: | :------------: | :-------------------------------------------------------------------------------------------------------------------------: | +| input_ids | [batch_size, seq_len] | GPU | int | The input ids | +| sequence_length | [batch_size] | GPU | int | The lengths of input ids | +| inputs_embeds | [batch_size, seq_len, d_model] | GPU | fp32/fp16/bf16 | **Optional**. The embedding after embedding lookup. If this input is not null, using this embedding as input of transformer | + +* Output of T5 Encoder + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :-----------------: | :-------------------------------------: | :------: | :------------: | :-----------------------------: | +| output_hidden_state | [batch_size, sequence_length, d_model_] | GPU | fp32/fp16/bf16 | The output of transformer layer | + +* Constructor of T5 Decoding + +| Classification | Name | Data Type | Description | +| :------------: | :--------------------------: | :----------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [0] | max_batch_size | size_t | **Deprecated, move to input** | +| [1] | max_seq_len | size_t | **Deprecated, move to input** | +| [2] | mem_max_seq_len | size_t | **Deprecated, move to input** | +| [3] | beam_width | size_t | **Deprecated, move to input** | +| [4] | head_num | size_t | Head number for model configuration | +| [5] | size_per_head | size_t | Size per head for model configuration | +| [6] | inter_size | size_t | The inter size of feed forward network. It is often set to 4 * head_num * size_per_head. | +| [7] | d_model | size_t | The dimension of embedding of transformer input. | +| [8] | num_layer | size_t | Number of transformer layers for model configuration | +| [9] | vocab_size | size_t | Vocabulary size for model configuration | +| [10] | num_bucket | size_t | Number of bucket in relative position embedding, or max sequence length for absolute position embedding | +| [11] | max_distance | size_t | Max distance for relative position embedding | +| [12] | q_scaling | float | It is used to scale the query before the batch multiplication of query and key | +| [13] | start_id | int | Start id for vocabulary | +| [14] | end_id | int | End id for vocabulary | +| [15] | beam_search_diversity_rate | float | **Deprecated, move to input** | +| [16] | top_k | size_t | **Deprecated, move to input** | +| [17] | top_p | float | **Deprecated, move to input** | +| [18] | temperature | float | **Deprecated, move to input** | +| [19] | len_penalty | float | **Deprecated, move to input** | +| [20] | repetition_penalty | float | **Deprecated, move to input** | +| [21] | stream | cudaStream_t | CUDA stream | +| [22] | cublas_wrapper | cublasMMWrapper* | Pointer of cuBLAS wrapper, which is declared in `src/fastertransformer/utils/cublasMMWrapper.h` | +| [23] | allocator | IAllocator* | Pointer of memory allocator, which is declared in `src/fastertransformer/utils/allocator.h` | +| [24] | is_free_buffer_after_forward | bool | If setting to be `true`, FasterTransformer will allocate buffer before forward, and free buffer after forward. When the allocator is based on memory pool, setting to `true` may help reducing the memory usage during inference. | +| [25] | cuda_device_prop | cudaDeviceProp* | Pointer of CUDA device properties, which is used to get the properties of hardware like size of shared memory | +| [26] | tensor_para | NcclParam | Tensor Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [27] | pipeline_para | NcclParam | Pipeline Parallel information, which is declared in `src/fastertransformer/utils/nccl_utils.h` | +| [28] | activation_type | ActivationType | Determine the activation in FFN, which is declared in `src/fastertransformer/layers/attention_layers/FfnLayer.h` | +| [29] | tie_word_embeddings | bool | A flag controlling the scale of transformer output | +| [30] | custom_all_reduce_comm | AbstractCustomComm | Custom all reduction communication for custom all reduction in model parallelism. It is only supported in 8-way tensor parallelism | +| [31] | enable_custom_all_reduce | int | Flag of enabling custom all reduction or not | + +* Input of T5 Decoding + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :------------------------: | :-------------------------------------------: | :------: | :--------------------: | :------------------------------------------------------------------------------------------------------------------------------------: | +| encoder_output | [batch_size, mem_max_seq_len, memory_d_model] | GPU | fp32/fp16/bf16 | The output of T5 Encoder | +| encoder_sequence_length | [batch_size] | GPU | int | The sequence length of encoder input/output | +| stop_words_list | [batch_size, 2, stop_words_length] | GPU | int | **Optional**. When FT generates words in this list, it will stop the generation. An extension of stop id | +| bad_words_list | [batch_size, 2, bad_words_length] | GPU | int | **Optional**. The words in the list will be When FT generates words in this list, it will stop the generation. An extension of stop id | +| start_id | [batch_size] | CPU | int | **Optional**. If FT receives this input, FT will replace default start id by it | +| end_id | [batch_size] | CPU | int | **Optional**. If FT receives this input, FT will replace default end id by it | +| runtime_top_k | [1] or [batch_size] | CPU | uint | **Optional**. top_k value for top k sampling | +| runtime_top_p | [1] or [batch_size] | CPU | float | **Optional**. top_p value for top p sampling | +| beam_search_diversity_rate | [1] or [batch_size] | CPU | float | **Optional**. A hyper hyper-parameter for [simple diverse decoding](https://arxiv.org/pdf/1611.08562.pdf) | +| temperature | [1] or [batch_size] | CPU | float | **Optional**. Temperature applied to logits for both beam search and sampling | +| len_penalty | [1] or [batch_size] | CPU | float | **Optional**. Length penalty applied to logits for only beam search | +| repetition_penalty | [1] or [batch_size] | CPU | float | **Optional**. Repetition penalty applied to logits for both beam search and sampling | +| random_seed | [1] or [batch_size] | CPU | unsigned long long int | **Optional**. Random seed to initialize the random table in sampling. | + +* Output of T5 Decoding + +| Name | Tensor/Parameter Shape | Location | Data Type | Description | +| :--------------: | :-----------------------------------------------------------------------------------------------------------------: | :------: | :-------: | :-------------------------------------------------------------------------------: | +| output_ids | [batch_size, beam_width, max_output_seq_len] | GPU | int | The output ids. It contains the input_ids and generated ids | +| sequence_length | [batch_size, beam_width] | GPU | int | The lengths of output ids | +| output_log_probs | [batch_size, beam_width, request_output_seq_len] | GPU | float | **Optional**. It records the log probability of logits at each step for sampling. | +| cum_log_probs | [batch_size, beam_width] | GPU | float | **Optional**. Cumulative log probability of generated sentences | +| cross_attentions | [num_layer / pipeline_para_size, batch_size, beam_width, head_num / tensor_para_size, max_seq_len, mem_max_seq_len] | GPU | float | **Optional**. The attention scores of cross attention | + ### Optimization 1. Kernel optimization: First, since the sequence length of query in `SelfAttention` and `CrossAttention` is always 1, we use customed fused multi-head attention kernel to optimize. Second, we fuse many small operations into one kernel. For example, `AddBiasResidualLayerNorm` combines the adding bias, adding residual of previous block and the computation of layer normalization into 1 kernel. Third, we optimize top k operation and sampling to accelerate the beam search and sampling. Finally, to prevent from recomputing the previous keys and values, we allocate a buffer to store them at each step. Although it takes some additional memory usage, we can save the cost of recomputing, allocating buffer at each step, and the cost of concatenation. @@ -75,10 +178,10 @@ The following section lists the requirements to use FasterTransformer. - CMake >= 3.13 for PyTorch - CUDA 11.0 or newer version - NCCL 2.10 or newer version -- Python 3 is recommended because some features are not supported in python 2 +- Python: Only verify on Python 3. - PyTorch: Verify on 1.10.0, >= 1.5.0 should work. -Recommend use nvcr image like `nvcr.io/nvidia/pytorch:21.11-py3`. +Recommend use nvcr image like `nvcr.io/nvidia/pytorch:22.07-py3`. Ensure you have the following components: - [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker) and NGC container are recommended @@ -96,10 +199,10 @@ For those unable to use the NGC container, to set up the required environment or #### Prepare -You can choose the pytorch version and python version you want. Here, we suggest image `nvcr.io/nvidia/pytorch:21.11-py3`, which contains the PyTorch 1.8.0 and python 3.8. +You can choose the pytorch version and python version you want. Here, we suggest image `nvcr.io/nvidia/pytorch:22.07-py3`, which contains the PyTorch 1.8.0 and python 3.8. ```bash - nvidia-docker run -ti --rm nvcr.io/nvidia/pytorch:21.11-py3 bash + nvidia-docker run -ti --rm nvcr.io/nvidia/pytorch:22.07-py3 bash git clone https://github.com/NVIDIA/FasterTransformer.git mkdir -p FasterTransformer/build cd FasterTransformer/build @@ -108,7 +211,19 @@ You can choose the pytorch version and python version you want. Here, we suggest #### Build the project -* Note: the `xx` of `-DSM=xx` in following scripts means the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100). Default setting is including 70, 75, 80 and 86. +* Note: the `xx` of `-DSM=xx` in following scripts means the compute capability of your GPU. The following table shows the compute capability of common GPUs. + +| GPU | compute capacity | +| :---: | :--------------: | +| P40 | 60 | +| P4 | 61 | +| V100 | 70 | +| T4 | 75 | +| A100 | 80 | +| A30 | 80 | +| A10 | 86 | + +By default, `-DSM` is set by 70, 75, 80 and 86. When users set more kinds of `-DSM`, it requires longer time to compile. So, we suggest setting the `-DSM` for the device you use only. Here, we use `xx` as an example due to convenience. 1. build with PyTorch @@ -121,7 +236,7 @@ You can choose the pytorch version and python version you want. Here, we suggest 2. build with TensorRT - Can use `nvcr.io/nvidia/pytorch:21.11-py3` docker image, too. + Can use `nvcr.io/nvidia/pytorch:22.07-py3` docker image, too. ```bash cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TRT=ON -DBUILD_MULTI_GPU=ON .. @@ -134,7 +249,8 @@ You can choose the pytorch version and python version you want. Here, we suggest 1. Run FasterTransformer T5 on PyTorch - Please install transformers first before running the demos by + Please install utils first before running the demos by + ```bash pip install -r ../examples/pytorch/t5/requirement.txt ``` @@ -179,6 +295,8 @@ You can choose the pytorch version and python version you want. Here, we suggest --model t5-small ``` + Data Type can be `fp32`, `fp16` and `bf16` + The outputs should be like to the following: ```bash @@ -213,16 +331,18 @@ You can choose the pytorch version and python version you want. Here, we suggest ``` ```bash - python ../examples/tensorrt/t5/createT5TestData.py # get T5PluginTestIO.npz for test - - python ../examples/tensorrt/t5/extractT5ModelToBIN.py # get T5Model weight for test (need Internet) + # get T5Model weight for test (need Internet or pre-downloaded model) + # Note that the model is saved in ./ft_t5_small/1-gpu, but not ./ft_t5_small + python ../examples/tensorrt/t5/extractT5ModelToBIN.py \ + -in_file t5-small \ + -saved_dir ./ft_t5_small python ../examples/tensorrt/t5/testT5Plugin.py \ --batch_size 32 \ --beam_width 4 \ --max_seq_len 128 \ - --data_type fp32 \ - --sampling_topk 4 + --data_type fp16 \ + --ckpt_path ./ft_t5_small/1-gpu ``` * Input/Output Tensor/Parameter of T5Encoder Plugin @@ -235,51 +355,189 @@ You can choose the pytorch version and python version you want. Here, we suggest | [0] | [] | int32 | max_batch_size | | [1] | [] | int32 | max_seq_len | | [2] | [] | int32 | beam_width (keep the same as decoding) | -| [3] | [] | int32 | head_num | -| [4] | [] | int32 | size_per_head | -| [5] | [] | int32 | inter_size | -| [6] | [] | int32 | d_model | -| [7] | [] | int32 | num_layer | -| [8] | [] | int32 | num_bucket | -| [9] | [] | int32 | max_distance | -| [10] | [] | int32 | sm | -| [11] | [] | float32 | q_scaling | -| [12] | [] | int32 | useFP16 | +| [3] | [] | int32 | sm | +| [4] | [] | int32 | useFP16 | +| [5] | [] | string | checkpoint path of converted FT model | | output tensor | | | | | [0] | [batch_size,max_seq_len,d_model] | foat32/float16 | encoder output | * Input/Output Tensor/Parameter of T5Decoding Plugin -| Classification | Tensor/Parameter Shape | Data Type | Description | -| :-------------: | :---------------------------------: | :-------------: | :---------------------------------: | -| input tensor | | | | -| [0] | [batch_size,max_seq_len,d_model] | foat32/float16 | encoder output | -| [1] | [batch_size] | int32 | real sequence length of each input | -| input parameter | | | | -| [0] | [] | int32 | max_batch_size | -| [1] | [] | int32 | max_seq_len | -| [2] | [] | int32 | mem_max_seq_len | -| [3] | [] | int32 | beam_width | -| [4] | [] | int32 | head_num | -| [5] | [] | int32 | size_per_head | -| [6] | [] | int32 | inter_size | -| [7] | [] | int32 | d_model | -| [8] | [] | int32 | num_layer | -| [9] | [] | int32 | vocab_size | -| [10] | [] | int32 | num_bucket | -| [11] | [] | int32 | max_distance | -| [12] | [] | int32 | start_id | -| [13] | [] | int32 | end_id | -| [14] | [] | float32 | beam_search_diversity_rate | -| [15] | [] | int32 | top_k | -| [16] | [] | float32 | top_p | -| [17] | [] | float32 | temperature | -| [18] | [] | float32 | len_penalty | -| [19] | [] | float32 | repetition_penalty | -| [20] | [] | int32 | usaeFP16 | -| output tensor | | | | -| [0] | [batch_size,beam_width,max_seq_len] | float32/float16 | decoding output | -| [1] | [batch_size,beam_width] | float32/float16 | real sequence length of each output | +| Classification | Tensor/Parameter Shape | Data Type | Description | +| :-------------: | :---------------------------------: | :-------------: | :-----------------------------------: | +| input tensor | | | | +| [0] | [batch_size,max_seq_len,d_model] | foat32/float16 | encoder output | +| [1] | [batch_size] | int32 | real sequence length of each input | +| [2] | [1] or [batch_size] | int32 | top_k | +| [3] | [1] or [batch_size] | float32 | top_p | +| [4] | [1] or [batch_size] | float32 | beam_search_diversity_rate | +| [5] | [1] or [batch_size] | float32 | temperature | +| [6] | [1] or [batch_size] | float32 | len_penalty | +| [7] | [1] or [batch_size] | float32 | repetition_penalty | +| input parameter | | | | +| [0] | [] | int32 | max_batch_size | +| [1] | [] | int32 | max_seq_len | +| [2] | [] | int32 | mem_max_seq_len | +| [3] | [] | int32 | beam_width | +| [4] | [] | int32 | usaeFP16 | +| [5] | [] | string | checkpoint path of converted FT model | +| output tensor | | | | +| [0] | [batch_size,beam_width,max_seq_len] | float32/float16 | decoding output | +| [1] | [batch_size,beam_width] | float32/float16 | real sequence length of each output | + +The model configuration are stored in `config.ini` of checkpoint path. For example, after running, + +``` +python ../examples/tensorrt/t5/extractT5ModelToBIN.py \ + -in_file t5-small \ + -saved_dir ./ft_t5_small` +``` + +users can see the model configuration in `./ft_t5_small/1-gpu/config.ini` + +### Running UL2 on FasterTransformer Pytorch op + +[UL2](https://arxiv.org/pdf/2205.05131v1.pdf) (Unifying Language Learning Paradigms) is published by Google. The following is its introduction: + +> UL2 is a unified framework for pretraining models that are universally effective across datasets and setups. UL2 uses Mixture-of-Denoisers (MoD), apre-training objective that combines diverse pre-training paradigms together. UL2 introduces a notion of mode switching, wherein downstream fine-tuning is associated with specific pre-training schemes. + +We show how to sever UL2 by FasterTransformer PyTorch op on huggingface's model in this section. + + 3.1 Download model (It requires some time because the model size is about 40GBs) + + ``` + sudo apt-get install git-lfs + git lfs install + git lfs clone https://huggingface.co/google/ul2 + ``` + + 3.2 Convert the checkpoint to FT + + Because loading UL2 model on pytorch and do prprocessing takes long time, and `summarization.py` only supports loading FT's model from binary files, we convert the pytorch checkpoint to FasterTransformer by converter `huggingface_t5_ckpt_convert.py`. + + ``` + python3 ../examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py \ + -saved_dir ul2/c-models \ + -in_file ul2/ \ + -inference_tensor_para_size 2 \ + -weight_data_type fp32 + ``` + + 3.3 Run UL2 on summarization task + + ``` + mpirun -n 2 python3 ../examples/pytorch/t5/summarization.py \ + --ft_model_location ul2/c-models/ \ + --hf_model_location ul2/ \ + --test_ft \ + --data_type bf16 \ + --tensor_para_size 2 + ``` + + The results would be like + + ``` + rouge1 : 23.673944166014593 + rouge2 : 5.946485383012474 + rougeL : 14.749827731626247 + rougeLsum : 20.217932008044144 + ``` + +### Running t5-v1.1 + + 3.1 Download model (It requires some time because the model size is about 40GBs) + + ``` + sudo apt-get install git-lfs + git lfs install + git lfs clone https://huggingface.co/google/t5-v1_1-base + ``` + + + 3.2 Convert the checkpoint to FT + + ``` + python3 ../examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py \ + -saved_dir t5-v1_1-base/c-models \ + -in_file t5-v1_1-base/ \ + -inference_tensor_para_size 1 \ + -weight_data_type fp32 + ``` + + 3.3 Run t5-v1.1 on summarization task + + ``` + python3 ../examples/pytorch/t5/summarization.py \ + --ft_model_location t5-v1_1-base/c-models/ \ + --hf_model_location t5-v1_1-base/ \ + --test_ft \ + --test_hf + ``` + + The results would be like + + ``` + Hugging Face (total latency: 21.826529 sec) + rouge1 : 10.786476875527406 + rouge2 : 1.8231246974441166 + rougeL : 8.652689713627165 + rougeLsum : 10.326607305635523 + Faster Transformers (total latency: 7.036808000000001 sec) + rouge1 : 10.91735083630513 + rouge2 : 1.8454654301092783 + rougeL : 8.76872604148143 + rougeLsum : 10.453229536094794 + ``` + + * Note that these models are not fine-tuned, so running with FP16 or setting topk > 1 may lead to unstable results. + +### Running mt5 + + 3.1 Download model (It requires some time because the model size is about 40GBs) + + ``` + sudo apt-get install git-lfs + git lfs install + git lfs clone https://huggingface.co/google/mt5-base + ``` + + + 3.2 Convert the checkpoint to FT + + ``` + python3 ../examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py \ + -saved_dir mt5-base/c-models \ + -in_file mt5-base/ \ + -inference_tensor_para_size 1 \ + -weight_data_type fp32 + ``` + + 3.3 Run mt5 on summarization task + + ``` + python3 ../examples/pytorch/t5/summarization.py \ + --ft_model_location mt5-base/c-models/ \ + --hf_model_location mt5-base/ \ + --test_ft \ + --test_hf + ``` + + The results would be like + + ``` + Hugging Face (total latency: 3.143815 sec) + rouge1 : 4.636193727758547 + rouge2 : 0.20661157024793395 + rougeL : 3.7990194456844026 + rougeLsum : 4.274724726798723 + Faster Transformers (total latency: 1.3952859999999998 sec) + rouge1 : 4.726148174547172 + rouge2 : 0.20818875780707846 + rougeL : 3.8698557495145516 + rougeLsum : 4.3507453221528 + ``` + + * Note that these models are not fine-tuned, so running with FP16 or setting topk > 1 may lead to unstable results. ## Performance @@ -287,6 +545,7 @@ Hardware settings: * CPU: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz * V100-16GB (with mclk 877MHz, pclk 1380MHz) with Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz (dgx-1 server) * A100-40GB +* A100-80GB (with mclk 1593, pclk 1410) with AMD EPYC 7742 64-Core Processor To run the following benchmark, we need to install the unix computing tool "bc" by @@ -296,10 +555,40 @@ apt-get install bc ### End-to-end translation performance on PyTorch -We demonstrate the throughput of huggingface and FT for end-to-end translation on V100. We also skip the BLEU score because the score of PyTorch, FT Decoder and FT Decoding are close. +We demonstrate the throughput of huggingface and FT for end-to-end translation on V100 and A100. We also skip the BLEU score because the score of PyTorch, FT Decoder and FT Decoding are close. Although the bleu scores of all methods are close, the results may be little different, and the number of generated tokens may be also different. So, we use throughput but not latency to show the performance in this benchmark. +#### T5-3B on A100-80GB + +* T5-3B on FP16 with beamsearch + +| Batch Size | beamsearch | Precision | FT Decoding
Throughput (token/sec) | +| :--------: | :--------: | :-------: | :--------------------------------------: | +| 1 | 4 | FP16 | 192 | +| 1 | 32 | FP16 | 140 | +| 8 | 4 | FP16 | 787 | +| 8 | 32 | FP16 | 271 | +| 32 | 4 | FP16 | 1540 | +| 32 | 32 | FP16 | OOM | +| 128 | 4 | FP16 | 1907 | +| 128 | 32 | FP16 | OOM | + +When batch size is 32, beam width is 32, the k/v caches require about 90GBs and lead to OOM. + +* T5-3B on FP16 with sampling + +| Batch Size | sampling | Precision | FT Decoding
Throughput (token/sec) | +| :--------: | :------: | :-------: | :--------------------------------------: | +| 1 | 4 | FP16 | 218 | +| 1 | 0.5 | FP16 | 217 | +| 8 | 4 | FP16 | 932 | +| 8 | 0.5 | FP16 | 908 | +| 32 | 4 | FP16 | 2416 | +| 32 | 0.5 | FP16 | 2344 | +| 128 | 4 | FP16 | 5004 | +| 128 | 0.5 | FP16 | 4891 | + #### T5-base on A100-40GB * T5-base on FP32 with beamsearch diff --git a/docs/vit_guide.md b/docs/vit_guide.md index 54ffaf811..42fe351d9 100644 --- a/docs/vit_guide.md +++ b/docs/vit_guide.md @@ -56,7 +56,7 @@ In this demo, you can run Faster ViT as a C++ program. - Python 3 is recommended because some features are not supported in python 2 - PyTorch: Verify on 1.10.0, >= 1.5.0 should work. -Recommand to use image `nvcr.io/nvidia/pytorch:21.07-py3`. +Recommend to use image `nvcr.io/nvidia/pytorch:21.07-py3`. > docker run -ti --gpus all --rm nvcr.io/nvidia/pytorch:21.07-py3 bash @@ -153,12 +153,11 @@ Refer to [Guide of ViT Quantization Toolkit](../examples/pytorch/vit/ViT-quantiz ```bash cd $WORKSPACE/examples/pytorch/vit/ViT-quantization export DATA_DIR=Path to the dataset -export CKPT_DIR=Path to the ViT checkpoints python -m torch.distributed.launch --nproc_per_node 1 \ --master_port 12345 main.py \ --calib \ --name vit \ - --pretrained_dir $CKPT_DIR/ViT-B_16.npz \ + --pretrained_dir ViT-B_16.npz \ --data-path $DATA_DIR \ --model_type ViT-B_16 \ --img_size 384 \ @@ -167,7 +166,7 @@ python -m torch.distributed.launch --nproc_per_node 1 \ --quant-mode ft2 \ --calibrator percentile \ --percentile 99.99 \ - --calib-output-path $CKPT_DIR + --calib-output-path . ``` **NOTE: Difference between `--quant-mode ft1` and `--quant-mode ft2`**: @@ -187,7 +186,7 @@ cd $WORKSPACE/examples/pytorch/vit python infer_visiontransformer_int8_op.py \ --model_type=ViT-B_16 \ --img_size 384 \ - --calibrated_dir $CKPT_DIR/ViT-B_16_calib.pth \ + --calibrated_dir ViT-B_16_calib.pth \ --batch-size=32 \ --th-path=$WORKSPACE/build/lib/libpyt_vit.so \ --quant-mode ft2 @@ -216,6 +215,19 @@ python infer_visiontransformer_plugin.py \ ``` +**INT8 TensorRT plugin** +```bash +cd $WORKSPACE/examples/tensorrt/vit +#INT8 engine build & infer +python infer_visiontransformer_int8_plugin.py \ + --model_type=ViT-B_16 \ + --img_size=384 \ + --pretrained_dir=$WORKSPACE/examples/pytorch/vit/ViT-quantization/ViT-B_16_calib.pth \ + --plugin_path=../../../build/lib/libvit_plugin.so \ + --batch-size=32 + +``` + ## Performance Hardware settings: diff --git a/docs/xlnet_guide.md b/docs/xlnet_guide.md index 360f21136..9daaa6d75 100644 --- a/docs/xlnet_guide.md +++ b/docs/xlnet_guide.md @@ -33,11 +33,10 @@ In this demo, you can run the XLNet as a C++ program. - CMake >= 3.8 - CUDA 11.0 or newer version -- NCCL 2.10 or newer version - Python 3 is recommended because some features are not supported in python 2 - Tensorflow: Verify on 1.15, 1.13 and 1.14 should work. -Recommand to use image `nvcr.io/nvidia/tensorflow:20.12-tf1-py3`. +Recommend to use image `nvcr.io/nvidia/tensorflow:20.12-tf1-py3`. ```bash docker run -ti --gpus all --rm nvcr.io/nvidia/tensorflow:20.12-tf1-py3 bash @@ -76,7 +75,7 @@ cd /workspace/FasterTransformer/build ```bash -./bin/xlnet_gemm +./bin/xlnet_gemm ./bin/xlnet_example ``` Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) @@ -92,7 +91,7 @@ Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) ./bin/xlnet_example 8 12 128 12 64 1 ``` -- Run XLNet under FP16 on C++ +- Run XLNet under BF16 on C++ ```bash ./bin/xlnet_gemm 8 128 12 64 2 ./bin/xlnet_example 8 12 128 12 64 2 @@ -101,7 +100,7 @@ Data Type = 0 (FP32) or 1 (FP16) or 2 (BF16) #### Verify the correctness ```bash cd examples/tensorflow/xlnet -bash downloadModel.sh #Dowload the input and model data +bash downloadModel.sh #Download the input and model data bash verifyCorrectness.sh # For FP32 model bash verifyCorrectness.sh -f 1 #For FP16 model ``` diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index b67cd01f2..9c5862afc 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -22,7 +22,6 @@ add_subdirectory(swin_int8) add_subdirectory(vit) add_subdirectory(vit_int8) -if(BUILD_MULTI_GPU) - add_subdirectory(gptj) - add_subdirectory(multi_gpu_gpt) -endif() +add_subdirectory(gptj) +add_subdirectory(gptneox) +add_subdirectory(multi_gpu_gpt) diff --git a/examples/cpp/bert/CMakeLists.txt b/examples/cpp/bert/CMakeLists.txt index e1672d85b..6440fe50b 100644 --- a/examples/cpp/bert/CMakeLists.txt +++ b/examples/cpp/bert/CMakeLists.txt @@ -18,3 +18,9 @@ target_link_libraries(bert_example PUBLIC -lcublas -lcublasLt -lcudart -lcuspars else() target_link_libraries(bert_example PUBLIC -lcublas -lcublasLt -lcudart Bert) endif() + +if(BUILD_MULTI_GPU) +add_executable(bert_triton_example bert_triton_example.cc) +target_link_libraries(bert_triton_example PUBLIC -lcublas -lcublasLt -lcudart -lpthread + BertTritonBackend TransformerTritonBackend mpi_utils nccl_utils) +endif() diff --git a/examples/cpp/bert/bert_config.ini b/examples/cpp/bert/bert_config.ini new file mode 100644 index 000000000..134d0c982 --- /dev/null +++ b/examples/cpp/bert/bert_config.ini @@ -0,0 +1,15 @@ +[ft_instance_hyperparameter] +tensor_para_size=2 +pipeline_para_size=1 +data_type=fp16 +is_sparse=0 +is_remove_padding=0 +int8_mode=0 +enable_custom_all_reduce=0 + +model_name=bert_base +model_dir=/models/Bert/HF/bert-base/c-models/2-gpu/ + +[request] +request_batch_size=8 ; determine by the request +request_seq_len=32 ; determine by the request diff --git a/examples/cpp/bert/bert_example.cc b/examples/cpp/bert/bert_example.cc index 28b4d7a5f..a33554120 100644 --- a/examples/cpp/bert/bert_example.cc +++ b/examples/cpp/bert/bert_example.cc @@ -25,13 +25,13 @@ int bertExample(size_t batch_size, size_t seq_len, size_t head_num, size_t size_per_head, - bool is_remove_padding, - bool allow_gemm_test = false); + bool is_remove_padding, + bool allow_gemm_test = false); int main(int argc, char** argv) { if (argc != 8 && argc != 9) { - FT_LOG_ERROR("bert_example batch_size num_layers seq_len head_num size_per_head is_fp16 is_remove_padding"); + FT_LOG_ERROR("bert_example batch_size num_layers seq_len head_num size_per_head data_type is_remove_padding"); FT_LOG_ERROR("e.g., ./bin/bert_example 32 12 32 12 64 0 0"); return 0; } @@ -40,21 +40,28 @@ int main(int argc, char** argv) allow_gemm_test = (atoi(argv[8]) == 1) ? true : false; } - int batch_size = atoi(argv[1]); - int num_layers = atoi(argv[2]); - int seq_len = atoi(argv[3]); - int head_num = atoi(argv[4]); - int size_per_head = atoi(argv[5]); - bool is_remove_padding = static_cast(atoi(argv[7])); + int batch_size = atoi(argv[1]); + int num_layers = atoi(argv[2]); + int seq_len = atoi(argv[3]); + int head_num = atoi(argv[4]); + int size_per_head = atoi(argv[5]); + bool is_remove_padding = static_cast(atoi(argv[7])); + const CublasDataType data_type = static_cast(atoi(argv[6])); // 0 FP32, 1 FP16, 2 BF 16 - if (atoi(argv[6]) == 0) { + if (data_type == FLOAT_DATATYPE) { return bertExample( batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding, allow_gemm_test); } - else if (atoi(argv[6]) == 1) { + else if (data_type == HALF_DATATYPE) { return bertExample( batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding, allow_gemm_test); } +#ifdef ENABLE_BF16 + else if (data_type == BFLOAT16_DATATYPE) { + return bertExample<__nv_bfloat16>( + batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding, allow_gemm_test); + } +#endif else { throw std::runtime_error(std::string("[FT][ERROR] is_fp16 should be 0 (use float)" "or 1 (use half). \n ")); @@ -67,16 +74,16 @@ int bertExample(size_t batch_size, size_t seq_len, size_t head_num, size_t size_per_head, - bool is_remove_padding, - bool allow_gemm_test) + bool is_remove_padding, + bool allow_gemm_test) { printf("[INFO] Device: %s \n", getDeviceName().c_str()); - + print_mem_usage("Before loading model"); const size_t hidden_units = head_num * size_per_head; - const size_t inter_size = 4 * hidden_units; + const size_t inter_size = 4 * hidden_units; - cudaStream_t stream; - cublasHandle_t cublas_handle; + cudaStream_t stream; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cudaStreamCreate(&stream); cublasCreate(&cublas_handle); @@ -101,6 +108,11 @@ int bertExample(size_t batch_size, if (std::is_same::value) { cublas_wrapper.setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } @@ -109,8 +121,8 @@ int bertExample(size_t batch_size, AttentionType attention_type = getAttentionType(size_per_head, getSMVersion(), is_remove_padding, seq_len); - Bert bert = Bert(batch_size, - seq_len, + Bert bert = Bert(0, // max_batch_size_, deprecated + 0, // max_seq_len_, deprecated head_num, size_per_head, inter_size, @@ -131,8 +143,8 @@ int bertExample(size_t batch_size, deviceMalloc(&out_tensor, batch_size * seq_len * head_num * size_per_head, false); deviceMalloc(&from_tensor, batch_size * seq_len * head_num * size_per_head, false); - int* h_sequence_lengths = new int[batch_size]; - unsigned int seed = 0; + int* h_sequence_lengths = new int[batch_size]; + unsigned int seed = 0; for (uint i = 0; i < batch_size; i++) { h_sequence_lengths[i] = rand_r(&seed) % seq_len; } @@ -153,11 +165,13 @@ int bertExample(size_t batch_size, getTensorType(), std::vector{batch_size, seq_len, (size_t)(head_num * size_per_head)}, out_tensor}}; + print_mem_usage("After loading model"); // warmup for (int i = 0; i < 10; i++) { bert.forward(&output_tensors, &input_tensors, &bert_weights); } + print_mem_usage("After inference"); // profile time const int ite = 10; @@ -179,6 +193,9 @@ int bertExample(size_t batch_size, #ifdef SPARSITY_ENABLED cusparseLtDestroy(&cusparselt_handle); #endif + deviceFree(d_sequence_lengths); + deviceFree(from_tensor); + deviceFree(out_tensor); delete cublas_algo_map; delete cublas_wrapper_mutex; return 0; diff --git a/examples/cpp/bert/bert_triton_example.cc b/examples/cpp/bert/bert_triton_example.cc new file mode 100644 index 000000000..e80b9759f --- /dev/null +++ b/examples/cpp/bert/bert_triton_example.cc @@ -0,0 +1,301 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "3rdparty/INIReader.h" +#include "src/fastertransformer/triton_backend/bert/BertTritonModel.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace ft = fastertransformer; + +template +std::vector>> +broadcastRequest(const std::vector& h_input_hidden_state, + const std::vector& h_input_seq_len, + const size_t request_batch_size, + const size_t request_seq_len, + const size_t head_num, + const size_t size_per_head, + const int node_id, + const int gpu_count, + std::vector* pointer_record) +{ + std::vector>> request_list; + for (int device_id = 0; device_id < gpu_count; device_id++) { + ft::check_cuda_error(cudaSetDevice(device_id)); + + T* d_input_hidden_state; + int* d_input_seq_len; + + ft::deviceMalloc(&d_input_hidden_state, h_input_hidden_state.size(), false); + ft::deviceMalloc(&d_input_seq_len, h_input_seq_len.size(), false); + ft::cudaH2Dcpy(d_input_hidden_state, h_input_hidden_state.data(), h_input_hidden_state.size()); + ft::cudaH2Dcpy(d_input_seq_len, h_input_seq_len.data(), h_input_seq_len.size()); + + request_list.push_back(std::shared_ptr>( + new std::unordered_map(std::unordered_map{ + {"input_hidden_state", + triton::Tensor{triton::MEMORY_GPU, + std::is_same::value ? triton::TYPE_FP32 : triton::TYPE_FP16, + std::vector{request_batch_size, request_seq_len, head_num * size_per_head}, + d_input_hidden_state}}, + {"sequence_lengths", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{request_batch_size}, + d_input_seq_len}}}))); + + pointer_record->push_back(d_input_hidden_state); + pointer_record->push_back(d_input_seq_len); + } + + return request_list; +} + +template +std::vector>> +prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std::vector* pointer_record) +{ + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + ft::FT_CHECK(false); + } + + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + const size_t request_seq_len = reader.GetInteger("request", "request_seq_len"); + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + const std::string model_dir = reader.Get("ft_instance_hyperparameter", "model_dir"); + + INIReader model_config_reader = INIReader(model_dir + "/config.ini"); + if (model_config_reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << model_dir << "/config.ini" + << "'\n"; + ft::FT_CHECK(false); + } + + const size_t head_num = model_config_reader.GetInteger("bert", "head_num"); + const size_t size_per_head = model_config_reader.GetInteger("bert", "size_per_head"); + + std::vector h_input_hidden_state; + std::vector h_input_seq_len; + srand(0); + for (size_t i = 0; i < request_batch_size * request_seq_len * head_num * size_per_head; ++i) { + T random_num = (T)((random() % 1000) / 1000.f - 0.5f); + h_input_hidden_state.push_back(random_num); + } + for (uint i = 0; i < request_batch_size; i++) { + h_input_seq_len.push_back(random() % request_seq_len); + } + + auto request_list = broadcastRequest(h_input_hidden_state, + h_input_seq_len, + request_batch_size, + request_seq_len, + head_num, + size_per_head, + node_id, + gpu_count, + pointer_record); + return request_list; +} + +int threadCreateModelInstances(std::shared_ptr model, + std::vector>* model_instances, + const int device_id, + const int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + cudaStream_t stream; + ft::check_cuda_error(cudaStreamCreate(&stream)); + model->createSharedWeights(device_id, rank); + auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_params, custom_all_reduce_comm); + model_instances->at(device_id) = std::move(model_instance); + FT_LOG_INFO("model instance %d is created", device_id); + ft::print_mem_usage(); + return 0; +} + +int threadForward(std::unique_ptr* model_instance, + std::shared_ptr> request, + std::shared_ptr>* output_tensors, + const int device_id) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + *output_tensors = (*model_instance)->forward(request); + return 0; +} + +template +int bert_triton_example(int argc, char* argv[]) +{ + /* + Prepare the nccl ids, node id, device id and world size + by MPI or triton + */ + + ft::mpi::initialize(&argc, &argv); + int node_id = ft::mpi::getCommWorldRank(); + int node_num = ft::mpi::getCommWorldSize(); + + // Note: Only supports that all nodes have same gpu count + const int gpu_count = ft::getDeviceCount(); + const int world_size = node_num * gpu_count; + std::string ini_name = argc >= 2 ? std::string(argv[1]) : "../examples/cpp/bert/bert_config.ini"; + + // step 1: Create model + INIReader reader = INIReader(ini_name); + std::shared_ptr model = std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + reader.Get("ft_instance_hyperparameter", "model_dir"), + reader.GetInteger("ft_instance_hyperparameter", "int8_mode"), + reader.GetInteger("ft_instance_hyperparameter", "is_sparse"), + reader.GetInteger("ft_instance_hyperparameter", "is_remove_padding")); + std::cout << model->toString(); + int tensor_para_size = model->getTensorParaSize(); + int pipeline_para_size = model->getPipelineParaSize(); + ft::FT_CHECK_WITH_INFO(world_size == (tensor_para_size * pipeline_para_size), + fmtstr("World Size(%d) != Tensor Parallel Size (%d) * Pipeline Parallel Size (%d) !", + world_size, + tensor_para_size, + pipeline_para_size)); + + // step 2: Initialize the NCCL + std::pair, std::vector> nccl_params = model->createNcclParams(node_id); + cudaDeviceSynchronize(); + + // Optional Step: create custom all reduce comm + std::vector> custom_all_reduce_comms; + model->createCustomComms(&custom_all_reduce_comms, world_size); + + // step 3: Create model instances + std::vector> model_instances((size_t)gpu_count); + std::vector threads; + + threads.clear(); + + for (int device_id = 0; device_id < gpu_count; device_id++) { + const int rank = node_id * gpu_count + device_id; + threads.push_back(std::thread(threadCreateModelInstances, + model, + &model_instances, + device_id, + rank, + nccl_params, + custom_all_reduce_comms[rank])); + } + for (auto& t : threads) { + t.join(); + } + + // step 4: prepare request + std::vector pointer_record; // Used to prevent the pointers are release after leaving functions + std::vector>> request_list = + prepareRequest(ini_name, node_id, gpu_count, &pointer_record); + FT_LOG_INFO("request is created"); + + // step 5: Forward + std::vector>> output_tensors_lists( + (size_t)gpu_count); + for (int i = 0; i < 2; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id)); + } + for (auto& t : threads) { + t.join(); + } + } + FT_LOG_INFO("forward is completed."); + const size_t request_batch_size = output_tensors_lists[0].get()->at("output_hidden_state").shape[0]; + const size_t request_seq_len = output_tensors_lists[0].get()->at("output_hidden_state").shape[1]; + const size_t hidden_dim = output_tensors_lists[0].get()->at("output_hidden_state").shape[1]; + + if (node_id == 0) { + ft::print_abs_mean((T*)output_tensors_lists[0].get()->at("output_hidden_state").data, + request_batch_size * request_seq_len * hidden_dim, + (cudaStream_t)0, + "output_tensors_lists[0].at(\"output_hidden_state\").data"); + } + + // test time + struct timeval start, end; + ft::mpi::barrier(); + cudaDeviceSynchronize(); + gettimeofday(&start, NULL); + + const int ite = 20; + for (int i = 0; i < ite; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id)); + } + for (auto& t : threads) { + t.join(); + } + } + + cudaDeviceSynchronize(); + ft::mpi::barrier(); + + gettimeofday(&end, NULL); + + FT_LOG_INFO("request_batch_size %d request_seq_len %d" + " FT-CPP-BERT-Triton-time %.2f ms", + request_batch_size, + request_seq_len, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + + ft::mpi::finalize(); + return 0; +} + +template int bert_triton_example(int argc, char* argv[]); +template int bert_triton_example(int argc, char* argv[]); + +int main(int argc, char* argv[]) +{ + std::string ini_name = argc >= 2 ? std::string(argv[1]) : "../examples/cpp/bert/bert_config.ini"; + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + ft::FT_CHECK(false); + } + + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + if (data_type == "fp32") { + bert_triton_example(argc, argv); + } + else if (data_type == "fp16") { + bert_triton_example(argc, argv); + } + + return 0; +} diff --git a/examples/cpp/bert_int8/CMakeLists.txt b/examples/cpp/bert_int8/CMakeLists.txt index 1d04c97d6..e67a624c9 100644 --- a/examples/cpp/bert_int8/CMakeLists.txt +++ b/examples/cpp/bert_int8/CMakeLists.txt @@ -13,4 +13,4 @@ # limitations under the License. add_executable(bert_int8_example bert_int8_example.cc) -target_link_libraries(bert_int8_example PUBLIC -lcublas -lcublasLt -lcudart BertINT8) +target_link_libraries(bert_int8_example PUBLIC -lcublas -lcublasLt -lcudart BertINT8 nvtx_utils) diff --git a/examples/cpp/bert_int8/bert_int8_example.cc b/examples/cpp/bert_int8/bert_int8_example.cc index f1534ebe0..a84e6395b 100644 --- a/examples/cpp/bert_int8/bert_int8_example.cc +++ b/examples/cpp/bert_int8/bert_int8_example.cc @@ -15,11 +15,17 @@ */ #include "src/fastertransformer/models/bert_int8/BertINT8.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #endif +#ifdef USE_NVTX +bool NVTX_ON = true; +#endif + using namespace fastertransformer; template @@ -28,9 +34,9 @@ int bertINT8Example(size_t batch_size, size_t seq_len, size_t head_num, size_t size_per_head, - int int8_mode, - bool is_remove_padding, - bool allow_gemm_test = false); + int int8_mode, + bool is_remove_padding, + bool allow_gemm_test = false); int main(int argc, char** argv) { @@ -45,12 +51,12 @@ int main(int argc, char** argv) allow_gemm_test = (atoi(argv[9]) == 1) ? true : false; } - int batch_size = atoi(argv[1]); - int num_layers = atoi(argv[2]); - int seq_len = atoi(argv[3]); - int head_num = atoi(argv[4]); - int size_per_head = atoi(argv[5]); - int int8_mode = atoi(argv[8]); + int batch_size = atoi(argv[1]); + int num_layers = atoi(argv[2]); + int seq_len = atoi(argv[3]); + int head_num = atoi(argv[4]); + int size_per_head = atoi(argv[5]); + int int8_mode = atoi(argv[8]); bool is_remove_padding = static_cast(atoi(argv[7])); if (atoi(argv[6]) == 0) { @@ -73,16 +79,16 @@ int bertINT8Example(size_t batch_size, size_t seq_len, size_t head_num, size_t size_per_head, - int int8_mode, - bool is_remove_padding, - bool allow_gemm_test) + int int8_mode, + bool is_remove_padding, + bool allow_gemm_test) { printf("[INFO] Device: %s \n", getDeviceName().c_str()); const size_t hidden_units = head_num * size_per_head; - const size_t inter_size = 4 * hidden_units; + const size_t inter_size = 4 * hidden_units; - cudaStream_t stream; + cudaStream_t stream; cublasLtHandle_t cublaslt_handle; cudaStreamCreate(&stream); cublasLtCreate(&cublaslt_handle); @@ -138,8 +144,8 @@ int bertINT8Example(size_t batch_size, deviceMalloc(&out_tensor, batch_size * seq_len * head_num * size_per_head, false); deviceMalloc(&from_tensor, batch_size * seq_len * head_num * size_per_head, false); - int* h_sequence_lengths = new int[batch_size]; - unsigned int seed = 0; + int* h_sequence_lengths = new int[batch_size]; + unsigned int seed = 0; for (uint i = 0; i < batch_size; i++) { h_sequence_lengths[i] = rand_r(&seed) % seq_len; } @@ -168,12 +174,16 @@ int bertINT8Example(size_t batch_size, // profile time const int ite = 100; + cudaProfilerStart(); CudaTimer cuda_timer(stream); + nvtx::resetScope(); + nvtx::addScope("BertInt8"); cuda_timer.start(); for (int i = 0; i < ite; i++) { bert_int8.forward(&output_tensors, &input_tensors, &bert_layer_weights); } float total_time = cuda_timer.stop(); + cudaProfilerStop(); printf("[INFO] batch_size %ld seq_len %ld layer %ld " "FT-CPP-time %.2f ms (%d iterations) \n", diff --git a/examples/cpp/decoding/decoding_example.cc b/examples/cpp/decoding/decoding_example.cc index f14983b0f..486ab6c67 100644 --- a/examples/cpp/decoding/decoding_example.cc +++ b/examples/cpp/decoding/decoding_example.cc @@ -32,32 +32,33 @@ int decodingExample(const size_t batch_size, const size_t max_seq_len, const size_t memory_max_seq_len, const size_t memory_hidden_units, - const int top_k, - const float top_p); + const int top_k, + const float top_p); int main(int argc, char** argv) { if (argc != 14) { printf("[ERROR] decoding_example batch_size beam_width head_num size_per_head inter_size vocab_size" - " num_layers max_seq_len memory_max_seq_len memory_hidden_units top_k top_p is_fp16\n"); + " num_layers max_seq_len memory_max_seq_len memory_hidden_units top_k top_p data_type\n"); printf("e.g., ./bin/decoding_example 4 1 8 64 2048 30000 6 32 32 512 0 0.6 1\n"); return 0; } - int batch_size = atoi(argv[1]); - int beam_width = atoi(argv[2]); - int head_num = atoi(argv[3]); - int size_per_head = atoi(argv[4]); - int inter_size = atoi(argv[5]); - int vocab_size = atoi(argv[6]); - int num_layers = atoi(argv[7]); - int max_seq_len = atoi(argv[8]); - int memory_max_seq_len = atoi(argv[9]); - int memory_hidden_units = atoi(argv[10]); - int top_k = atoi(argv[11]); - float top_p = atof(argv[12]); - - if (atoi(argv[13]) == 0) { + int batch_size = atoi(argv[1]); + int beam_width = atoi(argv[2]); + int head_num = atoi(argv[3]); + int size_per_head = atoi(argv[4]); + int inter_size = atoi(argv[5]); + int vocab_size = atoi(argv[6]); + int num_layers = atoi(argv[7]); + int max_seq_len = atoi(argv[8]); + int memory_max_seq_len = atoi(argv[9]); + int memory_hidden_units = atoi(argv[10]); + int top_k = atoi(argv[11]); + float top_p = atof(argv[12]); + const CublasDataType data_type = static_cast(atoi(argv[13])); // 0 FP32, 1 FP16, 2 BF 16 + + if (data_type == FLOAT_DATATYPE) { return decodingExample(batch_size, beam_width, head_num, @@ -71,7 +72,7 @@ int main(int argc, char** argv) top_k, top_p); } - else if (atoi(argv[13]) == 1) { + else if (data_type == HALF_DATATYPE) { return decodingExample(batch_size, beam_width, head_num, @@ -85,6 +86,22 @@ int main(int argc, char** argv) top_k, top_p); } +#ifdef ENABLE_BF16 + else if (data_type == BFLOAT16_DATATYPE) { + return decodingExample<__nv_bfloat16>(batch_size, + beam_width, + head_num, + size_per_head, + inter_size, + vocab_size, + num_layers, + max_seq_len, + memory_max_seq_len, + memory_hidden_units, + top_k, + top_p); + } +#endif else { throw std::runtime_error(std::string("[FT][ERROR] is_fp16 should be 0 (use float)" "or 1 (use half). \n ")); @@ -102,15 +119,15 @@ int decodingExample(const size_t batch_size, const size_t max_seq_len, const size_t memory_max_seq_len, const size_t memory_hidden_units, - const int top_k, - const float top_p) + const int top_k, + const float top_p) { const size_t hidden_units = head_num * size_per_head; - const int start_id = 50256; - const int end_id = 50256; + const int start_id = 0; + const int end_id = 1; - cudaStream_t stream; - cublasHandle_t cublas_handle; + cudaStream_t stream; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cudaStreamCreate(&stream); cublasCreate(&cublas_handle); @@ -120,12 +137,17 @@ int decodingExample(const size_t batch_size, Allocator allocator(getDevice()); - std::mutex* cublas_wrapper_mutex = new std::mutex(); + std::mutex* cublas_wrapper_mutex = new std::mutex(); cublasMMWrapper cublas_wrapper = cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); if (std::is_same::value) { cublas_wrapper.setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } @@ -157,7 +179,7 @@ int decodingExample(const size_t batch_size, false, &prop); - T* d_memory_tensor; + T* d_memory_tensor; int* d_memory_sequence_lengths; deviceMalloc(&d_memory_tensor, memory_hidden_units * memory_max_seq_len * batch_size * beam_width); deviceMalloc(&d_memory_sequence_lengths, batch_size * beam_width); diff --git a/examples/cpp/decoding/layernorm_test.cc b/examples/cpp/decoding/layernorm_test.cc index b788055ac..d9fcff5e4 100644 --- a/examples/cpp/decoding/layernorm_test.cc +++ b/examples/cpp/decoding/layernorm_test.cc @@ -37,13 +37,13 @@ int checkNonZero(T* A, int size) } template -void checkMat(TA* A, TB* B, int size, char* mark) +void checkMat(TA* A, TB* B, int size, const char* mark, float threshold) { float max_diff = -10000.0f; float max_diff_a, max_diff_b; - TA* matA = (TA*)malloc(sizeof(TA) * size); - TB* matB = (TB*)malloc(sizeof(TB) * size); - int not_passed = 0; + TA* matA = (TA*)malloc(sizeof(TA) * size); + TB* matB = (TB*)malloc(sizeof(TB) * size); + int not_passed = 0; cudaMemcpy(matA, A, sizeof(TA) * size, cudaMemcpyDeviceToHost); cudaMemcpy(matB, B, sizeof(TB) * size, cudaMemcpyDeviceToHost); float A_nonZero_ratio = float(checkNonZero(A, size)) / float(size); @@ -54,23 +54,24 @@ void checkMat(TA* A, TB* B, int size, char* mark) for (int jjj = 0; jjj < size; jjj++) { float diff = fabs(float(matA[jjj]) - float(matB[jjj])); if (diff > max_diff) { - max_diff = diff; + max_diff = diff; max_diff_a = float(matA[jjj]); max_diff_b = float(matB[jjj]); } - if (fabs(float(matA[jjj]) - float(matB[jjj])) > 0.001) { + if (fabs(float(matA[jjj]) - float(matB[jjj])) > threshold) { not_passed += 1; if (not_passed < 1000) { printf("%d %f %f %f\n", jjj, float(matA[jjj]), float(matB[jjj]), float(matA[jjj]) - float(matB[jjj])); } } } - printf("[%s] max diff : %f ; a : %f ; b : %f\n", mark, max_diff, max_diff_a, max_diff_b); + FT_LOG_INFO("[%s] max diff : %f ; a : %f ; b : %f", mark, max_diff, max_diff_a, max_diff_b); if (not_passed != 0) { - printf("[%s] different elements : %d \n", mark, not_passed); + FT_LOG_ERROR("[%s] different elements : %d ", mark, not_passed); + FT_CHECK(false); } else { - printf("[%s] check pass!\n", mark); + FT_LOG_INFO("[%s] check pass!", mark); } free(matA); free(matB); @@ -85,23 +86,32 @@ void add_bias_residual_layernorm_test(const int m, const int n); int main(int argc, char** argv) { if (argc != 4) { - printf("[ERROR] layernorm_test max_m max_n is_fp16\n"); + printf("[ERROR] layernorm_test max_m max_n data_type\n"); printf("e.g., ./bin/layernorm_test 1 1024 1\n"); return 0; } - int max_m = atoi(argv[1]); - int max_n = atoi(argv[2]); - int is_fp16 = atoi(argv[3]); + int max_m = atoi(argv[1]); + int max_n = atoi(argv[2]); + const FtCudaDataType data_type = static_cast(atoi(argv[3])); // 0 FP32, 1 FP16, 2 BF 16 for (int m = 1; m <= max_m; m *= 2) { for (int n = 128; n <= max_n; n *= 2) { - if (is_fp16) { + if (data_type == FP16) { add_bias_residual_layernorm_test(m, n); } - else { +#ifdef ENABLE_BF16 + else if (data_type == BF16) { + add_bias_residual_layernorm_test<__nv_bfloat16>(m, n); + } +#endif + else if (data_type == FP32) { add_bias_residual_layernorm_test(m, n); } + else { + FT_LOG_ERROR("data_type should be fp32, fp16 or bf16!"); + exit(-1); + } } } return 0; @@ -114,7 +124,8 @@ void layernorm_test(const int m, const int n) check_cuda_error(cudaGetDeviceProperties(&prop, 0)); printf("Device %s\n", prop.name); - T *input, *output_opt, *output_baseline, *gamma, *beta; + const float layernorm_eps = 1e-4f; + T * input, *output_opt, *output_baseline, *gamma, *beta; deviceMalloc(&input, m * n); deviceMalloc(&output_baseline, m * n); deviceMalloc(&output_opt, m * n); @@ -126,15 +137,15 @@ void layernorm_test(const int m, const int n) // warmup for (int i = 0; i < 1000; i++) { - invokeGeneralLayerNorm(output_baseline, input, gamma, beta, m, n, stream); - invokeGeneralLayerNorm(output_opt, input, gamma, beta, m, n, stream, true); + invokeGeneralLayerNorm(output_baseline, input, gamma, beta, layernorm_eps, m, n, stream); + invokeGeneralLayerNorm(output_opt, input, gamma, beta, layernorm_eps, m, n, stream, true); } struct timeval start, end; cudaDeviceSynchronize(); gettimeofday(&start, NULL); for (int i = 0; i < ite; i++) { - invokeGeneralLayerNorm(output_baseline, input, gamma, beta, m, n, stream); + invokeGeneralLayerNorm(output_baseline, input, gamma, beta, layernorm_eps, m, n, stream); } cudaDeviceSynchronize(); gettimeofday(&end, NULL); @@ -144,7 +155,7 @@ void layernorm_test(const int m, const int n) cudaDeviceSynchronize(); gettimeofday(&start_2, NULL); for (int i = 0; i < ite; i++) { - invokeGeneralLayerNorm(output_opt, input, gamma, beta, m, n, stream, true); + invokeGeneralLayerNorm(output_opt, input, gamma, beta, layernorm_eps, m, n, stream, true); } cudaDeviceSynchronize(); gettimeofday(&end_2, NULL); @@ -166,9 +177,10 @@ void add_bias_residual_layernorm_test(const int m, const int n) check_cuda_error(cudaGetDeviceProperties(&prop, 0)); printf("Device %s\n", prop.name); - int opt_version = 2; - T *input, *output_opt, *output_baseline, *gamma, *beta, *bias; - T *normed_output_opt, *normed_output_baseline; + const float layernorm_eps = 1e-4f; + int opt_version = 2; + T * input, *output_opt, *output_baseline, *gamma, *beta, *bias; + T * normed_output_opt, *normed_output_baseline; deviceMalloc(&input, m * n); deviceMalloc(&output_baseline, m * n); deviceMalloc(&output_opt, m * n); @@ -181,22 +193,39 @@ void add_bias_residual_layernorm_test(const int m, const int n) cudaStream_t stream; cudaStreamCreate(&stream); const int warmup_ite = 1000; // 1000; - const int ite = 5000; // 5000; + const int ite = 5000; // 5000; // verify correctness invokeGeneralAddBiasResidualPreLayerNorm( - output_baseline, normed_output_baseline, input, gamma, beta, bias, m, n, stream, 0); + output_baseline, normed_output_baseline, input, gamma, beta, bias, layernorm_eps, m, n, stream, 0); invokeGeneralAddBiasResidualPreLayerNorm( - output_opt, normed_output_opt, input, gamma, beta, bias, m, n, stream, opt_version); - checkMat(output_baseline, output_opt, m * n, "output_baseline vs output_opt"); - checkMat(normed_output_baseline, normed_output_opt, m * n, "normed_output_baseline vs normed_output_opt"); + output_opt, normed_output_opt, input, gamma, beta, bias, layernorm_eps, m, n, stream, opt_version); + float threshold = 0.0f; + if (std::is_same::value) { + threshold = 1e-6f; + } + else if (std::is_same::value) { + threshold = 1e-3; + } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + threshold = 5e-2; + } +#endif + else { + FT_LOG_ERROR("data_type should be fp32, fp16 or bf16!"); + exit(-1); + } + checkMat(output_baseline, output_opt, m * n, "output_baseline vs output_opt", threshold); + checkMat( + normed_output_baseline, normed_output_opt, m * n, "normed_output_baseline vs normed_output_opt", threshold); // warmup for (int i = 0; i < warmup_ite; i++) { invokeGeneralAddBiasResidualPreLayerNorm( - output_baseline, normed_output_baseline, input, gamma, beta, bias, m, n, stream, 0); + output_baseline, normed_output_baseline, input, gamma, beta, bias, layernorm_eps, m, n, stream, 0); invokeGeneralAddBiasResidualPreLayerNorm( - output_opt, normed_output_opt, input, gamma, beta, bias, m, n, stream, opt_version); + output_opt, normed_output_opt, input, gamma, beta, bias, layernorm_eps, m, n, stream, opt_version); } struct timeval start, end; @@ -204,7 +233,7 @@ void add_bias_residual_layernorm_test(const int m, const int n) gettimeofday(&start, NULL); for (int i = 0; i < ite; i++) { invokeGeneralAddBiasResidualPreLayerNorm( - output_baseline, normed_output_baseline, input, gamma, beta, bias, m, n, stream, 0); + output_baseline, normed_output_baseline, input, gamma, beta, bias, layernorm_eps, m, n, stream, 0); } cudaDeviceSynchronize(); gettimeofday(&end, NULL); @@ -215,7 +244,7 @@ void add_bias_residual_layernorm_test(const int m, const int n) gettimeofday(&start_2, NULL); for (int i = 0; i < ite; i++) { invokeGeneralAddBiasResidualPreLayerNorm( - output_opt, normed_output_opt, input, gamma, beta, bias, m, n, stream, opt_version); + output_opt, normed_output_opt, input, gamma, beta, bias, layernorm_eps, m, n, stream, opt_version); } cudaDeviceSynchronize(); gettimeofday(&end_2, NULL); diff --git a/examples/cpp/gpt/gpt_config.ini b/examples/cpp/gpt/gpt_config.ini index f40ee0135..7a87be765 100644 --- a/examples/cpp/gpt/gpt_config.ini +++ b/examples/cpp/gpt/gpt_config.ini @@ -6,6 +6,8 @@ top_k=0 ; k value for top k sampling top_p=0.5 ; p value for top p sampling temperature=1.0 ; Use for sampling repetition_penalty=2.0 ; Use for sampling +len_penalty=0.0 +beam_search_diversity_rate=0.0 data_type=fp16 sparse=0 model_name=gpt_124M @@ -15,6 +17,7 @@ model_name=gpt_124M ; model_name=self_defined ; model_dir=./models/megatron-models/c-model/6.7b/ model_dir=models/openai-gpt-models/c-model/124m/1-gpu/ +shared_contexts_ratio=1.0 [request] request_batch_size=8 ; determine by the request diff --git a/examples/cpp/gpt/gpt_example.cc b/examples/cpp/gpt/gpt_example.cc index cacb09e90..94adb065d 100644 --- a/examples/cpp/gpt/gpt_example.cc +++ b/examples/cpp/gpt/gpt_example.cc @@ -74,26 +74,26 @@ int main(int argc, char* argv[]) return 0; } -int read_start_ids(int batch_size, +int read_start_ids(int batch_size, std::vector* v_start_lengths, std::vector* v_start_ids, - int& max_input_len, - const int end_id, - const int beam_width) + int& max_input_len, + const int end_id, + const int beam_width) { std::vector> tmp_start_ids; - std::vector tmp_start_lengths; + std::vector tmp_start_lengths; - std::string file_name = "../examples/cpp/gpt/start_ids.csv"; + std::string file_name = "../examples/cpp/gpt/start_ids.csv"; std::ifstream start_id_file(file_name, std::ios::in); if (start_id_file.is_open()) { std::string line; - int i0 = 0; + int i0 = 0; while (std::getline(start_id_file, line)) { std::stringstream lineStream(line); - std::string vals; - int i1 = 0; - std::vector tmp_vec; + std::string vals; + int i1 = 0; + std::vector tmp_vec; while (std::getline(lineStream, vals, ',')) { tmp_vec.push_back(std::stoi(vals)); i1++; @@ -145,26 +145,28 @@ int read_start_ids(int batch_size, template void gpt_example(const INIReader reader) { - const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); - const size_t max_batch_size = reader.GetInteger("ft_instance_hyperparameter", "max_batch_size"); - const size_t max_seq_len = reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"); - const size_t beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); - const int top_k = reader.GetInteger("ft_instance_hyperparameter", "top_k"); - const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); - const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); - const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); - const std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); - const bool sparse = static_cast(reader.GetInteger("ft_instance_hyperparameter", "sparse")); - const float len_penalty = 1.0f; - const float beam_search_diversity_rate = 0.0f; + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + const size_t max_batch_size = reader.GetInteger("ft_instance_hyperparameter", "max_batch_size"); + const size_t max_seq_len = reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"); + const size_t beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); + const uint top_k = (uint)reader.GetInteger("ft_instance_hyperparameter", "top_k"); + const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); + const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); + const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); + const std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); + const bool sparse = static_cast(reader.GetInteger("ft_instance_hyperparameter", "sparse")); + const float shared_contexts_ratio = reader.GetFloat("ft_instance_hyperparameter", "shared_contexts_ratio", 1.0f); + const float len_penalty = reader.GetFloat("ft_instance_hyperparameter", "len_penalty"); + const float beam_search_diversity_rate = + reader.GetFloat("ft_instance_hyperparameter", "beam_search_diversity_rate"); const unsigned long long int random_seed = 0; - const size_t head_num = reader.GetInteger(model_name, "head_num"); - const size_t size_per_head = reader.GetInteger(model_name, "size_per_head"); - const size_t vocab_size = reader.GetInteger(model_name, "vocab_size"); + const size_t head_num = reader.GetInteger(model_name, "head_num"); + const size_t size_per_head = reader.GetInteger(model_name, "size_per_head"); + const size_t vocab_size = reader.GetInteger(model_name, "vocab_size"); const size_t decoder_layers = reader.GetInteger(model_name, "decoder_layers"); - const size_t hidden_units = head_num * size_per_head; - const size_t inter_size = 4 * hidden_units; + const size_t hidden_units = head_num * size_per_head; + const size_t inter_size = 4 * hidden_units; const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); // The length of tokens we hope this model to generate @@ -178,12 +180,12 @@ void gpt_example(const INIReader reader) } const int start_id = 50256; - const int end_id = 50256; + const int end_id = 50256; const int rank = 0; // Read ids of request from file. - int max_input_len = -1; + int max_input_len = -1; std::vector v_start_lengths; std::vector v_start_ids; read_start_ids(request_batch_size, &v_start_lengths, &v_start_ids, max_input_len, end_id, 1); @@ -192,7 +194,7 @@ void gpt_example(const INIReader reader) int* d_input_lengths; if (max_input_len == 0) { // unconditional case, no input ids, so do nothing. - d_input_ids = nullptr; + d_input_ids = nullptr; d_input_lengths = nullptr; } else { @@ -209,8 +211,8 @@ void gpt_example(const INIReader reader) exit(-1); } - cudaStream_t stream; - cublasHandle_t cublas_handle; + cudaStream_t stream; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cudaStreamCreate(&stream); cublasCreate(&cublas_handle); @@ -221,7 +223,7 @@ void gpt_example(const INIReader reader) CHECK_CUSPARSE(cusparseLtInit(&cusparselt_handle)); cublasAlgoMap* cublas_algo_map = new cublasAlgoMap(GEMM_CONFIG, SPGEMM_CONFIG); #else - cublasAlgoMap* cublas_algo_map = new cublasAlgoMap(GEMM_CONFIG); + cublasAlgoMap* cublas_algo_map = new cublasAlgoMap(GEMM_CONFIG); #endif Allocator allocator(getDevice()); @@ -249,8 +251,7 @@ void gpt_example(const INIReader reader) struct cudaDeviceProp prop; check_cuda_error(cudaGetDeviceProperties(&prop, 0)); - fastertransformer::ParallelGptWeight gpt_weights( - hidden_units, inter_size, vocab_size, decoder_layers, max_seq_len, 1, 0, 1, 0, 0); + ParallelGptWeight gpt_weights(hidden_units, inter_size, vocab_size, decoder_layers, max_seq_len, 1, 0, 1, 0, 0); gpt_weights.loadModel(model_dir); @@ -276,12 +277,15 @@ void gpt_example(const INIReader reader) vocab_size, start_id, end_id, + end_id + 1, // p_prompt_tuning token start id + PromptLearningType::no_prompt, + gptVariantParams{}, 0.0f, // beam_search_diversity_rate, 0, // top_k, 0.0, // top_p, 0, // random_seed, 1.0f, // temperature, - 1.0f, // len_penalty, + 0.0f, // len_penalty, 1.0f, // repetition_penalty, tensor_para, pipeline_para, @@ -291,20 +295,24 @@ void gpt_example(const INIReader reader) false, &prop, sparse, - 0); + 0, + nullptr, + 0, + true, + shared_contexts_ratio); int* d_output_ids; - int* d_parent_ids; int* d_sequence_lengths; deviceMalloc(&d_output_ids, request_batch_size * beam_width * total_output_len, false); - deviceMalloc(&d_parent_ids, request_batch_size * beam_width * total_output_len, false); deviceMalloc(&d_sequence_lengths, request_batch_size * beam_width, false); + std::vector output_seq_len(request_batch_size, total_output_len); std::unordered_map input_tensors = std::unordered_map{ {"input_ids", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, (size_t)max_input_len}, d_input_ids}}, {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size}, d_input_lengths}}, - {"max_output_seq_len", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &total_output_len}}, + {"output_seq_len", + Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}, {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &temperature}}, {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &len_penalty}}, {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &repetition_penalty}}}; @@ -319,7 +327,7 @@ void gpt_example(const INIReader reader) input_tensors.insert({"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &top_p}}); } if (top_k != 0) { - input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &top_k}}); + input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{1}, &top_k}}); } } @@ -329,16 +337,11 @@ void gpt_example(const INIReader reader) TYPE_INT32, std::vector{request_batch_size, beam_width, (size_t)total_output_len}, d_output_ids}}, - {"parent_ids", - Tensor{MEMORY_GPU, - TYPE_INT32, - std::vector{(size_t)total_output_len, request_batch_size, beam_width}, - d_parent_ids}}, {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths}}}; float* output_log_probs = nullptr; - float* d_cum_log_probs = nullptr; + float* d_cum_log_probs = nullptr; if (is_return_log_probs) { deviceMalloc(&output_log_probs, request_batch_size * beam_width * request_output_len); output_tensors.insert({"output_log_probs", @@ -400,14 +403,14 @@ void gpt_example(const INIReader reader) if (rank == 0) { - std::string fName = "out"; - auto outFile = std::ofstream(fName, std::ios::out); + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); if (!outFile.is_open()) { printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); } else { size_t outCount = total_output_len * request_batch_size * beam_width; - int* hBuf = new int[outCount]; + int* hBuf = new int[outCount]; cudaD2Hcpy(hBuf, d_output_ids, outCount); { @@ -436,8 +439,8 @@ void gpt_example(const INIReader reader) outFile.close(); if (d_cum_log_probs != nullptr) { - std::string logprob_fname = "logprob.out"; - std::ofstream logprob_file = std::ofstream("logprob.out", std::ios::out); + std::string logprob_fname = "logprob.out"; + std::ofstream logprob_file = std::ofstream("logprob.out", std::ios::out); if (!logprob_file.is_open()) { printf("[WARNING] Cannot write results into output file %s \n", logprob_fname.c_str()); } diff --git a/examples/cpp/gptj/CMakeLists.txt b/examples/cpp/gptj/CMakeLists.txt index 3c90b1b19..05e6425c8 100644 --- a/examples/cpp/gptj/CMakeLists.txt +++ b/examples/cpp/gptj/CMakeLists.txt @@ -14,8 +14,9 @@ add_executable(gptj_example gptj_example.cc) target_link_libraries(gptj_example PUBLIC -lcublas -lcublasLt -lcudart - GptJ nvtx_utils -lmpi gpt_example_utils word_list) + GptJ nvtx_utils gpt_example_utils word_list mpi_utils nccl_utils) add_executable(gptj_triton_example gptj_triton_example.cc) -target_link_libraries(gptj_triton_example PUBLIC -lcublas -lcublasLt -lcudart - GptJTritonBackend custom_ar_comm -lmpi gpt_example_utils word_list -lpthread) +target_link_libraries(gptj_triton_example PUBLIC -lcublas -lcublasLt -lcudart -lpthread + GptJTritonBackend TransformerTritonBackend custom_ar_comm + gpt_example_utils word_list mpi_utils nccl_utils) diff --git a/examples/cpp/gptj/gptj_config.ini b/examples/cpp/gptj/gptj_config.ini index 4957072e3..d62ec2e06 100644 --- a/examples/cpp/gptj/gptj_config.ini +++ b/examples/cpp/gptj/gptj_config.ini @@ -6,12 +6,12 @@ top_k=0 ; k value for top k sampling top_p=0.5 ; p value for top p sampling temperature=1.0 ; Use for sampling repetition_penalty=2.0 ; Use for sampling -len_penalty=1.0 +len_penalty=0.0 beam_search_diversity_rate=0.0 -is_half=0 +data_type=fp16 enable_custom_all_reduce=0 -tensor_para_size=8 +tensor_para_size=1 pipeline_para_size=1 model_name=gptj_6B @@ -30,3 +30,14 @@ rotary_embedding=64 start_id=50256 end_id=50256 inter_size=16384 +num_tasks=2 ;optional +prompt_learning_type=2 ;optional --> 0: no prompt, 1: soft_prompt, 2: prefix_prompt, 3: p/prompt_tuning + +;prompt learning example (soft prompt doesn't need it) +[gptj_6B_task_0] ; task_name_id = 0 +task_name=task_0 +prompt_length=5 +;optional +[gptj_6B_task_1] ; task_name_id = 1 +task_name=task_1 +prompt_length=10 diff --git a/examples/cpp/gptj/gptj_example.cc b/examples/cpp/gptj/gptj_example.cc index a3d400b6f..e2cd6d2d0 100644 --- a/examples/cpp/gptj/gptj_example.cc +++ b/examples/cpp/gptj/gptj_example.cc @@ -17,7 +17,7 @@ #include "3rdparty/INIReader.h" #include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" #include "src/fastertransformer/models/gptj/GptJ.h" -#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" #include "src/fastertransformer/utils/nvtx_utils.h" #include "src/fastertransformer/utils/word_list.h" @@ -39,7 +39,7 @@ void gptj_example(const INIReader reader); int main(int argc, char* argv[]) { - MPICHECK(MPI_Init(&argc, &argv)); + mpi::initialize(&argc, &argv); srand(0); std::string ini_name; @@ -55,65 +55,72 @@ int main(int argc, char* argv[]) std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; return -1; } - const int is_half = reader.GetInteger("ft_instance_hyperparameter", "is_half"); + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); - if (is_half == 0) { + if (data_type == "fp32") { gptj_example(reader); } - else if (is_half == 1) { + else if (data_type == "fp16") { gptj_example(reader); } +#ifdef ENABLE_BF16 + else if (data_type == "bf16") { + gptj_example<__nv_bfloat16>(reader); + } +#endif else { - printf("[ERROR] is_fp16 should be 0 (use float) or 1 (use half). \n"); + FT_LOG_ERROR("data_type should be fp32, fp16 or bf16!"); return -1; } - MPI_Finalize(); + mpi::finalize(); return 0; } template void gptj_example(const INIReader reader) { - const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); - const size_t max_seq_len = reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"); - const size_t beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); - const int top_k = reader.GetInteger("ft_instance_hyperparameter", "top_k"); - const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); - const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); - const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); - const float len_penalty = reader.GetFloat("ft_instance_hyperparameter", "len_penalty"); - const float beam_search_diversity_rate = + print_mem_usage("Before loading model"); + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + const size_t max_seq_len = reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"); + const size_t beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); + const uint top_k = (uint)reader.GetInteger("ft_instance_hyperparameter", "top_k"); + const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); + const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); + const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); + const float len_penalty = reader.GetFloat("ft_instance_hyperparameter", "len_penalty"); + const float beam_search_diversity_rate = reader.GetFloat("ft_instance_hyperparameter", "beam_search_diversity_rate"); std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); - int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); - const size_t head_num = reader.GetInteger(model_name, "head_num"); - const size_t size_per_head = reader.GetInteger(model_name, "size_per_head"); - const size_t vocab_size = reader.GetInteger(model_name, "vocab_size"); - const size_t decoder_layers = reader.GetInteger(model_name, "decoder_layers"); + const size_t head_num = reader.GetInteger(model_name, "head_num"); + const size_t size_per_head = reader.GetInteger(model_name, "size_per_head"); + const size_t vocab_size = reader.GetInteger(model_name, "vocab_size"); + const size_t decoder_layers = reader.GetInteger(model_name, "decoder_layers"); const size_t rotary_embedding_dim = reader.GetInteger(model_name, "rotary_embedding"); - const int start_id = reader.GetInteger(model_name, "start_id"); - const int end_id = reader.GetInteger(model_name, "end_id"); + const int start_id = reader.GetInteger(model_name, "start_id"); + const int end_id = reader.GetInteger(model_name, "end_id"); const size_t hidden_units = head_num * size_per_head; - const size_t inter_size = 4 * hidden_units; + const size_t inter_size = 4 * hidden_units; const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); // The length of tokens we hope this model to generate - const int request_output_len = reader.GetInteger("request", "request_output_len"); + const int request_output_len = reader.GetInteger("request", "request_output_len"); + const uint32_t memory_len = reader.GetInteger("request", "memory_len", 0); FT_CHECK(head_num % tensor_para_size == 0); FT_CHECK(decoder_layers % pipeline_para_size == 0); // Prepare the parallelism parameters - int rank, world_size, device, device_count; - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size)); + int rank = mpi::getCommWorldRank(); + int world_size = mpi::getCommWorldSize(); if (rank == 0) { printf("Total ranks: %d.\n", world_size); } + int device, device_count; check_cuda_error(cudaGetDeviceCount(&device_count)); check_cuda_error(cudaSetDevice(rank % device_count)); check_cuda_error(cudaGetDevice(&device)); @@ -122,7 +129,7 @@ void gptj_example(const INIReader reader) check_cuda_error(cudaGetDeviceProperties(&prop, device)); printf("Device %s\n", prop.name); - printf("P%d is runing with %d GPU.\n", rank, device); + printf("P%d is running with %d GPU.\n", rank, device); if (tensor_para_size * pipeline_para_size != world_size) { if (world_size % pipeline_para_size) { @@ -133,8 +140,6 @@ void gptj_example(const INIReader reader) printf("[INFO] Setting tensor_para_size to %d \n", tensor_para_size); } - const int tensor_para_rank = rank % tensor_para_size; - const int pipeline_para_rank = rank / tensor_para_size; const int layers_per_group = decoder_layers / pipeline_para_size; if (layers_per_group * pipeline_para_size != (int)decoder_layers) { printf("[ERROR] layers_per_group (%d) * pipeline_para_size (%d) should equal to decoder_layers (%ld) \n", @@ -147,46 +152,9 @@ void gptj_example(const INIReader reader) // assume gpu_num = k * n, // tensor parallelism group size is n // pipeline parallelism group size is k - - // convert WORLD communicator into 2D grid (k * n) communicator - // comms of the same row means they are in the same tensor parallel group - // comms of the same col means they are in the same pipeline parallel group - MPI_Comm grid_comm; - int dims[2] = {pipeline_para_size, tensor_para_size}; - int periods[2] = {0, 0}; - MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm); - - MPI_Comm comm_tensor_parallel, comm_pipeline_parallel; - - int remain_dims_tensor_parallel[2] = {false, true}; - int remain_dims_pipeline_parallel[2] = {true, false}; - // split 2D communicator into rows and cols, each row = one tensor parallel group, each col = one pipeline parallel - // group - MPI_Cart_sub(grid_comm, remain_dims_tensor_parallel, &comm_tensor_parallel); - MPI_Cart_sub(grid_comm, remain_dims_pipeline_parallel, &comm_pipeline_parallel); - - int rank_tensor_parallel, rank_pipeline_parallel; - MPI_Comm_rank(comm_tensor_parallel, &rank_tensor_parallel); - MPI_Comm_rank(comm_pipeline_parallel, &rank_pipeline_parallel); - - ncclUniqueId tensor_para_nccl_uid; - ncclUniqueId pipeline_para_nccl_uid; - // root of tensor parallel group and pipeline parallel group creates the nccl uid - if (rank_tensor_parallel == 0) { - NCCLCHECK(ncclGetUniqueId(&tensor_para_nccl_uid)); - } - - if (rank_pipeline_parallel == 0) { - NCCLCHECK(ncclGetUniqueId(&pipeline_para_nccl_uid)); - } - // broadcast nccl uid to the comms in the same tensor parallel group or pipeline parallel group - MPI_Bcast(&tensor_para_nccl_uid, sizeof(tensor_para_nccl_uid), MPI_BYTE, 0, comm_tensor_parallel); - MPI_Bcast(&pipeline_para_nccl_uid, sizeof(pipeline_para_nccl_uid), MPI_BYTE, 0, comm_pipeline_parallel); - - ncclComm_t tensor_para_nccl_comm, pipeline_para_nccl_comm; - NCCLCHECK(ncclCommInitRank(&tensor_para_nccl_comm, tensor_para_size, tensor_para_nccl_uid, tensor_para_rank)); - NCCLCHECK( - ncclCommInitRank(&pipeline_para_nccl_comm, pipeline_para_size, pipeline_para_nccl_uid, pipeline_para_rank)); + NcclParam tensor_para; + NcclParam pipeline_para; + ftNcclInitialize(tensor_para, pipeline_para, tensor_para_size, pipeline_para_size); // Handle bad_words dictionary std::vector bad_words; @@ -212,7 +180,7 @@ void gptj_example(const INIReader reader) cudaH2Dcpy(d_stop_words, tiled_stop_words.data(), tiled_stop_words.size()); // Read ids of request from file. - int max_input_len = -1; + int max_input_len = -1; std::vector v_start_lengths; std::vector v_start_ids; read_start_ids(request_batch_size, @@ -227,7 +195,7 @@ void gptj_example(const INIReader reader) int* d_input_lengths; if (max_input_len == 0) { // unconditional case, no input ids, so do nothing. - d_input_ids = nullptr; + d_input_ids = nullptr; d_input_lengths = nullptr; } else { @@ -237,17 +205,49 @@ void gptj_example(const INIReader reader) cudaH2Dcpy(d_input_ids, v_start_ids.data(), request_batch_size * max_input_len); cudaH2Dcpy(d_input_lengths, v_start_lengths.data(), request_batch_size); } + std::vector start_ids(request_batch_size, start_id); std::vector end_ids(request_batch_size, end_id); + // Prompt Learning Configurations + // NOTE: if you don't need prefix prompts, remember to set max_prefix_len to 0 and others to nullptr + int prompt_learning_start_id = reader.GetInteger(model_name, "prompt_learning_start_id", end_id + 1); + fastertransformer::PromptLearningType prompt_learning_type = + static_cast(reader.GetInteger(model_name, "prompt_learning_type", 0)); + + // NOTE: specify task names, take name id, prompt length in order to load those prompt learning tables. + // NOTE: Please make sure task ids are continuous and start from 0 + // for example: + // std::map> prefix_prompt_table_pair{{"no_prompt", {0, 0}}, + // {"prompt_1", {1, 1}}, + // {"prompt_2", {2, 2}}, + // {"prompt_3", {3, 3}}, + // {"prompt_4", {4, 4}}, + // {"prompt_5", {5, 5}}}; + + std::map> prefix_prompt_table_pair; + + // NOTE: get prompt table pairs from configuration files + const int num_tasks = reader.GetInteger(model_name, "num_tasks", 0); + for (int task_name_id = 0; task_name_id < num_tasks; task_name_id++) { + std::string config_task_name = model_name + "_task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prefix_prompt_table_pair.insert({task_name, {task_name_id, prompt_length}}); + } + + // NOTE: task_name_ids for each sequence in one batch + // Each sequence can have different prompt learning task ids + std::vector prefix_prompt_task_ids{}; + const int total_output_len = max_input_len + request_output_len; if (total_output_len > (int)max_seq_len) { printf("[ERROR] total_output_len (%d) should be <= max_seq_len (%ld). \n", total_output_len, max_seq_len); exit(-1); } - cudaStream_t stream; - cublasHandle_t cublas_handle; + cudaStream_t stream; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cudaStreamCreate(&stream); cublasCreate(&cublas_handle); @@ -257,39 +257,44 @@ void gptj_example(const INIReader reader) Allocator allocator(getDevice()); - std::mutex* cublas_wrapper_mutex = new std::mutex(); + std::mutex* cublas_wrapper_mutex = new std::mutex(); cublasMMWrapper cublas_wrapper = cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); if (std::is_same::value) { cublas_wrapper.setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } - fastertransformer::GptJWeight gpt_weights(hidden_units, - inter_size, - vocab_size, - decoder_layers, - max_seq_len, - tensor_para_size, - tensor_para_rank, - pipeline_para_size, - pipeline_para_rank); - - model_dir = model_dir + "/" + std::to_string(tensor_para_size) + "-gpu/"; + fastertransformer::GptJWeight gpt_weights( + hidden_units, + inter_size, + vocab_size, + decoder_layers, + max_seq_len, + tensor_para.world_size_, + tensor_para.rank_, + pipeline_para.world_size_, + pipeline_para.rank_, + prompt_learning_type, + prefix_prompt_table_pair); // optional if you don't need prefix prompts + + model_dir = model_dir + "/" + std::to_string(tensor_para.world_size_) + "-gpu/"; gpt_weights.loadModel(model_dir); unsigned long long random_seed; if (rank == 0) { random_seed = (unsigned long long)(0); } if (world_size > 1) { - MPICHECK(MPI_Bcast(&random_seed, 1, MPI_UNSIGNED_LONG_LONG, 0, MPI_COMM_WORLD)); + mpi::bcast(&random_seed, 1, mpi::MPI_TYPE_UNSIGNED_LONG_LONG, 0, mpi::COMM_WORLD); } - NcclParam tensor_para(tensor_para_rank, tensor_para_size, tensor_para_nccl_comm); - NcclParam pipeline_para(pipeline_para_rank, pipeline_para_size, pipeline_para_nccl_comm); - GptJ gpt = GptJ(0, // max_batch_size, FT will adjust the buffer automatically. 0, // max_seq_len, FT will adjust the buffer automatically. 0, // max_input_len, FT will adjust the buffer automatically. @@ -302,6 +307,8 @@ void gptj_example(const INIReader reader) rotary_embedding_dim, start_id, end_id, + prompt_learning_start_id, + prompt_learning_type, 0.0f, top_k, top_p, @@ -321,11 +328,16 @@ void gptj_example(const INIReader reader) int* d_sequence_lengths; deviceMalloc(&d_output_ids, request_batch_size * beam_width * total_output_len, false); deviceMalloc(&d_sequence_lengths, request_batch_size * beam_width, false); + std::vector output_seq_len(request_batch_size, total_output_len); std::unordered_map input_tensors = std::unordered_map{ {"input_ids", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, (size_t)max_input_len}, d_input_ids}}, {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size}, d_input_lengths}}, - {"max_output_seq_len", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &total_output_len}}, + // NOTE: if you need prefix prompts, remember to add prefix_prompt_task_ids here + // {"prompt_learning_task_name_ids", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, + // prefix_prompt_task_ids.data()}}, + {"output_seq_len", + Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}, {"bad_words_list", Tensor{MEMORY_GPU, TYPE_INT32, {2, bad_words.size() / 2}, d_bad_words}}, {"stop_words_list", Tensor{MEMORY_GPU, TYPE_INT32, {request_batch_size, 2, stop_words_len}, d_stop_words}}, {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &temperature}}, @@ -344,9 +356,12 @@ void gptj_example(const INIReader reader) input_tensors.insert({"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &top_p}}); } if (top_k != 0) { - input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &top_k}}); + input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{1}, &top_k}}); } } + if (memory_len > 0) { + input_tensors.insert({"memory_len", {MEMORY_CPU, TYPE_UINT32, {1}, &memory_len}}); + } std::unordered_map output_tensors = std::unordered_map{ {"output_ids", @@ -362,11 +377,11 @@ void gptj_example(const INIReader reader) std::vector{(size_t)request_output_len, request_batch_size, beam_width}, nullptr}}}; - print_mem_usage(); + print_mem_usage("After loading model"); int ite = 1; cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); cudaProfilerStart(); // warm up @@ -376,22 +391,23 @@ void gptj_example(const INIReader reader) for (int i = 0; i < ite; ++i) { gpt.forward(&output_tensors, &input_tensors, &gpt_weights); } + print_mem_usage("After forward"); cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); POP_RANGE; nvtx::resetScope(); if (rank == 0) { - std::string fName = "out"; - auto outFile = std::ofstream(fName, std::ios::out); + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); if (!outFile.is_open()) { printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); } else { size_t outCount = total_output_len * request_batch_size * beam_width; - int* hBuf = new int[outCount]; + int* hBuf = new int[outCount]; cudaD2Hcpy(hBuf, d_output_ids, outCount); { @@ -421,7 +437,7 @@ void gptj_example(const INIReader reader) // test time struct timeval start, end; - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); cudaDeviceSynchronize(); gettimeofday(&start, NULL); @@ -432,7 +448,7 @@ void gptj_example(const INIReader reader) } cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); POP_RANGE; nvtx::resetScope(); @@ -451,8 +467,8 @@ void gptj_example(const INIReader reader) vocab_size, ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); - ncclCommDestroy(tensor_para_nccl_comm); - ncclCommDestroy(pipeline_para_nccl_comm); + ftNcclParamDestroy(tensor_para); + ftNcclParamDestroy(pipeline_para); delete cublas_algo_map; delete cublas_wrapper_mutex; @@ -465,6 +481,12 @@ void gptj_example(const INIReader reader) if (d_input_lengths != nullptr) { cudaFree(d_input_lengths); } + if (d_output_ids != nullptr) { + cudaFree(d_output_ids); + } + if (d_sequence_lengths != nullptr) { + cudaFree(d_sequence_lengths); + } return; } diff --git a/examples/cpp/gptj/gptj_triton_example.cc b/examples/cpp/gptj/gptj_triton_example.cc index 0d55ad6e3..b9e283de7 100644 --- a/examples/cpp/gptj/gptj_triton_example.cc +++ b/examples/cpp/gptj/gptj_triton_example.cc @@ -20,6 +20,7 @@ #include "src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h" #include "src/fastertransformer/utils/custom_ar_comm.h" #include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" #include "src/fastertransformer/utils/word_list.h" #include @@ -28,35 +29,35 @@ namespace ft = fastertransformer; struct RequestParam { - int beam_width; - int request_output_len; - float beam_search_diversity_rate; - int runtime_top_k; - float runtime_top_p; - float temperature; - float len_penalty; - float repetition_penalty; + int beam_width; + int request_output_len; + float beam_search_diversity_rate; + uint runtime_top_k; + float runtime_top_p; + float temperature; + float len_penalty; + float repetition_penalty; unsigned long long int random_seed; - int start_id; - int end_id; + int start_id; + int end_id; }; std::vector>> broadCastRequest(const std::vector& v_start_ids, const std::vector& v_start_lengths, const std::vector& v_bad_words, - const int node_id, - const int gpu_count, - const RequestParam param, - std::vector* pointer_record) + const int node_id, + const int gpu_count, + const RequestParam param, + std::vector* pointer_record) { // broadcast the request to all nodes, and copy "gpu_count" copies on different gpu - int size_1 = v_start_ids.size(); - int size_2 = v_start_lengths.size(); + int size_1 = v_start_ids.size(); + int size_2 = v_start_lengths.size(); int size_bad_words = v_bad_words.size(); - MPICHECK(MPI_Bcast(&size_1, 1, MPI_INT, 0, MPI_COMM_WORLD)); - MPICHECK(MPI_Bcast(&size_2, 1, MPI_INT, 0, MPI_COMM_WORLD)); - MPICHECK(MPI_Bcast(&size_bad_words, 1, MPI_INT, 0, MPI_COMM_WORLD)); + ft::mpi::bcast(&size_1, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_2, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_bad_words, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); std::vector v_input_ids(size_1); std::vector v_input_lengths(size_2); @@ -67,14 +68,14 @@ broadCastRequest(const std::vector& v_start_ids, memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int)); memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int)); } - MPI_Barrier(MPI_COMM_WORLD); + ft::mpi::barrier(); int request_batch_size = size_2; - int max_input_len = size_1 / size_2; + int max_input_len = size_1 / size_2; - MPICHECK(MPI_Bcast(v_input_ids.data(), size_1, MPI_INT, 0, MPI_COMM_WORLD)); - MPICHECK(MPI_Bcast(v_input_lengths.data(), size_2, MPI_INT, 0, MPI_COMM_WORLD)); - MPICHECK(MPI_Bcast(v_input_bad_words.data(), size_bad_words, MPI_INT, 0, MPI_COMM_WORLD)); + ft::mpi::bcast(v_input_ids.data(), size_1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_lengths.data(), size_2, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_bad_words.data(), size_bad_words, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); std::vector>> request_list; for (int device_id = 0; device_id < gpu_count; device_id++) { @@ -86,9 +87,9 @@ broadCastRequest(const std::vector& v_start_ids, if (max_input_len == 0) { // unconditional case, no input ids, so do nothing. - d_input_ids = nullptr; + d_input_ids = nullptr; d_input_lengths = nullptr; - max_input_len = 0; + max_input_len = 0; } else { // conditional case. @@ -100,13 +101,16 @@ broadCastRequest(const std::vector& v_start_ids, ft::deviceMalloc(&d_input_bad_words, size_bad_words, false); ft::cudaH2Dcpy(d_input_bad_words, v_input_bad_words.data(), size_bad_words); - int* request_output_len_ptr = new int((int)(param.request_output_len)); + uint32_t* request_output_len_ptr = (uint32_t*)malloc(request_batch_size * sizeof(uint32_t)); + for (int i = 0; i < request_batch_size; i++) { + request_output_len_ptr[i] = param.request_output_len; + } int* start_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); - int* end_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + int* end_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); for (int i = 0; i < request_batch_size; i++) { start_ids_ptr[i] = param.start_id; - end_ids_ptr[i] = param.end_id; + end_ids_ptr[i] = param.end_id; } pointer_record->push_back(start_ids_ptr); pointer_record->push_back(end_ids_ptr); @@ -123,9 +127,14 @@ broadCastRequest(const std::vector& v_start_ids, triton::TYPE_INT32, std::vector{(size_t)request_batch_size}, d_input_lengths}}, + // NOTE: add prefix prompt task ids here if you need + // {"prefix_prompt_task_ids", triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, + // std::vector{request_batch_size}, task_name_ids}}, {"request_output_len", - triton::Tensor{ - triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{(size_t)1}, request_output_len_ptr}}, + triton::Tensor{triton::MEMORY_CPU, + triton::TYPE_UINT32, + std::vector{(size_t)request_batch_size}, + request_output_len_ptr}}, {"bad_words_list", triton::Tensor{ triton::MEMORY_GPU, triton::TYPE_INT32, {2, v_input_bad_words.size() / 2}, d_input_bad_words}}, @@ -156,12 +165,12 @@ broadCastRequest(const std::vector& v_start_ids, triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, runtime_top_p_ptr}}); } if (param.runtime_top_k != 0) { - int* runtime_top_k_ptr = new int(param.runtime_top_k); + uint* runtime_top_k_ptr = new uint(param.runtime_top_k); pointer_record->push_back(runtime_top_k_ptr); request_list[device_id]->insert( {"runtime_top_k", triton::Tensor{ - triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, runtime_top_k_ptr}}); + triton::MEMORY_CPU, triton::TYPE_UINT32, std::vector{1}, runtime_top_k_ptr}}); } } float* temperature_ptr = new float(param.temperature); @@ -203,10 +212,10 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std ft::FT_CHECK(false); } - const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); - - const int start_id = reader.GetInteger("gptj_6B", "start_id"); - const int end_id = reader.GetInteger("gptj_6B", "end_id"); + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + const int start_id = reader.GetInteger(model_name, "start_id"); + const int end_id = reader.GetInteger(model_name, "end_id"); std::vector v_start_ids; std::vector v_start_lengths; @@ -224,45 +233,46 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std ft::read_word_list("../examples/cpp/gptj/bad_words.csv", v_bad_words); RequestParam param; - param.beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); - param.request_output_len = reader.GetInteger("request", "request_output_len"); + param.beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); + param.request_output_len = reader.GetInteger("request", "request_output_len"); param.beam_search_diversity_rate = reader.GetFloat("ft_instance_hyperparameter", "beam_search_diversity_rate"); - param.runtime_top_k = reader.GetInteger("ft_instance_hyperparameter", "top_k"); - param.runtime_top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); - param.temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); - param.len_penalty = reader.GetFloat("ft_instance_hyperparameter", "len_penalty"); - param.repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); - param.random_seed = (unsigned long long int)0; - param.start_id = start_id; - param.end_id = end_id; + param.runtime_top_k = (uint)reader.GetInteger("ft_instance_hyperparameter", "top_k"); + param.runtime_top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); + param.temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); + param.len_penalty = reader.GetFloat("ft_instance_hyperparameter", "len_penalty"); + param.repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); + param.random_seed = (unsigned long long int)0; + param.start_id = start_id; + param.end_id = end_id; auto request_list = broadCastRequest(v_start_ids, v_start_lengths, v_bad_words, node_id, gpu_count, param, pointer_record); return request_list; } -int threadCreateModelInstances(std::shared_ptr model, - std::vector>* model_instances, - const int device_id, - const int rank, - std::pair, std::vector> nccl_comms, +int threadCreateModelInstances(std::shared_ptr model, + std::vector>* model_instances, + const int device_id, + const int rank, + std::pair, std::vector> nccl_params, std::shared_ptr custom_all_reduce_comm = nullptr) { printf("[INFO] rank = %d \n", rank); ft::check_cuda_error(cudaSetDevice(device_id)); cudaStream_t stream; ft::check_cuda_error(cudaStreamCreate(&stream)); - auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_comms, custom_all_reduce_comm); + model->createSharedWeights(device_id, rank); + auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_params, custom_all_reduce_comm); model_instances->at(device_id) = std::move(model_instance); printf("model instance %d is created \n", device_id); ft::print_mem_usage(); return 0; } -int threadForward(std::unique_ptr* model_instance, - std::shared_ptr> request, +int threadForward(std::unique_ptr* model_instance, + std::shared_ptr> request, std::shared_ptr>* output_tensors, - const int device_id) + const int device_id) { ft::check_cuda_error(cudaSetDevice(device_id)); *output_tensors = (*model_instance)->forward(request); @@ -276,38 +286,26 @@ int main(int argc, char* argv[]) by MPI or triton */ - MPICHECK(MPI_Init(&argc, &argv)); - int node_id; - int node_num; - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &node_id)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &node_num)); + ft::mpi::initialize(&argc, &argv); + int node_id = ft::mpi::getCommWorldRank(); + int node_num = ft::mpi::getCommWorldSize(); // Note: Only supports that all nodes have same gpu count - const int gpu_count = ft::getDeviceCount(); - const int world_size = node_num * gpu_count; - std::string ini_name = argc >= 2 ? std::string(argv[1]) : "../examples/cpp/gptj/gptj_config.ini"; + const int gpu_count = ft::getDeviceCount(); + const int world_size = node_num * gpu_count; + std::string ini_name = argc >= 2 ? std::string(argv[1]) : "../examples/cpp/gptj/gptj_config.ini"; // step 1: Create model - std::shared_ptr model = AbstractTransformerModel::createGptJModel(ini_name); + std::shared_ptr model = AbstractTransformerModel::createGptJModel(ini_name); + int tensor_para_size = model->getTensorParaSize(); + int pipeline_para_size = model->getPipelineParaSize(); + ft::FT_CHECK_WITH_INFO(world_size == (tensor_para_size * pipeline_para_size), + "World Size != Tensor Parallel Size * Pipeline Parallel Size !"); + std::cout << model->toString(); // step 2: Initialize the NCCL - std::vector nccl_ids; - if (node_id == 0) { - nccl_ids = model->createNcclIds(world_size); - } - int nccl_size = nccl_ids.size(); - MPI_Barrier(MPI_COMM_WORLD); - MPICHECK(MPI_Bcast(&nccl_size, 1, MPI_INT, 0, MPI_COMM_WORLD)); - if (node_id != 0) { - nccl_ids.resize(nccl_size); - } - MPI_Barrier(MPI_COMM_WORLD); - for (size_t i = 0; i < nccl_ids.size(); i++) { - MPICHECK(MPI_Bcast(&nccl_ids[i], sizeof(nccl_ids[i]), MPI_BYTE, 0, MPI_COMM_WORLD)); - } - MPI_Barrier(MPI_COMM_WORLD); - std::pair, std::vector> nccl_comms = model->createNcclComms(nccl_ids, node_id); + std::pair, std::vector> nccl_params = model->createNcclParams(node_id); cudaDeviceSynchronize(); // Optional Step: create custom all reduce comm @@ -316,7 +314,7 @@ int main(int argc, char* argv[]) // step 3: Create model instances std::vector> model_instances((size_t)gpu_count); - std::vector threads; + std::vector threads; for (int device_id = 0; device_id < gpu_count; device_id++) { const int rank = node_id * gpu_count + device_id; threads.push_back(std::thread(threadCreateModelInstances, @@ -324,7 +322,7 @@ int main(int argc, char* argv[]) &model_instances, device_id, rank, - nccl_comms, + nccl_params, custom_all_reduce_comms[rank])); } for (auto& t : threads) { @@ -356,20 +354,20 @@ int main(int argc, char* argv[]) printf("[INFO] forward is completed. \n"); const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data; - const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; - const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; - const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; + const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; + const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; + const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; // step 6: check results if (node_id == 0) { - std::string fName = "out"; - auto outFile = std::ofstream(fName, std::ios::out); + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); if (!outFile.is_open()) { printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); } else { size_t outCount = batch_size * beam_width * seq_len; - int* hBuf = new int[outCount]; + int* hBuf = new int[outCount]; ft::cudaD2Hcpy(hBuf, d_output_ids, outCount); { @@ -399,7 +397,7 @@ int main(int argc, char* argv[]) // test time struct timeval start, end; - MPI_Barrier(MPI_COMM_WORLD); + ft::mpi::barrier(); cudaDeviceSynchronize(); gettimeofday(&start, NULL); @@ -419,7 +417,7 @@ int main(int argc, char* argv[]) } cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + ft::mpi::barrier(); gettimeofday(&end, NULL); @@ -430,6 +428,6 @@ int main(int argc, char* argv[]) seq_len, ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); - MPICHECK(MPI_Finalize()); + ft::mpi::finalize(); return 0; } diff --git a/examples/cpp/gptneox/CMakeLists.txt b/examples/cpp/gptneox/CMakeLists.txt new file mode 100644 index 000000000..7177f1c80 --- /dev/null +++ b/examples/cpp/gptneox/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(gptneox_example gptneox_example.cc) +target_link_libraries(gptneox_example PUBLIC -lcublas -lcublasLt -lcudart + GptNeoX mpi_utils nccl_utils nvtx_utils gpt_example_utils word_list) + +add_executable(gptneox_triton_example gptneox_triton_example.cc) +target_link_libraries(gptneox_triton_example PUBLIC -lcublas -lcublasLt -lcudart + GptNeoXTritonBackend TransformerTritonBackend custom_ar_comm + mpi_utils nccl_utils gpt_example_utils word_list -lpthread) diff --git a/examples/cpp/gptneox/bad_words.csv b/examples/cpp/gptneox/bad_words.csv new file mode 100644 index 000000000..6a1126ebd --- /dev/null +++ b/examples/cpp/gptneox/bad_words.csv @@ -0,0 +1,2 @@ +7768,3908 +1,2 diff --git a/examples/cpp/gptneox/gptneox_config.ini b/examples/cpp/gptneox/gptneox_config.ini new file mode 100644 index 000000000..528b51dd2 --- /dev/null +++ b/examples/cpp/gptneox/gptneox_config.ini @@ -0,0 +1,31 @@ +[ft_instance_hyperparameter] +data_type=fp16 +enable_custom_all_reduce=0 + +tensor_para_size=2 +pipeline_para_size=1 + +model_name=gptneox_20B +model_dir=../models/gptneox + +[request] +beam_width=1 # beam width for beam search +top_k=1 ; k value for top k sampling +top_p=0.0 ; p value for top p sampling +temperature=1.0 ; Use for sampling +repetition_penalty=1.0 ; Use for sampling +len_penalty=0.0 +beam_search_diversity_rate=0.0 +request_batch_size=8 # determine by the request +request_output_len=32 # determine by the request + +[gptneox_20B] +head_num=64 +size_per_head=96 +vocab_size=50432 +decoder_layers=44 +rotary_embedding=24 +start_id=0 +end_id=2 +inter_size=24576 +use_gptj_residual=1 diff --git a/examples/cpp/gptneox/gptneox_example.cc b/examples/cpp/gptneox/gptneox_example.cc new file mode 100644 index 000000000..ecfcf0437 --- /dev/null +++ b/examples/cpp/gptneox/gptneox_example.cc @@ -0,0 +1,481 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" +#include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" +#include "src/fastertransformer/models/gptneox/GptNeoX.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include "src/fastertransformer/utils/word_list.h" + +#include +#include +#include +#include +#include +#include + +#ifdef USE_NVTX +bool NVTX_ON = true; +#endif + +using namespace fastertransformer; + +template +void gptneox_example(const INIReader reader); + +int main(int argc, char* argv[]) +{ + mpi::initialize(&argc, &argv); + srand(0); + + std::string ini_name; + if (argc == 2) { + ini_name = std::string(argv[1]); + } + else { + ini_name = "../examples/cpp/gptneox/gptneox_config.ini"; + } + + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + return -1; + } + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + + if (data_type == "fp32") { + gptneox_example(reader); + } + else if (data_type == "fp16") { + gptneox_example(reader); + } + else { + FT_LOG_ERROR("is_fp16 should be 0 (use float) or 1 (use half)."); + return -1; + } + mpi::finalize(); + return 0; +} + +template +void gptneox_example(const INIReader reader) +{ + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); + + int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); + + const size_t head_num = reader.GetInteger(model_name, "head_num"); + const size_t size_per_head = reader.GetInteger(model_name, "size_per_head"); + const size_t vocab_size = reader.GetInteger(model_name, "vocab_size"); + const size_t decoder_layers = reader.GetInteger(model_name, "decoder_layers"); + const size_t rotary_embedding_dim = reader.GetInteger(model_name, "rotary_embedding"); + const int start_id = reader.GetInteger(model_name, "start_id"); + const int end_id = reader.GetInteger(model_name, "end_id"); + + const size_t hidden_units = head_num * size_per_head; + const size_t inter_size = 4 * hidden_units; + + const size_t beam_width = reader.GetInteger("request", "beam_width"); + const uint top_k = (uint)reader.GetInteger("request", "top_k"); + const float top_p = reader.GetFloat("request", "top_p"); + const float temperature = reader.GetFloat("request", "temperature"); + const float repetition_penalty = reader.GetFloat("request", "repetition_penalty"); + const float len_penalty = reader.GetFloat("request", "len_penalty"); + const float beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate"); + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + // The length of tokens we hope this model to generate + const int request_output_len = reader.GetInteger("request", "request_output_len"); + + FT_CHECK(head_num % tensor_para_size == 0); + FT_CHECK(decoder_layers % pipeline_para_size == 0); + + // Prepare the parallelism parameters + int rank = mpi::getCommWorldRank(); + int world_size = mpi::getCommWorldSize(); + if (rank == 0) { + printf("Total ranks: %d.\n", world_size); + } + int device, device_count; + check_cuda_error(cudaGetDeviceCount(&device_count)); + check_cuda_error(cudaSetDevice(rank % device_count)); + check_cuda_error(cudaGetDevice(&device)); + + struct cudaDeviceProp prop; + check_cuda_error(cudaGetDeviceProperties(&prop, device)); + printf("Device %s\n", prop.name); + + printf("P%d is running with GPU #%d.\n", rank, device); + if (tensor_para_size * pipeline_para_size != world_size) { + if (world_size % pipeline_para_size) { + printf("[ERROR] tensor_para_size * pipeline_para_size should equal to world_size \n"); + exit(-1); + } + tensor_para_size = world_size / pipeline_para_size; + printf("[INFO] Setting tensor_para_size to %d \n", tensor_para_size); + } + + const int layers_per_group = decoder_layers / pipeline_para_size; + if (layers_per_group * pipeline_para_size != (int)decoder_layers) { + printf("[ERROR] layers_per_group (%d) * pipeline_para_size (%d) should equal to decoder_layers (%ld) \n", + layers_per_group, + pipeline_para_size, + decoder_layers); + exit(-1); + } + + // assume gpu_num = k * n, + // tensor parallelism group size is n + // pipeline parallelism group size is k + NcclParam tensor_para; + NcclParam pipeline_para; + ftNcclInitialize(tensor_para, pipeline_para, tensor_para_size, pipeline_para_size); + + // Handle bad_words dictionary + std::vector bad_words; + read_word_list("../examples/cpp/gptneox/bad_words.csv", bad_words); + + int* d_bad_words = nullptr; + deviceMalloc(&d_bad_words, bad_words.size(), false); + cudaH2Dcpy(d_bad_words, bad_words.data(), bad_words.size()); + + // Handle stop_words dictionary + std::vector stop_words; + read_word_list("../examples/cpp/gptneox/stop_words.csv", stop_words); + + const size_t stop_words_len = stop_words.size() / 2; + // Tile with same dict for each element + std::vector tiled_stop_words; + for (int i = 0; i < request_batch_size; i++) { + tiled_stop_words.insert(tiled_stop_words.end(), stop_words.begin(), stop_words.end()); + } + + int* d_stop_words = nullptr; + deviceMalloc(&d_stop_words, tiled_stop_words.size(), false); + cudaH2Dcpy(d_stop_words, tiled_stop_words.data(), tiled_stop_words.size()); + + // Read ids of request from file. + int max_input_len = -1; + std::vector v_start_lengths; + std::vector v_start_ids; + read_start_ids(request_batch_size, + &v_start_lengths, + &v_start_ids, + max_input_len, + end_id, + 1, + "../examples/cpp/gptneox/start_ids.csv"); + + int* d_input_ids; + int* d_input_lengths; + if (max_input_len == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids = nullptr; + d_input_lengths = nullptr; + } + else { + // conditional case. + deviceMalloc(&d_input_ids, request_batch_size * max_input_len, false); + deviceMalloc(&d_input_lengths, request_batch_size, false); + cudaH2Dcpy(d_input_ids, v_start_ids.data(), request_batch_size * max_input_len); + cudaH2Dcpy(d_input_lengths, v_start_lengths.data(), request_batch_size); + } + std::vector start_ids(request_batch_size, start_id); + std::vector end_ids(request_batch_size, end_id); + + // Prompt Learning Configurations + // NOTE: if you don't need prefix prompts, remember to set max_prefix_len to 0 and others to nullptr + int prompt_learning_start_id = reader.GetInteger(model_name, "prompt_learning_start_id", end_id + 1); + fastertransformer::PromptLearningType prompt_learning_type = + static_cast(reader.GetInteger(model_name, "prompt_learning_type", 0)); + + // NOTE: specify task names, take name id, prompt length in order to load those prompt learning tables. + // NOTE: Please make sure task ids are continuous and start from 0 + // for example: + // std::map> prefix_prompt_table_pair{{"no_prompt", {0, 0}}, + // {"prompt_1", {1, 1}}, + // {"prompt_2", {2, 2}}, + // {"prompt_3", {3, 3}}, + // {"prompt_4", {4, 4}}, + // {"prompt_5", {5, 5}}}; + + std::map> prefix_prompt_table_pair; + + // NOTE: get prompt table pairs from configuration files + const int num_tasks = reader.GetInteger(model_name, "num_tasks", 0); + for (int task_name_id = 0; task_name_id < num_tasks; task_name_id++) { + std::string config_task_name = model_name + "_task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prefix_prompt_table_pair.insert({task_name, {task_name_id, prompt_length}}); + } + + // NOTE: task_name_ids for each sequence in one batch + // Each sequence can have different prompt learning task ids + std::vector prefix_prompt_task_ids(request_batch_size, 0); + + // Set different task ids + for (int i = 0; i < request_batch_size; i++) { + prefix_prompt_task_ids[i] = (num_tasks > 0) ? i % num_tasks : 0; + } + + const int total_output_len = max_input_len + request_output_len; + + cudaStream_t stream; + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + cudaStreamCreate(&stream); + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); + cublasAlgoMap* cublas_algo_map = new cublasAlgoMap("gemm_config.in"); + + Allocator allocator(getDevice()); + + std::mutex* cublas_wrapper_mutex = new std::mutex(); + cublasMMWrapper cublas_wrapper = + cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); + if (std::is_same::value) { + cublas_wrapper.setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } + else if (std::is_same::value) { + cublas_wrapper.setFP32GemmConfig(); + } + + // GPT NeoX Residual Type + const bool use_gptj_residual = (bool)reader.GetInteger(model_name, "use_gptj_residual", 1); + fastertransformer::GptNeoXWeight gpt_weights(hidden_units, + inter_size, + vocab_size, + decoder_layers, + 0, // max_seq_len, deprecated + tensor_para.world_size_, + tensor_para.rank_, + pipeline_para.world_size_, + pipeline_para.rank_, + use_gptj_residual, + prompt_learning_type, + prefix_prompt_table_pair); + + model_dir = model_dir + "/" + std::to_string(tensor_para.world_size_) + "-gpu"; + gpt_weights.loadModel(model_dir); + unsigned long long random_seed; + if (rank == 0) { + random_seed = (unsigned long long)(0); + } + if (world_size > 1) { + mpi::bcast(&random_seed, 1, mpi::MPI_TYPE_UNSIGNED_LONG_LONG, 0, mpi::COMM_WORLD); + } + + GptNeoX gpt = GptNeoX(head_num, + size_per_head, + inter_size, + decoder_layers, + vocab_size, + rotary_embedding_dim, + start_id, + end_id, + prompt_learning_start_id, + prompt_learning_type, + use_gptj_residual, + 0.0f, + top_k, + top_p, + random_seed, + temperature, + len_penalty, + repetition_penalty, + tensor_para, + pipeline_para, + stream, + &cublas_wrapper, + &allocator, + false, + &prop); + + int* d_output_ids; + int* d_sequence_lengths; + deviceMalloc(&d_output_ids, request_batch_size * beam_width * total_output_len, false); + deviceMalloc(&d_sequence_lengths, request_batch_size * beam_width, false); + std::vector output_seq_len(request_batch_size, total_output_len); + std::unordered_map input_tensors = std::unordered_map{ + {"input_ids", + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, (size_t)max_input_len}, d_input_ids}}, + {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size}, d_input_lengths}}, + // NOTE: if you need prefix prompts, remember to add prefix_prompt_task_ids here + // {"prompt_learning_task_name_ids", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, + // prefix_prompt_task_ids.data()}}, + {"output_seq_len", + Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}, + {"bad_words_list", Tensor{MEMORY_GPU, TYPE_INT32, {2, bad_words.size() / 2}, d_bad_words}}, + {"stop_words_list", Tensor{MEMORY_GPU, TYPE_INT32, {request_batch_size, 2, stop_words_len}, d_stop_words}}, + {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &temperature}}, + {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &len_penalty}}, + {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &repetition_penalty}}, + {"start_id", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, start_ids.data()}}, + {"end_id", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, end_ids.data()}}}; + + if (num_tasks > 0) { + // Prefix Prompt Task Name Ids here + input_tensors.insert( + {"prompt_learning_task_name_ids", + Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, prefix_prompt_task_ids.data()}}); + } + + if (top_k == 0 && top_p == 0.0f) { + FT_CHECK(beam_width > 1); + input_tensors.insert({"beam_search_diversity_rate", + Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &beam_search_diversity_rate}}); + } + else { + input_tensors.insert({"random_seed", Tensor{MEMORY_CPU, TYPE_UINT64, std::vector{1}, &random_seed}}); + if (top_p != 0.0f) { + input_tensors.insert({"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &top_p}}); + } + if (top_k != 0) { + input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{1}, &top_k}}); + } + } + + std::unordered_map output_tensors = std::unordered_map{ + {"output_ids", + Tensor{MEMORY_GPU, + TYPE_INT32, + std::vector{request_batch_size, beam_width, (size_t)total_output_len}, + d_output_ids}}, + {"sequence_length", + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths}}, + {"output_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + std::vector{(size_t)request_output_len, request_batch_size, beam_width}, + nullptr}}}; + + print_mem_usage(); + + int ite = 1; + cudaDeviceSynchronize(); + mpi::barrier(); + + cudaProfilerStart(); + // warm up + ite = 1; + nvtx::setScope("warmup_time"); + PUSH_RANGE("warmup time") + for (int i = 0; i < ite; ++i) { + gpt.forward(&output_tensors, &input_tensors, &gpt_weights); + } + cudaDeviceSynchronize(); + mpi::barrier(); + + POP_RANGE; + nvtx::resetScope(); + + if (rank == 0) { + + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); + if (!outFile.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); + } + else { + size_t outCount = total_output_len * request_batch_size * beam_width; + int* hBuf = new int[outCount]; + cudaD2Hcpy(hBuf, d_output_ids, outCount); + + { + std::cout << "Writing " << outCount << " elements\n"; + int zeroCount = 0; + for (size_t i = 0; i < outCount; i++) { + if (hBuf[i] == int(0)) { + zeroCount++; + } + outFile << hBuf[i] << " "; + if ((i + 1) % (total_output_len) == 0) { + outFile << std::endl; + } + + if (i < 10) { + printf("%5d ", hBuf[i]); + } + if ((i + 1) % (total_output_len) == 0 && i < 10) { + std::cout << std::endl; + } + } + std::cout << std::endl << "zeroCount = " << zeroCount << std::endl; + } + delete[] hBuf; + } + } + + // test time + struct timeval start, end; + mpi::barrier(); + cudaDeviceSynchronize(); + gettimeofday(&start, NULL); + + nvtx::setScope("total_time"); + PUSH_RANGE("total time") + for (int i = 0; i < ite; ++i) { + gpt.forward(&output_tensors, &input_tensors, &gpt_weights); + } + + cudaDeviceSynchronize(); + mpi::barrier(); + + POP_RANGE; + nvtx::resetScope(); + gettimeofday(&end, NULL); + + cudaProfilerStop(); + + printf("[INFO] request_batch_size %ld beam_width %ld head_num %ld size_per_head %ld total_output_len %d" + " decoder_layers %ld vocab_size %ld FT-CPP-decoding-beamsearch-time %.2f ms\n", + request_batch_size, + beam_width, + head_num, + size_per_head, + total_output_len, + decoder_layers, + vocab_size, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + + ftNcclParamDestroy(tensor_para); + ftNcclParamDestroy(pipeline_para); + + delete cublas_algo_map; + delete cublas_wrapper_mutex; + + cudaFree(d_bad_words); + cudaFree(d_stop_words); + if (d_input_ids != nullptr) { + cudaFree(d_input_ids); + } + if (d_input_lengths != nullptr) { + cudaFree(d_input_lengths); + } + if (d_output_ids != nullptr) { + deviceFree(d_output_ids); + } + if (d_sequence_lengths != nullptr) { + deviceFree(d_sequence_lengths); + } + + return; +} diff --git a/examples/cpp/gptneox/gptneox_triton_example.cc b/examples/cpp/gptneox/gptneox_triton_example.cc new file mode 100644 index 000000000..3e83f37a5 --- /dev/null +++ b/examples/cpp/gptneox/gptneox_triton_example.cc @@ -0,0 +1,427 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" +#include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" +#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h" +#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/word_list.h" + +#include +#include + +namespace ft = fastertransformer; + +struct RequestParam { + int beam_width; + int request_output_len; + float beam_search_diversity_rate; + uint runtime_top_k; + float runtime_top_p; + float temperature; + float len_penalty; + float repetition_penalty; + unsigned long long int random_seed; + int start_id; + int end_id; +}; + +std::vector>> +broadCastRequest(const std::vector& v_start_ids, + const std::vector& v_start_lengths, + const std::vector& v_bad_words, + const int node_id, + const int gpu_count, + const RequestParam param, + std::vector* pointer_record) +{ + // broadcast the request to all nodes, and copy "gpu_count" copies on different gpu + int size_1 = v_start_ids.size(); + int size_2 = v_start_lengths.size(); + int size_bad_words = v_bad_words.size(); + ft::mpi::bcast(&size_1, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_2, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_bad_words, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + + std::vector v_input_ids(size_1); + std::vector v_input_lengths(size_2); + std::vector v_input_bad_words(size_bad_words); + + if (node_id == 0) { + memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int)); + memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int)); + memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int)); + } + ft::mpi::barrier(); + + int request_batch_size = size_2; + int max_input_len = size_1 / size_2; + + ft::mpi::bcast(v_input_ids.data(), size_1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_lengths.data(), size_2, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_bad_words.data(), size_bad_words, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + + std::vector>> request_list; + for (int device_id = 0; device_id < gpu_count; device_id++) { + ft::check_cuda_error(cudaSetDevice(device_id)); + + int* d_input_ids; + int* d_input_lengths; + int* d_input_bad_words; + + if (max_input_len == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids = nullptr; + d_input_lengths = nullptr; + max_input_len = 0; + } + else { + // conditional case. + ft::deviceMalloc(&d_input_ids, size_1, false); + ft::deviceMalloc(&d_input_lengths, size_2, false); + ft::cudaH2Dcpy(d_input_ids, v_input_ids.data(), size_1); + ft::cudaH2Dcpy(d_input_lengths, v_input_lengths.data(), size_2); + } + ft::deviceMalloc(&d_input_bad_words, size_bad_words, false); + ft::cudaH2Dcpy(d_input_bad_words, v_input_bad_words.data(), size_bad_words); + + uint32_t* request_output_len_ptr = (uint32_t*)malloc(request_batch_size * sizeof(uint32_t)); + for (int i = 0; i < request_batch_size; i++) { + request_output_len_ptr[i] = param.request_output_len; + } + + int* start_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + int* end_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + for (int i = 0; i < request_batch_size; i++) { + start_ids_ptr[i] = param.start_id; + end_ids_ptr[i] = param.end_id; + } + pointer_record->push_back(start_ids_ptr); + pointer_record->push_back(end_ids_ptr); + + request_list.push_back(std::shared_ptr>( + new std::unordered_map{ + {"input_ids", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size, (size_t)max_input_len}, + d_input_ids}}, + {"input_lengths", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + d_input_lengths}}, + {"request_output_len", + triton::Tensor{triton::MEMORY_CPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + request_output_len_ptr}}, + {"bad_words_list", + triton::Tensor{ + triton::MEMORY_GPU, triton::TYPE_INT32, {2, v_input_bad_words.size() / 2}, d_input_bad_words}}, + {"start_id", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, start_ids_ptr}}, + {"end_id", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, end_ids_ptr}}})); + + int* beam_width_ptr = new int(param.beam_width); + pointer_record->push_back(beam_width_ptr); + request_list[device_id]->insert( + {"beam_width", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, beam_width_ptr}}); + if (param.beam_width > 1) { + float* beam_search_diversity_rate_ptr = new float(param.beam_search_diversity_rate); + pointer_record->push_back(beam_search_diversity_rate_ptr); + request_list[device_id]->insert( + {"beam_search_diversity_rate", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, beam_search_diversity_rate_ptr}}); + } + else { + if (param.runtime_top_p != 0.0f) { + float* runtime_top_p_ptr = new float(param.runtime_top_p); + pointer_record->push_back(runtime_top_p_ptr); + request_list[device_id]->insert( + {"runtime_top_p", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, runtime_top_p_ptr}}); + } + if (param.runtime_top_k != 0) { + uint* runtime_top_k_ptr = new uint(param.runtime_top_k); + pointer_record->push_back(runtime_top_k_ptr); + request_list[device_id]->insert( + {"runtime_top_k", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_UINT32, std::vector{1}, runtime_top_k_ptr}}); + } + } + float* temperature_ptr = new float(param.temperature); + pointer_record->push_back(temperature_ptr); + request_list[device_id]->insert( + {"temperature", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, temperature_ptr}}); + float* len_penalty_ptr = new float(param.len_penalty); + pointer_record->push_back(len_penalty_ptr); + request_list[device_id]->insert( + {"len_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, len_penalty_ptr}}); + float* repetition_penalty_ptr = new float(param.repetition_penalty); + pointer_record->push_back(repetition_penalty_ptr); + request_list[device_id]->insert( + {"repetition_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, repetition_penalty_ptr}}); + unsigned long long int* random_seed_ptr = new unsigned long long int(param.random_seed); + pointer_record->push_back(random_seed_ptr); + request_list[device_id]->insert( + {"random_seed", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_UINT64, std::vector{1}, random_seed_ptr}}); + + pointer_record->push_back(d_input_ids); + pointer_record->push_back(d_input_lengths); + pointer_record->push_back(d_input_bad_words); + pointer_record->push_back(request_output_len_ptr); + } + + return request_list; +} + +std::vector>> +prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std::vector* pointer_record) +{ + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + ft::FT_CHECK(false); + } + + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + + const int start_id = reader.GetInteger("gptneox_20B", "start_id"); + const int end_id = reader.GetInteger("gptneox_20B", "end_id"); + + std::vector v_start_ids; + std::vector v_start_lengths; + + int max_input_len = 0; + ft::read_start_ids(request_batch_size, + &v_start_lengths, + &v_start_ids, + max_input_len, + end_id, + 1, + "../examples/cpp/gptneox/start_ids.csv"); + + std::vector v_bad_words; + ft::read_word_list("../examples/cpp/gptneox/bad_words.csv", v_bad_words); + + RequestParam param; + param.beam_width = reader.GetInteger("request", "beam_width"); + param.request_output_len = reader.GetInteger("request", "request_output_len"); + param.beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate"); + param.runtime_top_k = reader.GetInteger("request", "top_k"); + param.runtime_top_p = reader.GetFloat("request", "top_p"); + param.temperature = reader.GetFloat("request", "temperature"); + param.len_penalty = reader.GetFloat("request", "len_penalty"); + param.repetition_penalty = reader.GetFloat("request", "repetition_penalty"); + param.random_seed = (unsigned long long int)0; + param.start_id = start_id; + param.end_id = end_id; + + auto request_list = + broadCastRequest(v_start_ids, v_start_lengths, v_bad_words, node_id, gpu_count, param, pointer_record); + return request_list; +} + +int threadCreateModelInstances(std::shared_ptr model, + std::vector>* model_instances, + const int device_id, + const int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) +{ + printf("[INFO] rank = %d \n", rank); + ft::check_cuda_error(cudaSetDevice(device_id)); + cudaStream_t stream; + ft::check_cuda_error(cudaStreamCreate(&stream)); + model->createSharedWeights(device_id, rank); + auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_params, custom_all_reduce_comm); + model_instances->at(device_id) = std::move(model_instance); + printf("model instance %d is created \n", device_id); + ft::print_mem_usage(); + return 0; +} + +int threadForward(std::unique_ptr* model_instance, + std::shared_ptr> request, + std::shared_ptr>* output_tensors, + const int device_id) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + *output_tensors = (*model_instance)->forward(request); + return 0; +} + +int main(int argc, char* argv[]) +{ + /* + Prepare the nccl ids, node id, device id and world size + by MPI or triton + */ + + MPICHECK(MPI_Init(&argc, &argv)); + ft::mpi::initialize(&argc, &argv); + int node_id = ft::mpi::getCommWorldRank(); + int node_num = ft::mpi::getCommWorldSize(); + + // Note: Only supports that all nodes have same gpu count + const int gpu_count = ft::getDeviceCount(); + const int world_size = node_num * gpu_count; + std::string ini_name = argc >= 2 ? std::string(argv[1]) : "../examples/cpp/gptneox/gptneox_config.ini"; + + // step 1: Create model + std::shared_ptr model = AbstractTransformerModel::createGptNeoXModel(ini_name); + int tensor_para_size = model->getTensorParaSize(); + int pipeline_para_size = model->getPipelineParaSize(); + ft::FT_CHECK_WITH_INFO(world_size == (tensor_para_size * pipeline_para_size), + "World Size != Tensor Parallel Size * Pipeline Parallel Size !"); + + std::cout << model->toString(); + + // step 2: Initialize the NCCL + std::pair, std::vector> nccl_comms = model->createNcclParams(node_id); + cudaDeviceSynchronize(); + + // Optional Step: create custom all reduce comm + std::vector> custom_all_reduce_comms; + model->createCustomComms(&custom_all_reduce_comms, world_size); + + // step 3: Create model instances + std::vector> model_instances((size_t)gpu_count); + std::vector threads; + for (int device_id = 0; device_id < gpu_count; device_id++) { + const int rank = node_id * gpu_count + device_id; + threads.push_back(std::thread(threadCreateModelInstances, + model, + &model_instances, + device_id, + rank, + nccl_comms, + custom_all_reduce_comms[rank])); + } + for (auto& t : threads) { + t.join(); + } + + // step 4: prepare request + std::vector pointer_record; // Used to prevent the pointers are release after leaving functions + std::vector>> request_list = + prepareRequest(ini_name, node_id, gpu_count, &pointer_record); + printf("[INFO] request is created \n"); + + // step 5: Forward + std::vector>> output_tensors_lists( + (size_t)gpu_count); + for (int i = 0; i < 2; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id)); + } + for (auto& t : threads) { + t.join(); + } + } + printf("[INFO] forward is completed. \n"); + + const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data; + const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; + const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; + const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; + // step 6: check results + if (node_id == 0) { + + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); + if (!outFile.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); + } + else { + size_t outCount = batch_size * beam_width * seq_len; + int* hBuf = new int[outCount]; + ft::cudaD2Hcpy(hBuf, d_output_ids, outCount); + + { + std::cout << "Writing " << outCount << " elements\n"; + int zeroCount = 0; + for (size_t i = 0; i < outCount; i++) { + if (hBuf[i] == int(0)) + zeroCount++; + outFile << hBuf[i] << " "; + if ((i + 1) % (seq_len) == 0) + outFile << std::endl; + + if (i < 10) + printf("%5d ", hBuf[i]); + if ((i + 1) % (seq_len) == 0 && i < 10) + std::cout << std::endl; + } + std::cout << std::endl << "zeroCount = " << zeroCount << std::endl; + } + delete[] hBuf; + } + } + + // test time + struct timeval start, end; + ft::mpi::barrier(); + cudaDeviceSynchronize(); + gettimeofday(&start, NULL); + + const int ite = 1; + for (int i = 0; i < ite; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id)); + } + for (auto& t : threads) { + t.join(); + } + } + + cudaDeviceSynchronize(); + ft::mpi::barrier(); + + gettimeofday(&end, NULL); + + printf("[INFO] batch_size %d beam_width %d seq_len %d" + " FT-CPP-GPT-Triton-time %.2f ms\n", + batch_size, + beam_width, + seq_len, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + + ft::mpi::finalize(); + return 0; +} diff --git a/examples/cpp/gptneox/start_ids.csv b/examples/cpp/gptneox/start_ids.csv new file mode 100644 index 000000000..88e742f39 --- /dev/null +++ b/examples/cpp/gptneox/start_ids.csv @@ -0,0 +1,8 @@ +688, 253, 1390, 4564, 273, 1897, 13, 247 +510, 1457, 8911, 4487, 273, 26593, 310, 6600 +510, 1457, 2816, 28260, 452, 247, 747, 1481 +510, 1457, 2816, 7717, 556, 3863, 697, 7970 +688, 247, 2118, 326, 588, 2779, 1056, 352 +510, 1457, 2816, 28260, 8, 13413, 19169, 14745 +510, 9462, 5687, 556, 38350, 26212, 253, 747 +510, 806, 673, 309, 3047, 253, 6440, 13 \ No newline at end of file diff --git a/examples/cpp/gptneox/stop_words.csv b/examples/cpp/gptneox/stop_words.csv new file mode 100644 index 000000000..9b9b09eba --- /dev/null +++ b/examples/cpp/gptneox/stop_words.csv @@ -0,0 +1,2 @@ +287, 4346, 12 +3, -1, -1 diff --git a/examples/cpp/multi_gpu_gpt/CMakeLists.txt b/examples/cpp/multi_gpu_gpt/CMakeLists.txt index b535c2da0..b7d26b5e1 100644 --- a/examples/cpp/multi_gpu_gpt/CMakeLists.txt +++ b/examples/cpp/multi_gpu_gpt/CMakeLists.txt @@ -17,12 +17,18 @@ target_link_libraries(gpt_example_utils PUBLIC -lcudart) add_executable(multi_gpu_gpt_example multi_gpu_gpt_example.cc) target_link_libraries(multi_gpu_gpt_example PUBLIC -lcublas -lcublasLt -lcudart - ParallelGpt nvtx_utils -lmpi gpt_example_utils) + ParallelGpt nvtx_utils mpi_utils nccl_utils gpt_example_utils) add_executable(multi_gpu_gpt_async_example multi_gpu_gpt_async_example.cc) target_link_libraries(multi_gpu_gpt_async_example PUBLIC -lcublas -lcublasLt -lcudart - ParallelGpt nvtx_utils -lmpi gpt_example_utils) + ParallelGpt nvtx_utils mpi_utils nccl_utils gpt_example_utils) add_executable(multi_gpu_gpt_triton_example multi_gpu_gpt_triton_example.cc) target_link_libraries(multi_gpu_gpt_triton_example PUBLIC -lcublas -lcublasLt -lcudart - ParallelGptTritonBackend memory_utils custom_ar_comm -lmpi gpt_example_utils -lpthread) + ParallelGptTritonBackend TransformerTritonBackend memory_utils + custom_ar_comm mpi_utils nccl_utils gpt_example_utils -lpthread) + +add_executable(multi_gpu_gpt_interactive_example multi_gpu_gpt_interactive_example.cc) +target_link_libraries(multi_gpu_gpt_interactive_example PUBLIC -lcublas -lcublasLt -lcudart + ParallelGpt nvtx_utils mpi_utils nccl_utils gpt_example_utils) + diff --git a/examples/cpp/multi_gpu_gpt/concat_interactive_ids.csv b/examples/cpp/multi_gpu_gpt/concat_interactive_ids.csv new file mode 100644 index 000000000..990c3f9e4 --- /dev/null +++ b/examples/cpp/multi_gpu_gpt/concat_interactive_ids.csv @@ -0,0 +1,8 @@ +818, 262, 938, 3155, 286, 1528, 11, 257, 1256, 286, 661, 423, 587, 4737, 502, 546, 262, 649, 1492, 11, 290, 314, 1053, 587, 2111, 284, 3280, 617, 286, 262, 2683, 326, 661, 423, 587, 4737, 502, 13, 198, 198, 5962, 11, 314, 561, 588, 284, 910, 326 +198, 464, 968, 8221, 2732, 286, 15198, 318, 1762, 351, 262, 1181, 338, 9358, 5011, 284, 5004, 262, 1266, 835, 284, 1445, 262, 4979, 13, 198, 1, 1135, 821, 1016, 284, 307, 2045, 379, 262, 1266, 835, 284, 1445, 262, 11125, 286, 2844, 291, 5028, 422, 262, 7627 +464, 968, 1971, 12056, 423, 257, 649, 1182, 3985, 11, 290, 339, 338, 257, 3516, 508, 338, 587, 1088, 262, 4652, 329, 257, 890, 640, 13, 679, 338, 257, 3516, 508, 338, 587, 1088, 262, 4652, 329, 257, 890, 640, 392, 257, 1913, 1998, 351, 1353, 12, 28282 +464, 968, 1971, 3782, 468, 3199, 663, 5079, 1351, 286, 262, 995, 338, 749, 14212, 661, 13, 198, 464, 1351, 11, 543, 373, 14102, 416, 262, 968, 1971, 3782, 11, 318, 1912, 319, 257, 5526, 286, 517, 621, 352, 11, 830, 34643, 11, 7602, 11, 4708, 6332, 1938 +818, 257, 1445, 326, 481, 1884, 787, 340, 4577, 329, 262, 1664, 284, 3677, 663, 7303, 11, 262, 1664, 468, 4987, 284, 3677, 663, 10171, 287, 262, 1664, 284, 257, 1448, 286, 7713, 2957, 416, 262, 2839, 13598, 4081, 309, 5, 38328, 763, 13, 1119, 481, 2148, 257 +464, 968, 1971, 12056, 6, 5859, 41683, 423, 587, 257, 1263, 636, 286, 262, 1074, 338, 1943, 428, 1622, 13, 198, 464, 12056, 423, 587, 1498, 284, 1057, 262, 2613, 6840, 11, 290, 484, 423, 587, 1498, 284, 1057, 262, 3245, 355, 257, 22080, 1074, 13, 4042, 286 +198, 198, 464, 5398, 4332, 628, 628, 198, 198, 464, 5398, 4332, 628, 628, 198, 198, 464, 5398, 4332, 628, 628, 198, 198, 464, 5398, 4332, 628, 628, 198, 198, 464, 5398, 4332, 628, 628, 198, 198, 464, 5398, 4332, 14150, 26443, 262, 1230, 338, 1410, 284, 3958 +464, 717, 640, 314, 2497, 262, 3807, 11, 314, 373, 588, 11, 705, 5812, 616, 1793, 11, 428, 318, 523, 3608, 2637, 314, 373, 588, 11, 705, 40, 765, 284, 307, 287, 428, 3807, 2637, 314, 373, 588, 11, 705, 5195, 4398, 470, 314, 7342, 340, 2961, 30 diff --git a/examples/cpp/multi_gpu_gpt/gpt_config.ini b/examples/cpp/multi_gpu_gpt/gpt_config.ini index 5eefa05d6..72cfc0756 100644 --- a/examples/cpp/multi_gpu_gpt/gpt_config.ini +++ b/examples/cpp/multi_gpu_gpt/gpt_config.ini @@ -14,19 +14,25 @@ int8_mode=0 enable_custom_all_reduce=0 ; model_name=gpt_124M model_name=megatron_345M +; model_name=megatron_1.3B_adapter ; model_name=megatron_6.7B +; model_name=megatron_20B ; model_name=gpt_175B +; model_name=opt_125M +; model_name=opt_350M ; model_name=self_defined ; model_dir=./models/megatron-models/c-model/6.7b/ -model_dir=../models/megatron-models/c-model/345m/8-gpu/ -len_penalty=1.0 +model_dir=../models/megatron-models/c-model/345m/1-gpu/ +len_penalty=0.0 beam_search_diversity_rate=0.0 +shared_contexts_ratio=1.0 [request] request_batch_size=8 ; determine by the request request_output_len=32 ; determine by the request return_log_probs=false ; return the output log probs and cumulative log probs. context_log_probs=false ; include input contexts in the cumulative log probability computation. +remove_padding=true [gpt_124M] head_num=12 @@ -36,6 +42,22 @@ decoder_layers=12 start_id=50256 end_id=50256 inter_size=3072 +num_tasks=3 ;optional +prompt_learning_start_id=50257 ;optional +prompt_learning_type=3 ;optional + +;prompt learning example (optional) +[gpt_124M_task_0] ; task_name_id = 0 +task_name=sentiment +prompt_length=10 +;optional +[gpt_124M_task_1] ; task_name_id = 1 +task_name=intent_and_slot +prompt_length=10 +;optional +[gpt_124M_task_2] ; task_name_id = 2 +task_name=squad +prompt_length=16 [megatron_345M] head_num=16 @@ -46,6 +68,18 @@ start_id=50256 end_id=50256 inter_size=4096 +[megatron_1.3B_adapter] +head_num=32 +size_per_head=64 +vocab_size=50304 +decoder_layers=24 +start_id=50256 +end_id=50256 +inter_size=8192 +layernorm_eps=1e-5 +adapter_inter_size=1024 +has_adapters=true + [megatron_6.7B] head_num=32 size_per_head=128 @@ -55,6 +89,15 @@ start_id=50256 end_id=50256 inter_size=16384 +[megatron_20B] +head_num=48 +size_per_head=128 +vocab_size=51200 +decoder_layers=44 +start_id=50256 +end_id=50256 +inter_size=24576 + [gpt_175B] head_num=96 size_per_head=128 @@ -64,6 +107,26 @@ start_id=50256 end_id=50256 inter_size=49152 +[opt_125M] +head_num=12 +size_per_head=64 +vocab_size=50272 +decoder_layers=12 +start_id=2 +end_id=2 +inter_size=3072 +model_variant=opt-pre ;define variant structure + +[opt_350M] +head_num=16 +size_per_head=64 +vocab_size=50272 +decoder_layers=24 +start_id=2 +end_id=2 +inter_size=4096 +model_variant=opt-post + [self_defined] head_num=16 size_per_head=64 diff --git a/examples/cpp/multi_gpu_gpt/gpt_example_utils.cc b/examples/cpp/multi_gpu_gpt/gpt_example_utils.cc index 5a008b60a..497478566 100644 --- a/examples/cpp/multi_gpu_gpt/gpt_example_utils.cc +++ b/examples/cpp/multi_gpu_gpt/gpt_example_utils.cc @@ -20,26 +20,26 @@ namespace fastertransformer { -int read_start_ids(int batch_size, +int read_start_ids(int batch_size, std::vector* v_start_lengths, std::vector* v_start_ids, - int& max_input_len, - const int end_id, - const int beam_width, - std::string file_name) + int& max_input_len, + const int end_id, + const int beam_width, + std::string file_name) { std::vector> tmp_start_ids; - std::vector tmp_start_lengths; + std::vector tmp_start_lengths; std::ifstream start_id_file(file_name, std::ios::in); if (start_id_file.is_open()) { std::string line; - int i0 = 0; + int i0 = 0; while (std::getline(start_id_file, line)) { std::stringstream lineStream(line); - std::string vals; - int i1 = 0; - std::vector tmp_vec; + std::string vals; + int i1 = 0; + std::vector tmp_vec; while (std::getline(lineStream, vals, ',')) { tmp_vec.push_back(std::stoi(vals)); i1++; diff --git a/examples/cpp/multi_gpu_gpt/gpt_example_utils.h b/examples/cpp/multi_gpu_gpt/gpt_example_utils.h index 2cfc411fa..80ee9a534 100644 --- a/examples/cpp/multi_gpu_gpt/gpt_example_utils.h +++ b/examples/cpp/multi_gpu_gpt/gpt_example_utils.h @@ -19,12 +19,12 @@ namespace fastertransformer { -int read_start_ids(int batch_size, +int read_start_ids(int batch_size, std::vector* v_start_lengths, std::vector* v_start_ids, - int& max_input_len, - const int end_id, - const int beam_width, - std::string file_name); + int& max_input_len, + const int end_id, + const int beam_width, + std::string file_name); } // namespace fastertransformer diff --git a/examples/cpp/multi_gpu_gpt/interactive_inputs_ids.csv b/examples/cpp/multi_gpu_gpt/interactive_inputs_ids.csv new file mode 100644 index 000000000..f04028f49 --- /dev/null +++ b/examples/cpp/multi_gpu_gpt/interactive_inputs_ids.csv @@ -0,0 +1,8 @@ +5962, 11, 314, 561, 588, 284, 910, 326 +11125, 286, 2844, 291, 5028, 422, 262, 7627 +392, 257, 1913, 1998, 351, 1353, 12, 28282 +830, 34643, 11, 7602, 11, 4708, 6332, 1938 +5, 38328, 763, 13, 1119, 481, 2148, 257 +3245, 355, 257, 22080, 1074, 13, 4042, 286 +14150, 26443, 262, 1230, 338, 1410, 284, 3958 +5195, 4398, 470, 314, 7342, 340, 2961, 30 diff --git a/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_async_example.cc b/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_async_example.cc index 81286c591..226bdd096 100644 --- a/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_async_example.cc +++ b/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_async_example.cc @@ -32,11 +32,12 @@ #include "3rdparty/INIReader.h" #include "src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h" #include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" #include "src/fastertransformer/utils/nvtx_utils.h" -static bool USE_ASYNC = true; -const int START_TOKEN_ID = 50256; -const int END_TOKEN_ID = 50256; +static bool USE_ASYNC = true; +const int START_TOKEN_ID = 50256; +const int END_TOKEN_ID = 50256; #ifdef USE_NVTX bool NVTX_ON = true; @@ -62,8 +63,8 @@ std::string join(const T* arr, const int length, const std::string sep = " ") template std::string toString(const T* arr, const int length, const bool is_device_array = true) { - size_t size = sizeof(T) * length; - T* h_arr; + size_t size = sizeof(T) * length; + T* h_arr; std::string token_ids_str; if (is_device_array) { h_arr = (T*)malloc(size); @@ -89,20 +90,20 @@ class GptStreamer { const int max_output_length; // including both input and generated tokens. // results - int* output_ids; - int* sequence_lengths; + int* output_ids; + int* sequence_lengths; bool* finished; // decoder settings - ParallelGpt* gpt_; - std::unordered_map* output_tensors_; + ParallelGpt* gpt_; + std::unordered_map* output_tensors_; const std::unordered_map* input_tensors_; - const ParallelGptWeight* gpt_weights_; + const ParallelGptWeight* gpt_weights_; // streamer status and internal buffers - int prev_step = 0; - int curr_step = 0; - bool is_generation_done = false; + int prev_step = 0; + int curr_step = 0; + bool is_generation_done = false; cudaStream_t stream; /** @@ -194,17 +195,17 @@ class GptStreamer { void streamDecoding() { - int input_len = input_tensors_->at("input_ids").shape[1]; + int input_len = input_tensors_->at("input_ids").shape[1]; int max_output_len = output_tensors_->at("output_ids").shape[0]; - int batch_size = output_tensors_->at("output_ids").shape[0]; + int batch_size = output_tensors_->at("output_ids").shape[0]; // initialization - is_generation_done = false; - prev_step = 0; - curr_step = input_len; - int* seqlen_buf_ = new int[batch_size]; + is_generation_done = false; + prev_step = 0; + curr_step = input_len; + int* seqlen_buf_ = new int[batch_size]; bool* decoding_finished_buf_ = new bool[batch_size]; - bool* finished = new bool[batch_size]; + bool* finished = new bool[batch_size]; std::fill(seqlen_buf_, seqlen_buf_ + batch_size, input_len); std::fill(decoding_finished_buf_, decoding_finished_buf_ + batch_size, false); std::fill(finished, finished + batch_size, false); @@ -287,9 +288,9 @@ class GptStreamer { GptStreamer(int max_batch_size, int max_output_length): max_batch_size(max_batch_size), max_output_length(max_output_length) { - output_ids = new int[max_output_length * max_batch_size]; + output_ids = new int[max_output_length * max_batch_size]; sequence_lengths = new int[max_batch_size]; - finished = new bool[max_batch_size]; + finished = new bool[max_batch_size]; cudaStreamCreate(&stream); } @@ -301,23 +302,23 @@ class GptStreamer { cudaStreamDestroy(stream); } - void initialize(ParallelGpt* gpt, - std::unordered_map* output_tensors, + void initialize(ParallelGpt* gpt, + std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const ParallelGptWeight* gpt_weights) + const ParallelGptWeight* gpt_weights) { - gpt_ = gpt; + gpt_ = gpt; output_tensors_ = output_tensors; - input_tensors_ = input_tensors; - gpt_weights_ = gpt_weights; + input_tensors_ = input_tensors; + gpt_weights_ = gpt_weights; int total_output_tokens = max_output_length * max_batch_size; std::fill(output_ids, output_ids + total_output_tokens, 0); std::fill(sequence_lengths, sequence_lengths + max_batch_size, output_tensors_->at("sequence_length").shape[0]); std::fill(finished, finished + max_batch_size, false); - prev_step = 0; - curr_step = 0; + prev_step = 0; + curr_step = 0; is_generation_done = false; } @@ -325,7 +326,7 @@ class GptStreamer { * \brief Forward a model and asynchronously check whether to stop. * * The device having the last rank of a pipeline parallel group checks and - * broadcasts to the other devices. So only the last rank runs asychronously + * broadcasts to the other devices. So only the last rank runs asynchronously * and monitor whether to terminate by given stop criteria. * * For now, we provide a streaming function as a separated example with @@ -333,15 +334,15 @@ class GptStreamer { * * \param gpt An ParallelGpt pointer to generate tokens. * \param output_tensors A vector of tensors, containing the output tensors of gpt, including - * output_ids, parent_ids, sequence_lengths and cum_log_probs + * output_ids, sequence_lengths and cum_log_probs * \param input_tensors A vector of tensors, containing the input tensors of gpt, including * input_ids, input_lengths and request_output_len * \param gpt_weights A ParallelGptWeight pointer, which continas the weights of gpt model */ - void run(ParallelGpt* gpt, - std::unordered_map* output_tensors, + void run(ParallelGpt* gpt, + std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const ParallelGptWeight* gpt_weights) + const ParallelGptWeight* gpt_weights) { initialize(gpt, output_tensors, input_tensors, gpt_weights); // Only the last rank of pipeline parallel will run asynchronously @@ -366,7 +367,7 @@ class GptFileStreamer: public GptStreamer { protected: const std::string output_file; - std::ofstream ofs; + std::ofstream ofs; void streamHook(const int prev_step, const int curr_step, const int* output_ids) override { @@ -405,7 +406,7 @@ void multi_gpu_gpt_example(const INIReader reader); int main(int argc, char* argv[]) { - MPICHECK(MPI_Init(&argc, &argv)); + mpi::initialize(&argc, &argv); srand(0); std::string ini_name; @@ -449,30 +450,30 @@ int main(int argc, char* argv[]) printf("[ERROR] data_type should be fp32, fp16 or bf16 ! \n"); return -1; } - MPI_Finalize(); + mpi::finalize(); return 0; } -int read_start_ids(int batch_size, +int read_start_ids(int batch_size, std::vector* v_start_lengths, std::vector* v_start_ids, - int& max_input_len, - const int end_id, - const int beam_width) + int& max_input_len, + const int end_id, + const int beam_width) { std::vector> tmp_start_ids; - std::vector tmp_start_lengths; + std::vector tmp_start_lengths; - std::string file_name = "../examples/cpp/multi_gpu_gpt/start_ids.csv"; + std::string file_name = "../examples/cpp/multi_gpu_gpt/start_ids.csv"; std::ifstream start_id_file(file_name, std::ios::in); if (start_id_file.is_open()) { std::string line; - int i0 = 0; + int i0 = 0; while (std::getline(start_id_file, line)) { std::stringstream lineStream(line); - std::string vals; - int i1 = 0; - std::vector tmp_vec; + std::string vals; + int i1 = 0; + std::vector tmp_vec; while (std::getline(lineStream, vals, ',')) { tmp_vec.push_back(std::stoi(vals)); i1++; @@ -523,37 +524,38 @@ int read_start_ids(int batch_size, template void multi_gpu_gpt_example(const INIReader reader) { - const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); - const size_t max_batch_size = reader.GetInteger("ft_instance_hyperparameter", "max_batch_size"); - const size_t max_seq_len = reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"); - const size_t beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); - const int top_k = reader.GetInteger("ft_instance_hyperparameter", "top_k"); - const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); - const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); - const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); - const std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); - const bool sparse = static_cast(reader.GetInteger("ft_instance_hyperparameter", "sparse")); - const float len_penalty = 1.0f; - const float beam_search_diversity_rate = 0.0f; - - const int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + const size_t max_batch_size = reader.GetInteger("ft_instance_hyperparameter", "max_batch_size"); + const size_t max_seq_len = reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"); + const size_t beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); + const uint top_k = (uint)reader.GetInteger("ft_instance_hyperparameter", "top_k"); + const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); + const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); + const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); + const std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); + const bool sparse = static_cast(reader.GetInteger("ft_instance_hyperparameter", "sparse")); + const float len_penalty = reader.GetFloat("ft_instance_hyperparameter", "len_penalty"); + const float beam_search_diversity_rate = + reader.GetFloat("ft_instance_hyperparameter", "beam_search_diversity_rate"); + + const int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); const int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); const int int8_mode = reader.GetInteger("ft_instance_hyperparameter", "int8_mode"); - const size_t head_num = reader.GetInteger(model_name, "head_num"); - const size_t size_per_head = reader.GetInteger(model_name, "size_per_head"); - const size_t vocab_size = reader.GetInteger(model_name, "vocab_size"); + const size_t head_num = reader.GetInteger(model_name, "head_num"); + const size_t size_per_head = reader.GetInteger(model_name, "size_per_head"); + const size_t vocab_size = reader.GetInteger(model_name, "vocab_size"); const size_t decoder_layers = reader.GetInteger(model_name, "decoder_layers"); - const size_t hidden_units = head_num * size_per_head; - const size_t inter_size = 4 * hidden_units; + const size_t hidden_units = head_num * size_per_head; + const size_t inter_size = 4 * hidden_units; const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); // The length of tokens we hope this model to generate const int request_output_len = reader.GetInteger("request", "request_output_len"); const int start_id = 50256; - const int end_id = 50256; + const int end_id = 50256; if (USE_ASYNC) { FT_CHECK(beam_width == 1); // async forward does not support beam search @@ -563,12 +565,12 @@ void multi_gpu_gpt_example(const INIReader reader) FT_CHECK(decoder_layers % pipeline_para_size == 0); // Prepare the parallelism parameters - int rank, world_size, device, device_count; - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size)); + int rank = mpi::getCommWorldRank(); + int world_size = mpi::getCommWorldSize(); if (rank == 0) { printf("Total ranks: %d.\n", world_size); } + int device, device_count; check_cuda_error(cudaGetDeviceCount(&device_count)); check_cuda_error(cudaSetDevice(rank % device_count)); check_cuda_error(cudaGetDevice(&device)); @@ -577,16 +579,16 @@ void multi_gpu_gpt_example(const INIReader reader) check_cuda_error(cudaGetDeviceProperties(&prop, device)); printf("Device %s\n", prop.name); - printf("P%d is runing with %d GPU.\n", rank, device); + printf("P%d is running with %d GPU.\n", rank, device); if (tensor_para_size * pipeline_para_size != world_size) { printf("[ERROR] tensor_para_size * pipeline_para_size should equal to world_size \n"); exit(-1); } - const int tensor_para_rank = rank % tensor_para_size; + const int tensor_para_rank = rank % tensor_para_size; const int pipeline_para_rank = rank / tensor_para_size; - const int layers_per_group = decoder_layers / pipeline_para_size; + const int layers_per_group = decoder_layers / pipeline_para_size; if (layers_per_group * pipeline_para_size != (int)decoder_layers) { printf("[ERROR] layers_per_group (%d) * pipeline_para_size (%d) should equal to decoder_layers (%ld) \n", layers_per_group, @@ -594,70 +596,17 @@ void multi_gpu_gpt_example(const INIReader reader) decoder_layers); exit(-1); } - ncclUniqueId tensor_para_nccl_uid; - ncclUniqueId pipeline_para_nccl_uid; // assume gpu_num = n * k, // tensor parallelism group size is n // pipeline parallelism group size is k - if (tensor_para_rank == 0) { - // get the uid of each tensor parallelism group - // here, 0, 1, ..., n-1 are in group 0, - // n, ..., 2n - 1 are in group 1. - NCCLCHECK(ncclGetUniqueId(&tensor_para_nccl_uid)); - for (int i = 1; i < tensor_para_size; i++) { - printf("[INFO] rank %d sends tensor_para_nccl_uid to rank %d \n", rank, rank + i); - MPICHECK( - MPI_Send(&tensor_para_nccl_uid, sizeof(tensor_para_nccl_uid), MPI_BYTE, rank + i, 0, MPI_COMM_WORLD)); - } - } - else { - MPI_Status status; - printf("[INFO] rank %d receives tensor_para_nccl_uid from rank %d \n", rank, rank - tensor_para_rank); - MPICHECK(MPI_Recv(&tensor_para_nccl_uid, - sizeof(tensor_para_nccl_uid), - MPI_BYTE, - rank - tensor_para_rank, - 0, - MPI_COMM_WORLD, - &status)); - } - - if (pipeline_para_rank == 0) { - // get the uid of each pipeline parallelism group - // 0, k, 2k, are in group 0 - // 1, k+1, 2k+1 are in group 1 - NCCLCHECK(ncclGetUniqueId(&pipeline_para_nccl_uid)); - for (int i = 1; i < pipeline_para_size; i++) { - printf("[INFO] rank %d sends pipeline_para_nccl_uid to rank %d \n", rank, rank + i * tensor_para_size); - MPICHECK(MPI_Send(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank + i * tensor_para_size, - 0, - MPI_COMM_WORLD)); - } - } - else { - MPI_Status status; - printf("[INFO] rank %d receives pipeline_para_nccl_uid from rank %d \n", rank, rank % tensor_para_size); - MPICHECK(MPI_Recv(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank % tensor_para_size, - 0, - MPI_COMM_WORLD, - &status)); - } - - ncclComm_t tensor_para_nccl_comm, pipeline_para_nccl_comm; - NCCLCHECK(ncclCommInitRank(&tensor_para_nccl_comm, tensor_para_size, tensor_para_nccl_uid, tensor_para_rank)); - NCCLCHECK( - ncclCommInitRank(&pipeline_para_nccl_comm, pipeline_para_size, pipeline_para_nccl_uid, pipeline_para_rank)); + NcclParam tensor_para; + NcclParam pipeline_para; + ftNcclInitialize(tensor_para, pipeline_para, tensor_para_size, pipeline_para_size); // Read ids of request from file. - int max_input_len = -1; + int max_input_len = -1; std::vector v_start_lengths; std::vector v_start_ids; read_start_ids(request_batch_size, &v_start_lengths, &v_start_ids, max_input_len, end_id, beam_width); @@ -666,9 +615,9 @@ void multi_gpu_gpt_example(const INIReader reader) int* d_input_lengths; if (max_input_len == 0) { // unconditional case, no input ids, so do nothing. - d_input_ids = nullptr; + d_input_ids = nullptr; d_input_lengths = nullptr; - max_input_len = 0; + max_input_len = 0; } else { // conditional case. @@ -684,8 +633,8 @@ void multi_gpu_gpt_example(const INIReader reader) exit(-1); } - cudaStream_t stream; - cublasHandle_t cublas_handle; + cudaStream_t stream; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cudaStreamCreate(&stream); cublasCreate(&cublas_handle); @@ -695,7 +644,7 @@ void multi_gpu_gpt_example(const INIReader reader) Allocator allocator(getDevice()); - std::mutex* cublas_wrapper_mutex = new std::mutex(); + std::mutex* cublas_wrapper_mutex = new std::mutex(); cublasMMWrapper cublas_wrapper = cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); if (std::is_same::value) { @@ -710,19 +659,18 @@ void multi_gpu_gpt_example(const INIReader reader) cublas_wrapper.setFP32GemmConfig(); } - fastertransformer::ParallelGptWeight gpt_weights(hidden_units, - inter_size, - vocab_size, - decoder_layers, - max_seq_len, - tensor_para_size, - tensor_para_rank, - pipeline_para_size, - pipeline_para_rank, - int8_mode); + ParallelGptWeight gpt_weights(hidden_units, + inter_size, + vocab_size, + decoder_layers, + max_seq_len, + tensor_para_size, + tensor_para_rank, + pipeline_para_size, + pipeline_para_rank, + int8_mode); gpt_weights.loadModel(model_dir); - NcclParam tensor_para(tensor_para_rank, tensor_para_size, tensor_para_nccl_comm); - NcclParam pipeline_para(pipeline_para_rank, pipeline_para_size, pipeline_para_nccl_comm); + unsigned long long int random_seed = 0; ParallelGpt gpt = ParallelGpt(0, // max_batch_size, FT will adjust the buffer automatically. @@ -736,12 +684,15 @@ void multi_gpu_gpt_example(const INIReader reader) vocab_size, start_id, end_id, + end_id + 1, // p_prompt_tuning token start id + PromptLearningType::no_prompt, + gptVariantParams{}, 0.0f, top_k, top_p, random_seed, temperature, - 1.0f, // len_penalty, + 0.0f, // len_penalty, repetition_penalty, tensor_para, pipeline_para, @@ -754,11 +705,10 @@ void multi_gpu_gpt_example(const INIReader reader) int8_mode); int* d_output_ids; - int* d_parent_ids; int* d_sequence_lengths; deviceMalloc(&d_output_ids, request_batch_size * beam_width * total_output_len, false); - deviceMalloc(&d_parent_ids, request_batch_size * beam_width * total_output_len, false); deviceMalloc(&d_sequence_lengths, request_batch_size * beam_width, false); + std::vector output_seq_len(request_batch_size, total_output_len); std::unordered_map input_tensors = std::unordered_map{ {"input_ids", @@ -768,7 +718,8 @@ void multi_gpu_gpt_example(const INIReader reader) d_input_ids}}, {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size * beam_width}, d_input_lengths}}, - {"max_output_seq_len", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &total_output_len}}}; + {"output_seq_len", + Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}}; if (top_k == 0 && top_p == 0.0f) { FT_CHECK(beam_width > 1); input_tensors.insert({"beam_search_diversity_rate", @@ -779,7 +730,7 @@ void multi_gpu_gpt_example(const INIReader reader) input_tensors.insert({"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &top_p}}); } if (top_k != 0) { - input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &top_k}}); + input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{1}, &top_k}}); } } input_tensors.insert({"temperature", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &temperature}}); @@ -794,11 +745,6 @@ void multi_gpu_gpt_example(const INIReader reader) TYPE_INT32, std::vector{request_batch_size, beam_width, (size_t)total_output_len}, d_output_ids}}, - {"parent_ids", - Tensor{MEMORY_GPU, - TYPE_INT32, - std::vector{(size_t)total_output_len, request_batch_size, beam_width}, - d_parent_ids}}, {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths}}, {"output_log_probs", @@ -810,10 +756,10 @@ void multi_gpu_gpt_example(const INIReader reader) print_mem_usage(); cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); - int total_output_ids = total_output_len * request_batch_size; - int* h_output_ids = new int[total_output_ids]; + int total_output_ids = total_output_len * request_batch_size; + int* h_output_ids = new int[total_output_ids]; int* h_sequence_lengths = new int[request_batch_size * beam_width]; if (rank == 0) { @@ -824,7 +770,7 @@ void multi_gpu_gpt_example(const INIReader reader) std::fill(h_output_ids, h_output_ids + total_output_ids, 0); std::fill(h_sequence_lengths, h_sequence_lengths + request_batch_size, 0); - std::string stream_file = pipeline_para_rank == (pipeline_para_size - 1) ? "out.stream" : ""; + std::string stream_file = pipeline_para_rank == (pipeline_para_size - 1) ? "out.stream" : ""; GptFileStreamer gpt_streamer(request_batch_size * beam_width, total_output_len, stream_file); cudaProfilerStart(); @@ -838,7 +784,7 @@ void multi_gpu_gpt_example(const INIReader reader) gpt.forward(&output_tensors, &input_tensors, &gpt_weights); } cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); POP_RANGE; nvtx::resetScope(); @@ -852,7 +798,7 @@ void multi_gpu_gpt_example(const INIReader reader) struct timeval start, end; cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); gettimeofday(&start, NULL); nvtx::setScope("total_time"); @@ -868,7 +814,7 @@ void multi_gpu_gpt_example(const INIReader reader) } cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); POP_RANGE; nvtx::resetScope(); gettimeofday(&end, NULL); @@ -887,14 +833,14 @@ void multi_gpu_gpt_example(const INIReader reader) vocab_size, ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); - std::string fName = USE_ASYNC ? "out.async" : "out.sync"; - auto outFile = std::ofstream(fName, std::ios::out); + std::string fName = USE_ASYNC ? "out.async" : "out.sync"; + auto outFile = std::ofstream(fName, std::ios::out); if (!outFile.is_open()) { printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); } else { size_t outCount = total_output_len * request_batch_size; - int* hBuf = new int[outCount]; + int* hBuf = new int[outCount]; cudaDeviceSynchronize(); cudaMemcpyAsync(hBuf, d_output_ids, outCount * sizeof(int), cudaMemcpyDeviceToHost, stream); cudaDeviceSynchronize(); @@ -927,8 +873,8 @@ void multi_gpu_gpt_example(const INIReader reader) } } - ncclCommDestroy(tensor_para_nccl_comm); - ncclCommDestroy(pipeline_para_nccl_comm); + ftNcclParamDestroy(tensor_para); + ftNcclParamDestroy(pipeline_para); delete[] h_output_ids; delete[] h_sequence_lengths; diff --git a/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_example.cc b/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_example.cc index 582dea3c9..acd75e70c 100644 --- a/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_example.cc +++ b/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_example.cc @@ -18,6 +18,7 @@ #include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" #include "src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h" #include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" #include "src/fastertransformer/utils/nvtx_utils.h" #include @@ -34,21 +35,29 @@ bool NVTX_ON = true; using namespace fastertransformer; template -void multi_gpu_gpt_example(const INIReader reader); +void multi_gpu_gpt_example(const INIReader reader, std::string in_csv); int main(int argc, char* argv[]) { - MPICHECK(MPI_Init(&argc, &argv)); + mpi::initialize(&argc, &argv); srand(0); std::string ini_name; - if (argc == 2) { + if (argc >= 2) { ini_name = std::string(argv[1]); } else { ini_name = "../examples/cpp/multi_gpu_gpt/gpt_config.ini"; } + std::string in_csv; + if (argc == 3) { + in_csv = std::string(argv[2]); + } + else { + in_csv = "../examples/cpp/multi_gpu_gpt/start_ids.csv"; + } + INIReader reader = INIReader(ini_name); if (reader.ParseError() < 0) { std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; @@ -57,74 +66,79 @@ int main(int argc, char* argv[]) const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); if (data_type == "fp32") { - multi_gpu_gpt_example(reader); + multi_gpu_gpt_example(reader, in_csv); } else if (data_type == "fp16") { - multi_gpu_gpt_example(reader); + multi_gpu_gpt_example(reader, in_csv); } #ifdef ENABLE_BF16 else if (data_type == "bf16") { - multi_gpu_gpt_example<__nv_bfloat16>(reader); + multi_gpu_gpt_example<__nv_bfloat16>(reader, in_csv); } #endif else { printf("[ERROR] data_type should be fp32, fp16 or bf16 ! \n"); return -1; } - MPI_Finalize(); + mpi::finalize(); return 0; } template -void multi_gpu_gpt_example(const INIReader reader) +void multi_gpu_gpt_example(const INIReader reader, std::string in_csv) { - const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); - const size_t max_batch_size = (size_t)reader.GetInteger("ft_instance_hyperparameter", "max_batch_size"); - const size_t max_seq_len = (size_t)reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"); - const size_t beam_width = (size_t)reader.GetInteger("ft_instance_hyperparameter", "beam_width"); - const int top_k = reader.GetInteger("ft_instance_hyperparameter", "top_k"); - const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); - const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); - const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); - const std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); - const bool sparse = static_cast(reader.GetInteger("ft_instance_hyperparameter", "sparse")); - const int int8_mode = reader.GetInteger("ft_instance_hyperparameter", "int8_mode"); - const float len_penalty = 1.0f; - const float beam_search_diversity_rate = 0.0f; - - const int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + const size_t max_batch_size = (size_t)reader.GetInteger("ft_instance_hyperparameter", "max_batch_size"); + const size_t max_seq_len = (size_t)reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"); + const size_t beam_width = (size_t)reader.GetInteger("ft_instance_hyperparameter", "beam_width"); + const uint top_k = (uint)reader.GetInteger("ft_instance_hyperparameter", "top_k"); + const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); + const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); + const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); + const std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); + const bool sparse = static_cast(reader.GetInteger("ft_instance_hyperparameter", "sparse")); + const int int8_mode = reader.GetInteger("ft_instance_hyperparameter", "int8_mode"); + const float len_penalty = reader.GetFloat("ft_instance_hyperparameter", "len_penalty"); + const float beam_search_diversity_rate = + reader.GetFloat("ft_instance_hyperparameter", "beam_search_diversity_rate"); + const float shared_contexts_ratio = reader.GetFloat("ft_instance_hyperparameter", "shared_contexts_ratio", true); + + const int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); const int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); - const size_t head_num = (size_t)reader.GetInteger(model_name, "head_num"); - const size_t size_per_head = (size_t)reader.GetInteger(model_name, "size_per_head"); - const size_t vocab_size = (size_t)reader.GetInteger(model_name, "vocab_size"); - const size_t decoder_layers = (size_t)reader.GetInteger(model_name, "decoder_layers"); - const size_t hidden_units = head_num * size_per_head; - const size_t inter_size = 4 * hidden_units; + const size_t head_num = (size_t)reader.GetInteger(model_name, "head_num"); + const size_t size_per_head = (size_t)reader.GetInteger(model_name, "size_per_head"); + const size_t vocab_size = (size_t)reader.GetInteger(model_name, "vocab_size"); + const size_t decoder_layers = (size_t)reader.GetInteger(model_name, "decoder_layers"); + const size_t hidden_units = head_num * size_per_head; + const size_t inter_size = 4 * hidden_units; + const std::string model_variant = std::string(reader.Get(model_name, "model_variant", "gpt")); const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); // The length of tokens we hope this model to generate - const int request_output_len = reader.GetInteger("request", "request_output_len"); + const int request_output_len = reader.GetInteger("request", "request_output_len"); const bool is_return_log_probs = reader.GetBoolean("request", "return_log_probs", false); // Whether to include input contexts in computing the cumulative log probabilities. const bool is_return_context_cum_log_probs = reader.GetBoolean("request", "context_log_probs", false); if (is_return_log_probs && !is_return_context_cum_log_probs) { FT_LOG_WARNING("context_log_probs will be ignored since return_log_probs is disabled."); } + const bool remove_padding = reader.GetBoolean("request", "remove_padding", false); + const uint32_t memory_len = reader.GetInteger("request", "memory_len", 0); const int start_id = 50256; - const int end_id = 50256; + const int end_id = 50256; FT_CHECK(head_num % tensor_para_size == 0); FT_CHECK(decoder_layers % pipeline_para_size == 0); // Prepare the parallelism parameters - int rank, world_size, device, device_count; - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size)); + int rank = mpi::getCommWorldRank(); + int world_size = mpi::getCommWorldSize(); if (rank == 0) { printf("Total ranks: %d.\n", world_size); } + int device, device_count; check_cuda_error(cudaGetDeviceCount(&device_count)); check_cuda_error(cudaSetDevice(rank % device_count)); check_cuda_error(cudaGetDevice(&device)); @@ -133,15 +147,13 @@ void multi_gpu_gpt_example(const INIReader reader) check_cuda_error(cudaGetDeviceProperties(&prop, device)); printf("Device %s\n", prop.name); - printf("P%d is runing with %d GPU.\n", rank, device); + printf("P%d is running with %d GPU.\n", rank, device); if (tensor_para_size * pipeline_para_size != world_size) { printf("[ERROR] tensor_para_size * pipeline_para_size should equal to world_size \n"); exit(-1); } - const int tensor_para_rank = rank % tensor_para_size; - const int pipeline_para_rank = rank / tensor_para_size; const int layers_per_group = decoder_layers / pipeline_para_size; if (layers_per_group * pipeline_para_size != (int)decoder_layers) { printf("[ERROR] layers_per_group (%d) * pipeline_para_size (%d) should equal to decoder_layers (%ld) \n", @@ -154,66 +166,23 @@ void multi_gpu_gpt_example(const INIReader reader) // assume gpu_num = k * n, // tensor parallelism group size is n // pipeline parallelism group size is k - - // convert WORLD communicator into 2D grid (k * n) communicator - // comms of the same row means they are in the same tensor parallel group - // comms of the same col means they are in the same pipeline parallel group - MPI_Comm grid_comm; - int dims[2] = {pipeline_para_size, tensor_para_size}; - int periods[2] = {0, 0}; - MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm); - - MPI_Comm comm_tensor_parallel, comm_pipeline_parallel; - - int remain_dims_tensor_parallel[2] = {false, true}; - int remain_dims_pipeline_parallel[2] = {true, false}; - // split 2D communicator into rows and cols, each row = one tensor parallel group, each col = one pipeline parallel - // group - MPI_Cart_sub(grid_comm, remain_dims_tensor_parallel, &comm_tensor_parallel); - MPI_Cart_sub(grid_comm, remain_dims_pipeline_parallel, &comm_pipeline_parallel); - - int rank_tensor_parallel, rank_pipeline_parallel; - MPI_Comm_rank(comm_tensor_parallel, &rank_tensor_parallel); - MPI_Comm_rank(comm_pipeline_parallel, &rank_pipeline_parallel); - - ncclUniqueId tensor_para_nccl_uid; - ncclUniqueId pipeline_para_nccl_uid; - // root of tensor parallel group and pipeline parallel group creates the nccl uid - if (rank_tensor_parallel == 0) { - NCCLCHECK(ncclGetUniqueId(&tensor_para_nccl_uid)); - } - - if (rank_pipeline_parallel == 0) { - NCCLCHECK(ncclGetUniqueId(&pipeline_para_nccl_uid)); - } - // broadcast nccl uid to the comms in the same tensor parallel group or pipeline parallel group - MPI_Bcast(&tensor_para_nccl_uid, sizeof(tensor_para_nccl_uid), MPI_BYTE, 0, comm_tensor_parallel); - MPI_Bcast(&pipeline_para_nccl_uid, sizeof(pipeline_para_nccl_uid), MPI_BYTE, 0, comm_pipeline_parallel); - - ncclComm_t tensor_para_nccl_comm, pipeline_para_nccl_comm; - NCCLCHECK(ncclCommInitRank(&tensor_para_nccl_comm, tensor_para_size, tensor_para_nccl_uid, tensor_para_rank)); - NCCLCHECK( - ncclCommInitRank(&pipeline_para_nccl_comm, pipeline_para_size, pipeline_para_nccl_uid, pipeline_para_rank)); + NcclParam tensor_para; + NcclParam pipeline_para; + ftNcclInitialize(tensor_para, pipeline_para, tensor_para_size, pipeline_para_size); // Read ids of request from file. - int max_input_len = -1; + int max_input_len = -1; std::vector v_start_lengths; std::vector v_start_ids; - read_start_ids(request_batch_size, - &v_start_lengths, - &v_start_ids, - max_input_len, - end_id, - 1, - "../examples/cpp/multi_gpu_gpt/start_ids.csv"); + read_start_ids(request_batch_size, &v_start_lengths, &v_start_ids, max_input_len, end_id, 1, in_csv); int* d_input_ids; int* d_input_lengths; if (max_input_len == 0) { // unconditional case, no input ids, so do nothing. - d_input_ids = nullptr; + d_input_ids = nullptr; d_input_lengths = nullptr; - max_input_len = 0; + max_input_len = 0; } else { // conditional case. @@ -229,8 +198,8 @@ void multi_gpu_gpt_example(const INIReader reader) exit(-1); } - cudaStream_t stream; - cublasHandle_t cublas_handle; + cudaStream_t stream; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cudaStreamCreate(&stream); cublasCreate(&cublas_handle); @@ -241,7 +210,7 @@ void multi_gpu_gpt_example(const INIReader reader) CHECK_CUSPARSE(cusparseLtInit(&cusparselt_handle)); cublasAlgoMap* cublas_algo_map = new cublasAlgoMap(GEMM_CONFIG, SPGEMM_CONFIG); #else - cublasAlgoMap* cublas_algo_map = new cublasAlgoMap(GEMM_CONFIG); + cublasAlgoMap* cublas_algo_map = new cublasAlgoMap(GEMM_CONFIG); #endif Allocator allocator(getDevice()); @@ -267,16 +236,63 @@ void multi_gpu_gpt_example(const INIReader reader) cublas_wrapper.setFP32GemmConfig(); } - fastertransformer::ParallelGptWeight gpt_weights(hidden_units, - inter_size, - vocab_size, - decoder_layers, - max_seq_len, - tensor_para_size, - tensor_para_rank, - pipeline_para_size, - pipeline_para_rank, - int8_mode); + // Prompt Learning Configurations + int prompt_learning_start_id = reader.GetInteger(model_name, "prompt_learning_start_id", end_id + 1); + PromptLearningType prompt_learning_type = + static_cast(reader.GetInteger(model_name, "prompt_learning_type", 0)); + + // NOTE:specify task names, take name id, prompt length in order to load those prompt learning tables. + // for example: + // std::map> p_prompt_tuning_table_pair_{{"sentiment", {0, 10}}, + // {"intent_and_slot", {1, 10}}, + // {"squad", {2, 16}}}; + + std::map> p_prompt_tuning_table_pair_; + + // NOTE: get prompt table pairs from configuration files + const int num_tasks = reader.GetInteger(model_name, "num_tasks", 0); + for (int task_name_id = 0; task_name_id < num_tasks; task_name_id++) { + std::string config_task_name = model_name + "_task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + p_prompt_tuning_table_pair_.insert({task_name, {task_name_id, prompt_length}}); + } + + // NOTE: task_name_ids for each sequence in one batch + // Each sequence can have different prompt learning task ids + std::vector p_prompt_tuning_task_name_ids(request_batch_size, 0); + + // NOTE: gpt variants parameters --> meta opt as an example here + gptVariantParams gpt_variant_params = {}; // default is gpt + if (model_variant == "opt-pre") { + gpt_variant_params.layernorm_eps = 1e-5f; + gpt_variant_params.layernorm_type = LayerNormType::pre_layernorm; + gpt_variant_params.activation_type = ActivationType::Relu; + gpt_variant_params.has_post_decoder_layernorm = true; + } + else if (model_variant == "opt-post") { + gpt_variant_params.layernorm_eps = 1e-5f; + gpt_variant_params.layernorm_type = LayerNormType::post_layernorm; + gpt_variant_params.activation_type = ActivationType::Relu; + gpt_variant_params.has_post_decoder_layernorm = false; + } + gpt_variant_params.has_adapters = reader.GetBoolean(model_name, "has_adapters", false); + gpt_variant_params.adapter_inter_size = reader.GetInteger(model_name, "adapter_inter_size", inter_size); + gpt_variant_params.layernorm_eps = reader.GetInteger(model_name, "layernorm_eps", 1e-6f); + + ParallelGptWeight gpt_weights(hidden_units, + inter_size, + vocab_size, + decoder_layers, + max_seq_len, + tensor_para.world_size_, + tensor_para.rank_, + pipeline_para.world_size_, + pipeline_para.rank_, + int8_mode, + prompt_learning_type, + p_prompt_tuning_table_pair_, + gpt_variant_params); gpt_weights.loadModel(model_dir); #ifdef SPARSITY_ENABLED if (sparse) { @@ -290,12 +306,9 @@ void multi_gpu_gpt_example(const INIReader reader) random_seed = (unsigned long long)(0); } if (world_size > 1) { - MPICHECK(MPI_Bcast(&random_seed, 1, MPI_UNSIGNED_LONG_LONG, 0, MPI_COMM_WORLD)); + mpi::bcast(&random_seed, 1, mpi::MPI_TYPE_UNSIGNED_LONG_LONG, 0, mpi::COMM_WORLD); } - NcclParam tensor_para(tensor_para_rank, tensor_para_size, tensor_para_nccl_comm); - NcclParam pipeline_para(pipeline_para_rank, pipeline_para_size, pipeline_para_nccl_comm); - ParallelGpt gpt = ParallelGpt(0, // max_batch_size, FT will adjust the buffer automatically. 0, // max_seq_len, FT will adjust the buffer automatically. 0, // max_input_len, FT will adjust the buffer automatically. @@ -307,12 +320,15 @@ void multi_gpu_gpt_example(const INIReader reader) vocab_size, start_id, end_id, + prompt_learning_start_id, // p/prompt tuning virtual token start id + prompt_learning_type, + gpt_variant_params, 0.0f, // beam_search_diversity_rate, 0, // top_k, 0.0, // top_p, 0, // random_seed, 1.0f, // temperature, - 1.0f, // len_penalty, + 0.0f, // len_penalty, 1.0f, // repetition_penalty, tensor_para, pipeline_para, @@ -322,20 +338,24 @@ void multi_gpu_gpt_example(const INIReader reader) false, &prop, sparse, - int8_mode); + int8_mode, + nullptr, + 0, + remove_padding, + shared_contexts_ratio); int* d_output_ids; - int* d_parent_ids; int* d_sequence_lengths; deviceMalloc(&d_output_ids, request_batch_size * beam_width * total_output_len, false); - deviceMalloc(&d_parent_ids, request_batch_size * beam_width * total_output_len, false); deviceMalloc(&d_sequence_lengths, request_batch_size * beam_width, false); + std::vector output_seq_len(request_batch_size, total_output_len); std::unordered_map input_tensors = std::unordered_map{ {"input_ids", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, (size_t)max_input_len}, d_input_ids}}, {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size}, d_input_lengths}}, - {"max_output_seq_len", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &total_output_len}}}; + {"output_seq_len", + Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}}; if (top_k == 0 && top_p == 0.0f) { FT_CHECK(beam_width > 1); input_tensors.insert({"beam_search_diversity_rate", @@ -346,14 +366,24 @@ void multi_gpu_gpt_example(const INIReader reader) input_tensors.insert({"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &top_p}}); } if (top_k != 0) { - input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &top_k}}); + input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{1}, &top_k}}); } } + if (num_tasks > 0) { + input_tensors.insert({"prompt_learning_task_name_ids", + Tensor{MEMORY_CPU, + TYPE_INT32, + std::vector{request_batch_size}, + p_prompt_tuning_task_name_ids.data()}}); + } input_tensors.insert({"temperature", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &temperature}}); input_tensors.insert({"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &len_penalty}}); input_tensors.insert( {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &repetition_penalty}}); input_tensors.insert({"random_seed", Tensor{MEMORY_CPU, TYPE_UINT64, std::vector{1}, &random_seed}}); + if (memory_len > 0) { + input_tensors.insert({"memory_len", {MEMORY_CPU, TYPE_UINT32, {1}, &memory_len}}); + } std::unordered_map output_tensors = std::unordered_map{ {"output_ids", @@ -361,16 +391,11 @@ void multi_gpu_gpt_example(const INIReader reader) TYPE_INT32, std::vector{request_batch_size, beam_width, (size_t)total_output_len}, d_output_ids}}, - {"parent_ids", - Tensor{MEMORY_GPU, - TYPE_INT32, - std::vector{(size_t)total_output_len, request_batch_size, beam_width}, - d_parent_ids}}, {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths}}}; float* output_log_probs = nullptr; - float* d_cum_log_probs = nullptr; + float* d_cum_log_probs = nullptr; if (is_return_log_probs) { deviceMalloc(&output_log_probs, request_batch_size * beam_width * request_output_len); output_tensors.insert({"output_log_probs", @@ -390,7 +415,7 @@ void multi_gpu_gpt_example(const INIReader reader) int ite = 1; cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); cudaProfilerStart(); // warm up @@ -401,20 +426,20 @@ void multi_gpu_gpt_example(const INIReader reader) gpt.forward(&output_tensors, &input_tensors, &gpt_weights); } cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); POP_RANGE; nvtx::resetScope(); if (rank == 0) { - std::string fName = "out"; - auto outFile = std::ofstream(fName, std::ios::out); + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); if (!outFile.is_open()) { printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); } else { size_t outCount = total_output_len * request_batch_size * beam_width; - int* hBuf = new int[outCount]; + int* hBuf = new int[outCount]; cudaD2Hcpy(hBuf, d_output_ids, outCount); { @@ -443,8 +468,8 @@ void multi_gpu_gpt_example(const INIReader reader) outFile.close(); if (d_cum_log_probs != nullptr) { - std::string logprob_fname = "logprob.out"; - std::ofstream logprob_file = std::ofstream("logprob.out", std::ios::out); + std::string logprob_fname = "logprob.out"; + std::ofstream logprob_file = std::ofstream("logprob.out", std::ios::out); if (!logprob_file.is_open()) { printf("[WARNING] Cannot write results into output file %s \n", logprob_fname.c_str()); } @@ -467,10 +492,12 @@ void multi_gpu_gpt_example(const INIReader reader) // test time struct timeval start, end; - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); cudaDeviceSynchronize(); gettimeofday(&start, NULL); + ite = 10; + nvtx::setScope("total_time"); PUSH_RANGE("total time") for (int i = 0; i < ite; ++i) { @@ -478,7 +505,7 @@ void multi_gpu_gpt_example(const INIReader reader) } cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + mpi::barrier(); POP_RANGE; nvtx::resetScope(); @@ -497,8 +524,8 @@ void multi_gpu_gpt_example(const INIReader reader) vocab_size, ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); - ncclCommDestroy(tensor_para_nccl_comm); - ncclCommDestroy(pipeline_para_nccl_comm); + ftNcclParamDestroy(tensor_para); + ftNcclParamDestroy(pipeline_para); #ifdef SPARSITY_ENABLED cusparseLtDestroy(&cusparselt_handle); diff --git a/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_interactive_example.cc b/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_interactive_example.cc new file mode 100644 index 000000000..06ef96fd7 --- /dev/null +++ b/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_interactive_example.cc @@ -0,0 +1,622 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" +#include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" +#include "src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/nvtx_utils.h" + +#include +#include +#include +#include +#include +#include + +#ifdef USE_NVTX +bool NVTX_ON = true; +#endif + +using namespace fastertransformer; + +template +void multi_gpu_gpt_interactive_example(const INIReader reader, std::string in_csv, std::string in_csv_final); +void writeOutputIds(const std::string& fName, size_t outCount, size_t total_output_len, int* d_output_ids); + +int main(int argc, char* argv[]) +{ + mpi::initialize(&argc, &argv); + srand(0); + + std::string ini_name; + if (argc >= 2) { + ini_name = std::string(argv[1]); + } + else { + ini_name = "../examples/cpp/multi_gpu_gpt/gpt_config.ini"; + } + + std::string in_csv; + if (argc >= 3) { + in_csv = std::string(argv[2]); + } + else { + in_csv = "../examples/cpp/multi_gpu_gpt/start_ids.csv"; + } + + std::string in_csv_final; + if (argc >= 4) { + in_csv_final = std::string(argv[3]); + } + else { + in_csv_final = "../examples/cpp/multi_gpu_gpt/interactive_inputs_ids.csv"; + } + + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + return -1; + } + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + + if (data_type == "fp32") { + multi_gpu_gpt_interactive_example(reader, in_csv, in_csv_final); + } + else if (data_type == "fp16") { + multi_gpu_gpt_interactive_example(reader, in_csv, in_csv_final); + } +#ifdef ENABLE_BF16 + else if (data_type == "bf16") { + multi_gpu_gpt_interactive_example<__nv_bfloat16>(reader, in_csv, in_csv_final); + } +#endif + else { + printf("[ERROR] data_type should be fp32, fp16 or bf16 ! \n"); + return -1; + } + mpi::finalize(); + return 0; +} + +template +void multi_gpu_gpt_interactive_example(const INIReader reader, std::string in_csv, std::string in_csv_final) +{ + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + const size_t max_batch_size = (size_t)reader.GetInteger("ft_instance_hyperparameter", "max_batch_size"); + const size_t max_seq_len = (size_t)reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"); + const size_t beam_width = (size_t)reader.GetInteger("ft_instance_hyperparameter", "beam_width"); + const uint top_k = (uint)reader.GetInteger("ft_instance_hyperparameter", "top_k"); + const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); + const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); + const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); + const std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); + const bool sparse = static_cast(reader.GetInteger("ft_instance_hyperparameter", "sparse")); + const int int8_mode = reader.GetInteger("ft_instance_hyperparameter", "int8_mode"); + const float len_penalty = reader.GetFloat("ft_instance_hyperparameter", "len_penalty"); + const float beam_search_diversity_rate = + reader.GetFloat("ft_instance_hyperparameter", "beam_search_diversity_rate"); + const float shared_contexts_ratio = reader.GetFloat("ft_instance_hyperparameter", "shared_contexts_ratio", true); + + const int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + const int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); + + const size_t head_num = (size_t)reader.GetInteger(model_name, "head_num"); + const size_t size_per_head = (size_t)reader.GetInteger(model_name, "size_per_head"); + const size_t vocab_size = (size_t)reader.GetInteger(model_name, "vocab_size"); + const size_t decoder_layers = (size_t)reader.GetInteger(model_name, "decoder_layers"); + const size_t hidden_units = head_num * size_per_head; + const size_t inter_size = 4 * hidden_units; + const std::string model_variant = std::string(reader.Get(model_name, "model_variant", "gpt")); + + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + // The length of tokens we hope this model to generate + const int request_output_len = reader.GetInteger("request", "request_output_len"); + const bool is_return_log_probs = reader.GetBoolean("request", "return_log_probs", false); + // Whether to include input contexts in computing the cumulative log probabilities. + const bool is_return_context_cum_log_probs = reader.GetBoolean("request", "context_log_probs", false); + if (is_return_log_probs && !is_return_context_cum_log_probs) { + FT_LOG_WARNING("context_log_probs will be ignored since return_log_probs is disabled."); + } + const bool remove_padding = reader.GetBoolean("request", "remove_padding", false); + + const int start_id = 50256; + const int end_id = 50256; + + FT_CHECK(head_num % tensor_para_size == 0); + FT_CHECK(decoder_layers % pipeline_para_size == 0); + + // Prepare the parallelism parameters + int rank = mpi::getCommWorldRank(); + int world_size = mpi::getCommWorldSize(); + if (rank == 0) { + printf("Total ranks: %d.\n", world_size); + } + int device, device_count; + check_cuda_error(cudaGetDeviceCount(&device_count)); + check_cuda_error(cudaSetDevice(rank % device_count)); + check_cuda_error(cudaGetDevice(&device)); + + struct cudaDeviceProp prop; + check_cuda_error(cudaGetDeviceProperties(&prop, device)); + printf("Device %s\n", prop.name); + + printf("P%d is running with %d GPU.\n", rank, device); + + if (tensor_para_size * pipeline_para_size != world_size) { + printf("[ERROR] tensor_para_size * pipeline_para_size should equal to world_size \n"); + exit(-1); + } + + const int layers_per_group = decoder_layers / pipeline_para_size; + if (layers_per_group * pipeline_para_size != (int)decoder_layers) { + printf("[ERROR] layers_per_group (%d) * pipeline_para_size (%d) should equal to decoder_layers (%ld) \n", + layers_per_group, + pipeline_para_size, + decoder_layers); + exit(-1); + } + + // assume gpu_num = k * n, + // tensor parallelism group size is n + // pipeline parallelism group size is k + + NcclParam tensor_para; + NcclParam pipeline_para; + ftNcclInitialize(tensor_para, pipeline_para, tensor_para_size, pipeline_para_size); + + // Read ids of request from file. + int max_input_len = -1; + std::vector v_start_lengths; + std::vector v_start_ids; + read_start_ids(request_batch_size, &v_start_lengths, &v_start_ids, max_input_len, end_id, 1, in_csv); + + int* d_input_ids; + int* d_input_lengths; + if (max_input_len == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids = nullptr; + d_input_lengths = nullptr; + max_input_len = 0; + } + else { + // conditional case. + deviceMalloc(&d_input_ids, request_batch_size * max_input_len, false); + deviceMalloc(&d_input_lengths, request_batch_size, false); + cudaH2Dcpy(d_input_ids, v_start_ids.data(), request_batch_size * max_input_len); + cudaH2Dcpy(d_input_lengths, v_start_lengths.data(), request_batch_size); + } + + const uint32_t session_len = (uint32_t)max_seq_len; + const int first_output_len = max_input_len + request_output_len; + + int max_input_len_final = -1; + std::vector v_lengths_final; + std::vector v_ids_final; + read_start_ids(request_batch_size, &v_lengths_final, &v_ids_final, max_input_len_final, end_id, 1, in_csv_final); + + int* d_input_ids_final; + int* d_input_lengths_final; + if (max_input_len_final == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids_final = nullptr; + d_input_lengths_final = nullptr; + max_input_len_final = 0; + } + else { + // conditional case. + deviceMalloc(&d_input_ids_final, request_batch_size * max_input_len_final, false); + deviceMalloc(&d_input_lengths_final, request_batch_size, false); + cudaH2Dcpy(d_input_ids_final, v_ids_final.data(), request_batch_size * max_input_len_final); + cudaH2Dcpy(d_input_lengths_final, v_lengths_final.data(), request_batch_size); + } + const size_t total_output_len = first_output_len + max_input_len_final + request_output_len; + if (total_output_len > (int)max_seq_len) { + FT_LOG_ERROR("first_output_len (%d) should be <= max_seq_len (%lu). \n", first_output_len, max_seq_len); + exit(-1); + } + if (total_output_len > (int)session_len) { + FT_LOG_ERROR("first_output_len (%d) should be <= session_len (%u). \n", first_output_len, session_len); + exit(-1); + } + + cudaStream_t stream; + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + cudaStreamCreate(&stream); + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); +#ifdef SPARSITY_ENABLED + cusparseLtHandle_t cusparselt_handle; + CHECK_CUSPARSE(cusparseLtInit(&cusparselt_handle)); + cublasAlgoMap* cublas_algo_map = new cublasAlgoMap(GEMM_CONFIG, SPGEMM_CONFIG); +#else + cublasAlgoMap* cublas_algo_map = new cublasAlgoMap(GEMM_CONFIG); +#endif + + Allocator allocator(getDevice()); + + std::mutex* cublas_wrapper_mutex = new std::mutex(); +#ifdef SPARSITY_ENABLED + cublasMMWrapper cublas_wrapper = cublasMMWrapper( + cublas_handle, cublaslt_handle, cusparselt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); +#else + cublasMMWrapper cublas_wrapper = + cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); +#endif + + if (std::is_same::value) { + cublas_wrapper.setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif + else if (std::is_same::value) { + cublas_wrapper.setFP32GemmConfig(); + } + + // Prompt Learning Configurations + int prompt_learning_start_id = reader.GetInteger(model_name, "prompt_learning_start_id", end_id + 1); + PromptLearningType prompt_learning_type = + static_cast(reader.GetInteger(model_name, "prompt_learning_type", 0)); + + // NOTE:specify task names, take name id, prompt length in order to load those prompt learning tables. + // for example: + // std::map> p_prompt_tuning_table_pair_{{"sentiment", {0, 10}}, + // {"intent_and_slot", {1, 10}}, + // {"squad", {2, 16}}}; + + std::map> p_prompt_tuning_table_pair_; + + // NOTE: get prompt table pairs from configuration files + const int num_tasks = reader.GetInteger(model_name, "num_tasks", 0); + for (int task_name_id = 0; task_name_id < num_tasks; task_name_id++) { + std::string config_task_name = model_name + "_task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + p_prompt_tuning_table_pair_.insert({task_name, {task_name_id, prompt_length}}); + } + + // NOTE: task_name_ids for each sequence in one batch + // Each sequence can have different prompt learning task ids + std::vector p_prompt_tuning_task_name_ids(request_batch_size, 0); + + // NOTE: gpt variants parameters --> meta opt as an example here + gptVariantParams gpt_variant_params = {}; // default is gpt + if (model_variant == "opt-pre") { + gpt_variant_params.layernorm_eps = 1e-5f; + gpt_variant_params.layernorm_type = LayerNormType::pre_layernorm; + gpt_variant_params.activation_type = ActivationType::Relu; + gpt_variant_params.has_post_decoder_layernorm = true; + } + else if (model_variant == "opt-post") { + gpt_variant_params.layernorm_eps = 1e-5f; + gpt_variant_params.layernorm_type = LayerNormType::post_layernorm; + gpt_variant_params.activation_type = ActivationType::Relu; + gpt_variant_params.has_post_decoder_layernorm = false; + } + + ParallelGptWeight gpt_weights(hidden_units, + inter_size, + vocab_size, + decoder_layers, + max_seq_len, + tensor_para.world_size_, + tensor_para.rank_, + pipeline_para.world_size_, + pipeline_para.rank_, + int8_mode, + prompt_learning_type, + p_prompt_tuning_table_pair_, + gpt_variant_params); + gpt_weights.loadModel(model_dir); +#ifdef SPARSITY_ENABLED + if (sparse) { + printf("[INFO] Compress weights for sparse inference\n"); + gpt_weights.compress_weights(cublas_wrapper); + } +#endif + + unsigned long long random_seed; + if (rank == 0) { + random_seed = (unsigned long long)(0); + } + if (world_size > 1) { + mpi::bcast(&random_seed, 1, mpi::MPI_TYPE_UNSIGNED_LONG_LONG, 0, mpi::COMM_WORLD); + } + + ParallelGpt gpt = ParallelGpt(0, // max_batch_size, FT will adjust the buffer automatically. + 0, // max_seq_len, FT will adjust the buffer automatically. + 0, // max_input_len, FT will adjust the buffer automatically. + beam_width, + head_num, + size_per_head, + inter_size, + decoder_layers, + vocab_size, + start_id, + end_id, + prompt_learning_start_id, // p/prompt tuning virtual token start id + prompt_learning_type, + gpt_variant_params, + 0.0f, // beam_search_diversity_rate, + 0, // top_k, + 0.0, // top_p, + 0, // random_seed, + 1.0f, // temperature, + 0.0f, // len_penalty, + 1.0f, // repetition_penalty, + tensor_para, + pipeline_para, + stream, + &cublas_wrapper, + &allocator, + false, + &prop, + sparse, + int8_mode, + nullptr, + 0, + remove_padding); + /* shared_contexts_ratio); */ + + int* d_output_ids; + int* d_sequence_lengths; + deviceMalloc(&d_output_ids, request_batch_size * beam_width * first_output_len, false); + deviceMalloc(&d_sequence_lengths, request_batch_size * beam_width, false); + std::vector output_seq_len(request_batch_size, first_output_len); + + std::unordered_map input_tensors = std::unordered_map{ + {"input_ids", + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, (size_t)max_input_len}, d_input_ids}}, + {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size}, d_input_lengths}}, + {"output_seq_len", + Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}}; + if (top_k == 0 && top_p == 0.0f) { + FT_CHECK(beam_width > 1); + input_tensors.insert({"beam_search_diversity_rate", + Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &beam_search_diversity_rate}}); + } + else { + if (top_p != 0.0f) { + input_tensors.insert({"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &top_p}}); + } + if (top_k != 0) { + input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{1}, &top_k}}); + } + } + if (num_tasks > 0) { + input_tensors.insert({"prompt_learning_task_name_ids", + Tensor{MEMORY_CPU, + TYPE_INT32, + std::vector{request_batch_size}, + p_prompt_tuning_task_name_ids.data()}}); + } + input_tensors.insert({"session_len", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{1}, &session_len}}); + input_tensors.insert({"temperature", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &temperature}}); + input_tensors.insert({"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &len_penalty}}); + input_tensors.insert( + {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &repetition_penalty}}); + input_tensors.insert({"random_seed", Tensor{MEMORY_CPU, TYPE_UINT64, std::vector{1}, &random_seed}}); + + std::unordered_map output_tensors = std::unordered_map{ + {"output_ids", + Tensor{MEMORY_GPU, + TYPE_INT32, + std::vector{request_batch_size, beam_width, (size_t)first_output_len}, + d_output_ids}}, + {"sequence_length", + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths}}}; + + float* output_log_probs = nullptr; + float* d_cum_log_probs = nullptr; + if (is_return_log_probs) { + deviceMalloc(&output_log_probs, request_batch_size * beam_width * request_output_len); + output_tensors.insert({"output_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + std::vector{request_batch_size, beam_width, (size_t)request_output_len}, + output_log_probs}}); + deviceMalloc(&d_cum_log_probs, request_batch_size * beam_width); + output_tensors.insert( + {"cum_log_probs", + Tensor{MEMORY_GPU, TYPE_FP32, std::vector{request_batch_size, beam_width}, d_cum_log_probs}}); + input_tensors.insert({"is_return_context_cum_log_probs", + Tensor{MEMORY_CPU, TYPE_BOOL, std::vector{1}, &is_return_context_cum_log_probs}}); + } + + int* d_output_ids_final; + int* d_sequence_lengths_final; + deviceMalloc(&d_output_ids_final, request_batch_size * beam_width * total_output_len, false); + deviceMalloc(&d_sequence_lengths_final, request_batch_size * beam_width, false); + std::vector output_seq_len_final(request_batch_size, total_output_len); + + std::unordered_map input_tensors_final = input_tensors; + for (auto it = input_tensors_final.begin(); it != input_tensors_final.end();) { + if (it->first == "input_ids" || it->first == "input_lengths" || it->first == "output_seq_len" + || it->first == "session_len") { + it = input_tensors_final.erase(it); + } + else { + it++; + } + } + input_tensors_final.insert( + {"input_ids", {MEMORY_GPU, TYPE_INT32, {request_batch_size, (size_t)max_input_len_final}, d_input_ids_final}}); + input_tensors_final.insert( + {"input_lengths", {MEMORY_GPU, TYPE_INT32, {request_batch_size}, d_input_lengths_final}}); + input_tensors_final.insert( + {"output_seq_len", {MEMORY_CPU, TYPE_UINT32, {request_batch_size}, output_seq_len_final.data()}}); + bool continue_gen = true; + input_tensors_final.insert({"continue_gen", {MEMORY_CPU, TYPE_BOOL, {1}, &continue_gen}}); + + std::unordered_map output_tensors_final{ + {"output_ids", + {MEMORY_GPU, TYPE_INT32, {request_batch_size, beam_width, total_output_len}, d_output_ids_final}}, + {"sequence_length", {MEMORY_GPU, TYPE_INT32, {request_batch_size, beam_width}, d_sequence_lengths_final}}}; + float* output_log_probs_final = nullptr; + float* d_cum_log_probs_final = nullptr; + if (is_return_log_probs) { + deviceMalloc(&output_log_probs_final, request_batch_size * beam_width * request_output_len); + output_tensors_final.insert( + {"output_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + std::vector{request_batch_size, beam_width, (size_t)request_output_len}, + output_log_probs_final}}); + deviceMalloc(&d_cum_log_probs_final, request_batch_size * beam_width); + output_tensors_final.insert( + {"cum_log_probs", + Tensor{ + MEMORY_GPU, TYPE_FP32, std::vector{request_batch_size, beam_width}, d_cum_log_probs_final}}); + } + + print_mem_usage(); + + int ite = 1; + cudaDeviceSynchronize(); + mpi::barrier(); + + cudaProfilerStart(); + // warm up + ite = 1; + nvtx::setScope("warmup_time"); + PUSH_RANGE("warmup time") + for (int i = 0; i < ite; ++i) { + gpt.forward(&output_tensors, &input_tensors, &gpt_weights); + gpt.forward(&output_tensors_final, &input_tensors_final, &gpt_weights); + } + cudaDeviceSynchronize(); + mpi::barrier(); + + POP_RANGE; + nvtx::resetScope(); + + if (rank == 0) { + size_t outCount = first_output_len * request_batch_size * beam_width; + writeOutputIds("out.interm", outCount, first_output_len, d_output_ids); + + outCount = total_output_len * request_batch_size * beam_width; + writeOutputIds("out", outCount, total_output_len, d_output_ids_final); + + if (d_cum_log_probs != nullptr) { + std::string logprob_fname = "logprob.out"; + std::ofstream logprob_file = std::ofstream("logprob.out", std::ios::out); + if (!logprob_file.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", logprob_fname.c_str()); + } + else { + size_t cum_log_probs_size = request_batch_size * beam_width; + printf("[INFO] Writing %ld elements (log probs)\n", cum_log_probs_size); + float* h_buf = new float[cum_log_probs_size]; + cudaD2Hcpy(h_buf, d_cum_log_probs, cum_log_probs_size); + for (size_t i = 0; i < cum_log_probs_size; i++) { + logprob_file << h_buf[i] << std::endl; + if (i < 10) { + printf(" %10.6f\n", h_buf[i]); + } + } + delete[] h_buf; + } + logprob_file.close(); + } + } + + // test time + struct timeval start, end; + mpi::barrier(); + cudaDeviceSynchronize(); + gettimeofday(&start, NULL); + + ite = 10; + + nvtx::setScope("total_time"); + PUSH_RANGE("total time") + for (int i = 0; i < ite; ++i) { + gpt.forward(&output_tensors, &input_tensors, &gpt_weights); + gpt.forward(&output_tensors_final, &input_tensors_final, &gpt_weights); + } + + cudaDeviceSynchronize(); + mpi::barrier(); + + POP_RANGE; + nvtx::resetScope(); + gettimeofday(&end, NULL); + + cudaProfilerStop(); + + printf("[INFO] request_batch_size %ld beam_width %ld head_num %ld size_per_head %ld total_output_len %ld" + " decoder_layers %ld vocab_size %ld FT-CPP-decoding-beamsearch-time %.2f ms\n", + request_batch_size, + beam_width, + head_num, + size_per_head, + total_output_len, + decoder_layers, + vocab_size, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + + ftNcclParamDestroy(tensor_para); + ftNcclParamDestroy(pipeline_para); + +#ifdef SPARSITY_ENABLED + cusparseLtDestroy(&cusparselt_handle); +#endif + delete cublas_algo_map; + delete cublas_wrapper_mutex; + return; +} + +void writeOutputIds(const std::string& fName, size_t outCount, size_t total_output_len, int* d_output_ids) +{ + auto outFile = std::ofstream(fName, std::ios::out); + if (!outFile.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); + } + else { + int* hBuf = new int[outCount]; + cudaD2Hcpy(hBuf, d_output_ids, outCount); + + { + std::cout << "Writing " << outCount << " elements\n"; + int zeroCount = 0; + for (size_t i = 0; i < outCount; i++) { + if (hBuf[i] == int(0)) { + zeroCount++; + } + outFile << hBuf[i] << " "; + if ((i + 1) % (total_output_len) == 0) { + outFile << std::endl; + } + + if (i < 10) { + printf("%5d ", hBuf[i]); + } + if ((i + 1) % (total_output_len) == 0 && i < 10) { + std::cout << std::endl; + } + } + std::cout << std::endl << "zeroCount = " << zeroCount << std::endl; + } + delete[] hBuf; + } + outFile.close(); +} diff --git a/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_triton_example.cc b/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_triton_example.cc index 9e22ae87e..169ccace5 100644 --- a/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_triton_example.cc +++ b/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_triton_example.cc @@ -23,32 +23,33 @@ #include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h" #include "src/fastertransformer/utils/custom_ar_comm.h" #include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" namespace ft = fastertransformer; std::vector>> -broadcastRequest(const std::vector& v_start_ids, - const std::vector& v_start_lengths, - const int node_id, - const int gpu_count, - const int beam_width, - const int request_output_len, - const float beam_search_diversity_rate, - const int runtime_top_k, - const float runtime_top_p, - const float temperature, - const float len_penalty, - const float repetition_penalty, +broadcastRequest(const std::vector& v_start_ids, + const std::vector& v_start_lengths, + const int node_id, + const int gpu_count, + const int beam_width, + const int request_output_len, + const float beam_search_diversity_rate, + const uint runtime_top_k, + const float runtime_top_p, + const float temperature, + const float len_penalty, + const float repetition_penalty, const unsigned long long int random_seed, - const bool is_return_log_probs, - const bool is_return_context_cum_log_probs, - std::vector* pointer_record) + const bool is_return_log_probs, + const bool is_return_context_cum_log_probs, + std::vector* pointer_record) { // broadcast the request to all nodes, and copy "gpu_count" copies on different gpu int size_1 = v_start_ids.size(); int size_2 = v_start_lengths.size(); - MPICHECK(MPI_Bcast(&size_1, 1, MPI_INT, 0, MPI_COMM_WORLD)); - MPICHECK(MPI_Bcast(&size_2, 1, MPI_INT, 0, MPI_COMM_WORLD)); + ft::mpi::bcast(&size_1, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_2, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); std::vector v_input_ids(size_1); std::vector v_input_lengths(size_2); @@ -57,13 +58,13 @@ broadcastRequest(const std::vector& v_start_ids, memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int)); memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int)); } - MPI_Barrier(MPI_COMM_WORLD); + ft::mpi::barrier(); int request_batch_size = size_2; - int max_input_len = size_1 / size_2; + int max_input_len = size_1 / size_2; - MPICHECK(MPI_Bcast(v_input_ids.data(), size_1, MPI_INT, 0, MPI_COMM_WORLD)); - MPICHECK(MPI_Bcast(v_input_lengths.data(), size_2, MPI_INT, 0, MPI_COMM_WORLD)); + ft::mpi::bcast(v_input_ids.data(), size_1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_lengths.data(), size_2, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); std::vector>> request_list; for (int device_id = 0; device_id < gpu_count; device_id++) { @@ -74,9 +75,9 @@ broadcastRequest(const std::vector& v_start_ids, if (max_input_len == 0) { // unconditional case, no input ids, so do nothing. - d_input_ids = nullptr; + d_input_ids = nullptr; d_input_lengths = nullptr; - max_input_len = 0; + max_input_len = 0; } else { // conditional case. @@ -86,7 +87,10 @@ broadcastRequest(const std::vector& v_start_ids, ft::cudaH2Dcpy(d_input_lengths, v_input_lengths.data(), size_2); } - int* request_output_len_ptr = new int((int)(request_output_len)); + uint32_t* request_output_len_ptr = (uint32_t*)malloc(request_batch_size * sizeof(uint32_t)); + for (int i = 0; i < request_batch_size; i++) { + request_output_len_ptr[i] = request_output_len; + } request_list.push_back(std::shared_ptr>( new std::unordered_map{ @@ -103,7 +107,7 @@ broadcastRequest(const std::vector& v_start_ids, {"request_output_len", triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, - std::vector{(size_t)1}, + std::vector{(size_t)request_batch_size}, request_output_len_ptr}}})); int* beam_width_ptr = new int(beam_width); pointer_record->push_back(beam_width_ptr); @@ -127,12 +131,12 @@ broadcastRequest(const std::vector& v_start_ids, triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, runtime_top_p_ptr}}); } if (runtime_top_k != 0) { - int* runtime_top_k_ptr = new int(runtime_top_k); + uint* runtime_top_k_ptr = new uint(runtime_top_k); pointer_record->push_back(runtime_top_k_ptr); request_list[device_id]->insert( {"runtime_top_k", triton::Tensor{ - triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, runtime_top_k_ptr}}); + triton::MEMORY_CPU, triton::TYPE_UINT32, std::vector{1}, runtime_top_k_ptr}}); } } float* temperature_ptr = new float(temperature); @@ -154,18 +158,18 @@ broadcastRequest(const std::vector& v_start_ids, pointer_record->push_back(random_seed_ptr); request_list[device_id]->insert( {"random_seed", - triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, random_seed_ptr}}); + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_UINT64, std::vector{1}, random_seed_ptr}}); bool* is_return_log_probs_ptr = new bool(is_return_log_probs); pointer_record->push_back(is_return_log_probs_ptr); request_list[device_id]->insert( {"is_return_log_probs", triton::Tensor{triton::MEMORY_CPU, triton::TYPE_BOOL, std::vector{1}, is_return_log_probs_ptr}}); - bool* is_return_context_cum_log_probs = new bool(is_return_context_cum_log_probs); - pointer_record->push_back(is_return_context_cum_log_probs); + bool* is_return_context_cum_log_probs_ptr = new bool(is_return_context_cum_log_probs); + pointer_record->push_back(is_return_context_cum_log_probs_ptr); request_list[device_id]->insert( {"is_return_context_cum_log_probs", triton::Tensor{ - triton::MEMORY_CPU, triton::TYPE_BOOL, std::vector{1}, is_return_context_cum_log_probs}}); + triton::MEMORY_CPU, triton::TYPE_BOOL, std::vector{1}, is_return_context_cum_log_probs_ptr}}); pointer_record->push_back(d_input_ids); pointer_record->push_back(d_input_lengths); pointer_record->push_back(request_output_len_ptr); @@ -184,23 +188,24 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std } const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); - const size_t beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); + const size_t beam_width = reader.GetInteger("ft_instance_hyperparameter", "beam_width"); const size_t request_output_len = reader.GetInteger("request", "request_output_len"); - const int top_k = reader.GetInteger("ft_instance_hyperparameter", "top_k"); - const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); - const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); - const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); - const float len_penalty = 1.0f; - const float beam_search_diversity_rate = 0.0f; - const unsigned long long int random_seed = 0; - const bool is_return_log_probs = reader.GetBoolean("request", "return_log_probs", false); + const int top_k = reader.GetInteger("ft_instance_hyperparameter", "top_k"); + const float top_p = reader.GetFloat("ft_instance_hyperparameter", "top_p"); + const float temperature = reader.GetFloat("ft_instance_hyperparameter", "temperature"); + const float repetition_penalty = reader.GetFloat("ft_instance_hyperparameter", "repetition_penalty"); + const float len_penalty = reader.GetFloat("ft_instance_hyperparameter", "len_penalty"); + const float beam_search_diversity_rate = + reader.GetFloat("ft_instance_hyperparameter", "beam_search_diversity_rate"); + const unsigned long long int random_seed = 0; + const bool is_return_log_probs = reader.GetBoolean("request", "return_log_probs", false); // Whether to include input contexts in computing the cumulative log probabilities. const bool is_return_context_cum_log_probs = reader.GetBoolean("request", "context_log_probs", false); - if (is_return_log_probs && !is_return_context_cum_log_probs) { + if (!is_return_log_probs && !is_return_context_cum_log_probs) { FT_LOG_WARNING("context_log_probs will be ignored since return_log_probs is disabled."); } - const int end_id = 50256; + const int end_id = 50256; std::vector v_start_ids; std::vector v_start_lengths; @@ -232,17 +237,18 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std return request_list; } -int threadCreateModelInstances(std::shared_ptr model, - std::vector>* model_instances, - const int device_id, - const int rank, - std::pair, std::vector> nccl_comms, +int threadCreateModelInstances(std::shared_ptr model, + std::vector>* model_instances, + const int device_id, + const int rank, + std::pair, std::vector> nccl_comms, std::shared_ptr custom_all_reduce_comm = nullptr) { FT_LOG_INFO("rank = %d", rank); ft::check_cuda_error(cudaSetDevice(device_id)); cudaStream_t stream; ft::check_cuda_error(cudaStreamCreate(&stream)); + model->createSharedWeights(device_id, rank); auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_comms, custom_all_reduce_comm); model_instances->at(device_id) = std::move(model_instance); FT_LOG_INFO("model instance %d is created", device_id); @@ -250,10 +256,10 @@ int threadCreateModelInstances(std::shared_ptr model, return 0; } -int threadForward(std::unique_ptr* model_instance, - std::shared_ptr> request, +int threadForward(std::unique_ptr* model_instance, + std::shared_ptr> request, std::shared_ptr>* output_tensors, - const int device_id) + const int device_id) { ft::check_cuda_error(cudaSetDevice(device_id)); *output_tensors = (*model_instance)->forward(request); @@ -267,38 +273,26 @@ int main(int argc, char* argv[]) by MPI or triton */ - MPICHECK(MPI_Init(&argc, &argv)); - int node_id; - int node_num; - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &node_id)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &node_num)); + ft::mpi::initialize(&argc, &argv); + int node_id = ft::mpi::getCommWorldRank(); + int node_num = ft::mpi::getCommWorldSize(); // Note: Only supports that all nodes have same gpu count - const int gpu_count = ft::getDeviceCount(); - const int world_size = node_num * gpu_count; - std::string ini_name = argc >= 2 ? std::string(argv[1]) : "../examples/cpp/multi_gpu_gpt/gpt_config.ini"; + const int gpu_count = ft::getDeviceCount(); + const int world_size = node_num * gpu_count; + std::string ini_name = argc >= 2 ? std::string(argv[1]) : "../examples/cpp/multi_gpu_gpt/gpt_config.ini"; // step 1: Create model - std::shared_ptr model = AbstractTransformerModel::createGptModel(ini_name); + std::shared_ptr model = AbstractTransformerModel::createGptModel(ini_name); + int tensor_para_size = model->getTensorParaSize(); + int pipeline_para_size = model->getPipelineParaSize(); + ft::FT_CHECK_WITH_INFO(world_size == (tensor_para_size * pipeline_para_size), + "World Size != Tensor Parallel Size * Pipeline Parallel Size !"); + std::cout << model->toString(); // step 2: Initialize the NCCL - std::vector nccl_ids; - if (node_id == 0) { - nccl_ids = model->createNcclIds(world_size); - } - int nccl_size = nccl_ids.size(); - MPI_Barrier(MPI_COMM_WORLD); - MPICHECK(MPI_Bcast(&nccl_size, 1, MPI_INT, 0, MPI_COMM_WORLD)); - if (node_id != 0) { - nccl_ids.resize(nccl_size); - } - MPI_Barrier(MPI_COMM_WORLD); - for (size_t i = 0; i < nccl_ids.size(); i++) { - MPICHECK(MPI_Bcast(&nccl_ids[i], sizeof(nccl_ids[i]), MPI_BYTE, 0, MPI_COMM_WORLD)); - } - MPI_Barrier(MPI_COMM_WORLD); - std::pair, std::vector> nccl_comms = model->createNcclComms(nccl_ids, node_id); + std::pair, std::vector> nccl_comms = model->createNcclParams(node_id); cudaDeviceSynchronize(); // Optional Step: create custom all reduce comm @@ -307,7 +301,7 @@ int main(int argc, char* argv[]) // step 3: Create model instances std::vector> model_instances((size_t)gpu_count); - std::vector threads; + std::vector threads; threads.clear(); @@ -350,20 +344,20 @@ int main(int argc, char* argv[]) FT_LOG_INFO("forward is completed."); const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data; - const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; - const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; - const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; + const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; + const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; + const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; // step 6: check results if (node_id == 0) { - std::string fName = "out"; - auto outFile = std::ofstream(fName, std::ios::out); + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); if (!outFile.is_open()) { FT_LOG_WARNING("Cannot write results into output file %s", fName.c_str()); } else { size_t outCount = batch_size * beam_width * seq_len; - int* hBuf = new int[outCount]; + int* hBuf = new int[outCount]; ft::cudaD2Hcpy(hBuf, d_output_ids, outCount); { @@ -393,7 +387,7 @@ int main(int argc, char* argv[]) // test time struct timeval start, end; - MPI_Barrier(MPI_COMM_WORLD); + ft::mpi::barrier(); cudaDeviceSynchronize(); gettimeofday(&start, NULL); @@ -413,7 +407,7 @@ int main(int argc, char* argv[]) } cudaDeviceSynchronize(); - MPI_Barrier(MPI_COMM_WORLD); + ft::mpi::barrier(); gettimeofday(&end, NULL); @@ -424,6 +418,6 @@ int main(int argc, char* argv[]) seq_len, ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); - MPICHECK(MPI_Finalize()); + ft::mpi::finalize(); return 0; } diff --git a/examples/cpp/swin/CMakeLists.txt b/examples/cpp/swin/CMakeLists.txt index bb95c98e6..4aaa82c22 100644 --- a/examples/cpp/swin/CMakeLists.txt +++ b/examples/cpp/swin/CMakeLists.txt @@ -18,4 +18,4 @@ set(swin_transformer_nv_files ) add_executable(swin_example ${swin_transformer_nv_files}) target_link_libraries(swin_example PUBLIC trt_fused_multi_head_attention Swin - cublasMMWrapper memory_utils -lcublas -lcublasLt -lcudart -lcudnn) + cublasMMWrapper memory_utils -lcublas -lcublasLt -lcudart -lcudnn tensor) diff --git a/examples/cpp/swin/functions.h b/examples/cpp/swin/functions.h index b42d52e48..9cd6927a5 100644 --- a/examples/cpp/swin/functions.h +++ b/examples/cpp/swin/functions.h @@ -41,15 +41,15 @@ static int getWeightNum(const int layer_num, const int* depths) } static void generateWeightSize(std::vector& weight_size, - const int layer_num, - const int embed_dim, - const float mlp_ratio, - const int window_size, - const int img_size, - const int patch_size, - const int in_chans, - const int* depths, - const int* num_heads) + const int layer_num, + const int embed_dim, + const float mlp_ratio, + const int window_size, + const int img_size, + const int patch_size, + const int in_chans, + const int* depths, + const int* num_heads) { size_t l_pow2 = 1; for (int l = 0; l < layer_num; l++) { diff --git a/examples/cpp/swin/swin_example.cc b/examples/cpp/swin/swin_example.cc index ae7db0bce..372acdd43 100644 --- a/examples/cpp/swin/swin_example.cc +++ b/examples/cpp/swin/swin_example.cc @@ -29,10 +29,10 @@ using namespace std; template void test(int model_type, int batch) { - cudnnHandle_t cudnn_handle; - cublasHandle_t cublas_handle; + cudnnHandle_t cudnn_handle; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; - cudaStream_t stream = 0; + cudaStream_t stream = 0; checkCUDNN(cudnnCreate(&cudnn_handle)); checkCUDNN(cudnnSetStream(cudnn_handle, stream)); check_cuda_error(cublasCreate(&cublas_handle)); @@ -49,31 +49,36 @@ void test(int model_type, int batch) if (std::is_same::value) { cublas_wrapper->setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } bool is_tiny = true; - int embed_dim = is_tiny ? 96 : 192; + int embed_dim = is_tiny ? 96 : 192; int window_size = is_tiny ? 7 : 12; - int img_size = is_tiny ? 224 : 384; - int shift_size = window_size / 2; + int img_size = is_tiny ? 224 : 384; + int shift_size = window_size / 2; int depths[4], num_heads[4]; if (is_tiny) { - depths[0] = 2; - depths[1] = 2; - depths[2] = 6; - depths[3] = 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 6; + depths[3] = 2; num_heads[0] = 3; num_heads[1] = 6; num_heads[2] = 12; num_heads[3] = 24; } else { - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 6; num_heads[1] = 12; num_heads[2] = 24; @@ -81,103 +86,103 @@ void test(int model_type, int batch) } if (model_type == 0) { - embed_dim = 96; - window_size = 7; - img_size = 224; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 6; - depths[3] = 2; + embed_dim = 96; + window_size = 7; + img_size = 224; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 6; + depths[3] = 2; num_heads[0] = 3; num_heads[1] = 6; num_heads[2] = 12; num_heads[3] = 24; } else if (model_type == 1) { - embed_dim = 96; - window_size = 7; - img_size = 224; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + embed_dim = 96; + window_size = 7; + img_size = 224; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 3; num_heads[1] = 6; num_heads[2] = 12; num_heads[3] = 24; } else if (model_type == 2) { - embed_dim = 128; - window_size = 7; - img_size = 224; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + embed_dim = 128; + window_size = 7; + img_size = 224; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 4; num_heads[1] = 8; num_heads[2] = 16; num_heads[3] = 32; } else if (model_type == 3) { - embed_dim = 128; - window_size = 12; - img_size = 384; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + embed_dim = 128; + window_size = 12; + img_size = 384; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 4; num_heads[1] = 8; num_heads[2] = 16; num_heads[3] = 32; } else if (model_type == 4) { - embed_dim = 192; - window_size = 7; - img_size = 224; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + embed_dim = 192; + window_size = 7; + img_size = 224; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 6; num_heads[1] = 12; num_heads[2] = 24; num_heads[3] = 48; } else if (model_type == 5) { - embed_dim = 192; - window_size = 12; - img_size = 384; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + embed_dim = 192; + window_size = 12; + img_size = 384; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 6; num_heads[1] = 12; num_heads[2] = 24; num_heads[3] = 48; } - int in_chans = 3; - bool ape = false; - bool patch_norm = true; - float mlp_ratio = 4.0f; - bool qkv_bias = true; - float qk_scale = 1.0f; - int layer_num = 4; - int patch_size = 4; + int in_chans = 3; + bool ape = false; + bool patch_norm = true; + float mlp_ratio = 4.0f; + bool qkv_bias = true; + float qk_scale = 1.0f; + int layer_num = 4; + int patch_size = 4; int output_dim = int(pow(2, layer_num - 1)) * embed_dim; int weight_num = getWeightNum(layer_num, depths); // calculate the size of each weight std::vector weight_size; - std::vector weight; + std::vector weight; generateWeightSize( weight_size, layer_num, embed_dim, mlp_ratio, window_size, img_size, patch_size, in_chans, depths, num_heads); for (int i = 0; i < weight_size.size(); i++) { @@ -187,48 +192,48 @@ void test(int model_type, int batch) } SwinTransformerWeight params; - int weight_idx = 0; + int weight_idx = 0; for (int l = 0; l < layer_num; l++) { SwinTransformerBasicLayerWeight bl; for (int di = 0; di < depths[l]; di++) { SwinTransformerBlockWeight p; - p.attention_weights.query_weight.kernel = weight[weight_idx++]; - p.attention_weights.query_weight.bias = weight[weight_idx++]; + p.attention_weights.query_weight.kernel = weight[weight_idx++]; + p.attention_weights.query_weight.bias = weight[weight_idx++]; p.attention_weights.attention_output_weight.kernel = weight[weight_idx++]; - p.attention_weights.attention_output_weight.bias = weight[weight_idx++]; - p.ffn_weights.intermediate_weight.kernel = weight[weight_idx++]; - p.ffn_weights.intermediate_weight.bias = weight[weight_idx++]; - p.ffn_weights.output_weight.kernel = weight[weight_idx++]; - p.ffn_weights.output_weight.bias = weight[weight_idx++]; - p.attn_layernorm_weights.gamma = weight[weight_idx++]; - p.attn_layernorm_weights.beta = weight[weight_idx++]; - p.ffn_layernorm_weights.gamma = weight[weight_idx++]; - p.ffn_layernorm_weights.beta = weight[weight_idx++]; + p.attention_weights.attention_output_weight.bias = weight[weight_idx++]; + p.ffn_weights.intermediate_weight.kernel = weight[weight_idx++]; + p.ffn_weights.intermediate_weight.bias = weight[weight_idx++]; + p.ffn_weights.output_weight.kernel = weight[weight_idx++]; + p.ffn_weights.output_weight.bias = weight[weight_idx++]; + p.attn_layernorm_weights.gamma = weight[weight_idx++]; + p.attn_layernorm_weights.beta = weight[weight_idx++]; + p.ffn_layernorm_weights.gamma = weight[weight_idx++]; + p.ffn_layernorm_weights.beta = weight[weight_idx++]; // Please use invokeGenRelativePosBias to get attention_relative_pos_bias from // attention_relative_pos_bias_table; p.attention_relative_pos_bias = weight[weight_idx++]; bl.block_weight_list.push_back(p); } bl.merge_layernorm_weights.gamma = weight[weight_idx++]; - bl.merge_layernorm_weights.beta = weight[weight_idx++]; - bl.merge_linear_weights.kernel = weight[weight_idx++]; - bl.attn_mask = weight[weight_idx++]; + bl.merge_layernorm_weights.beta = weight[weight_idx++]; + bl.merge_linear_weights.kernel = weight[weight_idx++]; + bl.attn_mask = weight[weight_idx++]; params.basic_layer_weight_list.push_back(bl); } params.patchEmbed_linear_weights.kernel = weight[weight_idx++]; - params.patchEmbed_linear_weights.bias = weight[weight_idx++]; - params.patchEmbed_norm_weights.gamma = weight[weight_idx++]; - params.patchEmbed_norm_weights.beta = weight[weight_idx++]; - params.norm_weights.gamma = weight[weight_idx++]; - params.norm_weights.beta = weight[weight_idx++]; + params.patchEmbed_linear_weights.bias = weight[weight_idx++]; + params.patchEmbed_norm_weights.gamma = weight[weight_idx++]; + params.patchEmbed_norm_weights.beta = weight[weight_idx++]; + params.norm_weights.gamma = weight[weight_idx++]; + params.norm_weights.beta = weight[weight_idx++]; T *input_d, *output_d; deviceMalloc(&input_d, batch * img_size * img_size * in_chans, false); deviceMalloc(&output_d, batch * output_dim, false); fastertransformer::Allocator allocator(0); - int max_batch = batch; - SwinTransformer sw(max_batch, + int max_batch = batch; + SwinTransformer sw(max_batch, img_size, patch_size, in_chans, @@ -250,8 +255,8 @@ void test(int model_type, int batch) // sw.allocateBuffer(&allocator); - const int sm = getSMVersion(); - int sm_ptr[1] = {sm}; + const int sm = getSMVersion(); + int sm_ptr[1] = {sm}; std::vector input_tensors = std::vector{Tensor{MEMORY_GPU, getTensorType(), @@ -269,7 +274,7 @@ void test(int model_type, int batch) for (int i = 0; i < 10; i++) sw.forward(&output_tensors, &input_tensors, params); - int ite = 100; + int ite = 100; CudaTimer cuda_timer(stream); cuda_timer.start(); for (int i = 0; i < ite; i++) @@ -305,7 +310,7 @@ int main(int argc, char* argv[]) struct cudaDeviceProp prop; check_cuda_error(cudaGetDeviceProperties(&prop, 0)); if (argc != 4) { - printf("[ERROR] swin_example is_fp16(0/1) model_type(0-5) batch_size\n"); + printf("[ERROR] swin_example data_type(0/1/2, fp32/fp16/bf16) model_type(0-5) batch_size\n"); printf("model_type:\n"); printf("0: tiny\t7x7\n"); printf("1: small\t7x7\n"); @@ -318,14 +323,23 @@ int main(int argc, char* argv[]) } printf("Device %s\n", prop.name); - bool use_fp16 = (atoi(argv[1]) == 1) ? true : false; - int model_type = atoi(argv[2]); - int batch = atoi(argv[3]); + FtCudaDataType data_type = static_cast(atoi(argv[1])); // 0: fp32, 1: fp16, 2: bf16 + int model_type = atoi(argv[2]); + int batch = atoi(argv[3]); - if (use_fp16) { + if (data_type == FP16) { test(model_type, batch); } - else { +#ifdef ENABLE_BF16 + else if (data_type == BF16) { + test<__nv_bfloat16>(model_type, batch); + } +#endif + else if (data_type == FP32) { test(model_type, batch); } + else { + printf("[ERROR] data_type is not supported\n"); + return 0; + } } diff --git a/examples/cpp/swin_int8/swin_int8_example.cc b/examples/cpp/swin_int8/swin_int8_example.cc index a77c09a54..176ce13eb 100644 --- a/examples/cpp/swin_int8/swin_int8_example.cc +++ b/examples/cpp/swin_int8/swin_int8_example.cc @@ -30,10 +30,10 @@ using namespace std; template void test(int model_type, int batch) { - cudnnHandle_t cudnn_handle; - cublasHandle_t cublas_handle; + cudnnHandle_t cudnn_handle; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; - cudaStream_t stream; + cudaStream_t stream; cudaStreamCreate(&stream); checkCUDNN(cudnnCreate(&cudnn_handle)); checkCUDNN(cudnnSetStream(cudnn_handle, stream)); @@ -65,27 +65,27 @@ void test(int model_type, int batch) bool is_tiny = true; - int embed_dim = is_tiny ? 96 : 192; + int embed_dim = is_tiny ? 96 : 192; int window_size = is_tiny ? 7 : 12; - int int8_mode = is_tiny ? 2 : 4; - int img_size = is_tiny ? 224 : 384; - int shift_size = window_size / 2; + int int8_mode = is_tiny ? 2 : 4; + int img_size = is_tiny ? 224 : 384; + int shift_size = window_size / 2; int depths[4], num_heads[4]; if (is_tiny) { - depths[0] = 2; - depths[1] = 2; - depths[2] = 6; - depths[3] = 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 6; + depths[3] = 2; num_heads[0] = 3; num_heads[1] = 6; num_heads[2] = 12; num_heads[3] = 24; } else { - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 6; num_heads[1] = 12; num_heads[2] = 24; @@ -93,109 +93,109 @@ void test(int model_type, int batch) } if (model_type == 0) { - int8_mode = 1; - embed_dim = 96; - window_size = 7; - img_size = 224; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 6; - depths[3] = 2; + int8_mode = 1; + embed_dim = 96; + window_size = 7; + img_size = 224; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 6; + depths[3] = 2; num_heads[0] = 3; num_heads[1] = 6; num_heads[2] = 12; num_heads[3] = 24; } else if (model_type == 1) { - int8_mode = 1; - embed_dim = 96; - window_size = 7; - img_size = 224; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + int8_mode = 1; + embed_dim = 96; + window_size = 7; + img_size = 224; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 3; num_heads[1] = 6; num_heads[2] = 12; num_heads[3] = 24; } else if (model_type == 2) { - int8_mode = 1; - embed_dim = 128; - window_size = 7; - img_size = 224; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + int8_mode = 1; + embed_dim = 128; + window_size = 7; + img_size = 224; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 4; num_heads[1] = 8; num_heads[2] = 16; num_heads[3] = 32; } else if (model_type == 3) { - int8_mode = 1; - embed_dim = 128; - window_size = 12; - img_size = 384; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + int8_mode = 1; + embed_dim = 128; + window_size = 12; + img_size = 384; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 4; num_heads[1] = 8; num_heads[2] = 16; num_heads[3] = 32; } else if (model_type == 4) { - int8_mode = 1; - embed_dim = 192; - window_size = 7; - img_size = 224; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + int8_mode = 1; + embed_dim = 192; + window_size = 7; + img_size = 224; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 6; num_heads[1] = 12; num_heads[2] = 24; num_heads[3] = 48; } else if (model_type == 5) { - int8_mode = 1; - embed_dim = 192; - window_size = 12; - img_size = 384; - shift_size = window_size / 2; - depths[0] = 2; - depths[1] = 2; - depths[2] = 18; - depths[3] = 2; + int8_mode = 1; + embed_dim = 192; + window_size = 12; + img_size = 384; + shift_size = window_size / 2; + depths[0] = 2; + depths[1] = 2; + depths[2] = 18; + depths[3] = 2; num_heads[0] = 6; num_heads[1] = 12; num_heads[2] = 24; num_heads[3] = 48; } - int in_chans = 3; - bool ape = false; - bool patch_norm = true; - float mlp_ratio = 4.0f; - bool qkv_bias = true; - float qk_scale = 1.0f; - int layer_num = 4; - int patch_size = 4; + int in_chans = 3; + bool ape = false; + bool patch_norm = true; + float mlp_ratio = 4.0f; + bool qkv_bias = true; + float qk_scale = 1.0f; + int layer_num = 4; + int patch_size = 4; int output_dim = int(pow(2, layer_num - 1)) * embed_dim; int weight_num = getWeightNum(layer_num, depths); // calculate the size of each weight std::vector weight_size; - std::vector weight; + std::vector weight; generateWeightSize( weight_size, layer_num, embed_dim, mlp_ratio, window_size, img_size, patch_size, in_chans, depths, num_heads); for (int i = 0; i < weight_size.size(); i++) { @@ -205,25 +205,25 @@ void test(int model_type, int batch) } SwinTransformerINT8Weight params; - int weight_idx = 0; - int hidden_dim = embed_dim; + int weight_idx = 0; + int hidden_dim = embed_dim; for (int l = 0; l < layer_num; l++) { SwinTransformerINT8BasicLayerWeight bl; for (int di = 0; di < depths[l]; di++) { SwinTransformerINT8BlockWeight p; - p.attention_weights.query_weight.kernel = weight[weight_idx++]; - p.attention_weights.query_weight.bias = weight[weight_idx++]; + p.attention_weights.query_weight.kernel = weight[weight_idx++]; + p.attention_weights.query_weight.bias = weight[weight_idx++]; p.attention_weights.attention_output_weight.kernel = weight[weight_idx++]; - p.attention_weights.attention_output_weight.bias = weight[weight_idx++]; - p.ffn_weights.intermediate_weight.kernel = weight[weight_idx++]; - p.ffn_weights.intermediate_weight.bias = weight[weight_idx++]; - p.ffn_weights.output_weight.kernel = weight[weight_idx++]; - p.ffn_weights.output_weight.bias = weight[weight_idx++]; - p.attn_layernorm_weights.gamma = weight[weight_idx++]; - p.attn_layernorm_weights.beta = weight[weight_idx++]; - p.ffn_layernorm_weights.gamma = weight[weight_idx++]; - p.ffn_layernorm_weights.beta = weight[weight_idx++]; - p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; + p.attention_weights.attention_output_weight.bias = weight[weight_idx++]; + p.ffn_weights.intermediate_weight.kernel = weight[weight_idx++]; + p.ffn_weights.intermediate_weight.bias = weight[weight_idx++]; + p.ffn_weights.output_weight.kernel = weight[weight_idx++]; + p.ffn_weights.output_weight.bias = weight[weight_idx++]; + p.attn_layernorm_weights.gamma = weight[weight_idx++]; + p.attn_layernorm_weights.beta = weight[weight_idx++]; + p.ffn_layernorm_weights.gamma = weight[weight_idx++]; + p.ffn_layernorm_weights.beta = weight[weight_idx++]; + p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; p.scalelist.p2_offset_ = ACTIVATION_AMAX_NUM; p.scalelist.p3_offset_ = ACTIVATION_AMAX_NUM + 5; p.scalelist.p4_offset_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM; @@ -238,26 +238,26 @@ void test(int model_type, int batch) bl.block_weight_list.push_back(p); } bl.merge_layernorm_weights.gamma = weight[weight_idx++]; - bl.merge_layernorm_weights.beta = weight[weight_idx++]; - bl.merge_linear_weights.kernel = weight[weight_idx++]; - bl.attn_mask = weight[weight_idx++]; + bl.merge_layernorm_weights.beta = weight[weight_idx++]; + bl.merge_linear_weights.kernel = weight[weight_idx++]; + bl.attn_mask = weight[weight_idx++]; params.basic_layer_weight_list.push_back(bl); hidden_dim *= 2; } params.patchEmbed_linear_weights.kernel = weight[weight_idx++]; - params.patchEmbed_linear_weights.bias = weight[weight_idx++]; - params.patchEmbed_norm_weights.gamma = weight[weight_idx++]; - params.patchEmbed_norm_weights.beta = weight[weight_idx++]; - params.norm_weights.gamma = weight[weight_idx++]; - params.norm_weights.beta = weight[weight_idx++]; + params.patchEmbed_linear_weights.bias = weight[weight_idx++]; + params.patchEmbed_norm_weights.gamma = weight[weight_idx++]; + params.patchEmbed_norm_weights.beta = weight[weight_idx++]; + params.norm_weights.gamma = weight[weight_idx++]; + params.norm_weights.beta = weight[weight_idx++]; T *input_d, *output_d; deviceMalloc(&input_d, batch * img_size * img_size * in_chans, false); deviceMalloc(&output_d, batch * output_dim, false); fastertransformer::Allocator allocator(0); - int max_batch = batch; - SwinTransformerINT8 sw(int8_mode, + int max_batch = batch; + SwinTransformerINT8 sw(int8_mode, max_batch, img_size, patch_size, @@ -277,8 +277,8 @@ void test(int model_type, int batch) false, qkv_bias, qk_scale); - int sm_ptr[1] = {sm}; - std::vector input_tensors = + int sm_ptr[1] = {sm}; + std::vector input_tensors = std::vector{Tensor{MEMORY_GPU, getTensorType(), std::vector{(size_t)batch, (size_t)in_chans, (size_t)img_size * img_size}, @@ -296,7 +296,7 @@ void test(int model_type, int batch) sw.forward(&output_tensors, &input_tensors, params); } - int ite = 100; + int ite = 100; CudaTimer cuda_timer(stream); cuda_timer.start(); for (int i = 0; i < ite; i++) { @@ -346,7 +346,7 @@ int main(int argc, char* argv[]) printf("Device %s\n", prop.name); int model_type = atoi(argv[1]); - int batch = atoi(argv[2]); + int batch = atoi(argv[2]); test(model_type, batch); } diff --git a/examples/cpp/vit/vit_example.cc b/examples/cpp/vit/vit_example.cc index 95fa8fa98..4a336dcb2 100644 --- a/examples/cpp/vit/vit_example.cc +++ b/examples/cpp/vit/vit_example.cc @@ -28,10 +28,10 @@ template void test( int batch_size, int img_size, int patch_size, int embed_dim, int head_num, int layer_num, int token_classifier) { - cudnnHandle_t cudnn_handle; - cublasHandle_t cublas_handle; + cudnnHandle_t cudnn_handle; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; - cudaStream_t stream = 0; + cudaStream_t stream = 0; checkCUDNN(cudnnCreate(&cudnn_handle)); checkCUDNN(cudnnSetStream(cudnn_handle, stream)); check_cuda_error(cublasCreate(&cublas_handle)); @@ -51,11 +51,11 @@ void test( else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } - const int in_chans = 3; + const int in_chans = 3; const bool with_cls_token = token_classifier > 0; - const int inter_size = embed_dim * 4; - const int head_dim = embed_dim / head_num; - const int seq_len = (img_size / patch_size) * (img_size / patch_size) + (with_cls_token ? 1 : 0); + const int inter_size = embed_dim * 4; + const int head_dim = embed_dim / head_num; + const int seq_len = (img_size / patch_size) * (img_size / patch_size) + (with_cls_token ? 1 : 0); ViTWeight params = ViTWeight(embed_dim, inter_size, layer_num, img_size, patch_size, in_chans, with_cls_token); @@ -78,8 +78,8 @@ void test( AttentionType attention_type = getAttentionType(head_dim, getSMVersion(), true, seq_len); printf("Attention Type: %d\n", int(attention_type)); fastertransformer::Allocator allocator(0); - int max_batch = batch_size; - ViTTransformer* vit = new ViTTransformer(max_batch, + int max_batch = batch_size; + ViTTransformer* vit = new ViTTransformer(max_batch, img_size, in_chans, patch_size, @@ -118,7 +118,7 @@ void test( vit->forward(&output_tensors, &input_tensors, ¶ms); } - int ite = 100; + int ite = 100; CudaTimer cuda_timer(stream); cuda_timer.start(); for (int i = 0; i < ite; i++) { @@ -170,14 +170,14 @@ int main(int argc, char* argv[]) return 0; } - const int batch_size = atoi(argv[1]); - const int img_size = atoi(argv[2]); - const int patch_size = atoi(argv[3]); - const int embed_dim = atoi(argv[4]); - const int head_num = atoi(argv[5]); - const int layer_num = atoi(argv[6]); + const int batch_size = atoi(argv[1]); + const int img_size = atoi(argv[2]); + const int patch_size = atoi(argv[3]); + const int embed_dim = atoi(argv[4]); + const int head_num = atoi(argv[5]); + const int layer_num = atoi(argv[6]); const int token_classifier = atoi(argv[7]); - const int is_fp16 = atoi(argv[8]); + const int is_fp16 = atoi(argv[8]); if (is_fp16) { test(batch_size, img_size, patch_size, embed_dim, head_num, layer_num, token_classifier); diff --git a/examples/cpp/vit_int8/CMakeLists.txt b/examples/cpp/vit_int8/CMakeLists.txt index 186144dcd..1e1ca5aa5 100644 --- a/examples/cpp/vit_int8/CMakeLists.txt +++ b/examples/cpp/vit_int8/CMakeLists.txt @@ -15,4 +15,4 @@ cmake_minimum_required(VERSION 3.8) add_executable(vit_int8_example vit_int8_example.cc) target_link_libraries(vit_int8_example PUBLIC ViTINT8 trt_fused_multi_head_attention vit_kernels - cublasMMWrapper -lcublas -lcublasLt -lcudart -lcudnn -lm) + cublasMMWrapper nvtx_utils -lcublas -lcublasLt -lcudart -lcudnn -lm) diff --git a/examples/cpp/vit_int8/vit_int8_example.cc b/examples/cpp/vit_int8/vit_int8_example.cc index 21b11cf86..ce89aa8da 100644 --- a/examples/cpp/vit_int8/vit_int8_example.cc +++ b/examples/cpp/vit_int8/vit_int8_example.cc @@ -15,12 +15,17 @@ */ #include "src/fastertransformer/models/vit_int8/ViTINT8.h" +#include "src/fastertransformer/utils/nvtx_utils.h" #include "stdio.h" #include "stdlib.h" #include #include #include +#ifdef USE_NVTX +bool NVTX_ON = true; +#endif + using namespace fastertransformer; using namespace std; @@ -34,10 +39,10 @@ void test(int batch_size, int int8_mode, int has_cls_token) { - cudnnHandle_t cudnn_handle; - cublasHandle_t cublas_handle; + cudnnHandle_t cudnn_handle; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; - cudaStream_t stream = 0; + cudaStream_t stream = 0; checkCUDNN(cudnnCreate(&cudnn_handle)); checkCUDNN(cudnnSetStream(cudnn_handle, stream)); check_cuda_error(cublasCreate(&cublas_handle)); @@ -66,11 +71,11 @@ void test(int batch_size, else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } - const int in_chans = 3; - const int inter_size = embed_dim * 4; - const int head_dim = embed_dim / head_num; + const int in_chans = 3; + const int inter_size = embed_dim * 4; + const int head_dim = embed_dim / head_num; const bool with_cls_token = has_cls_token > 0; - const int seq_len = (img_size / patch_size) * (img_size / patch_size) + (with_cls_token ? 1 : 0); + const int seq_len = (img_size / patch_size) * (img_size / patch_size) + (with_cls_token ? 1 : 0); ViTINT8Weight params = ViTINT8Weight(embed_dim, inter_size, layer_num, img_size, patch_size, in_chans, with_cls_token); @@ -93,8 +98,8 @@ void test(int batch_size, AttentionType attention_type = getAttentionType(head_dim, getSMVersion(), true, seq_len); printf("attention_type: %d\n", int(attention_type)); fastertransformer::Allocator allocator(0); - int max_batch = batch_size; - ViTTransformerINT8* vit = new ViTTransformerINT8(max_batch, + int max_batch = batch_size; + ViTTransformerINT8* vit = new ViTTransformerINT8(max_batch, img_size, in_chans, patch_size, @@ -134,7 +139,7 @@ void test(int batch_size, vit->forward(&output_tensors, &input_tensors, ¶ms); } - int ite = 100; + int ite = 100; CudaTimer cuda_timer(stream); cuda_timer.start(); for (int i = 0; i < ite; i++) { @@ -188,15 +193,15 @@ int main(int argc, char* argv[]) return 0; } - const int batch_size = atoi(argv[1]); - const int img_size = atoi(argv[2]); - const int patch_size = atoi(argv[3]); - const int embed_dim = atoi(argv[4]); - const int head_num = atoi(argv[5]); - const int layer_num = atoi(argv[6]); + const int batch_size = atoi(argv[1]); + const int img_size = atoi(argv[2]); + const int patch_size = atoi(argv[3]); + const int embed_dim = atoi(argv[4]); + const int head_num = atoi(argv[5]); + const int layer_num = atoi(argv[6]); const int has_cls_token = atoi(argv[7]); - const int is_fp16 = atoi(argv[8]); - const int int8_mode = atoi(argv[9]); + const int is_fp16 = atoi(argv[8]); + const int int8_mode = atoi(argv[9]); if (is_fp16) { test(batch_size, img_size, patch_size, embed_dim, head_num, layer_num, int8_mode, has_cls_token); diff --git a/examples/cpp/xlnet/cnpy.cpp b/examples/cpp/xlnet/cnpy.cpp index 80cfd0590..9872c82c0 100644 --- a/examples/cpp/xlnet/cnpy.cpp +++ b/examples/cpp/xlnet/cnpy.cpp @@ -85,20 +85,20 @@ std::vector& cnpy::operator+=(std::vector& lhs, const char* rhs) void cnpy::parse_npy_header(unsigned char* buffer, size_t& word_size, std::vector& shape, bool& fortran_order) { - uint16_t header_len = *reinterpret_cast(buffer + 8); + uint16_t header_len = *reinterpret_cast(buffer + 8); std::string header(reinterpret_cast(buffer + 9), header_len); size_t loc1, loc2; // fortran order - loc1 = header.find("fortran_order") + 16; + loc1 = header.find("fortran_order") + 16; fortran_order = (header.substr(loc1, 4) == "True" ? true : false); // shape loc1 = header.find("("); loc2 = header.find(")"); - std::regex num_regex("[0-9][0-9]*"); + std::regex num_regex("[0-9][0-9]*"); std::smatch sm; shape.clear(); @@ -111,7 +111,7 @@ void cnpy::parse_npy_header(unsigned char* buffer, size_t& word_size, std::vecto // endian, word size, data type // byte order code | stands for not applicable. // not sure when this applies except for byte array - loc1 = header.find("descr") + 9; + loc1 = header.find("descr") + 9; bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false); assert(littleEndian); @@ -119,13 +119,13 @@ void cnpy::parse_npy_header(unsigned char* buffer, size_t& word_size, std::vecto // assert(type == map_type(T)); std::string str_ws = header.substr(loc1 + 2); - loc2 = str_ws.find("'"); - word_size = atoi(str_ws.substr(0, loc2).c_str()); + loc2 = str_ws.find("'"); + word_size = atoi(str_ws.substr(0, loc2).c_str()); } void cnpy::parse_npy_header(FILE* fp, size_t& word_size, std::vector& shape, bool& fortran_order) { - char buffer[256]; + char buffer[256]; size_t res = fread(buffer, sizeof(char), 11, fp); if (res != 11) throw std::runtime_error("parse_npy_header: failed fread"); @@ -147,7 +147,7 @@ void cnpy::parse_npy_header(FILE* fp, size_t& word_size, std::vector& sh if (loc1 == std::string::npos || loc2 == std::string::npos) throw std::runtime_error("parse_npy_header: failed to find header keyword: '(' or ')'"); - std::regex num_regex("[0-9][0-9]*"); + std::regex num_regex("[0-9][0-9]*"); std::smatch sm; shape.clear(); @@ -171,8 +171,8 @@ void cnpy::parse_npy_header(FILE* fp, size_t& word_size, std::vector& sh // assert(type == map_type(T)); std::string str_ws = header.substr(loc1 + 2); - loc2 = str_ws.find("'"); - word_size = atoi(str_ws.substr(0, loc2).c_str()); + loc2 = str_ws.find("'"); + word_size = atoi(str_ws.substr(0, loc2).c_str()); } void cnpy::parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_size, size_t& global_header_offset) @@ -184,13 +184,13 @@ void cnpy::parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_siz throw std::runtime_error("parse_zip_footer: failed fread"); uint16_t disk_no, disk_start, nrecs_on_disk, comment_len; - disk_no = *(uint16_t*)&footer[4]; - disk_start = *(uint16_t*)&footer[6]; - nrecs_on_disk = *(uint16_t*)&footer[8]; - nrecs = *(uint16_t*)&footer[10]; - global_header_size = *(uint32_t*)&footer[12]; + disk_no = *(uint16_t*)&footer[4]; + disk_start = *(uint16_t*)&footer[6]; + nrecs_on_disk = *(uint16_t*)&footer[8]; + nrecs = *(uint16_t*)&footer[10]; + global_header_size = *(uint32_t*)&footer[12]; global_header_offset = *(uint32_t*)&footer[16]; - comment_len = *(uint16_t*)&footer[20]; + comment_len = *(uint16_t*)&footer[20]; assert(disk_no == 0); assert(disk_start == 0); @@ -201,12 +201,12 @@ void cnpy::parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_siz cnpy::NpyArray load_the_npy_file(FILE* fp) { std::vector shape; - size_t word_size; - bool fortran_order; + size_t word_size; + bool fortran_order; cnpy::parse_npy_header(fp, word_size, shape, fortran_order); cnpy::NpyArray arr(shape, word_size, fortran_order); - size_t nread = fread(arr.data(), 1, arr.num_bytes(), fp); + size_t nread = fread(arr.data(), 1, arr.num_bytes(), fp); if (nread != arr.num_bytes()) throw std::runtime_error("load_the_npy_file: failed fread"); return arr; @@ -216,30 +216,30 @@ cnpy::NpyArray load_the_npz_array(FILE* fp, uint32_t compr_bytes, uint32_t uncom { std::vector buffer_compr(compr_bytes); std::vector buffer_uncompr(uncompr_bytes); - size_t nread = fread(&buffer_compr[0], 1, compr_bytes, fp); + size_t nread = fread(&buffer_compr[0], 1, compr_bytes, fp); if (nread != compr_bytes) throw std::runtime_error("load_the_npy_file: failed fread"); z_stream d_stream; - d_stream.zalloc = Z_NULL; - d_stream.zfree = Z_NULL; - d_stream.opaque = Z_NULL; + d_stream.zalloc = Z_NULL; + d_stream.zfree = Z_NULL; + d_stream.opaque = Z_NULL; d_stream.avail_in = 0; - d_stream.next_in = Z_NULL; + d_stream.next_in = Z_NULL; inflateInit2(&d_stream, -MAX_WBITS); - d_stream.avail_in = compr_bytes; - d_stream.next_in = &buffer_compr[0]; + d_stream.avail_in = compr_bytes; + d_stream.next_in = &buffer_compr[0]; d_stream.avail_out = uncompr_bytes; - d_stream.next_out = &buffer_uncompr[0]; + d_stream.next_out = &buffer_uncompr[0]; inflate(&d_stream, Z_FINISH); inflateEnd(&d_stream); std::vector shape; - size_t word_size; - bool fortran_order; + size_t word_size; + bool fortran_order; cnpy::parse_npy_header(&buffer_uncompr[0], word_size, shape, fortran_order); cnpy::NpyArray array(shape, word_size, fortran_order); @@ -263,7 +263,7 @@ cnpy::npz_t cnpy::npz_load(std::string fname) if (fp) { while (1) { std::vector local_header(30); - size_t headerres = fread(&local_header[0], sizeof(char), 30, fp); + size_t headerres = fread(&local_header[0], sizeof(char), 30, fp); if (headerres != 30) throw std::runtime_error("npz_load: failed fread"); @@ -272,9 +272,9 @@ cnpy::npz_t cnpy::npz_load(std::string fname) break; // read in the variable name - uint16_t name_len = *(uint16_t*)&local_header[26]; + uint16_t name_len = *(uint16_t*)&local_header[26]; std::string varname(name_len, ' '); - size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp); + size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp); if (vname_res != name_len) throw std::runtime_error("npz_load: failed fread"); @@ -285,13 +285,13 @@ cnpy::npz_t cnpy::npz_load(std::string fname) uint16_t extra_field_len = *(uint16_t*)&local_header[28]; if (extra_field_len > 0) { std::vector buff(extra_field_len); - size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp); + size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp); if (efield_res != extra_field_len) throw std::runtime_error("npz_load: failed fread"); } - uint16_t compr_method = *reinterpret_cast(&local_header[0] + 8); - uint32_t compr_bytes = *reinterpret_cast(&local_header[0] + 18); + uint16_t compr_method = *reinterpret_cast(&local_header[0] + 8); + uint32_t compr_bytes = *reinterpret_cast(&local_header[0] + 18); uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0] + 22); if (compr_method == 0) { @@ -316,7 +316,7 @@ cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) while (1) { std::vector local_header(30); - size_t header_res = fread(&local_header[0], sizeof(char), 30, fp); + size_t header_res = fread(&local_header[0], sizeof(char), 30, fp); if (header_res != 30) throw std::runtime_error("npz_load: failed fread"); @@ -325,9 +325,9 @@ cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) break; // read in the variable name - uint16_t name_len = *(uint16_t*)&local_header[26]; + uint16_t name_len = *(uint16_t*)&local_header[26]; std::string vname(name_len, ' '); - size_t vname_res = fread(&vname[0], sizeof(char), name_len, fp); + size_t vname_res = fread(&vname[0], sizeof(char), name_len, fp); if (vname_res != name_len) throw std::runtime_error("npz_load: failed fread"); vname.erase(vname.end() - 4, vname.end()); // erase the lagging .npy @@ -336,8 +336,8 @@ cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) uint16_t extra_field_len = *(uint16_t*)&local_header[28]; fseek(fp, extra_field_len, SEEK_CUR); // skip past the extra field - uint16_t compr_method = *reinterpret_cast(&local_header[0] + 8); - uint32_t compr_bytes = *reinterpret_cast(&local_header[0] + 18); + uint16_t compr_method = *reinterpret_cast(&local_header[0] + 8); + uint32_t compr_bytes = *reinterpret_cast(&local_header[0] + 18); uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0] + 22); if (vname == varname) { diff --git a/examples/cpp/xlnet/cnpy.h b/examples/cpp/xlnet/cnpy.h index 86bbe491e..fcf4c7088 100644 --- a/examples/cpp/xlnet/cnpy.h +++ b/examples/cpp/xlnet/cnpy.h @@ -64,10 +64,10 @@ struct NpyArray { } std::shared_ptr> data_holder; - std::vector shape; - size_t word_size; - bool fortran_order; - size_t num_vals; + std::vector shape; + size_t word_size; + bool fortran_order; + size_t num_vals; }; using npz_t = std::map; @@ -76,10 +76,10 @@ char BigEndianTest(); char map_type(const std::type_info& t); template std::vector create_npy_header(const std::vector& shape); -void parse_npy_header(FILE* fp, size_t& word_size, std::vector& shape, bool& fortran_order); -void parse_npy_header(unsigned char* buffer, size_t& word_size, std::vector& shape, bool& fortran_order); -void parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_size, size_t& global_header_offset); -npz_t npz_load(std::string fname); +void parse_npy_header(FILE* fp, size_t& word_size, std::vector& shape, bool& fortran_order); +void parse_npy_header(unsigned char* buffer, size_t& word_size, std::vector& shape, bool& fortran_order); +void parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_size, size_t& global_header_offset); +npz_t npz_load(std::string fname); NpyArray npz_load(std::string fname, std::string varname); NpyArray npy_load(std::string fname); @@ -102,7 +102,7 @@ std::vector& operator+=(std::vector& lhs, const char* rhs); template void npy_save(std::string fname, const T* data, const std::vector shape, std::string mode = "w") { - FILE* fp = NULL; + FILE* fp = NULL; std::vector true_data_shape; // if appending, the shape of existing + new data if (mode == "a") @@ -112,7 +112,7 @@ void npy_save(std::string fname, const T* data, const std::vector shape, // file exists. we need to append to it. read the header, modify the array // size size_t word_size; - bool fortran_order; + bool fortran_order; parse_npy_header(fp, word_size, true_data_shape, fortran_order); assert(!fortran_order); @@ -137,12 +137,12 @@ void npy_save(std::string fname, const T* data, const std::vector shape, true_data_shape[0] += shape[0]; } else { - fp = fopen(fname.c_str(), "wb"); + fp = fopen(fname.c_str(), "wb"); true_data_shape = shape; } std::vector header = create_npy_header(true_data_shape); - size_t nels = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + size_t nels = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); fseek(fp, 0, SEEK_SET); fwrite(&header[0], sizeof(char), header.size(), fp); @@ -159,9 +159,9 @@ void npz_save( fname += ".npy"; // now, on with the show - FILE* fp = NULL; - uint16_t nrecs = 0; - size_t global_header_offset = 0; + FILE* fp = NULL; + uint16_t nrecs = 0; + size_t global_header_offset = 0; std::vector global_header; if (mode == "a") @@ -190,12 +190,12 @@ void npz_save( std::vector npy_header = create_npy_header(shape); - size_t nels = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + size_t nels = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); size_t nbytes = nels * sizeof(T) + npy_header.size(); // get the CRC of the data to be added uint32_t crc = crc32(0L, (uint8_t*)&npy_header[0], npy_header.size()); - crc = crc32(crc, (uint8_t*)data, nels * sizeof(T)); + crc = crc32(crc, (uint8_t*)data, nels * sizeof(T)); // build the local header std::vector local_header; diff --git a/examples/cpp/xlnet/xlnet_correctness_example.cc b/examples/cpp/xlnet/xlnet_correctness_example.cc index 815ee2d5f..e6f075e6a 100644 --- a/examples/cpp/xlnet/xlnet_correctness_example.cc +++ b/examples/cpp/xlnet/xlnet_correctness_example.cc @@ -16,6 +16,7 @@ #include "cnpy.h" #include "src/fastertransformer/models/xlnet/Xlnet.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" using namespace fastertransformer; using namespace std; #include @@ -30,7 +31,7 @@ int xlnetCorrectnessExample(size_t batch_size, string input_name, string model_name, string check_name, - bool allow_gemm_test = false); + bool allow_gemm_test = false); /*************** NPZ related operations *****************/ template @@ -68,6 +69,14 @@ float castToFloat(__half input) return output; } +#ifdef ENABLE_BF16 +template<> +float castToFloat(__nv_bfloat16 input) +{ + return __bfloat162float(input); +} +#endif + template void setByNpz(cnpy::npz_t& my_npz, std::string name, T* d_ptr, int size, int offset = 0) { @@ -86,8 +95,8 @@ void setByNpz<__half>(cnpy::npz_t& my_npz, std::string name, __half* d_ptr, int cnpy::NpyArray arr = my_npz[name]; // load it into a new array - float* loaded_data = arr.data(); - __half* half_data = (__half*)malloc(sizeof(__half) * size); + float* loaded_data = arr.data(); + __half* half_data = (__half*)malloc(sizeof(__half) * size); loaded_data = loaded_data + offset; for (int i = 0; i < size; i++) { @@ -97,6 +106,26 @@ void setByNpz<__half>(cnpy::npz_t& my_npz, std::string name, __half* d_ptr, int check_cuda_error(cudaMemcpy(d_ptr, half_data, sizeof(__half) * size, cudaMemcpyHostToDevice)); free(half_data); } +#ifdef ENABLE_BF16 +template<> +void setByNpz<__nv_bfloat16>(cnpy::npz_t& my_npz, std::string name, __nv_bfloat16* d_ptr, int size, int offset) +{ + // check that the loaded myVar1 matches myVar1 + cnpy::NpyArray arr = my_npz[name]; + + // load it into a new array + float* loaded_data = arr.data(); + __nv_bfloat16* half_data = (__nv_bfloat16*)malloc(sizeof(__nv_bfloat16) * size); + + loaded_data = loaded_data + offset; + for (int i = 0; i < size; i++) { + half_data[i] = __float2bfloat16_rn(loaded_data[i]); + } + + check_cuda_error(cudaMemcpy(d_ptr, half_data, sizeof(__nv_bfloat16) * size, cudaMemcpyHostToDevice)); + free(half_data); +} +#endif std::string paraName(int i_layer, std::string sub_para) { @@ -116,19 +145,19 @@ template void checkByNpz(cnpy::npz_t& data_npz, cudaStream_t stream, std::string name, T* d_ptr, int size) { std::cout << name << " " << size << std::endl; - bool ifCorrect = 1; - cnpy::NpyArray arr = data_npz[name]; - T* loaded_data = arr.data(); + bool ifCorrect = 1; + cnpy::NpyArray arr = data_npz[name]; + float* loaded_data = arr.data(); T* h_ptr = (T*)malloc(size * sizeof(T)); check_cuda_error(cudaMemcpyAsync(h_ptr, d_ptr, sizeof(T) * size, cudaMemcpyDeviceToHost, stream)); float err = 0; float max = castToFloat(h_ptr[0]); - int i = 0; + int i = 0; for (i = 0; i < size; i++) { - float sub = abs(castToFloat(h_ptr[i]) - castToFloat(loaded_data[i])); + float sub = abs(castToFloat(h_ptr[i]) - loaded_data[i]); if (sub > err) { err = sub; } @@ -139,7 +168,7 @@ void checkByNpz(cnpy::npz_t& data_npz, cudaStream_t stream, std::string name, T* std::cout << name << " Max err :" << err << " Max value :" << max << " Ralative error rate: " << err / max << std::endl; - + assert(err < 0.004f || err / max < 0.001f); free(h_ptr); } @@ -167,28 +196,28 @@ int main(int argc, char** argv) if (argc != 11) { printf("[ERROR] ./bin/xlnet_correctness_example batch_size num_layers seq_len " "head_num size_per_head num_token input_name model_name check_name " - "is_fp16\n"); + "data_type 0: fp32, 1: fp16, 2: bf16\n"); printf("e.g., ./bin/xlnet_correctness_example 8 12 128 12 64 32000 " "./data/data.npz ./data/model.npz ./data/output.npz 0\n"); return 0; } bool allow_gemm_test = false; - int batch_size = atoi(argv[1]); - int num_layers = atoi(argv[2]); - int seq_len = atoi(argv[3]); - int head_num = atoi(argv[4]); - int size_per_head = atoi(argv[5]); - int num_token = atoi(argv[6]); - input_name = argv[7]; - model_name = argv[8]; - check_name = argv[9]; - bool is_fp16 = atoi(argv[10]); + int batch_size = atoi(argv[1]); + int num_layers = atoi(argv[2]); + int seq_len = atoi(argv[3]); + int head_num = atoi(argv[4]); + int size_per_head = atoi(argv[5]); + int num_token = atoi(argv[6]); + input_name = argv[7]; + model_name = argv[8]; + check_name = argv[9]; + FtCudaDataType data_type = static_cast(atoi(argv[10])); // 0: fp32, 1: fp16, 2: bf16 cout << " " << batch_size << " " << num_layers << " " << seq_len << " " << head_num << " " << size_per_head << " " - << num_token << " " << input_name << " " << model_name << " " << check_name << " " << is_fp16 << endl; + << num_token << " " << input_name << " " << model_name << " " << check_name << " " << data_type << endl; - if (is_fp16 == 0) { + if (data_type == FP32) { return xlnetCorrectnessExample(batch_size, num_layers, seq_len, @@ -200,7 +229,7 @@ int main(int argc, char** argv) check_name, allow_gemm_test); } - else if (is_fp16 == 1) { + else if (data_type == FP16) { return xlnetCorrectnessExample(batch_size, num_layers, seq_len, @@ -212,9 +241,22 @@ int main(int argc, char** argv) check_name, allow_gemm_test); } +#ifdef ENABLE_BF16 + else if (data_type == BF16) { + return xlnetCorrectnessExample<__nv_bfloat16>(batch_size, + num_layers, + seq_len, + head_num, + size_per_head, + num_token, + input_name, + model_name, + check_name, + allow_gemm_test); + } +#endif else { - throw std::runtime_error(std::string("[FT][ERROR] is_fp16 should be 0 (use float)" - "or 1 (use half). \n ")); + throw std::runtime_error(std::string("[FT][ERROR] data_type should be fp32, fp16, or bf16 \n ")); } } @@ -229,15 +271,15 @@ int xlnetCorrectnessExample(size_t batch_size, string input_name, string model_name, string check_name, - bool allow_gemm_test) + bool allow_gemm_test) { printf("[INFO] Device: %s \n", getDeviceName().c_str()); const size_t hidden_units = head_num * size_per_head; - const size_t inter_size = 4 * hidden_units; + const size_t inter_size = 4 * hidden_units; - cudaStream_t stream; - cublasHandle_t cublas_handle; + cudaStream_t stream; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cudaStreamCreate(&stream); cublasCreate(&cublas_handle); @@ -256,35 +298,40 @@ int xlnetCorrectnessExample(size_t batch_size, if (std::is_same::value) { cublas_wrapper.setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } // Set layer weight std::vector> xlnet_layer_weights(num_layers, XlnetLayerWeight(hidden_units, inter_size)); - const int weight_nums = 17; - string weight_name[17] = {"/rel_attn/q/kernel:0", - "/rel_attn/k/kernel:0", - "/rel_attn/v/kernel:0", - "/rel_attn/r/kernel:0", - "model/transformer/r_w_bias:0", - "model/transformer/r_r_bias:0", - "model/transformer/r_s_bias:0", - "model/transformer/seg_embed:0", - "/rel_attn/o/kernel:0", - "/rel_attn/LayerNorm/gamma:0", - "/rel_attn/LayerNorm/beta:0", - "/ff/layer_1/kernel:0", - "/ff/layer_1/bias:0", - "/ff/layer_2/kernel:0", - "/ff/layer_2/bias:0", - "/ff/LayerNorm/gamma:0", - "/ff/LayerNorm/beta:0"}; + const int weight_nums = 17; + string weight_name[17] = {"/rel_attn/q/kernel:0", + "/rel_attn/k/kernel:0", + "/rel_attn/v/kernel:0", + "/rel_attn/r/kernel:0", + "model/transformer/r_w_bias:0", + "model/transformer/r_r_bias:0", + "model/transformer/r_s_bias:0", + "model/transformer/seg_embed:0", + "/rel_attn/o/kernel:0", + "/rel_attn/LayerNorm/gamma:0", + "/rel_attn/LayerNorm/beta:0", + "/ff/layer_1/kernel:0", + "/ff/layer_1/bias:0", + "/ff/layer_2/kernel:0", + "/ff/layer_2/bias:0", + "/ff/LayerNorm/gamma:0", + "/ff/LayerNorm/beta:0"}; cnpy::npz_t model_npz = cnpy::npz_load(model_name); for (int i = 0; i < num_layers; i++) { - T** weight_ptrs = xlnet_layer_weights[i].getWeightPtrs(); + T** weight_ptrs = xlnet_layer_weights[i].getWeightPtrs(); int* weight_sizes = xlnet_layer_weights[i].getWeightSizes(); for (int j = 0; j < weight_nums; j++) { string str; diff --git a/examples/cpp/xlnet/xlnet_example.cc b/examples/cpp/xlnet/xlnet_example.cc index d341376a3..d0c1c6399 100644 --- a/examples/cpp/xlnet/xlnet_example.cc +++ b/examples/cpp/xlnet/xlnet_example.cc @@ -25,27 +25,31 @@ int main(int argc, char** argv) { if (argc != 7) { printf("[ERROR] xlnet_example " - " \n"); + " \n"); printf("e.g., ./bin/xlnet_example 8 12 128 12 64 0\n"); return 0; } - int batch_size = atoi(argv[1]); - int num_layers = atoi(argv[2]); - int seq_len = atoi(argv[3]); - int head_num = atoi(argv[4]); - int size_per_head = atoi(argv[5]); - bool is_fp16 = atoi(argv[6]); + int batch_size = atoi(argv[1]); + int num_layers = atoi(argv[2]); + int seq_len = atoi(argv[3]); + int head_num = atoi(argv[4]); + int size_per_head = atoi(argv[5]); + FtCudaDataType data_type = static_cast(atoi(argv[6])); // 0: fp32, 1: fp16, 2: bf16 - if (is_fp16 == 0) { + if (data_type == FP32) { return xlnetExample(batch_size, num_layers, seq_len, head_num, size_per_head); } - else if (is_fp16 == 1) { +#ifdef ENABLE_BF16 + else if (data_type == BF16) { + return xlnetExample<__nv_bfloat16>(batch_size, num_layers, seq_len, head_num, size_per_head); + } +#endif + else if (data_type == FP16) { return xlnetExample(batch_size, num_layers, seq_len, head_num, size_per_head); } else { - throw std::runtime_error(std::string("[FT][ERROR] is_fp16 should be 0 (use float)" - "or 1 (use half). \n ")); + throw std::runtime_error(std::string("[FT][ERROR] data_type should be fp32, fp16, or bf16 \n ")); } } @@ -55,10 +59,10 @@ int xlnetExample(size_t batch_size, size_t num_layers, size_t seq_len, size_t he printf("[INFO] Device: %s \n", getDeviceName().c_str()); const size_t hidden_units = head_num * size_per_head; - const size_t inter_size = 4 * hidden_units; + const size_t inter_size = 4 * hidden_units; - cudaStream_t stream; - cublasHandle_t cublas_handle; + cudaStream_t stream; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cudaStreamCreate(&stream); cublasCreate(&cublas_handle); @@ -77,6 +81,11 @@ int xlnetExample(size_t batch_size, size_t num_layers, size_t seq_len, size_t he if (std::is_same::value) { cublas_wrapper.setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } diff --git a/examples/pytorch/__init__.py b/examples/pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/pytorch/bert/bert-quantization-sparsity/README.md b/examples/pytorch/bert/bert-quantization-sparsity/README.md index 91705000f..9505be8f6 100644 --- a/examples/pytorch/bert/bert-quantization-sparsity/README.md +++ b/examples/pytorch/bert/bert-quantization-sparsity/README.md @@ -8,7 +8,7 @@ This directory contains examples for BERT PTQ/QAT and sparsity related training. ## Setup -Please follow the original README to do some inital setup. +Please follow the original README to do some initial setup. setup steps: ```bash diff --git a/examples/pytorch/bert/bert-quantization-sparsity/apex_sparsity/asp.py b/examples/pytorch/bert/bert-quantization-sparsity/apex_sparsity/asp.py index 17b1f3485..c38c63a6b 100644 --- a/examples/pytorch/bert/bert-quantization-sparsity/apex_sparsity/asp.py +++ b/examples/pytorch/bert/bert-quantization-sparsity/apex_sparsity/asp.py @@ -52,7 +52,7 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", Arguments: model The model mask_calculator Either callable that computes mask given a tensor OR pattern string for sparse mask lib. - verbosity Integer controling verbosity level. + verbosity Integer controlling verbosity level. 0 -> Only errors. 1 -> Errors and warnings. 2 -> Errors, warnings and info. @@ -62,7 +62,7 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity. allow_recompute_mask If True, stores pruned values so that dense weights can be restored. Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage. - custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']} + custom_layer_dict Dictionary of additional layer parameters to sparsify. e.g. {CustomLinear: ['weight']} [Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe -- AKM. """ diff --git a/examples/pytorch/bert/bert-quantization-sparsity/modeling.py b/examples/pytorch/bert/bert-quantization-sparsity/modeling.py index e6fb98424..e1be82e62 100755 --- a/examples/pytorch/bert/bert-quantization-sparsity/modeling.py +++ b/examples/pytorch/bert/bert-quantization-sparsity/modeling.py @@ -734,7 +734,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_d . `model.chkpt` a TensorFlow checkpoint from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models + state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ @@ -852,7 +852,7 @@ class BertModel(BertPreTrainedModel): a batch has varying length sentences. Outputs: Tuple of (encoded_layers, pooled_output) - `encoded_layers`: controled by `output_all_encoded_layers` argument: + `encoded_layers`: controlled by `output_all_encoded_layers` argument: - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], diff --git a/examples/pytorch/bert/bert-quantization-sparsity/run_pretraining.py b/examples/pytorch/bert/bert-quantization-sparsity/run_pretraining.py index a5873589b..dfe62f77d 100755 --- a/examples/pytorch/bert/bert-quantization-sparsity/run_pretraining.py +++ b/examples/pytorch/bert/bert-quantization-sparsity/run_pretraining.py @@ -211,7 +211,7 @@ def parse_arguments(): parser.add_argument('--gradient_accumulation_steps', type=int, default=1, - help="Number of updates steps to accumualte before performing a backward/update pass.") + help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--fp16', default=False, action='store_true', diff --git a/examples/pytorch/bert/bert-quantization-sparsity/tokenization.py b/examples/pytorch/bert/bert-quantization-sparsity/tokenization.py index c25c323e7..8992d9cd2 100755 --- a/examples/pytorch/bert/bert-quantization-sparsity/tokenization.py +++ b/examples/pytorch/bert/bert-quantization-sparsity/tokenization.py @@ -166,7 +166,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: - # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # if we're using a pretrained model, ensure the tokenizer won't index sequences longer # than the number of positional embeddings max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) @@ -354,7 +354,7 @@ def tokenize(self, text): def _is_whitespace(char): """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them + # \t, \n, and \r are technically controlled characters but we treat them # as whitespace since they are generally considered as such. if char == " " or char == "\t" or char == "\n" or char == "\r": return True diff --git a/examples/pytorch/bert/bert_example.py b/examples/pytorch/bert/bert_example.py index 352e74826..467a70542 100644 --- a/examples/pytorch/bert/bert_example.py +++ b/examples/pytorch/bert/bert_example.py @@ -13,11 +13,13 @@ # limitations under the License. from __future__ import print_function +import threading import os import argparse import timeit import torch +import torch.distributed as dist import torch.cuda.nvtx as nvtx import time import sys @@ -25,11 +27,11 @@ import random dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.append(dir_path + "/../../..") -from examples.pytorch.bert.utils.encoder import EncoderWeights -from examples.pytorch.bert.utils.encoder import CustomEncoder -from examples.pytorch.bert.utils.encoder import HuggingFaceEncoder from examples.pytorch.utils import print_memory_usage -import threading +from examples.pytorch.bert.utils.encoder import HuggingFaceEncoder +from examples.pytorch.bert.utils.encoder import CustomEncoder +from examples.pytorch.bert.utils.encoder import EncoderWeights + def sequence_mask(lengths, max_len=None, is_2d=True): batch_size = lengths.numel() @@ -58,8 +60,7 @@ def main(): help='head number') parser.add_argument('head_size', type=int, help='size per head') - parser.add_argument('--fp16', action='store_true', - help='is fp16') + parser.add_argument('--data_type', type=str, choices=['fp32', 'fp16', 'bf16'], default='fp32') parser.add_argument('--int8_mode', type=int, default=0, metavar='NUMBER', help='int8 mode (default: 0)', choices=[0, 1, 2, 3]) parser.add_argument('--time', action='store_true', @@ -75,15 +76,33 @@ def main(): help='path of the pyt_fastertransformer dynamic lib file') parser.add_argument('-thread_num', '--thread_num', type=int, default=1, metavar='int', help='Testing multithread if thread_num > 1.') - + parser.add_argument('-tensor_para_size', '--tensor_para_size', type=int, default=1, metavar='int', + help='Size of tensor parallelism.') + parser.add_argument('-pipeline_para_size', '--pipeline_para_size', type=int, default=1, metavar='int', + help='Size of pipeline parallelism.') + parser.add_argument('--error-threshold', type=float, + help='Threshold of error') + args = parser.parse_args() bert_example(vars(args)) + def bert_example(args): torch.manual_seed(0) random.seed(0) np.random.seed(0) - + + if dist.is_mpi_available(): + try: + dist.init_process_group(backend='mpi') + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + batch_size = args['batch_size'] seq_len = args['seq_len'] if args['weight_path'] is not None: @@ -112,10 +131,11 @@ def bert_example(args): elif args['int8_mode'] != 0: raise ValueError("wrong int8_mode argument") - print("\n=============== Argument ===============") - for key in args: - print("{}: {}".format(key, args[key])) - print("========================================\n") + if rank == 0: + print("\n=============== Argument ===============") + for key in args: + print("{}: {}".format(key, args[key])) + print("========================================\n") inp = torch.empty(batch_size, seq_len, hidden_dim).cuda() torch.nn.init.normal_(inp, -0.02, 0.02) @@ -123,127 +143,159 @@ def bert_example(args): mem_seq_lens = torch.ones((batch_size,)) * args['avg_seq_len'] mem_seq_lens = mem_seq_lens.to(torch.int32).cuda() elif args['avg_seq_len'] == -1: - mem_seq_lens = torch.randint(1, seq_len+1, (batch_size,), dtype=torch.int32).cuda() + mem_seq_lens = torch.randint(1, seq_len + 1, (batch_size,), dtype=torch.int32).cuda() else: raise ValueError("wrong avg_seq_len") mask = sequence_mask(mem_seq_lens, args['seq_len'], False).to(torch.float) # mask = torch.randint(0, 2, (batch_size, seq_len, seq_len), dtype=torch.float32).cuda() - if args['fp16'] or args['int8_mode'] != 0: + if args['data_type'] == 'fp16' or args['int8_mode'] != 0: inp = inp.half() mask = mask.half() + elif args['data_type'] == 'bf16': + inp = inp.bfloat16() + mask = mask.bfloat16() pretrained_weights = torch.load(args['weight_path']) if (args['weight_path'] is not None) else None weights = EncoderWeights(layer_num, hidden_dim, pretrained_weights, args['sparse']) - - hf_encoder = HuggingFaceEncoder(layer_num, head_num, head_size, weights) - hf_encoder.cuda() - if args['fp16'] or args['int8_mode'] != 0: - hf_encoder.half() - hf_encoder.eval() - hf_encoder = torch.jit.trace(hf_encoder, (inp, mask)) + ft_weights = EncoderWeights(layer_num, hidden_dim, weights.weights, args['sparse'], + args["tensor_para_size"], args["pipeline_para_size"]) + world_size = dist.get_world_size() if dist.is_mpi_available() else 1 + assert world_size == args["tensor_para_size"] * \ + args["pipeline_para_size"], f"[ERROR] world_size ({world_size}) != tensor_para_size ({args['tensor_para_size']}) * pipeline_para_size ({args['pipeline_para_size']})" + ft_weights._generated_weights = True # for int8 handling + if rank == 0: + hf_encoder = HuggingFaceEncoder(layer_num, head_num, head_size, weights) + hf_encoder.cuda() + if args['data_type'] == 'fp16' or args['int8_mode'] != 0: + hf_encoder.half() + elif args['data_type'] == 'bf16': + hf_encoder.bfloat16() + hf_encoder.eval() + hf_encoder = torch.jit.trace(hf_encoder, (inp, mask)) if args['int8_mode'] != 0: - weights.to_int8(args['sparse'], args['ths_path']) - elif args['fp16']: - weights.to_half() - weights.to_cuda() - custom_encoder = CustomEncoder(layer_num, head_num, head_size, weights, - int8_mode=args['int8_mode'], - remove_padding=False, - sparse=args['sparse'], - path=args['ths_path']) + ft_weights.to_int8(args['sparse'], args['ths_path']) + elif args['data_type'] == 'fp16': + ft_weights.to_half() + elif args['data_type'] == 'bf16': + ft_weights.to_bfloat16() + ft_weights.to_cuda() + custom_encoder = CustomEncoder(layer_num, head_num, head_size, ft_weights, + int8_mode=args['int8_mode'], + remove_padding=False, + sparse=args['sparse'], + path=args['ths_path'], + tensor_para_size=args["tensor_para_size"], + pipeline_para_size=args["pipeline_para_size"]) custom_encoder = torch.jit.script(custom_encoder) - eff_custom_encoder = CustomEncoder(layer_num, head_num, head_size, weights, - int8_mode=args['int8_mode'], - remove_padding=True, - sparse=args['sparse'], - path=args['ths_path']) + eff_custom_encoder = CustomEncoder(layer_num, head_num, head_size, ft_weights, + int8_mode=args['int8_mode'], + remove_padding=True, + sparse=args['sparse'], + path=args['ths_path'], + tensor_para_size=args["tensor_para_size"], + pipeline_para_size=args["pipeline_para_size"]) eff_custom_encoder = torch.jit.script(eff_custom_encoder) with torch.no_grad(): output_mask = sequence_mask(mem_seq_lens, args['seq_len']).to(mask.dtype).unsqueeze(-1) - hf_output = hf_encoder(inp, mask)[0] * output_mask - print(hf_output) - print(hf_output.size()) + if rank == 0: + hf_output = hf_encoder(inp, mask)[0] * output_mask + print(hf_output) + print(hf_output.size()) - ft_output = custom_encoder(inp, mask, mem_seq_lens)[0] * output_mask - print(ft_output) - print(ft_output.size()) + ft_inp = inp.to(f"cuda:{rank}") + ft_mask = mask.to(f"cuda:{rank}") + ft_mem_seq_lens = mem_seq_lens.to(f"cuda:{rank}") + ft_output_mask = output_mask.to(f"cuda:{rank}") + ft_output = custom_encoder(ft_inp, ft_mask, ft_mem_seq_lens)[0] * ft_output_mask + if rank == 0: + print(ft_output) + print(ft_output.size()) - eff_ft_output = eff_custom_encoder(inp, mask, mem_seq_lens)[0] * output_mask - print(eff_ft_output) - print(eff_ft_output.size()) + eff_ft_output = eff_custom_encoder(ft_inp, ft_mask, ft_mem_seq_lens)[0] * ft_output_mask + if rank == 0: + print(eff_ft_output) + print(eff_ft_output.size()) - FT_diff = torch.abs(hf_output - ft_output) - print('FT Mean diff: {}'.format(torch.mean(FT_diff))) - print('FT Max diff: {}'.format(torch.max(FT_diff))) - print('FT Min diff: {}'.format(torch.min(FT_diff))) + if rank == 0: + FT_diff = torch.abs(hf_output - ft_output).float() # Prevent error under bfloat16 + print('FT Mean diff: {}'.format(torch.mean(FT_diff))) + print('FT Max diff: {}'.format(torch.max(FT_diff))) + print('FT Min diff: {}'.format(torch.min(FT_diff))) - EFF_diff = torch.abs(hf_output - eff_ft_output) - print('EFF-FT Mean diff: {}'.format(torch.mean(EFF_diff))) - print('EFF-FT Max diff: {}'.format(torch.max(EFF_diff))) - print('EFF-FT Min diff: {}'.format(torch.min(EFF_diff))) + if rank == 0: + EFF_diff = torch.abs(hf_output - eff_ft_output).float() # Prevent error under bfloat16 + print('EFF-FT Mean diff: {}'.format(torch.mean(EFF_diff))) + print('EFF-FT Max diff: {}'.format(torch.max(EFF_diff))) + print('EFF-FT Min diff: {}'.format(torch.min(EFF_diff))) if args['time']: iterations = 100 - for i in range(iterations): - output = hf_encoder(inp, mask) - t10 = timeit.default_timer() - # nvtx.range_push("hf") - for i in range(iterations): - # nvtx.range_push("hf"+str(i)) - output = hf_encoder(inp, mask) + if rank == 0: + for i in range(iterations): + output = hf_encoder(inp, mask) + t10 = timeit.default_timer() + # nvtx.range_push("hf") + for i in range(iterations): + # nvtx.range_push("hf"+str(i)) + output = hf_encoder(inp, mask) + # nvtx.range_pop() # nvtx.range_pop() - # nvtx.range_pop() - t1 = timeit.default_timer() - t10 - # time.sleep(60) + t1 = timeit.default_timer() - t10 + # time.sleep(60) for i in range(iterations): - output = custom_encoder(inp, mask, mem_seq_lens) + output = custom_encoder(ft_inp, ft_mask, ft_mem_seq_lens) t20 = timeit.default_timer() # nvtx.range_push("ext") for i in range(iterations): # nvtx.range_push("ext"+str(i)) - output = custom_encoder(inp, mask, mem_seq_lens) + output = custom_encoder(ft_inp, ft_mask, ft_mem_seq_lens) # nvtx.range_pop() # nvtx.range_pop() t2 = timeit.default_timer() - t20 # time.sleep(60) for i in range(iterations): - output = eff_custom_encoder(inp, mask, mem_seq_lens) + output = eff_custom_encoder(ft_inp, ft_mask, ft_mem_seq_lens) t30 = timeit.default_timer() # nvtx.range_push("eff_ext") for i in range(iterations): # nvtx.range_push("eff_ext"+str(i)) - output = eff_custom_encoder(inp, mask, mem_seq_lens) + output = eff_custom_encoder(ft_inp, ft_mask, ft_mem_seq_lens) # nvtx.range_pop() # nvtx.range_pop() t3 = timeit.default_timer() - t30 # time.sleep(60) - print("[INFO] HuggingFaceEnocder time costs: {:.2f} ms".format(t1*1000/iterations)) - print("[INFO] FasterTransformer time costs: {:.2f} ms".format(t2*1000/iterations)) - print("[INFO] EFF-FasterTransformer time costs: {:.2f} ms".format(t3*1000/iterations)) + + if rank == 0: + print("[INFO] HuggingFaceEnocder time costs: {:.2f} ms".format(t1 * 1000 / iterations)) + print("[INFO] FasterTransformer time costs: {:.2f} ms".format(t2 * 1000 / iterations)) + print("[INFO] EFF-FasterTransformer time costs: {:.2f} ms".format(t3 * 1000 / iterations)) if args['thread_num'] > 1: # Multi-threading demonstration + assert world_size == 1, "[ERROR] multi thread does not support MGMN" thread_list = [] thread_num = args['thread_num'] iterations = 100 + def run(): t40 = timeit.default_timer() for i in range(iterations): - output = custom_encoder(inp, mask, mem_seq_lens) + ft_output = custom_encoder(ft_inp, ft_mask, ft_mem_seq_lens)[0] * ft_output_mask t4 = timeit.default_timer() - t40 - diff = torch.abs(hf_output - ft_output) - print('FT Mean diff: {}'.format(torch.mean(diff))) - print('FT Max diff: {}'.format(torch.max(diff))) - print('FT Min diff: {}'.format(torch.min(diff))) - print("[INFO] batch_size {} max_seq_len {} {} layer FT-OP-time {:6.2f} ms with {} threads".format(batch_size, - seq_len, layer_num, t4, thread_num)) + if rank == 0: + diff = torch.abs(hf_output - ft_output) + print('FT Mean diff: {}'.format(torch.mean(diff))) + print('FT Max diff: {}'.format(torch.max(diff))) + print('FT Min diff: {}'.format(torch.min(diff))) + print("[INFO] batch_size {} max_seq_len {} {} layer FT-OP-time {:6.2f} ms with {} threads".format(batch_size, + seq_len, layer_num, t4, thread_num)) for i in range(thread_num): thread_list.append(threading.Thread(target=run, name="RunFT")) @@ -251,10 +303,17 @@ def run(): t.start() for t in thread_list: t.join() - + torch.cuda.empty_cache() sys.stdout.flush() - return max(torch.mean(FT_diff), torch.mean(EFF_diff)) + if rank == 0: + if (args["error_threshold"] != None): + assert max(torch.mean(FT_diff), torch.mean(EFF_diff)) < args["error_threshold"], "[ERROR] TEST FAIL!" + print("[INFO] TEST PASS!") + + return max(torch.mean(FT_diff), torch.mean(EFF_diff)) + else: + return 0 if __name__ == '__main__': diff --git a/examples/pytorch/bert/run_glue.py b/examples/pytorch/bert/run_glue.py index bf258c20c..d4b4373b1 100644 --- a/examples/pytorch/bert/run_glue.py +++ b/examples/pytorch/bert/run_glue.py @@ -46,6 +46,14 @@ def set_seed(args): if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) +def convert_type(tensor, data_type): + if data_type == 'fp16': + return tensor.half() + elif data_type == 'fp32': + return tensor.float() + elif data_type == 'bf16': + return tensor.bfloat16() + def evaluate(args, model, tokenizer, prefix=""): # Loop to handle MNLI double evaluation (matched, mis-matched) @@ -59,7 +67,8 @@ def evaluate(args, model, tokenizer, prefix=""): if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: os.makedirs(eval_output_dir) - args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + # args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + args.eval_batch_size = 1 # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) @@ -83,19 +92,20 @@ def evaluate(args, model, tokenizer, prefix=""): batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): - inputs = [batch[0], batch[1].half() if args.data_type == 'fp16' else batch[1], batch[2]] + inputs = [batch[0], convert_type(batch[1], args.data_type), batch[2]] outputs = model(*inputs) # tmp_eval_loss, logits = outputs[:2] logits = outputs[0] # eval_loss += tmp_eval_loss.mean().item() + nb_eval_steps += 1 if preds is None: - preds = logits.detach().cpu().numpy() - out_label_ids = batch[3].detach().cpu().numpy() + preds = logits.detach().float().cpu().numpy() + out_label_ids = batch[3].detach().float().cpu().numpy() else: - preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) - out_label_ids = np.append(out_label_ids, batch[3].detach().cpu().numpy(), axis=0) + preds = np.append(preds, logits.detach().float().cpu().numpy(), axis=0) + out_label_ids = np.append(out_label_ids, batch[3].detach().float().cpu().numpy(), axis=0) evalTime = timeit.default_timer() - start_time logger.info(" Evaluation for " + eval_task + " done in total %f secs (%f sec per example)", evalTime, evalTime / len(eval_dataset)) @@ -244,7 +254,7 @@ def main(): parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--model_type", type=str, help="ori, ths, thsext") - parser.add_argument("--data_type", type=str, help="fp32, fp16") + parser.add_argument('--data_type', type=str, choices=['fp32', 'fp16', 'bf16'], default='fp32') parser.add_argument('--ths_path', type=str, default='./lib/libth_bert.so', help='path of the pyt_fastertransformer dynamic lib file') parser.add_argument('--remove_padding', action='store_true', @@ -325,6 +335,9 @@ def main(): if args.data_type == 'fp16': logger.info("Use fp16") model.half() + elif args.data_type == 'bf16': + logger.info("Use bf16") + model.bfloat16() if args.model_type == 'thsext': logger.info("Use custom BERT encoder for TorchScript") from utils.encoder import EncoderWeights, CustomEncoder @@ -334,6 +347,8 @@ def main(): weights.to_cuda() if args.data_type == 'fp16': weights.to_half() + elif args.data_type == 'bf16': + weights.to_bfloat16() enc = CustomEncoder(model.config.num_hidden_layers, model.config.num_attention_heads, model.config.hidden_size//model.config.num_attention_heads, @@ -351,6 +366,8 @@ def main(): fake_type_id = fake_input_id.clone().detach() if args.data_type == 'fp16': fake_mask = fake_mask.half() + elif args.data_type == 'bf16': + fake_mask = fake_mask.bfloat16() model.eval() with torch.no_grad(): model_ = torch.jit.trace(model, (fake_input_id, fake_mask, fake_type_id)) diff --git a/examples/pytorch/bert/run_squad.py b/examples/pytorch/bert/run_squad.py index 8b3b38aeb..294582503 100644 --- a/examples/pytorch/bert/run_squad.py +++ b/examples/pytorch/bert/run_squad.py @@ -355,7 +355,7 @@ def main(): parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") parser.add_argument("--model_type", type=str, help="ori, ths, thsext") - parser.add_argument("--data_type", type=str, help="fp32, fp16") + parser.add_argument("--data_type", type=str, help="fp32, fp16, bf16") parser.add_argument('--ths_path', type=str, default='./lib/libth_bert.so', help='path of the pyt_fastertransformer dynamic lib file') parser.add_argument('--int8_mode', type=int, default=0, metavar='NUMBER', @@ -442,6 +442,9 @@ def main(): elif args.data_type == 'fp16': logger.info("Use fp16") model.half() + elif args.data_type == 'bf16': + logger.info("Use bf16") + model.bfloat16() if args.sparse: logger.info("Sparse mode") if args.model_type == 'thsext': @@ -454,6 +457,8 @@ def main(): weights.to_int8(args.sparse, args.ths_path) elif args.data_type == 'fp16': weights.to_half() + elif args.data_type == 'bf16': + weights.to_bfloat16() weights.to_cuda() enc = CustomEncoder(model.config.num_hidden_layers, model.config.num_attention_heads, @@ -474,6 +479,8 @@ def main(): fake_type_id = fake_input_id.clone().detach() if args.data_type == 'fp16': fake_mask = fake_mask.half() + elif args.data_type == 'bf16': + fake_mask = fake_mask.bfloat16() model.eval() model_ = torch.jit.trace(model, (fake_input_id, fake_mask, fake_type_id)) model = model_ diff --git a/examples/pytorch/bert/scripts/run_mrpc.sh b/examples/pytorch/bert/scripts/run_mrpc.sh index f7a094701..3f5c4a960 100644 --- a/examples/pytorch/bert/scripts/run_mrpc.sh +++ b/examples/pytorch/bert/scripts/run_mrpc.sh @@ -19,7 +19,7 @@ if [ "$1" != "ori" ] && [ "$1" != "ths" ] && [ "$1" != "thsext" ]; then echo "[Usage]: bash PATH_TO_THIS_SCRIPT model_type[ori, ths, thsext] data_type[fp32, fp16]" exit 1 fi -if [ "$2" != "fp32" ] && [ "$2" != "fp16" ]; then +if [ "$2" != "fp32" ] && [ "$2" != "fp16" ] && [ "$2" != "bf16" ]; then echo "wrong data type" echo "[Usage]: bash PATH_TO_THIS_SCRIPT model_type[ori, ext] data_type[fp32, fp16]" exit 1 @@ -57,8 +57,10 @@ cd $MAIN_PATH if [ "$1" == "thsext" ]; then if [ "$2" == "fp32" ]; then $MAIN_PATH/bin/bert_gemm ${batch_size} ${seq_len} 12 64 0 0 - else + elif [ "$2" == "fp16" ]; then $MAIN_PATH/bin/bert_gemm ${batch_size} ${seq_len} 12 64 1 0 + else + $MAIN_PATH/bin/bert_gemm ${batch_size} ${seq_len} 12 64 2 0 fi fi diff --git a/examples/pytorch/bert/scripts/run_squad.sh b/examples/pytorch/bert/scripts/run_squad.sh index 5edb38011..ff33382e0 100644 --- a/examples/pytorch/bert/scripts/run_squad.sh +++ b/examples/pytorch/bert/scripts/run_squad.sh @@ -57,16 +57,18 @@ if [ "$MODEL_TYPE" != "ori" ] && [ "$MODEL_TYPE" != "ths" ] && [ "$MODEL_TYPE" ! echo "wrong model type, need be one of [ori, ths, thsext]" exit 1 fi -if [ "$DATA_TYPE" != "fp32" ] && [ "$DATA_TYPE" != "fp16" ] && [ "$DATA_TYPE" != "int8_1" ] && [ "$DATA_TYPE" != "int8_2" ] && [ "$DATA_TYPE" != "int8_3" ]; then +if [ "$DATA_TYPE" != "fp32" ] && [ "$DATA_TYPE" != "fp16" ] && [ "$DATA_TYPE" != "bf16" ] && [ "$DATA_TYPE" != "int8_1" ] && [ "$DATA_TYPE" != "int8_2" ] && [ "$DATA_TYPE" != "int8_3" ]; then echo "wrong data type, need be one of [fp32, fp16, int8_1, int8_2, int8_3]" exit 1 fi -if [ "$DATA_TYPE" == "fp32" ] || [ "$DATA_TYPE" == "fp16" ]; then +if [ "$DATA_TYPE" == "fp32" ] || [ "$DATA_TYPE" == "fp16" ] || [ "$DATA_TYPE" == "bf16" ]; then if [ "$DATA_TYPE" == "fp32" ]; then - FP16_MODE=0 + DATA_TYPE_ID=0 + elif [ "$DATA_TYPE" == "fp16" ]; then + DATA_TYPE_ID=1 else - FP16_MODE=1 + DATA_TYPE_ID=2 ## bf16 fi INT8_MODE=0 if [ "$MODEL_PATH" == "" ]; then @@ -90,7 +92,7 @@ if [ "$DATA_TYPE" == "fp32" ] || [ "$DATA_TYPE" == "fp16" ]; then HEAD_SIZE=64 fi else - FP16_MODE=1 + DATA_TYPE_ID=1 if [ "$DATA_TYPE" == "int8_1" ]; then INT8_MODE=1 elif [ "$DATA_TYPE" == "int8_2" ]; then @@ -110,7 +112,7 @@ cd $MAIN_PATH/pytorch/bert_squad/squad_data wget -c https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json if [ "$MODEL_TYPE" == "thsext" ]; then - $MAIN_PATH/bin/bert_gemm ${BATCH_SIZE} ${SEQ_LEN} ${HEAD_NUM} ${HEAD_SIZE} ${FP16_MODE} ${INT8_MODE} + $MAIN_PATH/bin/bert_gemm ${BATCH_SIZE} ${SEQ_LEN} ${HEAD_NUM} ${HEAD_SIZE} ${DATA_TYPE_ID} ${INT8_MODE} fi SPCMD="" diff --git a/examples/pytorch/bert/utils/encoder.py b/examples/pytorch/bert/utils/encoder.py index 886f28426..237447f7e 100644 --- a/examples/pytorch/bert/utils/encoder.py +++ b/examples/pytorch/bert/utils/encoder.py @@ -16,18 +16,45 @@ import sys import torch +import torch.distributed as dist from transformers import BertConfig from transformers.modeling_bert import BertEncoder from .checkpoint_quantization import checkpoint_quantization class EncoderWeights(object): - def __init__(self, layer_num, hidden_dim, weights=None, sparse=False): + def __init__(self, layer_num, hidden_dim, weights=None, sparse=False, tensor_para_size=1, pipeline_para_size=1): """weights need be a state_dict of bert model""" self.layer_num = layer_num self.int8 = False self.hidden_dim = hidden_dim self.weights = {} + self.tensor_para_size = tensor_para_size + self.pipeline_para_size = pipeline_para_size + + self.use_mpi = dist.is_mpi_available() + + if self.use_mpi: + try: + dist.init_process_group(backend='mpi') + except: + print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend='mpi'). Maybe the process group has been initialized somewhere else.") + else: + print("[INFO] MPI is not available in this PyTorch build.") + assert tensor_para_size == 1, "[FATAL] MPI is required for tensor_para_size > 1." + assert pipeline_para_size == 1, "[FATAL] MPI is required for pipeline_para_size > 1." + + self.rank = dist.get_rank() if self.use_mpi else 0 + self.device_count = torch.cuda.device_count() + self.device = self.rank % self.device_count + torch.cuda.set_device(self.device) + + world_size = dist.get_world_size() if self.use_mpi else 1 + self.tensor_para_rank = self.rank % self.tensor_para_size + self.pipeline_para_rank = self.rank // self.tensor_para_size + start_layer = self.pipeline_para_rank * self.layer_num // self.pipeline_para_size + end_layer = (self.pipeline_para_rank + 1) * self.layer_num // self.pipeline_para_size + if weights is None: self._generated_weights = True for i in range(layer_num): @@ -72,42 +99,98 @@ def __init__(self, layer_num, hidden_dim, weights=None, sparse=False): def listed_weights(self): ret = [] + start_layer = self.pipeline_para_rank * self.layer_num // self.pipeline_para_size + end_layer = (self.pipeline_para_rank + 1) * self.layer_num // self.pipeline_para_size if not self.int8: - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.query.weight'].transpose(-1, -2) for layer_idx in range(self.layer_num)], 0).contiguous()) # 0 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.query.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.key.weight'].transpose(-1, -2) for layer_idx in range(self.layer_num)], 0).contiguous()) # 2 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.key.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.value.weight'].transpose(-1, -2) for layer_idx in range(self.layer_num)], 0).contiguous()) # 4 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.value.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.output.dense.weight'].transpose(-1, -2) for layer_idx in range(self.layer_num)], 0).contiguous()) # 6 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.output.dense.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.output.LayerNorm.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.output.LayerNorm.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'intermediate.dense.weight'].transpose(-1, -2) for layer_idx in range(self.layer_num)], 0).contiguous()) # 10 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'intermediate.dense.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.dense.weight'].transpose(-1, -2) for layer_idx in range(self.layer_num)], 0).contiguous()) # 12 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.dense.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.LayerNorm.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.LayerNorm.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.query.weight'].transpose(-1, -2) + for layer_idx in range(start_layer, end_layer)], 0).contiguous()) # 0 + ret[-1] = ret[-1].split(ret[-1].shape[-1] // self.tensor_para_size, + dim=-1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.self.query.bias'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) + ret[-1] = ret[-1].split(ret[-1].shape[-1] // self.tensor_para_size, + dim=-1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.key.weight'].transpose(-1, -2) + for layer_idx in range(start_layer, end_layer)], 0).contiguous()) # 2 + ret[-1] = ret[-1].split(ret[-1].shape[-1] // self.tensor_para_size, + dim=-1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.self.key.bias'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) + ret[-1] = ret[-1].split(ret[-1].shape[-1] // self.tensor_para_size, + dim=-1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.value.weight'].transpose(-1, -2) + for layer_idx in range(start_layer, end_layer)], 0).contiguous()) # 4 + ret[-1] = ret[-1].split(ret[-1].shape[-1] // self.tensor_para_size, + dim=-1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.self.value.bias'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) + ret[-1] = ret[-1].split(ret[-1].shape[-1] // self.tensor_para_size, + dim=-1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.output.dense.weight'].transpose(-1, -2) + for layer_idx in range(start_layer, end_layer)], 0).contiguous()) # 6 + ret[-1] = ret[-1].split(ret[-1].shape[1] // self.tensor_para_size, + dim=1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.output.dense.bias'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.output.LayerNorm.weight'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.output.LayerNorm.bias'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'intermediate.dense.weight'].transpose(-1, -2) + for layer_idx in range(start_layer, end_layer)], 0).contiguous()) # 10 + ret[-1] = ret[-1].split(ret[-1].shape[-1] // self.tensor_para_size, + dim=-1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'intermediate.dense.bias'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) + ret[-1] = ret[-1].split(ret[-1].shape[-1] // self.tensor_para_size, + dim=-1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.dense.weight'].transpose(-1, -2) + for layer_idx in range(start_layer, end_layer)], 0).contiguous()) # 12 + ret[-1] = ret[-1].split(ret[-1].shape[1] // self.tensor_para_size, + dim=1)[self.tensor_para_rank].contiguous() + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'output.dense.bias'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'output.LayerNorm.weight'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'output.LayerNorm.bias'] for layer_idx in range(start_layer, end_layer)], 0).contiguous()) else: - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.query.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) # 0 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.query.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.key.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) # 2 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.key.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.value.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) # 4 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.value.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.output.dense.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) # 6 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.output.dense.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.output.LayerNorm.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.output.LayerNorm.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'intermediate.dense.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) # 10 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'intermediate.dense.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.dense.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) # 12 - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.dense.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.LayerNorm.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.LayerNorm.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'amaxList'] for layer_idx in range(self.layer_num)], 0).contiguous()) - ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'h_amaxList'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.self.query.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) # 0 + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.self.query.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'attention.self.key.weight'] + for layer_idx in range(self.layer_num)], 0).contiguous()) # 2 + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.self.key.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.self.value.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) # 4 + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.self.value.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.output.dense.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) # 6 + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.output.dense.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.output.LayerNorm.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'attention.output.LayerNorm.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'intermediate.dense.weight'] + for layer_idx in range(self.layer_num)], 0).contiguous()) # 10 + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'intermediate.dense.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'output.dense.weight'] + for layer_idx in range(self.layer_num)], 0).contiguous()) # 12 + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'output.dense.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'output.LayerNorm.weight'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'output.LayerNorm.bias'] for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + 'amaxList'] + for layer_idx in range(self.layer_num)], 0).contiguous()) + ret.append(torch.stack([self.weights['bert.encoder.layer.' + str(layer_idx) + '.' + + 'h_amaxList'] for layer_idx in range(self.layer_num)], 0).contiguous()) return ret def to_cuda(self): @@ -120,7 +203,7 @@ def to_cuda(self): if "amaxList" in k: k_h = k.replace("amaxList", "h_amaxList") h_scale_list[k_h] = v - self.weights[k] = v.cuda() + self.weights[k] = v.cuda() for k, v in h_scale_list.items(): self.weights[k] = v @@ -130,6 +213,12 @@ def to_half(self): for k, v in self.weights.items(): self.weights[k] = v.half() + def to_bfloat16(self): + if self.int8: + raise RuntimeError("Cannot cast to bfloat16 if the weights have been casted to int8.") + for k, v in self.weights.items(): + self.weights[k] = v.bfloat16() + def to_int8(self, sparse=False, ths_path='./lib/libth_bert.so'): if self._generated_weights: amax_tensor_1 = torch.Tensor(self.hidden_dim).fill_(127.) @@ -180,8 +269,9 @@ def to_int8(self, sparse=False, ths_path='./lib/libth_bert.so'): class CustomEncoder(torch.nn.Module): def __init__(self, layer_num, head_num, head_size, weights, - int8_mode=0, remove_padding=False,sparse=False, - path='./lib/libth_bert.so'): + int8_mode=0, remove_padding=False, sparse=False, + path='./lib/libth_bert.so', tensor_para_size=1, + pipeline_para_size=1): super().__init__() self.layer_num = layer_num self.remove_padding = remove_padding @@ -189,28 +279,45 @@ def __init__(self, layer_num, head_num, head_size, weights, torch.classes.load_library(path) weights_ = weights.listed_weights() + + self.use_mpi = dist.is_mpi_available() + + if self.use_mpi: + try: + dist.init_process_group(backend='mpi') + except: + print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend='mpi'). Maybe the process group has been initialized somewhere else.") + else: + print("[INFO] MPI is not available in this PyTorch build.") + assert tensor_para_size == 1, "[FATAL] MPI is required for tensor_para_size > 1." + assert pipeline_para_size == 1, "[FATAL] MPI is required for pipeline_para_size > 1." + if int8_mode == 0: assert len(weights_) == 16 try: self.encoders = torch.classes.FasterTransformer.Bert( - *weights_, - head_num, head_size, 4 * head_num * head_size, remove_padding, layer_num, sparse, 1.0) + *weights_, + head_num, head_size, 4 * head_num * head_size, remove_padding, layer_num, sparse, 1.0, + tensor_para_size, pipeline_para_size) except: # legacy ths for 20.03 image self.encoders = torch.classes.FasterTransformerBert( - *weights_, - head_num, head_size, 4 * head_num * head_size, remove_padding, layer_num, sparse, 1.0) + *weights_, + head_num, head_size, 4 * head_num * head_size, remove_padding, layer_num, sparse, 1.0, + tensor_para_size, pipeline_para_size) else: assert len(weights_) == 18 + assert tensor_para_size == 1, "INT8 BERT still only support tensor_para_size = 1" + assert pipeline_para_size == 1, "INT8 BERT still only support pipeline_para_size = 1" try: self.encoders = torch.classes.FasterTransformer.INT8Bert( - *weights_, - head_num, head_size, remove_padding, layer_num, int8_mode, sparse, 1.0) + *weights_, + head_num, head_size, remove_padding, layer_num, int8_mode, sparse, 1.0) except: # legacy ths for 20.03 image self.encoders = torch.classes.FasterTransformerINT8Bert( - *weights_, - head_num, head_size, remove_padding, layer_num, int8_mode, sparse, 1.0) + *weights_, + head_num, head_size, remove_padding, layer_num, int8_mode, sparse, 1.0) def forward(self, hidden_states, attention_mask, sequence_lengths): hidden_states = self.encoders.forward(hidden_states, sequence_lengths) @@ -223,7 +330,8 @@ def __init__(self, layer_num, head_num, head_size, weights=None): hidden_dim = head_num * head_size # TODO(bhsueh) The implementation of hidden_act='gelu' is different to FT's (and google BERT) implementation # FT's implementation is equivalent to hidden_act='gelu_new', but there are some issues for int8 sparse under gelu_new - conf = BertConfig(hidden_size=hidden_dim, intermediate_size=4*hidden_dim, num_attention_heads=head_num, num_hidden_layers=layer_num, hidden_act='gelu') + conf = BertConfig(hidden_size=hidden_dim, intermediate_size=4 * hidden_dim, + num_attention_heads=head_num, num_hidden_layers=layer_num, hidden_act='gelu') self.encoder = BertEncoder(conf) w = {} for k, v in weights.weights.items(): diff --git a/examples/pytorch/bert/utils/huggingface_bert_convert.py b/examples/pytorch/bert/utils/huggingface_bert_convert.py new file mode 100644 index 000000000..f8c2537fd --- /dev/null +++ b/examples/pytorch/bert/utils/huggingface_bert_convert.py @@ -0,0 +1,149 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +Convert huggingface bert model. Use https://huggingface.co/bert-base-uncased as demo. +''' + +import argparse +import configparser +import multiprocessing +import numpy as np +import pathlib +import torch +import os +import sys + +# __root_package_path__ = pathlib.Path(__file__).parent.parent.parent.parent.parent.absolute().as_posix() +# if __root_package_path__ not in sys.path: +# print( +# f"[ERROR] add project root directory to your PYTHONPATH with " +# f"'export PYTHONPATH={__root_package_path__}:${{PYTHONPATH}}'" +# ) + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(dir_path + "/../../../..") +sys.path.append(dir_path) +from examples.pytorch.utils import torch2np, safe_transpose, WEIGHT2DTYPE + +from transformers import BertModel # transformers-4.10.0-py3 + +def split_and_convert_process(i, saved_dir,factor,key, args, val): + + if key.find("attention.output.dense.bias") != -1 or \ + key.find("attention.output.LayerNorm.weight") != -1 or \ + key.find("attention.output.LayerNorm.bias") != -1 or \ + key.find("output.dense.bias") != -1 or \ + key.find("output.LayerNorm.weight") != -1 or \ + key.find("output.LayerNorm.bias") != -1 : + + # shared weights, only need to convert the weights of rank 0 + if i == 0: + saved_path = saved_dir + "/model." + key + ".bin" + val.tofile(saved_path) + + elif key.find("attention.output.dense.weight") != -1 or key.find("output.dense.weight") != -1: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + saved_path = f"{saved_dir}/model.{key}.{i * factor + j}.bin" + split_vals[j].tofile(saved_path) + + elif key.find("attention.self.query.weight") != -1 or \ + key.find("attention.self.query.bias") != -1 or \ + key.find("attention.self.key.weight") != -1 or \ + key.find("attention.self.key.bias") != -1 or \ + key.find("attention.self.value.weight") != -1 or \ + key.find("attention.self.value.bias") != -1 or \ + key.find("intermediate.dense.weight") != -1 or \ + key.find("intermediate.dense.bias") != -1: + + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/model." + key + ".%d.bin" % (i * factor + j) + split_vals[j].tofile(saved_path) + + else: + print("[WARNING] cannot convert key '{}'".format(key)) + +def split_and_convert(args): + saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_tensor_para_size + + if(os.path.exists(saved_dir) == False): + os.makedirs(saved_dir) + ckpt_name = args.in_file + + t_gpu_num = args.training_tensor_para_size + i_gpu_num = args.infer_tensor_para_size + assert(i_gpu_num % t_gpu_num == 0) + + factor = (int)(i_gpu_num / t_gpu_num) + + # load position_embedding from rank 0 + torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' + model = BertModel.from_pretrained(args.in_file).to(torch_device) + np_weight_data_type = WEIGHT2DTYPE[args.weight_data_type] + + hf_config = vars(model.config) + + # NOTE: save parameters to config files (loaded by triton backends) + config = configparser.ConfigParser() + config["bert"] = {} + try: + config["bert"]["model_name"] = "bert" if hf_config["model_type"] == '' else hf_config["model_type"] + config["bert"]["position_embedding_type"] = str(hf_config["position_embedding_type"]) + config["bert"]["hidden_size"] = str(hf_config["hidden_size"]) + config["bert"]["num_layer"] = str(hf_config["num_hidden_layers"]) + config["bert"]["head_num"] = str(hf_config["num_attention_heads"]) + config["bert"]["size_per_head"] = str(hf_config["hidden_size"] // hf_config["num_attention_heads"]) + config["bert"]["activation_type"] = str(hf_config["hidden_act"]) + config["bert"]["inter_size"] = str(hf_config["intermediate_size"]) + config["bert"]["max_position_embeddings"] = str(hf_config["max_position_embeddings"]) + config["bert"]["layer_norm_eps"] = str(hf_config["layer_norm_eps"]) + config["bert"]["weight_data_type"] = args.weight_data_type + config["bert"]["tensor_para_size"] = str(args.infer_tensor_para_size) + with open(saved_dir + "/config.ini", 'w') as configfile: + config.write(configfile) + except: + print(f"Fail to save the config in config.ini.") + + torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_sharing_strategy("file_system") + pool = multiprocessing.Pool(args.processes) + for name, param in model.named_parameters(): + if name.find("weight") == -1 and name.find("bias") == -1: + continue + else: + pool.starmap(split_and_convert_process, + [(0, saved_dir, factor, name, args, + torch2np(safe_transpose(param.detach()), np_weight_data_type))], ) + + pool.close() + pool.join() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('-saved_dir', '-o', type=str, help='file name of output file', required=True) + parser.add_argument('-in_file', '-i', type=str, help='file name of input checkpoint file', required=True) + parser.add_argument('-training_tensor_para_size', '-t_g', type=int, help='The size of tensor parallelism for training.', default=1) + parser.add_argument('-infer_tensor_para_size', '-i_g', type=int, help='The size of tensor parallelism for inference.', required=True) + parser.add_argument("-processes", "-p", type=int, help="How many processes to spawn for conversion (default: 4)", default=4) + parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) + + args = parser.parse_args() + print("\n=============== Argument ===============") + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") + + split_and_convert(args) diff --git a/examples/pytorch/bert/utils/update_bert_config.py b/examples/pytorch/bert/utils/update_bert_config.py new file mode 100644 index 000000000..cf3df6a67 --- /dev/null +++ b/examples/pytorch/bert/utils/update_bert_config.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import pathlib + + +def main(): + parser = argparse.ArgumentParser( + description="Script updating GPT config.ini hyper-parameters and requests parameters" + ) + # config.ini path + parser.add_argument("--config-ini-path", required=True, help="Path to config.ini file to be updated") + + # FT hyperparameters + parser.add_argument("--model-dir", type=str, required=True, help="Model path prefix") + parser.add_argument("--tensor-para-size", type=int, required=True, help="tensor parallelism size") + parser.add_argument("--pipeline-para-size", type=int, required=True, help="layer parallelism size") + parser.add_argument("--data-type", type=str, default="fp32", help="data type", choices=["fp32", "fp16", "bf16"]) + # request + parser.add_argument("--request-batch-size", type=int, default=8, help="batch size") + parser.add_argument("--request-seq-len", type=int, default=32, help="output length") + + args = parser.parse_args() + + config_path = pathlib.Path(args.config_ini_path) + + config = configparser.ConfigParser() + config.read(config_path) + + config["ft_instance_hyperparameter"] = { + "tensor_para_size": args.tensor_para_size, + "pipeline_para_size": args.pipeline_para_size, + "data_type": args.data_type, + "is_sparse": 0, + "is_remove_padding": 1, + "int8_mode": 0, + "enable_custom_all_reduce": 0, + "model_name": "bert_base", + "model_dir": args.model_dir, + } + + config["request"] = { + "request_batch_size": args.request_batch_size, + "request_seq_len": args.request_seq_len, + } + + with config_path.open("w") as config_file: + config.write(config_file) + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/decoder/decoder_example.py b/examples/pytorch/decoder/decoder_example.py index 3278a0d8b..b62b19f71 100644 --- a/examples/pytorch/decoder/decoder_example.py +++ b/examples/pytorch/decoder/decoder_example.py @@ -50,7 +50,7 @@ def main(): parser.add_argument('--ths_path', type=str, default='./lib/libpyt_fastertransformer.so', help='path of the pyt_fastertransformer dynamic lib file') parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', - help='data type (default: fp32)', choices=['fp32', 'fp16']) + help='data type (default: fp32)', choices=['fp32', 'fp16', 'bf16']) args = parser.parse_args() @@ -101,6 +101,9 @@ def main(): if args.data_type == 'fp16': weights.to_half() ft_weights.to_half() + elif args.data_type == 'bf16': + weights.to_bfloat16() + ft_weights.to_bfloat16() custom_decoder = FTDecoder(args.head_num, args.head_size, hidden_dim, args.layer_num, ft_weights, args) with torch.no_grad(): diff --git a/examples/pytorch/decoder/utils/ft_decoder.py b/examples/pytorch/decoder/utils/ft_decoder.py index 35e79c3df..392e73ada 100644 --- a/examples/pytorch/decoder/utils/ft_decoder.py +++ b/examples/pytorch/decoder/utils/ft_decoder.py @@ -112,6 +112,10 @@ def to_half(self): for i in range(len(self.w)): self.w[i] = self.w[i].half() + def to_bfloat16(self): + for i in range(len(self.w)): + self.w[i] = self.w[i].bfloat16() + class FTDecoder(nn.Module): def __init__(self, head_num, head_size, mem_hidden_dim, layer_num, weights, args): super().__init__() diff --git a/examples/pytorch/decoding/decoding_example.py b/examples/pytorch/decoding/decoding_example.py index e64e1cea6..c649a92f7 100644 --- a/examples/pytorch/decoding/decoding_example.py +++ b/examples/pytorch/decoding/decoding_example.py @@ -49,8 +49,7 @@ def main(): help='beam size') parser.add_argument('vocab_size', type=int, help='vocab size') - parser.add_argument('--fp16', action='store_true', - help='is fp16') + parser.add_argument('--data_type', type=str, choices=['fp32', 'fp16', 'bf16'], default='fp32') parser.add_argument('--time', action='store_true', help='test the time or not.') parser.add_argument('--use_pretrained', action='store_true', @@ -95,13 +94,15 @@ def main(): print("{}: {}".format(key, vars(args)[key])) print("========================================") - decodingargs1 = ArgHelper('torch_decoding', 'fp16' if args.fp16 else 'fp32', os.path.abspath(args.decoder_ths_path), os.path.abspath(args.decoding_ths_path)) - decodingargs2 = ArgHelper('torch_decoding_with_decoder_ext', 'fp16' if args.fp16 else 'fp32', os.path.abspath(args.decoder_ths_path), os.path.abspath(args.decoding_ths_path)) + decodingargs1 = ArgHelper('torch_decoding', args.data_type, os.path.abspath(args.decoder_ths_path), os.path.abspath(args.decoding_ths_path)) + decodingargs2 = ArgHelper('torch_decoding_with_decoder_ext', args.data_type, os.path.abspath(args.decoder_ths_path), os.path.abspath(args.decoding_ths_path)) mem = torch.empty(args.batch_size, args.seq_len, args.memory_hidden_dim).cuda() torch.nn.init.uniform_(mem, -1, 1) - if args.fp16: + if args.data_type == "fp16": mem = mem.half() + elif args.data_type == "bf16": + mem = mem.bloat16() mem_seq_lens = torch.randint(1, args.seq_len+1, (args.batch_size,), dtype=torch.int32).cuda() if args.use_pretrained: @@ -124,14 +125,19 @@ def fix_key(s): torch_decoding_with_decoder_ext = TorchDecoding(layer_num, head_num, head_size, vocab_size, start_id, end_id, weights, args=decodingargs2) torch_decoding.cuda() torch_decoding_with_decoder_ext.cuda() - if args.fp16: + if args.data_type == "fp16": torch_decoding.half() torch_decoding_with_decoder_ext.half() + elif args.data_type == "bf16": + torch_decoding.bloat16() + torch_decoding_with_decoder_ext.bloat16() torch_decoding.eval() torch_decoding_with_decoder_ext.eval() ft_weights.to_cuda() - if args.fp16: + if args.data_type == "fp16": ft_weights.to_half() + elif args.data_type == "bf16": + ft_weights.to_bfloat16() custom_decoding = CustomDecoding(head_num, head_size, inter_size, args.memory_hidden_dim, layer_num, vocab_size, start_id, end_id, args.beam_search_diversity_rate, diff --git a/examples/pytorch/decoding/translate_example.py b/examples/pytorch/decoding/translate_example.py index 522649b32..dae9ea19b 100644 --- a/examples/pytorch/decoding/translate_example.py +++ b/examples/pytorch/decoding/translate_example.py @@ -33,7 +33,7 @@ parser.add_argument("--max_seq_len", type=int, default=100, help="max_seq_len") parser.add_argument("--model_type", type=str, help="decoding_ext, torch_decoding, torch_decoding_with_decoder_ext", choices=['decoding_ext', 'torch_decoding', 'torch_decoding_with_decoder_ext'], required=True) -parser.add_argument("--data_type", type=str, help="fp32, fp16") +parser.add_argument('--data_type', type=str, choices=['fp32', 'fp16', 'bf16'], default='fp32') parser.add_argument('--model_path', type=str, default='./pytorch/translation/models/averaged-10-epoch.pt', help='path for model checkpoint') parser.add_argument('--decoding_ths_path', type=str, default='./lib/libth_decoding.so', diff --git a/examples/pytorch/decoding/utils/bleu_score.py b/examples/pytorch/decoding/utils/bleu_score.py new file mode 100644 index 000000000..5c79235cf --- /dev/null +++ b/examples/pytorch/decoding/utils/bleu_score.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from sacrebleu import corpus_bleu + +def bleu_score(pred_file, ref_file, bleu_score_threshold=None): + with open(pred_file, "r") as pred_stream, open(ref_file, "r") as ref_stream: + pred_stream_txt = pred_stream.readlines() + ref_stream_txt = ref_stream.readlines() + bleu = corpus_bleu(pred_stream_txt, [ref_stream_txt], force=True) + print(" bleu score: {:6.2f}".format(bleu.score)) + print(" bleu counts: {}".format(bleu.counts)) + print(" bleu totals: {}".format(bleu.totals)) + print(" bleu precisions: {}".format(bleu.precisions)) + print(" bleu sys_len: {}; ref_len: {}".format(bleu.sys_len, bleu.ref_len)) + if bleu_score_threshold != None: + assert bleu.score >= bleu_score_threshold, "TEST FAIL !" + print("[INFO] TEST PASS !") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--pred_file', type=str, metavar='NUMBER', + help='The prediction files.', required=True) + parser.add_argument('--ref_file', type=str, metavar='NUMBER', + help='The reference files.', required=True) + parser.add_argument('--bleu_score_threshold', type=float, metavar='NUMBER', + help='The threshold of bleu score.') + args = parser.parse_args() + + bleu_score(args.pred_file, args.ref_file, args.bleu_score_threshold) diff --git a/examples/pytorch/decoding/utils/decoding.py b/examples/pytorch/decoding/utils/decoding.py index e8c6cf42b..81d565835 100644 --- a/examples/pytorch/decoding/utils/decoding.py +++ b/examples/pytorch/decoding/utils/decoding.py @@ -35,8 +35,10 @@ USE_CACHE_BATCH_MAJOR_ATTENTION = True -def get_op_cache_config(size_per_head, is_fp16): - x = 8 if is_fp16 else 4 +to_torch_type = {'fp32' : torch.float32, 'fp16' : torch.float16, 'bf16' : torch.bfloat16} + +def get_op_cache_config(size_per_head, is_fp32): + x = 4 if is_fp32 else 8 use_batch_major_op_cache = True if USE_CACHE_BATCH_MAJOR_ATTENTION == True and \ size_per_head % x == 0 \ else False @@ -117,6 +119,14 @@ def to_half(self): self.w[key][next_key] = self.w[key][next_key].half() else: self.w[key] = self.w[key].half() + + def to_bfloat16(self): + for key in self.w: + if isinstance(self.w[key], dict): + for next_key in self.w[key]: + self.w[key][next_key] = self.w[key][next_key].bfloat16() + else: + self.w[key] = self.w[key].bfloat16() def _get_position_encoding(self): pe = torch.zeros(self.max_step_for_pe, self.hidden_dim) @@ -233,8 +243,8 @@ def __init__(self, num_layers, d_model, heads, head_size, d_ff, # relevant to custom cache config # self.use_batch_major_op_cache = False # self.op_cache_dim_x = 1 - self.is_fp16 = True if self.args.data_type == 'fp16' else False - self.use_batch_major_op_cache, self.op_cache_dim_x = get_op_cache_config(head_size, self.is_fp16) + self.is_fp32 = True if self.args.data_type == 'fp32' else False + self.use_batch_major_op_cache, self.op_cache_dim_x = get_op_cache_config(head_size, self.is_fp32) self.head_num = heads self.size_per_head = head_size @@ -388,7 +398,7 @@ def _init_cache(self, memory_bank, decoding_max_seq_len): self.state["cache"]["layer_{}".format(i)] = layer_cache elif self.args.model_type == 'decoder_ext' or self.args.model_type == 'torch_decoding_with_decoder_ext': max_seq_len = memory_bank.size(0) - dtype = torch.half if self.args.data_type == 'fp16' else torch.float32 + dtype = to_torch_type[self.args.data_type] self.state['cache']['mem'] = [torch.zeros(self.transformer_layers[0].layer_num, batch_size, max_seq_len, depth, dtype=dtype, device='cuda'), torch.zeros(self.transformer_layers[0].layer_num, batch_size, max_seq_len, depth, dtype=dtype, device='cuda')] self.state['cache']['self'] = [ torch.zeros(self.transformer_layers[0].layer_num, batch_size, self.head_num, self.size_per_head // self.op_cache_dim_x, @@ -461,6 +471,8 @@ def __init__(self, layer_num, head_num, head_size, vocab_size, start_id, end_id, ft_decoder_weights.to_cuda() if args.data_type == 'fp16': ft_decoder_weights.to_half() + elif args.data_type == 'bf16': + ft_decoder_weights.to_bfloat16() self.decoder.transformer_layers = nn.ModuleList( [FTDecoder(head_num, head_size, head_num * head_size, layer_num, ft_decoder_weights, args)]) else: diff --git a/examples/pytorch/decoding/utils/ft_decoding.py b/examples/pytorch/decoding/utils/ft_decoding.py index db37e9188..f87f12b55 100644 --- a/examples/pytorch/decoding/utils/ft_decoding.py +++ b/examples/pytorch/decoding/utils/ft_decoding.py @@ -113,6 +113,10 @@ def to_cuda(self): def to_half(self): for i in range(len(self.w)): self.w[i] = self.w[i].half() + + def to_bfloat16(self): + for i in range(len(self.w)): + self.w[i] = self.w[i].bfloat16() def _get_position_encoding(self): pe = torch.zeros(self.max_step_for_pe, self.hidden_dim) diff --git a/examples/pytorch/decoding/utils/translation_model.py b/examples/pytorch/decoding/utils/translation_model.py index ace42364f..63400140f 100644 --- a/examples/pytorch/decoding/utils/translation_model.py +++ b/examples/pytorch/decoding/utils/translation_model.py @@ -97,6 +97,8 @@ def load_test_model(opt, args): model.float() elif args.data_type == 'fp16': model.half() + elif args.data_type == 'bf16': + model.bfloat16() else: raise ValueError('wrong data_type argument {}'.format(args.data_type)) model.eval() @@ -204,6 +206,8 @@ def fix_key(s): encoder_weights = EncoderWeights(model_opt.enc_layers, model_opt.enc_rnn_size, checkpoint['model']) if args.data_type == 'fp16': encoder_weights.to_half() + elif args.data_type == 'bf16': + encoder_weights.to_bfloat16() encoder_weights.to_cuda() encoder = CustomEncoder(model_opt.enc_layers, model_opt.heads, model_opt.enc_rnn_size // model_opt.heads, encoder_weights, path=args.encoder_ths_path, embedding=model.encoder.embeddings) @@ -217,6 +221,8 @@ def fix_key(s): ft_decoding_weights = FtDecodingWeights(model_opt.dec_layers, model_opt.dec_rnn_size, decoding_weights.w) if args.data_type == 'fp16': ft_decoding_weights.to_half() + elif args.data_type == 'bf16': + ft_decoding_weights.to_bfloat16() ft_decoding_weights.to_cuda() model.decoder = CustomDecoding(model_opt.heads, model_opt.dec_rnn_size // model_opt.heads, model_opt.dec_rnn_size * 4, model_opt.dec_rnn_size, model_opt.dec_layers, @@ -230,6 +236,8 @@ def fix_key(s): decoding_weights.to_cuda() if args.data_type == 'fp16': decoding_weights.to_half() + elif args.data_type == 'bf16': + decoding_weights.to_bfloat16() model.decoder = TorchDecoding(model_opt.dec_layers, model_opt.heads, model_opt.dec_rnn_size // model_opt.heads, vocab_size, bos_idx, eos_idx, decoding_weights, args=args) else: @@ -259,4 +267,6 @@ def fix_key(s): model.to(device) if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam': model.half() + elif model_opt.model_dtype == 'bf16' and model_opt.optim == 'fusedadam': + model.bfloat16() return model diff --git a/examples/pytorch/encoder/encoder_example.py b/examples/pytorch/encoder/encoder_example.py index 61d9ed5b9..0bbb7d528 100644 --- a/examples/pytorch/encoder/encoder_example.py +++ b/examples/pytorch/encoder/encoder_example.py @@ -43,8 +43,7 @@ def main(): help='head number') parser.add_argument('head_size', type=int, help='size per head') - parser.add_argument('--fp16', action='store_true', - help='is fp16') + parser.add_argument('--data_type', type=str, choices=['fp32', 'fp16', 'bf16'], default='fp32') parser.add_argument('--time', action='store_true', help='test the time or not.') parser.add_argument('--avg_seq_len', type=int, default=-1, metavar='NUMBER', @@ -89,19 +88,19 @@ def encoder_example(args): raise ValueError("wrong avg_seq_len") mask = ~sequence_mask(mem_seq_lens, seq_len).unsqueeze(1) - if args['fp16']: + if args['data_type'] == 'fp16': inp = inp.half() weights = EncoderWeights(layer_num, hidden_dim) onmt_encoder = ONMTEncoder(layer_num, hidden_dim, head_num, 4 * hidden_dim, weights) onmt_encoder.cuda() - if args['fp16']: + if args['data_type'] == 'fp16': onmt_encoder.half() onmt_encoder.eval() onmt_encoder = torch.jit.trace(onmt_encoder, (inp, mask)) - if args['fp16']: + if args['data_type'] == 'fp16': weights.to_half() weights.to_cuda() custom_encoder = CustomEncoder(layer_num, head_num, head_size, weights, diff --git a/examples/pytorch/encoder/utils/ft_encoder.py b/examples/pytorch/encoder/utils/ft_encoder.py index 515afc12c..04a8061d7 100644 --- a/examples/pytorch/encoder/utils/ft_encoder.py +++ b/examples/pytorch/encoder/utils/ft_encoder.py @@ -84,6 +84,10 @@ def to_half(self): for k, v in self.weights.items(): self.weights[k] = v.half() + def to_bfloat16(self): + for k, v in self.weights.items(): + self.weights[k] = v.bfloat16() + class CustomEncoder(torch.nn.Module): def __init__(self, layer_num, head_num, head_size, weights, diff --git a/examples/pytorch/gpt/duplicate_input_ids.txt b/examples/pytorch/gpt/duplicate_input_ids.txt new file mode 100644 index 000000000..bd1869dd6 --- /dev/null +++ b/examples/pytorch/gpt/duplicate_input_ids.txt @@ -0,0 +1,8 @@ +Tyler Skaggs (July 13, 1991 – July 1, 2019) was an American +Tyler Skaggs (July 13, 1991 – July 1, 2019) was an American +Tyler Skaggs (July 13, 1991 – July 1, 2019) was an American +NASA releases the first operational image (shown) taken by +NASA releases the first operational image (shown) taken by +Did you know that the US Mint released the 1909-S VDB Lincoln Cent +Did you know that the US Mint released the 1909-S VDB Lincoln Cent +Tyler Skaggs (July 13, 1991 – July 1, 2019) was an American diff --git a/examples/pytorch/gpt/evaluate_zeroshot_gpt.py b/examples/pytorch/gpt/evaluate_zeroshot_gpt.py index db3c89a5e..42e02d14e 100644 --- a/examples/pytorch/gpt/evaluate_zeroshot_gpt.py +++ b/examples/pytorch/gpt/evaluate_zeroshot_gpt.py @@ -97,7 +97,7 @@ def get_tasks_args(parser): help='Whether to use negative examples during model ' 'training') group.add_argument('--train-hard-neg', type=int, default=0, - help='Number of hard negative exmaples to use during ' + help='Number of hard negative examples to use during ' 'training') # parameters for Av.rank validation method @@ -119,7 +119,13 @@ def get_tasks_args(parser): help='top k for sampling.') group.add_argument('--top_p', type=float, required=True, help='top p for sampling.') - + group.add_argument( + '--weights_data_type', + type=str, + default="fp32", + choices=["fp32", "fp16"], + help='Data type of FT checkpoint weights', + ) return parser @@ -158,7 +164,7 @@ def process_batch(batch): labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() - # Get the masks and postition ids. + # Get the masks and position ids. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, @@ -190,6 +196,7 @@ def forward_step(batch, model, eval_metric, args): start_lengths = torch.sum(tokens != model.end_id, axis=1).contiguous().int() input_len = torch.max(start_lengths).contiguous().int() output = [] + random_seed_tensor = 0 * torch.ones([max_batch_size], dtype=torch.int64) for i in range(input_len): tmp_length = torch.ones(args.micro_batch_size) * (i + 1) tmp_length = tmp_length.cuda().int() @@ -199,15 +206,14 @@ def forward_step(batch, model, eval_metric, args): output_id = model(input_ids, tmp_start_lengths, 1, - args.beam_width, - args.top_k, - args.top_p, - 0.0, - 1.0, - 1.0, - 1.0, - 0) - + args.top_k * torch.ones(size=[max_batch_size], dtype=torch.int32), + args.top_p * torch.ones(size=[max_batch_size], dtype=torch.float32), + 0.0 * torch.ones(size=[max_batch_size], dtype=torch.float32), + 1.0 * torch.ones(size=[max_batch_size], dtype=torch.float32), + 1.0 * torch.ones(size=[max_batch_size], dtype=torch.float32), + 1.0 * torch.ones(size=[max_batch_size], dtype=torch.float32), + random_seed_tensor) + output.append(output_id[:,0,-1].reshape([-1, 1])) output = torch.cat((output), 1) @@ -318,7 +324,7 @@ def main(): # Set up model and load checkpoint. model = GPT(args.num_attention_heads, (int)(args.hidden_size / args.num_attention_heads), args.padded_vocab_size, tokenzier.eod, tokenzier.eod, - args.num_layers, args.seq_length, 1, 1, "lib/libth_gpt.so") + args.num_layers, args.seq_length, 1, 1, "lib/libth_gpt.so", weights_data_type=args.weights_data_type) if not model.load(ckpt_path=args.ckpt_path): print("[ERROR] Checkpoint file not found at {}.".format(args.ckpt_path)) diff --git a/examples/pytorch/gpt/gpt_example.py b/examples/pytorch/gpt/gpt_example.py index d2d45c87d..5c6782963 100644 --- a/examples/pytorch/gpt/gpt_example.py +++ b/examples/pytorch/gpt/gpt_example.py @@ -49,7 +49,7 @@ def main(): help='top p probability threshold') parser.add_argument('--temperature', type=float, default=1., help='temperature') - parser.add_argument('--len_penalty', type=float, default=1., + parser.add_argument('--len_penalty', type=float, default=0., help='len_penalty') parser.add_argument('--beam_search_diversity_rate', type=float, default=0., help='beam_search_diversity_rate') @@ -82,10 +82,17 @@ def main(): help='path to sample input file. If not set, it runs with no context inputs.') parser.add_argument('--sample_output_file', type=str, default=None, help='path to sample output file.') - parser.add_argument('--is_fix_random_seed', type=bool, default=True, - help='is fixing the random seed.') + parser.add_argument('--enable_random_seed', action='store_true', + help='is enable the random seed.') parser.add_argument('--sparse', action='store_true', dest='sparse', help='Enable sparse matrix multiplication. (Need SM 8.0 or 8.6 and SPARSITY_SUPPORT=ON)') + parser.add_argument( + '--weights_data_type', + type=str, + default="fp32", + choices=["fp32", "fp16"], + help='Data type of FT checkpoint weights', + ) parser.add_argument('--return_cum_log_probs', type=int, default=0, choices=[0, 1, 2], help='Whether to compute the cumulative log probsbility of sentences.' ' 0: do not return the cumulative log probs ' @@ -143,14 +150,15 @@ def main(): start_ids = pad_sequence(start_ids, batch_first=True, padding_value=end_id) start_lengths = torch.IntTensor(start_lengths) - if args.is_fix_random_seed == True: - random_seed = 0 + if args.enable_random_seed == True: + random_seed_tensor = torch.randint(0, 10000, size=[max_batch_size], dtype=torch.int64) else: - random_seed = random.randint(0, 100000) + random_seed_tensor = torch.zeros([max_batch_size], dtype=torch.int64) # Prepare model. gpt = GPT(head_num, size_per_head, vocab_size, start_id, end_id, layer_num, - max_seq_len, tensor_para_size, pipeline_para_size, lib_path=args.lib_path) + max_seq_len, tensor_para_size, pipeline_para_size, lib_path=args.lib_path, + weights_data_type=args.weights_data_type) if not gpt.load(ckpt_path=args.ckpt_path): print("[WARNING] Checkpoint file not found. Model loading is skipped.") if args.data_type == 'fp16': @@ -167,13 +175,13 @@ def main(): start_lengths, output_len, beam_width, - top_k, - top_p, - beam_search_diversity_rate, - temperature, - len_penalty, - repetition_penalty, - random_seed, + top_k * torch.ones(size=[max_batch_size], dtype=torch.int32), + top_p * torch.ones(size=[max_batch_size], dtype=torch.float32), + beam_search_diversity_rate * torch.ones(size=[max_batch_size], dtype=torch.float32), + temperature * torch.ones(size=[max_batch_size], dtype=torch.float32), + len_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + repetition_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + random_seed_tensor, return_output_length, return_cum_log_probs) if return_cum_log_probs > 0: @@ -204,13 +212,13 @@ def main(): start_lengths, output_len, beam_width, - top_k, - top_p, - beam_search_diversity_rate, - temperature, - len_penalty, - repetition_penalty, - random_seed, + top_k * torch.ones(size=[max_batch_size], dtype=torch.int32), + top_p * torch.ones(size=[max_batch_size], dtype=torch.float32), + beam_search_diversity_rate * torch.ones(size=[max_batch_size], dtype=torch.float32), + temperature * torch.ones(size=[max_batch_size], dtype=torch.float32), + len_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + repetition_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + random_seed_tensor, return_output_length, return_cum_log_probs) @@ -222,13 +230,13 @@ def main(): start_lengths, output_len, beam_width, - top_k, - top_p, - beam_search_diversity_rate, - temperature, - len_penalty, - repetition_penalty, - random_seed, + top_k * torch.ones(size=[max_batch_size], dtype=torch.int32), + top_p * torch.ones(size=[max_batch_size], dtype=torch.float32), + beam_search_diversity_rate * torch.ones(size=[max_batch_size], dtype=torch.float32), + temperature * torch.ones(size=[max_batch_size], dtype=torch.float32), + len_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + repetition_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + random_seed_tensor, return_output_length, return_cum_log_probs) batch_num += 1 diff --git a/examples/pytorch/gpt/gpt_summarization.py b/examples/pytorch/gpt/gpt_summarization.py index 923f675eb..9a01b9175 100644 --- a/examples/pytorch/gpt/gpt_summarization.py +++ b/examples/pytorch/gpt/gpt_summarization.py @@ -14,13 +14,13 @@ from __future__ import print_function import argparse -import configparser import json import numpy as np import os import sys import torch import torch.distributed as dist +import configparser from datetime import datetime from datasets import load_dataset, load_metric from transformers import GPT2Tokenizer, GPT2LMHeadModel @@ -33,9 +33,9 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('--ft_model_location', type=str, - default='/data2/byshiue/models/huggingface-gpt/gpt2-xl/c-models') + default='/models/GPT/HF/gpt2-xl/c-models') parser.add_argument('--hf_model_location', type=str, - default='/data2/byshiue/models/huggingface-gpt/gpt2-xl/gpt2-xl') + default='/models/GPT/HF/gpt2-xl/') parser.add_argument('--summarize', action='store_true') parser.add_argument('--test_hf', action='store_true') parser.add_argument('--data_type', type=str, choices=['fp32', 'fp16', 'bf16'], default='fp32') @@ -49,13 +49,24 @@ def main(): help='tensor parallel size') parser.add_argument('--pipeline_para_size', type=int, default=1, help='pipeline parallel size') + parser.add_argument( + '--weights_data_type', + type=str, + default="fp32", + choices=["fp32", "fp16"], + help='Data type of FT checkpoint weights', + ) + parser.add_argument('--rougeLsum_threshold', type=float, + help='Threshold of FT rougeLsum score') + + args = parser.parse_args() try: dist.init_process_group(backend='mpi') except: - print("[INFO] WARNING: Have initalize the process group") + print("[INFO] WARNING: Have initialized the process group") rank = dist.get_rank() summarize = args.summarize @@ -78,13 +89,12 @@ def main(): if not args.ft_use_hf_config: ft_config = configparser.ConfigParser() - ft_config.read(os.path.join(ft_model_location, '1-gpu/config.ini')) - - head_num = ft_config.getint('gpt', 'num_attention_heads') - layer_num = ft_config.getint('gpt', 'num_layers') - start_id = 50256 # TODO: get this from the tokenizer - end_id = 50256 # TODO: get this from the tokenizer - size_per_head = ft_config.getint('gpt', 'hidden_size') // head_num + assert ft_config.read(os.path.join(ft_model_location, '1-gpu/config.ini')) != [], "[ERROR] fail to read the config.ini of model" + head_num = ft_config.getint('gpt', 'head_num') + layer_num = ft_config.getint('gpt', 'num_layer') + start_id = ft_config.getint('gpt', 'start_id') # TODO: get this from the tokenizer + end_id = ft_config.getint('gpt', 'end_id') # TODO: get this from the tokenizer + size_per_head = ft_config.getint('gpt', 'size_per_head') if summarize: top_k = 2 @@ -95,8 +105,8 @@ def main(): top_p = 0.0 random_seed = 5 temperature = 1 - max_seq_len = hf_config['n_ctx'] if args.ft_use_hf_config else ft_config.getint('gpt', 'max_position_embeddings') - max_batch_size = 5 + max_seq_len = hf_config['n_ctx'] if args.ft_use_hf_config else ft_config.getint('gpt', 'max_pos_seq_len') + max_batch_size = 1 repetition_penalty = 1 vocab_size = 50257 tensor_para_size = args.tensor_para_size @@ -118,16 +128,17 @@ def main(): print(f"ckpt_path: {ckpt_path}") print(f"hf_config: {hf_config}") + random_seed_tensor = random_seed * torch.ones([max_batch_size], dtype=torch.int64) + gpt = ParallelGPT(head_num, size_per_head, vocab_size, start_id, end_id, layer_num, - max_seq_len, tensor_para_size, pipeline_para_size, lib_path=lib_path, int8_mode=0) + max_seq_len, tensor_para_size, pipeline_para_size, lib_path=lib_path, int8_mode=0, + weights_data_type=args.weights_data_type) if not gpt.load(ckpt_path=ckpt_path): print("[WARNING] Checkpoint file not found. Model loading is skipped.") if (test_hf and summarize) or not summarize: model = GPT2LMHeadModel.from_pretrained(hf_model_location) - # device_hf = 'cuda:1' - # model.to(device_hf) model.cuda() if args.data_type == 'fp16': model.half() @@ -158,13 +169,13 @@ def summarize_ft(datapoint): output, ft_output_len = gpt(line_encoded, torch.IntTensor([len(line_encoded[0])]), output_len, 1, - top_k, - top_p, - 0.0, - temperature, - 1.0, - repetition_penalty, - random_seed, + top_k * torch.ones(size=[max_batch_size], dtype=torch.int32), + top_p * torch.ones(size=[max_batch_size], dtype=torch.float32), + 0.0 * torch.ones(size=[max_batch_size], dtype=torch.float32), + temperature * torch.ones(size=[max_batch_size], dtype=torch.float32), + 1.0 * torch.ones(size=[max_batch_size], dtype=torch.float32), + repetition_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + random_seed_tensor, True) tokens = output[0][0][len(line_encoded[0]):ft_output_len[0]].cpu().numpy() @@ -193,7 +204,7 @@ def summarize_hf(datapoint): output = model.generate(line_encoded, max_length=len(line_encoded[0]) + output_len, k=top_k, - temprature=temperature, + temperature=temperature, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id) @@ -286,6 +297,9 @@ def compute_exact_match(tokens, n_tokens=[1, 10, 25, 50, 100, 150, 200, 250]): print(f'Faster Transformers (total latency: {ft_time} sec)') for key in computed_metrics_ft.keys(): print(f'{key} : {computed_metrics_ft[key].mid[2]*100}') + if args.rougeLsum_threshold != None: + assert computed_metrics_ft["rougeLsum"].mid[2]*100 >= args.rougeLsum_threshold, "[INFO] TEST FAIL !" + print(f"[INFO] TEST PASS !") else: em_metrics = compute_exact_match(tokens) diff --git a/examples/pytorch/gpt/lambada_task_example.py b/examples/pytorch/gpt/lambada_task_example.py new file mode 100644 index 000000000..1b69fbee7 --- /dev/null +++ b/examples/pytorch/gpt/lambada_task_example.py @@ -0,0 +1,266 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import dataclasses +import json +import pathlib +import typing + +import numpy as np +import torch +import transformers + +from utils.gpt import GptInitModelParameters, GptRuntimeModelParameters +from utils.parallel_gpt import ParallelGPT + +class TensorEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, torch.Tensor): + return obj.tolist() + return super().default(obj) + +class LambadaDataset(torch.utils.data.Dataset): + def __init__(self, path, tokenizer, seq_len): + self.seq_len = seq_len + self.tokenizer = tokenizer + + with open(path, "r") as f: + texts = [json.loads(line)["text"] for line in f.readlines()] + + # this whitespace preprocessing (additional space and stripping) is required + labels = [" " + text.split()[-1] for text in texts] + inputs = [text[: text.rfind(label)].strip() for text, label in zip(texts, labels)] + self.encodings = self.tokenizer( + inputs, + labels, + padding="max_length", + max_length=self.seq_len, + return_token_type_ids=True, + return_tensors="pt", + ) + + def __len__(self): + return self.encodings["input_ids"].shape[0] + + def __getitem__(self, idx): + return { + "input_ids": self.encodings["input_ids"][idx], + "attention_mask": self.encodings["attention_mask"][idx], + "token_type_ids": self.encodings["token_type_ids"][idx], + } + + +@dataclasses.dataclass +class Metric: + acc: float + + +@dataclasses.dataclass +class RequestAndResult: + prompt: str + model_answer: str + target: str + input_ids: typing.List[int] + input_len: int + output_len: int + init_model_parameters: GptInitModelParameters + runtime_model_parameters: GptRuntimeModelParameters + output_ids: typing.List[int] + metrics: Metric + + +def _read_config_ini(args, checkpoint_path): + config_reader = configparser.ConfigParser() + config_ini_files_in_checkpoint_dir = list(checkpoint_path.rglob("config.ini")) + if args.config_ini_path is None and not config_ini_files_in_checkpoint_dir: + raise RuntimeError( + f"Missing config.ini file in {checkpoint_path}. Use --config-ini-path to point config.ini to load" + ) + config_ini_path = pathlib.Path(args.config_ini_path or config_ini_files_in_checkpoint_dir[0]) + if not config_ini_path.is_file(): + raise FileNotFoundError(f"Missing {config_ini_path}") + else: + config_reader.read(config_ini_path.as_posix()) + return config_reader + + +def _get_model(args, config_reader): + init_parameters = GptInitModelParameters.from_args(args, config_reader) + print("\n=============== GPT params ===============") + for key, value in dataclasses.asdict(init_parameters).items(): + print(f"{key}: {value}") + print(f"lib_path: {args.lib_path}") + print("========================================") + + gpt_params = init_parameters.gpt_init_kwargs() + gpt = ParallelGPT(**gpt_params, lib_path=args.lib_path) + + if not gpt.load(ckpt_path=args.checkpoint_path): + print("[WARNING] Checkpoint file not found. Model loading is skipped.") + + if init_parameters.data_type == "fp16": + gpt.half() + elif init_parameters.data_type == "bf16": + gpt.bfloat16() + if init_parameters.sparse: + gpt.sparse() + + gpt.eval() + + return gpt + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--checkpoint-path", type=str, required=True, help="Path to FasterTransformer checkpoint dir") + parser.add_argument("--lib-path", type=str, required=True, help="Path of FasterTransformer PyTorch GPT op library") + parser.add_argument( + "--config-ini-path", + type=str, + help="Path to config.ini file. If not provided /config.ini will be used.", + ) + parser.add_argument("--lambada-path", type=str, required=True, help="LAMBADA task data path") + parser.add_argument("--output-path", type=str, help="Path to sample output file.") + parser.add_argument("--batch-size", type=int, default=1, help="Batch size") + + GptInitModelParameters.update_argparser(parser) + GptRuntimeModelParameters.update_argparser(parser) + + args = parser.parse_args() + + print("\n============== Arguments ===============") + for key, value in vars(args).items(): + print(f"{key}: {value}") + print("========================================") + + checkpoint_path = pathlib.Path(args.checkpoint_path) + + config_reader = _read_config_ini(args, checkpoint_path) + + gpt = _get_model(args, config_reader) + + vocab_path = checkpoint_path / "vocab.json" + merges_path = checkpoint_path / "merges.txt" + max_seq_len = config_reader.getint("ft_instance_hyperparameter", "max_seq_len") + + tokenizer = transformers.GPT2TokenizerFast(vocab_path.as_posix(), merges_path.as_posix()) + tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token}) + dataset = LambadaDataset(args.lambada_path, tokenizer=tokenizer, seq_len=max_seq_len) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size) + + runtime_parameters = GptRuntimeModelParameters.from_args(args, config_reader) + inference_parameters_dict = dataclasses.asdict(runtime_parameters) + print("\n=========== Inference params ===========") + for key, value in inference_parameters_dict.items(): + print(f"{key}: {value}") + print("========================================") + + beam_idx = 0 # use only 1st beam result + + requested_num = 0 + correct_num = 0 + results = {"output": {"lambada": []}, "results": {"lambada": {}}} + with torch.no_grad(): + + for entries in data_loader: + inputs_tokens_batch = [ + input_ids[(attention_mask == 1) & (token_type_ids == 0)] + for input_ids, attention_mask, token_type_ids in zip( + entries["input_ids"], entries["attention_mask"], entries["token_type_ids"] + ) + ] + labels_tokens_batch = [ + input_ids[(attention_mask == 1) & (token_type_ids == 1)] + for input_ids, attention_mask, token_type_ids in zip( + entries["input_ids"], entries["attention_mask"], entries["token_type_ids"] + ) + ] + + inputs_tokens_batch_padded = [ + torch.nn.functional.pad( + input_tokens, + pad=[0, (max_seq_len - input_tokens.shape[0])], + mode="constant", + value=tokenizer.pad_token_id, + ) + for input_tokens in inputs_tokens_batch + ] + + input_tokens_lengths = [input_tokens.shape[0] for input_tokens in inputs_tokens_batch] + # max is required due to scalar is used for output_seq_len input + expected_tokens_lengths = max([label_tokens.shape[0] for label_tokens in labels_tokens_batch]) + + start_ids = torch.stack(inputs_tokens_batch_padded) # shape=(batch_size, max_seq_len) + runtime_parameters = GptRuntimeModelParameters.from_args(args, config_reader, start_ids.shape[0]) + inference_parameters_dict = dataclasses.asdict(runtime_parameters) + + start_ids = start_ids.to(torch.int32) + result_all_tokens_batch = gpt( + start_ids, + torch.IntTensor(input_tokens_lengths), + expected_tokens_lengths, + **inference_parameters_dict, + ) + + results_idxes = [ + torch.nonzero(token_type_ids, as_tuple=True)[0] for token_type_ids in entries["token_type_ids"] + ] + results_tokens_batch = [ + result_tokens_ids[beam_idx][result_idxes].cpu() + for result_tokens_ids, result_idxes in zip(result_all_tokens_batch, results_idxes) + ] + + labels_tokens_batch = [tokens.cpu() for tokens in labels_tokens_batch] + results_tokens_batch = [tokens.cpu() for tokens in results_tokens_batch] + + result_text_batch = tokenizer.batch_decode(results_tokens_batch) + input_text_batch = tokenizer.batch_decode(inputs_tokens_batch) + label_text_batch = tokenizer.batch_decode(labels_tokens_batch) + + for idx in range(len(inputs_tokens_batch)): + is_correct_answer = torch.all(labels_tokens_batch[idx] == results_tokens_batch[idx]) + correct_num += int(is_correct_answer) + result = RequestAndResult( + prompt=input_text_batch[idx], + model_answer=result_text_batch[idx], + target=label_text_batch[idx], + input_ids=list(map(int, inputs_tokens_batch[idx])), + input_len=int(input_tokens_lengths[idx]), + output_len=expected_tokens_lengths, + init_model_parameters=GptInitModelParameters.from_args(args, config_reader), + runtime_model_parameters=runtime_parameters.slice_args(idx), + output_ids=list(map(int, result_all_tokens_batch[idx][beam_idx])), + metrics=Metric(acc=float(is_correct_answer)), + ) + results["output"]["lambada"].append(dataclasses.asdict(result)) + + requested_num += len(inputs_tokens_batch) + + accuracy = correct_num * 100 / requested_num + print(f"[INFO] accuracy: {accuracy:0.4f}% (total : {requested_num})") + + # Dump prediction json + results["results"]["lambada"]["acc"] = accuracy + if args.output_path: + output_json_path = pathlib.Path(args.output_path) + output_json_path.parent.mkdir(parents=True, exist_ok=True) + with output_json_path.open(mode="w") as json_file: + json.dump(results, json_file, indent=2, cls=TensorEncoder) + print(f"[INFO] Detailed test results saved to {output_json_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/gpt/multi_gpu_gpt_example.py b/examples/pytorch/gpt/multi_gpu_gpt_example.py index 7fb286c4b..516fc336f 100644 --- a/examples/pytorch/gpt/multi_gpu_gpt_example.py +++ b/examples/pytorch/gpt/multi_gpu_gpt_example.py @@ -32,6 +32,8 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('--layer_num', type=int, default=24, help='number of layers') + parser.add_argument('--input_len', type=int, default=1, + help='input sequence length to generate.') parser.add_argument('--output_len', type=int, default=32, help='output sequence length to generate.') parser.add_argument('--head_num', type=int, default=16, @@ -48,7 +50,7 @@ def main(): help='top p probability threshold') parser.add_argument('--temperature', type=float, default=1., help='temperature') - parser.add_argument('--len_penalty', type=float, default=1., + parser.add_argument('--len_penalty', type=float, default=0., help='len_penalty') parser.add_argument('--beam_search_diversity_rate', type=float, default=0., help='beam_search_diversity_rate') @@ -81,16 +83,27 @@ def main(): help='path to sample input file. If not set, it runs with no context inputs.') parser.add_argument('--sample_output_file', type=str, default=None, help='path to sample output file.') - parser.add_argument('--is_fix_random_seed', type=bool, default=True, - help='is fixing the random seed.') + parser.add_argument('--enable_random_seed', action='store_true', + help='is use the random seed for sentences in a batch.') parser.add_argument('--int8_mode', type=int, default=0, help='int8 mode.') + parser.add_argument( + '--weights_data_type', + type=str, + default="fp32", + choices=["fp32", "fp16"], + help='Data type of FT checkpoint weights', + ) parser.add_argument('--return_cum_log_probs', type=int, default=0, choices=[0, 1, 2], help='Whether to compute the cumulative log probsbility of sentences.' ' 0: do not return the cumulative log probs ' ' 1: return the cumulative log probs of generated sequences' ' 2: return the cumulative log probs of sequences') + parser.add_argument('--shared_contexts_ratio', type=float, default=1.0, + help='Triggers the shared context optimization when' + 'compact_size <= shared_contexts_ratio * batch_size' + 'A value of 0.0 deactivate the optimization') args = parser.parse_args() @@ -113,8 +126,10 @@ def main(): max_seq_len = args.max_seq_len repetition_penalty = args.repetition_penalty int8_mode = args.int8_mode + weights_data_type = args.weights_data_type return_cum_log_probs = args.return_cum_log_probs return_output_length = return_cum_log_probs > 0 + shared_contexts_ratio = args.shared_contexts_ratio print("\n=============== Arguments ===============") for arg in vars(args): @@ -122,6 +137,7 @@ def main(): print("=========================================\n") enc = encoder.get_encoder(args.vocab_file, args.merges_file) + torch.manual_seed(0) # Inputs contexts = [] @@ -134,7 +150,7 @@ def main(): else: # unconditional case batch_size = max_batch_size contexts = ['<|endoftext|>'] * batch_size - start_ids = [torch.IntTensor([end_id])] * batch_size + start_ids = [torch.IntTensor([end_id for _ in range(args.input_len)])] * batch_size start_lengths = [len(ids) for ids in start_ids] input_len = max(start_lengths) @@ -142,15 +158,16 @@ def main(): start_ids = pad_sequence(start_ids, batch_first=True, padding_value=end_id) start_lengths = torch.IntTensor(start_lengths) - if args.is_fix_random_seed == True: - random_seed = 0 + if args.enable_random_seed == True: + random_seed_tensor = torch.randint(0, 10000, size=[max_batch_size], dtype=torch.int64) else: - random_seed = random.randint(0, 100000) + random_seed_tensor = torch.zeros([max_batch_size], dtype=torch.int64) # Prepare model. gpt = ParallelGPT(head_num, size_per_head, vocab_size, start_id, end_id, layer_num, max_seq_len, tensor_para_size, pipeline_para_size, - lib_path=args.lib_path, int8_mode=args.int8_mode) + lib_path=args.lib_path, int8_mode=args.int8_mode, weights_data_type=weights_data_type, + shared_contexts_ratio=shared_contexts_ratio) if not gpt.load(ckpt_path=args.ckpt_path): print("[WARNING] Checkpoint file not found. Model loading is skipped.") if args.data_type == 'fp16': @@ -164,13 +181,13 @@ def main(): start_lengths, output_len, beam_width, - top_k, - top_p, - beam_search_diversity_rate, - temperature, - len_penalty, - repetition_penalty, - random_seed, + top_k * torch.ones(size=[max_batch_size], dtype=torch.int32), + top_p * torch.ones(size=[max_batch_size], dtype=torch.float32), + beam_search_diversity_rate * torch.ones(size=[max_batch_size], dtype=torch.float32), + temperature * torch.ones(size=[max_batch_size], dtype=torch.float32), + len_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + repetition_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + random_seed_tensor, return_output_length, return_cum_log_probs) # only a thread (rank 0) gets the output, while the others are supposed to return None. @@ -200,13 +217,13 @@ def main(): start_lengths, output_len, beam_width, - top_k, - top_p, - beam_search_diversity_rate, - temperature, - len_penalty, - repetition_penalty, - random_seed, + top_k * torch.ones(size=[max_batch_size], dtype=torch.int32), + top_p * torch.ones(size=[max_batch_size], dtype=torch.float32), + beam_search_diversity_rate * torch.ones(size=[max_batch_size], dtype=torch.float32), + temperature * torch.ones(size=[max_batch_size], dtype=torch.float32), + len_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + repetition_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + random_seed_tensor, return_output_length, return_cum_log_probs) @@ -216,13 +233,13 @@ def main(): start_lengths, output_len, beam_width, - top_k, - top_p, - beam_search_diversity_rate, - temperature, - len_penalty, - repetition_penalty, - random_seed, + top_k * torch.ones(size=[max_batch_size], dtype=torch.int32), + top_p * torch.ones(size=[max_batch_size], dtype=torch.float32), + beam_search_diversity_rate * torch.ones(size=[max_batch_size], dtype=torch.float32), + temperature * torch.ones(size=[max_batch_size], dtype=torch.float32), + len_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + repetition_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + random_seed_tensor, return_output_length, return_cum_log_probs) time_elapsed = timeit.default_timer() - time diff --git a/examples/pytorch/gpt/opt_summarization.py b/examples/pytorch/gpt/opt_summarization.py new file mode 100644 index 000000000..548ee1f5b --- /dev/null +++ b/examples/pytorch/gpt/opt_summarization.py @@ -0,0 +1,301 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import argparse +import numpy as np +import os +import sys +import torch +import torch.distributed as dist +from datetime import datetime +from datasets import load_dataset, load_metric +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from tqdm import tqdm + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(dir_path + "/../../..") +from examples.pytorch.gpt.utils.parallel_gpt import ParallelGPT + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--ft_model_location', type=str, + default='/models/GPT/HF/gpt2-xl/c-models') + parser.add_argument('--hf_model_name', type=str, + default='facebook/opt-350m') + parser.add_argument('--summarize', action='store_true') + parser.add_argument('--test_hf', action='store_true') + parser.add_argument('--data_type', type=str, choices=['fp32', 'fp16', 'bf16'], default='fp32') + parser.add_argument("--cache_path", type=str, default="/workdir/datasets/ccdv/") + parser.add_argument("--max_ite", type=int, default=20) + parser.add_argument('--lib_path', type=str, default='./lib/libth_parallel_gpt.so', + help='path to the pyt_fastertransformer dynamic lib file.') + parser.add_argument('--tensor_para_size', type=int, default=1, + help='tensor parallel size') + parser.add_argument('--pipeline_para_size', type=int, default=1, + help='pipeline parallel size') + parser.add_argument( + '--weights_data_type', + type=str, + default="fp32", + choices=["fp32", "fp16"], + help='Data type of FT checkpoint weights', + ) + parser.add_argument('--rougeLsum_threshold', type=float, + help='Threshold of FT rougeLsum score') + + + + args = parser.parse_args() + + try: + dist.init_process_group(backend='mpi') + except: + print("[INFO] WARNING: Have initialized the process group") + rank = dist.get_rank() + + summarize = args.summarize + test_hf = args.test_hf + ft_model_location = args.ft_model_location + hf_model_name = args.hf_model_name + + tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + tokenizer.pad_token = tokenizer.eos_token + dataset_cnn = load_dataset("ccdv/cnn_dailymail", '3.0.0', cache_dir=args.cache_path) + + hf_config = vars(AutoConfig.from_pretrained(hf_model_name)) + + head_num = hf_config['num_attention_heads'] + layer_num = hf_config['num_hidden_layers'] + start_id = hf_config['bos_token_id'] + end_id = hf_config['eos_token_id'] + size_per_head = hf_config['hidden_size'] // head_num + + # opt specific params: some are fixed + layernorm_eps = 1e-5 + layernorm_type = 'pre_layernorm' if hf_config['do_layer_norm_before'] else 'post_layernorm' + activation_type = 'Relu' if hf_config['activation_function'] == 'relu' else 'Gelu' + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L498 + # has post decoder layernorm when layernorm_type is pre layernorm + has_post_decoder_layernorm = layernorm_type == 'pre_layernorm' + + if summarize: + top_k = 2 + output_len = 100 + else: + top_k = 1 + output_len = 256 + top_p = 0.0 + temperature = 1 + max_seq_len = hf_config['max_position_embeddings'] + max_batch_size = 1 + repetition_penalty = 1 + vocab_size = hf_config['vocab_size'] + tensor_para_size = args.tensor_para_size + pipeline_para_size = args.pipeline_para_size + lib_path = args.lib_path + ckpt_path = os.path.join(ft_model_location, f'{tensor_para_size}-gpu') + + print(f"top_k: {top_k}") + print(f"top_p: {top_p}") + print(f"temperature: {temperature}") + print(f"max_seq_len: {max_seq_len}") + print(f"max_batch_size: {max_batch_size}") + print(f"repetition_penalty: {repetition_penalty}") + print(f"vocab_size: {vocab_size}") + print(f"tensor_para_size: {tensor_para_size}") + print(f"pipeline_para_size: {pipeline_para_size}") + print(f"lib_path: {lib_path}") + print(f"ckpt_path: {ckpt_path}") + print(f"hf_config: {hf_config}") + + gpt = ParallelGPT(head_num, size_per_head, vocab_size, start_id, end_id, layer_num, + max_seq_len, tensor_para_size, pipeline_para_size, lib_path, + layernorm_eps, layernorm_type, activation_type, has_post_decoder_layernorm, + int8_mode=0, weights_data_type=args.weights_data_type) + + random_seed_tensor = torch.zeros([max_batch_size], dtype=torch.int64) + + if not gpt.load(ckpt_path=ckpt_path): + print("[WARNING] Checkpoint file not found. Model loading is skipped.") + + if (test_hf and summarize) or not summarize: + model = AutoModelForCausalLM.from_pretrained(hf_model_name) + model.cuda() + if args.data_type == 'fp16': + model.half() + elif args.data_type == 'bf16': + model.bfloat16() + + if args.data_type == 'fp16': + gpt.half() + elif args.data_type == 'bf16': + gpt.bfloat16() + + def summarize_ft(datapoint): + if summarize: + line = datapoint['article'] + ' TL;DR: ' + else: + line = datapoint['article'] + line = line.strip() + line = line.replace(" n't", "n't") + + line_encoded = tokenizer.encode(line, return_tensors='pt') + if summarize: + line_encoded = line_encoded[:, -923:] + else: + line_encoded = line_encoded[:, -768:] + line_encoded = line_encoded.type(torch.int32) + + with torch.no_grad(): + output, ft_output_len = gpt(line_encoded, torch.IntTensor([len(line_encoded[0])]), + output_len, + 1, + top_k * torch.ones(size=[max_batch_size], dtype=torch.int32), + top_p * torch.ones(size=[max_batch_size], dtype=torch.float32), + 0.0 * torch.ones(size=[max_batch_size], dtype=torch.float32), + temperature * torch.ones(size=[max_batch_size], dtype=torch.float32), + 1.0 * torch.ones(size=[max_batch_size], dtype=torch.float32), + repetition_penalty * torch.ones(size=[max_batch_size], dtype=torch.float32), + random_seed_tensor, + True) + + tokens = output[0][0][len(line_encoded[0]):ft_output_len[0]].cpu().numpy() + + output_lines = tokenizer.decode(output[0][0][len(line_encoded[0]):ft_output_len[0]]) + output_lines = ".".join(output_lines.split('.')[:4]) + "." + return output_lines, tokens + + def summarize_hf(datapoint): + if summarize: + line = datapoint['article'] + ' TL;DR: ' + else: + line = datapoint['article'] + line = line.strip() + line = line.replace(" n't", "n't") + + line_encoded = tokenizer.encode(line, return_tensors='pt') + if summarize: + line_encoded = line_encoded[:, -923:] + else: + line_encoded = line_encoded[:, -768:] + # line_encoded = line_encoded.to(device_hf) + line_encoded = line_encoded.cuda() + + with torch.no_grad(): + output = model.generate(line_encoded, + max_length=len(line_encoded[0]) + output_len, + k=top_k, + temperature=temperature, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id) + + tokens = output[0][len(line_encoded[0]):].cpu().numpy() + output_lines = tokenizer.decode(output[0][len(line_encoded[0]):]) + output_lines = ".".join(output_lines.split('.')[:4]) + "." + return output_lines, tokens + + if summarize: + datapoint = dataset_cnn['test'][0] + summary, _ = summarize_ft(datapoint) + print('---------------------------------------------------------') + print('FT Generated : ') + print(' Article : ', datapoint['article']) + print('\n Highlights : ', datapoint['highlights']) + print('\n Summary : ', summary) + print('---------------------------------------------------------') + + if test_hf: + summary, _ = summarize_hf(datapoint) + print('---------------------------------------------------------') + print('HF Generated : ') + print(' Article : ', datapoint['article']) + print('\n Highlights : ', datapoint['highlights']) + print('\n Summary : ', summary) + print('---------------------------------------------------------') + + if summarize: + metric_ft = load_metric("rouge") + metric_hf = load_metric("rouge") + else: + tokens = [] + + ft_time = 0.0 + hf_time = 0.0 + for data_point_idx in tqdm(range(1, 11490, int(11490 / args.max_ite))): + try: + datapoint = dataset_cnn['test'][data_point_idx] + + start_time = datetime.now() + summary_ft, tokens_ft = summarize_ft(datapoint) + stop_time = datetime.now() + ft_time += (stop_time - start_time).total_seconds() + if (test_hf and summarize) or not summarize: + start_time = datetime.now() + summary_hf, tokens_hf = summarize_hf(datapoint) + stop_time = datetime.now() + hf_time += (stop_time - start_time).total_seconds() + + if rank == 0: + if summarize: + metric_ft.add_batch(predictions=[summary_ft], references=[datapoint['highlights']]) + if test_hf: + metric_hf.add_batch(predictions=[summary_hf], references=[datapoint['highlights']]) + else: + tokens.append((tokens_ft, tokens_hf)) + except: + print('Error with datapoint : ', data_point_idx) + + def compute_exact_match(tokens, n_tokens=[1, 10, 25, 50, 100, 150, 200, 250]): + em_metrics = [] + for t in n_tokens: + errors = 0 + total = 0 + for ft_tokens, hf_tokens in tokens: + if len(ft_tokens) > t and len(hf_tokens) > t: + total = total + 1 + if not np.array_equal(ft_tokens[:t], hf_tokens[:t]): + errors = errors + 1 + + if total > 0: + print(f"{t}-token exact match acc: {100*(1-errors/total):.2f}") + em_metrics.append(1 - errors / total) + else: + em_metrics.append(np.nan) + + return em_metrics + + if rank == 0: + if summarize: + computed_metrics_ft = metric_ft.compute() + + if test_hf: + computed_metrics_hf = metric_hf.compute() + + print(f'Hugging Face (total latency: {hf_time} sec)') + for key in computed_metrics_hf.keys(): + print(f'{key} : {computed_metrics_hf[key].mid[2]*100}') + + print(f'Faster Transformers (total latency: {ft_time} sec)') + for key in computed_metrics_ft.keys(): + print(f'{key} : {computed_metrics_ft[key].mid[2]*100}') + if args.rougeLsum_threshold != None: + assert computed_metrics_ft["rougeLsum"].mid[2]*100 >= args.rougeLsum_threshold, "[INFO] TEST FAIL !" + print(f"[INFO] TEST PASS !") + else: + em_metrics = compute_exact_match(tokens) + + +if __name__ == '__main__': + main() diff --git a/examples/pytorch/gpt/requirement.txt b/examples/pytorch/gpt/requirement.txt index 7e4943046..a84bd4859 100644 --- a/examples/pytorch/gpt/requirement.txt +++ b/examples/pytorch/gpt/requirement.txt @@ -1,4 +1,5 @@ -datasets -fire -rouge_score -transformers \ No newline at end of file +datasets~=2.3.2 +fire~=0.4.0 +omegaconf~=2.1.2 +rouge_score~=0.1.2 +transformers~=4.20.1 diff --git a/examples/pytorch/gpt/scripts/evaluate_zeroshot_gpt.sh b/examples/pytorch/gpt/scripts/evaluate_zeroshot_gpt.sh index 845d53a85..d7b990096 100644 --- a/examples/pytorch/gpt/scripts/evaluate_zeroshot_gpt.sh +++ b/examples/pytorch/gpt/scripts/evaluate_zeroshot_gpt.sh @@ -1,21 +1,21 @@ #!/bin/bash +VOCAB_FILE=$1 +MERGE_FILE=$2 +LAMBADA_PATH=$3 +CHECKPOINT=$4 + TASK="LAMBADA" -LAMBADA_PATH="../models/megatron-models/lambada_test.jsonl" VALID_DATA=$LAMBADA_PATH -VOCAB_FILE=../models/gpt2-vocab.json -MERGE_FILE=../models/gpt2-merges.txt -CHECKPOINT=../models/megatron-models/345m/ - python -m torch.distributed.run --nproc_per_node 1 ../examples/pytorch/gpt/evaluate_zeroshot_gpt.py \ --task $TASK \ - --valid-data $VALID_DATA \ + --valid-data "${VALID_DATA}" \ --tokenizer-type GPT2BPETokenizer \ --strict-lambada \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --load $CHECKPOINT \ + --vocab-file "${VOCAB_FILE}" \ + --merge-file "${MERGE_FILE}" \ + --load "${CHECKPOINT}" \ --tensor-model-parallel-size 1 \ --num-layers 24 \ --hidden-size 1024 \ @@ -28,7 +28,7 @@ python -m torch.distributed.run --nproc_per_node 1 ../examples/pytorch/gpt/evalu --fp16 \ --no-load-optim \ --no-load-rng \ - --ckpt-path '../models/megatron-models/c-model/345m/1-gpu' \ + --ckpt-path "${CHECKPOINT}" \ --lib-path "lib/libth_gpt.so" \ --beam_width 1 \ --top_k 1 \ @@ -38,11 +38,11 @@ sleep 20 python -m torch.distributed.run --nproc_per_node 1 ../examples/pytorch/gpt/evaluate_zeroshot_gpt.py \ --task $TASK \ - --valid-data $VALID_DATA \ + --valid-data "${VALID_DATA}" \ --tokenizer-type GPT2BPETokenizer \ --strict-lambada \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ + --vocab-file "${VOCAB_FILE}" \ + --merge-file "${MERGE_FILE}" \ --load $CHECKPOINT \ --tensor-model-parallel-size 1 \ --num-layers 24 \ @@ -56,7 +56,7 @@ python -m torch.distributed.run --nproc_per_node 1 ../examples/pytorch/gpt/evalu --fp16 \ --no-load-optim \ --no-load-rng \ - --ckpt-path '../models/megatron-models/c-model/345m/1-gpu' \ + --ckpt-path "${CHECKPOINT}" \ --lib-path "lib/libth_gpt.so" \ --beam_width 1 \ --top_k 0 \ diff --git a/examples/pytorch/gpt/utils/checkpoint_saver_fastertransformer.py b/examples/pytorch/gpt/utils/checkpoint_saver_fastertransformer.py index 2875f7475..cfb21553c 100644 --- a/examples/pytorch/gpt/utils/checkpoint_saver_fastertransformer.py +++ b/examples/pytorch/gpt/utils/checkpoint_saver_fastertransformer.py @@ -3,7 +3,6 @@ import numpy as np import torch -# This file is used with "https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/checkpoint_util/tools/checkpoint_util.py" # Example # python tools/checkpoint_util.py --model-type GPT --loader megatron --saver fastertransformer \ # --input /home/scratch.bhsueh_sw/megatron_new_ckpt/357m-pipeline-2-tensor-2/ --output ./tmp --target-tensor-parallel-size 2 diff --git a/examples/pytorch/gpt/utils/generate_gpt_config.py b/examples/pytorch/gpt/utils/generate_gpt_config.py index 43b5a0ce6..af0a5a240 100644 --- a/examples/pytorch/gpt/utils/generate_gpt_config.py +++ b/examples/pytorch/gpt/utils/generate_gpt_config.py @@ -1,5 +1,4 @@ # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -101,10 +100,12 @@ def generate_gpt_config(args): help='end id (default: 50256)') parser.add_argument('-repetition_penalty', '--repetition_penalty', type=float, default=1.0, metavar='NUMBER', help='repetition_penalty (default: 1.0)') - parser.add_argument('-len_penalty', '--len_penalty', type=float, default=1.0, metavar='NUMBER', - help='len_penalty (default: 1.0)') + parser.add_argument('-len_penalty', '--len_penalty', type=float, default=0.0, metavar='NUMBER', + help='len_penalty (default: 0.0)') parser.add_argument('-beam_search_diversity_rate', '--beam_search_diversity_rate', type=float, default=0.0, metavar='NUMBER', help='beam_search_diversity_rate (default: 0.0)') + parser.add_argument('-memory_len', '--memory_len', type=int, default=None, metavar='NUMBER', + help='Memory length (how many time steps to keep in memory) (default: None)') args = parser.parse_args() - generate_gpt_config(vars(args)) \ No newline at end of file + generate_gpt_config(vars(args)) diff --git a/examples/pytorch/gpt/utils/gpt.py b/examples/pytorch/gpt/utils/gpt.py index 9e0d14253..d8d412388 100644 --- a/examples/pytorch/gpt/utils/gpt.py +++ b/examples/pytorch/gpt/utils/gpt.py @@ -14,7 +14,13 @@ from __future__ import print_function +import argparse +import dataclasses +import json import os +import pathlib +import typing + import torch import torch.nn as nn import numpy as np @@ -22,7 +28,9 @@ class GPTWeights(object): - def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, tensor_para_size, pipeline_para_size, int8_mode = 0): + def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, tensor_para_size, pipeline_para_size, + has_adapters=False, adapter_inter_size = 0, has_post_decoder_layernorm=True, + int8_mode=0, weights_data_type=typing.Union[str, np.float32]): assert(head_num % tensor_para_size == 0) if int8_mode != 0: @@ -37,11 +45,16 @@ def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, self.pipeline_para_size = pipeline_para_size self.layers_per_device = layer_num // pipeline_para_size + self.has_adapters = has_adapters + self.adapter_inter_size = adapter_inter_size + self.has_post_decoder_layernorm = has_post_decoder_layernorm + local_head_num = head_num // tensor_para_size global_head_num = head_num local_hidden_units = local_head_num * size_per_head global_hidden_units = global_head_num * size_per_head local_inter_size = local_hidden_units * 4 + local_adapter_inter_size = self.adapter_inter_size // tensor_para_size self.local_head_num = local_head_num self.global_head_num = global_head_num @@ -51,6 +64,20 @@ def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, self.int8_mode = int8_mode + if isinstance(weights_data_type, str): + try: + weights_data_type = { + "fp16": np.float16, + "fp32": np.float32, + "float16": np.float16, + "float32": np.float32, + }[weights_data_type] + except KeyError: + raise ValueError(f"Don't know how to interpret weights_data_type: {weights_data_type}") + + assert weights_data_type in [np.float32, np.float16] + self.weights_data_type = weights_data_type + self.w = [] self.int8_w = [] self.scale = [] @@ -68,11 +95,22 @@ def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, self.w.extend([torch.zeros(local_inter_size, global_hidden_units)] * layer_num) # ffn_kernel2 self.w.extend([torch.zeros(global_hidden_units)] * layer_num) # ffn_bias2 # After Transformer blocks - self.w.append(torch.zeros(global_hidden_units)) # layernorm_gamma - self.w.append(torch.zeros(global_hidden_units)) # layernorm_beta + if self.has_post_decoder_layernorm: + self.w.append(torch.zeros(global_hidden_units)) # layernorm_gamma + self.w.append(torch.zeros(global_hidden_units)) # layernorm_beta self.w.append(torch.zeros(max_seq_len, global_hidden_units)) # position_encoding_table self.w.append(torch.zeros(vocab_size, global_hidden_units)) # embedding_table self.w.append(torch.zeros(vocab_size, global_hidden_units)) # embedding_kernel + # adapters + if self.has_adapters: + self.w.extend([torch.zeros(global_hidden_units, local_adapter_inter_size)] * layer_num) # adaptor1_kernel1 + self.w.extend([torch.zeros(local_adapter_inter_size)] * layer_num) # adaptor1_bias1 + self.w.extend([torch.zeros(local_adapter_inter_size, global_hidden_units)] * layer_num) # adaptor1_kernel2 + self.w.extend([torch.zeros(global_hidden_units)] * layer_num) # adaptor1_bias2 + self.w.extend([torch.zeros(global_hidden_units, local_adapter_inter_size)] * layer_num) # adaptor2_kernel1 + self.w.extend([torch.zeros(local_adapter_inter_size)] * layer_num) # adaptor2_bias1 + self.w.extend([torch.zeros(local_adapter_inter_size, global_hidden_units)] * layer_num) # adaptor2_kernel2 + self.w.extend([torch.zeros(global_hidden_units)] * layer_num) # adaptor2_bias2 # Initialization self._map(lambda w: torch.nn.init.normal_(w, mean=0., std=1.)) @@ -86,8 +124,15 @@ def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, self.scale.extend([torch.zeros(local_inter_size, dtype=torch.float)] * layer_num) # ffn_scale1 self.int8_w.extend([torch.zeros(local_inter_size, global_hidden_units, dtype=torch.int8)] * layer_num) # ffn_int8_kernel2 self.scale.extend([torch.zeros(global_hidden_units, dtype=torch.float)] * layer_num) # ffn_scale2 - - + if self.has_adapters: + self.int8_w.extend([torch.zeros(global_hidden_units, local_adapter_inter_size, dtype=torch.int8)] * layer_num) # adaptor1_int8_kernel1 + self.scale.extend([torch.zeros(local_adapter_inter_size, dtype=torch.float)] * layer_num) # adaptor1_scale1 + self.int8_w.extend([torch.zeros(local_adapter_inter_size, global_hidden_units, dtype=torch.int8)] * layer_num) # adaptor1_int8_kernel2 + self.scale.extend([torch.zeros(global_hidden_units, dtype=torch.float)] * layer_num) # adaptor1_scale2 + self.int8_w.extend([torch.zeros(global_hidden_units, local_adapter_inter_size, dtype=torch.int8)] * layer_num) # adaptor2_int8_kernel1 + self.scale.extend([torch.zeros(local_adapter_inter_size, dtype=torch.float)] * layer_num) # adaptor2_scale1 + self.int8_w.extend([torch.zeros(local_adapter_inter_size, global_hidden_units, dtype=torch.int8)] * layer_num) # adaptor2_int8_kernel2 + self.scale.extend([torch.zeros(global_hidden_units, dtype=torch.float)] * layer_num) # adaptor2_scale2 def __getitem__(self, idx): return self.w[idx] @@ -127,63 +172,99 @@ def load(self, ckpt_path, tensor_para_rank, pipeline_para_rank): return False w = [] + type_map = {np.float32: torch.float32, np.float16: torch.float16} # Load def is_load(i): return i >= self.layers_per_device * \ pipeline_para_rank and i < self.layers_per_device * (pipeline_para_rank + 1) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.input_layernorm.weight.bin".format(i), - dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.input_layernorm.bias.bin".format(i), - dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.attention.query_key_value.weight.{}.bin".format(i, - tensor_para_rank), dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.attention.query_key_value.bias.{}.bin".format(i, - tensor_para_rank), dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.attention.dense.weight.{}.bin".format(i, - tensor_para_rank), dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.attention.dense.bias.bin".format(i), - dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.post_attention_layernorm.weight.bin".format(i), - dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.post_attention_layernorm.bias.bin".format(i), - dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.mlp.dense_h_to_4h.weight.{}.bin".format(i, - tensor_para_rank), dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.mlp.dense_h_to_4h.bias.{}.bin".format(i, - tensor_para_rank), dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.mlp.dense_4h_to_h.weight.{}.bin".format(i, - tensor_para_rank), dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.mlp.dense_4h_to_h.bias.bin".format(i), - dtype=np.single)) if is_load(i) else torch.empty(0) for i in range(self.layer_num)]) + dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) - w.append(torch.from_numpy(np.fromfile(ckpt_path + "/model.final_layernorm.weight.bin", dtype=np.single))) - w.append(torch.from_numpy(np.fromfile(ckpt_path + "/model.final_layernorm.bias.bin", dtype=np.single))) + if self.has_post_decoder_layernorm: + w.append(torch.from_numpy(np.fromfile(ckpt_path + "/model.final_layernorm.weight.bin", dtype=self.weights_data_type))) + w.append(torch.from_numpy(np.fromfile(ckpt_path + "/model.final_layernorm.bias.bin", dtype=self.weights_data_type))) - wpe = torch.from_numpy(np.fromfile(ckpt_path + "/model.wpe.bin", dtype=np.single) + wpe = torch.from_numpy(np.fromfile(ckpt_path + "/model.wpe.bin", dtype=self.weights_data_type) ).reshape(-1, self.global_hidden_units) - assert self.max_seq_len <= wpe.size( - 0), "max_seq_len must not exceed the value of maximum sequence length during traning." + assert self.max_seq_len <= wpe.size(0), ( + f"max_seq_len ({self.max_seq_len} must not exceed " + f"the value of maximum sequence length during training ({wpe.size(0)})." + ) w.append(wpe) - w.append(torch.from_numpy(np.fromfile(ckpt_path + "/model.wte.bin", dtype=np.single))) - w.append(torch.from_numpy(np.fromfile(ckpt_path + "/model.wte.bin", dtype=np.single))) + w.append(torch.from_numpy(np.fromfile(ckpt_path + "/model.wte.bin", dtype=self.weights_data_type))) + if os.path.isfile(ckpt_path + "/model.lm_head.weight.bin"): + w.append(torch.from_numpy(np.fromfile(ckpt_path + "/model.lm_head.weight.bin", dtype=self.weights_data_type))) + else: + w.append(torch.from_numpy(np.fromfile(ckpt_path + "/model.wte.bin", dtype=self.weights_data_type))) + + + if self.has_adapters: + w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.after_attention_adapter.dense_h_to_4h.weight.{}.bin".format(i, + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) + w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.after_attention_adapter.dense_h_to_4h.bias.{}.bin".format(i, + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) + w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.after_attention_adapter.dense_4h_to_h.weight.{}.bin".format(i, + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) + w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.after_attention_adapter.dense_4h_to_h.bias.bin".format(i), + dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) + w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.after_ffn_adapter.dense_h_to_4h.weight.{}.bin".format(i, + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) + w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.after_ffn_adapter.dense_h_to_4h.bias.{}.bin".format(i, + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) + w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.after_ffn_adapter.dense_4h_to_h.weight.{}.bin".format(i, + tensor_para_rank), dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) + w.extend([torch.from_numpy(np.fromfile(ckpt_path + "/model.layers.{}.after_ffn_adapter.dense_4h_to_h.bias.bin".format(i), + dtype=self.weights_data_type)) if is_load(i) else torch.empty(0).to(type_map[self.weights_data_type]) for i in range(self.layer_num)]) # Reshape try: for i in range(len(w)): if w[i].nelement() > 0: self.w[i] = w[i].reshape(self.w[i].shape) + else: + self.w[i] = w[i] except RuntimeError: raise RuntimeError( - "head_num, size_per_head, vocab_size, and max_seq_len must be the same as the ones during training.") + f"head_num, size_per_head, vocab_size, and max_seq_len must be the same as the ones during training " + f"(idx: {i} expected shape: {self.w[i].shape} got shape: {w[i].shape})." + ) #transpose calibrate quantize the kernel layer_num = self.layer_num + final_layernorm_w_offset = 2 if self.has_post_decoder_layernorm else 0 if self.int8_mode != 0: for i in range(layer_num): self.int8_w[i + 0*layer_num], self.scale[i + 0*layer_num] = self.weight_transpose_calibrate_quantize(self.w[2*layer_num + i]) self.int8_w[i + 1*layer_num], self.scale[i + 1*layer_num] = self.weight_transpose_calibrate_quantize(self.w[4*layer_num + i]) self.int8_w[i + 2*layer_num], self.scale[i + 2*layer_num] = self.weight_transpose_calibrate_quantize(self.w[8*layer_num + i]) self.int8_w[i + 3*layer_num], self.scale[i + 3*layer_num] = self.weight_transpose_calibrate_quantize(self.w[10*layer_num + i]) + if self.has_adapters: + self.int8_w[i + 4*layer_num], self.scale[i + 0*layer_num] = self.weight_transpose_calibrate_quantize(self.w[12*layer_num + i + 3 + final_layernorm_w_offset]) + self.int8_w[i + 5*layer_num], self.scale[i + 1*layer_num] = self.weight_transpose_calibrate_quantize(self.w[14*layer_num + i + 3 + final_layernorm_w_offset]) + self.int8_w[i + 6*layer_num], self.scale[i + 2*layer_num] = self.weight_transpose_calibrate_quantize(self.w[16*layer_num + i + 3 + final_layernorm_w_offset]) + self.int8_w[i + 7*layer_num], self.scale[i + 3*layer_num] = self.weight_transpose_calibrate_quantize(self.w[18*layer_num + i + 3 + final_layernorm_w_offset]) return True @@ -195,7 +276,12 @@ def __init__(self, max_seq_len, tensor_para_size, pipeline_para_size, lib_path, - int8_mode = 0): + layernorm_eps = 1e-6, layernorm_type = "pre_layernorm", # gpt_variant_params + activation_type = "Gelu", has_post_decoder_layernorm = True, # gpt variant params + has_adapters = False, adapter_inter_size = 0, # gpt variant params + int8_mode = 0, + weights_data_type: np.dtype = np.float32, + shared_contexts_ratio=1.0): super().__init__() self.head_num = head_num self.size_per_head = size_per_head @@ -203,11 +289,21 @@ def __init__(self, self.start_id = start_id self.end_id = end_id self.layer_num = layer_num + # gpt_variant_params + self.layernorm_eps = layernorm_eps + self.layernorm_type = layernorm_type + self.activation_type = activation_type + self.has_post_decoder_layernorm = has_post_decoder_layernorm + self.has_adapters = has_adapters + self.adapter_inter_size = adapter_inter_size + # multi-gpu params self.tensor_para_size = tensor_para_size self.pipeline_para_size = pipeline_para_size self.use_sparse_gemm = False self.build_model = False self.int8_mode = int8_mode + self.weights_data_type = weights_data_type + self.shared_contexts_ratio= shared_contexts_ratio assert torch.cuda.is_available(), "CUDA is required for this model." @@ -220,13 +316,17 @@ def __init__(self, # Prepare weights self.weights = GPTWeights(head_num, size_per_head, layer_num, vocab_size, max_seq_len, tensor_para_size, pipeline_para_size, - int8_mode) + has_post_decoder_layernorm = has_post_decoder_layernorm, + has_adapters = has_adapters, + adapter_inter_size = adapter_inter_size, + weights_data_type=weights_data_type, + int8_mode=int8_mode) # Prepare for tensor/pipeline parallel try: dist.init_process_group(backend='mpi') except: - print("[INFO] WARNING: Have initalize the process group") + print("[INFO] WARNING: Have initialized the process group") self.rank = dist.get_rank() self.device_count = torch.cuda.device_count() self.device = self.rank % self.device_count @@ -270,7 +370,10 @@ def cuda(self): self.build_model = False self.model = torch.classes.FasterTransformer.GptOp(self.head_num, self.size_per_head, 4 * self.head_num * self.size_per_head, self.layer_num, self.vocab_size, self.start_id, self.end_id, - self.use_sparse_gemm, self.weights.w) + self.use_sparse_gemm, + self.layernorm_eps, self.layernorm_type, self.activation_type, + self.has_post_decoder_layernorm, self.has_adapters, self.adapter_inter_size, # gpt_variant_params + self.weights.w) self.build_model = True def forward(self, @@ -278,13 +381,13 @@ def forward(self, start_lengths, output_len, beam_width=1, - top_k=1, - top_p=0.0, - beam_search_diversity_rate=0.0, - temperature=1.0, - len_penalty=1.0, - repetition_penalty=1.0, - random_seed=0, + top_k=None, + top_p=None, + beam_search_diversity_rate=None, + temperature=None, + len_penalty=None, + repetition_penalty=None, + random_seed=None, return_output_length=False, return_cum_log_probs=0): if not self.build_model: @@ -299,15 +402,15 @@ def forward(self, outputs = self.model.forward(start_ids, start_lengths, output_len, - beam_width, - top_k, - top_p, - beam_search_diversity_rate, - temperature, - len_penalty, - repetition_penalty, - random_seed, - return_cum_log_probs) + beam_width, # optional, can be None + top_k, # optional, can be None + top_p, # optional, can be None + beam_search_diversity_rate, # optional, can be None + temperature, # optional, can be None + len_penalty, # optional, can be None + repetition_penalty, # optional, can be None + random_seed, # optional, can be None + return_cum_log_probs) # optional, can be None if return_cum_log_probs == 0: output_ids, output_lengths = outputs else: @@ -329,3 +432,168 @@ def set_input_tensor(self, input_tensor): used by internal code to bypass the input provided by the forward_step_func""" self.input_tensor = input_tensor + + +@dataclasses.dataclass +class GptInitModelParameters: + head_num: int + size_per_head: int + layer_num: int + max_seq_len: int + tensor_para_size: int + vocab_size: int + start_id: int + end_id: int + pipeline_para_size: int + weights_data_type: str + has_adapters: bool + adapter_inter_size: int + data_type: str + int8_mode: int + sparse: int + + def gpt_init_kwargs(self): + do_not_include = ["data_type", "sparse"] + return {k: v for k, v in dataclasses.asdict(self).items() if k not in do_not_include} + + @classmethod + def from_args(cls, args, config_reader): + model_name = args.model_name + + return cls( + head_num=config_reader.getint(model_name, "head_num"), + size_per_head=config_reader.getint(model_name, "size_per_head"), + layer_num=config_reader.getint(model_name, "num_layer"), + max_seq_len=config_reader.getint(model_name, "max_pos_seq_len"), + tensor_para_size=config_reader.getint(model_name, "tensor_para_size"), + vocab_size=config_reader.getint(model_name, "vocab_size"), + start_id=config_reader.getint(model_name, "start_id"), + end_id=config_reader.getint(model_name, "end_id"), + weights_data_type=config_reader.get(model_name, "weight_data_type"), + has_adapters=config_reader.getboolean(model_name, "has_adapters", fallback=False), + adapter_inter_size=config_reader.getint(model_name, "adapter_inter_size", fallback=0), + pipeline_para_size=( + args.pipeline_para_size or config_reader.getint("ft_instance_hyperparameter", "pipeline_para_size") + ), + int8_mode=( + args.int8_mode + if args.int8_mode is not None + else config_reader.getint("ft_instance_hyperparameter", "int8_mode") + ), + data_type=(args.data_type or config_reader.get("ft_instance_hyperparameter", "data_type")), + sparse=int(args.sparse or False), + ) + + @classmethod + def update_argparser(cls, parser): + parser.add_argument("--model-name", type=str, default="gpt", help="Model name from config.ini file") + parser.add_argument("--pipeline-para-size", type=int, help="size of pipeline parallelism") + parser.add_argument("--data-type", type=str, help="data type", choices=["fp32", "bf16", "fp16"]) + parser.add_argument( + "--sparse", + type=int, + choices=[0, 1], + help="Set sparse matrix multiplication. (Need SM 8.0 or 8.6 and SPARSITY_SUPPORT=ON)", + ) + parser.add_argument("--int8-mode", type=int, choices=[0, 1], help="Set int8 mode") + + +@dataclasses.dataclass +class GptRuntimeModelParameters: + beam_width: int + top_k: torch.Tensor + top_p: torch.Tensor + beam_search_diversity_rate: torch.Tensor + temperature: torch.Tensor + len_penalty: torch.Tensor + repetition_penalty: torch.Tensor + + def gpt_forward_kwargs(self): + return dataclasses.asdict(self) + + @classmethod + def from_args(cls, args, config_reader, batch_size=None): + bs = args.batch_size + if batch_size is not None: + bs = batch_size + return cls( + beam_width=args.beam_width or config_reader.getint("ft_instance_hyperparameter", "beam_width"), + top_k=(args.sampling_top_k or config_reader.getint("ft_instance_hyperparameter", "top_k")) * torch.ones(size=[bs], dtype=torch.int32), + top_p=(args.sampling_top_p or config_reader.getfloat("ft_instance_hyperparameter", "top_p")) * torch.ones(size=[bs], dtype=torch.float32), + beam_search_diversity_rate=( + args.beam_search_diversity_rate + or config_reader.getfloat("ft_instance_hyperparameter", "beam_search_diversity_rate") + ) * torch.ones(size=[bs], dtype=torch.float32), + temperature=(args.temperature or config_reader.getfloat("ft_instance_hyperparameter", "temperature")) * torch.ones(size=[bs], dtype=torch.float32), + len_penalty=(args.len_penalty or config_reader.getfloat("ft_instance_hyperparameter", "len_penalty")) * torch.ones(size=[bs], dtype=torch.float32), + repetition_penalty=( + args.repetition_penalty or config_reader.getfloat("ft_instance_hyperparameter", "repetition_penalty") + ) * torch.ones(size=[bs], dtype=torch.float32), + ) + + def slice_args(self, idx): + return GptRuntimeModelParameters( + beam_width=self.beam_width, + top_k=self.top_k[idx], + top_p=self.top_p[idx], + beam_search_diversity_rate=self.beam_search_diversity_rate[idx], + temperature=self.temperature[idx], + len_penalty=self.len_penalty[idx], + repetition_penalty=self.repetition_penalty[idx], + ) + + @classmethod + def update_argparser(cls, parser): + parser.add_argument("--beam-width", type=int, help="beam width") + parser.add_argument("--sampling-top-k", type=int, help="Candidate (k) value of top k sampling in decoding") + parser.add_argument("--sampling-top-p", type=float, help="Probability (p) value of top p sampling in decoding.") + parser.add_argument("--temperature", type=float, help="temperature") + parser.add_argument("--len-penalty", type=float, help="len_penalty") + parser.add_argument("--repetition-penalty", type=float, help="repetition penalty") + parser.add_argument("--beam-search-diversity-rate", type=float, help="beam_search_diversity_rate") + + +DEFAULT_START_TAG = "<|endoftext|>" +DEFAULT_END_TAG = "<|endoftext|>" +OPENAI_GPT2_START_ID = 50256 +OPENAI_GPT2_END_ID = 50256 + + +@dataclasses.dataclass +class GptModelConfig: + model_name: str + tensor_para_size: int + head_num: int + size_per_head: int + inter_size: int + num_layer: int + max_pos_seq_len: int + weight_data_type: str + vocab_size: int + start_id: int + end_id: int + + @classmethod + def from_nemo_package( + cls, + *, + args: argparse.Namespace, + nemo_model_config: typing.Dict[str, typing.Any], + bos_id: int, + eos_id: int, + vocab_size: int, + ): + + return cls( + model_name="gpt", + tensor_para_size=args.infer_gpu_num, + head_num=nemo_model_config["num_attention_heads"], + size_per_head=nemo_model_config["hidden_size"] // nemo_model_config["num_attention_heads"], + inter_size=nemo_model_config["ffn_hidden_size"], + num_layer=nemo_model_config["num_layers"], + max_pos_seq_len=nemo_model_config["max_position_embeddings"], + weight_data_type=args.weight_data_type, + vocab_size=vocab_size, + start_id=bos_id, + end_id=eos_id, + ) diff --git a/examples/pytorch/gpt/utils/gpt_token_converter.py b/examples/pytorch/gpt/utils/gpt_token_converter.py index 83c05569d..73e83de6a 100644 --- a/examples/pytorch/gpt/utils/gpt_token_converter.py +++ b/examples/pytorch/gpt/utils/gpt_token_converter.py @@ -24,11 +24,13 @@ def convert_token( vocab_file="../models/gpt2-vocab.json", bpe_file="../models/gpt2-merges.txt", out_file="out", - max_input_length=-1 + max_input_length=-1, + text_out_file=None, ): enc = encoder.get_encoder(vocab_file, bpe_file) tokens_batch = np.loadtxt(out_file, dtype=np.int32) end_id = 50256 + outputs = [] if(tokens_batch.ndim == 1): tokens_batch = tokens_batch.reshape([1, -1]) for batch_num, tokens in enumerate(tokens_batch): @@ -40,6 +42,11 @@ def convert_token( if len(end_index) > 0: end_pos = end_index[0] print(f"[INFO] batch {batch_num}: {enc.decode(tokens[:end_pos])}") + outputs.append(enc.decode(tokens[:end_pos])) + + if text_out_file != None: + with open(text_out_file, "w+") as f: + f.writelines("\n".join(outputs)) return tokens_batch if __name__ == "__main__": diff --git a/examples/pytorch/gpt/utils/gpt_token_encoder.py b/examples/pytorch/gpt/utils/gpt_token_encoder.py index 8a560b6a0..739c32b83 100644 --- a/examples/pytorch/gpt/utils/gpt_token_encoder.py +++ b/examples/pytorch/gpt/utils/gpt_token_encoder.py @@ -50,7 +50,7 @@ def bytes_to_unicode(): The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. + This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ diff --git a/examples/pytorch/gpt/utils/huggingface_gpt_convert.py b/examples/pytorch/gpt/utils/huggingface_gpt_convert.py index a0ef1eab5..6d3c6bad6 100644 --- a/examples/pytorch/gpt/utils/huggingface_gpt_convert.py +++ b/examples/pytorch/gpt/utils/huggingface_gpt_convert.py @@ -100,21 +100,32 @@ def split_and_convert(args): factor = (int)(i_gpu_num / t_gpu_num) - # load position_embedding from rank 0 - model = GPT2Model.from_pretrained(args.in_file) - + # load position_embedding from rank 0 + torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' + model = GPT2Model.from_pretrained(args.in_file).to(torch_device) + + hf_config = vars(model.config) + + # NOTE: save parameters to config files (loaded by triton backends) + config = configparser.ConfigParser() + config["gpt"] = {} try: - config = configparser.ConfigParser() - config["gpt"] = {} - for key in vars(args): - config["gpt"][key] = f"{vars(args)[key]}" - for k, v in vars(model.config).items(): - config["gpt"][k] = f"{v}" - config["gpt"]["weight_data_type"] = args.weight_data_type - with open((Path(saved_dir) / f"config.ini").as_posix(), 'w') as configfile: + config["gpt"]["model_name"] = "gpt" if hf_config["_name_or_path"] == '' else hf_config["_name_or_path"] + config["gpt"]["head_num"] = str(hf_config["n_head"]) + n_embd = hf_config["n_embd"] + config["gpt"]["size_per_head"] = str(n_embd // hf_config["n_head"]) + config["gpt"]["inter_size"] = str(n_embd * 4) + config['gpt']['max_pos_seq_len'] = str(hf_config['n_positions']) + config["gpt"]["num_layer"] = str(hf_config["n_layer"]) + config["gpt"]["vocab_size"] = str(hf_config["vocab_size"]) + config["gpt"]["start_id"] = str(hf_config["bos_token_id"]) + config["gpt"]["end_id"] = str(hf_config["eos_token_id"]) + config['gpt']['weight_data_type'] = args.weight_data_type + with open(saved_dir + "/config.ini", 'w') as configfile: config.write(configfile) except: print(f"Fail to save the config in config.ini.") + np_weight_data_type = get_weight_data_type(args.weight_data_type) huggingface_model_name_pattern = [ @@ -148,6 +159,7 @@ def split_and_convert(args): ] torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_sharing_strategy("file_system") pool = multiprocessing.Pool(args.processes) for name, param in model.named_parameters(): if name.find("weight") == -1 and name.find("bias") == -1: diff --git a/examples/pytorch/gpt/utils/huggingface_jp_gpt_convert.py b/examples/pytorch/gpt/utils/huggingface_jp_gpt_convert.py index bb4245694..9f484581c 100644 --- a/examples/pytorch/gpt/utils/huggingface_jp_gpt_convert.py +++ b/examples/pytorch/gpt/utils/huggingface_jp_gpt_convert.py @@ -98,20 +98,24 @@ def split_and_convert(args): # load position_embedding from rank 0 # model = torch.load(ckpt_name) - model = GPT2Model.from_pretrained(args.in_file) + model = GPT2Model.from_pretrained(args.in_file).to(torch.device('cuda:0')) + + hf_config = vars(model.config) + + config["gpt"]["model_name"] = "gpt" if hf_config["_name_or_path"] == '' else hf_config["_name_or_path"] + config["gpt"]["head_num"] = str(hf_config["n_head"]) + n_embd = hf_config["n_embd"] + config["gpt"]["size_per_head"] = str(n_embd // hf_config["n_head"]) + config["gpt"]["inter_size"] = str(n_embd * 4) + config['gpt']['max_pos_seq_len'] = str(hf_config['n_positions']) + config["gpt"]["num_layer"] = str(hf_config["n_layer"]) + config["gpt"]["vocab_size"] = str(hf_config["vocab_size"]) + config["gpt"]["start_id"] = str(hf_config["bos_token_id"]) + config["gpt"]["end_id"] = str(hf_config["eos_token_id"]) + config['gpt']['weight_data_type'] = args.weight_data_type + with open(output_dir + "/config.ini", 'w') as configfile: + config.write(configfile) - try: - config = configparser.ConfigParser() - config["gpt"] = {} - for key in vars(args): - config["gpt"][key] = f"{vars(args)[key]}" - for k, v in vars(model.config).items(): - config["gpt"][k] = f"{v}" - config["gpt"]["weight_data_type"] = args.weight_data_type - with open((Path(saved_dir) / f"config.ini").as_posix(), 'w') as configfile: - config.write(configfile) - except: - print(f"Fail to save the config in config.ini.") np_weight_data_type = get_weight_data_type(args.weight_data_type) huggingface_model_name_pattern = [ @@ -145,6 +149,7 @@ def split_and_convert(args): ] torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_sharing_strategy("file_system") pool = multiprocessing.Pool(args.processes) for name, param in model.named_parameters(): if name.find("weight") == -1 and name.find("bias") == -1: diff --git a/examples/pytorch/gpt/utils/huggingface_opt_convert.py b/examples/pytorch/gpt/utils/huggingface_opt_convert.py new file mode 100644 index 000000000..27610e597 --- /dev/null +++ b/examples/pytorch/gpt/utils/huggingface_opt_convert.py @@ -0,0 +1,250 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +Convert huggingface Meta OPT model. Use https://huggingface.co/facebook/opt-125m as demo. +''' + +import argparse +import configparser +import multiprocessing +import numpy as np +from pathlib import Path +import torch + +import os +import sys +from datetime import datetime +from transformers import OPTForCausalLM, AutoModelForCausalLM # transformers-4.20.0.dev0 +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(dir_path + "/../../../..") +sys.path.append(dir_path) + +def get_weight_data_type(data_type): + if data_type == "fp32": + return np.float32 + elif data_type == "fp16": + return np.float16 + else: + assert False, f"Invalid weight data type {data_type}" + +def split_and_convert_process(i, saved_dir,factor,key,args, val): + + if key.find("input_layernorm.weight") != -1 or key.find("input_layernorm.bias") != -1 or \ + key.find("attention.dense.bias") != -1 or key.find("post_attention_layernorm.weight") != -1 or \ + key.find("post_attention_layernorm.bias") != -1 or key.find("mlp.dense_4h_to_h.bias") != -1 or \ + key.find("final_layernorm.weight") != -1 or key.find("final_layernorm.bias") != -1: + + # shared weights, only need to convert the weights of rank 0 + if i == 0: + saved_path = saved_dir + "/model." + key + ".bin" + val.tofile(saved_path) + + elif key.find("attention.dense.weight") != -1 or key.find("mlp.dense_4h_to_h.weight") != -1: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + saved_path = saved_dir + "/model." + key + ".%d.bin" % (i * factor + j) + split_vals[j].tofile(saved_path) + + elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: + + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/model." + key + ".%d.bin" % (i * factor + j) + split_vals[j].tofile(saved_path) + + elif key.find("attention.query_key_value.bias") != -1: + local_dim = (int)(val.shape[-1] / 3) + + val = val.reshape(3, local_dim) + split_vals = np.split(val, factor, axis=-1) + + for j in range(factor): + saved_path = saved_dir + "/model." + key + ".%d.bin" % (i * factor + j) + split_vals[j].tofile(saved_path) + + elif key.find("attention.query_key_value.weight") != -1: + hidden_dim = val.shape[0] + local_dim = (int)(val.shape[-1] / 3) + + val = val.reshape(hidden_dim, 3, local_dim) + split_vals = np.split(val, factor, axis=-1) + + for j in range(factor): + saved_path = saved_dir + "/model." + key + ".%d.bin" % (i * factor + j) + split_vals[j].tofile(saved_path) + + else: + print("[ERROR] cannot find key '{}'".format(key)) + +def fuse_qkv_weight(q, k, v): + qkv = torch.cat([q, k, v], dim=-1) + return qkv + +def split_and_convert(args): + saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num + + if(os.path.exists(saved_dir) == False): + os.makedirs(saved_dir) + ckpt_name = args.in_file + + t_gpu_num = args.trained_gpu_num + i_gpu_num = args.infer_gpu_num + assert(i_gpu_num % t_gpu_num == 0) + + factor = (int)(i_gpu_num / t_gpu_num) + + # load position_embedding from rank 0 + model = AutoModelForCausalLM.from_pretrained(args.in_file) + + hf_config = vars(model.config) + + num_layers = hf_config["num_hidden_layers"] + + layer_names = [name for name, param in model.named_parameters()] + + # NOTE: save parameters to config files (loaded by triton backends) + config = configparser.ConfigParser() + config["gpt"] = {} + has_post_decoder_layernorm = "model.decoder.final_layer_norm.bias" in layer_names + try: + config["gpt"]["model_name"] = "opt" if hf_config["_name_or_path"] == '' else hf_config["_name_or_path"] + config["gpt"]["head_num"] = str(hf_config["num_attention_heads"]) + n_embd = hf_config["hidden_size"] + config["gpt"]["size_per_head"] = str(n_embd // hf_config["num_attention_heads"]) + config["gpt"]["inter_size"] = str(hf_config["ffn_dim"]) + config['gpt']['max_pos_seq_len'] = str(hf_config['max_position_embeddings']) + config["gpt"]["num_layer"] = str(hf_config["num_hidden_layers"]) + config["gpt"]["layernorm_eps"] = "1e-5"; + config["gpt"]["layernorm_type"] = "pre_layernorm" if hf_config["do_layer_norm_before"] else "post_layernorm" + config["gpt"]["activation_type"] = "Relu" + config["gpt"]["has_post_decoder_layernorm"] = "1" if has_post_decoder_layernorm else "0" + config["gpt"]["vocab_size"] = str(hf_config["vocab_size"]) + config["gpt"]["start_id"] = str(hf_config["bos_token_id"]) + config["gpt"]["end_id"] = str(hf_config["eos_token_id"]) + config['gpt']['weight_data_type'] = args.weight_data_type + with open(saved_dir + "/config.ini", 'w') as configfile: + config.write(configfile) + except: + print(f"Fail to save the config in config.ini.") + + np_weight_data_type = get_weight_data_type(args.weight_data_type) + + huggingface_model_name_pattern = [ + "self_attn_layer_norm.bias", + "self_attn_layer_norm.weight", + "self_attn.qkv_proj.bias", + "self_attn.qkv_proj.weight", + "self_attn.out_proj.bias", + "self_attn.out_proj.weight", + "final_layer_norm.bias", + "final_layer_norm.weight", + "fc1.bias", + "fc1.weight", + "fc2.bias", + "fc2.weight", + ] + + ft_model_name_pattern = [ + "input_layernorm.bias", + "input_layernorm.weight", + "attention.query_key_value.bias", + "attention.query_key_value.weight", + "attention.dense.bias", + "attention.dense.weight", + "post_attention_layernorm.bias", + "post_attention_layernorm.weight", + "mlp.dense_h_to_4h.bias", + "mlp.dense_h_to_4h.weight", + "mlp.dense_4h_to_h.bias", + "mlp.dense_4h_to_h.weight", + ] + + model_named_parameters_iter = model.named_parameters() + model_named_parameters = dict() + for name, param in model_named_parameters_iter: + if name.find("embed") != -1: + model_named_parameters[name] = param + elif name.find("project_in") != -1: + model_named_parameters[name] = param.permute(1, 0) + elif name.find("project_out") != -1: + model_named_parameters[name] = param + else: + model_named_parameters[name] = param.permute(1, 0) if len(param.shape) == 2 else param + # print(model_named_parameters.keys()) + for l in range(num_layers): + q_weight = model_named_parameters[f'model.decoder.layers.{l}.self_attn.q_proj.weight'] + k_weight = model_named_parameters[f'model.decoder.layers.{l}.self_attn.k_proj.weight'] + v_weight = model_named_parameters[f'model.decoder.layers.{l}.self_attn.v_proj.weight'] + q_bias = model_named_parameters[f'model.decoder.layers.{l}.self_attn.q_proj.bias'] + k_bias = model_named_parameters[f'model.decoder.layers.{l}.self_attn.k_proj.bias'] + v_bias = model_named_parameters[f'model.decoder.layers.{l}.self_attn.v_proj.bias'] + qkv_weight = fuse_qkv_weight(q_weight, k_weight, v_weight) + qkv_bias = fuse_qkv_weight(q_bias, k_bias, v_bias) + model_named_parameters[f'model.decoder.layers.{l}.self_attn.qkv_proj.weight'] = qkv_weight + model_named_parameters[f'model.decoder.layers.{l}.self_attn.qkv_proj.bias'] = qkv_bias + + torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_sharing_strategy("file_system") + pool = multiprocessing.Pool(args.processes) + padding_offset = 2 + for name, param in model_named_parameters.items(): + if name == 'model.decoder.embed_positions.weight': + param[padding_offset:,...].detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wpe.bin") + elif name == 'model.decoder.embed_tokens.weight': + if 'model.decoder.project_in.weight' in model_named_parameters.keys(): + project_in = model_named_parameters['model.decoder.project_in.weight'] + project_out = model_named_parameters['model.decoder.project_out.weight'] + torch.matmul(param, project_in).detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.bin") + torch.matmul(param, project_out).detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin") + else: + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.bin") + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin") + elif name == 'model.decoder.final_layer_norm.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.final_layernorm.weight.bin") + elif name == 'model.decoder.final_layer_norm.bias': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.final_layernorm.bias.bin") + elif name.find("project_in") != -1 or name.find("project_out") != -1: + continue + else: + for i in range(len(huggingface_model_name_pattern)): + if name.find(huggingface_model_name_pattern[i]) != -1: + new_name = name.replace("model.decoder.layers.", "layers.").replace(huggingface_model_name_pattern[i], ft_model_name_pattern[i]) + pool.starmap(split_and_convert_process, + [(0, saved_dir, factor, new_name, args, + param.detach().cpu().numpy().astype(np_weight_data_type))], ) + + pool.close() + pool.join() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('-saved_dir', '-o', type=str, help='file name of output file', required=True) + parser.add_argument('-in_file', '-i', type=str, help='file name of input checkpoint file', required=True) + parser.add_argument('-trained_gpu_num', '-t_g', type=int, help='How many gpus for inference', default=1) + parser.add_argument('-infer_gpu_num', '-i_g', type=int, help='How many gpus for inference', required=True) + parser.add_argument("-processes", "-p", type=int, help="How many processes to spawn for conversion (default: 4)", default=4) + parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) + + args = parser.parse_args() + print("\n=============== Argument ===============") + for key in vars(args): + print(f"{key}: {vars(args)[key]}") + print("========================================") + + start_time = datetime.now() + split_and_convert(args) + stop_time = datetime.now() + run_time = (stop_time - start_time) + print(f"[INFO] Spend {run_time} (h:m:s) to convert the model") diff --git a/examples/pytorch/gpt/utils/megatron_ckpt_convert.py b/examples/pytorch/gpt/utils/megatron_ckpt_convert.py index ec76ef0d6..62616376b 100644 --- a/examples/pytorch/gpt/utils/megatron_ckpt_convert.py +++ b/examples/pytorch/gpt/utils/megatron_ckpt_convert.py @@ -14,156 +14,224 @@ import argparse import configparser +import datetime +import json import multiprocessing -from pathlib import Path +import pathlib +import re +import shutil +import sys import numpy as np import torch # pytype: disable=import-error -def get_weight_data_type(data_type): - if data_type == "fp32": - return np.float32 - elif data_type == "fp16": - return np.float16 - else: - assert False, f"Invalid weight data type {data_type}" +# verify if root package is in PYTHONPATH +__root_package_path__ = pathlib.Path(__file__).parent.parent.parent.parent.parent.absolute().as_posix() +if __root_package_path__ not in sys.path: + print( + f"[ERROR] add project root directory to your PYTHONPATH with " + f"'export PYTHONPATH={__root_package_path__}:${{PYTHONPATH}}'" + ) -# more to less. e.g., trained by 8 gpus, infer by 2 gpus -def merge_and_convert(args): # noqa: C901 too complex - saved_dir = Path(args.saved_dir) - if args.fused_qkv == 1: - saved_dir = saved_dir / f"{args.infer_gpu_num:d}-gpu/" +from examples.pytorch.gpt.utils.gpt import DEFAULT_START_TAG, DEFAULT_END_TAG, OPENAI_GPT2_START_ID, OPENAI_GPT2_END_ID +from examples.pytorch.utils import torch2np, safe_transpose, cpu_map_location, gpu_map_location, WEIGHT2DTYPE + + +def _inject_model_parallel_rank( + filepath, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + tensor_model_parallel_rank=0, + pipeline_model_parallel_rank=0, +): + """ + Injects tensor/pipeline model parallel ranks into the filepath. + Does nothing if not using model parallelism. + """ + filepath = pathlib.Path(filepath) + if tensor_model_parallel_size > 1 or pipeline_model_parallel_size > 1: + # filepath needs to be updated to include mp_rank + if pipeline_model_parallel_size is None or pipeline_model_parallel_size == 1: + filepath = filepath.parent / f"mp_rank_{tensor_model_parallel_rank:02d}" / filepath.name + else: + filepath = ( + filepath.parent / + f"mp_rank_{tensor_model_parallel_rank:02d}_{pipeline_model_parallel_rank:03d}" / + filepath.name + ) + if not filepath.exists(): + filepath = ( + filepath.parent / + f"tp_rank_{tensor_model_parallel_rank:02d}_pp_rank_{pipeline_model_parallel_rank:03d}" / + filepath.name + ) + return filepath else: - saved_dir = saved_dir / f"unfusedQKV-{args.infer_gpu_num:d}-gpu" - ckpt_ver = args.checkpoint_version + if filepath.exists(): + return filepath + else: + return filepath.parent / "mp_rank_00" / filepath.name - saved_dir.mkdir(parents=True, exist_ok=True) - - config = configparser.ConfigParser() - config["gpt"] = {} - for key in vars(args): - config["gpt"][key] = f"{vars(args)[key]}" - config["gpt"]["weight_data_type"] = args.weight_data_type - with open((saved_dir / f"config.ini").as_posix(), 'w') as configfile: - config.write(configfile) - np_weight_data_type = get_weight_data_type(args.weight_data_type) - - prefix = Path(args.in_file) - ckpt_name = "model_optim_rng.pt" - t_gpu_num = args.trained_gpu_num - i_gpu_num = args.infer_gpu_num +def _create_model_training_args_for_checkpoint_version_0(args, model_00): + model_training_args = argparse.Namespace() + if args.head_num is None or args.trained_tensor_parallel_size is None: + raise ValueError( + "Provided checkpoint have missing training args. " + "Thus it is required to provide -head_num and -trained_tensor_parallel_size CLI arguments" + ) + model_training_args.num_attention_heads = args.head_num + model_training_args.tensor_model_parallel_size = args.trained_tensor_parallel_size + # megatron ckpt_ver=0 only supports pipeline_parallel_size = 1 + model_training_args.pipeline_model_parallel_size = 1 + model_training_args.max_position_embeddings = \ + model_00["model"]["language_model"]["embedding"]["position_embeddings"]["weight"].shape[0] + model_training_args.hidden_size = \ + model_00["model"]["language_model"]["embedding"]["position_embeddings"]["weight"].shape[1] + model_training_args.ffn_hidden_size = 4 * model_training_args.hidden_size + + def get_layer_num_from_weights(model_keys): + layer_num = 1 + for key in model_keys: + if re.search(r'\d+', key) is not None: + layer_num = max(int(re.search(r'\d+', key).group()), layer_num) + return layer_num + 1 + + model_training_args.num_layers = \ + get_layer_num_from_weights(model_00["model"]["language_model"]['transformer'].keys()) + + model_training_args.layernorm_epsilon = 1e-6 + + return model_training_args + + +# This tool is used to support the new megatron model trained by pipeline parallel + tensor parallel +def merge_and_convert_process(i, pipeline_para_rank, saved_dir, factor, key, model_training_args, transformer_model_list, ckpt_ver, np_weight_data_type): + saved_dir = pathlib.Path(saved_dir) + if key.find("layers.") != -1: + layer_index = (int)(key[7 : key.find(".", 7)]) + saved_key = key.replace( + "layers.%d." % layer_index, + "layers.%d." % (layer_index + pipeline_para_rank * model_training_args.num_layers // model_training_args.pipeline_model_parallel_size)) + + if saved_key.find("self_attention") != -1: + saved_key = saved_key.replace("self_attention", "attention") + if saved_key.find("adaptor1") != -1: + saved_key = saved_key.replace("adaptor1", "after_attention_adapter") + if saved_key.find("adaptor2") != -1: + saved_key = saved_key.replace("adaptor2", "after_ffn_adapter") + else: + saved_key = key + major_device = transformer_model_list[0][key].device - assert t_gpu_num % i_gpu_num == 0 - factor = int(t_gpu_num / i_gpu_num) + if ( + key.find("input_layernorm.weight") != -1 + or key.find("input_layernorm.bias") != -1 + or key.find("attention.dense.bias") != -1 + or key.find("post_attention_layernorm.weight") != -1 + or key.find("post_attention_layernorm.bias") != -1 + or key.find("mlp.dense_4h_to_h.bias") != -1 + or key.find("adaptor1.dense_4h_to_h.bias") != -1 + or key.find("adaptor2.dense_4h_to_h.bias") != -1 + or key.find("final_layernorm.weight") != -1 + or key.find("final_layernorm.bias") != -1): - # load position_embedding from rank 0 - model_00 = torch.load((prefix / "mp_rank_00" / ckpt_name).as_posix()) - model_00["model"]["language_model"]["embedding"]["position_embeddings"]["weight"].cpu().numpy().astype( - np_weight_data_type - ).tofile( - (saved_dir / "model.wpe.bin").as_posix() - ) # not weight, do not need transpose + # shared weights, only need to convert the weights of rank 0 + if i == 0: + saved_path = saved_dir / f"model.{saved_key}.bin" + val = safe_transpose(transformer_model_list[0][key]) + val = torch2np(val, np_weight_data_type) + val = np.squeeze(val) + val.tofile(saved_path) + + elif (key.find("attention.dense.weight") != -1 + or key.find("mlp.dense_4h_to_h.weight") != -1 + or key.find("adaptor1.dense_4h_to_h.weight") != -1 + or key.find("adaptor2.dense_4h_to_h.weight") != -1): + vals = [ + safe_transpose(transformer_model_list[k][key]).float().to(major_device) + for k in range(factor) + ] + val = torch.cat(vals, dim=0) + val = torch2np(val, np_weight_data_type) + saved_path = saved_dir / f"model.{saved_key}.{i:d}.bin" + val.tofile(saved_path) + + elif (key.find("mlp.dense_h_to_4h.weight") != -1 + or key.find("adaptor1.dense_h_to_4h.weight") != -1 + or key.find("adaptor2.dense_h_to_4h.weight") != -1 + or key.find("mlp.dense_h_to_4h.bias") != -1 + or key.find("adaptor1.dense_h_to_4h.bias") != -1 + or key.find("adaptor2.dense_h_to_4h.bias") != -1): + vals = [ + safe_transpose(transformer_model_list[k][key]).float().to(major_device) + for k in range(factor) + ] + val = torch.cat(vals, dim=-1) + val = torch2np(val, np_weight_data_type) + saved_path = saved_dir / f"model.{saved_key}.{i:d}.bin" + val.tofile(saved_path) - del model_00 - w_e_list = [] - for i in range(i_gpu_num): - transformer_models = [] - for j in range(factor): - model = torch.load(prefix / f"mp_rank_{i * factor + j:02d}" / ckpt_name) - w_e_list.append( - model["model"]["language_model"]["embedding"]["word_embeddings"]["weight"] - .cpu() - .numpy() - .astype(np_weight_data_type) - ) + elif key.find("attention.query_key_value.bias") != -1: + vals = [] + for k in range(factor): + val = safe_transpose(transformer_model_list[k][key]).float() + local_dim = int(val.shape[-1] / 3) if ckpt_ver == 3: - transformer_models.append(model["model"]["language_model"]["encoder"]) - else: - transformer_models.append(model["model"]["language_model"]["transformer"]) - - for key in transformer_models[0]: - print(key, transformer_models[0][key].shape) - - if ( - key.find("input_layernorm.weight") != -1 - or key.find("input_layernorm.bias") != -1 - or key.find("attention.dense.bias") != -1 - or key.find("post_attention_layernorm.weight") != -1 - or key.find("post_attention_layernorm.bias") != -1 - or key.find("mlp.dense_4h_to_h.bias") != -1 - or key.find("final_layernorm.weight") != -1 - or key.find("final_layernorm.bias") != -1 - ): - - # shared weights, only need to convert the weights of rank 0 - if i == 0: - val = transformer_models[0][key].T.cpu().numpy() - saved_path = saved_dir / f"model.{key}.bin" - np.squeeze(val).astype(np_weight_data_type).tofile(saved_path.as_posix()) - - elif key.find("attention.dense.weight") != -1 or key.find("mlp.dense_4h_to_h.weight") != -1: - vals = [] - for k in range(factor): - vals.append(transformer_models[k][key].T.cpu().numpy()) - saved_path = saved_dir / f"model.{key}.{i}.bin" - np.concatenate(vals, axis=0).astype(np_weight_data_type).tofile(saved_path.as_posix()) - - elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: - - vals = [] - for k in range(factor): - vals.append(transformer_models[k][key].T.cpu().numpy()) - saved_path = saved_dir / f"model.{key}.{i}.bin" - np.concatenate(vals, axis=-1).astype(np_weight_data_type).tofile(saved_path.as_posix()) - - elif key.find("attention.query_key_value.bias") != -1: - vals = [] - for k in range(factor): - val = transformer_models[k][key].T.cpu().numpy() - local_dim = (int)(val.shape[-1] / 3) - if ckpt_ver == 3: - num_splits = 3 - head_num = args.head_num // args.trained_gpu_num - size_per_head = local_dim // head_num - val = val.reshape(head_num, num_splits, size_per_head) - val = val.transpose(1, 0, 2) - val = val.reshape(3, local_dim) - vals.append(val) - - saved_path = saved_dir / f"model.{key}.{i}.bin" - np.concatenate(vals, axis=-1).astype(np_weight_data_type).tofile(saved_path.as_posix()) - - elif key.find("attention.query_key_value.weight") != -1: - vals = [] - for k in range(factor): - val = transformer_models[k][key].T.cpu().numpy() - hidden_dim = val.shape[0] - local_dim = (int)(val.shape[-1] / 3) - if ckpt_ver == 3: - num_splits = 3 - head_num = args.head_num - size_per_head = hidden_dim // head_num - head_num = head_num // args.trained_gpu_num - val = val.reshape(hidden_dim, head_num, num_splits, size_per_head) - val = val.transpose(0, 2, 1, 3) - val = val.reshape(hidden_dim, 3, local_dim) - vals.append(val) - - saved_path = saved_dir / f"model.{key}.{i}.bin" - if args.fused_qkv == 1: - np.concatenate(vals, axis=-1).astype(np_weight_data_type).tofile(saved_path.as_posix()) - elif args.fused_qkv == 0: - np.concatenate(vals, axis=-1).transpose(1, 0, 2).astype(np_weight_data_type).tofile(saved_path.as_posix()) + num_splits = 3 + head_num = model_training_args.num_attention_heads // model_training_args.tensor_model_parallel_size + size_per_head = local_dim // head_num + val = val.reshape(head_num, num_splits, size_per_head) + val = val.permute(1, 0, 2) + val = val.reshape(3, local_dim) + vals.append(val.to(major_device)) + val = torch.cat(vals, dim=-1) + val = torch2np(val, np_weight_data_type) + saved_path = saved_dir / f"model.{saved_key}.{i:d}.bin" + val.tofile(saved_path) - else: - print(f"[ERROR] cannot find key '{key}'") + elif key.find("attention.query_key_value.weight") != -1: + vals = [] + for k in range(factor): + val = safe_transpose(transformer_model_list[k][key]).float() + hidden_dim = val.shape[0] + local_dim = int(val.shape[-1] / 3) + if ckpt_ver == 3: + num_splits = 3 + head_num = model_training_args.num_attention_heads + size_per_head = hidden_dim // head_num + head_num = head_num // model_training_args.tensor_model_parallel_size + val = val.reshape(hidden_dim, head_num, num_splits, size_per_head) + val = val.permute(0, 2, 1, 3) + val = val.reshape(hidden_dim, 3, local_dim) + vals.append(val.to(major_device)) + val = torch.cat(vals, dim=-1) + val = torch2np(val, np_weight_data_type) + saved_path = saved_dir / f"model.{saved_key}.{i:d}.bin" + val.tofile(saved_path) + + else: + print(f"[ERROR] cannot find key '{key}'") - np.concatenate(w_e_list, axis=0).tofile((saved_dir / "model.wte.bin").as_posix()) +def split_and_convert_process(i, pipeline_para_rank, saved_dir, factor, key, model_training_args, transformer_model_list, ckpt_ver, np_weight_data_type): + val = safe_transpose(transformer_model_list[0][key]) + val = torch2np(val, np_weight_data_type) + if key.find("layers.") != -1: + layer_index = (int)(key[7 : key.find(".", 7)]) + saved_key = key.replace( + "layers.%d." % layer_index, + "layers.%d." % (layer_index + pipeline_para_rank * model_training_args.num_layers // model_training_args.pipeline_model_parallel_size)) + + if saved_key.find("self_attention") != -1: + saved_key = saved_key.replace("self_attention", "attention") + if saved_key.find("adaptor1") != -1: + saved_key = saved_key.replace("adaptor1", "after_attention_adapter") + if saved_key.find("adaptor2") != -1: + saved_key = saved_key.replace("adaptor2", "after_ffn_adapter") + else: + saved_key = key -def split_and_convert_process(i, saved_dir, factor, key, args, val, ckpt_ver): - saved_dir = Path(saved_dir) if ( key.find("input_layernorm.weight") != -1 or key.find("input_layernorm.bias") != -1 @@ -171,36 +239,43 @@ def split_and_convert_process(i, saved_dir, factor, key, args, val, ckpt_ver): or key.find("post_attention_layernorm.weight") != -1 or key.find("post_attention_layernorm.bias") != -1 or key.find("mlp.dense_4h_to_h.bias") != -1 + or key.find("adaptor1.dense_4h_to_h.bias") != -1 + or key.find("adaptor2.dense_4h_to_h.bias") != -1 or key.find("final_layernorm.weight") != -1 or key.find("final_layernorm.bias") != -1 ): - # shared weights, only need to convert the weights of rank 0 if i == 0: - saved_path = saved_dir / f"model.{key}.bin" + saved_path = saved_dir / f"model.{saved_key}.bin" val.tofile(saved_path.as_posix()) - elif key.find("attention.dense.weight") != -1 or key.find("mlp.dense_4h_to_h.weight") != -1: + elif (key.find("attention.dense.weight") != -1 + or key.find("mlp.dense_4h_to_h.weight") != -1 + or key.find("adaptor1.dense_4h_to_h.weight") != -1 + or key.find("adaptor2.dense_4h_to_h.weight") != -1): split_vals = np.split(val, factor, axis=0) for j in range(factor): - saved_path = saved_dir / f"model.{key}.{i * factor + j}.bin" + saved_path = saved_dir / f"model.{saved_key}.{i * factor + j:d}.bin" split_vals[j].tofile(saved_path.as_posix()) - elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: - + elif (key.find("mlp.dense_h_to_4h.weight") != -1 + or key.find("adaptor1.dense_h_to_4h.weight") != -1 + or key.find("adaptor2.dense_h_to_4h.weight") != -1 + or key.find("mlp.dense_h_to_4h.bias") != -1 + or key.find("adaptor1.dense_h_to_4h.bias") != -1 + or key.find("adaptor2.dense_h_to_4h.bias") != -1): split_vals = np.split(val, factor, axis=-1) for j in range(factor): - saved_path = saved_dir / f"model.{key}.{i * factor + j}.bin" + saved_path = saved_dir / f"model.{saved_key}.{i * factor + j:d}.bin" split_vals[j].tofile(saved_path.as_posix()) elif key.find("attention.query_key_value.bias") != -1: - local_dim = (int)(val.shape[-1] / 3) + local_dim = int(val.shape[-1] / 3) if ckpt_ver == 3: num_splits = 3 - head_num = args.head_num // args.trained_gpu_num + head_num = model_training_args.num_attention_heads // model_training_args.tensor_model_parallel_size size_per_head = local_dim // head_num - val = val.reshape(head_num, num_splits, size_per_head) val = val.transpose(1, 0, 2) @@ -208,125 +283,274 @@ def split_and_convert_process(i, saved_dir, factor, key, args, val, ckpt_ver): split_vals = np.split(val, factor, axis=-1) for j in range(factor): - saved_path = saved_dir / f"model.{key}.{i * factor + j}.bin" + saved_path = saved_dir / f"model.{saved_key}.{i * factor + j:d}.bin" split_vals[j].tofile(saved_path.as_posix()) elif key.find("attention.query_key_value.weight") != -1: hidden_dim = val.shape[0] - local_dim = (int)(val.shape[-1] / 3) + local_dim = int(val.shape[-1] / 3) if ckpt_ver == 3: num_splits = 3 - head_num = args.head_num + head_num = model_training_args.num_attention_heads size_per_head = hidden_dim // head_num - head_num = head_num // args.trained_gpu_num - + head_num = head_num // model_training_args.tensor_model_parallel_size val = val.reshape(hidden_dim, head_num, num_splits, size_per_head) val = val.transpose(0, 2, 1, 3) - if args.fused_qkv == 1: - val = val.reshape(hidden_dim, 3, local_dim) - elif args.fused_qkv == 0: - val = val.reshape(hidden_dim, 3, local_dim).transpose(1, 0, 2) + val = val.reshape(hidden_dim, 3, local_dim) split_vals = np.split(val, factor, axis=-1) for j in range(factor): - saved_path = saved_dir / f"model.{key}.{i * factor + j}.bin" + saved_path = saved_dir / f"model.{saved_key}.{i * factor + j:d}.bin" split_vals[j].tofile(saved_path.as_posix()) else: print(f"[ERROR] cannot find key '{key}'") -# less to more. e.g., trained by 2 gpus, infer by 8 gpus -def split_and_convert(args): - saved_dir = Path(args.saved_dir) - if args.fused_qkv == 1: - saved_dir = saved_dir / f"{args.infer_gpu_num}-gpu" - else: - saved_dir = saved_dir / f"unfusedQKV-{args.infer_gpu_num}-gpu/" +def _get_checkpoint_name(checkpoint_dir): - saved_dir.mkdir(parents=True, exist_ok=True) - - config = configparser.ConfigParser() - config["gpt"] = {} - for key in vars(args): - config["gpt"][key] = f"{vars(args)[key]}" - config["gpt"]["weight_data_type"] = args.weight_data_type - with open((saved_dir / f"config.ini").as_posix(), 'w') as configfile: - config.write(configfile) + checkpoint_dir = pathlib.Path(checkpoint_dir) + patterns = [ + "model_optim_rng.pt", # older megatron checkpoints + "*last.ckpt", # newer format of checkpoints + ] + for pattern in patterns: + model_files = sorted(list(checkpoint_dir.rglob(pattern))) + if model_files: + return model_files[0].name - np_weight_data_type = get_weight_data_type(args.weight_data_type) - prefix = Path(args.in_file) - ckpt_name = "model_optim_rng.pt" - t_gpu_num = args.trained_gpu_num - i_gpu_num = args.infer_gpu_num - assert i_gpu_num % t_gpu_num == 0 + raise ValueError(f"Could not find checkpoint files in {checkpoint_dir}") - factor = int(i_gpu_num / t_gpu_num) + +def convert_checkpoint(args): + saved_dir = pathlib.Path(args.saved_dir) / f"{args.infer_gpu_num:d}-gpu" + if saved_dir.exists(): + print(f"[ERROR] Remove {saved_dir} target directory before running conversion") + sys.exit(1) + saved_dir.mkdir(parents=True) + + if args.vocab_path: + shutil.copy(args.vocab_path, (saved_dir / "vocab.json").as_posix()) + if args.merges_path: + shutil.copy(args.merges_path, (saved_dir / "merges.txt").as_posix()) + + load_checkpoints_to_cpu = bool(args.load_checkpoints_to_cpu) + map_location_fn = cpu_map_location if load_checkpoints_to_cpu else gpu_map_location + + checkpoints_dir = pathlib.Path(args.in_file) + checkpoint_name = _get_checkpoint_name(checkpoints_dir) # load position_embedding from rank 0 - model_00 = torch.load((prefix / "mp_rank_00" / ckpt_name).as_posix()) - model_00["model"]["language_model"]["embedding"]["position_embeddings"]["weight"].cpu().numpy().astype( - np_weight_data_type - ).tofile( - (saved_dir / "model.wpe.bin").as_posix() - ) # not weight, do not need transpose + checkpoints_paths = sorted(checkpoints_dir.rglob(checkpoint_name)) + if not checkpoints_paths: + print(f"[ERROR] Cannot find checkpoint in {checkpoints_dir}.") + exit(1) + model_00 = torch.load(checkpoints_paths[0].as_posix(), map_location=map_location_fn) + + if "hyper_parameters" in list(model_00.keys()): + print("Use nemo_ckpt_converter.py script for conversion of this checkpoint") + exit(1) + elif "args" in list(model_00.keys()): + checkpoint_version = model_00["checkpoint_version"] + model_training_args = model_00["args"] + megatron_gpt_key = "encoder" + else: + checkpoint_version = 0 + model_training_args = _create_model_training_args_for_checkpoint_version_0(args, model_00) + megatron_gpt_key = "transformer" + + with (saved_dir / "args.txt").open("w") as training_args_file: + for k, v in vars(model_training_args).items(): + training_args_file.write(f"{k}:{v}\n") + + np_weight_data_type = WEIGHT2DTYPE[args.weight_data_type] + + val = model_00["model"]["language_model"]["embedding"]["position_embeddings"]["weight"] + val = torch2np(val, np_weight_data_type) + val.tofile((saved_dir / "model.wpe.bin").as_posix()) # not weight, do not need to transpose del model_00 w_e_list = [] - pool = multiprocessing.Pool(8) - for i in range(t_gpu_num): - m = torch.load(prefix / f"mp_rank_{i:02d}" / ckpt_name) - if args.checkpoint_version == 3: - transformer_model = m["model"]["language_model"]["encoder"] - else: - transformer_model = m["model"]["language_model"]["transformer"] - - w_e_list.append( - m["model"]["language_model"]["embedding"]["word_embeddings"]["weight"].cpu().numpy().astype(np_weight_data_type) - ) + training_tensor_para_size = model_training_args.tensor_model_parallel_size + training_pipeline_para_size = model_training_args.pipeline_model_parallel_size + inference_tensor_para_size = args.infer_gpu_num + + model_weights_paths = [ + [ + _inject_model_parallel_rank( + checkpoints_dir / checkpoint_name, + tensor_model_parallel_size=training_tensor_para_size, + pipeline_model_parallel_size=training_pipeline_para_size, + tensor_model_parallel_rank=tp_rank, + pipeline_model_parallel_rank=pp_rank, + ) + for pp_rank in range(training_pipeline_para_size) + ] + for tp_rank in range(training_tensor_para_size) + ] + + if training_tensor_para_size > inference_tensor_para_size: + assert training_tensor_para_size % inference_tensor_para_size == 0 + is_merge_ckpt = True + factor = int(training_tensor_para_size / inference_tensor_para_size) + else: + assert inference_tensor_para_size % training_tensor_para_size == 0 + is_merge_ckpt = False + factor = int(inference_tensor_para_size / training_tensor_para_size) - pool.starmap( - split_and_convert_process, - [ - ( - i, - saved_dir, - factor, - k, - args, - transformer_model[k].T.cpu().numpy().astype(np_weight_data_type), - args.checkpoint_version, - ) - for (k, v) in transformer_model.items() - ], - ) + main_loop = min(training_tensor_para_size, inference_tensor_para_size) + vocab_size_list = [0 for i in range(main_loop)] + + torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_sharing_strategy("file_system") + pool = multiprocessing.Pool(args.processes) + has_adapters = False + for i in range(main_loop): + for j in range(training_pipeline_para_size): + + transformer_models = [] + if is_merge_ckpt: + for k in range(factor): + m = torch.load(model_weights_paths[i * factor + k][j].as_posix(), map_location=map_location_fn) + if not has_adapters: + has_adapters = any("adaptor" in key for key in m['model']['language_model'][megatron_gpt_key].keys()) + transformer_models.append(m["model"]["language_model"][megatron_gpt_key]) + + if j == 0: + vocab_size_list[i] = m["model"]["language_model"]["embedding"]["word_embeddings"]["weight"].shape[0] + w_e_list.append(torch2np(m["model"]["language_model"]["embedding"]["word_embeddings"]["weight"], np_weight_data_type)) + else: + m = torch.load(model_weights_paths[i][j].as_posix(), map_location=map_location_fn) + if not has_adapters: + has_adapters = any("adaptor" in key for key in m['model']['language_model'][megatron_gpt_key].keys()) + + if j == 0: + vocab_size_list[i] = m["model"]["language_model"]["embedding"]["word_embeddings"]["weight"].shape[0] + w_e_list.append(torch2np( + m["model"]["language_model"]["embedding"]["word_embeddings"]["weight"], + np_weight_data_type + )) + transformer_models.append(m["model"]["language_model"][megatron_gpt_key]) + + pool.starmap( + merge_and_convert_process if is_merge_ckpt else split_and_convert_process, + [ + ( + i, + j, + saved_dir, + factor, + k, + model_training_args, + transformer_models, + checkpoint_version, + np_weight_data_type, + ) + for (k, v) in transformer_models[0].items() + ], + ) pool.close() pool.join() + torch.cuda.synchronize() + np.concatenate(w_e_list, axis=0).tofile((saved_dir / "model.wte.bin").as_posix()) + # save vocab_size + full_vocab_size = sum(vocab_size_list) + if not hasattr(model_training_args, "padded_vocab_size"): + model_training_args.padded_vocab_size = full_vocab_size -if __name__ == "__main__": + # Configuration for the model (load by triton backends) + config = configparser.ConfigParser() + config["gpt"] = {} + + if args.vocab_path: + vocab_path = pathlib.Path(args.vocab_path) + with vocab_path.open("r") as vocab_file: + vocab = json.load(vocab_file) + start_id, end_id = vocab[DEFAULT_START_TAG], vocab[DEFAULT_END_TAG] + else: + # hard coded values from english gpt_vocab.json file + start_id, end_id = str(OPENAI_GPT2_START_ID), str(OPENAI_GPT2_END_ID) + try: + config["gpt"]["model_name"] = "gpt" + config["gpt"]["head_num"] = str(model_training_args.num_attention_heads) + config["gpt"]["size_per_head"] = str(model_training_args.hidden_size // model_training_args.num_attention_heads) + config["gpt"]["inter_size"] = str(model_training_args.ffn_hidden_size) + config["gpt"]["num_layer"] = str(model_training_args.num_layers) + config["gpt"]["max_pos_seq_len"] = str(model_training_args.max_position_embeddings) + config["gpt"]["vocab_size"] = str(model_training_args.padded_vocab_size) + config["gpt"]["has_adapters"] = str(has_adapters) + config['gpt']['adapter_inter_size'] = str(model_training_args.project_size) if has_adapters else str(0) + config["gpt"]["layernorm_eps"] = str(model_training_args.layernorm_epsilon) + config["gpt"]["start_id"] = str(start_id) + config["gpt"]["end_id"] = str(end_id) + config["gpt"]["weight_data_type"] = args.weight_data_type + config["gpt"]["tensor_para_size"] = str(args.infer_gpu_num) + with open((saved_dir / f"config.ini").as_posix(), 'w') as configfile: + config.write(configfile) + except Exception as e: + print(f"Fail to save the config in config.ini: {e}") + + +def main(): parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument("-saved_dir", "-o", type=str, help="file name of output file", required=True) - parser.add_argument("-in_file", "-i", type=str, help="file name of input checkpoint file", required=True) - parser.add_argument("-trained_gpu_num", "-t_g", type=int, help="How many gpus for inference", required=True) - parser.add_argument("-infer_gpu_num", "-i_g", type=int, help="How many gpus for inference", required=True) + parser.add_argument("--saved-dir", "-saved_dir", "-o", help="folder name of output files", required=True) parser.add_argument( - "-fused_qkv", - "-fused_qkv", + "--in-file", "-in_file", "-i", help="file name of input checkpoint file", required=True + ) + parser.add_argument( + "--infer-gpu-num", "-infer_gpu_num", "-i_g", type=int, help="How many gpus for inference", required=True + ) + # -h_n and -t_g are needed when megatron_ckpt_version = 0, for example the public megatron 345M gpt model + parser.add_argument( + "--head-num", + "-head_num", + "-h_n", + type=int, + help="The number of heads, only needed when weight doesn't contain structure hyperparameters" + ) + parser.add_argument( + "--trained-tensor-parallel-size", + "-trained_tensor_parallel_size", + "-t_g", + type=int, + help="the tensor parallel size for training" + ) + parser.add_argument( + "--processes", + "-processes", + "-p", + type=int, + default=16, + help="How many processes to spawn for conversion", + ) + parser.add_argument( + "--weight-data-type", "-weight_data_type", choices=["fp32", "fp16"], default="fp32", help="" + ) + parser.add_argument( + "--load-checkpoints-to-cpu", + "-load_checkpoints_to_cpu", + "-cpu", type=int, - default=1, - help="Fuse the qkv weights or not. Default is true (1)", choices=[0, 1], + default=1, + help="Whether to load model weights to CPU", + ) + parser.add_argument( + "--vocab-path", + type=str, + help="Path to vocabulary file to embed in FasterTransformer checkpoint", + required=False, + ) + parser.add_argument( + "--merges-path", type=str, help="Path to merges file to embed in FasterTransformer checkpoint", required=False ) - parser.add_argument("-head_num", "-h_n", type=int, help="Number of heads", required=True) - parser.add_argument("-checkpoint_version", type=int, default=0, help="Checkpoint version of Megatron-LM") - parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) args = parser.parse_args() print("\n=============== Argument ===============") @@ -334,8 +558,11 @@ def split_and_convert(args): print(f"{key}: {vars(args)[key]}") print("========================================") - if args.trained_gpu_num > args.infer_gpu_num: - merge_and_convert(args) - else: - split_and_convert(args) + start_time = datetime.datetime.now() + convert_checkpoint(args) + run_time = datetime.datetime.now() - start_time + print(f"[INFO] Spent {run_time} (h:m:s) to convert the model") + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/gpt/utils/nemo_ckpt_convert.py b/examples/pytorch/gpt/utils/nemo_ckpt_convert.py index fe274b8ea..cbca62f93 100644 --- a/examples/pytorch/gpt/utils/nemo_ckpt_convert.py +++ b/examples/pytorch/gpt/utils/nemo_ckpt_convert.py @@ -12,434 +12,648 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import argparse import configparser +import dataclasses +import datetime +import logging import multiprocessing -from pathlib import Path -import tarfile +import os +import pathlib +import shutil +import sys import tempfile +import typing + import numpy as np import torch # pytype: disable=import-error import yaml -def get_weight_data_type(data_type): - if data_type == "fp32": - return np.float32 - elif data_type == "fp16": - return np.float16 +# verify if root package is in PYTHONPATH +__root_package_path__ = pathlib.Path(__file__).parent.parent.parent.parent.parent.absolute().as_posix() +if __root_package_path__ not in sys.path: + print( + f"[ERROR] add project root directory to your PYTHONPATH with " + f"'export PYTHONPATH={__root_package_path__}:${{PYTHONPATH}}'" + ) + +from examples.pytorch.gpt.utils.gpt import GptModelConfig +from examples.pytorch.nemo import ( + UnpackedNemoCheckpointDir, + unpack_nemo_ckpt, + extract_layers_with_prefix, +) +from examples.pytorch.utils import ( + torch2np, + safe_transpose, + cpu_map_location, + gpu_map_location, + WEIGHT2DTYPE, +) + + +LOGGER = logging.getLogger(__name__) + + +# This tool is used to support the new NeMo megatron model trained by pipeline parallel + tensor parallel +def merge_and_convert_process( + tp_rank: int, + pp_rank: int, + saved_dir: typing.Union[str, pathlib.Path], + factor: int, + key: str, + nemo_model_config: typing.Dict[str, typing.Any], + transformer_model_list: typing.List, + np_weight_data_type, + args: argparse.Namespace, +): + # Config params + num_layers = nemo_model_config["num_layers"] + num_attention_heads = nemo_model_config["num_attention_heads"] + tensor_model_parallel_size = nemo_model_config.get("tensor_model_parallel_size", 1) + pipeline_model_parallel_size = nemo_model_config.get("pipeline_model_parallel_size", 1) + + if key.find("layers.") != -1: + layer_index = int(key[7 : key.find(".", 7)]) + saved_key = key.replace( + "layers.%d." % layer_index, + "layers.%d." % (layer_index + pp_rank * num_layers // pipeline_model_parallel_size), + ) + + if saved_key.find("self_attention") != -1: + saved_key = saved_key.replace("self_attention", "attention") else: - assert False, f"Invalid weight data type {data_type}" - -def unpack_nemo_ckpt(nemo_ckpt_path, out_folder): - """ - .nemo file is an archive (tar.gz) with the following: - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor - model_wights.chpt - model checkpoint - """ - if not os.path.exists(nemo_ckpt_path): - raise FileNotFoundError(f"{nemo_ckpt_path} does not exist") - tar = tarfile.open(nemo_ckpt_path, "r:gz") - tar.extractall(path=out_folder) - tar.close() - return out_folder - - -def _cpu_map_location(storage, loc): - return storage.cpu() - - -def _gpu_map_location(storage, loc): - if loc.startswith("cuda"): - training_gpu_idx = int(loc.split(":")[1]) - inference_gpu_idx = training_gpu_idx % torch.cuda.device_count() - return storage.cuda(inference_gpu_idx) - elif loc.startswith("cpu"): - return storage.cpu() + saved_key = key + + if ( + key.find("input_layernorm.weight") != -1 + or key.find("input_layernorm.bias") != -1 + or key.find("attention.dense.bias") != -1 + or key.find("post_attention_layernorm.weight") != -1 + or key.find("post_attention_layernorm.bias") != -1 + or key.find("mlp.dense_4h_to_h.bias") != -1 + or key.find("final_layernorm.weight") != -1 + or key.find("final_layernorm.bias") != -1 + ): + + # shared weights, only need to convert the weights of rank 0 + if tp_rank == 0: + val = safe_transpose(transformer_model_list[0][key]) + val = torch2np(val, np_weight_data_type) + saved_path = saved_dir / f"model.{saved_key}.bin" + np.squeeze(val).tofile(saved_path) + + elif key.find("attention.dense.weight") != -1 or key.find("mlp.dense_4h_to_h.weight") != -1: + vals = [] + for k in range(factor): + val = safe_transpose(transformer_model_list[k][key]) + val = torch2np(val, np_weight_data_type) + vals.append(val) + saved_path = saved_dir / f"model.{saved_key}.{tp_rank:d}.bin" + np.concatenate(vals, axis=0).tofile(saved_path) + + elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: + vals = [] + for k in range(factor): + val = safe_transpose(transformer_model_list[k][key]) + val = torch2np(val, np_weight_data_type) + vals.append(val) + saved_path = saved_dir / f"model.{saved_key}.{tp_rank:d}.bin" + np.concatenate(vals, axis=-1).tofile(saved_path) + + elif key.find("attention.query_key_value.bias") != -1: + vals = [] + for k in range(factor): + val = safe_transpose(transformer_model_list[k][key]) + val = torch2np(val, np_weight_data_type) + local_dim = int(val.shape[-1] / 3) + num_splits = 3 + head_num = num_attention_heads // tensor_model_parallel_size + size_per_head = local_dim // head_num + val = val.reshape(head_num, num_splits, size_per_head) + val = val.transpose(1, 0, 2) + val = val.reshape(3, local_dim) + vals.append(val) + + saved_path = saved_dir / f"model.{saved_key}.{tp_rank:d}.bin" + np.concatenate(vals, axis=-1).tofile(saved_path) + + elif key.find("attention.query_key_value.weight") != -1: + vals = [] + for k in range(factor): + val = safe_transpose(transformer_model_list[k][key]) + val = torch2np(val, np_weight_data_type) + hidden_dim = val.shape[0] + local_dim = int(val.shape[-1] / 3) + num_splits = 3 + head_num = num_attention_heads // tensor_model_parallel_size + size_per_head = local_dim // head_num + val = val.reshape(hidden_dim, head_num, num_splits, size_per_head) + val = val.transpose(0, 2, 1, 3) + val = val.reshape(hidden_dim, 3, local_dim) + vals.append(val) + + saved_path = saved_dir / f"model.{saved_key}.{tp_rank:d}.bin" + if args.fused_qkv == 1: + np.concatenate(vals, axis=-1).tofile(saved_path) + elif args.fused_qkv == 0: + np.concatenate(vals, axis=-1).transpose(1, 0, 2).tofile(saved_path) else: - raise NotImplementedError(f"Not handled {loc}") + LOGGER.error("cannot find key '%s'", key) + + +def split_and_convert_process( + tp_rank: int, + pp_rank: int, + saved_dir: typing.Union[str, pathlib.Path], + factor: int, + key: str, + nemo_model_config: typing.Dict[str, typing.Any], + transformer_model_list: typing.List, + np_weight_data_type, + args: argparse.Namespace, +): + + # Config params + num_layers = nemo_model_config["num_layers"] + num_attention_heads = nemo_model_config["num_attention_heads"] + tensor_model_parallel_size = nemo_model_config.get("tensor_model_parallel_size", 1) + pipeline_model_parallel_size = nemo_model_config.get("pipeline_model_parallel_size", 1) + + # Handle model[key] weights + transformer_model = transformer_model_list[0] + val = safe_transpose(transformer_model[key]) + val = torch2np(val, np_weight_data_type) + if key.find("layers.") != -1: + layer_index = (int)(key[7 : key.find(".", 7)]) + saved_key = key.replace( + "layers.%d." % layer_index, + "layers.%d." % (layer_index + pp_rank * num_layers // pipeline_model_parallel_size), + ) + if saved_key.find("self_attention") != -1: + saved_key = saved_key.replace("self_attention", "attention") + else: + saved_key = key + + if ( + key.find("input_layernorm.weight") != -1 + or key.find("input_layernorm.bias") != -1 + or key.find("attention.dense.bias") != -1 + or key.find("post_attention_layernorm.weight") != -1 + or key.find("post_attention_layernorm.bias") != -1 + or key.find("mlp.dense_4h_to_h.bias") != -1 + or key.find("final_layernorm.weight") != -1 + or key.find("final_layernorm.bias") != -1 + ): + # shared weights, only need to convert the weights of rank 0 + if tp_rank == 0: + saved_path = saved_dir / f"model.{saved_key}.bin" + val.tofile(saved_path) + + elif key.find("attention.dense.weight") != -1 or key.find("mlp.dense_4h_to_h.weight") != -1: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + saved_path = saved_dir / f"model.{saved_key}.{tp_rank * factor + j:d}.bin" + split_vals[j].tofile(saved_path) + + elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir / f"model.{saved_key}.{tp_rank * factor + j:d}.bin" + split_vals[j].tofile(saved_path) + + elif key.find("attention.query_key_value.bias") != -1: + local_dim = int(val.shape[-1] / 3) + + num_splits = 3 + head_num = num_attention_heads // tensor_model_parallel_size + size_per_head = local_dim // head_num + val = val.reshape(head_num, num_splits, size_per_head) + val = val.transpose(1, 0, 2) + + val = val.reshape(3, local_dim) + split_vals = np.split(val, factor, axis=-1) + + for j in range(factor): + saved_path = saved_dir / f"model.{saved_key}.{tp_rank * factor + j:d}.bin" + split_vals[j].tofile(saved_path) + + elif key.find("attention.query_key_value.weight") != -1: + hidden_dim = val.shape[0] + local_dim = int(val.shape[-1] / 3) + + num_splits = 3 + head_num = num_attention_heads + size_per_head = hidden_dim // head_num + head_num = head_num // tensor_model_parallel_size + val = val.reshape(hidden_dim, head_num, num_splits, size_per_head) + val = val.transpose(0, 2, 1, 3) + + val = val.reshape(hidden_dim, 3, local_dim) + split_vals = np.split(val, factor, axis=-1) + + for j in range(factor): + saved_path = saved_dir / f"model.{saved_key}.{tp_rank * factor + j:d}.bin" + split_vals[j].tofile(saved_path) -# more to less. e.g., trained by 8 gpus, infer by 2 gpus -def merge_and_convert( - args, model_config, weight_files, *, load_checkpoints_to_cpu: bool = False -): # noqa: C901 too complex - saved_dir = Path(args.saved_dir) - if args.fused_qkv == 1: - saved_dir = saved_dir / f"{args.infer_gpu_num:d}-gpu/" else: - saved_dir = saved_dir / f"unfusedQKV-{args.infer_gpu_num:d}-gpu" + LOGGER.error("cannot find key '%s'", key) - saved_dir.mkdir(parents=True, exist_ok=True) - config = configparser.ConfigParser() - config["gpt"] = {} +def convert_checkpoint(unpacked_checkpoints_dir: UnpackedNemoCheckpointDir, args): + nemo_model_config = unpacked_checkpoints_dir.model_config - try: - for key in vars(args): - config["gpt"][key] = f"{vars(args)[key]}" - for k, v in model_config.items(): - config["gpt"][k] = f"{v}" - config["gpt"]["weight_data_type"] = args.weight_data_type - with open((saved_dir / f"config.ini").as_posix(), 'w') as configfile: - config.write(configfile) - except: - print(f"Fail to save the config in config.ini.") - - np_weight_data_type = get_weight_data_type(args.weight_data_type) - - prefix = Path(args.in_file) - i_gpu_num = args.infer_gpu_num - - t_gpu_num = model_config["tensor_model_parallel_size"] - num_attention_heads = model_config["num_attention_heads"] - - assert t_gpu_num % i_gpu_num == 0 - factor = int(t_gpu_num / i_gpu_num) - - num_checkpoints_per_convert = max(factor, 1) - if num_checkpoints_per_convert > torch.cuda.device_count(): - print( - f"[WARNING] Need to load #{num_checkpoints_per_convert} checkpoints at once " - f"while having {torch.cuda.device_count()} GPUs. Force load checkpoints on CPU" - ) - load_checkpoints_to_cpu = True + checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths( + nemo_model_config.get("tensor_model_parallel_size", 1), + nemo_model_config.get("pipeline_model_parallel_size", 1), + ) - map_location_fn = _cpu_map_location if load_checkpoints_to_cpu else _gpu_map_location + # if checkpoints files could be found - start preparing output dir + saved_dir = _prepare_saved_dir(args) - # load position_embedding from rank 0 - model_00 = torch.load(weight_files[0], map_location=map_location_fn) - model_00["model.language_model.embedding.position_embeddings.weight"].float().cpu().numpy().astype( - np_weight_data_type - ).tofile( - (saved_dir / "model.wpe.bin").as_posix() - ) # not weight, do not need transpose + map_location_fn = cpu_map_location if bool(args.load_checkpoints_to_cpu) else gpu_map_location + np_weight_data_type = WEIGHT2DTYPE[args.weight_data_type] + # load position_embedding from rank 0 + model_00 = torch.load(checkpoints_paths[0][0], map_location=map_location_fn) + val = model_00.get("state_dict", model_00)["model.language_model.embedding.position_embeddings.weight"] + # not weight, do not need to transpose + val = torch2np(val, np_weight_data_type) + val.tofile(saved_dir / "model.wpe.bin") del model_00 + w_e_list = [] - for i in range(i_gpu_num): - transformer_models = [] - for j in range(factor): - model = torch.load(weight_files[i * factor + j], map_location=map_location_fn) - - w_e_list.append( - model["model.language_model.embedding.word_embeddings.weight"] - .float() - .cpu() - .numpy() - .astype(np_weight_data_type) - ) - prefix = "model.language_model.encoder" - model["model"] = {} - model["model"]["language_model"] = {} - model["model"]["language_model"]["encoder"] = {} - model["model"]["language_model"]["embedding"] = {} - model["model"]["language_model"]["embedding"]["word_embeddings"] = {} - model["model"]["language_model"]["embedding"]["position_embeddings"] = {} - model["model"]["language_model"]["embedding"]["word_embeddings"]["weight"] = model[ - "model.language_model.embedding.word_embeddings.weight"] - model["model"]["language_model"]["embedding"]["position_embeddings"]["weight"] = model[ - "model.language_model.embedding.position_embeddings.weight"] - for key in model.keys(): - if prefix in key: - first = key[:len(prefix)] - second = key[len(prefix) + 1:] - model["model"]["language_model"]["encoder"][second] = model[key] - - # print(model["model"]["language_model"]["encoder"].keys()) - - # this model should be able to load into megatron - # torch.save(model, "model.pt") - - transformer_models.append(model["model"]["language_model"]["encoder"]) - - for key in transformer_models[0]: - if ( - key.find("input_layernorm.weight") != -1 - or key.find("input_layernorm.bias") != -1 - or key.find("attention.dense.bias") != -1 - or key.find("post_attention_layernorm.weight") != -1 - or key.find("post_attention_layernorm.bias") != -1 - or key.find("mlp.dense_4h_to_h.bias") != -1 - or key.find("final_layernorm.weight") != -1 - or key.find("final_layernorm.bias") != -1 - ): - # shared weights, only need to convert the weights of rank 0 - if i == 0: - val = transformer_models[0][key].T.float().cpu().numpy() - key = key.replace("self_attention", "attention") - saved_path = saved_dir / f"model.{key}.bin" - np.squeeze(val).astype(np_weight_data_type).tofile(saved_path.as_posix()) - - elif key.find("attention.dense.weight") != -1 or key.find("mlp.dense_4h_to_h.weight") != -1: - vals = [] - for k in range(factor): - vals.append(transformer_models[k][key].T.float().cpu().numpy()) - key = key.replace("self_attention", "attention") - saved_path = saved_dir / f"model.{key}.{i}.bin" - np.concatenate(vals, axis=0).astype(np_weight_data_type).tofile(saved_path.as_posix()) + training_tensor_para_size = nemo_model_config.get("tensor_model_parallel_size", 1) + training_pipeline_para_size = nemo_model_config.get("pipeline_model_parallel_size", 1) + inference_tensor_para_size = args.infer_gpu_num - elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: + if training_tensor_para_size > inference_tensor_para_size: + assert training_tensor_para_size % inference_tensor_para_size == 0 + is_merge_ckpt = True + factor = int(training_tensor_para_size / inference_tensor_para_size) + else: + assert inference_tensor_para_size % training_tensor_para_size == 0 + is_merge_ckpt = False + factor = int(inference_tensor_para_size / training_tensor_para_size) - vals = [] - for k in range(factor): - vals.append(transformer_models[k][key].T.float().cpu().numpy()) - saved_path = saved_dir / f"model.{key}.{i}.bin" - np.concatenate(vals, axis=-1).astype(np_weight_data_type).tofile(saved_path.as_posix()) + main_loop = min(training_tensor_para_size, inference_tensor_para_size) - elif key.find("attention.query_key_value.bias") != -1: - vals = [] - for k in range(factor): - val = transformer_models[k][key].T.float().cpu().numpy() - local_dim = (int)(val.shape[-1] / 3) - num_splits = 3 - head_num = num_attention_heads // t_gpu_num - size_per_head = local_dim // head_num - val = val.reshape(head_num, num_splits, size_per_head) - val = val.transpose(1, 0, 2) - val = val.reshape(3, local_dim) - vals.append(val) - - key = key.replace("self_attention", "attention") - saved_path = saved_dir / f"model.{key}.{i}.bin" - np.concatenate(vals, axis=-1).astype(np_weight_data_type).tofile(saved_path.as_posix()) - - elif key.find("attention.query_key_value.weight") != -1: - vals = [] - for k in range(factor): - val = transformer_models[k][key].T.float().cpu().numpy() - hidden_dim = val.shape[0] - local_dim = (int)(val.shape[-1] / 3) - num_splits = 3 - head_num = num_attention_heads // t_gpu_num - size_per_head = local_dim // head_num - val = val.reshape(hidden_dim, head_num, num_splits, size_per_head) - val = val.transpose(0, 2, 1, 3) - val = val.reshape(hidden_dim, 3, local_dim) - vals.append(val) - - key = key.replace("self_attention", "attention") - saved_path = saved_dir / f"model.{key}.{i}.bin" - if args.fused_qkv == 1: - np.concatenate(vals, axis=-1).astype(np_weight_data_type).tofile(saved_path.as_posix()) - elif args.fused_qkv == 0: - np.concatenate(vals, axis=-1).transpose(1, 0, 2).astype(np_weight_data_type).tofile(saved_path.as_posix()) + torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_sharing_strategy("file_system") + pool = multiprocessing.Pool(args.processes) + for i in range(main_loop): + for j in range(training_pipeline_para_size): + transformer_models = [] + if is_merge_ckpt: + for k in range(factor): + rank_weights = checkpoints_paths[i * factor + k][j] + model = torch.load(rank_weights, map_location=map_location_fn) + if j == 0: + val = model.get("state_dict", model)["model.language_model.embedding.word_embeddings.weight"] + val = torch2np(val, np_weight_data_type) + w_e_list.append(val) + layers = extract_layers_with_prefix(model, "model.language_model.encoder.") + transformer_models.append(layers) else: - print(f"[ERROR] cannot find key '{key}'") + rank_weights = checkpoints_paths[i][j] + model = torch.load(rank_weights, map_location=map_location_fn) + + if j == 0: + val = model.get("state_dict", model)["model.language_model.embedding.word_embeddings.weight"] + val = torch2np(val, np_weight_data_type) + w_e_list.append(val) + layers = extract_layers_with_prefix(model, "model.language_model.encoder.") + transformer_models.append(layers) + + pool.starmap( + merge_and_convert_process if is_merge_ckpt else split_and_convert_process, + [ + ( + i, # tp_rank + j, # pp_rank + saved_dir, + factor, + key, + nemo_model_config, + transformer_models, + np_weight_data_type, + args, + ) + for key in transformer_models[0] + ], + ) - np.concatenate(w_e_list, axis=0).tofile((saved_dir / "model.wte.bin").as_posix()) + pool.close() + pool.join() + + val = np.concatenate(w_e_list, axis=0) + val.tofile(saved_dir / "model.wte.bin") + + vocab_size = val.shape[0] + + tokenizer_config = nemo_model_config["tokenizer"] + tokenizer_config = _update_tokenizer_config(tokenizer_config, unpacked_checkpoints_dir) + if args.tokenizer_model_path: + LOGGER.debug("Use tokenizer model passed from CLI: %s", args.tokenizer_model_path) + tokenizer_config["model"] = args.tokenizer_model_path + if args.vocab_path: + LOGGER.debug("Use tokenizer vocab passed from CLI: %s", args.vocab_path) + tokenizer_config["vocab_file"] = args.vocab_path + if args.merges_path: + LOGGER.debug("Use tokenizer merge passed from CLI: %s", args.merges_path) + tokenizer_config["merge_file"] = args.merges_path + + _copy_tokenizer_file_if_defined("model", tokenizer_config["model"], saved_dir) + _copy_tokenizer_file_if_defined("vocab_file", tokenizer_config["vocab_file"], saved_dir) + _copy_tokenizer_file_if_defined("merge_file", tokenizer_config["merge_file"], saved_dir) + + bos_id, eos_id = _get_special_tokens_ids(tokenizer_config) + + gpt_model_config = GptModelConfig.from_nemo_package( + args=args, + nemo_model_config=nemo_model_config, + vocab_size=vocab_size, + bos_id=bos_id, + eos_id=eos_id, + ) + + # Configuration for the model (load by triton backends) + config = configparser.ConfigParser() + config["gpt"] = {k: str(v) for k, v in dataclasses.asdict(gpt_model_config).items()} + try: + config_path = saved_dir / "config.ini" + with config_path.open("w") as config_file: + config.write(config_file) + except Exception as e: + LOGGER.error("Fail to save the config; %s", e) -def split_and_convert(args, model_config, weight_files, *, load_checkpoints_to_cpu: bool = False): - saved_dir = Path(args.saved_dir) +def _prepare_saved_dir(args): + saved_dir = pathlib.Path(args.saved_dir) if args.fused_qkv == 1: saved_dir = saved_dir / f"{args.infer_gpu_num:d}-gpu/" else: saved_dir = saved_dir / f"unfusedQKV-{args.infer_gpu_num:d}-gpu" + if saved_dir.exists(): + LOGGER.error(f"Remove %s target directory before running conversion", saved_dir) + sys.exit(1) + saved_dir.mkdir(parents=True) + return saved_dir - saved_dir.mkdir(parents=True, exist_ok=True) - - config = configparser.ConfigParser() - config["gpt"] = {} - try: - for key in vars(args): - config["gpt"][key] = f"{vars(args)[key]}" - for k, v in model_config.items(): - config["gpt"][k] = f"{v}" - config["gpt"]["weight_data_type"] = args.weight_data_type - with open((saved_dir / f"config.ini").as_posix(), 'w') as configfile: - config.write(configfile) - except: - print(f"Fail to save the config in config.ini.") - - np_weight_data_type = get_weight_data_type(args.weight_data_type) - prefix = Path(args.in_file) - - i_gpu_num = args.infer_gpu_num - t_gpu_num = model_config["tensor_model_parallel_size"] - num_attention_heads = model_config["num_attention_heads"] - - assert i_gpu_num % t_gpu_num == 0 - factor = int(i_gpu_num / t_gpu_num) - - num_checkpoints_per_convert = max(int(1 / (i_gpu_num / t_gpu_num)), 1) - if num_checkpoints_per_convert > torch.cuda.device_count(): - print( - f"[WARNING] Need to load #{num_checkpoints_per_convert} checkpoints at once " - f"while having {torch.cuda.device_count()} GPUs. Force load checkpoints on CPU" - ) - load_checkpoints_to_cpu = True - map_location_fn = _cpu_map_location if load_checkpoints_to_cpu else _gpu_map_location +def prompt_convert(args, prompt_config, prompt_weights): - # load position_embedding from rank 0 - model_00 = torch.load(weight_files[0], map_location=map_location_fn) - model_00["model.language_model.embedding.position_embeddings.weight"].float().cpu().numpy().astype( - np_weight_data_type - ).tofile( - (saved_dir / "model.wpe.bin").as_posix() - ) # not weight, do not need transpose - del model_00 + prompt_templates = prompt_config["task_templates"] - w_e_list = [] + # model config save dir + config_saved_dir = _prepare_saved_dir(args) - # main_loop = min(t_gpu_num, i_gpu_num) - for i in range(t_gpu_num): - model = torch.load(weight_files[i], map_location=map_location_fn) + # Configuration for the model (load by triton backends) + config_path = config_saved_dir / "config.ini" + config = configparser.ConfigParser() + with config_path.open("r") as config_file: + config.read_file(config_file) + + num_tasks = len(prompt_templates) + prompt_learning_type = 3 # p_prompt_tuning + prompt_learning_start_id = 50257 # hard code here + config["gpt"]["num_tasks"] = str(num_tasks) + config["gpt"]["prompt_learning_start_id"] = str(prompt_learning_start_id) + config["gpt"]["prompt_learning_type"] = str(prompt_learning_type) + + for task_name_id, prompt_task in enumerate(prompt_templates): + prompt_task_name = prompt_task["taskname"] + prompt_length = int(prompt_task["total_virtual_tokens"]) + config[f"task_{task_name_id:d}"] = {} + config[f"task_{task_name_id:d}"]["task_name"] = prompt_task_name + config[f"task_{task_name_id:d}"]["prompt_length"] = str(prompt_length) + prompt_task_weights = prompt_weights["prompt_table"][ + f"prompt_table.{prompt_task_name}.prompt_embeddings.weight" + ] + # put converted prompts weights to the model weights saved dir + prompt_task_weights_output_path = config_saved_dir / f"model.prompt_table.{prompt_task_name}.weight.bin" + val = torch2np(prompt_task_weights) + val.tofile(prompt_task_weights_output_path) + + with config_path.open("w") as config_file: + config.write(config_file) + + LOGGER.info(">>>>>>>>>>>>>>>> model saved config") + LOGGER.info(config_path.read_text()) + + +def _update_tokenizer_config(tokenizer_config: typing.Dict, unpacked_checkpoints_dir): + def _update_config_entry(key, file_pattern): + old_file_path = tokenizer_config[key] + if old_file_path: + LOGGER.debug("tokenizer %s %s type %s", key, old_file_path, type(old_file_path)) + old_file_path = pathlib.Path(old_file_path) + new_file_path = unpacked_checkpoints_dir.get_tokenizer_file_path("tokenizer", key, file_pattern) + if new_file_path: + LOGGER.debug("Update tokenizer %s %s -> %s", key, old_file_path, new_file_path) + tokenizer_config[key] = new_file_path.as_posix() + elif not old_file_path.exists(): + LOGGER.warning("Because tokenizer %s %s does not exists - set it as None", key, old_file_path) + tokenizer_config[key] = None + + _update_config_entry("model", "*.model") + _update_config_entry("vocab_file", "*vocab*") + _update_config_entry("merge_file", "*merge*.txt") + + return tokenizer_config + + +def _copy_tokenizer_file_if_defined(key_name, tokenizer_file_path, saved_dir): + if tokenizer_file_path: + tokenizer_file_path = pathlib.Path(tokenizer_file_path) + if tokenizer_file_path.exists(): + tokenizer_basename = { + "model": "tokenizer", + "vocab_file": "vocab", + "merge_file": "merges", + }[key_name] + dst_path = saved_dir / f"{tokenizer_basename}{tokenizer_file_path.suffix}" + LOGGER.debug("Copy of %s %s file as %s", tokenizer_file_path, key_name, dst_path) + shutil.copy(tokenizer_file_path.as_posix(), dst_path.as_posix()) + else: + LOGGER.debug("%s %s file does not exists", tokenizer_file_path, key_name) + + +def _get_special_tokens_ids(tokenizer_config: typing.Dict): + from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + from examples.pytorch.tokenizer import add_special_tokens_to_tokenizer + + logging.getLogger("git.cmd").setLevel(logging.INFO) + logging.getLogger("h5py._conv").setLevel(logging.INFO) + logging.getLogger("matplotlib").setLevel(logging.INFO) + logging.getLogger("matplotlib.font_manager").setLevel(logging.INFO) + logging.getLogger("matplotlib.pyplot").setLevel(logging.INFO) + + tokenizer = get_nmt_tokenizer( + library=tokenizer_config["library"], + model_name=tokenizer_config["type"], + tokenizer_model=tokenizer_config["model"], + vocab_file=tokenizer_config["vocab_file"], + merges_file=tokenizer_config["merge_file"], + legacy=True, + ) - w_e_list.append( - model["model.language_model.embedding.word_embeddings.weight"] - .float() - .cpu() - .numpy() - .astype(np_weight_data_type) - ) + if tokenizer_config["library"] == "sentencepiece": + add_special_tokens_to_tokenizer(tokenizer) - prefix = "model.language_model.encoder" - # Build dictionary - model["model"] = {} - model["model"]["language_model"] = {} - model["model"]["language_model"]["encoder"] = {} - model["model"]["language_model"]["embedding"] = {} - model["model"]["language_model"]["embedding"]["word_embeddings"] = {} - model["model"]["language_model"]["embedding"]["position_embeddings"] = {} - model["model"]["language_model"]["embedding"]["word_embeddings"]["weight"] = model[ - "model.language_model.embedding.word_embeddings.weight"] - model["model"]["language_model"]["embedding"]["position_embeddings"]["weight"] = model[ - "model.language_model.embedding.position_embeddings.weight"] - - for key in model.keys(): - if prefix in key: - first = key[:len(prefix)] - second = key[len(prefix) + 1:] - model["model"]["language_model"]["encoder"][second] = model[key] - - transformer_model = model["model"]["language_model"]["encoder"] - - for key in transformer_model: - val = transformer_model[key].T.float().cpu().numpy().astype(np_weight_data_type) - if key.find("layers.") != -1: - layer_index = (int)(key[7: key.find(".", 7)]) - saved_key = key - # saved_key = key.replace( - # "layers.%d." % layer_index, - # "layers.%d." % (layer_index + pipeline_para_rank * model_args.num_layers // model_args.pipeline_model_parallel_size)) - - if saved_key.find("self_attention") != -1: - saved_key = saved_key.replace("self_attention", "attention") - else: - saved_key = key - - if ( - key.find("input_layernorm.weight") != -1 - or key.find("input_layernorm.bias") != -1 - or key.find("attention.dense.bias") != -1 - or key.find("post_attention_layernorm.weight") != -1 - or key.find("post_attention_layernorm.bias") != -1 - or key.find("mlp.dense_4h_to_h.bias") != -1 - or key.find("final_layernorm.weight") != -1 - or key.find("final_layernorm.bias") != -1 - ): - # shared weights, only need to convert the weights of rank 0 - if i == 0: - saved_path = saved_dir / f"model.{saved_key}.bin" - val.tofile(saved_path.as_posix()) - - elif key.find("attention.dense.weight") != -1 or key.find("mlp.dense_4h_to_h.weight") != -1: - split_vals = np.split(val, factor, axis=0) - for j in range(factor): - saved_path = saved_dir / f"model.{saved_key}.{i * factor + j:d}.bin" - split_vals[j].tofile(saved_path.as_posix()) - - elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: - split_vals = np.split(val, factor, axis=-1) - for j in range(factor): - saved_path = saved_dir / f"model.{saved_key}.{i * factor + j:d}.bin" - split_vals[j].tofile(saved_path.as_posix()) - - elif key.find("attention.query_key_value.bias") != -1: - local_dim = int(val.shape[-1] / 3) - - # ckpt_ver == 3 - num_splits = 3 - head_num = num_attention_heads // t_gpu_num - size_per_head = local_dim // head_num - val = val.reshape(head_num, num_splits, size_per_head) - val = val.transpose(1, 0, 2) - - val = val.reshape(3, local_dim) - split_vals = np.split(val, factor, axis=-1) - - for j in range(factor): - saved_path = saved_dir / f"model.{saved_key}.{i * factor + j:d}.bin" - split_vals[j].tofile(saved_path.as_posix()) - - elif key.find("attention.query_key_value.weight") != -1: - hidden_dim = val.shape[0] - local_dim = int(val.shape[-1] / 3) - - # ckpt_ver == 3: - num_splits = 3 - head_num = num_attention_heads - size_per_head = hidden_dim // head_num - head_num = head_num // t_gpu_num - val = val.reshape(hidden_dim, head_num, num_splits, size_per_head) - val = val.transpose(0, 2, 1, 3) - - val = val.reshape(hidden_dim, 3, local_dim) - split_vals = np.split(val, factor, axis=-1) - - for j in range(factor): - saved_path = saved_dir / f"model.{saved_key}.{i * factor + j:d}.bin" - split_vals[j].tofile(saved_path.as_posix()) - # print(split_vals[j].shape) + bos_id = tokenizer.bos_id + eos_id = tokenizer.eos_id - else: - print(f"[ERROR] cannot find key '{key}'") + LOGGER.debug("for %s obtained tokenizer tokens ids bos_id=%d eos_id=%d", tokenizer_config, bos_id, eos_id) - np.concatenate(w_e_list, axis=0).tofile((saved_dir / "model.wte.bin").as_posix()) - # print(torch.from_numpy(np.fromfile((saved_dir / "model.wte.bin").as_posix(), dtype=np.single)).size()) + return bos_id, eos_id -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument("-saved_dir", "-o", type=str, help="folder name of output files", required=True) - parser.add_argument("-in_file", "-i", type=str, help="file name of .nemo checkpoint file", required=True) - parser.add_argument("-infer_gpu_num", "-i_g", type=int, help="How many gpus for inference", required=True) parser.add_argument( - "-fused_qkv", + "--saved-dir", + "-saved_dir", + "-o", + help="folder name of output files", + required=True, + ) + parser.add_argument( + "--in-file", + "-in_file", + "-i", + help="file name of .nemo checkpoint file", + required=True, + ) + parser.add_argument( + "--prompt-in-file", + "-prompt_in_file", + "-p_i", + help="file name of .nemo prompt checkpoint file", + ) + parser.add_argument( + "--prompt-saved-dir", + "-prompt_saved_dir", + "-p_o", + help="folder name of prompt checkpoint output files", + ) + parser.add_argument( + "--infer-gpu-num", + "-infer_gpu_num", + "-i_g", + type=int, + help="How many gpus for inference", + required=True, + ) + parser.add_argument( + "--fused-qkv", "-fused_qkv", type=int, + choices=[0, 1], default=1, - help="Fuse the qkv weights or not. Default is true (1)", + help="Fuse the qkv weights or not", + ) + parser.add_argument( + "--processes", + "-processes", + "-p", + type=int, + default=16, + help="How many processes to spawn for conversion", + ) + parser.add_argument( + "--weight-data-type", + "-weight_data_type", + choices=["fp32", "fp16"], + default="fp32", + help="Data type of results weights", + ) + parser.add_argument( + "--load-checkpoints-to-cpu", + "-load_checkpoints_to_cpu", + "-cpu", + type=int, choices=[0, 1], + default=1, + help="Whether to load model weights to CPU", + ) + parser.add_argument( + "--vocab-path", + help="Path to vocabulary file to embed in FasterTransformer checkpoint", + required=False, + ) + parser.add_argument( + "--merges-path", + help="Path to merges file to embed in FasterTransformer checkpoint", + required=False, ) - parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) + parser.add_argument( + "--tokenizer-model-path", + help="Path to tokenizer model file to embed in FasterTransformer checkpoint", + required=False, + ) + parser.add_argument("--verbose", action="store_true", help="Provide verbose messages") args = parser.parse_args() + + log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, format=log_format) + print("\n=============== Argument ===============") for key in vars(args): print(f"{key}: {vars(args)[key]}") print("========================================") - model_config_yaml = "model_config.yaml" - model_weights_ckpt = "model_weights.ckpt" - config_yaml = os.path.join(args.saved_dir, model_config_yaml) - - unpack_nemo_ckpt(args.in_file, args.saved_dir) + input_path = pathlib.Path(args.in_file) + if not input_path.exists(): + LOGGER.error("%s does not exists", input_path) + sys.exit(1) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = pathlib.Path(temp_dir) + + # unpack if needed + if input_path.is_file(): + checkpoint_dir_path = temp_dir / "unpacked" + start_time = datetime.datetime.now() + unpacked_checkpoint_dir = UnpackedNemoCheckpointDir( + unpack_nemo_ckpt(args.in_file, checkpoint_dir_path), + load_checkpoints_to_cpu=bool(args.load_checkpoints_to_cpu), + ) + LOGGER.info("Spent %s (h:m:s) to unpack NeMo archive", datetime.datetime.now() - start_time) + else: + unpacked_checkpoint_dir = UnpackedNemoCheckpointDir( + input_path, load_checkpoints_to_cpu=bool(args.load_checkpoints_to_cpu) + ) - with open(config_yaml) as f: - model_config = yaml.full_load(f) + start_time = datetime.datetime.now() + convert_checkpoint(unpacked_checkpoint_dir, args) + LOGGER.info("Spent %s (h:m:s) to convert the model", datetime.datetime.now() - start_time) + + map_location_fn = cpu_map_location if bool(args.load_checkpoints_to_cpu) else gpu_map_location + # prompt checkpoint converting + if args.prompt_in_file is not None: + start_time = datetime.datetime.now() + assert args.prompt_saved_dir is not None + unpack_nemo_ckpt(args.prompt_in_file, args.prompt_saved_dir) + LOGGER.info("Spent %s (h:m:s) to unpack NeMo prompt archive", datetime.datetime.now() - start_time) + + model_config_yaml = "model_config.yaml" + model_weights_ckpt = "model_weights.ckpt" + prompt_config_file = open(os.path.join(args.prompt_saved_dir, model_config_yaml), "r") + prompt_config = yaml.full_load(prompt_config_file) + LOGGER.info(prompt_config) + + start_time = datetime.datetime.now() + prompt_weights = torch.load( + os.path.join(args.prompt_saved_dir, model_weights_ckpt), + map_location=map_location_fn, + ) + prompt_convert(args, prompt_config, prompt_weights) + LOGGER.info(f"Spent %s (h:m:s) to unpack convert prompt model", datetime.datetime.now() - start_time) - t_gpu_num = model_config["tensor_model_parallel_size"] - if t_gpu_num == 1: - model_weights = [os.path.join(args.saved_dir, model_weights_ckpt)] - else: - model_weights = [os.path.join(args.saved_dir, f"mp_rank_{i:02d}", model_weights_ckpt) for i in range(t_gpu_num)] - print(model_config) - if t_gpu_num > args.infer_gpu_num: - merge_and_convert(args, model_config, model_weights) - else: - split_and_convert(args, model_config, model_weights) +if __name__ == "__main__": + main() diff --git a/examples/pytorch/gpt/utils/parallel_gpt.py b/examples/pytorch/gpt/utils/parallel_gpt.py index 85c117228..68b358a1d 100644 --- a/examples/pytorch/gpt/utils/parallel_gpt.py +++ b/examples/pytorch/gpt/utils/parallel_gpt.py @@ -15,13 +15,23 @@ from __future__ import print_function import torch +import numpy as np from examples.pytorch.gpt.utils.gpt import GPT class ParallelGPT(GPT): def __init__(self, head_num, size_per_head, vocab_size, start_id, end_id, layer_num, max_seq_len, - tensor_para_size, pipeline_para_size, lib_path, int8_mode): + tensor_para_size, pipeline_para_size, lib_path, + layernorm_eps = 1e-6, layernorm_type = "pre_layernorm", # gpt_variant_params + activation_type = "Gelu", has_post_decoder_layernorm = True, # gpt variant params + has_adapters = False, adapter_inter_size = 0, # gpt variant params + int8_mode = 0, + weights_data_type: np.dtype = np.float32, + shared_contexts_ratio=1.0): super().__init__(head_num, size_per_head, vocab_size, start_id, end_id, layer_num, max_seq_len, - tensor_para_size, pipeline_para_size, lib_path, int8_mode) + tensor_para_size, pipeline_para_size, lib_path, + layernorm_eps, layernorm_type, activation_type, has_post_decoder_layernorm, + has_adapters, adapter_inter_size, + int8_mode, weights_data_type, shared_contexts_ratio) def cuda(self): self.weights._map(lambda w: w.cuda(self.device)) @@ -34,5 +44,7 @@ def cuda(self): self.model = torch.classes.FasterTransformer.ParallelGptOp(self.head_num, self.size_per_head, 4 * self.head_num * self.size_per_head, self.layer_num, self.vocab_size, self.start_id, self.end_id, self.tensor_para_size, self.pipeline_para_size, self.int8_mode, - self.weights.w, self.weights.int8_w, self.weights.scale) + self.layernorm_eps, self.layernorm_type, self.activation_type, self.has_post_decoder_layernorm, + self.has_adapters, self.adapter_inter_size, # gpt_variant_params + self.weights.w, self.weights.int8_w, self.weights.scale, self.shared_contexts_ratio) self.build_model = True diff --git a/examples/pytorch/gpt/utils/update_gpt_config.py b/examples/pytorch/gpt/utils/update_gpt_config.py new file mode 100644 index 000000000..cc72a3a72 --- /dev/null +++ b/examples/pytorch/gpt/utils/update_gpt_config.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import pathlib + + +def main(): + parser = argparse.ArgumentParser( + description="Script updating GPT config.ini hyper-parameters and requests parameters" + ) + + # config.ini path + parser.add_argument("--config-ini-path", required=True, help="Path to config.ini file to be updated") + + # FT hyperparameters + parser.add_argument("--model-dir", type=str, required=True, help="Model path prefix") + parser.add_argument("--tensor-para-size", type=int, required=True, help="tensor parallelism size") + parser.add_argument("--pipeline-para-size", type=int, required=True, help="layer parallelism size") + parser.add_argument("--max-batch-size", type=int, default=8, help="batch size") + parser.add_argument("--max-seq-len", type=int, default=256, help="max sequence length") + parser.add_argument("--beam-width", type=int, default=1, help="beam width for beam search") + parser.add_argument("--data-type", type=str, default="fp32", help="data type", choices=["fp32", "fp16", "bf16"]) + parser.add_argument( + "--sampling-top-k", + type=int, + default=1, + help="Candidate (k) value of top k sampling in decoding", + ) + parser.add_argument( + "--sampling-top-p", + type=float, + default=0.0, + help="Probability (p) value of top p sampling in decoding", + ) + parser.add_argument("--temperature", type=float, default=1.0, help="temperature of penalty") + parser.add_argument("--repetition-penalty", type=float, default=1.0, help="repetition_penalty") + parser.add_argument("--len-penalty", type=float, default=0.0, help="len_penalty") + parser.add_argument("--beam-search-diversity-rate", type=float, default=0.0, help="beam_search_diversity_rate") + + # request + parser.add_argument("--request-batch-size", type=int, default=8, help="batch size") + parser.add_argument("--request-output-len", type=int, default=32, help="output length") + parser.add_argument("--model-name", type=str, default="gpt", help="model-name for testing") + + args = parser.parse_args() + + config_path = pathlib.Path(args.config_ini_path) + + config = configparser.ConfigParser() + config.read(config_path) + + config["ft_instance_hyperparameter"] = { + "max_batch_size": args.max_batch_size, + "max_seq_len": args.max_seq_len, + "beam_width": args.beam_width, + "top_k": args.sampling_top_k, + "top_p": args.sampling_top_p, + "temperature": args.temperature, + "tensor_para_size": args.tensor_para_size, + "pipeline_para_size": args.pipeline_para_size, + "data_type": args.data_type, + "sparse": 0, + "int8_mode": 0, + "enable_custom_all_reduce": 0, + "model_name": args.model_name, + "model_dir": args.model_dir, + "repetition_penalty": args.repetition_penalty, + "len_penalty": args.len_penalty, + "beam_search_diversity_rate": args.beam_search_diversity_rate, + } + + config["request"] = { + "request_batch_size": args.request_batch_size, + "request_output_len": args.request_output_len, + "return_log_probs": "false", + "context_log_probs": "false", + "beam_width": args.beam_width, + "top_k": args.sampling_top_k, + "top_p": args.sampling_top_p, + "temperature": args.temperature, + "repetition_penalty": args.repetition_penalty, + "len_penalty": args.len_penalty, + "beam_search_diversity_rate": args.beam_search_diversity_rate, + } + + with config_path.open("w") as config_file: + config.write(config_file) + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/gptj/utils/generate_gptj_config.py b/examples/pytorch/gptj/utils/generate_gptj_config.py new file mode 100644 index 000000000..ed203b1b7 --- /dev/null +++ b/examples/pytorch/gptj/utils/generate_gptj_config.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser + +def generate_gpt_config(args): + config = configparser.ConfigParser() + config["ft_instance_hyperparameter"] = { + "max_batch_size": "{}".format(args['max_batch_size']), + "max_seq_len": "{}".format(args['max_seq_len']), + "beam_width": "{}".format(args['beam_width']), + "top_k": "{}".format(args['sampling_topk']), + "top_p": "{}".format(args['sampling_topp']), + "temperature": "{}".format(args['temperature']), + "tensor_para_size": "{}".format(args['tensor_para_size']), + "pipeline_para_size": "{}".format(args['pipeline_para_size']), + "data_type": "{}".format(args['data_type']), + "sparse": "0", + "int8_mode": "0", + "enable_custom_all_reduce": "0", + "model_name": "tmp_model", + "model_dir": "{}".format(args['model_dir']), + "repetition_penalty": "{}".format(args['repetition_penalty']), + "len_penalty": "{}".format(args['len_penalty']), + "beam_search_diversity_rate": "{}".format(args['beam_search_diversity_rate']), + } + + config["request"] = { + "request_batch_size": "{}".format(args['request_batch_size']), + "request_output_len": "{}".format(args['request_output_len']), + "return_log_probs": "false", + "context_log_probs": "false", + } + + config["tmp_model"] = { + "head_num": "{}".format(args['head_number']), + "size_per_head": "{}".format(args['size_per_head']), + "inter_size": "{}".format(args['inter_size']), + "vocab_size": "{}".format(args['vocab_size']), + "decoder_layers": "{}".format(args['num_layer']), + "rotary_embedding": f"{args['rotary_embedding']}", + "start_id": "{}".format(args['start_id']), + "end_id": "{}".format(args['end_id']), + } + + with open('.tmp.config.ini', 'w') as configfile: + config.write(configfile) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('-max_batch_size', '--max_batch_size', type=int, default=8, metavar='NUMBER', + help='batch size (default: 8)') + parser.add_argument('-max_seq_len', '--max_seq_len', type=int, default=256, metavar='NUMBER', + help='max sequence length (default: 256)') + parser.add_argument('-beam_width', '--beam_width', type=int, default=1, metavar='NUMBER', + help='beam width for beam search (default: 1)') + parser.add_argument('-n', '--head_number', type=int, default=16, metavar='NUMBER', + help='head number (default: 16)') + parser.add_argument('-size', '--size_per_head', type=int, default=256, metavar='NUMBER', + help='size per head (default: 256)') + parser.add_argument('-inter_size', '--inter_size', type=int, default=16384, metavar='NUMBER', + help='inter size for ffn (default: 16384)') + parser.add_argument('-l', '--num_layer', type=int, default=28, metavar='NUMBER', + help='number of layers (default: 28)') + parser.add_argument('-v', '--vocab_size', type=int, default=50400, metavar='BOOL', + help='vocabulary size. (default: 50400).') + parser.add_argument('-d', '--data_type', type=str, default="bf16", metavar='STRING', + help='data type (default: bf16)', choices=['fp32', 'fp16', 'bf16']) + parser.add_argument('-topk', '--sampling_topk', type=int, default=0, metavar='NUMBER', + help='Candidate (k) value of top k sampling in decoding. Default is 0.') + parser.add_argument('-topp', '--sampling_topp', type=float, default=0.5, metavar='NUMBER', + help='Probability (p) value of top p sampling in decoding. Default is 0.5.') + parser.add_argument('-tensor_para_size', '--tensor_para_size', type=int, default=1, metavar='NUMBER', + help='tensor parallelism size. Default is 1.') + parser.add_argument('-pipeline_para_size', '--pipeline_para_size', type=int, default=1, metavar='NUMBER', + help='layer parallelism size. Default is 1.') + parser.add_argument('--model_dir', type=str, default="./models/", metavar='STRING', + help='Model path prfix. Default is "./models".') + parser.add_argument('-temperature', '--temperature', type=float, default=1.0, metavar='NUMBER', + help='temperature of penalty. Default is 1.0.') + parser.add_argument('-request_batch_size', '--request_batch_size', type=int, default=8, metavar='NUMBER', + help='batch size (default: 8)') + parser.add_argument('-request_output_len', '--request_output_len', type=int, default=32, metavar='NUMBER', + help='output length (default: 32)') + parser.add_argument('-start_id', '--start_id', type=int, default=50256, metavar='NUMBER', + help='start id (default: 50256)') + parser.add_argument('-end_id', '--end_id', type=int, default=50256, metavar='NUMBER', + help='end id (default: 50256)') + parser.add_argument('-repetition_penalty', '--repetition_penalty', type=float, default=1.0, metavar='NUMBER', + help='repetition_penalty (default: 1.0)') + parser.add_argument('-len_penalty', '--len_penalty', type=float, default=0.0, metavar='NUMBER', + help='len_penalty (default: 0.0)') + parser.add_argument('-beam_search_diversity_rate', '--beam_search_diversity_rate', type=float, default=0.0, metavar='NUMBER', + help='beam_search_diversity_rate (default: 0.0)') + parser.add_argument('-rotary_embedding', '--rotary_embedding', type=int, default=64, metavar='NUMBER', + help='batch size (default: 64)') + + args = parser.parse_args() + generate_gpt_config(vars(args)) diff --git a/examples/pytorch/gptj/utils/gptj_ckpt_convert.py b/examples/pytorch/gptj/utils/gptj_ckpt_convert.py index 81b9b3509..cfa96d8d7 100644 --- a/examples/pytorch/gptj/utils/gptj_ckpt_convert.py +++ b/examples/pytorch/gptj/utils/gptj_ckpt_convert.py @@ -5,6 +5,7 @@ import numpy as np import torch +import configparser torch.set_printoptions(linewidth=130, sci_mode=False) np.set_printoptions(linewidth=130, suppress=True) @@ -210,7 +211,7 @@ def main(ckpt_dir, num_layers=28, total_shards=8): while len(transforms) > 0: print(f"loading shards for part {part}") shards = [ - read_shard(f"{ckpt_dir}shard_{i}/", part) for i in range(total_shards) + read_shard(f"{ckpt_dir}/shard_{i}/", part) for i in range(total_shards) ] print(f"read from checkpoint") @@ -278,10 +279,30 @@ def main(ckpt_dir, num_layers=28, total_shards=8): print("saving") # load as in: https://github.com/finetuneanon/misc/blob/main/SizeTest.ipynb out_path = args.output_dir + output_dir = out_path + f"/{args.n_inference_gpus}-gpu/" + if len(out_path)>3 and out_path[-3:] == ".pt": torch.save(checkpoint, out_path) else: - output_dir = out_path + f"/{args.n_inference_gpus}-gpu/" save(checkpoint, output_dir, args.n_inference_gpus, num_layers) + # NOTE: hard code for gptj-6B configuration (TODO: make this automatic) + config = configparser.ConfigParser() + config["gptj"] = {} + try: + config["gptj"]["model_name"] = "gptj-6B" + config["gptj"]["head_num"] = "16" + config["gptj"]["size_per_head"] = "256" + config["gptj"]["inter_size"] = "16384" + config["gptj"]["num_layer"] = "28" + config["gptj"]["rotary_embedding"] = "64" + config["gptj"]["vocab_size"] = "50400" + config["gptj"]["start_id"] = "50256" + config["gptj"]["end_id"] = "50256" + config["gptj"]["weight_data_type"] = "fp32" + with open(output_dir + "/config.ini", 'w') as configfile: + config.write(configfile) + except: + print(f"Fail to save the config in config.ini.") + print("done") diff --git a/examples/pytorch/gptj/utils/huggingface_gptj_ckpt_convert.py b/examples/pytorch/gptj/utils/huggingface_gptj_ckpt_convert.py new file mode 100644 index 000000000..93f8917d2 --- /dev/null +++ b/examples/pytorch/gptj/utils/huggingface_gptj_ckpt_convert.py @@ -0,0 +1,152 @@ +from argparse import ArgumentParser +from os import makedirs +import numpy as np +from pathlib import Path + +import torch +import configparser +from transformers import PretrainedConfig + +torch.set_printoptions(linewidth=130, sci_mode=False) +np.set_printoptions(linewidth=130, suppress=True) + +# This converter is used to convert the huggingface gpt-j-6B model +# in https://huggingface.co/EleutherAI/gpt-j-6B/blob/main/pytorch_model.bin. + +def savebin(param, save_path): + if isinstance(param, torch.Tensor): + param = param.cpu().float().numpy() + np.squeeze(param).astype(np.float32).tofile(save_path + ".bin") + +def param2file(pt_param, layer_id, save_dir, dest_key): + base_n = save_dir + "/model.layers." + str(layer_id) + "." + save_path = base_n + dest_key + savebin(pt_param, save_path) + +def param2distributed( + pt_param, + layer_id, + save_dir, + dest_key, + n_inference_gpus, + split_axis, +): + np_param = pt_param.cpu().float().numpy() + base_n = save_dir + "/model.layers." + str(layer_id) + "." + save_path = base_n + dest_key + split_param = np.split(np_param, n_inference_gpus, axis=split_axis) + for i, p in enumerate(split_param): + savebin(p, save_path + f".{i}") + + +def save(w, save_dir, n_inference_gpus, layer_id): + makedirs(save_dir, exist_ok=True) + + savebin(w['transformer.wte.weight'], save_dir + "/model.wte") + l = layer_id + print(f"Saving layer {l} / 28") + base_k = "transformer.h." + str(l) + "." + param2file( + w[base_k + "ln_1.bias"], + l, save_dir, "input_layernorm.bias" + ) + param2file( + w[base_k + "ln_1.weight"], + l, save_dir, "input_layernorm.weight" + ) + param2distributed( + w[base_k + "mlp.fc_in.weight"].T, + l, save_dir, "mlp.dense_h_to_4h.weight", + n_inference_gpus, split_axis=-1 # split fast indx + ) + param2distributed( + w[base_k + "mlp.fc_in.bias"], + l, save_dir, "mlp.dense_h_to_4h.bias", + n_inference_gpus, split_axis=-1 # split fast indx + ) + + param2distributed( + w[base_k + "mlp.fc_out.weight"].T, + l, save_dir, "mlp.dense_4h_to_h.weight", + n_inference_gpus, split_axis=0 # split slow indx + ) + param2file( + w[base_k + "mlp.fc_out.bias"], + l, save_dir, "mlp.dense_4h_to_h.bias" + ) + param2distributed( + w[base_k + "attn.out_proj.weight"].T, + l, save_dir, "attention.dense.weight", + n_inference_gpus, split_axis=0 # split slow indx + ) + QKV_w = torch.stack([ + w[base_k + "attn.q_proj.weight"], + w[base_k + "attn.k_proj.weight"], + w[base_k + "attn.v_proj.weight"], + ]) # [qkv, n_heads * dim_head, latent_space] + QKV_w = QKV_w.permute(2, 0, 1) + param2distributed( + QKV_w, l, save_dir, "attention.query_key_value.weight", + n_inference_gpus, split_axis=-1 # split fast indx + ) + # Other unneeded per-layer params: + # attn.attention.masked_bias = torch.tensor(-1e9) + # attn.attention.bias = torch.tril(torch.ones(1, 1, 2048, 2048)) + +if __name__ == "__main__": + parser = ArgumentParser( + description="Convert GPT-J slim checkpoint to FasterTransformer", + ) + parser.add_argument( + "--output-dir", help="Folder where binary files are stored", default="gpt-j-6B/c-models/" + ) + parser.add_argument( + "--ckpt-dir", help="File of GPT-J huggingface checkpoint", default="gpt-j-6B/" + ) + parser.add_argument( + "--n-inference-gpus", help="Number of GPUs used for inference runtime", default=1, type=int + ) + args = parser.parse_args() + + NUM_LAYERS = 28 + + ckpt_file = args.ckpt_dir + "pytorch_model.bin" + checkpoint = torch.load(ckpt_file) + print(f"loading from {ckpt_file}") + + out_path = args.output_dir + output_dir = out_path + f"/{args.n_inference_gpus}-gpu/" + print(f"saving to {output_dir}") + + config_file = args.ckpt_dir + "config.json" + hf_config = PretrainedConfig.from_json_file(config_file).to_dict() + + # NOTE: save parameters to config files (loaded by triton backends) + config = configparser.ConfigParser() + config["gptj"] = {} + try: + config["gptj"]["model_name"] = "gptj" if hf_config["_name_or_path"] == '' else hf_config["_name_or_path"] + config["gptj"]["head_num"] = str(hf_config["n_head"]) + n_embd = hf_config["n_embd"] + config["gptj"]["size_per_head"] = str(n_embd // hf_config["n_head"]) + config["gptj"]["inter_size"] = str(n_embd * 4) + config["gptj"]["num_layer"] = str(hf_config["n_layer"]) + rotary_dim = n_embd // hf_config["n_head"] if hf_config["rotary_dim"] is None else hf_config["rotary_dim"] + config["gptj"]["rotary_embedding_dim"] = str(hf_config["rotary_dim"]) + config["gptj"]["vocab_size"] = str(hf_config["vocab_size"]) + config["gptj"]["start_id"] = str(hf_config["bos_token_id"]) + config["gptj"]["end_id"] = str(hf_config["eos_token_id"]) + config["gptj"]["weight_data_type"] = "fp32" + with open(output_dir + "/config.ini", 'w') as configfile: + config.write(configfile) + except: + print(f"Fail to save the config in config.ini.") + + for i in range(NUM_LAYERS): + save(checkpoint, output_dir, args.n_inference_gpus,i) + savebin(checkpoint['transformer.ln_f.weight'], output_dir + "/model.final_layernorm.weight") + savebin(checkpoint['transformer.ln_f.bias'], output_dir + "/model.final_layernorm.bias") + savebin(checkpoint['lm_head.weight'], output_dir + "/model.lm_head.weight") + savebin(checkpoint['lm_head.bias'], output_dir + "/model.lm_head.bias") + + print("done") diff --git a/examples/pytorch/gptj/utils/reference_gptj.py b/examples/pytorch/gptj/utils/reference_gptj.py index 41fbe7a42..77f1d38e5 100644 --- a/examples/pytorch/gptj/utils/reference_gptj.py +++ b/examples/pytorch/gptj/utils/reference_gptj.py @@ -1,5 +1,5 @@ import torch -from transformers import GPTNeoForCausalLM, AutoConfig +from transformers import GPTNeoForCausalLM, AutoConfig, GPT2Tokenizer from pathlib import Path # GPT-J 6B config @@ -49,6 +49,8 @@ def copy(self): state_dict=Checkpoint("j6b_ckpt.pt") ) +tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + input_ids = torch.as_tensor([ [ 818, 198, 464, 464, 818, 464, 198, 464], [ 262, 464, 968, 968, 257, 968, 198, 717], @@ -59,5 +61,9 @@ def copy(self): [ 11, 15198, 649, 663, 787, 41683, 628, 3807], [ 257, 318, 1182, 5079, 340, 423, 198, 11], ]).T.cuda() -o = model(input_ids) +output = model.generate(input_ids, max_length=40, k=1) + +# print(f"output ids: \n{output}") +for i in range(8): + print(f"[INFO] batch {i}: {tokenizer.decode(output[i][:])}") diff --git a/examples/pytorch/gptneox/utils/eleutherai_gpt_neox_convert.py b/examples/pytorch/gptneox/utils/eleutherai_gpt_neox_convert.py new file mode 100755 index 000000000..9e8953789 --- /dev/null +++ b/examples/pytorch/gptneox/utils/eleutherai_gpt_neox_convert.py @@ -0,0 +1,309 @@ +#! /usr/bin/env python3 +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import multiprocessing +import numpy as np +import torch # pytype: disable=import-error +import yaml + +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from tqdm import tqdm +from typing import List + +''' +GPT-NeoX 20B model +Download by wget --cut-dirs=5 -nH -r --no-parent --reject "index.html*" https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/ -P 20B_checkpoints + +layer_00-model_00-model_states.pt + word_embeddings.weight: embedding table, split by tensor parallelism +layer_02-model_00-model_states.pt ~ layer_45-model_01-model_states.pt: + input_layernorm.weight + input_layernorm.bias + attention.query_key_value.weight + attention.query_key_value.bias + attention.rotary_emb.inv_freq + attention.dense.weight + attention.dense.bias + post_attention_layernorm.weight + post_attention_layernorm.bias + mlp.dense_h_to_4h.weight + mlp.dense_h_to_4h.bias + mlp.dense_4h_to_h.weight + mlp.dense_4h_to_h.bias + +layer_47-model_00-model_states.pt: + finally layernorm. model_00 and model_01 have same weights. Using one of them is enough. +layer_48-model_00-model_states.pt + final_linear.weight. It should be the logit gemm weight. + +mp_rank_xx_model_states.pt: + some training states, useless in inference +''' + +weights_skip_tensor_split = ["input_layernorm.bias", + "input_layernorm.weight", + "attention.dense.bias", + "mlp.dense_4h_to_h.bias", + "post_attention_layernorm.bias", + "post_attention_layernorm.weight"] + +def write_config_file(save_dir): + file_template = """ +[gptneox] +model_name=gptneox_20B +head_num=64 +size_per_head=96 +vocab_size=50432 +num_layer=44 +rotary_embedding=24 +start_id=0 +end_id=2 +inter_size=24576 +use_gptj_residual=1 +weight_data_type=fp32 + """ + + with open(Path(save_dir) / "config.ini", "w") as f: + f.write(file_template) + +@dataclass +class KeyHandler: + outname: str + gather: str = "" + scatter: str = "copy" + reshape: List = field(default_factory=lambda: []) + transpose: List = field(default_factory=lambda: []) + + +def on_cpu(storage, loc): + return storage.cpu() + + +def handle_layer(chkpt_dir, in_filename, key_mapping, save_dir, + in_range, out_range, whole_range=None): + + def read_layers(filename, range): + if range is not None: + filename = [filename.format(i) for i in range] + else: + filename = [filename] + return [torch.load(chkpt_dir / fn, map_location=on_cpu) for fn in filename] + + layers = read_layers(in_filename, in_range) + layer_keys = set(layers[0].keys()) + + for key, value in key_mapping.items(): + key_templ, gather, scatter = value.outname, value.gather, value.scatter + reshape, transpose = value.reshape, value.transpose + layer_keys.remove(key) + if key_templ == "": + continue + + # Preprocess tensors + tensors = [np.array(layer[key], dtype=np.float32) for layer in layers] + if reshape: + tensors = [ten.reshape(reshape) for ten in tensors] + if transpose: + tensors = [ten.transpose(transpose) for ten in tensors] + + # Gather tensors + if len(tensors) == 1: + gather_tensor = tensors[0] + else: + if "join" in gather: + axis = int(gather.partition("_")[2]) + gather_tensor = np.concatenate(tensors, axis=axis) + elif gather == "mean": + gather_tensor = np.sum(tensors, axis=0) / len(tensors) + elif gather == "sum": + gather_tensor = np.sum(tensors, axis=0) + else: + raise NotImplementedError(f"Gather strategy {gather} is not supported") + + # Scatter tensors + if len(out_range) == 1: + scatter_tensors = [gather_tensor] + else: + if scatter == "copy": + scatter_tensors = [gather_tensor for i in out_range] + elif "split" in scatter: + axis = int(scatter.partition("_")[2]) + if gather_tensor.shape[axis] % out_range != 0: + raise ValueError(f"{key} cannot be divided in {len(out_range)} along axis {axis}") + + scatter_tensors = np.split(gather_tensor, len(out_range), axis=axis) + elif scatter == "divide": + scatter_tensors = [gather_tensor / len(out_range) for i in out_range] + else: + raise NotImplementedError(f"Scatter strategy {scatter} is not supported") + + for tensor, idx in zip(scatter_tensors, out_range): + output_name = key_templ.format(out_range[0]) + for weight_name in weights_skip_tensor_split: + if weight_name in output_name: + output_name = output_name.split('.') + del output_name[-1] + output_name = '.'.join(output_name) + tensor.tofile(save_dir / ("model." + output_name + ".bin")) + + if len(layer_keys) > 0: + print("[Warning] Remaining keys:", layer_keys) + + +def convert_checkpoint(args): + base_dir = Path(args.checkpoint_dir) + + with open(base_dir / "latest") as f: + chkpt_dir = f.readline().rstrip() + chkpt_dir = base_dir / chkpt_dir + + with open(base_dir / "configs/20B.yml") as f: + model_args = yaml.safe_load(f) + + hidden_dim = model_args["hidden-size"] + n_layers = model_args["num-layers"] + n_heads = model_args["num-attention-heads"] + hidden_per_head = hidden_dim // n_heads + + tp_source = model_args["model-parallel-size"] + tp_target = args.tensor_parallelism + print(f"Converting from {tp_source} to {tp_target} GPUs") + + save_dir = Path(args.save_dir) / f"{tp_target:d}-gpu" + save_dir.mkdir(parents=True, exist_ok=True) + + handle_layer_args = [] + handle_layer_args.append(( + chkpt_dir, + "layer_00-model_{:02d}-model_states.pt", + {"word_embeddings.weight": KeyHandler("wte", "join_0")}, + save_dir, + range(tp_source), + [0], + )) + handle_layer_args.append(( + chkpt_dir, + "layer_47-model_{:02d}-model_states.pt", + { + "norm.weight": KeyHandler("final_layernorm.weight", "mean"), + "norm.bias": KeyHandler("final_layernorm.bias", "mean"), + }, + save_dir, + range(tp_source), + [0], + )) + handle_layer_args.append(( + chkpt_dir, + "layer_48-model_{:02d}-model_states.pt", + { + "final_linear.weight": KeyHandler("lm_head.weight", "join_0"), + }, + save_dir, + range(tp_source), + [0], + )) + + gcd = np.gcd(tp_source, tp_target) + print(f"Strategy: group {tp_source//gcd} source gpu(s) into {tp_target//gcd} out gpu(s).\n") + + in_indices = np.split(np.arange(tp_source), gcd) + out_indices = np.split(np.arange(tp_target), gcd) + + for layer_id in range(model_args["num-layers"]): + for in_idx, out_idx in zip(in_indices, out_indices): + def make_fn_out(fn): + return f"layers.{layer_id}." + fn + ".{:d}" + + handle_layer_args.append(( + chkpt_dir, + f"layer_{layer_id+2:02d}" + "-model_{:02d}-model_states.pt", + { + "attention.rotary_emb.inv_freq": KeyHandler(""), + "attention.dense.weight": KeyHandler( + make_fn_out("attention.dense.weight"), + "join_0", "split_0", + transpose=[1, 0]), + "attention.dense.bias": KeyHandler( + make_fn_out("attention.dense.bias"), "sum", "divide"), + "attention.query_key_value.weight": KeyHandler( + make_fn_out("attention.query_key_value.weight"), + "join_2", "split_2", + reshape=[n_heads // tp_source, 3, hidden_per_head, hidden_dim], + transpose=[3, 1, 0, 2]), + "attention.query_key_value.bias": KeyHandler( + make_fn_out("attention.query_key_value.bias"), + "join_1", "split_1", + reshape=[n_heads // tp_source, 3, hidden_per_head], + transpose=[1, 0, 2]), + "input_layernorm.weight": KeyHandler( + make_fn_out("input_layernorm.weight"), "mean"), + "input_layernorm.bias": KeyHandler( + make_fn_out("input_layernorm.bias"), "mean"), + "mlp.dense_4h_to_h.weight": KeyHandler( + make_fn_out("mlp.dense_4h_to_h.weight"), + "join_0", "split_0", + transpose=[1, 0]), + "mlp.dense_4h_to_h.bias": KeyHandler( + make_fn_out("mlp.dense_4h_to_h.bias"), "sum", "divide"), + "mlp.dense_h_to_4h.weight": KeyHandler( + make_fn_out("mlp.dense_h_to_4h.weight"), + "join_1", "split_1", + transpose=[1, 0]), + "mlp.dense_h_to_4h.bias": KeyHandler( + make_fn_out("mlp.dense_h_to_4h.bias"), "join_0", "split_0"), + "post_attention_layernorm.weight": KeyHandler( + make_fn_out("post_attention_layernorm.weight"), "mean"), + "post_attention_layernorm.bias": KeyHandler( + make_fn_out("post_attention_layernorm.bias"), "mean"), + }, + save_dir, + in_idx, + out_idx, + )) + + torch.multiprocessing.set_start_method("spawn") + with multiprocessing.Pool(args.jobs) as pool: + pool.starmap(handle_layer, handle_layer_args) + + # Post-process biases and lm_head (TODO: remove this) + for layer_idx in range(model_args["num-layers"]): + attn_bias = np.fromfile(save_dir / f"model.layers.{layer_idx}.attention.dense.bias.bin", dtype=np.float32) + mlp_bias = np.fromfile(save_dir / f"model.layers.{layer_idx}.mlp.dense_4h_to_h.bias.bin", dtype=np.float32) + + (attn_bias + mlp_bias).tofile(save_dir / f"model.layers.{layer_idx}.mlp.attention.bias.sum.bin") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", metavar="checkpoint-dir", + help="directory where resides the source model. Must contain a \"latest\" file.") + parser.add_argument("save_dir", metavar="save-dir", + help="where to store the FT model") + parser.add_argument("--tensor-parallelism", "-t", type=int, default=1, + help="level of tensor parallelism used for inference") + parser.add_argument("--jobs", "-j", type=int, default=None, + help="how many processes to spawn for conversion (default: cpu_count)") + args = parser.parse_args() + + start_time = datetime.now() + convert_checkpoint(args) + write_config_file(args.save_dir + f"/{args.tensor_parallelism}-gpu") + stop_time = datetime.now() + run_time = (stop_time - start_time) + print("[INFO] Spend {} (h:m:s) to convert the model".format(run_time)) diff --git a/examples/pytorch/gptneox/utils/hftokenizer.py b/examples/pytorch/gptneox/utils/hftokenizer.py new file mode 100755 index 000000000..08a824de5 --- /dev/null +++ b/examples/pytorch/gptneox/utils/hftokenizer.py @@ -0,0 +1,92 @@ +#! /usr/bin/env python3 + +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser +from tokenizers import Tokenizer +from typing import List, Union + + +class HFTokenizer: + def __init__(self, vocab_file): + self.tokenizer = Tokenizer.from_file(vocab_file) + + def tokenize(self, text: str): + return self.tokenizer.encode(text).ids + + def tokenize_batch(self, text_batch: Union[List[str], str]): + return self.tokenizer.encode_batch(text_batch) + + def detokenize(self, token_ids): + return self.tokenizer.decode(token_ids) + + +def handle_args(): + parser = ArgumentParser() + parser.add_argument("in_file") + parser.add_argument("--out-file") + parser.add_argument("--tokenizer", default="../models/20B_tokenizer.json") + parser.add_argument("--action", choices=["tokenize", "detokenize", "auto"], default="auto") + + return parser.parse_args() + + +def main(in_file, tokenizer, out_file, action): + tokenizer = HFTokenizer(tokenizer) + + with open(in_file) as f: + lines = f.read().split('\n') + + in_lines = None + do = None + if action != "tokenize": + if in_lines is None: + try: + in_lines = [[int(tok) for tok in line.split(' ') if tok] for line in lines if line] + do = "detokenize" + except ValueError: + pass + if in_lines is None: + try: + in_lines = [[int(tok) for tok in line.split(', ') if tok] for line in lines if line] + do = "detokenize" + except ValueError: + pass + + if action != "detokenize": + if in_lines is None: + try: + in_lines = [line for line in lines if line] + do = "tokenize" + except ValueError: + pass + + if do is not None: + if do == "detokenize": + output = [tokenizer.detokenize(token_list) for token_list in in_lines] + else: + output = [tokenizer.tokenize(line) for line in in_lines] + output = [",".join(str(tok) for tok in tok_seq) for tok_seq in output] + + if args.out_file: + with open(out_file, "w") as f: + f.write("\n".join(output)) + else: + print("\n---\n".join(output)) + + +if __name__ == "__main__": + args = handle_args() + main(args.in_file, args.tokenizer, args.out_file, args.action) diff --git a/examples/pytorch/gptneox/utils/huggingface_jp_gptneox_convert.py b/examples/pytorch/gptneox/utils/huggingface_jp_gptneox_convert.py new file mode 100644 index 000000000..3411c0421 --- /dev/null +++ b/examples/pytorch/gptneox/utils/huggingface_jp_gptneox_convert.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import multiprocessing +import numpy as np +from pathlib import Path +import torch + +import os +import sys +from transformers import GPTNeoXForCausalLM # 4.21.1 + +def get_weight_data_type(data_type): + if data_type == "fp32": + return np.float32 + elif data_type == "fp16": + return np.float16 + else: + assert False, f"Invalid weight data type {data_type}" + +def prefix_prompt_convert(args, config, weight_data_type): + + saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num + + prompt_in_file_list = args.prompt_in_file_list.split(',') + + task_list = [] + for idx, prompt_in_file in enumerate(prompt_in_file_list): + weights=torch.load(prompt_in_file) + task_name = prompt_in_file.split("/")[-1].split(".")[-3] + + total_size = weights.nelement() + n_layers = config['n_layer'] + n_head = config['n_head'] + size_per_head = config['n_embd'] // n_head + prefix_prompt_len = total_size // (2 * n_layers * n_head * size_per_head) + + task_list.append((task_name, prefix_prompt_len)) + # GPT NeoX + weights=weights.view(prefix_prompt_len,n_layers,2,n_head,size_per_head) ## prefix_seq_len, num_layers, 2, num_heads, size_per_head + # weights=weights.view(prefix_prompt_len,28,2,16,256) ## prefix_seq_len, num_layers, 2, num_heads, size_per_head + weights=weights.permute(1,2,3,0,4) ## num_layers, 2, num_heads, perfix_seq_len, size_per_head + local_head_num = n_head // args.infer_gpu_num + weights_split = torch.split(weights, local_head_num, dim=2) + for i in range(args.infer_gpu_num): + output_file_path = saved_dir + "/model.prefix_prompt." + task_name + ".weight." + str(i) + ".bin" + weights_split[i].detach().cpu().numpy().astype(weight_data_type).tofile(output_file_path) + + return task_list + + +def split_and_convert_process(i, saved_dir,factor,key,args,config,val): + + if key.find("input_layernorm.weight") != -1 or key.find("input_layernorm.bias") != -1 or \ + key.find("attention.dense.bias") != -1 or key.find("post_attention_layernorm.weight") != -1 or \ + key.find("post_attention_layernorm.bias") != -1 or key.find("mlp.dense_4h_to_h.bias") != -1 or \ + key.find("final_layernorm.weight") != -1 or key.find("final_layernorm.bias") != -1: + + # shared weights, only need to convert the weights of rank 0 + if i == 0: + saved_path = saved_dir + "/model." + key + ".bin" + val.tofile(saved_path) + + elif key.find("attention.dense.weight") != -1 or key.find("mlp.dense_4h_to_h.weight") != -1: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + saved_path = saved_dir + "/model." + key + ".%d.bin" % (i * factor + j) + split_vals[j].tofile(saved_path) + + elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: + + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/model." + key + ".%d.bin" % (i * factor + j) + split_vals[j].tofile(saved_path) + + elif key.find("attention.query_key_value.bias") != -1: + local_dim = (int)(val.shape[-1] / 3) + n_head = config['n_head'] + + val = val.reshape(n_head, 3, local_dim // n_head) + val = np.transpose(val, [1, 0, 2]).reshape(3, local_dim) + split_vals = np.split(val, factor, axis=-1) + + for j in range(factor): + saved_path = saved_dir + "/model." + key + ".%d.bin" % (i * factor + j) + split_vals[j].tofile(saved_path) + + elif key.find("attention.query_key_value.weight") != -1: + hidden_dim = val.shape[0] + local_dim = (int)(val.shape[-1] / 3) + n_head = config['n_head'] + # Note that the HF qkv weight are stored as [hidden_size, num_heads, 3, head_hidden] + # FT needs the shape of [hidden_size, 3, num_heads, head_hidden] + val = val.reshape(hidden_dim, n_head, 3, local_dim // n_head) + val = np.transpose(val, [0, 2, 1, 3]).reshape(hidden_dim, 3, local_dim) + + # print(np.mean(np.abs(val[:, 0, :]))) + split_vals = np.split(val, factor, axis=-1) + + for j in range(factor): + saved_path = saved_dir + "/model." + key + ".%d.bin" % (i * factor + j) + split_vals[j].tofile(saved_path) + + else: + print("[ERROR] cannot find key '{}'".format(key)) + +def split_and_convert(args): + saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num + + if(os.path.exists(saved_dir) == False): + os.makedirs(saved_dir) + ckpt_name = args.in_file + + t_gpu_num = args.trained_gpu_num + i_gpu_num = args.infer_gpu_num + assert(i_gpu_num % t_gpu_num == 0) + + factor = (int)(i_gpu_num / t_gpu_num) + + # load position_embedding from rank 0 + # model = torch.load(ckpt_name) + model = GPTNeoXForCausalLM.from_pretrained(args.in_file) + hf_config = vars(model.config) + if "gpt_j_residual" not in hf_config: + hf_config["gpt_j_residual"] = 0 + + np_weight_data_type = get_weight_data_type(args.weight_data_type) + + task_list = [] + if args.prompt_in_file_list is not None: + task_list = prefix_prompt_convert(args, hf_config, np_weight_data_type) + + try: + model_name = args.model_name + config = configparser.ConfigParser() + config['gptneox'] = {} + config['gptneox']['model_name'] = model_name + config['gptneox']["head_num"] = str(hf_config["n_head"]) + n_embd = hf_config["n_embd"] + config['gptneox']["size_per_head"] = str(n_embd // hf_config["n_head"]) + config['gptneox']["inter_size"] = str(n_embd * 4) + config['gptneox']["num_layer"] = str(hf_config["n_layer"]) + rotary_dim = n_embd // hf_config["n_head"] if hf_config["rotary_dim"] is None else hf_config["rotary_dim"] + config['gptneox']["rotary_embedding"] = str(rotary_dim) + config['gptneox']["vocab_size"] = str(hf_config["vocab_size"]) + config['gptneox']["start_id"] = str(hf_config["bos_token_id"]) + config['gptneox']["end_id"] = str(hf_config["eos_token_id"]) + config['gptneox']['use_gptj_residual'] = str(int(hf_config['gpt_j_residual'])) + config['gptneox']["weight_data_type"] = args.weight_data_type + + if len(task_list) > 0: + config['gptneox']['num_tasks'] = str(len(task_list)) + config['gptneox']['prompt_learning_type'] = str(2) + for idx, (task_name, prompt_length) in enumerate(task_list): + config[f'task_{idx}'] = {} + config[f'task_{idx}']['task_name'] = task_name + config[f'task_{idx}']['prompt_length'] = str(prompt_length) + with open((Path(saved_dir) / f"config.ini").as_posix(), 'w') as configfile: + config.write(configfile) + except: + print(f"Fail to save the config in config.ini.") + + huggingface_model_name_pattern = [ + "ln_1.bias", + "ln_1.weight", + "attn.qkv_proj.bias", + "attn.qkv_proj.weight", + "attn.out_proj.bias", + "attn.out_proj.weight", + "ln_2.bias", + "ln_2.weight", + "mlp.fc_in.bias", + "mlp.fc_in.weight", + "mlp.fc_out.bias", + "mlp.fc_out.weight", + ] + + ft_model_name_pattern = [ + "input_layernorm.bias", + "input_layernorm.weight", + "attention.query_key_value.bias", + "attention.query_key_value.weight", + "attention.dense.bias", + "attention.dense.weight", + "post_attention_layernorm.bias", + "post_attention_layernorm.weight", + "mlp.dense_h_to_4h.bias", + "mlp.dense_h_to_4h.weight", + "mlp.dense_4h_to_h.bias", + "mlp.dense_4h_to_h.weight", + ] + + torch.multiprocessing.set_start_method("spawn") + pool = multiprocessing.Pool(args.processes) + for name, param in model.named_parameters(): + if name.find("weight") == -1 and name.find("bias") == -1: + continue + print(name) + if name == 'transformer.wpe.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wpe.bin") + elif name == 'transformer.wte.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.bin") + elif name == 'transformer.ln_f.bias': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.final_layernorm.bias.bin") + elif name == 'transformer.ln_f.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.final_layernorm.weight.bin") + elif name == 'lm_head.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin") + elif name == 'lm_head.bias': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.bias.bin") + else: + for i in range(len(huggingface_model_name_pattern)): + if name.find(huggingface_model_name_pattern[i]) != -1: + new_name = name.replace("transformer.h.", "layers.").replace(huggingface_model_name_pattern[i], ft_model_name_pattern[i]) + pool.starmap(split_and_convert_process, + [(0, saved_dir, factor, new_name, args, vars(model.config), + param.detach().cpu().numpy().astype(np_weight_data_type).T)], ) + + pool.close() + pool.join() + + # Post-process biases if use_gptj_residual is True + if hf_config['gpt_j_residual']: + for layer_idx in range(hf_config["n_layer"]): + attn_bias = np.fromfile(saved_dir + f"/model.layers.{layer_idx}.attention.dense.bias.bin", dtype=np.float32) + mlp_bias = np.fromfile(saved_dir + f"/model.layers.{layer_idx}.mlp.dense_4h_to_h.bias.bin", dtype=np.float32) + + (attn_bias + mlp_bias).tofile(saved_dir + f"/model.layers.{layer_idx}.mlp.attention.bias.sum.bin") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('-saved_dir', '-o', type=str, help='file name of output file', required=True) + parser.add_argument('-in_file', '-i', type=str, help='file name of input checkpoint file', required=True) + parser.add_argument('-prompt_in_file_list','-p_i_list', type=str, help='list of the prompt weight file path,' + 'separate by (,). e.g. -prompt_in_file_list prefix_prompt.task0.weight,prefix_prompt.task1.weight') + parser.add_argument('-trained_gpu_num', '-t_g', type=int, help='How many gpus for inference', default=1) + parser.add_argument('-infer_gpu_num', '-i_g', type=int, help='How many gpus for inference', required=True) + parser.add_argument("-processes", "-p", type=int, help="How many processes to spawn for conversion (default: 4)", default=4) + parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) + parser.add_argument('-model_name', '-m_n', type=str, help='model name', required=True) + + args = parser.parse_args() + print("\n=============== Argument ===============") + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") + + split_and_convert(args) \ No newline at end of file diff --git a/examples/pytorch/longformer/longformer_qa.py b/examples/pytorch/longformer/longformer_qa.py index 9895dd003..ec46446a4 100644 --- a/examples/pytorch/longformer/longformer_qa.py +++ b/examples/pytorch/longformer/longformer_qa.py @@ -30,15 +30,17 @@ def parse_from_config(model_dir): def build_ft_longformer(hf_model_dir, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, ft_longformer_lib, fp16): + attn_scaler, ft_longformer_lib, data_type): weights_file = os.path.join(hf_model_dir, 'pytorch_model.bin') ft_encoder = FTLongformerEncoder(weights_file, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, ft_longformer_lib, fp16) + attn_scaler, ft_longformer_lib, data_type) ft_longformer = build_hf_longformer(hf_model_dir) - if fp16: + if data_type == 'fp16': ft_longformer = ft_longformer.half() + elif data_type == 'bf16': + ft_longformer = ft_longformer.bfloat16() ft_longformer.cuda() ft_longformer.eval() ft_encoder.set_hf_plugin_mode(True) @@ -53,7 +55,7 @@ def build_hf_longformer(model_dir): return hf_longformer -def prepare_input(question, passage_text, seq_len, batch_size, model_dir, fp16): +def prepare_input(question, passage_text, seq_len, batch_size, model_dir, data_type): tokenizer = LongformerTokenizer.from_pretrained(model_dir) encoding = tokenizer(question, passage_text, return_token_type_ids=True) qa_sep_index = 0 @@ -78,9 +80,12 @@ def prepare_input(question, passage_text, seq_len, batch_size, model_dir, fp16): local_attn_mask_b = torch.stack([local_attn_mask for _ in range(batch_size)], axis=0).contiguous() global_attn_mask_b = torch.stack([global_attn_mask for _ in range(batch_size)], axis=0).contiguous() - if fp16: + if data_type == 'fp16': local_attn_mask_b = local_attn_mask_b.half() global_attn_mask_b = global_attn_mask_b.half() + elif data_type == 'bf16': + local_attn_mask_b = local_attn_mask_b.bfloat16() + global_attn_mask_b = global_attn_mask_b.bfloat16() local_attn_mask_b = local_attn_mask_b.cuda() global_attn_mask_b = global_attn_mask_b.cuda() @@ -107,7 +112,7 @@ def main(): help='Path to huggingface model dir where model file and config file is stored') parser.add_argument('-l', '--ft-longformer-lib', type=str, default=os.path.join(project_root, 'build', 'lib', 'libth_longformer.so'), help='Path to fastertransformer longformer pytorch op lib') - parser.add_argument('--fp16', action='store_true', help="Use FP16") + parser.add_argument('--data_type', type=str, choices=['fp32', 'fp16', 'bf16'], default='fp32') parser.add_argument('-p', '--passage', type=str, nargs='*', help='Text for paragraph/passage for LongformerBERT QA', default=None) parser.add_argument('-pf', '--passage-file', type=str, help='File containing input passage', @@ -123,7 +128,7 @@ def main(): parser.add_argument("-g", "--max-global-attention-num", default=128, help="Max global attention token num from start of the sequence to the end.", type=int) parser.add_argument('-r', '--repeat-test-num', - help='If specified, will run inference serveral rounds, to test average performace.', + help='If specified, will run inference several rounds, to test average performance.', type=int, default=None) args, _ = parser.parse_known_args() @@ -157,36 +162,44 @@ def main(): # huggeingFace longformer hf_longformer = build_hf_longformer(model_dir) - if args.fp16: + if args.data_type == 'fp16': hf_longformer = hf_longformer.half() + # fastertransformer longformer ft_longformer = build_ft_longformer(model_dir, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, ft_longformer_lib, args.fp16) + attn_scaler, ft_longformer_lib, args.data_type) # prepare input input_ids_b, local_attn_mask_b, global_attn_mask_b, input_ids, actual_seq_len = prepare_input( - question, passage_text, seq_len, batch_size, model_dir, args.fp16) + question, passage_text, seq_len, batch_size, model_dir, args.data_type) # 1. Compare the performance between HF and FT, using dummy input dummy_local_attn_mask_b = torch.ones_like(local_attn_mask_b) extended_mask_b = (global_attn_mask_b + dummy_local_attn_mask_b) * 10000. - 10000. dummy_embedding_out = torch.rand(batch_size, seq_len, hidden_size, dtype=torch.float32) - if args.fp16: + if args.data_type == 'fp16': dummy_embedding_out = dummy_embedding_out.half() + elif args.data_type == 'bf16': + dummy_embedding_out = dummy_embedding_out.bfloat16() dummy_embedding_out = dummy_embedding_out.cuda() + hf_encoder = hf_longformer.longformer.encoder ft_encoder = ft_longformer.longformer.encoder + if args.data_type == 'bf16': + print("HF longerformer encoder doesn't support BFloat16, FallBack to FP32 !") + with torch.no_grad(): # HuggingFace warmup + for i in range(10): - output = hf_encoder(dummy_embedding_out, attention_mask=extended_mask_b, head_mask=None, + output = hf_encoder(dummy_embedding_out.float(), attention_mask=extended_mask_b, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=True) start = time.time() for i in range(repeat_num): - output = hf_encoder(dummy_embedding_out, attention_mask=extended_mask_b, head_mask=None, + output = hf_encoder(dummy_embedding_out.float(), attention_mask=extended_mask_b, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=True) stop = time.time() print("HuggingFace Longformer encoder average latency {:.3f} second ({} iterations)".format((stop - start) / repeat_num, repeat_num)) @@ -212,8 +225,8 @@ def main(): ft_answer = decode_output(outputs, model_dir, input_ids, actual_seq_len) outputs = hf_longformer(input_ids_b, - attention_mask=local_attn_mask_b, - global_attention_mask=global_attn_mask_b) + attention_mask=local_attn_mask_b.float(), + global_attention_mask=global_attn_mask_b.float()) hf_answer = decode_output(outputs, model_dir, input_ids, actual_seq_len) print("HuggingFace Answer: " + hf_answer) print("FasterTransformer Answer: " + ft_answer) diff --git a/examples/pytorch/longformer/model.py b/examples/pytorch/longformer/model.py index 5174fc6d5..00e624f3d 100644 --- a/examples/pytorch/longformer/model.py +++ b/examples/pytorch/longformer/model.py @@ -4,12 +4,12 @@ from transformers.models.longformer.modeling_longformer import LongformerBaseModelOutput -def from_hf_longformer_weight_to_ft(weights_file, layer_num, fp16): +def from_hf_longformer_weight_to_ft(weights_file, layer_num, data_type): weights = torch.load(weights_file) all_weights = [] for i in range(0, layer_num): # Need to transpose the kernel for torch.nn.Linear - # q k v kg vg weights and bias should be continous, required by the ft longformer encoder. + # q k v kg vg weights and bias should be continuous, required by the ft longformer encoder. all_weights.append(weights["longformer.encoder.layer.{}.attention.self.query.weight".format(i)].transpose(0, 1)) all_weights.append(weights["longformer.encoder.layer.{}.attention.self.key.weight".format(i)].transpose(0, 1)) all_weights.append(weights["longformer.encoder.layer.{}.attention.self.value.weight".format(i)].transpose(0, 1)) @@ -48,7 +48,12 @@ def from_hf_longformer_weight_to_ft(weights_file, layer_num, fp16): for i in range(0, len(all_weights)): all_weights[i] = all_weights[i].flatten() - all_weights = torch.cat(all_weights).type(torch.float16) if fp16 else torch.cat(all_weights) + if data_type == "fp16": + all_weights = torch.cat(all_weights).type(torch.float16) + elif data_type == "bf16": + all_weights = torch.cat(all_weights).type(torch.bfloat16) + elif data_type == "fp32": + all_weights = torch.cat(all_weights).type(torch.float32) return all_weights.contiguous() @@ -56,14 +61,14 @@ class FTLongformerEncoder(torch.nn.Module): def __init__(self, weights_file, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, ft_longformer_lib, fp16=False, hf_plugin_mode=False): + attn_scaler, ft_longformer_lib, data_type='fp32', hf_plugin_mode=False): super().__init__() - self.fp16 = fp16 + self.data_type = data_type assert seq_len % local_attn_window_size == 0 and seq_len / \ local_attn_window_size >= 2, "seq_len need to be multiple of local_attn_window_size and at least 2 times big." self.hf_plugin_mode = hf_plugin_mode - all_weight = from_hf_longformer_weight_to_ft(weights_file, layer_num, self.fp16) + all_weight = from_hf_longformer_weight_to_ft(weights_file, layer_num, data_type) self.all_weight = all_weight.cuda() torch.classes.load_library(ft_longformer_lib) self.ft_encoder = torch.classes.FasterTransformer.LongformerEncoder(layer_num, head_num * size_per_head, diff --git a/examples/pytorch/nemo.py b/examples/pytorch/nemo.py new file mode 100644 index 000000000..27a2a3a65 --- /dev/null +++ b/examples/pytorch/nemo.py @@ -0,0 +1,188 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import logging +import pathlib +import tarfile +import typing + +import torch +import yaml + +from .utils import cpu_map_location, gpu_map_location + + +LOGGER = logging.getLogger(__name__) + + +def unpack_nemo_ckpt( + nemo_archive_path: typing.Union[str, pathlib.Path], + out_dir_path: typing.Union[str, pathlib.Path], +): + nemo_archive_path = pathlib.Path(nemo_archive_path) + if not nemo_archive_path.exists(): + raise FileNotFoundError(f"{nemo_archive_path} does not exist") + + for tar_mode in ["r:", "r:gz"]: + try: + with tarfile.open(nemo_archive_path, mode=tar_mode) as tar_file: + tar_file.extractall(path=out_dir_path) + return out_dir_path + except tarfile.ReadError: + pass + + raise RuntimeError(f"Could not unpack {nemo_archive_path}") + + +def extract_layers_with_prefix(model_, prefix): + length_to_trim = len(prefix) + model_state = model_.get("state_dict", model_) + return {key[length_to_trim:]: model_state[key] for key in model_state.keys() if prefix in key} + + +class UnpackedNemoCheckpointDir: + def __init__(self, checkpoints_dir: typing.Union[str, pathlib.Path], load_checkpoints_to_cpu: bool = False): + self._checkpoints_dir = pathlib.Path(checkpoints_dir) + self._load_checkpoints_to_cpu = load_checkpoints_to_cpu + + @property + @functools.lru_cache + def model_config(self): + model_config = None + + model_config_filename = "model_config.yaml" + model_configs_paths = list(self._checkpoints_dir.rglob(model_config_filename)) + if model_configs_paths: + if len(model_configs_paths) > 1: + raise RuntimeError( + f"There are more than single {model_config_filename} " + f"in {self._checkpoints_dir}: {', '.join(map(lambda p: p.as_posix(), model_configs_paths))}" + ) + model_config_path = model_configs_paths[0] + LOGGER.debug("Loading model config from %s", model_config_path) + with model_config_path.open("r") as model_config_file: + model_config = yaml.load(model_config_file, Loader=yaml.SafeLoader) + else: + LOGGER.debug("Searching model config in checkpoints") + # try to obtain from checkpoint + checkpoint_name = self.checkpoint_name + checkpoints_paths = sorted(self._checkpoints_dir.rglob(checkpoint_name)) + if checkpoints_paths: + # assume that parallel ranks 0 checkpoint should have model config embedded + checkpoint_path = checkpoints_paths[0] + + map_location_fn = cpu_map_location if self._load_checkpoints_to_cpu else gpu_map_location + + model_00 = torch.load(checkpoint_path, map_location=map_location_fn) + if "hyper_parameters" in model_00 and "cfg" in model_00["hyper_parameters"]: + model_config = model_00["hyper_parameters"]["cfg"] + LOGGER.debug("Loaded model config from checkpoint %s", checkpoint_path) + else: + LOGGER.debug("Could not find model config in checkpoint %s", checkpoint_path) + + del model_00 + + if model_config is None: + LOGGER.warning("Could not find checkpoint with NeMo model config in %s", self._checkpoints_dir) + + LOGGER.debug("Loaded model config %s", model_config) + + return model_config + + @property + def checkpoints_dir(self): + return self._checkpoints_dir + + def get_checkpoints_paths(self, tensor_model_parallel_size=1, pipeline_model_parallel_size=1): + """ + Injects tensor/pipeline model parallel ranks into the filepath. + Does nothing if not using model parallelism. + """ + + checkpoint_path_without_rank = self.checkpoints_dir / self.checkpoint_name + + def _inject_parallel_ranks(tp_rank, pp_rank): + if tensor_model_parallel_size > 1 or pipeline_model_parallel_size > 1: + if pipeline_model_parallel_size is None or pipeline_model_parallel_size == 1: + checkpoint_path = ( + checkpoint_path_without_rank.parent + / f"mp_rank_{tp_rank:02d}" + / checkpoint_path_without_rank.name + ) + else: + checkpoint_path = ( + checkpoint_path_without_rank.parent + / f"tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}" + / checkpoint_path_without_rank.name + ) + return checkpoint_path + else: + return checkpoint_path_without_rank + + return [ + [ + _inject_parallel_ranks(tp_rank=tp_rank, pp_rank=pp_rank) + for pp_rank in range(pipeline_model_parallel_size) + ] + for tp_rank in range(tensor_model_parallel_size) + ] + + @property + @functools.lru_cache + def checkpoint_name(self): + patterns = [ + "model_weights.ckpt", # older megatron checkpoints + "*last.ckpt", # newer format of checkpoints + ] + for pattern in patterns: + model_files = sorted(list(self._checkpoints_dir.rglob(pattern))) + if model_files: + return model_files[0].name + + raise ValueError(f"Could not find checkpoint files in {self._checkpoints_dir}") + + @functools.lru_cache + def get_tokenizer_file_path(self, tokenizer_key, file_key, default_filename_pattern): + model_config = self.model_config + file_property = None + if tokenizer_key in model_config and file_key in model_config[tokenizer_key]: + file_property = model_config[tokenizer_key][file_key] + elif file_key in model_config: + file_property = model_config[file_key] + + LOGGER.debug("model_config[%s][%s]=%s", tokenizer_key, file_key, file_property) + + if file_property and file_property.startswith("nemo:"): + filename = file_property.split("nemo:")[1] + filename_pattern = f"*{filename}" + elif file_property and file_property.startswith("/artifacts/"): + filename = pathlib.Path(file_property).name + filename_pattern = f"*{filename}" + elif file_property is None or file_property == "None": + filename_pattern = None + else: + filename_pattern = default_filename_pattern + LOGGER.warning( + f"Tokenizer file from config: {tokenizer_key}.{file_key}={file_property} " + f"looks like unsupported path. Pattern {filename_pattern} will be used." + ) + + file_path = None + if filename_pattern is not None: + files_paths = list(self._checkpoints_dir.glob(filename_pattern)) + if files_paths: + assert len(files_paths) == 1 + file_path = files_paths[0] + + return file_path diff --git a/examples/pytorch/swin/Swin-Transformer-Quantization/main.py b/examples/pytorch/swin/Swin-Transformer-Quantization/main.py index a385421e6..fa49f5c2e 100644 --- a/examples/pytorch/swin/Swin-Transformer-Quantization/main.py +++ b/examples/pytorch/swin/Swin-Transformer-Quantization/main.py @@ -87,7 +87,7 @@ def parse_option(): parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") parser.add_argument('--use-checkpoint', action='store_true', help="whether to use gradient checkpointing to save memory") - parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], + parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'], help='mixed precision opt level, if O0, no amp is used') parser.add_argument('--output', default='output', type=str, metavar='PATH', help='root of output folder, the full path is // (default: output)') diff --git a/examples/pytorch/swin/Swin-Transformer-Quantization/models.py b/examples/pytorch/swin/Swin-Transformer-Quantization/models.py index e859ce02c..ef333d8de 100644 --- a/examples/pytorch/swin/Swin-Transformer-Quantization/models.py +++ b/examples/pytorch/swin/Swin-Transformer-Quantization/models.py @@ -49,7 +49,7 @@ def build_model(config): patch_norm=config.MODEL.SWIN.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT) else: - raise NotImplementedError(f"Unkown model: {model_type}") + raise NotImplementedError(f"Unknown model: {model_type}") return model @@ -360,7 +360,7 @@ class SwinTransformerBlock(nn.Module): Args: dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. + input_resolution (tuple[int]): Input resolution. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. diff --git a/examples/pytorch/swin/Swin-Transformer-Quantization/quant_utils.py b/examples/pytorch/swin/Swin-Transformer-Quantization/quant_utils.py index 6fa82e77f..468ff6076 100644 --- a/examples/pytorch/swin/Swin-Transformer-Quantization/quant_utils.py +++ b/examples/pytorch/swin/Swin-Transformer-Quantization/quant_utils.py @@ -277,7 +277,7 @@ def set_quantizers(name, mod, which='both', **kwargs): set_quantizer(name, mod, '_input_quantizer', k, v) if which in ['weight', 'both']: set_quantizer(name, mod, '_weight_quantizer', k, v) - # logger.info(s) + logger.info(s) def set_quantizer_by_name(model, names, **kwargs): @@ -295,4 +295,4 @@ def set_quantizer_by_name(model, names, **kwargs): for k, v in kwargs.items(): s += (f' {k}={v}') setattr(mod, k, v) - # logger.info(s) + logger.info(s) diff --git a/examples/pytorch/swin/SwinTransformerWeightTransposeQKVWeight.py b/examples/pytorch/swin/SwinTransformerWeightTransposeQKVWeight.py index 1acafc551..2de633992 100644 --- a/examples/pytorch/swin/SwinTransformerWeightTransposeQKVWeight.py +++ b/examples/pytorch/swin/SwinTransformerWeightTransposeQKVWeight.py @@ -120,3 +120,7 @@ def to_half(self): for idx, v in enumerate(self.weights): self.weights[idx] = v.half() + def to_bfloat16(self): + for idx, v in enumerate(self.weights): + self.weights[idx] = v.bfloat16() + diff --git a/examples/pytorch/swin/infer_swintransformer_int8_op.py b/examples/pytorch/swin/infer_swintransformer_int8_op.py index 7fda392ed..3235f2e46 100644 --- a/examples/pytorch/swin/infer_swintransformer_int8_op.py +++ b/examples/pytorch/swin/infer_swintransformer_int8_op.py @@ -89,7 +89,7 @@ def parse_option(): parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") parser.add_argument('--use-checkpoint', action='store_true', help="whether to use gradient checkpointing to save memory") - parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], + parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'], help='mixed precision opt level, if O0, no amp is used') parser.add_argument('--output', default='output', type=str, metavar='PATH', help='root of output folder, the full path is // (default: output)') @@ -304,7 +304,7 @@ def run_swintransformernv_op(config, args, model, images, use_fp16): # warm up for i in range(warmup_time): op_embedding = swin_transformer.forward(images) - op_output = model.head(op_embedding) + op_output = model.head(op_embedding.float()) torch.cuda.synchronize() op_begin = time.time() @@ -368,7 +368,7 @@ def validate_with_random_data(config, args, model): # torch.cuda.synchronize() # torch_start = time.time() # for i in range(test_time): - INT8_torch_output = model(images_half) + INT8_torch_output = model(images_float) # torch.cuda.synchronize() # torch_end = time.time() @@ -386,7 +386,8 @@ def validate_with_random_data(config, args, model): # print("diff between instance 0 and 1:", diff_int8.mean()) diff = abs(INT8_torch_output - INT8_op_output) print("INT8_torch_output vs INT8_op_output , avg diff : ", diff.mean((1)), "max diff : ", diff.max((1))) - + assert diff.mean() < 0.04, "[ERROR] SWIN INT8 Op TEST FAIL !" + print("[INFO] SWIN INT8 Op TEST PASS !") @torch.no_grad() diff --git a/examples/pytorch/swin/infer_swintransformer_op.py b/examples/pytorch/swin/infer_swintransformer_op.py index 7473b14cb..615848239 100644 --- a/examples/pytorch/swin/infer_swintransformer_op.py +++ b/examples/pytorch/swin/infer_swintransformer_op.py @@ -86,7 +86,7 @@ def main(args, config): validate_with_random_data(args, config, model) @torch.no_grad() -def run_swintransformernv_op(args, config, model, images, use_fp16): +def run_swintransformernv_op(args, config, model, images, data_type): th_path = args.th_path depths = config.MODEL.SWIN.DEPTHS depths_tensor = torch.tensor(depths, dtype=torch.int) @@ -109,9 +109,12 @@ def run_swintransformernv_op(args, config, model, images, use_fp16): qk_scale = 1.0 torch.classes.load_library(th_path) sw_weights = SwinTransformerWeightTransposeQKVWeight(layer_num, window_size, depths, num_heads, th_path, model.state_dict()) - if use_fp16: + if data_type == 'fp16': sw_weights.to_half() model.half() + elif data_type == 'bf16': + sw_weights.to_bfloat16() + model.bfloat16() sw_weights.to_cuda() ##run pytorch op @@ -134,9 +137,11 @@ def run_swintransformernv_op(args, config, model, images, use_fp16): #_nvtx.rangePop() torch.cuda.synchronize() op_end = time.time() - op_output = op_output.cpu().numpy() - if use_fp16: + op_output = op_output.float().cpu().numpy() + if data_type == 'fp16': print("FP16 op time : ", (op_end - op_begin)/test_time*1000.0, "ms") + elif data_type == 'bf16': + print("BF16 op time : ", (op_end - op_begin)/test_time*1000.0, "ms") else: print("FP32 op time : ", (op_end - op_begin)/test_time*1000.0, "ms") @@ -157,7 +162,7 @@ def run_torch(model, images, mark): #_nvtx.rangePop() torch.cuda.synchronize() torch_end = time.time() - torch_output = torch_output.cpu().numpy() + torch_output = torch_output.float().cpu().numpy() # Numpy doesn't support BF16 print(mark + " time : ", (torch_end - torch_start)/test_time*1000.0, "ms") return torch_output @@ -171,35 +176,49 @@ def validate_with_random_data(args, config, model): image = np.random.rand(1, in_chans, img_size, img_size) images = np.repeat(image, max_batch, axis=0) images_half = torch.tensor(images, dtype=torch.half) + images_bfloat16 = torch.tensor(images, dtype=torch.bfloat16) images_float = torch.tensor(images, dtype=torch.float) images_half = images_half.cuda(non_blocking=True) + images_bfloat16 = images_bfloat16.cuda(non_blocking=True) images_float = images_float.cuda(non_blocking=True) ##run original swin-transformer # run pytorch op - FP32_op_output = run_swintransformernv_op(args, config, model, images_float, False) + FP32_op_output = run_swintransformernv_op(args, config, model, images_float, 'fp32') traced_module_float = torch.jit.trace(model, images_float) FP32_torch_traced_output = run_torch(traced_module_float, images_float, "FP32 torch trace") FP32_torch_output = run_torch(model, images_float, "FP32 torch") - FP16_op_output = run_swintransformernv_op(args, config, model, images_half, True) + FP16_op_output = run_swintransformernv_op(args, config, model, images_half, 'fp16') - traced_module_half = torch.jit.trace(model, images_half) + traced_module_half = torch.jit.trace(model.half(), images_half) FP16_torch_traced_output = run_torch(traced_module_half, images_half, "FP16 torch trace") FP16_torch_output = run_torch(model, images_half, "FP16 torch") + BF16_op_output = run_swintransformernv_op(args, config, model, images_bfloat16, 'bf16') + traced_module_bfloat16 = torch.jit.trace(model.bfloat16(), images_bfloat16) + BF16_torch_traced_output = run_torch(traced_module_bfloat16, images_bfloat16, "BF16 torch trace") + BF16_torch_output = run_torch(model, images_bfloat16, "BF16 torch") diff = abs(FP32_torch_traced_output - FP32_op_output) print("FP32_torch_traced_output vs FP32_op_output , avg diff : ", diff.mean(), "max diff : ", diff.max()) + assert diff.mean() < 0.001, "[ERROR] SWIN FP32 Op TEST FAIL !" diff = abs(FP16_torch_traced_output - FP16_op_output) print("FP16_torch_traced_output vs FP16_op_output , avg diff : ", diff.mean(), "max diff : ", diff.max()) + assert diff.mean() < 0.001, "[ERROR] SWIN FP16 Op TEST FAIL !" + diff = abs(BF16_torch_traced_output - BF16_op_output) + print("BF16_torch_traced_output vs BF16_op_output , avg diff : ", diff.mean(), "max diff : ", diff.max()) + assert diff.mean() < 0.003, "[ERROR] SWIN BF16 Op TEST FAIL !" + + print("[INFO] SWIN Op TEST PASS !") if __name__ == '__main__': args, config = parse_option() - seed = config.SEED + int(time.time()) + # seed = config.SEED + int(time.time()) + seed = config.SEED torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True diff --git a/examples/pytorch/t5/mnli_task_example.py b/examples/pytorch/t5/mnli_task_example.py new file mode 100644 index 000000000..952f7d19d --- /dev/null +++ b/examples/pytorch/t5/mnli_task_example.py @@ -0,0 +1,394 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import dataclasses +import json +import os +import pathlib +import time + +import numpy as np +import torch +import torch.distributed as dist +import typing +from tqdm import tqdm + +from omegaconf.omegaconf import OmegaConf, open_dict +from nemo.collections.nlp.data.glue_benchmark.glue_benchmark_dataset import ( + TextToTextGLUEDataset, + TextToTextXNLIDataset, +) +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + +from examples.pytorch.t5.utils.ft_encoder import FTT5EncoderWeight, FTT5Encoder +from examples.pytorch.t5.utils.ft_decoding import FTT5DecodingWeight, FTT5Decoding, FTT5 + + +def _build_dataset(data_cfg, tokenizer): + if data_cfg.task_name == 'xnli': + dataset = TextToTextXNLIDataset( + data_cfg.file_path, + task_name=data_cfg.task_name, + tokenizer=tokenizer, + max_seq_length=data_cfg.max_seq_length, + lang_list=data_cfg.eval_languages, + ) + else: + dataset = TextToTextGLUEDataset( + data_cfg.file_path, + task_name=data_cfg.task_name, + tokenizer=tokenizer, + max_seq_length=data_cfg.max_seq_length, + ) + return dataset + + +def preds_and_labels_to_text(tokenizer, preds, labels): + preds = preds.cpu().numpy().tolist() + labels = labels.cpu().numpy().tolist() + preds = [pred[0] for pred in preds] + + preds_text, labels_text = [], [] + for _, (pred, label) in enumerate(zip(preds, labels)): + if tokenizer.eos_id in pred: + idx = pred.index(tokenizer.eos_id) + pred = pred[:idx] + + # Legacy sentencepiece detokenization still preserves special tokens which messes up exact string match. + if hasattr(tokenizer, 'special_token_to_id'): + pred = [id for id in pred if id not in tokenizer.special_token_to_id.values()] + label = [id for id in label if id not in tokenizer.special_token_to_id.values()] + pred = tokenizer.ids_to_text(pred) + label = tokenizer.ids_to_text(label) + preds_text.append(pred) + labels_text.append(label) + + return preds_text, labels_text + + +def accuracy_score(pred, ref): + assert len(pred) == len(ref) + total = len(pred) + correct = 0 + for p, r in zip(pred, ref): + if p == r: + correct += 1 + # else: + # print(f"[pred]: {p} [label]: {r}") + accuracy = correct / total + print(f"[accuracy]: {accuracy}") + return accuracy + + +@dataclasses.dataclass +class Metric: + acc: float + + +@dataclasses.dataclass +class RequestAndResult: + model_answer: str + target: str + metrics: Metric + + +class InputToken: + def __init__(self, input_ids, attention_mask): + self.input_ids = input_ids + self.attention_mask = attention_mask + + +class EncoderDecoderConfig: + def __init__(self, d_model, vocab_size, num_heads, d_kv, d_ff, num_layers, + relative_attention_num_buckets_or_max_pos_seq_len, decoder_start_token_id=0, decoder_end_token_id=1): + self.d_model = d_model + self.vocab_size = vocab_size + self.num_heads = num_heads + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.relative_attention_num_buckets = relative_attention_num_buckets_or_max_pos_seq_len + self.decoder_start_token_id = decoder_start_token_id + self.decoder_end_token_id = decoder_end_token_id + +data_type_mapping = {"fp32": 0, "fp16": 1, "bf16": 2} + +def mnli_task(args_dict): + torch.set_printoptions(precision=6) + batch_size = args_dict['batch_size'] + beam_size = args_dict['beam_width'] + max_output_len = args_dict['max_output_len'] + beam_search_diversity_rate = args_dict['beam_search_diversity_rate'] + topk = args_dict['sampling_topk'] + topp = args_dict['sampling_topp'] + tensor_para_size = args_dict['tensor_para_size'] + pipeline_para_size = args_dict['pipeline_para_size'] + + if args_dict['ckpt_path'] is None: + raise Exception("Megatron T5 model needs to specify checkpoint path !") + + if dist.is_mpi_available(): + try: + dist.init_process_group(backend='mpi') + rank = dist.get_rank() + except: + rank = dist.get_rank() + else: + rank = 0 + + assert dist.get_world_size() == tensor_para_size * pipeline_para_size + + ckpt_path = pathlib.Path(args_dict['ckpt_path']) + ## read checkpoint config if exists + ckpt_config = configparser.ConfigParser() + + vocab_path = ckpt_path / "vocab.txt" + ckpt_config_path = ckpt_path / "config.ini" + if ckpt_config_path.is_file(): + ckpt_config.read(ckpt_config_path) + ## update structure config + t5_with_bias = ckpt_config.getboolean('structure', 't5_with_bias') + ## megatron with bias and use absolute position embedding + ## relative position embedding -> 0, absolute position embedding -> 1 + position_embedding_type = 0 if ckpt_config.get('structure', 'position_embedding_type') == 'relative' else 1 + use_gated_activation = ckpt_config.getboolean('structure', 'use_gated_activation') + weight_data_type = {"fp16": np.float16, "fp32": np.float32}[ckpt_config.get("encoder", "weight_data_type")] + activation_type = ckpt_config.get('encoder', 'feed_forward_proj') + assert ckpt_config.getint("encoder", "tensor_para_size") == tensor_para_size + else: + raise Exception("config file does exist with the ckpt !") + + if rank == 0: + print("\n=============== Argument ===============") + for key in args_dict: + print("{}: {}".format(key, args_dict[key])) + print("========================================") + + lib_path = args_dict['lib_path'] + + ## build tokenizer, dataset, dataloader + tokenizer_t5 = get_nmt_tokenizer( + library='megatron', + model_name='BertWordPieceCase', + tokenizer_model=None, + vocab_file=vocab_path.as_posix(), + merges_file=None, + legacy=False, + ) + + assert tokenizer_t5.bos_id == ckpt_config.getint("decoder", "decoder_start_token_id") + assert tokenizer_t5.eos_id == ckpt_config.getint("decoder", "eos_token_id") + + token_params = { + tokenizer_t5.bos_token: tokenizer_t5.bos_id, + tokenizer_t5.eos_token: tokenizer_t5.eos_id, + tokenizer_t5.pad_token: tokenizer_t5.pad_id, + } + print(f"tokenizer special tokens: {token_params}") + mnli_cfg = OmegaConf.create({ + "file_path": args_dict['data_path'], + "task_name": "mnli", + "max_seq_length": 512 + }) + + mnli_dataset = _build_dataset(mnli_cfg, tokenizer_t5) + + data_loader = torch.utils.data.DataLoader( + mnli_dataset, + collate_fn=mnli_dataset.collate_fn, + batch_size=batch_size, + num_workers=8, + pin_memory=True, + drop_last=True) + + q_scaling = 1.0 + + encoder_config = EncoderDecoderConfig(ckpt_config.getint('encoder', 'd_model'), + ckpt_config.getint('encoder', 'vocab_size'), + ckpt_config.getint('encoder', 'num_heads'), + ckpt_config.getint('encoder', 'd_kv'), + ckpt_config.getint('encoder', 'd_ff'), + ckpt_config.getint('encoder', 'num_layers'), + ckpt_config.getint('encoder', 'relative_attention_num_buckets_or_max_pos_seq_len') + ) + + decoder_config = EncoderDecoderConfig(ckpt_config.getint('decoder', 'd_model'), + ckpt_config.getint('decoder', 'vocab_size'), + ckpt_config.getint('decoder', 'num_heads'), + ckpt_config.getint('decoder', 'd_kv'), + ckpt_config.getint('decoder', 'd_ff'), + ckpt_config.getint('decoder', 'num_layers'), + ckpt_config.getint('decoder', 'relative_attention_num_buckets_or_max_pos_seq_len'), + tokenizer_t5.bos_id, + tokenizer_t5.eos_id + ) + + ## run gemm test + if os.path.isfile("gemm_config.in") and rank == 0: + cmd = f"rm gemm_config.in" + print(f"Run {cmd}") + os.system(cmd) + if rank == 0: + data_type = data_type_mapping[args_dict['data_type']] + cmd = f"./bin/t5_gemm {batch_size // pipeline_para_size} {beam_size} {128} " \ + f"{encoder_config.d_model} {encoder_config.num_heads} {encoder_config.d_kv} {encoder_config.d_ff} " \ + f"{decoder_config.d_model} {decoder_config.num_heads} {decoder_config.d_kv} {decoder_config.d_ff} " \ + f"{decoder_config.vocab_size} {data_type} {tensor_para_size} 1 > .tmp_gemm.log" + print(f"Run gemm test: {cmd}") + os.system(cmd) + + dist.barrier() + + ft_encoder_weight = FTT5EncoderWeight( + encoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type, + ) + ft_decoding_weight = FTT5DecodingWeight( + decoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type, + ) + + ft_encoder_weight.load_from_bin(ckpt_path.as_posix()) + ft_decoding_weight.load_from_bin(ckpt_path.as_posix()) + + if args_dict['data_type'] == 'fp16': + ft_encoder_weight.to_half() + ft_decoding_weight.to_half() + elif args_dict['data_type'] == 'fp32': + ft_encoder_weight.to_single() + ft_decoding_weight.to_single() + elif args_dict['data_type'] == 'bf16': + ft_encoder_weight.to_bfloat16() + ft_decoding_weight.to_bfloat16() + + remove_padding = True if batch_size > 32 else False + ft_encoder = FTT5Encoder(ft_encoder_weight.w, lib_path, encoder_config.num_heads, + encoder_config.d_kv, encoder_config.d_ff, + encoder_config.d_model, remove_padding, encoder_config.num_layers, + encoder_config.relative_attention_num_buckets, + 128, False, q_scaling, tensor_para_size, pipeline_para_size, t5_with_bias, + position_embedding_type, activation_type) + ft_decoding = FTT5Decoding(ft_decoding_weight.w, lib_path, + decoder_config.num_heads, decoder_config.d_kv, + decoder_config.d_ff, encoder_config.d_model, + decoder_config.d_model, decoder_config.num_layers, + decoder_config.decoder_start_token_id, decoder_config.decoder_end_token_id, + decoder_config.vocab_size, + q_scaling, + decoder_config.relative_attention_num_buckets, max_distance=128, + tensor_para_size=tensor_para_size, pipeline_para_size=pipeline_para_size, + activation_type=activation_type, + t5_with_bias=t5_with_bias, position_embedding_type = position_embedding_type) + + ft_t5 = FTT5(ft_encoder, ft_decoding) + + preds_list = [] + labels_list = [] + results_list = [] + start = time.time() + for idx, batch in tqdm(enumerate(data_loader)): + input_token = InputToken(batch['text_enc'], batch['enc_mask']) + ft_decoding_outputs, ft_decoding_seq_lens = ft_t5(input_token, + None, + beam_size, + max_output_len, + topk, + topp, + beam_search_diversity_rate=beam_search_diversity_rate, + is_return_output_log_probs=args_dict["return_output_log_probs"], + is_return_cum_log_probs=args_dict["return_cum_log_probs"]) + + preds, labels = preds_and_labels_to_text(tokenizer_t5, torch.IntTensor(ft_decoding_outputs), batch['labels']) + + labels_list += labels + preds_list += preds + results_list.extend([ + RequestAndResult( + model_answer=pred, + target=label, + metrics=Metric(acc=pred == label) + ) + for pred, label in zip(preds, labels) + ]) + + end = time.time() + if rank == 0: + print(f"\n[Elapsed Time]: {end - start} seconds") + + accuracy = accuracy_score(preds_list, labels_list) + output_path = args_dict.get("output_path") + if output_path is not None and rank == 0: + output_path = pathlib.Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as output_file: + results = { + "results": { + "mnli": { + "acc": accuracy + } + }, + "output": { + "mnli": [ + dataclasses.asdict(r) for r in results_list + ] + } + } + json.dump(results, output_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('-batch', '--batch_size', type=int, default=1, metavar='NUMBER', + help='batch size (default: 1)') + parser.add_argument('-beam', '--beam_width', type=int, default=4, metavar='NUMBER', + help='beam width (default: 4)') + parser.add_argument('-s', '--max_output_len', type=int, default=10, metavar='NUMBER', + help='max output length (default: 10)') + parser.add_argument('-diversity_rate', '--beam_search_diversity_rate', type=float, default=0.0, metavar='NUMBER', + help='deviersity rate of beam search. default is 0. When diversity rate = 0, it is equivalent to the naive beams earch.') + parser.add_argument('-topk', '--sampling_topk', type=int, default=1, metavar='NUMBER', + help='Candidate (k) value of top k sampling in decoding. Default is 1.') + parser.add_argument('-topp', '--sampling_topp', type=float, default=0.0, metavar='NUMBER', + help='Probability (p) value of top p sampling in decoding. Default is 0.0. ') + parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', + help='data type (default: fp32)', choices=['fp32', 'fp16', 'bf16']) + parser.add_argument('-lib_path', '--lib_path', type=str, default="/workspace/FasterTransformer/build/lib/libth_t5.so", metavar='STRING', + help='the path of FasterTransformer pytorch t5 op library.') + parser.add_argument('-data_path', '--data_path', type=str, required=True, help="the MNLI task data path") + parser.add_argument('-tensor_para_size', '--tensor_para_size', type=int, default=1, metavar='NUMBER', + help='size of tensor parallelism (default: 1)') + parser.add_argument('-pipeline_para_size', '--pipeline_para_size', type=int, default=1, metavar='NUMBER', + help='size of pipeline parallelism (default: 1)') + # assume checkpoint config is also in the same path + parser.add_argument('--ckpt_path', type=str, help='path to the checkpoint file.') + parser.add_argument('--output_path', help='path to results file with calculated metrics.') + parser.add_argument('--return_output_log_probs', action='store_true', + help='Return the log probability of generated tokens.') + parser.add_argument('--return_cum_log_probs', action='store_true', + help='Return the cumulative log probability of generated tokens.') + args = parser.parse_args() + + mnli_task(vars(args)) \ No newline at end of file diff --git a/examples/pytorch/t5/perf_benchmark.py b/examples/pytorch/t5/perf_benchmark.py index a08277f3f..8b1cb3410 100644 --- a/examples/pytorch/t5/perf_benchmark.py +++ b/examples/pytorch/t5/perf_benchmark.py @@ -91,25 +91,32 @@ def translate(args_dict): infer_iterations = args_dict['iterations'] infer_duration = args_dict['duration'] seed = args_dict['seed'] + skip_gemm = args_dict['skip_gemm'] torch.manual_seed(seed) ## huggingface without bias and use relative position embedding ## relative position embedding -> 0, absolute position embedding -> 1 - t5_with_bias = 0 + t5_with_bias = False + use_gated_activation = False position_embedding_type = 0 + weight_data_type = np.float32 ## only huggingface model path supported model_path = args_dict['model_path'] if args_dict['model_path'] != None else args_dict['model'] ckpt_path = args_dict['ckpt_path'] model_type = args_dict['model_type'] ## read checkpoint config if exists ckpt_config = configparser.ConfigParser() + activation_type = "relu" if (model_type == "Megatron"): ckpt_config_path = os.path.join(ckpt_path, 'config.ini') if os.path.isfile(ckpt_config_path): ckpt_config.read(ckpt_config_path) ## update structure config - t5_with_bias = ckpt_config.getint('structure', 't5_with_bias') - position_embedding_type = ckpt_config.getint('structure', 'position_embedding_type') + t5_with_bias = ckpt_config.getboolean('structure', 't5_with_bias') + position_embedding_type = 0 if ckpt_config.get('structure', 'position_embedding_type') == 'relative' else 1 + use_gated_activation = ckpt_config.getboolean('structure', 'use_gated_activation') + weight_data_type = {"fp16": np.float16, "fp32": np.float32}[ckpt_config.get("encoder", "weight_data_type")] + activation_type = "gated-gelu" if use_gated_activation else "gelu" # change to gelu, which is default setting of Megatron T5 else: raise Exception("config file does exist with the ckpt !") @@ -178,9 +185,9 @@ def translate(args_dict): if time_args.find("1") != -1: translation_result_list.append(TranslationResult("ft-beamsearch-warmup", "FT")) translation_result_list.append(TranslationResult("ft-beamsearch", "FT")) - if rank == 0: + if rank == 0 and not skip_gemm: is_fp16 = 1 if args_dict['data_type'] == 'fp16' else 0 - cmd = f"./bin/t5_gemm {batch_size // pipeline_para_size} {beam_size} {128} " \ + cmd = f"./bin/t5_gemm {math.ceil(batch_size / pipeline_para_size)} {beam_size} {128} " \ f"{encoder_config.d_model} {encoder_config.num_heads} {encoder_config.d_kv} {encoder_config.d_ff} " \ f"{decoder_config.d_model} {decoder_config.num_heads} {decoder_config.d_kv} {decoder_config.d_ff} " \ f"{decoder_config.vocab_size} {is_fp16} {tensor_para_size} 1 > .tmp_gemm.log" @@ -192,9 +199,9 @@ def translate(args_dict): if time_args.find("3") != -1: translation_result_list.append(TranslationResult("ft-sampling-warmup", "FT")) translation_result_list.append(TranslationResult("ft-sampling", "FT")) - if rank == 0: + if rank == 0 and not skip_gemm: is_fp16 = 1 if args_dict['data_type'] == 'fp16' else 0 - cmd = f"./bin/t5_gemm {batch_size // pipeline_para_size} {1} {128} " \ + cmd = f"./bin/t5_gemm {math.ceil(batch_size / pipeline_para_size)} {1} {128} " \ f"{encoder_config.d_model} {encoder_config.num_heads} {encoder_config.d_kv} {encoder_config.d_ff} " \ f"{decoder_config.d_model} {decoder_config.num_heads} {decoder_config.d_kv} {decoder_config.d_ff} " \ f"{decoder_config.vocab_size} {is_fp16} {tensor_para_size} 1 1 > .tmp_gemm.log" @@ -202,8 +209,24 @@ def translate(args_dict): os.system(cmd) if time_args.find("1") != -1 or time_args.find("3") != -1: - ft_encoder_weight = FTT5EncoderWeight(encoder_config, tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type) - ft_decoding_weight = FTT5DecodingWeight(decoder_config, tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type) + ft_encoder_weight = FTT5EncoderWeight( + encoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type + ) + ft_decoding_weight = FTT5DecodingWeight( + decoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type, + ) if args_dict["ckpt_path"] is not None: ft_encoder_weight.load_from_bin(args_dict["ckpt_path"]) @@ -223,7 +246,8 @@ def translate(args_dict): encoder_config.d_kv, encoder_config.d_ff, encoder_config.d_model, remove_padding, encoder_config.num_layers, encoder_config.relative_attention_num_buckets, - 128, False, q_scaling, tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type) + 128, False, q_scaling, tensor_para_size, pipeline_para_size, t5_with_bias, + position_embedding_type, activation_type) ft_decoding = FTT5Decoding(ft_decoding_weight.w, lib_path, decoder_config.num_heads, decoder_config.d_kv, decoder_config.d_ff, encoder_config.d_model, @@ -236,7 +260,8 @@ def translate(args_dict): q_scaling, decoder_config.relative_attention_num_buckets, max_distance=128, tensor_para_size=tensor_para_size, pipeline_para_size=pipeline_para_size, - t5_with_bias=t5_with_bias, position_embedding_type = position_embedding_type) + t5_with_bias=t5_with_bias, activation_type=activation_type, + position_embedding_type = position_embedding_type) ft_t5 = FTT5(ft_encoder, ft_decoding) @@ -275,6 +300,7 @@ def translate(args_dict): if translation_result_list[i].name.find("sampling") != -1: tmp_beam_size = 1 ft_decoding_outputs, ft_decoding_seq_lens = ft_t5(input_token, + None, tmp_beam_size, output_seq_len, topk, @@ -359,6 +385,8 @@ def translate(args_dict): help='Minimal duration in seconds for inference iterations for each implementation.') parser.add_argument('-seed', '--seed', type=int, default=0, metavar='NUMBER', help='Random seed used to generate random input values.') + parser.add_argument('-skip_gemm', '--skip_gemm', action="store_true", + help='Skip the gemm autotuning by not calling the ./bin/t5_gemm binary.') args = parser.parse_args() translate(vars(args)) \ No newline at end of file diff --git a/examples/pytorch/t5/requirement.txt b/examples/pytorch/t5/requirement.txt index 030ae70f1..80e0e470a 100644 --- a/examples/pytorch/t5/requirement.txt +++ b/examples/pytorch/t5/requirement.txt @@ -1,5 +1,7 @@ -transformers==4.10.0 -tokenizers==0.10.1 -omegaconf -SentencePiece -sacrebleu +SentencePiece~=0.1.96 +datasets~=2.3.2 +omegaconf~=2.1.2 +rouge_score~=0.1.2 +sacrebleu~=2.1.0 +transformers~=4.20.1 +tokenizers~=0.12.1 \ No newline at end of file diff --git a/examples/pytorch/t5/summarization.py b/examples/pytorch/t5/summarization.py new file mode 100644 index 000000000..affe6ee49 --- /dev/null +++ b/examples/pytorch/t5/summarization.py @@ -0,0 +1,375 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +This example is used to verify the correctess on summarization task. So, we don't +put benchmark testing in this example. +''' + +from __future__ import print_function +import argparse +import json +import numpy as np +import os +import sys +import torch +import torch.distributed as dist +from datasets import load_dataset, load_metric +# dir_path = os.path.dirname(os.path.realpath(__file__)) +# sys.path.append(dir_path + "/../../../3rdparty/transformers/src/") + +from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config +from tqdm import tqdm +import configparser +import math +import datetime + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(dir_path + "/../../..") +from examples.pytorch.t5.utils.ft_decoding import FTT5DecodingWeight, FTT5Decoding, FTT5 +from examples.pytorch.t5.utils.ft_encoder import FTT5EncoderWeight, FTT5Encoder + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--ft_model_location', type=str, + default='/models/T5/HF/t5-base/c-models/') + parser.add_argument('--hf_model_location', type=str, + default='/models/T5/HF/t5-base/') + parser.add_argument('--disable_summarize', action='store_true') + parser.add_argument('--test_hf', action='store_true') + parser.add_argument('--test_ft', action='store_true') + parser.add_argument('--data_type', type=str, choices=['fp32', 'fp16', 'bf16'], default='fp32') + parser.add_argument("--cache_path", type=str, default="/workdir/datasets/ccdv/") + parser.add_argument("--max_ite", type=int, default=20) + parser.add_argument("--max_seq_len", type=int, default=200) + parser.add_argument("--ft_use_hf_config", action="store_true", + help="use the hyper-parameters from the hf model") + parser.add_argument('--lib_path', type=str, default='./lib/libth_t5.so', + help='path to the pyt_fastertransformer dynamic lib file.') + parser.add_argument('--tensor_para_size', type=int, default=1, + help='tensor parallel size') + parser.add_argument('--pipeline_para_size', type=int, default=1, + help='pipeline parallel size') + parser.add_argument('--rougeLsum_threshold', type=float, + help='Threshold of FT rougeLsum score') + + args = parser.parse_args() + + if dist.is_mpi_available(): + try: + dist.init_process_group(backend='mpi') + rank = dist.get_rank() + except: + rank = dist.get_rank() + else: + rank = 0 + + disable_summarize = args.disable_summarize + test_hf = args.test_hf + test_ft = args.test_ft + + tensor_para_size = args.tensor_para_size + pipeline_para_size = args.pipeline_para_size + ft_model_location = args.ft_model_location + f"/{tensor_para_size}-gpu/" + hf_model_location = args.hf_model_location + + tokenizer = AutoTokenizer.from_pretrained(hf_model_location) + tokenizer.pad_token = tokenizer.eos_token + dataset_cnn = load_dataset("ccdv/cnn_dailymail", '3.0.0', cache_dir=args.cache_path) + + if rank == 0 and test_hf: + start_time = datetime.datetime.now() + if args.data_type == "fp32": + model = T5ForConditionalGeneration.from_pretrained(hf_model_location, torch_dtype=torch.float32).cuda() + elif args.data_type == "fp16": + model = T5ForConditionalGeneration.from_pretrained(hf_model_location, torch_dtype=torch.float16).cuda() + elif args.data_type == "bf16": + model = T5ForConditionalGeneration.from_pretrained(hf_model_location, torch_dtype=torch.bfloat16).cuda() + stop_time = datetime.datetime.now() + print(f"[INFO] load HF model spend {(stop_time - start_time).total_seconds()} sec") + + if test_ft: + ckpt_config = configparser.ConfigParser() + + ckpt_config_path = os.path.join(ft_model_location, 'config.ini') + if os.path.isfile(ckpt_config_path): + ckpt_config.read(ckpt_config_path) + else: + assert False, "[ERROR] This example only support loading model with FT format directly." + + weight_data_type = np.float32 + weight_data_type = {"fp16": np.float16, "fp32": np.float32}[ckpt_config.get("encoder", "weight_data_type")] + relative_attention_max_distance = 128 + encoder_config = T5Config(vocab_size=ckpt_config.getint("encoder", "vocab_size"), + d_model=ckpt_config.getint("encoder", "d_model"), + d_kv=ckpt_config.getint("encoder", "d_kv"), + d_ff=ckpt_config.getint("encoder", "d_ff"), + num_layers=ckpt_config.getint("encoder", "num_layers"), + num_decoder_layers=ckpt_config.getint("encoder", "num_decoder_layers"), + num_heads=ckpt_config.getint("encoder", "num_heads"), + relative_attention_num_buckets=ckpt_config.getint( + "encoder", "relative_attention_num_buckets_or_max_pos_seq_len"), + feed_forward_proj=ckpt_config.get("encoder", "feed_forward_proj"), + pad_token_id=ckpt_config.getint("encoder", "pad_token_id"), + eos_token_id=ckpt_config.getint("encoder", "eos_token_id"), + is_gated_act=ckpt_config.getboolean("encoder", "is_gated_act", fallback=0), + ) + decoder_config = T5Config(vocab_size=ckpt_config.getint("decoder", "vocab_size"), + d_model=ckpt_config.getint("decoder", "d_model"), + d_kv=ckpt_config.getint("decoder", "d_kv"), + d_ff=ckpt_config.getint("decoder", "d_ff"), + num_layers=ckpt_config.getint("decoder", "num_layers"), + num_decoder_layers=ckpt_config.getint("decoder", "num_decoder_layers"), + num_heads=ckpt_config.getint("decoder", "num_heads"), + relative_attention_num_buckets=ckpt_config.getint( + "decoder", "relative_attention_num_buckets_or_max_pos_seq_len"), + feed_forward_proj=ckpt_config.get("decoder", "feed_forward_proj"), + pad_token_id=ckpt_config.getint("decoder", "pad_token_id"), + eos_token_id=ckpt_config.getint("decoder", "eos_token_id"), + decoder_start_token_id=ckpt_config.getint("decoder", "decoder_start_token_id"), + is_gated_act=ckpt_config.getboolean("decoder", "is_gated_act", fallback=0), + ) + assert decoder_config.feed_forward_proj == encoder_config.feed_forward_proj + assert decoder_config.feed_forward_proj == encoder_config.feed_forward_proj + + t5_with_bias = ckpt_config.getboolean("structure", "t5_with_bias") + use_gated_activation = encoder_config.is_gated_act + position_embedding_type = 0 if ckpt_config.get('structure', 'position_embedding_type') == 'relative' else 1 + activation_type = encoder_config.feed_forward_proj + + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L1660 + # if tie_word_embeddings == True, scale the decoder output by sequence_output = sequence_output * (self.model_dim**-0.5) + tie_word_embeddings = ckpt_config.getboolean("decoder", "tie_word_embeddings") + ft_encoder_weight = FTT5EncoderWeight( + encoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type + ) + ft_decoding_weight = FTT5DecodingWeight( + decoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type, + ) + + start_time = datetime.datetime.now() + ft_encoder_weight.load_from_bin(ft_model_location) + stop_time = datetime.datetime.now() + print(f"[INFO] load FT encoder model spend {(stop_time - start_time).total_seconds()} sec") + start_time = datetime.datetime.now() + ft_decoding_weight.load_from_bin(ft_model_location) + stop_time = datetime.datetime.now() + print(f"[INFO] load FT decoding model spend {(stop_time - start_time).total_seconds()} sec") + if args.data_type == "fp32": + ft_encoder_weight.to_float() + ft_decoding_weight.to_float() + elif args.data_type == "fp16": + ft_encoder_weight.to_half() + ft_decoding_weight.to_half() + elif args.data_type == "bf16": + ft_encoder_weight.to_bfloat16() + ft_decoding_weight.to_bfloat16() + + ft_encoder_weight.to_cuda() + ft_decoding_weight.to_cuda() + + q_scaling = 1.0 / (math.sqrt(encoder_config.d_kv)) + remove_padding = True + ft_encoder = FTT5Encoder(ft_encoder_weight.w, args.lib_path, encoder_config.num_heads, + encoder_config.d_kv, encoder_config.d_ff, + encoder_config.d_model, remove_padding, encoder_config.num_layers, + encoder_config.relative_attention_num_buckets, + relative_attention_max_distance, False, q_scaling, tensor_para_size, + pipeline_para_size, t5_with_bias, + position_embedding_type, activation_type=activation_type) + + ft_decoding = FTT5Decoding(ft_decoding_weight.w, args.lib_path, + decoder_config.num_heads, decoder_config.d_kv, + decoder_config.d_ff, encoder_config.d_model, + decoder_config.d_model, decoder_config.num_layers, + decoder_config.decoder_start_token_id, decoder_config.eos_token_id, + decoder_config.vocab_size, q_scaling, + decoder_config.relative_attention_num_buckets, max_distance=relative_attention_max_distance, + tensor_para_size=tensor_para_size, pipeline_para_size=pipeline_para_size, + t5_with_bias=t5_with_bias, position_embedding_type=position_embedding_type, + activation_type=activation_type, tie_word_embeddings=tie_word_embeddings) + + ft_t5 = FTT5(ft_encoder, ft_decoding) + + if disable_summarize: + top_k = 1 + output_len = args.max_seq_len + else: + top_k = 1 + output_len = args.max_seq_len + + def summarize_ft(datapoint): + if not disable_summarize: + line = "summarize: " + datapoint['article'] + else: + line = datapoint['article'] + line = line.strip() + line = line.replace(" n't", "n't") + + line_tokens = tokenizer(line, return_tensors='pt') + + with torch.no_grad(): + output, ft_output_len = ft_t5(line_tokens, + None, + 1, + output_len, + top_k, + 0.0, + beam_search_diversity_rate=0.0, + is_return_output_log_probs=False, + is_return_cum_log_probs=False) + tokens = output[0][0] + + output_lines = tokenizer.decode(output[0][0][:ft_output_len[0][0]]) + output_lines = ".".join(output_lines.split('.')[:4]) + "." + return output_lines, tokens + + def summarize_hf(datapoint): + if not disable_summarize: + line = "summarize: " + datapoint['article'] + else: + line = datapoint['article'] + line = line.strip() + line = line.replace(" n't", "n't") + + line_encoded = tokenizer.encode(line, return_tensors='pt') + line_encoded = line_encoded.cuda() + + with torch.no_grad(): + output = model.generate(line_encoded, + max_length=output_len + 1, + top_k=top_k, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id) + + tokens = output[0].cpu().numpy() + output_lines = tokenizer.decode(output[0]) + output_lines = ".".join(output_lines.split('.')[:4]) + "." + return output_lines, tokens + + if disable_summarize: + tokens = [] + else: + metric_ft = load_metric("rouge") + metric_hf = load_metric("rouge") + + if not disable_summarize: + datapoint = dataset_cnn['test'][0] + if test_ft: + summary_ft, _ = summarize_ft(datapoint) + if rank == 0: + print('---------------------------------------------------------') + print('FT Generated : ') + print(' Article : ', datapoint['article']) + print('\n Highlights : ', datapoint['highlights']) + print('\n Summary : ', summary_ft) + print('---------------------------------------------------------') + metric_ft.add_batch(predictions=[summary_ft], references=[datapoint['highlights']]) + + if test_hf and rank == 0: + summary_hf, _ = summarize_hf(datapoint) + print('---------------------------------------------------------') + print('HF Generated : ') + print(' Article : ', datapoint['article']) + print('\n Highlights : ', datapoint['highlights']) + print('\n Summary : ', summary_hf) + print('---------------------------------------------------------') + metric_hf.add_batch(predictions=[summary_hf], references=[datapoint['highlights']]) + + ft_time = 0.0 + hf_time = 0.0 + for data_point_idx in tqdm(range(1, 11490, int(11490 / args.max_ite))): + try: + datapoint = dataset_cnn['test'][data_point_idx] + + start_time = datetime.datetime.now() + if test_ft: + summary_ft, tokens_ft = summarize_ft(datapoint) + stop_time = datetime.datetime.now() + ft_time += (stop_time - start_time).total_seconds() + + if rank == 0 and ((test_hf and not disable_summarize) or disable_summarize): + start_time = datetime.datetime.now() + summary_hf, tokens_hf = summarize_hf(datapoint) + stop_time = datetime.datetime.now() + hf_time += (stop_time - start_time).total_seconds() + + if rank == 0: + if not disable_summarize: + if test_ft: + metric_ft.add_batch(predictions=[summary_ft], references=[datapoint['highlights']]) + if test_hf: + metric_hf.add_batch(predictions=[summary_hf], references=[datapoint['highlights']]) + else: + tokens.append((tokens_ft, tokens_hf)) + except: + print('Error with datapoint : ', data_point_idx) + + def compute_exact_match(tokens, n_tokens=[1, 10, 25, 50, 100, 150, 200, 250]): + em_metrics = [] + for t in n_tokens: + errors = 0 + total = 0 + for ft_tokens, hf_tokens in tokens: + if len(ft_tokens) > t and len(hf_tokens) > t: + total = total + 1 + if not np.array_equal(ft_tokens[:t], hf_tokens[:t]): + errors = errors + 1 + + if total > 0: + print(f"{t}-token exact match acc: {100*(1-errors/total):.2f}") + em_metrics.append(1 - errors / total) + else: + em_metrics.append(np.nan) + + return em_metrics + + if rank == 0: + if not disable_summarize: + if test_ft: + computed_metrics_ft = metric_ft.compute() + + if test_hf: + computed_metrics_hf = metric_hf.compute() + + print(f'Hugging Face (total latency: {hf_time} sec)') + for key in computed_metrics_hf.keys(): + print(f'{key} : {computed_metrics_hf[key].mid[2]*100}') + + if test_ft: + print(f'Faster Transformers (total latency: {ft_time} sec)') + for key in computed_metrics_ft.keys(): + print(f'{key} : {computed_metrics_ft[key].mid[2]*100}') + if args.rougeLsum_threshold != None: + assert computed_metrics_ft["rougeLsum"].mid[2] * \ + 100 >= args.rougeLsum_threshold, "[INFO] TEST FAIL !" + print(f"[INFO] TEST PASS !") + else: + em_metrics = compute_exact_match(tokens) + + +if __name__ == '__main__': + main() diff --git a/examples/pytorch/t5/translate_example.py b/examples/pytorch/t5/translate_example.py index 36d347f20..edd843b1e 100644 --- a/examples/pytorch/t5/translate_example.py +++ b/examples/pytorch/t5/translate_example.py @@ -17,6 +17,7 @@ import os import sys import math +import logging from datetime import datetime import numpy as np import torch @@ -33,14 +34,18 @@ from examples.pytorch.t5.utils.ft_decoding import FTT5DecodingWeight, FTT5Decoding, FTT5 from examples.pytorch.decoding.utils.recover_bpe import recover_bpe +LOGGER = logging.getLogger(__name__) + +gemm_data_type_mapping = {"fp32":0, "fp16":1, "bf16":2} + def bleu_score(pred, ref): from sacrebleu import corpus_bleu bleu = corpus_bleu(pred, [ref], force=True) - print(" bleu score: {:6.2f}".format(bleu.score)) - print(" bleu counts: {}".format(bleu.counts)) - print(" bleu totals: {}".format(bleu.totals)) - print(" bleu precisions: {}".format(bleu.precisions)) - print(" bleu sys_len: {}; ref_len: {}".format(bleu.sys_len, bleu.ref_len)) + LOGGER.info(" bleu score: {:6.2f}".format(bleu.score)) + LOGGER.info(" bleu counts: {}".format(bleu.counts)) + LOGGER.info(" bleu totals: {}".format(bleu.totals)) + LOGGER.info(" bleu precisions: {}".format(bleu.precisions)) + LOGGER.info(" bleu sys_len: {}; ref_len: {}".format(bleu.sys_len, bleu.ref_len)) return bleu class TranslationResult(object): @@ -74,31 +79,37 @@ def translate(args_dict): max_ite = args_dict['max_iteration'] ## huggingface without bias and use relative position embedding ## relative position embedding -> 0, absolute position embedding -> 1 - t5_with_bias = 0 + t5_with_bias = False + use_gated_activation = False position_embedding_type = 0 + weight_data_type = np.float32 ## only huggingface model path supported model_path = args_dict['model_path'] if args_dict['model_path'] != None else args_dict['model'] ckpt_path = args_dict['ckpt_path'] model_type = args_dict['model_type'] ## read checkpoint config if exists ckpt_config = configparser.ConfigParser() + activation_type = "relu" if (model_type == "Megatron"): ckpt_config_path = os.path.join(ckpt_path, 'config.ini') if os.path.isfile(ckpt_config_path): ckpt_config.read(ckpt_config_path) ## update structure config - t5_with_bias = ckpt_config.getint('structure', 't5_with_bias') - position_embedding_type = ckpt_config.getint('structure', 'position_embedding_type') + t5_with_bias = ckpt_config.getboolean('structure', 't5_with_bias') + position_embedding_type = 0 if ckpt_config.get('structure', 'position_embedding_type') == 'relative' else 1 + use_gated_activation = ckpt_config.getboolean('structure', 'use_gated_activation') + weight_data_type = {"fp16": np.float16, "fp32": np.float32}[ckpt_config.get("encoder", "weight_data_type")] + activation_type = "gated-gelu" if use_gated_activation else "gelu" # change to gelu, which is default setting of Megatron T5 else: raise Exception("config file does exist with the ckpt !") if model_type == "Megatron" and args_dict['ckpt_path'] == None: raise Exception("Megatron T5 model needs to specify checkpoint path !") - print("\n=============== Argument ===============") + LOGGER.info("\n=============== Argument ===============") for key in args_dict: - print("{}: {}".format(key, args_dict[key])) - print("========================================") + LOGGER.info("{}: {}".format(key, args_dict[key])) + LOGGER.info("========================================") lib_path = args_dict['lib_path'] @@ -117,13 +128,24 @@ def translate(args_dict): t5_model = t5_model.to(rank) if args_dict['data_type'] == 'fp16': t5_model = t5_model.half() + elif args_dict['data_type'] == 'bf16': + t5_model = t5_model ## bfloat inference not supported yet ## TODO: modidy Megatron T5 Converter ## TODO: add megatron t5 tokenizer tokenizer = T5Tokenizer.from_pretrained(model_path) - fast_tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path) + try: + fast_tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path) + except: + fast_tokenizer = T5Tokenizer.from_pretrained(model_path) encoder_config = t5_model.encoder.config decoder_config = t5_model.decoder.config + if model_type != "Megatron": + activation_type = encoder_config.feed_forward_proj + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L1660 + # if tie_word_embeddings == True, scale the decoder output by sequence_output = sequence_output * (self.model_dim**-0.5) + tie_word_embeddings = decoder_config.tie_word_embeddings + q_scaling = 1.0 / (math.sqrt(encoder_config.d_kv)) if (model_type == "Megatron"): ## update configs when using Megatron model structure @@ -149,12 +171,12 @@ def translate(args_dict): decoder_config.decoder_start_token_id = 30522 # Only for megatron t5 model decoder_config.eos_token_id = 30523 # Only for megatron t5 model - print(f"{model_type} encoder_config: {encoder_config}") - print(f"{model_type} decoder_config: {decoder_config}") + LOGGER.debug(f"{model_type} encoder_config: {encoder_config}") + LOGGER.debug(f"{model_type} decoder_config: {decoder_config}") if os.path.isfile("gemm_config.in") and rank == 0: cmd = f"rm gemm_config.in" - print(f"Run {cmd}") + LOGGER.info(f"Run {cmd}") os.system(cmd) translation_result_list = [] if time_args.find("0") != -1: @@ -164,13 +186,14 @@ def translate(args_dict): translation_result_list.append(TranslationResult("ft-beamsearch-warmup", "FT")) translation_result_list.append(TranslationResult("ft-beamsearch", "FT")) if rank == 0: - is_fp16 = 1 if args_dict['data_type'] == 'fp16' else 0 - cmd = f"./bin/t5_gemm {batch_size // pipeline_para_size} {beam_size} {128} " \ + data_type = gemm_data_type_mapping[args_dict['data_type']] + cmd = f"./bin/t5_gemm {math.ceil(batch_size / pipeline_para_size)} {beam_size} {128} " \ f"{encoder_config.d_model} {encoder_config.num_heads} {encoder_config.d_kv} {encoder_config.d_ff} " \ f"{decoder_config.d_model} {decoder_config.num_heads} {decoder_config.d_kv} {decoder_config.d_ff} " \ - f"{decoder_config.vocab_size} {is_fp16} {tensor_para_size} 1 > .tmp_gemm.log" - print(f"Run gemm test: {cmd}") + f"{decoder_config.vocab_size} {data_type} {tensor_para_size} 0 0 > .tmp_gemm.log" + LOGGER.info(f"Run gemm test: {cmd}") os.system(cmd) + if time_args.find("2") != -1: translation_result_list.append(TranslationResult("hf-sampling-warmup", "HF")) translation_result_list.append(TranslationResult("hf-sampling", "HF")) @@ -178,17 +201,33 @@ def translate(args_dict): translation_result_list.append(TranslationResult("ft-sampling-warmup", "FT")) translation_result_list.append(TranslationResult("ft-sampling", "FT")) if rank == 0: - is_fp16 = 1 if args_dict['data_type'] == 'fp16' else 0 - cmd = f"./bin/t5_gemm {batch_size // pipeline_para_size} {1} {128} " \ + data_type = gemm_data_type_mapping[args_dict['data_type']] + cmd = f"./bin/t5_gemm {math.ceil(batch_size / pipeline_para_size)} {1} {128} " \ f"{encoder_config.d_model} {encoder_config.num_heads} {encoder_config.d_kv} {encoder_config.d_ff} " \ f"{decoder_config.d_model} {decoder_config.num_heads} {decoder_config.d_kv} {decoder_config.d_ff} " \ - f"{decoder_config.vocab_size} {is_fp16} {tensor_para_size} 1 1 > .tmp_gemm.log" - print(f"Run gemm test: {cmd}") + f"{decoder_config.vocab_size} {data_type} {tensor_para_size} 0 1 > .tmp_gemm.log" + LOGGER.info(f"Run gemm test: {cmd}") os.system(cmd) if time_args.find("1") != -1 or time_args.find("3") != -1: - ft_encoder_weight = FTT5EncoderWeight(encoder_config, tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type) - ft_decoding_weight = FTT5DecodingWeight(decoder_config, tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type) + ft_encoder_weight = FTT5EncoderWeight( + encoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type, + ) + ft_decoding_weight = FTT5DecodingWeight( + decoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type, + ) if args_dict["ckpt_path"] is not None: ft_encoder_weight.load_from_bin(args_dict["ckpt_path"]) @@ -201,13 +240,19 @@ def translate(args_dict): t5_model = t5_model.half() ft_encoder_weight.to_half() ft_decoding_weight.to_half() + elif args_dict['data_type'] == 'bf16': + t5_model = t5_model ## bfloat inference not supported yet + ft_encoder_weight.to_bfloat16() + ft_decoding_weight.to_bfloat16() remove_padding = True if batch_size > 32 else False ft_encoder = FTT5Encoder(ft_encoder_weight.w, lib_path, encoder_config.num_heads, encoder_config.d_kv, encoder_config.d_ff, encoder_config.d_model, remove_padding, encoder_config.num_layers, encoder_config.relative_attention_num_buckets, - 128, False, q_scaling, tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type) + 128, False, q_scaling, tensor_para_size, pipeline_para_size, t5_with_bias, + position_embedding_type, + activation_type=activation_type,) ft_decoding = FTT5Decoding(ft_decoding_weight.w, lib_path, decoder_config.num_heads, decoder_config.d_kv, decoder_config.d_ff, encoder_config.d_model, @@ -217,7 +262,9 @@ def translate(args_dict): q_scaling, decoder_config.relative_attention_num_buckets, max_distance=128, tensor_para_size=tensor_para_size, pipeline_para_size=pipeline_para_size, - t5_with_bias=t5_with_bias, position_embedding_type = position_embedding_type) + t5_with_bias=t5_with_bias, + position_embedding_type=position_embedding_type, + activation_type=activation_type, tie_word_embeddings=tie_word_embeddings,) ft_t5 = FTT5(ft_encoder, ft_decoding) @@ -257,6 +304,7 @@ def translate(args_dict): if translation_result_list[i].name.find("sampling") != -1: tmp_beam_size = 1 ft_decoding_outputs, ft_decoding_seq_lens = ft_t5(input_token, + None, tmp_beam_size, max_seq_len, topk, @@ -299,9 +347,17 @@ def translate(args_dict): for t in translation_result_list: if t.name.find("warmup") != -1: continue - print(f"[INFO] {t.name} translates {t.batch_num} batches taking {t.execution_time:.2f} sec to translate " + LOGGER.info(f"{t.name} translates {t.batch_num} batches taking {t.execution_time:.2f} sec to translate " f"{t.token_num} tokens, BLEU score: {t.bleu_score.score:.2f}, {(t.token_num / t.execution_time):.0f} tokens/sec." f" ({t.bleu_score.sys_len} words, {(t.bleu_score.sys_len / t.execution_time):.0f} words/sec)") + + if t.name == "ft-beamsearch" and args_dict["ft_beamsearch_BLEU_threshold"] != None: + assert t.bleu_score.score >= args_dict["ft_beamsearch_BLEU_threshold"], f"[ERROR] {t.name} test fail !" + LOGGER.info(f"{t.name} PASS !") + + if t.name == "ft-sampling" and args_dict["ft_sampling_BLEU_threshold"] != None: + assert t.bleu_score.score >= args_dict["ft_sampling_BLEU_threshold"], f"[ERROR] {t.name} test fail !" + LOGGER.info(f"{t.name} PASS !") if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -332,13 +388,13 @@ def translate(args_dict): parser.add_argument('-topp', '--sampling_topp', type=float, default=0.0, metavar='NUMBER', help='Probability (p) value of top p sampling in decoding. Default is 0.0. ') parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', - help='data type (default: fp32)', choices=['fp32', 'fp16']) + help='data type (default: fp32)', choices=['fp32', 'fp16', 'bf16']) parser.add_argument('-lib_path', '--lib_path', type=str, default="lib/libth_t5.so", metavar='STRING', help='the path of FasterTransformer pytorch t5 op library.') parser.add_argument('-model_path', '--model_path', type=str, default=None, metavar='STRING', help='T5 model path.') parser.add_argument('-model', '--model', type=str, default="t5-small", metavar='STRING', - help='T5 model size. Only used when --model_path=None', choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]) + help='T5 model size. Only used when --model_path=None') parser.add_argument('-tensor_para_size', '--tensor_para_size', type=int, default=1, metavar='NUMBER', help='size of tensor parallelism (default: 1)') parser.add_argument('-pipeline_para_size', '--pipeline_para_size', type=int, default=1, metavar='NUMBER', @@ -354,6 +410,12 @@ def translate(args_dict): help='Return the log probability of generated tokens.') parser.add_argument('--return_cum_log_probs', action='store_true', help='Return the cumulative log probability of generated tokens.') + parser.add_argument('--ft_beamsearch_BLEU_threshold', type=float, + help='Threshold of FT beam search BLEU score') + parser.add_argument('--ft_sampling_BLEU_threshold', type=float, + help='Threshold of FT beam search BLEU score') + parser.add_argument("--verbose", action="store_true", help="Provide verbose messages") args = parser.parse_args() - + log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, format=log_format) translate(vars(args)) \ No newline at end of file diff --git a/examples/pytorch/t5/utils/ft_decoding.py b/examples/pytorch/t5/utils/ft_decoding.py index 5cd266e4f..a2cd643b3 100644 --- a/examples/pytorch/t5/utils/ft_decoding.py +++ b/examples/pytorch/t5/utils/ft_decoding.py @@ -17,15 +17,28 @@ import torch.distributed as dist import numpy as np + class FTT5DecodingWeight(object): - def __init__(self, config, tensor_para_size, pipeline_para_size, t5_with_bias = False, position_embedding_type = 0): + def __init__( + self, + config, + tensor_para_size, + pipeline_para_size, + *, + t5_with_bias=False, + use_gated_activation=False, + position_embedding_type=0, + weight_data_type + ): self.config = config self.num_layer = config.num_layers self.tensor_para_size = tensor_para_size self.pipeline_para_size = pipeline_para_size self.t5_with_bias = t5_with_bias + self.use_gated_activation = use_gated_activation self.position_embedding_type = position_embedding_type - self.real_weights_num = 27 if t5_with_bias else 14 + self.real_weights_num = 30 # assume all weights are allocated and converted to specific data type + self.weight_data_type = weight_data_type self.w = [] self.use_mpi = dist.is_mpi_available() @@ -33,7 +46,7 @@ def __init__(self, config, tensor_para_size, pipeline_para_size, t5_with_bias = try: dist.init_process_group(backend='mpi') except: - print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend='mpi'). Maybe the process group has been initialized somewhere else.") + print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend = 'mpi'). Maybe the process group has been initialized somewhere else.") else: print("[INFO] MPI is not available in this PyTorch build.") assert tensor_para_size == 1, "[FATAL] MPI is required for tensor_para_size > 1." @@ -45,17 +58,21 @@ def __init__(self, config, tensor_para_size, pipeline_para_size, t5_with_bias = torch.cuda.set_device(self.device) world_size = dist.get_world_size() if self.use_mpi else 1 - assert world_size == tensor_para_size * pipeline_para_size, "[ERROR] world_size != tensor_para_size * pipeline_para_size" + assert world_size == tensor_para_size * \ + pipeline_para_size, "[ERROR] world_size != tensor_para_size * pipeline_para_size" self.tensor_para_rank = self.rank % self.tensor_para_size self.pipeline_para_rank = self.rank // self.tensor_para_size def load_from_model(self, model): - start_layer = self.pipeline_para_rank * self.num_layer // self.pipeline_para_size - end_layer = (self.pipeline_para_rank + 1) * self.num_layer // self.pipeline_para_size - + start_layer = self.pipeline_para_rank * self.num_layer // self.pipeline_para_size + end_layer = (self.pipeline_para_rank + 1) * self.num_layer // self.pipeline_para_size + + np_weight_dtype = self.weight_data_type + torch_weight_dtype = {np.float32: torch.float32, np.float16: torch.float16}[np_weight_dtype] + weight_dict = {} qkv_tmp = [] - for name, param in model.named_parameters(): + for name, param in model.state_dict().items(): if param.dim() == 2: param_t = param.transpose(1, 0) elif param.dim() == 1: @@ -73,134 +90,209 @@ def load_from_model(self, model): weight_dict[name] = param_t elif name.find("decoder") != -1: weight_dict[name] = param_t + else: + weight_dict[name] = param_t # load by torch model directly - t = torch.stack([weight_dict["decoder.block.{}.layer.0.layer_norm.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.0.layer_norm.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.0.SelfAttention.qkv.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.0.SelfAttention.qkv.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() t = t.reshape([t.shape[0], t.shape[1], 3, t.shape[2] // 3]) t = t.split(t.shape[-1] // self.tensor_para_size, dim=-1)[self.tensor_para_rank].contiguous() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.0.SelfAttention.o.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.0.SelfAttention.o.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() t = t.split(t.shape[1] // self.tensor_para_size, dim=1)[self.tensor_para_rank].contiguous() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.1.layer_norm.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.1.layer_norm.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.1.EncDecAttention.q.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.1.EncDecAttention.q.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() t = t.split(t.shape[-1] // self.tensor_para_size, dim=-1)[self.tensor_para_rank].contiguous() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.1.EncDecAttention.k.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.1.EncDecAttention.k.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() t = t.split(t.shape[-1] // self.tensor_para_size, dim=-1)[self.tensor_para_rank].contiguous() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.1.EncDecAttention.v.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.1.EncDecAttention.v.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() t = t.split(t.shape[-1] // self.tensor_para_size, dim=-1)[self.tensor_para_rank].contiguous() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.1.EncDecAttention.o.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.1.EncDecAttention.o.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() t = t.split(t.shape[1] // self.tensor_para_size, dim=1)[self.tensor_para_rank].contiguous() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.2.layer_norm.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.2.layer_norm.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.2.DenseReluDense.wi.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([weight_dict["decoder.block.{}.layer.2.DenseReluDense.wi.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() t = t.split(t.shape[-1] // self.tensor_para_size, dim=-1)[self.tensor_para_rank].contiguous() self.w.append(t) - t = torch.stack([weight_dict["decoder.block.{}.layer.2.DenseReluDense.wo.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + # empty wi2 weight + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + t = torch.stack([weight_dict["decoder.block.{}.layer.2.DenseReluDense.wo.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() t = t.split(t.shape[1] // self.tensor_para_size, dim=1)[self.tensor_para_rank].contiguous() self.w.append(t) t = weight_dict["decoder.final_layer_norm.weight"].contiguous().cuda() self.w.append(t) t = model.get_output_embeddings().weight.contiguous().cuda() self.w.append(t) + t = weight_dict["lm_head.weight"].transpose(1, 0).contiguous().cuda() # Transpose back to [vocab, hidden] + self.w.append(t) t = weight_dict["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"].contiguous().cuda() t = t.split(t.shape[0] // self.tensor_para_size, dim=0)[self.tensor_para_rank].contiguous() self.w.append(t) - #TODO: pass None Type to Torch Op - for i in range(13): - self.w.append(torch.empty((1,1), dtype=torch.float32).contiguous().cuda()) - + # TODO: pass None Type to Torch Op + for i in range(14): + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + def load_from_bin(self, ckpt_path): - start_layer = self.pipeline_para_rank * self.num_layer // self.pipeline_para_size - end_layer = (self.pipeline_para_rank + 1) * self.num_layer // self.pipeline_para_size - + start_layer = self.pipeline_para_rank * self.num_layer // self.pipeline_para_size + end_layer = (self.pipeline_para_rank + 1) * self.num_layer // self.pipeline_para_size + + np_weight_dtype = self.weight_data_type + torch_weight_dtype = {np.float32: torch.float32, np.float16: torch.float16}[np_weight_dtype] + # load by binary files - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.layer_norm.weight.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.layer_norm.weight.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + self.w.append(t) + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.SelfAttention.qkv.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.SelfAttention.qkv.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.SelfAttention.o.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.SelfAttention.o.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.layer_norm.weight.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.layer_norm.weight.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.q.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.q.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.k.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.k.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.v.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.v.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.o.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.o.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.layer_norm.weight.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.layer_norm.weight.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wi.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wi.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + if self.use_gated_activation: + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wi2.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + self.w.append(t) + else: + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wo.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wo.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.from_numpy(np.fromfile( + f"{ckpt_path}/decoder.final_layer_norm.weight.bin", dtype=np_weight_dtype)).contiguous().cuda() self.w.append(t) - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.final_layer_norm.weight.bin", dtype=np.single)).contiguous().cuda() + t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.weight_T.bin", dtype=np_weight_dtype).reshape( + [self.config.d_model, self.config.vocab_size])).contiguous().cuda() self.w.append(t) - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.weight_T.bin", dtype=np.single).reshape([self.config.d_model, self.config.vocab_size])).contiguous().cuda() + t = torch.from_numpy(np.fromfile(f"{ckpt_path}/lm_head.weight.bin", dtype=np_weight_dtype).reshape( + [self.config.d_model, self.config.vocab_size])).contiguous().cuda() self.w.append(t) t = None if (self.position_embedding_type == 0): - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight.{self.tensor_para_rank}.bin", dtype=np.single)).contiguous().cuda() + t = torch.from_numpy(np.fromfile( + f"{ckpt_path}/decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight.{self.tensor_para_rank}.bin", dtype=np_weight_dtype)).contiguous().cuda() else: - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.ape.bin", dtype=np.single)).contiguous().cuda() + t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.ape.bin", dtype=np_weight_dtype)).contiguous().cuda() self.w.append(t) - - # add 13 additional bias if it is t5 megatron structure + + # add 14 additional bias if it is t5 megatron structure if self.t5_with_bias: - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.layer_norm.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.layer_norm.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.SelfAttention.qkv.bias.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.SelfAttention.qkv.bias.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.SelfAttention.o.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.0.SelfAttention.o.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.layer_norm.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.layer_norm.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.q.bias.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.q.bias.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.k.bias.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.k.bias.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.v.bias.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.v.bias.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.o.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.1.EncDecAttention.o.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.layer_norm.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.layer_norm.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wi.bias.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wi.bias.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wo.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + if self.use_gated_activation: + t = torch.stack([torch.from_numpy(np.fromfile( + f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wi2.bias.{self.tensor_para_rank}.bin", dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + self.w.append(t) + else: + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.block.{i}.layer.2.DenseReluDense.wo.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/decoder.final_layer_norm.bias.bin", dtype=np.single)).contiguous().cuda() + t = torch.from_numpy(np.fromfile( + f"{ckpt_path}/decoder.final_layer_norm.bias.bin", dtype=np_weight_dtype)).contiguous().cuda() self.w.append(t) - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.bias.bin", dtype=np.single)).contiguous().cuda() + t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.bias.bin", dtype=np_weight_dtype)).contiguous().cuda() self.w.append(t) else: - #TODO: pass None Type to Torch Op - for i in range(13): - self.w.append(torch.empty((1,1), dtype=torch.float32).contiguous().cuda()) - + # TODO: pass None Type to Torch Op + for i in range(14): + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + def to_cuda(self): for i in range(self.real_weights_num): self.w[i] = self.w[i].cuda() + def to_float(self): + for i in range(self.real_weights_num): + self.w[i] = self.w[i].float() + def to_half(self): for i in range(self.real_weights_num): self.w[i] = self.w[i].half() + def to_single(self): + for i in range(self.real_weights_num): + self.w[i] = self.w[i].float() + + def to_bfloat16(self): + for i in range(self.real_weights_num): + self.w[i] = self.w[i].bfloat16() + + class FTT5Decoding(nn.Module): def __init__(self, decoding_weight_list, lib_path, head_num, head_size, inter_size, - mem_d_model, d_model, num_layer, start_id, end_id, vocab_size, q_scaling = 1.0, num_bucket=32, - max_distance=128, tensor_para_size=1, pipeline_para_size=1, t5_with_bias=False, position_embedding_type=0): + mem_d_model, d_model, num_layer, start_id, end_id, vocab_size, q_scaling=1.0, num_bucket=32, + max_distance=128, tensor_para_size=1, pipeline_para_size=1, t5_with_bias=False, position_embedding_type=0, + activation_type="relu", tie_word_embeddings=True): super().__init__() self.use_mpi = dist.is_mpi_available() @@ -209,7 +301,7 @@ def __init__(self, decoding_weight_list, lib_path, head_num, head_size, inter_si try: dist.init_process_group(backend='mpi') except: - print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend='mpi'). Maybe the process group has been initialized somewhere else.") + print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend = 'mpi'). Maybe the process group has been initialized somewhere else.") else: print("[INFO] MPI is not available in this PyTorch build.") assert tensor_para_size == 1, "[FATAL] MPI is required for tensor_para_size > 1." @@ -219,58 +311,59 @@ def __init__(self, decoding_weight_list, lib_path, head_num, head_size, inter_si try: self.decoding = torch.classes.FasterTransformer.T5Decoding(head_num, head_size, inter_size, mem_d_model, d_model, num_layer, vocab_size, num_bucket, max_distance, q_scaling, start_id, end_id, - tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type, - *decoding_weight_list) + tensor_para_size, pipeline_para_size, t5_with_bias, + position_embedding_type, activation_type, tie_word_embeddings, *decoding_weight_list) except: self.decoding = torch.classes.FasterTransformerT5Decoding(head_num, head_size, inter_size, mem_d_model, d_model, num_layer, vocab_size, num_bucket, max_distance, q_scaling, start_id, end_id, - tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type, - *decoding_weight_list) + tensor_para_size, pipeline_para_size, t5_with_bias, + position_embedding_type, activation_type, tie_word_embeddings, *decoding_weight_list) def forward(self, beam_width, max_seq_len, top_k, top_p, beam_search_diversity_rate, temperature, len_penalty, repetition_penalty, random_seed, mem_hidden_states, mem_seq_len, - is_return_output_log_probs, is_return_cum_log_probs): + is_return_output_log_probs, is_return_cum_log_probs, is_return_cross_attentions=False): # TODO (bhsueh) Not found an method to put a None Type into op forward function # So, the top_k and top_p must be some values now. results = self.decoding.forward(beam_width, max_seq_len, top_k, top_p, beam_search_diversity_rate, temperature, len_penalty, repetition_penalty, random_seed, mem_hidden_states, mem_seq_len, - is_return_output_log_probs, is_return_cum_log_probs) - results + is_return_output_log_probs, is_return_cum_log_probs, is_return_cross_attentions) return results - + + class FTT5(nn.Module): def __init__(self, encoder, decoding): super().__init__() self.encoder = encoder self.decoding = decoding - - def forward(self, input_token, beam_size, max_seq_len, + + def forward(self, input_token, inputs_embeds, beam_size, max_seq_len, top_k, top_p, beam_search_diversity_rate, - temperature=1.0, len_penalty=1.0, repetition_penalty=1.0, random_seed=0, - is_return_output_log_probs=False, is_return_cum_log_probs=False): + temperature=1.0, len_penalty=0.0, repetition_penalty=1.0, random_seed=0, + is_return_output_log_probs=False, is_return_cum_log_probs=False, is_return_cross_attentions=False): input_ids = input_token.input_ids.to("cuda").type(torch.int32) mem_seq_len = 0 - if hasattr(input_token, "attention_mask") : + if hasattr(input_token, "attention_mask"): mem_seq_len = torch.sum(input_token.attention_mask, dim=1).type(torch.int32).to("cuda") - else : + else: mem_seq_len = input_token.seq_len.type(torch.int32).to("cuda") - ft_encoder_outputs = self.encoder.forward(input_ids, mem_seq_len) - results = self.decoding.forward(beam_size, + ft_encoder_outputs = self.encoder.forward(input_ids, mem_seq_len, inputs_embeds) + results = self.decoding.forward(beam_size, # optional, can be None max_seq_len, - top_k, - top_p, - beam_search_diversity_rate, - temperature, - len_penalty, - repetition_penalty, - random_seed, - is_return_output_log_probs, - is_return_cum_log_probs, + top_k, # optional, can be None + top_p, # optional, can be None + beam_search_diversity_rate, # optional, can be None + temperature, # optional, can be None + len_penalty, # optional, can be None + repetition_penalty, # optional, can be None + random_seed, # optional, can be None + is_return_output_log_probs, # optional, can be None + is_return_cum_log_probs, # optional, can be None + is_return_cross_attentions, # optional, can be None ft_encoder_outputs, mem_seq_len) ft_decoding_outputs = results.pop(0).reshape([-1, beam_size, max_seq_len]) @@ -279,5 +372,8 @@ def forward(self, input_token, beam_size, max_seq_len, ft_output_log_probs = results.pop(0) if is_return_cum_log_probs: ft_cum_log_probs = results.pop(0) + if is_return_cross_attentions: + ft_cross_attentions = results.pop(0) + return ft_decoding_outputs.cpu().numpy(), ft_decoding_seq_lens.cpu().numpy(), ft_cross_attentions.cpu().numpy() - return ft_decoding_outputs.cpu().numpy(), ft_decoding_seq_lens.cpu().numpy() \ No newline at end of file + return ft_decoding_outputs.cpu().numpy(), ft_decoding_seq_lens.cpu().numpy() diff --git a/examples/pytorch/t5/utils/ft_encoder.py b/examples/pytorch/t5/utils/ft_encoder.py index 6f9f09e7f..bab00bd7a 100644 --- a/examples/pytorch/t5/utils/ft_encoder.py +++ b/examples/pytorch/t5/utils/ft_encoder.py @@ -17,15 +17,28 @@ import torch.distributed as dist import numpy as np + class FTT5EncoderWeight(object): - def __init__(self, config, tensor_para_size, pipeline_para_size, t5_with_bias = False, position_embedding_type = 0): + def __init__( + self, + config, + tensor_para_size, + pipeline_para_size, + *, + t5_with_bias=False, + use_gated_activation=False, + position_embedding_type=0, + weight_data_type + ): self.num_layer = config.num_layers self.config = config self.tensor_para_size = tensor_para_size self.pipeline_para_size = pipeline_para_size self.t5_with_bias = t5_with_bias - self.real_weights_num = 20 if t5_with_bias else 11 + self.use_gated_activation = use_gated_activation + self.real_weights_num = 22 # assume all weights are allocated self.position_embedding_type = position_embedding_type + self.weight_data_type = weight_data_type self.w = [] self.use_mpi = dist.is_mpi_available() @@ -33,7 +46,7 @@ def __init__(self, config, tensor_para_size, pipeline_para_size, t5_with_bias = try: dist.init_process_group(backend='mpi') except: - print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend='mpi'). Maybe the process group has been initialized somewhere else.") + print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend = 'mpi'). Maybe the process group has been initialized somewhere else.") else: print("[INFO] MPI is not available in this PyTorch build.") assert tensor_para_size == 1, "[FATAL] MPI is required for tensor_para_size > 1." @@ -45,14 +58,18 @@ def __init__(self, config, tensor_para_size, pipeline_para_size, t5_with_bias = torch.cuda.set_device(self.device) world_size = dist.get_world_size() if self.use_mpi else 1 - assert world_size == tensor_para_size * pipeline_para_size, "[ERROR] world_size != tensor_para_size * pipeline_para_size" + assert world_size == tensor_para_size * \ + pipeline_para_size, "[ERROR] world_size != tensor_para_size * pipeline_para_size" self.tensor_para_rank = self.rank % self.tensor_para_size self.pipeline_para_rank = self.rank // self.tensor_para_size - def load_from_model(self, model): # assume this only applies to huggingface models - start_layer = self.pipeline_para_rank * self.num_layer // self.pipeline_para_size - end_layer = (self.pipeline_para_rank + 1) * self.num_layer // self.pipeline_para_size - + def load_from_model(self, model): # assume this only applies to huggingface models + start_layer = self.pipeline_para_rank * self.num_layer // self.pipeline_para_size + end_layer = (self.pipeline_para_rank + 1) * self.num_layer // self.pipeline_para_size + + np_weight_dtype = self.weight_data_type + torch_weight_dtype = {np.float32: torch.float32, np.float16: torch.float16}[np_weight_dtype] + encoder_weight_dict = {} for name, param in model.named_parameters(): if param.dim() == 2: @@ -64,21 +81,31 @@ def load_from_model(self, model): # assume this only applies to huggingface mode if name.find("encoder.block") != -1 or name.find("encoder.final_layer_norm.weight") != -1: encoder_weight_dict[name] = param_t - t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.layer_norm.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.layer_norm.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.SelfAttention.q.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.SelfAttention.q.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t.split(t.shape[-1] // self.tensor_para_size, dim=-1)[self.tensor_para_rank].contiguous()) - t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.SelfAttention.k.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.SelfAttention.k.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t.split(t.shape[-1] // self.tensor_para_size, dim=-1)[self.tensor_para_rank].contiguous()) - t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.SelfAttention.v.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.SelfAttention.v.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t.split(t.shape[-1] // self.tensor_para_size, dim=-1)[self.tensor_para_rank].contiguous()) - t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.SelfAttention.o.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.0.SelfAttention.o.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t.split(t.shape[1] // self.tensor_para_size, dim=1)[self.tensor_para_rank].contiguous()) - t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.1.layer_norm.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.1.layer_norm.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.1.DenseReluDense.wi.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.1.DenseReluDense.wi.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t.split(t.shape[-1] // self.tensor_para_size, dim=-1)[self.tensor_para_rank].contiguous()) - t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.1.DenseReluDense.wo.weight".format(i)] for i in range(start_layer, end_layer)], 0).contiguous().cuda() + # add empty wi2 + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + t = torch.stack([encoder_weight_dict["encoder.block.{}.layer.1.DenseReluDense.wo.weight".format(i)] + for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t.split(t.shape[1] // self.tensor_para_size, dim=1)[self.tensor_para_rank].contiguous()) t = encoder_weight_dict["encoder.final_layer_norm.weight"].contiguous().cuda() self.w.append(t) @@ -87,77 +114,127 @@ def load_from_model(self, model): # assume this only applies to huggingface mode t = model.get_input_embeddings().weight.contiguous().cuda() self.w.append(t) - #TODO: pass None Type to Torch Op - for i in range(9): - self.w.append(torch.empty((1,1), dtype=torch.float32).contiguous().cuda()) - - def load_from_bin(self, ckpt_path): # assume this only applies to megatron models - start_layer = self.pipeline_para_rank * self.num_layer // self.pipeline_para_size - end_layer = (self.pipeline_para_rank + 1) * self.num_layer // self.pipeline_para_size + # TODO: pass None Type to Torch Op + for i in range(10): + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + + def load_from_bin(self, ckpt_path): # assume this only applies to megatron models + start_layer = self.pipeline_para_rank * self.num_layer // self.pipeline_para_size + end_layer = (self.pipeline_para_rank + 1) * self.num_layer // self.pipeline_para_size + + np_weight_dtype = self.weight_data_type + torch_weight_dtype = {np.float32: torch.float32, np.float16: torch.float16}[np_weight_dtype] + # load by binary files - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.layer_norm.weight.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.layer_norm.weight.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.q.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.q.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.k.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.k.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.v.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.v.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.o.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.o.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.layer_norm.weight.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.layer_norm.weight.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wi.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wi.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wo.weight.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + if self.use_gated_activation: + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wi2.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + self.w.append(t) + else: + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wo.weight.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.final_layer_norm.weight.bin", dtype=np.single)).contiguous().cuda() + t = torch.from_numpy(np.fromfile( + f"{ckpt_path}/encoder.final_layer_norm.weight.bin", dtype=np_weight_dtype)).contiguous().cuda() self.w.append(t) t = None - if (self.position_embedding_type == 0): - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight.{self.tensor_para_rank}.bin", dtype=np.single)).contiguous().cuda() + if (self.position_embedding_type == 0): + t = torch.from_numpy(np.fromfile( + f"{ckpt_path}/encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight.{self.tensor_para_rank}.bin", dtype=np_weight_dtype)).contiguous().cuda() else: - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.ape.bin", dtype=np.single)).contiguous().cuda() + t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.ape.bin", dtype=np_weight_dtype)).contiguous().cuda() self.w.append(t) - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.weight_T.bin", dtype=np.single).reshape([self.config.d_model, self.config.vocab_size])).contiguous().cuda() + t = torch.from_numpy(np.fromfile(f"{ckpt_path}/shared.weight_T.bin", dtype=np_weight_dtype).reshape( + [self.config.d_model, self.config.vocab_size])).contiguous().cuda() self.w.append(t) - - # add 9 additional bias if it is t5 megatron structure + + # add 10 additional bias if it is t5 megatron structure if self.t5_with_bias: - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.layer_norm.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.layer_norm.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.q.bias.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.q.bias.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.k.bias.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.k.bias.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.v.bias.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.v.bias.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.o.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.0.SelfAttention.o.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.layer_norm.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.layer_norm.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wi.bias.{self.tensor_para_rank}.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wi.bias.{self.tensor_para_rank}.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wo.bias.bin", dtype=np.single)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + if self.use_gated_activation: + t = torch.stack([torch.from_numpy(np.fromfile( + f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wi2.bias.{self.tensor_para_rank}.bin", dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() + self.w.append(t) + else: + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + t = torch.stack([torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.block.{i}.layer.1.DenseReluDense.wo.bias.bin", + dtype=np_weight_dtype)) for i in range(start_layer, end_layer)], 0).contiguous().cuda() self.w.append(t) - t = torch.from_numpy(np.fromfile(f"{ckpt_path}/encoder.final_layer_norm.bias.bin", dtype=np.single)).contiguous().cuda() + t = torch.from_numpy(np.fromfile( + f"{ckpt_path}/encoder.final_layer_norm.bias.bin", dtype=np_weight_dtype)).contiguous().cuda() self.w.append(t) else: - #TODO: pass None Type to Torch Op - for i in range(9): - self.w.append(torch.empty((1,1), dtype=torch.float32).contiguous().cuda()) - + # TODO: pass None Type to Torch Op + for i in range(10): + self.w.append(torch.empty((1, 1), dtype=torch_weight_dtype).contiguous().cuda()) + def to_cuda(self): for i in range(self.real_weights_num): self.w[i] = self.w[i].cuda() + def to_float(self): + for i in range(self.real_weights_num): + self.w[i] = self.w[i].float() + def to_half(self): for i in range(self.real_weights_num): self.w[i] = self.w[i].half() + def to_single(self): + for i in range(self.real_weights_num): + self.w[i] = self.w[i].float() + + def to_bfloat16(self): + for i in range(self.real_weights_num): + self.w[i] = self.w[i].bfloat16() + + class FTT5Encoder(nn.Module): def __init__(self, encoder_weight_list, lib_path, head_num, head_size, inter_size, d_model, is_remove_padding, - num_layer, num_bucket=32, max_distance=128, sparse=False, q_scaling=1.0, tensor_para_size=1, pipeline_para_size=1, t5_with_bias=False, position_embedding_type=0): + num_layer, num_bucket=32, max_distance=128, sparse=False, q_scaling=1.0, tensor_para_size=1, pipeline_para_size=1, t5_with_bias=False, + position_embedding_type=0, activation_type="relu"): super().__init__() self.use_mpi = dist.is_mpi_available() @@ -166,7 +243,7 @@ def __init__(self, encoder_weight_list, lib_path, head_num, head_size, inter_siz try: dist.init_process_group(backend='mpi') except: - print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend='mpi'). Maybe the process group has been initialized somewhere else.") + print("[INFO] WARNING: Exception occurred in dist.init_process_group(backend = 'mpi'). Maybe the process group has been initialized somewhere else.") else: print("[INFO] MPI is not available in this PyTorch build.") assert tensor_para_size == 1, "[FATAL] MPI is required for tensor_para_size > 1." @@ -175,11 +252,13 @@ def __init__(self, encoder_weight_list, lib_path, head_num, head_size, inter_siz torch.classes.load_library(lib_path) try: self.encoder = torch.classes.FasterTransformer.T5Encoder(*encoder_weight_list, head_num, head_size, inter_size, d_model, - is_remove_padding, num_layer, num_bucket, max_distance, sparse, q_scaling, tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type) + is_remove_padding, num_layer, num_bucket, max_distance, sparse, q_scaling, tensor_para_size, pipeline_para_size, + t5_with_bias, position_embedding_type, activation_type) except: self.encoder = torch.classes.FasterTransformerT5Encoder(*encoder_weight_list, head_num, head_size, inter_size, d_model, - is_remove_padding, num_layer, num_bucket, max_distance, sparse, q_scaling, tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type) + is_remove_padding, num_layer, num_bucket, max_distance, sparse, q_scaling, tensor_para_size, pipeline_para_size, + t5_with_bias, position_embedding_type, activation_type) - def forward(self, input, seq_len): - output = self.encoder.forward(input, seq_len) + def forward(self, input, seq_len, inputs_embeds=None): + output = self.encoder.forward(input, seq_len, inputs_embeds) return output \ No newline at end of file diff --git a/examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py b/examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py index 41cb01d76..204d7a38b 100644 --- a/examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py +++ b/examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py @@ -15,6 +15,7 @@ import argparse import configparser from datetime import datetime +import logging from pathlib import Path import sys @@ -27,8 +28,10 @@ import numpy as np import torch # pytype: disable=import-error +LOGGER = logging.getLogger(__name__) + rename_mapping={"relative_attention_num_buckets":"relative_attention_num_buckets_or_max_pos_seq_len"} -new_configs={"structure":{"t5_with_bias":"0", "position_embedding_type":"0"}} +new_configs={"structure":{"t5_with_bias":"false", "use_gated_activation":"false", "position_embedding_type":"relative"}} def get_weight_data_type(data_type): if data_type == "fp32": @@ -59,8 +62,11 @@ def fuse_decoder_qkv(model, factor, saved_dir, np_weight_data_type): split_vals[j].tofile(saved_path.as_posix()) def split_and_convert_process(key, val, factor, saved_dir, np_weight_data_type): - val = val.T.cpu().detach().numpy().astype(np_weight_data_type) + if val.dim() == 2: + val = val.transpose(1, 0) + val = val.cpu().detach().numpy().astype(np_weight_data_type) saved_key = key + LOGGER.debug(f"key: {key}, val.shape: {val.shape}") if key.find("shared.weight") != -1: # shared weights, only need to convert the weights of rank 0 @@ -69,6 +75,12 @@ def split_and_convert_process(key, val, factor, saved_dir, np_weight_data_type): saved_path = saved_dir / f"{saved_key}_T.bin" val.T.tofile(saved_path.as_posix()) + elif key.find("lm_head.weight") != -1: + # lm_head weights, only need to convert the weights of rank 0 + val = val.transpose(1, 0) # For lm_head, we use TN gemm to compute, so we don't need to transpose + saved_path = saved_dir / f"{saved_key}.bin" + val.tofile(saved_path.as_posix()) + elif key.find("layer_norm.weight") != -1: # shared weights, only need to convert the weights of rank 0 saved_path = saved_dir / f"{saved_key}.bin" @@ -100,6 +112,19 @@ def split_and_convert_process(key, val, factor, saved_dir, np_weight_data_type): for j in range(factor): saved_path = saved_dir / f"{saved_key}.{j:d}.bin" split_vals[j].tofile(saved_path.as_posix()) + elif ( + key.find("DenseReluDense.wi_0.weight") != -1 + or key.find("DenseReluDense.wi_1.weight") != -1 + ): + # For gated activation. + if key.find("DenseReluDense.wi_0.weight") != -1: + saved_key = key.replace("wi_0", "wi") + elif key.find("DenseReluDense.wi_1.weight") != -1: + saved_key = key.replace("wi_1", "wi2") + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir / f"{saved_key}.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) elif key.find("relative_attention_bias") != -1: split_vals = np.split(val, factor, axis=0) for j in range(factor): @@ -114,16 +139,22 @@ def split_and_convert_process(key, val, factor, saved_dir, np_weight_data_type): ) ): pass + elif key.find("encoder.embed_tokens.weight") != -1 or \ + key.find("decoder.embed_tokens.weight") != -1: + LOGGER.warning(f"Not save {key}, using shared.weight directly.") else: - print(f"[ERROR] cannot find key '{key}'") + LOGGER.warning(f"cannot find key '{key}' with shape {val.shape}") def convert_checkpoint(args): - saved_dir = Path(args.saved_dir) / f"{args.infer_gpu_num:d}-gpu" + saved_dir = Path(args.saved_dir) / f"{args.inference_tensor_para_size:d}-gpu" saved_dir.mkdir(parents=True, exist_ok=True) t5_model = T5ForConditionalGeneration.from_pretrained(args.in_file) config = configparser.ConfigParser() + if t5_model.encoder.config.feed_forward_proj.find("gated") != -1: + new_configs["structure"]["use_gated_activation"] = "1" + config["encoder"] = {} for key, val in t5_model.encoder.config.to_dict().items(): config["encoder"][key] = f"{val}" @@ -143,9 +174,9 @@ def convert_checkpoint(args): config.write(configfile) np_weight_data_type = get_weight_data_type(args.weight_data_type) - i_gpu_num = args.infer_gpu_num + i_gpu_num = args.inference_tensor_para_size - for name, param in t5_model.named_parameters(): + for name, param in t5_model.state_dict().items(): split_and_convert_process(name, param, i_gpu_num, saved_dir, np_weight_data_type) fuse_decoder_qkv(t5_model, i_gpu_num, saved_dir, np_weight_data_type) @@ -153,16 +184,19 @@ def convert_checkpoint(args): parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("-saved_dir", "-o", type=str, help="file name of output file", required=True) parser.add_argument("-in_file", "-i", type=str, help="file name of input checkpoint file", required=True) - parser.add_argument("-infer_gpu_num", "-i_g", type=int, help="How many gpus for inference", required=True) + parser.add_argument("-inference_tensor_para_size", "-i_g", type=int, help="How many gpus for inference", required=True) parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) + parser.add_argument("--verbose", action="store_true", help="Provide verbose messages") args = parser.parse_args() - print("\n=============== Argument ===============") + log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, format=log_format) + LOGGER.info("\n=============== Argument ===============") for key in vars(args): - print(f"{key}: {vars(args)[key]}") - print("========================================") + LOGGER.info(f"{key}: {vars(args)[key]}") + LOGGER.info("========================================") start_time = datetime.now() convert_checkpoint(args) stop_time = datetime.now() run_time = (stop_time - start_time) - print("[INFO] Spend {} (h:m:s) to convert the model".format(run_time)) + LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time)) diff --git a/examples/pytorch/t5/utils/megatron_t5_ckpt_convert.py b/examples/pytorch/t5/utils/megatron_t5_ckpt_convert.py index 7d0d01b2d..f030289f6 100644 --- a/examples/pytorch/t5/utils/megatron_t5_ckpt_convert.py +++ b/examples/pytorch/t5/utils/megatron_t5_ckpt_convert.py @@ -16,6 +16,7 @@ import configparser from datetime import datetime import multiprocessing +import shutil from pathlib import Path import numpy as np @@ -89,7 +90,7 @@ "eos_token_id":30523 ## need to adjust } -model_new_config = {"structure":{"t5_with_bias":1, "position_embedding_type":0}} +model_new_config = {"structure":{"t5_with_bias": "true", "use_gated_activation": "false", "position_embedding_type": "absolute"}} def get_weight_data_type(data_type): if data_type == "fp32": @@ -351,10 +352,16 @@ def split_and_convert_process(model_type, i, pipeline_para_rank, saved_dir, fact else: print(f"[ERROR] cannot find key '{key}'") + def convert_checkpoint(args): saved_dir = Path(args.saved_dir) / f"{args.infer_gpu_num:d}-gpu" saved_dir.mkdir(parents=True, exist_ok=True) + if args.vocab_path: + shutil.copy(args.vocab_path, (saved_dir / "vocab.json").as_posix()) + if args.merges_path: + shutil.copy(args.merges_path, (saved_dir / "merges.txt").as_posix()) + prefix = Path(args.in_file) ckpt_name = "model_optim_rng.pt" @@ -371,7 +378,7 @@ def convert_checkpoint(args): # update model structure config if not hasattr(model_args, 'position_embedding_type') or model_args.position_embedding_type == "absolute": - model_new_config["structure"]["position_embedding_type"] = 1 + model_new_config["structure"]["position_embedding_type"] = "absolute" config = configparser.ConfigParser() config["encoder"] = {} @@ -451,6 +458,7 @@ def convert_checkpoint(args): w_e_list = [] torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_sharing_strategy("file_system") pool = multiprocessing.Pool(args.processes) for i in range(main_loop): for j in range(model_args.pipeline_model_parallel_size): @@ -535,6 +543,15 @@ def convert_checkpoint(args): parser.add_argument("-processes", "-p", type=int, help="How many processes to spawn for conversion (default: 64)", default=64) parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) parser.add_argument("-model_name", "-m", type=str, help="model name", required=True) + parser.add_argument( + "--vocab-path", + type=str, + help="Path to vocabulary file to embed in FasterTransformer checkpoint", + required=False, + ) + parser.add_argument( + "--merges-path", type=str, help="Path to merges file to embed in FasterTransformer checkpoint", required=False + ) args = parser.parse_args() print("\n=============== Argument ===============") for key in vars(args): diff --git a/examples/pytorch/t5/utils/nemo_t5_ckpt_convert.py b/examples/pytorch/t5/utils/nemo_t5_ckpt_convert.py index 3e881b89b..e2b58ce5a 100644 --- a/examples/pytorch/t5/utils/nemo_t5_ckpt_convert.py +++ b/examples/pytorch/t5/utils/nemo_t5_ckpt_convert.py @@ -13,155 +13,143 @@ # limitations under the License. import argparse +import concurrent.futures import configparser -from datetime import datetime -import multiprocessing -from pathlib import Path +import datetime +import logging +import pathlib +import shutil +import sys +import tempfile +import typing import numpy as np import torch # pytype: disable=import-error -import sys -import glob -import os -import tarfile -import yaml -sys.path.append("/workdir/megatron-lm") +# verify if root package is in PYTHONPATH +__root_package_path__ = pathlib.Path(__file__).parent.parent.parent.parent.parent.absolute().as_posix() +if __root_package_path__ not in sys.path: + print( + f"[ERROR] add project root directory to your PYTHONPATH with " + f"'export PYTHONPATH={__root_package_path__}:${{PYTHONPATH}}'" + ) + +from examples.pytorch.nemo import unpack_nemo_ckpt, UnpackedNemoCheckpointDir, extract_layers_with_prefix +from examples.pytorch.utils import gpu_map_location, WEIGHT2DTYPE, torch2np, cpu_map_location, safe_transpose + + +LOGGER = logging.getLogger(__name__) + shared_mapping = { - "wte":"shared.weight", - "wte_T":"shared.weight_T", - "ape":"shared.ape", - "encoder_rpe":"block.0.layer.0.SelfAttention.relative_attention_bias", - "decoder_rpe":"block.0.layer.0.SelfAttention.relative_attention_bias" + "wte": "shared.weight", + "wte_T": "shared.weight_T", + "ape": "shared.ape", + "encoder_rpe": "block.0.layer.0.SelfAttention.relative_attention_bias", + "decoder_rpe": "block.0.layer.0.SelfAttention.relative_attention_bias", } encoder_mapping = { - "input_layernorm":"layer.0.layer_norm", - "self_attention.query_key_value":["layer.0.SelfAttention.q", "layer.0.SelfAttention.k", "layer.0.SelfAttention.v"], - "self_attention.dense":"layer.0.SelfAttention.o", - "post_attention_layernorm":"layer.1.layer_norm", - "mlp.dense_h_to_4h":"layer.1.DenseReluDense.wi", - "mlp.dense_4h_to_h":"layer.1.DenseReluDense.wo", - "final_layernorm":"final_layer_norm" + "input_layernorm": "layer.0.layer_norm", + "self_attention.query_key_value": ["layer.0.SelfAttention.q", "layer.0.SelfAttention.k", "layer.0.SelfAttention.v"], + "self_attention.dense": "layer.0.SelfAttention.o", + "post_attention_layernorm": "layer.1.layer_norm", + "mlp.dense_h_to_4h": "layer.1.DenseReluDense.wi", + "mlp.dense_h_to_4h_2": "layer.1.DenseReluDense.wi2", ## gated activation + "mlp.dense_4h_to_h": "layer.1.DenseReluDense.wo", + "final_layernorm": "final_layer_norm", } decoder_mapping = { - "input_layernorm":"layer.0.layer_norm", - "self_attention.query_key_value":["layer.0.SelfAttention.qkv"], - "self_attention.dense":"layer.0.SelfAttention.o", - "post_attention_layernorm":"layer.1.layer_norm", - "inter_attention.query":["layer.1.EncDecAttention.q"], - "inter_attention.key_value":["layer.1.EncDecAttention.k","layer.1.EncDecAttention.v"], - "inter_attention.dense":"layer.1.EncDecAttention.o", - "post_inter_attention_layernorm":"layer.2.layer_norm", - "mlp.dense_h_to_4h":"layer.2.DenseReluDense.wi", - "mlp.dense_4h_to_h":"layer.2.DenseReluDense.wo", - "final_layernorm":"final_layer_norm" + "input_layernorm": "layer.0.layer_norm", + "self_attention.query_key_value": ["layer.0.SelfAttention.qkv"], + "self_attention.dense": "layer.0.SelfAttention.o", + "post_attention_layernorm": "layer.1.layer_norm", + "inter_attention.query": ["layer.1.EncDecAttention.q"], + "inter_attention.key_value": ["layer.1.EncDecAttention.k", "layer.1.EncDecAttention.v"], + "inter_attention.dense": "layer.1.EncDecAttention.o", + "post_inter_attention_layernorm": "layer.2.layer_norm", + "mlp.dense_h_to_4h": "layer.2.DenseReluDense.wi", + "mlp.dense_h_to_4h_2": "layer.2.DenseReluDense.wi2", + "mlp.dense_4h_to_h": "layer.2.DenseReluDense.wo", + "final_layernorm": "final_layer_norm", } -megatron_HF_name_mapping = { - "shared":shared_mapping, - "encoder":encoder_mapping, - "decoder":decoder_mapping -} +megatron_HF_name_mapping = {"shared": shared_mapping, "encoder": encoder_mapping, "decoder": decoder_mapping} encoder_config_mapping = { - "num_attention_heads":"num_heads", - "hidden_size":"d_model", - "kv_channels":"d_kv", - "ffn_hidden_size":"d_ff", - "num_layers":"num_layers", - "max_position_embeddings":"relative_attention_num_buckets_or_max_pos_seq_len", - "relative_position_num_buckets":"relative_attention_num_buckets_or_max_pos_seq_len" + "num_attention_heads": "num_heads", + "hidden_size": "d_model", + "kv_channels": "d_kv", + "ffn_hidden_size": "d_ff", + "num_layers": "num_layers", + "max_position_embeddings": "relative_attention_num_buckets_or_max_pos_seq_len", + "relative_position_num_buckets": "relative_attention_num_buckets_or_max_pos_seq_len", + "activation": "feed_forward_proj", } decoder_config_mapping = { - "num_attention_heads":"num_heads", - "hidden_size":"d_model", - "kv_channels":"d_kv", - "ffn_hidden_size":"d_ff", - "num_layers":"num_layers", - "max_position_embeddings":"relative_attention_num_buckets_or_max_pos_seq_len", - "relative_position_num_buckets":"relative_attention_num_buckets_or_max_pos_seq_len" -} - -decoder_new_config = { - "decoder_start_token_id":0, ## need to adjust - "eos_token_id":1 ## need to adjust + "num_attention_heads": "num_heads", + "hidden_size": "d_model", + "kv_channels": "d_kv", + "ffn_hidden_size": "d_ff", + "num_layers": "num_layers", + "max_position_embeddings": "relative_attention_num_buckets_or_max_pos_seq_len", + "relative_position_num_buckets": "relative_attention_num_buckets_or_max_pos_seq_len", } -model_new_config = {"structure":{"t5_with_bias":1, "position_embedding_type":0}} -def convert_megatron_to_HF_naming_style_single(saved_key, name_mapping): - saved_key = saved_key.replace("layers","block") +def megatron2hf_name(saved_key, name_mapping): + saved_key = saved_key.replace("layers", "block") mapping_key = saved_key.rsplit(sep=".", maxsplit=1)[0] mapping_key_no_num = mapping_key[mapping_key.find(".", 6) + 1 :] block_num = mapping_key[: mapping_key.find(".", 6) + 1] weight_or_bias = saved_key.rsplit(sep=".", maxsplit=1)[1] - saved_key = block_num + name_mapping[mapping_key_no_num] + "." + weight_or_bias - return saved_key -def convert_megatron_to_HF_naming_style_multiple(saved_key, name_mapping): - saved_key = saved_key.replace("layers","block") - mapping_key = saved_key.rsplit(sep=".", maxsplit=1)[0] - mapping_key_no_num = mapping_key[mapping_key.find(".", 6) + 1 :] mapping_vals_no_num = name_mapping[mapping_key_no_num] - block_num = mapping_key[: mapping_key.find(".", 6) + 1] - weight_or_bias = saved_key.rsplit(sep=".", maxsplit=1)[1] + if not isinstance(mapping_vals_no_num, list): + mapping_vals_no_num = [mapping_vals_no_num] + saved_keys = [block_num + val + "." + weight_or_bias for val in mapping_vals_no_num] return saved_keys -def unpack_nemo_ckpt(nemo_ckpt_path, out_folder): - """ - .nemo file is an archive (tar.gz) with the following: - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor - model_wights.chpt - model checkpoint - """ - if not os.path.exists(nemo_ckpt_path): - raise FileNotFoundError(f"{nemo_ckpt_path} does not exist") - tar_header = "r:" - try: - tar = tarfile.open(nemo_ckpt_path, tar_header) - except tarfile.ReadError: - # can be older checkpoint => try compressed tar - tar_header = "r:gz" - tar = tarfile.open(nemo_ckpt_path, tar_header) - tar.extractall(path=out_folder) - tar.close() - return out_folder - -def get_weight_data_type(data_type): - if data_type == "fp32": - return np.float32 - elif data_type == "fp16": - return np.float16 - else: - assert False, f"Invalid weight data type {data_type}" - -def _gpu_map_location(storage, loc): - if loc.startswith("cuda"): - training_gpu_idx = int(loc.split(":")[1]) - inference_gpu_idx = training_gpu_idx % torch.cuda.device_count() - return storage.cuda(inference_gpu_idx) - elif loc.startswith("cpu"): - return storage.cpu() - else: - raise NotImplementedError(f"Not handled {loc}") # This tool is used to support the new megatron model trained by pipeline parallel + tensor parallel -def merge_and_convert_process(model_type, i, pipeline_para_rank, saved_dir, factor, key, model_args, transformer_model_list, np_weight_data_type): +def merge_and_convert_process( + model_type, + tensor_para_rank, + pipeline_para_rank, + saved_dir, + factor, + key, + nemo_model_config, + models_list, + np_weight_data_type, +): + assert model_type == "encoder" or model_type == "decoder" prefix = model_type + pipeline_para_size = nemo_model_config["pipeline_model_parallel_size"] + pipeline_model_parallel_split_rank = nemo_model_config.get("pipeline_model_parallel_split_rank", 0) + num_layers = nemo_model_config["num_layers"] + + major_device = models_list[0][key].device + name_mapping = megatron_HF_name_mapping[model_type] - saved_dir = Path(saved_dir) + saved_dir = pathlib.Path(saved_dir) if key.find("layers.") != -1: - layer_index = (int)(key[7 : key.find(".", 7)]) - saved_key = key.replace( - "layers.%d." % layer_index, - "layers.%d." % (layer_index + pipeline_para_rank * model_args['num_layers'] // model_args['pipeline_model_parallel_size'])) + layer_index = int(key[7 : key.find(".", 7)]) + encoder_num_pipeline_stages = pipeline_model_parallel_split_rank + decoder_num_pipeline_stages = pipeline_para_size - pipeline_model_parallel_split_rank + offset = 0 + if model_type == "encoder" and pipeline_para_size > 1: + offset = pipeline_para_rank * (num_layers // encoder_num_pipeline_stages) + elif model_type == "decoder" and pipeline_para_size > 1: + offset = (pipeline_para_rank - pipeline_model_parallel_split_rank) * ( + num_layers // decoder_num_pipeline_stages + ) + saved_key = key.replace(f"layers.{layer_index}.", f"layers.{layer_index + offset}.") else: saved_key = key - major_device = transformer_model_list[0][key].device if ( key.find("input_layernorm.weight") != -1 @@ -174,105 +162,229 @@ def merge_and_convert_process(model_type, i, pipeline_para_rank, saved_dir, fact or key.find("post_inter_attention_layernorm.bias") != -1 or key.find("mlp.dense_4h_to_h.bias") != -1 or key.find("final_layernorm.weight") != -1 - or key.find("final_layernorm.bias") != -1): - + or key.find("final_layernorm.bias") != -1 + ): # shared weights, only need to convert the weights of rank 0 - if i == 0: - val = transformer_model_list[0][key].T.float().cpu().numpy() - saved_key = convert_megatron_to_HF_naming_style_single(saved_key, name_mapping) - saved_path = saved_dir / f"{prefix}.{saved_key}.bin" - np.squeeze(val).astype(np_weight_data_type).tofile(saved_path) - - elif (key.find("self_attention.dense.weight") != -1 - or key.find("inter_attention.dense.weight") != -1 - or key.find("mlp.dense_4h_to_h.weight") != -1): + if tensor_para_rank == 0: + saved_keys = megatron2hf_name(saved_key, name_mapping) + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.bin" + val = safe_transpose(models_list[0][key]) + val = torch2np(val, np_weight_data_type) + val = np.squeeze(val) + LOGGER.debug( + "merge for pp_rank=%d tp_rank=%d only for tp_rank=0 src_key=%s filename=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + key, + saved_path.name, + val.shape, + val.dtype, + ) + val.tofile(saved_path) + + elif ( + key.find("self_attention.dense.weight") != -1 + or key.find("inter_attention.dense.weight") != -1 + or key.find("mlp.dense_4h_to_h.weight") != -1 + ): vals = [] for k in range(factor): - vals.append(transformer_model_list[k][key].T.float().to(major_device)) - saved_key = convert_megatron_to_HF_naming_style_single(saved_key, name_mapping) - torch.cat(vals, dim=0).cpu().numpy().astype(np_weight_data_type).tofile(saved_path) - - elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: + val = safe_transpose(models_list[k][key]) + val = val.float().to(major_device) + vals.append(val) + saved_keys = megatron2hf_name(saved_key, name_mapping) + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{tensor_para_rank:d}.bin" + val = torch.cat(vals, dim=0) + val = torch2np(val, np_weight_data_type) + LOGGER.debug( + "merge for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + val.dtype, + ) + val.tofile(saved_path) + + elif ( + key.find("mlp.dense_h_to_4h.weight") != -1 + or key.find("mlp.dense_h_to_4h.bias") != -1 + or key.find("mlp.dense_h_to_4h_2.weight") != -1 + or key.find("mlp.dense_h_to_4h_2.bias") != -1 + ): vals = [] for k in range(factor): - vals.append(transformer_model_list[k][key].T.float().to(major_device)) - saved_key = convert_megatron_to_HF_naming_style_single(saved_key, name_mapping) - saved_path = saved_dir / f"{prefix}.{saved_key}.{i:d}.bin" - torch.cat(vals, dim=-1).cpu().numpy().astype(np_weight_data_type).tofile(saved_path) - - elif (key.find("self_attention.query_key_value.bias") != -1 - or key.find("inter_attention.query.bias") != -1 - or key.find("inter_attention.key_value.bias") != -1): + val = safe_transpose(models_list[k][key]) + val = val.float().to(major_device) + vals.append(val) + saved_keys = megatron2hf_name(saved_key, name_mapping) + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{tensor_para_rank:d}.bin" + val = torch.cat(vals, dim=-1) + val = torch2np(val, np_weight_data_type) + LOGGER.debug( + "merge for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + val.dtype, + ) + val.tofile(saved_path) + + elif ( + key.find("self_attention.query_key_value.bias") != -1 + or key.find("inter_attention.query.bias") != -1 + or key.find("inter_attention.key_value.bias") != -1 + ): num_splits = 3 if key.find("inter_attention.key_value.bias") != -1: num_splits = 2 if key.find("inter_attention.query.bias") != -1: num_splits = 1 + vals = [] for k in range(factor): - val = transformer_model_list[k][key].T.float() + val = safe_transpose(models_list[k][key]) + val = val.float() local_dim = int(val.shape[-1] / num_splits) - head_num = model_args['num_attention_heads'] // model_args['tensor_model_parallel_size'] - size_per_head = model_args['kv_channels'] # t5 kv_channels may not be equal to hidden_size // head_num + head_num = nemo_model_config["num_attention_heads"] // nemo_model_config["tensor_model_parallel_size"] + # t5 kv_channels may not be equal to hidden_size // head_num + size_per_head = nemo_model_config["kv_channels"] val = val.reshape(head_num, num_splits, size_per_head) val = val.permute(1, 0, 2) val = val.reshape(num_splits, local_dim) vals.append(val.to(major_device)) saved_vals = torch.cat(vals, dim=-1) - saved_keys = convert_megatron_to_HF_naming_style_multiple(saved_key, name_mapping) + saved_keys = megatron2hf_name(saved_key, name_mapping) if len(saved_keys) == 1: - saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{i:d}.bin" - saved_vals.cpu().numpy().astype(np_weight_data_type).tofile(saved_path) - return - for index in range(len(saved_keys)): - saved_path = saved_dir / f"{prefix}.{saved_keys[index]}.{i:d}.bin" - saved_vals[index,...].cpu().numpy().astype(np_weight_data_type).tofile(saved_path) - - elif (key.find("self_attention.query_key_value.weight") != -1 - or key.find("inter_attention.query.weight") != -1 - or key.find("inter_attention.key_value.weight") != -1): + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{tensor_para_rank:d}.bin" + val = torch2np(saved_vals, np_weight_data_type) + LOGGER.debug( + "merge for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + val.dtype, + ) + val.tofile(saved_path) + else: + for index in range(len(saved_keys)): + saved_path = saved_dir / f"{prefix}.{saved_keys[index]}.{tensor_para_rank:d}.bin" + val = torch2np(saved_vals[index, ...], np_weight_data_type) + LOGGER.debug( + "merge for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + val.dtype, + ) + val.tofile(saved_path) + + elif ( + key.find("self_attention.query_key_value.weight") != -1 + or key.find("inter_attention.query.weight") != -1 + or key.find("inter_attention.key_value.weight") != -1 + ): num_splits = 3 if key.find("inter_attention.key_value.weight") != -1: num_splits = 2 if key.find("inter_attention.query.weight") != -1: num_splits = 1 + vals = [] for k in range(factor): - val = transformer_model_list[k][key].T.float() + val = safe_transpose(models_list[k][key]) + val = val.float() hidden_dim = val.shape[0] local_dim = int(val.shape[-1] / num_splits) - head_num = model_args['num_attention_heads'] - size_per_head = model_args['kv_channels'] # t5 kv_channels may not be equal to hidden_size // head_num - head_num = head_num // model_args['tensor_model_parallel_size'] + head_num = nemo_model_config["num_attention_heads"] + # t5 kv_channels may not be equal to hidden_size // head_num + size_per_head = nemo_model_config["kv_channels"] + head_num = head_num // nemo_model_config["tensor_model_parallel_size"] val = val.reshape(hidden_dim, head_num, num_splits, size_per_head) val = val.permute(0, 2, 1, 3) val = val.reshape(hidden_dim, num_splits, local_dim) vals.append(val.to(major_device)) saved_vals = torch.cat(vals, dim=-1) - saved_keys = convert_megatron_to_HF_naming_style_multiple(saved_key, name_mapping) + saved_keys = megatron2hf_name(saved_key, name_mapping) if len(saved_keys) == 1: - saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{i:d}.bin" - saved_vals.cpu().numpy().astype(np_weight_data_type).tofile(saved_path) - return - for index in range(len(saved_keys)): - saved_path = saved_dir / f"{prefix}.{saved_keys[index]}.{i:d}.bin" - saved_vals[:, index, ...].cpu().numpy().astype(np_weight_data_type).tofile(saved_path) + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{tensor_para_rank:d}.bin" + val = torch2np(saved_vals, np_weight_data_type) + LOGGER.debug( + "merge for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + val.dtype, + ) + val.tofile(saved_path) + else: + for index in range(len(saved_keys)): + saved_path = saved_dir / f"{prefix}.{saved_keys[index]}.{tensor_para_rank:d}.bin" + val = torch2np(saved_vals[:, index, ...], np_weight_data_type) + LOGGER.debug( + "merge for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + val.dtype, + ) + val.tofile(saved_path) else: - print(f"[ERROR] cannot find key '{key}'") - -def split_and_convert_process(model_type, i, pipeline_para_rank, saved_dir, factor, key, model_args, transformer_model_list, np_weight_data_type): + LOGGER.error(f"cannot find key '{key}'") + + +def split_and_convert_process( + model_type, + tensor_para_rank, + pipeline_para_rank, + saved_dir, + factor, + key, + nemo_model_config, + models_list, + np_weight_data_type, +): + assert model_type == "encoder" or model_type == "decoder" prefix = model_type + num_layers = nemo_model_config["num_layers"] + pipeline_para_size = nemo_model_config["pipeline_model_parallel_size"] + pipeline_model_parallel_split_rank = nemo_model_config.get("pipeline_model_parallel_split_rank", 0) + name_mapping = megatron_HF_name_mapping[model_type] - val = transformer_model_list[0][key].T.float().cpu().numpy().astype(np_weight_data_type) - del transformer_model_list[0][key] + val = safe_transpose(models_list[0][key]) + val = torch2np(val, np_weight_data_type) if key.find("layers.") != -1: - layer_index = (int)(key[7 : key.find(".", 7)]) - saved_key = key.replace( - "layers.%d." % layer_index, - "layers.%d." % (layer_index + pipeline_para_rank * model_args['num_layers'] // model_args['pipeline_model_parallel_size'])) + layer_index = int(key[7 : key.find(".", 7)]) + encoder_num_pipeline_stages = pipeline_model_parallel_split_rank + decoder_num_pipeline_stages = pipeline_para_size - pipeline_model_parallel_split_rank + offset = 0 + if model_type == "encoder" and pipeline_para_size > 1: + offset = pipeline_para_rank * (num_layers // encoder_num_pipeline_stages) + elif model_type == "decoder" and pipeline_para_size > 1: + offset = (pipeline_para_rank - pipeline_model_parallel_split_rank) * ( + num_layers // decoder_num_pipeline_stages + ) + saved_key = key.replace(f"layers.{layer_index}.", f"layers.{layer_index + offset}.") else: saved_key = key @@ -287,58 +399,128 @@ def split_and_convert_process(model_type, i, pipeline_para_rank, saved_dir, fact or key.find("post_inter_attention_layernorm.bias") != -1 or key.find("mlp.dense_4h_to_h.bias") != -1 or key.find("final_layernorm.weight") != -1 - or key.find("final_layernorm.bias") != -1): + or key.find("final_layernorm.bias") != -1 + ): # shared weights, only need to convert the weights of rank 0 - if i == 0: - saved_key = convert_megatron_to_HF_naming_style_single(saved_key, name_mapping) - saved_path = saved_dir / f"{prefix}.{saved_key}.bin" - val.tofile(saved_path.as_posix()) - - elif (key.find("self_attention.dense.weight") != -1 - or key.find("inter_attention.dense.weight") != -1 - or key.find("mlp.dense_4h_to_h.weight") != -1): + if tensor_para_rank == 0: + saved_keys = megatron2hf_name(saved_key, name_mapping) + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.bin" + LOGGER.debug( + "split for pp_rank=%d tp_rank=%d only for tp_rank=0 src_key=%s filename=%s " + "shape=%s (same as original) dtype=%s", + pipeline_para_rank, + tensor_para_rank, + key, + saved_path.name, + val.shape, + val.dtype, + ) + val.tofile(saved_path) + + elif ( + key.find("self_attention.dense.weight") != -1 + or key.find("inter_attention.dense.weight") != -1 + or key.find("mlp.dense_4h_to_h.weight") != -1 + ): split_vals = np.split(val, factor, axis=0) - saved_key = convert_megatron_to_HF_naming_style_single(saved_key, name_mapping) + saved_keys = megatron2hf_name(saved_key, name_mapping) for j in range(factor): - saved_path = saved_dir / f"{prefix}.{saved_key}.{i * factor + j:d}.bin" - split_vals[j].tofile(saved_path.as_posix()) - - elif key.find("mlp.dense_h_to_4h.weight") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1: + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{tensor_para_rank * factor + j:d}.bin" + LOGGER.debug( + "split for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s original_shape=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + split_vals[j].shape, + split_vals[j].dtype, + ) + split_vals[j].tofile(saved_path) + + elif ( + key.find("mlp.dense_h_to_4h.weight") != -1 + or key.find("mlp.dense_h_to_4h.bias") != -1 + or key.find("mlp.dense_h_to_4h_2.weight") != -1 + or key.find("mlp.dense_h_to_4h_2.bias") != -1 + ): split_vals = np.split(val, factor, axis=-1) - saved_key = convert_megatron_to_HF_naming_style_single(saved_key, name_mapping) + saved_keys = megatron2hf_name(saved_key, name_mapping) for j in range(factor): - saved_path = saved_dir / f"{prefix}.{saved_key}.{i * factor + j:d}.bin" - split_vals[j].tofile(saved_path.as_posix()) + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{tensor_para_rank * factor + j:d}.bin" + LOGGER.debug( + "split for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s original_shape=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + split_vals[j].shape, + split_vals[j].dtype, + ) + split_vals[j].tofile(saved_path) - elif (key.find("self_attention.query_key_value.bias") != -1 - or key.find("inter_attention.query.bias") != -1 - or key.find("inter_attention.key_value.bias") != -1): + elif ( + key.find("self_attention.query_key_value.bias") != -1 + or key.find("inter_attention.query.bias") != -1 + or key.find("inter_attention.key_value.bias") != -1 + ): num_splits = 3 if key.find("inter_attention.key_value.bias") != -1: num_splits = 2 if key.find("inter_attention.query.bias") != -1: num_splits = 1 local_dim = int(val.shape[-1] / num_splits) - head_num = model_args['num_attention_heads'] // model_args['tensor_model_parallel_size'] - size_per_head = model_args['kv_channels'] # t5 kv_channels may not be equal to hidden_size // head_num + head_num = nemo_model_config["num_attention_heads"] // nemo_model_config["tensor_model_parallel_size"] + # t5 kv_channels may not be equal to hidden_size // head_num + size_per_head = nemo_model_config["kv_channels"] val = val.reshape(head_num, num_splits, size_per_head) val = val.transpose(1, 0, 2) val = val.reshape(num_splits, local_dim) split_vals = np.split(val, factor, axis=-1) - saved_keys = convert_megatron_to_HF_naming_style_multiple(saved_key, name_mapping) + saved_keys = megatron2hf_name(saved_key, name_mapping) for j in range(factor): if len(saved_keys) == 1: - saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{i * factor + j:d}.bin" - split_vals[j].tofile(saved_path.as_posix()) + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{tensor_para_rank * factor + j:d}.bin" + LOGGER.debug( + "split for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s " + "preprocessed_shape=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + split_vals[j].shape, + split_vals[j].dtype, + ) + split_vals[j].tofile(saved_path) continue for index in range(len(saved_keys)): - saved_path = saved_dir / f"{prefix}.{saved_keys[index]}.{i * factor + j:d}.bin" - split_vals[j][index, ...].tofile(saved_path.as_posix()) - - elif (key.find("self_attention.query_key_value.weight") != -1 - or key.find("inter_attention.query.weight") != -1 - or key.find("inter_attention.key_value.weight") != -1): + saved_path = saved_dir / f"{prefix}.{saved_keys[index]}.{tensor_para_rank * factor + j:d}.bin" + split_val_idxed = split_vals[j][index, ...] + LOGGER.debug( + "split for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s " + "preprocessed_shape=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + split_val_idxed.shape, + split_val_idxed.dtype, + ) + split_val_idxed.tofile(saved_path) + + elif ( + key.find("self_attention.query_key_value.weight") != -1 + or key.find("inter_attention.query.weight") != -1 + or key.find("inter_attention.key_value.weight") != -1 + ): num_splits = 3 if key.find("inter_attention.key_value.weight") != -1: num_splits = 2 @@ -347,280 +529,536 @@ def split_and_convert_process(model_type, i, pipeline_para_rank, saved_dir, fact hidden_dim = val.shape[0] local_dim = int(val.shape[-1] / num_splits) - head_num = model_args['num_attention_heads'] - size_per_head = model_args['kv_channels'] # t5 kv_channels may not be equal to hidden_size // head_num - head_num = head_num // model_args['tensor_model_parallel_size'] + head_num = nemo_model_config["num_attention_heads"] + # t5 kv_channels may not be equal to hidden_size // head_num + size_per_head = nemo_model_config["kv_channels"] + head_num = head_num // nemo_model_config["tensor_model_parallel_size"] val = val.reshape(hidden_dim, head_num, num_splits, size_per_head) val = val.transpose(0, 2, 1, 3) val = val.reshape(hidden_dim, num_splits, local_dim) split_vals = np.split(val, factor, axis=-1) - saved_keys = convert_megatron_to_HF_naming_style_multiple(saved_key, name_mapping) + saved_keys = megatron2hf_name(saved_key, name_mapping) for j in range(factor): if len(saved_keys) == 1: - saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{i * factor + j:d}.bin" - split_vals[j].tofile(saved_path.as_posix()) + saved_path = saved_dir / f"{prefix}.{saved_keys[0]}.{tensor_para_rank * factor + j:d}.bin" + LOGGER.debug( + "split for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s " + "preprocessed_shape=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + split_vals[j].shape, + split_vals[j].dtype, + ) + split_vals[j].tofile(saved_path) continue for index in range(len(saved_keys)): - saved_path = saved_dir / f"{prefix}.{saved_keys[index]}.{i * factor + j:d}.bin" - split_vals[j][:, index, ...].tofile(saved_path.as_posix()) + saved_path = saved_dir / f"{prefix}.{saved_keys[index]}.{tensor_para_rank * factor + j:d}.bin" + split_val_idxed = split_vals[j][:, index, ...] + LOGGER.debug( + "split for pp_rank=%d tp_rank=%d factor=%d src_key=%s filename=%s " + "preprocessed_shape=%s shape=%s dtype=%s", + pipeline_para_rank, + tensor_para_rank, + factor, + key, + saved_path.name, + val.shape, + split_val_idxed.shape, + split_val_idxed.dtype, + ) + split_val_idxed.tofile(saved_path) else: - print(f"[ERROR] cannot find key '{key}'") - -def convert_checkpoint(args, model_config = None): - saved_dir = Path(args.saved_dir) / f"{args.infer_gpu_num:d}-gpu" - saved_dir.mkdir(parents=True, exist_ok=True) - - prefix = Path(args.in_file) if args.ckpt_type == "ckpt" else Path(args.saved_dir) - base_ckpt_name = "*last.ckpt" if args.ckpt_type == "ckpt" else "model_weights.ckpt" - - # load position_embedding from rank 0 - if (prefix).is_dir() and args.ckpt_type == "nemo" and model_config['tensor_model_parallel_size'] == 1: - model_00 = torch.load(os.path.join(args.saved_dir, base_ckpt_name), map_location=_gpu_map_location) - elif (prefix / "mp_rank_00").is_dir(): - ckpt_name = glob.glob((prefix / "mp_rank_00" / base_ckpt_name).as_posix())[0].split('/')[-1] - model_00 = torch.load((prefix / "mp_rank_00" / ckpt_name).as_posix(), map_location=_gpu_map_location) - elif (prefix / "mp_rank_00_000").is_dir(): - ckpt_name = glob.glob((prefix / "mp_rank_00_000" / base_ckpt_name).as_posix())[0].split('/')[-1] - model_00 = torch.load((prefix / "mp_rank_00_000" / ckpt_name).as_posix(), map_location=_gpu_map_location) + LOGGER.error(f"cannot find key '{key}'") + + +def convert_checkpoint(unpacked_checkpoints_dir: UnpackedNemoCheckpointDir, args: argparse.Namespace): + nemo_model_config = unpacked_checkpoints_dir.model_config + + if nemo_model_config.get("kv_channels", None) is None: + nemo_model_config["kv_channels"] = nemo_model_config["hidden_size"] // nemo_model_config["num_attention_heads"] + + inference_tensor_para_size = args.infer_gpu_num + + # if checkpoints files could be found - start preparing output dir + saved_dir = pathlib.Path(args.saved_dir) / f"{inference_tensor_para_size:d}-gpu" + if saved_dir.exists(): + LOGGER.error("Remove %s target directory before running conversion", saved_dir) + for file_path in saved_dir.rglob("*"): + LOGGER.debug(" %s", file_path.relative_to(saved_dir)) + sys.exit(1) + + saved_dir.mkdir(parents=True) + + checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths( + nemo_model_config.get("tensor_model_parallel_size", 1), + nemo_model_config.get("pipeline_model_parallel_size", 1), + ) + + LOGGER.debug("Expecting checkpoints paths in:") + for tp_rank_checkpoints_paths in checkpoints_paths: + for checkpoint_path in tp_rank_checkpoints_paths: + LOGGER.debug(" %s", checkpoint_path) + + map_location_fn = cpu_map_location if bool(args.load_checkpoints_to_cpu) else gpu_map_location + np_weight_data_type = WEIGHT2DTYPE[args.weight_data_type] + + has_gated_activations = False + + for pipeline_rank in range(len(checkpoints_paths[0])): + model_from_selected_pipeline = torch.load(checkpoints_paths[0][pipeline_rank], map_location=map_location_fn) + model_from_selected_pipeline = model_from_selected_pipeline.get("state_dict", model_from_selected_pipeline) + + LOGGER.debug(f"Existent pipeline_rank={pipeline_rank} keys:") + for key in model_from_selected_pipeline.keys(): + LOGGER.debug(" %s", key) + + encoder_ape_key = "enc_dec_model.encoder_embedding.position_embeddings.weight" + if encoder_ape_key in model_from_selected_pipeline.keys(): + saved_path = saved_dir / "shared.ape.bin" + # not weight, do not need to transpose + val = model_from_selected_pipeline[encoder_ape_key] + val = torch2np(val, np_weight_data_type) + LOGGER.debug( + "save for pp_rank=%d src_key=%s saved_keys=%s shape=%s dtype=%s", + pipeline_rank, + encoder_ape_key, + saved_path.name, + val.shape, + val.dtype, + ) + val.tofile(saved_path) + + has_gated_activations |= any("mlp.dense_h_to_4h_2" in key for key in model_from_selected_pipeline.keys()) + + def _split(src_key, dst_filename_fn): + if src_key in model_from_selected_pipeline.keys(): + _val = model_from_selected_pipeline[src_key] + _val = torch2np(_val, np_weight_data_type) + _val = np.split(_val, inference_tensor_para_size, axis=0) + for tensor_idx in range(inference_tensor_para_size): + saved_path = saved_dir / dst_filename_fn(tensor_idx) + LOGGER.debug( + "save for pp_rank=%d src_key=%s filename=%s shape=%s dtype=%s", + pipeline_rank, + src_key, + saved_path.name, + val.shape, + val.dtype, + ) + _val[tensor_idx].tofile(saved_path) + del _val + + # split rpe into tensor parallel ranks + _split( + "enc_dec_model.encoder_embedding.encoder_relative_position_embedding.weight", + lambda idx: f"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight.{idx}.bin", + ) + _split( + "enc_dec_model.decoder_embedding.decoder_relative_position_embedding.weight", + lambda idx: f"decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight.{idx}.bin", + ) + + del model_from_selected_pipeline + + nemo_position_embedding_type = nemo_model_config.get("position_embedding_type", "absolute") + nenemo_position_embedding_type = ( + "absolute" if nemo_position_embedding_type == "learned_absolute" else nemo_position_embedding_type + ) + model_new_config = { + "structure": { + "t5_with_bias": str(True), + "position_embedding_type": nenemo_position_embedding_type, + "use_gated_activation": str(has_gated_activations), + } + } + + training_tensor_para_size = nemo_model_config.get("tensor_model_parallel_size", 1) + training_pipeline_para_size = nemo_model_config.get("pipeline_model_parallel_size", 1) + + if training_tensor_para_size > inference_tensor_para_size: + assert training_tensor_para_size % inference_tensor_para_size == 0 + is_merge_ckpt = True + factor = int(training_tensor_para_size / inference_tensor_para_size) else: - print(f"[ERROR] Cannot find checkpoint in {prefix}.") - exit(1) - - model_args = dict(model_00["hyper_parameters"]['cfg']) if args.ckpt_type == "ckpt" else model_config - - # checkpoint weights - model00_state = model_00['state_dict'] if args.ckpt_type == "ckpt" else model_00 - - # update model structure config - if 'position_embedding_type' not in model_args.keys() or model_args['position_embedding_type'] == "absolute": - model_new_config["structure"]["position_embedding_type"] = 1 - - ## 'pipeline_model_parallel_size' - if 'pipeline_model_parallel_size' not in model_args.keys(): - model_args['pipeline_model_parallel_size'] = 1 - - config = configparser.ConfigParser() - - config["encoder"] = {} - config["decoder"] = {} - for key, val in model_args.items(): - if key in encoder_config_mapping.keys(): - if key == "kv_channels" and val == None: - val = int(config["encoder"]["d_model"]) // int(config["encoder"]["num_heads"]) - config["encoder"][encoder_config_mapping[key]] = f"{val}" - if key in decoder_config_mapping.keys(): - if key == "kv_channels" and val == None: - val = int(config["decoder"]["d_model"]) // int(config["decoder"]["num_heads"]) - config["decoder"][decoder_config_mapping[key]] = f"{val}" - - # vocab is not stored in the config by default, we need to get it from word embedding's shape - vocab_size = model00_state['enc_dec_model.encoder_embedding.word_embeddings.weight'].shape[0] - config["encoder"]["vocab_size"] = f"{vocab_size}" - config["decoder"]["vocab_size"] = f"{vocab_size}" - for key, val in decoder_new_config.items(): - config["decoder"][key] = f"{val}" - for key, val in model_new_config.items(): - config[key] = {} - for val_key, val_val in val.items(): - config[key][val_key] = f"{val_val}" - - # add model name - config["encoder"]["_name_or_path"] = args.model_name - config["decoder"]["_name_or_path"] = args.model_name - - # add weight data type - config["encoder"]["weight_data_type"] = args.weight_data_type - config["decoder"]["weight_data_type"] = args.weight_data_type - - np_weight_data_type = get_weight_data_type(args.weight_data_type) - - with open((saved_dir / f"config.ini").as_posix(), 'w') as configfile: - config.write(configfile) + assert inference_tensor_para_size % training_tensor_para_size == 0 + is_merge_ckpt = False + factor = int(inference_tensor_para_size / training_tensor_para_size) - if "enc_dec_model.encoder_embedding.position_embeddings.weight" in model00_state.keys(): - model00_state["enc_dec_model.encoder_embedding.position_embeddings.weight"] \ - .float().cpu().numpy().astype(np_weight_data_type) \ - .tofile((saved_dir / "shared.ape.bin").as_posix()) + assert nemo_model_config["ffn_hidden_size"] % inference_tensor_para_size == 0 + assert nemo_model_config["num_attention_heads"] % inference_tensor_para_size == 0 - # inference factor calculation - t_gpu_num = model_args['tensor_model_parallel_size'] - i_gpu_num = args.infer_gpu_num + main_loop = min(training_tensor_para_size, inference_tensor_para_size) - if t_gpu_num > i_gpu_num: - assert t_gpu_num % i_gpu_num == 0 - is_merge_ckpt = True - factor = int(t_gpu_num / i_gpu_num) - else: - assert i_gpu_num % t_gpu_num == 0 - is_merge_ckpt = False - factor = int(i_gpu_num / t_gpu_num) - - main_loop = min(t_gpu_num, i_gpu_num) - - # split rpe into tensor parallel ranks - encoder_rpe_key = "enc_dec_model.encoder_embedding.encoder_relative_position_embedding.weight" - if encoder_rpe_key in model00_state.keys(): - encoder_relative_position_embedding = model00_state[encoder_rpe_key] \ - .T.float().cpu().numpy().astype(np_weight_data_type) - encoder_relative_position_embedding_split = np.split(encoder_relative_position_embedding, i_gpu_num, axis=0) - for i in range(i_gpu_num): - saved_path = saved_dir / f"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight.{i}.bin" - encoder_relative_position_embedding_split[i].tofile(saved_path.as_posix()) - - del encoder_relative_position_embedding, encoder_relative_position_embedding_split - - decoder_rpe_key = "enc_dec_model.decoder_embedding.decoder_relative_position_embedding.weight" - if decoder_rpe_key in model00_state.keys(): - decoder_relative_position_embedding = model00_state[decoder_rpe_key] \ - .T.float().cpu().numpy().astype(np_weight_data_type) - decoder_relative_position_embedding_split = np.split(decoder_relative_position_embedding, i_gpu_num, axis=0) - for i in range(i_gpu_num): - saved_path = saved_dir / f"decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight.{i}.bin" - decoder_relative_position_embedding_split[i].tofile(saved_path.as_posix()) - - del decoder_relative_position_embedding, decoder_relative_position_embedding_split - - del model_00 w_e_list = [] + lm_head_list = [] torch.multiprocessing.set_start_method("spawn") - pool = multiprocessing.Pool(args.processes) - for i in range(main_loop): - for j in range(model_args['pipeline_model_parallel_size']): - if model_args['pipeline_model_parallel_size'] == 1: - layer_rank_num = "" - else: - layer_rank_num = f"_{j:03d}" - - encoder_models = [] - decoder_models = [] - - if is_merge_ckpt == True: - for k in range(factor): - ckpt_name = glob.glob((prefix / f"mp_rank_{i * factor + k:02d}{layer_rank_num}" / base_ckpt_name).as_posix())[0].split('/')[-1] - ckpt_path = glob.glob((prefix / f"mp_rank_{i * factor + k:02d}{layer_rank_num}" / ckpt_name).as_posix())[0] - m = torch.load(ckpt_path, map_location=_gpu_map_location) - m = m['state_dict'] if args.ckpt_type == "ckpt" else m - encoder_models_dict = {} - decoder_models_dict = {} - for key, val in m.items(): - encoder_prefix = "enc_dec_model.enc_dec_model.encoder.model." - decoder_prefix = "enc_dec_model.enc_dec_model.decoder.model." - if key.find(encoder_prefix) != -1: - encoder_models_dict[key.split(encoder_prefix, 1)[1]] = val - elif key.find(decoder_prefix) != -1: - decoder_models_dict[key.split(decoder_prefix, 1)[1]] = val - encoder_models.append(encoder_models_dict) - decoder_models.append(decoder_models_dict) - - if j == 0: - w_e_list.append(m["enc_dec_model.encoder_embedding.word_embeddings.weight"].float().cpu().numpy().astype(np_weight_data_type)) - else: - if t_gpu_num == 1 and args.ckpt_type == "nemo": - ckpt_path = glob.glob((prefix / ckpt_name).as_posix())[0] + torch.multiprocessing.set_sharing_strategy("file_system") + with concurrent.futures.ProcessPoolExecutor(args.processes) as pool: + for tp_idx in range(main_loop): + for pp_idx in range(training_pipeline_para_size): + + encoder_models = [] + decoder_models = [] + + word_embedding_key = "enc_dec_model.encoder_embedding.word_embeddings.weight" + lm_head_bias_key = "enc_dec_model.tokens_head.bias" + if is_merge_ckpt: + for k in range(factor): + rank_weights = checkpoints_paths[tp_idx * factor + k][pp_idx] + model = torch.load(rank_weights, map_location=map_location_fn) + + if pp_idx == 0: + w_e_val = model.get("state_dict", model)[word_embedding_key] + w_e_val = torch2np(w_e_val, np_weight_data_type) + w_e_list.append(w_e_val) + if pp_idx == training_pipeline_para_size - 1: + lm_head_val = model.get("state_dict", model)[lm_head_bias_key] + lm_head_val = torch2np(lm_head_val, np_weight_data_type) + lm_head_list.append(lm_head_val) + + encoder_models.append( + extract_layers_with_prefix(model, "enc_dec_model.enc_dec_model.encoder.model.") + ) + decoder_models.append( + extract_layers_with_prefix(model, "enc_dec_model.enc_dec_model.decoder.model.") + ) + LOGGER.debug( + "For pp_idx=%d tp_id=%d merging weights from %s extracted:", pp_idx, tp_idx, rank_weights + ) + LOGGER.debug(" encoder layers") + for name in encoder_models[-1]: + LOGGER.debug(" %s", name) + LOGGER.debug(" decoder layers") + for name in decoder_models[-1]: + LOGGER.debug(" %s", name) else: - ckpt_name = glob.glob((prefix / f"mp_rank_{i:02d}{layer_rank_num}" / base_ckpt_name).as_posix())[0].split('/')[-1] - ckpt_path = glob.glob((prefix / f"mp_rank_{i:02d}{layer_rank_num}" / ckpt_name).as_posix())[0] - m = torch.load(ckpt_path, map_location=_gpu_map_location) - m = m['state_dict'] if args.ckpt_type == "ckpt" else m - - if j == 0: - w_e_list.append( - m["enc_dec_model.encoder_embedding.word_embeddings.weight"] - .float() - .cpu() - .numpy() - .astype(np_weight_data_type) + rank_weights = checkpoints_paths[tp_idx][pp_idx] + model = torch.load(rank_weights, map_location=map_location_fn) + + if pp_idx == 0: + w_e_val = model.get("state_dict", model)[word_embedding_key] + w_e_val = torch2np(w_e_val, np_weight_data_type) + w_e_list.append(w_e_val) + if pp_idx == training_pipeline_para_size - 1: + lm_head_val = model.get("state_dict", model)[lm_head_bias_key] + lm_head_val = torch2np(lm_head_val, np_weight_data_type) + lm_head_list.append(lm_head_val) + + encoder_models.append( + extract_layers_with_prefix(model, "enc_dec_model.enc_dec_model.encoder.model.") ) - - encoder_models_dict = {} - decoder_models_dict = {} - for key, val in m.items(): - encoder_prefix = "enc_dec_model.enc_dec_model.encoder.model." - decoder_prefix = "enc_dec_model.enc_dec_model.decoder.model." - if key.find(encoder_prefix) != -1: - encoder_models_dict[key.split(encoder_prefix, 1)[1]] = val - elif key.find(decoder_prefix) != -1: - decoder_models_dict[key.split(decoder_prefix, 1)[1]] = val - encoder_models.append(encoder_models_dict) - decoder_models.append(decoder_models_dict) - - pool.starmap( - merge_and_convert_process if is_merge_ckpt == True else split_and_convert_process, - [ - ( + decoder_models.append( + extract_layers_with_prefix(model, "enc_dec_model.enc_dec_model.decoder.model.") + ) + LOGGER.debug( + "For pp_idx=%d tp_id=%d copy/splitting weights from %s extracted:", pp_idx, tp_idx, rank_weights + ) + LOGGER.debug(" encoder layers") + for name in encoder_models[-1]: + LOGGER.debug(" %s", name) + LOGGER.debug(" decoder layers") + for name in decoder_models[-1]: + LOGGER.debug(" %s", name) + + process_fn = merge_and_convert_process if is_merge_ckpt else split_and_convert_process + + for key in encoder_models[0]: + pool.submit( + process_fn, "encoder", - i, - j, + tp_idx, # tp_rank + pp_idx, # pp_rank saved_dir, factor, - k, - model_args, + key, + nemo_model_config, encoder_models, - np_weight_data_type + np_weight_data_type, ) - for k in encoder_models[0].keys() - ], - ) - pool.starmap( - merge_and_convert_process if is_merge_ckpt == True else split_and_convert_process, - [ - ( + + for key in decoder_models[0]: + pool.submit( + process_fn, "decoder", - i, - j, + tp_idx, # tp_rank + pp_idx, # pp_rank saved_dir, factor, - k, - model_args, + key, + nemo_model_config, decoder_models, - np_weight_data_type + np_weight_data_type, ) - for k in decoder_models[0].keys() - ], - ) - pool.close() - pool.join() + w_e_saved_path = saved_dir / "shared.weight_T.bin" + lm_head_weight_saved_path = saved_dir / "lm_head.weight.bin" + lm_head_saved_path = saved_dir / "shared.bias.bin" + w_e_val = np.concatenate(w_e_list, axis=0) + lm_head_val = np.concatenate(lm_head_list, axis=0) + LOGGER.debug( + "save for src_key=%s filename=%s shape=%s dtype=%s", + word_embedding_key, + w_e_saved_path.name, + w_e_val.shape, + w_e_val.dtype, + ) + LOGGER.debug( + "save for src_key=%s filename=%s shape=%s dtype=%s", + lm_head_bias_key, + lm_head_saved_path.name, + lm_head_val.shape, + lm_head_val.dtype, + ) + w_e_val.tofile(w_e_saved_path) + w_e_val.tofile(lm_head_weight_saved_path) + lm_head_val.tofile(lm_head_saved_path) + + vocab_size = w_e_val.shape[0] - np.concatenate(w_e_list, axis=0).tofile((saved_dir / "shared.weight_T.bin").as_posix()) - m["enc_dec_model.tokens_head.bias"].float().cpu().numpy().tofile((saved_dir / "shared.bias.bin").as_posix()) + config = configparser.ConfigParser() -if __name__ == "__main__": + config["encoder"] = { + **{ + "_name_or_path": args.model_name, + "model_type": "T5", + "weight_data_type": args.weight_data_type, + "tensor_para_size": str(inference_tensor_para_size), + "vocab_size": str(vocab_size), + }, + **{ + encoder_config_mapping[key]: str(val) + for key, val in nemo_model_config.items() + if key in encoder_config_mapping + }, + } + + tokenizer_config = nemo_model_config["tokenizer"] + tokenizer_config = _update_tokenizer_config(tokenizer_config, unpacked_checkpoints_dir) + if args.tokenizer_model_path: + LOGGER.debug("Use tokenizer model passed from CLI: %s", args.tokenizer_model_path) + tokenizer_config["model"] = args.tokenizer_model_path + if args.vocab_path: + LOGGER.debug("Use tokenizer vocab passed from CLI: %s", args.vocab_path) + tokenizer_config["vocab_file"] = args.vocab_path + if args.merges_path: + LOGGER.debug("Use tokenizer merge passed from CLI: %s", args.merges_path) + tokenizer_config["merge_file"] = args.merges_path + + _copy_tokenizer_file_if_defined("model", tokenizer_config["model"], saved_dir) + _copy_tokenizer_file_if_defined("vocab_file", tokenizer_config["vocab_file"], saved_dir) + _copy_tokenizer_file_if_defined("merge_file", tokenizer_config["merge_file"], saved_dir) + + bos_id, eos_id = _get_special_tokens_ids(tokenizer_config) + + config["decoder"] = { + **{ + "_name_or_path": args.model_name, + "model_type": "T5", + "weight_data_type": args.weight_data_type, + "tensor_para_size": str(inference_tensor_para_size), + "vocab_size": str(vocab_size), + "decoder_start_token_id": str(bos_id), + "eos_token_id": str(eos_id), + }, + **{ + decoder_config_mapping[key]: str(val) + for key, val in nemo_model_config.items() + if key in decoder_config_mapping + }, + } + + for section, section_dict in model_new_config.items(): + config[section] = {k: str(v) for k, v in section_dict.items()} + + with (saved_dir / f"config.ini").open("w") as configfile: + config.write(configfile) + + +def _update_tokenizer_config(tokenizer_config: typing.Dict, unpacked_checkpoints_dir): + def _update_config_entry(key, file_pattern): + old_file_path = tokenizer_config[key] + if old_file_path: + LOGGER.debug("tokenizer %s %s type %s", key, old_file_path, type(old_file_path)) + old_file_path = pathlib.Path(old_file_path) + new_file_path = unpacked_checkpoints_dir.get_tokenizer_file_path("tokenizer", key, file_pattern) + if new_file_path: + LOGGER.debug("Update tokenizer %s %s -> %s", key, old_file_path, new_file_path) + tokenizer_config[key] = new_file_path.as_posix() + elif not old_file_path.exists(): + LOGGER.warning("Because tokenizer %s %s does not exists - set it as None", key, old_file_path) + tokenizer_config[key] = None + + _update_config_entry("model", "*.model") + _update_config_entry("vocab_file", "*vocab*") + _update_config_entry("merge_file", "*merge*.txt") + + return tokenizer_config + + +def _copy_tokenizer_file_if_defined(key_name, tokenizer_file_path, saved_dir): + if tokenizer_file_path: + tokenizer_file_path = pathlib.Path(tokenizer_file_path) + if tokenizer_file_path.exists(): + tokenizer_basename = { + "model": "tokenizer", + "vocab_file": "vocab", + "merge_file": "merges", + }[key_name] + dst_path = saved_dir / f"{tokenizer_basename}{tokenizer_file_path.suffix}" + LOGGER.debug("Copy of %s %s file as %s", tokenizer_file_path, key_name, dst_path) + shutil.copy(tokenizer_file_path.as_posix(), dst_path.as_posix()) + else: + LOGGER.debug("%s %s file does not exists", tokenizer_file_path, key_name) + + +def _get_special_tokens_ids(tokenizer_config: typing.Dict): + from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + from examples.pytorch.tokenizer import add_special_tokens_to_tokenizer + + logging.getLogger("git.cmd").setLevel(logging.INFO) + logging.getLogger("h5py._conv").setLevel(logging.INFO) + logging.getLogger("matplotlib").setLevel(logging.INFO) + logging.getLogger("matplotlib.font_manager").setLevel(logging.INFO) + logging.getLogger("matplotlib.pyplot").setLevel(logging.INFO) + + tokenizer = get_nmt_tokenizer( + library=tokenizer_config["library"], + model_name=tokenizer_config["type"], + tokenizer_model=tokenizer_config["model"], + vocab_file=tokenizer_config["vocab_file"], + merges_file=tokenizer_config["merge_file"], + legacy=True, + ) + + if tokenizer_config["library"] == "sentencepiece": + add_special_tokens_to_tokenizer(tokenizer) + + bos_id = tokenizer.bos_id + eos_id = tokenizer.eos_id + + LOGGER.debug("for %s obtained tokenizer tokens ids bos_id=%d eos_id=%d", tokenizer_config, bos_id, eos_id) + + return bos_id, eos_id + + +def main(): parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument("-saved_dir", "-o", type=str, help="file name of output file", required=True) - parser.add_argument("-in_file", "-i", type=str, help="file name of input checkpoint file", required=True) - parser.add_argument("-infer_gpu_num", "-i_g", type=int, help="How many gpus for inference", required=True) - parser.add_argument("-processes", "-p", type=int, help="How many processes to spawn for conversion (default: 64)", default=64) - parser.add_argument("-ckpt_type", "-ct", type=str, choices=['nemo', 'ckpt'], help="checkpoint type. nemo or ckpt", default="nemo") - parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) - parser.add_argument("-model_name", "-m", type=str, help="model name", required=True) + parser.add_argument( + "--saved-dir", + "-saved_dir", + "-o", + help="folder name of output files", + required=True, + ) + parser.add_argument( + "--in-file", + "-in_file", + "-i", + help="file name of .nemo checkpoint file or checkpoint dir", + required=True, + ) + parser.add_argument( + "--infer-gpu-num", + "-infer_gpu_num", + "-i_g", + type=int, + help="How many gpus for inference", + required=True, + ) + parser.add_argument( + "--processes", + "-processes", + "-p", + type=int, + default=64, + help="How many processes to spawn for conversion", + ) + parser.add_argument( + "--weight-data-type", + "-weight_data_type", + choices=["fp32", "fp16"], + default="fp32", + help="Data type of results weights", + ) + parser.add_argument( + "--model-name", + "-model_name", + "-m", + help="model name", + required=True, + ) + parser.add_argument( + "--vocab-path", + help="Path to vocabulary file to embed in FasterTransformer checkpoint", + required=False, + ) + parser.add_argument( + "--merges-path", + help="Path to merges file to embed in FasterTransformer checkpoint", + required=False, + ) + parser.add_argument( + "--tokenizer-model-path", + help="Path to tokenizer model file to embed in FasterTransformer checkpoint", + required=False, + ) + parser.add_argument( + "--load-checkpoints-to-cpu", + "-load_checkpoints_to_cpu", + "-cpu", + type=int, + choices=[0, 1], + default=1, + help="Whether to load model weights to CPU", + ) + parser.add_argument("--verbose", action="store_true", help="Provide verbose messages") args = parser.parse_args() + + log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, format=log_format) + print("\n=============== Argument ===============") for key in vars(args): print(f"{key}: {vars(args)[key]}") print("========================================") - ## unpack .nemo format if specified - if (args.ckpt_type == "nemo"): - model_config_yaml = "model_config.yaml" - config_yaml = os.path.join(args.saved_dir, model_config_yaml) + input_path = pathlib.Path(args.in_file) + if not input_path.exists(): + LOGGER.error("%s does not exists", input_path) + sys.exit(1) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = pathlib.Path(temp_dir) + + # unpack if needed + if input_path.is_file(): + checkpoint_dir_path = temp_dir / "unpacked" + start_time = datetime.datetime.now() + unpacked_checkpoint_dir = UnpackedNemoCheckpointDir( + unpack_nemo_ckpt(args.in_file, checkpoint_dir_path), + load_checkpoints_to_cpu=bool(args.load_checkpoints_to_cpu), + ) + LOGGER.info("Spent %s (h:m:s) to unpack NeMo archive", datetime.datetime.now() - start_time) + else: + unpacked_checkpoint_dir = UnpackedNemoCheckpointDir( + input_path, load_checkpoints_to_cpu=bool(args.load_checkpoints_to_cpu) + ) + + LOGGER.debug("Unpacked NeMo checkpoint contains:") + for file_path in unpacked_checkpoint_dir.checkpoints_dir.rglob("*"): + LOGGER.debug(" %s", file_path) - # unpack_nemo_ckpt(args.in_file, args.saved_dir) + start_time = datetime.datetime.now() + convert_checkpoint(unpacked_checkpoint_dir, args) + LOGGER.info("Spent %s (h:m:s) to convert the model", datetime.datetime.now() - start_time) - with open(config_yaml) as f: - model_config = yaml.full_load(f) - start_time = datetime.now() - convert_checkpoint(args, model_config) - stop_time = datetime.now() - run_time = (stop_time - start_time) - print("[INFO] Spend {} (h:m:s) to convert the model".format(run_time)) - else: - start_time = datetime.now() - convert_checkpoint(args) - stop_time = datetime.now() - run_time = (stop_time - start_time) - print("[INFO] Spend {} (h:m:s) to convert the model".format(run_time)) +if __name__ == "__main__": + main() diff --git a/examples/pytorch/t5/xnli_task_example.py b/examples/pytorch/t5/xnli_task_example.py new file mode 100644 index 000000000..61ee56859 --- /dev/null +++ b/examples/pytorch/t5/xnli_task_example.py @@ -0,0 +1,416 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import dataclasses +import json +import os +import pathlib +import time + +import numpy as np +import torch +import torch.distributed as dist +from tqdm import tqdm + + +from omegaconf.omegaconf import OmegaConf +from nemo.collections.nlp.data.glue_benchmark.glue_benchmark_dataset import ( + TextToTextGLUEDataset, + TextToTextXNLIDataset, +) +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.collections.common.metrics.classification_accuracy import ExactStringPerCategoryMatchMetric + +from examples.pytorch.t5.utils.ft_encoder import FTT5EncoderWeight, FTT5Encoder +from examples.pytorch.t5.utils.ft_decoding import FTT5DecodingWeight, FTT5Decoding, FTT5 +from examples.pytorch.tokenizer import add_special_tokens_to_tokenizer + + +def _build_dataset(data_cfg, tokenizer): + if data_cfg.task_name == 'xnli': + dataset = TextToTextXNLIDataset( + data_cfg.file_path, + task_name=data_cfg.task_name, + tokenizer=tokenizer, + max_seq_length=data_cfg.max_seq_length, + lang_list=data_cfg.eval_languages, + ) + else: + dataset = TextToTextGLUEDataset( + data_cfg.file_path, + task_name=data_cfg.task_name, + tokenizer=tokenizer, + max_seq_length=data_cfg.max_seq_length, + ) + return dataset + + +@dataclasses.dataclass +class Metric: + acc: float + + +@dataclasses.dataclass +class RequestAndResult: + model_answer: str + target: str + lang: str + metrics: Metric + + +def preds_and_labels_to_text(tokenizer, preds, labels): + preds = preds.cpu().numpy().tolist() + labels = labels.cpu().numpy().tolist() + # preds = [pred[0] for pred in preds] + + preds_text, labels_text = [], [] + for _, (pred, label) in enumerate(zip(preds, labels)): + if tokenizer.eos_id in pred: + idx = pred.index(tokenizer.eos_id) + pred = pred[:idx] + + # Legacy sentencepiece detokenization still preserves special tokens which messes up exact string match. + if hasattr(tokenizer, 'special_token_to_id'): + pred = [id for id in pred if id not in tokenizer.special_token_to_id.values()] + label = [id for id in label if id not in tokenizer.special_token_to_id.values()] + pred = tokenizer.ids_to_text(pred) + label = tokenizer.ids_to_text(label) + preds_text.append(pred) + labels_text.append(label) + + return preds_text, labels_text + + +def accuracy_score(pred, ref): + assert len(pred) == len(ref) + total = len(pred) + correct = 0 + for p, r in zip(pred, ref): + if p in r: + correct += 1 + # else: + # print(f"[pred]: {p} [label]: {r}") + print(f"[total_acc] {correct / total}") + return correct / total + + +class InputToken: + def __init__(self, input_ids, attention_mask): + self.input_ids = input_ids + self.attention_mask = attention_mask + + +class EncoderDecoderConfig: + def __init__(self, d_model, vocab_size, num_heads, d_kv, d_ff, num_layers, + relative_attention_num_buckets_or_max_pos_seq_len, decoder_start_token_id=0, decoder_end_token_id=1): + self.d_model = d_model + self.vocab_size = vocab_size + self.num_heads = num_heads + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.relative_attention_num_buckets = relative_attention_num_buckets_or_max_pos_seq_len + self.decoder_start_token_id = decoder_start_token_id + self.decoder_end_token_id = decoder_end_token_id + + +data_type_mapping = {"fp32": 0, "fp16": 1, "bf16": 2} + +def xnli_task(args_dict): + torch.set_printoptions(precision=6) + batch_size = args_dict['batch_size'] + beam_size = args_dict['beam_width'] + max_output_len = args_dict['max_output_len'] + beam_search_diversity_rate = args_dict['beam_search_diversity_rate'] + topk = args_dict['sampling_topk'] + topp = args_dict['sampling_topp'] + tensor_para_size = args_dict['tensor_para_size'] + pipeline_para_size = args_dict['pipeline_para_size'] + + if args_dict['ckpt_path'] is None: + raise Exception("Megatron T5 model needs to specify checkpoint path !") + + if dist.is_mpi_available(): + try: + dist.init_process_group(backend='mpi') + rank = dist.get_rank() + except: + rank = dist.get_rank() + else: + rank = 0 + + assert dist.get_world_size() == tensor_para_size * pipeline_para_size + + ckpt_path = args_dict['ckpt_path'] + ## read checkpoint config if exists + ckpt_config = configparser.ConfigParser() + + if args_dict['ckpt_path'] is None: + raise Exception("Megatron T5 model needs to specify checkpoint path !") + + tokenizer_model_path = os.path.join(ckpt_path, "tokenizer.model") + ckpt_config_path = os.path.join(ckpt_path, 'config.ini') + if os.path.isfile(ckpt_config_path): + ckpt_config.read(ckpt_config_path) + ## update structure config + t5_with_bias = ckpt_config.getboolean('structure', 't5_with_bias') + ## megatron with bias and use absolute position embedding + ## relative position embedding -> 0, absolute position embedding -> 1 + position_embedding_type = 0 if ckpt_config.get('structure', 'position_embedding_type') == 'relative' else 1 + use_gated_activation = ckpt_config.getboolean('structure', 'use_gated_activation') + weight_data_type = {"fp16": np.float16, "fp32": np.float32}[ckpt_config.get("encoder", "weight_data_type")] + activation_type = ckpt_config.get('encoder', 'feed_forward_proj') + assert ckpt_config.getint("encoder", "tensor_para_size") == tensor_para_size + else: + raise Exception("config file does exist with the ckpt !") + + if rank == 0: + print("\n=============== Argument ===============") + for key in args_dict: + print("{}: {}".format(key, args_dict[key])) + print("========================================") + + lib_path = args_dict['lib_path'] + + #xnli + tokenizer_mt5 = get_nmt_tokenizer( + library='sentencepiece', + model_name=None, + tokenizer_model=tokenizer_model_path, + vocab_file=None, + merges_file=None, + legacy=True, + ) + add_special_tokens_to_tokenizer(tokenizer_mt5) + + assert tokenizer_mt5.bos_id == ckpt_config.getint("decoder", "decoder_start_token_id") + assert tokenizer_mt5.eos_id == ckpt_config.getint("decoder", "eos_token_id") + + token_params = { + tokenizer_mt5.bos_token: tokenizer_mt5.bos_id, + tokenizer_mt5.eos_token: tokenizer_mt5.eos_id, + tokenizer_mt5.pad_token: tokenizer_mt5.pad_id, + } + print(f"tokenizer special tokens: {token_params}") + + xnli_cfg = OmegaConf.create({ + "file_path": args_dict['data_path'], + "task_name": "xnli", + "max_seq_length": 512, + "eval_languages": ['en', 'es', 'de', 'fr'] + }) + xnli_dataset = _build_dataset(xnli_cfg, tokenizer_mt5) + + data_loader = torch.utils.data.DataLoader( + xnli_dataset, + collate_fn=xnli_dataset.collate_fn, + batch_size=batch_size, + num_workers=1, + pin_memory=False, + drop_last=True) + + q_scaling = 1.0 + + encoder_config = EncoderDecoderConfig(ckpt_config.getint('encoder', 'd_model'), + ckpt_config.getint('encoder', 'vocab_size'), + ckpt_config.getint('encoder', 'num_heads'), + ckpt_config.getint('encoder', 'd_kv'), + ckpt_config.getint('encoder', 'd_ff'), + ckpt_config.getint('encoder', 'num_layers'), + ckpt_config.getint('encoder', 'relative_attention_num_buckets_or_max_pos_seq_len') + ) + + decoder_config = EncoderDecoderConfig(ckpt_config.getint('decoder', 'd_model'), + ckpt_config.getint('decoder', 'vocab_size'), + ckpt_config.getint('decoder', 'num_heads'), + ckpt_config.getint('decoder', 'd_kv'), + ckpt_config.getint('decoder', 'd_ff'), + ckpt_config.getint('decoder', 'num_layers'), + ckpt_config.getint('decoder', 'relative_attention_num_buckets_or_max_pos_seq_len'), + tokenizer_mt5.bos_id, + tokenizer_mt5.eos_id + ) + + ## run gemm test + if os.path.isfile("gemm_config.in") and rank == 0: + cmd = f"rm gemm_config.in" + print(f"Run {cmd}") + os.system(cmd) + if rank == 0: + data_type = data_type_mapping[args_dict['data_type']] + cmd = f"./bin/t5_gemm {batch_size // pipeline_para_size} {beam_size} {128} " \ + f"{encoder_config.d_model} {encoder_config.num_heads} {encoder_config.d_kv} {encoder_config.d_ff} " \ + f"{decoder_config.d_model} {decoder_config.num_heads} {decoder_config.d_kv} {decoder_config.d_ff} " \ + f"{decoder_config.vocab_size} {data_type} {tensor_para_size} 1 > .tmp_gemm.log" + print(f"Run gemm test: {cmd}") + os.system(cmd) + + dist.barrier() + + ft_encoder_weight = FTT5EncoderWeight( + encoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type, + ) + ft_decoding_weight = FTT5DecodingWeight( + decoder_config, + tensor_para_size, + pipeline_para_size, + t5_with_bias=t5_with_bias, + use_gated_activation=use_gated_activation, + position_embedding_type=position_embedding_type, + weight_data_type=weight_data_type, + ) + + ft_encoder_weight.load_from_bin(args_dict["ckpt_path"]) + ft_decoding_weight.load_from_bin(args_dict["ckpt_path"]) + + if args_dict['data_type'] == 'fp16': + ft_encoder_weight.to_half() + ft_decoding_weight.to_half() + elif args_dict['data_type'] == 'fp32': + ft_encoder_weight.to_single() + ft_decoding_weight.to_single() + elif args_dict['data_type'] == 'bf16': + ft_encoder_weight.to_bfloat16() + ft_decoding_weight.to_bfloat16() + + remove_padding = True if batch_size > 32 else False + ft_encoder = FTT5Encoder(ft_encoder_weight.w, lib_path, encoder_config.num_heads, + encoder_config.d_kv, encoder_config.d_ff, + encoder_config.d_model, remove_padding, encoder_config.num_layers, + encoder_config.relative_attention_num_buckets, + 128, False, q_scaling, tensor_para_size, pipeline_para_size, t5_with_bias, position_embedding_type, activation_type) + ft_decoding = FTT5Decoding(ft_decoding_weight.w, lib_path, + decoder_config.num_heads, decoder_config.d_kv, + decoder_config.d_ff, encoder_config.d_model, + decoder_config.d_model, decoder_config.num_layers, + decoder_config.decoder_start_token_id, decoder_config.decoder_end_token_id, + decoder_config.vocab_size, + q_scaling, + decoder_config.relative_attention_num_buckets, max_distance=128, + tensor_para_size=tensor_para_size, pipeline_para_size=pipeline_para_size, + t5_with_bias=t5_with_bias, activation_type=activation_type, position_embedding_type=position_embedding_type) + + ft_t5 = FTT5(ft_encoder, ft_decoding) + + #metric + languages = ['de','en','es','fr'] + acc_metric = ExactStringPerCategoryMatchMetric(languages) + + preds_list = [] + labels_list = [] + results_list = [] + start = time.time() + for idx, batch in tqdm(enumerate(data_loader)): + input_token = InputToken(batch['text_enc'], batch['enc_mask']) + ft_decoding_outputs, ft_decoding_seq_lens = ft_t5(input_token, + None, + beam_size, + max_output_len, + topk, + topp, + beam_search_diversity_rate=beam_search_diversity_rate, + is_return_output_log_probs=args_dict["return_output_log_probs"], + is_return_cum_log_probs=args_dict["return_cum_log_probs"]) + ft_decoding_outputs = ft_decoding_outputs.squeeze() + preds, labels = preds_and_labels_to_text(tokenizer_mt5, torch.IntTensor(ft_decoding_outputs), batch['labels']) + langs = batch['lang'] + for _, (pred, label, lang) in enumerate(zip(preds, labels, langs)): + _ = acc_metric(pred, label, lang) + labels_list += labels + preds_list += preds + + results_list.extend([ + RequestAndResult( + model_answer=pred, + target=label, + lang=lang, + metrics=Metric(acc=(pred == label)) + ) + for lang, pred, label in zip(langs, preds, labels) + ]) + + end = time.time() + + lang_accuracy = acc_metric.compute() + + if rank == 0: + + print(f"\n[Elapsed Time]: {end - start} seconds") + + # each language + for lang in languages: + print(f'[{lang}_acc]', lang_accuracy[lang].item()) + + # total accuracy + accuracy = accuracy_score(preds_list, labels_list) + output_path = args_dict.get("output_path") + if output_path is not None and rank == 0: + output_path = pathlib.Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as output_file: + results = { + "results": { + "xnli": { + "acc": accuracy + } + }, + "output": { + "xnli": [ + dataclasses.asdict(r) for r in results_list + ] + } + } + json.dump(results, output_file) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('-batch', '--batch_size', type=int, default=1, metavar='NUMBER', + help='batch size (default: 1)') + parser.add_argument('-beam', '--beam_width', type=int, default=4, metavar='NUMBER', + help='beam width (default: 4)') + parser.add_argument('-s', '--max_output_len', type=int, default=10, metavar='NUMBER', + help='max output length (default: 10)') + parser.add_argument('-diversity_rate', '--beam_search_diversity_rate', type=float, default=0.0, metavar='NUMBER', + help='deviersity rate of beam search. default is 0. When diversity rate = 0, it is equivalent to the naive beams earch.') + parser.add_argument('-topk', '--sampling_topk', type=int, default=1, metavar='NUMBER', + help='Candidate (k) value of top k sampling in decoding. Default is 1.') + parser.add_argument('-topp', '--sampling_topp', type=float, default=0.0, metavar='NUMBER', + help='Probability (p) value of top p sampling in decoding. Default is 0.0. ') + parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', + help='data type (default: fp32)', choices=['fp32', 'fp16', 'bf16']) + parser.add_argument('-lib_path', '--lib_path', type=str, default="/workspace/FasterTransformer/build/lib/libth_t5.so", metavar='STRING', + help='the path of FasterTransformer pytorch t5 op library.') + parser.add_argument('-data_path', '--data_path', type=str, required=True, help="the xnli task data path") + parser.add_argument('-tensor_para_size', '--tensor_para_size', type=int, default=1, metavar='NUMBER', + help='size of tensor parallelism (default: 1)') + parser.add_argument('-pipeline_para_size', '--pipeline_para_size', type=int, default=1, metavar='NUMBER', + help='size of pipeline parallelism (default: 1)') + # assume checkpoint config is also in the same path + parser.add_argument('--ckpt_path', type=str, help='path to the checkpoint file.') + parser.add_argument('--output_path', help='path to results file with calculated metrics.') + parser.add_argument('--return_output_log_probs', action='store_true', + help='Return the log probability of generated tokens.') + parser.add_argument('--return_cum_log_probs', action='store_true', + help='Return the cumulative log probability of generated tokens.') + args = parser.parse_args() + + xnli_task(vars(args)) \ No newline at end of file diff --git a/examples/pytorch/tokenizer.py b/examples/pytorch/tokenizer.py new file mode 100644 index 000000000..ba78984a6 --- /dev/null +++ b/examples/pytorch/tokenizer.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def add_special_tokens_to_tokenizer(tokenizer): + + # Need to add cls, sep, mask tokens to the tokenizer if they don't exist. + # If cls, sep and mask are not attributes of the tokenizer, add it. + if not hasattr(tokenizer, 'cls_token'): + tokenizer.add_special_tokens({'cls_token': ''}) + if not hasattr(tokenizer.tokenizer, 'sep_id'): + tokenizer.add_special_tokens({'sep_token': ''}) + if not hasattr(tokenizer.tokenizer, 'mask_id'): + tokenizer.add_special_tokens({'mask_token': ''}) + + # bos, eos, pad and unk may be present in the provided spm .model file, if they are, use it. + if not hasattr(tokenizer, 'pad_token'): + if hasattr(tokenizer.tokenizer, 'pad_id') and tokenizer.tokenizer.pad_id() > 0: + tokenizer.pad_token = tokenizer.tokenizer.id_to_piece(tokenizer.tokenizer.pad_id()) + else: + tokenizer.add_special_tokens({'pad_token': ''}) + else: + tokenizer.add_special_tokens({'pad_token': ''}) + + if not hasattr(tokenizer, 'bos_token'): + if hasattr(tokenizer.tokenizer, 'bos_id') and tokenizer.tokenizer.bos_id() > 0: + tokenizer.bos_token = tokenizer.tokenizer.id_to_piece(tokenizer.tokenizer.bos_id()) + else: + tokenizer.add_special_tokens({'bos_token': ''}) + else: + tokenizer.add_special_tokens({'bos_token': ''}) + + if not hasattr(tokenizer, 'eos_token'): + if hasattr(tokenizer.tokenizer, 'eos_id') and tokenizer.tokenizer.eos_id() > 0: + tokenizer.eos_token = tokenizer.tokenizer.id_to_piece(tokenizer.tokenizer.eos_id()) + else: + tokenizer.add_special_tokens({'eos_token': ''}) + else: + tokenizer.add_special_tokens({'eos_token': ''}) diff --git a/examples/pytorch/utils.py b/examples/pytorch/utils.py index 711989719..830fe0620 100644 --- a/examples/pytorch/utils.py +++ b/examples/pytorch/utils.py @@ -1,8 +1,66 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np import torch +import typing + def print_memory_usage(info=""): - t = torch.cuda.get_device_properties(0).total_memory / 1024**2 - r = torch.cuda.memory_reserved(0) / 1024**2 - a = torch.cuda.memory_allocated(0) / 1024**2 - f = r-a # free inside reserved + t = torch.cuda.get_device_properties(0).total_memory / 1024 ** 2 + r = torch.cuda.memory_reserved(0) / 1024 ** 2 + a = torch.cuda.memory_allocated(0) / 1024 ** 2 + f = r - a # free inside reserved print(f"[INFO][{info}] total_memory: {t}, reversed: {r}, allocated: {a}") + + +def torch2np(tensor: torch.Tensor, np_data_type: typing.Optional[np.dtype] = None): + tensor = tensor.cpu() + if tensor.dtype == torch.bfloat16: + tensor = tensor.to(torch.float32) + + data = tensor.numpy() + if np_data_type is not None: + data = data.astype(np_data_type) + + return data + + +def safe_transpose(tensor): + if tensor.dim() == 1: + return tensor + if tensor.dim() == 2: + return tensor.T + raise ValueError("Tensor has more than 2 dimensions, unable to safely transpose.") + + +WEIGHT2DTYPE = { + "fp32": np.float32, + "fp16": np.float16, +} + + +def cpu_map_location(storage, loc): + return storage.cpu() + + +def gpu_map_location(storage, loc): + if loc.startswith("cuda"): + training_gpu_idx = int(loc.split(":")[1]) + inference_gpu_idx = training_gpu_idx % torch.cuda.device_count() + return storage.cuda(inference_gpu_idx) + elif loc.startswith("cpu"): + return storage.cpu() + else: + raise NotImplementedError(f"Not handled {loc}") \ No newline at end of file diff --git a/examples/pytorch/vit/VisionTransformerINT8WeightLoader.py b/examples/pytorch/vit/VisionTransformerINT8WeightLoader.py index ad7cd7b9c..a174dad3b 100644 --- a/examples/pytorch/vit/VisionTransformerINT8WeightLoader.py +++ b/examples/pytorch/vit/VisionTransformerINT8WeightLoader.py @@ -58,7 +58,7 @@ def __init__(self, layer_num, img_size, patch_size, weight_dict=None, classifier for name in pre_layer_weight_names: if name not in weight_dict.keys(): - print("Unsupport weight file: Missing weights %s" % name) + print("Unsupported weight file: Missing weights %s" % name) th_weight = weight_dict[name] if name.split('.')[-1] == "pos_embedding": @@ -90,7 +90,7 @@ def __init__(self, layer_num, img_size, patch_size, weight_dict=None, classifier # def load_weights(self, weight_path:str): # suffix = weight_path.split('.')[-1] # if suffix != 'pth': - # print("Unsupport weight file: Unrecognized format %s " % suffix) + # print("Unsupported weight file: Unrecognized format %s " % suffix) # exit(-1) # return th.load(weight_path) @@ -119,21 +119,34 @@ def listed_weights(self): continue if k.split('.')[0] == 'head': continue - # print(k, v.type()) ret.append(v) for i in range(self.layer_num): name = 'transformer.encoder.layer.{}.amaxList'.format(i) ret.append(self.weights[name]) - # print(name, self.weights[name].type()) name = 'transformer.encoder.layer.{}.h_amaxList'.format(i) ret.append(self.weights[name]) - # print(name, self.weights[name].type()) + + return ret + + def listed_weight_to_dict(self): + ret = {} + for k, v in self.weights.items(): + if k.split('.')[-1] == '_amax' or k.endswith('amaxList'): + continue + if k.split('.')[0] == 'head': + continue + ret[k] = v + + for i in range(self.layer_num): + name = 'transformer.encoder.layer.{}.amaxList'.format(i) + ret[name] = self.weights[name] + name = 'transformer.encoder.layer.{}.h_amaxList'.format(i) + ret[name] = self.weights[name] return ret def to_int8(self, ths_path='../../../lib/libpyt_vit.so'): - # print(self.weights.keys()) if 'transformer.encoder.layer.0.attn.query._input_quantizer._amax' not in self.weights: raise RuntimeError("There is no quantization node in the checkpoint, cannot be quantized to int8.") if self.int8: @@ -147,4 +160,3 @@ def to_int8(self, ths_path='../../../lib/libpyt_vit.so'): else: self.weights[k] = v.float().cpu() self.weights = checkpoint_quantization(self.weights, ths_path, verbose=False) - # print(self.weights.keys()) diff --git a/examples/pytorch/vit/VisionTransformerWeightLoader.py b/examples/pytorch/vit/VisionTransformerWeightLoader.py index be2ba9854..9c6e4ab7b 100644 --- a/examples/pytorch/vit/VisionTransformerWeightLoader.py +++ b/examples/pytorch/vit/VisionTransformerWeightLoader.py @@ -78,7 +78,7 @@ def __init__(self, layer_num, img_size, patch_size, weight_path=None, classifier for name in pre_layer_weight_names: if name not in weight_dict.files: - print("Unsupport weight file: Missing weights %s" % name) + print("Unsupported weight file: Missing weights %s" % name) is_conv = name == 'embedding/kernel' if classifier != 'token' and name == 'cls': @@ -114,20 +114,20 @@ def __init__(self, layer_num, img_size, patch_size, weight_path=None, classifier for name in layer_weight_names: w_name = name.format(layer_idx) if w_name not in weight_dict.files: - print("Unsupport weight file: Missing weights %s" % w_name) + print("Unsupported weight file: Missing weights %s" % w_name) th_weight = np2th(weight_dict[w_name]) self.weights.append(th_weight) for name in post_layer_weight_names: if name not in weight_dict.files: - print("Unsupport weight file: Missing weights %s" % name) + print("Unsupported weight file: Missing weights %s" % name) th_weight = np2th(weight_dict[name]) self.weights.append(th_weight) def load_weights(self, weight_path:str): suffix = weight_path.split('.')[-1] if suffix != 'npz': - print("Unsupport weight file: Unrecognized format %s " % suffix) + print("Unsupported weight file: Unrecognized format %s " % suffix) exit(-1) return np.load(weight_path) diff --git a/examples/pytorch/vit/infer_visiontransformer_op.py b/examples/pytorch/vit/infer_visiontransformer_op.py index 27117ba28..52d3f9e66 100644 --- a/examples/pytorch/vit/infer_visiontransformer_op.py +++ b/examples/pytorch/vit/infer_visiontransformer_op.py @@ -192,14 +192,18 @@ def validate_with_random_data(args, config, model): # diff = abs(FP32_torch_traced_output - FP32_op_output) diff = abs(FP32_torch_output - FP32_op_output) print("FP32_torch_traced_output vs FP32_op_output , avg diff : ", diff.mean(), "max diff : ", diff.max()) + assert diff.mean() < 0.004, "[ERROR] VIT OP TEST FAIL !" # diff = abs(FP16_torch_traced_output - FP16_op_output) diff = abs(FP16_torch_output - FP16_op_output) print("FP16_torch_traced_output vs FP16_op_output , avg diff : ", diff.mean(), "max diff : ", diff.max()) - + assert diff.mean() < 0.005, "[ERROR] VIT OP TEST FAIL !" + print("[INFO] VIT OP TEST PASS !") + if __name__ == '__main__': args = parse_option() - seed = args.seed + int(time.time()) + # seed = args.seed + int(time.time()) + seed = args.seed torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True diff --git a/examples/tensorflow/bert/bert-quantization/README_orig.md b/examples/tensorflow/bert/bert-quantization/README_orig.md index 4e89068fa..dddd70f21 100644 --- a/examples/tensorflow/bert/bert-quantization/README_orig.md +++ b/examples/tensorflow/bert/bert-quantization/README_orig.md @@ -4,71 +4,74 @@ This repository provides a script and recipe to train the BERT model for TensorF ## Table Of Contents -- [Model overview](#model-overview) - * [Model architecture](#model-architecture) - * [Default configuration](#default-configuration) - * [Feature support matrix](#feature-support-matrix) - * [Features](#features) - * [Mixed precision training](#mixed-precision-training) - * [Enabling mixed precision](#enabling-mixed-precision) - * [Glossary](#glossary) -- [Setup](#setup) - * [Requirements](#requirements) -- [Quick Start Guide](#quick-start-guide) -- [Advanced](#advanced) - * [Scripts and sample code](#scripts-and-sample-code) - * [Parameters](#parameters) - * [Command-line options](#command-line-options) - * [Getting the data](#getting-the-data) - * [Dataset guidelines](#dataset-guidelines) - * [Multi-dataset](#multi-dataset) - * [Training process](#training-process) - * [Pre-training](#pre-training) - * [Fine tuning](#fine-tuning) - * [Multi-node](#multi-node) - * [Inference process](#inference-process) - * [Inference Process With TensorRT](#inference-process-with-tensorrt) - * [Deploying the BERT model using TensorRT Inference Server](#deploying-the-bert-model-using-tensorrt-inference-server) - * [BioBERT](#biobert) -- [Performance](#performance) - * [Benchmarking](#benchmarking) - * [Training performance benchmark](#training-performance-benchmark) - * [Inference performance benchmark](#inference-performance-benchmark) - * [Results](#results) - * [Training accuracy results](#training-accuracy-results) - * [Pre-training accuracy: single-node](#pre-training-accuracy-single-node) - * [Pre-training accuracy: multi-node](#pre-training-accuracy-multi-node) - * [Fine-tuning accuracy for SQuAD: NVIDIA DGX-2 (16x V100 32G)](#fine-tuning-accuracy-for-squad-nvidia-dgx-2-16x-v100-32g) - * [Training stability test](#training-stability-test) - * [Pre-training SQuAD stability test: NVIDIA DGX-2 (512x V100 32G)](#fine-tuning-squad-stability-test-nvidia-dgx-2-512x-v100-32g) - * [Fine-tuning SQuAD stability test: NVIDIA DGX-2 (16x V100 32G)](#fine-tuning-squad-stability-test-nvidia-dgx-2-16x-v100-32g) - * [Training performance results](#training-performance-results) - * [Training performance: NVIDIA DGX-1 (8x V100 16G)](#training-performance-nvidia-dgx-1-8x-v100-16g) - * [Pre-training training performance: single-node on 16G](#pre-training-training-performance-single-node-on-16g) - * [Pre-training training performance: multi-node on 16G](#pre-training-training-performance-multi-node-on-16g) - * [Fine-tuning training performance for SQuAD on 16G](#fine-tuning-training-performance-for-squad-on-16g) - * [Training performance: NVIDIA DGX-1 (8x V100 32G)](#training-performance-nvidia-dgx-1-8x-v100-32g) - * [Pre-training training performance: single-node on 32G](#pre-training-training-performance-single-node-on-32g) - * [Fine-tuning training performance for SQuAD on 32G](#fine-tuning-training-performance-for-squad-on-32g) - * [Training performance: NVIDIA DGX-2 (16x V100 32G)](#training-performance-nvidia-dgx-2-16x-v100-32g) - * [Pre-training training performance: single-node on DGX-2 32G](#pre-training-training-performance-single-node-on-dgx-2-32g) - * [Pre-training training performance: multi-node on DGX-2 32G](#pre-training-training-performance-multi-node-on-dgx-2-32g) - * [Fine-tuning training performance for SQuAD on DGX-2 32G](#fine-tuning-training-performance-for-squad-on-dgx-2-32g) - * [Inference performance results](#inference-performance-results) - * [Inference performance: NVIDIA DGX-1 (1x V100 16G)](#inference-performance-nvidia-dgx-1-1x-v100-16g) - * [Pre-training inference performance on 16G](#pre-training-inference-performance-on-16g) - * [Fine-tuning inference performance for SQuAD on 16G](#fine-tuning-inference-performance-for-squad-on-16g) - * [Inference performance: NVIDIA DGX-1 (1x V100 32G)](#inference-performance-nvidia-dgx-1-1x-v100-32g) - * [Pre-training inference performance on 32G](#pre-training-inference-performance-on-32g) - * [Fine-tuning inference performance for SQuAD on 32G](#fine-tuning-inference-performance-for-squad-on-32g) - * [Inference performance: NVIDIA DGX-2 (1x V100 32G)](#inference-performance-nvidia-dgx-2-1x-v100-32g) - * [Pre-training inference performance on DGX-2 32G](#pre-training-inference-performance-on-dgx-2-32g) - * [Fine-tuning inference performance for SQuAD on DGX-2 32G](#fine-tuning-inference-performance-for-squad-on-dgx-2-32g) - * [Inference performance: NVIDIA Tesla T4 (1x T4 16G)](#inference-performance-nvidia-tesla-t4-1x-t4-16g) - * [Fine-tuning inference performance for SQuAD on Tesla T4 16G](#fine-tuning-inference-performance-for-squad-on-tesla-t4-16g) -- [Release notes](#release-notes) - * [Changelog](#changelog) - * [Known issues](#known-issues) +- [BERT For TensorFlow](#bert-for-tensorflow) + - [Table Of Contents](#table-of-contents) + - [Model overview](#model-overview) + - [Model architecture](#model-architecture) + - [Default configuration](#default-configuration) + - [Feature support matrix](#feature-support-matrix) + - [Features](#features) + - [Mixed precision training](#mixed-precision-training) + - [Enabling mixed precision](#enabling-mixed-precision) + - [Glossary](#glossary) + - [Setup](#setup) + - [Requirements](#requirements) + - [Quick Start Guide](#quick-start-guide) + - [Advanced](#advanced) + - [Scripts and sample code](#scripts-and-sample-code) + - [Parameters](#parameters) + - [Command-line options](#command-line-options) + - [Getting the data](#getting-the-data) + - [Dataset guidelines](#dataset-guidelines) + - [Multi-dataset](#multi-dataset) + - [Training process](#training-process) + - [Pre-training](#pre-training) + - [Fine tuning](#fine-tuning) + - [Multi-node](#multi-node) + - [Inference process](#inference-process) + - [Inference Process With TensorRT](#inference-process-with-tensorrt) + - [Deploying the BERT model using TensorRT Inference Server](#deploying-the-bert-model-using-tensorrt-inference-server) + - [BioBERT](#biobert) + - [Performance](#performance) + - [Benchmarking](#benchmarking) + - [Training performance benchmark](#training-performance-benchmark) + - [Inference performance benchmark](#inference-performance-benchmark) + - [Results](#results) + - [Training accuracy results](#training-accuracy-results) + - [Training accuracy](#training-accuracy) + - [Pre-training accuracy: single-node](#pre-training-accuracy-single-node) + - [Pre-training accuracy: multi-node](#pre-training-accuracy-multi-node) + - [Fine-tuning accuracy for SQuAD: NVIDIA DGX-2 (16x V100 32G)](#fine-tuning-accuracy-for-squad-nvidia-dgx-2-16x-v100-32g) + - [Training stability test](#training-stability-test) + - [Pre-training stability test: NVIDIA DGX-2 (512x V100 32G)](#pre-training-stability-test-nvidia-dgx-2-512x-v100-32g) + - [Fine-tuning SQuAD stability test: NVIDIA DGX-2 (16x V100 32G)](#fine-tuning-squad-stability-test-nvidia-dgx-2-16x-v100-32g) + - [Training performance results](#training-performance-results) + - [Training performance: NVIDIA DGX-1 (8x V100 16G)](#training-performance-nvidia-dgx-1-8x-v100-16g) + - [Pre-training training performance: single-node on 16G](#pre-training-training-performance-single-node-on-16g) + - [Pre-training training performance: multi-node on 16G](#pre-training-training-performance-multi-node-on-16g) + - [Fine-tuning training performance for SQuAD on 16G](#fine-tuning-training-performance-for-squad-on-16g) + - [Training performance: NVIDIA DGX-1 (8x V100 32G)](#training-performance-nvidia-dgx-1-8x-v100-32g) + - [Pre-training training performance: single-node on 32G](#pre-training-training-performance-single-node-on-32g) + - [Fine-tuning training performance for SQuAD on 32G](#fine-tuning-training-performance-for-squad-on-32g) + - [Training performance: NVIDIA DGX-2 (16x V100 32G)](#training-performance-nvidia-dgx-2-16x-v100-32g) + - [Pre-training training performance: single-node on DGX-2 32G](#pre-training-training-performance-single-node-on-dgx-2-32g) + - [Pre-training training performance: multi-node on DGX-2H 32G](#pre-training-training-performance-multi-node-on-dgx-2h-32g) + - [Fine-tuning training performance for SQuAD on DGX-2 32G](#fine-tuning-training-performance-for-squad-on-dgx-2-32g) + - [Inference performance results](#inference-performance-results) + - [Inference performance: NVIDIA DGX-1 (1x V100 16G)](#inference-performance-nvidia-dgx-1-1x-v100-16g) + - [Pre-training inference performance on 16G](#pre-training-inference-performance-on-16g) + - [Fine-tuning inference performance for SQuAD on 16G](#fine-tuning-inference-performance-for-squad-on-16g) + - [Inference performance: NVIDIA DGX-1 (1x V100 32G)](#inference-performance-nvidia-dgx-1-1x-v100-32g) + - [Pre-training inference performance on 32G](#pre-training-inference-performance-on-32g) + - [Fine-tuning inference performance for SQuAD on 32G](#fine-tuning-inference-performance-for-squad-on-32g) + - [Inference performance: NVIDIA DGX-2 (1x V100 32G)](#inference-performance-nvidia-dgx-2-1x-v100-32g) + - [Pre-training inference performance on DGX-2 32G](#pre-training-inference-performance-on-dgx-2-32g) + - [Fine-tuning inference performance for SQuAD on DGX-2 32G](#fine-tuning-inference-performance-for-squad-on-dgx-2--32g) + - [Inference performance: NVIDIA Tesla T4 (1x T4 16G)](#inference-performance-nvidia-tesla-t4-1x-t4-16g) + - [Fine-tuning inference performance for SQuAD on Tesla T4 16G](#fine-tuning-inference-performance-for-squad-on-tesla-t4-16g) + - [Release notes](#release-notes) + - [Changelog](#changelog) + - [Known issues](#known-issues) @@ -84,7 +87,7 @@ Other publicly available implementations of BERT include: 4. [gluon-nlp](https://github.com/dmlc/gluon-nlp/tree/master/scripts/bert) 5. [Google's official implementation](https://github.com/google-research/bert) -This model is trained with mixed precision using Tensor Cores on NVIDIA Volta and Turing GPUs. Therefore, researchers can get results upto 4x faster than training without Tensor Cores, while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time. +This model is trained with mixed precision using Tensor Cores on NVIDIA Volta and Turing GPUs. Therefore, researchers can get results up to 4x faster than training without Tensor Cores, while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time. ### Model architecture @@ -137,7 +140,7 @@ Multi-GPU training with Horovod - Our model uses Horovod to implement efficient [LAMB](https://arxiv.org/pdf/1904.00962.pdf) stands for Layerwise Adaptive Moments based optimizer, is a large batch optimization technique that helps accelerate training of deep neural networks using large minibatches. It allows using a global batch size of 65536 and 32768 on sequence lengths 128 and 512 respectively, compared to a batch size of 256 for Adam. The optimized implementation accumulates 1024 gradients batches in phase 1 and 4096 steps in phase 2 before updating weights once. This results in 27% training speedup on a single DGX2 node. On multi-node systems, LAMB allows scaling up to 1024 GPUs resulting in training speedups of up to 17x in comparison to [Adam](https://arxiv.org/pdf/1412.6980.pdf). Adam has limitations on the learning rate that can be used since it is applied globally on all parameters whereas LAMB follows a layerwise learning rate strategy. -NVLAMB adds necessary tweaks to [LAMB version 1](https://arxiv.org/abs/1904.00962v1), to ensure correct convergence. A guide to implementating the LAMB optimizer can be found in our [article](https://medium.com/@NvidiaAI/a-guide-to-optimizer-implementation-for-bert-at-scale-8338cc7f45fd) on Medium.com. The algorithm is as follows: +NVLAMB adds necessary tweaks to [LAMB version 1](https://arxiv.org/abs/1904.00962v1), to ensure correct convergence. A guide to implement the LAMB optimizer can be found in our [article](https://medium.com/@NvidiaAI/a-guide-to-optimizer-implementation-for-bert-at-scale-8338cc7f45fd) on Medium.com. The algorithm is as follows: ![NVLAMB](data/images/images_nvlamb.png) ### Mixed precision training @@ -690,7 +693,7 @@ Our results were obtained by running the `scripts/run_pretraining_lamb.sh` train | DGX2H | 64 | FP16 | 32, 8 | 2, 4 | 2.44 | 1.64 | | DGX2H | 64 | FP32 | 32, 4 | 2, 8 | 5.76 | 1.66 | -Note: Time to train includes upto 16 minutes of start up time for every restart. Experiments were run on clusters with a maximum wall clock time of 8 hours. +Note: Time to train includes up to 16 minutes of start up time for every restart. Experiments were run on clusters with a maximum wall clock time of 8 hours. ###### Fine-tuning accuracy for SQuAD: NVIDIA DGX-2 (16x V100 32G) @@ -1137,7 +1140,7 @@ To achieve these same results, follow the [Quick Start Guide](#quick-start-guide ### Changelog -Janurary 2020 +January 2020 - Added inference with TensorRT November 2019 diff --git a/examples/tensorflow/bert/bert-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/calib/calibrator.py b/examples/tensorflow/bert/bert-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/calib/calibrator.py index a9257fecb..add22f8d5 100644 --- a/examples/tensorflow/bert/bert-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/calib/calibrator.py +++ b/examples/tensorflow/bert/bert-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/calib/calibrator.py @@ -30,7 +30,7 @@ class Calibrator(): - """A calibrator that wraps up a collector and relavent tensors and does calibration + """A calibrator that wraps up a collector and relevant tensors and does calibration Args: tensor_name_prefix: A string. The common name prefix of `quant_min`, `quant_max`, and `calib_tensor`. @@ -63,7 +63,7 @@ def calib_step_op(self, graph): """get the op for one step of calibration Args: - graph: The being excuted TensorFlow Graph. + graph: The being executed TensorFlow Graph. Returns: A wrapped TensorFlow op of `tf.py_function` for one calib step. @@ -258,7 +258,7 @@ def load_range(self, sess): def compute_and_load_range(self, sess, **compute_range_args): """wraps :func:`compute_range ` - and :func:`load_range ` for convinience""" + and :func:`load_range ` for convenience""" self.compute_range(**compute_range_args) self.load_range(sess) @@ -306,7 +306,7 @@ def get_calibrators(collection_name_prefix, collection_name_prefix: A string. Determine the collection of tensors. Need to be unified with FakeQuantizer. graph: an instance of `tf.Graph`, if None, use default graph. Default None. collector_types: A string. What collector to use. One of `["max", "histogram"]`. Default `"max"`. - Collector arugments can be passed by collector_args. + Collector arguments can be passed by collector_args. If :func:`MaxCollector ` is used, only `axis` and `track_minmax` can be passed to collector_args. If :func:`HistogramCollector ` is used, diff --git a/examples/tensorflow/bert/bert-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/calib/max.py b/examples/tensorflow/bert/bert-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/calib/max.py index 8f7c31c05..f234e7e32 100644 --- a/examples/tensorflow/bert/bert-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/calib/max.py +++ b/examples/tensorflow/bert/bert-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/calib/max.py @@ -28,7 +28,7 @@ class MaxCollector(): Args: axis: None or integer. axis which will have its own max for computing scaling factor. If None, collect per tensor min/max. Default None - track_minmax: A boolean. If true, track all min/max it sees in addtion to the returned calib_min/calib_max. + track_minmax: A boolean. If true, track all min/max it sees in addition to the returned calib_min/calib_max. Default False """ diff --git a/examples/tensorflow/bert/bert-quantization/run_pretraining.py b/examples/tensorflow/bert/bert-quantization/run_pretraining.py index a3f3a236e..557ffc093 100644 --- a/examples/tensorflow/bert/bert-quantization/run_pretraining.py +++ b/examples/tensorflow/bert/bert-quantization/run_pretraining.py @@ -465,7 +465,7 @@ def input_fn(): # We must `drop_remainder` on training because the TPU requires fixed # size dimensions. For eval, we assume we are evaluating on the CPU or GPU - # and we *don't* want to drop the remainder, otherwise we wont cover + # and we *don't* want to drop the remainder, otherwise we won't cover # every sample. d = d.apply( tf.contrib.data.map_and_batch( diff --git a/examples/tensorflow/bert/bert-quantization/run_squad.py b/examples/tensorflow/bert/bert-quantization/run_squad.py index 10daddb14..42f10671c 100644 --- a/examples/tensorflow/bert/bert-quantization/run_squad.py +++ b/examples/tensorflow/bert/bert-quantization/run_squad.py @@ -752,7 +752,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, start_logit=pred.start_logit, end_logit=pred.end_logit)) - # if we didn't inlude the empty option in the n-best, inlcude it + # if we didn't inlude the empty option in the n-best, include it if FLAGS.version_2_with_negative: if "" not in seen_predictions: nbest.append( diff --git a/examples/tensorflow/bert/bert-quantization/tokenization.py b/examples/tensorflow/bert/bert-quantization/tokenization.py index 6e53ce767..ca3c60613 100644 --- a/examples/tensorflow/bert/bert-quantization/tokenization.py +++ b/examples/tensorflow/bert/bert-quantization/tokenization.py @@ -413,7 +413,7 @@ def tokenize(self, text): def _is_whitespace(char): """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them + # \t, \n, and \r are technically controlled characters but we treat them # as whitespace since they are generally considered as such. if char == " " or char == "\t" or char == "\n" or char == "\r": return True diff --git a/examples/tensorflow/bert/utils/bert.py b/examples/tensorflow/bert/utils/bert.py index bd52101e0..fa2b35eb8 100755 --- a/examples/tensorflow/bert/utils/bert.py +++ b/examples/tensorflow/bert/utils/bert.py @@ -355,7 +355,7 @@ def ft_bert(inputs, encoder_vars_dict: A dict of tf.Tensor or numpy array. The variables for encoder. They can be either some tensor or some numpy array. The key is the name of the tensor, like 'layer_0/attention/self/query/kernel:0'. - Teh value is the corresponding tensor or numpy array + The value is the corresponding tensor or numpy array sequence_length: A tf.Tensor or numpy array with shape [batch_size]. The sequence length of the sentences Outputs: diff --git a/examples/tensorflow/bert/utils/common.py b/examples/tensorflow/bert/utils/common.py index da8ab321b..95aa37a8d 100644 --- a/examples/tensorflow/bert/utils/common.py +++ b/examples/tensorflow/bert/utils/common.py @@ -45,7 +45,7 @@ def __init__( self, dtype: The data type of weights initializer and inputs. kernel_init_range: The initializer range of kernel for all convolution layer and fully-connected layer. kernel_init_range: The initializer range of bias for all convolution layer and fully-connected layer. - fuse_qkv: bool. Wether fuse the q, k, v gemm or not. + fuse_qkv: bool. Whether fuse the q, k, v gemm or not. remove_padding: bool. Remove the padding of sentences of encoder. int8_mode: Mode of int8 quantization. 0 means not using int8 quantization, 1 means using int8 quantization without quantizing residuals, 2 means using int8 quantization with quantizing residuals. diff --git a/examples/tensorflow/common_utils/common.py b/examples/tensorflow/common_utils/common.py index 58d3725bb..5d7b31691 100644 --- a/examples/tensorflow/common_utils/common.py +++ b/examples/tensorflow/common_utils/common.py @@ -44,7 +44,7 @@ def __init__( self, dtype: The data type of weights initializer and inputs. kernel_init_range: The initializer range of kernel for all convolution layer and fully-connected layer. kernel_init_range: The initializer range of bias for all convolution layer and fully-connected layer. - fuse_qkv: bool. Wether fuse the q, k, v gemm or not. + fuse_qkv: bool. Whether fuse the q, k, v gemm or not. remove_padding: bool. Remove the padding of sentences of encoder. int8_mode: Mode of int8 quantization. 0 means not using int8 quantization, 1 means using int8 quantization without quantizing residuals, 2 means using int8 quantization with quantizing residuals. diff --git a/examples/tensorflow/decoder/decoder_example.py b/examples/tensorflow/decoder/decoder_example.py index d5df7e588..515004b13 100644 --- a/examples/tensorflow/decoder/decoder_example.py +++ b/examples/tensorflow/decoder/decoder_example.py @@ -72,7 +72,7 @@ parser.add_argument('-v', '--vocab_size', type=int, default=30000, metavar='BOOL', help='vocabulary size. (default: 30000).') parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', - help='data type (default: fp32)', choices=['fp32', 'fp16']) + help='data type (default: fp32)', choices=['fp32', 'fp16', 'bf16']) parser.add_argument('-time', '--test_time', type=int, default=0, metavar='BOOL', help='test the time or not. (default: False (0)), True is 1.', choices=[0, 1]) @@ -115,6 +115,8 @@ if args.data_type == "fp16": tf_datatype = tf.float16 np_datatype = np.float16 + elif args.data_type == "bf16": + tf_datatype = tf.bfloat16 ## numpy doesn't support bfloat16, fallback to float32 decoder_args = TransformerArgument(beam_width=beam_width, head_num=head_num, @@ -135,7 +137,7 @@ 0.0) embedding_table = np.random.randn(vocab_size, hidden_dim).astype(np_datatype) * 0.01 # a [vocab_size, hidden_dim] table - embedding_table = tf.convert_to_tensor(embedding_table) + embedding_table = tf.convert_to_tensor(embedding_table, dtype = tf_datatype) memory, memory_sequence_length = generate_encoder_result( batch_size, max_seq_len, memory_hidden_dim, tf_datatype) diff --git a/examples/tensorflow/decoder/utils/common.py b/examples/tensorflow/decoder/utils/common.py index beb22c8eb..4c4b1723c 100644 --- a/examples/tensorflow/decoder/utils/common.py +++ b/examples/tensorflow/decoder/utils/common.py @@ -46,7 +46,7 @@ def __init__( self, dtype: The data type of weights initializer and inputs. kernel_init_range: The initializer range of kernel for all convolution layer and fully-connected layer. kernel_init_range: The initializer range of bias for all convolution layer and fully-connected layer. - fuse_qkv: bool. Wether fuse the q, k, v gemm or not. + fuse_qkv: bool. Whether fuse the q, k, v gemm or not. remove_padding: bool. Remove the padding of sentences of encoder. int8_mode: Mode of int8 quantization. 0 means not using int8 quantization, 1 means using int8 quantization without quantizing residuals, 2 means using int8 quantization with quantizing residuals. diff --git a/examples/tensorflow/decoder/utils/decoder.py b/examples/tensorflow/decoder/utils/decoder.py index 07419b1f0..9a43db526 100644 --- a/examples/tensorflow/decoder/utils/decoder.py +++ b/examples/tensorflow/decoder/utils/decoder.py @@ -90,7 +90,7 @@ def tf_decoder(decoder_args, The results of encoder transformer layer. The rank must be 3. Note that it must be extended by beam_width times memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. - The lenght of each sentence of results of encoder. + The length of each sentence of results of encoder. Note that it must be extended by beam_width times step: A tf.Tensor with tf.int type. The current step in the translation process. cache: A dict. The cache space to store the keys and values of attention layers. @@ -132,7 +132,7 @@ def tf_decoder(decoder_args, else: ''' This progress wants to prevent a addictional tf.concat to concat the q, k, v kernels for decoder op - becuase the concat bring large overhead for small batch size. + because the concat bring large overhead for small batch size. ''' queries = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, bias_initializer=create_initializer(b_init_range, data_type), @@ -346,7 +346,7 @@ def op_decoder(inputs, The results of encoder transformer layer. The rank must be 3. Note that it must be extended by beam_width times memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. - The lenght of each sentence of results of encoder. + The length of each sentence of results of encoder. Note that it must be extended by beam_width times op_self_cache: A tf.Tensor with shape [num_layer, 2, None, batch_size * beam_width, hidden_dimension]. The cache space to store the keys and values of first attention layer in each step. @@ -367,7 +367,7 @@ def op_decoder(inputs, ''' ''' - If fuse_qkv == Ture, this means that the computation of q, k, v in decoder are fused in one convolution. + If fuse_qkv == True, this means that the computation of q, k, v in decoder are fused in one convolution. Therefore, we need to split them and then passing into the decoder op. The split will bring additional overhead, especially when the batch size is small because the computation time is short. diff --git a/examples/tensorflow/decoder/utils/decoding.py b/examples/tensorflow/decoder/utils/decoding.py index 51ebc915c..09864dece 100644 --- a/examples/tensorflow/decoder/utils/decoding.py +++ b/examples/tensorflow/decoder/utils/decoding.py @@ -210,7 +210,7 @@ def tf_beamsearch_decoding(memory_tensor, The results of encoder transformer layer. The rank must be 3. Note that it must be extended by beam_width times. memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. - The lenght of each sentence of results of encoder. + The length of each sentence of results of encoder. Note that it must be extended by beam_width times. embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension]. The embedding table of embedding lookup for each step. @@ -350,7 +350,7 @@ def tf_sampling_decoding(memory_tensor, memory_tensor: A tf.tensor with shape [batch_size, max(memory_sequence_length), encoder_hidden_dimension]. The results of encoder transformer layer. The rank must be 3. memory_sequence_length: A tf.Tensor with shape [batch_size], type tf.int. - The lenght of each sentence of results of encoder. + The length of each sentence of results of encoder. embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension]. The embedding table of embedding lookup for each step. decoder_args: The arguments for decoding. The details are in the class "DecodingSamplingArgument" of common.py diff --git a/examples/tensorflow/decoding/decoding_example.py b/examples/tensorflow/decoding/decoding_example.py index 788a4fe67..bd88e7da2 100644 --- a/examples/tensorflow/decoding/decoding_example.py +++ b/examples/tensorflow/decoding/decoding_example.py @@ -66,7 +66,7 @@ parser.add_argument('-v', '--vocab_size', type=int, default=30000, metavar='BOOL', help='vocabulary size. (default: 30000).') parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', - help='data type (default: fp32)', choices=['fp32', 'fp16']) + help='data type (default: fp32)', choices=['fp32', 'fp16', 'bf16']) parser.add_argument('-x', '--use_XLA', type=int, default=0, metavar='BOOL', help='use XLA (default: False 0)', choices=[0, 1]) parser.add_argument('-time', '--test_time', type=str, default='', metavar='STRING', @@ -117,6 +117,8 @@ if args.data_type == "fp16": tf_datatype = tf.float16 np_datatype = np.float16 + elif args.data_type == 'bf16': + tf_datatype = tf.bfloat16 use_XLA = args.use_XLA beam_search_diversity_rate = args.beam_search_diversity_rate sampling_topk = args.sampling_topk diff --git a/examples/tensorflow/decoding/translate_example.py b/examples/tensorflow/decoding/translate_example.py index 5e2f1137e..955eecebb 100644 --- a/examples/tensorflow/decoding/translate_example.py +++ b/examples/tensorflow/decoding/translate_example.py @@ -95,6 +95,8 @@ def translate(args_dict): max_ite = args_dict['max_iteration'] if args_dict['data_type'] == "fp16": tf_datatype = tf.float16 + elif args_dict['data_type'] == "bf16": + tf_datatype = tf.bfloat16 print("\n=============== Argument ===============") for key in args_dict: @@ -302,14 +304,18 @@ def __init__(self, token_op, length_op, name): # Iterates on the dataset. float_checkpoint_path = tf.train.latest_checkpoint(model_dir) half_checkpoint_path = tf.train.latest_checkpoint(model_dir + "_fp16") + bf16_checkpoint_path = tf.train.latest_checkpoint(model_dir + "_bf16") float_var_list = [] half_var_list = [] + bf16_var_list = [] for var in tf.global_variables(): if var.dtype.base_dtype == tf.float32: float_var_list.append(var) elif var.dtype.base_dtype == tf.float16: half_var_list.append(var) + elif var.dtype.base_dtype == tf.bfloat16: + bf16_var_list.append(var) config = tf.ConfigProto() config.gpu_options.allow_growth = True @@ -321,6 +327,9 @@ def __init__(self, token_op, length_op, name): if(len(half_var_list) > 0): half_saver = tf.train.Saver(half_var_list) half_saver.restore(sess, half_checkpoint_path) + if(len(bf16_var_list) > 0): + bf16_saver = tf.train.Saver(bf16_var_list) + bf16_saver.restore(sess, bf16_checkpoint_path) sess.run(tf.tables_initializer()) sess.run(iterator.initializer) @@ -405,7 +414,7 @@ def main(): parser.add_argument('-topp', '--sampling_topp', type=float, default=0.0, metavar='NUMBER', help='Probability (p) value of top p sampling in decoding. Default is 0.0. ') parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', - help='data type (default: fp32)', choices=['fp32', 'fp16']) + help='data type (default: fp32)', choices=['fp32', 'fp16', 'bf16']) parser.add_argument('-max_ite', '--max_iteration', type=int, default=100000, metavar='NUMBER', help='Maximum iteraiton for translation, default is 100000 (as large as possible to run all test set).') args = parser.parse_args() diff --git a/examples/tensorflow/decoding/utils/ft_decoding.py b/examples/tensorflow/decoding/utils/ft_decoding.py index ce78ad360..a7453fb46 100644 --- a/examples/tensorflow/decoding/utils/ft_decoding.py +++ b/examples/tensorflow/decoding/utils/ft_decoding.py @@ -55,7 +55,7 @@ def ft_decoding(memory_tensor, The results of encoder transformer layer. The rank must be 3. Note that it must be extended by beam_width times. memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. - The lenght of each sentence of results of encoder. + The length of each sentence of results of encoder. Note that it must be extended by beam_width times. embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension]. The embedding table of embedding lookup for each step. @@ -152,7 +152,7 @@ def ft_decoding(memory_tensor, top_k=decoding_args.top_k, top_p=decoding_args.top_p, temperature=1.0, - len_penalty=1.0, + len_penalty=0.0, repetition_penalty=1.0) if decoder_args.beam_width > 1: diff --git a/examples/tensorflow/encoder/utils/encoder.py b/examples/tensorflow/encoder/utils/encoder.py index f08661dcb..faca927f4 100644 --- a/examples/tensorflow/encoder/utils/encoder.py +++ b/examples/tensorflow/encoder/utils/encoder.py @@ -215,7 +215,7 @@ def ft_encoder_opennmt(inputs, encoder_vars_dict: A dict of tf.Tensor or numpy array. The variables for encoder. They can be either some tensor or some numpy array. The key is the name of the tensor, like 'layer_0/attention/self/query/kernel:0'. - Teh value is the corresponding tensor or numpy array + The value is the corresponding tensor or numpy array sequence_length: A tf.Tensor or numpy array with shape [batch_size]. The sequence length of the sentences Outputs: diff --git a/examples/tensorflow/gpt/gpt_example.py b/examples/tensorflow/gpt/gpt_example.py index ed3d845b3..4e6683185 100644 --- a/examples/tensorflow/gpt/gpt_example.py +++ b/examples/tensorflow/gpt/gpt_example.py @@ -73,7 +73,7 @@ def sample_model( :model_name=124M : String, which model to use :nsamples=0 : Number of samples to return, if 0, continues to - generate samples indefinately. + generate samples indefinitely. :batch_size=1 : Number of batches (only affects speed/memory). :length=None : Number of tokens in generated text, if None (default), is determined by model hyperparameters @@ -169,13 +169,9 @@ def sample_model( for i in range(batch_size): generated += 1 - if beam_width > 1: - for j in range(beam_width): - print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) - print(enc.decode(op_out[i][j][:seq_len[i][j]])) - else: - print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) - print(enc.decode(op_out[i][:seq_len[i]])) + for j in range(beam_width): + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(enc.decode(op_out[i][j][:seq_len[i][j]])) def finalize(input_ids, beam_width, parent_ids, sequence_lengths, outputs, end_id, max_seq_len=None): maximum_lengths = tf.reduce_max(tf.reshape( @@ -229,7 +225,7 @@ def ft_gpt_op(var_dict, gpt_op_module = tf.load_op_library(os.path.join('./lib/libtf_gpt.so')) data_type = decoder_args.dtype - output_ids, parent_ids, sequence_length, cum_log_probs = gpt_op_module.gpt( + output_ids, sequence_length, cum_log_probs = gpt_op_module.gpt( input_ids, # 0 input_lengths, # 1 [tf.cast(var_dict["model/h%d/ln_1/b:0" % l], data_type) for l in range(decoder_args.num_layer)], # 2 @@ -262,7 +258,7 @@ def ft_gpt_op(var_dict, top_k=decoding_args.top_k, top_p=decoding_args.top_p, temperature=1.0, - len_penalty=1.0, + len_penalty=0.0, repetition_penalty=1.0, output_log_probs=True, request_output_length=decoding_args.max_seq_len - input_lengths.max()) diff --git a/examples/tensorflow/gpt/utils/gpt_token_encoder.py b/examples/tensorflow/gpt/utils/gpt_token_encoder.py index 8a560b6a0..739c32b83 100644 --- a/examples/tensorflow/gpt/utils/gpt_token_encoder.py +++ b/examples/tensorflow/gpt/utils/gpt_token_encoder.py @@ -50,7 +50,7 @@ def bytes_to_unicode(): The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. + This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ diff --git a/examples/tensorflow/xlnet/modeling.py b/examples/tensorflow/xlnet/modeling.py index 93c60b386..bc600f31a 100644 --- a/examples/tensorflow/xlnet/modeling.py +++ b/examples/tensorflow/xlnet/modeling.py @@ -394,7 +394,7 @@ def transformer_xl(inp_k, n_token, n_layer, d_model, n_head, use_tpu=True, input_mask=None, perm_mask=None, seg_id=None, reuse_len=None, ff_activation='relu', target_mapping=None, - use_float16=False, scope='transformer', **kwargs): + data_type="fp32", scope='transformer', **kwargs): """ Defines a Transformer-XL computation graph with additional support for XLNet. @@ -442,7 +442,7 @@ def transformer_xl(inp_k, n_token, n_layer, d_model, n_head, init_std: float, initialize the parameters with a normal distribution with mean 0 and stddev init_std. Only effective when init="normal". mem_len: int, the number of tokens to cache. - reuse_len: int, the number of tokens in the currect batch to be cached + reuse_len: int, the number of tokens in the current batch to be cached and reused in the future. bi_data: bool, whether to use bidirectional input pipeline. Usually set to True during pretraining and False during finetuning. @@ -456,7 +456,11 @@ def transformer_xl(inp_k, n_token, n_layer, d_model, n_head, """ tf.logging.info('memory input {}'.format(mems)) - tf_float = tf.float16 if use_float16 else tf.float32 + tf_float = tf.float32 + if data_type == "fp16": + tf_float = tf.float16 + elif data_type == "bf16": + tf_float = tf.bfloat16 tf.logging.info('Use float type {}'.format(tf_float)) new_mems = [] diff --git a/examples/tensorrt/swin/builder_fp16.py b/examples/tensorrt/swin/builder_fp16.py index cc7a6787f..523e73854 100755 --- a/examples/tensorrt/swin/builder_fp16.py +++ b/examples/tensorrt/swin/builder_fp16.py @@ -246,9 +246,8 @@ def build_engine(config, args, weights_dict): sw_output.precision = trt.float16 sw_output.set_output_type(0, trt.float16) - output_size = weights_dict["head.bias"].shape[0] - output = network.add_fully_connected(sw_output.get_output(0), output_size, trt.Weights(weights_dict["head.weight"].numpy().astype(np.float16).flatten()), trt.Weights(weights_dict["head.bias"].numpy().astype(np.float16).flatten())) - network.mark_output(output.get_output(0)) + sw_output = network.add_identity(sw_output.get_output(0)) + network.mark_output(sw_output.get_output(0)) engine = builder.build_engine(network, builder_config) return engine diff --git a/examples/tensorrt/swin/builder_fp32.py b/examples/tensorrt/swin/builder_fp32.py index f25edc58d..bb1e1aa40 100755 --- a/examples/tensorrt/swin/builder_fp32.py +++ b/examples/tensorrt/swin/builder_fp32.py @@ -244,9 +244,7 @@ def build_engine(config, args, weights_dict): #import pdb;pdb.set_trace() sw_output = swin_transformer(network, config, args, input_img, weights_dict) - output_size = weights_dict["head.bias"].shape[0] - output = network.add_fully_connected(sw_output.get_output(0), output_size, trt.Weights(weights_dict["head.weight"].numpy().astype(np.float32).flatten()), trt.Weights(weights_dict["head.bias"].numpy().astype(np.float32).flatten())) - network.mark_output(output.get_output(0)) + network.mark_output(sw_output.get_output(0)) engine = builder.build_engine(network, builder_config) return engine diff --git a/examples/tensorrt/swin/builder_int8.py b/examples/tensorrt/swin/builder_int8.py index 5c6acb5a6..47eb856e1 100644 --- a/examples/tensorrt/swin/builder_int8.py +++ b/examples/tensorrt/swin/builder_int8.py @@ -256,9 +256,7 @@ def build_engine(config, args, weights_dict): sw_output.precision = trt.float16 sw_output.set_output_type(0, trt.float16) - output_size = weights_dict["head.bias"].shape[0] - output = network.add_fully_connected(sw_output.get_output(0), output_size, trt.Weights(weights_dict["head.weight"].cpu().numpy().astype(np.float16).flatten()), trt.Weights(weights_dict["head.bias"].cpu().numpy().astype(np.float16).flatten())) - network.mark_output(output.get_output(0)) + network.mark_output(sw_output.get_output(0)) print("Before build_engine") engine = builder.build_engine(network, builder_config) print("After build_engine") diff --git a/examples/tensorrt/swin/infer_swintransformer_plugin.py b/examples/tensorrt/swin/infer_swintransformer_plugin.py index 7cb0ddd12..5b0663345 100644 --- a/examples/tensorrt/swin/infer_swintransformer_plugin.py +++ b/examples/tensorrt/swin/infer_swintransformer_plugin.py @@ -23,13 +23,10 @@ import ctypes import tensorrt as trt -import pycuda.driver as cuda -import pycuda.autoinit import sys sys.path.insert(0, "../../pytorch/swin/Swin-Transformer-Quantization") from SwinTransformer.config import get_config - from SwinTransformer.models import build_model test_time = 100 @@ -61,7 +58,7 @@ def parse_option(): parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") parser.add_argument('--use-checkpoint', action='store_true', help="whether to use gradient checkpointing to save memory") - parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], + parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'], help='mixed precision opt level, if O0, no amp is used') parser.add_argument('--output', default='output', type=str, metavar='PATH', help='root of output folder, the full path is // (default: output)') @@ -88,7 +85,7 @@ def main(config, args): validate_with_random_data(config, args, model) @torch.no_grad() -def run_swintransformer_plugin(args, config, model, image): +def run_swintransformer_plugin(args, config, model, images): TRT_LOGGER = trt.Logger(trt.Logger.INFO) # Import necessary plugins for BERT TensorRT ctypes.CDLL("../../../build/lib/libswinTransformer_plugin.so", mode=ctypes.RTLD_GLOBAL) @@ -100,65 +97,46 @@ def run_swintransformer_plugin(args, config, model, image): in_chans = config.MODEL.SWIN.IN_CHANS embed_dim = config.MODEL.SWIN.EMBED_DIM - output_size = model.head.bias.shape[0] - print("output_size ", output_size) - with open(args.engine, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime, \ runtime.deserialize_cuda_engine(f.read()) as engine, \ engine.create_execution_context() as context: context.active_optimization_profile = 0 - if args.use_fp16: - input_nbytes0 = max_batch * in_chans * img_size * img_size * trt.float16.itemsize - else: - input_nbytes0 = max_batch * in_chans * img_size * img_size * trt.float32.itemsize - stream = cuda.Stream() - - d_inputs = [cuda.mem_alloc(input_nbytes0)] - output_shape = (max_batch * output_size) - h_output = cuda.pagelocked_empty(output_shape, dtype=np.float32) - d_output = cuda.mem_alloc(h_output.nbytes) - - #import pdb;pdb.set_trace() context.set_binding_shape(0, (max_batch, in_chans, img_size, img_size)) + output_shape = tuple(context.get_binding_shape(1)) + print('output_shape binding:', output_shape) - if args.use_fp16: - image = image.astype(np.float16) - else: - image = image.astype(np.float32) + d_inputs = [images] + d_output = torch.empty(output_shape, dtype=torch.float32).cuda() - h_input_embeds = cuda.register_host_memory(np.ascontiguousarray(image.ravel())) - cuda.memcpy_htod_async(d_inputs[0], h_input_embeds, stream) + stream = torch.cuda.Stream() # warm up for i in range(warmup_time): - context.execute_async_v2(bindings=[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle) + context.execute_async_v2(bindings=[d_inp.data_ptr() for d_inp in d_inputs] + [d_output.data_ptr()], stream_handle=stream.cuda_stream) #ignore the last fc layer op_end = time.time() for i in range(test_time): - context.execute_async_v2(bindings=[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle) + context.execute_async_v2(bindings=[d_inp.data_ptr() for d_inp in d_inputs] + [d_output.data_ptr()], stream_handle=stream.cuda_stream) stream.synchronize() print("plugin time : ", (time.time() - op_end)/test_time*1000.0, "ms") - cuda.memcpy_dtoh_async(h_output, d_output, stream) - stream.synchronize() - - return h_output + return d_output.cpu().numpy() @torch.no_grad() def run_torch(model, images, mark): # warm up for i in range(warmup_time): - output = model(images) + output = model.forward_features(images) torch.cuda.synchronize() torch_start = time.time() #_nvtx.rangePushA("torch") for i in range(test_time): - torch_output = model(images) + torch_output = model.forward_features(images) #_nvtx.rangePop() torch.cuda.synchronize() torch_end = time.time() @@ -172,10 +150,8 @@ def validate_with_random_data(config, args, model): max_batch = config.DATA.BATCH_SIZE img_size = config.DATA.IMG_SIZE in_chans = config.MODEL.SWIN.IN_CHANS - images = np.random.rand(max_batch, in_chans, img_size, img_size) - - ## run pytorch plugin - plugin_output = run_swintransformer_plugin(args, config, model, images) + image = np.random.rand(1, in_chans, img_size, img_size) + images = np.repeat(image, max_batch, axis=0) if args.use_fp16: images = torch.tensor(images, dtype=torch.half) @@ -183,12 +159,15 @@ def validate_with_random_data(config, args, model): else: images = torch.tensor(images, dtype=torch.float) images = images.cuda(non_blocking=True) - traced_module = torch.jit.trace(model, images) - torch_traced_output = run_torch(traced_module, images, "torch trace") + ## run pytorch plugin + plugin_output = run_swintransformer_plugin(args, config, model, images) torch_output = run_torch(model, images, "torch") diff = abs(torch_output - plugin_output.reshape(max_batch, -1)) + print('plugin_output', plugin_output.mean((1, 2, 3)), 'torch_output',torch_output.mean((1))) print("torch_output vs plugin_output , avg diff : ", diff.mean((1)), "max diff : ", diff.max((1))) + assert diff.mean() < 0.001, "[ERROR] SWIN PLUGIN TEST FAIL !" + print("[INFO] SWIN TRT PLUGIN TEST PASS !") if __name__ == '__main__': args, config = parse_option() diff --git a/examples/tensorrt/swin/infer_swintransformer_plugin_int8.py b/examples/tensorrt/swin/infer_swintransformer_plugin_int8.py index 40afd5cc4..c4ff92df3 100644 --- a/examples/tensorrt/swin/infer_swintransformer_plugin_int8.py +++ b/examples/tensorrt/swin/infer_swintransformer_plugin_int8.py @@ -23,11 +23,8 @@ import ctypes import tensorrt as trt -import pycuda.driver as cuda -import pycuda.autoinit import sys -#sys.path.insert(0, "../third_party/Swin-Transformer") sys.path.insert(0, "../../pytorch/swin/Swin-Transformer-Quantization") sys.path.insert(0, "../../pytorch/swin") @@ -95,7 +92,7 @@ def main(config, args): validate_with_random_data(config, args, model) @torch.no_grad() -def run_swintransformer_plugin(args, config, model, image): +def run_swintransformer_plugin(args, config, model, images): TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) # Import necessary plugins for BERT TensorRT ctypes.CDLL("../../../build/lib/libswinTransformer_plugin.so", mode=ctypes.RTLD_GLOBAL) @@ -107,61 +104,48 @@ def run_swintransformer_plugin(args, config, model, image): in_chans = config.MODEL.SWIN.IN_CHANS embed_dim = config.MODEL.SWIN.EMBED_DIM - output_size = model.head.bias.shape[0] - print("output_size ", output_size) - with open(args.engine, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime, \ runtime.deserialize_cuda_engine(f.read()) as engine, \ engine.create_execution_context() as context: context.active_optimization_profile = 0 - input_nbytes0 = max_batch * in_chans * img_size * img_size * trt.float16.itemsize - stream = cuda.Stream() - - d_inputs = [cuda.mem_alloc(input_nbytes0)] - output_shape = (max_batch * output_size) - h_output = cuda.pagelocked_empty(output_shape, dtype=np.float32) - d_output = cuda.mem_alloc(h_output.nbytes) + stream = torch.cuda.Stream() #import pdb;pdb.set_trace() context.set_binding_shape(0, (max_batch, in_chans, img_size, img_size)) + output_shape = tuple(context.get_binding_shape(1)) + print('output_shape binding:', output_shape) - image = image.astype(np.float16) - - # Copy input h2d - h_input_embeds = cuda.register_host_memory(np.ascontiguousarray(image).ravel()) - cuda.memcpy_htod_async(d_inputs[0], h_input_embeds, stream) + d_inputs = [images] + d_output = torch.empty(output_shape, dtype=torch.float32).cuda() # warm up for i in range(warmup_time): - context.execute_async_v2(bindings=[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle) + context.execute_async_v2(bindings=[d_inp.data_ptr() for d_inp in d_inputs] + [d_output.data_ptr()], stream_handle=stream.cuda_stream) #ignore the last fc layer + torch.cuda.synchronize() op_end = time.time() for i in range(test_time): - context.execute_async_v2(bindings=[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle) + context.execute_async_v2(bindings=[d_inp.data_ptr() for d_inp in d_inputs] + [d_output.data_ptr()], stream_handle=stream.cuda_stream) stream.synchronize() + torch.cuda.synchronize() print("plugin time : ", (time.time() - op_end)/test_time*1000.0, "ms") - cuda.memcpy_dtoh_async(h_output, d_output, stream) - stream.synchronize() - - return h_output + return d_output.cpu().numpy() @torch.no_grad() def run_torch(model, images, mark): # warm up for i in range(warmup_time): - output = model(images) + output = model.forward_features(images) torch.cuda.synchronize() torch_start = time.time() - #_nvtx.rangePushA("torch") for i in range(test_time): - torch_output = model(images) - #_nvtx.rangePop() + torch_output = model.forward_features(images) torch.cuda.synchronize() torch_end = time.time() torch_output = torch_output.cpu().numpy() @@ -183,15 +167,16 @@ def validate_with_random_data(config, args, model): images_float = images_float.cuda(non_blocking=True) ## run pytorch plugin - plugin_output = run_swintransformer_plugin(args, config, model, images) + plugin_output = run_swintransformer_plugin(args, config, model, images_half) # warm up model.half() - # traced_module = torch.jit.trace(model, images_half) - # torch_traced_output = run_torch(traced_module, images_half, "torch trace") torch_output = run_torch(model, images_half, "torch") + # torch_output = model.forward_features(images_half) + # torch_output = torch_output.cpu().numpy() diff = abs(torch_output - plugin_output.reshape(max_batch, -1)) + print(diff.shape) print("torch_output vs plugin_output , avg diff : ", diff.mean((1)), "max diff : ", diff.max((1))) if __name__ == '__main__': diff --git a/examples/tensorrt/swin/run_builder_int8.sh b/examples/tensorrt/swin/run_builder_int8.sh index 7c1764879..9af3ab01a 100644 --- a/examples/tensorrt/swin/run_builder_int8.sh +++ b/examples/tensorrt/swin/run_builder_int8.sh @@ -3,4 +3,5 @@ python builder_int8.py \ --cfg ../../pytorch/swin/Swin-Transformer-Quantization/SwinTransformer/configs/swin_tiny_patch4_window7_224.yaml \ --resume ../../pytorch/swin/Swin-Transformer-Quantization/calib-checkpoint/swin_tiny_patch4_window7_224_calib.pth \ --th-path ../../../build/lib/libpyt_swintransformer.so \ + --int8-mode 1 \ --output swin_transformer_int8.engine \ No newline at end of file diff --git a/examples/tensorrt/t5/extractT5ModelToBIN.py b/examples/tensorrt/t5/extractT5ModelToBIN.py index 95499b85b..00cd4a3a6 100644 --- a/examples/tensorrt/t5/extractT5ModelToBIN.py +++ b/examples/tensorrt/t5/extractT5ModelToBIN.py @@ -1,12 +1,27 @@ -import os -import sys +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse +import os +import configparser import numpy as np import torch from transformers import T5ForConditionalGeneration +from pathlib import Path -modelName = 't5-small' -savePath = './para' +rename_mapping={"relative_attention_num_buckets":"relative_attention_num_buckets_or_max_pos_seq_len"} +new_configs={"structure":{"t5_with_bias":"false", "use_gated_activation":"false", "position_embedding_type":"relative"}} def fuse_decoder_qkv(model, factor, saved_dir): model_dict = {} @@ -15,33 +30,35 @@ def fuse_decoder_qkv(model, factor, saved_dir): model_dict[name] = param for i in range(model.decoder.config.num_layers): - shape = model_dict[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"].T.shape - qkv = torch.cat([model_dict[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"].T, - model_dict[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"].T, - model_dict[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"].T], dim=-1) + shape = model_dict[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"].transpose(1, 0).shape + qkv = torch.cat([model_dict[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"].transpose(1, 0), + model_dict[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"].transpose(1, 0), + model_dict[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"].transpose(1, 0)], dim=-1) qkv = qkv.reshape([shape[0], 3, shape[1]]) qkv = qkv.float().cpu().detach().numpy() split_vals = np.split(qkv, factor, axis=-1) for j in range(factor): - saved_path = saved_dir + "/" + f"decoder.block.{i}.layer.0.SelfAttention.qkv.weight.{j}.bin" + saved_path = saved_dir / f"decoder.block.{i}.layer.0.SelfAttention.qkv.weight.{j}.bin" split_vals[j].tofile(saved_path) def split_and_convert_process(key, val, factor, saved_dir): - val = val.T.detach().numpy() + if val.dim() == 2: + val = val.transpose(1, 0) + val = val.detach().numpy() saved_key = key if key.find("shared.weight") != -1: # shared weights, only need to convert the weights of rank 0 - saved_path = saved_dir + "/" + f"{saved_key}.bin" + saved_path = saved_dir / f"{saved_key}.bin" val.tofile(saved_path) - saved_path = saved_dir + "/" + f"{saved_key}_T.bin" - val.T.tofile(saved_path) + saved_path = saved_dir / f"{saved_key}_T.bin" + val.transpose(1, 0).tofile(saved_path) elif key.find("layer_norm.weight") != -1: # shared weights, only need to convert the weights of rank 0 - saved_path = saved_dir + "/" + f"{saved_key}.bin" + saved_path = saved_dir / f"{saved_key}.bin" val.tofile(saved_path) elif ( @@ -51,7 +68,7 @@ def split_and_convert_process(key, val, factor, saved_dir): ): split_vals = np.split(val, factor, axis=0) for j in range(factor): - saved_path = saved_dir + "/" + f"{saved_key}.{j:d}.bin" + saved_path = saved_dir / f"{saved_key}.{j:d}.bin" split_vals[j].tofile(saved_path) elif ( @@ -68,12 +85,12 @@ def split_and_convert_process(key, val, factor, saved_dir): ): split_vals = np.split(val, factor, axis=-1) for j in range(factor): - saved_path = saved_dir + "/" + f"{saved_key}.{j:d}.bin" + saved_path = saved_dir / f"{saved_key}.{j:d}.bin" split_vals[j].tofile(saved_path) elif key.find("relative_attention_bias") != -1: split_vals = np.split(val, factor, axis=0) for j in range(factor): - saved_path = saved_dir + "/" + f"{saved_key}.{j:d}.bin" + saved_path = saved_dir / f"{saved_key}.{j:d}.bin" split_vals[j].tofile(saved_path) elif ( key.find("decoder") != -1 and @@ -88,12 +105,37 @@ def split_and_convert_process(key, val, factor, saved_dir): print(f"[ERROR] cannot find key '{key}'") if __name__ == "__main__": - os.system("mkdir -p para") + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("-saved_dir", "-o", type=str, help="file name of output file", required=True) + parser.add_argument("-in_file", "-i", type=str, help="file name of input checkpoint file. Using model name like 't5-small' is also ok.", required=True) + args = parser.parse_args() + + saved_dir = Path(args.saved_dir) / f"1-gpu" + saved_dir.mkdir(parents=True, exist_ok=True) - t5_model = T5ForConditionalGeneration.from_pretrained('t5-small') + t5_model = T5ForConditionalGeneration.from_pretrained(args.in_file) + config = configparser.ConfigParser() + + config["encoder"] = {} + for key, val in t5_model.encoder.config.to_dict().items(): + config["encoder"][key] = f"{val}" + config["encoder"]["weight_data_type"] = "fp32" + config["decoder"] = {} + for key, val in t5_model.decoder.config.to_dict().items(): + config["decoder"][key] = f"{val}" + config["decoder"]["weight_data_type"] = "fp32" + for key, val in rename_mapping.items(): + config['encoder'][val] = config['encoder'].pop(key) + config['decoder'][val] = config['decoder'].pop(key) + for key, val in new_configs.items(): + config[key] = {} + for val_key, val_val in val.items(): + config[key][val_key] = val_val + with open(f"{saved_dir}/config.ini", 'w') as configfile: + config.write(configfile) for name, param in t5_model.named_parameters(): - split_and_convert_process(name, param, 1, savePath) - fuse_decoder_qkv(t5_model, 1, savePath) + split_and_convert_process(name, param, 1, saved_dir) + fuse_decoder_qkv(t5_model, 1, saved_dir) print("extract T5 model weight finish!") diff --git a/examples/tensorrt/t5/testT5Plugin.py b/examples/tensorrt/t5/testT5Plugin.py index 8f9765665..1e2f44acd 100644 --- a/examples/tensorrt/t5/testT5Plugin.py +++ b/examples/tensorrt/t5/testT5Plugin.py @@ -14,6 +14,7 @@ # limitations under the License. # +import configparser import os import sys import ctypes @@ -26,46 +27,35 @@ from transformers import PreTrainedTokenizerFast from transformers import T5Tokenizer + dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.append(dir_path + "/../../..") from examples.pytorch.decoding.utils.recover_bpe import recover_bpe -npToTrt = {np.int8:trt.int8,np.float16:trt.float16,np.int32:trt.int32,np.float32:trt.float32} -npToPFT = {np.int8:trt.PluginFieldType.INT8,np.float16:trt.PluginFieldType.FLOAT16, - np.int32:trt.PluginFieldType.INT32,np.float32:trt.PluginFieldType.FLOAT32} -npToTorch = {np.dtype('float16'):torch.float16,np.dtype('int32'):torch.int32,np.dtype('float32'):torch.float32} -device = 0 +npToTrt = {np.int8: trt.int8, np.float16: trt.float16, np.int32: trt.int32, np.float32: trt.float32} +npToPFT = {np.int8: trt.PluginFieldType.INT8, np.float16: trt.PluginFieldType.FLOAT16, np.int32: trt.PluginFieldType.INT32, np.float32: trt.PluginFieldType.FLOAT32} +npToTorch = {np.dtype('float16'): torch.float16, np.dtype('int32'): torch.int32, np.dtype('float32'): torch.float32} +device = 0 # global variables with default value -globalNMaxBatchSize = 128 -globalNMaxSeqLen = 384 -globalNBeamSize = 4 -globalNUseFP16 = 0 - -globalNHead = 8 -globalNModelDim = 512 -globalNSizePerHead = globalNModelDim / 8 -globalNInterSize = globalNModelDim * 4 -globalNLayer = 6 -globalNBucket = 32 -globalNMaxDistance = 128 -globalNSM = (lambda x: x[0]*10 + x[1])( torch.cuda.get_device_capability() ) -globalFQScale = 1.0 / math.sqrt(globalNSizePerHead) -globalNVocabSize = 32128 -globalNStartId = 0 -globalNEndId = 1 -globalFBeamDiversity = 0.0 -globalFTopP = 0.0 -globalFTemperature = 1.0 -globalFLenPenalty = 1.0 -globalFRepPenalty = 1.0 - -nMinBatchSize = 1 -nOptBatchSize = globalNMaxBatchSize -nMaxBatchSize = globalNMaxBatchSize -nMinSeqLen = 32 -nOptSeqLen = globalNMaxSeqLen -nMaxSeqLen = globalNMaxSeqLen +globalNMaxBatchSize = 128 +globalNMaxSeqLen = 384 +globalNBeamSize = 4 +globalNUseFP16 = 0 + +globalNSM = (lambda x: x[0] * 10 + x[1])(torch.cuda.get_device_capability()) + +globalFBeamDiversity = 0.0 +globalFTemperature = 1.0 +globalFLenPenalty = 0.0 +globalFRepPenalty = 1.0 + +nMinBatchSize = 1 +nOptBatchSize = globalNMaxBatchSize +nMaxBatchSize = globalNMaxBatchSize +nMinSeqLen = 1 +nOptSeqLen = globalNMaxSeqLen +nMaxSeqLen = globalNMaxSeqLen def bleu_score(pred, ref): from sacrebleu import corpus_bleu @@ -78,104 +68,77 @@ def bleu_score(pred, ref): return bleu def getT5EncoderPlugin(arg): - nBatchSize = arg['batch_size'] - nMaxSeqLen = arg['max_seq_len'] - nBeamSize = arg['beam_width'], - nHead = globalNHead - nSizePerHead = globalNSizePerHead - nInterSize = globalNInterSize - nModelDim = globalNModelDim - nLayer = globalNLayer - nBucket = globalNBucket - nMaxDistance = globalNMaxDistance - nSM = globalNSM - fQScale = globalFQScale - useFP16 = int(arg['data_type']=='fp16') + nBatchSize = arg['batch_size'] + nMaxSeqLen = arg['max_seq_len'] + nBeamSize = arg['beam_width'], + nSM = globalNSM + useFP16 = int(arg['data_type'] == 'fp16') + ckpt_path = arg['ckpt_path'].encode() for c in trt.get_plugin_registry().plugin_creator_list: if c.name == 'T5EncoderPlugin': pList = [ - trt.PluginField('max_batch_size', np.int32(nBatchSize), npToPFT[np.int32]), - trt.PluginField('max_seq_len', np.int32(nMaxSeqLen), npToPFT[np.int32]), - trt.PluginField('beam_width', np.int32(nBeamSize), npToPFT[np.int32]), - trt.PluginField('head_num', np.int32(nHead), npToPFT[np.int32]), - trt.PluginField('size_per_head', np.int32(nSizePerHead), npToPFT[np.int32]), - trt.PluginField('inter_size', np.int32(nInterSize), npToPFT[np.int32]), - trt.PluginField('d_model', np.int32(nModelDim), npToPFT[np.int32]), - trt.PluginField('num_layer', np.int32(nLayer), npToPFT[np.int32]), - trt.PluginField('num_bucket', np.int32(nBucket), npToPFT[np.int32]), - trt.PluginField('max_distance', np.int32(nMaxDistance), npToPFT[np.int32]), - trt.PluginField('sm', np.int32(nSM), npToPFT[np.int32]), - trt.PluginField('q_scaling', np.float32(fQScale), npToPFT[np.float32]), - trt.PluginField('useFP16', np.int32(useFP16), npToPFT[np.int32]), - ] + trt.PluginField('max_batch_size', np.int32(nBatchSize), npToPFT[np.int32]), + trt.PluginField('max_seq_len', np.int32(nMaxSeqLen), npToPFT[np.int32]), + trt.PluginField('beam_width', np.int32(nBeamSize), npToPFT[np.int32]), + trt.PluginField('sm', np.int32(nSM), npToPFT[np.int32]), + trt.PluginField('useFP16', np.int32(useFP16), npToPFT[np.int32]), + trt.PluginField('ckpt_path', ckpt_path, trt.PluginFieldType.CHAR), + ] return c.create_plugin(c.name, trt.PluginFieldCollection(pList)) return None def getT5DecodingPlugin(arg): - nBatchSize = arg['batch_size'] - nMaxSeqLen = arg['max_seq_len'] - nMemMaxSeqLen = arg['max_seq_len'] - nBeamSize = arg['beam_width'] - nHead = globalNHead - nSizePerHead = globalNSizePerHead - nInterSize = globalNInterSize - nModelDim = globalNModelDim - nLayer = globalNLayer - nVocabSize = globalNVocabSize - nBucket = globalNBucket - nMaxDistance = globalNMaxDistance - nStartId = globalNStartId - nEndId = globalNEndId - fBeamDiversity = arg['beam_search_diversity_rate'] - nTopK = arg['sampling_topk'] - fTopP = arg['sampling_topp'] - fTemperature = globalFTemperature - fLenPenalty = globalFLenPenalty - fRepPenalty = globalFRepPenalty - useFP16 = int(arg['data_type']=='fp16') + nBatchSize = arg['batch_size'] + nMaxSeqLen = arg['max_seq_len'] + nMemMaxSeqLen = arg['max_seq_len'] + nBeamSize = arg['beam_width'] + useFP16 = int(arg['data_type'] == 'fp16') + ckpt_path = arg['ckpt_path'].encode() for c in trt.get_plugin_registry().plugin_creator_list: if c.name == 'T5DecodingPlugin': pList = [ - trt.PluginField('max_batch_size', np.int32(nBatchSize), npToPFT[np.int32]), - trt.PluginField('max_seq_len', np.int32(nMaxSeqLen), npToPFT[np.int32]), - trt.PluginField('mem_max_seq_len', np.int32(nMaxSeqLen), npToPFT[np.int32]), - trt.PluginField('beam_width', np.int32(nBeamSize), npToPFT[np.int32]), - trt.PluginField('head_num', np.int32(nHead), npToPFT[np.int32]), - trt.PluginField('size_per_head', np.int32(nSizePerHead), npToPFT[np.int32]), - trt.PluginField('inter_size', np.int32(nInterSize), npToPFT[np.int32]), - trt.PluginField('d_model', np.int32(nModelDim), npToPFT[np.int32]), - trt.PluginField('num_layer', np.int32(nLayer), npToPFT[np.int32]), - trt.PluginField('vocab_size', np.int32(nVocabSize), npToPFT[np.int32]), - trt.PluginField('num_bucket', np.int32(nBucket), npToPFT[np.int32]), - trt.PluginField('max_distance', np.int32(nMaxDistance), npToPFT[np.int32]), - trt.PluginField('start_id', np.int32(nStartId), npToPFT[np.int32]), - trt.PluginField('end_id', np.int32(nEndId), npToPFT[np.int32]), - trt.PluginField('beam_search_diversity_rate', np.float32(fBeamDiversity), npToPFT[np.float32]), - trt.PluginField('top_k', np.int32(nTopK), npToPFT[np.int32]), - trt.PluginField('top_p', np.float32(fTopP), npToPFT[np.float32]), - trt.PluginField('temperature', np.float32(fTemperature), npToPFT[np.float32]), - trt.PluginField('len_penalty', np.float32(fLenPenalty), npToPFT[np.float32]), - trt.PluginField('repetition_penalty', np.float32(fRepPenalty), npToPFT[np.float32]), - trt.PluginField('useFP16', np.int32(useFP16), npToPFT[np.int32]), - ] + trt.PluginField('max_batch_size', np.int32(nBatchSize), npToPFT[np.int32]), + trt.PluginField('max_seq_len', np.int32(nMaxSeqLen), npToPFT[np.int32]), + trt.PluginField('mem_max_seq_len', np.int32(nMaxSeqLen), npToPFT[np.int32]), + trt.PluginField('beam_width', np.int32(nBeamSize), npToPFT[np.int32]), + trt.PluginField('useFP16', np.int32(useFP16), npToPFT[np.int32]), + trt.PluginField('ckpt_path', ckpt_path, trt.PluginFieldType.CHAR), + ] return c.create_plugin(c.name, trt.PluginFieldCollection(pList)) return None -def buildEngine(logger, arg): - builder = trt.Builder(logger) - network = builder.create_network(1) - profile = builder.create_optimization_profile() - config = builder.create_builder_config() - config.max_workspace_size = 1 << 30 - config.flags = int(arg['data_type'] == 'fp16') - - inputT0 = network.add_input('inputId', npToTrt[np.int32], [-1,-1]) - inputT1 = network.add_input('inputSeqLen', npToTrt[np.int32], [-1]) - - profile.set_shape(inputT0.name, [nMinBatchSize,nMinSeqLen],[nOptBatchSize,nOptSeqLen],[nMaxBatchSize,nMaxSeqLen]) - profile.set_shape(inputT1.name, [nMinBatchSize],[nOptBatchSize],[nMaxBatchSize]) +def buildEngine(logger, arg, trtFileName): + builder = trt.Builder(logger) + network = builder.create_network(1) + profile = builder.create_optimization_profile() + config = builder.create_builder_config() + config.max_workspace_size = 1 << 30 + config.flags = int(arg['data_type'] == 'fp16') + + inputT0 = network.add_input('inputId', npToTrt[np.int32], [-1, -1]) + inputT1 = network.add_input('inputSeqLen', npToTrt[np.int32], [-1]) + inputT2 = network.add_input('inputTopK', npToTrt[np.int32], [-1]) + inputT3 = network.add_input('inputTopP', npToTrt[np.int32], [-1]) + inputT4 = network.add_input('inputBeam_search_diversity_rate', npToTrt[np.float32], [-1]) + inputT5 = network.add_input('inputTemperature', npToTrt[np.float32], [-1]) + inputT6 = network.add_input('inputLen_penalty', npToTrt[np.float32], [-1]) + inputT7 = network.add_input('inputRepetition_penalty', npToTrt[np.float32], [-1]) + + profile.set_shape(inputT0.name, [nMinBatchSize, nMinSeqLen], [nOptBatchSize, nOptSeqLen], [nMaxBatchSize, nMaxSeqLen]) + profile.set_shape(inputT1.name, [nMinBatchSize], [nOptBatchSize], [nMaxBatchSize]) + profile.set_shape(inputT2.name, [1], [nOptBatchSize], [nMaxBatchSize]) + profile.set_shape(inputT3.name, [1], [nOptBatchSize], [nMaxBatchSize]) + profile.set_shape(inputT4.name, [1], [nOptBatchSize], [nMaxBatchSize]) + profile.set_shape(inputT5.name, [1], [nOptBatchSize], [nMaxBatchSize]) + profile.set_shape(inputT6.name, [1], [nOptBatchSize], [nMaxBatchSize]) + profile.set_shape(inputT7.name, [1], [nOptBatchSize], [nMaxBatchSize]) config.add_optimization_profile(profile) + model_config = configparser.ConfigParser() + model_config_path = os.path.join(arg["ckpt_path"], 'config.ini') + if os.path.isfile(model_config_path): + model_config.read(model_config_path) + encoderPlugin = getT5EncoderPlugin(arg) decodingPlugin = getT5DecodingPlugin(arg) if encoderPlugin == None: @@ -185,29 +148,39 @@ def buildEngine(logger, arg): print("Failed making decoding plugin!") return None - encoderLayer = network.add_plugin_v2([inputT0,inputT1], encoderPlugin) - decodingLayer = network.add_plugin_v2([encoderLayer.get_output(0),inputT1], decodingPlugin) - decodingLayer.get_output(0).name = "decodingOutput0" - decodingLayer.get_output(1).name = "decodingOutput1" + encoderLayer = network.add_plugin_v2([inputT0, inputT1], encoderPlugin) + decodingLayer = network.add_plugin_v2([encoderLayer.get_output(0), inputT1, inputT2, inputT3, inputT4, inputT5, inputT6, inputT7], decodingPlugin) + decodingLayer.get_output(0).name = "decodingOutput0" + decodingLayer.get_output(1).name = "decodingOutput1" decodingLayer.get_output(0).dtype = npToTrt[np.int32] decodingLayer.get_output(1).dtype = npToTrt[np.int32] network.mark_output(decodingLayer.get_output(0)) network.mark_output(decodingLayer.get_output(1)) - return builder.build_engine(network,config) + + engineString = builder.build_serialized_network(network, config) + if engineString == None: + print("Failed getting serialized engine!") + return None + print("Succeeded getting serialized engine!") + with open(trtFileName, "wb") as f: + f.write(engineString) + print("Succeeded saving .plan file!") + engine = trt.Runtime(logger).deserialize_cuda_engine(engineString) + return engine def testBoth(arg, stream): - useFP16 = int(arg['data_type'] == 'fp16') - nBatchSize = arg['batch_size'] - nSeqLen = arg['max_seq_len'] - testCase = ""%(['32','16'][useFP16],nBatchSize,nSeqLen) - print("Test both Encoder and Decoding",testCase) + useFP16 = int(arg['data_type'] == 'fp16') + nBatchSize = arg['batch_size'] + nSeqLen = arg['max_seq_len'] + testCase = "" % (['32', '16'][useFP16], nBatchSize, nSeqLen) + print("Test both Encoder and Decoding", testCase) logger = trt.Logger(trt.Logger.ERROR) trt.init_libnvinfer_plugins(logger, '') ctypes.cdll.LoadLibrary(arg['lib_path']) - trtFile = 'T5Engine-fp' + ['32','16'][useFP16] +'.trt' + trtFile = 'T5Engine-fp' + ['32', '16'][useFP16] + '.plan' if os.path.isfile(trtFile): with open(trtFile, 'rb') as f: engineString = f.read() @@ -217,16 +190,10 @@ def testBoth(arg, stream): return print("Succeeded loading engine!") else: - engine = buildEngine(logger, arg) - if engine == None: - print("Failed building engine!") - return - print("Succeeded building engine!") - with open(trtFile, 'wb') as f: - f.write( engine.serialize() ) + engine = buildEngine(logger, arg, trtFile) context = engine.create_execution_context() - nInput = np.sum([ engine.binding_is_input(i) for i in range(engine.num_bindings) ]) + nInput = np.sum([engine.binding_is_input(i) for i in range(engine.num_bindings)]) nOutput = engine.num_bindings - nInput #for i in range(engine.num_bindings): # print("Bind[%2d]:i[%d]->"%(i,i) if engine.binding_is_input(i) else "Bind[%2d]:o[%d]->"%(i,i-nInput), @@ -251,77 +218,101 @@ def testBoth(arg, stream): torch.cuda.synchronize() start_time = datetime.now() while prev < len(src_text): - input_texts = src_text[prev:prev+nBatchSize] + input_texts = src_text[prev:prev + nBatchSize] prev += nBatchSize - + input_token = tokenizer(input_texts, return_tensors='pt', padding=True) inputId = np.ascontiguousarray(input_token['input_ids'].numpy().astype(np.int32)) - inputMask = np.ascontiguousarray(np.sum(input_token['attention_mask'].numpy(),1).astype(np.int32)) - nRealBatchSize,nRealSeqLen = np.shape(inputId) - - context.set_binding_shape(0,[nRealBatchSize,nRealSeqLen]) - context.set_binding_shape(1,[nRealBatchSize]) + inputMask = np.ascontiguousarray(np.sum(input_token['attention_mask'].numpy(), 1).astype(np.int32)) + nRealBatchSize, nRealSeqLen = np.shape(inputId) + + context.set_binding_shape(0, [nRealBatchSize, nRealSeqLen]) + context.set_binding_shape(1, [nRealBatchSize]) + context.set_binding_shape(2, [nRealBatchSize]) + context.set_binding_shape(3, [nRealBatchSize]) + context.set_binding_shape(4, [nRealBatchSize]) + context.set_binding_shape(5, [nRealBatchSize]) + context.set_binding_shape(6, [nRealBatchSize]) + context.set_binding_shape(7, [nRealBatchSize]) + + inputTopK = np.full([nRealBatchSize], arg['sampling_topk'], dtype=np.int32) + inputTopP = np.full([nRealBatchSize], arg['sampling_topp'], dtype=np.float32) + + inputFBeamDiversity = np.full([nRealBatchSize], globalFBeamDiversity, dtype=np.float32) + inputFTemperature = np.full([nRealBatchSize], globalFTemperature, dtype=np.float32) + inputFLenPenalty = np.full([nRealBatchSize], globalFLenPenalty, dtype=np.float32) + inputFRepPenalty = np.full([nRealBatchSize], globalFRepPenalty, dtype=np.float32) bufferD = [] - bufferD.append( torch.from_numpy(inputId).to(device) ) - bufferD.append( torch.from_numpy(inputMask).to(device) ) - bufferD.append( torch.empty(tuple(context.get_binding_shape(2)), dtype=torch.int32, device=device) ) - bufferD.append( torch.empty(tuple(context.get_binding_shape(3)), dtype=torch.int32, device=device) ) + bufferD.append(torch.from_numpy(inputId).to(device)) + bufferD.append(torch.from_numpy(inputMask).to(device)) + bufferD.append(torch.from_numpy(inputTopK).to(device)) + bufferD.append(torch.from_numpy(inputTopP).to(device)) + bufferD.append(torch.from_numpy(inputFBeamDiversity).to(device)) + bufferD.append(torch.from_numpy(inputFTemperature).to(device)) + bufferD.append(torch.from_numpy(inputFLenPenalty).to(device)) + bufferD.append(torch.from_numpy(inputFRepPenalty).to(device)) + bufferD.append(torch.empty(tuple(context.get_binding_shape(8)), dtype=torch.int32, device=device)) + bufferD.append(torch.empty(tuple(context.get_binding_shape(9)), dtype=torch.int32, device=device)) torch.cuda.synchronize() if needWarmUp: for i in range(5): - context.execute_async_v2([ b.data_ptr() for b in bufferD ], stream) + context.execute_async_v2([b.data_ptr() for b in bufferD], stream) prev = 0 needWarmUp = False torch.cuda.synchronize() start_time = datetime.now() continue - context.execute_async_v2([ b.data_ptr() for b in bufferD ], stream) + context.execute_async_v2([b.data_ptr() for b in bufferD], stream) torch.cuda.synchronize() - outputId.append( bufferD[nInput+0].cpu().numpy() ) - outputSeqLen.append( bufferD[nInput+1].cpu().numpy() ) + outputId.append(bufferD[nInput + 0].cpu().numpy()) + outputSeqLen.append(bufferD[nInput + 1].cpu().numpy()) + + if len(outputId) >= arg["max_iteration"]: + break stop_time = datetime.now() execution_time = (stop_time - start_time).total_seconds() outputText = [] - for batch_token, batch_seq_len in zip(outputId,outputSeqLen): + for batch_token, batch_seq_len in zip(outputId, outputSeqLen): for j in range(len(batch_token)): - outputText.append( fast_tokenizer.decode(batch_token[j][0][:batch_seq_len[j][0]], skip_special_tokens=True)) - + outputText.append(fast_tokenizer.decode(batch_token[j][0][:batch_seq_len[j][0]], skip_special_tokens=True)) + bleuScore = bleu_score(outputText, tgt_text[:len(outputText)]) with open("output.txt", 'w') as f: for line in outputText: f.write(line + '\n') - print("[INFO] FT translates {} batches taking {:.2f} sec to translate {} tokens, BLEU score: {:.2f}, {:.0f} tokens/sec.".format( - len(outputText)//nBatchSize, execution_time, bleuScore.sys_len, bleuScore.score, bleuScore.sys_len / execution_time)) - print("Test both Encoder and Decoding",testCase,"finish!") + print("[INFO] FT translates {} batches taking {:.2f} sec to translate {} tokens, BLEU score: {:.2f}, {:.0f} tokens/sec.".format(len(outputText) // nBatchSize, execution_time, bleuScore.sys_len, bleuScore.score, bleuScore.sys_len / execution_time)) + + if arg["ft_BLEU_threshold"] != None: + assert bleuScore.score >= arg["ft_BLEU_threshold"], f"[ERROR] T5Plugin Test FAIL !" + print(f"[INFO] T5Plugin Test PASS !") + print(f"[INFO] Test both Encoder and Decoding {testCase} finish!") if __name__ == '__main__': - np.set_printoptions(precision = 4, linewidth = 200, suppress = True) + np.set_printoptions(precision=4, linewidth=200, suppress=True) torch.cuda.set_device(device) - stream = 0 #torch.cuda.Stream(device).cuda_stream - #os.system('rm -f ./*.trt ./*.in') + stream = 0 #torch.cuda.Stream(device).cuda_stream + #os.system('rm -f ./*.plan ./*.in') parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('-batch', '--batch_size', type=int, metavar='NUMBER', default=32, help='batch size (default: 32)') - parser.add_argument('-beam', '--beam_width', type=int, metavar='NUMBER', default=4, help='beam width (default: 4)') - parser.add_argument('-s', '--max_seq_len', type=int, metavar='NUMBER', default=128, help='max sequence length (default: 200)') - parser.add_argument( '--source', type=str, metavar='STRING', default="../examples/pytorch/decoding/utils/translation/test.en", help="Path to the source file.") - parser.add_argument( '--target', type=str, metavar='STRING', default="../examples/pytorch/decoding/utils/translation/test.de", help="Path to the target file.") - parser.add_argument('-diversity_rate', '--beam_search_diversity_rate', type=float, metavar='NUMBER', default=0.0, help='deviersity rate of beam search. default is 0. When diversity rate = 0, it is equivalent to the naive beams earch.') - parser.add_argument('-topk', '--sampling_topk', type=int, metavar='NUMBER', default=4, help='Candidate (k) value of top k sampling in decoding. Default is 1.') - parser.add_argument('-topp', '--sampling_topp', type=float, metavar='NUMBER', default=0.0, help='Probability (p) value of top p sampling in decoding. Default is 0.0. ') - parser.add_argument('-d', '--data_type', type=str, metavar='STRING', default="fp32", help='data type (default: fp32)', choices=['fp32', 'fp16']) - parser.add_argument('-lib_path','--lib_path', type=str, metavar='STRING', default="lib/libtrt_t5.so", help='the path of FasterTransformer pytorch t5 op library.') - parser.add_argument('-model', '--model', type=str, metavar='STRING', default="t5-small", help='T5 model size.', choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]) - #parser.add_argument('-tensor_para_size', '--tensor_para_size', type=int, metavar='NUMBER', default=1, help='size of tensor parallelism (default: 1)') - #parser.add_argument('-pipeline_para_size', '--pipeline_para_size', type=int, metavar='NUMBER', default=1, help='size of pipeline parallelism (default: 1)') - #parser.add_argument( '--ckpt_path', type=str, help='path to the checkpoint file.') - #parser.add_argument('-max_ite', '--max_iteration', type=int, metavar='NUMBER', default=100000, help='Maximum iteraiton for translation, default is 100000 (as large as possible to run all test set).') + parser.add_argument('-batch', '--batch_size', type=int, metavar='NUMBER', default=32, help='batch size (default: 32)') + parser.add_argument('-beam', '--beam_width', type=int, metavar='NUMBER', default=4, help='beam width (default: 4)') + parser.add_argument('-s', '--max_seq_len', type=int, metavar='NUMBER', default=128, help='max sequence length (default: 200)') + parser.add_argument('--source', type=str, metavar='STRING', default="../examples/pytorch/decoding/utils/translation/test.en", help="Path to the source file.") + parser.add_argument('--target', type=str, metavar='STRING', default="../examples/pytorch/decoding/utils/translation/test.de", help="Path to the target file.") + parser.add_argument('-diversity_rate', '--beam_search_diversity_rate', type=float, metavar='NUMBER', default=0.0, help='deviersity rate of beam search. default is 0. When diversity rate = 0, it is equivalent to the naive beams earch.') + parser.add_argument('-topk', '--sampling_topk', type=int, metavar='NUMBER', default=4, help='Candidate (k) value of top k sampling in decoding. Default is 1.') + parser.add_argument('-topp', '--sampling_topp', type=float, metavar='NUMBER', default=0.0, help='Probability (p) value of top p sampling in decoding. Default is 0.0. ') + parser.add_argument('-d', '--data_type', type=str, metavar='STRING', default="fp32", help='data type (default: fp32)', choices=['fp32', 'fp16']) + parser.add_argument('-lib_path', '--lib_path', type=str, metavar='STRING', default="lib/libtrt_t5.so", help='the path of FasterTransformer pytorch t5 op library.') + parser.add_argument('-model', '--model', type=str, metavar='STRING', default="t5-small", help='T5 model size.', choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]) + parser.add_argument( '--ckpt_path', type=str, metavar='STRING', help='path to the checkpoint file.') + parser.add_argument('-max_ite', '--max_iteration', type=int, metavar='NUMBER', default=100000, help='Maximum iteraiton for translation, default is 100000 (as large as possible to run all test set).') + parser.add_argument('--ft_BLEU_threshold', type=float, help='Threshold of FT BLEU score') arg = vars(parser.parse_args()) - testBoth(arg,stream) + testBoth(arg, stream) print("Test finish!") - diff --git a/examples/tensorrt/vit/infer_visiontransformer_int8_plugin.py b/examples/tensorrt/vit/infer_visiontransformer_int8_plugin.py new file mode 100644 index 000000000..d4b4a2193 --- /dev/null +++ b/examples/tensorrt/vit/infer_visiontransformer_int8_plugin.py @@ -0,0 +1,196 @@ +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import argparse +import numpy as np + +import torch +import torch.backends.cudnn as cudnn + +import tensorrt as trt +import ctypes + +import sys +sys.path.insert(0, "../../pytorch/vit/ViT-quantization") +sys.path.insert(0, "../../pytorch/vit/ViT-quantization/ViT-pytorch") + +from vit_int8 import VisionTransformerINT8 +import quant_utils +from models.modeling import CONFIGS + +from plugin_loader_int8 import ViTINT8PluginLoader + + +test_time = 100 +warmup_time = 10 + +def setup_torch(args): + # Prepare model + config = CONFIGS[args.model_type] + print(config) + + model = VisionTransformerINT8(config, args.img_size, zero_head=False, num_classes=1000) + model.load_state_dict(torch.load(args.pretrained_dir)) + + quant_utils.configure_model(model, args, calib=False) + model.to(args.device) + return config, model + +def setup_trt(args, config, model): + p_loader = ViTINT8PluginLoader(args.plugin_path) + p_loader.load_model_config(config, args) + engine = p_loader.build_network(model.state_dict()) + return engine, p_loader + +def parse_option(): + parser = argparse.ArgumentParser('ViT evaluation script', add_help=False) + parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", + "ViT-L_32", "ViT-H_14"], + default="ViT-B_16", + help="Which variant to use.") + parser.add_argument("--img_size", default=384, type=int, + help="Resolution size") + parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz", + help="Where to search for pretrained ViT models.") + + # easy config modification + parser.add_argument('--plugin_path', type=str, default="../../../build/lib/libvit_plugin.so", help='path to plugin lib') + parser.add_argument('--batch-size', type=int, default=32, help="batch size for single GPU") + parser.add_argument('--int8-mode', type=int, default=2, choices=[1, 2], + help="Which int8 mode to use, choose from [1, 2]") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument("--local_rank", type=int, default=-1, help='local rank for DistributedDataParallel') + + quant_utils.add_arguments(parser) + args, unparsed = parser.parse_known_args() + if args.quant_mode is not None: + args = quant_utils.set_args(args) + quant_utils.set_default_quantizers(args) + + if args.quant_mode == 'ft1': + args.int8_mode = 1 + elif args.quant_mode == 'ft2': + args.int8_mode = 2 + else: + raise NotImplementedError("For ViT-INT8, we only support ft1/ft2 as quant_mode") + + return args + + +def main(args): + + config, model = setup_torch(args) + engine, p_loader = setup_trt(args, config, model) + + validate_with_random_data(p_loader, model, engine) + +@torch.no_grad() +def run_trt_plugin(plugin_loader:ViTINT8PluginLoader, images, engine): + TRT_LOGGER = trt.Logger(trt.Logger.INFO) + seq_len = plugin_loader.seq_len_ + embed_dim = plugin_loader.embed_dim_ + max_batch = plugin_loader.max_batch_ + img_size = plugin_loader.img_size_ + in_chans = plugin_loader.in_chans_ + + with engine.create_execution_context() as context: + + context.active_optimization_profile = 0 + + stream = torch.cuda.Stream() + + context.set_binding_shape(0, (max_batch, in_chans, img_size, img_size)) + output_shape = tuple(context.get_binding_shape(1)) + print(output_shape) + + # Copy input h2d + d_inputs = [images] + d_output = torch.empty(output_shape, dtype=torch.float32).cuda() + + # warm up + for i in range(warmup_time): + context.execute_async_v2([d_inp.data_ptr() for d_inp in d_inputs] + [d_output.data_ptr()], stream.cuda_stream) + + #ignore the last fc layer + torch.cuda.synchronize() + op_end = time.time() + for i in range(test_time): + context.execute_async_v2([d_inp.data_ptr() for d_inp in d_inputs] + [d_output.data_ptr()], stream.cuda_stream) + stream.synchronize() + + torch.cuda.synchronize() + print("plugin time : ", (time.time() - op_end)/test_time*1000.0, "ms") + + return d_output.cpu().numpy() + + +@torch.no_grad() +def run_torch(model, images, mark): + # warm up + for i in range(warmup_time): + output = model(images) + + torch.cuda.synchronize() + torch_start = time.time() + for i in range(test_time): + torch_output = model.transformer(images) + + torch.cuda.synchronize() + torch_end = time.time() + torch_output = torch_output[0].cpu().numpy() + print(mark + " time : ", (torch_end - torch_start)/test_time*1000.0, "ms") + + return torch_output + +@torch.no_grad() +def validate_with_random_data(plugin_loader:ViTINT8PluginLoader, model, engine): + model.eval() + model.half() + + dtype_torch = torch.float16 + + max_batch = plugin_loader.max_batch_ + img_size = plugin_loader.img_size_ + in_chans = plugin_loader.in_chans_ + image = np.random.rand(1, in_chans, img_size, img_size) + images = np.repeat(image, max_batch, axis=0) + images_tensor = torch.tensor(images, dtype=dtype_torch) + images_tensor = images_tensor.cuda(non_blocking=True) + + plugin_output = run_trt_plugin(plugin_loader, images_tensor, engine) + + torch_output = run_torch(model, images_tensor, "torch") + print(torch_output.shape) + print(plugin_output.shape) + + diff = abs(torch_output - plugin_output.reshape(torch_output.shape)) + print("torch_output vs plugin_output , avg diff : ", diff.mean(), "max diff : ", diff.max()) + + +if __name__ == '__main__': + args = parse_option() + + seed = args.seed + int(time.time()) + torch.manual_seed(seed) + np.random.seed(seed) + cudnn.benchmark = True + + # Setup CUDA, GPU + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args.n_gpu = torch.cuda.device_count() + args.device = device + + main(args) diff --git a/examples/tensorrt/vit/infer_visiontransformer_plugin.py b/examples/tensorrt/vit/infer_visiontransformer_plugin.py index 156ad9d87..c99ddff7a 100644 --- a/examples/tensorrt/vit/infer_visiontransformer_plugin.py +++ b/examples/tensorrt/vit/infer_visiontransformer_plugin.py @@ -172,7 +172,8 @@ def validate_with_random_data(plugin_loader:ViTPluginLoader, model, engine): diff = abs(torch_output - plugin_output.reshape(torch_output.shape)) print("torch_output vs plugin_output , avg diff : ", diff.mean(), "max diff : ", diff.max()) - + assert diff.mean() < 0.006, "[ERROR] VIT TRT PLUGIN TEST FAIL !" + print("[INFO] VIT TRT PLUGIN TEST PASS !") if __name__ == '__main__': args = parse_option() diff --git a/examples/tensorrt/vit/plugin_loader.py b/examples/tensorrt/vit/plugin_loader.py index 6210ffade..f6ba3b664 100644 --- a/examples/tensorrt/vit/plugin_loader.py +++ b/examples/tensorrt/vit/plugin_loader.py @@ -22,7 +22,7 @@ def load_weights(weight_path:str): suffix = weight_path.split('.')[-1] if suffix != 'npz': - print("Unsupport weight file: Unrecognized format %s " % suffix) + print("Unsupported weight file: Unrecognized format %s " % suffix) exit(-1) return np.load(weight_path) diff --git a/examples/tensorrt/vit/plugin_loader_int8.py b/examples/tensorrt/vit/plugin_loader_int8.py new file mode 100644 index 000000000..fd932eb4f --- /dev/null +++ b/examples/tensorrt/vit/plugin_loader_int8.py @@ -0,0 +1,198 @@ +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ctypes +import sys +sys.path.insert(0, "../../pytorch/vit") +from VisionTransformerINT8WeightLoader import ViTINT8WeightLoader +import numpy as np +import os +import os.path +import tensorrt as trt +import torch + +def load_weights(weight_path:str): + suffix = weight_path.split('.')[-1] + if suffix != 'pth': + print("Unsupported weight file: Unrecognized format %s " % suffix) + exit(-1) + return torch.load(weight_path, map_location="cpu") + +class ViTINT8PluginLoader: + def __init__(self, plugin_path) -> None: + + handle = ctypes.CDLL(plugin_path, mode=ctypes.RTLD_GLOBAL) + if not handle: + raise RuntimeError("Fail to load plugin library: %s" % plugin_path) + + self.logger_ = trt.Logger(trt.Logger.INFO) + trt.init_libnvinfer_plugins(self.logger_, "") + plg_registry = trt.get_plugin_registry() + + self.plg_creator = plg_registry.get_plugin_creator("CustomVisionTransformerINT8Plugin", "1", "") + + def load_model_config(self, config, args): + self.patch_size_ = config.patches.size[0] + self.num_heads_ = config.transformer.num_heads + self.layer_num_ = config.transformer.num_layers + self.inter_size_ = config.transformer.mlp_dim + self.embed_dim_ = config.hidden_size + self.max_batch_ = args.batch_size + self.img_size_ = args.img_size + self.with_class_token_ = (config.classifier == 'token') + self.seq_len_ = pow(self.img_size_//self.patch_size_, 2) + 1 if self.with_class_token_ else 0 + self.in_chans_ = 3 + self.int8_mode_ = args.int8_mode + self.serial_name_ = "ViTINT8Engine_{}_{}_{}_{}_{}_{}_{}_{}_{}".format(self.patch_size_, + self.num_heads_ , + self.layer_num_ , + self.inter_size_, + self.embed_dim_ , + self.max_batch_ , + self.img_size_ , + self.seq_len_, + self.int8_mode_) + self.value_holder = [] + + + def build_plugin_field_collection(self, weights): + field_type = trt.PluginFieldType.FLOAT16 + arr_type = np.float16 + + self.value_holder = [np.array([self.max_batch_ ]).astype(np.int32), + np.array([self.img_size_ ]).astype(np.int32), + np.array([self.patch_size_]).astype(np.int32), + np.array([self.in_chans_ ]).astype(np.int32), + np.array([self.embed_dim_ ]).astype(np.int32), + np.array([self.num_heads_ ]).astype(np.int32), + np.array([self.inter_size_]).astype(np.int32), + np.array([self.layer_num_ ]).astype(np.int32), + np.array([self.int8_mode_ ]).astype(np.int32), + np.array([self.with_class_token_]).astype(np.int32) + ] + + max_batch = trt.PluginField("max_batch", self.value_holder[0], trt.PluginFieldType.INT32) + img_size = trt.PluginField("img_size", self.value_holder[1], trt.PluginFieldType.INT32) + patch_size = trt.PluginField("patch_size", self.value_holder[2], trt.PluginFieldType.INT32) + in_chans = trt.PluginField("in_chans", self.value_holder[3], trt.PluginFieldType.INT32) + embed_dim = trt.PluginField("embed_dim", self.value_holder[4], trt.PluginFieldType.INT32) + num_heads = trt.PluginField("num_heads", self.value_holder[5], trt.PluginFieldType.INT32) + inter_size = trt.PluginField("inter_size", self.value_holder[6], trt.PluginFieldType.INT32) + layer_num = trt.PluginField("layer_num", self.value_holder[7], trt.PluginFieldType.INT32) + int8_mode = trt.PluginField("int8_mode", self.value_holder[8], trt.PluginFieldType.INT32) + with_cls_token = trt.PluginField("with_cls_token", self.value_holder[9], trt.PluginFieldType.INT32) + + vit_weights = ViTINT8WeightLoader(self.layer_num_, self.img_size_, self.patch_size_, weights, + classifier='token' if self.with_class_token_ else '' ) + vit_weights.to_int8(ths_path='../../../build/lib/libpyt_vit.so') + vit_weights.to_cuda() + weights = vit_weights.listed_weight_to_dict() + + part_fc = [] + for name in weights.keys(): + if name == 'transformer.embeddings.cls_token' and (not self.with_class_token_): + continue + elif name.split('.')[-1] == 'amaxList' or name.split('.')[-1] == 'h_amaxList': + self.value_holder.append(weights[name].cpu().numpy().astype(np.float32)) + part_fc.append(trt.PluginField(name, self.value_holder[-1], trt.PluginFieldType.FLOAT32)) + else: + self.value_holder.append(weights[name].cpu().numpy().astype(np.float16)) + part_fc.append(trt.PluginField(name, self.value_holder[-1], trt.PluginFieldType.FLOAT16)) + + return trt.PluginFieldCollection([max_batch, img_size, patch_size, in_chans, embed_dim, num_heads, inter_size, layer_num, int8_mode, with_cls_token] + part_fc) + + + def build_network(self, weights): + explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + # weights = load_weights(weights_path) + + with trt.Builder(self.logger_) as builder, builder.create_network(explicit_batch_flag) as network, builder.create_builder_config() as builder_config: + builder_config.max_workspace_size = 8 << 30 + builder_config.set_flag(trt.BuilderFlag.FP16) + builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + # Create the network + input_tensor = network.add_input(name="input_img", dtype=trt.float16, shape=(-1, self.in_chans_, self.img_size_, self.img_size_)) + + # Specify profiles + profile = builder.create_optimization_profile() + min_shape = (1, self.in_chans_, self.img_size_, self.img_size_) + ##TODO: There is a bug in TRT when opt batch is large + max_shape = (self.max_batch_, self.in_chans_, self.img_size_, self.img_size_) + profile.set_shape("input_img", min=min_shape, opt=min_shape, max=max_shape) + builder_config.add_optimization_profile(profile) + + #import pdb;pdb.set_trace() + print("Generate plugin field collection...") + pfc = self.build_plugin_field_collection(weights) + + + fn = self.plg_creator.create_plugin("vision_transformer", pfc) + inputs = [input_tensor] + vit = network.add_plugin_v2(inputs, fn) + + output_tensor = vit.get_output(0) + output_tensor.name = "visiont_transformer_output" + + vit.precision = trt.float16 + vit.set_output_type(0, trt.float16) + network.mark_output(output_tensor) + + print("Building TRT engine....") + engine = builder.build_engine(network, builder_config) + return engine + + def serialize_engine(self, engine, file_folder='./'): + if not os.path.isdir(file_folder): + self.logger_.log(self.logger_.VERBOSE, "%s is not a folder." % file_folder) + exit(-1) + + file_path =os.path.join(file_folder, self.serial_name_) + + self.logger_.log(self.logger_.VERBOSE, "Serializing Engine...") + serialized_engine = engine.serialize() + self.logger_.log(self.logger_.INFO, "Saving Engine to {:}".format(file_path)) + with open(file_path, "wb") as fout: + fout.write(serialized_engine) + self.logger_.log(self.logger_.INFO, "Done.") + + def deserialize_engine(self, file_folder='./'): + if not os.path.isdir(file_folder): + self.logger_.log(self.logger_.VERBOSE, "%s is not a folder." % file_folder) + exit(-1) + + file_path =os.path.join(file_folder, self.serial_name_) + if not os.path.isfile(file_path): + self.logger_.log(self.logger_.VERBOSE, "%s not exists. " % file_path) + return None + + filename = os.path.basename(file_path) + info = filename.split('_') + self.patch_size_ = int(info[1]) + self.num_heads_ = int(info[2]) + self.layer_num_ = int(info[3]) + self.inter_size_ = int(info[4]) + self.embed_dim_ = int(info[5]) + self.max_batch_ = int(info[6]) + self.img_size_ = int(info[7]) + self.seq_len_ = int(info[8]) + self.int8_mode_ = int(info[9]) + self.in_chans_ = 3 + with open(file_path, 'rb') as f: + runtime = trt.Runtime(self.logger_) + return runtime.deserialize_cuda_engine(f.read()) + + + + diff --git a/src/fastertransformer/kernels/CMakeLists.txt b/src/fastertransformer/kernels/CMakeLists.txt index 3db08301f..1d207a245 100644 --- a/src/fastertransformer/kernels/CMakeLists.txt +++ b/src/fastertransformer/kernels/CMakeLists.txt @@ -107,7 +107,11 @@ set_property(TARGET xlnet_attention_kernels PROPERTY POSITION_INDEPENDENT_CODE set_property(TARGET xlnet_attention_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -add_library(decoder_masked_multihead_attention STATIC decoder_masked_multihead_attention.cu) +set(decoder_masked_multihead_attention_files + decoder_masked_multihead_attention.cu +) +file(GLOB decoder_masked_multihead_attention_files ${decoder_masked_multihead_attention_files} ./decoder_masked_multihead_attention/*.cu) +add_library(decoder_masked_multihead_attention STATIC ${decoder_masked_multihead_attention_files}) set_property(TARGET decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/fastertransformer/kernels/activation_int8_kernels.cu b/src/fastertransformer/kernels/activation_int8_kernels.cu index ee595a68a..10312617c 100644 --- a/src/fastertransformer/kernels/activation_int8_kernels.cu +++ b/src/fastertransformer/kernels/activation_int8_kernels.cu @@ -39,18 +39,18 @@ __inline__ __device__ T gelu(T x) template<> __inline__ __device__ half gelu(half x) { - half val = half(0.7978845608028654) * (x + half(0.044715) * x * x * x); + half val = half(0.7978845608028654) * (x + half(0.044715) * x * x * x); half fast_val = fast_tanh(val); - half cdf = half(0.5) * (half(1.0) + fast_val); + half cdf = half(0.5) * (half(1.0) + fast_val); return x * cdf; } template<> __inline__ __device__ half2 gelu(half2 val) { - half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); + half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); @@ -61,101 +61,101 @@ __inline__ __device__ half2 gelu(half2 val) // grid, thread = (m), (n/4) // using char4 as output // for per-channel-quantization weight -__global__ void add_bias_gelu_COL32_int32I_int8O(int8_t* out, +__global__ void add_bias_gelu_COL32_int32I_int8O(int8_t* out, const int32_t* input, - const float* bias, - const int m, - const int n, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr) + const float* bias, + const int m, + const int n, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr) { const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float out_scale = __ldg(out_scale_ptr); - int col_start = threadIdx.x << 2; + int col_start = threadIdx.x << 2; char4* outTmpPtr = (char4*)out; - char4 tmp; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; - float val; + char4 tmp; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + float val; - const int4 input4 = __ldg(((const int4*)input) + outIdx); + const int4 input4 = __ldg(((const int4*)input) + outIdx); const float4 weight4 = __ldg(((const float4*)weight_amax) + threadIdx.x); - const float4 bias4 = __ldg(((const float4*)bias) + threadIdx.x); + const float4 bias4 = __ldg(((const float4*)bias) + threadIdx.x); - val = static_cast(input4.x) * weight4.x * input_deQFactor_div127 + bias4.x; - val = gelu(val); + val = static_cast(input4.x) * weight4.x * input_deQFactor_div127 + bias4.x; + val = gelu(val); tmp.x = float_to_int8_rn(val * out_scale); - val = static_cast(input4.y) * weight4.y * input_deQFactor_div127 + bias4.y; - val = gelu(val); + val = static_cast(input4.y) * weight4.y * input_deQFactor_div127 + bias4.y; + val = gelu(val); tmp.y = float_to_int8_rn(val * out_scale); col_start = col_start + 1; - val = static_cast(input4.z) * weight4.z * input_deQFactor_div127 + bias4.z; - val = gelu(val); - tmp.z = float_to_int8_rn(val * out_scale); + val = static_cast(input4.z) * weight4.z * input_deQFactor_div127 + bias4.z; + val = gelu(val); + tmp.z = float_to_int8_rn(val * out_scale); col_start = col_start + 1; - val = static_cast(input4.w) * weight4.w * input_deQFactor_div127 + bias4.w; - val = gelu(val); - tmp.w = float_to_int8_rn(val * out_scale); + val = static_cast(input4.w) * weight4.w * input_deQFactor_div127 + bias4.w; + val = gelu(val); + tmp.w = float_to_int8_rn(val * out_scale); outTmpPtr[outIdx] = tmp; } -__global__ void add_bias_gelu_COL32_int32I_int8O(char4* out, - const int4* input, - const half2* bias, - const int m, - const int n, +__global__ void add_bias_gelu_COL32_int32I_int8O(char4* out, + const int4* input, + const half2* bias, + const int m, + const int n, const float4* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr) + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr) { const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); - const float out_scale = __ldg(out_scale_ptr); - int col_start = threadIdx.x << 2; - int threadIdx2 = threadIdx.x << 1; - char4 tmp; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; - float val; - - const int4 input4 = __ldg(input + outIdx); - const float4 weight4 = __ldg(weight_amax + threadIdx.x); - const half2 biasTmp = __ldg(bias + threadIdx2); - const half2 biasTmp2 = __ldg(bias + threadIdx2 + 1); - - val = static_cast(input4.x) * weight4.x * input_deQFactor_div127 + static_cast(biasTmp.x); - val = gelu(val); + const float out_scale = __ldg(out_scale_ptr); + int col_start = threadIdx.x << 2; + int threadIdx2 = threadIdx.x << 1; + char4 tmp; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + float val; + + const int4 input4 = __ldg(input + outIdx); + const float4 weight4 = __ldg(weight_amax + threadIdx.x); + const half2 biasTmp = __ldg(bias + threadIdx2); + const half2 biasTmp2 = __ldg(bias + threadIdx2 + 1); + + val = static_cast(input4.x) * weight4.x * input_deQFactor_div127 + static_cast(biasTmp.x); + val = gelu(val); tmp.x = float_to_int8_rn(out_scale * val); - val = static_cast(input4.y) * weight4.y * input_deQFactor_div127 + static_cast(biasTmp.y); - val = gelu(val); + val = static_cast(input4.y) * weight4.y * input_deQFactor_div127 + static_cast(biasTmp.y); + val = gelu(val); tmp.y = float_to_int8_rn(out_scale * val); - val = static_cast(input4.z) * weight4.z * input_deQFactor_div127 + static_cast(biasTmp2.x); - val = gelu(val); + val = static_cast(input4.z) * weight4.z * input_deQFactor_div127 + static_cast(biasTmp2.x); + val = gelu(val); tmp.z = float_to_int8_rn(out_scale * val); - val = static_cast(input4.w) * weight4.w * input_deQFactor_div127 + static_cast(biasTmp2.y); - val = gelu(val); + val = static_cast(input4.w) * weight4.w * input_deQFactor_div127 + static_cast(biasTmp2.y); + val = gelu(val); tmp.w = float_to_int8_rn(out_scale * val); out[outIdx] = tmp; } template -void invokeAddBiasGeluCol32(int8_t* out, +void invokeAddBiasGeluCol32(int8_t* out, const int32_t* in, - const T* bias, - const int m, - const int n, - cudaStream_t stream, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr) + const T* bias, + const int m, + const int n, + cudaStream_t stream, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr) { dim3 grid(m); dim3 block(n / 4); @@ -176,114 +176,114 @@ void invokeAddBiasGeluCol32(int8_t* out, } } -template void invokeAddBiasGeluCol32(int8_t* out, +template void invokeAddBiasGeluCol32(int8_t* out, const int32_t* in, - const float* bias, - const int m, - const int n, - cudaStream_t stream, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr); -template void invokeAddBiasGeluCol32(int8_t* out, + const float* bias, + const int m, + const int n, + cudaStream_t stream, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr); +template void invokeAddBiasGeluCol32(int8_t* out, const int32_t* in, - const half* bias, - const int m, - const int n, - cudaStream_t stream, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr); + const half* bias, + const int m, + const int n, + cudaStream_t stream, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr); // add bias to matrix of m * n, CUBLASLT_ORDER_COL32 // grid, thread = (m), (n/4) // using char4 // for per-tensor-quantization weight template -__global__ void add_bias_gelu_COL32_int8IO(int8_t* out, +__global__ void add_bias_gelu_COL32_int8IO(int8_t* out, const int8_t* input, - const T* bias, - const int m, - const int n, - const float* input_deQFactor_ptr, - const float* out_scale_ptr) + const T* bias, + const int m, + const int n, + const float* input_deQFactor_ptr, + const float* out_scale_ptr) { const float input_deQFactor = __ldg(input_deQFactor_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float out_scale = __ldg(out_scale_ptr); // int col_start = threadIdx.x << 2; - char4* outTmpPtr = (char4*)out; + char4* outTmpPtr = (char4*)out; char4* inputTmpPtr = (char4*)input; - char4 tmp; + char4 tmp; for (int col_start = threadIdx.x << 2; col_start < n; col_start += (blockDim.x << 2)) { - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; float val; - tmp = __ldg(inputTmpPtr + outIdx); - val = static_cast(tmp.x) * input_deQFactor + static_cast(__ldg(bias + col_start)); - val = gelu(val); + tmp = __ldg(inputTmpPtr + outIdx); + val = static_cast(tmp.x) * input_deQFactor + static_cast(__ldg(bias + col_start)); + val = gelu(val); tmp.x = float_to_int8_rn(val * out_scale); col_start = col_start + 1; - val = static_cast(tmp.y) * input_deQFactor + static_cast(__ldg(bias + col_start)); - val = gelu(val); - tmp.y = float_to_int8_rn(val * out_scale); + val = static_cast(tmp.y) * input_deQFactor + static_cast(__ldg(bias + col_start)); + val = gelu(val); + tmp.y = float_to_int8_rn(val * out_scale); col_start = col_start + 1; - val = static_cast(tmp.z) * input_deQFactor + static_cast(__ldg(bias + col_start)); - val = gelu(val); - tmp.z = float_to_int8_rn(val * out_scale); + val = static_cast(tmp.z) * input_deQFactor + static_cast(__ldg(bias + col_start)); + val = gelu(val); + tmp.z = float_to_int8_rn(val * out_scale); col_start = col_start + 1; - val = static_cast(tmp.w) * input_deQFactor + static_cast(__ldg(bias + col_start)); - val = gelu(val); - tmp.w = float_to_int8_rn(val * out_scale); + val = static_cast(tmp.w) * input_deQFactor + static_cast(__ldg(bias + col_start)); + val = gelu(val); + tmp.w = float_to_int8_rn(val * out_scale); outTmpPtr[outIdx] = tmp; } } template -void invokeAddBiasGeluCol32(int8_t* out, +void invokeAddBiasGeluCol32(int8_t* out, const int8_t* in, - const T* bias, - const int m, - const int n, - cudaStream_t stream, - const float* input_deQFactor_ptr, - const float* out_scale_ptr) + const T* bias, + const int m, + const int n, + cudaStream_t stream, + const float* input_deQFactor_ptr, + const float* out_scale_ptr) { dim3 grid; dim3 block; if (n / 4 <= 1024) { block.x = n / 4; - grid.x = m; + grid.x = m; } else { block.x = 1024; - grid.x = m; + grid.x = m; } add_bias_gelu_COL32_int8IO<<>>(out, in, bias, m, n, input_deQFactor_ptr, out_scale_ptr); } -template void invokeAddBiasGeluCol32(int8_t* out, +template void invokeAddBiasGeluCol32(int8_t* out, const int8_t* in, - const float* bias, - const int m, - const int n, - cudaStream_t stream, - const float* input_deQFactor_ptr, - const float* out_scale_ptr); - -template void invokeAddBiasGeluCol32(int8_t* out, + const float* bias, + const int m, + const int n, + cudaStream_t stream, + const float* input_deQFactor_ptr, + const float* out_scale_ptr); + +template void invokeAddBiasGeluCol32(int8_t* out, const int8_t* in, - const half* bias, - const int m, - const int n, - cudaStream_t stream, - const float* input_deQFactor_ptr, - const float* out_scale_ptr); + const half* bias, + const int m, + const int n, + cudaStream_t stream, + const float* input_deQFactor_ptr, + const float* out_scale_ptr); /******************* invokeAddBiasGeluCol32_v2 ***********************/ @@ -297,28 +297,28 @@ __global__ void add_bias_gelu_COL32_int8IO( { const float input_deQFactor = __ldg(input_deQFactor_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float out_scale = __ldg(out_scale_ptr); for (int col_start = threadIdx.x << 2; col_start < n; col_start += (blockDim.x << 2)) { char4* outTmpPtr = (char4*)out; - char4 tmp; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; - float val; - tmp = __ldg(outTmpPtr + outIdx); - val = static_cast(tmp.x) * input_deQFactor + static_cast(__ldg(bias + col_start)); - val = gelu(val); + char4 tmp; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + float val; + tmp = __ldg(outTmpPtr + outIdx); + val = static_cast(tmp.x) * input_deQFactor + static_cast(__ldg(bias + col_start)); + val = gelu(val); tmp.x = float_to_int8_rn(val * out_scale); - val = static_cast(tmp.y) * input_deQFactor + static_cast(__ldg(bias + col_start + 1)); - val = gelu(val); + val = static_cast(tmp.y) * input_deQFactor + static_cast(__ldg(bias + col_start + 1)); + val = gelu(val); tmp.y = float_to_int8_rn(val * out_scale); - val = static_cast(tmp.z) * input_deQFactor + static_cast(__ldg(bias + col_start + 2)); - val = gelu(val); + val = static_cast(tmp.z) * input_deQFactor + static_cast(__ldg(bias + col_start + 2)); + val = gelu(val); tmp.z = float_to_int8_rn(val * out_scale); - val = static_cast(tmp.w) * input_deQFactor + static_cast(__ldg(bias + col_start + 3)); - val = gelu(val); + val = static_cast(tmp.w) * input_deQFactor + static_cast(__ldg(bias + col_start + 3)); + val = gelu(val); tmp.w = float_to_int8_rn(val * out_scale); outTmpPtr[outIdx] = tmp; @@ -330,55 +330,55 @@ __global__ void add_bias_gelu_COL32_int8IO( // using char4 // for per-tensor-quantization weight template<> -__global__ void add_bias_gelu_COL32_int8IO(int8_t* out, - const half* bias, - const int m, - const int n, +__global__ void add_bias_gelu_COL32_int8IO(int8_t* out, + const half* bias, + const int m, + const int n, const float* input_deQFactor_ptr, const float* out_scale_ptr) { const float input_deQFactor = __ldg(input_deQFactor_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float out_scale = __ldg(out_scale_ptr); for (int col_start = threadIdx.x << 2; col_start < n; col_start += (blockDim.x << 2)) { - char4* outTmpPtr = (char4*)out; + char4* outTmpPtr = (char4*)out; char4* inputTmpPtr = (char4*)out; - char4 tmp; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; - half val; - half input_deQFactor_half = half(input_deQFactor); - half out_scale_half = half(out_scale); - tmp = __ldg(inputTmpPtr + outIdx); - - val = static_cast(tmp.x) * input_deQFactor_half + (__ldg(bias + col_start)); - val = gelu(val); + char4 tmp; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + half val; + half input_deQFactor_half = half(input_deQFactor); + half out_scale_half = half(out_scale); + tmp = __ldg(inputTmpPtr + outIdx); + + val = static_cast(tmp.x) * input_deQFactor_half + (__ldg(bias + col_start)); + val = gelu(val); tmp.x = float_to_int8_rn(float(val * out_scale_half)); col_start = col_start + 1; - val = static_cast(tmp.y) * input_deQFactor_half + (__ldg(bias + col_start)); - val = gelu(val); - tmp.y = float_to_int8_rn(float(val * out_scale_half)); + val = static_cast(tmp.y) * input_deQFactor_half + (__ldg(bias + col_start)); + val = gelu(val); + tmp.y = float_to_int8_rn(float(val * out_scale_half)); col_start = col_start + 1; - val = static_cast(tmp.z) * input_deQFactor_half + (__ldg(bias + col_start)); - val = gelu(val); - tmp.z = float_to_int8_rn(float(val * out_scale_half)); + val = static_cast(tmp.z) * input_deQFactor_half + (__ldg(bias + col_start)); + val = gelu(val); + tmp.z = float_to_int8_rn(float(val * out_scale_half)); col_start = col_start + 1; - val = static_cast(tmp.w) * input_deQFactor_half + (__ldg(bias + col_start)); - val = gelu(val); - tmp.w = float_to_int8_rn(float(val * out_scale_half)); + val = static_cast(tmp.w) * input_deQFactor_half + (__ldg(bias + col_start)); + val = gelu(val); + tmp.w = float_to_int8_rn(float(val * out_scale_half)); outTmpPtr[outIdx] = tmp; } } template -void invokeAddBiasGeluCol32_v2(int8_t* out, - const T* bias, - const int m, - const int n, +void invokeAddBiasGeluCol32_v2(int8_t* out, + const T* bias, + const int m, + const int n, cudaStream_t stream, const float* input_deQFactor_ptr, const float* out_scale_ptr) @@ -397,18 +397,18 @@ void invokeAddBiasGeluCol32_v2(int8_t* out, } } -template void invokeAddBiasGeluCol32_v2(int8_t* out, +template void invokeAddBiasGeluCol32_v2(int8_t* out, const float* bias, - const int m, - const int n, + const int m, + const int n, cudaStream_t stream, const float* input_deQFactor_ptr, const float* out_scale_ptr); -template void invokeAddBiasGeluCol32_v2(int8_t* out, - const half* bias, - const int m, - const int n, +template void invokeAddBiasGeluCol32_v2(int8_t* out, + const half* bias, + const int m, + const int n, cudaStream_t stream, const float* input_deQFactor_ptr, const float* out_scale_ptr); @@ -418,56 +418,56 @@ template void invokeAddBiasGeluCol32_v2(int8_t* out, // using char4 // for per-tensor-quantization weight template -__global__ void add_bias_gelu_ROW_int8IO(int8_t* out, +__global__ void add_bias_gelu_ROW_int8IO(int8_t* out, const int8_t* input, - const T* bias, - const int m, - const int n, - const float* input_deQFactor_ptr, - const float* out_scale_ptr) + const T* bias, + const int m, + const int n, + const float* input_deQFactor_ptr, + const float* out_scale_ptr) { const float input_deQFactor = __ldg(input_deQFactor_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float out_scale = __ldg(out_scale_ptr); - int col_start = threadIdx.x << 2; - char4* outTmpPtr = (char4*)out; + int col_start = threadIdx.x << 2; + char4* outTmpPtr = (char4*)out; char4* inputTmpPtr = (char4*)input; - char4 tmp; - int outIdx = (blockIdx.x * n + col_start) >> 2; - float val; - tmp = __ldg(inputTmpPtr + outIdx); - val = static_cast(tmp.x) * input_deQFactor + static_cast(__ldg(bias + col_start)); - val = gelu(val); + char4 tmp; + int outIdx = (blockIdx.x * n + col_start) >> 2; + float val; + tmp = __ldg(inputTmpPtr + outIdx); + val = static_cast(tmp.x) * input_deQFactor + static_cast(__ldg(bias + col_start)); + val = gelu(val); tmp.x = float_to_int8_rn(val * out_scale); col_start = col_start + 1; - val = static_cast(tmp.y) * input_deQFactor + static_cast(__ldg(bias + col_start)); - val = gelu(val); - tmp.y = float_to_int8_rn(val * out_scale); + val = static_cast(tmp.y) * input_deQFactor + static_cast(__ldg(bias + col_start)); + val = gelu(val); + tmp.y = float_to_int8_rn(val * out_scale); col_start = col_start + 1; - val = static_cast(tmp.z) * input_deQFactor + static_cast(__ldg(bias + col_start)); - val = gelu(val); - tmp.z = float_to_int8_rn(val * out_scale); + val = static_cast(tmp.z) * input_deQFactor + static_cast(__ldg(bias + col_start)); + val = gelu(val); + tmp.z = float_to_int8_rn(val * out_scale); col_start = col_start + 1; - val = static_cast(tmp.w) * input_deQFactor + static_cast(__ldg(bias + col_start)); - val = gelu(val); - tmp.w = float_to_int8_rn(val * out_scale); + val = static_cast(tmp.w) * input_deQFactor + static_cast(__ldg(bias + col_start)); + val = gelu(val); + tmp.w = float_to_int8_rn(val * out_scale); outTmpPtr[outIdx] = tmp; } template -void invokeAddBiasGeluRow(int8_t* out, +void invokeAddBiasGeluRow(int8_t* out, const int8_t* in, - const T* bias, - const int m, - const int n, - cudaStream_t stream, - const float* input_deQFactor_ptr, - const float* out_scale_ptr) + const T* bias, + const int m, + const int n, + cudaStream_t stream, + const float* input_deQFactor_ptr, + const float* out_scale_ptr) { dim3 grid(m); dim3 block(n / 4); @@ -476,22 +476,22 @@ void invokeAddBiasGeluRow(int8_t* out, add_bias_gelu_ROW_int8IO<<>>(out, in, bias, m, n, input_deQFactor_ptr, out_scale_ptr); } -template void invokeAddBiasGeluRow(int8_t* out, +template void invokeAddBiasGeluRow(int8_t* out, const int8_t* in, - const float* bias, - const int m, - const int n, - cudaStream_t stream, - const float* input_deQFactor_ptr, - const float* out_scale_ptr); - -template void invokeAddBiasGeluRow(int8_t* out, + const float* bias, + const int m, + const int n, + cudaStream_t stream, + const float* input_deQFactor_ptr, + const float* out_scale_ptr); + +template void invokeAddBiasGeluRow(int8_t* out, const int8_t* in, - const half* bias, - const int m, - const int n, - cudaStream_t stream, - const float* input_deQFactor_ptr, - const float* out_scale_ptr); + const half* bias, + const int m, + const int n, + cudaStream_t stream, + const float* input_deQFactor_ptr, + const float* out_scale_ptr); } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/activation_int8_kernels.h b/src/fastertransformer/kernels/activation_int8_kernels.h index ce285b3ad..e7bb8be7f 100644 --- a/src/fastertransformer/kernels/activation_int8_kernels.h +++ b/src/fastertransformer/kernels/activation_int8_kernels.h @@ -24,43 +24,43 @@ namespace fastertransformer { template -void invokeAddBiasGeluCol32(int8_t* out, +void invokeAddBiasGeluCol32(int8_t* out, const int32_t* in, - const T* bias, - const int m, - const int n, - cudaStream_t stream, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr); + const T* bias, + const int m, + const int n, + cudaStream_t stream, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr); template -void invokeAddBiasGeluCol32(int8_t* out, +void invokeAddBiasGeluCol32(int8_t* out, const int8_t* in, - const T* bias, - const int m, - const int n, - cudaStream_t stream, - const float* input_deQFactor_ptr, - const float* out_scale_ptr); + const T* bias, + const int m, + const int n, + cudaStream_t stream, + const float* input_deQFactor_ptr, + const float* out_scale_ptr); template -void invokeAddBiasGeluCol32_v2(int8_t* out, - const T* bias, - const int m, - const int n, +void invokeAddBiasGeluCol32_v2(int8_t* out, + const T* bias, + const int m, + const int n, cudaStream_t stream, const float* input_deQFactor_ptr, const float* out_scale_ptr); template -void invokeAddBiasGeluRow(int8_t* out, +void invokeAddBiasGeluRow(int8_t* out, const int8_t* in, - const T* bias, - const int m, - const int n, - cudaStream_t stream, - const float* input_deQFactor_ptr, - const float* out_scale_ptr); + const T* bias, + const int m, + const int n, + cudaStream_t stream, + const float* input_deQFactor_ptr, + const float* out_scale_ptr); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/activation_kernels.cu b/src/fastertransformer/kernels/activation_kernels.cu index 7ff8e0fa5..c16a6077d 100644 --- a/src/fastertransformer/kernels/activation_kernels.cu +++ b/src/fastertransformer/kernels/activation_kernels.cu @@ -19,22 +19,41 @@ #include "src/fastertransformer/utils/cuda_utils.h" namespace fastertransformer { +__forceinline__ __device__ float copysignf_pos(float a, float b) +{ + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__inline__ __device__ float tanh_opt(float x) +{ +#if (__CUDA_ARCH__ >= 750) + float r; + asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); + return r; +#else + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#endif +} + template __inline__ __device__ T gelu(T x) { - float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (x + 0.044715f * x * x * x)))); return x * cdf; } template<> __inline__ __device__ half2 gelu(half2 val) { - half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); + half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); - tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); return __hmul2(val, __float22half2_rn(tmp)); } @@ -43,11 +62,11 @@ template<> __inline__ __device__ __nv_bfloat162 gelu(__nv_bfloat162 val) { __nv_bfloat162 val_pow3 = bf16hmul2(val, bf16hmul2(val, val)); - float2 tmp_pow = bf1622float2(val_pow3); - float2 tmp = bf1622float2(val); + float2 tmp_pow = bf1622float2(val_pow3); + float2 tmp = bf1622float2(val); - tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); return bf16hmul2(val, __floats2bfloat162_rn(tmp.x, tmp.y)); } #endif @@ -59,7 +78,7 @@ __global__ void addBiasGelu(T* out, const T* __restrict bias, int m, int n) T val = out[id]; if (bias != nullptr) { T reg_bias = __ldg(&bias[id % n]); - val = val + reg_bias; + val = val + reg_bias; } out[id] = (T)(gelu(val)); } @@ -68,14 +87,14 @@ __global__ void addBiasGelu(T* out, const T* __restrict bias, int m, int n) template<> __global__ void addBiasGelu(half* out, const half* __restrict bias, int m, int n) { - half2* out_ptr = (half2*)out; + half2* out_ptr = (half2*)out; const half2* bias_ptr = (half2*)bias; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { half2 val = out_ptr[id]; if (bias != nullptr) { half2 reg_bias = __ldg(&bias_ptr[id % n]); - val = __hadd2(val, reg_bias); + val = __hadd2(val, reg_bias); } out_ptr[id] = gelu(val); } @@ -85,14 +104,14 @@ __global__ void addBiasGelu(half* out, const half* __restrict bias, int m, int n template<> __global__ void addBiasGelu(__nv_bfloat16* out, const __nv_bfloat16* __restrict bias, int m, int n) { - __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; + __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; const __nv_bfloat162* bias_ptr = (__nv_bfloat162*)bias; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { __nv_bfloat162 val = out_ptr[id]; if (bias != nullptr) { __nv_bfloat162 reg_bias = ldg(&bias_ptr[id % n]); - val = bf16hadd2(val, reg_bias); + val = bf16hadd2(val, reg_bias); } out_ptr[id] = gelu(val); } @@ -103,14 +122,14 @@ template void invokeAddBiasGelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream) { const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 - dim3 block, grid; + dim3 block, grid; if (n / 4 / data_type_factor <= 1024) { block.x = n / 4 / data_type_factor; - grid.x = m; + grid.x = m; } else { block.x = 1024; - grid.x = ceil(m * n / 1024.); + grid.x = ceil(m * n / 1024.); } addBiasGelu<<>>(out, bias, m, n / data_type_factor); } @@ -122,6 +141,122 @@ template void invokeAddBiasGelu(__nv_bfloat16* out, const __nv_bfloat16* bias, const int m, const int n, cudaStream_t stream); #endif +// Invoke GeGlu (gated glue) +template +__global__ void +addBiasGatedGelu(T* hidden1, const T* hidden2, const T* __restrict bias1, const T* __restrict bias2, int m, int n) +{ + const bool use_bias = bias1 != nullptr && bias2 != nullptr; + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + T val1 = hidden1[id]; + T val2 = hidden2[id]; + if (use_bias) { + T reg_bias1 = __ldg(&bias1[id % n]); + T reg_bias2 = __ldg(&bias2[id % n]); + hidden1[id] = (T)(gelu(val1 + reg_bias1) * (val2 + reg_bias2)); + } + else { + hidden1[id] = (T)(gelu(val1) * val2); + } + } +} + +template<> +__global__ void addBiasGatedGelu( + half* hidden1, const half* hidden2, const half* __restrict bias1, const half* __restrict bias2, int m, int n) +{ + half2* hidden1_ptr = (half2*)hidden1; + const half2* hidden2_ptr = (half2*)hidden2; + const half2* bias1_ptr = (half2*)bias1; + const half2* bias2_ptr = (half2*)bias2; + const bool use_bias = bias1 != nullptr && bias2 != nullptr; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + half2 val1 = hidden1_ptr[id]; + half2 val2 = hidden2_ptr[id]; + if (use_bias) { + half2 reg_bias1 = __ldg(&bias1_ptr[id % n]); + half2 reg_bias2 = __ldg(&bias2_ptr[id % n]); + hidden1_ptr[id] = __hmul2(gelu(__hadd2(val1, reg_bias1)), __hadd2(val2, reg_bias2)); + } + else { + hidden1_ptr[id] = __hmul2(gelu(val1), val2); + } + } +} + +#ifdef ENABLE_BF16 +template<> +__global__ void addBiasGatedGelu(__nv_bfloat16* hidden1, + const __nv_bfloat16* hidden2, + const __nv_bfloat16* __restrict bias1, + const __nv_bfloat16* __restrict bias2, + int m, + int n) +{ + __nv_bfloat162* hidden1_ptr = (__nv_bfloat162*)hidden1; + const __nv_bfloat162* hidden2_ptr = (__nv_bfloat162*)hidden2; + const __nv_bfloat162* bias1_ptr = (__nv_bfloat162*)bias1; + const __nv_bfloat162* bias2_ptr = (__nv_bfloat162*)bias2; + const bool use_bias = bias1 != nullptr && bias2 != nullptr; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + __nv_bfloat162 val1 = hidden1_ptr[id]; + __nv_bfloat162 val2 = hidden2_ptr[id]; + if (use_bias) { + __nv_bfloat162 reg_bias1 = ldg(&bias1_ptr[id % n]); + __nv_bfloat162 reg_bias2 = ldg(&bias2_ptr[id % n]); + hidden1_ptr[id] = bf16hmul2(gelu(bf16hadd2(val1, reg_bias1)), bf16hadd2(val2, reg_bias2)); + } + else { + hidden1_ptr[id] = bf16hmul2(gelu(val1), val2); + } + } +} +#endif + +template +void invokeAddBiasGatedGelu( + T* hidden1, const T* hidden2, const T* bias1, const T* bias2, const int m, const int n, cudaStream_t stream) +{ + const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 + dim3 block, grid; + if (n / 4 / data_type_factor <= 1024) { + block.x = n / 4 / data_type_factor; + grid.x = m; + } + else { + block.x = 1024; + grid.x = ceil(m * n / 1024.); + } + addBiasGatedGelu<<>>(hidden1, hidden2, bias1, bias2, m, n / data_type_factor); +} + +// GELU(hidden1 + bias1) * (hidden2 + bias2) +template void invokeAddBiasGatedGelu(float* hidden1, + const float* hidden2, + const float* bias1, + const float* bias2, + const int m, + const int n, + cudaStream_t stream); +template void invokeAddBiasGatedGelu(half* hidden1, + const half* hidden2, + const half* bias1, + const half* bias2, + const int m, + const int n, + cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddBiasGatedGelu(__nv_bfloat16* hidden1, + const __nv_bfloat16* hidden2, + const __nv_bfloat16* bias1, + const __nv_bfloat16* bias2, + const int m, + const int n, + cudaStream_t stream); +#endif + template __global__ void add_bias_relu(T* out, const T* __restrict bias, int m, int n) { @@ -137,7 +272,7 @@ __global__ void add_bias_relu(T* out, const T* __restrict bias, int m, int n) template<> __global__ void add_bias_relu(half* out, const half* __restrict bias, int m, int n) { - half2* out_ptr = (half2*)out; + half2* out_ptr = (half2*)out; const half2* bias_ptr = (half2*)bias; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { @@ -145,8 +280,8 @@ __global__ void add_bias_relu(half* out, const half* __restrict bias, int m, int if (bias != nullptr) { val = val + __ldg(&bias_ptr[id % n]); } - val.x = val.x > (half)0.0f ? val.x : (half)0.0f; - val.y = val.y > (half)0.0f ? val.y : (half)0.0f; + val.x = val.x > (half)0.0f ? val.x : (half)0.0f; + val.y = val.y > (half)0.0f ? val.y : (half)0.0f; out_ptr[id] = val; } } @@ -155,7 +290,7 @@ __global__ void add_bias_relu(half* out, const half* __restrict bias, int m, int template<> __global__ void add_bias_relu(__nv_bfloat16* out, const __nv_bfloat16* __restrict bias, int m, int n) { - __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; + __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; const __nv_bfloat162* bias_ptr = (__nv_bfloat162*)bias; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { @@ -163,8 +298,8 @@ __global__ void add_bias_relu(__nv_bfloat16* out, const __nv_bfloat16* __restric if (bias != nullptr) { val = bf16hadd2(val, ldg(&bias_ptr[id % n])); } - val.x = val.x > (__nv_bfloat16)0.0f ? val.x : (__nv_bfloat16)0.0f; - val.y = val.y > (__nv_bfloat16)0.0f ? val.y : (__nv_bfloat16)0.0f; + val.x = val.x > (__nv_bfloat16)0.0f ? val.x : (__nv_bfloat16)0.0f; + val.y = val.y > (__nv_bfloat16)0.0f ? val.y : (__nv_bfloat16)0.0f; out_ptr[id] = val; } } @@ -174,14 +309,14 @@ template void invokeAddBiasRelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream) { const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 - dim3 block, grid; + dim3 block, grid; if (n / 4 / data_type_factor <= 1024) { block.x = n / 4 / data_type_factor; - grid.x = m; + grid.x = m; } else { block.x = 1024; - grid.x = ceil(m * n / 1024.); + grid.x = ceil(m * n / 1024.); } add_bias_relu<<>>(out, bias, m, n / data_type_factor); } @@ -193,6 +328,123 @@ template void invokeAddBiasRelu(__nv_bfloat16* out, const __nv_bfloat16* bias, const int m, const int n, cudaStream_t stream); #endif +// Invoke GeGlu (gated glue) +template +__global__ void +addBiasGatedRelu(T* hidden1, const T* hidden2, const T* __restrict bias1, const T* __restrict bias2, int m, int n) +{ + const bool use_bias = bias1 != nullptr && bias2 != nullptr; + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + T val1 = hidden1[id]; + T val2 = hidden2[id]; + if (use_bias) { + T reg_bias1 = __ldg(&bias1[id % n]); + T reg_bias2 = __ldg(&bias2[id % n]); + val1 += reg_bias1; + val2 += reg_bias2; + } + hidden1[id] = val1 > (T)0.0f ? val1 * val2 : (T)0.0f; + } +} + +template<> +__global__ void addBiasGatedRelu( + half* hidden1, const half* hidden2, const half* __restrict bias1, const half* __restrict bias2, int m, int n) +{ + half2* hidden1_ptr = (half2*)hidden1; + const half2* hidden2_ptr = (half2*)hidden2; + const half2* bias1_ptr = (half2*)bias1; + const half2* bias2_ptr = (half2*)bias2; + const bool use_bias = bias1 != nullptr && bias2 != nullptr; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + half2 val1 = hidden1_ptr[id]; + half2 val2 = hidden2_ptr[id]; + if (use_bias) { + half2 reg_bias1 = __ldg(&bias1_ptr[id % n]); + half2 reg_bias2 = __ldg(&bias2_ptr[id % n]); + val1 = __hadd2(val1, reg_bias1); + val2 = __hadd2(val2, reg_bias2); + } + val1.x = val1.x > (half)0.0f ? val1.x * val2.x : (half)0.0f; + val1.y = val1.y > (half)0.0f ? val1.y * val2.y : (half)0.0f; + hidden1_ptr[id] = val1; + } +} + +#ifdef ENABLE_BF16 +template<> +__global__ void addBiasGatedRelu(__nv_bfloat16* hidden1, + const __nv_bfloat16* hidden2, + const __nv_bfloat16* __restrict bias1, + const __nv_bfloat16* __restrict bias2, + int m, + int n) +{ + __nv_bfloat162* hidden1_ptr = (__nv_bfloat162*)hidden1; + const __nv_bfloat162* hidden2_ptr = (__nv_bfloat162*)hidden2; + const __nv_bfloat162* bias1_ptr = (__nv_bfloat162*)bias1; + const __nv_bfloat162* bias2_ptr = (__nv_bfloat162*)bias2; + const bool use_bias = bias1 != nullptr && bias2 != nullptr; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + __nv_bfloat162 val1 = hidden1_ptr[id]; + __nv_bfloat162 val2 = hidden2_ptr[id]; + if (use_bias) { + __nv_bfloat162 reg_bias1 = ldg(&bias1_ptr[id % n]); + __nv_bfloat162 reg_bias2 = ldg(&bias2_ptr[id % n]); + val1 = bf16hadd2(val1, reg_bias1); + val2 = bf16hadd2(val2, reg_bias2); + } + val1.x = val1.x > (__nv_bfloat16)0.0f ? bf16hadd(val1.x, val2.x) : (__nv_bfloat16)0.0f; + val1.y = val1.y > (__nv_bfloat16)0.0f ? bf16hadd(val1.y, val2.y) : (__nv_bfloat16)0.0f; + hidden1_ptr[id] = val1; + } +} +#endif + +template +void invokeAddBiasGatedRelu( + T* hidden1, const T* hidden2, const T* bias1, const T* bias2, const int m, const int n, cudaStream_t stream) +{ + const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 + dim3 block, grid; + if (n / 4 / data_type_factor <= 1024) { + block.x = n / 4 / data_type_factor; + grid.x = m; + } + else { + block.x = 1024; + grid.x = ceil(m * n / 1024.); + } + addBiasGatedRelu<<>>(hidden1, hidden2, bias1, bias2, m, n / data_type_factor); +} + +// GELU(hidden1 + bias1) * (hidden2 + bias2) +template void invokeAddBiasGatedRelu(float* hidden1, + const float* hidden2, + const float* bias1, + const float* bias2, + const int m, + const int n, + cudaStream_t stream); +template void invokeAddBiasGatedRelu(half* hidden1, + const half* hidden2, + const half* bias1, + const half* bias2, + const int m, + const int n, + cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddBiasGatedRelu(__nv_bfloat16* hidden1, + const __nv_bfloat16* hidden2, + const __nv_bfloat16* bias1, + const __nv_bfloat16* bias2, + const int m, + const int n, + cudaStream_t stream); +#endif + template __global__ void add_bias(H_T* out, const B_T* __restrict bias, int m, int n) { @@ -204,7 +456,7 @@ __global__ void add_bias(H_T* out, const B_T* __restrict bias, int m, int n) template<> __global__ void add_bias(half* out, const half* __restrict bias, int m, int n) { - half2* out_ptr = (half2*)out; + half2* out_ptr = (half2*)out; const half2* bias_ptr = (half2*)bias; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { out_ptr[id] = out_ptr[id] + __ldg(&bias_ptr[id % n]); @@ -215,7 +467,7 @@ __global__ void add_bias(half* out, const half* __restrict bias, int m, int n) template<> __global__ void add_bias(__nv_bfloat16* out, const __nv_bfloat16* __restrict bias, int m, int n) { - __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; + __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; const __nv_bfloat162* bias_ptr = (__nv_bfloat162*)bias; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { out_ptr[id] = bf16hadd2(out_ptr[id], ldg(&bias_ptr[id % n])); @@ -227,14 +479,14 @@ template void invokeAddBias(H_T* out, const B_T* bias, const int m, const int n, cudaStream_t stream) { const int data_type_factor = 4 / sizeof(H_T); // 1 for fp32, 2 for fp16 and bf16 - dim3 block, grid; + dim3 block, grid; if (n / 4 / data_type_factor <= 1024) { block.x = n / 4 / data_type_factor; - grid.x = m; + grid.x = m; } else { block.x = 1024; - grid.x = ceil(m * n / 1024.); + grid.x = ceil(m * n / 1024.); } add_bias<<>>(out, bias, m, n / data_type_factor); } @@ -248,4 +500,305 @@ invokeAddBias(__nv_bfloat16* out, const __nv_bfloat16* bias, const int m, const template void invokeAddBias(float* out, const __nv_bfloat16* bias, const int m, const int n, cudaStream_t stream); #endif +template +__global__ void addBiasGeluV2(T2* out, const T2* __restrict bias, const int size) +{ + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) { + T2 val = out[id]; + if (bias != nullptr) { + T2 reg_bias = ldg(&bias[id % N]); + val = hadd2(val, reg_bias); + } + out[id] = gelu(val); + } +} + +template +__global__ void addBiasGeluV3(T2* out, const T2* __restrict bias, const int size) +{ + T2 buffer[ELEMENT_PER_ROUND]; + T2 tmp_bias[ELEMENT_PER_ROUND]; + for (int id = blockIdx.x * blockDim.x * ELEMENT_PER_ROUND + threadIdx.x * ELEMENT_PER_ROUND; id < size; + id += blockDim.x * gridDim.x * ELEMENT_PER_ROUND) { +#pragma unroll + for (int i = 0; i < ELEMENT_PER_ROUND; i++) { + buffer[i] = out[id + i]; + if (bias != nullptr) { + tmp_bias[i] = ldg(&bias[(id + i) % N]); + } + } +#pragma unroll + for (int i = 0; i < ELEMENT_PER_ROUND; i++) { + if (bias != nullptr) { + buffer[i] = hadd2(buffer[i], tmp_bias[i]); + } + out[id + i] = gelu(buffer[i]); + } + } +} + +#define ADD_BIAS_GELU(HALF_N, ELEMENT_PER_ROUND) \ + case HALF_N: \ + if (ELEMENT_PER_ROUND > 1) { \ + grid.x = grid.x / ELEMENT_PER_ROUND; \ + addBiasGeluV3 \ + <<>>((T2*)out, (const T2*)bias, m * half_n); \ + } \ + else { \ + addBiasGeluV2<<>>((T2*)out, (const T2*)bias, m * half_n); \ + } \ + break; + +template +void invokeAddBiasGeluV2(T* out, const T* bias, const int m, const int n, cudaStream_t stream) +{ + if (n % 2 == 0 && sizeof(T) == 2) { + const int half_n = n / 2; + dim3 block, grid; + block.x = std::min(half_n, 512); + grid.x = (m * half_n + (block.x - 1)) / block.x; + using T2 = typename TypeConverter::Type; + + if (grid.x >= 512) { + switch (half_n) { + ADD_BIAS_GELU(256, 1) + ADD_BIAS_GELU(512, 1) + ADD_BIAS_GELU(1024, 1) + ADD_BIAS_GELU(1536, 1) + ADD_BIAS_GELU(2048, 1) + ADD_BIAS_GELU(4096, 2) + ADD_BIAS_GELU(8192, 2) + ADD_BIAS_GELU(16384, 2) + ADD_BIAS_GELU(24576, 2) + ADD_BIAS_GELU(40960, 4) + default: + invokeAddBiasGelu(out, bias, m, n, stream); + break; + } + } + else { + switch (half_n) { + ADD_BIAS_GELU(256, 1) + ADD_BIAS_GELU(512, 1) + ADD_BIAS_GELU(1024, 1) + ADD_BIAS_GELU(1536, 1) + ADD_BIAS_GELU(2048, 1) + ADD_BIAS_GELU(4096, 1) + ADD_BIAS_GELU(8192, 2) + ADD_BIAS_GELU(16384, 2) + ADD_BIAS_GELU(24576, 2) + ADD_BIAS_GELU(40960, 2) + default: + invokeAddBiasGelu(out, bias, m, n, stream); + break; + } + } + } + else { + invokeAddBiasGelu(out, bias, m, n, stream); + } +} + +#undef ADD_BIAS_GELU + +template void invokeAddBiasGeluV2(float* out, const float* bias, const int m, const int n, cudaStream_t stream); +template void invokeAddBiasGeluV2(half* out, const half* bias, const int m, const int n, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void +invokeAddBiasGeluV2(__nv_bfloat16* out, const __nv_bfloat16* bias, const int m, const int n, cudaStream_t stream); +#endif // ENABLE_BF16 + +template +__inline__ __device__ T silu(T x) +{ + return (T)((float)x / (1.0f + __expf((float)-x))); +} + +template +__global__ void add_bias_silu(T* out, const T* __restrict bias, int m, int n) +{ + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + T val = out[id]; + if (bias != nullptr) { + val = val + ldg(&bias[id % n]); + } + out[id] = silu(val); + } +} + +template<> +__global__ void add_bias_silu(half* out, const half* __restrict bias, int m, int n) +{ + half2* out_ptr = (half2*)out; + const half2* bias_ptr = (half2*)bias; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + half2 val = out_ptr[id]; + if (bias != nullptr) { + val = val + __ldg(&bias_ptr[id % n]); + } + val.x = silu(val.x); + val.y = silu(val.y); + out_ptr[id] = val; + } +} + +#ifdef ENABLE_BF16 +template<> +__global__ void add_bias_silu(__nv_bfloat16* out, const __nv_bfloat16* __restrict bias, int m, int n) +{ + __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; + const __nv_bfloat162* bias_ptr = (__nv_bfloat162*)bias; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + __nv_bfloat162 val = out_ptr[id]; + if (bias != nullptr) { + val = bf16hadd2(val, ldg(&bias_ptr[id % n])); + } + val.x = silu(val.x); + val.y = silu(val.y); + out_ptr[id] = val; + } +} +#endif + +template +void invokeAddBiasSilu(T* out, const T* bias, const int m, const int n, cudaStream_t stream) +{ + const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 + dim3 block, grid; + if (n / 4 / data_type_factor <= 1024) { + block.x = n / 4 / data_type_factor; + grid.x = m; + } + else { + block.x = 1024; + grid.x = ceil(m * n / 1024.); + } + add_bias_silu<<>>(out, bias, m, n / data_type_factor); +} + +template void invokeAddBiasSilu(float* out, const float* bias, const int m, const int n, cudaStream_t stream); +template void invokeAddBiasSilu(half* out, const half* bias, const int m, const int n, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void +invokeAddBiasSilu(__nv_bfloat16* out, const __nv_bfloat16* bias, const int m, const int n, cudaStream_t stream); +#endif + +// Invoke GeGlu (gated glue) +template +__global__ void +addBiasGatedSilu(T* hidden1, const T* hidden2, const T* __restrict bias1, const T* __restrict bias2, int m, int n) +{ + const bool use_bias = bias1 != nullptr && bias2 != nullptr; + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + T val1 = hidden1[id]; + T val2 = hidden2[id]; + if (use_bias) { + T reg_bias1 = __ldg(&bias1[id % n]); + T reg_bias2 = __ldg(&bias2[id % n]); + val1 += reg_bias1; + val2 += reg_bias2; + } + hidden1[id] = silu(val1) * val2; + } +} + +template<> +__global__ void addBiasGatedSilu( + half* hidden1, const half* hidden2, const half* __restrict bias1, const half* __restrict bias2, int m, int n) +{ + half2* hidden1_ptr = (half2*)hidden1; + const half2* hidden2_ptr = (half2*)hidden2; + const half2* bias1_ptr = (half2*)bias1; + const half2* bias2_ptr = (half2*)bias2; + const bool use_bias = bias1 != nullptr && bias2 != nullptr; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + half2 val1 = hidden1_ptr[id]; + half2 val2 = hidden2_ptr[id]; + if (use_bias) { + half2 reg_bias1 = __ldg(&bias1_ptr[id % n]); + half2 reg_bias2 = __ldg(&bias2_ptr[id % n]); + val1 = __hadd2(val1, reg_bias1); + val2 = __hadd2(val2, reg_bias2); + } + val1.x = silu(val1.x) * val2.x; + val1.y = silu(val1.y) * val2.y; + hidden1_ptr[id] = val1; + } +} + +#ifdef ENABLE_BF16 +template<> +__global__ void addBiasGatedSilu(__nv_bfloat16* hidden1, + const __nv_bfloat16* hidden2, + const __nv_bfloat16* __restrict bias1, + const __nv_bfloat16* __restrict bias2, + int m, + int n) +{ + __nv_bfloat162* hidden1_ptr = (__nv_bfloat162*)hidden1; + const __nv_bfloat162* hidden2_ptr = (__nv_bfloat162*)hidden2; + const __nv_bfloat162* bias1_ptr = (__nv_bfloat162*)bias1; + const __nv_bfloat162* bias2_ptr = (__nv_bfloat162*)bias2; + const bool use_bias = bias1 != nullptr && bias2 != nullptr; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + __nv_bfloat162 val1 = hidden1_ptr[id]; + __nv_bfloat162 val2 = hidden2_ptr[id]; + if (use_bias) { + __nv_bfloat162 reg_bias1 = ldg(&bias1_ptr[id % n]); + __nv_bfloat162 reg_bias2 = ldg(&bias2_ptr[id % n]); + val1 = bf16hadd2(val1, reg_bias1); + val2 = bf16hadd2(val2, reg_bias2); + } + val1.x = (__nv_bfloat16)(silu((float)val1.x) * (float)val2.x); + val1.y = (__nv_bfloat16)(silu((float)val1.y) * (float)val2.y); + hidden1_ptr[id] = val1; + } +} +#endif + +template +void invokeAddBiasGatedSilu( + T* hidden1, const T* hidden2, const T* bias1, const T* bias2, const int m, const int n, cudaStream_t stream) +{ + const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 + dim3 block, grid; + if (n / 4 / data_type_factor <= 1024) { + block.x = n / 4 / data_type_factor; + grid.x = m; + } + else { + block.x = 1024; + grid.x = ceil(m * n / 1024.); + } + addBiasGatedSilu<<>>(hidden1, hidden2, bias1, bias2, m, n / data_type_factor); +} + +template void invokeAddBiasGatedSilu(float* hidden1, + const float* hidden2, + const float* bias1, + const float* bias2, + const int m, + const int n, + cudaStream_t stream); +template void invokeAddBiasGatedSilu(half* hidden1, + const half* hidden2, + const half* bias1, + const half* bias2, + const int m, + const int n, + cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddBiasGatedSilu(__nv_bfloat16* hidden1, + const __nv_bfloat16* hidden2, + const __nv_bfloat16* bias1, + const __nv_bfloat16* bias2, + const int m, + const int n, + cudaStream_t stream); +#endif + } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/activation_kernels.h b/src/fastertransformer/kernels/activation_kernels.h index 6600457ef..2a1a4a8fe 100644 --- a/src/fastertransformer/kernels/activation_kernels.h +++ b/src/fastertransformer/kernels/activation_kernels.h @@ -25,10 +25,28 @@ namespace fastertransformer { template void invokeAddBiasGelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream); +template +void invokeAddBiasGatedGelu( + T* hidden1, const T* hidden2, const T* bias1, const T* bias2, const int m, const int n, cudaStream_t stream); + template void invokeAddBiasRelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream); +template +void invokeAddBiasGatedRelu( + T* hidden1, const T* hidden2, const T* bias1, const T* bias2, const int m, const int n, cudaStream_t stream); + template void invokeAddBias(F_T* out, const B_T* bias, const int m, const int n, cudaStream_t stream); +template +void invokeAddBiasGeluV2(T* out, const T* bias, const int m, const int n, cudaStream_t stream); + +template +void invokeAddBiasGatedSilu( + T* hidden1, const T* hidden2, const T* bias1, const T* bias2, const int m, const int n, cudaStream_t stream); + +template +void invokeAddBiasSilu(T* out, const T* bias, const int m, const int n, cudaStream_t stream); + } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/add_bias_transpose_kernels.cu b/src/fastertransformer/kernels/add_bias_transpose_kernels.cu index 952bcb863..55b086955 100644 --- a/src/fastertransformer/kernels/add_bias_transpose_kernels.cu +++ b/src/fastertransformer/kernels/add_bias_transpose_kernels.cu @@ -19,13 +19,14 @@ #include #include "add_bias_transpose_kernels.h" +#include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" namespace fastertransformer { template -__global__ void addBiasTransposeToMultiHead(const T* matrices, - const T* biases, - T* output, +__global__ void addBiasTransposeToMultiHead(const T* matrices, + const T* biases, + T* output, const int batch_size, const int head_num, const int size_per_head, @@ -36,20 +37,20 @@ __global__ void addBiasTransposeToMultiHead(const T* matrices, { for (int j = 0; j < y_repeat_per_block; j++) { int y_offset = blockIdx.y * blockDim.y * y_repeat_per_block + j * blockDim.y + threadIdx.y; - int bias_id = -1; - T bias_element; + int bias_id = -1; + T bias_element; for (int i = 0; i < x_repeat_per_block; i++) { - int x_offset = blockIdx.x * blockDim.x * x_repeat_per_block + i * blockDim.x + threadIdx.x; + int x_offset = blockIdx.x * blockDim.x * x_repeat_per_block + i * blockDim.x + threadIdx.x; int bias_id_new = x_offset / (batch_size * seq_len); if (bias_id_new != bias_id) { bias_element = biases[bias_id_new * head_num * size_per_head + y_offset]; - bias_id = bias_id_new; + bias_id = bias_id_new; } if (x_offset < batch_size * seq_len * matrices_num && y_offset < head_num * size_per_head) { - int matrix_id = x_offset / (batch_size * seq_len); - int batch_id = (x_offset % (batch_size * seq_len)) / seq_len; - int seq_id = x_offset % seq_len; - int head_id = y_offset / size_per_head; + int matrix_id = x_offset / (batch_size * seq_len); + int batch_id = (x_offset % (batch_size * seq_len)) / seq_len; + int seq_id = x_offset % seq_len; + int head_id = y_offset / size_per_head; int head_y_offset = y_offset % size_per_head; int output_offset = matrix_id * batch_size * head_num * seq_len * size_per_head @@ -63,14 +64,14 @@ __global__ void addBiasTransposeToMultiHead(const T* matrices, } template -void invokeAddBiasTransposeToMultiHead(const T* matrices, - const T* biases, - T* output, - const int batch_size, - const int head_num, - const int size_per_head, - const int seq_len, - const int matrices_num, +void invokeAddBiasTransposeToMultiHead(const T* matrices, + const T* biases, + T* output, + const int batch_size, + const int head_num, + const int size_per_head, + const int seq_len, + const int matrices_num, const cudaStream_t stream) { /* @@ -86,8 +87,8 @@ void invokeAddBiasTransposeToMultiHead(const T* matrices, const int x_repeat_per_block = 8; const int y_repeat_per_block = 1; - const int block_dim_x = 1; - const int block_dim_y = 32; + const int block_dim_x = 1; + const int block_dim_y = 32; const int x_total_len = matrices_num * batch_size * seq_len; const int y_total_len = head_num * size_per_head; @@ -108,26 +109,38 @@ void invokeAddBiasTransposeToMultiHead(const T* matrices, y_repeat_per_block); } -template void invokeAddBiasTransposeToMultiHead(const float* matrices, - const float* biases, - float* output, - const int batch_size, - const int head_num, - const int size_per_head, - const int seq_len, - const int matrices_num, +template void invokeAddBiasTransposeToMultiHead(const float* matrices, + const float* biases, + float* output, + const int batch_size, + const int head_num, + const int size_per_head, + const int seq_len, + const int matrices_num, const cudaStream_t stream); -template void invokeAddBiasTransposeToMultiHead(const half* matrices, - const half* biases, - half* output, - const int batch_size, - const int head_num, - const int size_per_head, - const int seq_len, - const int matrices_num, +template void invokeAddBiasTransposeToMultiHead(const half* matrices, + const half* biases, + half* output, + const int batch_size, + const int head_num, + const int size_per_head, + const int seq_len, + const int matrices_num, const cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddBiasTransposeToMultiHead(const __nv_bfloat16* matrices, + const __nv_bfloat16* biases, + __nv_bfloat16* output, + const int batch_size, + const int head_num, + const int size_per_head, + const int seq_len, + const int matrices_num, + const cudaStream_t stream); +#endif + __inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4) { return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4; @@ -136,46 +149,46 @@ __inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int d template __global__ void transposeMultiHeadToSingleKernel( T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head) -{ - int batch_id = blockIdx.x / (head_num * seq_len); - int seq_id = blockIdx.x % seq_len; - int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; - dst[batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head + head_id * size_per_head - + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; -} - -template<> -__global__ void transposeMultiHeadToSingleKernel( - half* src, half* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int batch_id = tid / (head_num * seq_len * size_per_head); - int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); - int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; - int id = tid % size_per_head; + int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); + int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; + int id = tid % size_per_head; - int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head); - half2* src_ptr = (half2*)src; - half2* dst_ptr = (half2*)dst; + int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head); + uint32_t* src_ptr = (uint32_t*)src; + uint32_t* dst_ptr = (uint32_t*)dst; dst_ptr[target_id] = src_ptr[tid]; } +template<> +__global__ void transposeMultiHeadToSingleKernel( + float* src, float* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head) +{ + int batch_id = blockIdx.x / (head_num * seq_len); + int seq_id = blockIdx.x % seq_len; + int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; + dst[batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head + head_id * size_per_head + + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; +} + template -void invokeTransposeMultiHeadToSingle(T* dst, - T* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +void invokeTransposeMultiHeadToSingle(T* dst, + T* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream) { dim3 grid, block; if (sizeof(T) == 2) { const int seq_per_block = 4; - grid.x = batch_size * head_num * seq_len / seq_per_block; - block.x = seq_per_block * size_per_head / 2; + grid.x = batch_size * head_num * seq_len / seq_per_block; + block.x = seq_per_block * size_per_head / 2; assert(grid.x * seq_per_block == batch_size * head_num * seq_len); @@ -184,27 +197,37 @@ void invokeTransposeMultiHeadToSingle(T* dst, } else { const int seq_per_block = 1; - grid.x = batch_size * head_num * seq_len / seq_per_block; - block.x = seq_per_block * size_per_head; + grid.x = batch_size * head_num * seq_len / seq_per_block; + block.x = seq_per_block * size_per_head; transposeMultiHeadToSingleKernel <<>>(src, dst, batch_size, seq_len, head_num, size_per_head); } } -template void invokeTransposeMultiHeadToSingle(float* dst, - float* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +template void invokeTransposeMultiHeadToSingle(float* dst, + float* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream); -template void invokeTransposeMultiHeadToSingle(half* dst, - half* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +template void invokeTransposeMultiHeadToSingle(half* dst, + half* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeTransposeMultiHeadToSingle(__nv_bfloat16* dst, + __nv_bfloat16* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); +#endif + } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/add_bias_transpose_kernels.h b/src/fastertransformer/kernels/add_bias_transpose_kernels.h index d2cc10f12..f8ccf2302 100644 --- a/src/fastertransformer/kernels/add_bias_transpose_kernels.h +++ b/src/fastertransformer/kernels/add_bias_transpose_kernels.h @@ -17,22 +17,22 @@ namespace fastertransformer { template -void invokeAddBiasTransposeToMultiHead(const T* matrices, - const T* biases, - T* output, - const int batch_size, - const int head_num, - const int size_per_head, - const int seq_len, - const int matrices_num, +void invokeAddBiasTransposeToMultiHead(const T* matrices, + const T* biases, + T* output, + const int batch_size, + const int head_num, + const int size_per_head, + const int seq_len, + const int matrices_num, const cudaStream_t stream); template -void invokeTransposeMultiHeadToSingle(T* dst, - T* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +void invokeTransposeMultiHeadToSingle(T* dst, + T* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream); } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/add_residual_kernels.cu b/src/fastertransformer/kernels/add_residual_kernels.cu index 4cd9f0fbb..697ad066b 100644 --- a/src/fastertransformer/kernels/add_residual_kernels.cu +++ b/src/fastertransformer/kernels/add_residual_kernels.cu @@ -15,112 +15,165 @@ */ #include "src/fastertransformer/kernels/add_residual_kernels.h" +#include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" namespace fastertransformer { -template -__global__ void addBiasResidual(T* output, const T* input, const T* bias, const int m, const int n) +template +__global__ void +addBiasResidual(T* output, const T* residual1, const T* residual2, const T* bias, const int m, const int n) { const int col_index = blockIdx.y * blockDim.x + threadIdx.x; if (col_index < n) { T bias_val = (bias == nullptr) ? (T)(0.0f) : bias[col_index]; - output[blockIdx.x * n + col_index] = - output[blockIdx.x * n + col_index] + input[blockIdx.x * n + col_index] + bias_val; + if (RESIDUAL_NUM == 1) { + output[blockIdx.x * n + col_index] = + output[blockIdx.x * n + col_index] + residual1[blockIdx.x * n + col_index] + bias_val; + } + else if (RESIDUAL_NUM == 2) { + output[blockIdx.x * n + col_index] = output[blockIdx.x * n + col_index] + + residual1[blockIdx.x * n + col_index] + + residual2[blockIdx.x * n + col_index] + bias_val; + } } } template -void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, cudaStream_t stream) +void invokeAddBiasResidual( + T* output, const T* residual1, const T* residual2, const T* bias, const int m, const int n, cudaStream_t stream) { - int blocks_per_row = ceil(float(n) / 1024); + int blocks_per_row = ceil(float(n) / 1024); dim3 grid(m, blocks_per_row); dim3 block(min(n, 1024)); - addBiasResidual<<>>(output, input, bias, m, n); + if (residual2 == nullptr) { + addBiasResidual<<>>(output, residual1, residual2, bias, m, n); + } + else { + addBiasResidual<<>>(output, residual1, residual2, bias, m, n); + } } template -__global__ void addBiasAttentionFfnResidual(T* block_output, - const T* ffn_output, - const T* attn_output, - const T* block_input, - const T* bias, +__global__ void addBiasAttentionFfnResidual(T* block_output, + const T* ffn_output, + const T* attn_output, + const T* block_input, + const T* bias, const int m, - const int n) + const int n, + const int block_input_tp_split) { const int col_index = blockIdx.y * blockDim.x + threadIdx.x; if (col_index < n) { - block_output[blockIdx.x * n + col_index] = ffn_output[blockIdx.x * n + col_index] - + attn_output[blockIdx.x * n + col_index] - + block_input[blockIdx.x * n + col_index] + bias[col_index]; + block_output[blockIdx.x * n + col_index] = + ffn_output[blockIdx.x * n + col_index] + attn_output[blockIdx.x * n + col_index] + bias[col_index] + + ((block_input != nullptr) ? + float2type((float)block_input[blockIdx.x * n + col_index] / (float)block_input_tp_split) : + static_cast(0.0f)); } } template -__global__ void addBiasAttentionFfnResidual( - T* block_output, const T* ffn_output, const T* attn_output, const T* bias, const int m, const int n) +__global__ void addBiasAttentionFfnResidual(T* block_output, + const T* ffn_output, + const T* attn_output, + const T* bias, + const int m, + const int n, + const int block_input_tp_split) { const int col_index = blockIdx.y * blockDim.x + threadIdx.x; if (col_index < n) { - block_output[blockIdx.x * n + col_index] += - ffn_output[blockIdx.x * n + col_index] + attn_output[blockIdx.x * n + col_index] + bias[col_index]; + const int global_index = blockIdx.x * n + col_index; + block_output[global_index] = add(float2type((float)block_output[global_index] / (float)block_input_tp_split), + ffn_output[global_index], + attn_output[global_index], + bias[col_index]); } } template -void invokeAddBiasAttentionFfnResidual(T* block_output, - const T* ffn_output, - const T* attn_output, - const T* block_input, - const T* bias, - const int m, - const int n, +void invokeAddBiasAttentionFfnResidual(T* block_output, + const T* ffn_output, + const T* attn_output, + const T* block_input, + const T* bias, + const int m, + const int n, + const int block_input_tp_split, cudaStream_t stream) { - int blocks_per_row = ceil(float(n) / 1024); + int blocks_per_row = ceil(float(n) / 1024); dim3 grid(m, blocks_per_row); dim3 block(min(n, 1024)); if (block_output == block_input) { - addBiasAttentionFfnResidual<<>>(block_output, ffn_output, attn_output, bias, m, n); + addBiasAttentionFfnResidual<<>>( + block_output, ffn_output, attn_output, bias, m, n, block_input_tp_split); } else { addBiasAttentionFfnResidual<<>>( - block_output, ffn_output, attn_output, block_input, bias, m, n); + block_output, ffn_output, attn_output, block_input, bias, m, n, block_input_tp_split); } } -template void invokeAddBiasResidual( - float* output, const float* input, const float* bias, const int m, const int n, cudaStream_t stream); +template void invokeAddBiasResidual(float* output, + const float* residual1, + const float* residual2, + const float* bias, + const int m, + const int n, + cudaStream_t stream); -template void -invokeAddBiasResidual(half* output, const half* input, const half* bias, const int m, const int n, cudaStream_t stream); +template void invokeAddBiasResidual(half* output, + const half* residual1, + const half* residual2, + const half* bias, + const int m, + const int n, + cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeAddBiasResidual(__nv_bfloat16* output, - const __nv_bfloat16* input, +template void invokeAddBiasResidual(__nv_bfloat16* output, + const __nv_bfloat16* residual1, + const __nv_bfloat16* residual2, const __nv_bfloat16* bias, - const int m, - const int n, - cudaStream_t stream); + const int m, + const int n, + cudaStream_t stream); #endif -template void invokeAddBiasAttentionFfnResidual(float* block_output, +template void invokeAddBiasAttentionFfnResidual(float* block_output, const float* ffn_output, const float* attn_output, const float* input, const float* bias, - const int m, - const int n, + const int m, + const int n, + const int block_input_tp_split, cudaStream_t stream); -template void invokeAddBiasAttentionFfnResidual(half* block_output, - const half* ffn_output, - const half* attn_output, - const half* input, - const half* bias, - const int m, - const int n, +template void invokeAddBiasAttentionFfnResidual(half* block_output, + const half* ffn_output, + const half* attn_output, + const half* input, + const half* bias, + const int m, + const int n, + const int block_input_tp_split, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddBiasAttentionFfnResidual(__nv_bfloat16* block_output, + const __nv_bfloat16* ffn_output, + const __nv_bfloat16* attn_output, + const __nv_bfloat16* input, + const __nv_bfloat16* bias, + const int m, + const int n, + const int block_input_tp_split, + cudaStream_t stream); +#endif + template __global__ void T5addResidual(T* output, const T* input, const int m, const int n) { @@ -137,7 +190,7 @@ __global__ void T5addResidual(T* output, const T* input, const int m, const int template void invokeT5AddResidual(T* output, const T* input, const int m, const int n, cudaStream_t stream) { - int blocks_per_row = ceil(float(n) / 1024); + int blocks_per_row = ceil(float(n) / 1024); dim3 grid(m, blocks_per_row); dim3 block(min(n, 1024)); T5addResidual<<>>(output, input, m, n); @@ -145,6 +198,10 @@ void invokeT5AddResidual(T* output, const T* input, const int m, const int n, cu template void invokeT5AddResidual(float* output, const float* input, const int m, const int n, cudaStream_t stream); template void invokeT5AddResidual(half* output, const half* input, const int m, const int n, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void +invokeT5AddResidual(__nv_bfloat16* output, const __nv_bfloat16* input, const int m, const int n, cudaStream_t stream); +#endif template void invokeT5AddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, cudaStream_t stream) @@ -162,6 +219,15 @@ template void invokeT5AddBiasResidual( float* output, const float* input, const float* bias, const int m, const int n, cudaStream_t stream); template void invokeT5AddBiasResidual( half* output, const half* input, const half* bias, const int m, const int n, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeT5AddBiasResidual(__nv_bfloat16* output, + const __nv_bfloat16* input, + const __nv_bfloat16* bias, + const int m, + const int n, + cudaStream_t stream); +#endif + /******************* invokeAddBiasResidualCol32 ***********************/ // input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n) //(grid, block) must be (m, n/4) @@ -171,24 +237,24 @@ __global__ void add_bias_input_COL32_int8I_DataTypeO( T* output, const int8_t* input1, const T* input2, const T* bias, int m, int n, const float* input1_deQFactor_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); - int col_start = threadIdx.x << 2; + int col_start = threadIdx.x << 2; - float local_out[4]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + float local_out[4]; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; char4* input1TmpPtr = (char4*)input1; - char4 input1Tmp = __ldg(input1TmpPtr + outIdx); + char4 input1Tmp = __ldg(input1TmpPtr + outIdx); int col_start_tmp = col_start; local_out[0] = static_cast(input2[(outIdx << 2) + 0]) + static_cast(input1Tmp.x) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[1] = static_cast(input2[(outIdx << 2) + 1]) + static_cast(input1Tmp.y) * input1_deQFactor + local_out[1] = static_cast(input2[(outIdx << 2) + 1]) + static_cast(input1Tmp.y) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[2] = static_cast(input2[(outIdx << 2) + 2]) + static_cast(input1Tmp.z) * input1_deQFactor + local_out[2] = static_cast(input2[(outIdx << 2) + 2]) + static_cast(input1Tmp.z) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[3] = static_cast(input2[(outIdx << 2) + 3]) + static_cast(input1Tmp.w) * input1_deQFactor + local_out[3] = static_cast(input2[(outIdx << 2) + 3]) + static_cast(input1Tmp.w) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); for (int i = 0; i < 4; i++) { @@ -197,25 +263,25 @@ __global__ void add_bias_input_COL32_int8I_DataTypeO( } template<> -__global__ void add_bias_input_COL32_int8I_DataTypeO(half4* output, +__global__ void add_bias_input_COL32_int8I_DataTypeO(half4* output, const int8_t* input1, - const half4* input2, - const half4* bias, - int m, - int n, - const float* input1_deQFactor_ptr) + const half4* input2, + const half4* bias, + int m, + int n, + const float* input1_deQFactor_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); - int col_start = (blockIdx.x << 5) + (threadIdx.x << 2); - int row_start = (blockIdx.y << 5) + (threadIdx.y); + int col_start = (blockIdx.x << 5) + (threadIdx.x << 2); + int row_start = (blockIdx.y << 5) + (threadIdx.y); if (col_start < n && row_start < m) { - half4 local_out; - int outIdx = ((col_start & 0xffffffe0) * m + (row_start << 5) + (col_start & 31)) >> 2; + half4 local_out; + int outIdx = ((col_start & 0xffffffe0) * m + (row_start << 5) + (col_start & 31)) >> 2; char4* input1TmpPtr = (char4*)input1; - char4 input1Tmp = input1TmpPtr[outIdx]; - half4 input2Tmp = input2[outIdx]; - half4 biasTmp = bias[col_start >> 2]; + char4 input1Tmp = input1TmpPtr[outIdx]; + half4 input2Tmp = input2[outIdx]; + half4 biasTmp = bias[col_start >> 2]; local_out.x = static_cast((float)input1Tmp.x * input1_deQFactor + (float)biasTmp.x + (float)input2Tmp.x); local_out.y = static_cast((float)input1Tmp.y * input1_deQFactor + (float)biasTmp.y + (float)input2Tmp.y); @@ -226,14 +292,14 @@ __global__ void add_bias_input_COL32_int8I_DataTypeO(half4* output, } template -void invokeAddBiasResidualCol32(T* output, +void invokeAddBiasResidualCol32(T* output, const int8_t* input1, - const T* input2, - const T* bias, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr) + const T* input2, + const T* bias, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr) { dim3 grid((n + 31) / 32, (m + 31) / 32); dim3 block(8, 32); @@ -248,63 +314,63 @@ void invokeAddBiasResidualCol32(T* output, } } -template void invokeAddBiasResidualCol32(float* output, +template void invokeAddBiasResidualCol32(float* output, const int8_t* input1, - const float* input2, - const float* bias, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr); - -template void invokeAddBiasResidualCol32(half* output, + const float* input2, + const float* bias, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr); + +template void invokeAddBiasResidualCol32(half* output, const int8_t* input1, - const half* input2, - const half* bias, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr); + const half* input2, + const half* bias, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr); /******************* invokeAddBiasResidualCol32 ***********************/ // input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n) //(grid, block) must be (m, n/4) // using char4 template -__global__ void add_bias_input_COL32_int32I_DataTypeO(T* output, +__global__ void add_bias_input_COL32_int32I_DataTypeO(T* output, const int32_t* input1, - const T* input2, - const T* bias, - int m, - int n, - const float* weight_amax, - const float* input1_amax_ptr, - const int scale_is_vector) + const T* input2, + const T* bias, + int m, + int n, + const float* weight_amax, + const float* input1_amax_ptr, + const int scale_is_vector) { - int col_start = threadIdx.x << 2; + int col_start = threadIdx.x << 2; const float4* weight_scale_ptr = (const float4*)weight_amax; - const float4 weight_scale = __ldg(weight_scale_ptr + threadIdx.x * scale_is_vector); - const float input1_deQ = __ldg(input1_amax_ptr) / 127.0f; + const float4 weight_scale = __ldg(weight_scale_ptr + threadIdx.x * scale_is_vector); + const float input1_deQ = __ldg(input1_amax_ptr) / 127.0f; float local_out[4]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; int4* input1TmpPtr = (int4*)input1; - int4 input1Tmp = input1TmpPtr[outIdx]; + int4 input1Tmp = input1TmpPtr[outIdx]; int col_start_tmp = col_start; - local_out[0] = static_cast(input2[(outIdx << 2) + 0]) + local_out[0] = static_cast(input2[(outIdx << 2) + 0]) + static_cast(input1Tmp.x) * input1_deQ * weight_scale.x / 127.0f + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[1] = static_cast(input2[(outIdx << 2) + 1]) + local_out[1] = static_cast(input2[(outIdx << 2) + 1]) + static_cast(input1Tmp.y) * input1_deQ * weight_scale.y / 127.0f + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[2] = static_cast(input2[(outIdx << 2) + 2]) + local_out[2] = static_cast(input2[(outIdx << 2) + 2]) + static_cast(input1Tmp.z) * input1_deQ * weight_scale.z / 127.0f + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[3] = static_cast(input2[(outIdx << 2) + 3]) + local_out[3] = static_cast(input2[(outIdx << 2) + 3]) + static_cast(input1Tmp.w) * input1_deQ * weight_scale.w / 127.0f + static_cast(__ldg(bias + col_start_tmp)); @@ -314,27 +380,31 @@ __global__ void add_bias_input_COL32_int32I_DataTypeO(T* output, } template<> -__global__ void add_bias_input_COL32_int32I_DataTypeO(half4* output, +__global__ void add_bias_input_COL32_int32I_DataTypeO(half4* output, const int32_t* input1, - const half4* input2, - const half4* bias, - int m, - int n, - const float* weight_amax, - const float* input1_amax_ptr, - const int scale_is_vector) + const half4* input2, + const half4* bias, + int m, + int n, + const float* weight_amax, + const float* input1_amax_ptr, + const int scale_is_vector) { - int col_start = threadIdx.x << 2; - const float4* weight_scale_ptr = (const float4*)weight_amax; - const float4 weight_scale = __ldg(weight_scale_ptr + threadIdx.x * scale_is_vector); + int col_start = threadIdx.x << 2; + const float4* weight_scale_ptr = (const float4*)weight_amax; + const float weight_scale_single = __ldg(weight_amax); + const float4 weight_scale = + scale_is_vector == 1 ? + __ldg(weight_scale_ptr + threadIdx.x * scale_is_vector) : + make_float4(weight_scale_single, weight_scale_single, weight_scale_single, weight_scale_single); const float input1_deQ = __ldg(input1_amax_ptr) / 127.0f; float local_out[4]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; int4* input1TmpPtr = (int4*)input1; - int4 input1Tmp = input1TmpPtr[outIdx]; - half4 input2Tmp = input2[outIdx]; - half4 biasTmp = bias[threadIdx.x]; + int4 input1Tmp = input1TmpPtr[outIdx]; + half4 input2Tmp = input2[outIdx]; + half4 biasTmp = bias[threadIdx.x]; local_out[0] = static_cast(input2Tmp.x) + static_cast(input1Tmp.x) * input1_deQ * weight_scale.x / 127.0f @@ -359,16 +429,16 @@ __global__ void add_bias_input_COL32_int32I_DataTypeO(half4* output, } template -void invokeAddBiasResidualCol32(T* output, +void invokeAddBiasResidualCol32(T* output, const int32_t* input1, - const T* input2, - const T* bias, - int m, - int n, - cudaStream_t stream, - const float* weight_amax, - const float* input1_amax_ptr, - const int scale_is_vector) + const T* input2, + const T* bias, + int m, + int n, + cudaStream_t stream, + const float* weight_amax, + const float* input1_amax_ptr, + const int scale_is_vector) { dim3 grid(m); dim3 block(n / 4); @@ -390,26 +460,26 @@ void invokeAddBiasResidualCol32(T* output, } } -template void invokeAddBiasResidualCol32(float* output, - const int* input1, +template void invokeAddBiasResidualCol32(float* output, + const int* input1, const float* input2, const float* bias, - int m, - int n, + int m, + int n, cudaStream_t stream, const float* weight_amax, const float* input1_amax_ptr, - const int scale_is_vector); - -template void invokeAddBiasResidualCol32(half* output, - const int* input1, - const half* input2, - const half* bias, - int m, - int n, + const int scale_is_vector); + +template void invokeAddBiasResidualCol32(half* output, + const int* input1, + const half* input2, + const half* bias, + int m, + int n, cudaStream_t stream, const float* weight_amax, const float* input1_amax_ptr, - const int scale_is_vector); + const int scale_is_vector); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/add_residual_kernels.h b/src/fastertransformer/kernels/add_residual_kernels.h index edd8179b1..669ec2f04 100644 --- a/src/fastertransformer/kernels/add_residual_kernels.h +++ b/src/fastertransformer/kernels/add_residual_kernels.h @@ -25,7 +25,14 @@ namespace fastertransformer { template -void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, cudaStream_t stream); +void invokeAddBiasResidual( + T* output, const T* residual1, const T* residual2, const T* bias, const int m, const int n, cudaStream_t stream); + +template +void invokeAddBiasResidual(T* output, const T* residual1, const T* bias, const int m, const int n, cudaStream_t stream) +{ + invokeAddBiasResidual(output, residual1, (const T*)nullptr, bias, m, n, stream); +} template void invokeT5AddResidual(T* output, const T* input, const int m, const int n, cudaStream_t stream); @@ -34,35 +41,49 @@ template void invokeT5AddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, cudaStream_t stream); template -void invokeAddBiasAttentionFfnResidual(T* block_output, - const T* ffn_output, - const T* attn_output, - const T* block_input, - const T* bias, - const int m, - const int n, +void invokeAddBiasAttentionFfnResidual(T* block_output, + const T* ffn_output, + const T* attn_output, + const T* block_input, + const T* bias, + const int m, + const int n, + const int block_input_tp_split, cudaStream_t stream); template -void invokeAddBiasResidualCol32(T* output, +void invokeAddBiasAttentionFfnResidual(T* block_output, + const T* ffn_output, + const T* attn_output, + const T* block_input, + const T* bias, + const int m, + const int n, + cudaStream_t stream) +{ + invokeAddBiasAttentionFfnResidual(block_output, ffn_output, attn_output, block_input, bias, m, n, 1, stream); +} + +template +void invokeAddBiasResidualCol32(T* output, const int8_t* input1, - const T* input2, - const T* bias, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr); + const T* input2, + const T* bias, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr); template -void invokeAddBiasResidualCol32(T* output, +void invokeAddBiasResidualCol32(T* output, const int32_t* input1, - const T* input2, - const T* bias, - int m, - int n, - cudaStream_t stream, - const float* weight_amax, - const float* input1_amax_ptr, - const int scale_is_vector = 0); + const T* input2, + const T* bias, + int m, + int n, + cudaStream_t stream, + const float* weight_amax, + const float* input1_amax_ptr, + const int scale_is_vector = 0); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/ban_bad_words.cu b/src/fastertransformer/kernels/ban_bad_words.cu index b52031081..426111080 100644 --- a/src/fastertransformer/kernels/ban_bad_words.cu +++ b/src/fastertransformer/kernels/ban_bad_words.cu @@ -20,40 +20,40 @@ namespace fastertransformer { template -__global__ void ban_bad_words(T* logits, +__global__ void ban_bad_words(T* logits, const int* output_ids_buf, const int* parent_ids_buf, - int batch_size, - int beam_width, + int batch_size, + int beam_width, const int* bad_words, - size_t bad_words_len, - bool share_words, - int id_offset, - int vocab_size_padded, - size_t step) + size_t bad_words_len, + bool share_words, + int id_offset, + int vocab_size_padded, + size_t step) { - const int id = blockIdx.x * blockDim.x + threadIdx.x; + const int id = blockIdx.x * blockDim.x + threadIdx.x; const int batch_idx = blockIdx.y / beam_width; - const int beam_idx = blockIdx.y % beam_width; + const int beam_idx = blockIdx.y % beam_width; - const int* base_bad_words = share_words ? bad_words : bad_words + batch_idx * 2 * bad_words_len; + const int* base_bad_words = share_words ? bad_words : bad_words + batch_idx * 2 * bad_words_len; const int* base_bad_words_offsets = base_bad_words + bad_words_len; if (id >= bad_words_len || base_bad_words_offsets[id] < 0) { return; } - const int item_end = base_bad_words_offsets[id]; + const int item_end = base_bad_words_offsets[id]; const int item_start = (id > 0) ? base_bad_words_offsets[id - 1] : 0; - const int item_size = item_end - item_start; + const int item_size = item_end - item_start; /* The single-token case unconditionally bans the token */ bool should_ban = item_size == 1; /* Multi-token case and enough previously generated tokens to look for a match */ if (item_size > 1 && step >= item_size - 1) { - should_ban = true; - int parent_id = beam_idx; + should_ban = true; + int parent_id = beam_idx; const bool gather_beam = beam_width > 1; for (int token_idx = item_size - 2; token_idx >= 0; token_idx--) { @@ -86,24 +86,24 @@ __global__ void ban_bad_words(T* logits, } template -void invokeBanBadWords(T* logits, - const int* output_ids_buf, - const int* parent_ids_buf, - int batch_size, - int local_batch_size, - int beam_width, - const int* bad_words, - bool share_words, - size_t bad_words_len, - int id_offset, - int vocab_size_padded, - size_t step, +void invokeBanBadWords(T* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, cudaStream_t stream) { dim3 block, grid; block.x = min(((bad_words_len + 32 - 1) / 32) * 32, 256UL); - grid.x = (bad_words_len + block.x - 1) / block.x; - grid.y = local_batch_size * beam_width; + grid.x = (bad_words_len + block.x - 1) / block.x; + grid.y = local_batch_size * beam_width; ban_bad_words<<>>(logits, output_ids_buf, @@ -119,31 +119,46 @@ void invokeBanBadWords(T* logits, sync_check_cuda_error(); } -template void invokeBanBadWords(half* logits, - const int* output_ids_buf, - const int* parent_ids_buf, - int batch_size, - int local_batch_size, - int beam_width, - const int* bad_words, - bool share_words, - size_t bad_words_len, - int id_offset, - int vocab_size_padded, - size_t step, +template void invokeBanBadWords(half* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, cudaStream_t stream); -template void invokeBanBadWords(float* logits, - const int* output_ids_buf, - const int* parent_ids_buf, - int batch_size, - int local_batch_size, - int beam_width, - const int* bad_words, - bool share_words, - size_t bad_words_len, - int id_offset, - int vocab_size_padded, - size_t step, +#ifdef ENABLE_BF16 +template void invokeBanBadWords(__nv_bfloat16* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, + cudaStream_t stream); +#endif +template void invokeBanBadWords(float* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/ban_bad_words.h b/src/fastertransformer/kernels/ban_bad_words.h index 1cafe793c..c3159b8ce 100644 --- a/src/fastertransformer/kernels/ban_bad_words.h +++ b/src/fastertransformer/kernels/ban_bad_words.h @@ -22,18 +22,18 @@ namespace fastertransformer { template -void invokeBanBadWords(T* logits, - const int* output_ids_buf, - const int* parent_ids_buf, - int batch_size, - int local_batch_size, - int beam_width, - const int* bad_words, - bool share_words, - size_t bad_words_len, - int id_offset, - int vocab_size_padded, - size_t step, +void invokeBanBadWords(T* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/beam_search_penalty_kernels.cu b/src/fastertransformer/kernels/beam_search_penalty_kernels.cu index 8cd5c7a17..904244f77 100644 --- a/src/fastertransformer/kernels/beam_search_penalty_kernels.cu +++ b/src/fastertransformer/kernels/beam_search_penalty_kernels.cu @@ -14,167 +14,231 @@ * limitations under the License. */ +#include + #include "src/fastertransformer/kernels/beam_search_penalty_kernels.h" #include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" namespace fastertransformer { template -__global__ void add_bias_apply_logit_penalties_kernel(int step, - int vocab_size, - const int vocab_size_padded, - int beam_width, - T* logits, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, - const T* bias, - const int ite, - const int max_input_length, - const int batch_size, - const int* end_ids, - float inv_temp, - float len_penalty, - float repeat_penalty) +__global__ void add_bias_temperature(T* logits, + const T* bias, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const float temperature) { - // TODO(bhsueh) Seems there are some problem for len_penalty implementation - int tid = threadIdx.x; - int bid = blockIdx.x; + int tid = threadIdx.x; + int bid = blockIdx.x; int bbid = blockIdx.y; - int bbsize = batch_size * beam_width; - int batch_id = blockIdx.y / beam_width; - - const int vocab_size_padded_offset = bbid * vocab_size_padded; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + logits += bbid * vocab_size_padded; + const T MASK_VAL = (std::is_same::value) ? -HALF_FLT_MAX : -FLT_MAX; + const T inv_temp = static_cast(1.0f / temperature); for (int i = tid + bid * blockDim.x; i < vocab_size_padded; i += blockDim.x * gridDim.x) { if (i < vocab_size) { T bias_val = bias == nullptr ? (T)(0.0f) : bias[i]; - logits[i + vocab_size_padded_offset] = (logits[i + vocab_size_padded_offset] + bias_val) * (T)(inv_temp); + logits[i] = (logits[i] + bias_val) * inv_temp; } else { - logits[i + vocab_size_padded_offset] = -MAX_T_VAL; + logits[i] = MASK_VAL; } } - if (tid == 0 && bid == 0) { - // TODO(bhsueh) apply repetition penalty (this can apply the penalty multiple times to a repeated word). - int prev_id = current_ids[bbid]; - const int end_id = end_ids[batch_id]; - if (logits[prev_id + vocab_size_padded_offset] > T(0)) { - logits[prev_id + vocab_size_padded_offset] = - float(logits[prev_id + vocab_size_padded_offset]) / repeat_penalty; - logits[end_id + vocab_size_padded_offset] = float(logits[end_id + vocab_size_padded_offset]) / len_penalty; - } - else { - logits[prev_id + vocab_size_padded_offset] = - float(logits[prev_id + vocab_size_padded_offset]) * repeat_penalty; - logits[end_id + vocab_size_padded_offset] = float(logits[end_id + vocab_size_padded_offset]) * len_penalty; +} + +template<> +__global__ void add_bias_temperature(half2* logits, + const half2* bias, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const float temperature) +{ + assert(vocab_size % 2 == 0); + assert(vocab_size_padded % 2 == 0); + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int bbid = blockIdx.y; + + const half2 mask_val = __float2half2_rn(-HALF_FLT_MAX); + const half2 inv_temp = __float2half2_rn(1.0f / temperature); + + const int half_vocab_size = vocab_size / 2; + const int half_vocab_size_padded = vocab_size_padded / 2; + + logits += bbid * half_vocab_size_padded; + for (int index = tid + bid * blockDim.x; index < half_vocab_size_padded; index += blockDim.x * gridDim.x) { + int vocab_idx = index % half_vocab_size_padded; + half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val; + if (vocab_idx < half_vocab_size) { + if (bias != nullptr) { + logit = __hadd2(logit, bias[vocab_idx]); + } + logit = __hmul2(logit, inv_temp); } + logits[index] = logit; + } +} + +template +__global__ void apply_repetition_penalty(T* logits, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int step, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, + const int max_input_length, + const float repetition_penalty) +{ + assert(step > 0); + + const int tid = threadIdx.x; + const int bbid = blockIdx.x; + const int batch_id = bbid / beam_width; + const int bbsize = batch_size * beam_width; + + logits += bbid * vocab_size_padded; + extern __shared__ char sbuf[]; + T* penalty_logits = reinterpret_cast(sbuf); + int* penalty_indices = reinterpret_cast(penalty_logits + step); + const int input_length = (input_lengths != nullptr) ? input_lengths[bbid] : max_input_length; + if (tid == 0) { + T repet_penalty = static_cast(repetition_penalty); + int prev_id = current_ids[bbid]; + T prev_logit = logits[prev_id]; + penalty_indices[step - 1] = prev_id; + penalty_logits[step - 1] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty; if (step > 1) { - int parent_beamid = parent_ids[bbsize * (step - 2) + ite * gridDim.y + blockIdx.y]; - for (int i = step - 2; i > 0; --i) { - bool is_mask = input_lengths != nullptr && i >= input_lengths[bbid] && step < max_input_length; - if (is_mask == false) { - prev_id = previous_ids[bbsize * i + ite * gridDim.y + batch_id * beam_width + parent_beamid]; - if (logits[prev_id + vocab_size_padded_offset] > T(0)) { - logits[prev_id + vocab_size_padded_offset] = - float(logits[prev_id + vocab_size_padded_offset]) / repeat_penalty; - } - else { - logits[prev_id + vocab_size_padded_offset] = - float(logits[prev_id + vocab_size_padded_offset]) * repeat_penalty; - } + int parent_beam = bbid % beam_width; + for (int i = step - 2; i >= 0; --i) { + // Skip the padded tokens. + if (i >= input_length && i < max_input_length) { + continue; } - parent_beamid = parent_ids[bbsize * (i - 1) + ite * gridDim.y + batch_id * beam_width + parent_beamid]; + parent_beam = parent_ids[i * bbsize + batch_id * beam_width + parent_beam]; + prev_id = previous_ids[i * bbsize + batch_id * beam_width + parent_beam]; + prev_logit = logits[prev_id]; + penalty_indices[i] = prev_id; + penalty_logits[i] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty; } } } + __syncthreads(); + for (int i = tid; i < step; i += blockDim.x) { + if (i >= input_length && i < max_input_length) { + continue; + } + logits[penalty_indices[i]] = penalty_logits[i]; + } } template -void invokeAddBiasApplyPenalties(int step, - T* logits, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, - const T* bias, - const int ite, - const int max_input_length, - const int local_batch_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const int* end_ids, - const float temperature, - const float len_penalty, - const float repeat_penalty, +void invokeAddBiasApplyPenalties(int step, + T* logits, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, + const T* bias, + const int ite, + const int max_input_length, + const int local_batch_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int* end_ids, + const float temperature, + const float repetition_penalty, cudaStream_t stream) { - dim3 block(256); - dim3 grid((vocab_size_padded + block.x - 1) / block.x, beam_width * local_batch_size); - add_bias_apply_logit_penalties_kernel<<>>(step, - vocab_size, - vocab_size_padded, - beam_width, - logits, - current_ids, - previous_ids, - parent_ids, - input_lengths, - bias, - ite, - max_input_length, - batch_size, - end_ids, - 1.f / temperature, - len_penalty, - repeat_penalty); - sync_check_cuda_error(); + if (bias != nullptr || temperature != 1.0f) { + dim3 block(512); + if (std::is_same::value && vocab_size % 2 == 0 && vocab_size_padded % 2 == 0) { + dim3 grid((vocab_size_padded / 2 + block.x - 1) / block.x, beam_width * local_batch_size); + add_bias_temperature<<>>(reinterpret_cast(logits), + reinterpret_cast(bias), + batch_size, + beam_width, + vocab_size, + vocab_size_padded, + temperature); + } + else { + dim3 grid((vocab_size_padded + block.x - 1) / block.x, beam_width * local_batch_size); + add_bias_temperature<<>>( + logits, bias, batch_size, beam_width, vocab_size, vocab_size_padded, temperature); + } + } + + if (repetition_penalty != 1.0f) { + size_t smem_size = (sizeof(T) + sizeof(int)) * step; + dim3 block(256); + dim3 grid(beam_width * local_batch_size); + apply_repetition_penalty<<>>( + logits, + batch_size, + beam_width, + vocab_size, + vocab_size_padded, + step, + current_ids, + previous_ids, + // TODO(jaedeokk): + // Remove (+ite ...) by getting parent_ids with offset + // and then remove 'ite' argument from the function. + parent_ids + ite * beam_width * local_batch_size, + input_lengths, + max_input_length, + repetition_penalty); + } } -template void invokeAddBiasApplyPenalties(int step, - float* logits, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, +template void invokeAddBiasApplyPenalties(int step, + float* logits, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, const float* bias, - const int ite, - const int max_input_length, - const int local_batch_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const int* end_ids, - const float temerature, - const float len_penalty, - const float repeat_penalty, + const int ite, + const int max_input_length, + const int local_batch_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int* end_ids, + const float temperature, + const float repetition_penalty, cudaStream_t stream); -template void invokeAddBiasApplyPenalties(int step, - half* logits, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, - const half* bias, - const int ite, - const int max_input_length, - const int local_batch_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const int* end_ids, - const float temerature, - const float len_penalty, - const float repeat_penalty, +template void invokeAddBiasApplyPenalties(int step, + half* logits, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, + const half* bias, + const int ite, + const int max_input_length, + const int local_batch_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int* end_ids, + const float temperature, + const float repetition_penalty, cudaStream_t stream); -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/beam_search_penalty_kernels.h b/src/fastertransformer/kernels/beam_search_penalty_kernels.h index 462a90b03..27ee5e449 100644 --- a/src/fastertransformer/kernels/beam_search_penalty_kernels.h +++ b/src/fastertransformer/kernels/beam_search_penalty_kernels.h @@ -22,24 +22,23 @@ namespace fastertransformer { template -void invokeAddBiasApplyPenalties(int step, - T* logits, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, - const T* bias, - const int ite, - const int max_input_length, - const int local_batch_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const int* end_ids, - const float temerature, - const float len_penalty, - const float repeat_penalty, +void invokeAddBiasApplyPenalties(int step, + T* logits, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, + const T* bias, + const int ite, + const int max_input_length, + const int local_batch_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int* end_ids, + const float temperature, + const float repetition_penalty, cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/beam_search_topk_kernels.cu b/src/fastertransformer/kernels/beam_search_topk_kernels.cu index fe64a2d6a..3cef297e4 100644 --- a/src/fastertransformer/kernels/beam_search_topk_kernels.cu +++ b/src/fastertransformer/kernels/beam_search_topk_kernels.cu @@ -23,22 +23,37 @@ #endif #include "src/fastertransformer/kernels/beam_search_topk_kernels.h" +#include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" #include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" namespace fastertransformer { + +template +__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) +{ + // score = log(prob) / (length)^length_penalty. + return log_prob / static_cast(powf(length, length_penalty)); +} + template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel( - const T* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, const int vocab_size, T diversity_rate) +__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel(const T* log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const bool* finished, + const int* sequence_lengths, + const int vocab_size, + T diversity_rate, + float length_penalty) { typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage; - int thread_id = threadIdx.x; - int block_id = blockIdx.x; + int thread_id = threadIdx.x; + int block_id = blockIdx.x; // batch beam index. TopK partial; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; #pragma unroll for (int i = 0; i < MAX_K; ++i) { @@ -49,7 +64,12 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel( #pragma unroll for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) { int index = elem_id + block_id * vocab_size; - partial.insert(log_probs[index], index); + T score = length_penalty == 0.0f ? log_probs[index] : + apply_length_penalty(log_probs[index], + finished[block_id] ? sequence_lengths[block_id] : + sequence_lengths[block_id] + 1, + length_penalty); + partial.insert(score, index); } TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); @@ -59,7 +79,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel( #pragma unroll for (int i = 0; i < MAX_K; ++i) { - topk_tmp_id_buf[index + i] = total.p[i]; + topk_tmp_id_buf[index + i] = total.p[i]; topk_tmp_val_buf[index + i] = total.u[i] + diversity_rate * (T)i; } } @@ -69,10 +89,10 @@ template __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) { - int thread_id = threadIdx.x; - int block_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + int thread_id = threadIdx.x; + int block_id = blockIdx.x; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; TopK partial; if (thread_id == 0) { for (int i = 0; i < MAX_K; ++i) { @@ -97,13 +117,13 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel_v2(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) { typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage; - int tid = threadIdx.x; - int bid = blockIdx.x; + int tid = threadIdx.x; + int bid = blockIdx.x; TopK partial; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; #pragma unroll for (int i = 0; i < MAX_K; ++i) { @@ -130,38 +150,44 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ template __global__ void topk_stage_1_opt3(const T* __restrict log_probs, - T* tmp_log_probs, - int* topk_tmp_id_buf, - T* topk_tmp_val_buf, + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, const bool* finished, - const int k, - const int vocab_size, - const int* end_ids) + const int* sequence_lengths, + const int k, + const int vocab_size, + const float length_penalty, + const int* end_ids) { typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage; const int tid = threadIdx.x; const int bid = blockIdx.x; - const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs - const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam - const int tmp_log_buf_index = row_id * vocab_size; - const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; - TopK_2 partial; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index) + const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam + const int tmp_log_buf_index = row_id * vocab_size; + const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; + TopK_2 partial; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; if (finished != nullptr && finished[row_id] == true) { if (tid < k) { const int index = tmp_topk_buf_index + tid; if (block_lane == 0 && tid == 0) { - const int end_id = end_ids[row_id / k]; - topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; - topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id]; + const int end_id = end_ids[row_id / k]; + topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; + topk_tmp_val_buf[index] = length_penalty == 0.0f ? + log_probs[tmp_log_buf_index + end_id] : + apply_length_penalty(log_probs[tmp_log_buf_index + end_id], + sequence_lengths[row_id], + length_penalty); } else { - topk_tmp_id_buf[index] = -1; + topk_tmp_id_buf[index] = -1; topk_tmp_val_buf[index] = -MAX_T_VAL; } } @@ -170,8 +196,10 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { - int index = elem_id + tmp_log_buf_index; - tmp_log_probs[index] = log_probs[index]; + int index = elem_id + tmp_log_buf_index; + tmp_log_probs[index] = length_penalty == 0.0f ? + log_probs[index] : + apply_length_penalty(log_probs[index], sequence_lengths[row_id] + 1, length_penalty); } for (int ite = 0; ite < k; ite++) { @@ -186,10 +214,10 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { - const int index = tmp_topk_buf_index + ite; - topk_tmp_id_buf[index] = total.p; + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = total.p; topk_tmp_val_buf[index] = total.u; - tmp_log_probs[total.p] = -MAX_T_VAL; + tmp_log_probs[total.p] = -MAX_T_VAL; } __syncthreads(); } @@ -198,17 +226,17 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, template __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, const int k) { - const int size = k * k * BLOCKS_PER_BEAM_; - const int tid = threadIdx.x; - const int batch_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const int size = k * k * BLOCKS_PER_BEAM_; + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - extern __shared__ char array[]; - T* s_val = topk_tmp_val_buf + batch_id * size; - int* s_id = (int*)(array); + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char array[]; + T* s_val = topk_tmp_val_buf + batch_id * size; + int* s_id = (int*)(array); TopK_2 partial; @@ -222,7 +250,7 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { - s_id[ite] = total.p; + s_id[ite] = total.p; s_val[total.p] = -MAX_T_VAL; } __syncthreads(); @@ -234,28 +262,35 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk template __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, - T* tmp_log_probs, - int* topk_tmp_id_buf, - T* topk_tmp_val_buf, - const int k, - const int vocab_size) + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const bool* finished, + const int* sequence_lengths, + const int k, + const int vocab_size, + const float length_penalty) { - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs - const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam - const int tmp_log_buf_index = row_id * vocab_size; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs + const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam + const int tmp_log_buf_index = row_id * vocab_size; const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k; TopK_2 partial; for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) { - int index = elem_id + tmp_log_buf_index; - tmp_log_probs[index] = log_probs[index]; + int index = elem_id + tmp_log_buf_index; + tmp_log_probs[index] = length_penalty == 0.0f ? log_probs[index] : + apply_length_penalty(log_probs[index], + finished[bid] ? sequence_lengths[bid] : + sequence_lengths[bid] + 1, + length_penalty); } for (int ite = 0; ite < k; ite++) { @@ -270,10 +305,10 @@ __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { - const int index = tmp_topk_buf_index + ite; - topk_tmp_id_buf[index] = total.p; + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = total.p; topk_tmp_val_buf[index] = total.u; - tmp_log_probs[total.p] = -MAX_T_VAL; + tmp_log_probs[total.p] = -MAX_T_VAL; } __syncthreads(); } @@ -283,17 +318,17 @@ template __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, const int k) { - const int size = k * k * BLOCKS_PER_BEAM; - const int tid = threadIdx.x; - const int batch_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const int size = k * k * BLOCKS_PER_BEAM; + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - extern __shared__ char array[]; - T* s_val = topk_tmp_val_buf + batch_id * size; - int* s_id = (int*)(array); + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char array[]; + T* s_val = topk_tmp_val_buf + batch_id * size; + int* s_id = (int*)(array); TopK_2 partial; @@ -307,7 +342,7 @@ topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { - s_id[ite] = total.p; + s_id[ite] = total.p; s_val[total.p] = -MAX_T_VAL; } __syncthreads(); @@ -319,8 +354,14 @@ topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val #define CASE_K_DIV(K, BLOCK_SIZE_1, BLOCK_SIZE_2) \ case K: \ - beam_topK_kernel<<>>( \ - log_probs, topk_tmp_id_buf, topk_tmp_val_buf, vocab_size, diversity_rate); \ + beam_topK_kernel<<>>(log_probs, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + finished, \ + sequence_lengths, \ + vocab_size, \ + diversity_rate, \ + length_penalty); \ if (K < 10) \ batch_topK_kernel \ <<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \ @@ -336,8 +377,10 @@ topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val topk_tmp_id_buf, \ topk_tmp_val_buf, \ finished, \ + sequence_lengths, \ beam_width, \ vocab_size, \ + length_penalty, \ end_ids); \ topk_stage_2_opt3 \ <<>>( \ @@ -345,29 +388,33 @@ topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val break; template -void invokeTopkBeamSearch(void* workspace, - size_t& workspace_size, - T* log_probs, - int* ids, - const bool* finished, - const int batch_size, - const int beam_width, - const int vocab_size_padded_, - const T diversity_rate, - const int* end_ids, +void invokeTopkBeamSearch(void* workspace, + size_t& workspace_size, + T* log_probs, + int* ids, + const bool* finished, + const int* sequence_lengths, + const int batch_size, + const int beam_width, + const int vocab_size_padded_, + const T diversity_rate, + const float length_penalty, + const int* end_ids, cudaStream_t stream) { + // log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a token. const int vocab_size = vocab_size_padded_; - - const int max_block_per_beam = 8; - int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float - int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int - int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float - - // prevent memory misalinged address + // Beam search needs the sequence lengths of beams to apply length penalty. + assert(length_penalty == 0.0f || sequence_lengths != nullptr); + const int max_block_per_beam = 8; + int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float + int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int + int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float + + // prevent memory misaligned address temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4; - topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; - topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; + topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; + topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; if (workspace == nullptr) { workspace_size = sizeof(float) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size @@ -375,9 +422,9 @@ void invokeTopkBeamSearch(void* workspace, return; } else { - T* temp_log_probs = (T*)workspace; - int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); - T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); + T* temp_log_probs = (T*)workspace; + int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); + T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); if (diversity_rate == 0.0f) { switch (beam_width) { CASE_K(1, 128, 128, 8); @@ -387,8 +434,16 @@ void invokeTopkBeamSearch(void* workspace, CASE_K(32, 256, 128, 1); CASE_K(64, 256, 256, 1); default: - topk_stage_1_opt2_general<<>>( - log_probs, temp_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, beam_width, vocab_size); + topk_stage_1_opt2_general + <<>>(log_probs, + temp_log_probs, + topk_tmp_id_buf, + topk_tmp_val_buf, + finished, + sequence_lengths, + beam_width, + vocab_size, + length_penalty); topk_stage_2_opt2_general << -__global__ void tileEncoderResults(T* tiled_output, - int* tiled_sequence_length, - const T* output, +__global__ void tileEncoderResults(T* tiled_output, + int* tiled_sequence_length, + const T* output, const int* sequence_length, const uint batch_size, const uint beam_width, @@ -452,10 +509,10 @@ __global__ void tileEncoderResults(T* tiled_output, } template -void invokeTileEncoderResults(T* tiled_output, - int* tiled_sequence_length, - const T* output, - const int* sequence_length, +void invokeTileEncoderResults(T* tiled_output, + int* tiled_sequence_length, + const T* output, + const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, @@ -469,16 +526,18 @@ void invokeTileEncoderResults(T* tiled_output, // sequence_length [batch_size] dim3 grid(batch_size, beam_width, mem_max_seq_len); + bool is_half2 = (std::is_same::value) && (d_model % 2 == 0); - if (d_model % 2 == 0 && std::is_same::value) { + if (is_half2) { + using T2 = typename TypeConverter::Type; // fp16 to half2, bf16 to bf162 dim3 block(min(512, (int)(d_model / 2))); - tileEncoderResults<<>>((half2*)tiled_output, - tiled_sequence_length, - (const half2*)output, - sequence_length, - batch_size, - beam_width, - d_model / 2); + tileEncoderResults<<>>((T2*)tiled_output, + tiled_sequence_length, + (const T2*)output, + sequence_length, + batch_size, + beam_width, + d_model / 2); } else { dim3 block(min(512, (int)d_model)); @@ -487,33 +546,45 @@ void invokeTileEncoderResults(T* tiled_output, } } -template void invokeTileEncoderResults(float* tiled_output, - int* tiled_sequence_length, +template void invokeTileEncoderResults(float* tiled_output, + int* tiled_sequence_length, const float* output, - const int* sequence_length, + const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); -template void invokeTileEncoderResults(half* tiled_output, - int* tiled_sequence_length, - const half* output, - const int* sequence_length, +template void invokeTileEncoderResults(half* tiled_output, + int* tiled_sequence_length, + const half* output, + const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); -template void invokeTileEncoderResults(half2* tiled_output, - int* tiled_sequence_length, +template void invokeTileEncoderResults(half2* tiled_output, + int* tiled_sequence_length, const half2* output, - const int* sequence_length, + const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); -} // namespace fastertransformer \ No newline at end of file +#ifdef ENABLE_BF16 +template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, + int* tiled_sequence_length, + const __nv_bfloat16* output, + const int* sequence_length, + const size_t batch_size, + const size_t beam_width, + const size_t mem_max_seq_len, + const size_t d_model, + cudaStream_t stream); +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/beam_search_topk_kernels.h b/src/fastertransformer/kernels/beam_search_topk_kernels.h index 38a930d46..42f509261 100644 --- a/src/fastertransformer/kernels/beam_search_topk_kernels.h +++ b/src/fastertransformer/kernels/beam_search_topk_kernels.h @@ -21,23 +21,25 @@ namespace fastertransformer { template -void invokeTopkBeamSearch(void* workspace, - size_t& workspace_size, - T* log_probs, - int* ids, - const bool* finished, - const int batch_size, - const int beam_width, - const int vocab_size_padded_, - const T diversity_rate, - const int* end_ids, +void invokeTopkBeamSearch(void* workspace, + size_t& workspace_size, + T* log_probs, + int* ids, + const bool* finished, + const int* sequence_lengths, + const int batch_size, + const int beam_width, + const int vocab_size_padded_, + const T diversity_rate, + const float length_penalty, + const int* end_ids, cudaStream_t stream); template -void invokeTileEncoderResults(T* tiled_encoder_output, - int* tiled_encoder_sequence_length, - const T* encoder_output, - const int* encoder_sequence_length, +void invokeTileEncoderResults(T* tiled_encoder_output, + int* tiled_encoder_sequence_length, + const T* encoder_output, + const int* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.cu b/src/fastertransformer/kernels/bert_preprocess_kernels.cu index c855fa157..a40fd6951 100644 --- a/src/fastertransformer/kernels/bert_preprocess_kernels.cu +++ b/src/fastertransformer/kernels/bert_preprocess_kernels.cu @@ -18,16 +18,16 @@ namespace fastertransformer { -__global__ void getPaddingOffsetKernel(size_t* valid_word_num, - int* tmp_mask_offset, +__global__ void getPaddingOffsetKernel(size_t* valid_word_num, + int* tmp_mask_offset, const int* sequence_length, - const int batch_size, - const int max_seq_len) + const int batch_size, + const int max_seq_len) { // do cumulated sum int total_seq_len = 0; - int cum_offset = 0; - int index = 0; + int cum_offset = 0; + int index = 0; for (int i = 0; i < batch_size; i++) { const int seq_len = sequence_length[i]; for (int j = 0; j < seq_len; j++) { @@ -40,12 +40,12 @@ __global__ void getPaddingOffsetKernel(size_t* valid_word_num, valid_word_num[0] = (size_t)total_seq_len; } -void invokeGetPaddingOffset(size_t* h_token_num, - size_t* d_token_num, - int* tmp_mask_offset, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, +void invokeGetPaddingOffset(size_t* h_token_num, + size_t* d_token_num, + int* tmp_mask_offset, + const int* sequence_lengths, + const int batch_size, + const int max_seq_len, cudaStream_t stream) { getPaddingOffsetKernel<<<1, 1, 0, stream>>>( @@ -83,16 +83,23 @@ void invokeBuildEncoderAttentionMask( buildEncoderAttentionMaskKernel<<>>(attention_mask, sequence_lengths, max_seq_len); } -template void invokeBuildEncoderAttentionMask(float* attention_mask, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, +template void invokeBuildEncoderAttentionMask(float* attention_mask, + const int* sequence_lengths, + const int batch_size, + const int max_seq_len, cudaStream_t stream); -template void invokeBuildEncoderAttentionMask(half* attention_mask, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, +template void invokeBuildEncoderAttentionMask(half* attention_mask, + const int* sequence_lengths, + const int batch_size, + const int max_seq_len, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeBuildEncoderAttentionMask(__nv_bfloat16* attention_mask, + const int* sequence_lengths, + const int batch_size, + const int max_seq_len, + cudaStream_t stream); +#endif __global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int* sequence_length, const int batch_size) { @@ -113,19 +120,19 @@ __global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int } } -void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int batch_size, +void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, + const int* sequence_length, + const int batch_size, cudaStream_t stream) { getTrtPaddingOffsetKernel<<<1, 256, sizeof(int) * (batch_size + 1), stream>>>( trt_mha_padding_offset, sequence_length, batch_size); } -__global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, +__global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int* sequence_length, - const int request_batch_size, - const int request_seq_len) + const int request_batch_size, + const int request_seq_len) { // use for get tensorrt fused mha padding offset // when we keep the padding @@ -145,10 +152,10 @@ __global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, } } -void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int request_batch_size, - const int request_seq_len, +void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, + const int* sequence_length, + const int request_batch_size, + const int request_seq_len, cudaStream_t stream) { getTrtPaddingOffsetKernel<<<1, 256, sizeof(int) * (2 * request_batch_size + 1), stream>>>( @@ -158,8 +165,8 @@ void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, template __global__ void rebuild_sequence_length_padding(const T* src, T* dst, const int* padding_offset, const int n) { - const int tid = threadIdx.x; - const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int bid = blockIdx.x; const int dst_seq_id = bid + padding_offset[bid]; const int src_seq_id = bid; @@ -180,24 +187,32 @@ void invokeRebuildPadding( template void invokeRebuildPadding( T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); -template void invokeRebuildPadding(float* dst, +template void invokeRebuildPadding(float* dst, const float* src, - const int* padding_offset, - const int token_num, - const int hidden_dim, + const int* padding_offset, + const int token_num, + const int hidden_dim, cudaStream_t stream); -template void invokeRebuildPadding(half* dst, - const half* src, - const int* padding_offset, - const int token_num, - const int hidden_dim, +template void invokeRebuildPadding(half* dst, + const half* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeRebuildPadding(__nv_bfloat16* dst, + const __nv_bfloat16* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); +#endif template __global__ void remove_padding(T* tgt, const T* src, const int* padding_offset, const int n) { - const int tid = threadIdx.x; - const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int bid = blockIdx.x; const int src_seq_id = bid + padding_offset[bid]; const int tgt_seq_id = bid; @@ -213,28 +228,37 @@ void invokeRemovePadding( remove_padding<<>>(dst, src, padding_offset, hidden_dim); } -template void invokeRemovePadding(float* dst, +template void invokeRemovePadding(float* dst, const float* src, - const int* padding_offset, - const int token_num, - const int hidden_dim, + const int* padding_offset, + const int token_num, + const int hidden_dim, cudaStream_t stream); -template void invokeRemovePadding(half* dst, - const half* src, - const int* padding_offset, - const int token_num, - const int hidden_dim, +template void invokeRemovePadding(half* dst, + const half* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeRemovePadding(__nv_bfloat16* dst, + const __nv_bfloat16* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); +#endif + template -__global__ void buildRelativeAttentionBias(T* relative_attention_bias, - const T* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, +__global__ void buildRelativeAttentionBias(T* relative_attention_bias, + const T* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, const bool is_bidirectional, - const int max_distance) + const int max_distance) { const int head_id = blockIdx.x; @@ -245,7 +269,7 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias, int relative_position = col_id - row_id; int relative_buckets = 0; - int tmp_num_bucket = num_bucket; + int tmp_num_bucket = num_bucket; if (is_bidirectional) { tmp_num_bucket /= 2; if (relative_position > 0) { @@ -259,8 +283,8 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias, relative_position = abs(relative_position); } - int max_exact = tmp_num_bucket / 2; - bool is_small = relative_position < max_exact; + int max_exact = tmp_num_bucket / 2; + bool is_small = relative_position < max_exact; int relative_position_if_large = max_exact @@ -277,15 +301,15 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias, } template -void invokeBuildRelativeAttentionBias(T* relative_attention_bias, - const T* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance, +void invokeBuildRelativeAttentionBias(T* relative_attention_bias, + const T* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, const PositionEmbeddingType position_embedding_type, - cudaStream_t stream) + cudaStream_t stream) { if (position_embedding_type == PositionEmbeddingType::absolute) { return; @@ -301,24 +325,36 @@ void invokeBuildRelativeAttentionBias(T* relative_attention_bias, max_distance); } -template void invokeBuildRelativeAttentionBias(float* relative_attention_bias, - const float* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance, +template void invokeBuildRelativeAttentionBias(float* relative_attention_bias, + const float* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, + const PositionEmbeddingType position_embedding_type, + cudaStream_t stream); + +template void invokeBuildRelativeAttentionBias(half* relative_attention_bias, + const half* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, const PositionEmbeddingType position_embedding_type, - cudaStream_t stream); - -template void invokeBuildRelativeAttentionBias(half* relative_attention_bias, - const half* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance, + cudaStream_t stream); + +#ifdef ENABLE_BF16 +template void invokeBuildRelativeAttentionBias(__nv_bfloat16* relative_attention_bias, + const __nv_bfloat16* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, const PositionEmbeddingType position_embedding_type, - cudaStream_t stream); + cudaStream_t stream); +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.h b/src/fastertransformer/kernels/bert_preprocess_kernels.h index dcb8f85f2..0bd3cfa89 100644 --- a/src/fastertransformer/kernels/bert_preprocess_kernels.h +++ b/src/fastertransformer/kernels/bert_preprocess_kernels.h @@ -22,27 +22,27 @@ namespace fastertransformer { -void invokeGetPaddingOffset(size_t* h_token_num, - size_t* d_token_num, - int* tmp_mask_offset, - const int* sequence_length, - const int batch_size, - const int max_seq_len, +void invokeGetPaddingOffset(size_t* h_token_num, + size_t* d_token_num, + int* tmp_mask_offset, + const int* sequence_length, + const int batch_size, + const int max_seq_len, cudaStream_t stream); template void invokeBuildEncoderAttentionMask( T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream); -void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int request_batch_size, +void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, + const int* sequence_length, + const int request_batch_size, cudaStream_t stream); -void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int request_batch_size, - const int request_seq_len, +void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, + const int* sequence_length, + const int request_batch_size, + const int request_seq_len, cudaStream_t stream); template @@ -54,14 +54,14 @@ void invokeRemovePadding( T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); template -void invokeBuildRelativeAttentionBias(T* relative_attention_bias, - const T* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance, +void invokeBuildRelativeAttentionBias(T* relative_attention_bias, + const T* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, const PositionEmbeddingType position_embedding_type, - cudaStream_t stream); + cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh b/src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh index de1a72e82..c81028bb3 100644 --- a/src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh +++ b/src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh @@ -33,6 +33,14 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { #endif } +inline __device__ __nv_bfloat162 float22bf162(const float2 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __floats2bfloat162_rn(val.x, val.y); +#else + return __float22bfloat162_rn(val); +#endif +} + inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 __nv_bfloat162 val2; @@ -122,6 +130,14 @@ inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bf #endif } +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); +#else + return __hfma(x, y, z); +#endif +} + inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 float fxl, fxh; @@ -195,7 +211,9 @@ inline __device__ __nv_bfloat162 float2type2(float a) { // Convert float to type (applied to half and bfloat16) template -inline __device__ T float2type(float a); +inline __device__ T float2type(float a) { + return a; +} template<> inline __device__ half float2type(float a) { @@ -209,6 +227,60 @@ inline __device__ __nv_bfloat16 float2type(float a) { } #endif // ENABLE_BF16 +// Convert type to float (applied to half and bfloat16) +template +inline __device__ float type2float(T a) { + return a; +} + +template<> +inline __device__ float type2float(half a) { + return __half2float(a); +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ float type2float(__nv_bfloat16 a) { + return __bfloat162float(a); +} +#endif + +// Convert type2 to float2 (applied to half and bfloat16) +template +inline __device__ float2 type22float2(T a) { + return a; +} + +template<> +inline __device__ float2 type22float2(half2 a) { + return __half22float2(a); +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ float2 type22float2(__nv_bfloat162 a) { + return bf1622float2(a); +} +#endif // ENABLE_BF16 + +// Convert float2 to type2 (applied to half and bfloat16) +template +inline __device__ T float22type2(float2 a) { + return a; +} + +template<> +inline __device__ half2 float22type2(float2 a) { + return __float22half2_rn(a); +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat162 float22type2(float2 a) { + return float22bf162(a); +} +#endif // ENABLE_BF16 + // Convert type to type2 (applied to half and bfloat16) template inline __device__ T_OUT type2type2(T_IN a); @@ -238,6 +310,87 @@ inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b) { } #endif // ENABLE_BF16 +template +inline __device__ T add(T a, T b) { + return a + b; +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { + return bf16hadd2(a, b); +} + +template<> +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { + return bf16hadd(a, b); +} + +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b) { + return bf16hadd(a, __float2bfloat16(b)); +} +#endif // ENABLE_BF16 + +template<> +inline __device__ half2 add(half2 a, half2 b) { + return __hadd2(a, b); +} + +template<> +inline __device__ half add(half a, half b) { + return __hadd(a, b); +} + +// applies to all 4 values addition +template +inline __device__ T add(T a, T b, T c) { + return a + b + c; +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); +#else + return a + b + c; +#endif +} + +template<> +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); +#else + return a + b + c; +#endif +} +#endif // ENABLE_BF16 + +// applies to all 4 values addition +template +inline __device__ T add(T a, T b, T c, T d) { + return (T)((float)a + (float)b + (float)c + (float)d); +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); +#else + return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); +#endif +} +#endif // ENABLE_BF16 + template inline __device__ T hsub2(T a, T b) { return __hsub2(a, b); @@ -262,6 +415,103 @@ inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b) { } #endif // ENABLE_BF16 +template +inline __device__ T hmul2(T a, T b, T c) { + return a * b * c; +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); +#else + return a * b * c; +#endif +} +#endif // ENABLE_BF16 + +template +inline __device__ T mul(T a, T b, T c) { + return a * b * c; +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); +#else + return a * b * c; +#endif +} + +template<> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); +#else + return a * b * c; +#endif +} +#endif // ENABLE_BF16 + +template +inline __device__ T fma(T a, T b, T c, T d) { + return a * b * c + d; +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); +#else + return a * b * c + d; +#endif +} +#endif // ENABLE_BF16 + +template +inline __device__ T fma(T a, T b, T c) { + return a * b + c; +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { + return bf16hfma2(a, b, c); +} + +template<> +inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { + return bf16hfma(a, b, c); +} +#endif // ENABLE_BF16 + template inline __device__ T hexp2(T a) { return h2exp(a); diff --git a/src/fastertransformer/kernels/calibrate_quantize_weight_kernels.cu b/src/fastertransformer/kernels/calibrate_quantize_weight_kernels.cu index 0c2ce3a21..6f8359804 100644 --- a/src/fastertransformer/kernels/calibrate_quantize_weight_kernels.cu +++ b/src/fastertransformer/kernels/calibrate_quantize_weight_kernels.cu @@ -81,25 +81,25 @@ __global__ void ldk_calibrate_quantize_weight_per_channel(int8_t* dst, float* sc scale += bidx; src += bidx * k; dst += bidx * k; - T amax_val = 0.0f; - const T zero = static_cast(0.0f); + T amax_val = 0.0f; + const T zero = static_cast(0.0f); for (int k_i = tidx; k_i < k; k_i += blockDim.x) { T val = src[k_i]; - val = val > zero ? val : -val; + val = val > zero ? val : -val; if (amax_val > val) { amax_val = val; } } __shared__ float s_amax; - const float block_amax_val = blockReduceMax(static_cast(amax_val)); + const float block_amax_val = blockReduceMax(static_cast(amax_val)); if (tidx == 0) { - s_amax = block_amax_val; + s_amax = block_amax_val; scale[0] = block_amax_val / 127.0f; } __syncthreads(); for (int k_i = tidx; k_i < k; k_i += blockDim.x) { - T val = src[k_i]; + T val = src[k_i]; dst[k_i] = float_to_int8_rn(127.0f * static_cast(val) / s_amax); } } @@ -132,10 +132,10 @@ __global__ void ldn_transpose_quantize_weight_per_channel(int8_t* dst, const float* scale, const T* src, const int k, const int n) { __shared__ T shm[32][33]; - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - int n_idx = blockIdx.x * 32 + tidx; - int k_idx = blockIdx.y * 32 + tidy; + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + int n_idx = blockIdx.x * 32 + tidx; + int k_idx = blockIdx.y * 32 + tidy; if (n_idx < n && k_idx < k) { shm[tidx][tidy] = src[k_idx * n + n_idx]; } diff --git a/src/fastertransformer/kernels/custom_ar_kernels.cu b/src/fastertransformer/kernels/custom_ar_kernels.cu index a03ed1cbf..2de640b56 100644 --- a/src/fastertransformer/kernels/custom_ar_kernels.cu +++ b/src/fastertransformer/kernels/custom_ar_kernels.cu @@ -44,7 +44,7 @@ static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_add #if __CUDA_ARCH__ >= 700 asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); #else - __threadfence(); + __threadfence_system(); asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); #endif } @@ -128,9 +128,9 @@ inline __device__ uint4 init_packed_type() template<> inline __device__ bf168 init_packed_type() { - bf168 val; + bf168 val; uint4& val_u = reinterpret_cast(val); - val_u = make_uint4(0u, 0u, 0u, 0u); + val_u = make_uint4(0u, 0u, 0u, 0u); return val; } #endif @@ -173,7 +173,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) const T* src_d[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - int rank = (params.local_rank + ii) % RANKS_PER_NODE; + int rank = (params.local_rank + ii) % RANKS_PER_NODE; src_d[ii] = params.peer_comm_buffer_ptrs[rank]; } @@ -239,8 +239,8 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params) size_t dst_rank[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - int rank = (params.local_rank + ii) % RANKS_PER_NODE; - src_d[ii] = params.peer_comm_buffer_ptrs[rank]; + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + src_d[ii] = params.peer_comm_buffer_ptrs[rank]; dst_rank[ii] = rank; } @@ -275,7 +275,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params) st_flag_release(params.barrier_flag, params.peer_barrier_ptrs[tidx] + flag_block_offset + params.local_rank); // Busy-wait until all ranks are ready. - uint32_t rank_barrier = 0; + uint32_t rank_barrier = 0; uint32_t* peer_barrier_d = params.peer_barrier_ptrs[params.local_rank] + flag_block_offset + tidx; do { ld_flag_acquire(rank_barrier, peer_barrier_d); @@ -303,15 +303,15 @@ void kernelLaunchConfig( int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo, size_t data_type_bytes) { assert(data_type_bytes == 2 || data_type_bytes == 4); - // NOTE: need to supprot FP16 and FP32 - int elts_per_thread = 16 / data_type_bytes; - int elts_per_warp = (16 * WARP_SIZE) / data_type_bytes; + // NOTE: need to support FP16 and FP32 + size_t elts_per_thread = 16 / data_type_bytes; + size_t elts_per_warp = (16 * WARP_SIZE) / data_type_bytes; switch (kernel_algo) { case 0: { // one stage all reduce algo assert(elts % elts_per_warp == 0); if (elts < (elts_per_thread * DEFAULT_BLOCK_SIZE)) { // local reduce threads_per_block = ((elts + elts_per_warp - 1) / elts_per_warp) * WARP_SIZE; - blocks_per_grid = 1; + blocks_per_grid = 1; } else { // local reduce if (elts % (elts_per_thread * threads_per_block) == 0) { @@ -328,7 +328,7 @@ void kernelLaunchConfig( } else { int total_threads = elts / elts_per_thread; - blocks_per_grid = 1; + blocks_per_grid = 1; while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) { blocks_per_grid += 1; @@ -366,9 +366,9 @@ void kernelLaunchConfig( template void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream) { - size_t elts_total = param.elts_total; - int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; - int kernel_algo = 1; + size_t elts_total = param.elts_total; + int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; + int kernel_algo = 1; if (elts_total <= DEFALUT_ALGO_AR_SIZE_THRESHOLD) { kernel_algo = 0; } @@ -376,14 +376,14 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t s kernelLaunchConfig(blocks_per_grid, threads_per_block, elts_total, kernel_algo, sizeof(T)); if (kernel_algo == 0) { - param.elts_per_rank = elts_total; + param.elts_per_rank = elts_total; param.elts_per_block = param.elts_per_rank / blocks_per_grid; oneShotAllReduceKernel<<>>(param); } else { - param.elts_per_rank = param.elts_total / RANKS_PER_NODE; + param.elts_per_rank = param.elts_total / RANKS_PER_NODE; param.elts_per_block = param.elts_per_rank / blocks_per_grid; - param.rank_offset = param.rank * param.elts_per_rank; + param.rank_offset = param.rank * param.elts_per_rank; twoShotAllReduceKernel<<>>(param); } } @@ -392,7 +392,7 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t s template void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams<__nv_bfloat16>& param, - cudaStream_t stream); + cudaStream_t stream); #endif template void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream); } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/custom_ar_kernels.h b/src/fastertransformer/kernels/custom_ar_kernels.h index a63cc5984..2938c18d6 100644 --- a/src/fastertransformer/kernels/custom_ar_kernels.h +++ b/src/fastertransformer/kernels/custom_ar_kernels.h @@ -44,15 +44,15 @@ typedef struct bf168 { template struct AllReduceParams { - size_t elts_total; - size_t elts_per_rank; - size_t elts_per_block; - size_t rank_offset; - size_t rank, local_rank, node_id; - uint32_t barrier_flag; + size_t elts_total; + size_t elts_per_rank; + size_t elts_per_block; + size_t rank_offset; + size_t rank, local_rank, node_id; + uint32_t barrier_flag; uint32_t* peer_barrier_ptrs[RANKS_PER_NODE]; - T* peer_comm_buffer_ptrs[RANKS_PER_NODE]; - T* local_output_buffer_ptr; + T* peer_comm_buffer_ptrs[RANKS_PER_NODE]; + T* local_output_buffer_ptr; }; template diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu index 2c33ba37a..193489bf9 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu @@ -15,1278 +15,13 @@ */ #include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include #include #include -// #define MMHA_USE_HMMA_FOR_REDUCTION - -// Below are knobs to extend FP32 accumulation for higher FP16 accuracy - -// Does not seem to affect the accuracy that much -// #define MMHA_USE_FP32_ACUM_FOR_FMA - -// Seems to slightly improve the accuracy -#define MMHA_USE_FP32_ACUM_FOR_OUT - -#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) - // Does not seem to improve the accuracy - //#define MMHA_USE_FP32_ACUM_FOR_LOGITS -#endif - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. -// -// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use -// 64, 128 and 256 threads per block. -// -// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to -// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The -// cache buffer helps with memory accesses and contains keys with bias. -// -// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and -// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The -// values for x are chosen to create chunks of 16 bytes. -// -// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs -// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At -// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an -// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. -// -// After that loop, a parallel softmax is computed accross the different Q * K^T values stored in -// shared memory. -// -// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many -// timesteps are computed by loop iteration. As with the keys, the values are read from a cache -// except for the current timestep. The layout of the cache buffer for the values is much simpler -// as it is [B, H, L, Dh]. -// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_vec_ {}; - -template<> -struct Qk_vec_ { - using Type = float; -}; -template<> -struct Qk_vec_ { - using Type = float2; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint2; -}; -template<> -struct Qk_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct Qk_vec_<__nv_bfloat16, 32> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 64> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 128> { - using Type = bf16_4_t; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 256> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_ {}; - -template<> -struct K_vec_ { - using Type = float; -}; -template<> -struct K_vec_ { - using Type = float2; -}; -template<> -struct K_vec_ { - using Type = float4; -}; -template<> -struct K_vec_ { - using Type = uint32_t; -}; -template<> -struct K_vec_ { - using Type = uint2; -}; -template<> -struct K_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct K_vec_<__nv_bfloat16, 4> { - using Type = __nv_bfloat162; -}; -template<> -struct K_vec_<__nv_bfloat16, 2> { - using Type = bf16_4_t; -}; -template<> -struct K_vec_<__nv_bfloat16, 1> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct V_vec_ {}; - -template<> -struct V_vec_ { - using Type = float; -}; -template<> -struct V_vec_ { - using Type = float2; -}; -template<> -struct V_vec_ { - using Type = float4; -}; -template<> -struct V_vec_ { - using Type = uint32_t; -}; -template<> -struct V_vec_ { - using Type = uint2; -}; -template<> -struct V_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_<__nv_bfloat16, 2> { - using Type = __nv_bfloat162; -}; -template<> -struct V_vec_<__nv_bfloat16, 4> { - using Type = bf16_4_t; -}; -template<> -struct V_vec_<__nv_bfloat16, 8> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct Qk_vec_acum_fp32_ {}; - -template<> -struct Qk_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float4; -}; -// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_acum_fp32_ {}; - -template<> -struct K_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template -struct V_vec_acum_fp32_ {}; - -template<> -struct V_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif // ENABLE_BF16 -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) -{ -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = K_vec; -#endif - // Compute the parallel products for Q*K^T (treat vector lanes separately). - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } - - // Finalize the reduction across lanes. - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) - { - return qk_dot_(q, k); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) -{ - float4 c; - float zero = 0.f; - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) - { -#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) - return qk_hmma_dot_(q, k); -#else - return qk_dot_<4>(q, k); -#endif // defined MMHA_USE_HMMA_FOR_REDUCTION - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - -// Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float& dst, float src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint16_t& dst, float src) -{ - dst = float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint32_t& dst, float2 src) -{ - dst = float2_to_half2(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) -{ - dst = __float2bfloat16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, Float4_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint4& dst, Float8_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float2& dst, float2 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float4& dst, float4 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(float4 u) -{ - return u.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(uint4 u) -{ - float2 tmp = half2_to_float2(u.x); - return tmp.x; -} - -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float cast_to_float(float u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(float2 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 cast_to_float(float4 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(Float4_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(Float8_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(uint32_t u) -{ - return half2_to_float2(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(uint2 u) -{ - Float4_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - return tmp; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(uint4 u) -{ - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ __host__ T div_up(T m, T n) -{ - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline size_t smem_size_in_bytes(const Multihead_attention_params& params, - int threads_per_value, - int threads_per_block) -{ - // The amount of shared memory needed to store the Q*K^T values in float. - // TODO - size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.seq_length + 1, 4) * 16 : div_up(params.timestep + 1, 4) * 16; - - // The extra memory needed if we are not using floats for the final logits. - size_t logits_sz = 0; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TDOD - logits_sz = div_up(params.seq_length, 4) * 4 * sizeof(T); - } -#endif - - // The total size needed during softmax. - size_t softmax_sz = qk_sz + logits_sz; - - // The number of partial rows to reduce in the final reduction. - int rows_per_red = threads_per_block / threads_per_value; - // The amount of storage needed to finalize the outputs. - size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; - - // The max. - return max(softmax_sz, red_sz); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ constexpr uint32_t shfl_mask(int threads) -{ - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The type of the inputs. Supported types: float and half. - typename T, - // The hidden dimension per head. - int Dh, - int Dh_MAX, - // The number of threads per key. - int THREADS_PER_KEY, - // The number of threads per value. - int THREADS_PER_VALUE, - // The number of threads in a threadblock. - int THREADS_PER_BLOCK, - bool DO_CROSS_ATTENTION> -__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) -{ - - // Make sure the hidden dimension per head is a multiple of the number of threads per key. - static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); - // Make sure the hidden dimension per head is a multiple of the number of threads per value. - static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - - // The size of a warp. - constexpr int WARP_SIZE = 32; - // The number of warps in a threadblock. - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - // Use smem_size_in_bytes (above) to determine the amount of shared memory. - extern __shared__ char smem_[]; - - // The shared memory for the Q*K^T values and partial logits in softmax. - float* qk_smem = reinterpret_cast(smem_); - - // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. - char* logits_smem_ = smem_; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TODO - cahnge to tlength - logits_smem_ += - (DO_CROSS_ATTENTION) ? div_up(params.seq_length + 1, 4) * 16 : div_up(params.timestep + 1, 4) * 16; - } - T* logits_smem = reinterpret_cast(logits_smem_); -#else - float* logits_smem = reinterpret_cast(logits_smem_); -#endif - - // The shared memory to do the final reduction for the output values. Reuse qk_smem. - T* out_smem = reinterpret_cast(smem_); - - // The shared memory buffers for the block-wide reductions. One for max, one for sum. - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - - // Use alignment for safely casting the shared buffers as Qk_vec. - // Shared memory to store Q inputs. - __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - - // This is one of the reasons we should have a separate kernel for cross attention - __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - // The number of elements per vector. - constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); - // We will use block wide reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); - // The number of vectors per warp. - constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; - - // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread - // owns x elements, we have to decompose the linear index into chunks of x values and the posi- - // tion of the thread in that chunk. - - // The number of elements in a chunk of 16B (that's the x in the above formula). - constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); - // The number of K vectors in 16B. - constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - - // The batch/beam idx - const int bi = blockIdx.y; - if (params.finished != nullptr && params.finished[bi] == true) { - return; - } - // The beam idx - const int beami = bi % params.beam_width; - // The "beam-aware" batch idx - const int bbi = bi / params.beam_width; - // The head. - const int hi = blockIdx.x; - // Combine the batch and the head indices. - const int bhi = bi * params.num_heads + hi; - // Combine the "beam-aware" batch idx and the head indices. - const int bbhi = bbi * params.beam_width * params.num_heads + hi; - // The thread in the block. - const int tidx = threadIdx.x; - - // While doing the product Q*K^T for the different keys we track the max. - float qk_max = -FLT_MAX; - - float qk = 0.0F; - - int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; - - // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : - (params.length_per_sample == nullptr) ? params.timestep : - params.length_per_sample[bi]; - // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. - if (tidx < QK_VECS_PER_WARP) { - - // The offset in the Q and K buffer also accounts for the batch. - int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; - // The offset in the bias buffer. - int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; - - // Trigger the loads from the Q and K buffers. - Qk_vec q; - zero(q); - q = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? *reinterpret_cast(¶ms.q[qk_offset]) : q; - Qk_vec k; - zero(k); - if (DO_CROSS_ATTENTION) { - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi * params.seq_length * Dh + co * params.seq_length * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength * QK_ELTS_IN_16B + ci; - k = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? *reinterpret_cast(¶ms.k_cache[offset]) : - k; - } - else { - k = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? *reinterpret_cast(¶ms.k[qk_offset]) : k; - } - - // Trigger the loads from the Q and K bias buffers. - Qk_vec q_bias; - zero(q_bias); - q_bias = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? - *reinterpret_cast(¶ms.q_bias[qk_bias_offset]) : - q_bias; - Qk_vec k_bias; - zero(k_bias); - - if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { - k_bias = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? - *reinterpret_cast(¶ms.k_bias[qk_bias_offset]) : - k_bias; - } - - // Computes the Q/K values with bias. - q = add(q, q_bias); - if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { - k = add(k, k_bias); - if (params.rotary_embedding_dim > 0) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep); - } - } - else { - if (params.rotary_embedding_dim > 0) { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep); - } - } - - // Store the Q values to shared memory. - *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - - // Store Dh values of k_bias into smem, since will need to add later - // if params.timestep == 0 - if (DO_CROSS_ATTENTION && params.timestep == 0) { - *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; - } - - // Write the K values to the global memory cache. - // - // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory - // system. We designed it this way as it allows much better memory loads (and there are many - // more loads) + the stores are really "write and forget" since we won't need the ack before - // the end of the kernel. There's plenty of time for the transactions to complete. - - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi * params.seq_length * Dh + co * params.seq_length * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength * QK_ELTS_IN_16B + ci; - - if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { - // Trigger the stores to global memory. - if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset]) = k; - } - } - - // Compute \sum_i Q[i] * K^T[i] for the current timestep. -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; -#else - using Qk_vec_acum = Qk_vec; -#endif - qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); - } - } - } - - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - - // Store that value in shared memory. Keep the Q*K^T value in register for softmax. - if (tidx == 0) { - // Normalize qk. - qk *= params.inv_sqrt_dh; - - if (params.relative_attention_bias_float != nullptr) { - qk = qk - + params.relative_attention_bias_float[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + tlength]; - } - else if (params.relative_attention_bias_half != nullptr) { - qk = qk - + (float) - params.relative_attention_bias_half[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + tlength]; - } - qk_max = qk; - qk_smem[tlength] = qk; - // qk_smem[params.timestep] = qk; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The type of queries and keys for the math in the Q*K^T product. - using K_vec = typename K_vec_::Type; - // The number of elements per vector. - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - // The number of elements per thread. - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - // The number of vectors per thread. - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - - // The position the first key loaded by each thread from the cache buffer (for this B * H). - int ko = tidx / THREADS_PER_KEY; - // The position of the thread in the chunk of keys. - int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; - - static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); - - // Load the Q values from shared memory. The values are reused during the loop on K. - K_vec q[K_VECS_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - - K_vec k_bias[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; - if (DO_CROSS_ATTENTION && params.timestep == 0) { -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - k_bias[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - } - - // The number of timesteps loaded per iteration. - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - // The number of keys per warp. - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bhi * params.seq_length * Dh + ki]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[bbhi * params.seq_length * Dh + ki]; - - // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). - // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; - int ti_end = div_up(tlength, K_PER_WARP) * K_PER_WARP; - - // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. - for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { - - // The keys loaded from the key cache. - K_vec k[K_VECS_PER_THREAD]; - K_vec k_vec_zero; - zero(k_vec_zero); -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.seq_length + ti; - // if( ti < params.timestep ) { - if (ti < tlength) { - const int beam_src = - (params.cache_indir != nullptr) ? - params.cache_indir[(bbi * params.beam_width + beami) * params.seq_length + ti] : - 0; - const int beam_offset = beam_src * params.num_heads * params.seq_length * Dh; - k[ii] = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.seq_length) ? - *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]) : - k_vec_zero; - // add bias and update k_cache - if (DO_CROSS_ATTENTION && params.timestep == 0) { - k[ii] = add(k[ii], k_bias[ii]); - if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.seq_length) { - *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; - } - } - } - } - - // Perform the dot product and normalize qk. - // - // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! - float qk = Qk_dot::dot(q, k) * params.inv_sqrt_dh; - bool is_mask = (params.input_lengths != nullptr && ti >= params.input_lengths[bi] && ti < params.max_input_len); - - // Store the product to shared memory. There's one qk value per timestep. Update the max. - // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias_float != nullptr) { - qk = qk - + params.relative_attention_bias_float[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + ti]; - } - else if (params.relative_attention_bias_half != nullptr) { - qk = qk - + (float) - params.relative_attention_bias_half[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + ti]; - } - qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[ti] = qk; - } - } - -// Perform the final reduction to compute the max inside each warp. -// -// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the -// group so it's not needed to run the reduction inside the group (again). -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Decompose the thread index into warp and lane. - const int warp = tidx / WARP_SIZE; - const int lane = tidx % WARP_SIZE; - - // The warp leader writes the max to shared memory. - if (lane == 0) { - red_smem[warp] = qk_max; - } - - // Make sure the products are in shared memory. - __syncthreads(); - - // The warps finalize the reduction. - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Broadcast to all the threads in the warp. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Compute the logits and start the sum. - float sum = 0.f; - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - bool is_mask = (params.input_lengths != nullptr && ti >= params.input_lengths[bi] && ti < params.max_input_len); - float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); - sum += logit; - qk_smem[ti] = logit; - } - - // Compute the sum. - sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); - - // Normalize the logits. - float inv_sum = __fdividef(1.f, sum + 1.e-6f); - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); - } - - // Put Values part below so we leverage __syncthreads - // from the previous step - - // The number of elements per vector. - constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; - // A vector of V elements for the current timestep. - using V_vec = typename V_vec_::Type; - - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - - // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[bhi * params.seq_length * Dh + vi]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[bbhi * params.seq_length * Dh + vi]; - - // The number of values processed per iteration of the loop. - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - - // One group of threads computes the product(s) for the current timestep. - V_vec v_bias; - zero(v_bias); - // if( vo == params.timestep % V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { - if (vo == tlength % V_PER_ITER) { - // Trigger the loads from the V bias buffer. - if (params.v_bias != nullptr) { - v_bias = *reinterpret_cast(¶ms.v_bias[hi * Dh + vi]); - } - if (DO_CROSS_ATTENTION) { - *reinterpret_cast(&bias_smem[vi]) = v_bias; - } - } - } - } - - // From previous, before values, step - // Also make sure the logits are in shared memory. - __syncthreads(); - - // Values continued -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - using V_vec_acum = typename V_vec_acum_fp32_::Type; -#else - using V_vec_acum = V_vec; -#endif - // The partial outputs computed by each thread. - V_vec_acum out; - zero(out); - - // Loop over the timesteps to compute the partial outputs. - // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = vo; ti < tlength; ti += V_PER_ITER) { - - // Fetch offset based on cache_indir when beam sampling - const int beam_src = (params.cache_indir != nullptr) ? - params.cache_indir[(bbi * params.beam_width + beami) * params.seq_length + ti] : - 0; - const int beam_offset = beam_src * params.num_heads * params.seq_length * Dh; - // Load the values from the cache. - V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh]); - if (DO_CROSS_ATTENTION && params.timestep == 0) { - v = add(v, *reinterpret_cast(&bias_smem[vi])); - *reinterpret_cast(&v_cache[ti * Dh]) = v; - } - // Load the logits from shared memory. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti]; - out = fma(logit, cast_to_float(v), out); -#else - T logit = logits_smem[ti]; - - // Update the partial sums. - out = fma(logit, v, out); -#endif - } - } - - // One group of threads computes the product(s) for the current timestep. - // if( vo == params.timestep % V_PER_ITER ) { - if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { - - V_vec v; - if (DO_CROSS_ATTENTION) { - v = *reinterpret_cast(&v_cache[tlength * Dh]); - } - else { - // Trigger the loads from the V buffer. - v = *reinterpret_cast(¶ms.v[qkv_base_offset + vi]); - // Trigger the loads from the V bias buffer. - // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); - } - - // Compute the V values with bias. - if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { - v = add(v, v_bias); - - // Store the values with bias back to global memory in the cache for V. - //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength * Dh]) = v; - } - - // Initialize the output value with the current timestep. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - // out = fma(logits_smem[params.timestep], cast_to_float(v), out); - out = fma(logits_smem[tlength], cast_to_float(v), out); -#else - // out = fma(logits_smem[params.timestep], v, out); - out = fma(logits_smem[tlength], v, out); -#endif - } - - // Make sure we can start writing to shared memory. - __syncthreads(); - - // Run the final reduction amongst the different groups computing different partial outputs. - if (Dh == Dh_MAX || vi < Dh) -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { - - // The midpoint in the number of active groups. - int midpoint = active_groups / 2; - - // The upper part of active threads store to shared memory. - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); -#else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; -#endif - } - __syncthreads(); - - // The bottom warps update their values. - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); - } - - // Output the final values. - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); -#else - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; -#endif - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace mmha - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.seq_length : params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); - if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); - } - else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); - } - else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { @@ -1341,7 +76,7 @@ void masked_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream) + const cudaStream_t& stream) { multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); } @@ -1362,4 +97,12 @@ void cross_multihead_attention(const Cross_multihead_attention_params& //////////////////////////////////////////////////////////////////////////////////////////////////// -#undef MMHA_LAUNCH_KERNEL +#ifdef ENABLE_BF16 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h index 4f4f8c05d..898edb95c 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h @@ -50,77 +50,89 @@ template struct Multihead_attention_params_base { // The output buffer. Dimensions B x D. - T* out; + T* out = nullptr; // The input Qs and the associated bias. Dimensions B x D and D, resp. - const T *q, *q_bias; + const T *q = nullptr, *q_bias = nullptr; // The input Ks and the associated bias. Dimensions B x D and D, resp. - const T *k, *k_bias; + const T *k = nullptr, *k_bias = nullptr; // The input Vs and the associated bias. Dimensions B x D and D, resp. - const T *v, *v_bias; + const T *v = nullptr, *v_bias = nullptr; // The cache for the Ks. The size must be at least B x L x D. - T* k_cache; + T* k_cache = nullptr; // The cache for the Vs. The size must be at least B x L x D. - T* v_cache; + T* v_cache = nullptr; // The indirections to use for cache when beam sampling. - const int* cache_indir; + const int* cache_indir = nullptr; // Stride to handle the case when KQV is a single buffer - int stride; + int stride = 0; // The batch size. - int batch_size; + int batch_size = 0; // The beam width - int beam_width; + int beam_width = 0; // The sequence length. - int seq_length; + int memory_max_len = 0; // The number of heads (H). - int num_heads; + int num_heads = 0; // The hidden dimension per head (Dh). - int hidden_size_per_head; + int hidden_size_per_head = 0; // The per-head latent space reserved for rotary embeddings. - int rotary_embedding_dim = 0; + int rotary_embedding_dim = 0; + bool neox_rotary_style = false; // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? - int timestep; + int timestep = 0; // The current timestep of each sentences (support different timestep for different sentences) // The 1.f / sqrt(Dh). Computed on the host. - float inv_sqrt_dh; + float inv_sqrt_dh = 0.0f; // Used when we have some input context like gpt - const int* input_lengths; - int max_input_len; + const int* total_padding_tokens = nullptr; - const float* relative_attention_bias_float = nullptr; - const half* relative_attention_bias_half = nullptr; - int relative_attention_bias_stride; + const bool* masked_tokens = nullptr; + const int* prefix_prompt_lengths = nullptr; + int max_prefix_prompt_length = 0; + + const T* relative_attention_bias = nullptr; + int relative_attention_bias_stride = 0; }; template struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; // allows to exist attention eary - bool* finished; + bool* finished = nullptr; // required in case of cross attention // will need it here till if constexpr in c++17 - int* memory_length_per_sample; + int* memory_length_per_sample = nullptr; // required in case of masked attention with different length - const int* length_per_sample; + const int* length_per_sample = nullptr; }; template struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + // allows to exist attention eary - bool* finished; + bool* finished = nullptr; // required in case of cross attention - int* memory_length_per_sample; + int* memory_length_per_sample = nullptr; // required in case of masked attention with different length - const int* length_per_sample; + const int* length_per_sample = nullptr; }; template @@ -129,19 +141,27 @@ using Masked_multihead_attention_params = Multihead_attention_params; template using Cross_multihead_attention_params = Multihead_attention_params; +template +struct outputCrossAttentionParam { + // max decoder output length + int max_decoder_seq_len = 0; + T* cross_attention_out = nullptr; + bool is_return_cross_attentions = false; +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); + const cudaStream_t& stream); #endif void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); + const cudaStream_t& stream); #endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu new file mode 100644 index 000000000..73b8111e1 --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 128, 128, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 128, 128, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_160.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_160.cu new file mode 100644 index 000000000..2ad2fd890 --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_160.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 160, 256, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 160, 256, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_192.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_192.cu new file mode 100644 index 000000000..fade57f91 --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_192.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 192, 256, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 192, 256, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_224.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_224.cu new file mode 100644 index 000000000..7b54f42aa --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_224.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 224, 256, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 224, 256, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_256.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_256.cu new file mode 100644 index 000000000..96a19e13b --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_256.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 256, 256, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 256, 256, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_32.cu new file mode 100644 index 000000000..bde032890 --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_32.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 32, 32, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 32, 32, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_64.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_64.cu new file mode 100644 index 000000000..5033517df --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_64.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 64, 64, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 64, 64, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_80.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_80.cu new file mode 100644 index 000000000..c1931a19f --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_80.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 80, 128, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 80, 128, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_96.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_96.cu new file mode 100644 index 000000000..d4a563ede --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_96.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 96, 128, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 96, 128, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp new file mode 100644 index 000000000..46d533bbc --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp @@ -0,0 +1,1403 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +// #define MMHA_USE_HMMA_FOR_REDUCTION + +// Below are knobs to extend FP32 accumulation for higher FP16 accuracy + +// Does not seem to affect the accuracy that much +// #define MMHA_USE_FP32_ACUM_FOR_FMA + +// Seems to slightly improve the accuracy +#define MMHA_USE_FP32_ACUM_FOR_OUT + +#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) + // Does not seem to improve the accuracy + //#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#endif + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. +// +// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use +// 64, 128 and 256 threads per block. +// +// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to +// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The +// cache buffer helps with memory accesses and contains keys with bias. +// +// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and +// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The +// values for x are chosen to create chunks of 16 bytes. +// +// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs +// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At +// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an +// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. +// +// After that loop, a parallel softmax is computed across the different Q * K^T values stored in +// shared memory. +// +// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many +// timesteps are computed by loop iteration. As with the keys, the values are read from a cache +// except for the current timestep. The layout of the cache buffer for the values is much simpler +// as it is [B, H, L, Dh]. +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_ { +}; + +template<> +struct Qk_vec_ { + using Type = float; +}; +template<> +struct Qk_vec_ { + using Type = float2; +}; +template<> +struct Qk_vec_ { + using Type = float4; +}; +template<> +struct Qk_vec_ { + using Type = float4; +}; +template<> +struct Qk_vec_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_ { + using Type = uint2; +}; +template<> +struct Qk_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct Qk_vec_<__nv_bfloat16, 32> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_<__nv_bfloat16, 64> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_<__nv_bfloat16, 128> { + using Type = bf16_4_t; +}; +template<> +struct Qk_vec_<__nv_bfloat16, 256> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_ { +}; + +template<> +struct K_vec_ { + using Type = float; +}; +template<> +struct K_vec_ { + using Type = float2; +}; +template<> +struct K_vec_ { + using Type = float4; +}; +template<> +struct K_vec_ { + using Type = uint32_t; +}; +template<> +struct K_vec_ { + using Type = uint2; +}; +template<> +struct K_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct K_vec_<__nv_bfloat16, 4> { + using Type = __nv_bfloat162; +}; +template<> +struct K_vec_<__nv_bfloat16, 2> { + using Type = bf16_4_t; +}; +template<> +struct K_vec_<__nv_bfloat16, 1> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_ { +}; + +template<> +struct V_vec_ { + using Type = float; +}; +template<> +struct V_vec_ { + using Type = float2; +}; +template<> +struct V_vec_ { + using Type = float4; +}; +template<> +struct V_vec_ { + using Type = uint32_t; +}; +template<> +struct V_vec_ { + using Type = uint2; +}; +template<> +struct V_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct V_vec_<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct V_vec_<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct Qk_vec_acum_fp32_ { +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float4; +}; +// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template +struct V_vec_acum_fp32_ { +}; + +template<> +struct V_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +{ +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = K_vec; +#endif + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) + { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) +{ + float4 c; + float zero = 0.f; + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +// Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float& dst, float src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t& dst, float src) +{ + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t& dst, float2 src) +{ + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) +{ + dst = __float2bfloat16(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, Float4_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4& dst, Float8_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2& dst, float2 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4& dst, float4 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) +{ + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) +{ + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline size_t smem_size_in_bytes(const Multihead_attention_params& params, + int threads_per_value, + int threads_per_block) +{ + // The amount of shared memory needed to store the Q*K^T values in float. + const int max_timesteps = min(params.timestep, params.memory_max_len); + size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + + // The extra memory needed if we are not using floats for the final logits. + size_t logits_sz = 0; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(T) != 4) { + // TDOD + logits_sz = div_up(max_timesteps, 4) * 4 * sizeof(T); + } +#endif + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; + + size_t transpose_rotary_size = 0; + if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); + } + + // The max. + return max(max(softmax_sz, red_sz), transpose_rotary_size); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ constexpr uint32_t shfl_mask(int threads) +{ + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The type of the inputs. Supported types: float and half. + typename T, + // The hidden dimension per head. + int Dh, + int Dh_MAX, + // The number of threads per key. + int THREADS_PER_KEY, + // The number of threads per value. + int THREADS_PER_VALUE, + // The number of threads in a threadblock. + int THREADS_PER_BLOCK, + bool DO_CROSS_ATTENTION> +__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) +{ + + // Make sure the hidden dimension per head is a multiple of the number of threads per key. + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + // Make sure the hidden dimension per head is a multiple of the number of threads per value. + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + // The size of a warp. + constexpr int WARP_SIZE = 32; + // The number of warps in a threadblock. + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // Use smem_size_in_bytes (above) to determine the amount of shared memory. + extern __shared__ char smem_[]; + + // The shared memory for the Q*K^T values and partial logits in softmax. + float* qk_smem = reinterpret_cast(smem_); + + // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. + char* logits_smem_ = smem_; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(T) != 4) { + // TODO - change to tlength + const int max_timesteps = min(params.timestep, params.memory_max_len); + logits_smem_ += + (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + } + T* logits_smem = reinterpret_cast(logits_smem_); +#else + float* logits_smem = reinterpret_cast(logits_smem_); +#endif + + // The shared memory to do the final reduction for the output values. Reuse qk_smem. + T* out_smem = reinterpret_cast(smem_); + + // The shared memory buffers for the block-wide reductions. One for max, one for sum. + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec = typename Qk_vec_::Type; + + // Use alignment for safely casting the shared buffers as Qk_vec. + // Shared memory to store Q inputs. + __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; + + // This is one of the reasons we should have a separate kernel for cross attention + __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec = typename Qk_vec_::Type; + // The number of elements per vector. + constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // We will use block wide reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // The number of vectors per warp. + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread + // owns x elements, we have to decompose the linear index into chunks of x values and the posi- + // tion of the thread in that chunk. + + // The number of elements in a chunk of 16B (that's the x in the above formula). + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + // The number of K vectors in 16B. + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); + + // The batch/beam idx + const int bi = blockIdx.y; + if (params.finished != nullptr && params.finished[bi] == true) { + return; + } + // The beam idx + const int beami = bi % params.beam_width; + // The "beam-aware" batch idx + const int bbi = bi / params.beam_width; + // The head. + const int hi = blockIdx.x; + // Combine the batch and the head indices. + const int bhi = bi * params.num_heads + hi; + // Combine the "beam-aware" batch idx and the head indices. + const int bbhi = bbi * params.beam_width * params.num_heads + hi; + // The thread in the block. + const int tidx = threadIdx.x; + + const bool handle_k = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); + + // While doing the product Q*K^T for the different keys we track the max. + float qk_max = -FLT_MAX; + + float qk = 0.0F; + + int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; + + const size_t bi_seq_len_offset = bi * params.memory_max_len; + + // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : + (params.length_per_sample == nullptr) ? + params.timestep : + params.length_per_sample[bi] + params.max_prefix_prompt_length; + const int first_step = max(0, tlength + 1 - params.memory_max_len); + const int tlength_circ = tlength % params.memory_max_len; + + // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. + const bool is_masked = tidx >= QK_VECS_PER_WARP; + + // The offset in the Q and K buffer also accounts for the batch. + int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + // The offset in the bias buffer. + int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + + // Trigger the loads from the Q and K buffers. + Qk_vec q; + zero(q); + q = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + *reinterpret_cast(¶ms.q[qk_offset]) : + q; + + Qk_vec k; + zero(k); + if (DO_CROSS_ATTENTION) { + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength * QK_ELTS_IN_16B + ci; + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + *reinterpret_cast(¶ms.k_cache[offset]) : + k; + } + else { + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + *reinterpret_cast(¶ms.k[qk_offset]) : + k; + } + + // Trigger the loads from the Q and K bias buffers. + Qk_vec q_bias; + zero(q_bias); + q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? + *reinterpret_cast(¶ms.q_bias[qk_bias_offset]) : + q_bias; + + Qk_vec k_bias; + zero(k_bias); + if (handle_k) { + k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? + *reinterpret_cast(¶ms.k_bias[qk_bias_offset]) : + k_bias; + } + + // Computes the Q/K values with bias. + q = add(q, q_bias); + if (handle_k) { + k = add(k, k_bias); + } + // Padded len + const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; + if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { + if (handle_k) { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + else { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + } + else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + + T* q_smem = reinterpret_cast(smem_); + T* k_smem = q_smem + params.rotary_embedding_dim; + + const int half_rotary_dim = params.rotary_embedding_dim / 2; + const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; + const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts + + assert(half_rotary_dim % QK_VEC_SIZE == 0); + + if (do_rotary) { + *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; + + if (handle_k) { + *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; + } + } + + __syncthreads(); + + const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; + if (do_rotary) { + mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + + if (handle_k) { + mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + + mmha::apply_rotary_embedding( + q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len); + + mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + } + else { + mmha::apply_rotary_embedding( + q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep); + } + mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + } + + __syncthreads(); + + if (do_rotary) { + q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); + if (handle_k) { + k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); + } + } + + __syncthreads(); + } + + if (!is_masked) { + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + + // Store Dh values of k_bias into smem, since will need to add later + // if params.timestep == 0 + if (DO_CROSS_ATTENTION && params.timestep == 0) { + *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; + } + + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength_circ * QK_ELTS_IN_16B + ci; + + if (handle_k) { + // Trigger the stores to global memory. + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[offset]) = k; + } + } + + // Compute \sum_i Q[i] * K^T[i] for the current timestep. +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; +#else + using Qk_vec_acum = Qk_vec; +#endif + qk = dot(q, k); + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if (tidx == 0) { + // Normalize qk. + qk *= params.inv_sqrt_dh; + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + (tlength - padd_len) * params.relative_attention_bias_stride + + (tlength - padd_len)]); + } + qk_max = qk; + qk_smem[tlength - first_step] = qk; + // qk_smem[params.timestep] = qk; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The type of queries and keys for the math in the Q*K^T product. + using K_vec = typename K_vec_::Type; + // The number of elements per vector. + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + // The number of elements per thread. + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + // The number of vectors per thread. + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + // The position the first key loaded by each thread from the cache buffer (for this B * H). + int ko = tidx / THREADS_PER_KEY; + // The position of the thread in the chunk of keys. + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); + + // Load the Q values from shared memory. The values are reused during the loop on K. + K_vec q_vec[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + + K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; + if (DO_CROSS_ATTENTION && params.timestep == 0) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + } + + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + // The base pointer for the key in the cache buffer. + T* k_cache = ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; + + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). + // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + // prefix prompt length if has + const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; + + // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. + const bool has_beams = params.cache_indir != nullptr; + const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; + + for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + + // The keys loaded from the key cache. + K_vec k[K_VECS_PER_THREAD]; + K_vec k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.memory_max_len + ti_circ; + // if( ti < params.timestep ) { + const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); + if (ti < tlength) { + if (!within_bounds) { + k[ii] = k_vec_zero; + } + else { + if (has_beams) { + const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; + k[ii] = *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]); + } + else { + k[ii] = *reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]); + } + } + // add bias and update k_cache + if (DO_CROSS_ATTENTION && params.timestep == 0) { + k[ii] = add(k[ii], k_bias_vec[ii]); + if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { + *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; + } + } + } + } + + // Perform the dot product and normalize qk. + // + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + ti]); + } + qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = qk; + } + } + +// Perform the final reduction to compute the max inside each warp. +// +// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the +// group so it's not needed to run the reduction inside the group (again). +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + const int warp = tidx / WARP_SIZE; + const int lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Compute the logits and start the sum. + float sum = 0.f; + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); + sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // Normalize the logits. + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + const size_t cross_attention_out_offset = + params.is_return_cross_attentions ? + bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : + 0; + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + float logit = qk_smem[ti - first_step] * inv_sum; + if (params.is_return_cross_attentions) { + params.cross_attention_out[cross_attention_out_offset + ti] = logit; + } + convert_from_float(logits_smem[ti - first_step], logit); + } + + // Put Values part below so we leverage __syncthreads + // from the previous step + + // The number of elements per vector. + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + // A vector of V elements for the current timestep. + using V_vec = typename V_vec_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + + // The base pointer for the value in the cache buffer. + T* v_cache = ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; + + // The number of values processed per iteration of the loop. + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + // One group of threads computes the product(s) for the current timestep. + V_vec v_bias; + zero(v_bias); + // if( vo == params.timestep % V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + if (handle_k) { + if (vo == tlength % V_PER_ITER) { + // Trigger the loads from the V bias buffer. + if (params.v_bias != nullptr) { + v_bias = *reinterpret_cast(¶ms.v_bias[hi * Dh + vi]); + } + if (DO_CROSS_ATTENTION) { + *reinterpret_cast(&bias_smem[vi]) = v_bias; + } + } + } + } + + // From previous, before values, step + // Also make sure the logits are in shared memory. + __syncthreads(); + + // Values continued +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec; +#endif + // The partial outputs computed by each thread. + V_vec_acum out; + zero(out); + + // Loop over the timesteps to compute the partial outputs. + // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + + // Fetch offset based on cache_indir when beam sampling + const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; + const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh; + // Load the values from the cache. + V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh]); + if (DO_CROSS_ATTENTION && params.timestep == 0) { + v = add(v, *reinterpret_cast(&bias_smem[vi])); + *reinterpret_cast(&v_cache[ti * Dh]) = v; + } + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else + T logit = logits_smem[ti - first_step]; + + // Update the partial sums. + out = fma(logit, v, out); +#endif + } + } + + // One group of threads computes the product(s) for the current timestep. + // if( vo == params.timestep % V_PER_ITER ) { + if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { + + V_vec v; + if (DO_CROSS_ATTENTION) { + v = *reinterpret_cast(&v_cache[tlength * Dh]); + } + else { + // Trigger the loads from the V buffer. + v = *reinterpret_cast(¶ms.v[qkv_base_offset + vi]); + // Trigger the loads from the V bias buffer. + // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + } + + // Compute the V values with bias. + if (handle_k) { + v = add(v, v_bias); + + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = v; + } + + // Initialize the output value with the current timestep. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + // out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); +#else + // out = fma(logits_smem[params.timestep], v, out); + out = fma(logits_smem[tlength - first_step], v, out); +#endif + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); +#else + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace mmha + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); \ No newline at end of file diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h b/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h index 780b6de96..fd2cf804c 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h @@ -219,6 +219,22 @@ inline __device__ float2 half2_to_float2(uint32_t v) //////////////////////////////////////////////////////////////////////////////////////////////////// +inline __device__ float add(float a, uint16_t b) +{ + return a + half_to_float(b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float add(float a, __nv_bfloat16 b) +{ + return a + __bfloat162float(b); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ float2 add(uint32_t a, float2 fb) { float2 fa = half2_to_float2(a); @@ -392,7 +408,7 @@ inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { uint32_t s = h0_h0(a); - uint2 d; + uint2 d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); return d; @@ -415,7 +431,7 @@ inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { uint32_t s = h0_h0(a); - uint4 d; + uint4 d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); d.z = fma(s, b.z, c.z); @@ -463,7 +479,7 @@ inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { uint32_t s = h0_h0(a); - Float4_ fd; + Float4_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); return fd; @@ -486,7 +502,7 @@ inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { uint32_t s = h0_h0(a); - Float8_ fd; + Float8_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); fd.z = fma(s, b.z, fc.z); @@ -522,7 +538,7 @@ inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { __nv_bfloat162 s = bf162bf162(a); - bf16_4_t d; + bf16_4_t d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); return d; @@ -545,7 +561,7 @@ inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { __nv_bfloat162 s = bf162bf162(a); - bf16_8_t d; + bf16_8_t d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); d.z = fma(s, b.z, c.z); @@ -591,7 +607,7 @@ inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { __nv_bfloat162 s = bf162bf162(a); - Float4_ fd; + Float4_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); return fd; @@ -614,7 +630,7 @@ inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { __nv_bfloat162 s = bf162bf162(a); - Float8_ fd; + Float8_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); fd.z = fma(s, b.z, fc.z); @@ -728,7 +744,7 @@ template<> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); - uint2 c; + uint2 c; c.x = mul(s, b.x); c.y = mul(s, b.y); return c; @@ -753,7 +769,7 @@ template<> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); - uint4 c; + uint4 c; c.x = mul(s, b.x); c.y = mul(s, b.y); c.z = mul(s, b.z); @@ -806,7 +822,7 @@ template<> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); - Float4_ fc; + Float4_ fc; fc.x = mul(s, b.x); fc.y = mul(s, b.y); return fc; @@ -831,7 +847,7 @@ template<> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); - Float8_ fc; + Float8_ fc; fc.x = mul(s, b.x); fc.y = mul(s, b.y); fc.z = mul(s, b.z); @@ -885,7 +901,7 @@ template<> inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); - bf16_4_t c; + bf16_4_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); return c; @@ -910,7 +926,7 @@ template<> inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); - bf16_8_t c; + bf16_8_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); @@ -963,7 +979,7 @@ template<> inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); - Float4_ fc; + Float4_ fc; fc.x = mul(s, b.x); fc.y = mul(s, b.y); return fc; @@ -988,7 +1004,7 @@ template<> inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); - Float8_ fc; + Float8_ fc; fc.x = mul(s, b.x); fc.y = mul(s, b.y); fc.z = mul(s, b.z); @@ -1069,12 +1085,12 @@ inline __device__ float sum(uint4 v) { #if 1 uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); + c = add(c, v.z); + c = add(c, v.w); #else uint32_t c = add(v.x, v.y); uint32_t d = add(v.z, v.w); - c = add(c, d); + c = add(c, d); #endif return sum(c); } @@ -1123,7 +1139,7 @@ inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { - T raw; + T raw; uint32_t words[WORDS]; } tmp; #pragma unroll @@ -1151,7 +1167,7 @@ inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) { - float2 fv = half2_to_float2(v); + float2 fv = half2_to_float2(v); float2 rot_fv = rotary_embedding_transform(fv, coef); return float2_to_half2(rot_fv); } @@ -1159,7 +1175,7 @@ inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const fl #ifdef ENABLE_BF16 inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) { - float2 fv = bf1622float2(v); + float2 fv = bf1622float2(v); float2 rot_fv = rotary_embedding_transform(fv, coef); return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); } @@ -1181,7 +1197,7 @@ inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_ return; } const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); - q = rotary_embedding_transform(q, coef); + q = rotary_embedding_transform(q, coef); } inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step) @@ -1190,8 +1206,8 @@ inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int return; } const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); } inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step) @@ -1200,11 +1216,11 @@ inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_ return; } - Float4_& q_ = *reinterpret_cast(&q); + Float4_& q_ = *reinterpret_cast(&q); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); - q_.x = rotary_embedding_transform(q_.x, coef0); + q_.x = rotary_embedding_transform(q_.x, coef0); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); - q_.y = rotary_embedding_transform(q_.y, coef1); + q_.y = rotary_embedding_transform(q_.y, coef1); } inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step) @@ -1213,14 +1229,14 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int return; } - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); + Float4_& q_ = *reinterpret_cast(&q); + Float4_& k_ = *reinterpret_cast(&k); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); + q_.x = rotary_embedding_transform(q_.x, coef0); + k_.x = rotary_embedding_transform(k_.x, coef0); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); + q_.y = rotary_embedding_transform(q_.y, coef1); + k_.y = rotary_embedding_transform(k_.y, coef1); } inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step) @@ -1229,7 +1245,7 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embe return; } const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); - q = rotary_embedding_transform(q, coef); + q = rotary_embedding_transform(q, coef); } inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step) @@ -1238,8 +1254,8 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, return; } const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); } inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step) @@ -1248,9 +1264,9 @@ inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_d return; } const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); - q.x = rotary_embedding_transform(q.x, coef0); + q.x = rotary_embedding_transform(q.x, coef0); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); - q.y = rotary_embedding_transform(q.y, coef1); + q.y = rotary_embedding_transform(q.y, coef1); } inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step) @@ -1259,11 +1275,11 @@ inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int r return; } const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); } inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step) @@ -1272,13 +1288,13 @@ inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_d return; } const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); - q.x = rotary_embedding_transform(q.x, coef0); + q.x = rotary_embedding_transform(q.x, coef0); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); - q.y = rotary_embedding_transform(q.y, coef1); + q.y = rotary_embedding_transform(q.y, coef1); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); - q.z = rotary_embedding_transform(q.z, coef2); + q.z = rotary_embedding_transform(q.z, coef2); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); - q.w = rotary_embedding_transform(q.w, coef3); + q.w = rotary_embedding_transform(q.w, coef3); } inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step) @@ -1287,17 +1303,17 @@ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int r return; } const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); } #ifdef ENABLE_BF16 @@ -1307,7 +1323,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro return; } const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); - q = rotary_embedding_transform(q, coef); + q = rotary_embedding_transform(q, coef); } inline __device__ void @@ -1317,8 +1333,8 @@ apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_em return; } const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); } inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step) @@ -1327,9 +1343,9 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embe return; } const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); - q.x = rotary_embedding_transform(q.x, coef0); + q.x = rotary_embedding_transform(q.x, coef0); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); - q.y = rotary_embedding_transform(q.y, coef1); + q.y = rotary_embedding_transform(q.y, coef1); } inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step) @@ -1338,11 +1354,11 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, return; } const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); } inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step) @@ -1351,13 +1367,13 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embe return; } const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); - q.x = rotary_embedding_transform(q.x, coef0); + q.x = rotary_embedding_transform(q.x, coef0); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); - q.y = rotary_embedding_transform(q.y, coef1); + q.y = rotary_embedding_transform(q.y, coef1); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); - q.z = rotary_embedding_transform(q.z, coef2); + q.z = rotary_embedding_transform(q.z, coef2); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); - q.w = rotary_embedding_transform(q.w, coef3); + q.w = rotary_embedding_transform(q.w, coef3); } inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step) @@ -1366,18 +1382,286 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, return; } const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} +#endif // ENABLE_BF16 + +template +__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); + +template<> +__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch) +{ + return; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = smem[transpose_idx]; + tmp.u16[1] = smem[smem_pitch + transpose_idx]; + + vec = tmp.u32; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp_1, tmp_2; + tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + union { + uint2 u32x2; + uint16_t u16[4]; + } tmp_3; + tmp_3.u16[0] = tmp_1.u16[0]; + tmp_3.u16[1] = tmp_2.u16[0]; + tmp_3.u16[2] = tmp_1.u16[1]; + tmp_3.u16[3] = tmp_2.u16[1]; + + vec = tmp_3.u32x2; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint64_t u64; + uint16_t u16[4]; + } tmp_1, tmp_2; + tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + union { + uint4 u32x4; + uint16_t u16[8]; + } tmp_3; + tmp_3.u16[0] = tmp_1.u16[0]; + tmp_3.u16[1] = tmp_2.u16[0]; + tmp_3.u16[2] = tmp_1.u16[1]; + tmp_3.u16[3] = tmp_2.u16[1]; + tmp_3.u16[4] = tmp_1.u16[2]; + tmp_3.u16[5] = tmp_2.u16[2]; + tmp_3.u16[6] = tmp_1.u16[3]; + tmp_3.u16[7] = tmp_2.u16[3]; + + vec = tmp_3.u32x4; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + __nv_bfloat16 bf16[2]; + } tmp_1, tmp_2; + tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; + vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; +} + +template<> +__device__ __inline__ void +vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + union { + uint64_t u64; + __nv_bfloat16 bf16[4]; + } tmp_1, tmp_2; + tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; + vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; + vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]}; + vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]}; } #endif // ENABLE_BF16 -} // namespace mmha \ No newline at end of file +template<> +__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch) +{ + vec.x = smem[transpose_idx]; + vec.z = smem[transpose_idx + 1]; + vec.y = smem[smem_pitch + transpose_idx]; + vec.w = smem[smem_pitch + transpose_idx + 1]; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + half u16[2]; + } tmp; + tmp.u16[0] = smem[transpose_idx]; + tmp.u16[1] = smem[smem_pitch + transpose_idx]; + + vec = tmp.u32; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + vec.x = smem[transpose_idx]; + vec.y = smem[smem_pitch + transpose_idx]; +} +#endif + +template<> +__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch) +{ + vec.x = smem[transpose_idx]; + vec.y = smem[smem_pitch + transpose_idx]; +} + +template +__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); + +template<> +__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch) +{ + return; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + return; +} + +template<> +__device__ __inline__ void +write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + return; +} +#endif + +template<> +__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint64_t u64; + uint16_t u16[4]; + } tmp_1, tmp_2; + + union { + uint4 u32x4; + uint16_t u16[8]; + } tmp_3; + tmp_3.u32x4 = vec; + tmp_1.u16[0] = tmp_3.u16[0]; + tmp_2.u16[0] = tmp_3.u16[1]; + tmp_1.u16[1] = tmp_3.u16[2]; + tmp_2.u16[1] = tmp_3.u16[3]; + tmp_1.u16[2] = tmp_3.u16[4]; + tmp_2.u16[2] = tmp_3.u16[5]; + tmp_1.u16[3] = tmp_3.u16[6]; + tmp_2.u16[3] = tmp_3.u16[7]; + + *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u64; + *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u64; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp_1, tmp_2; + + union { + uint2 u32x2; + uint16_t u16[4]; + } tmp_3; + tmp_3.u32x2 = vec; + tmp_1.u16[0] = tmp_3.u16[0]; + tmp_2.u16[0] = tmp_3.u16[1]; + tmp_1.u16[1] = tmp_3.u16[2]; + tmp_2.u16[1] = tmp_3.u16[3]; + + *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u32; + *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u32; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = vec; + + smem[transpose_idx] = tmp.u16[0]; + smem[smem_pitch + transpose_idx] = tmp.u16[1]; +} + +template<> +__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch) +{ + smem[transpose_idx] = vec.x; + smem[transpose_idx + 1] = vec.z; + smem[smem_pitch + transpose_idx] = vec.y; + smem[smem_pitch + transpose_idx + 1] = vec.w; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + half u16[2]; + } tmp; + + tmp.u32 = vec; + smem[transpose_idx] = tmp.u16[0]; + smem[smem_pitch + transpose_idx] = tmp.u16[1]; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + smem[transpose_idx] = vec.x; + smem[smem_pitch + transpose_idx] = vec.y; +} +#endif + +template<> +__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch) +{ + smem[transpose_idx] = vec.x; + smem[smem_pitch + transpose_idx] = vec.y; +} + +} // namespace mmha diff --git a/src/fastertransformer/kernels/decoding_kernels.cu b/src/fastertransformer/kernels/decoding_kernels.cu index ccca4319d..93e7ed932 100644 --- a/src/fastertransformer/kernels/decoding_kernels.cu +++ b/src/fastertransformer/kernels/decoding_kernels.cu @@ -23,20 +23,20 @@ namespace fastertransformer { static const float HALF_FLT_MAX = 65504.F; template -__global__ void decodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - T* cum_log_probs, +__global__ void decodingInitialize(bool* finished, + int* sequence_length, + int* word_ids, + T* cum_log_probs, const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length) + const int batch_size, + const int beam_width, + const int max_input_length) { - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? (T)HALF_FLT_MAX : (T)1e20f; // BF16 and FP32 have the same dynamic range + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? (T)HALF_FLT_MAX : (T)1e20f; // BF16 and FP32 have the same dynamic range for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width; index += blockDim.x * gridDim.x) { - finished[index] = false; + finished[index] = false; sequence_length[index] = max_input_length; if (word_ids != nullptr) { word_ids[index] = sentence_ids[index / beam_width]; @@ -46,14 +46,14 @@ __global__ void decodingInitialize(bool* finished, } template -void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - T* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, +void invokeDecodingInitialize(bool* finished, + int* sequence_length, + int* word_ids, + T* cum_log_probs, + const int* sentence_ids, + const int batch_size, + const int beam_width, + const int max_input_length, cudaStream_t stream) { dim3 grid((int)ceil(batch_size * beam_width * 1.0 / 256)); @@ -63,83 +63,92 @@ void invokeDecodingInitialize(bool* finished, finished, sequence_length, word_ids, cum_log_probs, sentence_ids, batch_size, beam_width, max_input_length); } -template void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - float* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, +template void invokeDecodingInitialize(bool* finished, + int* sequence_length, + int* word_ids, + float* cum_log_probs, + const int* sentence_ids, + const int batch_size, + const int beam_width, + const int max_input_length, cudaStream_t stream); -template void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - half* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, +template void invokeDecodingInitialize(bool* finished, + int* sequence_length, + int* word_ids, + half* cum_log_probs, + const int* sentence_ids, + const int batch_size, + const int beam_width, + const int max_input_length, cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, +template void invokeDecodingInitialize(bool* finished, + int* sequence_length, + int* word_ids, __nv_bfloat16* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, - cudaStream_t stream); + const int* sentence_ids, + const int batch_size, + const int beam_width, + const int max_input_length, + cudaStream_t stream); #endif template -__global__ void embeddingLookupPosEncoding(T* from_tensor, - const T* embedding_table, - const T* position_encoding, +__global__ void embeddingLookupPosEncoding(T* from_tensor, + const T* embedding_table, + const T* position_encoding, const int* all_ids, + const int* padding_count, const int* input_lengths, - const int local_batch_size, - const int hidden_units, - const int step, - const int max_input_length, - const int batch_size, - const int ite, - const T scale) + const int local_batch_size, + const int hidden_units, + const int step, + const int max_input_length, + const int batch_size, + const int ite, + const T scale) { // 1. lookup from embedding table // 2. multiply scale // 3. add the position encoding const int id_offset = step * batch_size + ite * local_batch_size; + const bool use_padding_count = padding_count != nullptr; + const bool use_input_len = input_lengths != nullptr; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < local_batch_size * hidden_units; index += blockDim.x * gridDim.x) { - const int row_index = index / hidden_units; - const int col_index = index % hidden_units; - const int step_offset = input_lengths == nullptr ? - step * hidden_units : - (step - max_input_length + input_lengths[row_index]) * hidden_units; + const int row_index = index / hidden_units; + const int col_index = index % hidden_units; + int step_offset = step; + if (use_padding_count) { + step_offset -= padding_count[row_index]; + } + else if (use_input_len) { + step_offset -= max_input_length - input_lengths[row_index]; + } + step_offset *= hidden_units; + T val = embedding_table[all_ids[id_offset + row_index] * hidden_units + col_index] * scale; - val = val + position_encoding[step_offset + col_index]; + val = val + position_encoding[step_offset + col_index]; from_tensor[index] = val; } } -// No aboluste position embedding +// No absolute position embedding template -__global__ void embeddingLookup(T* from_tensor, - const T* embedding_table, +__global__ void embeddingLookup(T* from_tensor, + const T* embedding_table, const int* all_ids, - const int local_batch_size, - const int hidden_units, - const int step, - const int max_input_length, - const int batch_size, - const int ite, - const T scale) + const int local_batch_size, + const int hidden_units, + const int step, + const int batch_size, + const int ite, + const T scale) { // 1. lookup from embedding table // 2. multiply scale @@ -149,24 +158,26 @@ __global__ void embeddingLookup(T* from_tensor, index += blockDim.x * gridDim.x) { const int row_index = index / hidden_units; const int col_index = index % hidden_units; - T val = embedding_table[all_ids[id_offset + row_index] * hidden_units + col_index] * scale; - from_tensor[index] = val; + T val = embedding_table[all_ids[id_offset + row_index] * hidden_units + col_index] * scale; + from_tensor[index] = val; } } +/* Adapter function for invokeEmbeddingLookupPosEncoding{PadCount,InputLen} */ template -void invokeEmbeddingLookupPosEncoding(T* from_tensor, - const T* embedding_table, - const T* position_encoding, - const int* all_ids, - const int* input_lengths, - const int local_batch_size, - const int hidden_units, - const T scale, - const int step, - const int max_input_length, - const int batch_size, - const int ite, +void invokeEmbeddingLookupPosEncoding(T* from_tensor, + const T* embedding_table, + const T* position_encoding, + const int* all_ids, + const int* padding_count, + const int* input_lengths, + const int local_batch_size, + const int hidden_units, + const T scale, + const int step, + const int max_input_length, + const int batch_size, + const int ite, cudaStream_t stream) { dim3 grid(min(local_batch_size, 65536)); @@ -176,6 +187,7 @@ void invokeEmbeddingLookupPosEncoding(T* from_tensor, embedding_table, position_encoding, all_ids, + padding_count, input_lengths, local_batch_size, hidden_units, @@ -186,68 +198,66 @@ void invokeEmbeddingLookupPosEncoding(T* from_tensor, scale); } else { - embeddingLookup<<>>(from_tensor, - embedding_table, - all_ids, - local_batch_size, - hidden_units, - step, - max_input_length, - batch_size, - ite, - scale); + embeddingLookup<<>>( + from_tensor, embedding_table, all_ids, local_batch_size, hidden_units, step, batch_size, ite, scale); } } -template void invokeEmbeddingLookupPosEncoding(float* from_tensor, - const float* embedding_table, - const float* position_encoding, - const int* all_ids, - const int* input_lengths, - const int local_batch_size, - const int hidden_units, - const float scale, - const int step, - const int max_input_length, - const int batch_size, - const int ite, - cudaStream_t stream); - -template void invokeEmbeddingLookupPosEncoding(half* from_tensor, - const half* embedding_table, - const half* position_encoding, - const int* all_ids, - const int* input_lengths, - const int local_batch_size, - const int hidden_units, - const half scale, - const int step, - const int max_input_length, - const int batch_size, - const int ite, - cudaStream_t stream); +template +void invokeEmbeddingLookupPosEncodingPadCount(T* from_tensor, + const T* embedding_table, + const T* position_encoding, + const int* all_ids, + const int* pad_count, + const int local_batch_size, + const int hidden_units, + const T scale, + const int step, + const int batch_size, + const int ite, + cudaStream_t stream) +{ + invokeEmbeddingLookupPosEncoding(from_tensor, + embedding_table, + position_encoding, + all_ids, + pad_count, + nullptr, + local_batch_size, + hidden_units, + scale, + step, + 0, + batch_size, + ite, + stream); +} +#define INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(T) \ + template void invokeEmbeddingLookupPosEncodingPadCount(T* from_tensor, \ + const T* embedding_table, \ + const T* position_encoding, \ + const int* all_ids, \ + const int* pad_count, \ + const int local_batch_size, \ + const int hidden_units, \ + const T scale, \ + const int step, \ + const int batch_size, \ + const int ite, \ + cudaStream_t stream) +INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(float); +INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(half); #ifdef ENABLE_BF16 -template void invokeEmbeddingLookupPosEncoding(__nv_bfloat16* from_tensor, - const __nv_bfloat16* embedding_table, - const __nv_bfloat16* position_encoding, - const int* all_ids, - const int* input_lengths, - const int local_batch_size, - const int hidden_units, - const __nv_bfloat16 scale, - const int step, - const int max_input_length, - const int batch_size, - const int ite, - cudaStream_t stream); +INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(__nv_bfloat16); #endif +#undef INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT template -__global__ void paddingEmbedding(T* padded_embedding_kernel, - T* padded_embedding_bias, - const T* embedding_kernel, - const T* embedding_bias, +__global__ void paddingEmbedding(T* padded_embedding_kernel, + T* padded_embedding_bias, + const T* embedding_kernel, + const T* embedding_bias, const int hidden_unit, const int vocab_size, const int vocab_size_padded) @@ -275,13 +285,13 @@ __global__ void paddingEmbedding(T* padded_embedding_kernel, } template -void invokePaddingEmbedding(T* padded_embedding_kernel, - T* padded_embedding_bias, - const T* embedding_kernel, - const T* embedding_bias, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, +void invokePaddingEmbedding(T* padded_embedding_kernel, + T* padded_embedding_bias, + const T* embedding_kernel, + const T* embedding_bias, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, cudaStream_t stream) { dim3 block(512); @@ -295,27 +305,37 @@ void invokePaddingEmbedding(T* padded_embedding_kernel, vocab_size_padded); } -template void invokePaddingEmbedding(float* padded_embedding_kernel, - float* padded_embedding_bias, +template void invokePaddingEmbedding(float* padded_embedding_kernel, + float* padded_embedding_bias, const float* embedding_kernel, const float* embedding_bias, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, cudaStream_t stream); -template void invokePaddingEmbedding(half* padded_embedding_kernel, - half* padded_embedding_bias, - const half* embedding_kernel, - const half* embedding_bias, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, +template void invokePaddingEmbedding(half* padded_embedding_kernel, + half* padded_embedding_bias, + const half* embedding_kernel, + const half* embedding_bias, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokePaddingEmbedding(__nv_bfloat16* padded_embedding_kernel, + __nv_bfloat16* padded_embedding_bias, + const __nv_bfloat16* embedding_kernel, + const __nv_bfloat16* embedding_bias, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, + cudaStream_t stream); +#endif template -__global__ void paddingEmbeddingKernel(T* padded_embedding_kernel, - const T* embedding_kernel, +__global__ void paddingEmbeddingKernel(T* padded_embedding_kernel, + const T* embedding_kernel, const int hidden_unit, const int vocab_size, const int vocab_size_padded) @@ -334,11 +354,11 @@ __global__ void paddingEmbeddingKernel(T* padded_embedding_kernel, } template -void invokePaddingEmbeddingKernel(T* padded_embedding_kernel, - const T* embedding_kernel, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, +void invokePaddingEmbeddingKernel(T* padded_embedding_kernel, + const T* embedding_kernel, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, cudaStream_t stream) { dim3 block(512); @@ -347,22 +367,32 @@ void invokePaddingEmbeddingKernel(T* padded_embedding_kernel, padded_embedding_kernel, embedding_kernel, hidden_unit, vocab_size, vocab_size_padded); } -template void invokePaddingEmbeddingKernel(float* padded_embedding_kernel, +template void invokePaddingEmbeddingKernel(float* padded_embedding_kernel, const float* embedding_kernel, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, cudaStream_t stream); -template void invokePaddingEmbeddingKernel(half* padded_embedding_kernel, - const half* embedding_kernel, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, +template void invokePaddingEmbeddingKernel(half* padded_embedding_kernel, + const half* embedding_kernel, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokePaddingEmbeddingKernel(__nv_bfloat16* padded_embedding_kernel, + const __nv_bfloat16* embedding_kernel, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, + cudaStream_t stream); +#endif + __global__ void gatherTree(gatherTreeParam param) { + // PREFIX SOFT PROMPT // beam: have six parts // [prompt | input | input_padding | prompt_padding | generated output | padding (use end_token)] // parents: have five parts @@ -372,27 +402,41 @@ __global__ void gatherTree(gatherTreeParam param) // need to transpose to output_ids [bs, beam_width, input_length + requested_output_length] // max_input_length: input + input_padding + prompt_padding + // P/PROMPT TUNING + // NOTE: input (real ids | prompt virtual ids) have already been preprocessed during embedding lookup, no prompt + // templates now beam: [input (real ids | prompt virtual ids) | input_padding | generated output | padding (use + // end_token)] parents: [input (real ids | prompt virtual ids) | input_padding | generated output | padding (use + // 0)] step_ids: need to remove virtual prompt ids in input ids + // the shape is [input_length (real input length, prompt length) + requested_output_length, bs, beam_width] + // need to transpose to output_ids [bs, beam_width, input_length + requested_output_length] + // max_input_length: input (real ids | prompt virtual ids) + input_padding + const int max_input_length = param.input_lengths == nullptr ? 0 : param.max_input_length; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < param.batch_size * param.beam_width; i += gridDim.x * blockDim.x) { const int batch = i / param.beam_width; - const int beam = i % param.beam_width; + const int beam = i % param.beam_width; const int prompt_len = param.prefix_soft_prompt_lengths == nullptr ? 0 : param.prefix_soft_prompt_lengths[batch]; - const int input_len = param.input_lengths == nullptr ? 0 : param.input_lengths[i]; + int input_len = param.input_lengths == nullptr ? 0 : param.input_lengths[i]; + // virtual prompts mean the prompt embedded in input ids (with prompt templates) [p/prompt tuning] + const int virtual_prompt_length = + param.p_prompt_tuning_prompt_lengths == nullptr ? 0 : param.p_prompt_tuning_prompt_lengths[batch]; + // real input length (without virtual prompts) [p/prompt tuning] + input_len -= virtual_prompt_length; const int* parent_ids = param.parent_ids; - const int* step_ids = param.step_ids; + const int* step_ids = param.step_ids; // TODO(bhsueh) optimize the reduce_max operation for large beam_width - int max_len = -1; - int selected_beam_index = 0; + int max_len = -1; + // int selected_beam_index = 0; for (int j = 0; j < param.beam_width; j++) { int tmp_len = __ldg(param.max_sequence_lengths + batch * param.beam_width + j); if (tmp_len > max_len) { - max_len = tmp_len; - selected_beam_index = j; + max_len = tmp_len; + // selected_beam_index = j; } } const int max_seq_len_b = min(param.max_time, max_len); @@ -404,10 +448,10 @@ __global__ void gatherTree(gatherTreeParam param) (param.batch_size * param.beam_width * (time_ix) + param.beam_width * batch + (beam_ix)) const int padding_offset_and_prompt_offset = max_input_length - input_len + prompt_len; - const int initial_tgt_ix = GET_IX(max_seq_len_b - 1 - padding_offset_and_prompt_offset, beam); - const int initial_parent_ix = GET_IX(max_seq_len_b - 1, beam); - param.beams[initial_tgt_ix] = __ldg(step_ids + initial_parent_ix); - int parent = parent_ids == nullptr ? 0 : __ldg(parent_ids + initial_parent_ix) % param.beam_width; + const int initial_tgt_ix = GET_IX(max_seq_len_b - 1 - padding_offset_and_prompt_offset, beam); + const int initial_parent_ix = GET_IX(max_seq_len_b - 1, beam); + param.beams[initial_tgt_ix] = __ldg(step_ids + initial_parent_ix); + int parent = parent_ids == nullptr ? 0 : __ldg(parent_ids + initial_parent_ix) % param.beam_width; bool found_bad = false; for (int level = max_seq_len_b - 2; level >= 0; --level) { @@ -415,13 +459,13 @@ __global__ void gatherTree(gatherTreeParam param) continue; } int tgt_level = level >= max_input_length ? level - padding_offset_and_prompt_offset : level - prompt_len; - const int level_beam_ix = GET_IX(tgt_level, beam); + const int level_beam_ix = GET_IX(tgt_level, beam); const int level_parent_ix = GET_IX(level, parent); if (parent < 0 || parent > param.beam_width) { // param.beams[level_beam_ix] = -1; param.beams[level_beam_ix] = param.end_tokens[batch]; - parent = -1; - found_bad = true; + parent = -1; + found_bad = true; } else { param.beams[level_beam_ix] = __ldg(step_ids + level_parent_ix); @@ -464,64 +508,64 @@ __global__ void gatherTree(gatherTreeParam param) } } -void invokeGatherTree(int* beams, - int* max_sequence_lengths, - const int max_time, - const int batch_size, - const int beam_width, - const int* step_ids, - const int* parent_ids, - const int* end_tokens, +void invokeGatherTree(int* beams, + int* max_sequence_lengths, + const int max_time, + const int batch_size, + const int beam_width, + const int* step_ids, + const int* parent_ids, + const int* end_tokens, cudaStream_t stream) { gatherTreeParam param; - param.beams = beams; - param.max_sequence_lengths = max_sequence_lengths; - param.max_time = max_time; - param.batch_size = batch_size; - param.beam_width = beam_width; - param.step_ids = step_ids; - param.parent_ids = parent_ids; - param.end_tokens = end_tokens; - param.max_input_length = 1; + param.beams = beams; + param.max_sequence_lengths = max_sequence_lengths; + param.max_time = max_time; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.step_ids = step_ids; + param.parent_ids = parent_ids; + param.end_tokens = end_tokens; + param.max_input_length = 1; param.prefix_soft_prompt_lengths = nullptr; - param.stream = stream; + param.stream = stream; invokeGatherTree(param); } -void invokeGatherTree(int* beams, - int* max_sequence_lengths, - const int max_time, - const int batch_size, - const int beam_width, - const int* step_ids, - const int* parent_ids, - const int* end_tokens, - const int max_input_length, +void invokeGatherTree(int* beams, + int* max_sequence_lengths, + const int max_time, + const int batch_size, + const int beam_width, + const int* step_ids, + const int* parent_ids, + const int* end_tokens, + const int max_input_length, cudaStream_t stream) { gatherTreeParam param; - param.beams = beams; - param.max_sequence_lengths = max_sequence_lengths; - param.max_time = max_time; - param.batch_size = batch_size; - param.beam_width = beam_width; - param.step_ids = step_ids; - param.parent_ids = parent_ids; - param.end_tokens = end_tokens; - param.max_input_length = max_input_length; + param.beams = beams; + param.max_sequence_lengths = max_sequence_lengths; + param.max_time = max_time; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.step_ids = step_ids; + param.parent_ids = parent_ids; + param.end_tokens = end_tokens; + param.max_input_length = max_input_length; param.prefix_soft_prompt_lengths = nullptr; - param.stream = stream; + param.stream = stream; invokeGatherTree(param); } void invokeGatherTree(gatherTreeParam param) { - int batchbeam = param.batch_size * param.beam_width; + int batchbeam = param.batch_size * param.beam_width; dim3 grid(1), block(batchbeam); // though decoder do not support > 1024 for now if (batchbeam > 1024) { - grid.x = ceil(param.batch_size * param.beam_width / 1024.); + grid.x = ceil(param.batch_size * param.beam_width / 1024.); block.x = 1024; } gatherTree<<>>(param); diff --git a/src/fastertransformer/kernels/decoding_kernels.h b/src/fastertransformer/kernels/decoding_kernels.h index c378dcd37..116bec4ed 100644 --- a/src/fastertransformer/kernels/decoding_kernels.h +++ b/src/fastertransformer/kernels/decoding_kernels.h @@ -22,86 +22,87 @@ namespace fastertransformer { template -void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - T* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, +void invokeDecodingInitialize(bool* finished, + int* sequence_length, + int* word_ids, + T* cum_log_probs, + const int* sentence_ids, + const int batch_size, + const int beam_width, + const int max_input_length, cudaStream_t stream); // get token from all_ids at step, then lookup from the embedding table // by the token template -void invokeEmbeddingLookupPosEncoding(T* from_tensor, - const T* embedding_table, - const T* position_encoding, - const int* all_ids, - const int* input_lengths, - const int local_batch_size, - const int hidden_units, - const T scale, - const int step, - const int max_input_length, - const int batch_size, - const int ite, - cudaStream_t stream); +void invokeEmbeddingLookupPosEncodingPadCount(T* from_tensor, + const T* embedding_table, + const T* position_encoding, + const int* all_ids, + const int* padding_count, + const int local_batch_size, + const int hidden_units, + const T scale, + const int step, + const int batch_size, + const int ite, + cudaStream_t stream); template -void invokePaddingEmbedding(T* padded_embedding_kernel, - T* padded_embedding_bias, - const T* embedding_kernel, - const T* embedding_bias, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, +void invokePaddingEmbedding(T* padded_embedding_kernel, + T* padded_embedding_bias, + const T* embedding_kernel, + const T* embedding_bias, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, cudaStream_t stream); template -void invokePaddingEmbeddingKernel(T* padded_embedding_kernel, - const T* embedding_kernel, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, +void invokePaddingEmbeddingKernel(T* padded_embedding_kernel, + const T* embedding_kernel, + const int hidden_unit, + const int vocab_size, + const int vocab_size_padded, cudaStream_t stream); -void invokeGatherTree(int* beams, - int* max_sequence_lengths, - const int max_time, - const int batch_size, - const int beam_width, - const int* step_ids, - const int* parent_ids, - const int* end_tokens, +void invokeGatherTree(int* beams, + int* max_sequence_lengths, + const int max_time, + const int batch_size, + const int beam_width, + const int* step_ids, + const int* parent_ids, + const int* end_tokens, cudaStream_t stream); -void invokeGatherTree(int* beams, - int* max_sequence_lengths, - const int max_time, - const int batch_size, - const int beam_width, - const int* step_ids, - const int* parent_ids, - const int* end_tokens, - const int max_input_length, +void invokeGatherTree(int* beams, + int* max_sequence_lengths, + const int max_time, + const int batch_size, + const int beam_width, + const int* step_ids, + const int* parent_ids, + const int* end_tokens, + const int max_input_length, cudaStream_t stream); struct gatherTreeParam { - int* beams; + int* beams; const int* max_sequence_lengths = nullptr; - const int* input_lengths = nullptr; - int max_time; - int batch_size; - int beam_width; - const int* step_ids = nullptr; + const int* input_lengths = nullptr; + int max_time; + int batch_size; + int beam_width; + const int* step_ids = nullptr; const int* parent_ids = nullptr; const int* end_tokens; - int max_input_length; + int max_input_length; const int* prefix_soft_prompt_lengths = nullptr; - int max_prefix_soft_prompt_length; - int* output_ids = nullptr; + // p_prompt_tuning prompt leangths, used to remove prompts during post-processing + const int* p_prompt_tuning_prompt_lengths = nullptr; + int max_prefix_soft_prompt_length; + int* output_ids = nullptr; cudaStream_t stream; }; diff --git a/src/fastertransformer/kernels/dequantize_kernels.cu b/src/fastertransformer/kernels/dequantize_kernels.cu index 8151ac849..a78b11356 100644 --- a/src/fastertransformer/kernels/dequantize_kernels.cu +++ b/src/fastertransformer/kernels/dequantize_kernels.cu @@ -24,9 +24,9 @@ __global__ void dequantized_kernel(float4* dst, const char4* src, const int size { int tid = (blockIdx.x * blockDim.x + threadIdx.x); if (tid < size_div_4) { - const float scale = __ldg(scale_ptr); - char4 tmp = __ldg(src + tid); - int outIdx = tid; + const float scale = __ldg(scale_ptr); + char4 tmp = __ldg(src + tid); + int outIdx = tid; float4 float4Tmp; float4Tmp.x = static_cast(tmp.x) * scale; @@ -41,15 +41,15 @@ __global__ void dequantized_kernel(half4* dst, const char4* src, const int size_ { int tid = (blockIdx.x * blockDim.x + threadIdx.x); if (tid < size_div_4) { - const float scale = __ldg(scale_ptr); - char4 tmp = __ldg(src + tid); - int outIdx = tid; + const float scale = __ldg(scale_ptr); + char4 tmp = __ldg(src + tid); + int outIdx = tid; half4 half4Tmp; - half4Tmp.x = static_cast(static_cast(tmp.x) * scale); - half4Tmp.y = static_cast(static_cast(tmp.y) * scale); - half4Tmp.z = static_cast(static_cast(tmp.z) * scale); - half4Tmp.w = static_cast(static_cast(tmp.w) * scale); + half4Tmp.x = static_cast(static_cast(tmp.x) * scale); + half4Tmp.y = static_cast(static_cast(tmp.y) * scale); + half4Tmp.z = static_cast(static_cast(tmp.z) * scale); + half4Tmp.w = static_cast(static_cast(tmp.w) * scale); dst[outIdx] = half4Tmp; } } @@ -84,9 +84,9 @@ __global__ void dequantized_kernel_INT32( { int tid = (blockIdx.x * blockDim.x + threadIdx.x); if (tid < size_div_4) { - const float scale = (1.0f / __ldg(input_amax_ptr)) * (__ldg(weight_amax_ptr) / 127.0f); - int4 tmp = src[tid]; - int outIdx = tid; + const float scale = (1.0f / __ldg(input_amax_ptr)) * (__ldg(weight_amax_ptr) / 127.0f); + int4 tmp = src[tid]; + int outIdx = tid; float4 float4Tmp; float4Tmp.x = static_cast(tmp.x) * scale; @@ -102,26 +102,26 @@ __global__ void dequantized_kernel_INT32( { int tid = (blockIdx.x * blockDim.x + threadIdx.x); if (tid < size_div_4) { - const float scale = (1.0f / __ldg(input_amax_ptr)) * (__ldg(weight_amax_ptr) / 127.0f); - int4 tmp = src[tid]; - int outIdx = tid; + const float scale = (1.0f / __ldg(input_amax_ptr)) * (__ldg(weight_amax_ptr) / 127.0f); + int4 tmp = src[tid]; + int outIdx = tid; half4 half4Tmp; - half4Tmp.x = static_cast(static_cast(tmp.x) * scale); - half4Tmp.y = static_cast(static_cast(tmp.y) * scale); - half4Tmp.z = static_cast(static_cast(tmp.z) * scale); - half4Tmp.w = static_cast(static_cast(tmp.w) * scale); + half4Tmp.x = static_cast(static_cast(tmp.x) * scale); + half4Tmp.y = static_cast(static_cast(tmp.y) * scale); + half4Tmp.z = static_cast(static_cast(tmp.z) * scale); + half4Tmp.w = static_cast(static_cast(tmp.w) * scale); dst[outIdx] = half4Tmp; } } template -void invokeDequantization_INT32(T* dst, +void invokeDequantization_INT32(T* dst, const int32_t* src, - const int size, - cudaStream_t stream, - const float* input_amax_ptr, - const float* weight_amax_ptr) + const int size, + cudaStream_t stream, + const float* input_amax_ptr, + const float* weight_amax_ptr) { if (size % 4 != 0) { @@ -140,18 +140,18 @@ void invokeDequantization_INT32(T* dst, } } -template void invokeDequantization_INT32(float* dst, +template void invokeDequantization_INT32(float* dst, const int32_t* src, - const int size, - cudaStream_t stream, - const float* input_amax_ptr, - const float* weight_amax_ptr); + const int size, + cudaStream_t stream, + const float* input_amax_ptr, + const float* weight_amax_ptr); -template void invokeDequantization_INT32(half* dst, +template void invokeDequantization_INT32(half* dst, const int32_t* src, - const int size, - cudaStream_t stream, - const float* input_amax_ptr, - const float* weight_amax_ptr); + const int size, + cudaStream_t stream, + const float* input_amax_ptr, + const float* weight_amax_ptr); } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/dequantize_kernels.h b/src/fastertransformer/kernels/dequantize_kernels.h index e690c762b..6ec8e8881 100644 --- a/src/fastertransformer/kernels/dequantize_kernels.h +++ b/src/fastertransformer/kernels/dequantize_kernels.h @@ -26,11 +26,11 @@ template void invokeDequantization(T* dst, const int8_t* src, const int size, const float* scale_ptr, cudaStream_t stream); template -void invokeDequantization_INT32(T* dst, +void invokeDequantization_INT32(T* dst, const int32_t* src, - const int size, - cudaStream_t stream, - const float* input_amax_ptr, - const float* weight_amax_ptr); + const int size, + cudaStream_t stream, + const float* input_amax_ptr, + const float* weight_amax_ptr); } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/gen_relative_pos_bias.cu b/src/fastertransformer/kernels/gen_relative_pos_bias.cu index 12af8de2d..0cb1c58c3 100644 --- a/src/fastertransformer/kernels/gen_relative_pos_bias.cu +++ b/src/fastertransformer/kernels/gen_relative_pos_bias.cu @@ -31,18 +31,18 @@ namespace fastertransformer { // grid(window_size*window_size, head_num) // block(window_size*window_size) template -__global__ void gen_relative_pos_bias(T* relative_position_bias, - const T* relative_position_bias_table, +__global__ void gen_relative_pos_bias(T* relative_position_bias, + const T* relative_position_bias_table, const Tindex* relative_position_bias_index, - const int window_size, - const int head_num) + const int window_size, + const int head_num) { - const int h_in_window = blockIdx.x / window_size; - const int w_in_window = blockIdx.x % window_size; - const int h_in_token = threadIdx.x / window_size; - const int w_in_token = threadIdx.x % window_size; - const int head_idx = blockIdx.y; - const int elements_per_window = window_size * window_size; + const int h_in_window = blockIdx.x / window_size; + const int w_in_window = blockIdx.x % window_size; + const int h_in_token = threadIdx.x / window_size; + const int w_in_token = threadIdx.x % window_size; + const int head_idx = blockIdx.y; + const int elements_per_window = window_size * window_size; const size_t elements_per_window_2 = elements_per_window * elements_per_window; const size_t output_idx = head_idx * elements_per_window_2 + blockIdx.x * elements_per_window + threadIdx.x; if (output_idx < head_num * elements_per_window_2) { @@ -54,12 +54,12 @@ __global__ void gen_relative_pos_bias(T* relative_position_bias, } template -void invokeGenRelativePosBias(T* relative_position_bias, - const T* relative_position_bias_table, +void invokeGenRelativePosBias(T* relative_position_bias, + const T* relative_position_bias_table, const Tindex* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream) + const int window_size, + const int head_num, + cudaStream_t stream) { dim3 grid(window_size * window_size, head_num); dim3 block(window_size * window_size); @@ -75,32 +75,32 @@ void invokeGenRelativePosBias(T* relative_position_bias, /******************* instantiation ***********************/ -template void invokeGenRelativePosBias(float* relative_position_bias, +template void invokeGenRelativePosBias(float* relative_position_bias, const float* relative_position_bias_table, - const int* relative_position_bias_index, - const int window_size, - const int head_num, + const int* relative_position_bias_index, + const int window_size, + const int head_num, cudaStream_t stream); -template void invokeGenRelativePosBias(half* relative_position_bias, - const half* relative_position_bias_table, - const int* relative_position_bias_index, - const int window_size, - const int head_num, +template void invokeGenRelativePosBias(half* relative_position_bias, + const half* relative_position_bias_table, + const int* relative_position_bias_index, + const int window_size, + const int head_num, cudaStream_t stream); -template void invokeGenRelativePosBias(float* relative_position_bias, - const float* relative_position_bias_table, +template void invokeGenRelativePosBias(float* relative_position_bias, + const float* relative_position_bias_table, const int64_t* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream); + const int window_size, + const int head_num, + cudaStream_t stream); -template void invokeGenRelativePosBias(half* relative_position_bias, - const half* relative_position_bias_table, +template void invokeGenRelativePosBias(half* relative_position_bias, + const half* relative_position_bias_table, const int64_t* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream); + const int window_size, + const int head_num, + cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/gen_relative_pos_bias.h b/src/fastertransformer/kernels/gen_relative_pos_bias.h index 1d560d721..67364c708 100644 --- a/src/fastertransformer/kernels/gen_relative_pos_bias.h +++ b/src/fastertransformer/kernels/gen_relative_pos_bias.h @@ -28,11 +28,11 @@ enum class PositionEmbeddingType { }; template -void invokeGenRelativePosBias(T* relative_position_bias, - const T* relative_position_bias_table, +void invokeGenRelativePosBias(T* relative_position_bias, + const T* relative_position_bias_table, const Tindex* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream); + const int window_size, + const int head_num, + cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/gpt_kernels.cu b/src/fastertransformer/kernels/gpt_kernels.cu index bf3dc8455..f1f723193 100644 --- a/src/fastertransformer/kernels/gpt_kernels.cu +++ b/src/fastertransformer/kernels/gpt_kernels.cu @@ -14,157 +14,195 @@ * limitations under the License. */ +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#elif (CUDART_VERSION >= 11050) +#include +#else +#include "3rdparty/cub/cub.cuh" +#endif #include "src/fastertransformer/kernels/gpt_kernels.h" +#include "src/fastertransformer/utils/memory_utils.h" namespace fastertransformer { -template -__global__ void start_id_embedding_position_lookups_kernel(T* from_tensor, - int* output_ids, - const T* embedding_table, - const T* pos_table, - const int* input_ids, - const int start_step, - const int length, - const int max_length, - const int batch_size, - const int hidden_units) +// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts +template +__global__ void start_id_embedding_position_lookups_kernel(T* from_tensor, + int* output_ids, + const T* embedding_table, + const T* pos_table, + pPromptTuningParam prompt_param, + const int* input_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * length * hidden_units; index += blockDim.x * gridDim.x) { // transpose the input_ids [batch, length] (part of [batch, max_length]) to output_ids [length, batch] - if (index < batch_size * max_length) { - const int seq_id = index % max_length; - const int batch_id = index / max_length; - if (seq_id < length) { - output_ids[seq_id * batch_size + batch_id] = input_ids[index]; + if (OUTPUT_ID && index < batch_size * max_length) { + // for p/prompt_tuning (have prompt templates like [input1, prompt1, input2, prompt2]) + // we have to process it to like [input1, input2, prompt1, prompt2], and then remove the prompts during post + // processing + if (PROMPT_SRC > 0) { + if (index < batch_size) { + int no_prompt_output_seq_id = 0; +#pragma unroll 1 + for (int seq_id = 0; seq_id < max_length; seq_id++) { + int current_input_id = input_ids[index * max_length + seq_id]; + if (current_input_id < prompt_param.p_prompt_tuning_id_start) { + output_ids[no_prompt_output_seq_id * batch_size + index] = current_input_id; + no_prompt_output_seq_id++; + } + } + } + } + else { + const int seq_id = index % max_length; + const int batch_id = index / max_length; + if (seq_id < length) { + output_ids[seq_id * batch_size + batch_id] = input_ids[index]; + } } - // output_ids[index] = input_ids[index]; } // embedding lookup from word ids [batch, length] (part of [batch, max_length]) and [vocab, hidden] to generate // embedding [batch, length, hidden] - const int word_index = index / hidden_units; - const int word_index_row = word_index / length; - const int word_index_col = word_index % length; + const int word_index = index / hidden_units; + const int word_index_row = word_index / length; // batch_id + const int word_index_col = word_index % length; const int real_word_index = word_index_row * max_length + word_index_col; - const int step = start_step + word_index % length; - const int col_index = index % hidden_units; - T embedding = embedding_table[input_ids[real_word_index] * hidden_units + col_index]; - T pos_embed = pos_table == nullptr ? (T)0.f : pos_table[(step - 1) * hidden_units + col_index]; + const int step = start_step + word_index % length; + const int col_index = index % hidden_units; + const int input_id = input_ids == nullptr ? real_word_index : input_ids[real_word_index]; + const int prompt_id = input_id - prompt_param.p_prompt_tuning_id_start; + T embedding = (T)0.0f; + if (PROMPT_SRC > 0 && prompt_id >= 0) { + if (PROMPT_SRC == 1) { + // from loaded prompt embedding tables + embedding = + prompt_param.p_prompt_tuning_batch_weights[word_index_row][prompt_id * hidden_units + col_index]; + } + else { + // from request prompt embedding + embedding = + prompt_param + .request_prompt_embedding[word_index_row * prompt_param.request_prompt_max_length * hidden_units + + prompt_id * hidden_units + col_index]; + } + } + else { + embedding = embedding_table[input_id * hidden_units + col_index]; + } + T pos_embed = pos_table == nullptr ? (T)0.f : pos_table[(step - 1) * hidden_units + col_index]; from_tensor[index] = embedding + pos_embed; } } -template -__global__ void start_id_embedding_position_lookups_kernel(T* from_tensor, - const T* embedding_table, - const T* pos_table, - const int* input_ids, - const int start_step, - const int length, - const int max_length, - const int batch_size, - const int hidden_units) -{ - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * length * hidden_units; - index += blockDim.x * gridDim.x) { - // embedding lookup from word ids [batch, length] (part of [batch, max_length]) and [vocab, hidden] to generate - // embedding [batch, length, hidden] - const int word_index = index / hidden_units; - const int word_index_row = word_index / length; - const int word_index_col = word_index % length; - const int real_word_index = word_index_row * max_length + word_index_col; - const int step = start_step + word_index % length; - const int col_index = index % hidden_units; - T embedding = embedding_table[input_ids[real_word_index] * hidden_units + col_index]; - T pos_embed = pos_table == nullptr ? (T)0.f : pos_table[(step - 1) * hidden_units + col_index]; - from_tensor[index] = embedding + pos_embed; - } -} +#define WORD_POS_EMBEDDING_LOOPUP_KERNEL(OUTPUT_ID, PROMPT_SRC) \ + start_id_embedding_position_lookups_kernel<<>>(from_tensor, \ + output_ids, \ + embedding_table, \ + pos_table, \ + prompt_param, \ + input_ids, \ + start_step, \ + length, \ + max_length, \ + batch_size, \ + hidden_units); template -void invokeInputIdsEmbeddingLookupPosEncoding(T* from_tensor, - int* output_ids, - const T* embedding_table, - const T* pos_table, - const int* input_ids, - const int start_step, - const int length, - const int max_length, - const int batch_size, - const int hidden_units, - cudaStream_t stream) +void invokeInputIdsEmbeddingLookupPosEncoding(T* from_tensor, + int* output_ids, + const T* embedding_table, // can also be inputs_embeds + const T* pos_table, + pPromptTuningParam prompt_param, + const int* input_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream) { - dim3 grid(min(batch_size * length, 65536)); - dim3 block(min(hidden_units, 512)); - if (output_ids == nullptr) { - start_id_embedding_position_lookups_kernel<<>>(from_tensor, - embedding_table, - pos_table, - input_ids, - start_step, - length, - max_length, - batch_size, - hidden_units); + dim3 grid(min(batch_size * length, 65536)); + dim3 block(min(hidden_units, 512)); + const bool has_output_ids = output_ids != nullptr; + FT_CHECK(!(has_output_ids && input_ids == nullptr)); + + if (has_output_ids) { + if (prompt_param.use_request_p_prompt_embedding) { + WORD_POS_EMBEDDING_LOOPUP_KERNEL(true, 2); + } + else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) { + WORD_POS_EMBEDDING_LOOPUP_KERNEL(true, 1); + } + else { + WORD_POS_EMBEDDING_LOOPUP_KERNEL(true, 0); + } } else { - start_id_embedding_position_lookups_kernel<<>>(from_tensor, - output_ids, - embedding_table, - pos_table, - input_ids, - start_step, - length, - max_length, - batch_size, - hidden_units); + if (prompt_param.use_request_p_prompt_embedding) { + WORD_POS_EMBEDDING_LOOPUP_KERNEL(false, 2); + } + else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) { + WORD_POS_EMBEDDING_LOOPUP_KERNEL(false, 1); + } + else { + WORD_POS_EMBEDDING_LOOPUP_KERNEL(false, 0); + } } } -template void invokeInputIdsEmbeddingLookupPosEncoding(float* from_tensor, - int* output_ids, - const float* embedding_table, - const float* pos_table, - const int* input_ids, - const int start_step, - const int length, - const int max_length, - const int batch_size, - const int hidden_units, - cudaStream_t stream); - -template void invokeInputIdsEmbeddingLookupPosEncoding(half* from_tensor, - int* output_ids, - const half* embedding_table, - const half* pos_table, - const int* input_ids, - const int start_step, - const int length, - const int max_length, - const int batch_size, - const int hidden_units, - cudaStream_t stream); +template void invokeInputIdsEmbeddingLookupPosEncoding(float* from_tensor, + int* output_ids, + const float* embedding_table, + const float* pos_table, + pPromptTuningParam prompt_param, + const int* input_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + +template void invokeInputIdsEmbeddingLookupPosEncoding(half* from_tensor, + int* output_ids, + const half* embedding_table, + const half* pos_table, + pPromptTuningParam prompt_param, + const int* input_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeInputIdsEmbeddingLookupPosEncoding(__nv_bfloat16* from_tensor, - int* output_ids, - const __nv_bfloat16* embedding_table, - const __nv_bfloat16* pos_table, - const int* input_ids, - const int start_step, - const int length, - const int max_length, - const int batch_size, - const int hidden_units, - cudaStream_t stream); +template void invokeInputIdsEmbeddingLookupPosEncoding(__nv_bfloat16* from_tensor, + int* output_ids, + const __nv_bfloat16* embedding_table, + const __nv_bfloat16* pos_table, + pPromptTuningParam<__nv_bfloat16> prompt_param, + const int* input_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); #endif template __global__ void inputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam param) { - // 1. Copy the input ids to output ids and transpose ouptut ids to [seq_len, batch_size, beam_width]. + // 1. Copy the input ids to output ids and transpose output ids to [seq_len, batch_size, beam_width]. // 2. Embedding lookup by input ids and concat with soft prompt. The axis of concatenation is on axis of seq_len. // Assume batch size is 2 and prompts are [[t1, t2], [t3], [t4, t5]], input_ids are [[s1, s2], [s3], [s4]] @@ -187,12 +225,12 @@ __global__ void inputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLo // ouptut_ids need to add padding in the beginning for soft prompting. if (index < param.batch_size * param.beam_width * param.max_input_length) { - int tmp_index = index; - const int seq_id = tmp_index % param.max_input_length; - tmp_index = (tmp_index - seq_id) / param.max_input_length; - const int beam_id = tmp_index % param.beam_width; - tmp_index = (tmp_index - beam_id) / param.beam_width; - const int batch_id = tmp_index % param.batch_size; + int tmp_index = index; + const int seq_id = tmp_index % param.max_input_length; + tmp_index = (tmp_index - seq_id) / param.max_input_length; + const int beam_id = tmp_index % param.beam_width; + tmp_index = (tmp_index - beam_id) / param.beam_width; + const int batch_id = tmp_index % param.batch_size; if (seq_id < param.max_input_length) { param.output_ids[(param.prefix_soft_prompt_lengths[batch_id] + seq_id) * param.batch_size * param.beam_width @@ -203,28 +241,28 @@ __global__ void inputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLo // embedding lookup from word ids [batch, beam, length] (part of [batch, beam, max_input_length]), [vocab, // hidden] and [batch, max_prefix_soft_prompt_length, hidden] to generate embedding [batch, beam, length + // max_prefix_soft_prompt_length, hidden] - int tmp_index = index; + int tmp_index = index; const int hidden_id = tmp_index % param.hidden_units; - tmp_index = (tmp_index - hidden_id) / param.hidden_units; - const int seq_id = tmp_index % (param.max_prefix_soft_prompt_length + param.max_input_length); - tmp_index = (tmp_index - seq_id) / (param.max_prefix_soft_prompt_length + param.max_input_length); - const int beam_id = tmp_index % param.beam_width; - tmp_index = (tmp_index - beam_id) / param.beam_width; - const int batch_id = tmp_index % param.batch_size; - T embedding = + tmp_index = (tmp_index - hidden_id) / param.hidden_units; + const int seq_id = tmp_index % (param.max_prefix_soft_prompt_length + param.max_input_length); + tmp_index = (tmp_index - seq_id) / (param.max_prefix_soft_prompt_length + param.max_input_length); + const int beam_id = tmp_index % param.beam_width; + tmp_index = (tmp_index - beam_id) / param.beam_width; + const int batch_id = tmp_index % param.batch_size; + T embedding = (seq_id < param.prefix_soft_prompt_lengths[batch_id]) ? - (T)param + (T)param .prefix_soft_prompt_embedding[batch_id * param.max_prefix_soft_prompt_length * param.hidden_units + seq_id * param.hidden_units + hidden_id] : - param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length + param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length + beam_id * param.max_input_length + (seq_id - param.prefix_soft_prompt_lengths[batch_id])] * param.hidden_units + hidden_id]; - T pos_embed = param.pos_table == nullptr ? - (T)0.0f : - param.pos_table[(param.start_step + seq_id - 1) * param.hidden_units + hidden_id]; + T pos_embed = param.pos_table == nullptr ? + (T)0.0f : + param.pos_table[(param.start_step + seq_id - 1) * param.hidden_units + hidden_id]; param.from_tensor[index] = embedding + pos_embed; if (seq_id == 0 && hidden_id == 0) { @@ -260,9 +298,9 @@ __global__ void transposeAxis01(T* out, T* in, const int dim0, const int dim1, c int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < dim0 * dim1 * dim2) { const int input_dim2_index = index % dim2; - index = (index - input_dim2_index) / dim2; + index = (index - input_dim2_index) / dim2; const int input_dim1_index = index % dim1; - index = (index - input_dim1_index) / dim1; + index = (index - input_dim1_index) / dim1; const int input_dim0_index = index % dim0; out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + input_dim2_index] = @@ -297,9 +335,9 @@ __global__ void transposeAxis01(T* out, T* in, const int* in_skipping_dim1, cons int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < dim0 * dim1) { const int input_dim1_index = index % dim1; - index = (index - input_dim1_index) / dim1; + index = (index - input_dim1_index) / dim1; const int input_dim0_index = index % dim0; - const int in_offset = in_skipping_dim1 == nullptr ? 0 : in_skipping_dim1[input_dim1_index] * dim1; + const int in_offset = in_skipping_dim1 == nullptr ? 0 : in_skipping_dim1[input_dim1_index] * dim1; out[input_dim1_index * dim0 + input_dim0_index] = in[in_offset + input_dim0_index * dim1 + input_dim1_index]; } @@ -317,17 +355,24 @@ void invokeTransposeAxis01( template void invokeTransposeAxis01( int* out, int* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream); -template -__global__ void buildDecoderAttentionMaskKernel(T* attention_mask, const int* sequence_lengths, const int max_seq_len) +template +__global__ void buildDecoderAttentionMaskKernel(T* attention_mask, + const int* sequence_lengths, + const int* prefix_prompt_lengths, + const int max_seq_len, + const int max_prompt_length) { // sequence_lengths: [batch_size] - // attention_mask: [batch_size, 1, max_seq_len, max_seq_len] - attention_mask += blockIdx.x * max_seq_len * max_seq_len; - const int length = sequence_lengths[blockIdx.x]; - for (int i = threadIdx.x; i < max_seq_len * max_seq_len; i += blockDim.x) { - int row_id = i / max_seq_len; - int col_id = i % max_seq_len; - if (row_id < length && col_id <= row_id) { + // attention_mask: [batch_size, 1, max_seq_len, max_seq_len + max_prompt_length] + const int max_prompt_seq_length = max_seq_len + max_prompt_length; + const int mask_size_per_seq = max_seq_len * max_prompt_seq_length; + attention_mask += blockIdx.x * mask_size_per_seq; + const int seq_length = sequence_lengths[blockIdx.x]; + const int prompt_length = PREFIX_PROMPT ? prefix_prompt_lengths[blockIdx.x] : 0; + for (int i = threadIdx.x; i < mask_size_per_seq; i += blockDim.x) { + int row_id = i / max_prompt_seq_length; + int col_id = i % max_prompt_seq_length; + if (row_id < seq_length && col_id <= (row_id + prompt_length)) { attention_mask[i] = (T)(1.0f); } else { @@ -337,97 +382,121 @@ __global__ void buildDecoderAttentionMaskKernel(T* attention_mask, const int* se } template -void invokeBuildDecoderAttentionMask( - T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream) +void invokeBuildDecoderAttentionMask(T* attention_mask, + const int* sequence_lengths, + const int* prefix_prompt_lengths, + const int batch_size, + const int max_seq_len, + const int max_prompt_length, + cudaStream_t stream) { - buildDecoderAttentionMaskKernel<<>>(attention_mask, sequence_lengths, max_seq_len); + if (max_prompt_length == 0) { + buildDecoderAttentionMaskKernel<<>>( + attention_mask, sequence_lengths, prefix_prompt_lengths, max_seq_len, max_prompt_length); + } + else { + buildDecoderAttentionMaskKernel<<>>( + attention_mask, sequence_lengths, prefix_prompt_lengths, max_seq_len, max_prompt_length); + } } -template void invokeBuildDecoderAttentionMask(float* attention_mask, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, +template void invokeBuildDecoderAttentionMask(float* attention_mask, + const int* sequence_lengths, + const int* prefix_prompt_lengths, + const int batch_size, + const int max_seq_len, + const int max_prompt_length, cudaStream_t stream); -template void invokeBuildDecoderAttentionMask(half* attention_mask, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, +template void invokeBuildDecoderAttentionMask(half* attention_mask, + const int* sequence_lengths, + const int* prefix_prompt_lengths, + const int batch_size, + const int max_seq_len, + const int max_prompt_length, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeBuildDecoderAttentionMask(__nv_bfloat16* attention_mask, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, - cudaStream_t stream); + const int* sequence_lengths, + const int* prefix_prompt_lengths, + const int batch_size, + const int max_seq_len, + const int max_prompt_length, + cudaStream_t stream); #endif template -__launch_bounds__(1024, 1) __global__ void lookupHiddenStateOfLastToken(T* from_tensor, - const T* hidden_state, +__launch_bounds__(1024, 1) __global__ void lookupHiddenStateOfLastToken(T* from_tensor, + const T* hidden_state, const int* input_lengths, - const int max_input_length, - const int batch_size, - const int hidden_units) + const int max_input_length, + const int batch_size, + const int hidden_units) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * hidden_units; index += blockDim.x * gridDim.x) { const int col_index = index % hidden_units; - const int batch_id = index / hidden_units; - from_tensor[index] = hidden_state[batch_id * max_input_length * hidden_units + const int batch_id = index / hidden_units; + from_tensor[index] = hidden_state[batch_id * max_input_length * hidden_units + (input_lengths[batch_id] - 1) * hidden_units + col_index]; } } template -void invokeLookupHiddenStateOfLastToken(T* from_tensor, - const T* hidden_state, - const int* input_lengths, - const int max_input_length, - const int batch_size, - const int hidden_units, +void invokeLookupHiddenStateOfLastToken(T* from_tensor, + const T* hidden_state, + const int* input_lengths, + const int max_input_length, + const int batch_size, + const int hidden_units, cudaStream_t stream) { const int grid_size = (int)(ceil(batch_size * hidden_units / 1024.)); - dim3 grid(min(grid_size, 65536)); - dim3 block(min(hidden_units, 1024)); + dim3 grid(min(grid_size, 65536)); + dim3 block(min(hidden_units, 1024)); lookupHiddenStateOfLastToken<<>>( from_tensor, hidden_state, input_lengths, max_input_length, batch_size, hidden_units); } -template void invokeLookupHiddenStateOfLastToken(float* from_tensor, +template void invokeLookupHiddenStateOfLastToken(float* from_tensor, const float* hidden_state, - const int* input_lengths, - const int max_input_length, - const int batch_size, - const int hidden_units, + const int* input_lengths, + const int max_input_length, + const int batch_size, + const int hidden_units, cudaStream_t stream); -template void invokeLookupHiddenStateOfLastToken(half* from_tensor, - const half* hidden_state, - const int* input_lengths, - const int max_input_length, - const int batch_size, - const int hidden_units, +template void invokeLookupHiddenStateOfLastToken(half* from_tensor, + const half* hidden_state, + const int* input_lengths, + const int max_input_length, + const int batch_size, + const int hidden_units, cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeLookupHiddenStateOfLastToken(__nv_bfloat16* from_tensor, +template void invokeLookupHiddenStateOfLastToken(__nv_bfloat16* from_tensor, const __nv_bfloat16* hidden_state, - const int* input_lengths, - const int max_input_length, - const int batch_size, - const int hidden_units, - cudaStream_t stream); + const int* input_lengths, + const int max_input_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); #endif -__global__ void tileGptInputs(int* tiled_input_ids, - int* tiled_input_lengths, - const int* input_ids, - const int* input_lengths, - const int max_input_length) +template +__global__ void tileGptPromptInputs(int* tiled_input_ids, + int* tiled_input_lengths, + int* tiled_prompt_lengths, + const int* input_ids, + const int* input_lengths, + const int* prefix_prompt_lengths, + const int max_input_length) { if (threadIdx.x == 0) { tiled_input_lengths[blockIdx.x * gridDim.y + blockIdx.y] = input_lengths[blockIdx.x]; + if (PREFIX_PROMPT) { + tiled_prompt_lengths[blockIdx.x * gridDim.y + blockIdx.y] = prefix_prompt_lengths[blockIdx.x]; + } } for (int index = threadIdx.x; index < max_input_length; index += blockDim.x) { tiled_input_ids[(blockIdx.x * gridDim.y + blockIdx.y) * max_input_length + index] = @@ -435,74 +504,552 @@ __global__ void tileGptInputs(int* tiled_input_ids, } } -void invokeTileGptInputs(int* tiled_input_ids, - int* tiled_input_lengths, - const int* input_ids, - const int* input_lengths, - const int batch_size, - const int beam_width, - const int max_input_length, - cudaStream_t stream) +void invokeTileGptPromptInputs(int* tiled_input_ids, + int* tiled_input_lengths, + int* tiled_prompt_lengths, + const int* input_ids, + const int* input_lengths, + const int* prefix_prompt_lengths, + const int batch_size, + const int beam_width, + const int max_input_length, + cudaStream_t stream) { dim3 grid(batch_size, beam_width); dim3 block(min(1024, max_input_length)); - tileGptInputs<<>>( - tiled_input_ids, tiled_input_lengths, input_ids, input_lengths, max_input_length); + if (prefix_prompt_lengths != nullptr) { + tileGptPromptInputs<<>>(tiled_input_ids, + tiled_input_lengths, + tiled_prompt_lengths, + input_ids, + input_lengths, + prefix_prompt_lengths, + max_input_length); + } + else { + tileGptPromptInputs<<>>(tiled_input_ids, + tiled_input_lengths, + tiled_prompt_lengths, + input_ids, + input_lengths, + prefix_prompt_lengths, + max_input_length); + } } -bool hasDiffRuntimeArgs(const std::unordered_map* input_tensors) +void invokeTileGptInputs(int* tiled_input_ids, + int* tiled_input_lengths, + const int* input_ids, + const int* input_lengths, + const int batch_size, + const int beam_width, + const int max_input_length, + cudaStream_t stream) { - // runtime_top_k [1] or [batch_size] on cpu, optional. - // runtime_top_p [1] or [batch_size] on cpu, optional - // beam_search_diversity_rate [1] or [batch_size] on cpu, optional - // temperature [1] or [batch_size] on cpu, optional - // len_penalty [1] or [batch_size] on cpu, optional - // repetition_penalty [1] or [batch_size] on cpu, optional - // random_seed [1] or [batch_size] on cpu, optional - - std::vector check_list = {"runtime_top_k", - "runtime_top_p", - "beam_search_diversity_rate", - "temperature", - "len_penalty", - "repetition_penalty", - "random_seed"}; - - for (int i = 0; i < (int)check_list.size(); i++) { - if (input_tensors->count(check_list[i])) { - auto tensor = input_tensors->at(check_list[i]); - FT_CHECK(tensor.shape.size() == 1); - for (int j = 1; j < (int)tensor.shape[0]; j++) { - const void* data = tensor.data; - switch (tensor.type) { - case TYPE_FP32: - if (((const float*)data)[0] != ((const float*)data)[j]) { - return true; - } - break; - case TYPE_INT32: - if (((const int*)data)[0] != ((const int*)data)[j]) { - return true; - } - break; - case TYPE_UINT32: - if (((const uint*)data)[0] != ((const uint*)data)[j]) { - return true; - } - break; - case TYPE_UINT64: - if (((const unsigned long long int*)data)[0] != ((const unsigned long long int*)data)[j]) { - return true; - } - break; - default: - FT_CHECK_WITH_INFO(false, check_list[i] + ": " + tensor.toString() + " is invalid."); - break; - } - } + invokeTileGptPromptInputs(tiled_input_ids, + tiled_input_lengths, + nullptr, + input_ids, + input_lengths, + nullptr, + batch_size, + beam_width, + max_input_length, + stream); +} + +void setSeqLimitLen(uint32_t* seq_len_d, Tensor seq_len, int limit_len_offset, int batch_size) +{ + std::vector seq_len_h(batch_size); + for (int i = 0; i < batch_size; i++) { + seq_len_h[i] = seq_len.getPtr()[i] + limit_len_offset; + } + cudaH2Dcpy(seq_len_d, seq_len_h.data(), batch_size); +} + +template +__global__ void +find_context_dups(int* shared_contexts, const int* input_ids, const size_t batch_size, const size_t input_seq_len) +{ + /* We compare all context pairs (i, j), with i (tgt) < j (src) , to detect duplicate + * inputs. If there's a match between i and j, we store i at the + * j-th position of shared_context. So that we know that j can be + * represented by i. shared_contexts is initialized like shared_contexts[i] = i + * and when there's a match, we actually use shared_contexts[j] = min(shared_contexts[j], i) + * so that in the end, shared_contexts effectively contains an index + * to the match with the lowest index context. + * Note that shared_contexts[i] <= i, a property that will be used when uncompacting + * inputs. + */ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ bool match; + + /* Each block is responsible for a (i, j) pair. To map the block space to + * the i < j space, we need to convert a linear addressing to a triangle, of + * size (batch_size * (batch_size - 1)) / 2 + * For more information, check https://en.wikipedia.org/wiki/Triangular_number + */ + + // blockIdx = [0, 1, 2, ... n(n-1)/2] -> base_index = [0, 1, 1, 2, 2, 2, 3, 3, 3, 3, ..., n - 2] + const int base_index = floorf(0.5f * (sqrtf(1 + 8 * blockIdx.x) - 1)); + const int src_idx = base_index + 1; // base_index \in [1, batch_size) + + const int rev_base_index = base_index * (base_index + 1) / 2; + const int tgt_idx = blockIdx.x - rev_base_index; // tgt_idx \in [0, src_idx) + + const int padded_length = TB_SIZE * ((input_seq_len + TB_SIZE - 1) / TB_SIZE); + + int sum = 0; + for (int i = threadIdx.x; i < padded_length; i += TB_SIZE) { + int compare = + (i >= input_seq_len) ? 1 : input_ids[tgt_idx * input_seq_len + i] == input_ids[src_idx * input_seq_len + i]; + + sum = BlockReduce(temp_storage).Sum(compare); + + if (threadIdx.x == 0) { + match = (sum == TB_SIZE); + } + + __syncthreads(); + + if (!match) { + break; + } + } + + if (threadIdx.x == 0 && match) { + atomicMin(&shared_contexts[src_idx], tgt_idx); + } +} + +constexpr int DUPS_INDICES_BLOCK_SIZE = 128; + +__global__ void generate_dups_indices(int* batch_to_compact, + int* compact_to_batch, + int* compact_size, + const int* shared_contexts, + const size_t batch_size, + const size_t input_seq_len) +{ + const int padded_batchsize = blockDim.x * ((batch_size + blockDim.x - 1) / blockDim.x); + + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ int scan_offset; + + int scan = 0; + for (int batch = threadIdx.x; batch < padded_batchsize; batch += blockDim.x) { + bool masked = (batch >= batch_size); + bool first_iter = batch < blockDim.x; + + int is_first_occur = masked ? 0 : shared_contexts[batch] == batch; + BlockScan(temp_storage).ExclusiveSum(is_first_occur, scan); + + if (!masked && is_first_occur) { + int compact_idx = scan + (first_iter ? 0 : scan_offset); + // Context rep. writes initial index + batch_to_compact[batch] = compact_idx; + compact_to_batch[compact_idx] = batch; + } + + if (threadIdx.x == blockDim.x - 1) { + scan_offset = scan + is_first_occur + (first_iter ? 0 : scan_offset); + } + + __syncthreads(); + + if (!masked && !is_first_occur) { + // Fill the rest of batch_to_compact based on what rep. wrote + const int src_idx = batch_to_compact[shared_contexts[batch]]; + batch_to_compact[batch] = src_idx; + } + } + + if (threadIdx.x == 0) { + *compact_size = scan_offset; + } +} + +__global__ void init_shared_contexts(int* shared_contexts, const size_t batch_size) +{ + const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (global_idx >= batch_size) { + return; + } + shared_contexts[global_idx] = global_idx; +} + +void invokeFindContextDups(int* shared_contexts, + int* batch_to_compact, + int* compact_to_batch, + int* compact_size, + const int* input_ids, + const size_t batch_size, + const size_t input_seq_len, + cudaStream_t stream) +{ + dim3 block{512}; + dim3 grid{((int)batch_size + block.x - 1) / block.x}; + init_shared_contexts<<>>(shared_contexts, batch_size); + + grid = dim3{(unsigned int)(batch_size * (batch_size - 1)) / 2}; + if (input_seq_len <= 128) { + block = 128; + find_context_dups<128><<>>(shared_contexts, input_ids, batch_size, input_seq_len); + } + else { + block = 256; + find_context_dups<256><<>>(shared_contexts, input_ids, batch_size, input_seq_len); + } + + generate_dups_indices<<<1, DUPS_INDICES_BLOCK_SIZE, 0, stream>>>( + batch_to_compact, compact_to_batch, compact_size, shared_contexts, batch_size, input_seq_len); +} + +template +__global__ void compact_inputs(T* compact_input, + T* compact_attention_mask, + int* compact_input_lengths, + const T* decoder_input, + const T* decoder_mask, + const int* input_lengths, + const int* compact_idx, + size_t compact_size, + size_t seq_len, + size_t hidden_dimension) +{ + const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx < compact_size * seq_len * hidden_dimension) { + const int h_id = global_idx % hidden_dimension; + const int seq_id = (global_idx / hidden_dimension) % seq_len; + const int batch_id = global_idx / (hidden_dimension * seq_len); + + compact_input[global_idx] = decoder_input[(compact_idx[batch_id] * seq_len + seq_id) * hidden_dimension + h_id]; + } + + if (global_idx < compact_size * seq_len * seq_len) { + const int seq1_id = global_idx % seq_len; + const int seq2_id = (global_idx / seq_len) % seq_len; + const int batch_id = global_idx / (seq_len * seq_len); + + compact_attention_mask[global_idx] = + decoder_mask[(compact_idx[batch_id] * seq_len + seq2_id) * seq_len + seq1_id]; + } + + if (global_idx < compact_size) { + compact_input_lengths[global_idx] = input_lengths[compact_idx[global_idx]]; + } +} + +template +void invokeCompactInputs(T* compact_input, + T* compact_attention_mask, + int* compact_input_lengths, + const T* decoder_input, + const T* decoder_mask, + const int* input_lengths, + const int* compact_idx, + size_t compact_size, + size_t seq_len, + size_t hidden_dimension, + cudaStream_t stream) +{ + /* Compact relevant decoder_layer inputs based on the identical contexts. + * For example, decoder_input is [batch_size, seq_len, H]. It's compacted + * into compact_input [compact_size, seq_len, H] such that + * compact_input[i, ...] = decoder_input[compact_idx[i], ...] */ + const size_t elems_n = compact_size * seq_len * max(hidden_dimension, seq_len); + const dim3 blockDim(512); + const dim3 gridDim((elems_n + 512 - 1) / 512); + + compact_inputs<<>>(compact_input, + compact_attention_mask, + compact_input_lengths, + decoder_input, + decoder_mask, + input_lengths, + compact_idx, + compact_size, + seq_len, + hidden_dimension); +} + +#define INSTANTIATE_INVOKE_COMPACT_INPUTS(T) \ + template void invokeCompactInputs(T * compact_input, \ + T * compact_attention_mask, \ + int* compact_input_lengths, \ + const T* decoder_input, \ + const T* decoder_mask, \ + const int* input_lengths, \ + const int* compact_idx, \ + size_t compact_size, \ + size_t seq_len, \ + size_t hidden_dimension, \ + cudaStream_t stream) +INSTANTIATE_INVOKE_COMPACT_INPUTS(half); +INSTANTIATE_INVOKE_COMPACT_INPUTS(float); +#ifdef ENABLE_BF16 +INSTANTIATE_INVOKE_COMPACT_INPUTS(__nv_bfloat16); +#endif + +template +__global__ void uncompact_outputs(T* uncompact_buffer, + const T* compact_buffer, + const int* batch_to_compact_idx, + size_t batch_size, + size_t buffer_stride) +{ + /* Uncompact a buffer IN of size [Compact, Stride] into OUT of size [Batch, Stride] + * so that \forall i, OUT[i, :] = IN[batch_to_compact_idx[i], :] + */ + const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx >= batch_size * buffer_stride) { + return; + } + + const int stride_idx = global_idx % buffer_stride; + const int batch_idx = global_idx / buffer_stride; + + const int src = batch_to_compact_idx[batch_idx]; + uncompact_buffer[global_idx] = compact_buffer[src * buffer_stride + stride_idx]; +} + +template +void invokeUnCompactOutputs(T* uncompact_buffer, + const T* compact_buffer, + const int* batch_to_compact_idx, + size_t batch_size, + size_t buffer_stride, + cudaStream_t stream) +{ + const size_t num_elems = batch_size * buffer_stride; + const dim3 blockDim(1024); + const dim3 gridDim((num_elems + blockDim.x - 1) / blockDim.x); + + uncompact_outputs<<>>( + uncompact_buffer, compact_buffer, batch_to_compact_idx, batch_size, buffer_stride); +} + +#define INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(T) \ + template void invokeUnCompactOutputs(T* uncompact_buffer, \ + const T* compact_buffer, \ + const int* batch_to_compact_idx, \ + size_t batch_size, \ + size_t buffer_stride, \ + cudaStream_t stream) +INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(half); +INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(float); +#ifdef ENABLE_BF16 +INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(__nv_bfloat16); +#endif + +template +__global__ void uncompact_caches(T* uncompact_k_cache, + T* uncompact_v_cache, + const T* compact_k_cache, + const T* compact_v_cache, + const int* batch_to_compact_idx, + size_t batch_size, + size_t num_heads, + size_t max_seq_len, + size_t seq_len, + size_t size_per_head, + size_t local_batch_size, + size_t ite) +{ + const int hidden_dimension = num_heads * size_per_head; + const int num_elems_per_batch = seq_len * hidden_dimension; + const int num_elems_cache = batch_size * num_elems_per_batch; + const int x_size = 16 / sizeof(T); + + for (int global_idx = blockIdx.x * blockDim.x + threadIdx.x; global_idx < 2 * num_elems_cache; + global_idx += blockDim.x * gridDim.x) { + + const bool handle_k = global_idx < num_elems_cache; + const T* const cache_src = handle_k ? compact_k_cache : compact_v_cache; + T* const cache_dst = handle_k ? uncompact_k_cache : uncompact_v_cache; + const int idx = handle_k ? global_idx : global_idx - num_elems_cache; + + const int src_offset = idx % num_elems_per_batch; + const int batch_idx = idx / num_elems_per_batch; + const int batch_src = batch_to_compact_idx[batch_idx] - ite * local_batch_size; + + if (batch_src < 0 || batch_src >= local_batch_size) { + continue; } + + int dst_offset; + if (handle_k) { + const int i0 = idx % (x_size * seq_len); + const int i1 = (idx / (x_size * seq_len)) % (num_heads * size_per_head / x_size); + dst_offset = i1 * max_seq_len * x_size + i0; + } + else { + const int i0 = idx % (size_per_head * seq_len); + const int i1 = (idx / (size_per_head * seq_len)) % (num_heads); + dst_offset = i1 * max_seq_len * size_per_head + i0; + } + + cache_dst[batch_idx * max_seq_len * hidden_dimension + dst_offset] = + cache_src[batch_src * num_elems_per_batch + src_offset]; + } +} + +template +void invokeUnCompactCaches(T* uncompact_k_cache, + T* uncompact_v_cache, + const T* compact_k_cache, + const T* compact_v_cache, + const int* batch_to_compact_idx, + size_t batch_size, + size_t num_heads, + size_t max_seq_len, + size_t seq_len, + size_t size_per_head, + size_t local_batch_size, + size_t ite, + cudaStream_t stream) +{ + const dim3 blockDim(512); + const dim3 gridDim(1024); + uncompact_caches<<>>(uncompact_k_cache, + uncompact_v_cache, + compact_k_cache, + compact_v_cache, + batch_to_compact_idx, + batch_size, + num_heads, + max_seq_len, + seq_len, + size_per_head, + local_batch_size, + ite); +} + +#define INSTANTIATE_INVOKE_UNCOMPACT_CACHES(T) \ + template void invokeUnCompactCaches(T* uncompact_k_cache, \ + T* uncompact_v_cache, \ + const T* compact_k_cache, \ + const T* compact_v_cache, \ + const int* batch_to_compact_idx, \ + size_t batch_size, \ + size_t num_heads, \ + size_t max_seq_len, \ + size_t seq_len, \ + size_t size_per_head, \ + size_t local_batch_size, \ + size_t ite, \ + cudaStream_t stream) +INSTANTIATE_INVOKE_UNCOMPACT_CACHES(half); +INSTANTIATE_INVOKE_UNCOMPACT_CACHES(float); +#ifdef ENABLE_BF16 +INSTANTIATE_INVOKE_UNCOMPACT_CACHES(__nv_bfloat16); +#endif + +template +__global__ void update_padding_count(int* total_padding_count, + const int* input_lengths, + const int* tiled_prompt_lengths, + size_t max_input_length, + size_t max_prompt_length, + size_t batch_size, + size_t beam_width) +{ + const int gidx = blockIdx.x * blockDim.x + threadIdx.x; + + if (gidx >= batch_size * beam_width) { + return; + } + + const int batch_idx = gidx / beam_width; + + total_padding_count[gidx] += + PREFIX_PROMPT ? (max_input_length + max_prompt_length - input_lengths[batch_idx] - tiled_prompt_lengths[gidx]) : + (max_input_length - input_lengths[batch_idx]); +} + +void invokeUpdatePaddingCount(int* total_padding_count, + const int* input_lengths, + const int* tiled_prompt_lengths, + size_t max_input_length, + size_t max_prompt_length, + size_t batch_size, + size_t beam_width, + cudaStream_t stream) +{ + dim3 blockSize(256); + dim3 gridSize((batch_size * beam_width + blockSize.x - 1) / blockSize.x); + + if (tiled_prompt_lengths != nullptr) { + update_padding_count<<>>(total_padding_count, + input_lengths, + tiled_prompt_lengths, + max_input_length, + max_prompt_length, + batch_size, + beam_width); + } + else { + update_padding_count<<>>(total_padding_count, + input_lengths, + tiled_prompt_lengths, + max_input_length, + max_prompt_length, + batch_size, + beam_width); + } +} + +template +__global__ void mask_padding_tokens(bool* masked_tokens, + const int* input_lengths, + const int* tiled_prefix_prompt_lengths, + const size_t memory_len, + const size_t max_input_length, + const size_t initial_step, + size_t beam_width) +{ + const int seq_len = PREFIX_PROMPT ? + (input_lengths[blockIdx.x / beam_width] + tiled_prefix_prompt_lengths[blockIdx.x]) : + input_lengths[blockIdx.x / beam_width]; + for (int step = initial_step + seq_len + threadIdx.x; step < initial_step + max_input_length; step += blockDim.x) { + masked_tokens[blockIdx.x * memory_len + step % memory_len] = true; + } +} + +void invokeMaskPaddingTokens(bool* masked_tokens, + const int* input_lengths, + const int* tiled_prefix_prompt_lengths, + const size_t memory_len, + const size_t max_input_length, + const size_t initial_step, + size_t batch_size, + size_t beam_width, + cudaStream_t stream) +{ + dim3 blockSize(128); + dim3 gridSize(batch_size * beam_width); + if (tiled_prefix_prompt_lengths != nullptr) { + mask_padding_tokens<<>>(masked_tokens, + input_lengths, + tiled_prefix_prompt_lengths, + memory_len, + max_input_length, + initial_step, + beam_width); + } + else { + mask_padding_tokens<<>>(masked_tokens, + input_lengths, + tiled_prefix_prompt_lengths, + memory_len, + max_input_length, + initial_step, + beam_width); } - return false; } -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/gpt_kernels.h b/src/fastertransformer/kernels/gpt_kernels.h index 5f0fcdbfc..aecc9e5ce 100644 --- a/src/fastertransformer/kernels/gpt_kernels.h +++ b/src/fastertransformer/kernels/gpt_kernels.h @@ -24,38 +24,54 @@ namespace fastertransformer { -template -void invokeInputIdsEmbeddingLookupPosEncoding(T* from_tensor, - int* output_ids, - const T* embedding_table, - const T* pos_table, - const int* input_ids, - const int start_step, - const int length, - const int max_length, - const int batch_size, - const int hidden_units, - cudaStream_t stream); - template struct inputIdsEmbeddingLookupPosEncodingSoftPromptParam { - T* from_tensor; - int* output_ids; - int* input_lengths; - const T* embedding_table; - const T* pos_table; + T* from_tensor; + int* output_ids; + int* input_lengths; + const T* embedding_table; + const T* pos_table; const float* prefix_soft_prompt_embedding; - const int* prefix_soft_prompt_lengths; - int* input_ids; - int start_step; - int max_input_length; - int max_prefix_soft_prompt_length; - int batch_size; - int beam_width; - int hidden_units; + const int* prefix_soft_prompt_lengths; + int* input_ids; + int start_step; + int max_input_length; + int max_prefix_soft_prompt_length; + int batch_size; + int beam_width; + int hidden_units; cudaStream_t stream; }; +template +struct pPromptTuningParam { + // Batch number of ptrs, each ptr is the ptr of the specific p/prompt tuning weights for this sequence + const T** p_prompt_tuning_batch_weights = nullptr; + // The start id of p_prompt_tuning token ids (based on the tokenizer) + // PROMPT_0 --> p_prompt_tuning_id_start; PROMPT_1 --> p_prompt_tuning_id_start + 1; ... + const int p_prompt_tuning_id_start = 0; + // Request prompt embeddding's max length + const int request_prompt_max_length = 0; + // Whether or not use the request prompt embeddings + const bool use_request_p_prompt_embedding = false; + // Request prompt embeddings + const T* request_prompt_embedding = nullptr; +}; + +template +void invokeInputIdsEmbeddingLookupPosEncoding(T* from_tensor, + int* output_ids, + const T* embedding_table, + const T* pos_table, + pPromptTuningParam prompt_param, + const int* input_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + template void invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam param); @@ -67,35 +83,58 @@ void invokeTransposeAxis01( T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream); template -void invokeBuildDecoderAttentionMask( - T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream); +void invokeBuildDecoderAttentionMask(T* attention_mask, + const int* sequence_lengths, + const int* prefix_prompt_lengths, + const int batch_size, + const int max_seq_len, + const int max_prompt_length, + cudaStream_t stream); template -void invokeLookupHiddenStateOfLastToken(T* from_tensor, - const T* hidden_state, - const int* input_lengths, - const int max_input_length, - const int batch_size, - const int hidden_units, +void invokeLookupHiddenStateOfLastToken(T* from_tensor, + const T* hidden_state, + const int* input_lengths, + const int max_input_length, + const int batch_size, + const int hidden_units, cudaStream_t stream); -void invokeTileGptInputs(int* tiled_input_ids, - int* tiled_input_lengths, - const int* input_ids, - const int* input_lengths, - const int batch_size, - const int beam_width, - const int max_input_length, +void invokeTileGptPromptInputs(int* tiled_input_ids, + int* tiled_input_lengths, + int* tiled_prompt_lengths, + const int* input_ids, + const int* input_lengths, + const int* prefix_prompt_lengths, + const int batch_size, + const int beam_width, + const int max_input_length, + cudaStream_t stream); + +void invokeTileGptInputs(int* tiled_input_ids, + int* tiled_input_lengths, + const int* input_ids, + const int* input_lengths, + const int batch_size, + const int beam_width, + const int max_input_length, cudaStream_t stream); -bool hasDiffRuntimeArgs(const std::unordered_map* input_tensors); +void invokeFindContextDups(int* shared_contexts, + int* batch_to_compact, + int* compact_to_batch, + int* compact_size, + const int* input_ids, + const size_t batch_size, + const size_t input_seq_len, + cudaStream_t stream = 0); template void handleOptArg(const std::unordered_map* input_tensors, - const std::string& arg_name, - T* d_ptr, - T default_value, - size_t size) + const std::string& arg_name, + T* d_ptr, + T default_value, + size_t size) { if (input_tensors->find(arg_name) != input_tensors->end()) { FT_CHECK(input_tensors->at(arg_name).size() == size); @@ -106,4 +145,92 @@ void handleOptArg(const std::unordered_map* input_tensors, } } +void setSeqLimitLen(uint32_t* seq_len_d, Tensor seq_len, int limit_len_offset, int batch_size); + +template +void invokeCompactInputs(T* compact_input, + T* compact_attention_mask, + int* compact_input_lengths, + const T* decoder_input, + const T* decoder_mask, + const int* input_lengths, + const int* compact_idx, + size_t compact_size, + size_t seq_len, + size_t hidden_dimension, + cudaStream_t stream = 0); + +template +void invokeUnCompactOutputs(T* uncompact_buffer, + const T* compact_buffer, + const int* batch_to_compact_idx, + size_t batch_size, + size_t buffer_stride, + cudaStream_t stream = 0); + +template +void invokeUnCompactCaches(T* uncompact_k_cache, + T* uncompact_v_cache, + const T* compact_k_cache, + const T* compact_v_cache, + const int* batch_to_compact_idx, + size_t batch_size, + size_t num_heads, + size_t max_seq_len, + size_t seq_len, + size_t size_per_head, + size_t local_batch_size, + size_t ite, + cudaStream_t stream = 0); + +void invokeUpdatePaddingCount(int* total_padding_count, + const int* input_lengths, + const int* tiled_prompt_lengths, + size_t max_input_length, + size_t max_prompt_length, + size_t batch_size, + size_t beam_width, + cudaStream_t stream = 0); + +inline void invokeUpdatePaddingCount(int* total_padding_count, + const int* input_lengths, + size_t max_input_length, + size_t batch_size, + size_t beam_width, + cudaStream_t stream = 0) +{ + invokeUpdatePaddingCount( + total_padding_count, input_lengths, (const int*)nullptr, max_input_length, 0, batch_size, beam_width, stream); +} + +void invokeMaskPaddingTokens(bool* masked_tokens, + const int* input_lengths, + const int* tiled_prefix_prompt_lengths, + const size_t memory_len, + const size_t max_input_length, + const size_t initial_step, + size_t batch_size, + size_t beam_width, + cudaStream_t stream = 0); + +inline void invokeMaskPaddingTokens(bool* masked_tokens, + const int* input_lengths, + const size_t memory_len, + const size_t max_input_length, + const size_t initial_step, + size_t batch_size, + size_t beam_width, + cudaStream_t stream = 0) +{ + invokeMaskPaddingTokens(masked_tokens, + input_lengths, + (const int*)nullptr, + memory_len, + max_input_length, + initial_step, + batch_size, + beam_width, + stream); +} + } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/layernorm_int8_kernels.cu b/src/fastertransformer/kernels/layernorm_int8_kernels.cu index 297f4dd1e..1be6f5364 100644 --- a/src/fastertransformer/kernels/layernorm_int8_kernels.cu +++ b/src/fastertransformer/kernels/layernorm_int8_kernels.cu @@ -22,27 +22,27 @@ namespace fastertransformer { // input1/input2/output matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n) //(grid, block) must be (m, n) // for per_channel_quantization for weight -__global__ void add_bias_input_layernorm_COL32_int32I_DataTypeO(float* output, +__global__ void add_bias_input_layernorm_COL32_int32I_DataTypeO(float* output, const int32_t* input1, - const float* input2, - const float* bias, - const float* gamma, - const float* beta, - int m, - int n, - const float* weight_amax, - const float* input1_amax_ptr) + const float* input2, + const float* bias, + const float* gamma, + const float* beta, + int m, + int n, + const float* weight_amax, + const float* input1_amax_ptr) { const float input1_amax = __ldg(input1_amax_ptr); - int col_start = threadIdx.x; + int col_start = threadIdx.x; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float local_out; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)); + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)); float tmp = static_cast(__ldg(input1 + outIdx)) * __ldg(weight_amax + col_start) * input1_amax * 0.000062f; //(1/127/127); @@ -70,16 +70,16 @@ __global__ void add_bias_input_layernorm_COL32_int32I_DataTypeO(float* output, output[outIdx] = local_out; } -__global__ void add_bias_input_layernorm_COL32_int32I_DataTypeO(half2* output, - const int2* input1, - const half2* input2, - const half2* bias, - const half2* gamma, - const half2* beta, - int m, - int n, +__global__ void add_bias_input_layernorm_COL32_int32I_DataTypeO(half2* output, + const int2* input1, + const half2* input2, + const half2* bias, + const half2* gamma, + const half2* beta, + int m, + int n, const float2* weight_amax, - const float* input1_amax_ptr) + const float* input1_amax_ptr) { int col_start = threadIdx.x << 1; @@ -87,13 +87,13 @@ __global__ void add_bias_input_layernorm_COL32_int32I_DataTypeO(half2* output, __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float2 local_out; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; - const int2 input1Tmp = __ldg(input1 + outIdx); + const int2 input1Tmp = __ldg(input1 + outIdx); const float2 weightTmp = __ldg(weight_amax + threadIdx.x); float2 addTmp2; @@ -101,9 +101,9 @@ __global__ void add_bias_input_layernorm_COL32_int32I_DataTypeO(half2* output, addTmp2.y = static_cast(input1Tmp.y) * weightTmp.y * input1_amax * 0.000062f; //(1/127/127); const half2 inputTmp = __ldg(input2 + outIdx); - const half2 biasTmp = __ldg(bias + threadIdx.x); + const half2 biasTmp = __ldg(bias + threadIdx.x); - local_out = __half22float2(__hadd2(inputTmp, biasTmp)); + local_out = __half22float2(__hadd2(inputTmp, biasTmp)); local_out.x = local_out.x + addTmp2.x; local_out.y = local_out.y + addTmp2.y; @@ -123,9 +123,9 @@ __global__ void add_bias_input_layernorm_COL32_int32I_DataTypeO(half2* output, } __syncthreads(); - float2 outputTmp; + float2 outputTmp; const half2 gammaTmp = __ldg(gamma + threadIdx.x); - const half2 betaTmp = __ldg(beta + threadIdx.x); + const half2 betaTmp = __ldg(beta + threadIdx.x); outputTmp.x = (local_out.x * s_variance) * static_cast(gammaTmp.x) + static_cast(betaTmp.x); outputTmp.y = (local_out.y * s_variance) * static_cast(gammaTmp.y) + static_cast(betaTmp.y); @@ -134,17 +134,17 @@ __global__ void add_bias_input_layernorm_COL32_int32I_DataTypeO(half2* output, } template -void invokeAddBiasResidualLayerNormCol32(T* output, +void invokeAddBiasResidualLayerNormCol32(T* output, const int32_t* input1, - const T* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* weight_amax, - const float* input1_amax_ptr) + const T* input2, + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* weight_amax, + const float* input1_amax_ptr) { dim3 grid(m); @@ -178,77 +178,77 @@ void invokeAddBiasResidualLayerNormCol32(T* output, } } -template void invokeAddBiasResidualLayerNormCol32(float* output, +template void invokeAddBiasResidualLayerNormCol32(float* output, const int32_t* input1, - const float* input2, - const float* bias, - const float* gamma, - const float* beta, - int m, - int n, - cudaStream_t stream, - const float* weight_amax, - const float* input1_amax_ptr); -template void invokeAddBiasResidualLayerNormCol32(half* output, + const float* input2, + const float* bias, + const float* gamma, + const float* beta, + int m, + int n, + cudaStream_t stream, + const float* weight_amax, + const float* input1_amax_ptr); +template void invokeAddBiasResidualLayerNormCol32(half* output, const int32_t* input1, - const half* input2, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, - cudaStream_t stream, - const float* weight_amax, - const float* input1_amax_ptr); + const half* input2, + const half* bias, + const half* gamma, + const half* beta, + int m, + int n, + cudaStream_t stream, + const float* weight_amax, + const float* input1_amax_ptr); // input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n) //(grid, block) must be (m, n/4) // using char4 template -__global__ void add_bias_input_layernorm_COL32_int8IO(int8_t* output, +__global__ void add_bias_input_layernorm_COL32_int8IO(int8_t* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr) + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); const float input2_deQFactor = __ldg(input2_deQFactor_ptr); - const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 2; + const float output_scale = __ldg(output_scale_ptr); + int col_start = threadIdx.x << 2; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; - float local_out[4]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; - char4* outTmpPtr = (char4*)output; + float local_out[4]; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + char4* outTmpPtr = (char4*)output; char4* input1TmpPtr = (char4*)input1; char4* input2TmpPtr = (char4*)input2; - char4 input1Tmp = __ldg(input1TmpPtr + outIdx); - char4 input2Tmp = __ldg(input2TmpPtr + outIdx); + char4 input1Tmp = __ldg(input1TmpPtr + outIdx); + char4 input2Tmp = __ldg(input2TmpPtr + outIdx); int col_start_tmp = col_start; - local_out[0] = static_cast(input2Tmp.x) * input2_deQFactor + local_out[0] = static_cast(input2Tmp.x) * input2_deQFactor + static_cast(input1Tmp.x) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[1] = static_cast(input2Tmp.y) * input2_deQFactor + local_out[1] = static_cast(input2Tmp.y) * input2_deQFactor + static_cast(input1Tmp.y) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[2] = static_cast(input2Tmp.z) * input2_deQFactor + local_out[2] = static_cast(input2Tmp.z) * input2_deQFactor + static_cast(input1Tmp.z) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[3] = static_cast(input2Tmp.w) * input2_deQFactor + local_out[3] = static_cast(input2Tmp.w) * input2_deQFactor + static_cast(input1Tmp.w) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); @@ -262,7 +262,7 @@ __global__ void add_bias_input_layernorm_COL32_int8IO(int8_t* output, local_out[1] = local_out[1] - s_mean; local_out[2] = local_out[2] - s_mean; local_out[3] = local_out[3] - s_mean; - variance = blockReduceSum(local_out[0] * local_out[0] + local_out[1] * local_out[1] + variance = blockReduceSum(local_out[0] * local_out[0] + local_out[1] * local_out[1] + local_out[2] * local_out[2] + local_out[3] * local_out[3]); if (threadIdx.x == 0) { s_variance = variance * __fdividef(1.0f, n) + 1e-6f; @@ -274,17 +274,17 @@ __global__ void add_bias_input_layernorm_COL32_int8IO(int8_t* output, + static_cast(__ldg(beta + col_start)); input2Tmp.x = float_to_int8_rn(local_out[0] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[1] = (local_out[1] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); input2Tmp.y = float_to_int8_rn(local_out[1] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[2] = (local_out[2] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); input2Tmp.z = float_to_int8_rn(local_out[2] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[3] = (local_out[3] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); input2Tmp.w = float_to_int8_rn(local_out[3] * output_scale); @@ -293,50 +293,50 @@ __global__ void add_bias_input_layernorm_COL32_int8IO(int8_t* output, } template<> -__global__ void add_bias_input_layernorm_COL32_int8IO(int8_t* output, +__global__ void add_bias_input_layernorm_COL32_int8IO(int8_t* output, const int8_t* input1, const int8_t* input2, - const half2* bias, - const half2* gamma, - const half2* beta, - int m, - int n, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr) + const half2* bias, + const half2* gamma, + const half2* beta, + int m, + int n, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); const float input2_deQFactor = __ldg(input2_deQFactor_ptr); - const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 2; + const float output_scale = __ldg(output_scale_ptr); + int col_start = threadIdx.x << 2; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; - float local_out[4]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; - char4* outTmpPtr = (char4*)output; + float local_out[4]; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + char4* outTmpPtr = (char4*)output; char4* input1TmpPtr = (char4*)input1; char4* input2TmpPtr = (char4*)input2; - char4 input1Tmp = __ldg(input1TmpPtr + outIdx); - char4 input2Tmp = __ldg(input2TmpPtr + outIdx); + char4 input1Tmp = __ldg(input1TmpPtr + outIdx); + char4 input2Tmp = __ldg(input2TmpPtr + outIdx); - int col_start_tmp = col_start; - half2 biasTmp = __ldg(bias + (col_start_tmp >> 1)); - local_out[0] = static_cast(input2Tmp.x) * input2_deQFactor + int col_start_tmp = col_start; + half2 biasTmp = __ldg(bias + (col_start_tmp >> 1)); + local_out[0] = static_cast(input2Tmp.x) * input2_deQFactor + static_cast(input1Tmp.x) * input1_deQFactor + static_cast(biasTmp.x); col_start_tmp = col_start_tmp + 1; - local_out[1] = static_cast(input2Tmp.y) * input2_deQFactor + local_out[1] = static_cast(input2Tmp.y) * input2_deQFactor + static_cast(input1Tmp.y) * input1_deQFactor + static_cast(biasTmp.y); col_start_tmp = col_start_tmp + 1; - biasTmp = __ldg(bias + (col_start_tmp >> 1)); - local_out[2] = static_cast(input2Tmp.z) * input2_deQFactor + biasTmp = __ldg(bias + (col_start_tmp >> 1)); + local_out[2] = static_cast(input2Tmp.z) * input2_deQFactor + static_cast(input1Tmp.z) * input1_deQFactor + static_cast(biasTmp.x); col_start_tmp = col_start_tmp + 1; - local_out[3] = static_cast(input2Tmp.w) * input2_deQFactor + local_out[3] = static_cast(input2Tmp.w) * input2_deQFactor + static_cast(input1Tmp.w) * input1_deQFactor + static_cast(biasTmp.y); mean = blockReduceSum(local_out[0] + local_out[1] + local_out[2] + local_out[3]); @@ -349,7 +349,7 @@ __global__ void add_bias_input_layernorm_COL32_int8IO(int8_t* output, local_out[1] = local_out[1] - s_mean; local_out[2] = local_out[2] - s_mean; local_out[3] = local_out[3] - s_mean; - variance = blockReduceSum(local_out[0] * local_out[0] + local_out[1] * local_out[1] + variance = blockReduceSum(local_out[0] * local_out[0] + local_out[1] * local_out[1] + local_out[2] * local_out[2] + local_out[3] * local_out[3]); if (threadIdx.x == 0) { s_variance = variance * __fdividef(1.0f, n) + 1e-6f; @@ -358,43 +358,43 @@ __global__ void add_bias_input_layernorm_COL32_int8IO(int8_t* output, __syncthreads(); col_start_tmp = col_start >> 1; - biasTmp = __ldg(gamma + col_start_tmp); + biasTmp = __ldg(gamma + col_start_tmp); half2 betaTmp = __ldg(beta + col_start_tmp); local_out[0] = (local_out[0] * s_variance) * static_cast(biasTmp.x) + static_cast(betaTmp.x); - input2Tmp.x = float_to_int8_rn(local_out[0] * output_scale); + input2Tmp.x = float_to_int8_rn(local_out[0] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[1] = (local_out[1] * s_variance) * static_cast(biasTmp.y) + static_cast(betaTmp.y); - input2Tmp.y = float_to_int8_rn(local_out[1] * output_scale); + input2Tmp.y = float_to_int8_rn(local_out[1] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; col_start_tmp = col_start >> 1; - biasTmp = __ldg(gamma + col_start_tmp); - betaTmp = __ldg(beta + col_start_tmp); - local_out[2] = (local_out[2] * s_variance) * static_cast(biasTmp.x) + static_cast(betaTmp.x); - input2Tmp.z = float_to_int8_rn(local_out[2] * output_scale); + biasTmp = __ldg(gamma + col_start_tmp); + betaTmp = __ldg(beta + col_start_tmp); + local_out[2] = (local_out[2] * s_variance) * static_cast(biasTmp.x) + static_cast(betaTmp.x); + input2Tmp.z = float_to_int8_rn(local_out[2] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[3] = (local_out[3] * s_variance) * static_cast(biasTmp.y) + static_cast(betaTmp.y); - input2Tmp.w = float_to_int8_rn(local_out[3] * output_scale); + input2Tmp.w = float_to_int8_rn(local_out[3] * output_scale); outTmpPtr[outIdx] = input2Tmp; } template -void invokeAddBiasResidualLayerNormCol32(int8_t* output, +void invokeAddBiasResidualLayerNormCol32(int8_t* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr) + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr) { dim3 grid(m); dim3 block(n / 4); @@ -427,57 +427,57 @@ void invokeAddBiasResidualLayerNormCol32(int8_t* output, } } -template void invokeAddBiasResidualLayerNormCol32(int8_t* output, +template void invokeAddBiasResidualLayerNormCol32(int8_t* output, const int8_t* input1, const int8_t* input2, - const float* bias, - const float* gamma, - const float* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr); - -template void invokeAddBiasResidualLayerNormCol32(int8_t* output, + const float* bias, + const float* gamma, + const float* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr); + +template void invokeAddBiasResidualLayerNormCol32(int8_t* output, const int8_t* input1, const int8_t* input2, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr); + const half* bias, + const half* gamma, + const half* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr); // input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n) //(grid, block) must be (m, n) template -__global__ void add_bias_input_layernorm_COL32_int8I_DataTypeO(T* output, +__global__ void add_bias_input_layernorm_COL32_int8I_DataTypeO(T* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr) + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); const float input2_deQFactor = __ldg(input2_deQFactor_ptr); - int col_start = threadIdx.x; + int col_start = threadIdx.x; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float local_out; - int idx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)); + int idx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)); local_out = static_cast(__ldg(input2 + idx)) * input2_deQFactor + static_cast(__ldg(input1 + idx)) * input1_deQFactor @@ -508,33 +508,33 @@ __global__ void add_bias_input_layernorm_COL32_int8I_DataTypeO(T* output, // input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n) //(grid, block) must be (m, n/2) template<> -__global__ void add_bias_input_layernorm_COL32_int8I_DataTypeO(half2* output, +__global__ void add_bias_input_layernorm_COL32_int8I_DataTypeO(half2* output, const int8_t* input1, const int8_t* input2, - const half2* bias, - const half2* gamma, - const half2* beta, - int m, - int n, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr) + const half2* bias, + const half2* gamma, + const half2* beta, + int m, + int n, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); const float input2_deQFactor = __ldg(input2_deQFactor_ptr); - int col_start = threadIdx.x << 1; + int col_start = threadIdx.x << 1; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float2 local_out; - int idx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; + int idx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; const char2* input1_ptr2 = (const char2*)input1; const char2* input2_ptr2 = (const char2*)input2; - char2 input_tmp1 = __ldg(input1_ptr2 + idx); - char2 input_tmp2 = __ldg(input2_ptr2 + idx); + char2 input_tmp1 = __ldg(input1_ptr2 + idx); + char2 input_tmp2 = __ldg(input2_ptr2 + idx); half2 bias_tmp = __ldg(bias + threadIdx.x); @@ -563,7 +563,7 @@ __global__ void add_bias_input_layernorm_COL32_int8I_DataTypeO(half2* output, __syncthreads(); half2 gamma_tmp = __ldg(gamma + threadIdx.x); - half2 beta_tmp = __ldg(beta + threadIdx.x); + half2 beta_tmp = __ldg(beta + threadIdx.x); local_out.x = (local_out.x * s_variance) * static_cast(gamma_tmp.x) + static_cast(beta_tmp.x); local_out.y = (local_out.y * s_variance) * static_cast(gamma_tmp.y) + static_cast(beta_tmp.y); @@ -575,17 +575,17 @@ __global__ void add_bias_input_layernorm_COL32_int8I_DataTypeO(half2* output, } template -void invokeAddBiasResidualLayerNormCol32(T* output, +void invokeAddBiasResidualLayerNormCol32(T* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr) + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr) { dim3 grid(m); dim3 block(n); @@ -610,78 +610,78 @@ void invokeAddBiasResidualLayerNormCol32(T* output, } } -template void invokeAddBiasResidualLayerNormCol32(float* output, +template void invokeAddBiasResidualLayerNormCol32(float* output, const int8_t* input1, const int8_t* input2, - const float* bias, - const float* gamma, - const float* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr); - -template void invokeAddBiasResidualLayerNormCol32(half* output, + const float* bias, + const float* gamma, + const float* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr); + +template void invokeAddBiasResidualLayerNormCol32(half* output, const int8_t* input1, const int8_t* input2, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr); + const half* bias, + const half* gamma, + const half* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr); // input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n) //(grid, block) must be (m, n/4) // using char4 template -__global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, - int8_t* input1, - T* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, +__global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, + int8_t* input1, + T* input2, + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, const float* input1_deQFactor_ptr, const float* output_scale_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); - const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 2; - bool qual = (col_start < n); + const float output_scale = __ldg(output_scale_ptr); + int col_start = threadIdx.x << 2; + bool qual = (col_start < n); __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; - float local_out[4]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + float local_out[4]; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; char4* input1TmpPtr = (char4*)input1; - char4 input1Tmp; + char4 input1Tmp; if (qual) { - input1Tmp = __ldg(input1TmpPtr + outIdx); + input1Tmp = __ldg(input1TmpPtr + outIdx); int col_start_tmp = col_start; - local_out[0] = static_cast(input1Tmp.x) * input1_deQFactor + local_out[0] = static_cast(input1Tmp.x) * input1_deQFactor + static_cast(input2[(outIdx << 2) + 0]) + static_cast(bias[col_start_tmp]); input2[(outIdx << 2) + 0] = local_out[0]; col_start_tmp = col_start_tmp + 1; - local_out[1] = static_cast(input1Tmp.y) * input1_deQFactor + local_out[1] = static_cast(input1Tmp.y) * input1_deQFactor + static_cast(input2[(outIdx << 2) + 1]) + static_cast(bias[col_start_tmp]); input2[(outIdx << 2) + 1] = local_out[1]; col_start_tmp = col_start_tmp + 1; - local_out[2] = static_cast(input1Tmp.z) * input1_deQFactor + local_out[2] = static_cast(input1Tmp.z) * input1_deQFactor + static_cast(input2[(outIdx << 2) + 2]) + static_cast(bias[col_start_tmp]); input2[(outIdx << 2) + 2] = local_out[2]; col_start_tmp = col_start_tmp + 1; - local_out[3] = static_cast(input1Tmp.w) * input1_deQFactor + local_out[3] = static_cast(input1Tmp.w) * input1_deQFactor + static_cast(input2[(outIdx << 2) + 3]) + static_cast(bias[col_start_tmp]); input2[(outIdx << 2) + 3] = local_out[3]; } @@ -707,24 +707,24 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, } __syncthreads(); - char4 outputTmp; + char4 outputTmp; char4* outputTmpPtr = (char4*)output; if (qual) { local_out[0] = (local_out[0] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outputTmp.x = float_to_int8_rn(local_out[0] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[1] = (local_out[1] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outputTmp.y = float_to_int8_rn(local_out[1] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[2] = (local_out[2] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outputTmp.z = float_to_int8_rn(local_out[2] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[3] = (local_out[3] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outputTmp.w = float_to_int8_rn(local_out[3] * output_scale); @@ -734,43 +734,43 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, } template<> -__global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, - int8_t* input1, - half2* input2, +__global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, + int8_t* input1, + half2* input2, const half2* bias, const half2* gamma, const half2* beta, - int m, - int n, + int m, + int n, const float* input1_deQFactor_ptr, const float* output_scale_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); - const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 1; - bool qual = (col_start < n); + const float output_scale = __ldg(output_scale_ptr); + int col_start = threadIdx.x << 1; + bool qual = (col_start < n); __shared__ float s_mean; __shared__ float s_variance; - float sums[2] = {0.0f, 0.0f}; + float sums[2] = {0.0f, 0.0f}; // float mean = 0.0f; // float variance = 0.0f; - float local_out[2]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; + float local_out[2]; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; char2* input1TmpPtr = (char2*)input1; - char2 input1Tmp; + char2 input1Tmp; if (qual) { - input1Tmp = input1TmpPtr[outIdx]; - half2 biasTmp = bias[threadIdx.x]; + input1Tmp = input1TmpPtr[outIdx]; + half2 biasTmp = bias[threadIdx.x]; half2 input2Tmp = input2[outIdx]; - local_out[0] = static_cast(input1Tmp.x) * input1_deQFactor + static_cast(input2Tmp.x) + local_out[0] = static_cast(input1Tmp.x) * input1_deQFactor + static_cast(input2Tmp.x) + static_cast(biasTmp.x); local_out[1] = static_cast(input1Tmp.y) * input1_deQFactor + static_cast(input2Tmp.y) + static_cast(biasTmp.y); - input2Tmp.x = local_out[0]; - input2Tmp.y = local_out[1]; + input2Tmp.x = local_out[0]; + input2Tmp.y = local_out[1]; input2[outIdx] = input2Tmp; for (int i = 0; i < 2; i++) { sums[0] += local_out[i]; @@ -781,16 +781,16 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, blockReduceSumV2(sums); if (threadIdx.x == 0) { - s_mean = sums[0] * __fdividef(1.0f, n); + s_mean = sums[0] * __fdividef(1.0f, n); s_variance = rsqrtf(sums[1] * __fdividef(1.0f, n) - s_mean * s_mean + 1e-6); } __syncthreads(); - char2 outputTmp; + char2 outputTmp; char2* outputTmpPtr = (char2*)output; if (qual) { half2 gammaTmp = gamma[threadIdx.x]; - half2 betaTmp = beta[threadIdx.x]; + half2 betaTmp = beta[threadIdx.x]; local_out[0] = (local_out[0] - s_mean) * s_variance * static_cast(gammaTmp.x) + static_cast(betaTmp.x); outputTmp.x = float_to_int8_rn(local_out[0] * output_scale); @@ -804,20 +804,20 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, } template -void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, - int8_t* input1, - T* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, +void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, + int8_t* input1, + T* input2, + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, cudaStream_t stream, const float* input1_deQFactor_ptr, const float* output_scale_ptr) { dim3 grid(m); - int blockSize = (n / 4 + 31) / 32 * 32; + int blockSize = (n / 4 + 31) / 32 * 32; dim3 block(blockSize); assert(blockSize <= 1024); if (sizeof(T) == sizeof(half)) { @@ -840,26 +840,26 @@ void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, } } -template void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, - int8_t* input1, - float* input2, +template void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, + int8_t* input1, + float* input2, const float* bias, const float* gamma, const float* beta, - int m, - int n, + int m, + int n, cudaStream_t stream, const float* input1_deQFactor_ptr, const float* output_scale_ptr); -template void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, - int8_t* input1, - half* input2, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, +template void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, + int8_t* input1, + half* input2, + const half* bias, + const half* gamma, + const half* beta, + int m, + int n, cudaStream_t stream, const float* input1_deQFactor_ptr, const float* output_scale_ptr); @@ -868,34 +868,34 @@ template void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, //(grid, block) must be (m, n/4) // using char4 template -__global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, - int32_t* input1, - T* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, +__global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, + int32_t* input1, + T* input2, + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, const float* weight_amax, const float* input1_amax_ptr, const float* output_scale_ptr) { - const float input1_amax = __ldg(input1_amax_ptr); + const float input1_amax = __ldg(input1_amax_ptr); const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 2; - bool qual = (col_start < n); + int col_start = threadIdx.x << 2; + bool qual = (col_start < n); __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float local_out[4]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; int4* input1TmpPtr = (int4*)input1; - int4 input1Tmp; + int4 input1Tmp; if (qual) { - input1Tmp = __ldg(input1TmpPtr + outIdx); + input1Tmp = __ldg(input1TmpPtr + outIdx); int col_start_tmp = col_start; // NOTE: 0.000062f = 1 / 127/.0f / 127.0f local_out[0] = static_cast(input1Tmp.x) * input1_amax * weight_amax[col_start_tmp] * 0.000062f @@ -903,17 +903,17 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, input2[(outIdx << 2) + 0] = local_out[0]; col_start_tmp = col_start_tmp + 1; - local_out[1] = static_cast(input1Tmp.y) * input1_amax * weight_amax[col_start_tmp] * 0.000062f + local_out[1] = static_cast(input1Tmp.y) * input1_amax * weight_amax[col_start_tmp] * 0.000062f + static_cast(input2[(outIdx << 2) + 1]) + static_cast(bias[col_start_tmp]); input2[(outIdx << 2) + 1] = local_out[1]; col_start_tmp = col_start_tmp + 1; - local_out[2] = static_cast(input1Tmp.z) * input1_amax * weight_amax[col_start_tmp] * 0.000062f + local_out[2] = static_cast(input1Tmp.z) * input1_amax * weight_amax[col_start_tmp] * 0.000062f + static_cast(input2[(outIdx << 2) + 2]) + static_cast(bias[col_start_tmp]); input2[(outIdx << 2) + 2] = local_out[2]; col_start_tmp = col_start_tmp + 1; - local_out[3] = static_cast(input1Tmp.w) * input1_amax * weight_amax[col_start_tmp] * 0.000062f + local_out[3] = static_cast(input1Tmp.w) * input1_amax * weight_amax[col_start_tmp] * 0.000062f + static_cast(input2[(outIdx << 2) + 3]) + static_cast(bias[col_start_tmp]); input2[(outIdx << 2) + 3] = local_out[3]; } @@ -940,23 +940,23 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, __syncthreads(); char4* outputTmpPtr = (char4*)output; - char4 outputTmp; + char4 outputTmp; if (qual) { local_out[0] = (local_out[0] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outputTmp.x = float_to_int8_rn(local_out[0] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[1] = (local_out[1] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outputTmp.y = float_to_int8_rn(local_out[1] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[2] = (local_out[2] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outputTmp.z = float_to_int8_rn(local_out[2] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[3] = (local_out[3] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outputTmp.w = float_to_int8_rn(local_out[3] * output_scale); @@ -966,39 +966,39 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, } template<> -__global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, - int32_t* input1, - half2* input2, +__global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, + int32_t* input1, + half2* input2, const half2* bias, const half2* gamma, const half2* beta, - int m, - int n, + int m, + int n, const float* weight_amax, const float* input1_amax_ptr, const float* output_scale_ptr) { const float2* weight_scale_ptr = (const float2*)weight_amax; - const float input1_amax = __ldg(input1_amax_ptr); - const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 1; - bool qual = (col_start < n); + const float input1_amax = __ldg(input1_amax_ptr); + const float output_scale = __ldg(output_scale_ptr); + int col_start = threadIdx.x << 1; + bool qual = (col_start < n); __shared__ float s_mean; __shared__ float s_variance; - float sums[2] = {0.0f, 0.0f}; + float sums[2] = {0.0f, 0.0f}; // float mean = 0.0f; // float variance = 0.0f; float local_out[2]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; int2* input1TmpPtr = (int2*)input1; - int2 input1Tmp; + int2 input1Tmp; if (qual) { const float2 weight_scale = __ldg(weight_scale_ptr + threadIdx.x); - input1Tmp = input1TmpPtr[outIdx]; - half2 biasTmp = bias[threadIdx.x]; - half2 input2Tmp = input2[outIdx]; + input1Tmp = input1TmpPtr[outIdx]; + half2 biasTmp = bias[threadIdx.x]; + half2 input2Tmp = input2[outIdx]; // NOTE: 0.000062f = 1 / 127/.0f / 127.0f local_out[0] = static_cast(input1Tmp.x) * input1_amax * weight_scale.x * 0.000062f + static_cast(biasTmp.x); @@ -1008,8 +1008,8 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, local_out[0] += static_cast(input2Tmp.x); local_out[1] += static_cast(input2Tmp.y); - input2Tmp.x = local_out[0]; - input2Tmp.y = local_out[1]; + input2Tmp.x = local_out[0]; + input2Tmp.y = local_out[1]; input2[outIdx] = input2Tmp; for (int i = 0; i < 2; i++) { sums[0] += local_out[i]; @@ -1020,16 +1020,16 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, blockReduceSumV2(sums); if (threadIdx.x == 0) { - s_mean = sums[0] * __fdividef(1.0f, n); + s_mean = sums[0] * __fdividef(1.0f, n); s_variance = rsqrtf(sums[1] * __fdividef(1.0f, n) - s_mean * s_mean + 1e-6); } __syncthreads(); char2* outputTmpPtr = (char2*)output; - char2 outputTmp; + char2 outputTmp; if (qual) { half2 gammaTmp = gamma[threadIdx.x]; - half2 betaTmp = beta[threadIdx.x]; + half2 betaTmp = beta[threadIdx.x]; local_out[0] = (local_out[0] - s_mean) * s_variance * static_cast(gammaTmp.x) + static_cast(betaTmp.x); outputTmp.x = float_to_int8_rn(local_out[0] * output_scale); @@ -1043,21 +1043,21 @@ __global__ void add_bias_input_layernorm_COL32_int8IO_noRes(int8_t* output, } template -void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, - int32_t* input1, - T* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, +void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, + int32_t* input1, + T* input2, + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, cudaStream_t stream, const float* weight_amax, const float* input1_amax_ptr, const float* output_scale_ptr) { dim3 grid(m); - int blockSize = (n / 4 + 31) / 32 * 32; + int blockSize = (n / 4 + 31) / 32 * 32; dim3 block(blockSize); assert(blockSize <= 1024); if (sizeof(T) == sizeof(half)) { @@ -1081,27 +1081,27 @@ void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, } } -template void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, - int32_t* input1, - float* input2, +template void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, + int32_t* input1, + float* input2, const float* bias, const float* gamma, const float* beta, - int m, - int n, + int m, + int n, cudaStream_t stream, const float* weight_amax, const float* input1_amax_ptr, const float* output_scale_ptr); -template void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, - int32_t* input1, - half* input2, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, +template void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, + int32_t* input1, + half* input2, + const half* bias, + const half* gamma, + const half* beta, + int m, + int n, cudaStream_t stream, const float* weight_amax, const float* input1_amax_ptr, @@ -1117,16 +1117,16 @@ __global__ void layernorm_COL32_DataTypeI_int8O( int8_t* out, const T* input, const T* gamma, const T* beta, int m, int n, const float* output_scale_ptr) { const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 2; - bool qual = (col_start < n); + int col_start = threadIdx.x << 2; + bool qual = (col_start < n); __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float local_out[4]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2; if (qual) { local_out[0] = static_cast(input[(outIdx << 2) + 0]); @@ -1156,24 +1156,24 @@ __global__ void layernorm_COL32_DataTypeI_int8O( } __syncthreads(); - char4 outTmp; + char4 outTmp; char4* outPtr = (char4*)out; if (qual) { local_out[0] = (local_out[0] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outTmp.x = float_to_int8_rn(local_out[0] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[1] = (local_out[1] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outTmp.y = float_to_int8_rn(local_out[1] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[2] = (local_out[2] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outTmp.z = float_to_int8_rn(local_out[2] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[3] = (local_out[3] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); outTmp.w = float_to_int8_rn(local_out[3] * output_scale); @@ -1187,22 +1187,22 @@ __global__ void layernorm_COL32_DataTypeI_int8O( int8_t* out, const half2* input, const half2* gamma, const half2* beta, int m, int n, const float* output_scale_ptr) { const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 1; - bool qual = (col_start < n); + int col_start = threadIdx.x << 1; + bool qual = (col_start < n); __shared__ float s_mean; __shared__ float s_variance; - float sums[2] = {0.0f, 0.0f}; + float sums[2] = {0.0f, 0.0f}; // float mean = 0.0f; // float variance = 0.0f; float local_out[2]; - int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; + int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 1; if (qual) { half2 inputTmp = input[outIdx]; - local_out[0] = static_cast(inputTmp.x); - local_out[1] = static_cast(inputTmp.y); + local_out[0] = static_cast(inputTmp.x); + local_out[1] = static_cast(inputTmp.y); for (int i = 0; i < 2; i++) { sums[0] += local_out[i]; @@ -1213,16 +1213,16 @@ __global__ void layernorm_COL32_DataTypeI_int8O( blockReduceSumV2(sums); if (threadIdx.x == 0) { - s_mean = sums[0] * __fdividef(1.0f, n); + s_mean = sums[0] * __fdividef(1.0f, n); s_variance = rsqrtf(sums[1] * __fdividef(1.0f, n) - s_mean * s_mean + 1e-6); } __syncthreads(); - char2 outTmp; + char2 outTmp; char2* outPtr = (char2*)out; if (qual) { half2 gammaTmp = gamma[threadIdx.x]; - half2 betaTmp = beta[threadIdx.x]; + half2 betaTmp = beta[threadIdx.x]; local_out[0] = (local_out[0] - s_mean) * s_variance * static_cast(gammaTmp.x) + static_cast(betaTmp.x); outTmp.x = float_to_int8_rn(local_out[0] * output_scale); @@ -1236,17 +1236,17 @@ __global__ void layernorm_COL32_DataTypeI_int8O( } template -void invokeLayernormCol32(int8_t* out, - const T* input, - const T* gamma, - const T* beta, - int m, - int n, +void invokeLayernormCol32(int8_t* out, + const T* input, + const T* gamma, + const T* beta, + int m, + int n, const float* output_scale_ptr, cudaStream_t stream) { dim3 grid(m); - int blockSize = (n / 4 + 31) / 32 * 32; + int blockSize = (n / 4 + 31) / 32 * 32; dim3 block(blockSize); assert(blockSize <= 1024); if (sizeof(T) == sizeof(half)) { @@ -1260,59 +1260,59 @@ void invokeLayernormCol32(int8_t* out, } } -template void invokeLayernormCol32(int8_t* out, +template void invokeLayernormCol32(int8_t* out, const float* input, const float* gamma, const float* beta, - int m, - int n, + int m, + int n, const float* output_scale_ptr, cudaStream_t stream); -template void invokeLayernormCol32(int8_t* out, - const half* input, - const half* gamma, - const half* beta, - int m, - int n, +template void invokeLayernormCol32(int8_t* out, + const half* input, + const half* gamma, + const half* beta, + int m, + int n, const float* output_scale_ptr, cudaStream_t stream); /******************* invokeLayernormShiftPartitionCol32 ***********************/ template -__global__ void layernorm_shift_partition_COL32_noRes(int8_t* out, - const T* input, - const T* gamma, - const T* beta, - int batch, - int H, - int W, - int n, +__global__ void layernorm_shift_partition_COL32_noRes(int8_t* out, + const T* input, + const T* gamma, + const T* beta, + int batch, + int H, + int W, + int n, const float* norm_scale_ptr, - int shift_size, - int window_size) + int shift_size, + int window_size) { - float norm_scale = __ldg(norm_scale_ptr); - int tid = threadIdx.x; - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int m = gridDim.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; - const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; - const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; - const int window_H_idx = shifted_H_idx / window_size; - const int window_W_idx = shifted_W_idx / window_size; + float norm_scale = __ldg(norm_scale_ptr); + int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int m = gridDim.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; + const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; const int stride_of_window_H = W / window_size; - const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; - const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); - const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); + const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; const int offset_col32_in = (tid & 0xffffffe0) * m + (bid << 5) + (tid & 31); - float local_out = (tid < n) ? (float)(__ldg(input + offset_col32_in)) : 0.0f; + float local_out = (tid < n) ? (float)(__ldg(input + offset_col32_in)) : 0.0f; mean = blockReduceSum(local_out); if (threadIdx.x == 0) { @@ -1321,7 +1321,7 @@ __global__ void layernorm_shift_partition_COL32_noRes(int8_t* out, __syncthreads(); float diff = (tid < n) ? (local_out - s_mean) : 0.0f; - variance = blockReduceSum(diff * diff); + variance = blockReduceSum(diff * diff); if (threadIdx.x == 0) { s_variance = variance / n + 1e-6f; } @@ -1336,46 +1336,46 @@ __global__ void layernorm_shift_partition_COL32_noRes(int8_t* out, } template<> -__global__ void layernorm_shift_partition_COL32_noRes(int8_t* out_ptr, +__global__ void layernorm_shift_partition_COL32_noRes(int8_t* out_ptr, const half4* input_ptr, const half4* gamma_ptr, const half4* beta_ptr, - int batch, - int H, - int W, - int n, + int batch, + int H, + int W, + int n, const float* norm_scale_ptr, - int shift_size, - int window_size) + int shift_size, + int window_size) { float norm_scale = __ldg(norm_scale_ptr); const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int m = gridDim.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int m = gridDim.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; - const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; - const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; - const int window_H_idx = shifted_H_idx / window_size; - const int window_W_idx = shifted_W_idx / window_size; + const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; + const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; const int stride_of_window_H = W / window_size; - const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; - const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); - const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); + const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; - int tid = threadIdx.x << 2; + int tid = threadIdx.x << 2; __shared__ float s_mean; __shared__ float s_variance; - float sums[2] = {0.0f, 0.0f}; - half4 inputTmp; - float inputBuf[4]; + float sums[2] = {0.0f, 0.0f}; + half4 inputTmp; + float inputBuf[4]; char4* output_ptr = (char4*)out_ptr; - char4 int8_buf; + char4 int8_buf; const int offset_col32 = (tid & 0xffffffe0) * m + (bid << 5) + (tid & 31); if (tid < n) { - inputTmp = input_ptr[offset_col32 >> 2]; + inputTmp = input_ptr[offset_col32 >> 2]; inputBuf[0] = static_cast(inputTmp.x); inputBuf[1] = static_cast(inputTmp.y); inputBuf[2] = static_cast(inputTmp.z); @@ -1389,14 +1389,14 @@ __global__ void layernorm_shift_partition_COL32_noRes(int8_t* out_ptr, blockReduceSumV2(sums); if (threadIdx.x == 0) { - s_mean = sums[0] / n; + s_mean = sums[0] / n; s_variance = rsqrtf(sums[1] / n - s_mean * s_mean + 1e-6); } __syncthreads(); if (tid < n) { half4 gamma_val = gamma_ptr[tid >> 2]; - half4 beta_val = beta_ptr[tid >> 2]; + half4 beta_val = beta_ptr[tid >> 2]; inputBuf[0] = (inputBuf[0] - s_mean) * s_variance * static_cast(gamma_val.x) + static_cast(beta_val.x); inputBuf[1] = @@ -1409,10 +1409,10 @@ __global__ void layernorm_shift_partition_COL32_noRes(int8_t* out_ptr, const int offset_col32_out = (tid & 0xffffffe0) * m + (output_bid << 5) + (tid & 31); // const int offset_colMajor_out = output_bid * n + tid; // const int offset_out = index_CUBLASLT_ORDER_COL32_2R_4R4(tid, output_bid, m << 5); - int8_buf.x = float_to_int8_rn(norm_scale * inputBuf[0]); - int8_buf.y = float_to_int8_rn(norm_scale * inputBuf[1]); - int8_buf.z = float_to_int8_rn(norm_scale * inputBuf[2]); - int8_buf.w = float_to_int8_rn(norm_scale * inputBuf[3]); + int8_buf.x = float_to_int8_rn(norm_scale * inputBuf[0]); + int8_buf.y = float_to_int8_rn(norm_scale * inputBuf[1]); + int8_buf.z = float_to_int8_rn(norm_scale * inputBuf[2]); + int8_buf.w = float_to_int8_rn(norm_scale * inputBuf[3]); output_ptr[offset_col32_out >> 2] = int8_buf; } } @@ -1422,34 +1422,34 @@ __global__ void layernorm_shift_partition_v2_COL32_noRes(int8_t* out, const T* __restrict input, const T* __restrict gamma, const T* __restrict beta, - int batch, - int H, - int W, - int n, + int batch, + int H, + int W, + int n, const float* norm_scale_ptr, - int shift_size, - int window_size) + int shift_size, + int window_size) { - float norm_scale = __ldg(norm_scale_ptr); - const int ite = 4; - const int tid = threadIdx.x; - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int m = gridDim.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; - const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; - const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; - const int window_H_idx = shifted_H_idx / window_size; - const int window_W_idx = shifted_W_idx / window_size; + float norm_scale = __ldg(norm_scale_ptr); + const int ite = 4; + const int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int m = gridDim.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; + const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; const int stride_of_window_H = W / window_size; - const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; - const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); - const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); + const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float local_out[ite]; + float mean = 0.0f; + float variance = 0.0f; + float local_out[ite]; float sum = 0.0f; #pragma unroll @@ -1457,7 +1457,7 @@ __global__ void layernorm_shift_partition_v2_COL32_noRes(int8_t* out, int col_id = i * blockDim.x + tid; if (col_id < n) { const int offset_col32_in = (col_id & 0xffffffe0) * m + (bid << 5) + (col_id & 31); - local_out[i] = (float)(__ldg(input + offset_col32_in)); + local_out[i] = (float)(__ldg(input + offset_col32_in)); sum += local_out[i]; } } @@ -1473,7 +1473,7 @@ __global__ void layernorm_shift_partition_v2_COL32_noRes(int8_t* out, for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { - float diff = local_out[i] - s_mean; + float diff = local_out[i] - s_mean; local_out[i] = diff; var += diff * diff; } @@ -1490,7 +1490,7 @@ __global__ void layernorm_shift_partition_v2_COL32_noRes(int8_t* out, int col_id = i * blockDim.x + tid; if (col_id < n) { const int offset_col32_out = (col_id & 0xffffffe0) * m + (output_bid << 5) + (col_id & 31); - out[offset_col32_out] = float_to_int8_rn( + out[offset_col32_out] = float_to_int8_rn( norm_scale * (local_out[i] * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id]))); } } @@ -1501,39 +1501,39 @@ __global__ void layernorm_shift_partition_v2_COL32_noRes(int8_t* out_ptr, const half2* __restrict input_ptr, const half2* __restrict gamma_ptr, const half2* __restrict beta_ptr, - int batch, - int H, - int W, - int n, + int batch, + int H, + int W, + int n, const float* norm_scale_ptr, - int shift_size, - int window_size) + int shift_size, + int window_size) { - float norm_scale = __ldg(norm_scale_ptr); - const int ite = 4; - const int tid = threadIdx.x; + float norm_scale = __ldg(norm_scale_ptr); + const int ite = 4; + const int tid = threadIdx.x; const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int m = gridDim.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int m = gridDim.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; - const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; - const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; - const int window_H_idx = shifted_H_idx / window_size; - const int window_W_idx = shifted_W_idx / window_size; + const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; + const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; const int stride_of_window_H = W / window_size; - const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; - const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); - const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); + const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - half2 local_out_half2[ite]; - const half2 zero = {static_cast(0.0f), static_cast(0.0f)}; + float mean = 0.0f; + float variance = 0.0f; + half2 local_out_half2[ite]; + const half2 zero = {static_cast(0.0f), static_cast(0.0f)}; char2* output_ptr = (char2*)out_ptr; - char2 int8_buf; + char2 int8_buf; // float sum = 0.0f; half2 sum = __float2half2_rn(0.0f); @@ -1542,7 +1542,7 @@ __global__ void layernorm_shift_partition_v2_COL32_noRes(int8_t* out_ptr, int col_id = (i * blockDim.x + tid) << 1; if (col_id < n) { const int offset_col32 = (col_id & 0xffffffe0) * m + (bid << 5) + (col_id & 31); - local_out_half2[i] = __ldg(input_ptr + (offset_col32 >> 1)); + local_out_half2[i] = __ldg(input_ptr + (offset_col32 >> 1)); sum += local_out_half2[i]; } @@ -1554,15 +1554,15 @@ __global__ void layernorm_shift_partition_v2_COL32_noRes(int8_t* out_ptr, } __syncthreads(); - float var = 0.0f; + float var = 0.0f; half2 s_mean_2 = __float2half2_rn(s_mean); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = (i * blockDim.x + tid) << 1; if (col_id < n) { local_out_half2[i] = local_out_half2[i] - s_mean_2; - float v1 = (float)local_out_half2[i].x; - float v2 = (float)local_out_half2[i].y; + float v1 = (float)local_out_half2[i].x; + float v2 = (float)local_out_half2[i].y; var += v1 * v1 + v2 * v2; } } @@ -1579,31 +1579,31 @@ __global__ void layernorm_shift_partition_v2_COL32_noRes(int8_t* out_ptr, int col_id = (i * blockDim.x + tid) << 1; if (col_id < n) { const int offset_col32_out = (col_id & 0xffffffe0) * m + (output_bid << 5) + (col_id & 31); - half2 outVal = + half2 outVal = local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id >> 1]) + __ldg(&beta_ptr[col_id >> 1]); - int8_buf.x = float_to_int8_rn(norm_scale * static_cast(outVal.x)); - int8_buf.y = float_to_int8_rn(norm_scale * static_cast(outVal.y)); + int8_buf.x = float_to_int8_rn(norm_scale * static_cast(outVal.x)); + int8_buf.y = float_to_int8_rn(norm_scale * static_cast(outVal.y)); output_ptr[offset_col32_out >> 1] = int8_buf; } } } template -void invokeLayernormShiftPartitionCol32(int8_t* out, - const T* input, - const T* gamma, - const T* beta, - int batch, - int H, - int W, - int n, +void invokeLayernormShiftPartitionCol32(int8_t* out, + const T* input, + const T* gamma, + const T* beta, + int batch, + int H, + int W, + int n, const float* norm_scale_ptr, - int shift_size, - int window_size, + int shift_size, + int window_size, cudaStream_t stream) { dim3 grid(W, H, batch); - int blockSize = (n + 31) / 32 * 32; + int blockSize = (n + 31) / 32 * 32; if (blockSize >= 768) { blockSize = ((blockSize / 4) + 31) / 32 * 32; layernorm_shift_partition_v2_COL32_noRes<<>>( @@ -1616,22 +1616,22 @@ void invokeLayernormShiftPartitionCol32(int8_t* out, } template<> -void invokeLayernormShiftPartitionCol32(int8_t* out, - const half* input, - const half* gamma, - const half* beta, - int batch, - int H, - int W, - int n, +void invokeLayernormShiftPartitionCol32(int8_t* out, + const half* input, + const half* gamma, + const half* beta, + int batch, + int H, + int W, + int n, const float* norm_scale_ptr, - int shift_size, - int window_size, + int shift_size, + int window_size, cudaStream_t stream) { dim3 grid(W, H, batch); - int blockSize = n / 2; - blockSize = (blockSize + 31) / 32 * 32; + int blockSize = n / 2; + blockSize = (blockSize + 31) / 32 * 32; if ((batch * H * W >= 512 && blockSize >= 768) || blockSize > 1024) { blockSize = ((blockSize / 4) + 31) / 32 * 32; @@ -1663,30 +1663,30 @@ void invokeLayernormShiftPartitionCol32(int8_t* out, } } -template void invokeLayernormShiftPartitionCol32(int8_t* out, +template void invokeLayernormShiftPartitionCol32(int8_t* out, const float* input, const float* gamma, const float* beta, - int batch, - int H, - int W, - int n, + int batch, + int H, + int W, + int n, const float* norm_scale_ptr, - int shift_size, - int window_size, + int shift_size, + int window_size, cudaStream_t stream); -template void invokeLayernormShiftPartitionCol32(int8_t* out, - const half* input, - const half* gamma, - const half* beta, - int batch, - int H, - int W, - int n, +template void invokeLayernormShiftPartitionCol32(int8_t* out, + const half* input, + const half* gamma, + const half* beta, + int batch, + int H, + int W, + int n, const float* norm_scale_ptr, - int shift_size, - int window_size, + int shift_size, + int window_size, cudaStream_t stream); /******************* invokeMergeLayerNormCol32 ***********************/ @@ -1700,35 +1700,35 @@ __global__ void merge_layernorm_v2(int8_t* out, const T* __restrict input, const T* __restrict gamma, const T* __restrict beta, - int batch, + int batch, const float* merge_inFactor, - int H, - int W, - int n) + int H, + int W, + int n) { - const int ite = 4; - const int tid = threadIdx.x; - const int W_idx = blockIdx.x; - const int H_idx = blockIdx.y; + const int ite = 4; + const int tid = threadIdx.x; + const int W_idx = blockIdx.x; + const int H_idx = blockIdx.y; const float out_scale = __ldg(merge_inFactor); // const size_t batch_offset = blockIdx.z * H * W * n; // const int input_H_stride = W*n/2; // const int output_H_stride = W*n; const int n_4 = n >> 2; - const int m = batch * 4 * H * W; + const int m = batch * 4 * H * W; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float local_out[ite]; + float mean = 0.0f; + float variance = 0.0f; + float local_out[ite]; float sum = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { - int part_id = col_id / n_4; + int part_id = col_id / n_4; int offset_in_W = part_id / 2; int offset_in_H = part_id % 2; // size_t input_id = batch_offset + (2*H_idx + offset_in_H)*input_H_stride + (2*W_idx + offset_in_W)*n_4 + @@ -1737,7 +1737,7 @@ __global__ void merge_layernorm_v2(int8_t* out, int col_input = col_id % n_4; int row_input = blockIdx.z * H * W * 4 + (2 * H_idx + offset_in_H) * W * 2 + (2 * W_idx + offset_in_W); int input_idx_col32 = ((col_input >> 5) << 5) * m + (row_input << 5) + (col_input & 31); - local_out[i] = (float)(__ldg(input + input_idx_col32)); + local_out[i] = (float)(__ldg(input + input_idx_col32)); sum += local_out[i]; } } @@ -1770,9 +1770,9 @@ __global__ void merge_layernorm_v2(int8_t* out, if (col_id < n) { // size_t output_idx = batch_offset + (H_idx*W + W_idx)*n + col_id; - int col_output = col_id; - int row_output = blockIdx.z * H * W + H_idx * W + W_idx; - int output_idx_col32 = ((col_output >> 5) << 5) * (m >> 2) + (row_output << 5) + (col_output & 31); + int col_output = col_id; + int row_output = blockIdx.z * H * W + H_idx * W + W_idx; + int output_idx_col32 = ((col_output >> 5) << 5) * (m >> 2) + (row_output << 5) + (col_output & 31); out[output_idx_col32] = float_to_int8_rn( out_scale * (local_out[i] * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id]))); } @@ -1781,15 +1781,15 @@ __global__ void merge_layernorm_v2(int8_t* out, // TODO : accelerate with half2 template -void invokeMergeLayerNormCol32(int8_t* output, - const T* input, - const T* gamma, - const T* beta, - int batch, +void invokeMergeLayerNormCol32(int8_t* output, + const T* input, + const T* gamma, + const T* beta, + int batch, const float* merge_inFactor, - int H, - int W, - int n, + int H, + int W, + int n, cudaStream_t stream) { if ((W % 2 != 0) || (H % 2 != 0)) { @@ -1797,8 +1797,8 @@ void invokeMergeLayerNormCol32(int8_t* output, return; } dim3 grid(W / 2, H / 2, batch); - int blockSize = 4 * n; - blockSize = (blockSize + 31) / 32 * 32; + int blockSize = 4 * n; + blockSize = (blockSize + 31) / 32 * 32; // TODO // if (blockSize >= 768) { @@ -1812,76 +1812,76 @@ void invokeMergeLayerNormCol32(int8_t* output, */ } -template void invokeMergeLayerNormCol32(int8_t* output, +template void invokeMergeLayerNormCol32(int8_t* output, const float* input, const float* gamma, const float* beta, - int batch, + int batch, const float* merge_inFactor, - int H, - int W, - int n, + int H, + int W, + int n, cudaStream_t stream); -template void invokeMergeLayerNormCol32(int8_t* output, - const half* input, - const half* gamma, - const half* beta, - int batch, +template void invokeMergeLayerNormCol32(int8_t* output, + const half* input, + const half* gamma, + const half* beta, + int batch, const float* merge_inFactor, - int H, - int W, - int n, + int H, + int W, + int n, cudaStream_t stream); // input1/input2/out matrix with layout of row major (m*n) //(grid, block) must be (m, n/4) // using char4 template -__global__ void add_bias_input_layernorm_ROW_int8IO(int8_t* output, +__global__ void add_bias_input_layernorm_ROW_int8IO(int8_t* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr) + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); const float input2_deQFactor = __ldg(input2_deQFactor_ptr); - const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 2; + const float output_scale = __ldg(output_scale_ptr); + int col_start = threadIdx.x << 2; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; - float local_out[4]; - int outIdx = (blockIdx.x * n + col_start) >> 2; - char4* outTmpPtr = (char4*)output; + float local_out[4]; + int outIdx = (blockIdx.x * n + col_start) >> 2; + char4* outTmpPtr = (char4*)output; char4* input1TmpPtr = (char4*)input1; char4* input2TmpPtr = (char4*)input2; - char4 input1Tmp = __ldg(input1TmpPtr + outIdx); - char4 input2Tmp = __ldg(input2TmpPtr + outIdx); + char4 input1Tmp = __ldg(input1TmpPtr + outIdx); + char4 input2Tmp = __ldg(input2TmpPtr + outIdx); int col_start_tmp = col_start; - local_out[0] = static_cast(input2Tmp.x) * input2_deQFactor + local_out[0] = static_cast(input2Tmp.x) * input2_deQFactor + static_cast(input1Tmp.x) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[1] = static_cast(input2Tmp.y) * input2_deQFactor + local_out[1] = static_cast(input2Tmp.y) * input2_deQFactor + static_cast(input1Tmp.y) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[2] = static_cast(input2Tmp.z) * input2_deQFactor + local_out[2] = static_cast(input2Tmp.z) * input2_deQFactor + static_cast(input1Tmp.z) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); col_start_tmp = col_start_tmp + 1; - local_out[3] = static_cast(input2Tmp.w) * input2_deQFactor + local_out[3] = static_cast(input2Tmp.w) * input2_deQFactor + static_cast(input1Tmp.w) * input1_deQFactor + static_cast(__ldg(bias + col_start_tmp)); @@ -1895,7 +1895,7 @@ __global__ void add_bias_input_layernorm_ROW_int8IO(int8_t* output, local_out[1] = local_out[1] - s_mean; local_out[2] = local_out[2] - s_mean; local_out[3] = local_out[3] - s_mean; - variance = blockReduceSum(local_out[0] * local_out[0] + local_out[1] * local_out[1] + variance = blockReduceSum(local_out[0] * local_out[0] + local_out[1] * local_out[1] + local_out[2] * local_out[2] + local_out[3] * local_out[3]); if (threadIdx.x == 0) { s_variance = variance * __fdividef(1.0f, n) + 1e-6f; @@ -1907,17 +1907,17 @@ __global__ void add_bias_input_layernorm_ROW_int8IO(int8_t* output, + static_cast(__ldg(beta + col_start)); input2Tmp.x = float_to_int8_rn(local_out[0] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[1] = (local_out[1] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); input2Tmp.y = float_to_int8_rn(local_out[1] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[2] = (local_out[2] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); input2Tmp.z = float_to_int8_rn(local_out[2] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[3] = (local_out[3] * s_variance) * static_cast(__ldg(gamma + col_start)) + static_cast(__ldg(beta + col_start)); input2Tmp.w = float_to_int8_rn(local_out[3] * output_scale); @@ -1926,50 +1926,50 @@ __global__ void add_bias_input_layernorm_ROW_int8IO(int8_t* output, } template<> -__global__ void add_bias_input_layernorm_ROW_int8IO(int8_t* output, +__global__ void add_bias_input_layernorm_ROW_int8IO(int8_t* output, const int8_t* input1, const int8_t* input2, - const half2* bias, - const half2* gamma, - const half2* beta, - int m, - int n, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr) + const half2* bias, + const half2* gamma, + const half2* beta, + int m, + int n, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); const float input2_deQFactor = __ldg(input2_deQFactor_ptr); - const float output_scale = __ldg(output_scale_ptr); - int col_start = threadIdx.x << 2; + const float output_scale = __ldg(output_scale_ptr); + int col_start = threadIdx.x << 2; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; - float local_out[4]; - int outIdx = (blockIdx.x * n + col_start) >> 2; - char4* outTmpPtr = (char4*)output; + float local_out[4]; + int outIdx = (blockIdx.x * n + col_start) >> 2; + char4* outTmpPtr = (char4*)output; char4* input1TmpPtr = (char4*)input1; char4* input2TmpPtr = (char4*)input2; - char4 input1Tmp = __ldg(input1TmpPtr + outIdx); - char4 input2Tmp = __ldg(input2TmpPtr + outIdx); + char4 input1Tmp = __ldg(input1TmpPtr + outIdx); + char4 input2Tmp = __ldg(input2TmpPtr + outIdx); - int col_start_tmp = col_start; - half2 biasTmp = __ldg(bias + (col_start_tmp >> 1)); - local_out[0] = static_cast(input2Tmp.x) * input2_deQFactor + int col_start_tmp = col_start; + half2 biasTmp = __ldg(bias + (col_start_tmp >> 1)); + local_out[0] = static_cast(input2Tmp.x) * input2_deQFactor + static_cast(input1Tmp.x) * input1_deQFactor + static_cast(biasTmp.x); col_start_tmp = col_start_tmp + 1; - local_out[1] = static_cast(input2Tmp.y) * input2_deQFactor + local_out[1] = static_cast(input2Tmp.y) * input2_deQFactor + static_cast(input1Tmp.y) * input1_deQFactor + static_cast(biasTmp.y); col_start_tmp = col_start_tmp + 1; - biasTmp = __ldg(bias + (col_start_tmp >> 1)); - local_out[2] = static_cast(input2Tmp.z) * input2_deQFactor + biasTmp = __ldg(bias + (col_start_tmp >> 1)); + local_out[2] = static_cast(input2Tmp.z) * input2_deQFactor + static_cast(input1Tmp.z) * input1_deQFactor + static_cast(biasTmp.x); col_start_tmp = col_start_tmp + 1; - local_out[3] = static_cast(input2Tmp.w) * input2_deQFactor + local_out[3] = static_cast(input2Tmp.w) * input2_deQFactor + static_cast(input1Tmp.w) * input1_deQFactor + static_cast(biasTmp.y); mean = blockReduceSum(local_out[0] + local_out[1] + local_out[2] + local_out[3]); @@ -1982,7 +1982,7 @@ __global__ void add_bias_input_layernorm_ROW_int8IO(int8_t* output, local_out[1] = local_out[1] - s_mean; local_out[2] = local_out[2] - s_mean; local_out[3] = local_out[3] - s_mean; - variance = blockReduceSum(local_out[0] * local_out[0] + local_out[1] * local_out[1] + variance = blockReduceSum(local_out[0] * local_out[0] + local_out[1] * local_out[1] + local_out[2] * local_out[2] + local_out[3] * local_out[3]); if (threadIdx.x == 0) { s_variance = variance * __fdividef(1.0f, n) + 1e-6f; @@ -1991,43 +1991,43 @@ __global__ void add_bias_input_layernorm_ROW_int8IO(int8_t* output, __syncthreads(); col_start_tmp = col_start >> 1; - biasTmp = __ldg(gamma + col_start_tmp); + biasTmp = __ldg(gamma + col_start_tmp); half2 betaTmp = __ldg(beta + col_start_tmp); local_out[0] = (local_out[0] * s_variance) * static_cast(biasTmp.x) + static_cast(betaTmp.x); - input2Tmp.x = float_to_int8_rn(local_out[0] * output_scale); + input2Tmp.x = float_to_int8_rn(local_out[0] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[1] = (local_out[1] * s_variance) * static_cast(biasTmp.y) + static_cast(betaTmp.y); - input2Tmp.y = float_to_int8_rn(local_out[1] * output_scale); + input2Tmp.y = float_to_int8_rn(local_out[1] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; col_start_tmp = col_start >> 1; - biasTmp = __ldg(gamma + col_start_tmp); - betaTmp = __ldg(beta + col_start_tmp); - local_out[2] = (local_out[2] * s_variance) * static_cast(biasTmp.x) + static_cast(betaTmp.x); - input2Tmp.z = float_to_int8_rn(local_out[2] * output_scale); + biasTmp = __ldg(gamma + col_start_tmp); + betaTmp = __ldg(beta + col_start_tmp); + local_out[2] = (local_out[2] * s_variance) * static_cast(biasTmp.x) + static_cast(betaTmp.x); + input2Tmp.z = float_to_int8_rn(local_out[2] * output_scale); - col_start = col_start + 1; + col_start = col_start + 1; local_out[3] = (local_out[3] * s_variance) * static_cast(biasTmp.y) + static_cast(betaTmp.y); - input2Tmp.w = float_to_int8_rn(local_out[3] * output_scale); + input2Tmp.w = float_to_int8_rn(local_out[3] * output_scale); outTmpPtr[outIdx] = input2Tmp; } template -void invokeAddBiasResidualLayerNormRow(int8_t* output, +void invokeAddBiasResidualLayerNormRow(int8_t* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr) + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr) { dim3 grid(m); dim3 block(n / 4); @@ -2060,57 +2060,57 @@ void invokeAddBiasResidualLayerNormRow(int8_t* output, } } -template void invokeAddBiasResidualLayerNormRow(int8_t* output, +template void invokeAddBiasResidualLayerNormRow(int8_t* output, const int8_t* input1, const int8_t* input2, - const float* bias, - const float* gamma, - const float* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr); - -template void invokeAddBiasResidualLayerNormRow(int8_t* output, + const float* bias, + const float* gamma, + const float* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr); + +template void invokeAddBiasResidualLayerNormRow(int8_t* output, const int8_t* input1, const int8_t* input2, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr); + const half* bias, + const half* gamma, + const half* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr); // input1/input2/out matrix with layout of row major (m*n) //(grid, block) must be (m, n) template -__global__ void add_bias_input_layernorm_ROW_int8I_DataTypeO(T* output, +__global__ void add_bias_input_layernorm_ROW_int8I_DataTypeO(T* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr) + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); const float input2_deQFactor = __ldg(input2_deQFactor_ptr); - int col_start = threadIdx.x; + int col_start = threadIdx.x; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float local_out; - int idx = blockIdx.x * n + col_start; + int idx = blockIdx.x * n + col_start; local_out = static_cast(__ldg(input2 + idx)) * input2_deQFactor + static_cast(__ldg(input1 + idx)) * input1_deQFactor @@ -2141,33 +2141,33 @@ __global__ void add_bias_input_layernorm_ROW_int8I_DataTypeO(T* output, // input1/input2/out matrix with layout of row major (m*n) //(grid, block) must be (m, n/2) template<> -__global__ void add_bias_input_layernorm_ROW_int8I_DataTypeO(half2* output, +__global__ void add_bias_input_layernorm_ROW_int8I_DataTypeO(half2* output, const int8_t* input1, const int8_t* input2, - const half2* bias, - const half2* gamma, - const half2* beta, - int m, - int n, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr) + const half2* bias, + const half2* gamma, + const half2* beta, + int m, + int n, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr) { const float input1_deQFactor = __ldg(input1_deQFactor_ptr); const float input2_deQFactor = __ldg(input2_deQFactor_ptr); - int col_start = threadIdx.x << 1; + int col_start = threadIdx.x << 1; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float2 local_out; - int idx = (blockIdx.x * n + col_start) >> 1; + int idx = (blockIdx.x * n + col_start) >> 1; const char2* input1_ptr2 = (const char2*)input1; const char2* input2_ptr2 = (const char2*)input2; - char2 input_tmp1 = __ldg(input1_ptr2 + idx); - char2 input_tmp2 = __ldg(input2_ptr2 + idx); + char2 input_tmp1 = __ldg(input1_ptr2 + idx); + char2 input_tmp2 = __ldg(input2_ptr2 + idx); half2 bias_tmp = __ldg(bias + threadIdx.x); @@ -2196,7 +2196,7 @@ __global__ void add_bias_input_layernorm_ROW_int8I_DataTypeO(half2* output, __syncthreads(); half2 gamma_tmp = __ldg(gamma + threadIdx.x); - half2 beta_tmp = __ldg(beta + threadIdx.x); + half2 beta_tmp = __ldg(beta + threadIdx.x); local_out.x = (local_out.x * s_variance) * static_cast(gamma_tmp.x) + static_cast(beta_tmp.x); local_out.y = (local_out.y * s_variance) * static_cast(gamma_tmp.y) + static_cast(beta_tmp.y); @@ -2208,17 +2208,17 @@ __global__ void add_bias_input_layernorm_ROW_int8I_DataTypeO(half2* output, } template -void invokeAddBiasResidualLayerNormRow(T* output, +void invokeAddBiasResidualLayerNormRow(T* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr) + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr) { dim3 grid(m); dim3 block(n); @@ -2243,28 +2243,28 @@ void invokeAddBiasResidualLayerNormRow(T* output, } } -template void invokeAddBiasResidualLayerNormRow(float* output, +template void invokeAddBiasResidualLayerNormRow(float* output, const int8_t* input1, const int8_t* input2, - const float* bias, - const float* gamma, - const float* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr); - -template void invokeAddBiasResidualLayerNormRow(half* output, + const float* bias, + const float* gamma, + const float* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr); + +template void invokeAddBiasResidualLayerNormRow(half* output, const int8_t* input1, const int8_t* input2, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr); + const half* bias, + const half* gamma, + const half* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr); } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/layernorm_int8_kernels.h b/src/fastertransformer/kernels/layernorm_int8_kernels.h index 53f01b25c..be44980ef 100644 --- a/src/fastertransformer/kernels/layernorm_int8_kernels.h +++ b/src/fastertransformer/kernels/layernorm_int8_kernels.h @@ -26,133 +26,133 @@ namespace fastertransformer { template -void invokeAddBiasResidualLayerNormCol32(T* output, +void invokeAddBiasResidualLayerNormCol32(T* output, const int32_t* input1, - const T* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* weight_amax, - const float* input1_amax_ptr); + const T* input2, + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* weight_amax, + const float* input1_amax_ptr); template -void invokeAddBiasResidualLayerNormCol32(int8_t* output, +void invokeAddBiasResidualLayerNormCol32(int8_t* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr); + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr); template -void invokeAddBiasResidualLayerNormCol32(T* output, +void invokeAddBiasResidualLayerNormCol32(T* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr); + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr); template -void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, - int8_t* input1, - T* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, +void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, + int8_t* input1, + T* input2, + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, cudaStream_t stream, const float* input1_deQFactor_ptr, const float* output_scale_ptr); template -void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, - int32_t* input1, - T* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, +void invokeAddBiasResidualLayerNormCol32_noRes(int8_t* output, + int32_t* input1, + T* input2, + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, cudaStream_t stream, const float* weight_amax, const float* input1_amax_ptr, const float* output_scale_ptr); template -void invokeLayernormCol32(int8_t* out, - const T* input, - const T* gamma, - const T* beta, - int m, - int n, +void invokeLayernormCol32(int8_t* out, + const T* input, + const T* gamma, + const T* beta, + int m, + int n, const float* norm_scale_ptr, cudaStream_t stream); template -void invokeLayernormShiftPartitionCol32(int8_t* out, - const T* input, - const T* gamma, - const T* beta, - int batch, - int H, - int W, - int n, +void invokeLayernormShiftPartitionCol32(int8_t* out, + const T* input, + const T* gamma, + const T* beta, + int batch, + int H, + int W, + int n, const float* norm_scale_ptr, - int shift_size, - int window_size, + int shift_size, + int window_size, cudaStream_t stream); template -void invokeMergeLayerNormCol32(int8_t* output, - const T* input, - const T* gamma, - const T* beta, - int batch, +void invokeMergeLayerNormCol32(int8_t* output, + const T* input, + const T* gamma, + const T* beta, + int batch, const float* merge_inFactor, - int H, - int W, - int n, + int H, + int W, + int n, cudaStream_t stream); template -void invokeAddBiasResidualLayerNormRow(int8_t* output, +void invokeAddBiasResidualLayerNormRow(int8_t* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr, - const float* output_scale_ptr); + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr, + const float* output_scale_ptr); template -void invokeAddBiasResidualLayerNormRow(T* output, +void invokeAddBiasResidualLayerNormRow(T* output, const int8_t* input1, const int8_t* input2, - const T* bias, - const T* gamma, - const T* beta, - int m, - int n, - cudaStream_t stream, - const float* input1_deQFactor_ptr, - const float* input2_deQFactor_ptr); + const T* bias, + const T* gamma, + const T* beta, + int m, + int n, + cudaStream_t stream, + const float* input1_deQFactor_ptr, + const float* input2_deQFactor_ptr); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/layernorm_kernels.cu b/src/fastertransformer/kernels/layernorm_kernels.cu index 96a090ed9..10f7a1383 100644 --- a/src/fastertransformer/kernels/layernorm_kernels.cu +++ b/src/fastertransformer/kernels/layernorm_kernels.cu @@ -21,39 +21,44 @@ namespace fastertransformer { // * Note that typename T is half2 or bfloat2 type -template +template __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, T* output, const T* __restrict bias, - const T* __restrict residual, + const T* __restrict residual1, + const T* __restrict residual2, const T* __restrict gamma, const T* __restrict beta, - int m, - int n) + const float layernorm_eps, + int m, + int n) { __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; T local_sum = float2type2(0.0f); #pragma unroll for (int i = threadIdx.x; i < n; i += blockDim.x) { const int index = blockIdx.x * n + i; - T val = float2type2(0.0f); + T val = float2type2(0.0f); if (IS_BIAS) { val = hadd2(val, ldg(&bias[i])); } - if (IS_RESIDUAL) { - val = hadd2(val, ldg(&residual[index])); + if (RESIDUAL_NUM == 1) { + val = hadd2(val, ldg(&residual1[index])); + } + else if (RESIDUAL_NUM == 2) { + val = hadd2(hadd2(val, ldg(&residual1[index])), ldg(&residual2[index])); } if (IS_OUTPUT) { val = hadd2(val, output[index]); } output[index] = val; - local_sum = hadd2(local_sum, val); + local_sum = hadd2(local_sum, val); } mean = blockReduceSum((float)(local_sum.x + local_sum.y)); @@ -66,7 +71,7 @@ __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, float local_var_sum = 0.0f; #pragma unroll UNROLL_FACTOR for (int i = threadIdx.x; i < n; i += blockDim.x) { - T val = output[blockIdx.x * n + i]; + T val = output[blockIdx.x * n + i]; float diff_1 = (float)(val.x) - s_mean; float diff_2 = (float)(val.y) - s_mean; local_var_sum += (diff_1 * diff_1 + diff_2 * diff_2); @@ -74,16 +79,16 @@ __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, variance = blockReduceSum(local_var_sum); if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / n / 2 + 1e-6f); + s_variance = rsqrtf(variance / n / 2 + layernorm_eps); } __syncthreads(); T mean_2 = float2type2(s_mean); - T var_2 = float2type2(s_variance); + T var_2 = float2type2(s_variance); #pragma unroll UNROLL_FACTOR for (int i = threadIdx.x; i < n; i += blockDim.x) { const int index = blockIdx.x * n + i; - T val = hmul2(hmul2(hsub2(output[index], mean_2), var_2), ldg(&gamma[i])); + T val = hmul2(hsub2(output[index], mean_2), var_2, ldg(&gamma[i])); if (IS_BETA) { val = hadd2(val, ldg(&beta[i])); } @@ -92,48 +97,56 @@ __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, } // * Note that typename T is half2 or bfloat2 type -template +template __global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, T* output, const T* __restrict bias, - const T* __restrict residual, + const T* __restrict residual1, + const T* __restrict residual2, const T* __restrict gamma, const T* __restrict beta, - int m, - int n) + const float layernorm_eps, + int m, + int n) { __shared__ float s_mean; __shared__ float s_variance; - float x_sum = 0.0f; - float x2_sum = 0.0f; - const int b_offset = blockIdx.x * n; - using T1 = typename TypeConverter::Type; + float x_sum = 0.0f; + float x2_sum = 0.0f; + const int b_offset = blockIdx.x * n; + using T1 = typename TypeConverter::Type; #pragma unroll UNROLL_FACTOR for (int i = threadIdx.x; i < n; i += blockDim.x) { const int index = b_offset + i; - float val_1 = 0.0f; - float val_2 = 0.0f; - T tmp; + float val_1 = 0.0f; + float val_2 = 0.0f; + T tmp; if (IS_BIAS) { tmp = ldg(&bias[i]); val_1 += static_cast(tmp.x); val_2 += static_cast(tmp.y); } - if (IS_RESIDUAL) { - tmp = ldg(&residual[index]); + if (RESIDUAL_NUM == 1) { + tmp = ldg(&residual1[index]); val_1 += static_cast(tmp.x); val_2 += static_cast(tmp.y); } + else if (RESIDUAL_NUM == 2) { + tmp = ldg(&residual1[index]); + T tmp2 = ldg(&residual2[index]); + val_1 += (static_cast(tmp.x) + static_cast(tmp2.x)); + val_2 += (static_cast(tmp.y) + static_cast(tmp2.y)); + } if (IS_OUTPUT) { tmp = ldg(&output[index]); val_1 += static_cast(tmp.x); val_2 += static_cast(tmp.y); } - tmp.x = float2type(val_1); - tmp.y = float2type(val_2); + tmp.x = float2type(val_1); + tmp.y = float2type(val_2); output[index] = tmp; x_sum += val_1 + val_2; x2_sum += val_1 * val_1 + val_2 * val_2; @@ -144,18 +157,18 @@ __global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, blockReduceSumV2(sums); if (threadIdx.x == 0) { - s_mean = sums[0] / n / 2; - s_variance = rsqrtf(sums[1] / n / 2 - s_mean * s_mean + 1e-6f); + s_mean = sums[0] / n / 2; + s_variance = rsqrtf(sums[1] / n / 2 - s_mean * s_mean + layernorm_eps); } __syncthreads(); T mean_2 = float2type2(s_mean); - T var_2 = float2type2(s_variance); + T var_2 = float2type2(s_variance); #pragma unroll UNROLL_FACTOR for (int i = threadIdx.x; i < n; i += blockDim.x) { const int index = b_offset + i; - T val = hmul2(hmul2(hsub2(output[index], mean_2), var_2), ldg(&gamma[i])); + T val = hmul2(hsub2(output[index], mean_2), var_2, ldg(&gamma[i])); if (IS_BETA) { val = hadd2(val, ldg(&beta[i])); } @@ -165,18 +178,18 @@ __global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, // TODO(bhsueh) add half2 implementation template -__global__ void -addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n) +__global__ void addBiasResidualPostLayerNorm( + T* out, const T* input, const T* bias, const T* gamma, const T* beta, const float layernorm_eps, int m, int n) { __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float local_out_cache[N]; + float mean = 0.0f; + float variance = 0.0f; + float local_out_cache[N]; #pragma unroll N for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { - float local_out = (float)(out[blockIdx.x * n + idx] + input[blockIdx.x * n + idx] + __ldg(&bias[idx])); + float local_out = (float)(add(out[blockIdx.x * n + idx], input[blockIdx.x * n + idx], ldg(&bias[idx]))); mean += local_out; // save local_out to local_out_cache to save some recompute local_out_cache[i] = local_out; @@ -197,7 +210,7 @@ addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gam } variance = blockReduceSum(variance); if (threadIdx.x == 0) { - s_variance = variance / n + 1e-6f; + s_variance = variance / n + layernorm_eps; } __syncthreads(); @@ -205,32 +218,38 @@ addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gam for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { float local_out = local_out_cache[i]; out[blockIdx.x * n + idx] = - (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[idx])) + (float)(__ldg(&beta[idx]))); + (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(ldg(&gamma[idx])) + (float)(ldg(&beta[idx]))); idx += blockDim.x; } } template -__global__ void addBiasResidualPostLayerNormHalf( - half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n) +__global__ void addBiasResidualPostLayerNormHalf(half* out, + const half* input, + const half* bias, + const half* gamma, + const half* beta, + const float layernorm_eps, + int m, + int n) { __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; - half2* out_ptr = (half2*)out; + half2* out_ptr = (half2*)out; const half2* input_ptr = (const half2*)input; - const half2* bias_ptr = (const half2*)bias; + const half2* bias_ptr = (const half2*)bias; const half2* gamma_ptr = (const half2*)gamma; - const half2* beta_ptr = (const half2*)beta; + const half2* beta_ptr = (const half2*)beta; float2 out_fp2_cache[N]; float local_out = 0.0f; #pragma unroll N for (int idx = threadIdx.x, i = 0; idx < n / 2 && i < N; ++i) { - int id = blockIdx.x * n / 2 + idx; + int id = blockIdx.x * n / 2 + idx; float2 local_out_fp2 = __half22float2(__hadd2(__hadd2(out_ptr[id], input_ptr[id]), __ldg(&bias_ptr[idx]))); local_out += local_out_fp2.x; local_out += local_out_fp2.y; @@ -255,329 +274,378 @@ __global__ void addBiasResidualPostLayerNormHalf( variance = blockReduceSum(variance); if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); #pragma unroll N for (int idx = threadIdx.x, i = 0; i < N && idx < n / 2; ++i) { - int id = blockIdx.x * n / 2 + idx; + int id = blockIdx.x * n / 2 + idx; float2 local_out_fp2 = out_fp2_cache[i]; - float2 gamma_val = __half22float2(__ldg(&gamma_ptr[idx])); - float2 beta_val = __half22float2(__ldg(&beta_ptr[idx])); - local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; - local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; - out_ptr[id] = __float22half2_rn(local_out_fp2); + float2 gamma_val = __half22float2(__ldg(&gamma_ptr[idx])); + float2 beta_val = __half22float2(__ldg(&beta_ptr[idx])); + local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; + local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; + out_ptr[id] = __float22half2_rn(local_out_fp2); idx += blockDim.x; } } +// Optimization for fp16 and fp16 (bf162 and half2) template -__global__ void -generalAddBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n) +__global__ void generalAddBiasResidualPostLayerNorm( + T* out, const T* input, const T* bias, const T* gamma, const T* beta, const float layernorm_eps, int m, int n) { + using T2 = typename TypeConverter::Type; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; - for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { - float local_out = (float)(out[blockIdx.x * n + idx] + input[blockIdx.x * n + idx] + __ldg(&bias[idx])); - mean += local_out; - // save local_out to out to save some recompute - out[blockIdx.x * n + idx] = local_out; + T2* out_ptr = (T2*)out; + const T2* input_ptr = (const T2*)input; + const T2* bias_ptr = (const T2*)bias; + const T2* gamma_ptr = (const T2*)gamma; + const T2* beta_ptr = (const T2*)beta; + + float local_out = 0.0f; + for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { + int id = blockIdx.x * n / 2 + idx; + T2 tmp = hadd2(hadd2(out_ptr[id], input_ptr[id]), ldg(&bias_ptr[idx])); + float2 local_out_fp2 = type22float2(tmp); + local_out += local_out_fp2.x; + local_out += local_out_fp2.y; + // save tmp to out_ptr to save some recomputation + out_ptr[id] = tmp; } - mean = blockReduceSum(mean); + mean = blockReduceSum(local_out); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); - for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { - float local_out = out[blockIdx.x * n + idx]; - variance += (local_out - s_mean) * (local_out - s_mean); + for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { + int id = blockIdx.x * n / 2 + idx; + float2 local_out_fp2 = type22float2(out_ptr[id]); + variance += (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); + variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); } + variance = blockReduceSum(variance); if (threadIdx.x == 0) { - s_variance = variance / n + 1e-6f; + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); - for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { - float local_out = out[blockIdx.x * n + idx]; - out[blockIdx.x * n + idx] = - (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[idx])) + (float)(__ldg(&beta[idx]))); + for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { + int id = blockIdx.x * n / 2 + idx; + float2 local_out_fp2 = type22float2(out_ptr[id]); + float2 gamma_val = type22float2(ldg(&gamma_ptr[idx])); + float2 beta_val = type22float2(ldg(&beta_ptr[idx])); + local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; + local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; + out_ptr[id] = float22type2(local_out_fp2); } } template<> -__global__ void generalAddBiasResidualPostLayerNorm( - half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n) +__global__ void generalAddBiasResidualPostLayerNorm(float* out, + const float* input, + const float* bias, + const float* gamma, + const float* beta, + const float layernorm_eps, + int m, + int n) { __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - - half2* out_ptr = (half2*)out; - const half2* input_ptr = (const half2*)input; - const half2* bias_ptr = (const half2*)bias; - const half2* gamma_ptr = (const half2*)gamma; - const half2* beta_ptr = (const half2*)beta; + float mean = 0.0f; + float variance = 0.0f; - float local_out = 0.0f; - for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { - int id = blockIdx.x * n / 2 + idx; - half2 tmp = __hadd2(__hadd2(out_ptr[id], input_ptr[id]), __ldg(&bias_ptr[idx])); - float2 local_out_fp2 = __half22float2(tmp); - local_out += local_out_fp2.x; - local_out += local_out_fp2.y; - // save tmp to out_ptr to save some recomputation - out_ptr[id] = tmp; + for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { + float local_out = (float)(out[blockIdx.x * n + idx] + input[blockIdx.x * n + idx] + __ldg(&bias[idx])); + mean += local_out; + // save local_out to out to save some recompute + out[blockIdx.x * n + idx] = local_out; } - mean = blockReduceSum(local_out); + mean = blockReduceSum(mean); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); - for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { - int id = blockIdx.x * n / 2 + idx; - float2 local_out_fp2 = __half22float2(out_ptr[id]); - variance += (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); - variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); + for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { + float local_out = out[blockIdx.x * n + idx]; + variance += (local_out - s_mean) * (local_out - s_mean); } - variance = blockReduceSum(variance); if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); - for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { - int id = blockIdx.x * n / 2 + idx; - float2 local_out_fp2 = __half22float2(out_ptr[id]); - float2 gamma_val = __half22float2(__ldg(&gamma_ptr[idx])); - float2 beta_val = __half22float2(__ldg(&beta_ptr[idx])); - local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; - local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; - out_ptr[id] = __float22half2_rn(local_out_fp2); + for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { + float local_out = out[blockIdx.x * n + idx]; + out[blockIdx.x * n + idx] = + (float)(((local_out - s_mean) * s_variance) * (float)(__ldg(&gamma[idx])) + (float)(__ldg(&beta[idx]))); } } +// applied to half and b16 template __global__ void addBiasResidualPostLayerNormV2(T* out, const T* __restrict input, const T* __restrict bias, const T* __restrict gamma, const T* __restrict beta, - int n) + const float layernorm_eps, + int n) { - const int ite = 4; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - + using T2 = typename TypeConverter::Type; + const int ite = 4; + const int tid = threadIdx.x; + const int bid = blockIdx.x; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float local_out[ite]; + float mean = 0.0f; + float variance = 0.0f; + T2 local_out_half2[ite]; - float sum = 0.0f; + T2* out_ptr = (T2*)out; + const T2* input_ptr = (const T2*)input; + const T2* bias_ptr = (const T2*)bias; + const T2* gamma_ptr = (const T2*)gamma; + const T2* beta_ptr = (const T2*)beta; + + // float sum = 0.0f; + T2 sum = float2type2(0.0f); #pragma unroll for (int i = 0; i < ite; i++) { - int col_id = i * blockDim.x + tid; - int id = bid * n + col_id; - local_out[i] = (float)(out[id] + __ldg(&input[id]) + __ldg(&bias[col_id])); - sum += local_out[i]; + int col_id = i * blockDim.x + tid; + int id = bid * n / 2 + col_id; + local_out_half2[i] = add(out_ptr[id], ldg(&input_ptr[id]), ldg(&bias_ptr[col_id])); + sum = add(sum, local_out_half2[i]); } - mean = blockReduceSum(sum); - if (tid == 0) { + mean = blockReduceSum((float)(sum.x + sum.y)); + if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); - float var = 0.0f; + float var = 0.0f; + T2 s_mean_2 = float2type2(s_mean); #pragma unroll for (int i = 0; i < ite; i++) { - float diff = local_out[i] - s_mean; - var += diff * diff; + local_out_half2[i] = hsub2(local_out_half2[i], s_mean_2); + float v1 = (float)local_out_half2[i].x; + float v2 = (float)local_out_half2[i].y; + var += v1 * v1 + v2 * v2; } variance = blockReduceSum(var); if (tid == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); + T2 s_var_2 = float2type2(s_variance); #pragma unroll for (int i = 0; i < ite; i++) { - int col_id = i * blockDim.x + tid; - int id = bid * n + col_id; - out[id] = - (T)((local_out[i] - s_mean) * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); + int col_id = i * blockDim.x + tid; + int id = bid * n / 2 + col_id; + out_ptr[id] = fma(local_out_half2[i], s_var_2, ldg(&gamma_ptr[col_id]), ldg(&beta_ptr[col_id])); } } template<> -__global__ void addBiasResidualPostLayerNormV2(half* out, - const half* __restrict input, - const half* __restrict bias, - const half* __restrict gamma, - const half* __restrict beta, - int n) +__global__ void addBiasResidualPostLayerNormV2(float* out, + const float* __restrict input, + const float* __restrict bias, + const float* __restrict gamma, + const float* __restrict beta, + const float layernorm_eps, + int n) { const int ite = 4; const int tid = threadIdx.x; const int bid = blockIdx.x; + __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - half2 local_out_half2[ite]; + float mean = 0.0f; + float variance = 0.0f; + float local_out[ite]; - half2* out_ptr = (half2*)out; - const half2* input_ptr = (const half2*)input; - const half2* bias_ptr = (const half2*)bias; - const half2* gamma_ptr = (const half2*)gamma; - const half2* beta_ptr = (const half2*)beta; - - // float sum = 0.0f; - half2 sum = __float2half2_rn(0.0f); + float sum = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { - int col_id = i * blockDim.x + tid; - int id = bid * n / 2 + col_id; - local_out_half2[i] = out_ptr[id] + __ldg(&input_ptr[id]) + __ldg(&bias_ptr[col_id]); - sum += local_out_half2[i]; + int col_id = i * blockDim.x + tid; + int id = bid * n + col_id; + local_out[i] = (float)(out[id] + __ldg(&input[id]) + __ldg(&bias[col_id])); + sum += local_out[i]; } - mean = blockReduceSum((float)(sum.x + sum.y)); - if (threadIdx.x == 0) { + mean = blockReduceSum(sum); + if (tid == 0) { s_mean = mean / n; } __syncthreads(); float var = 0.0f; - half2 s_mean_2 = __float2half2_rn(s_mean); #pragma unroll for (int i = 0; i < ite; i++) { - local_out_half2[i] = local_out_half2[i] - s_mean_2; - float v1 = (float)local_out_half2[i].x; - float v2 = (float)local_out_half2[i].y; - var += v1 * v1 + v2 * v2; + float diff = local_out[i] - s_mean; + var += diff * diff; } variance = blockReduceSum(var); - if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + if (tid == 0) { + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); - half2 s_var_2 = __float2half2_rn(s_variance); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; - int id = bid * n / 2 + col_id; - out_ptr[id] = local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id]) + __ldg(&beta_ptr[col_id]); + int id = bid * n + col_id; + out[id] = + (float)((local_out[i] - s_mean) * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); } } +// bf16 and half data type template -void invokeAddBiasResidualLayerNorm( - T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream) +void invokeAddBiasResidualLayerNorm(T* out, + const T* input, + const T* bias, + const T* gamma, + const T* beta, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream) { dim3 grid(m); dim3 block(std::min(n, 1024)); - if (n == 768 || n == 1024) { - addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n); + + if (m >= 512 && (n == 768 || n == 1024)) { + addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, layernorm_eps, n); } else { - block.x = std::min(n, 1024); + block.x = std::min(n, 1024); int num_trips = (n + block.x - 1) / block.x; if (num_trips == 1) { - addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); + addBiasResidualPostLayerNorm + <<>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } else if (num_trips == 2) { - addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); + addBiasResidualPostLayerNorm + <<>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } else { - generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); + generalAddBiasResidualPostLayerNorm + <<>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } } } template<> -void invokeAddBiasResidualLayerNorm(half* out, - const half* input, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, +void invokeAddBiasResidualLayerNorm(float* out, + const float* input, + const float* bias, + const float* gamma, + const float* beta, + const float layernorm_eps, + int m, + int n, cudaStream_t stream) { dim3 grid(m); dim3 block(std::min(n, 1024)); - - if (m >= 512 && (n == 768 || n == 1024)) { - addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n); + if (n == 768 || n == 1024) { + addBiasResidualPostLayerNormV2 + <<>>(out, input, bias, gamma, beta, layernorm_eps, n); } else { - block.x = std::min(n, 1024); + block.x = std::min(n, 1024); int num_trips = (n + block.x - 1) / block.x; if (num_trips == 1) { - addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); + addBiasResidualPostLayerNorm + <<>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } else if (num_trips == 2) { - addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); + addBiasResidualPostLayerNorm + <<>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } else { - generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); + generalAddBiasResidualPostLayerNorm + <<>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } } } -template void invokeAddBiasResidualLayerNorm(float* out, +template void invokeAddBiasResidualLayerNorm(float* out, const float* input, const float* bias, const float* gamma, const float* beta, - int m, - int n, + const float layernorm_eps, + int m, + int n, cudaStream_t stream); -template void invokeAddBiasResidualLayerNorm(half* out, - const half* input, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, +template void invokeAddBiasResidualLayerNorm(half* out, + const half* input, + const half* bias, + const half* gamma, + const half* beta, + const float layernorm_eps, + int m, + int n, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddBiasResidualLayerNorm(__nv_bfloat16* out, + const __nv_bfloat16* input, + const __nv_bfloat16* bias, + const __nv_bfloat16* gamma, + const __nv_bfloat16* beta, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream); +#endif -template -__global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, +template +__global__ void generalAddBiasResidualLayerNorm(const T* __restrict residual1, + const T* __restrict residual2, const T* __restrict gamma, const T* __restrict beta, const T* __restrict bias, - T* output, - T* norm_output, - int m, - int n) + T* output, + T* norm_output, + const float layernorm_eps, + int m, + int n) { int tid = threadIdx.x; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float local_sum = 0.0f; for (int i = tid; i < n; i += blockDim.x) { - float local_out = (float)(ldg(&input[blockIdx.x * n + i])); + float local_out = 0.0f; + if (RESIDUAL_NUM == 1) { + local_out = (float)(ldg(&residual1[blockIdx.x * n + i])); + } + else if (RESIDUAL_NUM == 2) { + local_out = (float)(ldg(&residual1[blockIdx.x * n + i])) + float(ldg(&residual2[blockIdx.x * n + i])); + } local_out += (float)(output[blockIdx.x * n + i]); if (bias != nullptr) { local_out += (float)(ldg(&bias[i])); @@ -601,7 +669,7 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, variance = blockReduceSum(local_var_sum); if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); @@ -612,83 +680,108 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, } } -#define HALF_LAYERNORM_OPT(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt \ +#define HALF_LAYERNORM_OPT(UNROLL_FACTOR, RESIDUAL_NUM) \ + generalAddBiasResidualLayerNormOpt \ <<>>((T2*)norm_output, \ (T2*)output, \ (const T2*)bias, \ - (const T2*)input, \ + (const T2*)residual1, \ + (const T2*)residual2, \ (const T2*)gamma, \ (const T2*)beta, \ + layernorm_eps, \ m, \ half_n); -#define HALF_LAYERNORM_OPT2(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt2 \ +#define HALF_LAYERNORM_OPT2(UNROLL_FACTOR, RESIDUAL_NUM) \ + generalAddBiasResidualLayerNormOpt2 \ <<>>((T2*)norm_output, \ (T2*)output, \ (const T2*)bias, \ - (const T2*)input, \ + (const T2*)residual1, \ + (const T2*)residual2, \ (const T2*)gamma, \ (const T2*)beta, \ + layernorm_eps, \ m, \ half_n); +#define HALF_LAYERNORM_OPT_RESIDUAL(UNROLL_FACTOR, OPT_TYPE) \ + if (residual_num == 1) { \ + if (OPT_TYPE == 1) { \ + HALF_LAYERNORM_OPT(UNROLL_FACTOR, 1); \ + } \ + else if (OPT_TYPE == 2) { \ + HALF_LAYERNORM_OPT2(UNROLL_FACTOR, 1); \ + } \ + } \ + else if (residual_num == 2) { \ + if (OPT_TYPE == 1) { \ + HALF_LAYERNORM_OPT(UNROLL_FACTOR, 2); \ + } \ + else if (OPT_TYPE == 2) { \ + HALF_LAYERNORM_OPT2(UNROLL_FACTOR, 2); \ + } \ + } + template -void invokeGeneralAddBiasResidualPreLayerNorm(T* output, - T* norm_output, - const T* input, - const T* gamma, - const T* beta, - const T* bias, - int m, - int n, +void invokeGeneralAddBiasResidualPreLayerNorm(T* output, + T* norm_output, + const T* residual1, + const T* residual2, + const T* gamma, + const T* beta, + const T* bias, + const float layernorm_eps, + int m, + int n, cudaStream_t stream, - int opt_version) + int opt_version) { + const int residual_num = residual2 == nullptr ? 1 : 2; if (opt_version > 0 && sizeof(T) == 2 && n % 2 == 0) { dim3 grid(m); - int half_n = n / 2; - int half_n_32 = (half_n + 31) / 32 * 32; + int half_n = n / 2; + int half_n_32 = (half_n + 31) / 32 * 32; dim3 block(min(half_n_32, 512)); - int rolls_per_thread = half_n / block.x; - int unroll_factor = 8; + int rolls_per_thread = half_n / block.x; + int unroll_factor = 8; while (unroll_factor > rolls_per_thread && unroll_factor > 1) { unroll_factor /= 2; } using T2 = typename TypeConverter::Type; if (opt_version == 1) { if (unroll_factor == 1) { - HALF_LAYERNORM_OPT(1); + HALF_LAYERNORM_OPT_RESIDUAL(1, 1); } else if (unroll_factor == 2) { - HALF_LAYERNORM_OPT(2); + HALF_LAYERNORM_OPT_RESIDUAL(2, 1); } else if (unroll_factor == 3) { - HALF_LAYERNORM_OPT(3); + HALF_LAYERNORM_OPT_RESIDUAL(3, 1); } else if (unroll_factor == 4) { - HALF_LAYERNORM_OPT(4); + HALF_LAYERNORM_OPT_RESIDUAL(4, 1); } else if (unroll_factor == 8) { - HALF_LAYERNORM_OPT(8); + HALF_LAYERNORM_OPT_RESIDUAL(8, 1); } } else { if (unroll_factor == 1) { - HALF_LAYERNORM_OPT2(1); + HALF_LAYERNORM_OPT_RESIDUAL(1, 2); } else if (unroll_factor == 2) { - HALF_LAYERNORM_OPT2(2); + HALF_LAYERNORM_OPT_RESIDUAL(2, 2); } else if (unroll_factor == 3) { - HALF_LAYERNORM_OPT2(3); + HALF_LAYERNORM_OPT_RESIDUAL(3, 2); } else if (unroll_factor == 4) { - HALF_LAYERNORM_OPT2(4); + HALF_LAYERNORM_OPT_RESIDUAL(4, 2); } else if (unroll_factor == 8) { - HALF_LAYERNORM_OPT2(8); + HALF_LAYERNORM_OPT_RESIDUAL(8, 2); } } } @@ -708,61 +801,78 @@ void invokeGeneralAddBiasResidualPreLayerNorm(T* output, block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x /* should pay attention to the rsqrt precision*/ - generalAddBiasResidualLayerNorm - <<>>(input, gamma, beta, bias, output, norm_output, m, n); // For gpt-3 + if (residual_num == 1) { + generalAddBiasResidualLayerNorm<<>>( + residual1, residual2, gamma, beta, bias, output, norm_output, layernorm_eps, m, n); // For gpt-3 + } + else if (residual_num == 2) { + generalAddBiasResidualLayerNorm<<>>( + residual1, residual2, gamma, beta, bias, output, norm_output, layernorm_eps, m, n); // For gpt-3 + } } } #undef HALF_LAYERNORM_OPT #undef HALF_LAYERNORM_OPT2 -template void invokeGeneralAddBiasResidualPreLayerNorm(float* output, - float* norm_output, - const float* input, +template void invokeGeneralAddBiasResidualPreLayerNorm(float* output, + float* norm_output, + const float* residual1, + const float* residual2, const float* gamma, const float* beta, const float* bias, - int m, - int n, + const float layernorm_eps, + int m, + int n, cudaStream_t stream, - int opt_version); - -template void invokeGeneralAddBiasResidualPreLayerNorm(half* output, - half* norm_output, - const half* input, - const half* gamma, - const half* beta, - const half* bias, - int m, - int n, + int opt_version); + +template void invokeGeneralAddBiasResidualPreLayerNorm(half* output, + half* norm_output, + const half* residual1, + const half* residual2, + const half* gamma, + const half* beta, + const half* bias, + const float layernorm_eps, + int m, + int n, cudaStream_t stream, - int opt_version); + int opt_version); #ifdef ENABLE_BF16 -template void invokeGeneralAddBiasResidualPreLayerNorm(__nv_bfloat16* output, - __nv_bfloat16* norm_output, - const __nv_bfloat16* input, +template void invokeGeneralAddBiasResidualPreLayerNorm(__nv_bfloat16* output, + __nv_bfloat16* norm_output, + const __nv_bfloat16* residual1, + const __nv_bfloat16* residual2, const __nv_bfloat16* gamma, const __nv_bfloat16* beta, const __nv_bfloat16* bias, - int m, - int n, - cudaStream_t stream, - int opt_version); + const float layernorm_eps, + int m, + int n, + cudaStream_t stream, + int opt_version); #endif template -__global__ void generalAddResidualT5LayerNorm( - const T* __restrict input, const T* __restrict gamma, T* output, T* norm_output, int m, int n) +__global__ void generalAddResidualT5LayerNorm(const T* __restrict input, + const T* __restrict gamma, + T* output, + T* norm_output, + const float layernorm_eps, + int m, + int n) { // layernorm module in the T5 style No bias and no subtraction of mean. __shared__ float s_variance; - float variance = 0.0f; + float variance = 0.0f; float local_var_sum = 0.0f; for (int i = threadIdx.x; i < n; i += blockDim.x) { output[blockIdx.x * n + i] = - clamp_inf_for_half((float)__ldg(&input[blockIdx.x * n + i]) + (float)output[blockIdx.x * n + i]); + clamp_inf_for_half((float)ldg(&input[blockIdx.x * n + i]) + (float)output[blockIdx.x * n + i]); float diff = (float)(output[blockIdx.x * n + i]); local_var_sum += diff * diff; @@ -770,20 +880,25 @@ __global__ void generalAddResidualT5LayerNorm( variance = blockReduceSum(local_var_sum); if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / (float)n + layernorm_eps); } __syncthreads(); for (int i = threadIdx.x; i < n; i += blockDim.x) { - float out_val = (((float)output[blockIdx.x * n + i]) * s_variance) * (float)(__ldg(&gamma[i])); norm_output[blockIdx.x * n + i] = - clamp_inf_for_half((((float)output[blockIdx.x * n + i]) * s_variance) * (float)(__ldg(&gamma[i]))); + clamp_inf_for_half((((float)output[blockIdx.x * n + i]) * s_variance) * (float)(ldg(&gamma[i]))); } } template -void invokeGeneralAddResidualT5PreLayerNorm( - T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream) +void invokeGeneralAddResidualT5PreLayerNorm(T* output, + T* norm_output, + const T* input, + const T* gamma, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream) { dim3 grid(m); dim3 block(min(n, 1024)); @@ -796,68 +911,101 @@ void invokeGeneralAddResidualT5PreLayerNorm( block.x = 1024; } - block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x - + // TODO(bhsueh) add 16bitx2 implementation /* should pay attention to the rsqrt precision*/ - generalAddResidualT5LayerNorm<<>>(input, gamma, output, norm_output, m, n); + generalAddResidualT5LayerNorm + <<>>(input, gamma, output, norm_output, layernorm_eps, m, n); } -template void invokeGeneralAddResidualT5PreLayerNorm( - float* output, float* norm_output, const float* input, const float* gamma, int m, int n, cudaStream_t stream); - -template void invokeGeneralAddResidualT5PreLayerNorm( - half* output, half* norm_output, const half* input, const half* gamma, int m, int n, cudaStream_t stream); +template void invokeGeneralAddResidualT5PreLayerNorm(float* output, + float* norm_output, + const float* input, + const float* gamma, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream); + +template void invokeGeneralAddResidualT5PreLayerNorm(half* output, + half* norm_output, + const half* input, + const half* gamma, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream); template -void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, - T* norm_output, - const T* input, - const T* gamma, - const T* beta, - const T* bias, - int m, - int n, +void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, + T* norm_output, + const T* input, + const T* gamma, + const T* beta, + const T* bias, + const float layernorm_eps, + int m, + int n, cudaStream_t stream) { if (beta != nullptr && bias != nullptr) { - invokeGeneralAddBiasResidualPreLayerNorm(output, norm_output, input, gamma, beta, bias, m, n, stream); + invokeGeneralAddBiasResidualPreLayerNorm( + output, norm_output, input, (const T*)nullptr, gamma, beta, bias, layernorm_eps, m, n, stream); } else { - invokeGeneralAddResidualT5PreLayerNorm(output, norm_output, input, gamma, m, n, stream); + invokeGeneralAddResidualT5PreLayerNorm(output, norm_output, input, gamma, layernorm_eps, m, n, stream); } return; } -template void invokeGeneralAddBiasResidualT5PreLayerNorm(float* output, - float* norm_output, +template void invokeGeneralAddBiasResidualT5PreLayerNorm(float* output, + float* norm_output, const float* input, const float* gamma, const float* beta, const float* bias, - int m, - int n, + const float layernorm_eps, + int m, + int n, cudaStream_t stream); -template void invokeGeneralAddBiasResidualT5PreLayerNorm(half* output, - half* norm_output, - const half* input, - const half* gamma, - const half* beta, - const half* bias, - int m, - int n, +template void invokeGeneralAddBiasResidualT5PreLayerNorm(half* output, + half* norm_output, + const half* input, + const half* gamma, + const half* beta, + const half* bias, + const float layernorm_eps, + int m, + int n, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeGeneralAddBiasResidualT5PreLayerNorm(__nv_bfloat16* output, + __nv_bfloat16* norm_output, + const __nv_bfloat16* input, + const __nv_bfloat16* gamma, + const __nv_bfloat16* beta, + const __nv_bfloat16* bias, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream); +#endif template -__global__ void generalLayerNorm( - const T* __restrict input, const T* __restrict gamma, const T* __restrict beta, T* output, int m, int n) +__global__ void generalLayerNorm(const T* __restrict input, + const T* __restrict gamma, + const T* __restrict beta, + T* output, + const float layernorm_eps, + int m, + int n) { const int tid = threadIdx.x; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float local_sum = 0.0f; for (int i = tid; i < n; i += blockDim.x) { @@ -879,7 +1027,7 @@ __global__ void generalLayerNorm( variance = blockReduceSum(local_var_sum); if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); @@ -891,30 +1039,49 @@ __global__ void generalLayerNorm( } #define HALF_LAYERNORM_OPT(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt<<>>( \ - (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n); + generalAddBiasResidualLayerNormOpt \ + <<>>((T2*)out, \ + (T2*)out, \ + nullptr, \ + (const T2*)input, \ + nullptr, \ + (const T2*)gamma, \ + (const T2*)beta, \ + layernorm_eps, \ + m, \ + half_n); #define HALF_LAYERNORM_OPT2(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt2<<>>( \ - (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n); + generalAddBiasResidualLayerNormOpt2 \ + <<>>((T2*)out, \ + (T2*)out, \ + nullptr, \ + (const T2*)input, \ + nullptr, \ + (const T2*)gamma, \ + (const T2*)beta, \ + layernorm_eps, \ + m, \ + half_n); template -void invokeGeneralLayerNorm(T* out, - const T* input, - const T* gamma, - const T* beta, - const int m, - const int n, +void invokeGeneralLayerNorm(T* out, + const T* input, + const T* gamma, + const T* beta, + const float layernorm_eps, + const int m, + const int n, cudaStream_t stream, - int opt_version) + int opt_version) { dim3 grid(m); if (n % 2 == 0 && std::is_same::value && opt_version > 0) { - int half_n = n / 2; - int half_n_32 = (half_n + 31) / 32 * 32; + int half_n = n / 2; + int half_n_32 = (half_n + 31) / 32 * 32; dim3 block(min(half_n_32, 512)); - int rolls_per_thread = half_n / block.x; - int unroll_factor = 8; + int rolls_per_thread = half_n / block.x; + int unroll_factor = 8; while (unroll_factor > rolls_per_thread && unroll_factor > 1) { unroll_factor /= 2; } @@ -965,73 +1132,83 @@ void invokeGeneralLayerNorm(T* out, } /* should pay attention to the rsqrt precision*/ - generalLayerNorm<<>>(input, gamma, beta, out, m, n); // For gpt-3 + generalLayerNorm<<>>(input, gamma, beta, out, layernorm_eps, m, n); // For gpt-3 } } #undef HALF_LAYERNORM_OPT #undef HALF_LAYERNORM_OPT2 -template void invokeGeneralLayerNorm(float* out, +template void invokeGeneralLayerNorm(float* out, const float* input, const float* gamma, const float* beta, - const int m, - const int n, + const float layernorm_eps, + const int m, + const int n, cudaStream_t stream, - int opt_version); -template void invokeGeneralLayerNorm(half* out, - const half* input, - const half* gamma, - const half* beta, - const int m, - const int n, + int opt_version); +template void invokeGeneralLayerNorm(half* out, + const half* input, + const half* gamma, + const half* beta, + const float layernorm_eps, + const int m, + const int n, cudaStream_t stream, - int opt_version); + int opt_version); #ifdef ENABLE_BF16 -template void invokeGeneralLayerNorm(__nv_bfloat16* out, +template void invokeGeneralLayerNorm(__nv_bfloat16* out, const __nv_bfloat16* input, const __nv_bfloat16* gamma, const __nv_bfloat16* beta, - const int m, - const int n, - cudaStream_t stream, - int opt_version); + const float layernorm_eps, + const int m, + const int n, + cudaStream_t stream, + int opt_version); #endif template -__global__ void generalT5LayerNorm(const T* __restrict input, const T* __restrict gamma, T* output, int m, int n) +__global__ void generalT5LayerNorm( + const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n) { // layernorm module in the T5 style No bias and no subtraction of mean. const int tid = threadIdx.x; __shared__ float s_variance; - float variance = 0.0f; + float variance = 0.0f; float local_var_sum = 0.0f; for (int i = tid; i < n; i += blockDim.x) { - float diff = (float)(__ldg(&input[blockIdx.x * n + i])); + float diff = (float)(ldg(&input[blockIdx.x * n + i])); local_var_sum += diff * diff; } variance = blockReduceSum(local_var_sum); if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / (float)n + layernorm_eps); } __syncthreads(); for (int i = tid; i < n; i += blockDim.x) { output[blockIdx.x * n + i] = - clamp_inf_for_half((((float)input[blockIdx.x * n + i]) * s_variance) * (float)(__ldg(&gamma[i]))); + clamp_inf_for_half((((float)input[blockIdx.x * n + i]) * s_variance) * (float)(ldg(&gamma[i]))); } } template -void invokeGeneralT5LayerNorm( - T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream) +void invokeGeneralT5LayerNorm(T* out, + const T* input, + const T* gamma, + const T* beta, + const float layernorm_eps, + const int m, + const int n, + cudaStream_t stream) { if (beta != nullptr) { - invokeGeneralLayerNorm(out, input, gamma, beta, m, n, stream); + invokeGeneralLayerNorm(out, input, gamma, beta, layernorm_eps, m, n, stream); return; } @@ -1048,393 +1225,441 @@ void invokeGeneralT5LayerNorm( block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x /* should pay attention to the rsqrt precision*/ - generalT5LayerNorm<<>>(input, gamma, out, m, n); // For gpt-3 + generalT5LayerNorm<<>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3 } -template void invokeGeneralT5LayerNorm(float* out, +template void invokeGeneralT5LayerNorm(float* out, const float* input, const float* gamma, const float* beta, - const int m, - const int n, + const float layernorm_eps, + const int m, + const int n, cudaStream_t stream); -template void invokeGeneralT5LayerNorm( - half* out, const half* input, const half* gamma, const half* beta, const int m, const int n, cudaStream_t stream); +template void invokeGeneralT5LayerNorm(half* out, + const half* input, + const half* gamma, + const half* beta, + const float layernorm_eps, + const int m, + const int n, + cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeGeneralT5LayerNorm(__nv_bfloat16* out, + const __nv_bfloat16* input, + const __nv_bfloat16* gamma, + const __nv_bfloat16* beta, + const float layernorm_eps, + const int m, + const int n, + cudaStream_t stream); +#endif /******************* invokeLayernormShiftPartition ***********************/ -template -__global__ void layernorm_shift_partition(T* out, - const T* input, - const T* gamma, - const T* beta, - int batch, - int H, - int W, - int n, - int shift_size, - int window_size) +// applied to half2 and bfloat162 +template +__global__ void layernorm_shift_partition(T2* out_ptr, + const T2* input_ptr, + const T2* gamma_ptr, + const T2* beta_ptr, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size) { - int tid = threadIdx.x; - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; - const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; - const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; - const int window_H_idx = shifted_H_idx / window_size; - const int window_W_idx = shifted_W_idx / window_size; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; + const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; const int stride_of_window_H = W / window_size; - const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; - const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); - const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); + const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; + int tid = threadIdx.x; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; + float2 local_out_fp2; - float local_out = (tid < n) ? (float)(__ldg(input + bid * n + tid)) : 0.0f; + float local_out = 0.0f; + int id = bid * n + tid; + if (tid < n) { + local_out_fp2 = type22float2(ldg(input_ptr + id)); + local_out += local_out_fp2.x; + local_out += local_out_fp2.y; + } mean = blockReduceSum(local_out); - if (threadIdx.x == 0) { - s_mean = mean / n; - } + if (threadIdx.x == 0) + s_mean = mean / (n * 2); __syncthreads(); - float diff = (tid < n) ? (local_out - s_mean) : 0.0f; - variance = blockReduceSum(diff * diff); + if (tid < n) { + variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); + variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); + } + variance = blockReduceSum(variance); if (threadIdx.x == 0) { - s_variance = variance / n + 1e-6f; + s_variance = rsqrtf(variance / (n * 2) + layernorm_eps); } __syncthreads(); if (tid < n) { - out[output_bid * n + tid] = - (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); + float2 gamma_val = type22float2(ldg(&gamma_ptr[tid])); + float2 beta_val = type22float2(ldg(&beta_ptr[tid])); + local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; + local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; + out_ptr[output_bid * n + tid] = float22type2(local_out_fp2); } } +// applied to float template<> -__global__ void layernorm_shift_partition(half2* out_ptr, - const half2* input_ptr, - const half2* gamma_ptr, - const half2* beta_ptr, - int batch, - int H, - int W, - int n, - int shift_size, - int window_size) +__global__ void layernorm_shift_partition(float* out, + const float* input, + const float* gamma, + const float* beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size) { - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; - const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; - const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; - const int window_H_idx = shifted_H_idx / window_size; - const int window_W_idx = shifted_W_idx / window_size; + int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; + const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; const int stride_of_window_H = W / window_size; - const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; - const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); - const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; - int tid = threadIdx.x; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); + const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float2 local_out_fp2; + float mean = 0.0f; + float variance = 0.0f; - float local_out = 0.0f; - int id = bid * n + tid; - if (tid < n) { - local_out_fp2 = __half22float2(__ldg(input_ptr + id)); - local_out += local_out_fp2.x; - local_out += local_out_fp2.y; - } + float local_out = (tid < n) ? (float)(__ldg(input + bid * n + tid)) : 0.0f; mean = blockReduceSum(local_out); if (threadIdx.x == 0) { - s_mean = mean / (n * 2); + s_mean = mean / n; } __syncthreads(); - if (tid < n) { - variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); - variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); - } - variance = blockReduceSum(variance); + float diff = (tid < n) ? (local_out - s_mean) : 0.0f; + variance = blockReduceSum(diff * diff); if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / (n * 2) + 1e-6f); + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); if (tid < n) { - float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); - float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); - local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; - local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; - out_ptr[output_bid * n + tid] = __float22half2_rn(local_out_fp2); + out[output_bid * n + tid] = + (float)(((local_out - s_mean) * s_variance) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); } } -template -__global__ void layernorm_shift_partition_v2(T* out, - const T* __restrict input, - const T* __restrict gamma, - const T* __restrict beta, - int batch, - int H, - int W, - int n, - int shift_size, - int window_size) +// Applied to half2 and bfloat162 +template +__global__ void layernorm_shift_partition_v2(T2* out_ptr, + const T2* __restrict input_ptr, + const T2* __restrict gamma_ptr, + const T2* __restrict beta_ptr, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size) { - const int ite = 4; - const int tid = threadIdx.x; - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; - const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; - const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; - const int window_H_idx = shifted_H_idx / window_size; - const int window_W_idx = shifted_W_idx / window_size; + using T1 = typename TypeConverter::Type; // half2 to half, bfloat162 to bfloat + const int ite = 4; + const int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; + const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; const int stride_of_window_H = W / window_size; - const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; - const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); - const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; - const int offset = bid * n; - const int output_offset = output_bid * n; - + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); + const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; + const int offset = bid * n; + const int output_offset = output_bid * n; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float local_out[ite]; + float mean = 0.0f; + float variance = 0.0f; + T2 local_out_half2[ite]; + const T2 zero = {static_cast(0.0f), static_cast(0.0f)}; - float sum = 0.0f; + // float sum = 0.0f; + T2 sum = float2type2(0.0f); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { - local_out[i] = (float)(__ldg(input + offset + col_id)); - sum += local_out[i]; + local_out_half2[i] = ldg(input_ptr + offset + col_id); + sum = add(sum, local_out_half2[i]); } } - mean = blockReduceSum(sum); - if (tid == 0) { - s_mean = mean / n; + mean = blockReduceSum((float)(sum.x + sum.y)); + if (threadIdx.x == 0) { + s_mean = mean / (n * 2); } __syncthreads(); - float var = 0.0f; + float var = 0.0f; + T2 s_mean_2 = float2type2(s_mean); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { - float diff = local_out[i] - s_mean; - local_out[i] = diff; - var += diff * diff; + local_out_half2[i] = hsub2(local_out_half2[i], s_mean_2); + float v1 = (float)local_out_half2[i].x; + float v2 = (float)local_out_half2[i].y; + var += v1 * v1 + v2 * v2; } } variance = blockReduceSum(var); if (tid == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / (n * 2) + layernorm_eps); } __syncthreads(); + T2 s_var_2 = float2type2(s_variance); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { - out[output_offset + col_id] = - (T)(local_out[i] * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); + out_ptr[output_offset + col_id] = + fma(out_ptr[output_offset + col_id], s_var_2, ldg(&gamma_ptr[col_id]), ldg(&beta_ptr[col_id])); } } } template<> -__global__ void layernorm_shift_partition_v2(half2* out_ptr, - const half2* __restrict input_ptr, - const half2* __restrict gamma_ptr, - const half2* __restrict beta_ptr, - int batch, - int H, - int W, - int n, - int shift_size, - int window_size) +__global__ void layernorm_shift_partition_v2(float* out, + const float* __restrict input, + const float* __restrict gamma, + const float* __restrict beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size) { - const int ite = 4; - const int tid = threadIdx.x; - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; - const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; - const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; - const int window_H_idx = shifted_H_idx / window_size; - const int window_W_idx = shifted_W_idx / window_size; + const int ite = 4; + const int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) : blockIdx.y; + const int shifted_W_idx = (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; const int stride_of_window_H = W / window_size; - const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; - const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); - const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; - const int offset = bid * n; - const int output_offset = output_bid * n; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + (shifted_W_idx % window_size); + const int output_bid = batch_offset + window_idx * window_size * window_size + idx_in_window; + const int offset = bid * n; + const int output_offset = output_bid * n; + __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - half2 local_out_half2[ite]; - const half2 zero = {static_cast(0.0f), static_cast(0.0f)}; + float mean = 0.0f; + float variance = 0.0f; + float local_out[ite]; - // float sum = 0.0f; - half2 sum = __float2half2_rn(0.0f); + float sum = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { - local_out_half2[i] = __ldg(input_ptr + offset + col_id); - sum += local_out_half2[i]; + local_out[i] = (float)(__ldg(input + offset + col_id)); + sum += local_out[i]; } } - mean = blockReduceSum((float)(sum.x + sum.y)); - if (threadIdx.x == 0) { - s_mean = mean / (n * 2); + mean = blockReduceSum(sum); + if (tid == 0) { + s_mean = mean / n; } __syncthreads(); float var = 0.0f; - half2 s_mean_2 = __float2half2_rn(s_mean); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { - local_out_half2[i] = local_out_half2[i] - s_mean_2; - float v1 = (float)local_out_half2[i].x; - float v2 = (float)local_out_half2[i].y; - var += v1 * v1 + v2 * v2; + float diff = local_out[i] - s_mean; + local_out[i] = diff; + var += diff * diff; } } variance = blockReduceSum(var); - if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / (n * 2) + 1e-6f); + if (tid == 0) { + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); - half2 s_var_2 = __float2half2_rn(s_variance); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { - out_ptr[output_offset + col_id] = - local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id]) + __ldg(&beta_ptr[col_id]); + out[output_offset + col_id] = + (float)(local_out[i] * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); } } } +// Applied to half or Bfloat16 template -void invokeLayernormShiftPartition(T* out, - const T* input, - const T* gamma, - const T* beta, - int batch, - int H, - int W, - int n, - int shift_size, - int window_size, +void invokeLayernormShiftPartition(T* out, + const T* input, + const T* gamma, + const T* beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, cudaStream_t stream) { dim3 grid(W, H, batch); - int blockSize = (n + 31) / 32 * 32; - if (blockSize >= 768) { + int blockSize = n / 2; + blockSize = (blockSize + 31) / 32 * 32; + + using T2 = typename TypeConverter::Type; // bf162 or half2 + + if ((batch * H * W >= 512 && blockSize >= 768) || blockSize > 1024) { blockSize = ((blockSize / 4) + 31) / 32 * 32; - layernorm_shift_partition_v2 - <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size); + layernorm_shift_partition_v2<<>>((T2*)out, + (const T2*)input, + (const T2*)gamma, + (const T2*)beta, + layernorm_eps, + batch, + H, + W, + n / 2, + shift_size, + window_size); } else { - layernorm_shift_partition - <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size); + layernorm_shift_partition<<>>((T2*)out, + (const T2*)input, + (const T2*)gamma, + (const T2*)beta, + layernorm_eps, + batch, + H, + W, + n / 2, + shift_size, + window_size); } } template<> -void invokeLayernormShiftPartition(half* out, - const half* input, - const half* gamma, - const half* beta, - int batch, - int H, - int W, - int n, - int shift_size, - int window_size, - cudaStream_t stream) +void invokeLayernormShiftPartition(float* out, + const float* input, + const float* gamma, + const float* beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + cudaStream_t stream) { dim3 grid(W, H, batch); - int blockSize = n / 2; - blockSize = (blockSize + 31) / 32 * 32; - - if ((batch * H * W >= 512 && blockSize >= 768) || blockSize > 1024) { + int blockSize = (n + 31) / 32 * 32; + if (blockSize >= 768) { blockSize = ((blockSize / 4) + 31) / 32 * 32; - layernorm_shift_partition_v2<<>>((half2*)out, - (const half2*)input, - (const half2*)gamma, - (const half2*)beta, - batch, - H, - W, - n / 2, - shift_size, - window_size); + layernorm_shift_partition_v2<<>>( + out, input, gamma, beta, layernorm_eps, batch, H, W, n, shift_size, window_size); } else { - layernorm_shift_partition<<>>((half2*)out, - (const half2*)input, - (const half2*)gamma, - (const half2*)beta, - batch, - H, - W, - n / 2, - shift_size, - window_size); + layernorm_shift_partition<<>>( + out, input, gamma, beta, layernorm_eps, batch, H, W, n, shift_size, window_size); } } -template void invokeLayernormShiftPartition(float* out, +template void invokeLayernormShiftPartition(float* out, const float* input, const float* gamma, const float* beta, - int batch, - int H, - int W, - int n, - int shift_size, - int window_size, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, cudaStream_t stream); -template void invokeLayernormShiftPartition(half* out, - const half* input, - const half* gamma, - const half* beta, - int batch, - int H, - int W, - int n, - int shift_size, - int window_size, +template void invokeLayernormShiftPartition(half* out, + const half* input, + const half* gamma, + const half* beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeLayernormShiftPartition<__nv_bfloat16>(__nv_bfloat16* out, + const __nv_bfloat16* input, + const __nv_bfloat16* gamma, + const __nv_bfloat16* beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + cudaStream_t stream); +#endif + /******************* invokeAddBiasLayernorm ***********************/ template -__global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const T* beta, int n) +__global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const T* beta, float layernorm_eps, int n) { - int tid = threadIdx.x; - const int bid = blockIdx.x; + int tid = threadIdx.x; + const int bid = blockIdx.x; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; float local_out = (tid < n) ? (float)(out[bid * n + tid] + ldg(&bias[tid])) : 0.0f; @@ -1445,9 +1670,9 @@ __global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const __syncthreads(); float diff = (tid < n) ? (local_out - s_mean) : 0.0f; - variance = blockReduceSum(diff * diff); + variance = blockReduceSum(diff * diff); if (threadIdx.x == 0) { - s_variance = variance / n + 1e-6f; + s_variance = variance / n + layernorm_eps; } __syncthreads(); @@ -1458,24 +1683,24 @@ __global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const } template -__global__ void -add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamma, const T* __restrict beta, int n) +__global__ void add_bias_layernorm_v2( + T* out, const T* __restrict bias, const T* __restrict gamma, const T* __restrict beta, float layernorm_eps, int n) { - const int ite = 4; - const int tid = threadIdx.x; - const int bid = blockIdx.x; + const int ite = 4; + const int tid = threadIdx.x; + const int bid = blockIdx.x; const int offset = bid * n; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float local_out[ite]; + float mean = 0.0f; + float variance = 0.0f; + float local_out[ite]; float sum = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { - int col_id = i * blockDim.x + tid; + int col_id = i * blockDim.x + tid; local_out[i] = (col_id < n) ? (float)(out[offset + col_id] + ldg(&bias[col_id])) : 0.0f; sum += local_out[i]; } @@ -1489,14 +1714,14 @@ add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamm float var = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { - int col_id = i * blockDim.x + tid; - float diff = (col_id < n) ? (local_out[i] - s_mean) : 0.0f; + int col_id = i * blockDim.x + tid; + float diff = (col_id < n) ? (local_out[i] - s_mean) : 0.0f; var += diff * diff; } variance = blockReduceSum(var); if (tid == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); @@ -1511,24 +1736,49 @@ add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamm } #define HALF_LAYERNORM_OPT(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt<<>>( \ - (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n); + generalAddBiasResidualLayerNormOpt \ + <<>>((T2*)out, \ + (T2*)out, \ + (const T2*)bias, \ + (const T2*)out, \ + nullptr, \ + (const T2*)gamma, \ + (const T2*)beta, \ + layernorm_eps, \ + m, \ + half_n); #define HALF_LAYERNORM_OPT2(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt2<<>>( \ - (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n); + generalAddBiasResidualLayerNormOpt2 \ + <<>>((T2*)out, \ + (T2*)out, \ + (const T2*)bias, \ + (const T2*)out, \ + nullptr, \ + (const T2*)gamma, \ + (const T2*)beta, \ + layernorm_eps, \ + m, \ + half_n); template -void invokeAddBiasLayernorm( - T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, int opt_version) +void invokeAddBiasLayernorm(T* out, + const T* bias, + const T* gamma, + const T* beta, + float layernorm_eps, + int m, + int n, + cudaStream_t stream, + int opt_version) { dim3 grid(m); if (n % 2 == 0 && std::is_same::value && opt_version > 0) { - int half_n = n / 2; - int half_n_32 = (half_n + 31) / 32 * 32; + int half_n = n / 2; + int half_n_32 = (half_n + 31) / 32 * 32; dim3 block(min(half_n_32, 512)); - int rolls_per_thread = half_n / block.x; - int unroll_factor = 8; + int rolls_per_thread = half_n / block.x; + int unroll_factor = 8; while (unroll_factor > rolls_per_thread && unroll_factor > 1) { unroll_factor /= 2; } @@ -1572,10 +1822,10 @@ void invokeAddBiasLayernorm( int blockSize = (n + 31) / 32 * 32; if (blockSize >= 768) { blockSize = ((blockSize / 4) + 31) / 32 * 32; - add_bias_layernorm_v2<<>>(out, bias, gamma, beta, n); + add_bias_layernorm_v2<<>>(out, bias, gamma, beta, layernorm_eps, n); } else { - add_bias_layernorm<<>>(out, bias, gamma, beta, n); + add_bias_layernorm<<>>(out, bias, gamma, beta, layernorm_eps, n); } } } @@ -1583,32 +1833,35 @@ void invokeAddBiasLayernorm( #undef HALF_LAYERNORM_OPT #undef HALF_LAYERNORM_OPT2 -template void invokeAddBiasLayernorm(float* out, +template void invokeAddBiasLayernorm(float* out, const float* bias, const float* gamma, const float* beta, - int m, - int n, + const float layernorm_eps, + int m, + int n, cudaStream_t stream, - int opt_version); - -template void invokeAddBiasLayernorm(half* out, - const half* bias, - const half* gamma, - const half* beta, - int m, - int n, + int opt_version); + +template void invokeAddBiasLayernorm(half* out, + const half* bias, + const half* gamma, + const half* beta, + const float layernorm_eps, + int m, + int n, cudaStream_t stream, - int opt_version); + int opt_version); #ifdef ENABLE_BF16 -template void invokeAddBiasLayernorm<__nv_bfloat16>(__nv_bfloat16* out, +template void invokeAddBiasLayernorm<__nv_bfloat16>(__nv_bfloat16* out, const __nv_bfloat16* bias, const __nv_bfloat16* gamma, const __nv_bfloat16* beta, - int m, - int n, - cudaStream_t stream, - int opt_version); + const float layernorm_eps, + int m, + int n, + cudaStream_t stream, + int opt_version); #endif /******************* invokeMergeLayernorm ***********************/ @@ -1622,37 +1875,38 @@ __global__ void merge_layernorm_v2(T* out, const T* __restrict input, const T* __restrict gamma, const T* __restrict beta, - int batch, - int H, - int W, - int n) + const float layernorm_eps, + int batch, + int H, + int W, + int n) { - const int ite = 4; - const int tid = threadIdx.x; - const int W_idx = blockIdx.x; - const int H_idx = blockIdx.y; - const size_t batch_offset = blockIdx.z * H * W * n; - const int input_H_stride = W * n / 2; - const int output_H_stride = W * n; - const int n_4 = n >> 2; + const int ite = 4; + const int tid = threadIdx.x; + const int W_idx = blockIdx.x; + const int H_idx = blockIdx.y; + const size_t batch_offset = blockIdx.z * H * W * n; + const int input_H_stride = W * n / 2; + const int output_H_stride = W * n; + const int n_4 = n >> 2; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float local_out[ite]; + float mean = 0.0f; + float variance = 0.0f; + float local_out[ite]; float sum = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { - int part_id = col_id / n_4; - int offset_in_W = part_id / 2; - int offset_in_H = part_id % 2; - size_t input_id = batch_offset + (2 * H_idx + offset_in_H) * input_H_stride + int part_id = col_id / n_4; + int offset_in_W = part_id / 2; + int offset_in_H = part_id % 2; + size_t input_id = batch_offset + (2 * H_idx + offset_in_H) * input_H_stride + (2 * W_idx + offset_in_W) * n_4 + (col_id % n_4); - local_out[i] = (float)(__ldg(input + input_id)); + local_out[i] = (float)(ldg(input + input_id)); sum += local_out[i]; } } @@ -1675,7 +1929,7 @@ __global__ void merge_layernorm_v2(T* out, variance = blockReduceSum(var); if (tid == 0) { - s_variance = rsqrtf(variance / n + 1e-6f); + s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); @@ -1684,29 +1938,37 @@ __global__ void merge_layernorm_v2(T* out, int col_id = i * blockDim.x + tid; if (col_id < n) { size_t output_idx = batch_offset + H_idx * output_H_stride + W_idx * n + col_id; - out[output_idx] = - (T)(local_out[i] * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); + out[output_idx] = (T)(local_out[i] * s_variance * (float)ldg(&gamma[col_id]) + (float)ldg(&beta[col_id])); } } } // TODO : accelerate with half2 template -void invokeMergeLayernorm( - T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream) +void invokeMergeLayernorm(T* output, + const T* input, + const T* gamma, + const T* beta, + float layernorm_eps, + int batch, + int H, + int W, + int n, + cudaStream_t stream) { if ((W % 2 != 0) || (H % 2 != 0)) { printf("[ERROR][invokeMergeLayernorm] H(W) should be a multiple of 2.\n"); return; } dim3 grid(W / 2, H / 2, batch); - int blockSize = 4 * n; - blockSize = (blockSize + 31) / 32 * 32; + int blockSize = 4 * n; + blockSize = (blockSize + 31) / 32 * 32; // TODO // if (blockSize >= 768) { blockSize = ((blockSize / 4) + 31) / 32 * 32; - merge_layernorm_v2<<>>(output, input, gamma, beta, batch, H / 2, W / 2, n * 4); + merge_layernorm_v2 + <<>>(output, input, gamma, beta, layernorm_eps, batch, H / 2, W / 2, n * 4); } /* else @@ -1714,24 +1976,39 @@ void invokeMergeLayernorm( */ } -template void invokeMergeLayernorm(float* output, +template void invokeMergeLayernorm(float* output, const float* input, const float* gamma, const float* beta, - int batch, - int H, - int W, - int n, + float layernorm_eps, + int batch, + int H, + int W, + int n, cudaStream_t stream); -template void invokeMergeLayernorm(half* output, - const half* input, - const half* gamma, - const half* beta, - int batch, - int H, - int W, - int n, +template void invokeMergeLayernorm(half* output, + const half* input, + const half* gamma, + const half* beta, + float layernorm_eps, + int batch, + int H, + int W, + int n, cudaStream_t stream); -} // namespace fastertransformer \ No newline at end of file +#ifdef ENABLE_BF16 +template void invokeMergeLayernorm<__nv_bfloat16>(__nv_bfloat16* output, + const __nv_bfloat16* input, + const __nv_bfloat16* gamma, + const __nv_bfloat16* beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + cudaStream_t stream); +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/layernorm_kernels.h b/src/fastertransformer/kernels/layernorm_kernels.h index e8319debd..9a749e4c5 100644 --- a/src/fastertransformer/kernels/layernorm_kernels.h +++ b/src/fastertransformer/kernels/layernorm_kernels.h @@ -17,93 +17,159 @@ #pragma once #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_utils.h" #include #include #include namespace fastertransformer { -enum LayerNormType { +enum class LayerNormType { pre_layernorm, - post_layernorm + post_layernorm, + InvalidType }; +inline LayerNormType getLayerNormType(std::string layernorm_type_str) +{ + if (layernorm_type_str == "pre_layernorm") { + return LayerNormType::pre_layernorm; + } + else if (layernorm_type_str == "post_layernorm") { + return LayerNormType::post_layernorm; + } + else { + FT_CHECK_WITH_INFO(false, "Layernorm Type: " + layernorm_type_str + " not supported !"); + } + return LayerNormType::InvalidType; +} + template struct LayerNormWeight { const T* gamma = nullptr; - const T* beta = nullptr; + const T* beta = nullptr; }; template -void invokeAddBiasResidualLayerNorm(T* out, - const T* input, - const T* bias, - const T* gamma, - const T* beta, - const int m, - const int n, +void invokeAddBiasResidualLayerNorm(T* out, + const T* input, + const T* bias, + const T* gamma, + const T* beta, + const float layernorm_eps, + const int m, + const int n, cudaStream_t stream); template -void invokeGeneralAddBiasResidualPreLayerNorm(T* output, - T* norm_output, - const T* input, - const T* gamma, - const T* beta, - const T* bias, - int m, - int n, +void invokeGeneralAddBiasResidualPreLayerNorm(T* output, + T* norm_output, + const T* residual1, + const T* residual2, + const T* gamma, + const T* beta, + const T* bias, + const float layernorm_eps, + int m, + int n, cudaStream_t stream, - int opt_version = 2); + int opt_version = 2); template -void invokeGeneralLayerNorm(T* out, - const T* input, - const T* gamma, - const T* beta, - const int m, - const int n, +void invokeGeneralAddBiasResidualPreLayerNorm(T* output, + T* norm_output, + const T* input, + const T* gamma, + const T* beta, + const T* bias, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream, + int opt_version = 2) +{ + invokeGeneralAddBiasResidualPreLayerNorm( + output, norm_output, input, (const T*)nullptr, gamma, beta, bias, layernorm_eps, m, n, stream, opt_version); +} + +template +void invokeGeneralLayerNorm(T* out, + const T* input, + const T* gamma, + const T* beta, + const float layernorm_eps, + const int m, + const int n, cudaStream_t stream, - int opt_version = 2); + int opt_version = 2); template -void invokeGeneralT5LayerNorm( - T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream); +void invokeGeneralT5LayerNorm(T* out, + const T* input, + const T* gamma, + const T* beta, + const float layernorm_eps, + const int m, + const int n, + cudaStream_t stream); template -void invokeGeneralAddResidualT5PreLayerNorm( - T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream); +void invokeGeneralAddResidualT5PreLayerNorm(T* output, + T* norm_output, + const T* input, + const T* gamma, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream); template -void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, - T* norm_output, - const T* input, - const T* gamma, - const T* beta, - const T* bias, - int m, - int n, +void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, + T* norm_output, + const T* input, + const T* gamma, + const T* beta, + const T* bias, + const float layernorm_eps, + int m, + int n, cudaStream_t stream); template -void invokeLayernormShiftPartition(T* out, - const T* input, - const T* gamma, - const T* beta, - int batch, - int H, - int W, - int n, - int shift_size, - int window_size, +void invokeLayernormShiftPartition(T* out, + const T* input, + const T* gamma, + const T* beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, cudaStream_t stream); template -void invokeAddBiasLayernorm( - T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, int opt_version = 2); +void invokeAddBiasLayernorm(T* out, + const T* bias, + const T* gamma, + const T* beta, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream, + int opt_version = 2); template -void invokeMergeLayernorm( - T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream); +void invokeMergeLayernorm(T* output, + const T* input, + const T* gamma, + const T* beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n, + cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/layout_transformer_int8_kernels.cu b/src/fastertransformer/kernels/layout_transformer_int8_kernels.cu index ab8905ea7..f718f8d44 100644 --- a/src/fastertransformer/kernels/layout_transformer_int8_kernels.cu +++ b/src/fastertransformer/kernels/layout_transformer_int8_kernels.cu @@ -83,10 +83,10 @@ invokeTransposeMatrixCOL32ToColMajor(half* dst, const half* src, const int template void invokeTransposeMatrixCOL32ToColMajor( int8_t* dst, const int8_t* src, const int m, const int n, cudaStream_t stream); -// transpose matrix & transfrom col-major to COL32 +// transpose matrix & transform col-major to COL32 // input matrix is (m, n) col-major // output matrix is (n, m) COL32 -// m should be a mutiple of 32 +// m should be a multiple of 32 // grid((m+31)/32, (n+31)/32) // block(32, 32) template @@ -105,10 +105,10 @@ __global__ void transposeMatrix_colMajorToCOL32_kernel(T* dst, const T* src, con } } -// transpose matrix & transfrom col-major to COL32 +// transpose matrix & transform col-major to COL32 // input matrix is (m, n) col-major // output matrix is (n, m) COL32 -// m should be a mutiple of 32 +// m should be a multiple of 32 // grid((m+31)/32, (n+31)/32) // block(16, 32) template<> @@ -127,10 +127,10 @@ __global__ void transposeMatrix_colMajorToCOL32_kernel(half2* dst, const half2* } } -// transpose matrix & transfrom col-major to COL32 +// transpose matrix & transform col-major to COL32 // input matrix is (m, n) col-major // output matrix is (n, m) COL32, using char4 to write out -// m should be a mutiple of 32 +// m should be a multiple of 32 // grid((m+31)/32, (n+31)/32) // block(8, 32) template @@ -153,10 +153,10 @@ template void invokeTransposeMatrixColMajorToCOL32( template void invokeTransposeMatrixColMajorToCOL32(half* dst, const half* src, const int m, const int n, cudaStream_t stream); -// transpose matrix & transfrom col-major to COL32 & quantize +// transpose matrix & transform col-major to COL32 & quantize // input matrix is (m, n) col-major // output matrix is (n, m) COL32, using char4 to write out -// m should be a mutiple of 32 +// m should be a multiple of 32 // grid((m+31)/32, (n+31)/32) // block(8, 32) template @@ -184,10 +184,10 @@ __global__ void transposeMatrix_colMajorToCOL32_quantize_kernel( } } -// transpose matrix & transfrom col-major to COL32 & quantize +// transpose matrix & transform col-major to COL32 & quantize // input matrix is (m, n) col-major // output matrix is (n, m) COL32, using char4 to write out -// m should be a mutiple of 32 +// m should be a multiple of 32 // grid((m+31)/32, (n+31)/32) // block(8, 32) template @@ -205,10 +205,10 @@ template void invokeTransposeMatrixColMajorToCOL32Quantize( template void invokeTransposeMatrixColMajorToCOL32Quantize( int8_t* dst, const half* src, const int m, const int n, const float* scale_ptr, cudaStream_t stream); -// transfrom row-major to COL32 +// transform row-major to COL32 // input matrix is (m, n) row-major // output matrix is (m, n) COL32 -// n should be a mutiple of 32 +// n should be a multiple of 32 // grid((n+31)/32, (m+31)/32) // block(8, 32) __global__ void rowMajorToCOL32_kernel(char4* dst, const char4* src, const int m, const int n) @@ -226,10 +226,10 @@ __global__ void rowMajorToCOL32_kernel(char4* dst, const char4* src, const int m } } -// transfrom row-major to COL32 +// transform row-major to COL32 // input matrix is (m, n) row-major // output matrix is (m, n) COL32 -// n should be a mutiple of 32 +// n should be a multiple of 32 // grid((n+31)/32, (m+31)/32) // block(8, 32) void invokeRowMajorToCOL32(int8_t* dst, const int8_t* src, const int m, const int n, cudaStream_t stream) diff --git a/src/fastertransformer/kernels/logprob_kernels.cu b/src/fastertransformer/kernels/logprob_kernels.cu index 5d4938766..37e998671 100644 --- a/src/fastertransformer/kernels/logprob_kernels.cu +++ b/src/fastertransformer/kernels/logprob_kernels.cu @@ -33,20 +33,20 @@ namespace fastertransformer { template -__global__ void log_probs_kernel(float* log_probs, - const T* logits, - const int* ids, - const int* lengths, +__global__ void log_probs_kernel(float* log_probs, + const T* logits, + const int* ids, + const int* lengths, const size_t max_input_length, const size_t batch_size, const size_t vocab_size, const size_t vocab_size_padded, - bool batch_first) + bool batch_first) { // Calculate the log probability from logits. - // log_probs[i,j] = log(softmax(logits))[ids[i,j]] + // log_probs[t, :] = log(softmax(logits))[ids[t + 1, :]] // - // log_probs: [max_length, batch_size] or [batch_size, max_length], + // log_probs: [max_length - 1, batch_size] or [batch_size, max_length -1], // log probabilities of each token. // logits: [max_length, batch_size, vocab_size_padded] or [batch_size, max_length, vocab_size_padded] // lengths: [batch_size], sequence lengths @@ -55,8 +55,8 @@ __global__ void log_probs_kernel(float* log_probs, // vocab_size: [1], vocab_size, // vocab_size: [1], vocab_size_padded, padded vocab size. - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; int tidx = threadIdx.x; // vocab dim int bidx = batch_first ? blockIdx.x : blockIdx.y; // batch dim @@ -64,17 +64,17 @@ __global__ void log_probs_kernel(float* log_probs, __shared__ float s_max_logit; - if (bidx < batch_size && step < lengths[bidx]) { + if (bidx < batch_size && step < lengths[bidx] - 1) { // reposition logits to data for the current batch. - int step_offset = batch_first ? step * vocab_size_padded : step * batch_size * vocab_size_padded; + int step_offset = batch_first ? step * vocab_size_padded : step * batch_size * vocab_size_padded; int batch_offset = batch_first ? bidx * max_input_length * vocab_size_padded : bidx * vocab_size_padded; logits += step_offset + batch_offset; // Find max(logits). float local_max = -MAX_T_VAL; - float val = -MAX_T_VAL; + float val = -MAX_T_VAL; for (int i = tidx; i < vocab_size; i += blockDim.x) { - val = static_cast(logits[i]); + val = static_cast(logits[i]); local_max = fmax(local_max, val); } @@ -93,24 +93,27 @@ __global__ void log_probs_kernel(float* log_probs, float sum_exp = blockDim.x <= 32 ? warpReduceSum(local_sum_exp) : blockReduceSum(local_sum_exp); if (tidx == 0) { - int idx = batch_first ? step + bidx * max_input_length : step * batch_size + bidx; - log_probs[idx] = static_cast(logits[ids[idx]]) - s_max_logit - __logf(sum_exp + 1e-9f); + int idx = batch_first ? step + bidx * (max_input_length - 1) : step * batch_size + bidx; + // log_probs[step, ...] is the log probability of a token at step t + 1. + int token_idx = batch_first ? step + 1 + bidx * max_input_length : (step + 1) * batch_size + bidx; + log_probs[idx] = static_cast(logits[ids[token_idx]]) - s_max_logit - __logf(sum_exp + 1e-9f); } } } -__global__ void accumulate_log_probs(float* cum_log_probs, +__global__ void accumulate_log_probs(float* cum_log_probs, const float* log_probs, - const int* lengths, + const int* lengths, const size_t max_input_length, const size_t batch_size, - const bool batch_first) + const bool batch_first) { // Accumulate the log probability along with the sequence dimension. // cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]] // // cum_log_probs: [batch_size], cumulative log probability - // log_probs: [max_length, batch_size] or [batch_size, max_length], log probability of each token + // log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1], + // log probability of each token // lengths: [batch_size], sequence lengths // batch_size: [1], batch_size. in case of beam > 1, batch x beam. @@ -120,10 +123,10 @@ __global__ void accumulate_log_probs(float* cum_log_probs, if (bidx < batch_size) { int length = lengths[bidx]; // reposition logits to data for the current batch. - log_probs += batch_first ? bidx * max_input_length : bidx; + log_probs += batch_first ? bidx * (max_input_length - 1) : bidx; int stride = batch_first ? 1 : batch_size; // stride along with seq dim. float local_accum = 0.0f; - for (int step = tidx; step < length; step += blockDim.x) { + for (int step = tidx; step < length - 1; step += blockDim.x) { local_accum += static_cast(log_probs[step * stride]); } float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum(local_accum); @@ -134,18 +137,18 @@ __global__ void accumulate_log_probs(float* cum_log_probs, } template -void invokeLogProbFromLogits(float* cum_log_probs, - const T* logits, - const int* input_ids, - const int* input_lengths, +void invokeLogProbFromLogits(float* cum_log_probs, + const T* logits, + const int* input_ids, + const int* input_lengths, const size_t max_input_length, const size_t batch_size, const size_t vocab_size, const size_t vocab_size_padded, - void* workspace, + void* workspace, const size_t workspace_size, cudaStream_t stream, - const bool batch_first) + const bool batch_first) { // A batched version of log prob computation. // @@ -163,8 +166,8 @@ void invokeLogProbFromLogits(float* cum_log_probs, assert(vocab_size <= vocab_size_padded); float* log_probs = reinterpret_cast(workspace); - int gx = batch_first ? batch_size : max_input_length; - int gy = batch_first ? max_input_length : batch_size; + int gx = batch_first ? batch_size : max_input_length - 1; + int gy = batch_first ? max_input_length - 1 : batch_size; dim3 grid(gx, gy); log_probs_kernel<<>>(log_probs, logits, @@ -179,29 +182,29 @@ void invokeLogProbFromLogits(float* cum_log_probs, cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first); } -template void invokeLogProbFromLogits(float* cum_log_probs, +template void invokeLogProbFromLogits(float* cum_log_probs, const float* logits, - const int* input_ids, - const int* input_lengths, + const int* input_ids, + const int* input_lengths, const size_t max_input_length, const size_t batch_size, const size_t vocab_size, const size_t vocab_size_padded, - void* workspace, + void* workspace, const size_t workspace_size, cudaStream_t stream, - const bool batch_first); + const bool batch_first); -template void invokeLogProbFromLogits(float* cum_log_probs, - const half* logits, - const int* input_ids, - const int* input_lengths, +template void invokeLogProbFromLogits(float* cum_log_probs, + const half* logits, + const int* input_ids, + const int* input_lengths, const size_t max_input_length, const size_t batch_size, const size_t vocab_size, const size_t vocab_size_padded, - void* workspace, + void* workspace, const size_t workspace_size, cudaStream_t stream, - const bool batch_first); + const bool batch_first); } // end of namespace fastertransformer diff --git a/src/fastertransformer/kernels/logprob_kernels.h b/src/fastertransformer/kernels/logprob_kernels.h index 8e480c108..94f0d0111 100644 --- a/src/fastertransformer/kernels/logprob_kernels.h +++ b/src/fastertransformer/kernels/logprob_kernels.h @@ -19,16 +19,16 @@ namespace fastertransformer { template -void invokeLogProbFromLogits(float* cum_log_probs, - const T* logits, - const int* input_ids, - const int* input_lengths, +void invokeLogProbFromLogits(float* cum_log_probs, + const T* logits, + const int* input_ids, + const int* input_lengths, const size_t max_input_length, const size_t batch_size, const size_t vocab_size, const size_t vocab_size_padded, - void* workspace, + void* workspace, const size_t workspace_size, cudaStream_t stream, - const bool batch_first = false); + const bool batch_first = false); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/longformer_kernels.cu b/src/fastertransformer/kernels/longformer_kernels.cu index ca7d2c03c..3dcc51a66 100644 --- a/src/fastertransformer/kernels/longformer_kernels.cu +++ b/src/fastertransformer/kernels/longformer_kernels.cu @@ -24,12 +24,35 @@ #endif #include "longformer_kernels.h" +#include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" #include "src/fastertransformer/utils/cuda_utils.h" #include namespace fastertransformer { +// when sm < 80, cub device partition flagged bf16 not supported +// error: more than one operator "=" matches these operands: __nv_bfloat16::operator=(float), +// __nv_bfloat16::operator=(double) +template +struct CubBF16FallBackType { + using Type = T; +}; + +#ifdef ENABLE_BF16 + +// global attn mask can be converted to uint16_t +template<> +struct CubBF16FallBackType<__nv_bfloat16> { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + using Type = uint16_t; +#else + using Type = __nv_bfloat16; +#endif +}; + +#endif // ENABLE_BF16 + __global__ void initSeqIdxKernel(int* seq_idx, int seq_len) { int offset = blockIdx.x * blockDim.x + threadIdx.x; @@ -42,11 +65,13 @@ template size_t getInitLongformerCubStorage(const int seq_len) { size_t tmp_storage_bytes = 0; - int *seq_idx = NULL, *global_idx = NULL, *global_token_nums = NULL; - void* global_attn_mask = NULL; + int * seq_idx = NULL, *global_idx = NULL, *global_token_nums = NULL; + void* global_attn_mask = NULL; + + using AttenMaskFBType = typename CubBF16FallBackType::Type; check_cuda_error(cub::DevicePartition::Flagged( - NULL, tmp_storage_bytes, seq_idx, (T*)global_attn_mask, global_idx, global_token_nums, seq_len)); + NULL, tmp_storage_bytes, seq_idx, (AttenMaskFBType*)global_attn_mask, global_idx, global_token_nums, seq_len)); return tmp_storage_bytes; } @@ -54,10 +79,10 @@ size_t getInitLongformerCubStorage(const int seq_len) template __global__ void localAttnMaskShiftKernel(T* local_attn_mask, T* out, int thread_block_repeat, int total_len) { - int i = blockIdx.x * blockDim.x * thread_block_repeat + threadIdx.x; + int i = blockIdx.x * blockDim.x * thread_block_repeat + threadIdx.x; int end = i + thread_block_repeat * blockDim.x; for (; i < end && i < total_len; i += blockDim.x) { - out[i] = local_attn_mask[i] * (T)10000 - (T)10000; + out[i] = fma(local_attn_mask[i], (T)10000.f, (T)(-10000.f)); } } @@ -65,8 +90,8 @@ template void invokeLocalAttnMaskShift(T* local_attn_mask, T* out, int batch_size, int seq_len, cudaStream_t stream) { const int thread_block_repeat = 4; - const int block_dim = 128; - int block_num = std::ceil(batch_size * seq_len / (float)block_dim / (float)thread_block_repeat); + const int block_dim = 128; + int block_num = std::ceil(batch_size * seq_len / (float)block_dim / (float)thread_block_repeat); localAttnMaskShiftKernel<<>>( local_attn_mask, out, thread_block_repeat, batch_size * seq_len); } @@ -77,27 +102,33 @@ invokeLocalAttnMaskShift(float* local_attn_mask, float* out, int batch_size, int template void invokeLocalAttnMaskShift(half* local_attn_mask, half* out, int batch_size, int seq_len, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeLocalAttnMaskShift( + __nv_bfloat16* local_attn_mask, __nv_bfloat16* out, int batch_size, int seq_len, cudaStream_t stream); +#endif + template -void invokeInitLongformerIdx(T* global_attn_mask, - int* seq_idx, - int* global_idx, - int* global_token_nums, - int seq_len, - int batch_size, - void* cub_storage, +void invokeInitLongformerIdx(T* global_attn_mask, + int* seq_idx, + int* global_idx, + int* global_token_nums, + int seq_len, + int batch_size, + void* cub_storage, cudaStream_t stream) { const int threads = 1024; - int blocks = std::ceil(seq_len / float(threads)); + int blocks = std::ceil(seq_len / float(threads)); initSeqIdxKernel<<>>(seq_idx, seq_len); sync_check_cuda_error(); size_t storages_bytes = getInitLongformerCubStorage(seq_len); + using AttenMaskFBType = typename CubBF16FallBackType::Type; for (int i = 0; i < batch_size; ++i) { check_cuda_error(cub::DevicePartition::Flagged(cub_storage, storages_bytes, seq_idx, - global_attn_mask + i * seq_len, + (AttenMaskFBType*)(global_attn_mask + i * seq_len), global_idx + i * seq_len, global_token_nums + i, seq_len, @@ -105,64 +136,75 @@ void invokeInitLongformerIdx(T* global_attn_mask, } } -template void invokeInitLongformerIdx(float* global_attn_mask, - int* seq_idx, - int* global_idx, - int* global_token_nums, - int seq_len, - int batch_size, - void* cub_storage, +template void invokeInitLongformerIdx(float* global_attn_mask, + int* seq_idx, + int* global_idx, + int* global_token_nums, + int seq_len, + int batch_size, + void* cub_storage, cudaStream_t stream); -template void invokeInitLongformerIdx(half* global_attn_mask, - int* seq_idx, - int* global_idx, - int* global_token_nums, - int seq_len, - int batch_size, - void* cub_storage, +template void invokeInitLongformerIdx(half* global_attn_mask, + int* seq_idx, + int* global_idx, + int* global_token_nums, + int seq_len, + int batch_size, + void* cub_storage, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeInitLongformerIdx(__nv_bfloat16* global_attn_mask, + int* seq_idx, + int* global_idx, + int* global_token_nums, + int seq_len, + int batch_size, + void* cub_storage, + cudaStream_t stream); +#endif + // Apply softmax to local and global attention. Rewrite the result to the same buffer in-place template -__launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* global_attn, +__launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* global_attn, const int* global_idx, const int* global_token_nums, - void* input_ptrs, - const T* attn_mask, - float scaler, - int seq_len, - int head_num, - int attn_window_size) + void* input_ptrs, + const T* attn_mask, + float scaler, + int seq_len, + int head_num, + int attn_window_size) { - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage breduce_temp; size_t* p_inputs = (size_t*)(input_ptrs); // use input buffer as output buffer - size_t* p_outputs = (size_t*)(input_ptrs); - size_t* input_sizes = (size_t*)(input_ptrs) + 5; + size_t* p_outputs = (size_t*)(input_ptrs); + size_t* input_sizes = (size_t*)(input_ptrs) + 5; size_t* input_strides = (size_t*)(input_ptrs) + 10; - int tid = threadIdx.x; + int tid = threadIdx.x; const int batch_idx = blockIdx.x / (seq_len * head_num); - const int row_idx = blockIdx.x % seq_len; - const int head_idx = (blockIdx.x / seq_len) % head_num; + const int row_idx = blockIdx.x % seq_len; + const int head_idx = (blockIdx.x / seq_len) % head_num; // adjust the pointers for the batch - const T* mask_blk = attn_mask + seq_len * batch_idx; - const int global_num = global_token_nums[batch_idx]; + const T* mask_blk = attn_mask + seq_len * batch_idx; + const int global_num = global_token_nums[batch_idx]; const int* global_idx_blk = global_idx + seq_len * batch_idx; T* inputs[5]; T* outputs[5]; for (int i = 0; i < 5; ++i) { - inputs[i] = (T*)p_inputs[i] + batch_idx * head_num * input_sizes[i]; + inputs[i] = (T*)p_inputs[i] + batch_idx * head_num * input_sizes[i]; outputs[i] = (T*)p_outputs[i] + batch_idx * head_num * input_sizes[i]; } int col_start = 0; - int col_end = seq_len; + int col_end = seq_len; // is it local attention token int is_local_row = global_attn[row_idx + seq_len * batch_idx] == (T)0.f; @@ -170,7 +212,7 @@ __launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* // if local token if (is_local_row) { col_start = row_idx - attn_window_size; - col_end = row_idx + attn_window_size + 1; + col_end = row_idx + attn_window_size + 1; } if (col_start < 0) { @@ -183,27 +225,27 @@ __launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* // if mask is set then set everything to zero to match Python implementation if (mask_blk[row_idx] != (T)0.f) { if (is_local_row) { - T* output_blk = nullptr; - T* output_glb = nullptr; + T* output_blk = nullptr; + T* output_glb = nullptr; int local_offset = row_idx % attn_window_size; - int local_start = 0; - int local_end = 3 * attn_window_size; + int local_start = 0; + int local_end = 3 * attn_window_size; if (row_idx < attn_window_size) { local_start = 0; - local_end = 2 * attn_window_size; - output_blk = outputs[0] + row_idx * input_strides[0] + head_idx * input_sizes[0]; + local_end = 2 * attn_window_size; + output_blk = outputs[0] + row_idx * input_strides[0] + head_idx * input_sizes[0]; } else if (row_idx < seq_len - attn_window_size) { output_blk = outputs[1] + (row_idx - attn_window_size) * input_strides[1] + head_idx * input_sizes[1]; } else { local_start = 0; - local_end = 2 * attn_window_size; - output_blk = outputs[2] + local_offset * input_strides[2] + head_idx * input_sizes[2]; + local_end = 2 * attn_window_size; + output_blk = outputs[2] + local_offset * input_strides[2] + head_idx * input_sizes[2]; } for (int i = local_start + tid; i < local_end; i += blockSize) { - output_blk[i] = 0; + output_blk[i] = (T)0.f; } if ((row_idx - 2 * attn_window_size) >= 0) { @@ -212,57 +254,57 @@ __launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* if (output_glb != nullptr) { for (int i = tid; i < global_num; i += blockSize) { - output_glb[i] = 0; + output_glb[i] = (T)0.f; } } } else { T* output_blk = outputs[4]; for (int i = tid; i < seq_len; i += blockSize) { - output_blk[i] = 0; + output_blk[i] = (T)0.f; } } return; } __shared__ float sum_shared; - float sum_input = 0.; + float sum_input = 0.; // calculate max input - float max_input = -FLT_MAX; + float max_input = -FLT_MAX; __shared__ float max_shared; if (is_local_row) { - const T* input_blk = nullptr; - T* output_blk = nullptr; - T* output_glb = nullptr; - int local_offset = row_idx % attn_window_size; - int local_start = local_offset; - int local_end = local_start + 2 * attn_window_size + 1; - int zero_start = 0; - int zero_end = 3 * attn_window_size; + const T* input_blk = nullptr; + T* output_blk = nullptr; + T* output_glb = nullptr; + int local_offset = row_idx % attn_window_size; + int local_start = local_offset; + int local_end = local_start + 2 * attn_window_size + 1; + int zero_start = 0; + int zero_end = 3 * attn_window_size; if (row_idx < attn_window_size) { local_start = 0; - local_end = local_offset + attn_window_size + 1; - zero_end = 2 * attn_window_size; + local_end = local_offset + attn_window_size + 1; + zero_end = 2 * attn_window_size; - input_blk = inputs[0] + row_idx * input_strides[0] + head_idx * input_sizes[0]; + input_blk = inputs[0] + row_idx * input_strides[0] + head_idx * input_sizes[0]; output_blk = outputs[0] + row_idx * input_strides[0] + head_idx * input_sizes[0]; } else if (row_idx < seq_len - attn_window_size) { - input_blk = inputs[1] + (row_idx - attn_window_size) * input_strides[1] + head_idx * input_sizes[1]; + input_blk = inputs[1] + (row_idx - attn_window_size) * input_strides[1] + head_idx * input_sizes[1]; output_blk = outputs[1] + (row_idx - attn_window_size) * input_strides[1] + head_idx * input_sizes[1]; } else { local_start = local_offset; - local_end = 2 * attn_window_size; - zero_end = 2 * attn_window_size; + local_end = 2 * attn_window_size; + zero_end = 2 * attn_window_size; - input_blk = inputs[2] + local_offset * input_strides[2] + head_idx * input_sizes[2]; + input_blk = inputs[2] + local_offset * input_strides[2] + head_idx * input_sizes[2]; output_blk = outputs[2] + local_offset * input_strides[2] + head_idx * input_sizes[2]; } - const T* input_glb = nullptr; - int local_global = row_idx - attn_window_size; + const T* input_glb = nullptr; + int local_global = row_idx - attn_window_size; if (local_global > global_num) { local_global = global_num; } @@ -282,7 +324,7 @@ __launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* for (int i = local_start + tid, ii = col_start + tid; i < local_end; i += blockSize, ii += blockSize) { float x = (float)input_blk[i]; - x = x * scaler + (float)mask_blk[ii]; + x = x * scaler + (float)mask_blk[ii]; if (max_input < x) { max_input = x; } @@ -291,7 +333,7 @@ __launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* if (input_glb != nullptr) { for (int i = tid; i < local_global; i += blockSize) { float x = (float)input_glb[global_idx_blk[i]]; - x = x * scaler + (float)mask_blk[global_idx_blk[i]]; + x = x * scaler + (float)mask_blk[global_idx_blk[i]]; if (max_input < x) { max_input = x; } @@ -306,14 +348,14 @@ __launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* for (int i = local_start + tid, ii = col_start + tid; i < local_end; i += blockSize, ii += blockSize) { float x = (float)input_blk[i]; - x = expf(x * scaler + (float)mask_blk[ii] - max_shared); + x = expf(x * scaler + (float)mask_blk[ii] - max_shared); sum_input += x; } if (input_glb != nullptr) { for (int i = tid, ii = col_start + tid; i < local_global; i += blockSize, ii += blockSize) { float x = (float)input_glb[global_idx_blk[i]]; - x = expf(x * scaler + (float)mask_blk[ii] - max_shared); + x = expf(x * scaler + (float)mask_blk[ii] - max_shared); sum_input += x; } } @@ -336,27 +378,27 @@ __launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* __syncthreads(); for (int i = local_start + tid, ii = col_start + tid; i < local_end; i += blockSize, ii += blockSize) { - float x = (float)input_blk[i]; - x = expf(x * scaler + (float)mask_blk[ii] - max_shared); + float x = (float)input_blk[i]; + x = expf(x * scaler + (float)mask_blk[ii] - max_shared); output_blk[i] = (T)(recip_sum * x); } if (input_glb != nullptr) { for (int i = tid; i < local_global; i += blockSize) { - float x = (float)input_glb[global_idx_blk[i]]; - x = expf(x * scaler + (float)mask_blk[global_idx_blk[i]] - max_shared); + float x = (float)input_glb[global_idx_blk[i]]; + x = expf(x * scaler + (float)mask_blk[global_idx_blk[i]] - max_shared); output_glb[i] = (T)(recip_sum * x); } } } else { // global tokens - const T* input_blk = inputs[4] + row_idx * input_strides[4] + head_idx * input_sizes[4]; - T* output_blk = outputs[4] + row_idx * input_strides[4] + head_idx * input_sizes[4]; + const T* input_blk = inputs[4] + row_idx * input_strides[4] + head_idx * input_sizes[4]; + T* output_blk = outputs[4] + row_idx * input_strides[4] + head_idx * input_sizes[4]; for (int i = tid; i < seq_len; i += blockSize) { float x = (float)input_blk[i]; - x = x * scaler + (float)mask_blk[i]; + x = x * scaler + (float)mask_blk[i]; if (max_input < x) { max_input = x; } @@ -370,7 +412,7 @@ __launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* for (int i = tid; i < seq_len; i += blockSize) { float x = (float)input_blk[i]; - x = expf(x * scaler + (float)mask_blk[i] - max_shared); + x = expf(x * scaler + (float)mask_blk[i] - max_shared); sum_input += x; } @@ -382,28 +424,28 @@ __launch_bounds__(blockSize) __global__ void longformerMHASoftmaxKernel(const T* float recip_sum = 1.f / sum_shared; for (int i = tid; i < seq_len; i += blockSize) { - float x = (float)input_blk[i]; - x = expf(x * scaler + (float)mask_blk[i] - max_shared); + float x = (float)input_blk[i]; + x = expf(x * scaler + (float)mask_blk[i] - max_shared); output_blk[i] = (T)(recip_sum * x); } } } template -void invokeLongformerMHASoftmax(const T* global_attn_mask, - const int* global_idx, - const int* global_token_nums, - void* input_ptrs, - const T* local_attn_mask, - float scaler, - int seq_len, - int head_num, - int batch_size, - int local_attn_window_size, +void invokeLongformerMHASoftmax(const T* global_attn_mask, + const int* global_idx, + const int* global_token_nums, + void* input_ptrs, + const T* local_attn_mask, + float scaler, + int seq_len, + int head_num, + int batch_size, + int local_attn_window_size, cudaStream_t stream) { const int block_size = 64; - const int grid_size = seq_len * head_num * batch_size; + const int grid_size = seq_len * head_num * batch_size; longformerMHASoftmaxKernel<<>>(global_attn_mask, global_idx, global_token_nums, @@ -416,27 +458,41 @@ void invokeLongformerMHASoftmax(const T* global_attn_mask, } template void invokeLongformerMHASoftmax(const float* global_attn_mask, - const int* global_idx, - const int* global_token_nums, - void* input_ptrs, + const int* global_idx, + const int* global_token_nums, + void* input_ptrs, const float* local_attn_mask, - float scaler, - int seq_len, - int head_num, - int batch_size, - int local_attn_window_size, + float scaler, + int seq_len, + int head_num, + int batch_size, + int local_attn_window_size, cudaStream_t stream); -template void invokeLongformerMHASoftmax(const half* global_attn_mask, - const int* global_idx, - const int* global_token_nums, - void* input_ptrs, - const half* local_attn_mask, - float scaler, - int seq_len, - int head_num, - int batch_size, - int local_attn_window_size, +template void invokeLongformerMHASoftmax(const half* global_attn_mask, + const int* global_idx, + const int* global_token_nums, + void* input_ptrs, + const half* local_attn_mask, + float scaler, + int seq_len, + int head_num, + int batch_size, + int local_attn_window_size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeLongformerMHASoftmax(const __nv_bfloat16* global_attn_mask, + const int* global_idx, + const int* global_token_nums, + void* input_ptrs, + const __nv_bfloat16* local_attn_mask, + float scaler, + int seq_len, + int head_num, + int batch_size, + int local_attn_window_size, + cudaStream_t stream); +#endif + } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/longformer_kernels.h b/src/fastertransformer/kernels/longformer_kernels.h index e84887beb..537dd85ce 100644 --- a/src/fastertransformer/kernels/longformer_kernels.h +++ b/src/fastertransformer/kernels/longformer_kernels.h @@ -25,26 +25,26 @@ template void invokeLocalAttnMaskShift(T* local_attn_mask, T* out, int batch_size, int seq_len, cudaStream_t stream); template -void invokeInitLongformerIdx(T* global_attn_mask, - int* seq_idx, - int* global_idx, - int* global_token_nums, - int seq_len, - int batch_size, - void* cub_storage, +void invokeInitLongformerIdx(T* global_attn_mask, + int* seq_idx, + int* global_idx, + int* global_token_nums, + int seq_len, + int batch_size, + void* cub_storage, cudaStream_t stream); template -void invokeLongformerMHASoftmax(const T* global_attn_mask, - const int* global_idx, - const int* global_token_nums, - void* input_ptrs, - const T* local_attn_mask, - float scaler, - int seq_len, - int head_num, - int batch_size, - int local_attn_window_size, +void invokeLongformerMHASoftmax(const T* global_attn_mask, + const int* global_idx, + const int* global_token_nums, + void* input_ptrs, + const T* local_attn_mask, + float scaler, + int seq_len, + int head_num, + int batch_size, + int local_attn_window_size, cudaStream_t stream); } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/matrix_transpose_kernels.cu b/src/fastertransformer/kernels/matrix_transpose_kernels.cu index e7f44c62d..67c01bdfd 100644 --- a/src/fastertransformer/kernels/matrix_transpose_kernels.cu +++ b/src/fastertransformer/kernels/matrix_transpose_kernels.cu @@ -24,10 +24,10 @@ template __global__ void matrix_transpose(T* dst, const T* src, const int k, const int n) { __shared__ T shm[32][33]; - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - int n_idx = blockIdx.x * 32 + tidx; - int k_idx = blockIdx.y * 32 + tidy; + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + int n_idx = blockIdx.x * 32 + tidx; + int k_idx = blockIdx.y * 32 + tidy; if (n_idx < n && k_idx < k) { shm[tidx][tidy] = src[k_idx * n + n_idx]; } diff --git a/src/fastertransformer/kernels/matrix_vector_multiplication.cu b/src/fastertransformer/kernels/matrix_vector_multiplication.cu index f88566990..cfb628088 100644 --- a/src/fastertransformer/kernels/matrix_vector_multiplication.cu +++ b/src/fastertransformer/kernels/matrix_vector_multiplication.cu @@ -56,14 +56,14 @@ __global__ void int8WeightPerChannelLdkMultiplication( const char4* weight, const float4* input, const float* scale_list, void* output, const int k_4) { - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - const int row_idx = bidx * nPerThread; + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int row_idx = bidx * nPerThread; const size_t b_offset = row_idx * k_4; - using array = struct ARRAY; + using array = struct ARRAY; const array scale = *((const array*)scale_list + bidx); - array sum_list[m]; + array sum_list[m]; #pragma unroll for (int m_i = 0; m_i < m; m_i++) { #pragma unroll @@ -119,9 +119,9 @@ __global__ void int8WeightPerChannelLdkMultiplication( const char4* weight, const half4* input, const float* scale_list, void* output, const int k_4) { - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - const int row_idx = bidx * nPerThread; + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int row_idx = bidx * nPerThread; const size_t b_offset = row_idx * k_4; using array = struct ARRAY; @@ -140,12 +140,12 @@ __global__ void int8WeightPerChannelLdkMultiplication( #pragma unroll for (int m_i = 0; m_i < m; m_i++) { const half4 input_val = input[k_idx + m_i * k_4]; - input_val_0[m_i] = {input_val.x, input_val.y}; - input_val_1[m_i] = {input_val.z, input_val.w}; + input_val_0[m_i] = {input_val.x, input_val.y}; + input_val_1[m_i] = {input_val.z, input_val.w}; } #pragma unroll for (int i = 0; i < nPerThread; i++) { - const char4 weight_val = weight[b_offset + i * k_4 + k_idx]; + const char4 weight_val = weight[b_offset + i * k_4 + k_idx]; const half2 weight_val_0 = {static_cast(weight_val.x), static_cast(weight_val.y)}; const half2 weight_val_1 = {static_cast(weight_val.z), static_cast(weight_val.w)}; #pragma unroll @@ -162,7 +162,7 @@ __global__ void int8WeightPerChannelLdkMultiplication( __syncthreads(); } if (tidx == 0) { - using array_half = struct ARRAY; + using array_half = struct ARRAY; const array scale = *((const array*)scale_list + bidx); #pragma unroll for (int m_i = 0; m_i < m; m_i++) { @@ -182,9 +182,9 @@ __global__ void int8WeightPerChannelLdkMultiplication( const char4* weight, const bf164* input, const float* scale_list, void* output, const int k_4) { - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - const int row_idx = bidx * nPerThread; + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int row_idx = bidx * nPerThread; const size_t b_offset = row_idx * k_4; using array = struct ARRAY; @@ -203,8 +203,8 @@ __global__ void int8WeightPerChannelLdkMultiplication( #pragma unroll for (int m_i = 0; m_i < m; m_i++) { const bf164 input_val = input[k_idx + m_i * k_4]; - input_val_0[m_i] = {input_val.x, input_val.y}; - input_val_1[m_i] = {input_val.z, input_val.w}; + input_val_0[m_i] = {input_val.x, input_val.y}; + input_val_1[m_i] = {input_val.z, input_val.w}; } #pragma unroll for (int i = 0; i < nPerThread; i++) { @@ -234,7 +234,7 @@ __global__ void int8WeightPerChannelLdkMultiplication( __syncthreads(); } if (tidx == 0) { - using array_half = struct ARRAY; + using array_half = struct ARRAY; const array scale = *((const array*)scale_list + bidx); #pragma unroll for (int m_i = 0; m_i < m; m_i++) { @@ -256,17 +256,17 @@ __global__ void int8WeightPerChannelLdkMultiplication( template void int8WeightPerChannelLdkMultiplicationLauncher(const int8_t* weight, - const T* input, - const float* scale_list, - T* output, - const int m, - const int n, - const int k, - cudaStream_t stream) + const T* input, + const float* scale_list, + T* output, + const int m, + const int n, + const int k, + cudaStream_t stream) { const int nPerThread = 2; if ((n % nPerThread != 0) || (k % 4 != 0)) { - printf("[ERROR][int8WeightPerChannelLdkMultiplicationLauncher] (%d % %d != 0) || (%d % 4 != 0).\n", + printf("[ERROR][int8WeightPerChannelLdkMultiplicationLauncher] (%d %% %d != 0) || (%d %% 4 != 0).\n", n, nPerThread, k); @@ -285,10 +285,10 @@ void int8WeightPerChannelLdkMultiplicationLauncher(const int8_t* weight, else { block.x = 64; } - while (block.x * 4 > k) { + while (block.x * 4 > (size_t)k) { block.x /= 2; } - block.x = (block.x + 31) / 32 * 32; + block.x = (block.x + 31) / 32 * 32; const size_t shm_size = block.x * nPerThread * sizeof(float); if (m == 1) { if (std::is_same::value) { @@ -323,32 +323,32 @@ void int8WeightPerChannelLdkMultiplicationLauncher(const int8_t* weight, } template void int8WeightPerChannelLdkMultiplicationLauncher(const int8_t* matrix, - const float* vector, - const float* scale_list, - float* output, - const int m, - const int n, - const int k, - cudaStream_t stream); + const float* vector, + const float* scale_list, + float* output, + const int m, + const int n, + const int k, + cudaStream_t stream); template void int8WeightPerChannelLdkMultiplicationLauncher(const int8_t* matrix, - const half* vector, - const float* scale_list, - half* output, - const int m, - const int n, - const int k, - cudaStream_t stream); + const half* vector, + const float* scale_list, + half* output, + const int m, + const int n, + const int k, + cudaStream_t stream); #ifdef ENABLE_BF16 -template void int8WeightPerChannelLdkMultiplicationLauncher(const int8_t* matrix, +template void int8WeightPerChannelLdkMultiplicationLauncher(const int8_t* matrix, const __nv_bfloat16* vector, - const float* scale_list, - __nv_bfloat16* output, - const int m, - const int n, - const int k, - cudaStream_t stream); + const float* scale_list, + __nv_bfloat16* output, + const int m, + const int n, + const int k, + cudaStream_t stream); #endif ///////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/matrix_vector_multiplication.h b/src/fastertransformer/kernels/matrix_vector_multiplication.h index 7923d6e60..e6b0fb1f3 100644 --- a/src/fastertransformer/kernels/matrix_vector_multiplication.h +++ b/src/fastertransformer/kernels/matrix_vector_multiplication.h @@ -24,12 +24,12 @@ namespace fastertransformer { template void int8WeightPerChannelLdkMultiplicationLauncher(const int8_t* weight, - const T* input, - const float* scale_list, - T* output, - const int m, - const int n, - const int k, - cudaStream_t stream); + const T* input, + const float* scale_list, + T* output, + const int m, + const int n, + const int k, + cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu b/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu index b12bbf6bb..bce35f8c9 100644 --- a/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu +++ b/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu @@ -33,12 +33,19 @@ static const int SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256; #define TOPK_FP16_STORAGE 0 +template +__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) +{ + // score = log(prob) / (length)^length_penalty. + return log_prob / static_cast(powf(length, length_penalty)); +} + template __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) { - int thread_id = threadIdx.x; - int block_id = blockIdx.x; + int thread_id = threadIdx.x; + int block_id = blockIdx.x; TopK partial; if (thread_id == 0) { for (int i = 0; i < MAX_K; ++i) { @@ -64,8 +71,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* int* __restrict id_buf, T* __restrict val_buf) { - int thread_id = threadIdx.x; - int block_id = blockIdx.x; + int thread_id = threadIdx.x; + int block_id = blockIdx.x; TopK partial; if (thread_id == 0) { for (int i = 0; i < MAX_K; ++i) { @@ -80,7 +87,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* index = block_id * MAX_K; for (int i = 0; i < MAX_K; i++) { - id_buf[index + i] = partial.p[i]; + id_buf[index + i] = partial.p[i]; val_buf[index + i] = partial.u[i]; } } @@ -91,11 +98,14 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* const T* __restrict y, int* __restrict z, float* __restrict v, - float* output_log_probs, - int V, - int K, - int vocab_size, - T diversity_rate) + float* output_log_probs, + const bool* finished, + const int* sequence_lengths, + const int V, + const int K, + const int vocab_size, + const float length_penalty, + const T diversity_rate) { int thread_id = threadIdx.x; int vector_id = blockIdx.x; @@ -113,9 +123,15 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* partial.p[i] = -1; partial.u[i] = -FLT_MAX; } + for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) { - int i = elem_id % K; - T elem = y[elem_id] + diversity_rate * (T)i; + int i = elem_id % K; + T elem = length_penalty == 0.0f ? y[elem_id] : + apply_length_penalty(y[elem_id], + finished[vector_id] ? sequence_lengths[vector_id] : + sequence_lengths[vector_id] + 1, + length_penalty); + elem += diversity_rate * (T)i; int elem_idx = elem_id; // x[elem_id]; partial.insert(elem, elem_idx); } @@ -146,10 +162,10 @@ struct __align__(8) MD __device__ __forceinline__ MD reduce_md_op(MD a, MD b) { - bool a_bigger = (a.m > b.m); - MD bigger_m = a_bigger ? a : b; - MD smaller_m = a_bigger ? b : a; - MD res; + bool a_bigger = (a.m > b.m); + MD bigger_m = a_bigger ? a : b; + MD smaller_m = a_bigger ? b : a; + MD res; res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m); res.m = bigger_m.m; return res; @@ -157,7 +173,7 @@ __device__ __forceinline__ MD reduce_md_op(MD a, MD b) template struct TopKMD { - MD md; + MD md; TopK topk; }; @@ -165,7 +181,7 @@ template __device__ __forceinline__ TopKMD reduce_topk_md_op(const TopKMD& a, const TopKMD& b) { TopKMD res; - res.md = reduce_md_op(a.md, b.md); + res.md = reduce_md_op(a.md, b.md); res.topk = reduce_topk_op(a.topk, b.topk); return res; } @@ -184,17 +200,17 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker int thread_id = threadIdx.x; int vector_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; // reposition y to data for the current vector x += vector_id * V; typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage; TopKMD partial; - bool finish = finished[vector_id]; + bool finish = finished[vector_id]; for (int i = 0; i < MAX_K; ++i) { partial.topk.p[i] = -1; partial.topk.u[i] = -MAX_T_VAL; @@ -205,7 +221,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker if (finish) { for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) { float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; - MD new_elem{elem, 1.0F}; + MD new_elem{elem, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); partial.topk.insert(elem, elem_id); // if (elem_id > THREADBLOCK_SIZE * MAX_K && (elem_id == E)) break; @@ -214,7 +230,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker else { for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) { float elem = x[elem_id] + b[elem_id]; - MD new_elem{elem, 1.0F}; + MD new_elem{elem, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); partial.topk.insert(elem, elem_id); } @@ -251,18 +267,18 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ const int* __restrict end_ids) { int thread_id = threadIdx.x; - int vector_id = blockIdx.x; + int vector_id = blockIdx.x; // batch beam index. const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; // one will have multiple sections per V - const int v_local = (V + gridDim.y - 1) / gridDim.y; + const int v_local = (V + gridDim.y - 1) / gridDim.y; const int section_start = v_local * blockIdx.y; - int section_end = section_start + v_local; - section_end = (section_end > V) ? V : section_end; + int section_end = section_start + v_local; + section_end = (section_end > V) ? V : section_end; // reposition x to data for the current vector x += vector_id * V; @@ -272,12 +288,12 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; #endif __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result + __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result #if TOPK_FP16_STORAGE == 1 TopKMD<__half, MAX_K> partial; #else - TopKMD partial; + TopKMD partial; #endif bool finish = finished[vector_id]; for (int i = 0; i < MAX_K; ++i) { @@ -291,7 +307,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ #pragma unroll 1 for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) { float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; - MD new_elem{elem, 1.0F}; + MD new_elem{elem, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); partial.topk.insert(elem, elem_id); } @@ -299,8 +315,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ else { #pragma unroll 1 for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) { - T bias = b == nullptr ? (T)0.0f : b[elem_id]; // gpt-2 does not use bias - T elem = x[elem_id] + bias; + T bias = b == nullptr ? (T)0.0f : b[elem_id]; // gpt-2 does not use bias + T elem = x[elem_id] + bias; MD new_elem{elem, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); partial.topk.insert(elem, elem_id); @@ -316,9 +332,9 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ if (thread_id == 0) { for (int i = 0; i < K; i++) { reinterpret_cast(buf_s)[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id - buf_s[MAX_K + i] = total.topk.u[i]; + buf_s[MAX_K + i] = total.topk.u[i]; } - buf_s[2 * MAX_K] = total.md.d; + buf_s[2 * MAX_K] = total.md.d; buf_s[2 * MAX_K + 1] = total.md.m; } __syncthreads(); @@ -332,19 +348,19 @@ template __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel( const float* __restrict x, const float* __restrict c, int* __restrict z, T* __restrict v, int K, int parts_per_beam) { - const int vector_id = blockIdx.x; - const int thread_id = threadIdx.x; + const int vector_id = blockIdx.x; + const int thread_id = threadIdx.x; const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; extern __shared__ char buf_s_[]; // intermediate result - float* buf_s = reinterpret_cast(buf_s_); + float* buf_s = reinterpret_cast(buf_s_); //__shared__ float buf_s[PACKED_TOP_KMD_SIZE * THREADBLOCK_SIZE]; // intermediate result typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage; x += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam; @@ -394,11 +410,11 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta template void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, const float* cum_log_probs, - int* ids, - T* vals, - int batch_size, - int beam_width, - int parts_per_beam, + int* ids, + T* vals, + int batch_size, + int beam_width, + int parts_per_beam, cudaStream_t stream) { // might rewrite beam_online_softmax_topk_stage2_kernel no to depend on constant block size @@ -425,39 +441,43 @@ void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, } template -void topK_softMax_kernelLauncher(const T* log_probs, - const T* bias, - const bool* finished, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* temp_storage, - const int temp_storage_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - T diversity_rate, +void topK_softMax_kernelLauncher(const T* log_probs, + const T* bias, + const bool* finished, + const int* sequence_lengths, + float* cum_log_probs, + float* output_log_probs, + int* ids, + void* temp_storage, + const int temp_storage_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int* end_ids, + T diversity_rate, + const float length_penalty, cudaStream_t stream) { const int items_per_thread = 1; - const int block_sz = (MAX_K < 16) ? (MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128 : 64; + const int block_sz = (MAX_K < 16) ? (MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128 : 64; // const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE; assert(temp_storage_size % 2 == 0); assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width); + // Beam search needs the sequence lengths of beams to apply length penalty. + assert(length_penalty == 0.0f || sequence_lengths != nullptr); - const int topk_buf_offset = ceil(batch_size * beam_width * beam_width / 4.) * 4; - int* topk_tmp_id_buf = reinterpret_cast(temp_storage); - T* topk_tmp_val_buf = reinterpret_cast(topk_tmp_id_buf + topk_buf_offset); - float* tmp_buffer = reinterpret_cast(topk_tmp_val_buf + topk_buf_offset); + const int topk_buf_offset = ceil(batch_size * beam_width * beam_width / 4.) * 4; + int* topk_tmp_id_buf = reinterpret_cast(temp_storage); + T* topk_tmp_val_buf = reinterpret_cast(topk_tmp_id_buf + topk_buf_offset); + float* tmp_buffer = reinterpret_cast(topk_tmp_val_buf + topk_buf_offset); #ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX int voc_parts = 4; if (batch_size * beam_width < 256) { // Volta has 80 SMs, so we aim for three waves voc_parts = (240 + batch_size * beam_width - 1) / (batch_size * beam_width); - voc_parts = std::min(128, voc_parts); // we implment up to 128 + voc_parts = std::min(128, voc_parts); // we implement up to 128 } dim3 grid(batch_size * beam_width, voc_parts); cudaFuncSetAttribute(beam_online_softmax_topk_stage1_kernel, @@ -492,9 +512,12 @@ void topK_softMax_kernelLauncher(const T* log_probs, ids, cum_log_probs, output_log_probs, + finished, + sequence_lengths, beam_width * beam_width, beam_width, vocab_size, + length_penalty, diversity_rate); #endif } @@ -511,170 +534,91 @@ void topK_softMax_kernelLauncher(const T* log_probs, } } +#define CASE_K(K, MAX_K) \ + case K: \ + topK_softMax_kernelLauncher(log_probs, \ + bias, \ + finished, \ + sequence_lengths, \ + cum_log_probs, \ + output_log_probs, \ + ids, \ + temp_storage, \ + temp_storage_size, \ + batch_size, \ + beam_width, \ + vocab_size, \ + end_ids, \ + diversity_rate, \ + length_penalty, \ + stream); \ + break; + template -void invokeTopkSoftMax(const T* log_probs, - const T* bias, - const bool* finished, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* temp_storage, - const int temp_storage_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - const float diversity_rate, +void invokeTopkSoftMax(const T* log_probs, + const T* bias, + const bool* finished, + const int* sequence_lengths, + float* cum_log_probs, + float* output_log_probs, + int* ids, + void* temp_storage, + const int temp_storage_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int* end_ids, + const float diversity_rate, + const float length_penalty, cudaStream_t stream) { switch (beam_width) { - case 1: - topK_softMax_kernelLauncher(log_probs, - bias, - finished, - cum_log_probs, - output_log_probs, - ids, - temp_storage, - temp_storage_size, - batch_size, - beam_width, - vocab_size, - end_ids, - diversity_rate, - stream); - break; - case 2: - topK_softMax_kernelLauncher(log_probs, - bias, - finished, - cum_log_probs, - output_log_probs, - ids, - temp_storage, - temp_storage_size, - batch_size, - beam_width, - vocab_size, - end_ids, - diversity_rate, - stream); - break; - case 3: - topK_softMax_kernelLauncher(log_probs, - bias, - finished, - cum_log_probs, - output_log_probs, - ids, - temp_storage, - temp_storage_size, - batch_size, - beam_width, - vocab_size, - end_ids, - diversity_rate, - stream); - break; - case 4: - topK_softMax_kernelLauncher(log_probs, - bias, - finished, - cum_log_probs, - output_log_probs, - ids, - temp_storage, - temp_storage_size, - batch_size, - beam_width, - vocab_size, - end_ids, - diversity_rate, - stream); - break; - case 8: - topK_softMax_kernelLauncher(log_probs, - bias, - finished, - cum_log_probs, - output_log_probs, - ids, - temp_storage, - temp_storage_size, - batch_size, - beam_width, - vocab_size, - end_ids, - diversity_rate, - stream); - break; - case 16: - topK_softMax_kernelLauncher(log_probs, - bias, - finished, - cum_log_probs, - output_log_probs, - ids, - temp_storage, - temp_storage_size, - batch_size, - beam_width, - vocab_size, - end_ids, - diversity_rate, - stream); - break; - case 32: - topK_softMax_kernelLauncher(log_probs, - bias, - finished, - cum_log_probs, - output_log_probs, - ids, - temp_storage, - temp_storage_size, - batch_size, - beam_width, - vocab_size, - end_ids, - diversity_rate, - stream); - break; + CASE_K(1, 1); + CASE_K(2, 2); + CASE_K(3, 3); + CASE_K(4, 4); + CASE_K(8, 8); + CASE_K(16, 16); + CASE_K(32, 32); default: - printf("[ERROR] Topk kernel does not support beamwidth = %d \n", beam_width); - exit(0); - break; + throw std::runtime_error(fmtstr("Topk kernel does not support beam_width=%d", beam_width)); } } +#undef CASE_K + template void invokeTopkSoftMax(const float* log_probs, const float* bias, - const bool* finished, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* tmp_storage, - const int temp_storage_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - const float diversity_rate, + const bool* finished, + const int* sequence_lengths, + float* cum_log_probs, + float* output_log_probs, + int* ids, + void* tmp_storage, + const int temp_storage_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int* end_ids, + const float diversity_rate, + const float length_penalty, cudaStream_t stream); -template void invokeTopkSoftMax(const half* log_probs, - const half* bias, - const bool* finished, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* tmp_storage, - const int temp_storage_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - const float diversity_rate, +template void invokeTopkSoftMax(const half* log_probs, + const half* bias, + const bool* finished, + const int* sequence_lengths, + float* cum_log_probs, + float* output_log_probs, + int* ids, + void* tmp_storage, + const int temp_storage_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int* end_ids, + const float diversity_rate, + const float length_penalty, cudaStream_t stream); -} // end of namespace fastertransformer \ No newline at end of file +} // end of namespace fastertransformer diff --git a/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.h b/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.h index 5407865f6..415705c4e 100644 --- a/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.h +++ b/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.h @@ -18,19 +18,21 @@ namespace fastertransformer { template -void invokeTopkSoftMax(const T* log_probs, - const T* bias, - const bool* finished, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* tmp_storage, - const int temp_storage_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - const float diversity_rate, +void invokeTopkSoftMax(const T* log_probs, + const T* bias, + const bool* finished, + const int* sequence_lengths, + float* cum_log_probs, + float* output_log_probs, + int* ids, + void* tmp_storage, + const int temp_storage_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int* end_ids, + const float diversity_rate, + const float length_penalty, cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/quantization_int8_kernels.cu b/src/fastertransformer/kernels/quantization_int8_kernels.cu index 6ef3897ca..89641cf2d 100644 --- a/src/fastertransformer/kernels/quantization_int8_kernels.cu +++ b/src/fastertransformer/kernels/quantization_int8_kernels.cu @@ -21,14 +21,14 @@ __global__ void quantized_kernel(char4* dst, const float4* src, const int size_d { int tid = (blockIdx.x * blockDim.x + threadIdx.x); if (tid < size_div_4) { - const float scale = __ldg(scale_ptr); - char4 tmp; + const float scale = __ldg(scale_ptr); + char4 tmp; const float4 floatTmp = __ldg(src + tid); - tmp.x = float_to_int8_rn(floatTmp.x * scale); - tmp.y = float_to_int8_rn(floatTmp.y * scale); - tmp.z = float_to_int8_rn(floatTmp.z * scale); - tmp.w = float_to_int8_rn(floatTmp.w * scale); - dst[tid] = tmp; + tmp.x = float_to_int8_rn(floatTmp.x * scale); + tmp.y = float_to_int8_rn(floatTmp.y * scale); + tmp.z = float_to_int8_rn(floatTmp.z * scale); + tmp.w = float_to_int8_rn(floatTmp.w * scale); + dst[tid] = tmp; } } @@ -37,17 +37,17 @@ __global__ void quantized_kernel(char4* dst, const half2* src, const int size_di int tid = (blockIdx.x * blockDim.x + threadIdx.x); if (tid < size_div_4) { const float scale = __ldg(scale_ptr); - char4 tmp; - int src_id = tid << 1; + char4 tmp; + int src_id = tid << 1; const half2 half2Tmp = __ldg(src + src_id); - tmp.x = float_to_int8_rn(static_cast(half2Tmp.x) * scale); - tmp.y = float_to_int8_rn(static_cast(half2Tmp.y) * scale); + tmp.x = float_to_int8_rn(static_cast(half2Tmp.x) * scale); + tmp.y = float_to_int8_rn(static_cast(half2Tmp.y) * scale); const half2 half2Tmp2 = __ldg(src + src_id + 1); - tmp.z = float_to_int8_rn(static_cast(half2Tmp2.x) * scale); - tmp.w = float_to_int8_rn(static_cast(half2Tmp2.y) * scale); - dst[tid] = tmp; + tmp.z = float_to_int8_rn(static_cast(half2Tmp2.x) * scale); + tmp.w = float_to_int8_rn(static_cast(half2Tmp2.y) * scale); + dst[tid] = tmp; } } diff --git a/src/fastertransformer/kernels/quantize_weight.cu b/src/fastertransformer/kernels/quantize_weight.cu index 84b4af84f..86945c1f4 100644 --- a/src/fastertransformer/kernels/quantize_weight.cu +++ b/src/fastertransformer/kernels/quantize_weight.cu @@ -35,25 +35,25 @@ __device__ __host__ int index_CUBLASLT_ORDER_COL4_4R2_8C(int col_id, int row_id, __device__ __host__ int index_CUBLASLT_ORDER_COL32_2R_4R4(int col_id, int row_id, int m_32) { - int new_col = col_id >> 5; + int new_col = col_id >> 5; int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - int new_row = // CUBLASLT_ORDER_COL32_2R_4R4 + int new_row = // CUBLASLT_ORDER_COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile & 7) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) + col_in_tile); return new_col * m_32 + new_row; } -__global__ void quantize_weight_kernel(int8_t* dst, +__global__ void quantize_weight_kernel(int8_t* dst, const float* src, const float* amax, - const int n, - const int k, - const int format, - const int scale_is_vector) + const int n, + const int k, + const int format, + const int scale_is_vector) { - int tid = (blockIdx.x * blockDim.x + threadIdx.x); + int tid = (blockIdx.x * blockDim.x + threadIdx.x); int col_idx = tid / n; int row_idx = tid - col_idx * n; int new_idx; @@ -67,20 +67,20 @@ __global__ void quantize_weight_kernel(int8_t* dst, new_idx = index_CUBLASLT_ORDER_COL4_4R2_8C(col_idx, row_idx, 32 * n); } if (tid < n * k) { - int8_t v = float_to_int8_rn(src[tid] * 127.0 / amax[row_idx * scale_is_vector]); + int8_t v = float_to_int8_rn(src[tid] * 127.0 / amax[row_idx * scale_is_vector]); dst[new_idx] = v; } } -__global__ void quantize_weight_kernel(int8_t* dst, - const half* src, +__global__ void quantize_weight_kernel(int8_t* dst, + const half* src, const float* amax, - const int n, - const int k, - const int format, - const int scale_is_vector) + const int n, + const int k, + const int format, + const int scale_is_vector) { - int tid = (blockIdx.x * blockDim.x + threadIdx.x); + int tid = (blockIdx.x * blockDim.x + threadIdx.x); int col_idx = tid / n; int row_idx = tid - col_idx * n; int new_idx; @@ -94,20 +94,20 @@ __global__ void quantize_weight_kernel(int8_t* dst, new_idx = index_CUBLASLT_ORDER_COL4_4R2_8C(col_idx, row_idx, 32 * n); } if (tid < n * k) { - int8_t v = float_to_int8_rn(__half2float(src[tid]) * 127.0 / amax[row_idx * scale_is_vector]); + int8_t v = float_to_int8_rn(__half2float(src[tid]) * 127.0 / amax[row_idx * scale_is_vector]); dst[new_idx] = v; } } template -void invokeQuantizeWeight(int8_t* dst, - const T* src, +void invokeQuantizeWeight(int8_t* dst, + const T* src, const float* amax, - const int n, - const int k, - const int format, + const int n, + const int k, + const int format, cudaStream_t stream, - const int scale_is_vector) + const int scale_is_vector) { if (format != 0 && format != 1 & format != 2) { printf("[ERROR][invokeQuantizeWeight] format must be one of 0, 1, 2. current value: %d\n", format); @@ -128,22 +128,22 @@ void invokeQuantizeWeight(int8_t* dst, } } -template void invokeQuantizeWeight(int8_t* dst, +template void invokeQuantizeWeight(int8_t* dst, const float* src, const float* amax, - const int n, - const int k, - const int format, + const int n, + const int k, + const int format, cudaStream_t stream, - const int scale_is_vector); + const int scale_is_vector); -template void invokeQuantizeWeight(int8_t* dst, - const half* src, +template void invokeQuantizeWeight(int8_t* dst, + const half* src, const float* amax, - const int n, - const int k, - const int format, + const int n, + const int k, + const int format, cudaStream_t stream, - const int scale_is_vector); + const int scale_is_vector); } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/quantize_weight.h b/src/fastertransformer/kernels/quantize_weight.h index d3ae0b787..207a6afba 100644 --- a/src/fastertransformer/kernels/quantize_weight.h +++ b/src/fastertransformer/kernels/quantize_weight.h @@ -27,13 +27,13 @@ namespace fastertransformer { // 1: CUBLASLT_ORDER_COL32_2R_4R4 // 2: CUBLASLT_ORDER_COL4_4R2_8C template -void invokeQuantizeWeight(int8_t* dst, - const T* src, +void invokeQuantizeWeight(int8_t* dst, + const T* src, const float* amax, - const int n, - const int k, - const int format, + const int n, + const int k, + const int format, cudaStream_t stream, - const int scale_is_vector = 1); + const int scale_is_vector = 1); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/reduce_kernel_utils.cuh b/src/fastertransformer/kernels/reduce_kernel_utils.cuh index b3ee2e04e..7af2e20d2 100644 --- a/src/fastertransformer/kernels/reduce_kernel_utils.cuh +++ b/src/fastertransformer/kernels/reduce_kernel_utils.cuh @@ -22,10 +22,12 @@ #include #endif #include +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include #include #include #include +#include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" namespace cg = cooperative_groups; @@ -39,7 +41,7 @@ __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));//__shfl_sync bf16 return float when sm < 80 return val; } @@ -288,15 +290,23 @@ __device__ __forceinline__ TopK_2 reduce_topk_op_2(const TopK_2& a, const } template -__device__ __forceinline__ T clamp_inf_for_half(const float input) +__device__ __forceinline__ T clamp_inf_for_half(const float input) { + return input; +} + +template<> +__device__ __forceinline__ half clamp_inf_for_half(const float input) { - if (std::is_same::value == true) { - // clamp inf values to enable fp16 training - return (float)input > 0.0f ? min(input, HALF_FLT_MAX - 1000) : max(input, -HALF_FLT_MAX + 1000); - } - else { - return input; - } + // clamp inf values to enable fp16 training + return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000); } +#ifdef ENABLE_BF16 +template<> +__device__ __forceinline__ __nv_bfloat16 clamp_inf_for_half(const float input) +{ + return __float2bfloat16(input); +} +#endif + } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/reverse_roll_kernels.cu b/src/fastertransformer/kernels/reverse_roll_kernels.cu index f7f64b00b..9b7f751a7 100644 --- a/src/fastertransformer/kernels/reverse_roll_kernels.cu +++ b/src/fastertransformer/kernels/reverse_roll_kernels.cu @@ -25,50 +25,50 @@ namespace fastertransformer { // grid(W, H, batch) // block(min(1024, dim)) -__global__ void reverse_roll_col32(int8_t* dst, +__global__ void reverse_roll_col32(int8_t* dst, const int8_t* src, - const int batch, - const int window_num, - const int window_len, - const int window_size, - const int H, - const int W, - const int shift_size, - const int dim) + const int batch, + const int window_num, + const int window_len, + const int window_size, + const int H, + const int W, + const int shift_size, + const int dim) { const int batch_idx = blockIdx.z; - const int HW_idx = (blockIdx.y << 5) + threadIdx.y; + const int HW_idx = (blockIdx.y << 5) + threadIdx.y; if (HW_idx < H * W) { - const int H_idx = HW_idx / W; - const int W_idx = HW_idx % W; + const int H_idx = HW_idx / W; + const int W_idx = HW_idx % W; const int H_idx_shifted = (H_idx + shift_size) % H; const int W_idx_shifted = (W_idx + shift_size) % W; - const int window_idx = H_idx / window_size * (W / window_size) + W_idx / window_size; - const int idx_in_window = (H_idx % window_size) * window_size + (W_idx % window_size); - const int input_offset = (batch_idx * window_num + window_idx) * window_len + idx_in_window; - const int output_offset = (batch_idx * H + H_idx_shifted) * W + W_idx_shifted; - const int m = H * W * batch; - char4* inPtr = (char4*)src; - char4* outPtr = (char4*)dst; - const int col_start = (blockIdx.x << 5) + (threadIdx.x << 2); - const int offset_col32_in = (col_start & 0xffffffe0) * m + (input_offset << 5) + (col_start & 31); - const int offset_col32_out = (col_start & 0xffffffe0) * m + (output_offset << 5) + (col_start & 31); + const int window_idx = H_idx / window_size * (W / window_size) + W_idx / window_size; + const int idx_in_window = (H_idx % window_size) * window_size + (W_idx % window_size); + const int input_offset = (batch_idx * window_num + window_idx) * window_len + idx_in_window; + const int output_offset = (batch_idx * H + H_idx_shifted) * W + W_idx_shifted; + const int m = H * W * batch; + char4* inPtr = (char4*)src; + char4* outPtr = (char4*)dst; + const int col_start = (blockIdx.x << 5) + (threadIdx.x << 2); + const int offset_col32_in = (col_start & 0xffffffe0) * m + (input_offset << 5) + (col_start & 31); + const int offset_col32_out = (col_start & 0xffffffe0) * m + (output_offset << 5) + (col_start & 31); outPtr[offset_col32_out >> 2] = inPtr[offset_col32_in >> 2]; } } -void invokeReverseRollCol32(int8_t* dst, +void invokeReverseRollCol32(int8_t* dst, const int8_t* src, - int batch, - int window_num, - int window_len, - int window_size, - int H, - int W, - int dim, - int shift_size, - cudaStream_t stream) + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, + cudaStream_t stream) { dim3 grid((dim + 31) / 32, (H * W + 31) / 32, batch); dim3 block(8, 32); @@ -83,8 +83,8 @@ void invokeReverseRollCol32(int8_t* dst, // block(min(1024, dim)) template -__global__ void reverse_roll(T* dst, - const T* src, +__global__ void reverse_roll(T* dst, + const T* src, const int batch, const int window_num, const int window_len, @@ -94,14 +94,14 @@ __global__ void reverse_roll(T* dst, const int shift_size, const int dim) { - const int batch_idx = blockIdx.z; + const int batch_idx = blockIdx.z; const int H_idx_shifted = (blockIdx.y + shift_size) % H; const int W_idx_shifted = (blockIdx.x + shift_size) % W; - const int H_idx = blockIdx.y; - const int W_idx = blockIdx.x; - const int window_idx = H_idx / window_size * (W / window_size) + W_idx / window_size; + const int H_idx = blockIdx.y; + const int W_idx = blockIdx.x; + const int window_idx = H_idx / window_size * (W / window_size) + W_idx / window_size; const int idx_in_window = (H_idx % window_size) * window_size + (W_idx % window_size); - const int input_offset = (batch_idx * window_num + window_idx) * window_len + idx_in_window; + const int input_offset = (batch_idx * window_num + window_idx) * window_len + idx_in_window; const int output_offset = (batch_idx * H + H_idx_shifted) * W + W_idx_shifted; for (int tid = threadIdx.x; tid < dim; tid += blockDim.x) { dst[output_offset * dim + tid] = src[input_offset * dim + tid]; @@ -113,27 +113,32 @@ __global__ void reverse_roll(T* dst, // grid(W, H, batch) // block(min(1024, dim)) template -void invokeReverseRoll(T* dst, - const T* src, - int batch, - int window_num, - int window_len, - int window_size, - int H, - int W, - int dim, - int shift_size, +void invokeReverseRoll(T* dst, + const T* src, + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, cudaStream_t stream) { dim3 grid(W, H, batch); - int blockSize = dim; + int blockSize = dim; +#ifdef ENABLE_BF16 + if ((std::is_same::value || std::is_same::value) && (dim % 2 == 0)) { +#else if (std::is_same::value && (dim % 2 == 0)) { +#endif blockSize = dim / 2; if (blockSize > 1024) { blockSize = 1024; } + using T2 = typename TypeConverter::Type; // bfloat162 or half2 reverse_roll<<>>( - (half2*)dst, (const half2*)src, batch, window_num, window_len, window_size, H, W, shift_size, dim / 2); + (T2*)dst, (const T2*)src, batch, window_num, window_len, window_size, H, W, shift_size, dim / 2); } else { if (blockSize > 1024) { @@ -144,28 +149,42 @@ void invokeReverseRoll(T* dst, } } -template void invokeReverseRoll(float* dst, +template void invokeReverseRoll(float* dst, const float* src, - int batch, - int window_num, - int window_len, - int window_size, - int H, - int W, - int dim, - int shift_size, + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, cudaStream_t stream); -template void invokeReverseRoll(half* dst, - const half* src, - int batch, - int window_num, - int window_len, - int window_size, - int H, - int W, - int dim, - int shift_size, +template void invokeReverseRoll(half* dst, + const half* src, + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeReverseRoll(__nv_bfloat16* dst, + const __nv_bfloat16* src, + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, + cudaStream_t stream); +#endif + } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/reverse_roll_kernels.h b/src/fastertransformer/kernels/reverse_roll_kernels.h index 1e076d4b0..c8f88808b 100644 --- a/src/fastertransformer/kernels/reverse_roll_kernels.h +++ b/src/fastertransformer/kernels/reverse_roll_kernels.h @@ -21,29 +21,29 @@ namespace fastertransformer { -void invokeReverseRollCol32(int8_t* dst, +void invokeReverseRollCol32(int8_t* dst, const int8_t* src, - int batch, - int window_num, - int window_len, - int window_size, - int H, - int W, - int dim, - int shift_size, - cudaStream_t stream); + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, + cudaStream_t stream); template -void invokeReverseRoll(T* dst, - const T* src, - int batch, - int window_num, - int window_len, - int window_size, - int H, - int W, - int dim, - int shift_size, +void invokeReverseRoll(T* dst, + const T* src, + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, cudaStream_t stream); } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/sampling_penalty_kernels.cu b/src/fastertransformer/kernels/sampling_penalty_kernels.cu index 18b7a7057..80e56ac7e 100644 --- a/src/fastertransformer/kernels/sampling_penalty_kernels.cu +++ b/src/fastertransformer/kernels/sampling_penalty_kernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "src/fastertransformer/kernels/sampling_penalty_kernels.h" @@ -22,15 +23,15 @@ namespace fastertransformer { // TODO Add half2 implementation template -__global__ void applyTemperaturePenalty(T* logits, - const T* bias, +__global__ void applyTemperaturePenalty(T* logits, + const T* bias, const float temperature_inverse, - const int m, - const int vocab_size, - const int vocab_size_padd) + const int m, + const int vocab_size, + const int vocab_size_padd) { - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? 65504.F : FLT_MAX; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? 65504.F : FLT_MAX; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < m * vocab_size_padd; index += blockDim.x * gridDim.x) { T bias_val = bias == nullptr ? (T)(0.0f) : bias[index % vocab_size_padd]; @@ -43,56 +44,200 @@ __global__ void applyTemperaturePenalty(T* logits, } } +template<> +__global__ void applyTemperaturePenalty(half2* logits, + const half2* bias, + const float temperature_inverse, + const int batch_size, + const int vocab_size, + const int vocab_size_padded) +{ + assert(vocab_size % 2 == 0); + assert(vocab_size_padded % 2 == 0); + const half2 mask_val = __float2half2_rn(-65504.0f); + const half2 temp_inv = __float2half2_rn(temperature_inverse); + + const int half_vocab_size = vocab_size / 2; + const int half_vocab_size_padded = vocab_size_padded / 2; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * half_vocab_size_padded; + index += blockDim.x * gridDim.x) { + int vocab_idx = index % half_vocab_size_padded; + half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val; + if (vocab_idx < half_vocab_size) { + if (bias != nullptr) { + logit = __hadd2(logit, bias[vocab_idx]); + } + logits[index] = __hmul2(logit, temp_inv); + } + } +} + template -void invokeApplyTemperaturePenalty(T* logits, - const T* bias, - const float temperature, - const int m, - const int vocab_size, - const int vocab_size_padd, +void invokeApplyTemperaturePenalty(T* logits, + const T* bias, + const float temperature, + const int batch_size, + const int vocab_size, + const int vocab_size_padd, cudaStream_t stream) { - dim3 grid(min(m, 65536)); - dim3 block(min(vocab_size_padd, 1024)); + dim3 block(min(vocab_size_padd, 1024)); + dim3 grid(min(batch_size * vocab_size_padd / block.x, 65536)); const T temperature_inverse = (T)(1.f / (float)temperature); - applyTemperaturePenalty - <<>>(logits, bias, temperature_inverse, m, vocab_size, vocab_size_padd); + if (std::is_same::value && vocab_size % 2 == 0 && vocab_size_padd % 2 == 0) { + applyTemperaturePenalty<<>>(reinterpret_cast(logits), + reinterpret_cast(bias), + temperature_inverse, + batch_size, + vocab_size, + vocab_size_padd); + } + else { + applyTemperaturePenalty + <<>>(logits, bias, temperature_inverse, batch_size, vocab_size, vocab_size_padd); + } } -template void invokeApplyTemperaturePenalty(float* logits, +template void invokeApplyTemperaturePenalty(float* logits, const float* bias, - const float temperature, - const int m, - const int vocab_size, - const int vocab_size_padd, + const float temperature, + const int batch_size, + const int vocab_size, + const int vocab_size_padd, cudaStream_t stream); -template void invokeApplyTemperaturePenalty(half* logits, - const half* bias, - const float temperature, - const int m, - const int vocab_size, - const int vocab_size_padd, +template void invokeApplyTemperaturePenalty(half* logits, + const half* bias, + const float temperature, + const int batch_size, + const int vocab_size, + const int vocab_size_padd, cudaStream_t stream); template -__global__ void applyRepetitionPenalty(T* logits, +__global__ void batchApplyTemperaturePenalty(T* logits, + const T* bias, + const float* temperatures, + const int batch_size, + const int vocab_size, + const int vocab_size_padd) +{ + // TODO: Add macro or device function to get MAX_T_VAL. + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? 65504.F : FLT_MAX; + extern __shared__ float inv_temperatures[]; + if (threadIdx.x < batch_size) { + inv_temperatures[threadIdx.x] = 1 / temperatures[threadIdx.x]; + } + __syncthreads(); + + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * vocab_size_padd; + index += blockDim.x * gridDim.x) { + int batch_idx = index / vocab_size_padd; + int vocab_idx = index % vocab_size_padd; + T logit = (vocab_idx < vocab_size) ? logits[index] : -MAX_T_VAL; + if (vocab_idx < vocab_size) { + if (bias != nullptr) { + logit += bias[vocab_idx]; + } + logit *= inv_temperatures[batch_idx]; + } + logits[index] = logit; + } +} + +__global__ void batchApplyTemperaturePenalty_h2(half2* logits, + const half2* bias, + const float* temperatures, + const int batch_size, + const int vocab_size, + const int vocab_size_padded) +{ + assert(vocab_size % 2 == 0); + assert(vocab_size_padded % 2 == 0); + extern __shared__ half2 h2_inv_temperatures[]; + if (threadIdx.x < batch_size) { + h2_inv_temperatures[threadIdx.x] = __float2half2_rn(1 / temperatures[threadIdx.x]); + } + __syncthreads(); + + const half2 mask_val = __float2half2_rn(-65504.0f); + const int half_vocab_size = vocab_size / 2; + const int half_vocab_size_padded = vocab_size_padded / 2; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * half_vocab_size_padded; + index += blockDim.x * gridDim.x) { + int batch_idx = index / half_vocab_size_padded; + int vocab_idx = index % half_vocab_size_padded; + half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val; + if (vocab_idx < half_vocab_size) { + if (bias != nullptr) { + logit = __hadd2(logit, bias[vocab_idx]); + } + logits[index] = __hmul2(logit, h2_inv_temperatures[batch_idx]); + } + } +} + +template +void invokeBatchApplyTemperaturePenalty(T* logits, + const T* bias, + const float* temperatures, + const int batch_size, + const int vocab_size, + const int vocab_size_padd, + cudaStream_t stream) +{ + dim3 block(min(vocab_size_padd, 1024)); + dim3 grid(min(batch_size * vocab_size_padd / block.x, 65536)); + if (std::is_same::value && vocab_size % 2 == 0 && vocab_size_padd % 2 == 0) { + size_t smem_size = sizeof(half2) * batch_size; + batchApplyTemperaturePenalty_h2<<>>(reinterpret_cast(logits), + reinterpret_cast(bias), + temperatures, + batch_size, + vocab_size, + vocab_size_padd); + } + else { + size_t smem_size = sizeof(float) * batch_size; + batchApplyTemperaturePenalty + <<>>(logits, bias, temperatures, batch_size, vocab_size, vocab_size_padd); + } +} + +template void invokeBatchApplyTemperaturePenalty(float* logits, + const float* bias, + const float* temperatures, + const int batch_size, + const int vocab_size, + const int vocab_size_padd, + cudaStream_t stream); + +template void invokeBatchApplyTemperaturePenalty(half* logits, + const half* bias, + const float* temperatures, + const int batch_size, + const int vocab_size, + const int vocab_size_padd, + cudaStream_t stream); + +template +__global__ void applyRepetitionPenalty(T* logits, const float penalty, - const int* start_ids, - int* output_ids, - const int batch_size, - const int local_batch_size, - const int vocab_size, - const int vocab_size_padd, - const int* input_lengths, - const int max_input_len, - const int step, - const int ite) + const int* start_ids, + int* output_ids, + const int batch_size, + const int local_batch_size, + const int vocab_size, + const int vocab_size_padd, + const int* input_lengths, + const int max_input_len, + const int step) { extern __shared__ float penalty_logits[]; - int* penalty_indices = (int*)(penalty_logits + step); + int* penalty_indices = (int*)(penalty_logits + step); - logits = logits + blockIdx.x * vocab_size_padd; + logits = logits + blockIdx.x * vocab_size_padd; const int input_length = input_lengths != nullptr ? input_lengths[blockIdx.x] : max_input_len; for (int index = threadIdx.x; index < step; index += blockDim.x) { @@ -101,13 +246,13 @@ __global__ void applyRepetitionPenalty(T* logits, } // output_ids shape: (input_len + output_len, batch_size) - int penalty_index = output_ids[index * batch_size + local_batch_size * ite + blockIdx.x]; + int penalty_index = output_ids[index * batch_size + blockIdx.x]; if (penalty_index >= vocab_size) { continue; } penalty_indices[index] = penalty_index; - float logit = (float)logits[penalty_index]; - penalty_logits[index] = logit < 0.0f ? logit * penalty : logit / penalty; + float logit = (float)logits[penalty_index]; + penalty_logits[index] = logit < 0.0f ? logit * penalty : logit / penalty; } if (blockDim.x > 32) { @@ -129,62 +274,152 @@ __global__ void applyRepetitionPenalty(T* logits, } template -void invokeApplyRepetitionPenalty(T* logits, - const float penalty, - const int* start_ids, - int* output_ids, - const int batch_size, - const int local_batch_size, - const int vocab_size, - const int vocab_size_padd, - const int* input_lengths, - const int max_input_len, - const int step, - const int ite, +void invokeApplyRepetitionPenalty(T* logits, + const float penalty, + const int* start_ids, + int* output_ids, + const int batch_size, + const int local_batch_size, + const int vocab_size, + const int vocab_size_padd, + const int* input_lengths, + const int max_input_len, + const int step, cudaStream_t stream) { - dim3 block(min(512, step)); - dim3 grid((int)(local_batch_size)); - applyRepetitionPenalty<<>>(logits, - penalty, - start_ids, - output_ids, - batch_size, - local_batch_size, - vocab_size, - vocab_size_padd, - input_lengths, - max_input_len, - step, - ite); + dim3 block(min(step, 1024)); + dim3 grid(local_batch_size); + applyRepetitionPenalty<<>>(logits, + penalty, + start_ids, + output_ids, + batch_size, + local_batch_size, + vocab_size, + vocab_size_padd, + input_lengths, + max_input_len, + step); } -template void invokeApplyRepetitionPenalty(float* logits, - const float penalty, - const int* start_ids, - int* output_ids, - const int batch_size, - const int local_batch_size, - const int vocab_size, - const int vocab_size_padd, - const int* input_lengths, - const int max_input_len, - const int step, - const int ite, +template void invokeApplyRepetitionPenalty(float* logits, + const float penalty, + const int* start_ids, + int* output_ids, + const int batch_size, + const int local_batch_size, + const int vocab_size, + const int vocab_size_padd, + const int* input_lengths, + const int max_input_len, + const int step, cudaStream_t stream); -template void invokeApplyRepetitionPenalty(half* logits, - const float penalty, - const int* start_ids, - int* output_ids, - const int batch_size, - const int local_batch_size, - const int vocab_size, - const int vocab_size_padd, - const int* input_lengths, - const int max_input_len, - const int step, - const int ite, +template void invokeApplyRepetitionPenalty(half* logits, + const float penalty, + const int* start_ids, + int* output_ids, + const int batch_size, + const int local_batch_size, + const int vocab_size, + const int vocab_size_padd, + const int* input_lengths, + const int max_input_len, + const int step, cudaStream_t stream); -} // namespace fastertransformer \ No newline at end of file +template +__global__ void batchApplyRepetitionPenalty(T* logits, + const float* penalties, + const int* output_ids, + const int batch_size, + const int vocab_size, + const int* input_lengths, + const int max_input_length, + const int step) +{ + extern __shared__ float penalty_logits[]; + int* penalty_indices = (int*)(penalty_logits + step); + const int batch_idx = blockIdx.x; + const float penalty = penalties[batch_idx]; + const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; + + logits += batch_idx * vocab_size; + + // Phase 1. Find indices to penalize and keep the penalized values. + // A vocab id can appear multiple times but should be penalized once. + for (int index = threadIdx.x; index < step; index += blockDim.x) { + // Skip the padding tokens in input sequences. + if (index >= input_length && index < max_input_length) { + continue; + } + // output_ids shape: (input_len + output_len, batch_size) + int penalty_index = output_ids[index * batch_size + batch_idx]; + assert(penalty_index < vocab_size); + penalty_indices[index] = penalty_index; + float logit = (float)logits[penalty_index]; + penalty_logits[index] = logit < 0.0f ? logit * penalty : logit / penalty; + } + + if (blockDim.x > 32) { + __syncthreads(); + } + + // Phase 2. Replace a logit value by the penalized one. + for (int index = threadIdx.x; index < step; index += blockDim.x) { + // Skip the padding tokens in input sequences. + if (index >= input_length && index < max_input_length) { + continue; + } + logits[penalty_indices[index]] = penalty_logits[index]; + } +} + +template +void invokeBatchApplyRepetitionPenalty(T* logits, + const float* penalties, + const int* output_ids, + const int batch_size, + const int local_batch_size, + const int vocab_size, + const int* input_lengths, + const int max_input_length, + const int step, + cudaStream_t stream) +{ + // Inputs + // logits [local_batch_size, vocab_size] : logit values. + // penalties [local_batch_size] : repetition penalty factors. + // output_ids [step, batch_size] : output token ids (with offset ite * local_batch_size). + // input_lengths [local_batch_size], input lengths (optional). + // Padding tokens at [input_length, max_input_length) of input will not be penalized. + dim3 block(min(step, 1024)); + dim3 grid(local_batch_size); + size_t smem_size = step * (sizeof(float) + sizeof(int)); + batchApplyRepetitionPenalty<<>>( + logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); +} + +template void invokeBatchApplyRepetitionPenalty(float* logits, + const float* penalties, + const int* output_ids, + const int batch_size, + const int local_batch_size, + const int vocab_size, + const int* input_lengths, + const int max_input_length, + const int step, + cudaStream_t stream); + +template void invokeBatchApplyRepetitionPenalty(half* logits, + const float* penalties, + const int* output_ids, + const int batch_size, + const int local_batch_size, + const int vocab_size, + const int* input_lengths, + const int max_input_length, + const int step, + cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/sampling_penalty_kernels.h b/src/fastertransformer/kernels/sampling_penalty_kernels.h index 958fc2c16..b0830b104 100644 --- a/src/fastertransformer/kernels/sampling_penalty_kernels.h +++ b/src/fastertransformer/kernels/sampling_penalty_kernels.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,27 +22,46 @@ namespace fastertransformer { template -void invokeApplyRepetitionPenalty(T* logits, - const float penalty, - const int* start_ids, - int* output_ids, - const int batch_size, - const int local_batch_size, - const int vocab_size, - const int vocab_size_padd, - const int* input_lengths, - const int max_input_len, - const int step, - const int ite, +void invokeApplyRepetitionPenalty(T* logits, + const float penalty, + const int* start_ids, + int* output_ids, + const int batch_size, + const int local_batch_size, + const int vocab_size, + const int vocab_size_padd, + const int* input_lengths, + const int max_input_len, + const int step, cudaStream_t stream); template -void invokeApplyTemperaturePenalty(T* logits, - const T* bias, - const float temperature, - const int m, - const int vocab_size, - const int vocab_size_padd, +void invokeBatchApplyRepetitionPenalty(T* logits, + const float* penalties, + const int* output_ids, + const int batch_size, + const int local_batch_size, + const int vocab_size, + const int* input_lengths, + const int max_input_length, + const int step, + cudaStream_t stream); + +template +void invokeApplyTemperaturePenalty(T* logits, + const T* bias, + const float temperature, + const int batch_size, + const int vocab_size, + const int vocab_size_padd, cudaStream_t stream); -} // namespace fastertransformer \ No newline at end of file +template +void invokeBatchApplyTemperaturePenalty(T* logits, + const T* bias, + const float* temperatures, + const int batch_size, + const int vocab_size, + const int vocab_size_padd, + cudaStream_t stream); +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/sampling_topk_kernels.cu b/src/fastertransformer/kernels/sampling_topk_kernels.cu index 6da5f558f..bd5f46fd6 100644 --- a/src/fastertransformer/kernels/sampling_topk_kernels.cu +++ b/src/fastertransformer/kernels/sampling_topk_kernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +15,7 @@ * limitations under the License. */ +#include #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) @@ -35,50 +36,123 @@ __global__ void curandInitialize(curandState_t* state, const int size, const uns } } -void invokeCurandInitialize(curandState_t* state, - const size_t batch_size, +void invokeCurandInitialize(curandState_t* state, + const size_t batch_size, const unsigned long long random_seed, - cudaStream_t stream) + cudaStream_t stream) { dim3 block(256); dim3 grid((int)(ceil(batch_size * 1.0 / 256))); curandInitialize<<>>(state, batch_size, random_seed); } +__global__ void curandBatchInitialize(curandState_t* states, const int size, const unsigned long long* random_seeds) +{ + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < size) { + curand_init(random_seeds[idx], 0, 0, &states[idx]); + } +} + +void invokeCurandBatchInitialize(curandState_t* states, + const size_t batch_size, + const unsigned long long* random_seeds, + cudaStream_t stream) +{ + dim3 block(256); + dim3 grid((int)(ceil(batch_size * 1.0 / 256))); + curandBatchInitialize<<>>(states, batch_size, random_seeds); +} + +template +__global__ void addBiasEndMask(T* logits, const T* bias, const int* end_ids, const bool* finished, const int n) +{ + int bid = blockIdx.x; + bool finish = finished != nullptr ? finished[bid] : false; + int offset = bid * n; + + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + for (int tid = threadIdx.x; tid < n; tid += blockDim.x) { + if (finish) { + logits[offset + tid] = (tid == end_ids[bid]) ? MAX_T_VAL : -MAX_T_VAL; + } + else { + if (bias != nullptr) { + logits[offset + tid] += bias[tid]; + } + } + } +} + +template +void invokeAddBiasEndMask( + T* logits, const T* bias, const int* end_ids, const bool* finished, const int m, const int n, cudaStream_t stream) +{ + dim3 grid(m); + dim3 block(min(n, 1024)); + /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ + addBiasEndMask<<>>(logits, bias, end_ids, finished, n); +} + +template void invokeAddBiasEndMask(float* logits, + const float* bias, + const int* end_ids, + const bool* finished, + const int m, + const int n, + cudaStream_t stream); + +template void invokeAddBiasEndMask(half* logits, + const half* bias, + const int* end_ids, + const bool* finished, + const int m, + const int n, + cudaStream_t stream); + template -__global__ void topk_stage_1_opt3(const T* __restrict log_probs, - T* tmp_log_probs, - int* topk_tmp_id_buf, - T* topk_tmp_val_buf, - const bool* finished, - const int k, - const int vocab_size, - const int* end_ids) +__global__ void topk_stage1(const T* __restrict log_probs, + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size, + const int* end_ids, + const bool* skip_decode) { typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage; const int tid = threadIdx.x; const int bid = blockIdx.x; - const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs - const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam - const int tmp_log_buf_index = row_id * vocab_size; - const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; - TopK_2 partial; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const int batch_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs + if (skip_decode != nullptr && skip_decode[batch_id]) { + return; + } + const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam + const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k; // batch_id = batch index + + const int tmp_log_buf_index = batch_id * vocab_size; + const int tmp_topk_buf_index = batch_id * BLOCKS_PER_BEAM_ * max_top_k + block_lane * k; + + TopK_2 partial; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - if (finished != nullptr && finished[row_id] == true) { + if (finished != nullptr && finished[batch_id] == true) { if (tid < k) { const int index = tmp_topk_buf_index + tid; if (block_lane == 0 && tid == 0) { - const int end_id = end_ids[row_id]; - topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; + const int end_id = end_ids[batch_id]; + topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id]; } else { - topk_tmp_id_buf[index] = -1; + topk_tmp_id_buf[index] = -1; topk_tmp_val_buf[index] = -MAX_T_VAL; } } @@ -87,7 +161,7 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { - int index = elem_id + tmp_log_buf_index; + int index = elem_id + tmp_log_buf_index; tmp_log_probs[index] = log_probs[index]; } @@ -103,106 +177,64 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { - const int index = tmp_topk_buf_index + ite; - topk_tmp_id_buf[index] = total.p; + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = total.p; topk_tmp_val_buf[index] = total.u; - tmp_log_probs[total.p] = -MAX_T_VAL; + tmp_log_probs[total.p] = -MAX_T_VAL; } __syncthreads(); } } -template -__global__ void addBiasEndMask(T* logits, const T* bias, const int* end_ids, const bool* finished, const int n) +template +__global__ void topk_stage2_sampling(const int* __restrict topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + int* sequence_length, + bool* finished, + float* cum_log_probs, + float* output_log_probs, + const int max_top_k, + const int* top_ks, + const float top_p, + const float* top_ps, + curandState_t* curandstate, + const int* end_ids, + const int vocab_size, + const bool* skip_decode) { - int bid = blockIdx.x; - bool finish = finished != nullptr ? finished[bid] : false; - int offset = bid * n; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - for (int tid = threadIdx.x; tid < n; tid += blockDim.x) { - if (finish) { - logits[offset + tid] = (tid == end_ids[bid]) ? MAX_T_VAL : -MAX_T_VAL; - } - else { - if (bias != nullptr) { - logits[offset + tid] += bias[tid]; - } - } + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + if (skip_decode != nullptr && skip_decode[batch_id]) { + return; } -} - -template -void invokeAddBiasEndMask( - T* logits, const T* bias, const int* end_ids, const bool* finished, const int m, const int n, cudaStream_t stream) -{ - dim3 grid(m); - dim3 block(min(n, 1024)); - /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ - addBiasEndMask<<>>(logits, bias, end_ids, finished, n); -} - -template void invokeAddBiasEndMask(float* logits, - const float* bias, - const int* end_ids, - const bool* finished, - const int m, - const int n, - cudaStream_t stream); -template void invokeAddBiasEndMask(half* logits, - const half* bias, - const int* end_ids, - const bool* finished, - const int m, - const int n, - cudaStream_t stream); - -template -__global__ void topk_topp_stage_2_opt3_sampling(const int* __restrict topk_tmp_id_buf, - T* topk_tmp_val_buf, - T* topk_tmp2_val_buf, - int* ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, - const int k, - const T prob_threshold, - curandState_t* curandstate, - const int* end_ids, - const int vocab_size) -{ - const int size = k * BLOCKS_PER_BEAM_; - const int tid = threadIdx.x; - const int batch_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k; + const float prob_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p; + const int size = k * BLOCKS_PER_BEAM_; + const int stride = max_top_k * BLOCKS_PER_BEAM_; typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - extern __shared__ char array[]; - __shared__ float rand_num; - __shared__ float s_sum; - __shared__ float s_max; - T* s_val = topk_tmp_val_buf + batch_id * size; - int* s_id = (int*)(array); - s_max = 0.0f; - s_sum = 0.0f; + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char array[]; + __shared__ float rand_num; + __shared__ float s_sum; + __shared__ float s_max; + T* s_val = topk_tmp_val_buf + batch_id * stride; + int* s_id = reinterpret_cast(array); + s_max = 0.0f; + s_sum = 0.0f; TopK_2 partial; - if (finished_buf != nullptr && finished_buf[batch_id] == true) { + if (finished != nullptr && finished[batch_id] == true) { ids[batch_id] = end_ids[batch_id]; return; } - for (int index = tid; index < size; index += BLOCK_SIZE_) { - topk_tmp2_val_buf[batch_id * size + index] = topk_tmp_val_buf[batch_id * size + index]; - } - __syncthreads(); float* s_val2 = reinterpret_cast(s_id + k); - for (int ite = 0; ite < k; ite++) { partial.init(); #pragma unroll @@ -217,7 +249,7 @@ __global__ void topk_topp_stage_2_opt3_sampling(const int* __restrict topk_tmp_i } if (tid == 0) { - s_id[ite] = total.p; + s_id[ite] = total.p; s_val[total.p] = -MAX_T_VAL; // when cum_log_probs are computed, topk_tmp_val_buf (logits_buf_) are already pre-processed by @@ -230,13 +262,14 @@ __global__ void topk_topp_stage_2_opt3_sampling(const int* __restrict topk_tmp_i } __syncthreads(); } + if (tid == 0) { - rand_num = (float)curand_uniform(curandstate + blockIdx.x) * (float)prob_threshold * s_sum; + rand_num = (float)curand_uniform(curandstate + blockIdx.x) * prob_threshold * s_sum; for (int i = 0; i < k; i++) { float exp_logit = s_val2[i]; - rand_num = rand_num - exp_logit; + rand_num = rand_num - exp_logit; if (rand_num <= 0.0f || i == k - 1) { - ids[batch_id] = topk_tmp_id_buf[batch_id * size + s_id[i]] % vocab_size; + ids[batch_id] = topk_tmp_id_buf[batch_id * stride + s_id[i]] % vocab_size; if (cum_log_probs != nullptr || output_log_probs != nullptr) { float log_prob = logf(exp_logit); if (cum_log_probs != nullptr) { @@ -255,241 +288,275 @@ __global__ void topk_topp_stage_2_opt3_sampling(const int* __restrict topk_tmp_i break; } } - if (sequence_length != nullptr && finished_buf != nullptr) { - sequence_length[batch_id] = - finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1; - finished_buf[batch_id] = ids[batch_id] == end_ids[batch_id] ? 1 : 0; + if (sequence_length != nullptr && finished != nullptr) { + sequence_length[batch_id] = finished[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1; + finished[batch_id] = ids[batch_id] == end_ids[batch_id] ? true : false; } } } #define CASE_K(K_MIN, K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \ case K_MIN ... K_MAX: \ - topk_stage_1_opt3 \ + topk_stage1 \ <<>>(log_probs, \ temp_log_probs, \ topk_tmp_id_buf, \ topk_tmp_val_buf, \ - finished_buf, \ - candidate_num, \ + finished, \ + max_top_k, \ + top_ks, \ vocab_size, \ - end_ids); \ - topk_topp_stage_2_opt3_sampling \ + end_ids, \ + skip_decode); \ + topk_stage2_sampling \ <<>>(topk_tmp_id_buf, \ topk_tmp_val_buf, \ - topk_tmp2_val_buf, \ ids, \ sequence_length, \ - finished_buf, \ + finished, \ cum_log_probs, \ output_log_probs, \ - candidate_num, \ - 1.0f, \ + max_top_k, \ + top_ks, \ + top_p, \ + top_ps, \ curandstate, \ end_ids, \ - vocab_size); \ + vocab_size, \ + skip_decode); \ break; template -void invokeTopKSampling(void* workspace, - size_t& workspace_size, - T* log_probs, - int* ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, - curandState_t* curandstate, - const int top_k, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream, - const int batch_size) +void invokeBatchTopKSampling(void* workspace, + size_t& workspace_size, + const T* log_probs, + int* ids, + int* sequence_length, + bool* finished, + float* cum_log_probs, + float* output_log_probs, + curandState_t* curandstate, + const int max_top_k, + const int* top_ks, + const float top_p, + const float* top_ps, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode) { - // Here, we put batch size as an argument because the batch size of initialization - // and inference may be different due to pipelint parallelism. - const int candidate_num = top_k; - const int vocab_size = vocab_size_padded; - - const int max_block_per_beam = 8; - int temp_log_probs_buf_size = batch_size * vocab_size; // type float - int topk_tmp_ids_buf_size = batch_size * candidate_num * max_block_per_beam; // type int - int topk_tmp_val_buf_size = batch_size * candidate_num * max_block_per_beam; // type float - - // prevent memory misalinged address + // Not allow an ambiguous inputs top_p and top_ps. + assert(top_p == 1.0f || top_ps == nullptr); + const int vocab_size = vocab_size_padded; + const int max_block_per_beam = 8; + int temp_log_probs_buf_size = batch_size * vocab_size; // type float + int topk_tmp_ids_buf_size = batch_size * max_top_k * max_block_per_beam; // type int + int topk_tmp_val_buf_size = batch_size * max_top_k * max_block_per_beam; // type float + + // prevent memory misaligned address temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4; - topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; - topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; + topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; + topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; if (workspace == nullptr) { workspace_size = sizeof(T) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size - + 2 * sizeof(T) * topk_tmp_val_buf_size; + + sizeof(T) * topk_tmp_val_buf_size; return; } - else { - T* temp_log_probs = (T*)workspace; - int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); - T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); - T* topk_tmp2_val_buf = (T*)(topk_tmp_val_buf + topk_tmp_val_buf_size); - - switch (candidate_num) { - CASE_K(1, 16, 128, 128, 8); - CASE_K(17, 32, 256, 128, 8); - CASE_K(33, 64, 256, 256, 8); - default: - printf("[ERROR] Topk kernel does not support candidate_num = %d \n", candidate_num); - exit(0); - break; - } - return; + + T* temp_log_probs = (T*)workspace; + int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); + T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); + + switch (max_top_k) { + CASE_K(1, 16, 128, 128, 8); + CASE_K(17, 32, 256, 128, 8); + CASE_K(33, 64, 256, 256, 8); + CASE_K(65, 1024, 256, 256, 8); + default: + throw std::domain_error(fmtstr("top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k)); } } #undef CASE_K -template void invokeTopKSampling(void* workspace, - size_t& workspace_size, - float* log_probs, - int* ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, +template void invokeBatchTopKSampling(void* workspace, + size_t& workspace_size, + const float* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + curandState_t* curandstate, + const int max_top_k, + const int* top_ks, + const float top_p, + const float* top_ps, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); + +template void invokeBatchTopKSampling(void* workspace, + size_t& workspace_size, + const half* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + curandState_t* curandstate, + const int max_top_k, + const int* top_ks, + const float top_p, + const float* top_ps, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); + +template +void invokeTopKSampling(void* workspace, + size_t& workspace_size, + const T* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + curandState_t* curandstate, + const int top_k, + const float top_p, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode) +{ + invokeBatchTopKSampling(workspace, + workspace_size, + log_probs, + ids, + sequence_length, + finished_buf, + cum_log_probs, + output_log_probs, + curandstate, + top_k, + nullptr, + top_p, + nullptr, + vocab_size_padded, + end_ids, + stream, + batch_size, + skip_decode); +} + +template void invokeTopKSampling(void* workspace, + size_t& workspace_size, + const float* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, curandState_t* curandstate, - const int top_k, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream, - const int batch_size); - -template void invokeTopKSampling(void* workspace, - size_t& workspace_size, - half* log_probs, - int* ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, + const int top_k, + const float top_p, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); + +template void invokeTopKSampling(void* workspace, + size_t& workspace_size, + const half* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, curandState_t* curandstate, - const int top_k, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream, - const int batch_size); - -#define CASE_K(K_MIN, K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \ - case K_MIN ... K_MAX: \ - topk_stage_1_opt3 \ - <<>>(logits, \ - temp_logits, \ - topk_tmp_id_buf, \ - topk_tmp_val_buf, \ - finished_buf, \ - candidate_num, \ - vocab_size, \ - end_ids); \ - topk_topp_stage_2_opt3_sampling \ - <<>>(topk_tmp_id_buf, \ - topk_tmp_val_buf, \ - topk_tmp2_val_buf, \ - output_ids, \ - sequence_length, \ - finished_buf, \ - cum_log_probs, \ - output_log_probs, \ - candidate_num, \ - prob_threshold, \ - curandstate, \ - end_ids, \ - vocab_size); \ - break; + const int top_k, + const float top_p, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); template -void invokeTopKTopPSampling(void* workspace, - size_t& workspace_size, - int* output_ids, - const T* logits, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, +void invokeTopKTopPSampling(void* workspace, + size_t& workspace_size, + int* output_ids, + const T* logits, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, curandState_t* curandstate, - const int batch_size, - const int top_k, - const T top_p, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream) + const int batch_size, + const int top_k, + const float top_p, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream) { - // Here, we put batch size as an argument because the batch size of initialization - // and inference may be different due to pipeline parallelism. - const int candidate_num = top_k; - const T prob_threshold = top_p; - const int vocab_size = vocab_size_padded; - - const int max_block_per_beam = 8; - int temp_logits_buf_size = batch_size * vocab_size; // type T - int topk_tmp_ids_buf_size = batch_size * candidate_num * max_block_per_beam; // type int - int topk_tmp_val_buf_size = batch_size * candidate_num * max_block_per_beam; // type T - - // prevent memory misalinged address - temp_logits_buf_size = (int)(ceil(temp_logits_buf_size / 4.)) * 4; - topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; - topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; - - if (workspace == nullptr) { - workspace_size = sizeof(T) * temp_logits_buf_size + sizeof(int) * topk_tmp_ids_buf_size - + 2 * sizeof(T) * topk_tmp_val_buf_size; - return; - } - else { - T* temp_logits = (T*)workspace; - int* topk_tmp_id_buf = (int*)(temp_logits + temp_logits_buf_size); - T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); - T* topk_tmp2_val_buf = (T*)(topk_tmp_val_buf + topk_tmp_val_buf_size); - - switch (candidate_num) { - CASE_K(1, 16, 128, 128, 8); - CASE_K(17, 32, 256, 128, 8); - CASE_K(33, 64, 256, 256, 8); - default: - printf("[ERROR] Topk kernel does not support candidate_num = %d \n", candidate_num); - exit(0); - break; - } - return; - } + // invokeTopKTopPSampling will be deprecated. Please use invokeTopKSampling instead. + invokeTopKSampling(workspace, + workspace_size, + logits, + output_ids, + sequence_length, + finished_buf, + cum_log_probs, + output_log_probs, + curandstate, + top_k, + top_p, + vocab_size_padded, + end_ids, + stream, + batch_size, + nullptr); } -template void invokeTopKTopPSampling(void* workspace, - size_t& workspace_size, - int* output_ids, - const float* logits, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, +template void invokeTopKTopPSampling(void* workspace, + size_t& workspace_size, + int* output_ids, + const float* logits, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, curandState_t* curandstate, - const int batch_size, - const int top_k, - const float top_p, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream); - -template void invokeTopKTopPSampling(void* workspace, - size_t& workspace_size, - int* output_ids, - const half* logits, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, + const int batch_size, + const int top_k, + const float top_p, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream); + +template void invokeTopKTopPSampling(void* workspace, + size_t& workspace_size, + int* output_ids, + const half* logits, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, curandState_t* curandstate, - const int batch_size, - const int top_k, - const half top_p, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream); + const int batch_size, + const int top_k, + const float top_p, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream); + } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/sampling_topk_kernels.h b/src/fastertransformer/kernels/sampling_topk_kernels.h index 0adef4fc3..fd02204ce 100644 --- a/src/fastertransformer/kernels/sampling_topk_kernels.h +++ b/src/fastertransformer/kernels/sampling_topk_kernels.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,45 +21,72 @@ namespace fastertransformer { template -void invokeTopKSampling(void* workspace, - size_t& workspace_size, - T* log_probs, - int* ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, +void invokeTopKSampling(void* workspace, + size_t& workspace_size, + const T* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, curandState_t* curandstate, - const int top_k, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream, - const int batch_size); + const int top_k, + const float top_p, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); -void invokeCurandInitialize(curandState_t* state, - const size_t batch_size, +template +void invokeBatchTopKSampling(void* workspace, + size_t& workspace_size, + const T* log_probs, + int* ids, + int* sequence_length, + bool* finished, + float* cum_log_probs, + float* output_log_probs, + curandState_t* curandstate, + const int max_top_k, + const int* top_ks, + const float top_p, + const float* top_ps, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); + +void invokeCurandInitialize(curandState_t* state, + const size_t batch_size, unsigned long long random_seed, - cudaStream_t stream); + cudaStream_t stream); + +void invokeCurandBatchInitialize(curandState_t* states, + const size_t batch_size, + const unsigned long long* random_seeds, + cudaStream_t stream); template void invokeAddBiasEndMask( T* logits, const T* bias, const int* end_ids, const bool* finished, const int m, const int n, cudaStream_t stream); template -void invokeTopKTopPSampling(void* workspace, - size_t& workspace_size, - int* output_ids, - const T* logits, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, +void invokeTopKTopPSampling(void* workspace, + size_t& workspace_size, + int* output_ids, + const T* logits, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, curandState_t* curandstate, - const int batch_size, - const int top_k, - const T top_p, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream); + const int batch_size, + const int top_k, + const float top_p, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/sampling_topp_kernels.cu b/src/fastertransformer/kernels/sampling_topp_kernels.cu index 5b7aa9411..86d1ae4dd 100644 --- a/src/fastertransformer/kernels/sampling_topp_kernels.cu +++ b/src/fastertransformer/kernels/sampling_topp_kernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,8 +26,8 @@ #include "src/fastertransformer/kernels/sampling_topp_kernels.h" #include "src/fastertransformer/utils/cuda_utils.h" -constexpr int ENABLE_SINGLE_PASS_TOP_P = 0; -constexpr float SINGLE_PASS_THRESHOLD = 0.9; +constexpr int ENABLE_SINGLE_PASS_TOP_P = 0; +constexpr float SINGLE_PASS_THRESHOLD = 0.9; namespace fastertransformer { @@ -46,7 +46,8 @@ template using Copy_t = Copy_half_t; template -struct Float_as_int_ {}; +struct Float_as_int_ { +}; template<> struct Float_as_int_ { using Type = uint32_t; @@ -56,10 +57,10 @@ struct Float_as_int_<__half> { using Type = uint16_t; }; -using kernel_params_float = Segmented_topk_kernel_params; +using kernel_params_float = Segmented_topk_kernel_params; using kernel_params_float_1 = Segmented_topk_kernel_params; -using kernel_params_half = Segmented_topk_kernel_params<__half, int32_t, 256, 4>; -using kernel_params_half_1 = Segmented_topk_kernel_params<__half, int32_t, 256, 1>; +using kernel_params_half = Segmented_topk_kernel_params<__half, int32_t, 256, 4>; +using kernel_params_half_1 = Segmented_topk_kernel_params<__half, int32_t, 256, 1>; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,13 +82,13 @@ static inline __device__ float to_float(uint16_t src) // sort one segment per cta template __global__ void blockSortKernel(const T_SCORE* d_keys_in, - T_SCORE* d_keys_out, + T_SCORE* d_keys_out, const int32_t* d_values_in, - int32_t* d_values_out, + int32_t* d_values_out, const int32_t* active_counts, - int num_items_, - int stride_items, - int num_segments) + int num_items_, + int stride_items, + int num_segments) { // Specialize BlockRadixSort for a 1D block typedef cub::BlockRadixSort BlockRadixSort; @@ -127,23 +128,23 @@ __global__ void blockSortKernel(const T_SCORE* d_keys_in, /// block sort kernel template void blockSort(const T_SCORE* d_keys_in, - T_SCORE* d_keys_out, + T_SCORE* d_keys_out, const int32_t* d_values_in, - int32_t* d_values_out, + int32_t* d_values_out, const int32_t* active_counts, - int num_items, - int stride_items, - int num_segments, - cudaStream_t stream) + int num_items, + int stride_items, + int num_segments, + cudaStream_t stream) { if (num_items == 0) { return; } - int kernel_index = div_up(num_items, 128) - 1; + int kernel_index = div_up(num_items, 128) - 1; int warps_per_cta = (kernel_index + 1) * 128 / 32; if (kernel_index > 7) { - kernel_index = 7 + div_up(num_items, 1024) - 1; + kernel_index = 7 + div_up(num_items, 1024) - 1; warps_per_cta = 1024 / 32; } assert(warps_per_cta <= 32); @@ -152,13 +153,13 @@ void blockSort(const T_SCORE* d_keys_in, dim3 grid(num_segments); using kernel_func = void (*)(const T_SCORE* d_keys_in, - T_SCORE* d_keys_out, + T_SCORE* d_keys_out, const int32_t* d_values_in, - int32_t* d_values_out, + int32_t* d_values_out, const int32_t* active_counts, - int num_items, - int stride_items, - int num_segments); + int num_items, + int stride_items, + int num_segments); static const kernel_func kernel_funcs[] = { &blockSortKernel, @@ -200,10 +201,10 @@ struct BlockPrefixCallbackOp { // governs the split between regs and smem constexpr float SMEM_FRACTION = 0.5F; -constexpr float P_EPSILON = 0.01F; +constexpr float P_EPSILON = 0.01F; constexpr int MAX_TOP_K = 3072; -constexpr int WARP_SZ = 32; +constexpr int WARP_SZ = 32; template __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, @@ -213,14 +214,14 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, constexpr int debug_block_id = 26; #endif - using Key_Data_Type = typename Kernel_params::Key_Data_Type; + using Key_Data_Type = typename Kernel_params::Key_Data_Type; using Int_Key_Data_Type = typename Float_as_int_::Type; // 4 fp16 keys or 2 fp32 keys - constexpr int KEYS_PER_LDG = Kernel_params::KEYS_PER_LDG; + constexpr int KEYS_PER_LDG = Kernel_params::KEYS_PER_LDG; typedef Copy_t copy_t; union access_t { - copy_t v; + copy_t v; Int_Key_Data_Type x[KEYS_PER_LDG]; // supported size 1,2,4 }; @@ -238,11 +239,11 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, } #endif - constexpr int MIN_KEY = 0; - constexpr int ENABLED_PER_THREAD = (ITEMS_PER_THREAD + 32 - 1) / 32; + constexpr int MIN_KEY = 0; + constexpr int ENABLED_PER_THREAD = (ITEMS_PER_THREAD + 32 - 1) / 32; extern __shared__ int2 dynamic_smem[]; - int2* smem_selected_elements = dynamic_smem; - Int_Key_Data_Type* smem_thread_items = reinterpret_cast(smem_selected_elements + MAX_TOP_K); + int2* smem_selected_elements = dynamic_smem; + Int_Key_Data_Type* smem_thread_items = reinterpret_cast(smem_selected_elements + MAX_TOP_K); __shared__ unsigned int smem_selected_count; @@ -251,7 +252,7 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, // Specialize BlockScan type for our thread block typedef cub::BlockReduce BlockReduce; - __shared__ float smem_p_sum_total; + __shared__ float smem_p_sum_total; __shared__ union { typename BlockScan::TempStorage scan; @@ -265,7 +266,7 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, uint32_t segment = blockIdx.y * gridDim.x + blockIdx.x; - // Preceeding TopK has shortcutted this segment + // Preceding TopK has shortcutted this segment if (params.gmem_begin_offsets[segment] == params.gmem_end_offsets[segment]) { if (threadIdx.x == 0) { params.gmem_active_count_per_segment[segment] = 1; @@ -276,22 +277,22 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, Int_Key_Data_Type* gmem_src_keys = reinterpret_cast(params.gmem_src_keys); Int_Key_Data_Type* gmem_dst_keys = reinterpret_cast(params.gmem_dst_keys); - int32_t* gmem_dst_vals = reinterpret_cast(params.gmem_dst_vals); + int32_t* gmem_dst_vals = reinterpret_cast(params.gmem_dst_vals); constexpr int BITS_IN_KEY = sizeof(Key_Data_Type) * 8; - int items = params.num_items / params.num_segments; + int items = params.num_items / params.num_segments; int first_index = segment * items; gmem_src_keys += first_index; gmem_dst_keys += first_index; gmem_dst_vals += first_index; - int index_limit = items; + int index_limit = items; Int_Key_Data_Type thread_items[ITEMS_PER_THREAD_IN_REGS] = {0}; // Load all keys into registers and smem - const int lane_id = threadIdx.x % WARP_SZ; - const int warp_id = threadIdx.x / WARP_SZ; + const int lane_id = threadIdx.x % WARP_SZ; + const int warp_id = threadIdx.x / WARP_SZ; constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SZ; access_t ZERO; @@ -301,14 +302,14 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, // registers for (int iter = 0; iter < ITEMS_PER_THREAD_IN_REGS; iter++) { - int offset = (iter + threadIdx.x * ITEMS_PER_THREAD); + int offset = (iter + threadIdx.x * ITEMS_PER_THREAD); thread_items[iter] = (offset < index_limit) ? gmem_src_keys[offset] : MIN_KEY; } // shared memory for (int c = warp_id; c < BLOCK_THREADS; c += NUM_WARPS) { for (int iter = lane_id * KEYS_PER_LDG; iter < ITEMS_PER_THREAD_IN_SMEM; iter += WARP_SZ * KEYS_PER_LDG) { - int offset = iter + c * ITEMS_PER_THREAD + ITEMS_PER_THREAD_IN_REGS; + int offset = iter + c * ITEMS_PER_THREAD + ITEMS_PER_THREAD_IN_REGS; access_t val; val.v = (offset < index_limit) ? *reinterpret_cast(&gmem_src_keys[offset]) : ZERO.v; for (int i = 0; i < KEYS_PER_LDG; i++) { @@ -319,7 +320,7 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, } Int_Key_Data_Type select_mask = 0; - Int_Key_Data_Type save_mask = 0; + Int_Key_Data_Type save_mask = 0; // Int_Key_Data_Type save_bit = 0; // set to true when we finish with too few keys, so we go back to last_save_mask one more time @@ -327,7 +328,7 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, if (threadIdx.x == 0) { smem_selected_count = 0; - old_selected_count = 0; + old_selected_count = 0; } // iterate over bits. @@ -335,25 +336,25 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, // * bit 31 is the sign bit. all values are positive // * bit 30 is only set for values >= 2, but the input consists only of values in // the range of [0,1] - constexpr int START_BIT = BITS_IN_KEY - 1; - constexpr int SKIP_BITS = 2; - constexpr Int_Key_Data_Type ONE = (Int_Key_Data_Type)1; - uint32_t selected; - uint32_t sc; - float p_sum_total = 0.0F; - float old_p_sum_total = 0.0F; - uint32_t offset = 0; + constexpr int START_BIT = BITS_IN_KEY - 1; + constexpr int SKIP_BITS = 2; + constexpr Int_Key_Data_Type ONE = (Int_Key_Data_Type)1; + uint32_t selected; + uint32_t sc; + float p_sum_total = 0.0F; + float old_p_sum_total = 0.0F; + uint32_t offset = 0; for (Int_Key_Data_Type bit = START_BIT - SKIP_BITS; true; --bit) { __syncthreads(); Int_Key_Data_Type bit_mask = select_mask | (ONE << bit); uint32_t enabled[ENABLED_PER_THREAD] = {0}; - float thread_sum = 0.0F; + float thread_sum = 0.0F; for (int item = 0; item < ITEMS_PER_THREAD_IN_REGS; ++item) { // check if all the bits from bit mask are contained in the thread_item. If yes, set respective // bit of enabled - auto val = thread_items[item]; + auto val = thread_items[item]; uint32_t is_enabled = uint32_t(((val ^ bit_mask) & bit_mask) == 0); // thread_sum += (is_enabled)? to_float(val) : 0.0F; thread_sum += is_enabled * to_float(val); @@ -363,7 +364,7 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, for (int item = 0; item < ITEMS_PER_THREAD_IN_SMEM; ++item) { int idx = threadIdx.x + item * BLOCK_THREADS; // int idx = item + ITEMS_PER_THREAD_IN_SMEM * threadIdx.x; - auto val = smem_thread_items[idx]; + auto val = smem_thread_items[idx]; uint32_t is_enabled = uint32_t(((val ^ bit_mask) & bit_mask) == 0); // thread_sum += (is_enabled)? to_float(val) : 0.0F; thread_sum += is_enabled * to_float(val); @@ -476,7 +477,7 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, if (threadIdx.x == 0) { smem_selected_count = old_selected_count; - p_sum_total = old_p_sum_total; + p_sum_total = old_p_sum_total; prefix_op.running_total = old_selected_count; } @@ -490,7 +491,7 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, } if (threadIdx.x == 0) { old_selected_count = smem_selected_count; - old_p_sum_total = p_sum_total; + old_p_sum_total = p_sum_total; } } } @@ -508,8 +509,8 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, } for (int i = threadIdx.x; i < sc; i += blockDim.x) { int2 selected_element = smem_selected_elements[i]; - gmem_dst_keys[i] = selected_element.x; - gmem_dst_vals[i] = selected_element.y; + gmem_dst_keys[i] = selected_element.x; + gmem_dst_vals[i] = selected_element.y; } } @@ -518,23 +519,21 @@ __global__ __launch_bounds__(Kernel_params::BLOCK_THREADS, template int getSmemSizeAndCheck(const TopKPerSegmentContext& context, const TopKPerSegmentParams& params) { - constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS; - using Key_Data_Type = typename Kernel_params::Key_Data_Type; - int num_items_per_segment = params.num_items / params.num_segments; - constexpr int ITEMS_INCREMENT = Kernel_params::ITEMS_INCREMENT; - int kernel_index = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1; - - int smem_size = MAX_TOP_K * sizeof(int2); - const int items_per_thread = (kernel_index + 1) * ITEMS_INCREMENT; + constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS; + using Key_Data_Type = typename Kernel_params::Key_Data_Type; + int num_items_per_segment = params.num_items / params.num_segments; + constexpr int ITEMS_INCREMENT = Kernel_params::ITEMS_INCREMENT; + int kernel_index = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1; + + int smem_size = MAX_TOP_K * sizeof(int2); + const int items_per_thread = (kernel_index + 1) * ITEMS_INCREMENT; const int items_per_thread_in_regs = items_per_thread * (1.0F - SMEM_FRACTION); const int items_per_thread_in_smem = items_per_thread - items_per_thread_in_regs; - constexpr int KEYS_PER_LDG = Kernel_params::KEYS_PER_LDG; - smem_size += items_per_thread_in_smem * BLOCK_THREADS * sizeof(typename Float_as_int_::Type); int keys_per_ldg = 2 * sizeof(Key_Data_Type) / 2; - if (smem_size + BLOCK_THREADS * sizeof(float) > context.sm_shared_size || // dynamic + static memory + if (smem_size + BLOCK_THREADS * sizeof(float) > (size_t)context.sm_shared_size || // dynamic + static memory items_per_thread_in_regs + items_per_thread_in_smem != items_per_thread || params.top_p + P_EPSILON > 1.0F || items_per_thread_in_regs % keys_per_ldg != 0 || items_per_thread_in_smem % keys_per_ldg != 0 || num_items_per_segment % keys_per_ldg != 0) { @@ -547,8 +546,8 @@ int getSmemSizeAndCheck(const TopKPerSegmentContext& context, const TopKPerSegme /////////////////////////////////////////////////////////////////////////////////////////////////// int getSmemSizeAndCheck(const TopKPerSegmentContext& context, - const TopKPerSegmentParams& params, - const DType_t DT_SCORE) + const TopKPerSegmentParams& params, + const DType_t DT_SCORE) { int num_items_per_segment = params.num_items / params.num_segments; if (DT_SCORE == kFLOAT) { @@ -572,19 +571,19 @@ int getSmemSizeAndCheck(const TopKPerSegmentContext& context, /////////////////////////////////////////////////////////////////////////////////////////////////// template -void segmentedTopPSinglePass_dispatch(const TopKPerSegmentParams& params, +void segmentedTopPSinglePass_dispatch(const TopKPerSegmentParams& params, const TopKPerSegmentContext& context, - cudaStream_t stream) + cudaStream_t stream) { constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS; - using Key_Data_Type = typename Kernel_params::Key_Data_Type; - using Value_Data_Type = typename Kernel_params::Value_Data_Type; + using Key_Data_Type = typename Kernel_params::Key_Data_Type; + using Value_Data_Type = typename Kernel_params::Value_Data_Type; int num_items_per_segment = params.num_items / params.num_segments; constexpr int ITEMS_INCREMENT = Kernel_params::ITEMS_INCREMENT; - int kernel_index = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1; + int kernel_index = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1; #define KERNEL_RUN(INDEX) \ { \ @@ -627,13 +626,13 @@ void segmentedTopPSinglePass_dispatch(const TopKPerSegmentParams& params, template void topPPerSegment_dispatch(const TopKPerSegmentContext& context, - TopKPerSegmentParams& params, - void* temp_storage, - size_t& temp_storage_bytes, - cudaStream_t stream) + TopKPerSegmentParams& params, + void* temp_storage, + size_t& temp_storage_bytes, + cudaStream_t stream) { - using Key_Data_Type = typename Kernel_params::Key_Data_Type; + using Key_Data_Type = typename Kernel_params::Key_Data_Type; using Value_Data_Type = typename Kernel_params::Value_Data_Type; if (temp_storage == nullptr) { @@ -674,7 +673,7 @@ void topPPerSegment_dispatch(const TopKPerSegmentContext& context, size_t cub_temp_storage_bytes = temp_storage_bytes - div_up(sizeof(int), 256) * 256 - div_up(sizeof(int) * params.num_segments, 256) * 256; - void* cub_temp_storage = temp_storage; + void* cub_temp_storage = temp_storage; params.gmem_active_count_total = reinterpret_cast((char*)temp_storage + cub_temp_storage_bytes); params.gmem_active_count_per_segment = reinterpret_cast((char*)params.gmem_active_count_total + div_up(sizeof(int), 256) * 256); @@ -682,6 +681,7 @@ void topPPerSegment_dispatch(const TopKPerSegmentContext& context, int num_items_per_segment = params.num_items / params.num_segments; cudaMemsetAsync(params.gmem_active_count_total, 0, sizeof(int), stream); + cudaMemsetAsync(params.gmem_dst_keys, 0, params.num_items * sizeof(Key_Data_Type), stream); segmentedTopPSinglePass_dispatch(params, context, stream); int max_num_items = 0; @@ -735,11 +735,11 @@ void topPPerSegment_dispatch(const TopKPerSegmentContext& context, /////////////////////////////////////////////////////////////////////////////////////////////////// int topPPerSegment(const TopKPerSegmentContext& context, - TopKPerSegmentParams& params, - const DType_t DT_SCORE, - void* temp_storage, - size_t& temp_storage_bytes, - cudaStream_t stream) + TopKPerSegmentParams& params, + const DType_t DT_SCORE, + void* temp_storage, + size_t& temp_storage_bytes, + cudaStream_t stream) { int num_items_per_segment = params.num_items / params.num_segments; if (DT_SCORE == kFLOAT) { @@ -774,7 +774,7 @@ __global__ void topPInitialize( if (bid == 0) { for (int i = tid; i < batch_size + 1; i += blockDim.x) { - topp_offset_buf[i] = i * n; + topp_offset_buf[i] = i * n; begin_topp_offset_buf_[i] = topp_offset_buf[i]; } } @@ -787,35 +787,41 @@ __global__ void topPInitialize( } } -void invokeTopPInitialize(int* topp_id_val_buf, - int* topp_offset_buf, - int* begin_topp_offset_buf_, +void invokeTopPInitialize(int* topp_id_val_buf, + int* topp_offset_buf, + int* begin_topp_offset_buf_, const size_t batch_size, - const int n, + const int n, cudaStream_t stream) { - // n: the coloumn number of logits_buffer for top_p sampling + // n: the column number of logits_buffer for top_p sampling topPInitialize<<<32, 512, 0, stream>>>(topp_id_val_buf, topp_offset_buf, begin_topp_offset_buf_, batch_size, n); } template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel_for_topP(const T* log_probs, - int* topk_tmp_id_buf, - T* topk_tmp_val_buf, - const int vocab_size, - int* offset_buf, - int* begin_offset_buf, - float p_threshold) +__launch_bounds__(THREADBLOCK_SIZE) __global__ void topp_beam_topk_kernel(const T* log_probs, // prob. + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const int vocab_size, + int* offset_buf, + int* begin_offset_buf, + const float top_p, + const float* top_ps, + const bool* skip_decode) { - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - int thread_id = threadIdx.x; - int block_id = blockIdx.x; - TopK partial; + int batch_id = blockIdx.x; + if (skip_decode != nullptr && skip_decode[batch_id]) { + return; + } + float p_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + TopK partial; + + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; #pragma unroll for (int i = 0; i < MAX_K; ++i) { @@ -825,15 +831,15 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel_for_topP(co #pragma unroll for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) { - int index = elem_id + block_id * vocab_size; + int index = elem_id + batch_id * vocab_size; partial.insert(log_probs[index], index); } TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); if (thread_id == 0) { - begin_offset_buf[block_id] = offset_buf[block_id]; - T sum_prob = (T)(0.0f); + begin_offset_buf[batch_id] = offset_buf[batch_id]; + T sum_prob = (T)(0.0f); #pragma unroll for (int i = 0; i < MAX_K; i++) { @@ -841,12 +847,12 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel_for_topP(co } if ((float)sum_prob >= p_threshold) { - begin_offset_buf[block_id] += vocab_size; - int index = block_id * vocab_size; + begin_offset_buf[batch_id] += vocab_size; + int index = batch_id * vocab_size; #pragma unroll for (int i = 0; i < MAX_K; ++i) { - topk_tmp_id_buf[index + i] = total.p[i] % vocab_size; + topk_tmp_id_buf[index + i] = total.p[i] % vocab_size; topk_tmp_val_buf[index + i] = total.u[i]; } } @@ -869,58 +875,74 @@ struct BlockPrefixCallbackOp { }; template -__global__ void top_p_sampling_opt(T* sorted_log_probs, - int* sorted_id_vals, - int* ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, - const int* begin_offset_buf, - const int* offset_buf, - const int vocab_size, - curandState_t* curandstate, - const float prob_threshold, - const int* end_ids, - const int batch_size) +__global__ void topp_sampling(T* sorted_log_probs, + int* sorted_id_vals, + int* ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + const int* begin_offset_buf, + const int* offset_buf, + const int vocab_size, + curandState_t* curandstate, + const float top_p, + const float* top_ps, + const int* end_ids, + const int batch_size, + const bool* skip_decode) { - // Generate the random number every round to prevent that we may - // skip the curand_uniform sometimes and the seed would be different. - __shared__ int stop_shared; + __shared__ int stop_shared; __shared__ float rand_num_s; + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + if (skip_decode != nullptr && skip_decode[batch_id]) { + return; + } + + constexpr int WARP_SIZE = 32; + constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const float prob_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p; + if (threadIdx.x == 0) { stop_shared = 0; - rand_num_s = curand_uniform(curandstate + blockIdx.x) * prob_threshold; + rand_num_s = curand_uniform(curandstate + blockIdx.x) * prob_threshold; } // if begin_offset_buf and offset_buf of sorting have same value, // this means that we have find best one in beam_topK_kernel_for_topP // and skip the sorting. So, we can skip then during sampling. - if (begin_offset_buf[blockIdx.x] == offset_buf[blockIdx.x]) { - if (threadIdx.x == 0) { - ids[blockIdx.x] = sorted_id_vals[blockIdx.x * vocab_size]; - if (sequence_length != nullptr && finished_buf != nullptr) { + if (begin_offset_buf[batch_id] == offset_buf[batch_id]) { + if (tid == 0) { + int offset = batch_id * vocab_size; + ids[batch_id] = sorted_id_vals[offset]; - sequence_length[blockIdx.x] = - finished_buf[blockIdx.x] ? sequence_length[blockIdx.x] : sequence_length[blockIdx.x] + 1; - finished_buf[blockIdx.x] = ids[blockIdx.x] == end_ids[blockIdx.x] ? 1 : 0; + if (cum_log_probs != nullptr || output_log_probs != nullptr) { + float lprob = logf(sorted_log_probs[offset]); + if (cum_log_probs != nullptr) { + cum_log_probs[batch_id] += lprob; + } + if (output_log_probs != nullptr) { + output_log_probs[batch_id] = lprob; + } + } + if (sequence_length != nullptr && finished_buf != nullptr) { + sequence_length[batch_id] = + finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1; + finished_buf[batch_id] = ids[batch_id] == end_ids[batch_id] ? 1 : 0; } } return; } - constexpr int WARP_SIZE = 32; - constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE; - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; - __shared__ uint32_t selected_shared[NUM_WARPS]; + __shared__ uint32_t selected_shared[NUM_WARPS]; // Initialize running total BlockPrefixCallbackOp prefix_op(0); - const int tid = threadIdx.x; - const int batch_id = blockIdx.x; - const int lane_id = threadIdx.x % WARP_SIZE; - const int warp_id = threadIdx.x / WARP_SIZE; if (lane_id == 0) { selected_shared[warp_id] = 0; @@ -928,10 +950,10 @@ __global__ void top_p_sampling_opt(T* sorted_log_probs, __syncthreads(); - int offset = batch_id * vocab_size; - ids[batch_id] = sorted_id_vals[offset]; - int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; - int i_active = 0; + int offset = batch_id * vocab_size; + ids[batch_id] = sorted_id_vals[offset]; + int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int i_active = 0; float thread_offset = 0; for (int i = tid; i < end; i += BLOCK_SIZE) { float thread_count = (i < vocab_size) ? (float)sorted_log_probs[offset + i] : 0.f; @@ -969,9 +991,6 @@ __global__ void top_p_sampling_opt(T* sorted_log_probs, cum_log_probs[batch_id] += lprob; } if (output_log_probs != nullptr) { - // TODO(jaedeokk): Check if should we normalize. - // lprob is the probability of the token, not the probability - // induced by the top-p decoding. output_log_probs[batch_id] = lprob; } } @@ -985,70 +1004,71 @@ __global__ void top_p_sampling_opt(T* sorted_log_probs, } template -void invokeTopPSampling(void* workspace, - size_t& workspace_size, - size_t& cub_temp_storage_size, - int* output_ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, - const T* log_probs, - const int* id_vals, - int* offset_buf, - int* begin_offset_buf, - curandState_t* curandstate, - const int batch_size, - const size_t vocab_size_padded, - const int* end_ids, - const float top_p, - cudaStream_t stream, - cudaDeviceProp* cuda_device_prop) +void invokeBatchTopPSampling(void* workspace, + size_t& workspace_size, + size_t& cub_temp_storage_size, + int* output_ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + const T* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + curandState_t* curandstate, + const int batch_size, + const size_t vocab_size_padded, + const int* end_ids, + const float max_top_p, + const float* top_ps, + cudaStream_t stream, + cudaDeviceProp* cuda_device_prop, + const bool* skip_decode) { // Here, we put batch size as an argument because the batch size of initialization - // and inference may be different due to pipelint parallelism. + // and inference may be different due to pipeline parallelism. const int vocab_size = vocab_size_padded; const int block_size = 256; - size_t sorted_log_prob_buf_size = batch_size * vocab_size * sizeof(T); // type T - size_t sorted_id_vals_buf_size = batch_size * vocab_size * sizeof(int); // type int - sorted_log_prob_buf_size = div_up(sorted_log_prob_buf_size, 256) * 256; - sorted_id_vals_buf_size = div_up(sorted_id_vals_buf_size, 256) * 256; + size_t sorted_log_prob_buf_size = batch_size * vocab_size * sizeof(T); // type T + size_t sorted_id_vals_buf_size = batch_size * vocab_size * sizeof(int); // type int + sorted_log_prob_buf_size = div_up(sorted_log_prob_buf_size, 256) * 256; + sorted_id_vals_buf_size = div_up(sorted_id_vals_buf_size, 256) * 256; void* cub_temp_storage = workspace; - T* sorted_log_probs = (T*)((char*)cub_temp_storage + cub_temp_storage_size); - int* sorted_id_vals = (int*)((char*)sorted_log_probs + sorted_log_prob_buf_size); + T* sorted_log_probs = (T*)((char*)cub_temp_storage + cub_temp_storage_size); + int* sorted_id_vals = (int*)((char*)sorted_log_probs + sorted_log_prob_buf_size); - bool do_radix_sort = (ENABLE_SINGLE_PASS_TOP_P == 0 || top_p >= SINGLE_PASS_THRESHOLD); - int smem_size = -1; + bool do_radix_sort = (ENABLE_SINGLE_PASS_TOP_P == 0 || max_top_p >= SINGLE_PASS_THRESHOLD); + int smem_size = -1; - FT_CHECK(cuda_device_prop != nullptr); segmented_topp_impl::TopKPerSegmentContext context; - segmented_topp_impl::TopKPerSegmentParams params; - segmented_topp_impl::DType_t dataTypeKind = + segmented_topp_impl::TopKPerSegmentParams params; + segmented_topp_impl::DType_t dataTypeKind = (std::is_same::value) ? segmented_topp_impl::kFLOAT : segmented_topp_impl::kHALF; if (!do_radix_sort) { + FT_CHECK(cuda_device_prop != nullptr); memset(&context, 0, sizeof(context)); - context.sm_count = cuda_device_prop->multiProcessorCount; + context.sm_count = cuda_device_prop->multiProcessorCount; context.sm_shared_size = cuda_device_prop->sharedMemPerMultiprocessor; - context.sm_version = cuda_device_prop->major * 100 + cuda_device_prop->minor * 10; + context.sm_version = cuda_device_prop->major * 100 + cuda_device_prop->minor * 10; memset(¶ms, 0, sizeof(params)); - params.gmem_src_keys = reinterpret_cast(const_cast(log_probs)); - params.gmem_dst_keys = sorted_log_probs; - params.gmem_src_vals = reinterpret_cast(const_cast(id_vals)); - params.gmem_dst_vals = reinterpret_cast(sorted_id_vals); - params.gmem_begin_offsets = begin_offset_buf; - params.gmem_end_offsets = offset_buf + 1; - params.workspace = nullptr; - params.num_items = vocab_size * batch_size; - params.num_segments = batch_size; - params.top_p = top_p; + params.gmem_src_keys = reinterpret_cast(const_cast(log_probs)); + params.gmem_dst_keys = sorted_log_probs; + params.gmem_src_vals = reinterpret_cast(const_cast(id_vals)); + params.gmem_dst_vals = reinterpret_cast(sorted_id_vals); + params.gmem_begin_offsets = begin_offset_buf; + params.gmem_end_offsets = offset_buf + 1; + params.workspace = nullptr; + params.num_items = vocab_size * batch_size; + params.num_segments = batch_size; + params.top_p = max_top_p; params.confidence_threshold = 0.0F; - smem_size = getSmemSizeAndCheck(context, params, dataTypeKind); - + smem_size = getSmemSizeAndCheck(context, params, dataTypeKind); do_radix_sort = smem_size < 0; } @@ -1069,11 +1089,19 @@ void invokeTopPSampling(void* workspace, sizeof(T) * 8, // end_bit = sizeof(KeyT) * 8 stream)); // cudaStream_t cub_temp_storage_size = div_up(cub_temp_storage_size, 256) * 256; - workspace_size = sorted_log_prob_buf_size + sorted_id_vals_buf_size + cub_temp_storage_size; + workspace_size = sorted_log_prob_buf_size + sorted_id_vals_buf_size + cub_temp_storage_size; return; } - beam_topK_kernel_for_topP<<>>( - log_probs, sorted_id_vals, sorted_log_probs, vocab_size, offset_buf, begin_offset_buf, top_p); + + topp_beam_topk_kernel<<>>(log_probs, + sorted_id_vals, + sorted_log_probs, + vocab_size, + offset_buf, + begin_offset_buf, + max_top_p, + top_ps, + skip_decode); check_cuda_error( cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage, @@ -1091,7 +1119,6 @@ void invokeTopPSampling(void* workspace, stream)); // cudaStream_t } else { - if (workspace == nullptr) { segmented_topp_impl::topPPerSegment( context, params, dataTypeKind, cub_temp_storage, cub_temp_storage_size, stream); @@ -1099,82 +1126,182 @@ void invokeTopPSampling(void* workspace, return; } else { - beam_topK_kernel_for_topP<<>>( - log_probs, sorted_id_vals, sorted_log_probs, vocab_size, offset_buf, begin_offset_buf, top_p); + topp_beam_topk_kernel<<>>(log_probs, + sorted_id_vals, + sorted_log_probs, + vocab_size, + offset_buf, + begin_offset_buf, + max_top_p, + top_ps, + skip_decode); segmented_topp_impl::topPPerSegment( context, params, dataTypeKind, cub_temp_storage, cub_temp_storage_size, stream); } } constexpr int SAMPLING_BLOCK_SIZE = 256; - dim3 grid(batch_size); - top_p_sampling_opt<<>>(sorted_log_probs, - sorted_id_vals, - output_ids, - sequence_length, - finished_buf, - cum_log_probs, - output_log_probs, - begin_offset_buf, - offset_buf + 1, - vocab_size, - curandstate, - top_p, - end_ids, - batch_size); + dim3 grid(batch_size); + topp_sampling<<>>(sorted_log_probs, + sorted_id_vals, + output_ids, + sequence_length, + finished_buf, + cum_log_probs, + output_log_probs, + begin_offset_buf, + offset_buf + 1, + vocab_size, + curandstate, + max_top_p, + top_ps, + end_ids, + batch_size, + skip_decode); +} + +template void invokeBatchTopPSampling(void* workspace, + size_t& workspace_size, + size_t& cub_temp_storage_size, + int* output_ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + const float* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + curandState_t* curandstate, + const int batch_size, + const size_t vocab_size_padded, + const int* end_ids, + const float max_top_p, + const float* top_ps, + cudaStream_t stream, + cudaDeviceProp* cuda_device_prop, + const bool* skip_decode); + +template void invokeBatchTopPSampling(void* workspace, + size_t& workspace_size, + size_t& cub_temp_storage_size, + int* output_ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + const half* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + curandState_t* curandstate, + const int batch_size, + const size_t vocab_size_padded, + const int* end_ids, + const float max_top_p, + const float* top_ps, + cudaStream_t stream, + cudaDeviceProp* cuda_device_prop, + const bool* skip_decode); + +template +void invokeTopPSampling(void* workspace, + size_t& workspace_size, + size_t& cub_temp_storage_size, + int* output_ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + const T* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + curandState_t* curandstate, + const int batch_size, + const size_t vocab_size_padded, + const int* end_ids, + const float top_p, + cudaStream_t stream, + cudaDeviceProp* cuda_device_prop, + const bool* skip_decode) +{ + invokeBatchTopPSampling(workspace, + workspace_size, + cub_temp_storage_size, + output_ids, + sequence_length, + finished_buf, + cum_log_probs, + output_log_probs, + log_probs, + id_vals, + offset_buf, + begin_offset_buf, + curandstate, + batch_size, + vocab_size_padded, + end_ids, + top_p, + nullptr, + stream, + cuda_device_prop, + skip_decode); } -template void invokeTopPSampling(void* workspace, - size_t& workspace_size, - size_t& cub_temp_storage_size, - int* output_ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, - const float* log_probs, - const int* id_vals, - int* offset_buf, - int* begin_offset_buf, - curandState_t* curandstate, - const int batch_size, - const size_t vocab_size_padded, - const int* end_ids, - const float top_p, - cudaStream_t stream, - cudaDeviceProp* cuda_device_prop); - -template void invokeTopPSampling(void* workspace, - size_t& workspace_size, - size_t& cub_temp_storage_size, - int* output_ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, - const half* log_probs, - const int* id_vals, - int* offset_buf, - int* begin_offset_buf, - curandState_t* curandstate, - const int batch_size, - const size_t vocab_size_padded, - const int* end_ids, - const float top_p, - cudaStream_t stream, - cudaDeviceProp* cuda_device_prop); +template void invokeTopPSampling(void* workspace, + size_t& workspace_size, + size_t& cub_temp_storage_size, + int* output_ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + const float* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + curandState_t* curandstate, + const int batch_size, + const size_t vocab_size_padded, + const int* end_ids, + const float top_p, + cudaStream_t stream, + cudaDeviceProp* cuda_device_prop, + const bool* skip_decode); + +template void invokeTopPSampling(void* workspace, + size_t& workspace_size, + size_t& cub_temp_storage_size, + int* output_ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + const half* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + curandState_t* curandstate, + const int batch_size, + const size_t vocab_size_padded, + const int* end_ids, + const float top_p, + cudaStream_t stream, + cudaDeviceProp* cuda_device_prop, + const bool* skip_decode); template __global__ void addBiasSoftMax(T* logits, const T* bias, const int* end_ids, const bool* finished, const int n_padded, const int n) { - int bid = blockIdx.x; + int bid = blockIdx.x; bool finish = (finished != nullptr) ? finished[bid] : false; - int offset = bid * n_padded; + int offset = bid * n_padded; - float max_val = -1 * FLT_MAX; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + float max_val = -1 * FLT_MAX; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; __shared__ float s_max_val; __shared__ float s_sum_val; @@ -1218,13 +1345,13 @@ addBiasSoftMax(T* logits, const T* bias, const int* end_ids, const bool* finishe } template -void invokeAddBiasSoftMax(T* logits, - const T* bias, - const int* end_ids, - const bool* finished, - const int m, - const int n_padded, - const int n, +void invokeAddBiasSoftMax(T* logits, + const T* bias, + const int* end_ids, + const bool* finished, + const int m, + const int n_padded, + const int n, cudaStream_t stream) { dim3 grid(m); @@ -1233,22 +1360,22 @@ void invokeAddBiasSoftMax(T* logits, addBiasSoftMax<<>>(logits, bias, end_ids, finished, n_padded, n); } -template void invokeAddBiasSoftMax(float* logits, +template void invokeAddBiasSoftMax(float* logits, const float* bias, - const int* end_ids, - const bool* finished, - const int m, - const int n_padded, - const int n, + const int* end_ids, + const bool* finished, + const int m, + const int n_padded, + const int n, cudaStream_t stream); -template void invokeAddBiasSoftMax(half* logits, - const half* bias, - const int* end_ids, - const bool* finished, - const int m, - const int n_padded, - const int n, +template void invokeAddBiasSoftMax(half* logits, + const half* bias, + const int* end_ids, + const bool* finished, + const int m, + const int n_padded, + const int n, cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/sampling_topp_kernels.h b/src/fastertransformer/kernels/sampling_topp_kernels.h index 75a8100d1..17d0d0068 100644 --- a/src/fastertransformer/kernels/sampling_topp_kernels.h +++ b/src/fastertransformer/kernels/sampling_topp_kernels.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,42 +19,66 @@ namespace fastertransformer { -void invokeTopPInitialize(int* topp_id_val_buf, - int* topp_offset_buf, - int* begin_topp_offset_buf_, +void invokeTopPInitialize(int* topp_id_val_buf, + int* topp_offset_buf, + int* begin_topp_offset_buf_, const size_t batch_size, - const int n, + const int n, cudaStream_t stream); template -void invokeTopPSampling(void* workspace, - size_t& workspace_size, - size_t& cub_temp_storage_size, - int* output_ids, - int* sequence_length, - bool* finished_buf, - float* cum_log_probs, - float* output_log_probs, - const T* log_probs, - const int* id_vals, - int* offset_buf, - int* begin_offset_buf, - curandState_t* curandstate, - const int batch_size, - const size_t vocab_size_padded, - const int* end_ids, - const float top_p, - cudaStream_t stream, - cudaDeviceProp* cuda_device_prop); +void invokeTopPSampling(void* workspace, + size_t& workspace_size, + size_t& cub_temp_storage_size, + int* output_ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + const T* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + curandState_t* curandstate, + const int batch_size, + const size_t vocab_size_padded, + const int* end_ids, + const float top_p, + cudaStream_t stream, + cudaDeviceProp* cuda_device_prop, + const bool* skip_decode); template -void invokeAddBiasSoftMax(T* logits, - const T* bias, - const int* end_ids, - const bool* finished, - const int m, - const int n_padded, - const int n, +void invokeBatchTopPSampling(void* workspace, + size_t& workspace_size, + size_t& cub_temp_storage_size, + int* output_ids, + int* sequence_length, + bool* finished_buf, + float* cum_log_probs, + float* output_log_probs, + const T* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + curandState_t* curandstate, + const int batch_size, + const size_t vocab_size_padded, + const int* end_ids, + const float max_top_p, + const float* top_ps, + cudaStream_t stream, + cudaDeviceProp* cuda_device_prop, + const bool* skip_decode); + +template +void invokeAddBiasSoftMax(T* logits, + const T* bias, + const int* end_ids, + const bool* finished, + const int m, + const int n_padded, + const int n, cudaStream_t stream); namespace segmented_topp_impl { @@ -64,12 +88,12 @@ enum DType_t { kINT8 }; -template + int BLOCK_THREADS_ = 256, + int KEYS_PER_LDG_ = 1> struct Segmented_topk_kernel_params { - typedef Key_Data_Type_ Key_Data_Type; + typedef Key_Data_Type_ Key_Data_Type; typedef Value_Data_Type_ Value_Data_Type; enum { BLOCK_THREADS = BLOCK_THREADS_ @@ -93,30 +117,30 @@ struct TopKPerSegmentContext { struct TopKPerSegmentParams { // input/output keys and values void *gmem_src_keys, *gmem_dst_keys, *gmem_dst_vals; - // not used in the custom implementaiton + // not used in the custom implementation void* gmem_src_vals; // int array of size num_segments int* gmem_active_count_per_segment; int* gmem_active_count_total; int* gmem_begin_offsets; // gmem_end_offsets will be populated - int* gmem_end_offsets; + int* gmem_end_offsets; void* workspace; // total number of items for all segments int num_items; int num_segments; // top_k per segment - int num_top_k; + int num_top_k; float top_p; float confidence_threshold; }; int topPPerSegment(const TopKPerSegmentContext& context, - TopKPerSegmentParams& params, - const DType_t DT_SCORE, - void* temp_storage, - size_t& temp_storage_bytes, - cudaStream_t stream); + TopKPerSegmentParams& params, + const DType_t DT_SCORE, + void* temp_storage, + size_t& temp_storage_bytes, + cudaStream_t stream); } // namespace segmented_topp_impl } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/softmax_int8_kernels.cu b/src/fastertransformer/kernels/softmax_int8_kernels.cu index 0342e9454..b09e4c081 100644 --- a/src/fastertransformer/kernels/softmax_int8_kernels.cu +++ b/src/fastertransformer/kernels/softmax_int8_kernels.cu @@ -25,30 +25,30 @@ namespace fastertransformer { // block.x = max(32, (seq_len/4 + 31)/32*32) // for int32_t I; int8 O; template -__global__ void softmax_COL32(int8_t* output, +__global__ void softmax_COL32(int8_t* output, const int32_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* scalar1c, - const float* amax_ptr, - const int head_num_x_seq_len, - const int seq_len_x_seq_len) + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* scalar1c, + const float* amax_ptr, + const int head_num_x_seq_len, + const int seq_len_x_seq_len) { - const float amax = __ldg(amax_ptr); + const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b) * __ldg(scalar1c); - int mask_id; - int threadIdx4 = threadIdx.x << 2; + int mask_id; + int threadIdx4 = threadIdx.x << 2; char4* buf4Ptr = (char4*)output; bool qual = threadIdx4 < seq_len; for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { - char4 tmp4 = {0, 0, 0, 0}; - int inIdx = (blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len) + (threadIdx4 & 0xffffffe0) * seq_len + char4 tmp4 = {0, 0, 0, 0}; + int inIdx = (blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len) + (threadIdx4 & 0xffffffe0) * seq_len + (seq_id << 5) + (threadIdx4 & 31); // set softmax of padding word to 0 @@ -77,24 +77,24 @@ __global__ void softmax_COL32(int8_t* output, if (qual) { mask_id = threadIdx4 + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len; // for x - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f; floatTmp4.x = floatTmp4.x + mask_val; - max_val = fmaxf(max_val, floatTmp4.x); + max_val = fmaxf(max_val, floatTmp4.x); // for y - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 1))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 1))) * -10000.0f; floatTmp4.y = floatTmp4.y + mask_val; - max_val = fmaxf(max_val, floatTmp4.y); + max_val = fmaxf(max_val, floatTmp4.y); // for z - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 2))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 2))) * -10000.0f; floatTmp4.z = floatTmp4.z + mask_val; - max_val = fmaxf(max_val, floatTmp4.z); + max_val = fmaxf(max_val, floatTmp4.z); // for w - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 3))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 3))) * -10000.0f; floatTmp4.w = floatTmp4.w + mask_val; - max_val = fmaxf(max_val, floatTmp4.w); + max_val = fmaxf(max_val, floatTmp4.w); } max_val = blockDim.x <= 32 ? warpReduceMax(max_val) : blockReduceMax(max_val); @@ -143,33 +143,33 @@ __global__ void softmax_COL32(int8_t* output, // block.x = max(32, (seq_len_padded/4 + 31)/32*32) // for int8_t IO; template -__global__ void softmax_COL32_varlen(int8_t* output, +__global__ void softmax_COL32_varlen(int8_t* output, const int8_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const int seq_len_padded, - const float scalar1a, - const float* scalar1b, - const float* amax_ptr, - const int seq_len_x_seq_len, - const int seq_len_x_seq_len_padded) + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const int seq_len_padded, + const float scalar1a, + const float* scalar1b, + const float* amax_ptr, + const int seq_len_x_seq_len, + const int seq_len_x_seq_len_padded) { - const float amax = __ldg(amax_ptr); + const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b); - int mask_id; - int threadIdx4 = threadIdx.x << 2; + int mask_id; + int threadIdx4 = threadIdx.x << 2; - char4* buf4Ptr = (char4*)output; + char4* buf4Ptr = (char4*)output; const char4* inBuf4Ptr = (const char4*)input; - const bool qual = threadIdx4 < seq_len; + const bool qual = threadIdx4 < seq_len; const bool qual_padded = threadIdx4 < seq_len_padded; for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { - char4 tmp4 = {0, 0, 0, 0}; - int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + char4 tmp4 = {0, 0, 0, 0}; + int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + (threadIdx4 & 0xffffffe0) * seq_len + (seq_id << 5) + (threadIdx4 & 31)) >> 2; @@ -185,7 +185,7 @@ __global__ void softmax_COL32_varlen(int8_t* output, // set softmax of padding word in cols to 0 float4 floatTmp4 = {0.0f, 0.0f, 0.0f, 0.0f}; if (qual) { - tmp4 = __ldg(inBuf4Ptr + inIdx); + tmp4 = __ldg(inBuf4Ptr + inIdx); floatTmp4.x = static_cast(tmp4.x) * scalar1; floatTmp4.y = static_cast(tmp4.y) * scalar1; floatTmp4.z = static_cast(tmp4.z) * scalar1; @@ -200,24 +200,24 @@ __global__ void softmax_COL32_varlen(int8_t* output, if (qual) { mask_id = threadIdx4 + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len; // for x - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f; floatTmp4.x = floatTmp4.x + mask_val; - max_val = fmaxf(max_val, floatTmp4.x); + max_val = fmaxf(max_val, floatTmp4.x); // for y - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 1))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 1))) * -10000.0f; floatTmp4.y = floatTmp4.y + mask_val; - max_val = fmaxf(max_val, floatTmp4.y); + max_val = fmaxf(max_val, floatTmp4.y); // for z - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 2))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 2))) * -10000.0f; floatTmp4.z = floatTmp4.z + mask_val; - max_val = fmaxf(max_val, floatTmp4.z); + max_val = fmaxf(max_val, floatTmp4.z); // for w - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 3))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 3))) * -10000.0f; floatTmp4.w = floatTmp4.w + mask_val; - max_val = fmaxf(max_val, floatTmp4.w); + max_val = fmaxf(max_val, floatTmp4.w); } max_val = blockDim.x <= 32 ? warpReduceMax(max_val) : blockReduceMax(max_val); @@ -266,30 +266,30 @@ __global__ void softmax_COL32_varlen(int8_t* output, // block.x = max(32, (seq_len_padded + 31)/32*32) // for int8_t IO, I/O with int8_t element; template -__global__ void softmax_COL32_perElement_varlen(int8_t* output, +__global__ void softmax_COL32_perElement_varlen(int8_t* output, const int8_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const int seq_len_padded, - const float scalar1a, - const float* scalar1b, - const float* amax_ptr, - const int seq_len_x_seq_len, - const int seq_len_x_seq_len_padded) + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const int seq_len_padded, + const float scalar1a, + const float* scalar1b, + const float* amax_ptr, + const int seq_len_x_seq_len, + const int seq_len_x_seq_len_padded) { - const float amax = __ldg(amax_ptr); + const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b); - int mask_id; - const int tidx = threadIdx.x; + int mask_id; + const int tidx = threadIdx.x; - const bool qual = tidx < seq_len; + const bool qual = tidx < seq_len; const bool qual_padded = tidx < seq_len_padded; for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { int8_t tmp = 0; - int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + (tidx & 0xffffffe0) * seq_len + int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + (tidx & 0xffffffe0) * seq_len + (seq_id << 5) + (tidx & 31)); // set softmax of padding word in rows to 0 @@ -310,7 +310,7 @@ __global__ void softmax_COL32_perElement_varlen(int8_t* output, __shared__ float s_max, s_sum; if (qual) { - mask_id = tidx + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len; + mask_id = tidx + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len; mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f; floatTmp = floatTmp + mask_val; } @@ -335,7 +335,7 @@ __global__ void softmax_COL32_perElement_varlen(int8_t* output, __syncthreads(); if (qual_padded) { - tmp = qual ? float_to_int8_rn(floatTmp * s_sum) : static_cast(0); + tmp = qual ? float_to_int8_rn(floatTmp * s_sum) : static_cast(0); output[inIdx] = tmp; } } @@ -347,24 +347,24 @@ __global__ void softmax_COL32_perElement_varlen(int8_t* output, // for int32_t I; int8 O; // for seq_len <= 32 template -__global__ void softmax_COL32_LE32(int8_t* output, +__global__ void softmax_COL32_LE32(int8_t* output, const int32_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* scalar1c, - const float* amax_ptr, - const int head_num_x_seq_len, - const int seq_len_x_seq_len) + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* scalar1c, + const float* amax_ptr, + const int head_num_x_seq_len, + const int seq_len_x_seq_len) { - const float amax = __ldg(amax_ptr); + const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b) * __ldg(scalar1c); - int mask_id; - int threadIdxx = threadIdx.x; - bool qual = threadIdxx < seq_len; + int mask_id; + int threadIdxx = threadIdx.x; + bool qual = threadIdxx < seq_len; for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { int inIdx = (blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len) + (threadIdxx & 0xffffffe0) * seq_len + (seq_id << 5) + (threadIdxx & 31); @@ -384,10 +384,10 @@ __global__ void softmax_COL32_LE32(int8_t* output, __shared__ float s_max, s_sum; - mask_id = qual ? threadIdxx + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len : 0; + mask_id = qual ? threadIdxx + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len : 0; mask_val = qual ? (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f : 0.0f; floatTmp = qual ? floatTmp + mask_val : 0.0f; - max_val = qual ? floatTmp : -1e20f; + max_val = qual ? floatTmp : -1e20f; max_val = blockDim.x <= 32 ? warpReduceMax(max_val) : blockReduceMax(max_val); @@ -420,25 +420,25 @@ __global__ void softmax_COL32_LE32(int8_t* output, // for int8_t IO; // for seq_len_padded == 32 template -__global__ void softmax_COL32_LE32_varlen(int8_t* output, +__global__ void softmax_COL32_LE32_varlen(int8_t* output, const int8_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const int seq_len_padded, - const float scalar1a, - const float* scalar1b, - const float* amax_ptr, - const int seq_len_x_seq_len, - const int seq_len_x_seq_len_padded) + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const int seq_len_padded, + const float scalar1a, + const float* scalar1b, + const float* amax_ptr, + const int seq_len_x_seq_len, + const int seq_len_x_seq_len_padded) { - const float amax = __ldg(amax_ptr); + const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b); - int mask_id; - int threadIdxx = threadIdx.x; - const bool qual = threadIdxx < seq_len; - const bool qual_padded = threadIdxx < seq_len_padded; + int mask_id; + int threadIdxx = threadIdx.x; + const bool qual = threadIdxx < seq_len; + const bool qual_padded = threadIdxx < seq_len_padded; for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { int inIdx = (blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + (threadIdxx & 0xffffffe0) * seq_len + (seq_id << 5) + (threadIdxx & 31); @@ -452,15 +452,15 @@ __global__ void softmax_COL32_LE32_varlen(int8_t* output, continue; } - float mask_val, max_val; + float mask_val, max_val; __shared__ float s_max, s_sum; // set softmax of padding word in cols to 0 float floatTmp = qual ? static_cast(__ldg(input + inIdx)) * scalar1 : 0.0f; - mask_id = qual ? threadIdxx + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len : 0; - mask_val = qual ? (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f : 0.0f; - floatTmp = qual ? floatTmp + mask_val : 0.0f; - max_val = qual ? floatTmp : -1e20f; + mask_id = qual ? threadIdxx + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len : 0; + mask_val = qual ? (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f : 0.0f; + floatTmp = qual ? floatTmp + mask_val : 0.0f; + max_val = qual ? floatTmp : -1e20f; max_val = warpReduceMax(max_val); @@ -491,30 +491,30 @@ __global__ void softmax_COL32_LE32_varlen(int8_t* output, // for int32_t I; int8 O; // for seq_len in (32, 64] template -__global__ void softmax_COL32_LE64(int8_t* output, +__global__ void softmax_COL32_LE64(int8_t* output, const int32_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* scalar1c, - const float* amax_ptr, - const int head_num_x_seq_len, - const int seq_len_x_seq_len) + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* scalar1c, + const float* amax_ptr, + const int head_num_x_seq_len, + const int seq_len_x_seq_len) { - const float amax = __ldg(amax_ptr); + const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b) * __ldg(scalar1c); - int mask_id; - int threadIdx2 = threadIdx.x << 1; + int mask_id; + int threadIdx2 = threadIdx.x << 1; char2* buf2Ptr = (char2*)output; bool qual = threadIdx2 < seq_len; for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { - char2 tmp2 = {0, 0}; - int inIdx = (blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len) + (threadIdx2 & 0xffffffe0) * seq_len + char2 tmp2 = {0, 0}; + int inIdx = (blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len) + (threadIdx2 & 0xffffffe0) * seq_len + (seq_id << 5) + (threadIdx2 & 31); // set softmax of padding word to 0 @@ -540,11 +540,11 @@ __global__ void softmax_COL32_LE64(int8_t* output, if (qual) { mask_id = threadIdx2 + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len; // for x - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f; floatTmp2.x = floatTmp2.x + mask_val; // for y - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 1))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 1))) * -10000.0f; floatTmp2.y = floatTmp2.y + mask_val; max_val = fmaxf(floatTmp2.x, floatTmp2.y); @@ -575,8 +575,8 @@ __global__ void softmax_COL32_LE64(int8_t* output, __syncthreads(); if (qual) { - tmp2.x = float_to_int8_rn(floatTmp2.x * s_sum); - tmp2.y = float_to_int8_rn(floatTmp2.y * s_sum); + tmp2.x = float_to_int8_rn(floatTmp2.x * s_sum); + tmp2.y = float_to_int8_rn(floatTmp2.y * s_sum); buf2Ptr[inIdx >> 1] = tmp2; } } @@ -589,32 +589,32 @@ __global__ void softmax_COL32_LE64(int8_t* output, // for int8_t IO // for seq_len in (32, 64] template -__global__ void softmax_COL32_LE64_varlen(int8_t* output, +__global__ void softmax_COL32_LE64_varlen(int8_t* output, const int8_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const int seq_len_padded, - const float scalar1a, - const float* scalar1b, - const float* amax_ptr, - const int seq_len_x_seq_len, - const int seq_len_x_seq_len_padded) + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const int seq_len_padded, + const float scalar1a, + const float* scalar1b, + const float* amax_ptr, + const int seq_len_x_seq_len, + const int seq_len_x_seq_len_padded) { - const float amax = __ldg(amax_ptr); + const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b); - int mask_id; - int threadIdx2 = threadIdx.x << 1; + int mask_id; + int threadIdx2 = threadIdx.x << 1; - char2* buf2Ptr = (char2*)output; + char2* buf2Ptr = (char2*)output; const char2* inBuf2Ptr = (const char2*)input; - const bool qual = threadIdx2 < seq_len; + const bool qual = threadIdx2 < seq_len; const bool qual_padded = threadIdx2 < seq_len_padded; for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { - char2 tmp2 = {0, 0}; - int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + char2 tmp2 = {0, 0}; + int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + (threadIdx2 & 0xffffffe0) * seq_len + (seq_id << 5) + (threadIdx2 & 31)) >> 1; @@ -630,7 +630,7 @@ __global__ void softmax_COL32_LE64_varlen(int8_t* output, // set softmax of padding word in cols to 0 float2 floatTmp2 = {0.0f, 0.0f}; if (qual) { - tmp2 = __ldg(inBuf2Ptr + inIdx); + tmp2 = __ldg(inBuf2Ptr + inIdx); floatTmp2.x = static_cast(tmp2.x) * scalar1; floatTmp2.y = static_cast(tmp2.y) * scalar1; } @@ -643,11 +643,11 @@ __global__ void softmax_COL32_LE64_varlen(int8_t* output, if (qual) { mask_id = threadIdx2 + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len; // for x - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id))) * -10000.0f; floatTmp2.x = floatTmp2.x + mask_val; // for y - mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 1))) * -10000.0f; + mask_val = (1.0f - static_cast(__ldg(attr_mask + mask_id + 1))) * -10000.0f; floatTmp2.y = floatTmp2.y + mask_val; max_val = fmaxf(floatTmp2.x, floatTmp2.y); @@ -678,25 +678,25 @@ __global__ void softmax_COL32_LE64_varlen(int8_t* output, __syncthreads(); if (qual_padded) { - tmp2.x = qual ? float_to_int8_rn(floatTmp2.x * s_sum) : static_cast(0); - tmp2.y = qual ? float_to_int8_rn(floatTmp2.y * s_sum) : static_cast(0); + tmp2.x = qual ? float_to_int8_rn(floatTmp2.x * s_sum) : static_cast(0); + tmp2.y = qual ? float_to_int8_rn(floatTmp2.y * s_sum) : static_cast(0); buf2Ptr[inIdx] = tmp2; } } } template -void invokeSoftmaxCOL32(int8_t* output, +void invokeSoftmaxCOL32(int8_t* output, const int32_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* scalar1c, - const float* amax_ptr, - cudaStream_t stream) + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* scalar1c, + const float* amax_ptr, + cudaStream_t stream) { dim3 grid, block; grid.x = seq_len; @@ -758,46 +758,46 @@ void invokeSoftmaxCOL32(int8_t* output, } } -template void invokeSoftmaxCOL32(int8_t* output, +template void invokeSoftmaxCOL32(int8_t* output, const int32_t* input, - const float* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* scalar1c, - const float* amax_ptr, - cudaStream_t stream); - -template void invokeSoftmaxCOL32(int8_t* output, + const float* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* scalar1c, + const float* amax_ptr, + cudaStream_t stream); + +template void invokeSoftmaxCOL32(int8_t* output, const int32_t* input, - const half* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* scalar1c, - const float* amax_ptr, - cudaStream_t stream); + const half* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* scalar1c, + const float* amax_ptr, + cudaStream_t stream); template -void invokeSoftmaxCOL32(int8_t* output, +void invokeSoftmaxCOL32(int8_t* output, const int8_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* amax_ptr, - cudaStream_t stream) + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* amax_ptr, + cudaStream_t stream) { dim3 grid, block; - grid.x = seq_len; - grid.y = batch_size; - grid.z = head_num; + grid.x = seq_len; + grid.y = batch_size; + grid.z = head_num; const int seq_len_padded = (seq_len + 31) / 32 * 32; if (seq_len <= 32) { @@ -868,27 +868,27 @@ void invokeSoftmaxCOL32(int8_t* output, } } -template void invokeSoftmaxCOL32(int8_t* output, +template void invokeSoftmaxCOL32(int8_t* output, const int8_t* input, - const float* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* amax_ptr, - cudaStream_t stream); - -template void invokeSoftmaxCOL32(int8_t* output, + const float* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* amax_ptr, + cudaStream_t stream); + +template void invokeSoftmaxCOL32(int8_t* output, const int8_t* input, - const half* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* amax_ptr, - cudaStream_t stream); + const half* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* amax_ptr, + cudaStream_t stream); /******************* invokeSoftmaxCOL32 ***********************/ @@ -898,36 +898,36 @@ template void invokeSoftmaxCOL32(int8_t* output, // attn_mask is [window_num, window_len, window_len] + row-major // relative_pos_bias is [num_head, window_len, window_len] + row-majot template -__global__ void softmax_INT8IO_kernel_COL32(int8_t* a_buf, - int8_t* qk_buf_int8, - const T* attn_mask, - const T* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - const int window_len_x_window_len, - const float scalar, +__global__ void softmax_INT8IO_kernel_COL32(int8_t* a_buf, + int8_t* qk_buf_int8, + const T* attn_mask, + const T* relative_pos_bias, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const int window_len_x_window_len, + const float scalar, const float* deQ_scale_ptr, const float* out_scale_ptr) { - bool qual = threadIdx.x < window_len; + bool qual = threadIdx.x < window_len; const int padded_winlen = (window_len + 31) / 32 * 32; for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) { - float tmp = -1e20f; + float tmp = -1e20f; __shared__ float s_mean, s_max; - int qk_offset = (blockIdx.z * gridDim.y + blockIdx.y) * window_len * padded_winlen + int qk_offset = (blockIdx.z * gridDim.y + blockIdx.y) * window_len * padded_winlen + ((threadIdx.x >> 5) << 5) * window_len + (window_id << 5) + (threadIdx.x & 31); ; if (qual) { const int offset_in_window = window_id * window_len + threadIdx.x; const int relative_pos_bias_offset = (blockIdx.y % num_head) * window_len_x_window_len + offset_in_window; - float mask_val = + float mask_val = (attn_mask == nullptr) ? - 0.0f : - static_cast( + 0.0f : + static_cast( __ldg(attn_mask + ((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window))); tmp = scalar * static_cast(qk_buf_int8[qk_offset]) * __ldg(deQ_scale_ptr) + mask_val + static_cast(__ldg(relative_pos_bias + relative_pos_bias_offset)); @@ -939,7 +939,7 @@ __global__ void softmax_INT8IO_kernel_COL32(int8_t* a_buf, } __syncthreads(); - float qk_tmp = qual ? __expf(tmp - s_max) : 0.0f; + float qk_tmp = qual ? __expf(tmp - s_max) : 0.0f; float sum_val = blockReduceSum(qk_tmp); if (threadIdx.x == 0) { s_mean = sum_val + 1e-6f; @@ -951,15 +951,15 @@ __global__ void softmax_INT8IO_kernel_COL32(int8_t* a_buf, } template -void invokeSoftmaxWithRelPosBiasCOL32(int8_t* a_buf, - int8_t* qk_buf_int8, - const T* attn_mask, - const T* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - const float scalar, +void invokeSoftmaxWithRelPosBiasCOL32(int8_t* a_buf, + int8_t* qk_buf_int8, + const T* attn_mask, + const T* relative_pos_bias, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const float scalar, const float* deQ_scale_ptr, const float* out_scale_ptr, cudaStream_t stream) @@ -980,28 +980,28 @@ void invokeSoftmaxWithRelPosBiasCOL32(int8_t* a_buf, out_scale_ptr); } -template void invokeSoftmaxWithRelPosBiasCOL32(int8_t* a_buf, - int8_t* qk_buf_int8, +template void invokeSoftmaxWithRelPosBiasCOL32(int8_t* a_buf, + int8_t* qk_buf_int8, const float* attn_mask, const float* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - const float scalar, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const float scalar, const float* deQ_scale_ptr, const float* output_scale_ptr, cudaStream_t stream); -template void invokeSoftmaxWithRelPosBiasCOL32(int8_t* a_buf, - int8_t* qk_buf_int8, - const half* attn_mask, - const half* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - const float scalar, +template void invokeSoftmaxWithRelPosBiasCOL32(int8_t* a_buf, + int8_t* qk_buf_int8, + const half* attn_mask, + const half* relative_pos_bias, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const float scalar, const float* deQ_scale_ptr, const float* output_scale_ptr, cudaStream_t stream); diff --git a/src/fastertransformer/kernels/softmax_int8_kernels.h b/src/fastertransformer/kernels/softmax_int8_kernels.h index 65d641925..54f70f3b1 100644 --- a/src/fastertransformer/kernels/softmax_int8_kernels.h +++ b/src/fastertransformer/kernels/softmax_int8_kernels.h @@ -23,40 +23,40 @@ namespace fastertransformer { template -void invokeSoftmaxCOL32(int8_t* output, +void invokeSoftmaxCOL32(int8_t* output, const int32_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* scalar1c, - const float* amax_ptr, - cudaStream_t stream); + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* scalar1c, + const float* amax_ptr, + cudaStream_t stream); template -void invokeSoftmaxCOL32(int8_t* output, +void invokeSoftmaxCOL32(int8_t* output, const int8_t* input, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const float scalar1a, - const float* scalar1b, - const float* amax_ptr, - cudaStream_t stream); + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len, + const float scalar1a, + const float* scalar1b, + const float* amax_ptr, + cudaStream_t stream); template -void invokeSoftmaxWithRelPosBiasCOL32(int8_t* a_buf, - int8_t* qk_buf_int8, - const T* attn_mask, - const T* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - const float scalar, +void invokeSoftmaxWithRelPosBiasCOL32(int8_t* a_buf, + int8_t* qk_buf_int8, + const T* attn_mask, + const T* relative_pos_bias, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const float scalar, const float* deQ_scale_ptr, const float* out_scale_ptr, cudaStream_t stream); diff --git a/src/fastertransformer/kernels/stop_criteria_kernels.cu b/src/fastertransformer/kernels/stop_criteria_kernels.cu index 43b291a72..5420e90e3 100644 --- a/src/fastertransformer/kernels/stop_criteria_kernels.cu +++ b/src/fastertransformer/kernels/stop_criteria_kernels.cu @@ -14,43 +14,54 @@ * limitations under the License. */ +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#elif (CUDART_VERSION >= 11050) +#include +#else +#include "3rdparty/cub/cub.cuh" +#endif + #include "src/fastertransformer/kernels/stop_criteria_kernels.h" #include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/memory_utils.h" namespace fastertransformer { +constexpr int LENGTH_CRITERION_BLOCKSIZE = 256; + __global__ void stop_words_criterion(const int* output_ids, const int* parent_ids, const int* stop_words, - bool* finished, - size_t id_offset, - size_t stop_words_len, - int batch_size, - int beam_width, - int step) + bool* finished, + size_t id_offset, + size_t stop_words_len, + int batch_size, + int beam_width, + int step) { - const int id = blockIdx.x * blockDim.x + threadIdx.x; + const int id = blockIdx.x * blockDim.x + threadIdx.x; const int batch_idx = blockIdx.y / beam_width; - const int beam_idx = blockIdx.y % beam_width; + const int beam_idx = blockIdx.y % beam_width; const int* base_stop_words = stop_words + batch_idx * 2 * stop_words_len; - const int* base_offsets = base_stop_words + stop_words_len; + const int* base_offsets = base_stop_words + stop_words_len; if (id >= stop_words_len || base_offsets[id] < 0) { return; } - const int item_end = base_offsets[id]; + const int item_end = base_offsets[id]; const int item_start = (id > 0) ? base_offsets[id - 1] : 0; - const int item_size = item_end - item_start; + const int item_size = item_end - item_start; /* The single-token case unconditionally bans the token */ bool should_stop = false; /* Enough previously generated tokens to look for a match */ if (step + 1 >= item_size) { - should_stop = true; - int parent_id = beam_idx; + should_stop = true; + int parent_id = beam_idx; const bool gather_beam = beam_width > 1; for (int token_idx = item_size - 1; token_idx >= 0; token_idx--) { @@ -78,25 +89,75 @@ __global__ void stop_words_criterion(const int* output_ids, } } -void invokeStopWordsCriterion(const int* output_ids, - const int* parent_ids, - const int* stop_words, - bool* finished, - size_t id_offset, - size_t stop_words_len, - int batch_size, - int beam_width, - int step, +void invokeStopWordsCriterion(const int* output_ids, + const int* parent_ids, + const int* stop_words, + bool* finished, + size_t id_offset, + size_t stop_words_len, + int batch_size, + int beam_width, + int step, cudaStream_t stream) { + // Check if we have sampled a word from the stop_words list. If so, stop the sequence. dim3 block, grid; block.x = min(((stop_words_len + 32 - 1) / 32) * 32, 256UL); - grid.x = (stop_words_len + block.x - 1) / block.x; - grid.y = batch_size * beam_width; + grid.x = (stop_words_len + block.x - 1) / block.x; + grid.y = batch_size * beam_width; stop_words_criterion<<>>( output_ids, parent_ids, stop_words, finished, id_offset, stop_words_len, batch_size, beam_width, step); sync_check_cuda_error(); } +__global__ void length_criterion(bool* finished, + bool* should_stop, + int* finished_sum, + const uint32_t* sequence_limit_length, + int batch_size, + int beam_width, + int step) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int index = blockIdx.x * blockDim.x + threadIdx.x; + + // const int beam_idx = index % beam_width; + const int batch_idx = index / beam_width; + + if (index >= batch_size * beam_width) { + return; + } + + finished[index] |= step >= sequence_limit_length[batch_idx]; + + int agg = BlockReduce(temp_storage).Sum((int)finished[index]); + atomicAdd(finished_sum, agg); +} + +void invokeLengthCriterion(bool* finished, + bool* should_stop, + int* finished_sum, + const uint32_t* sequence_limit_length, + int batch_size, + int beam_width, + int step, + cudaStream_t stream) +{ + // Check if we have attained the sequence length limit. If so, stop the sequence. + // In addition, check if all sequences are stopped and return the result in should_stop + dim3 block{LENGTH_CRITERION_BLOCKSIZE}; + dim3 grid{(batch_size * beam_width + block.x - 1) / block.x}; + + length_criterion<<>>( + finished, should_stop, finished_sum, sequence_limit_length, batch_size, beam_width, step); + + int h_finished_sum = 0; + cudaD2Hcpy(&h_finished_sum, finished_sum, 1); + + *should_stop = h_finished_sum == batch_size * beam_width; +} + } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/stop_criteria_kernels.h b/src/fastertransformer/kernels/stop_criteria_kernels.h index 88ee1edfd..2723284c0 100644 --- a/src/fastertransformer/kernels/stop_criteria_kernels.h +++ b/src/fastertransformer/kernels/stop_criteria_kernels.h @@ -19,15 +19,24 @@ namespace fastertransformer { -void invokeStopWordsCriterion(const int* output_ids, - const int* parent_ids, - const int* stop_words, - bool* finished, - size_t id_offset, - size_t stop_words_len, - int batch_size, - int beam_width, - int step, +void invokeStopWordsCriterion(const int* output_ids, + const int* parent_ids, + const int* stop_words, + bool* finished, + size_t id_offset, + size_t stop_words_len, + int batch_size, + int beam_width, + int step, cudaStream_t stream); +void invokeLengthCriterion(bool* finished, + bool* should_stop, + int* finished_sum, + const uint32_t* sequence_limit_length, + int batch_size, + int beam_width, + int step, + cudaStream_t stream); + } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/transform_mask_kernels.cu b/src/fastertransformer/kernels/transform_mask_kernels.cu index b645049d5..96c3efb19 100644 --- a/src/fastertransformer/kernels/transform_mask_kernels.cu +++ b/src/fastertransformer/kernels/transform_mask_kernels.cu @@ -27,8 +27,8 @@ namespace fastertransformer { // in transformed_mask, the masks of one warp are stored in 4 continuous rows ([4, 64]), with two elements of one thread // stored in 2 continuous halfs. one cta calculates warps_m*warps_n mma == 16*warps_m*16*warps_n elements grid(B, // S2*S2/64) block(32) -__global__ void transform_mask_kernel(half2* tranformed_mask, - const half2* mask, +__global__ void transform_mask_kernel(half2* tranformed_mask, + const half2* mask, const uint32_t warps_m, const uint32_t warps_n, const uint32_t B, @@ -36,31 +36,31 @@ __global__ void transform_mask_kernel(half2* tranformed_mask, const uint32_t S2) { const int bi = blockIdx.x; - const int r = blockIdx.y; + const int r = blockIdx.y; - const int N_per_XMMAS = warps_n << 4; - const int M_per_XMMAS = warps_m << 4; - const int N_XMMAS = (S2 + N_per_XMMAS - 1) / (N_per_XMMAS); - const int warps_in_XMMAS = warps_m * warps_n; - const half2* mask_b = mask + ((bi * S * S) >> 1); - half2* tranformed_mask_b = tranformed_mask + (bi * gridDim.y << 5); //((bi * gridDim.y << 6) >> 1); + const int N_per_XMMAS = warps_n << 4; + const int M_per_XMMAS = warps_m << 4; + const int N_XMMAS = (S2 + N_per_XMMAS - 1) / (N_per_XMMAS); + const int warps_in_XMMAS = warps_m * warps_n; + const half2* mask_b = mask + ((bi * S * S) >> 1); + half2* tranformed_mask_b = tranformed_mask + (bi * gridDim.y << 5); //((bi * gridDim.y << 6) >> 1); half2 tmp = {half(-30000.0f), half(-30000.0f)}; - int c = threadIdx.x * 2; - int elt_offset = c % 2; - int warp_id = r / 4; - int elt_in_thread = (r % 4) * 2 + elt_offset; + int c = threadIdx.x * 2; + int elt_offset = c % 2; + int warp_id = r / 4; + int elt_in_thread = (r % 4) * 2 + elt_offset; int noffset_in_warp = (((elt_in_thread & 3) >> 1) << 3) + (elt_in_thread & 1); int moffset_in_warp = ((elt_in_thread >> 2) & 1) << 3; - int XMMAS_mi = warp_id / (N_XMMAS * warps_in_XMMAS); - int XMMAS_ni = warp_id % (N_XMMAS * warps_in_XMMAS) / warps_in_XMMAS; + int XMMAS_mi = warp_id / (N_XMMAS * warps_in_XMMAS); + int XMMAS_ni = warp_id % (N_XMMAS * warps_in_XMMAS) / warps_in_XMMAS; int warp_id_in_XMMAS = warp_id - (XMMAS_mi * N_XMMAS + XMMAS_ni) * warps_in_XMMAS; - int warp_mi = warp_id_in_XMMAS % warps_m; - int warp_ni = warp_id_in_XMMAS / warps_m; - int noffset = XMMAS_ni * N_per_XMMAS + (warp_ni << 4) + noffset_in_warp; - int moffset = XMMAS_mi * M_per_XMMAS + (warp_mi << 4) + moffset_in_warp; + int warp_mi = warp_id_in_XMMAS % warps_m; + int warp_ni = warp_id_in_XMMAS / warps_m; + int noffset = XMMAS_ni * N_per_XMMAS + (warp_ni << 4) + noffset_in_warp; + int moffset = XMMAS_mi * M_per_XMMAS + (warp_mi << 4) + moffset_in_warp; int mi = moffset + (c >> 3); int ni = noffset + (((c >> 1) & 3) << 1); @@ -78,8 +78,8 @@ __global__ void transform_mask_kernel(half2* tranformed_mask, // in transformed_mask, the masks of one warp are stored in 4 continuous rows ([4, 64]), with two elements of one thread // stored in 2 continuous halfs. one cta calculates warps_m*warps_n mma == 16*warps_m*16*warps_n elements grid(B, // S2*S2/64) block(32) -__global__ void transform_mask_kernel(half* tranformed_mask, - const half* mask, +__global__ void transform_mask_kernel(half* tranformed_mask, + const half* mask, const uint32_t warps_m, const uint32_t warps_n, const uint32_t B, @@ -87,30 +87,30 @@ __global__ void transform_mask_kernel(half* tranformed_mask, const uint32_t S2) { const int bi = blockIdx.x; - const int r = blockIdx.y; + const int r = blockIdx.y; - const int N_per_XMMAS = warps_n << 4; - const int M_per_XMMAS = warps_m << 4; - const int N_XMMAS = (S2 + N_per_XMMAS - 1) / (N_per_XMMAS); - const int warps_in_XMMAS = warps_m * warps_n; - half2* tranformed_mask_b = (half2*)(tranformed_mask + (bi * gridDim.y << 6)); + const int N_per_XMMAS = warps_n << 4; + const int M_per_XMMAS = warps_m << 4; + const int N_XMMAS = (S2 + N_per_XMMAS - 1) / (N_per_XMMAS); + const int warps_in_XMMAS = warps_m * warps_n; + half2* tranformed_mask_b = (half2*)(tranformed_mask + (bi * gridDim.y << 6)); half2 tmp = {half(-30000.0f), half(-30000.0f)}; - int c = threadIdx.x * 2; - int elt_offset = c % 2; - int warp_id = r / 4; - int elt_in_thread = (r % 4) * 2 + elt_offset; + int c = threadIdx.x * 2; + int elt_offset = c % 2; + int warp_id = r / 4; + int elt_in_thread = (r % 4) * 2 + elt_offset; int noffset_in_warp = (((elt_in_thread & 3) >> 1) << 3) + (elt_in_thread & 1); int moffset_in_warp = ((elt_in_thread >> 2) & 1) << 3; - int XMMAS_mi = warp_id / (N_XMMAS * warps_in_XMMAS); - int XMMAS_ni = warp_id % (N_XMMAS * warps_in_XMMAS) / warps_in_XMMAS; + int XMMAS_mi = warp_id / (N_XMMAS * warps_in_XMMAS); + int XMMAS_ni = warp_id % (N_XMMAS * warps_in_XMMAS) / warps_in_XMMAS; int warp_id_in_XMMAS = warp_id - (XMMAS_mi * N_XMMAS + XMMAS_ni) * warps_in_XMMAS; - int warp_mi = warp_id_in_XMMAS % warps_m; - int warp_ni = warp_id_in_XMMAS / warps_m; - int noffset = XMMAS_ni * N_per_XMMAS + (warp_ni << 4) + noffset_in_warp; - int moffset = XMMAS_mi * M_per_XMMAS + (warp_mi << 4) + moffset_in_warp; + int warp_mi = warp_id_in_XMMAS % warps_m; + int warp_ni = warp_id_in_XMMAS / warps_m; + int noffset = XMMAS_ni * N_per_XMMAS + (warp_ni << 4) + noffset_in_warp; + int moffset = XMMAS_mi * M_per_XMMAS + (warp_mi << 4) + moffset_in_warp; int mi = moffset + (c >> 3); int ni = noffset + (((c >> 1) & 3) << 1); @@ -141,17 +141,17 @@ void invokeTransformMask( S2 = 128; } else if (S <= 256) { - S2 = 256; + S2 = 256; warps_m = 1; warps_n = 4; } else if (S <= 384) { - S2 = 384; + S2 = 384; warps_m = 1; warps_n = 8; } else { - printf("[ERROR][invokeTransformMask]unsupport seq_len %d\n", S); + printf("[ERROR][invokeTransformMask]unsupported seq_len %d\n", S); exit(-1); } assert(S2 * S2 % 64 == 0); diff --git a/src/fastertransformer/kernels/transpose_int8_kernels.cu b/src/fastertransformer/kernels/transpose_int8_kernels.cu index 7de8300f8..8d6e6020c 100644 --- a/src/fastertransformer/kernels/transpose_int8_kernels.cu +++ b/src/fastertransformer/kernels/transpose_int8_kernels.cu @@ -23,23 +23,23 @@ namespace fastertransformer { // grid(seq_len, batch_size) // block(size_per_head/4, head_num) // assume size_per_head is multiples of 32 -__global__ void transpose_COL32_kernel(char4* dst, - const int4* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +__global__ void transpose_COL32_kernel(char4* dst, + const int4* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, const float* v_buf_addBias_deQFactor, const float* qk_afterSM_deQFactor, const float* out_scale_ptr, - const int batch_size_x_seq_len, - const int seq_len_x_size_per_head) + const int batch_size_x_seq_len, + const int seq_len_x_size_per_head) { - const float scale = __ldg(v_buf_addBias_deQFactor) * __ldg(qk_afterSM_deQFactor) * __ldg(out_scale_ptr); - int threadIdx4 = threadIdx.x << 2; - int batch_id = blockIdx.y; - int seq_id = blockIdx.x; - int head_id = threadIdx.y; + const float scale = __ldg(v_buf_addBias_deQFactor) * __ldg(qk_afterSM_deQFactor) * __ldg(out_scale_ptr); + int threadIdx4 = threadIdx.x << 2; + int batch_id = blockIdx.y; + int seq_id = blockIdx.x; + int head_id = threadIdx.y; // get the (row, col) output layout of m*k // m = batch_size*seq_len // k = head_num*size_per_head @@ -64,19 +64,19 @@ __global__ void transpose_COL32_kernel(char4* dst, char4 tmp; int4 srcTmp4 = __ldg(src + inIdx); - tmp.x = float_to_int8_rn(srcTmp4.x * scale); - tmp.y = float_to_int8_rn(srcTmp4.y * scale); - tmp.z = float_to_int8_rn(srcTmp4.z * scale); - tmp.w = float_to_int8_rn(srcTmp4.w * scale); - dst[outIdx] = tmp; + tmp.x = float_to_int8_rn(srcTmp4.x * scale); + tmp.y = float_to_int8_rn(srcTmp4.y * scale); + tmp.z = float_to_int8_rn(srcTmp4.z * scale); + tmp.w = float_to_int8_rn(srcTmp4.w * scale); + dst[outIdx] = tmp; } -void invokeTransposeCOL32(int8_t* dst, - const int* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +void invokeTransposeCOL32(int8_t* dst, + const int* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, const float* v_buf_addBias_deQFactor, const float* qk_afterSM_deQFactor, const float* out_scale_ptr, @@ -102,21 +102,21 @@ void invokeTransposeCOL32(int8_t* dst, // grid(seq_len, batch_size) // block(size_per_head/4, head_num) // assume size_per_head is multiples of 32 -__global__ void transpose_COL32_kernel(int8_t* dst, +__global__ void transpose_COL32_kernel(int8_t* dst, const int8_t* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* bmm2_deQFactor, - const float* out_scale_ptr, - const int batch_size_x_seq_len, - const int seq_len_x_size_per_head) + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* bmm2_deQFactor, + const float* out_scale_ptr, + const int batch_size_x_seq_len, + const int seq_len_x_size_per_head) { int threadIdx4 = threadIdx.x << 2; - int batch_id = blockIdx.y; - int seq_id = blockIdx.x; - int head_id = threadIdx.y; + int batch_id = blockIdx.y; + int seq_id = blockIdx.x; + int head_id = threadIdx.y; // get the (row, col) output layout of m*k // m = batch_size*seq_len // k = head_num*size_per_head @@ -125,7 +125,7 @@ __global__ void transpose_COL32_kernel(int8_t* dst, // get the (row, col) layout of COL32; leading dimension = 32*m = 32*batch_size*seq_len int COL32_row = (mk_row << 5) + (mk_col & 31); int COL32_col = mk_col >> 5; - int outIdx = ((COL32_col << 5) * batch_size_x_seq_len + COL32_row) >> 2; + int outIdx = ((COL32_col << 5) * batch_size_x_seq_len + COL32_row) >> 2; // get the (row, col) input layout of m'*k' // m' = seq_len @@ -139,19 +139,19 @@ __global__ void transpose_COL32_kernel(int8_t* dst, int inIdx = ((batch_id * head_num + head_id) * seq_len_x_size_per_head + (COL32_col << 5) * seq_len + COL32_row) >> 2; const char4* src_ptr4 = (const char4*)src; - char4* dst_ptr4 = (char4*)dst; - dst_ptr4[outIdx] = __ldg(src_ptr4 + inIdx); + char4* dst_ptr4 = (char4*)dst; + dst_ptr4[outIdx] = __ldg(src_ptr4 + inIdx); } -void invokeTransposeCOL32(int8_t* dst, +void invokeTransposeCOL32(int8_t* dst, const int8_t* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* bmm2_deQFactor, - const float* out_scale_ptr, - cudaStream_t stream) + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* bmm2_deQFactor, + const float* out_scale_ptr, + cudaStream_t stream) { assert(size_per_head % 32 == 0); transpose_COL32_kernel<<>>( @@ -172,24 +172,24 @@ void invokeTransposeCOL32(int8_t* dst, // grid(seq_len, batch_size) // block(size_per_head/4, head_num) // assume size_per_head is multiples of 32 -__global__ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, +__global__ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, const int32_t* src, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* v_buf_addBias_deQFactor, - const float* qk_afterSM_deQFactor, - const float* out_scale_ptr, - const int seq_len_x_size_per_head) + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* v_buf_addBias_deQFactor, + const float* qk_afterSM_deQFactor, + const float* out_scale_ptr, + const int seq_len_x_size_per_head) { - const float scale = __ldg(v_buf_addBias_deQFactor) * __ldg(qk_afterSM_deQFactor) * __ldg(out_scale_ptr); - int threadIdx4 = threadIdx.x << 2; - int batch_id = blockIdx.y; - int seq_id = blockIdx.x; - int head_id = threadIdx.y; + const float scale = __ldg(v_buf_addBias_deQFactor) * __ldg(qk_afterSM_deQFactor) * __ldg(out_scale_ptr); + int threadIdx4 = threadIdx.x << 2; + int batch_id = blockIdx.y; + int seq_id = blockIdx.x; + int head_id = threadIdx.y; // get the (row, col) output layout of m*k // m = valid_word_num // k = head_num*size_per_head @@ -199,7 +199,7 @@ __global__ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, // get the (row, col) layout of COL32; leading dimension = 32*m = 32*valid_word_num int COL32_row = (mk_row << 5) + (mk_col & 31); int COL32_col = mk_col >> 5; - int outIdx = ((COL32_col << 5) * valid_word_num + COL32_row) >> 2; + int outIdx = ((COL32_col << 5) * valid_word_num + COL32_row) >> 2; // get the (row, col) input layout of m'*k' // m' = seq_len @@ -212,23 +212,23 @@ __global__ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, int inIdx = (batch_id * head_num + head_id) * seq_len_x_size_per_head + (COL32_col << 5) * seq_len + COL32_row; char4 tmp; - tmp.x = float_to_int8_rn(__ldg(src + inIdx) * scale); - tmp.y = float_to_int8_rn(__ldg(src + inIdx + 1) * scale); - tmp.z = float_to_int8_rn(__ldg(src + inIdx + 2) * scale); - tmp.w = float_to_int8_rn(__ldg(src + inIdx + 3) * scale); - char4* dst_ptr4 = (char4*)dst; + tmp.x = float_to_int8_rn(__ldg(src + inIdx) * scale); + tmp.y = float_to_int8_rn(__ldg(src + inIdx + 1) * scale); + tmp.z = float_to_int8_rn(__ldg(src + inIdx + 2) * scale); + tmp.w = float_to_int8_rn(__ldg(src + inIdx + 3) * scale); + char4* dst_ptr4 = (char4*)dst; dst_ptr4[outIdx] = tmp; } } -void invokeTransposeCOL32RebuildPadding(int8_t* dst, - const int* src, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +void invokeTransposeCOL32RebuildPadding(int8_t* dst, + const int* src, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, const float* v_buf_addBias_deQFactor, const float* qk_afterSM_deQFactor, const float* out_scale_ptr, @@ -255,22 +255,22 @@ void invokeTransposeCOL32RebuildPadding(int8_t* dst, // grid(seq_len, batch_size) // block(size_per_head/4, head_num) // assume size_per_head is multiples of 32 -__global__ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, +__global__ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, const int8_t* src, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* bmm2_deQFactor, - const float* out_scale_ptr, - const int seq_len_x_size_per_head) + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* bmm2_deQFactor, + const float* out_scale_ptr, + const int seq_len_x_size_per_head) { int threadIdx4 = threadIdx.x << 2; - int batch_id = blockIdx.y; - int seq_id = blockIdx.x; - int head_id = threadIdx.y; + int batch_id = blockIdx.y; + int seq_id = blockIdx.x; + int head_id = threadIdx.y; // get the (row, col) output layout of m*k // m = valid_word_num // k = head_num*size_per_head @@ -280,7 +280,7 @@ __global__ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, // get the (row, col) layout of COL32; leading dimension = 32*m = 32*valid_word_num int COL32_row = (mk_row << 5) + (mk_col & 31); int COL32_col = mk_col >> 5; - int outIdx = ((COL32_col << 5) * valid_word_num + COL32_row) >> 2; + int outIdx = ((COL32_col << 5) * valid_word_num + COL32_row) >> 2; // get the (row, col) input layout of m'*k' // m' = seq_len @@ -296,22 +296,22 @@ __global__ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, const char4* src_ptr4 = (const char4*)src; - char4* dst_ptr4 = (char4*)dst; + char4* dst_ptr4 = (char4*)dst; dst_ptr4[outIdx] = __ldg(src_ptr4 + inIdx); } } -void invokeTransposeCOL32RebuildPadding(int8_t* dst, +void invokeTransposeCOL32RebuildPadding(int8_t* dst, const int8_t* src, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* bmm2_deQFactor, - const float* out_scale_ptr, - cudaStream_t stream) + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* bmm2_deQFactor, + const float* out_scale_ptr, + cudaStream_t stream) { assert(size_per_head % 32 == 0); transpose_COL32_rebuild_padding_kernel<<>>( @@ -333,21 +333,21 @@ void invokeTransposeCOL32RebuildPadding(int8_t* dst, // grid(seq_len, batch_size) // block(size_per_head/4, head_num) // assume size_per_head is multiples of 32 -__global__ void transpose_COL32_ROW_kernel(int8_t* dst, +__global__ void transpose_COL32_ROW_kernel(int8_t* dst, const int8_t* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* bmm2_deQFactor, - const float* out_scale_ptr, - const int head_num_x_size_per_head, - const int seq_len_x_size_per_head) + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* bmm2_deQFactor, + const float* out_scale_ptr, + const int head_num_x_size_per_head, + const int seq_len_x_size_per_head) { int threadIdx4 = threadIdx.x << 2; - int batch_id = blockIdx.y; - int seq_id = blockIdx.x; - int head_id = threadIdx.y; + int batch_id = blockIdx.y; + int seq_id = blockIdx.x; + int head_id = threadIdx.y; // get the (row, col) output layout of m*k // m = batch_size*seq_len // k = head_num*size_per_head @@ -367,19 +367,19 @@ __global__ void transpose_COL32_ROW_kernel(int8_t* dst, int inIdx = ((batch_id * head_num + head_id) * seq_len_x_size_per_head + (COL32_col << 5) * seq_len + COL32_row) >> 2; const char4* src_ptr4 = (const char4*)src; - char4* dst_ptr4 = (char4*)dst; - dst_ptr4[outIdx] = __ldg(src_ptr4 + inIdx); + char4* dst_ptr4 = (char4*)dst; + dst_ptr4[outIdx] = __ldg(src_ptr4 + inIdx); } -void invokeTransposeCOL32ToRow(int8_t* dst, +void invokeTransposeCOL32ToRow(int8_t* dst, const int8_t* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* bmm2_deQFactor, - const float* out_scale_ptr, - cudaStream_t stream) + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* bmm2_deQFactor, + const float* out_scale_ptr, + cudaStream_t stream) { assert(size_per_head % 32 == 0); transpose_COL32_ROW_kernel<<>>( @@ -400,23 +400,23 @@ void invokeTransposeCOL32ToRow(int8_t* dst, // grid(seq_len, batch_size) // block(size_per_head/4, head_num) // assume size_per_head is multiples of 32 -__global__ void transpose_COL32_ROW_rebuild_padding_kernel(int8_t* dst, +__global__ void transpose_COL32_ROW_rebuild_padding_kernel(int8_t* dst, const int8_t* src, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* bmm2_deQFactor, - const float* out_scale_ptr, - const int seq_len_x_size_per_head, - const int head_num_x_size_per_head) + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* bmm2_deQFactor, + const float* out_scale_ptr, + const int seq_len_x_size_per_head, + const int head_num_x_size_per_head) { int threadIdx4 = threadIdx.x << 2; - int batch_id = blockIdx.y; - int seq_id = blockIdx.x; - int head_id = threadIdx.y; + int batch_id = blockIdx.y; + int seq_id = blockIdx.x; + int head_id = threadIdx.y; // get the (row, col) output layout of m*k // m = valid_word_num // k = head_num*size_per_head @@ -439,22 +439,22 @@ __global__ void transpose_COL32_ROW_rebuild_padding_kernel(int8_t* dst, const char4* src_ptr4 = (const char4*)src; - char4* dst_ptr4 = (char4*)dst; + char4* dst_ptr4 = (char4*)dst; dst_ptr4[outIdx] = __ldg(src_ptr4 + inIdx); } } -void invokeTransposeCOL32ToRowRebuildPadding(int8_t* dst, +void invokeTransposeCOL32ToRowRebuildPadding(int8_t* dst, const int8_t* src, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* bmm2_deQFactor, - const float* out_scale_ptr, - cudaStream_t stream) + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* bmm2_deQFactor, + const float* out_scale_ptr, + cudaStream_t stream) { assert(size_per_head % 32 == 0); transpose_COL32_ROW_rebuild_padding_kernel<< -__global__ void add_QK_bias_transform(int8_t* q_buf_, - int8_t* k_buf_, +__global__ void add_QK_bias_transform(int8_t* q_buf_, + int8_t* k_buf_, const int32_t* Q, - const T* bias_Q, + const T* bias_Q, const int32_t* K, - const T* bias_K, - const int m, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* q_weight_amax, - const float* q_input_deQFactor_div127_ptr, - const float* k_weight_amax, - const float* k_input_deQFactor_div127_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* bias_K, + const int m, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* q_weight_amax, + const float* q_input_deQFactor_div127_ptr, + const float* k_weight_amax, + const float* k_input_deQFactor_div127_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { const int32_t* data_ptr; - char4* buf_ptr4; - const T* bias_ptr; - const float* weight_amax; - int qk_id = blockIdx.x / m; + char4* buf_ptr4; + const T* bias_ptr; + const float* weight_amax; + int qk_id = blockIdx.x / m; data_ptr = qk_id == 0 ? Q : K; buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; bias_ptr = qk_id == 0 ? bias_Q : bias_K; const float input_deQFactor_div127 = qk_id == 0 ? __ldg(q_input_deQFactor_div127_ptr) : __ldg(k_input_deQFactor_div127_ptr); - weight_amax = qk_id == 0 ? q_weight_amax : k_weight_amax; + weight_amax = qk_id == 0 ? q_weight_amax : k_weight_amax; const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); int threadIdx4 = threadIdx.x << 2; - int batch_id = (blockIdx.x % m) / seq_len; - int head_id = threadIdx4 / size_per_head; + int batch_id = (blockIdx.x % m) / seq_len; + int head_id = threadIdx4 / size_per_head; int id_in_head = threadIdx4 % size_per_head; - int word_id = blockIdx.x % seq_len; + int word_id = blockIdx.x % seq_len; int data_id = (((threadIdx4 >> 5) << 5) * m + ((blockIdx.x % m) << 5) + (threadIdx4 & 31)); float scale; float tmp; char4 tmp4; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.x = float_to_int8_rn(tmp * output_scale); - data_id = data_id + 1; + data_id = data_id + 1; threadIdx4 = threadIdx4 + 1; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.y = float_to_int8_rn(tmp * output_scale); - data_id = data_id + 1; + data_id = data_id + 1; threadIdx4 = threadIdx4 + 1; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.z = float_to_int8_rn(tmp * output_scale); - data_id = data_id + 1; + data_id = data_id + 1; threadIdx4 = threadIdx4 + 1; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.w = float_to_int8_rn(tmp * output_scale); // row_id, col_id of sub-matrix (m = seq_len, n = size_per_head), column-major @@ -138,10 +138,10 @@ __global__ void add_QK_bias_transform(int8_t* q_buf_, if (use_ORDER_COL32_2R_4R4) { int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - new_row = (qk_id != 1) ? - // COL32 + new_row = (qk_id != 1) ? + // COL32 ((row_id << 5) + (col_id & 31)) : - // COL32_2R_4R4 + // COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile % 8) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) @@ -179,73 +179,73 @@ __global__ void add_QK_bias_transform(int8_t* q_buf_, // block.x = head_num * size_per_head / 4; // using char4 template -__global__ void add_QK_bias_transform_varlen(int8_t* q_buf_, - int8_t* k_buf_, +__global__ void add_QK_bias_transform_varlen(int8_t* q_buf_, + int8_t* k_buf_, const int32_t* Q, - const T* bias_Q, + const T* bias_Q, const int32_t* K, - const T* bias_K, - const int m, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int seq_len_padded, - int stride_q, - int stride_k, - const float* q_weight_amax, - const float* q_input_deQFactor_div127_ptr, - const float* k_weight_amax, - const float* k_input_deQFactor_div127_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* bias_K, + const int m, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int seq_len_padded, + int stride_q, + int stride_k, + const float* q_weight_amax, + const float* q_input_deQFactor_div127_ptr, + const float* k_weight_amax, + const float* k_input_deQFactor_div127_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { const int32_t* data_ptr; - char4* buf_ptr4; - const T* bias_ptr; - const float* weight_amax; - int qk_id = blockIdx.x / m; + char4* buf_ptr4; + const T* bias_ptr; + const float* weight_amax; + int qk_id = blockIdx.x / m; data_ptr = qk_id == 0 ? Q : K; buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; bias_ptr = qk_id == 0 ? bias_Q : bias_K; const float input_deQFactor_div127 = qk_id == 0 ? __ldg(q_input_deQFactor_div127_ptr) : __ldg(k_input_deQFactor_div127_ptr); - weight_amax = qk_id == 0 ? q_weight_amax : k_weight_amax; + weight_amax = qk_id == 0 ? q_weight_amax : k_weight_amax; const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); int threadIdx4 = threadIdx.x << 2; - int batch_id = (blockIdx.x % m) / seq_len; - int head_id = threadIdx4 / size_per_head; + int batch_id = (blockIdx.x % m) / seq_len; + int head_id = threadIdx4 / size_per_head; int id_in_head = threadIdx4 % size_per_head; - int word_id = blockIdx.x % seq_len; + int word_id = blockIdx.x % seq_len; int data_id = (((threadIdx4 >> 5) << 5) * m + ((blockIdx.x % m) << 5) + (threadIdx4 & 31)); float scale; float tmp; char4 tmp4; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.x = float_to_int8_rn(tmp * output_scale); - data_id = data_id + 1; + data_id = data_id + 1; threadIdx4 = threadIdx4 + 1; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.y = float_to_int8_rn(tmp * output_scale); - data_id = data_id + 1; + data_id = data_id + 1; threadIdx4 = threadIdx4 + 1; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.z = float_to_int8_rn(tmp * output_scale); - data_id = data_id + 1; + data_id = data_id + 1; threadIdx4 = threadIdx4 + 1; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.w = float_to_int8_rn(tmp * output_scale); // row_id, col_id of sub-matrix (m = seq_len, n = size_per_head), column-major @@ -258,10 +258,10 @@ __global__ void add_QK_bias_transform_varlen(int8_t* q_buf_, if (use_ORDER_COL32_2R_4R4) { int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - new_row = (qk_id != 1) ? - // COL32 + new_row = (qk_id != 1) ? + // COL32 ((row_id << 5) + (col_id & 31)) : - // COL32_2R_4R4 + // COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile % 8) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) @@ -284,29 +284,29 @@ __global__ void add_QK_bias_transform_varlen(int8_t* q_buf_, } const int act_seq_len = (qk_id == 0) ? seq_len : seq_len_padded; - const int stride = (qk_id == 0) ? stride_q : stride_k; + const int stride = (qk_id == 0) ? stride_q : stride_k; buf_ptr4[(((batch_id * head_num + head_id) * stride + (new_col << 5) * act_seq_len + new_row) >> 2)] = tmp4; } template -void invokeAddQKBiasTransform(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransform(int8_t* q_buf, + int8_t* k_buf, const int32_t* Q, - const T* bias_Q, + const T* bias_Q, const int32_t* K, - const T* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_weight_amax, - const float* q_input_deQFactor_div127_ptr, - const float* k_weight_amax, - const float* k_input_deQFactor_div127_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_weight_amax, + const float* q_input_deQFactor_div127_ptr, + const float* k_weight_amax, + const float* k_input_deQFactor_div127_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { if (seq_len % 32 == 0) { add_QK_bias_transform<<>>( @@ -359,43 +359,43 @@ void invokeAddQKBiasTransform(int8_t* q_buf, } } -template void invokeAddQKBiasTransform(int8_t* q_buf, - int8_t* k_buf, +template void invokeAddQKBiasTransform(int8_t* q_buf, + int8_t* k_buf, const int32_t* Q, - const float* bias_Q, + const float* bias_Q, const int32_t* K, - const float* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_weight_amax, - const float* q_input_deQFactor_div127_ptr, - const float* k_weight_amax, - const float* k_input_deQFactor_div127_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); - -template void invokeAddQKBiasTransform(int8_t* q_buf, - int8_t* k_buf, + const float* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_weight_amax, + const float* q_input_deQFactor_div127_ptr, + const float* k_weight_amax, + const float* k_input_deQFactor_div127_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); + +template void invokeAddQKBiasTransform(int8_t* q_buf, + int8_t* k_buf, const int32_t* Q, - const half* bias_Q, + const half* bias_Q, const int32_t* K, - const half* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_weight_amax, - const float* q_input_deQFactor_div127_ptr, - const float* k_weight_amax, - const float* k_input_deQFactor_div127_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const half* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_weight_amax, + const float* q_input_deQFactor_div127_ptr, + const float* k_weight_amax, + const float* k_input_deQFactor_div127_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); // add_QK_bias_padding_transform for batch int8 cublasLtMatmul & per tensor quantization for weight // 1.add QK bias @@ -411,66 +411,66 @@ template void invokeAddQKBiasTransform(int8_t* q_buf, // block.x = head_num * size_per_head / 4; // using char4 template -__global__ void add_QK_bias_transform_varlen(int8_t* q_buf_, - int8_t* k_buf_, +__global__ void add_QK_bias_transform_varlen(int8_t* q_buf_, + int8_t* k_buf_, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int m, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int seq_len_padded, - const int stride_q, - const int stride_k, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* bias_K, + const int m, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int seq_len_padded, + const int stride_q, + const int stride_k, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { const char4* data_ptr; - char4* buf_ptr4; - const T* bias_ptr; - int qk_id = blockIdx.x / m; + char4* buf_ptr4; + const T* bias_ptr; + int qk_id = blockIdx.x / m; - data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; - buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; - bias_ptr = qk_id == 0 ? bias_Q : bias_K; + data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; + buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; + bias_ptr = qk_id == 0 ? bias_Q : bias_K; const float input_deQFactor = qk_id == 0 ? __ldg(q_input_deQFactor_ptr) : __ldg(k_input_deQFactor_ptr); - const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); + const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); int threadIdx4 = threadIdx.x << 2; - int batch_id = (blockIdx.x % m) / seq_len; - int head_id = threadIdx4 / size_per_head; + int batch_id = (blockIdx.x % m) / seq_len; + int head_id = threadIdx4 / size_per_head; int id_in_head = threadIdx4 % size_per_head; - int word_id = blockIdx.x % seq_len; + int word_id = blockIdx.x % seq_len; int data_id = (((threadIdx4 >> 5) << 5) * m + ((blockIdx.x % m) << 5) + (threadIdx4 & 31)) >> 2; float scale; float tmp; char4 tmp4 = __ldg(data_ptr + data_id); - scale = static_cast(tmp4.x) * input_deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.x = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.x) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.x = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.y) * input_deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.y = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.y) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.y = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.z) * input_deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.z = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.z) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.z = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.w) * input_deQFactor; + scale = static_cast(tmp4.w) * input_deQFactor; ; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.w = float_to_int8_rn(tmp * output_scale); // row_id, col_id of sub-matrix (m = seq_len/seq_len_padded, n = size_per_head), column-major @@ -483,10 +483,10 @@ __global__ void add_QK_bias_transform_varlen(int8_t* q_buf_, if (use_ORDER_COL32_2R_4R4) { int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - new_row = (qk_id != 1) ? - // COL32 + new_row = (qk_id != 1) ? + // COL32 ((row_id << 5) + (col_id & 31)) : - // COL32_2R_4R4 + // COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile % 8) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) @@ -509,7 +509,7 @@ __global__ void add_QK_bias_transform_varlen(int8_t* q_buf_, } const int act_seq_len = (qk_id == 0) ? seq_len : seq_len_padded; - const int stride = (qk_id == 0) ? stride_q : stride_k; + const int stride = (qk_id == 0) ? stride_q : stride_k; buf_ptr4[(((batch_id * head_num + head_id) * stride + (new_col << 5) * act_seq_len + new_row) >> 2)] = tmp4; } @@ -526,64 +526,64 @@ __global__ void add_QK_bias_transform_varlen(int8_t* q_buf_, // block.x = head_num * size_per_head / 4; // using char4 template -__global__ void add_QK_bias_transform(int8_t* q_buf_, - int8_t* k_buf_, +__global__ void add_QK_bias_transform(int8_t* q_buf_, + int8_t* k_buf_, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int m, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* bias_K, + const int m, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { const char4* data_ptr; - char4* buf_ptr4; - const T* bias_ptr; - int qk_id = blockIdx.x / m; + char4* buf_ptr4; + const T* bias_ptr; + int qk_id = blockIdx.x / m; - data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; - buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; - bias_ptr = qk_id == 0 ? bias_Q : bias_K; + data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; + buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; + bias_ptr = qk_id == 0 ? bias_Q : bias_K; const float input_deQFactor = qk_id == 0 ? __ldg(q_input_deQFactor_ptr) : __ldg(k_input_deQFactor_ptr); - const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); + const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); int threadIdx4 = threadIdx.x << 2; - int batch_id = (blockIdx.x % m) / seq_len; - int head_id = threadIdx4 / size_per_head; + int batch_id = (blockIdx.x % m) / seq_len; + int head_id = threadIdx4 / size_per_head; int id_in_head = threadIdx4 % size_per_head; - int word_id = blockIdx.x % seq_len; + int word_id = blockIdx.x % seq_len; int data_id = (((threadIdx4 >> 5) << 5) * m + ((blockIdx.x % m) << 5) + (threadIdx4 & 31)) >> 2; float scale; float tmp; char4 tmp4 = __ldg(data_ptr + data_id); - scale = static_cast(tmp4.x) * input_deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.x = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.x) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.x = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.y) * input_deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.y = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.y) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.y = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.z) * input_deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.z = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.z) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.z = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.w) * input_deQFactor; + scale = static_cast(tmp4.w) * input_deQFactor; ; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.w = float_to_int8_rn(tmp * output_scale); // row_id, col_id of sub-matrix (m = seq_len, n = size_per_head), column-major @@ -596,10 +596,10 @@ __global__ void add_QK_bias_transform(int8_t* q_buf_, if (use_ORDER_COL32_2R_4R4) { int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - new_row = (qk_id != 1) ? - // COL32 + new_row = (qk_id != 1) ? + // COL32 ((row_id << 5) + (col_id & 31)) : - // COL32_2R_4R4 + // COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile % 8) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) @@ -625,22 +625,22 @@ __global__ void add_QK_bias_transform(int8_t* q_buf_, } template -void invokeAddQKBiasTransform(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransform(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { assert(size_per_head % 32 == 0); if (seq_len % 32 == 0) { @@ -696,39 +696,39 @@ void invokeAddQKBiasTransform(int8_t* q_buf, } } -template void invokeAddQKBiasTransform(int8_t* q_buf, - int8_t* k_buf, +template void invokeAddQKBiasTransform(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const float* bias_Q, + const float* bias_Q, const int8_t* K, - const float* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); - -template void invokeAddQKBiasTransform(int8_t* q_buf, - int8_t* k_buf, + const float* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); + +template void invokeAddQKBiasTransform(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const half* bias_Q, + const half* bias_Q, const int8_t* K, - const half* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const half* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); // add_QK_bias_padding_transform for batch int8 cublasLtMatmul & per tensor quantization for weight // 1.add QK bias @@ -744,68 +744,68 @@ template void invokeAddQKBiasTransform(int8_t* q_buf, // block.x = head_num * size_per_head / 4; // using char4 template -__global__ void add_QK_bias_transform_varlen_row(int8_t* q_buf_, - int8_t* k_buf_, +__global__ void add_QK_bias_transform_varlen_row(int8_t* q_buf_, + int8_t* k_buf_, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int m, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int seq_len_padded, - const int stride_q, - const int stride_k, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - const int head_num_x_size_per_head) + const T* bias_K, + const int m, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int seq_len_padded, + const int stride_q, + const int stride_k, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + const int head_num_x_size_per_head) { const char4* data_ptr; - char4* buf_ptr4; - const T* bias_ptr; - int qk_id = blockIdx.x / m; + char4* buf_ptr4; + const T* bias_ptr; + int qk_id = blockIdx.x / m; - data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; - buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; - bias_ptr = qk_id == 0 ? bias_Q : bias_K; + data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; + buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; + bias_ptr = qk_id == 0 ? bias_Q : bias_K; const float input_deQFactor = qk_id == 0 ? __ldg(q_input_deQFactor_ptr) : __ldg(k_input_deQFactor_ptr); - const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); + const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); - int threadIdx4 = threadIdx.x << 2; + int threadIdx4 = threadIdx.x << 2; int batch_seq_id = blockIdx.x % m; - int batch_id = (batch_seq_id) / seq_len; - int head_id = threadIdx4 / size_per_head; - int id_in_head = threadIdx4 % size_per_head; - int word_id = blockIdx.x % seq_len; + int batch_id = (batch_seq_id) / seq_len; + int head_id = threadIdx4 / size_per_head; + int id_in_head = threadIdx4 % size_per_head; + int word_id = blockIdx.x % seq_len; int data_id = (batch_seq_id * head_num_x_size_per_head + threadIdx4) >> 2; float scale; float tmp; char4 tmp4 = __ldg(data_ptr + data_id); - scale = static_cast(tmp4.x) * input_deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.x = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.x) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.x = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.y) * input_deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.y = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.y) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.y = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.z) * input_deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.z = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.z) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.z = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.w) * input_deQFactor; + scale = static_cast(tmp4.w) * input_deQFactor; ; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.w = float_to_int8_rn(tmp * output_scale); // row_id, col_id of sub-matrix (m = seq_len/seq_len_padded, n = size_per_head), column-major @@ -818,10 +818,10 @@ __global__ void add_QK_bias_transform_varlen_row(int8_t* q_buf_, if (use_ORDER_COL32_2R_4R4) { int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - new_row = (qk_id != 1) ? - // COL32 + new_row = (qk_id != 1) ? + // COL32 ((row_id << 5) + (col_id & 31)) : - // COL32_2R_4R4 + // COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile % 8) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) @@ -844,27 +844,27 @@ __global__ void add_QK_bias_transform_varlen_row(int8_t* q_buf_, } const int act_seq_len = (qk_id == 0) ? seq_len : seq_len_padded; - const int stride = (qk_id == 0) ? stride_q : stride_k; + const int stride = (qk_id == 0) ? stride_q : stride_k; buf_ptr4[(((batch_id * head_num + head_id) * stride + (new_col << 5) * act_seq_len + new_row) >> 2)] = tmp4; } template -void invokeAddQKBiasTransformRow(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransformRow(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { assert(size_per_head % 32 == 0); if (seq_len % 32 == 0) { @@ -922,39 +922,39 @@ void invokeAddQKBiasTransformRow(int8_t* q_buf, } } -template void invokeAddQKBiasTransformRow(int8_t* q_buf, - int8_t* k_buf, +template void invokeAddQKBiasTransformRow(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const float* bias_Q, + const float* bias_Q, const int8_t* K, - const float* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); - -template void invokeAddQKBiasTransformRow(int8_t* q_buf, - int8_t* k_buf, + const float* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); + +template void invokeAddQKBiasTransformRow(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const half* bias_Q, + const half* bias_Q, const int8_t* K, - const half* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const half* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); // add_QK_bias_transform & rebuild padding for batch int8 cublasLtMatmul & per axis quantization for weight // 1.add QK bias @@ -970,33 +970,33 @@ template void invokeAddQKBiasTransformRow(int8_t* q_buf, // block.x = head_num * size_per_head / 4; // using char4 template -__global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, - int8_t* k_buf_, +__global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, + int8_t* k_buf_, const int32_t* Q, - const T* bias_Q, + const T* bias_Q, const int32_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int m, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* q_weight_amax, - const float* q_input_deQFactor_div127_ptr, - const float* k_weight_amax, - const float* k_input_deQFactor_div127_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int m, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* q_weight_amax, + const float* q_input_deQFactor_div127_ptr, + const float* k_weight_amax, + const float* k_input_deQFactor_div127_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { const int32_t* data_ptr; - char4* buf_ptr4; - const T* bias_ptr; - const float* weight_amax; - int qk_id = blockIdx.x / valid_word_num; + char4* buf_ptr4; + const T* bias_ptr; + const float* weight_amax; + int qk_id = blockIdx.x / valid_word_num; data_ptr = qk_id == 0 ? Q : K; buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; @@ -1004,15 +1004,15 @@ __global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, int threadIdx4 = threadIdx.x << 2; int m_full_idx = blockIdx.x % valid_word_num; - m_full_idx = (valid_word_num != m) ? (m_full_idx + __ldg(sequence_id_offset + m_full_idx)) : m_full_idx; - int batch_id = m_full_idx / seq_len; - int head_id = threadIdx4 / size_per_head; + m_full_idx = (valid_word_num != m) ? (m_full_idx + __ldg(sequence_id_offset + m_full_idx)) : m_full_idx; + int batch_id = m_full_idx / seq_len; + int head_id = threadIdx4 / size_per_head; int id_in_head = threadIdx4 % size_per_head; - int word_id = m_full_idx % seq_len; + int word_id = m_full_idx % seq_len; const float input_deQFactor_div127 = qk_id == 0 ? __ldg(q_input_deQFactor_div127_ptr) : __ldg(k_input_deQFactor_div127_ptr); - weight_amax = qk_id == 0 ? q_weight_amax : k_weight_amax; + weight_amax = qk_id == 0 ? q_weight_amax : k_weight_amax; const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); int data_id = @@ -1021,26 +1021,26 @@ __global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, float scale; float tmp; char4 tmp4; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.x = float_to_int8_rn(tmp * output_scale); - data_id = data_id + 1; + data_id = data_id + 1; threadIdx4 = threadIdx4 + 1; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.y = float_to_int8_rn(tmp * output_scale); - data_id = data_id + 1; + data_id = data_id + 1; threadIdx4 = threadIdx4 + 1; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.z = float_to_int8_rn(tmp * output_scale); - data_id = data_id + 1; + data_id = data_id + 1; threadIdx4 = threadIdx4 + 1; - scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(__ldg(data_ptr + data_id)) * __ldg(weight_amax + threadIdx4) * input_deQFactor_div127; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.w = float_to_int8_rn(tmp * output_scale); // row_id, col_id of sub-matrix (m = seq_len, n = size_per_head), column-major @@ -1052,10 +1052,10 @@ __global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, if (use_ORDER_COL32_2R_4R4) { int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - new_row = (qk_id != 1) ? - // COL32 + new_row = (qk_id != 1) ? + // COL32 ((row_id << 5) + (col_id & 31)) : - // COL32_2R_4R4 + // COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile % 8) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) @@ -1081,26 +1081,26 @@ __global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, } template -void invokeAddQKBiasTransformRebuildPadding(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransformRebuildPadding(int8_t* q_buf, + int8_t* k_buf, const int32_t* Q, - const T* bias_Q, + const T* bias_Q, const int32_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_weight_amax, - const float* q_input_deQFactor_div127_ptr, - const float* k_weight_amax, - const float* k_input_deQFactor_div127_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_weight_amax, + const float* q_input_deQFactor_div127_ptr, + const float* k_weight_amax, + const float* k_input_deQFactor_div127_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { add_QK_bias_transform_rebuild_padding<< -__global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, - int8_t* k_buf_, +__global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, + int8_t* k_buf_, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int m, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* q_deQFactor_ptr, - const float* k_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int m, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* q_deQFactor_ptr, + const float* k_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { const char4* data_ptr; - char4* buf_ptr4; - const T* bias_ptr; - int qk_id = blockIdx.x / valid_word_num; + char4* buf_ptr4; + const T* bias_ptr; + int qk_id = blockIdx.x / valid_word_num; data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; @@ -1215,13 +1215,13 @@ __global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, int threadIdx4 = threadIdx.x << 2; int m_full_idx = blockIdx.x % valid_word_num; - m_full_idx = (valid_word_num != m) ? (m_full_idx + __ldg(sequence_id_offset + m_full_idx)) : m_full_idx; - int batch_id = m_full_idx / seq_len; - int head_id = threadIdx4 / size_per_head; + m_full_idx = (valid_word_num != m) ? (m_full_idx + __ldg(sequence_id_offset + m_full_idx)) : m_full_idx; + int batch_id = m_full_idx / seq_len; + int head_id = threadIdx4 / size_per_head; int id_in_head = threadIdx4 % size_per_head; - int word_id = m_full_idx % seq_len; + int word_id = m_full_idx % seq_len; - const float deQFactor = qk_id == 0 ? __ldg(q_deQFactor_ptr) : __ldg(k_deQFactor_ptr); + const float deQFactor = qk_id == 0 ? __ldg(q_deQFactor_ptr) : __ldg(k_deQFactor_ptr); const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); int data_id = @@ -1233,24 +1233,24 @@ __global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, tmp4 = __ldg(data_ptr + data_id); - scale = static_cast(tmp4.x) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(tmp4.x) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.x = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.y) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.y = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.y) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.y = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.z) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.z = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.z) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.z = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.w) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.w = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.w) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.w = float_to_int8_rn(tmp * output_scale); // row_id, col_id of sub-matrix (m = seq_len, n = size_per_head), column-major int row_id = word_id; @@ -1261,10 +1261,10 @@ __global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, if (use_ORDER_COL32_2R_4R4) { int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - new_row = (qk_id != 1) ? - // COL32 + new_row = (qk_id != 1) ? + // COL32 ((row_id << 5) + (col_id & 31)) : - // COL32_2R_4R4 + // COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile % 8) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) @@ -1304,32 +1304,32 @@ __global__ void add_QK_bias_transform_rebuild_padding(int8_t* q_buf_, // block.x = head_num * size_per_head / 4; // using char4 template -__global__ void add_QK_bias_transform_rebuild_padding_varlen(int8_t* q_buf_, - int8_t* k_buf_, +__global__ void add_QK_bias_transform_rebuild_padding_varlen(int8_t* q_buf_, + int8_t* k_buf_, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int m, - const int batch_size, - const int seq_len, - const int seq_len_padded, - const int head_num, - const int size_per_head, - int stride_q, - int stride_k, - const float* q_deQFactor_ptr, - const float* k_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int m, + const int batch_size, + const int seq_len, + const int seq_len_padded, + const int head_num, + const int size_per_head, + int stride_q, + int stride_k, + const float* q_deQFactor_ptr, + const float* k_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { const char4* data_ptr; - char4* buf_ptr4; - const T* bias_ptr; - int qk_id = blockIdx.x / valid_word_num; + char4* buf_ptr4; + const T* bias_ptr; + int qk_id = blockIdx.x / valid_word_num; data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; @@ -1337,13 +1337,13 @@ __global__ void add_QK_bias_transform_rebuild_padding_varlen(int8_t* q_buf_, int threadIdx4 = threadIdx.x << 2; int m_full_idx = blockIdx.x % valid_word_num; - m_full_idx = (valid_word_num != m) ? (m_full_idx + __ldg(sequence_id_offset + m_full_idx)) : m_full_idx; - int batch_id = m_full_idx / seq_len; - int head_id = threadIdx4 / size_per_head; + m_full_idx = (valid_word_num != m) ? (m_full_idx + __ldg(sequence_id_offset + m_full_idx)) : m_full_idx; + int batch_id = m_full_idx / seq_len; + int head_id = threadIdx4 / size_per_head; int id_in_head = threadIdx4 % size_per_head; - int word_id = m_full_idx % seq_len; + int word_id = m_full_idx % seq_len; - const float deQFactor = qk_id == 0 ? __ldg(q_deQFactor_ptr) : __ldg(k_deQFactor_ptr); + const float deQFactor = qk_id == 0 ? __ldg(q_deQFactor_ptr) : __ldg(k_deQFactor_ptr); const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); int data_id = @@ -1355,24 +1355,24 @@ __global__ void add_QK_bias_transform_rebuild_padding_varlen(int8_t* q_buf_, tmp4 = __ldg(data_ptr + data_id); - scale = static_cast(tmp4.x) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(tmp4.x) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.x = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.y) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.y = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.y) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.y = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.z) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.z = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.z) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.z = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.w) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.w = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.w) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.w = float_to_int8_rn(tmp * output_scale); // row_id, col_id of sub-matrix (m = seq_len or seq_len_padded, n = size_per_head), column-major int row_id = word_id; @@ -1383,10 +1383,10 @@ __global__ void add_QK_bias_transform_rebuild_padding_varlen(int8_t* q_buf_, if (use_ORDER_COL32_2R_4R4) { int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - new_row = (qk_id != 1) ? - // COL32 + new_row = (qk_id != 1) ? + // COL32 ((row_id << 5) + (col_id & 31)) : - // COL32_2R_4R4 + // COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile % 8) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) @@ -1409,29 +1409,29 @@ __global__ void add_QK_bias_transform_rebuild_padding_varlen(int8_t* q_buf_, } const int stride = (qk_id != 1) ? stride_q : stride_k; - const int len = (qk_id != 1) ? seq_len : seq_len_padded; + const int len = (qk_id != 1) ? seq_len : seq_len_padded; buf_ptr4[(((batch_id * head_num + head_id) * stride + (new_col << 5) * len + new_row) >> 2)] = tmp4; } template -void invokeAddQKBiasTransformRebuildPadding(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransformRebuildPadding(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_deQFactor_ptr, - const float* k_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_deQFactor_ptr, + const float* k_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { int seq_len_padded = (seq_len + 31) / 32 * 32; add_QK_bias_transform_rebuild_padding_varlen<< -__global__ void add_QK_bias_transform_rebuild_padding_varlen_row(int8_t* q_buf_, - int8_t* k_buf_, +__global__ void add_QK_bias_transform_rebuild_padding_varlen_row(int8_t* q_buf_, + int8_t* k_buf_, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int m, - const int batch_size, - const int seq_len, - const int seq_len_padded, - const int head_num, - const int size_per_head, - int stride_q, - int stride_k, - const float* q_deQFactor_ptr, - const float* k_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - const int head_num_x_size_per_head) + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int m, + const int batch_size, + const int seq_len, + const int seq_len_padded, + const int head_num, + const int size_per_head, + int stride_q, + int stride_k, + const float* q_deQFactor_ptr, + const float* k_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + const int head_num_x_size_per_head) { const char4* data_ptr; - char4* buf_ptr4; - const T* bias_ptr; - int qk_id = blockIdx.x / valid_word_num; + char4* buf_ptr4; + const T* bias_ptr; + int qk_id = blockIdx.x / valid_word_num; data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; bias_ptr = qk_id == 0 ? bias_Q : bias_K; - int threadIdx4 = threadIdx.x << 2; + int threadIdx4 = threadIdx.x << 2; int batch_seq_id = blockIdx.x % valid_word_num; - int m_full_idx = (valid_word_num != m) ? (batch_seq_id + __ldg(sequence_id_offset + batch_seq_id)) : batch_seq_id; - int batch_id = m_full_idx / seq_len; - int head_id = threadIdx4 / size_per_head; - int id_in_head = threadIdx4 % size_per_head; - int word_id = m_full_idx % seq_len; + int m_full_idx = (valid_word_num != m) ? (batch_seq_id + __ldg(sequence_id_offset + batch_seq_id)) : batch_seq_id; + int batch_id = m_full_idx / seq_len; + int head_id = threadIdx4 / size_per_head; + int id_in_head = threadIdx4 % size_per_head; + int word_id = m_full_idx % seq_len; - const float deQFactor = qk_id == 0 ? __ldg(q_deQFactor_ptr) : __ldg(k_deQFactor_ptr); + const float deQFactor = qk_id == 0 ? __ldg(q_deQFactor_ptr) : __ldg(k_deQFactor_ptr); const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); int data_id = (batch_seq_id * head_num_x_size_per_head + threadIdx4) >> 2; @@ -1564,24 +1564,24 @@ __global__ void add_QK_bias_transform_rebuild_padding_varlen_row(int8_t* q_buf_, tmp4 = __ldg(data_ptr + data_id); - scale = static_cast(tmp4.x) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + scale = static_cast(tmp4.x) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; tmp4.x = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.y) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.y = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.y) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.y = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.z) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.z = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.z) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.z = float_to_int8_rn(tmp * output_scale); threadIdx4 = threadIdx4 + 1; - scale = static_cast(tmp4.w) * deQFactor; - tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; - tmp4.w = float_to_int8_rn(tmp * output_scale); + scale = static_cast(tmp4.w) * deQFactor; + tmp = static_cast(__ldg(bias_ptr + threadIdx4)) + scale; + tmp4.w = float_to_int8_rn(tmp * output_scale); // row_id, col_id of sub-matrix (m = seq_len or seq_len_padded, n = size_per_head), column-major int row_id = word_id; @@ -1592,10 +1592,10 @@ __global__ void add_QK_bias_transform_rebuild_padding_varlen_row(int8_t* q_buf_, if (use_ORDER_COL32_2R_4R4) { int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - new_row = (qk_id != 1) ? - // COL32 + new_row = (qk_id != 1) ? + // COL32 ((row_id << 5) + (col_id & 31)) : - // COL32_2R_4R4 + // COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile % 8) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) @@ -1618,29 +1618,29 @@ __global__ void add_QK_bias_transform_rebuild_padding_varlen_row(int8_t* q_buf_, } const int stride = (qk_id != 1) ? stride_q : stride_k; - const int len = (qk_id != 1) ? seq_len : seq_len_padded; + const int len = (qk_id != 1) ? seq_len : seq_len_padded; buf_ptr4[(((batch_id * head_num + head_id) * stride + (new_col << 5) * len + new_row) >> 2)] = tmp4; } template -void invokeAddQKBiasTransformRebuildPaddingRow(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransformRebuildPaddingRow(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_deQFactor_ptr, - const float* k_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_deQFactor_ptr, + const float* k_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { int seq_len_padded = (seq_len + 31) / 32 * 32; add_QK_bias_transform_rebuild_padding_varlen_row<< -__global__ void add_V_bias_transform(int8_t* v_buf_, +__global__ void add_V_bias_transform(int8_t* v_buf_, const int32_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { - const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); + const float out_scale = __ldg(out_scale_ptr); __shared__ int8_t shm[32][33]; - const int32_t* data_ptr = V; - char4* buf_ptr4 = (char4*)v_buf_; - const T* bias_ptr = V_bias; + const int32_t* data_ptr = V; + char4* buf_ptr4 = (char4*)v_buf_; + const T* bias_ptr = V_bias; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) - int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; - int word_id = (blockIdx.y << 5) + threadIdx.y; + int batch_id = blockIdx.z / head_num; + int head_id = blockIdx.z % head_num; + int word_id = (blockIdx.y << 5) + threadIdx.y; int id_in_size = (blockIdx.x << 5) + threadIdx4; // for V layout (batch_size*seq_len, head_num*size_per_head) - int col = head_id * size_per_head + id_in_size; - int row = batch_id * seq_len + word_id; + int col = head_id * size_per_head + id_in_size; + int row = batch_id * seq_len + word_id; int inIdx = (((col >> 5) << 5) * batch_size * seq_len + ((row << 5) + (col & 31))); // for shm row-major int sh_col = threadIdx4; @@ -1760,35 +1760,35 @@ __global__ void add_V_bias_transform(int8_t* v_buf_, // tmp2 = __ldg(&bias_ptr2[col >> 1]); - scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; - tmp = scale + static_cast(__ldg(bias_ptr + col)); //(tmp2.x); + scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; + tmp = scale + static_cast(__ldg(bias_ptr + col)); //(tmp2.x); shm[sh_row][sh_col] = float_to_int8_rn(tmp * out_scale); - scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; - tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); //(tmp2.y); + scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; + tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); //(tmp2.y); shm[sh_row][sh_col + 1] = float_to_int8_rn(tmp * out_scale); // tmp2 = __ldg(&bias_ptr2[(col >> 1) + 1]); - scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; - tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); //(tmp2.x); + scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; + tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); //(tmp2.x); shm[sh_row][sh_col + 2] = float_to_int8_rn(tmp * out_scale); - scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; - tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); //(tmp2.y); + scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; + tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); //(tmp2.y); shm[sh_row][sh_col + 3] = float_to_int8_rn(tmp * out_scale); __syncthreads(); // for dst of (size_per_head, seq_len) - word_id = (blockIdx.y << 5) + threadIdx4; + word_id = (blockIdx.y << 5) + threadIdx4; id_in_size = (blockIdx.x << 5) + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col @@ -1809,89 +1809,89 @@ __global__ void add_V_bias_transform(int8_t* v_buf_, } char4 dataTmp; - dataTmp.x = shm[sh_col][sh_row]; - dataTmp.y = shm[sh_col + 1][sh_row]; - dataTmp.z = shm[sh_col + 2][sh_row]; - dataTmp.w = shm[sh_col + 3][sh_row]; + dataTmp.x = shm[sh_col][sh_row]; + dataTmp.y = shm[sh_col + 1][sh_row]; + dataTmp.z = shm[sh_col + 2][sh_row]; + dataTmp.w = shm[sh_col + 3][sh_row]; buf_ptr4[(blockIdx.z * stride + (col << 5) * size_per_head + row) >> 2] = dataTmp; } template<> -__global__ void add_V_bias_transform(int8_t* v_buf_, +__global__ void add_V_bias_transform(int8_t* v_buf_, const int32_t* V, - const half* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const half* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { - const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); + const float out_scale = __ldg(out_scale_ptr); __shared__ int8_t shm[32][33]; - const int32_t* data_ptr = V; - char4* buf_ptr4 = (char4*)v_buf_; + const int32_t* data_ptr = V; + char4* buf_ptr4 = (char4*)v_buf_; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; + int head_id = blockIdx.z % head_num; int blockIdy32 = (blockIdx.y << 5); int blockIdx32 = (blockIdx.x << 5); - int word_id = blockIdy32 + threadIdx.y; + int word_id = blockIdy32 + threadIdx.y; int id_in_size = blockIdx32 + threadIdx4; // for V layout (batch_size*seq_len, head_num*size_per_head) - int col = head_id * size_per_head + id_in_size; - int row = batch_id * seq_len + word_id; + int col = head_id * size_per_head + id_in_size; + int row = batch_id * seq_len + word_id; int inIdx = ((col & 0xffffffe0) * batch_size * seq_len + ((row << 5) + (col & 31))); // for shm row-major int sh_col = threadIdx4; int sh_row = threadIdx.y; - int col_2 = col >> 1; + int col_2 = col >> 1; float scale; const half2* bias_ptr2 = (const half2*)V_bias; - half2 tmp2; + half2 tmp2; tmp2 = __ldg(bias_ptr2 + col_2); - scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.x); + scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.x); shm[sh_row][sh_col] = float_to_int8_rn(scale * out_scale); - scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.y); + scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.y); shm[sh_row][sh_col + 1] = float_to_int8_rn(scale * out_scale); tmp2 = __ldg(bias_ptr2 + col_2 + 1); - scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.x); + scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.x); shm[sh_row][sh_col + 2] = float_to_int8_rn(scale * out_scale); - scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.y); + scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.y); shm[sh_row][sh_col + 3] = float_to_int8_rn(scale * out_scale); __syncthreads(); // for dst of (size_per_head, seq_len) - word_id = blockIdy32 + threadIdx4; + word_id = blockIdy32 + threadIdx4; id_in_size = blockIdx32 + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col @@ -1912,47 +1912,47 @@ __global__ void add_V_bias_transform(int8_t* v_buf_, } char4 dataTmp; - dataTmp.x = shm[sh_col][sh_row]; - dataTmp.y = shm[sh_col + 1][sh_row]; - dataTmp.z = shm[sh_col + 2][sh_row]; - dataTmp.w = shm[sh_col + 3][sh_row]; + dataTmp.x = shm[sh_col][sh_row]; + dataTmp.y = shm[sh_col + 1][sh_row]; + dataTmp.z = shm[sh_col + 2][sh_row]; + dataTmp.w = shm[sh_col + 3][sh_row]; buf_ptr4[(blockIdx.z * stride + (col << 5) * size_per_head + row) >> 2] = dataTmp; } template -__global__ void add_V_bias_transform_varlen(int8_t* v_buf_, +__global__ void add_V_bias_transform_varlen(int8_t* v_buf_, const int32_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { - const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); + const float out_scale = __ldg(out_scale_ptr); __shared__ int8_t shm[32][33]; - const int32_t* data_ptr = V; - char4* buf_ptr4 = (char4*)v_buf_; + const int32_t* data_ptr = V; + char4* buf_ptr4 = (char4*)v_buf_; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; + int head_id = blockIdx.z % head_num; int blockIdy32 = (blockIdx.y << 5); int blockIdx32 = (blockIdx.x << 5); - int word_id = blockIdy32 + threadIdx.y; + int word_id = blockIdy32 + threadIdx.y; int id_in_size = blockIdx32 + threadIdx4; // for V layout (batch_size*seq_len, head_num*size_per_head) - int col = head_id * size_per_head + id_in_size; - int row = batch_id * seq_len + word_id; + int col = head_id * size_per_head + id_in_size; + int row = batch_id * seq_len + word_id; int inIdx = ((col & 0xffffffe0) * batch_size * seq_len + ((row << 5) + (col & 31))); // for shm row-major int sh_col = threadIdx4; @@ -1963,20 +1963,20 @@ __global__ void add_V_bias_transform_varlen(int8_t* v_buf_, if (word_id < seq_len) { const T* bias_ptr = V_bias; - scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; - scale = scale + static_cast(__ldg(bias_ptr + col)); + scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; + scale = scale + static_cast(__ldg(bias_ptr + col)); shm[sh_row][sh_col] = float_to_int8_rn(scale * out_scale); - scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; - scale = scale + static_cast(__ldg(bias_ptr + col + 1)); + scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; + scale = scale + static_cast(__ldg(bias_ptr + col + 1)); shm[sh_row][sh_col + 1] = float_to_int8_rn(scale * out_scale); - scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; - scale = scale + static_cast(__ldg(bias_ptr + col + 2)); + scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; + scale = scale + static_cast(__ldg(bias_ptr + col + 2)); shm[sh_row][sh_col + 2] = float_to_int8_rn(scale * out_scale); - scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; - scale = scale + static_cast(__ldg(bias_ptr + col + 3)); + scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; + scale = scale + static_cast(__ldg(bias_ptr + col + 3)); shm[sh_row][sh_col + 3] = float_to_int8_rn(scale * out_scale); } else { @@ -1986,14 +1986,14 @@ __global__ void add_V_bias_transform_varlen(int8_t* v_buf_, __syncthreads(); // for dst of (size_per_head, seq_len) - word_id = blockIdy32 + threadIdx4; + word_id = blockIdy32 + threadIdx4; id_in_size = blockIdx32 + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col @@ -2014,26 +2014,26 @@ __global__ void add_V_bias_transform_varlen(int8_t* v_buf_, } char4 dataTmp; - dataTmp.x = shm[sh_col][sh_row]; - dataTmp.y = shm[sh_col + 1][sh_row]; - dataTmp.z = shm[sh_col + 2][sh_row]; - dataTmp.w = shm[sh_col + 3][sh_row]; + dataTmp.x = shm[sh_col][sh_row]; + dataTmp.y = shm[sh_col + 1][sh_row]; + dataTmp.z = shm[sh_col + 2][sh_row]; + dataTmp.w = shm[sh_col + 3][sh_row]; buf_ptr4[(blockIdx.z * stride + (col << 5) * size_per_head + row) >> 2] = dataTmp; } template -void invokeAddVBiasTransform(int8_t* v_buf, +void invokeAddVBiasTransform(int8_t* v_buf, const int32_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { if (seq_len % 32 == 0) { add_V_bias_transform<<>>( @@ -2070,31 +2070,31 @@ void invokeAddVBiasTransform(int8_t* v_buf, } } -template void invokeAddVBiasTransform(int8_t* v_buf, +template void invokeAddVBiasTransform(int8_t* v_buf, const int32_t* V, - const float* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); - -template void invokeAddVBiasTransform(int8_t* v_buf, + const float* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); + +template void invokeAddVBiasTransform(int8_t* v_buf, const int32_t* V, - const half* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const half* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); // input matrix a matrix of m = batch_size*seq_len , n = head_num*size_per_head, CUBLASLT_ORDER_COL32 // seq_len_padded = (seq_len+31)/32*32 @@ -2105,43 +2105,43 @@ template void invokeAddVBiasTransform(int8_t* v_buf, // using char4 // per tensor quantization for weight template -__global__ void add_V_bias_transform_varlen(int8_t* v_buf_, +__global__ void add_V_bias_transform_varlen(int8_t* v_buf_, const int8_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int seq_len_padded, - int stride, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int seq_len_padded, + int stride, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { - const float input_deQFactor = __ldg(input_deQFactor_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float input_deQFactor = __ldg(input_deQFactor_ptr); + const float out_scale = __ldg(out_scale_ptr); __shared__ int8_t shm[32][33]; - const char4* data_ptr = (const char4*)V; - char4* buf_ptr4 = (char4*)v_buf_; - const T* bias_ptr = V_bias; + const char4* data_ptr = (const char4*)V; + char4* buf_ptr4 = (char4*)v_buf_; + const T* bias_ptr = V_bias; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) - int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; - int word_id = (blockIdx.y << 5) + threadIdx.y; + int batch_id = blockIdx.z / head_num; + int head_id = blockIdx.z % head_num; + int word_id = (blockIdx.y << 5) + threadIdx.y; int id_in_size = (blockIdx.x << 5) + threadIdx4; int col, row; // for shm row-major - int sh_col = threadIdx4; - int sh_row = threadIdx.y; + int sh_col = threadIdx4; + int sh_row = threadIdx.y; char4 dataTmp; if (word_id < seq_len) { // for V layout (batch_size*seq_len, head_num*size_per_head) - col = head_id * size_per_head + id_in_size; - row = batch_id * seq_len + word_id; + col = head_id * size_per_head + id_in_size; + row = batch_id * seq_len + word_id; int inIdx = (((col >> 5) << 5) * batch_size * seq_len + ((row << 5) + (col & 31))) >> 2; float tmp; @@ -2149,20 +2149,20 @@ __global__ void add_V_bias_transform_varlen(int8_t* v_buf_, dataTmp = __ldg(data_ptr + inIdx); - scale = dataTmp.x * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col)); //(tmp2.x); + scale = dataTmp.x * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col)); //(tmp2.x); shm[sh_row][sh_col] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.y * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); //(tmp2.y); + scale = dataTmp.y * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); //(tmp2.y); shm[sh_row][sh_col + 1] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.z * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); //(tmp2.x); + scale = dataTmp.z * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); //(tmp2.x); shm[sh_row][sh_col + 2] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.w * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); //(tmp2.y); + scale = dataTmp.w * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); //(tmp2.y); shm[sh_row][sh_col + 3] = float_to_int8_rn(tmp * out_scale); } else { @@ -2172,14 +2172,14 @@ __global__ void add_V_bias_transform_varlen(int8_t* v_buf_, __syncthreads(); // for dst of (size_per_head, seq_len_padded) - word_id = (blockIdx.y << 5) + threadIdx4; + word_id = (blockIdx.y << 5) + threadIdx4; id_in_size = (blockIdx.x << 5) + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col @@ -2199,25 +2199,25 @@ __global__ void add_V_bias_transform_varlen(int8_t* v_buf_, (word_id & 3)); } - dataTmp.x = shm[sh_col][sh_row]; - dataTmp.y = shm[sh_col + 1][sh_row]; - dataTmp.z = shm[sh_col + 2][sh_row]; - dataTmp.w = shm[sh_col + 3][sh_row]; + dataTmp.x = shm[sh_col][sh_row]; + dataTmp.y = shm[sh_col + 1][sh_row]; + dataTmp.z = shm[sh_col + 2][sh_row]; + dataTmp.w = shm[sh_col + 3][sh_row]; buf_ptr4[(blockIdx.z * stride + (col << 5) * size_per_head + row) >> 2] = dataTmp; } template -void invokeAddVBiasTransform(int8_t* v_buf, +void invokeAddVBiasTransform(int8_t* v_buf, const int8_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { assert(size_per_head % 32 == 0); if (seq_len % 32 == 0) { @@ -2257,29 +2257,29 @@ void invokeAddVBiasTransform(int8_t* v_buf, } } -template void invokeAddVBiasTransform(int8_t* v_buf, +template void invokeAddVBiasTransform(int8_t* v_buf, const int8_t* V, - const float* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); - -template void invokeAddVBiasTransform(int8_t* v_buf, + const float* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); + +template void invokeAddVBiasTransform(int8_t* v_buf, const int8_t* V, - const half* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const half* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); // input matrix a matrix of m = batch_size*seq_len , n = head_num*size_per_head, row major // seq_len_padded = (seq_len+31)/32*32 // output matrixes are a series of sub-matrixes with size of m = size_per_head, n = seq_len_padded , @@ -2289,44 +2289,44 @@ template void invokeAddVBiasTransform(int8_t* v_buf, // using char4 // per tensor quantization for weight template -__global__ void add_V_bias_transform_varlen_row(int8_t* v_buf_, +__global__ void add_V_bias_transform_varlen_row(int8_t* v_buf_, const int8_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int seq_len_padded, - int stride, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - const int head_num_x_size_per_head) + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int seq_len_padded, + int stride, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + const int head_num_x_size_per_head) { - const float input_deQFactor = __ldg(input_deQFactor_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float input_deQFactor = __ldg(input_deQFactor_ptr); + const float out_scale = __ldg(out_scale_ptr); __shared__ int8_t shm[32][33]; - const char4* data_ptr = (const char4*)V; - char4* buf_ptr4 = (char4*)v_buf_; - const T* bias_ptr = V_bias; + const char4* data_ptr = (const char4*)V; + char4* buf_ptr4 = (char4*)v_buf_; + const T* bias_ptr = V_bias; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) - int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; - int word_id = (blockIdx.y << 5) + threadIdx.y; + int batch_id = blockIdx.z / head_num; + int head_id = blockIdx.z % head_num; + int word_id = (blockIdx.y << 5) + threadIdx.y; int id_in_size = (blockIdx.x << 5) + threadIdx4; int col, row; // for shm row-major - int sh_col = threadIdx4; - int sh_row = threadIdx.y; + int sh_col = threadIdx4; + int sh_row = threadIdx.y; char4 dataTmp; if (word_id < seq_len) { // for V layout (batch_size*seq_len, head_num*size_per_head) - col = head_id * size_per_head + id_in_size; - row = batch_id * seq_len + word_id; + col = head_id * size_per_head + id_in_size; + row = batch_id * seq_len + word_id; int inIdx = (row * head_num_x_size_per_head + col) >> 2; float tmp; @@ -2334,20 +2334,20 @@ __global__ void add_V_bias_transform_varlen_row(int8_t* v_buf_, dataTmp = __ldg(data_ptr + inIdx); - scale = dataTmp.x * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col)); //(tmp2.x); + scale = dataTmp.x * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col)); //(tmp2.x); shm[sh_row][sh_col] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.y * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); //(tmp2.y); + scale = dataTmp.y * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); //(tmp2.y); shm[sh_row][sh_col + 1] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.z * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); //(tmp2.x); + scale = dataTmp.z * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); //(tmp2.x); shm[sh_row][sh_col + 2] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.w * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); //(tmp2.y); + scale = dataTmp.w * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); //(tmp2.y); shm[sh_row][sh_col + 3] = float_to_int8_rn(tmp * out_scale); } else { @@ -2357,14 +2357,14 @@ __global__ void add_V_bias_transform_varlen_row(int8_t* v_buf_, __syncthreads(); // for dst of (size_per_head, seq_len_padded) - word_id = (blockIdx.y << 5) + threadIdx4; + word_id = (blockIdx.y << 5) + threadIdx4; id_in_size = (blockIdx.x << 5) + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + (((((((row_in_tile % 8) / 2) * 4) + (row_in_tile / 8)) * 2) + (row_in_tile % 2)) * 32) + col_in_tile @@ -2385,25 +2385,25 @@ __global__ void add_V_bias_transform_varlen_row(int8_t* v_buf_, (word_id & 3)); } - dataTmp.x = shm[sh_col][sh_row]; - dataTmp.y = shm[sh_col + 1][sh_row]; - dataTmp.z = shm[sh_col + 2][sh_row]; - dataTmp.w = shm[sh_col + 3][sh_row]; + dataTmp.x = shm[sh_col][sh_row]; + dataTmp.y = shm[sh_col + 1][sh_row]; + dataTmp.z = shm[sh_col + 2][sh_row]; + dataTmp.w = shm[sh_col + 3][sh_row]; buf_ptr4[(blockIdx.z * stride + (col << 5) * size_per_head + row) >> 2] = dataTmp; } template -void invokeAddVBiasTransformRow(int8_t* v_buf, +void invokeAddVBiasTransformRow(int8_t* v_buf, const int8_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { assert(size_per_head % 32 == 0); if (seq_len % 32 == 0) { @@ -2445,29 +2445,29 @@ void invokeAddVBiasTransformRow(int8_t* v_buf, } } -template void invokeAddVBiasTransformRow(int8_t* v_buf, +template void invokeAddVBiasTransformRow(int8_t* v_buf, const int8_t* V, - const float* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); - -template void invokeAddVBiasTransformRow(int8_t* v_buf, + const float* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); + +template void invokeAddVBiasTransformRow(int8_t* v_buf, const int8_t* V, - const half* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const half* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); // add bias into V & rebuild padding // input matrix a matrix of m = valid_word_num, n = head_num*size_per_head, CUBLASLT_ORDER_COL32 @@ -2478,32 +2478,32 @@ template void invokeAddVBiasTransformRow(int8_t* v_buf, // using char4 // per axis quantization for weight template -__global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, +__global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, const int32_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { __shared__ int8_t shm[32][33]; - const int32_t* data_ptr = V; - char4* buf_ptr4 = (char4*)v_buf_; - const T* bias_ptr = V_bias; + const int32_t* data_ptr = V; + char4* buf_ptr4 = (char4*)v_buf_; + const T* bias_ptr = V_bias; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) - int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; - int word_id = (blockIdx.y << 5) + threadIdx.y; + int batch_id = blockIdx.z / head_num; + int head_id = blockIdx.z % head_num; + int word_id = (blockIdx.y << 5) + threadIdx.y; int id_in_size = (blockIdx.x << 5) + threadIdx4; // for shm row-major @@ -2515,29 +2515,29 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, int row = __ldg(sequence_id_map + batch_id * seq_len + word_id); if (row != -1) { - col = head_id * size_per_head + id_in_size; + col = head_id * size_per_head + id_in_size; int inIdx = ((col & 0xffffffe0) * valid_word_num + ((row << 5) + (col & 31))); float tmp; float scale; const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); - const float out_scale = __ldg(out_scale_ptr); + const float out_scale = __ldg(out_scale_ptr); - scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; - tmp = scale + static_cast(__ldg(bias_ptr + col)); + scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; + tmp = scale + static_cast(__ldg(bias_ptr + col)); shm[sh_row][sh_col] = float_to_int8_rn(tmp * out_scale); - scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; - tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); + scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; + tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); shm[sh_row][sh_col + 1] = float_to_int8_rn(tmp * out_scale); - scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; - tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); + scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; + tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); shm[sh_row][sh_col + 2] = float_to_int8_rn(tmp * out_scale); - scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; - tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); + scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; + tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); shm[sh_row][sh_col + 3] = float_to_int8_rn(tmp * out_scale); } else { @@ -2552,14 +2552,14 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, dataTmp.w = shm[sh_col + 3][sh_row]; // for dst of (size_per_head, seq_len) - word_id = (blockIdx.y << 5) + threadIdx4; + word_id = (blockIdx.y << 5) + threadIdx4; id_in_size = (blockIdx.x << 5) + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col @@ -2583,34 +2583,34 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, } template<> -__global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, +__global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, const int32_t* V, - const half* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const half* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { __shared__ int8_t shm[32][33]; - const int32_t* data_ptr = V; - char4* buf_ptr4 = (char4*)v_buf_; + const int32_t* data_ptr = V; + char4* buf_ptr4 = (char4*)v_buf_; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; + int head_id = blockIdx.z % head_num; int blockIdy32 = (blockIdx.y << 5); int blockIdx32 = (blockIdx.x << 5); - int word_id = blockIdy32 + threadIdx.y; + int word_id = blockIdy32 + threadIdx.y; int id_in_size = blockIdx32 + threadIdx4; // for shm row-major @@ -2623,33 +2623,33 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, if (row >= 0) { const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); - const float out_scale = __ldg(out_scale_ptr); - col = head_id * size_per_head + id_in_size; - int inIdx = ((col & 0xffffffe0) * valid_word_num + ((row << 5) + (col & 31))); - int col_2 = col >> 1; + const float out_scale = __ldg(out_scale_ptr); + col = head_id * size_per_head + id_in_size; + int inIdx = ((col & 0xffffffe0) * valid_word_num + ((row << 5) + (col & 31))); + int col_2 = col >> 1; float scale; const half2* bias_ptr2 = (const half2*)V_bias; - half2 tmp2; + half2 tmp2; tmp2 = __ldg(bias_ptr2 + col_2); - scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.x); + scale = __ldg(data_ptr + inIdx) * __ldg(weight_amax + col) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.x); shm[sh_row][sh_col] = float_to_int8_rn(scale * out_scale); - scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.y); + scale = __ldg(data_ptr + inIdx + 1) * __ldg(weight_amax + col + 1) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.y); shm[sh_row][sh_col + 1] = float_to_int8_rn(scale * out_scale); tmp2 = __ldg(bias_ptr2 + col_2 + 1); - scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.x); + scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.x); shm[sh_row][sh_col + 2] = float_to_int8_rn(scale * out_scale); - scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.y); + scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.y); shm[sh_row][sh_col + 3] = float_to_int8_rn(scale * out_scale); } else { @@ -2664,14 +2664,14 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, dataTmp.w = shm[sh_col + 3][sh_row]; // for dst of (size_per_head, seq_len) - word_id = blockIdy32 + threadIdx4; + word_id = blockIdy32 + threadIdx4; id_in_size = blockIdx32 + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col @@ -2695,20 +2695,20 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, } template -void invokeAddVBiasTransformRebuildPadding(int8_t* v_buf, +void invokeAddVBiasTransformRebuildPadding(int8_t* v_buf, const int32_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { add_V_bias_transform_rebuild_padding<< -__global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, +__global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, const int8_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - int stride, - const float* deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + int stride, + const float* deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { __shared__ int8_t shm[32][33]; - const char4* data_ptr = (const char4*)V; - char4* buf_ptr4 = (char4*)v_buf_; - const T* bias_ptr = V_bias; + const char4* data_ptr = (const char4*)V; + char4* buf_ptr4 = (char4*)v_buf_; + const T* bias_ptr = V_bias; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) - int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; - int word_id = (blockIdx.y << 5) + threadIdx.y; + int batch_id = blockIdx.z / head_num; + int head_id = blockIdx.z % head_num; + int word_id = (blockIdx.y << 5) + threadIdx.y; int id_in_size = (blockIdx.x << 5) + threadIdx4; // for shm row-major @@ -2801,7 +2801,7 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, int row = __ldg(sequence_id_map + batch_id * seq_len + word_id); if (row != -1) { - col = head_id * size_per_head + id_in_size; + col = head_id * size_per_head + id_in_size; int inIdx = ((col & 0xffffffe0) * valid_word_num + ((row << 5) + (col & 31))) >> 2; float tmp; @@ -2812,20 +2812,20 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, char4 dataTmp = __ldg(data_ptr + inIdx); - scale = dataTmp.x * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col)); + scale = dataTmp.x * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col)); shm[sh_row][sh_col] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.y * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); + scale = dataTmp.y * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); shm[sh_row][sh_col + 1] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.z * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); + scale = dataTmp.z * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); shm[sh_row][sh_col + 2] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.w * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); + scale = dataTmp.w * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); shm[sh_row][sh_col + 3] = float_to_int8_rn(tmp * out_scale); } else { @@ -2840,14 +2840,14 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, dataTmp.w = shm[sh_col + 3][sh_row]; // for dst of (size_per_head, seq_len) - word_id = (blockIdx.y << 5) + threadIdx4; + word_id = (blockIdx.y << 5) + threadIdx4; id_in_size = (blockIdx.x << 5) + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col @@ -2877,32 +2877,32 @@ __global__ void add_V_bias_transform_rebuild_padding(int8_t* v_buf_, // multiple of 32 grid = (size_per_head/32, seq_len_padded/32, batch_size*head_num) block = (8, 32); using char4 per // tensor quantization for weight template -__global__ void add_V_bias_transform_rebuild_padding_varlen(int8_t* v_buf_, +__global__ void add_V_bias_transform_rebuild_padding_varlen(int8_t* v_buf_, const int8_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int seq_len_padded, - const int head_num, - const int size_per_head, - int stride, - const float* deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4) + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int seq_len_padded, + const int head_num, + const int size_per_head, + int stride, + const float* deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4) { __shared__ int8_t shm[32][33]; - const char4* data_ptr = (const char4*)V; - char4* buf_ptr4 = (char4*)v_buf_; - const T* bias_ptr = V_bias; + const char4* data_ptr = (const char4*)V; + char4* buf_ptr4 = (char4*)v_buf_; + const T* bias_ptr = V_bias; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) - int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; - int word_id = (blockIdx.y << 5) + threadIdx.y; + int batch_id = blockIdx.z / head_num; + int head_id = blockIdx.z % head_num; + int word_id = (blockIdx.y << 5) + threadIdx.y; int id_in_size = (blockIdx.x << 5) + threadIdx4; // for shm row-major @@ -2914,7 +2914,7 @@ __global__ void add_V_bias_transform_rebuild_padding_varlen(int8_t* v_buf_, int row = word_id < seq_len ? __ldg(sequence_id_map + batch_id * seq_len + word_id) : -1; if (row != -1) { - col = head_id * size_per_head + id_in_size; + col = head_id * size_per_head + id_in_size; int inIdx = ((col & 0xffffffe0) * valid_word_num + ((row << 5) + (col & 31))) >> 2; float tmp; @@ -2925,20 +2925,20 @@ __global__ void add_V_bias_transform_rebuild_padding_varlen(int8_t* v_buf_, char4 dataTmp = __ldg(data_ptr + inIdx); - scale = dataTmp.x * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col)); + scale = dataTmp.x * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col)); shm[sh_row][sh_col] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.y * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); + scale = dataTmp.y * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); shm[sh_row][sh_col + 1] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.z * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); + scale = dataTmp.z * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); shm[sh_row][sh_col + 2] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.w * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); + scale = dataTmp.w * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); shm[sh_row][sh_col + 3] = float_to_int8_rn(tmp * out_scale); } else { @@ -2953,14 +2953,14 @@ __global__ void add_V_bias_transform_rebuild_padding_varlen(int8_t* v_buf_, dataTmp.w = shm[sh_col + 3][sh_row]; // for dst of (size_per_head, seq_len_padded) - word_id = (blockIdx.y << 5) + threadIdx4; + word_id = (blockIdx.y << 5) + threadIdx4; id_in_size = (blockIdx.x << 5) + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col @@ -2984,19 +2984,19 @@ __global__ void add_V_bias_transform_rebuild_padding_varlen(int8_t* v_buf_, } template -void invokeAddVBiasTransformRebuildPadding(int8_t* v_buf, +void invokeAddVBiasTransformRebuildPadding(int8_t* v_buf, const int8_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { int seq_len_padded = (seq_len + 31) / 32 * 32; add_V_bias_transform_rebuild_padding_varlen<< -__global__ void add_V_bias_transform_rebuild_padding_varlen_row(int8_t* v_buf_, +__global__ void add_V_bias_transform_rebuild_padding_varlen_row(int8_t* v_buf_, const int8_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int seq_len_padded, - const int head_num, - const int size_per_head, - int stride, - const float* deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - const int head_num_x_size_per_head) + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int seq_len_padded, + const int head_num, + const int size_per_head, + int stride, + const float* deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + const int head_num_x_size_per_head) { __shared__ int8_t shm[32][33]; - const char4* data_ptr = (const char4*)V; - char4* buf_ptr4 = (char4*)v_buf_; - const T* bias_ptr = V_bias; + const char4* data_ptr = (const char4*)V; + char4* buf_ptr4 = (char4*)v_buf_; + const T* bias_ptr = V_bias; int threadIdx4 = threadIdx.x << 2; // for src of (seq_len, size_per_head) - int batch_id = blockIdx.z / head_num; - int head_id = blockIdx.z % head_num; - int word_id = (blockIdx.y << 5) + threadIdx.y; + int batch_id = blockIdx.z / head_num; + int head_id = blockIdx.z % head_num; + int word_id = (blockIdx.y << 5) + threadIdx.y; int id_in_size = (blockIdx.x << 5) + threadIdx4; // for shm row-major @@ -3091,7 +3091,7 @@ __global__ void add_V_bias_transform_rebuild_padding_varlen_row(int8_t* v_buf_, int row = word_id < seq_len ? __ldg(sequence_id_map + batch_id * seq_len + word_id) : -1; if (row != -1) { - col = head_id * size_per_head + id_in_size; + col = head_id * size_per_head + id_in_size; int inIdx = (row * head_num_x_size_per_head + col) >> 2; float tmp; @@ -3102,20 +3102,20 @@ __global__ void add_V_bias_transform_rebuild_padding_varlen_row(int8_t* v_buf_, char4 dataTmp = __ldg(data_ptr + inIdx); - scale = dataTmp.x * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col)); + scale = dataTmp.x * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col)); shm[sh_row][sh_col] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.y * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); + scale = dataTmp.y * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 1)); shm[sh_row][sh_col + 1] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.z * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); + scale = dataTmp.z * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 2)); shm[sh_row][sh_col + 2] = float_to_int8_rn(tmp * out_scale); - scale = dataTmp.w * deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); + scale = dataTmp.w * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col + 3)); shm[sh_row][sh_col + 3] = float_to_int8_rn(tmp * out_scale); } else { @@ -3130,14 +3130,14 @@ __global__ void add_V_bias_transform_rebuild_padding_varlen_row(int8_t* v_buf_, dataTmp.w = shm[sh_col + 3][sh_row]; // for dst of (size_per_head, seq_len_padded) - word_id = (blockIdx.y << 5) + threadIdx4; + word_id = (blockIdx.y << 5) + threadIdx4; id_in_size = (blockIdx.x << 5) + threadIdx.y; - col = (word_id >> 5); + col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) { int row_in_tile = id_in_size & 31; int col_in_tile = word_id & 31; - row = ( + row = ( // COL32_2R_4R4 ((id_in_size >> 5) << 10) + (((((((row_in_tile % 8) / 2) * 4) + (row_in_tile / 8)) * 2) + (row_in_tile % 2)) * 32) + col_in_tile @@ -3162,19 +3162,19 @@ __global__ void add_V_bias_transform_rebuild_padding_varlen_row(int8_t* v_buf_, } template -void invokeAddVBiasTransformRebuildPaddingRow(int8_t* v_buf, +void invokeAddVBiasTransformRebuildPaddingRow(int8_t* v_buf, const int8_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream) + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream) { int seq_len_padded = (seq_len + 31) / 32 * 32; add_V_bias_transform_rebuild_padding_varlen_row<<< @@ -3198,32 +3198,32 @@ void invokeAddVBiasTransformRebuildPaddingRow(int8_t* v_buf, head_num * size_per_head); } -template void invokeAddVBiasTransformRebuildPaddingRow(int8_t* v_buf, +template void invokeAddVBiasTransformRebuildPaddingRow(int8_t* v_buf, const int8_t* V, - const float* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); - -template void invokeAddVBiasTransformRebuildPaddingRow(int8_t* v_buf, + const float* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); + +template void invokeAddVBiasTransformRebuildPaddingRow(int8_t* v_buf, const int8_t* V, - const half* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const half* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/unfused_attention_int8_kernels.h b/src/fastertransformer/kernels/unfused_attention_int8_kernels.h index 9e1dbf572..a81a46576 100644 --- a/src/fastertransformer/kernels/unfused_attention_int8_kernels.h +++ b/src/fastertransformer/kernels/unfused_attention_int8_kernels.h @@ -16,215 +16,215 @@ namespace fastertransformer { -void invokeMappingRemovePaddingData(const int batch_size, - const int seq_len, - const int valid_word_num, - int* mapping, - const int* sequence_id_offset, +void invokeMappingRemovePaddingData(const int batch_size, + const int seq_len, + const int valid_word_num, + int* mapping, + const int* sequence_id_offset, cudaStream_t stream); template -void invokeAddQKBiasTransform(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransform(int8_t* q_buf, + int8_t* k_buf, const int32_t* Q, - const T* bias_Q, + const T* bias_Q, const int32_t* K, - const T* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_weight_amax, - const float* q_input_deQFactor_div127_ptr, - const float* k_weight_amax, - const float* k_input_deQFactor_div127_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_weight_amax, + const float* q_input_deQFactor_div127_ptr, + const float* k_weight_amax, + const float* k_input_deQFactor_div127_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddQKBiasTransform(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransform(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddQKBiasTransformRow(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransformRow(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* bias_K, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddQKBiasTransformRebuildPadding(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransformRebuildPadding(int8_t* q_buf, + int8_t* k_buf, const int32_t* Q, - const T* bias_Q, + const T* bias_Q, const int32_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_weight_amax, - const float* q_input_deQFactor_div127_ptr, - const float* k_weight_amax, - const float* k_input_deQFactor_div127_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_weight_amax, + const float* q_input_deQFactor_div127_ptr, + const float* k_weight_amax, + const float* k_input_deQFactor_div127_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddQKBiasTransformRebuildPadding(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransformRebuildPadding(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_deQFactor_ptr, - const float* k_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_deQFactor_ptr, + const float* k_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddQKBiasTransformRebuildPaddingRow(int8_t* q_buf, - int8_t* k_buf, +void invokeAddQKBiasTransformRebuildPaddingRow(int8_t* q_buf, + int8_t* k_buf, const int8_t* Q, - const T* bias_Q, + const T* bias_Q, const int8_t* K, - const T* bias_K, - const int* sequence_id_offset, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* q_deQFactor_ptr, - const float* k_deQFactor_ptr, - const float* q_output_scale_ptr, - const float* k_output_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* bias_K, + const int* sequence_id_offset, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* q_deQFactor_ptr, + const float* k_deQFactor_ptr, + const float* q_output_scale_ptr, + const float* k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddVBiasTransform(int8_t* v_buf, +void invokeAddVBiasTransform(int8_t* v_buf, const int32_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddVBiasTransform(int8_t* v_buf, +void invokeAddVBiasTransform(int8_t* v_buf, const int8_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddVBiasTransformRow(int8_t* v_buf, +void invokeAddVBiasTransformRow(int8_t* v_buf, const int8_t* V, - const T* V_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* input_deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* V_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* input_deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddVBiasTransformRebuildPadding(int8_t* v_buf, +void invokeAddVBiasTransformRebuildPadding(int8_t* v_buf, const int32_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* weight_amax, - const float* input_deQFactor_div127_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* weight_amax, + const float* input_deQFactor_div127_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddVBiasTransformRebuildPadding(int8_t* v_buf, +void invokeAddVBiasTransformRebuildPadding(int8_t* v_buf, const int8_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); template -void invokeAddVBiasTransformRebuildPaddingRow(int8_t* v_buf, +void invokeAddVBiasTransformRebuildPaddingRow(int8_t* v_buf, const int8_t* V, - const T* V_bias, - const int* sequence_id_map, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const float* deQFactor_ptr, - const float* out_scale_ptr, - bool use_ORDER_COL32_2R_4R4, - cudaStream_t stream); + const T* V_bias, + const int* sequence_id_map, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const float* deQFactor_ptr, + const float* out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, + cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index f951e718a..1cbe44db1 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -43,25 +43,25 @@ __global__ void addQKVBiasTranspose(T* q_out, const int head_num, const int size_per_head) { - const int n = head_num * size_per_head; + const int n = head_num * size_per_head; const int batch_id = blockIdx.x; - const int word_id = blockIdx.y; - const int row_id = batch_id * seq_len + word_id; + const int word_id = blockIdx.y; + const int row_id = batch_id * seq_len + word_id; for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) { - const int head_id = col_id / size_per_head; - const int size_id = col_id % size_per_head; + const int head_id = col_id / size_per_head; + const int size_id = col_id % size_per_head; const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head + word_id * size_per_head + size_id; const int src_id = row_id * n + col_id; - q_out[target_id] = __ldg(&q_in[src_id]); - q_out[target_id] = q_out[target_id] + __ldg(&bias_q[col_id]); + q_out[target_id] = ldg(&q_in[src_id]); + q_out[target_id] = add(q_out[target_id], ldg(&bias_q[col_id])); - k_out[target_id] = __ldg(&k_in[src_id]); - k_out[target_id] = k_out[target_id] + __ldg(&bias_k[col_id]); + k_out[target_id] = ldg(&k_in[src_id]); + k_out[target_id] = add(k_out[target_id], ldg(&bias_k[col_id])); - v_out[target_id] = __ldg(&v_in[src_id]); - v_out[target_id] = v_out[target_id] + __ldg(&bias_v[col_id]); + v_out[target_id] = ldg(&v_in[src_id]); + v_out[target_id] = add(v_out[target_id], ldg(&bias_v[col_id])); } } @@ -77,42 +77,42 @@ __global__ void QKVTranspose(T* q_out, const int head_num, const int size_per_head) { - const int n = head_num * size_per_head; + const int n = head_num * size_per_head; const int batch_id = blockIdx.x; - const int word_id = blockIdx.y; - const int row_id = batch_id * seq_len + word_id; + const int word_id = blockIdx.y; + const int row_id = batch_id * seq_len + word_id; for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) { - const int head_id = col_id / size_per_head; - const int size_id = col_id % size_per_head; + const int head_id = col_id / size_per_head; + const int size_id = col_id % size_per_head; const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head + word_id * size_per_head + size_id; const int src_id = row_id * n + col_id; - q_out[target_id] = __ldg(&q_in[src_id]); - k_out[target_id] = __ldg(&k_in[src_id]); - v_out[target_id] = __ldg(&v_in[src_id]); + q_out[target_id] = ldg(&q_in[src_id]); + k_out[target_id] = ldg(&k_in[src_id]); + v_out[target_id] = ldg(&v_in[src_id]); } } template -void invokeAddQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* Q, - const T* bias_Q, - T* K, - const T* bias_K, - T* V, - const T* bias_V, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +void invokeAddQKVBiasTranspose(T* q_buf, + T* k_buf, + T* v_buf, + T* Q, + const T* bias_Q, + T* K, + const T* bias_K, + T* V, + const T* bias_V, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream) { const int k = head_num * size_per_head; - dim3 grid(batch_size, seq_len); - bool is_add_bias = bias_Q != nullptr; + dim3 grid(batch_size, seq_len); + bool is_add_bias = bias_Q != nullptr; if (sizeof(T) == 4 || k % 2 != 0) { dim3 block(min(k, 512)); if (is_add_bias) { @@ -126,94 +126,112 @@ void invokeAddQKVBiasTranspose(T* q_buf, sync_check_cuda_error(); } else { + using T2 = typename TypeConverter::Type; // fp16 to half2, bf16 to bf162 dim3 block(min(k / 2, 512)); if (is_add_bias) { - addQKVBiasTranspose<<>>((half2*)q_buf, - (half2*)k_buf, - (half2*)v_buf, - (const half2*)Q, - (const half2*)bias_Q, - (const half2*)K, - (const half2*)bias_K, - (const half2*)V, - (const half2*)bias_V, - batch_size, - seq_len, - head_num, - size_per_head / 2); + addQKVBiasTranspose<<>>((T2*)q_buf, + (T2*)k_buf, + (T2*)v_buf, + (const T2*)Q, + (const T2*)bias_Q, + (const T2*)K, + (const T2*)bias_K, + (const T2*)V, + (const T2*)bias_V, + batch_size, + seq_len, + head_num, + size_per_head / 2); } else { - QKVTranspose<<>>((half2*)q_buf, - (half2*)k_buf, - (half2*)v_buf, - (const half2*)Q, - (const half2*)K, - (const half2*)V, - batch_size, - seq_len, - head_num, - size_per_head / 2); + QKVTranspose<<>>((T2*)q_buf, + (T2*)k_buf, + (T2*)v_buf, + (const T2*)Q, + (const T2*)K, + (const T2*)V, + batch_size, + seq_len, + head_num, + size_per_head / 2); } sync_check_cuda_error(); } } -template void invokeAddQKVBiasTranspose(float* q_buf, - float* k_buf, - float* v_buf, - float* Q, +template void invokeAddQKVBiasTranspose(float* q_buf, + float* k_buf, + float* v_buf, + float* Q, const float* bias_Q, - float* K, + float* K, const float* bias_K, - float* V, + float* V, const float* bias_V, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream); -template void invokeAddQKVBiasTranspose(half* q_buf, - half* k_buf, - half* v_buf, - half* Q, - const half* bias_Q, - half* K, - const half* bias_K, - half* V, - const half* bias_V, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +template void invokeAddQKVBiasTranspose(half* q_buf, + half* k_buf, + half* v_buf, + half* Q, + const half* bias_Q, + half* K, + const half* bias_K, + half* V, + const half* bias_V, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddQKVBiasTranspose(__nv_bfloat16* q_buf, + __nv_bfloat16* k_buf, + __nv_bfloat16* v_buf, + __nv_bfloat16* Q, + const __nv_bfloat16* bias_Q, + __nv_bfloat16* K, + const __nv_bfloat16* bias_K, + __nv_bfloat16* V, + const __nv_bfloat16* bias_V, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); +#endif // TODO(bhsueh) Rename the softmax_kernel_v4 to softmax_kernel template -__global__ void softmax_kernel_v4(T* qk_buf_, - const T_IN* qk_buf_src, - const T* attr_mask, - const int batch_size, - const int head_num, - const int seq_len, - const T scalar) +__global__ void softmax_kernel_v4(T* qk_buf_, + const T_IN* qk_buf_src, // shape [batch_size, head_num, seq_len_1, seq_len_2] + const T* attr_mask, // shape [batch_size, seq_len_1, seq_len_2] + const int batch_size, + const int head_num, + const int seq_len_1, + const int seq_len_2, + const T scalar) { - for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { - float data[ITEMS_PER_THREAD]; - int qk_offset; + for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) { + float data[ITEMS_PER_THREAD]; + int qk_offset; __shared__ float s_mean, s_max; - float local_max = -1e20f; - for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + float local_max = -1e20f; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) { qk_offset = - ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * seq_len + blockDim.x * i + threadIdx.x; - int mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len + blockDim.x * i + threadIdx.x; + ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * seq_len_2 + blockDim.x * i + threadIdx.x; + int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * seq_len_2 + blockDim.x * i + threadIdx.x; - float qk = static_cast(qk_buf_src[qk_offset]); + float qk = static_cast(qk_buf_src[qk_offset]); float mask_val = static_cast(ldg(&attr_mask[mask_offset])); mask_val = (1.0f - mask_val) * -10000.0f; - data[i] = qk * static_cast(scalar) + mask_val; + data[i] = qk * static_cast(scalar) + mask_val; local_max = fmax(local_max, data[i]); } @@ -224,7 +242,7 @@ __global__ void softmax_kernel_v4(T* qk_buf_, __syncthreads(); float local_sum = 0; - for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) { data[i] = __expf(data[i] - s_max); local_sum += data[i]; } @@ -235,35 +253,40 @@ __global__ void softmax_kernel_v4(T* qk_buf_, } __syncthreads(); - for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) { qk_offset = - ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * seq_len + blockDim.x * i + threadIdx.x; + ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * seq_len_2 + blockDim.x * i + threadIdx.x; qk_buf_[qk_offset] = (T)(data[i] * s_mean); } } } template -__global__ void softmax_kernel_v4_half2( - T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) +__global__ void softmax_kernel_v4_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len_1, + const int seq_len_2, + const T scalar) { - using T2 = typename TypeConverter::Type; - T2* qk_buf_half2 = (T2*)qk_buf_; + using T2 = typename TypeConverter::Type; + T2* qk_buf_half2 = (T2*)qk_buf_; const T2* attr_mask_half2 = (const T2*)attr_mask; - for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { - T2 data[ITEMS_PER_THREAD]; - int qk_offset; + for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) { + T2 data[ITEMS_PER_THREAD]; + int qk_offset; __shared__ float s_mean, s_max; - float local_max = -1e20f; - for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { - qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (seq_len / 2) + blockDim.x * i + float local_max = -1e20f; + for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; i++) { + qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * (seq_len_2 / 2) + blockDim.x * i + threadIdx.x; - int mask_offset = (blockIdx.y * seq_len + seq_id) * (seq_len / 2) + blockDim.x * i + threadIdx.x; + int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * (seq_len_2 / 2) + blockDim.x * i + threadIdx.x; - T2 qk = qk_buf_half2[qk_offset]; + T2 qk = qk_buf_half2[qk_offset]; T2 mask_val = ldg(&attr_mask_half2[mask_offset]); - mask_val = hmul2(hsub2(float2type2(1.0f), mask_val), float2type2(-10000.0f)); + mask_val = hmul2(hsub2(float2type2(1.0f), mask_val), float2type2(-10000.0f)); data[i] = hadd2(hmul2(qk, type2type2(scalar)), mask_val); @@ -277,7 +300,7 @@ __global__ void softmax_kernel_v4_half2( __syncthreads(); float local_sum = 0; - for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { + for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; i++) { data[i] = hexp2(hsub2(data[i], float2type2(s_max))); local_sum += (float)(data[i].x + data[i].y); } @@ -290,8 +313,8 @@ __global__ void softmax_kernel_v4_half2( } __syncthreads(); - for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { - qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (seq_len / 2) + blockDim.x * i + for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; i++) { + qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * (seq_len_2 / 2) + blockDim.x * i + threadIdx.x; qk_buf_half2[qk_offset] = hmul2(data[i], float2type2(s_mean)); } @@ -299,55 +322,62 @@ __global__ void softmax_kernel_v4_half2( } template -__global__ void softmax_kernel_v5_half2( - T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) +__global__ void softmax_kernel_v5_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len_1, + const int seq_len_2, + const T scalar) { - using T2 = typename TypeConverter::Type; - T2* qk_buf_half2 = (T2*)qk_buf_; + using T2 = typename TypeConverter::Type; + T2* qk_buf_half2 = (T2*)qk_buf_; const T2* attr_mask_half2 = (const T2*)attr_mask; - for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x * NUM) { + for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x * NUM) { T2 data[NUM][ITEMS_PER_THREAD]; int qk_offset[NUM]; __shared__ float s_sum[NUM], s_max[NUM]; - float local_max[NUM]; + float local_max[NUM]; #pragma unroll for (int j = 0; j < NUM; j++) { local_max[j] = -1e20f; } - for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { + const int MAX_NUM = min((seq_len_1 - seq_id + gridDim.x - 1) / gridDim.x, NUM); + for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; i++) { int mask_offset[NUM]; #pragma unroll - for (int j = 0; j < NUM; j++) { - qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) - + blockDim.x * i + threadIdx.x; + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id + j * gridDim.x) * (seq_len_2 / 2) + + blockDim.x * i + threadIdx.x; mask_offset[j] = - (blockIdx.y * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) + blockDim.x * i + threadIdx.x; + (blockIdx.y * seq_len_1 + seq_id + j * gridDim.x) * (seq_len_2 / 2) + blockDim.x * i + threadIdx.x; } T2 mask_val[NUM]; #pragma unroll - for (int j = 0; j < NUM; j++) { + for (int j = 0; j < MAX_NUM; j++) { mask_val[j] = ldg(&attr_mask_half2[mask_offset[j]]); } T2 qk[NUM]; #pragma unroll - for (int j = 0; j < NUM; j++) { + for (int j = 0; j < MAX_NUM; j++) { qk[j] = qk_buf_half2[qk_offset[j]]; } #pragma unroll - for (int j = 0; j < NUM; j++) { + for (int j = 0; j < MAX_NUM; j++) { mask_val[j] = hmul2(hsub2(float2type2(1.0f), mask_val[j]), float2type2(-10000.0f)); } #pragma unroll - for (int j = 0; j < NUM; j++) { - data[j][i] = hadd2(hmul2(qk[j], type2type2(scalar)), mask_val[j]); + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = hadd2(hmul2(qk[j], type2type2(scalar)), mask_val[j]); local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y)); } } @@ -373,14 +403,14 @@ __global__ void softmax_kernel_v5_half2( local_sum[j] = {0.f}; } - for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { + for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; i++) { #pragma unroll - for (int j = 0; j < NUM; j++) { + for (int j = 0; j < MAX_NUM; j++) { data[j][i] = hexp2(hsub2(data[j][i], float2type2(s_max[j]))); } #pragma unroll - for (int j = 0; j < NUM; j++) { + for (int j = 0; j < MAX_NUM; j++) { local_sum[j] += (float)(data[j][i].x + data[j][i].y); } } @@ -400,15 +430,16 @@ __global__ void softmax_kernel_v5_half2( } __syncthreads(); - for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { + for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; i++) { #pragma unroll - for (int j = 0; j < NUM; j++) { - qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) - + blockDim.x * i + threadIdx.x; + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id + j * gridDim.x) * (seq_len_2 / 2) + + blockDim.x * i + threadIdx.x; } #pragma unroll - for (int j = 0; j < NUM; j++) { + for (int j = 0; j < MAX_NUM; j++) { qk_buf_half2[qk_offset[j]] = hmul2(data[j][i], float2type2(s_sum[j])); } } @@ -421,17 +452,27 @@ __global__ void softmax_kernel_v5_half2( if (is_half2) { \ if (grid.x % 4 == 0) { \ grid.x /= 4; \ - softmax_kernel_v5_half2<<>>( \ - (half*)buffer, (const half*)attr_mask, batch_size, head_num, seq_len, (const half)scalar); \ + softmax_kernel_v5_half2<<>>((half*)buffer, \ + (const half*)attr_mask, \ + batch_size, \ + head_num, \ + seq_len_1, \ + seq_len_2, \ + (const half)scalar); \ } \ else { \ - softmax_kernel_v4_half2<<>>( \ - (half*)buffer, (const half*)attr_mask, batch_size, head_num, seq_len, (const half)scalar); \ + softmax_kernel_v4_half2<<>>((half*)buffer, \ + (const half*)attr_mask, \ + batch_size, \ + head_num, \ + seq_len_1, \ + seq_len_2, \ + (const half)scalar); \ } \ } \ else { \ - softmax_kernel_v4 \ - <<>>(buffer, buffer_src, attr_mask, batch_size, head_num, seq_len, scalar); \ + softmax_kernel_v4<<>>( \ + buffer, buffer_src, attr_mask, batch_size, head_num, seq_len_1, seq_len_2, scalar); \ } #ifdef ENABLE_BF16 @@ -446,7 +487,8 @@ __global__ void softmax_kernel_v5_half2( (const __nv_bfloat16*)attr_mask, \ batch_size, \ head_num, \ - seq_len, \ + seq_len_1, \ + seq_len_2, \ (const __nv_bfloat16)scalar); \ } \ else { \ @@ -455,34 +497,37 @@ __global__ void softmax_kernel_v5_half2( (const __nv_bfloat16*)attr_mask, \ batch_size, \ head_num, \ - seq_len, \ + seq_len_1, \ + seq_len_2, \ (const __nv_bfloat16)scalar); \ } \ } \ else { \ - softmax_kernel_v4 \ - <<>>(buffer, buffer_src, attr_mask, batch_size, head_num, seq_len, scalar); \ + softmax_kernel_v4<<>>( \ + buffer, buffer_src, attr_mask, batch_size, head_num, seq_len_1, seq_len_2, scalar); \ } #endif // ENABLE_BF16 template -void invokeMaskedSoftMax(T* buffer, - const T_IN* buffer_src, - const T* attr_mask, - const int batch_size, - const int seq_len, - const int head_num, - const T scalar, +void invokeMaskedSoftMax(T* buffer, + const T_IN* buffer_src, + const T* attr_mask, + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const T scalar, cudaStream_t stream) { + // NOTE: attention scores shape (batch_size, head_num, seq_len_1, seq_len_2) - dim3 grid(seq_len, batch_size, head_num); + dim3 grid(seq_len_1, batch_size, head_num); if (batch_size * head_num > 360) { - grid.x = ceil(float(seq_len) / 32.0f); + grid.x = ceil(float(seq_len_1) / 32.0f); } - bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && seq_len % 2 == 0; - dim3 block((seq_len / (is_half2 ? 2 : 1) + 31) / 32 * 32); + bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && seq_len_2 % 2 == 0; + dim3 block((seq_len_2 / (is_half2 ? 2 : 1) + 31) / 32 * 32); if (block.x > 3072 && block.x <= 4096) { SOFTMAX_KERNEL(4) @@ -497,30 +542,31 @@ void invokeMaskedSoftMax(T* buffer, SOFTMAX_KERNEL(1) } else { - FT_CHECK(seq_len <= 4096); + FT_CHECK(seq_len_2 <= 4096); } } #ifdef ENABLE_BF16 template<> -void invokeMaskedSoftMax(__nv_bfloat16* buffer, +void invokeMaskedSoftMax(__nv_bfloat16* buffer, const __nv_bfloat16* buffer_src, const __nv_bfloat16* attr_mask, - const int batch_size, - const int seq_len, - const int head_num, - const __nv_bfloat16 scalar, - cudaStream_t stream) + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const __nv_bfloat16 scalar, + cudaStream_t stream) { using T_IN = __nv_bfloat16; - dim3 grid(seq_len, batch_size, head_num); + dim3 grid(seq_len_1, batch_size, head_num); if (batch_size * head_num > 360) { - grid.x = ceil(float(seq_len) / 32.0f); + grid.x = ceil(float(seq_len_1) / 32.0f); } - bool is_half2 = seq_len % 2 == 0; - dim3 block((seq_len / (is_half2 ? 2 : 1) + 31) / 32 * 32); + bool is_half2 = seq_len_2 % 2 == 0; + dim3 block((seq_len_2 / (is_half2 ? 2 : 1) + 31) / 32 * 32); if (block.x > 3072 && block.x <= 4096) { SOFTMAX_KERNEL_BF16(4) @@ -535,28 +581,29 @@ void invokeMaskedSoftMax(__nv_bfloat16* buffer, SOFTMAX_KERNEL_BF16(1) } else { - FT_CHECK(seq_len <= 4096); + FT_CHECK(seq_len_2 <= 4096); } } template<> -void invokeMaskedSoftMax(__nv_bfloat16* buffer, - const float* buffer_src, +void invokeMaskedSoftMax(__nv_bfloat16* buffer, + const float* buffer_src, const __nv_bfloat16* attr_mask, - const int batch_size, - const int seq_len, - const int head_num, - const __nv_bfloat16 scalar, - cudaStream_t stream) + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const __nv_bfloat16 scalar, + cudaStream_t stream) { using T_IN = float; - dim3 grid(seq_len, batch_size, head_num); + dim3 grid(seq_len_1, batch_size, head_num); if (batch_size * head_num > 360) { - grid.x = ceil(float(seq_len) / 32.0f); + grid.x = ceil(float(seq_len_1) / 32.0f); } bool is_half2 = false; - dim3 block((seq_len / (is_half2 ? 2 : 1) + 31) / 32 * 32); + dim3 block((seq_len_2 / (is_half2 ? 2 : 1) + 31) / 32 * 32); if (block.x > 3072 && block.x <= 4096) { SOFTMAX_KERNEL_BF16(4) @@ -571,65 +618,70 @@ void invokeMaskedSoftMax(__nv_bfloat16* buffer, SOFTMAX_KERNEL_BF16(1) } else { - FT_CHECK(seq_len <= 4096); + FT_CHECK(seq_len_2 <= 4096); } } #endif // ENABLE_BF16 -template void invokeMaskedSoftMax(float* buffer, +template void invokeMaskedSoftMax(float* buffer, const float* buffer_src, const float* attr_mask, - const int batch_size, - const int seq_len, - const int head_num, - const float scalar, + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const float scalar, cudaStream_t stream); -template void invokeMaskedSoftMax(half* buffer, +template void invokeMaskedSoftMax(half* buffer, const float* buffer_src, - const half* attr_mask, - const int batch_size, - const int seq_len, - const int head_num, - const half scalar, + const half* attr_mask, + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const half scalar, cudaStream_t stream); -template void invokeMaskedSoftMax(half* buffer, - const half* buffer_src, - const half* attr_mask, - const int batch_size, - const int seq_len, - const int head_num, - const half scalar, +template void invokeMaskedSoftMax(half* buffer, + const half* buffer_src, + const half* attr_mask, + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const half scalar, cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeMaskedSoftMax(__nv_bfloat16* buffer, +template void invokeMaskedSoftMax(__nv_bfloat16* buffer, const __nv_bfloat16* buffer_src, const __nv_bfloat16* attr_mask, - const int batch_size, - const int seq_len, - const int head_num, - const __nv_bfloat16 scalar, - cudaStream_t stream); - -template void invokeMaskedSoftMax(__nv_bfloat16* buffer, - const float* buffer_src, + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const __nv_bfloat16 scalar, + cudaStream_t stream); + +template void invokeMaskedSoftMax(__nv_bfloat16* buffer, + const float* buffer_src, const __nv_bfloat16* attr_mask, - const int batch_size, - const int seq_len, - const int head_num, - const __nv_bfloat16 scalar, - cudaStream_t stream); + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const __nv_bfloat16 scalar, + cudaStream_t stream); #endif // ENABLE_BF16 template __global__ void transpose(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head) { - int batch_id = blockIdx.x / (head_num * seq_len); - int seq_id = blockIdx.x % seq_len; - int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; + int batch_id = blockIdx.x / (head_num * seq_len); + int seq_id = blockIdx.x % seq_len; + int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; dst[batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head + head_id * size_per_head + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; } @@ -641,9 +693,9 @@ transpose(half* src, half* dst, const int batch_size, const int seq_len, const i int tid = blockIdx.x * blockDim.x + threadIdx.x; int batch_id = tid / (head_num * seq_len * size_per_head); - int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); - int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; - int id = tid % size_per_head; + int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); + int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; + int id = tid % size_per_head; int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head); @@ -657,9 +709,9 @@ transpose(half2* src, half2* dst, const int batch_size, const int seq_len, const int tid = blockIdx.x * blockDim.x + threadIdx.x; int batch_id = tid / (head_num * seq_len * size_per_head); - int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); - int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; - int id = tid % size_per_head; + int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); + int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; + int id = tid % size_per_head; int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head); @@ -670,17 +722,17 @@ transpose(half2* src, half2* dst, const int batch_size, const int seq_len, const template<> __global__ void transpose(__nv_bfloat16* src, __nv_bfloat16* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head) + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int batch_id = tid / (head_num * seq_len * size_per_head); - int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); - int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; - int id = tid % size_per_head; + int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); + int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; + int id = tid % size_per_head; int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head); @@ -690,17 +742,17 @@ __global__ void transpose(__nv_bfloat16* src, template<> __global__ void transpose(__nv_bfloat162* src, __nv_bfloat162* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head) + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int batch_id = tid / (head_num * seq_len * size_per_head); - int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); - int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; - int id = tid % size_per_head; + int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); + int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; + int id = tid % size_per_head; int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head); @@ -709,24 +761,24 @@ __global__ void transpose(__nv_bfloat162* src, #endif template -void invokeTransposeQKV(T* dst, - T* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +void invokeTransposeQKV(T* dst, + T* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream) { dim3 grid, block; if (sizeof(T) == 2) { int seq_per_block = 1; - grid.x = batch_size * head_num * seq_len / seq_per_block; + grid.x = batch_size * head_num * seq_len / seq_per_block; while (seq_per_block < 4 && grid.x % 2 == 0) { grid.x /= 2; seq_per_block *= 2; } - FT_CHECK(grid.x * seq_per_block == batch_size * head_num * seq_len); + FT_CHECK(grid.x * seq_per_block == (size_t)batch_size * head_num * seq_len); if (seq_per_block * size_per_head % 2 == 0) { block.x = seq_per_block * size_per_head / 2; @@ -748,124 +800,129 @@ void invokeTransposeQKV(T* dst, } else { const int seq_per_block = 1; - grid.x = batch_size * head_num * seq_len / seq_per_block; - block.x = seq_per_block * size_per_head; + grid.x = batch_size * head_num * seq_len / seq_per_block; + block.x = seq_per_block * size_per_head; transpose<<>>(src, dst, batch_size, seq_len, head_num, size_per_head); } } -template void invokeTransposeQKV(float* src, - float* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +template void invokeTransposeQKV(float* src, + float* dst, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream); -template void invokeTransposeQKV(half* src, - half* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +template void invokeTransposeQKV(half* src, + half* dst, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeTransposeQKV(__nv_bfloat16* src, __nv_bfloat16* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - cudaStream_t stream); + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); #endif template -__global__ void add_QKV_bias_rebuild_padding(const T* Q, - const T* bias_Q, - const T* K, - const T* bias_K, - const T* V, - const T* bias_V, - T* q_buf_, - T* k_buf_, - T* v_buf_, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +__global__ void add_QKV_bias_rebuild_padding(const T* Q, + const T* bias_Q, + const T* K, + const T* bias_K, + const T* V, + const T* bias_V, + T* q_buf_, + T* k_buf_, + T* v_buf_, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, const int* mask_offset) { const int bid = blockIdx.x; const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len; - const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len; - const int n = head_num * size_per_head; + const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len; + const int n = head_num * size_per_head; for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { - const int tgt_head_id = idx / size_per_head; + const int tgt_head_id = idx / size_per_head; const int tgt_hidden_id = idx % size_per_head; const int src_id = bid * n + idx; const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head + tgt_seq_id * size_per_head + tgt_hidden_id; - q_buf_[tgt_id] = __ldg(&Q[src_id]) + __ldg(&bias_Q[idx]); - k_buf_[tgt_id] = __ldg(&K[src_id]) + __ldg(&bias_K[idx]); - v_buf_[tgt_id] = __ldg(&V[src_id]) + __ldg(&bias_V[idx]); + q_buf_[tgt_id] = add(ldg(&Q[src_id]), ldg(&bias_Q[idx])); + k_buf_[tgt_id] = add(ldg(&K[src_id]), ldg(&bias_K[idx])); + v_buf_[tgt_id] = add(ldg(&V[src_id]), ldg(&bias_V[idx])); } } template -__global__ void rebuild_padding(const T* Q, - const T* K, - const T* V, - T* q_buf_, - T* k_buf_, - T* v_buf_, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +__global__ void rebuild_padding(const T* Q, + const T* K, + const T* V, + T* q_buf_, + T* k_buf_, + T* v_buf_, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, const int* mask_offset) { const int bid = blockIdx.x; const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len; - const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len; - const int n = head_num * size_per_head; + const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len; + const int n = head_num * size_per_head; for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { - const int tgt_head_id = idx / size_per_head; + const int tgt_head_id = idx / size_per_head; const int tgt_hidden_id = idx % size_per_head; const int src_id = bid * n + idx; const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head + tgt_seq_id * size_per_head + tgt_hidden_id; - q_buf_[tgt_id] = __ldg(&Q[src_id]); - k_buf_[tgt_id] = __ldg(&K[src_id]); - v_buf_[tgt_id] = __ldg(&V[src_id]); + q_buf_[tgt_id] = ldg(&Q[src_id]); + k_buf_[tgt_id] = ldg(&K[src_id]); + v_buf_[tgt_id] = ldg(&V[src_id]); } } template -void invokeAddQKVBiasRebuildPadding(T* Q, - const T* bias_Q, - T* K, - const T* bias_K, - T* V, - const T* bias_V, - T* q_buf, - T* k_buf, - T* v_buf, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int valid_word_num, - const int* mask_offset, +void invokeAddQKVBiasRebuildPadding(T* Q, + const T* bias_Q, + T* K, + const T* bias_K, + T* V, + const T* bias_V, + T* q_buf, + T* k_buf, + T* v_buf, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int valid_word_num, + const int* mask_offset, cudaStream_t stream) { - bool is_half2 = std::is_same::value && (size_per_head % 2 == 0); +#ifdef ENABLE_BF16 + bool is_half2 = (std::is_same::value || std::is_same::value) && (size_per_head % 2 == 0); +#else + bool is_half2 = (std::is_same::value) && (size_per_head % 2 == 0); +#endif + using T2 = typename TypeConverter::Type; // fp16 to half2, bf16 to bf162 int block_size = head_num * size_per_head; if (is_half2) { while (block_size > 512) { @@ -873,7 +930,7 @@ void invokeAddQKVBiasRebuildPadding(T* Q, block_size /= 2; } else { - is_half2 = false; + is_half2 = false; block_size = std::min(block_size, 512); break; } @@ -885,12 +942,12 @@ void invokeAddQKVBiasRebuildPadding(T* Q, if (bias_Q == nullptr && bias_K == nullptr && bias_V == nullptr) { if (is_half2) { - rebuild_padding<<>>((half2*)Q, - (half2*)K, - (half2*)V, - (half2*)q_buf, - (half2*)k_buf, - (half2*)v_buf, + rebuild_padding<<>>((T2*)Q, + (T2*)K, + (T2*)V, + (T2*)q_buf, + (T2*)k_buf, + (T2*)v_buf, batch_size, seq_len, head_num, @@ -904,15 +961,15 @@ void invokeAddQKVBiasRebuildPadding(T* Q, } else if (bias_Q != nullptr && bias_K != nullptr && bias_V != nullptr) { if (is_half2) { - add_QKV_bias_rebuild_padding<<>>((half2*)Q, - (const half2*)bias_Q, - (half2*)K, - (const half2*)bias_K, - (half2*)V, - (const half2*)bias_V, - (half2*)q_buf, - (half2*)k_buf, - (half2*)v_buf, + add_QKV_bias_rebuild_padding<<>>((T2*)Q, + (const T2*)bias_Q, + (T2*)K, + (const T2*)bias_K, + (T2*)V, + (const T2*)bias_V, + (T2*)q_buf, + (T2*)k_buf, + (T2*)v_buf, batch_size, seq_len, head_num, @@ -941,47 +998,66 @@ void invokeAddQKVBiasRebuildPadding(T* Q, } } -template void invokeAddQKVBiasRebuildPadding(float* Q, +template void invokeAddQKVBiasRebuildPadding(float* Q, const float* bias_Q, - float* K, + float* K, const float* bias_K, - float* V, + float* V, const float* bias_V, - float* q_buf, - float* k_buf, - float* v_buf, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int valid_word_num, - const int* mask_offset, + float* q_buf, + float* k_buf, + float* v_buf, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int valid_word_num, + const int* mask_offset, cudaStream_t stream); -template void invokeAddQKVBiasRebuildPadding(half* Q, - const half* bias_Q, - half* K, - const half* bias_K, - half* V, - const half* bias_V, - half* q_buf, - half* k_buf, - half* v_buf, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int valid_word_num, - const int* mask_offset, +template void invokeAddQKVBiasRebuildPadding(half* Q, + const half* bias_Q, + half* K, + const half* bias_K, + half* V, + const half* bias_V, + half* q_buf, + half* k_buf, + half* v_buf, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int valid_word_num, + const int* mask_offset, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddQKVBiasRebuildPadding(__nv_bfloat16* Q, + const __nv_bfloat16* bias_Q, + __nv_bfloat16* K, + const __nv_bfloat16* bias_K, + __nv_bfloat16* V, + const __nv_bfloat16* bias_V, + __nv_bfloat16* q_buf, + __nv_bfloat16* k_buf, + __nv_bfloat16* v_buf, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int valid_word_num, + const int* mask_offset, + cudaStream_t stream); +#endif + template -__global__ void transpose_remove_padding(const T* src, - T* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +__global__ void transpose_remove_padding(const T* src, + T* dst, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, const int* mask_offset) { // TODO: optimize this kernel? @@ -989,31 +1065,37 @@ __global__ void transpose_remove_padding(const T* src, const int bid = blockIdx.x; // batch * seq_len or valid_word_num const int src_batch_id = (bid + mask_offset[bid]) / seq_len; - const int src_seq_id = (bid + mask_offset[bid]) % seq_len; + const int src_seq_id = (bid + mask_offset[bid]) % seq_len; const int dst_seq_id = bid; + const int src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head; + const int dst_offset_base = dst_seq_id * head_num * size_per_head; + for (int idx = threadIdx.x; idx < head_num * size_per_head; idx += blockDim.x) { - const int head_id = idx / size_per_head; - const int hidden_id = idx % size_per_head; - dst[dst_seq_id * head_num * size_per_head + idx] = - __ldg(&src[src_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head - + src_seq_id * size_per_head + hidden_id]); + const int head_id = idx / size_per_head; + const int hidden_id = idx % size_per_head; + dst[dst_offset_base + idx] = ldg(&src[src_offset_base + head_id * seq_len * size_per_head + hidden_id]); } } template -void invokeTransposeAttentionOutRemovePadding(T* src, - T* dst, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int* mask_offset, +void invokeTransposeAttentionOutRemovePadding(T* src, + T* dst, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int* mask_offset, cudaStream_t stream) { - bool is_half2 = std::is_same::value && (size_per_head % 2 == 0); +#ifdef ENABLE_BF16 + bool is_half2 = (std::is_same::value || std::is_same::value) && (size_per_head % 2 == 0); +#else + bool is_half2 = (std::is_same::value) && (size_per_head % 2 == 0); +#endif + using T2 = typename TypeConverter::Type; // fp16 to half2, bf16 to bf162 int block_size = head_num * size_per_head; if (is_half2) { while (block_size > 512) { @@ -1021,7 +1103,7 @@ void invokeTransposeAttentionOutRemovePadding(T* src, block_size /= 2; } else { - is_half2 = false; + is_half2 = false; block_size = std::min(block_size, 1024); break; } @@ -1032,8 +1114,8 @@ void invokeTransposeAttentionOutRemovePadding(T* src, } if (is_half2) { - transpose_remove_padding<<>>( - (half2*)src, (half2*)dst, batch_size, seq_len, head_num, size_per_head / 2, mask_offset); + transpose_remove_padding<<>>( + (T2*)src, (T2*)dst, batch_size, seq_len, head_num, size_per_head / 2, mask_offset); } else { transpose_remove_padding<<>>( @@ -1041,25 +1123,36 @@ void invokeTransposeAttentionOutRemovePadding(T* src, } } -template void invokeTransposeAttentionOutRemovePadding(float* src, - float* dst, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int* mask_offset, +template void invokeTransposeAttentionOutRemovePadding(float* src, + float* dst, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int* mask_offset, cudaStream_t stream); -template void invokeTransposeAttentionOutRemovePadding(half* src, - half* dst, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int* mask_offset, +template void invokeTransposeAttentionOutRemovePadding(half* src, + half* dst, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int* mask_offset, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeTransposeAttentionOutRemovePadding(__nv_bfloat16* src, + __nv_bfloat16* dst, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int* mask_offset, + cudaStream_t stream); +#endif template __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, @@ -1067,31 +1160,32 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* v_buf, const T* __restrict QKV, const T* __restrict qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head) + const int* padding_offset, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { - // QKV: [m, 3, n] + // QKV: [token_num, 3, n] // qkv_bias: [3, n] // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] - T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; - const int n = head_num * size_per_head; - for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 3 * n; + T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; + const int n = head_num * size_per_head; + for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 3 * n; index += gridDim.x * blockDim.x) { int bias_id = index % (3 * n); - T val = ldg(&QKV[index]) + ldg(&qkv_bias[bias_id]); - - int tmp_index = index; - const int target_batch_id = tmp_index / (seq_len * 3 * n); - tmp_index -= target_batch_id * seq_len * 3 * n; - const int seq_id = tmp_index / (3 * n); - tmp_index -= seq_id * 3 * n; - const int qkv_id = tmp_index / n; - tmp_index -= qkv_id * n; - const int head_id = tmp_index / size_per_head; - const int size_id = tmp_index - head_id * size_per_head; + T val = ldg(&QKV[index]) + ldg(&qkv_bias[bias_id]); + + const int token_idx = index / (3 * n); + const int token_padded_idx = token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int target_batch_id = token_padded_idx / seq_len; + const int seq_id = token_padded_idx % seq_len; + + const int qkv_id = (index % (3 * n)) / n; + const int head_id = (index % n) / size_per_head; + const int size_id = index % size_per_head; qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head + seq_id * size_per_head + size_id] = val; @@ -1099,148 +1193,308 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, } template -struct Vec_t {}; +struct Vec_t { + static constexpr int size = 0; +}; + template<> struct Vec_t { - using Type = float2; + using Type = float2; + static constexpr int size = 2; }; + template<> struct Vec_t { - using Type = uint32_t; + using Type = uint32_t; + static constexpr int size = 2; }; #ifdef ENABLE_BF16 template<> struct Vec_t<__nv_bfloat16> { - using Type = __nv_bfloat162; + using Type = __nv_bfloat162; + static constexpr int size = 2; }; #endif -template -__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, - T* k_buf, - T* v_buf, +template +__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, + T* k_buf, + T* v_buf, + PrefixPromptBatchWeightsParam param, const T* __restrict QKV, const T* __restrict qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim) + const int* padding_offset, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const bool neox_rotary_style) { - using Vec_t = typename Vec_t::Type; - const int batch_idx = blockIdx.z; + // This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and + // QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head]. + // For q and k, also apply the rotary embedding. + + // When we pass prefix prompt, this kernel also concatenate the prefix prompt and key/value along + // seq_len dimension like [prompt, key/value]. + // So, the final shape of q is same ([batch_size, head_num, seq_len, size_per_head]), but + // the shapes of key and values become [batch_size, head_num, max_prefix_prompt_length + seq_len, size_per_head]. + + // NOTE: QKV src shape (batch_size, seq_len, 3, head_num, size_per_head) + // QKV dst shape (3, batch_size, head_num, seq_len, size_per_head) + extern __shared__ __align__(sizeof(float2)) char smem_[]; // align on largest vector type + + constexpr int vec_size = Vec_t::size; + using Vec_t = typename Vec_t::Type; + const int token_idx = blockIdx.x - batch_size * param.max_prefix_prompt_length; + const int token_padding_offset = (padding_offset == nullptr || token_idx < 0) ? 0 : padding_offset[token_idx]; + const int tgt_token_idx = token_idx + token_padding_offset; + + const int batch_idx = tgt_token_idx / seq_len; + const int seq_idx = tgt_token_idx % seq_len; + const int head_idx = blockIdx.y; - const int seq_idx = blockIdx.x; - const int tidx = threadIdx.x; - if (tidx * 2 >= size_per_head) { + const int tidx = threadIdx.x; + + const int total_seq_len = param.max_prefix_prompt_length + seq_len; + + const bool is_masked = tidx * vec_size >= size_per_head; + // NOTE: blockIdx.x < batch_size * param.max_prefix_prompt_length really handles prefix prompts + if (PREFIX_PROMPT && token_idx < 0) { + const int prompt_batch_idx = blockIdx.x / param.max_prefix_prompt_length; + const int prompt_seq_idx = blockIdx.x % param.max_prefix_prompt_length; + const int prompt_length = param.d_prefix_prompt_lengths[prompt_batch_idx]; + + if (prompt_seq_idx < prompt_length) { + const int dest_kv_idx = prompt_batch_idx * size_per_head * total_seq_len * head_num + + head_idx * size_per_head * total_seq_len + prompt_seq_idx * size_per_head + + tidx * vec_size; + const int prefix_kv_idx = + size_per_head * prompt_length * head_idx + size_per_head * prompt_seq_idx + tidx * vec_size; + + const T* prefix_prompt_k = param.d_prefix_prompt_batch[prompt_batch_idx] + + param.prefix_prompt_layer_offset_per_seq * prompt_length; + const T* prefix_prompt_v = prefix_prompt_k + prompt_length * head_num * size_per_head; + if (!is_masked) { + *reinterpret_cast(&k_buf[dest_kv_idx]) = + *reinterpret_cast(&prefix_prompt_k[prefix_kv_idx]); + *reinterpret_cast(&v_buf[dest_kv_idx]) = + *reinterpret_cast(&prefix_prompt_v[prefix_kv_idx]); + } + } return; } - const int batch_time_idx = seq_len * batch_idx + seq_idx; - const int hidden_idx = head_idx * size_per_head + tidx * 2; - const int n = head_num * size_per_head; + const int prefix_prompt_length = PREFIX_PROMPT ? param.d_prefix_prompt_lengths[batch_idx] : 0; + const int hidden_idx = head_idx * size_per_head + tidx * vec_size; + const int n = head_num * size_per_head; - // src QKV: [batch, time, 3, head, hidden] - const int q_idx = batch_time_idx * 3 * n + hidden_idx; - const int k_idx = batch_time_idx * 3 * n + hidden_idx + n; - const int v_idx = batch_time_idx * 3 * n + hidden_idx + 2 * n; + // the [0..seq_len) indices really handle KV [max_pp_len..seq_len+max_pp_len) + // and Q [0..seq_len) + // Note: if !PREFIX_PROMPT, max_pp_len = 0, so it's no-op + const int dst_kv_seq_idx = seq_idx + prefix_prompt_length; - Vec_t q = *reinterpret_cast(&QKV[q_idx]); - Vec_t k = *reinterpret_cast(&QKV[k_idx]); - Vec_t v = *reinterpret_cast(&QKV[v_idx]); + // NOTE: q has seq len excluding prefix prompt + const int batch_time_qkv_idx = seq_len * batch_idx + seq_idx; - // qkv_bias: [3, head, hidden] - Vec_t q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); - Vec_t k_bias = *reinterpret_cast(&qkv_bias[hidden_idx + n]); - Vec_t v_bias = *reinterpret_cast(&qkv_bias[hidden_idx + 2 * n]); + // src QKV: [batch, time, 3, head, hidden] + const int src_q_idx = batch_time_qkv_idx * 3 * n + hidden_idx; + const int src_k_idx = batch_time_qkv_idx * 3 * n + hidden_idx + n; + const int src_v_idx = batch_time_qkv_idx * 3 * n + hidden_idx + 2 * n; + + Vec_t q, k, v; + Vec_t q_bias, k_bias, v_bias; + if (!is_masked) { + q = *reinterpret_cast(&QKV[src_q_idx]); + k = *reinterpret_cast(&QKV[src_k_idx]); + v = *reinterpret_cast(&QKV[src_v_idx]); + + q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); + k_bias = *reinterpret_cast(&qkv_bias[hidden_idx + n]); + v_bias = *reinterpret_cast(&qkv_bias[hidden_idx + 2 * n]); + } q = mmha::add(q, q_bias); k = mmha::add(k, k_bias); v = mmha::add(v, v_bias); - mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, seq_idx); + if (!neox_rotary_style) { + mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, dst_kv_seq_idx); + } + else { + const bool do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim; - // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] - const int dest_idx = size_per_head * seq_len * head_num * batch_idx + size_per_head * seq_len * head_idx - + size_per_head * seq_idx + tidx * 2; + T* q_smem = reinterpret_cast(smem_); + T* k_smem = q_smem + rotary_embedding_dim; + + const int half_rotary_dim = rotary_embedding_dim / 2; + const int half_idx = (tidx * vec_size) / half_rotary_dim; + const int intra_half_idx = (tidx * vec_size) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts? + + if (do_rotary) { + *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; + *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; + } + + __syncthreads(); + + const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + constexpr int tidx_factor = vec_size / 2; + if (do_rotary) { + mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + + mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, rotary_embedding_dim, dst_kv_seq_idx); + + mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + } + + __syncthreads(); + + if (do_rotary) { + q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); + k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); + } + } + + const int dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len + + seq_idx * size_per_head + tidx * vec_size; - *reinterpret_cast(&q_buf[dest_idx]) = q; - *reinterpret_cast(&k_buf[dest_idx]) = k; - *reinterpret_cast(&v_buf[dest_idx]) = v; + const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * head_num + + head_idx * size_per_head * total_seq_len + dst_kv_seq_idx * size_per_head + + tidx * vec_size; + + if (!is_masked) { + *reinterpret_cast(&q_buf[dest_q_idx]) = q; + *reinterpret_cast(&k_buf[dest_kv_idx]) = k; + *reinterpret_cast(&v_buf[dest_kv_idx]) = v; + } } +#define FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, PREFIX_PROMPT) \ + add_fusedQKV_bias_transpose_kernel<<>>(q_buf, \ + k_buf, \ + v_buf, \ + param, \ + QKV, \ + qkv_bias, \ + padding_offset, \ + batch_size, \ + seq_len, \ + head_num, \ + size_per_head, \ + rotary_embedding_dim, \ + neox_rotary_style); + template -void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* QKV, - const T* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - cudaStream_t stream) +void invokeAddFusedQKVBiasTranspose(T* q_buf, + T* k_buf, + T* v_buf, + PrefixPromptBatchWeightsParam param, + T* QKV, + const T* qkv_bias, + const int* padding_offset, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const int neox_rotary_style, + cudaStream_t stream) { - if (rotary_embedding_dim == 0) { - const int m = batch_size * seq_len; + // [bs, seq_len, 3, head, Dh] + if (rotary_embedding_dim == 0 && param.max_prefix_prompt_length == 0) { + const int m = token_num; const int n = head_num * size_per_head; - dim3 block(384); - dim3 grid((int)(ceil(1.0 * m * n / 384))); - add_fusedQKV_bias_transpose_kernel<<>>( - q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); + dim3 block(384); + dim3 grid((int)(ceil(1.0 * m * n / 384))); + add_fusedQKV_bias_transpose_kernel<<>>(q_buf, + k_buf, + v_buf, + QKV, + qkv_bias, + padding_offset, + batch_size, + seq_len, + token_num, + head_num, + size_per_head); } else { // To implement rotary embeddings, each thread processes two QKV elems: - dim3 block((size_per_head / 2 + 31) / 32 * 32); - dim3 grid(seq_len, head_num, batch_size); - add_fusedQKV_bias_transpose_kernel<<>>( - q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, rotary_embedding_dim); + dim3 block((size_per_head / Vec_t::size + 31) / 32 * 32); + dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num); + size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0; + // NOTE: add offset for rotary embedding + // add_fusedQKV_bias_transpose_kernel<<>>( + // q_buf, k_buf, v_buf, param, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, + // rotary_embedding_dim); + if (param.max_prefix_prompt_length == 0) { + FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false); + } + else { + FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true); + } } } -template void invokeAddFusedQKVBiasTranspose(float* q_buf, - float* k_buf, - float* v_buf, - float* QKV, - const float* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - cudaStream_t stream); - -template void invokeAddFusedQKVBiasTranspose(half* q_buf, - half* k_buf, - half* v_buf, - half* QKV, - const half* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - cudaStream_t stream); +template void invokeAddFusedQKVBiasTranspose(float* q_buf, + float* k_buf, + float* v_buf, + PrefixPromptBatchWeightsParam param, + float* QKV, + const float* qkv_bias, + const int* padding_offset, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const int neox_rotary_style, + cudaStream_t stream); + +template void invokeAddFusedQKVBiasTranspose(half* q_buf, + half* k_buf, + half* v_buf, + PrefixPromptBatchWeightsParam param, + half* QKV, + const half* qkv_bias, + const int* padding_offset, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const int neox_rotary_style, + cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, - __nv_bfloat16* k_buf, - __nv_bfloat16* v_buf, - __nv_bfloat16* QKV, - const __nv_bfloat16* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - cudaStream_t stream); +template void invokeAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, + __nv_bfloat16* k_buf, + __nv_bfloat16* v_buf, + PrefixPromptBatchWeightsParam<__nv_bfloat16> param, + __nv_bfloat16* QKV, + const __nv_bfloat16* qkv_bias, + const int* padding_offset, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const int neox_rotary_style, + cudaStream_t stream); #endif template -__global__ void transpose_4d(T* dst, - T* src, +__global__ void transpose_4d(T* dst, + T* src, const int dim0, const int dim1, const int dim2, @@ -1251,22 +1505,22 @@ __global__ void transpose_4d(T* dst, // transpose from [dim0, dim1, dim2, dim3] to [dim2, X, dim1, dim3] // where the dimension of X is dim0_leading_dim, and offset is ite * dim0 for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * dim3; i += blockDim.x * gridDim.x) { - int index = i; - const int d3 = index % dim3; - index = (index - d3) / dim3; - const int d2 = index % dim2; - index = (index - d2) / dim2; - const int d1 = index % dim1; - index = (index - d1) / dim1; - const int d0 = index % dim0; - index = (index - d0) / dim0; + int index = i; + const int d3 = index % dim3; + index = (index - d3) / dim3; + const int d2 = index % dim2; + index = (index - d2) / dim2; + const int d1 = index % dim1; + index = (index - d1) / dim1; + const int d0 = index % dim0; + index = (index - d0) / dim0; dst[d2 * dim0_leading_dim * dim1 * dim3 + (d0 + dim0 * ite) * dim1 * dim3 + d1 * dim3 + d3] = src[i]; } } template<> -__global__ void transpose_4d(half* dst, - half* src, +__global__ void transpose_4d(half* dst, + half* src, const int dim0, const int dim1, const int dim2, @@ -1274,87 +1528,87 @@ __global__ void transpose_4d(half* dst, const int dim0_leading_dim, const int ite) { - half2* dst_ptr = (half2*)dst; - half2* src_ptr = (half2*)src; + half2* dst_ptr = (half2*)dst; + half2* src_ptr = (half2*)src; const int half_dim3 = dim3 / 2; // transpose from [dim0, dim1, dim2, half_dim3] to [dim2, dim0, dim1, half_dim3] // where the dimension of X is dim0_leading_dim, and offset is ite * dim0 for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * half_dim3; i += blockDim.x * gridDim.x) { - int index = i; - const int d3 = index % half_dim3; - index = (index - d3) / half_dim3; - const int d2 = index % dim2; - index = (index - d2) / dim2; - const int d1 = index % dim1; - index = (index - d1) / dim1; - const int d0 = index % dim0; - index = (index - d0) / dim0; + int index = i; + const int d3 = index % half_dim3; + index = (index - d3) / half_dim3; + const int d2 = index % dim2; + index = (index - d2) / dim2; + const int d1 = index % dim1; + index = (index - d1) / dim1; + const int d0 = index % dim0; + index = (index - d0) / dim0; dst_ptr[d2 * dim0_leading_dim * dim1 * half_dim3 + (d0 + dim0 * ite) * dim1 * half_dim3 + d1 * half_dim3 + d3] = src_ptr[i]; } } template -void invokeTranspose4d(T* dst, - T* src, - const int local_batch_size, - const int seq_len, - const int size_per_head, - const int local_hidden_units, - const int local_head_num, - const int batch_size, - const int ite, +void invokeTranspose4d(T* dst, + T* src, + const int local_batch_size, + const int seq_len, + const int size_per_head, + const int local_hidden_units, + const int local_head_num, + const int batch_size, + const int ite, cudaStream_t stream) { transpose_4d<<>>( dst, src, local_batch_size, local_head_num, seq_len, size_per_head, batch_size, ite); } -template void invokeTranspose4d(float* dst, - float* src, - const int local_batch_size, - const int seq_len, - const int size_per_head, - const int local_hidden_units, - const int local_head_num, - const int batch_size, - const int ite, +template void invokeTranspose4d(float* dst, + float* src, + const int local_batch_size, + const int seq_len, + const int size_per_head, + const int local_hidden_units, + const int local_head_num, + const int batch_size, + const int ite, cudaStream_t stream); -template void invokeTranspose4d(half* dst, - half* src, - const int local_batch_size, - const int seq_len, - const int size_per_head, - const int local_hidden_units, - const int local_head_num, - const int batch_size, - const int ite, +template void invokeTranspose4d(half* dst, + half* src, + const int local_batch_size, + const int seq_len, + const int size_per_head, + const int local_hidden_units, + const int local_head_num, + const int batch_size, + const int ite, cudaStream_t stream); template __global__ void transpose_4d_batch_major_k_cache( T* k_dst, const T* k_src, const int head_num, const int size_per_head, const int seq_len, const int max_seq_len) { - const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; auto key_src = reinterpret_cast(k_src + batch_id * head_num * size_per_head * seq_len + head_id * size_per_head * seq_len); auto key_dst = reinterpret_cast(k_dst + batch_id * head_num * size_per_head * max_seq_len + head_id * size_per_head * max_seq_len); - const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; - int size_per_head_div_x = size_per_head / X_ELEMS; + const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; if (out_idx >= size_per_head_div_x * max_seq_len) { return; } - int idx = out_idx; - const int k_seq_len_id = idx % max_seq_len; - idx = (idx - k_seq_len_id) / max_seq_len; + int idx = out_idx; + const int k_seq_len_id = idx % max_seq_len; + idx = (idx - k_seq_len_id) / max_seq_len; const int k_head_size_id = idx % size_per_head_div_x; if (k_seq_len_id < seq_len) { @@ -1367,7 +1621,7 @@ __global__ void transpose_4d_batch_major_v_cache( T* v_dst, const T* v_src, const int head_num, const int size_per_head, const int seq_len, const int max_seq_len) { const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; + const int head_id = blockIdx.z; // 16 byte loads will handle "x" dimension auto val_src = reinterpret_cast(v_src + batch_id * head_num * size_per_head * seq_len @@ -1378,8 +1632,8 @@ __global__ void transpose_4d_batch_major_v_cache( // idx is over output dimension L * size_per_head / x for values const int idx = blockIdx.x * blockDim.x + threadIdx.x; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - const int size_per_head_div_x = size_per_head / X_ELEMS; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + const int size_per_head_div_x = size_per_head / X_ELEMS; if (idx >= size_per_head_div_x * seq_len) { return; @@ -1389,22 +1643,22 @@ __global__ void transpose_4d_batch_major_v_cache( } template -void invokeTranspose4dBatchMajor(T* k_dst, - T* v_dst, - const T* k_src, - const T* v_src, - const int local_batch_size, - const int seq_len, - const int max_seq_len, - const int size_per_head, - const int local_head_num, +void invokeTranspose4dBatchMajor(T* k_dst, + T* v_dst, + const T* k_src, + const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, cudaStream_t stream) { constexpr int block_sz = 128; - constexpr int x = (sizeof(T) == 4) ? 4 : 8; - int size = max_seq_len * size_per_head / x; - dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_head_num); - dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num); + constexpr int x = (sizeof(T) == 4) ? 4 : 8; + int size = max_seq_len * size_per_head / x; + dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_head_num); + dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num); transpose_4d_batch_major_k_cache<<>>( k_dst, k_src, local_head_num, size_per_head, seq_len, max_seq_len); @@ -1413,39 +1667,39 @@ void invokeTranspose4dBatchMajor(T* k_dst, v_dst, v_src, local_head_num, size_per_head, seq_len, max_seq_len); } -template void invokeTranspose4dBatchMajor(float* k_dst, - float* v_dst, +template void invokeTranspose4dBatchMajor(float* k_dst, + float* v_dst, const float* k_src, const float* v_src, - const int local_batch_size, - const int seq_len, - const int max_seq_len, - const int size_per_head, - const int local_head_num, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, cudaStream_t stream); -template void invokeTranspose4dBatchMajor(half* k_dst, - half* v_dst, - const half* k_src, - const half* v_src, - const int local_batch_size, - const int seq_len, - const int max_seq_len, - const int size_per_head, - const int local_head_num, +template void invokeTranspose4dBatchMajor(half* k_dst, + half* v_dst, + const half* k_src, + const half* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeTranspose4dBatchMajor(__nv_bfloat16* k_dst, - __nv_bfloat16* v_dst, +template void invokeTranspose4dBatchMajor(__nv_bfloat16* k_dst, + __nv_bfloat16* v_dst, const __nv_bfloat16* k_src, const __nv_bfloat16* v_src, - const int local_batch_size, - const int seq_len, - const int max_seq_len, - const int size_per_head, - const int local_head_num, - cudaStream_t stream); + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + cudaStream_t stream); #endif template @@ -1454,44 +1708,35 @@ __global__ void addRelativeAttentionBias( { for (int i = threadIdx.x; i < batch_size * seq_len; i += blockDim.x) { int batch_id = i / seq_len; - int seq_id = i % seq_len; + int seq_id = i % seq_len; const int bias_index = blockIdx.x * seq_len + seq_id; - const int qk_index = batch_id * head_num * seq_len * seq_len + bias_index; - qk_buf[qk_index] = qk_buf[qk_index] + relative_attention_bias[bias_index]; + const int qk_index = batch_id * gridDim.x * seq_len + bias_index; + qk_buf[qk_index] = add(qk_buf[qk_index], relative_attention_bias[bias_index]); } } template -__global__ void addRelativeAttentionBias( - half2* qk_buf, const half2* relative_attention_bias, const int batch_size, const int head_num, const int seq_len) -{ - const int half2_seq_len = seq_len / 2; - for (int i = threadIdx.x; i < batch_size * half2_seq_len; i += blockDim.x) { - int batch_id = i / half2_seq_len; - int seq_id = i % half2_seq_len; - - const int bias_index = blockIdx.x * half2_seq_len + seq_id; - const int qk_index = batch_id * gridDim.x * half2_seq_len + bias_index; - qk_buf[qk_index] += relative_attention_bias[bias_index]; - } -} - -template -void invokeAddRelativeAttentionBias(T* qk_buf, - const T* relative_attention_bias, - const int batch_size, - const int head_num, - const int seq_len, +void invokeAddRelativeAttentionBias(T* qk_buf, + const T* relative_attention_bias, + const int batch_size, + const int head_num, + const int seq_len, cudaStream_t stream) { // qk_buf: [batch_size, head_num, seq_len, seq_len] // relative_attention_bias: [1, head_num, seq_len, seq_len] dim3 grid(head_num * seq_len); dim3 block(512); - if (std::is_same::value && seq_len % 2 == 0) { - addRelativeAttentionBias<<>>( - (half2*)qk_buf, (const half2*)relative_attention_bias, batch_size, head_num, seq_len); + using T2 = typename TypeConverter::Type; +#ifdef ENABLE_BF16 + const bool is_half2 = (std::is_same::value || std::is_same::value) && (seq_len % 2 == 0); +#else + const bool is_half2 = (std::is_same::value) && (seq_len % 2 == 0); +#endif + if (is_half2) { + addRelativeAttentionBias<<>>( + (T2*)qk_buf, (const T2*)relative_attention_bias, batch_size, head_num, seq_len / 2); } else { addRelativeAttentionBias<<>>( @@ -1499,19 +1744,27 @@ void invokeAddRelativeAttentionBias(T* qk_buf, } } -template void invokeAddRelativeAttentionBias(float* qk_buf, +template void invokeAddRelativeAttentionBias(float* qk_buf, const float* relative_attention_bias, - const int batch_size, - const int head_num, - const int seq_len, + const int batch_size, + const int head_num, + const int seq_len, cudaStream_t stream); -template void invokeAddRelativeAttentionBias(half* qk_buf, - const half* relative_attention_bias, - const int batch_size, - const int head_num, - const int seq_len, +template void invokeAddRelativeAttentionBias(half* qk_buf, + const half* relative_attention_bias, + const int batch_size, + const int head_num, + const int seq_len, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddRelativeAttentionBias(__nv_bfloat16* qk_buf, + const __nv_bfloat16* relative_attention_bias, + const int batch_size, + const int head_num, + const int seq_len, + cudaStream_t stream); +#endif /******************* invokeAddHead3SizeQKVBias ***********************/ // m = batch*window_num*window_len @@ -1521,11 +1774,11 @@ template void invokeAddRelativeAttentionBias(half* qk_buf, // grid(window_len, window_num, 3*batch); // block(num_head * size_per_head) template -__global__ void add_head3Size_QKV_bias(const T* mm_qkv, - const T* bias_qkv, - T* q_buf_, - T* k_buf_, - T* v_buf_, +__global__ void add_head3Size_QKV_bias(const T* mm_qkv, + const T* bias_qkv, + T* q_buf_, + T* k_buf_, + T* v_buf_, const int batch, const int window_num, const int window_len, @@ -1533,7 +1786,7 @@ __global__ void add_head3Size_QKV_bias(const T* mm_qkv, const int size_per_head) { - T* buf_ptr; + T* buf_ptr; int qkv_id = blockIdx.z / batch; if (qkv_id == 0) { buf_ptr = q_buf_; @@ -1545,14 +1798,14 @@ __global__ void add_head3Size_QKV_bias(const T* mm_qkv, buf_ptr = v_buf_; } - const int batch_id = blockIdx.z % batch; - const int token_id = blockIdx.x; - const int window_id = blockIdx.y; - const int head_id = threadIdx.x / size_per_head; + const int batch_id = blockIdx.z % batch; + const int token_id = blockIdx.x; + const int window_id = blockIdx.y; + const int head_id = threadIdx.x / size_per_head; const int id_in_head = threadIdx.x % size_per_head; const int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; - const T bias = __ldg(bias_qkv + bias_idx); + const T bias = ldg(bias_qkv + bias_idx); const int input_idx = ((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head + bias_idx; @@ -1574,18 +1827,18 @@ __global__ void add_head3Size_QKV_bias(const T* mm_qkv, template<> __global__ void add_head3Size_QKV_bias(const float2* mm_qkv, const float2* bias_qkv, - float2* q_buf_, - float2* k_buf_, - float2* v_buf_, - const int batch, - const int window_num, - const int window_len, - const int num_head, - const int size_per_head) + float2* q_buf_, + float2* k_buf_, + float2* v_buf_, + const int batch, + const int window_num, + const int window_len, + const int num_head, + const int size_per_head) { float2* buf_ptr; - int qkv_id = blockIdx.z / batch; + int qkv_id = blockIdx.z / batch; if (qkv_id == 0) { buf_ptr = q_buf_; } @@ -1596,14 +1849,14 @@ __global__ void add_head3Size_QKV_bias(const float2* mm_qkv, buf_ptr = v_buf_; } - const int batch_id = blockIdx.z % batch; - const int token_id = blockIdx.x; - const int window_id = blockIdx.y; - const int head_id = threadIdx.x / size_per_head; + const int batch_id = blockIdx.z % batch; + const int token_id = blockIdx.x; + const int window_id = blockIdx.y; + const int head_id = threadIdx.x / size_per_head; const int id_in_head = threadIdx.x % size_per_head; - const int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; - const float2 bias = __ldg(bias_qkv + bias_idx); + const int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; + const float2 bias = ldg(bias_qkv + bias_idx); const int input_idx = ((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head + bias_idx; @@ -1627,20 +1880,20 @@ __global__ void add_head3Size_QKV_bias(const float2* mm_qkv, template<> __global__ void add_head3Size_QKV_bias(const half2* mm_qkv, const half2* bias_qkv, - half2* q_buf_, - half2* k_buf_, - half2* v_buf_, - const int batch, - const int window_num, - const int window_len, - const int num_head, - const int size_per_head) + half2* q_buf_, + half2* k_buf_, + half2* v_buf_, + const int batch, + const int window_num, + const int window_len, + const int num_head, + const int size_per_head) { - const int batch_id = blockIdx.z; - const int token_id = blockIdx.x; - const int window_id = blockIdx.y; - const int head_id = threadIdx.x / size_per_head; + const int batch_id = blockIdx.z; + const int token_id = blockIdx.x; + const int window_id = blockIdx.y; + const int head_id = threadIdx.x / size_per_head; const int id_in_head = threadIdx.x % size_per_head; const int input_offset = @@ -1649,42 +1902,94 @@ __global__ void add_head3Size_QKV_bias(const half2* mm_qkv, (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head + id_in_head; - int qkv_id = 0; - int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; - half2 bias = __ldg(bias_qkv + bias_idx); - int input_idx = input_offset + bias_idx; - half2 tmp = mm_qkv[input_idx]; - tmp = __hadd2(tmp, bias); + int qkv_id = 0; + int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; + half2 bias = __ldg(bias_qkv + bias_idx); + int input_idx = input_offset + bias_idx; + half2 tmp = mm_qkv[input_idx]; + tmp = __hadd2(tmp, bias); q_buf_[target_id] = tmp; - qkv_id = 1; - bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; - bias = __ldg(bias_qkv + bias_idx); - input_idx = input_offset + bias_idx; - tmp = mm_qkv[input_idx]; - tmp = __hadd2(tmp, bias); + qkv_id = 1; + bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; + bias = __ldg(bias_qkv + bias_idx); + input_idx = input_offset + bias_idx; + tmp = mm_qkv[input_idx]; + tmp = __hadd2(tmp, bias); + k_buf_[target_id] = tmp; + + qkv_id = 2; + bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; + bias = __ldg(bias_qkv + bias_idx); + input_idx = input_offset + bias_idx; + tmp = mm_qkv[input_idx]; + tmp = __hadd2(tmp, bias); + v_buf_[target_id] = tmp; +} + +#ifdef ENABLE_BF16 +template<> +__global__ void add_head3Size_QKV_bias(const __nv_bfloat162* mm_qkv, + const __nv_bfloat162* bias_qkv, + __nv_bfloat162* q_buf_, + __nv_bfloat162* k_buf_, + __nv_bfloat162* v_buf_, + const int batch, + const int window_num, + const int window_len, + const int num_head, + const int size_per_head) +{ + + const int batch_id = blockIdx.z; + const int token_id = blockIdx.x; + const int window_id = blockIdx.y; + const int head_id = threadIdx.x / size_per_head; + const int id_in_head = threadIdx.x % size_per_head; + + const int input_offset = + ((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head; + const int target_id = + (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head + + id_in_head; + + int qkv_id = 0; + int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; + __nv_bfloat162 bias = ldg(bias_qkv + bias_idx); + int input_idx = input_offset + bias_idx; + __nv_bfloat162 tmp = mm_qkv[input_idx]; + tmp = bf16hadd2(tmp, bias); + q_buf_[target_id] = tmp; + + qkv_id = 1; + bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; + bias = ldg(bias_qkv + bias_idx); + input_idx = input_offset + bias_idx; + tmp = mm_qkv[input_idx]; + tmp = bf16hadd2(tmp, bias); k_buf_[target_id] = tmp; - qkv_id = 2; - bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; - bias = __ldg(bias_qkv + bias_idx); - input_idx = input_offset + bias_idx; - tmp = mm_qkv[input_idx]; - tmp = __hadd2(tmp, bias); + qkv_id = 2; + bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head; + bias = ldg(bias_qkv + bias_idx); + input_idx = input_offset + bias_idx; + tmp = mm_qkv[input_idx]; + tmp = bf16hadd2(tmp, bias); v_buf_[target_id] = tmp; } +#endif template -void invokeAddHead3SizeQKVBias(const T* mm_qkv, - const T* bias_qkv, - T* q_buf_, - T* k_buf_, - T* v_buf_, - const int batch, - const int window_num, - const int window_len, - const int num_head, - const int size_per_head, +void invokeAddHead3SizeQKVBias(const T* mm_qkv, + const T* bias_qkv, + T* q_buf_, + T* k_buf_, + T* v_buf_, + const int batch, + const int window_num, + const int window_len, + const int num_head, + const int size_per_head, cudaStream_t stream) { if (std::is_same::value) { @@ -1709,24 +2014,30 @@ void invokeAddHead3SizeQKVBias(const T* mm_qkv, size_per_head / 2); } else { - printf("[ERROR][invokeAddHead3SizeQKVBias] unsupport block.x!\n"); + printf("[ERROR][invokeAddHead3SizeQKVBias] unsupported block.x!\n"); exit(-1); } } +#ifdef ENABLE_BF16 + else if (std::is_same::value || std::is_same::value) { +#else else if (std::is_same::value) { +#endif dim3 grid(window_len, window_num, batch); dim3 block(num_head * size_per_head / 2); + using T2 = typename TypeConverter::Type; // half2 or bfloat16 + if (block.x > 1024) { printf("[ERROR][invokeAddHead3SizeQKVBias] block.x > 1024!\n"); exit(-1); } - add_head3Size_QKV_bias<<>>((const half2*)mm_qkv, - (const half2*)bias_qkv, - (half2*)q_buf_, - (half2*)k_buf_, - (half2*)v_buf_, + add_head3Size_QKV_bias<<>>((const T2*)mm_qkv, + (const T2*)bias_qkv, + (T2*)q_buf_, + (T2*)k_buf_, + (T2*)v_buf_, batch, window_num, window_len, @@ -1737,28 +2048,42 @@ void invokeAddHead3SizeQKVBias(const T* mm_qkv, template void invokeAddHead3SizeQKVBias(const float* mm_qkv, const float* bias_qkv, - float* q_buf_, - float* k_buf_, - float* v_buf_, - const int batch, - const int window_num, - const int window_len, - const int num_head, - const int size_per_head, + float* q_buf_, + float* k_buf_, + float* v_buf_, + const int batch, + const int window_num, + const int window_len, + const int num_head, + const int size_per_head, cudaStream_t stream); -template void invokeAddHead3SizeQKVBias(const half* mm_qkv, - const half* bias_qkv, - half* q_buf_, - half* k_buf_, - half* v_buf_, - const int batch, - const int window_num, - const int window_len, - const int num_head, - const int size_per_head, +template void invokeAddHead3SizeQKVBias(const half* mm_qkv, + const half* bias_qkv, + half* q_buf_, + half* k_buf_, + half* v_buf_, + const int batch, + const int window_num, + const int window_len, + const int num_head, + const int size_per_head, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddHead3SizeQKVBias<__nv_bfloat16>(const __nv_bfloat16* mm_qkv, + const __nv_bfloat16* bias_qkv, + __nv_bfloat16* q_buf_, + __nv_bfloat16* k_buf_, + __nv_bfloat16* v_buf_, + const int batch, + const int window_num, + const int window_len, + const int num_head, + const int size_per_head, + cudaStream_t stream); +#endif + /******************* invokeMaskedSoftMax ***********************/ // grid = (window_len/word_per_thread, window_num*num_head, batch_size) @@ -1767,33 +2092,33 @@ template void invokeAddHead3SizeQKVBias(const half* mm_qkv, // attn_mask is [window_num, window_len, window_len] + row-major // relative_pos_bias is [num_head, window_len, window_len] + row-majot template -__global__ void softmax_kernel(T* qk_buf, - const T* attn_mask, - const T* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - const int window_len_x_window_len, +__global__ void softmax_kernel(T* qk_buf, + const T* attn_mask, + const T* relative_pos_bias, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const int window_len_x_window_len, const float qk_scale) { bool qual = threadIdx.x < window_len; for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) { - float tmp = -1e20f; + float tmp = -1e20f; __shared__ float s_mean, s_max; - int qk_offset; + int qk_offset; if (qual) { const int offset_in_window = window_id * window_len + threadIdx.x; qk_offset = (blockIdx.z * gridDim.y + blockIdx.y) * window_len_x_window_len + offset_in_window; const int relative_pos_bias_offset = (blockIdx.y % num_head) * window_len_x_window_len + offset_in_window; - float mask_val = + float mask_val = (attn_mask == nullptr) ? - 0.0f : - static_cast( - __ldg(attn_mask + ((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window))); + 0.0f : + static_cast( + ldg(attn_mask + ((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window))); tmp = qk_scale * static_cast(qk_buf[qk_offset]) + mask_val - + static_cast(__ldg(relative_pos_bias + relative_pos_bias_offset)); + + static_cast(ldg(relative_pos_bias + relative_pos_bias_offset)); } float max_val = blockReduceMax(tmp); @@ -1802,7 +2127,7 @@ __global__ void softmax_kernel(T* qk_buf, } __syncthreads(); - float qk_tmp = qual ? __expf(tmp - s_max) : 0.0f; + float qk_tmp = qual ? __expf(tmp - s_max) : 0.0f; float sum_val = blockReduceSum(qk_tmp); if (threadIdx.x == 0) { s_mean = sum_val + 1e-6f; @@ -1816,19 +2141,19 @@ __global__ void softmax_kernel(T* qk_buf, } template -void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, - const T* attn_mask, - const T* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - float qk_scale, +void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, + const T* attn_mask, + const T* relative_pos_bias, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + float qk_scale, cudaStream_t stream) { const int word_per_thread = 1; - dim3 grid(window_len / word_per_thread, window_num * num_head, batch_size); - dim3 block((window_len + 31) / 32 * 32); + dim3 grid(window_len / word_per_thread, window_num * num_head, batch_size); + dim3 block((window_len + 31) / 32 * 32); softmax_kernel<<>>(qk_buf, attn_mask, relative_pos_bias, @@ -1840,24 +2165,36 @@ void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, qk_scale); } -template void invokeMaskedSoftMaxWithRelPosBias(float* qk_buf, +template void invokeMaskedSoftMaxWithRelPosBias(float* qk_buf, const float* attn_mask, const float* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - const float qk_scale, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const float qk_scale, cudaStream_t stream); -template void invokeMaskedSoftMaxWithRelPosBias(half* qk_buf, - const half* attn_mask, - const half* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - const float qk_scale, +template void invokeMaskedSoftMaxWithRelPosBias(half* qk_buf, + const half* attn_mask, + const half* relative_pos_bias, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const float qk_scale, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeMaskedSoftMaxWithRelPosBias(__nv_bfloat16* qk_buf, + const __nv_bfloat16* attn_mask, + const __nv_bfloat16* relative_pos_bias, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const float qk_scale, + cudaStream_t stream); +#endif + } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.h b/src/fastertransformer/kernels/unfused_attention_kernels.h index be8b178e9..3815eec78 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.h +++ b/src/fastertransformer/kernels/unfused_attention_kernels.h @@ -18,152 +18,182 @@ namespace fastertransformer { template -void invokeAddQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* Q, - const T* bias_Q, - T* K, - const T* bias_K, - T* V, - const T* bias_V, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +void invokeAddQKVBiasTranspose(T* q_buf, + T* k_buf, + T* v_buf, + T* Q, + const T* bias_Q, + T* K, + const T* bias_K, + T* V, + const T* bias_V, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream); template -void invokeMaskedSoftMax(T* buffer, - const T_IN* buffer_src, - const T* attr_mask, - const int batch_size, - const int seq_len, - const int head_num, - const T scalar, +void invokeMaskedSoftMax(T* buffer, + const T_IN* buffer_src, + const T* attr_mask, + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const T scalar, cudaStream_t stream); template -void invokeTransposeQKV(T* dst, - T* src, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +void invokeTransposeQKV(T* dst, + T* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, cudaStream_t stream); template -void invokeAddQKVBiasRebuildPadding(T* Q, - const T* bias_Q, - T* K, - const T* bias_K, - T* V, - const T* bias_V, - T* q_buf, - T* k_buf, - T* v_buf, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int valid_word_num, - const int* mask_offset, +void invokeAddQKVBiasRebuildPadding(T* Q, + const T* bias_Q, + T* K, + const T* bias_K, + T* V, + const T* bias_V, + T* q_buf, + T* k_buf, + T* v_buf, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int valid_word_num, + const int* mask_offset, cudaStream_t stream); template -void invokeTransposeAttentionOutRemovePadding(T* src, - T* dst, - const int valid_word_num, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int* mask_offset, +void invokeTransposeAttentionOutRemovePadding(T* src, + T* dst, + const int valid_word_num, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int* mask_offset, cudaStream_t stream); +// Prefix Prompt Parameters template -void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* QKV, - const T* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, +struct PrefixPromptBatchWeightsParam { + const T** d_prefix_prompt_batch = nullptr; + const int* d_prefix_prompt_lengths = nullptr; + const int max_prefix_prompt_length = 0; + // l * 2 * hidden_units_ / tensor_para_.world_size_ + const size_t prefix_prompt_layer_offset_per_seq = 0; +}; + +template +void invokeAddFusedQKVBiasTranspose(T* q_buf, + T* k_buf, + T* v_buf, + T* QKV, + const T* qkv_bias, + const int* padding_offset, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head, cudaStream_t stream) { - invokeAddFusedQKVBiasTranspose( - q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, 0, stream); + invokeAddFusedQKVBiasTranspose(q_buf, + k_buf, + v_buf, + PrefixPromptBatchWeightsParam{}, + QKV, + qkv_bias, + padding_offset, + batch_size, + seq_len, + token_num, + head_num, + size_per_head, + 0, + false, + stream); } template -void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* QKV, - const T* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - cudaStream_t stream); +void invokeAddFusedQKVBiasTranspose(T* q_buf, + T* k_buf, + T* v_buf, + PrefixPromptBatchWeightsParam param, + T* QKV, + const T* qkv_bias, + const int* padding_offset, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const int neox_rotary_style, + cudaStream_t stream); template -void invokeTranspose4d(T* dst, - T* src, - const int local_batch_size, - const int seq_len, - const int size_per_head, - const int local_hidden_units, - const int local_head_num, - const int batch_size, - const int ite, +void invokeTranspose4d(T* dst, + T* src, + const int local_batch_size, + const int seq_len, + const int size_per_head, + const int local_hidden_units, + const int local_head_num, + const int batch_size, + const int ite, cudaStream_t stream); template -void invokeTranspose4dBatchMajor(T* k_dst, - T* v_dst, - const T* k_src, - const T* v_src, - const int local_batch_size, - const int seq_len, - const int max_seq_len, - const int size_per_head, - const int local_head_num, +void invokeTranspose4dBatchMajor(T* k_dst, + T* v_dst, + const T* k_src, + const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, cudaStream_t stream); template -void invokeAddRelativeAttentionBias(T* qk_buf, - const T* relative_attention_bias, - const int batch_size, - const int head_num, - const int seq_len, +void invokeAddRelativeAttentionBias(T* qk_buf, + const T* relative_attention_bias, + const int batch_size, + const int head_num, + const int seq_len, cudaStream_t stream); template -void invokeAddHead3SizeQKVBias(const T* mm_qkv, - const T* bias_qkv, - T* q_buf_, - T* k_buf_, - T* v_buf_, - const int batch, - const int window_num, - const int window_len, - const int head_num, - const int size_per_head, +void invokeAddHead3SizeQKVBias(const T* mm_qkv, + const T* bias_qkv, + T* q_buf_, + T* k_buf_, + T* v_buf_, + const int batch, + const int window_num, + const int window_len, + const int head_num, + const int size_per_head, cudaStream_t stream); template -void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, - const T* attn_mask, - const T* relative_pos_bias, - const int batch_size, - const int num_head, - const int window_num, - const int window_len, - const float qk_scale, +void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, + const T* attn_mask, + const T* relative_pos_bias, + const int batch_size, + const int num_head, + const int window_num, + const int window_len, + const float qk_scale, cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/vit_kernels.cu b/src/fastertransformer/kernels/vit_kernels.cu index 5cdd30f11..14565c3db 100644 --- a/src/fastertransformer/kernels/vit_kernels.cu +++ b/src/fastertransformer/kernels/vit_kernels.cu @@ -28,7 +28,7 @@ __global__ void add_bias_slice( const int offset = on_top ? 1 : 0; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { - int slice_id = id / (s * n); + int slice_id = id / (s * n); out[id + (slice_id + offset) * n] = __ldg(&in[id]) + __ldg(&bias[id % n]); } } @@ -37,15 +37,15 @@ template<> __global__ void add_bias_slice( const half* __restrict in, half* __restrict out, const half* __restrict bias, int m, int n, int s, bool on_top) { - const int offset = on_top ? 1 : 0; - const half2* in_ptr = (half2*)in; + const int offset = on_top ? 1 : 0; + const half2* in_ptr = (half2*)in; const half2* bias_ptr = (half2*)bias; - half2* out_ptr = (half2*)out; + half2* out_ptr = (half2*)out; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { - half2 d1 = __ldg(&in_ptr[id]); - half2 d2 = __ldg(&bias_ptr[id % n]); - int slice_id = id / (s * n); + half2 d1 = __ldg(&in_ptr[id]); + half2 d2 = __ldg(&bias_ptr[id % n]); + int slice_id = id / (s * n); out_ptr[id + (slice_id + offset) * n] = __hadd2(d1, d2); } } @@ -54,14 +54,14 @@ template void invokeAddBiasSlice(T* in, T* out, const T* bias, const int m, const int n, const int s, cudaStream_t stream) { const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 - dim3 block, grid; + dim3 block, grid; if (n / 4 / data_type_factor <= 1024) { block.x = n / 4 / data_type_factor; - grid.x = m; + grid.x = m; } else { block.x = 1024; - grid.x = (m * n + 1023) / 1024; + grid.x = (m * n + 1023) / 1024; } add_bias_slice<<>>(in, out, bias, m, n / data_type_factor, s); } @@ -75,18 +75,18 @@ __global__ void add_bias_concat_clstoken_add_posembed(const T* __restrict in, const int m, // b * (h*w+1) const int n, const int s, // h*w+1 - bool on_top = true) + bool on_top = true) { const int concat_row_idx = on_top ? 0 : (s - 1); - const int offset = on_top ? 1 : 0; + const int offset = on_top ? 1 : 0; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { - int col_idx = id % n; - int row_idx = id / n; + int col_idx = id % n; + int row_idx = id / n; int slice_row_idx = row_idx % s; - int slice_idx = row_idx / s; - int idx_s = slice_row_idx * n + col_idx; - int idx_i = (slice_row_idx - offset + slice_idx * (s - 1)) * n + col_idx; + int slice_idx = row_idx / s; + int idx_s = slice_row_idx * n + col_idx; + int idx_i = (slice_row_idx - offset + slice_idx * (s - 1)) * n + col_idx; if (slice_row_idx == concat_row_idx) { out[id] = __ldg(&cls_token[col_idx]) + __ldg(&pos_embed[idx_s]); @@ -106,81 +106,81 @@ __global__ void add_bias_concat_clstoken_add_posembed(const half* __restrict in, const int m, // b * (h*w+1) const int n, const int s, // h*w+1 - bool on_top) + bool on_top) { - const int concat_row_idx = on_top ? 0 : (s - 1); - const int offset = on_top ? 1 : 0; - half2* out_ptr = (half2*)out; - const half2* in_ptr = (half2*)in; - const half2* bias_ptr = (half2*)bias; - const half2* token_ptr = (half2*)cls_token; - const half2* embed_ptr = (half2*)pos_embed; + const int concat_row_idx = on_top ? 0 : (s - 1); + const int offset = on_top ? 1 : 0; + half2* out_ptr = (half2*)out; + const half2* in_ptr = (half2*)in; + const half2* bias_ptr = (half2*)bias; + const half2* token_ptr = (half2*)cls_token; + const half2* embed_ptr = (half2*)pos_embed; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { - int col_idx = id % n; - int row_idx = id / n; + int col_idx = id % n; + int row_idx = id / n; int slice_row_idx = row_idx % s; - int slice_idx = row_idx / s; - int idx_s = slice_row_idx * n + col_idx; - int idx_i = (slice_row_idx - offset + slice_idx * (s - 1)) * n + col_idx; + int slice_idx = row_idx / s; + int idx_s = slice_row_idx * n + col_idx; + int idx_i = (slice_row_idx - offset + slice_idx * (s - 1)) * n + col_idx; if (slice_row_idx == concat_row_idx) { - half2 d1 = __ldg(&token_ptr[col_idx]); - half2 d2 = __ldg(&embed_ptr[idx_s]); + half2 d1 = __ldg(&token_ptr[col_idx]); + half2 d2 = __ldg(&embed_ptr[idx_s]); out_ptr[id] = __hadd2(d1, d2); } else { - half2 d1 = __ldg(&in_ptr[idx_i]); - half2 d2 = __ldg(&bias_ptr[col_idx]); - half2 d3 = __ldg(&embed_ptr[idx_s]); + half2 d1 = __ldg(&in_ptr[idx_i]); + half2 d2 = __ldg(&bias_ptr[col_idx]); + half2 d3 = __ldg(&embed_ptr[idx_s]); out_ptr[id] = __hadd2(d3, __hadd2(d1, d2)); } } } template -void invokeAddBiasConcatClsTokenAddPosEmbed(const T* in, - T* out, - const T* bias, - const T* cls_token, - const T* pos_embed, - const int m, - const int n, - const int s, +void invokeAddBiasConcatClsTokenAddPosEmbed(const T* in, + T* out, + const T* bias, + const T* cls_token, + const T* pos_embed, + const int m, + const int n, + const int s, cudaStream_t stream) { const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 - dim3 block, grid; + dim3 block, grid; if (n / 4 / data_type_factor <= 1024) { block.x = n / 4 / data_type_factor; - grid.x = m; + grid.x = m; } else { block.x = 1024; - grid.x = (m * n + 1023) / 1024; + grid.x = (m * n + 1023) / 1024; } add_bias_concat_clstoken_add_posembed<<>>( in, out, bias, cls_token, pos_embed, m, n / data_type_factor, s); } template void invokeAddBiasConcatClsTokenAddPosEmbed(const float* in, - float* out, + float* out, const float* bias, const float* cls_token, const float* pos_embed, - const int m, - const int n, - const int s, + const int m, + const int n, + const int s, cudaStream_t stream); -template void invokeAddBiasConcatClsTokenAddPosEmbed(const half* in, - half* out, - const half* bias, - const half* cls_token, - const half* pos_embed, - const int m, - const int n, - const int s, +template void invokeAddBiasConcatClsTokenAddPosEmbed(const half* in, + half* out, + const half* bias, + const half* cls_token, + const half* pos_embed, + const int m, + const int n, + const int s, cudaStream_t stream); template @@ -192,8 +192,8 @@ slice_copy(const T* __restrict in, T* __restrict out, const int m, const int n, return; } - int m_idx = idx / n; - int col = idx % n; + int m_idx = idx / n; + int col = idx % n; int in_idx = (m_idx * s + offset_s) * n + col; out[idx] = __ldg(&in[in_idx]); @@ -208,12 +208,12 @@ slice_copy(const half* __restrict in, half* __restrict out, const int m, const i return; } - int m_idx = idx / n; - int col = idx % n; + int m_idx = idx / n; + int col = idx % n; int in_idx = (m_idx * s + offset_s) * n + col; - half2* out_ptr = (half2*)out; - const half2* in_ptr = (half2*)in; + half2* out_ptr = (half2*)out; + const half2* in_ptr = (half2*)in; out_ptr[idx] = __ldg(&in_ptr[in_idx]); } @@ -223,14 +223,14 @@ void invokeSliceCopy( const T* in, T* out, const int m, const int n, const int s, const int offset_s, cudaStream_t stream) { const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 - dim3 block, grid; + dim3 block, grid; if (n / data_type_factor <= 1024) { block.x = n / data_type_factor; - grid.x = m; + grid.x = m; } else { block.x = 1024; - grid.x = (m * n + 1023) / 1024; + grid.x = (m * n + 1023) / 1024; } slice_copy<<>>(in, out, m, n / data_type_factor, s, offset_s); } @@ -266,15 +266,15 @@ __global__ void add_bias_add_posembed(half* __restrict out, // b*(h const int s // h*w *n ) { - half2* out_ptr = (half2*)out; - const half2* bias_ptr = (half2*)bias; + half2* out_ptr = (half2*)out; + const half2* bias_ptr = (half2*)bias; const half2* embed_ptr = (half2*)pos_embed; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { - int b_idx = id % n; - int p_idx = id % s; - half2 d1 = __ldg(&bias_ptr[b_idx]); - half2 d2 = __ldg(&embed_ptr[p_idx]); + int b_idx = id % n; + int p_idx = id % s; + half2 d1 = __ldg(&bias_ptr[b_idx]); + half2 d2 = __ldg(&embed_ptr[p_idx]); out_ptr[id] = __hadd2(out_ptr[id], __hadd2(d1, d2)); } } @@ -284,14 +284,14 @@ void invokeAddBiasAddPosEmbed( T* out, const T* bias, const T* pos_embed, const int m, const int n, const int s, cudaStream_t stream) { const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 - dim3 block, grid; + dim3 block, grid; if (n / 4 / data_type_factor <= 1024) { block.x = n / 4 / data_type_factor; - grid.x = m; + grid.x = m; } else { block.x = 1024; - grid.x = (m * n + 1023) / 1024; + grid.x = (m * n + 1023) / 1024; } add_bias_add_posembed<<>>(out, bias, pos_embed, m, n / data_type_factor, s); } diff --git a/src/fastertransformer/kernels/vit_kernels.h b/src/fastertransformer/kernels/vit_kernels.h index 354b07384..b9527fe87 100644 --- a/src/fastertransformer/kernels/vit_kernels.h +++ b/src/fastertransformer/kernels/vit_kernels.h @@ -24,14 +24,14 @@ template void invokeAddBiasSlice(T* in, T* out, const T* bias, const int m, const int n, const int s, cudaStream_t stream); template -void invokeAddBiasConcatClsTokenAddPosEmbed(const T* in, - T* out, - const T* bias, - const T* cls_token, - const T* pos_embed, - const int m, - const int n, - const int s, +void invokeAddBiasConcatClsTokenAddPosEmbed(const T* in, + T* out, + const T* bias, + const T* cls_token, + const T* pos_embed, + const int m, + const int n, + const int s, cudaStream_t stream); template diff --git a/src/fastertransformer/kernels/xlnet_attention_kernels.cu b/src/fastertransformer/kernels/xlnet_attention_kernels.cu index 1ae80a9b7..1214c0f03 100644 --- a/src/fastertransformer/kernels/xlnet_attention_kernels.cu +++ b/src/fastertransformer/kernels/xlnet_attention_kernels.cu @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" #include "src/fastertransformer/kernels/xlnet_attention_kernels.h" namespace fastertransformer { template @@ -36,6 +37,14 @@ __device__ half2 cH2(const half* ptr, int offset) { return __ldg((half2*)(ptr + offset)); } + +#ifdef ENABLE_BF16 +__device__ __nv_bfloat162 cH2(const __nv_bfloat16* ptr, int offset) +{ + return ldg((__nv_bfloat162*)(ptr + offset)); +} +#endif + __forceinline__ __device__ unsigned lane_id() { unsigned ret; @@ -73,8 +82,8 @@ template __inline__ __device__ T blockReduceSum(T val) { static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; val = warpReduceSum(val); @@ -94,8 +103,8 @@ template __inline__ __device__ T blockReduceMax(T val) { static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx val = warpReduceMax(val); // get maxx in each warp @@ -112,219 +121,184 @@ __inline__ __device__ T blockReduceMax(T val) } /********************** Kernels ************************/ + +// Applied to half and bfloat16 template -void __global__ prepareMatrixes(T* q_buf, - T* q_buf_bd, - T* q_buf_ef, - T* k_buf, - T* k_buf_bd, - T* k_buf_ef, - const T* query_buf, - const T* key_buf, - const T* k_head_r, - const T* attr_seg_embed, - const T* attr_bias_Q_w, - const T* attr_bias_Q_r, - const T* attr_bias_Q_s, +void __global__ prepareMatrixes(T* q_buf, + T* q_buf_bd, + T* q_buf_ef, + T* k_buf, + T* k_buf_bd, + T* k_buf_ef, + const T* query_buf, + const T* key_buf, + const T* k_head_r, + const T* attr_seg_embed, + const T* attr_bias_Q_w, + const T* attr_bias_Q_r, + const T* attr_bias_Q_s, const int off0, const int i_off1, const int o_off1, - int off2) + int off2) { - int batch = blockIdx.y; - int seq = blockIdx.x; - int head_loc = threadIdx.x; + using T2 = typename TypeConverter::Type; // half2 or bfloat162 + + int batch = blockIdx.y; + int seq = blockIdx.x; + int head_loc = threadIdx.x * 2; - T tmp; + T2 tmp; if (head_loc < i_off1) { - int head = head_loc / off2; - int loc = head_loc % off2; + int head = head_loc / off2; + int loc = head_loc % off2; + int h2_index = (batch * off0 + seq * i_off1 + head_loc) >> 1; - int index = batch * off0 + seq * i_off1 + head_loc; - tmp = query_buf[index]; - int index_out = batch * off0 + head * o_off1 + seq * off2 + loc; + tmp = ((T2*)query_buf)[h2_index]; + int h2_index_out = (batch * off0 + head * o_off1 + seq * off2 + loc) >> 1; // left matrix - q_buf[index_out] = tmp + __ldg(attr_bias_Q_w + head_loc); // tex2D(t_attr_bias_Q_w, loc, head); - q_buf_bd[index_out] = tmp + __ldg(attr_bias_Q_r + head_loc); // tex2D(t_attr_bias_Q_r, loc, head); - q_buf_ef[index_out] = tmp + __ldg(attr_bias_Q_s + head_loc); // tex2D(t_attr_bias_Q_s, loc, head); + ((T2*)q_buf)[h2_index_out] = hadd2(tmp, cH2(attr_bias_Q_w, head_loc)); + ((T2*)q_buf_bd)[h2_index_out] = hadd2(tmp, cH2(attr_bias_Q_r, head_loc)); + ((T2*)q_buf_ef)[h2_index_out] = hadd2(tmp, cH2(attr_bias_Q_s, head_loc)); // right matrix - k_buf[index_out] = key_buf[index]; // ac + ((T2*)k_buf)[h2_index_out] = ((T2*)key_buf)[h2_index]; // ac // bd - index = seq * i_off1 + head_loc; //(seq, head_loc) - tmp = k_head_r[index]; - index_out = index_out + batch * off0 + head * o_off1; //(batch, head,seq,loc) - k_buf_bd[index_out] = tmp; + h2_index = (seq * i_off1 + head_loc) >> 1; //(seq, head_loc) + tmp = ((T2*)k_head_r)[h2_index]; + h2_index_out = (batch * off0 * 2 + head * o_off1 * 2 + seq * off2 + loc) >> 1; //(batch, head,seq,loc) + ((T2*)k_buf_bd)[h2_index_out] = tmp; - index = index + off0; //(seq+seq_len, head_loc) - tmp = k_head_r[index]; - index_out = index_out + o_off1; //(batch, head,seq+seq_len,loc) - k_buf_bd[index_out] = tmp; + h2_index = (seq * i_off1 + head_loc + off0) >> 1; //(seq+seq_len, head_loc) + tmp = ((T2*)k_head_r)[h2_index]; + h2_index_out = + (batch * off0 * 2 + (head * 2 + 1) * o_off1 + seq * off2 + loc) >> 1; //(batch, head,seq+seq_len,loc) + ((T2*)k_buf_bd)[h2_index_out] = tmp; // ef if (seq <= 1) { - index = seq * i_off1 + head_loc; //(seq, head, loc) - tmp = attr_seg_embed[index]; - index_out = batch * 2 * i_off1 + (head * 2 + seq) * off2 + loc; //(head,seq,loc) - k_buf_ef[index_out] = tmp; + h2_index = (seq * i_off1 + head_loc) >> 1; //(seq, head, loc) + tmp = ((T2*)attr_seg_embed)[h2_index]; + h2_index_out = (batch * 2 * i_off1 + (head * 2 + seq) * off2 + loc) >> 1; //(head,seq,loc) + ((T2*)k_buf_ef)[h2_index_out] = tmp; } } } + template<> -void __global__ prepareMatrixes(__half* q_buf, - __half* q_buf_bd, - __half* q_buf_ef, - __half* k_buf, - __half* k_buf_bd, - __half* k_buf_ef, - const __half* query_buf, - const __half* key_buf, - const __half* k_head_r, - const __half* attr_seg_embed, - const __half* attr_bias_Q_w, - const __half* attr_bias_Q_r, - const __half* attr_bias_Q_s, - const int off0, - const int i_off1, - const int o_off1, - int off2) +void __global__ prepareMatrixes(float* q_buf, + float* q_buf_bd, + float* q_buf_ef, + float* k_buf, + float* k_buf_bd, + float* k_buf_ef, + const float* query_buf, + const float* key_buf, + const float* k_head_r, + const float* attr_seg_embed, + const float* attr_bias_Q_w, + const float* attr_bias_Q_r, + const float* attr_bias_Q_s, + const int off0, + const int i_off1, + const int o_off1, + int off2) { - int batch = blockIdx.y; - int seq = blockIdx.x; - int head_loc = threadIdx.x * 2; + int batch = blockIdx.y; + int seq = blockIdx.x; + int head_loc = threadIdx.x; - half2 tmp; + float tmp; if (head_loc < i_off1) { int head = head_loc / off2; - int loc = head_loc % off2; - int h2_index = (batch * off0 + seq * i_off1 + head_loc) >> 1; + int loc = head_loc % off2; - tmp = ((half2*)query_buf)[h2_index]; - int h2_index_out = (batch * off0 + head * o_off1 + seq * off2 + loc) >> 1; + int index = batch * off0 + seq * i_off1 + head_loc; + tmp = query_buf[index]; + int index_out = batch * off0 + head * o_off1 + seq * off2 + loc; // left matrix - ((half2*)q_buf)[h2_index_out] = __hadd2(tmp, cH2(attr_bias_Q_w, head_loc)); - ((half2*)q_buf_bd)[h2_index_out] = __hadd2(tmp, cH2(attr_bias_Q_r, head_loc)); - ((half2*)q_buf_ef)[h2_index_out] = __hadd2(tmp, cH2(attr_bias_Q_s, head_loc)); + q_buf[index_out] = tmp + __ldg(attr_bias_Q_w + head_loc); // tex2D(t_attr_bias_Q_w, loc, head); + q_buf_bd[index_out] = tmp + __ldg(attr_bias_Q_r + head_loc); // tex2D(t_attr_bias_Q_r, loc, head); + q_buf_ef[index_out] = tmp + __ldg(attr_bias_Q_s + head_loc); // tex2D(t_attr_bias_Q_s, loc, head); // right matrix - ((half2*)k_buf)[h2_index_out] = ((half2*)key_buf)[h2_index]; // ac + k_buf[index_out] = key_buf[index]; // ac // bd - h2_index = (seq * i_off1 + head_loc) >> 1; //(seq, head_loc) - tmp = ((half2*)k_head_r)[h2_index]; - h2_index_out = (batch * off0 * 2 + head * o_off1 * 2 + seq * off2 + loc) >> 1; //(batch, head,seq,loc) - ((half2*)k_buf_bd)[h2_index_out] = tmp; + index = seq * i_off1 + head_loc; //(seq, head_loc) + tmp = k_head_r[index]; + index_out = index_out + batch * off0 + head * o_off1; //(batch, head,seq,loc) + k_buf_bd[index_out] = tmp; - h2_index = (seq * i_off1 + head_loc + off0) >> 1; //(seq+seq_len, head_loc) - tmp = ((half2*)k_head_r)[h2_index]; - h2_index_out = - (batch * off0 * 2 + (head * 2 + 1) * o_off1 + seq * off2 + loc) >> 1; //(batch, head,seq+seq_len,loc) - ((half2*)k_buf_bd)[h2_index_out] = tmp; + index = index + off0; //(seq+seq_len, head_loc) + tmp = k_head_r[index]; + index_out = index_out + o_off1; //(batch, head,seq+seq_len,loc) + k_buf_bd[index_out] = tmp; // ef if (seq <= 1) { - h2_index = (seq * i_off1 + head_loc) >> 1; //(seq, head, loc) - tmp = ((half2*)attr_seg_embed)[h2_index]; - h2_index_out = (batch * 2 * i_off1 + (head * 2 + seq) * off2 + loc) >> 1; //(head,seq,loc) - ((half2*)k_buf_ef)[h2_index_out] = tmp; + index = seq * i_off1 + head_loc; //(seq, head, loc) + tmp = attr_seg_embed[index]; + index_out = batch * 2 * i_off1 + (head * 2 + seq) * off2 + loc; //(head,seq,loc) + k_buf_ef[index_out] = tmp; } } } +// Applied to half and bfloat16 template void __global__ transpose102(T* dst, const T* src, const int off0, const int i_off1, const int o_off1, const int off2) { + using T2 = typename TypeConverter::Type; // half2 or bfloat162 int x[4] = {0}; - x[0] = blockIdx.x; //[0,7] - x[1] = blockIdx.y; //[0,11] - x[2] = threadIdx.x; //[0,127] - x[3] = threadIdx.y; //[0,1] + x[0] = blockIdx.x; //[0,7] + x[1] = blockIdx.y; //[0,11] + x[2] = threadIdx.x; //[0,127] - int input_index = x[0] * off0 + x[1] * i_off1 + x[2] * off2 + x[3]; // [batch, 0, 1, 2]=[d0,d1,d2,d3] + int input_index = (x[0] * off0 + x[1] * i_off1 + x[2] * off2) >> 1; // [batch, 0, 1, 2]=[d0,d1,d2,d3] - int out_index = x[0] * off0 + x[2] * o_off1 + x[1] * off2 + x[3]; // [batch, 1, 0, 2]=[d0,d2,d1,d3] + int out_index = (x[0] * off0 + x[2] * o_off1 + x[1] * off2) >> 1; // [batch, 1, 0, 2]=[d0,d2,d1,d3] - dst[out_index] = src[input_index]; + ((T2*)dst)[out_index] = ((T2*)src)[input_index]; } template<> void __global__ -transpose102(__half* dst, const __half* src, const int off0, const int i_off1, const int o_off1, const int off2) +transpose102(float* dst, const float* src, const int off0, const int i_off1, const int o_off1, const int off2) { int x[4] = {0}; - x[0] = blockIdx.x; //[0,7] - x[1] = blockIdx.y; //[0,11] - x[2] = threadIdx.x; //[0,127] + x[0] = blockIdx.x; //[0,7] + x[1] = blockIdx.y; //[0,11] + x[2] = threadIdx.x; //[0,127] + x[3] = threadIdx.y; //[0,1] - int input_index = (x[0] * off0 + x[1] * i_off1 + x[2] * off2) >> 1; // [batch, 0, 1, 2]=[d0,d1,d2,d3] + int input_index = x[0] * off0 + x[1] * i_off1 + x[2] * off2 + x[3]; // [batch, 0, 1, 2]=[d0,d1,d2,d3] - int out_index = (x[0] * off0 + x[2] * o_off1 + x[1] * off2) >> 1; // [batch, 1, 0, 2]=[d0,d2,d1,d3] + int out_index = x[0] * off0 + x[2] * o_off1 + x[1] * off2 + x[3]; // [batch, 1, 0, 2]=[d0,d2,d1,d3] - ((half2*)dst)[out_index] = ((half2*)src)[input_index]; + dst[out_index] = src[input_index]; } -void __global__ transpose201(float* dst, - const float* src, - const int off0, - const int i_off1, - const int head_num, - const int o_off1, - const int seq_len) +template +void __global__ transpose201( + T* dst, const T* src, const int off0, const int i_off1, const int head_num, const int o_off1, const int seq_len) { int batch = blockIdx.x; - int d0 = blockIdx.y; - int d1 = threadIdx.x; - - extern __shared__ float sdata[]; - int i = 0; - - // Read data into shared memory - int index = batch * off0 + d0 * i_off1; // d1*i_off2+d2 - int offset = d1; - src = src + index; - int row = offset / head_num; - int col = offset % head_num; - for (i = 0; i < head_num; i++) { - sdata[row * (head_num + 1) + col] = src[offset]; - offset += seq_len; - row = offset / head_num; - col = offset % head_num; - } - - __syncthreads(); + int d0 = blockIdx.y; + int d1 = threadIdx.x; - index = batch * off0 + d0 * seq_len + d1; - offset = 0; - dst = dst + index; - for (i = 0; i < head_num; i++) { - dst[offset] = sdata[d1 * (head_num + 1) + i]; - offset += o_off1; - } -} - -void __global__ transpose201(__half* dst, - const __half* src, - const int off0, - const int i_off1, - const int head_num, - const int o_off1, - const int seq_len) -{ - int batch = blockIdx.x; - int d0 = blockIdx.y; - int d1 = threadIdx.x; extern __shared__ float sdata[]; - int i = 0; + int i = 0; // Read data into shared memory - int index = batch * off0 + d0 * i_off1; // d1*i_off2+d2 + int index = batch * off0 + d0 * i_off1; // d1*i_off2+d2 int offset = d1; - src = src + index; - int row = offset / head_num; - int col = offset % head_num; + src = src + index; + int row = offset / head_num; + int col = offset % head_num; for (i = 0; i < head_num; i++) { - sdata[row * (head_num + 1) + col] = __half2float(src[offset]); + sdata[row * (head_num + 1) + col] = type2float(src[offset]); offset += seq_len; row = offset / head_num; col = offset % head_num; @@ -332,75 +306,78 @@ void __global__ transpose201(__half* dst, __syncthreads(); - index = batch * off0 + d0 * seq_len + d1; + index = batch * off0 + d0 * seq_len + d1; offset = 0; - dst = dst + index; + dst = dst + index; for (i = 0; i < head_num; i++) { - dst[offset] = __float2half(sdata[d1 * (head_num + 1) + i]); + dst[offset] = float2type(sdata[d1 * (head_num + 1) + i]); offset += o_off1; } } -/*dim3 grid_shift(batch_size, head_num, seq_len); - dim3 block_shift(seq_len*2); - int off0=head_num*seq_len*seq_len; - int off1=seq_len*seq_len; */ -template -void __global__ relShiftBd(T* outMatrix, const T* inputMatrix, const int off0, const int off1, const int seq_len) -{ - int batch = blockIdx.x; //[0,7] - int head = blockIdx.y; //[0,11] - int row = blockIdx.z; //[0,127] - int col = threadIdx.x; //[0,255] - int input_index = (batch * off0 + head * off1 + row * seq_len) * 2 + col; - if (col >= seq_len || row != 0) { - T idata = inputMatrix[input_index]; - // int tmp_index=row*(2*seq_len-1)+row+col-seq_len; - int tmp_index = row * 2 * seq_len + col - seq_len; - int out_row = tmp_index / (seq_len * 2 - 1); - int out_col = tmp_index % (seq_len * 2 - 1); - if (out_col < seq_len) { - int out_index = batch * off0 + head * off1 + out_row * seq_len + out_col; - outMatrix[out_index] = idata; - } - } -} /*int threads=512; seq_dim1=threads/seq_len seq_dim2=seq_len/dimx dim3 grid_rel(batch_size, head_num, seq_dim2); dim3 block_rel(seq_dim1, seq_len);*/ -template<> -void __global__ -relShiftBd(__half* outMatrix, const __half* inputMatrix, const int off0, const int off1, const int seq_len) +template +void __global__ relShiftBd(T* outMatrix, const T* inputMatrix, const int off0, const int off1, const int seq_len) { - int batch = blockIdx.x; //[0,7] - int head = blockIdx.y; //[0,11] - int row = blockIdx.z * blockDim.x + threadIdx.x; //[0,127] - int col = threadIdx.y * 2; //[0,255] + using T2 = typename TypeConverter::Type; // half2 or bfloat162 + int batch = blockIdx.x; //[0,7] + int head = blockIdx.y; //[0,11] + int row = blockIdx.z * blockDim.x + threadIdx.x; //[0,127] + int col = threadIdx.y * 2; //[0,255] int input_index = (batch * off0 + head * off1 + row * seq_len) * 2 + col; int out_index; int out_row; int out_col; int tmp_index; - half2 idata; + T2 idata; if (col >= seq_len || row != 0) { - idata = ((half2*)inputMatrix)[input_index >> 1]; + idata = ((T2*)inputMatrix)[input_index >> 1]; // int tmp_index=row*(2*seq_len-1)+row+col-seq_len; tmp_index = row * 2 * seq_len + col - seq_len; - out_row = tmp_index / (seq_len * 2 - 1); - out_col = tmp_index % (seq_len * 2 - 1); + out_row = tmp_index / (seq_len * 2 - 1); + out_col = tmp_index % (seq_len * 2 - 1); if (out_col < seq_len) { - out_index = (batch * off0 + head * off1 + out_row * seq_len + out_col); - outMatrix[out_index] = __low2half(idata); + out_index = (batch * off0 + head * off1 + out_row * seq_len + out_col); + outMatrix[out_index] = idata.x; } tmp_index += 1; out_row = tmp_index / (seq_len * 2 - 1); out_col = tmp_index % (seq_len * 2 - 1); if (out_col < seq_len) { - out_index = (batch * off0 + head * off1 + out_row * seq_len + out_col); - outMatrix[out_index] = __high2half(idata); + out_index = (batch * off0 + head * off1 + out_row * seq_len + out_col); + outMatrix[out_index] = idata.y; + } + } +} + +/*dim3 grid_shift(batch_size, head_num, seq_len); + dim3 block_shift(seq_len*2); + int off0=head_num*seq_len*seq_len; + int off1=seq_len*seq_len; */ +template<> +void __global__ +relShiftBd(float* outMatrix, const float* inputMatrix, const int off0, const int off1, const int seq_len) +{ + int batch = blockIdx.x; //[0,7] + int head = blockIdx.y; //[0,11] + int row = blockIdx.z; //[0,127] + int col = threadIdx.x; //[0,255] + + int input_index = (batch * off0 + head * off1 + row * seq_len) * 2 + col; + if (col >= seq_len || row != 0) { + float idata = inputMatrix[input_index]; + // int tmp_index=row*(2*seq_len-1)+row+col-seq_len; + int tmp_index = row * 2 * seq_len + col - seq_len; + int out_row = tmp_index / (seq_len * 2 - 1); + int out_col = tmp_index % (seq_len * 2 - 1); + if (out_col < seq_len) { + int out_index = batch * off0 + head * off1 + out_row * seq_len + out_col; + outMatrix[out_index] = idata; } } } @@ -416,33 +393,33 @@ relShiftBd(__half* outMatrix, const __half* inputMatrix, const int off0, const i int voff2=size_per_head; int v_i_off1=head_num*size_per_head;*/ template -__global__ void calAttnScore_valueBuf(T* attn_score, - const T* ac, - const T* bd, - const T* ef, - const T* attn_mask, - const int off0, - const int off1, - const int seq_len, +__global__ void calAttnScore_valueBuf(T* attn_score, + const T* ac, + const T* bd, + const T* ef, + const T* attn_mask, + const int off0, + const int off1, + const int seq_len, const float p, - T* value_buf_trans, - const T* value_buf, - const int voff0, - const int v_i_off1, - const int v_o_off1, - const int voff2) + T* value_buf_trans, + const T* value_buf, + const int voff0, + const int v_i_off1, + const int v_o_off1, + const int voff2) { int batch = blockIdx.x; - int head = blockIdx.y; - int seq1 = blockIdx.z; - int seq2 = threadIdx.x; + int head = blockIdx.y; + int seq1 = blockIdx.z; + int seq2 = threadIdx.x; int offset = batch * off0 + head * off1 + seq1 * seq_len; - int index = offset + seq2; + int index = offset + seq2; int out_index; - T score; - T mask; - T large_value = -1e4; + T score; + T mask; + T large_value = -1e4; if (sizeof(T) == 4) { large_value = -1e30f; } @@ -451,18 +428,18 @@ __global__ void calAttnScore_valueBuf(T* attn_score, score = score * p; out_index = batch * off1 + seq1 * seq_len + seq2; - mask = attn_mask[out_index] * (large_value); - score = score + mask; + mask = attn_mask[out_index] * (large_value); + score = score + mask; } // softmax(attn_score+offset,seq_len, seq2); __shared__ float s_sum, s_max; - float tmp = seq2 < seq_len ? score : large_value; - float max_val = blockReduceMax(tmp); + float tmp = seq2 < seq_len ? score : large_value; + float max_val = blockReduceMax(tmp); if (seq2 == 0) { s_max = max_val; } __syncthreads(); - float qk_tmp = seq2 < seq_len ? __expf((float)(tmp - s_max)) : 0.0f; + float qk_tmp = seq2 < seq_len ? __expf((float)(tmp - s_max)) : 0.0f; float sum_val = blockReduceSum(qk_tmp); __syncthreads(); if (seq2 == 0) { @@ -476,52 +453,55 @@ __global__ void calAttnScore_valueBuf(T* attn_score, offset = seq2; while (offset < voff2) { - out_index = batch * voff0 + head * v_o_off1 + seq1 * voff2 + offset; - index = batch * voff0 + seq1 * v_i_off1 + head * voff2 + offset; + out_index = batch * voff0 + head * v_o_off1 + seq1 * voff2 + offset; + index = batch * voff0 + seq1 * v_i_off1 + head * voff2 + offset; value_buf_trans[out_index] = value_buf[index]; offset += seq_len; } } -void __global__ calAttnScore_valueBuf_small(__half* attn_score, - const __half* ac, - const __half* bd, - const __half* ef, - const __half* attn_mask, - const int off0, - const int off1, - const int seq_len, - int n_seq1, +// Applied to half and bfloat16 +template +void __global__ calAttnScore_valueBuf_small(T* attn_score, + const T* ac, + const T* bd, + const T* ef, + const T* attn_mask, + const int off0, + const int off1, + const int seq_len, + int n_seq1, const float p, - __half* value_buf_trans, - const __half* value_buf, - const int voff0, - const int v_i_off1, - const int v_o_off1, - const int voff2) + T* value_buf_trans, + const T* value_buf, + const int voff0, + const int v_i_off1, + const int v_o_off1, + const int voff2) { - int lid = lane_id(); - int tid = threadIdx.x; - int wid = tid / 32; + using T2 = typename TypeConverter::Type; // half2 or bfloat162 + int lid = lane_id(); + int tid = threadIdx.x; + int wid = tid / 32; int seq2 = lid << 1; int batch = blockIdx.x; - int head = blockIdx.y; - int seq1 = blockIdx.z * n_seq1 + wid; + int head = blockIdx.y; + int seq1 = blockIdx.z * n_seq1 + wid; - int offset = batch * off0 + head * off1 + seq1 * seq_len; - int index = (offset + seq2) >> 1; - int out_index; + int offset = batch * off0 + head * off1 + seq1 * seq_len; + int index = (offset + seq2) >> 1; + int out_index; float2 tmp1, tmp2; // Data prepare section if (seq2 < seq_len) { - tmp1 = __half22float2(((half2*)ac)[index]); - tmp2 = __half22float2(((half2*)bd)[index]); + tmp1 = type22float2(((T2*)ac)[index]); + tmp2 = type22float2(((T2*)bd)[index]); tmp1.x += tmp2.x; tmp1.y += tmp2.y; // tmp1=__hadd2(tmp1,tmp2); - tmp2 = __half22float2(((half2*)ef)[index]); + tmp2 = type22float2(((T2*)ef)[index]); tmp1.x += tmp2.x; tmp1.y += tmp2.y; @@ -530,7 +510,7 @@ void __global__ calAttnScore_valueBuf_small(__half* attn_score, tmp1.y = tmp1.y * p; out_index = (batch * off1 + seq1 * seq_len + seq2) >> 1; - tmp2 = __half22float2(((half2*)attn_mask)[out_index]); + tmp2 = type22float2(((T2*)attn_mask)[out_index]); tmp1.x = tmp1.x + -1e4 * tmp2.x; tmp1.y = tmp1.y + -1e4 * tmp2.y; @@ -549,7 +529,7 @@ void __global__ calAttnScore_valueBuf_small(__half* attn_score, /// normalize the input tmp1.x = seq2 < seq_len ? __expf((float)(tmp1.x - tmp)) : 0.0f; tmp1.y = seq2 < seq_len ? __expf((float)(tmp1.y - tmp)) : 0.0f; - tmp = tmp1.x + tmp1.y; + tmp = tmp1.x + tmp1.y; /// get sum of the normalized value for (int mask = 16; mask > 0; mask >>= 1) { tmp = tmp + __shfl_xor_sync(FINAL_MASK, tmp, mask, 32); @@ -561,62 +541,66 @@ void __global__ calAttnScore_valueBuf_small(__half* attn_score, /// set the value if (seq2 < seq_len) { - tmp1.x = tmp1.x / tmp; - tmp1.y = tmp1.y / tmp; - ((half2*)attn_score)[index] = __float22half2_rn(tmp1); + tmp1.x = tmp1.x / tmp; + tmp1.y = tmp1.y / tmp; + ((T2*)attn_score)[index] = float22type2(tmp1); } // value_buf section offset = seq2; while (offset < voff2) { index = (batch * voff0 + seq1 * v_i_off1 + head * voff2 + offset) >> 1; - half2 v = ((half2*)value_buf)[index]; + T2 v = ((T2*)value_buf)[index]; - out_index = (batch * voff0 + head * v_o_off1 + seq1 * voff2 + offset) >> 1; - ((half2*)value_buf_trans)[out_index] = v; + out_index = (batch * voff0 + head * v_o_off1 + seq1 * voff2 + offset) >> 1; + ((T2*)value_buf_trans)[out_index] = v; offset += seq_len; } } -void __global__ calAttnScore_valueBuf_large(__half* attn_score, - const __half* ac, - const __half* bd, - const __half* ef, - const __half* attn_mask, - const int off0, - const int off1, - const int seq_len, + +// Applied to half and bfloat16 +template +void __global__ calAttnScore_valueBuf_large(T* attn_score, + const T* ac, + const T* bd, + const T* ef, + const T* attn_mask, + const int off0, + const int off1, + const int seq_len, const float p, - __half* value_buf_trans, - const __half* value_buf, - const int voff0, - const int v_i_off1, - const int v_o_off1, - const int voff2) + T* value_buf_trans, + const T* value_buf, + const int voff0, + const int v_i_off1, + const int v_o_off1, + const int voff2) { + using T2 = typename TypeConverter::Type; // half2 or bfloat162 int batch = blockIdx.x; - int head = blockIdx.y; - int seq1 = blockIdx.z; + int head = blockIdx.y; + int seq1 = blockIdx.z; - int lid = lane_id(); - int tid = threadIdx.x; - int wid = tid / 32; + int lid = lane_id(); + int tid = threadIdx.x; + int wid = tid / 32; int seq2 = tid << 1; - int offset = batch * off0 + head * off1 + seq1 * seq_len; - int index = (offset + seq2) >> 1; - int out_index; - float2 tmp1, tmp2; + int offset = batch * off0 + head * off1 + seq1 * seq_len; + int index = (offset + seq2) >> 1; + int out_index; + float2 tmp1, tmp2; __shared__ float sdata[32]; __shared__ float s_max; __shared__ float s_sum; // Data prepare section if (seq2 < seq_len) { - tmp1 = __half22float2(((half2*)ac)[index]); - tmp2 = __half22float2(((half2*)bd)[index]); + tmp1 = type22float2(((T2*)ac)[index]); + tmp2 = type22float2(((T2*)bd)[index]); tmp1.x += tmp2.x; tmp1.y += tmp2.y; // tmp1=__hadd2(tmp1,tmp2); - tmp2 = __half22float2(((half2*)ef)[index]); + tmp2 = type22float2(((T2*)ef)[index]); tmp1.x += tmp2.x; tmp1.y += tmp2.y; @@ -625,7 +609,7 @@ void __global__ calAttnScore_valueBuf_large(__half* attn_score, tmp1.y = tmp1.y * p; out_index = (batch * off1 + seq1 * seq_len + seq2) >> 1; - tmp2 = __half22float2(((half2*)attn_mask)[out_index]); + tmp2 = type22float2(((T2*)attn_mask)[out_index]); tmp1.x = tmp1.x + -1e4 * tmp2.x; tmp1.y = tmp1.y + -1e4 * tmp2.y; @@ -662,7 +646,7 @@ void __global__ calAttnScore_valueBuf_large(__half* attn_score, tmp1.x = seq2 < seq_len ? __expf((float)(tmp1.x - s_max)) : 0.0f; tmp1.y = seq2 < seq_len ? __expf((float)(tmp1.y - s_max)) : 0.0f; - tmp = tmp1.x + tmp1.y; + tmp = tmp1.x + tmp1.y; /// get sum of the normalized value for (int mask = 16; mask > 0; mask /= 2) { @@ -690,124 +674,132 @@ void __global__ calAttnScore_valueBuf_large(__half* attn_score, /// set the value if (seq2 < seq_len) { - tmp1.x = tmp1.x / s_sum; - tmp1.y = tmp1.y / s_sum; - ((half2*)attn_score)[index] = __float22half2_rn(tmp1); + tmp1.x = tmp1.x / s_sum; + tmp1.y = tmp1.y / s_sum; + ((T2*)attn_score)[index] = float22type2(tmp1); } // value_buf section offset = seq2; while (offset < voff2) { index = (batch * voff0 + seq1 * v_i_off1 + head * voff2 + offset) >> 1; - half2 v = ((half2*)value_buf)[index]; + T2 v = ((T2*)value_buf)[index]; - out_index = (batch * voff0 + head * v_o_off1 + seq1 * voff2 + offset) >> 1; - ((half2*)value_buf_trans)[out_index] = v; + out_index = (batch * voff0 + head * v_o_off1 + seq1 * voff2 + offset) >> 1; + ((T2*)value_buf_trans)[out_index] = v; offset += seq_len; } } -// dim3 grid_trans_v(batch_size,seq_len, head_num); -// dim3 block_trans_v(size_per_head); +// Applied to half or bfloat16 template __global__ void transpose102_v2(T* dst, const T* src, const int off0, const int i_off1, const int o_off1, const int off2) { + using T2 = typename TypeConverter::Type; // half2 or bfloat162 int x[4] = {0}; - x[0] = blockIdx.x; - x[1] = threadIdx.x / off2; - x[2] = blockIdx.y; //[0,128] seq_len - x[3] = threadIdx.x % off2; //[0,31] size_per_head + x[0] = blockIdx.x; //[0,7] batch_size + x[1] = threadIdx.x * 2 / off2; // head_num + x[2] = blockIdx.y; // seq_len + x[3] = threadIdx.x * 2 % off2; //[0,63] size_per_head - T tmp; if (x[3] < off2) { - int input_index = x[0] * off0 + x[1] * i_off1 + x[2] * off2 + x[3]; // [batch, 0, 1, 2]=[d0,d1,d2,d3] - tmp = src[input_index]; - int out_index = x[0] * off0 + x[2] * o_off1 + x[1] * off2 + x[3]; // [batch, 1, 0, 2]=[d0,d2,d1,d3] - - dst[out_index] = tmp; + T2 tmp; + int in_index = (x[0] * off0 + x[1] * i_off1 + x[2] * off2 + x[3]) >> 1; // [batch, 0, 1, 2]=[d0,d1,d2,d3] + tmp = ((T2*)src)[in_index]; + int out_index = (x[0] * off0 + x[2] * o_off1 + x[1] * off2 + x[3]) >> 1; // [batch, 1, 0, 2]=[d0,d2,d1,d3] + ((T2*)dst)[out_index] = tmp; } } +// dim3 grid_trans_v(batch_size,seq_len, head_num); +// dim3 block_trans_v(size_per_head); template<> __global__ void -transpose102_v2(__half* dst, const __half* src, const int off0, const int i_off1, const int o_off1, const int off2) +transpose102_v2(float* dst, const float* src, const int off0, const int i_off1, const int o_off1, const int off2) { int x[4] = {0}; - x[0] = blockIdx.x; //[0,7] batch_size - x[1] = threadIdx.x * 2 / off2; // head_num - x[2] = blockIdx.y; // seq_len - x[3] = threadIdx.x * 2 % off2; //[0,63] size_per_head + x[0] = blockIdx.x; + x[1] = threadIdx.x / off2; + x[2] = blockIdx.y; //[0,128] seq_len + x[3] = threadIdx.x % off2; //[0,31] size_per_head + float tmp; if (x[3] < off2) { - half2 tmp; - int in_index = (x[0] * off0 + x[1] * i_off1 + x[2] * off2 + x[3]) >> 1; // [batch, 0, 1, 2]=[d0,d1,d2,d3] - tmp = ((half2*)src)[in_index]; - int out_index = (x[0] * off0 + x[2] * o_off1 + x[1] * off2 + x[3]) >> 1; // [batch, 1, 0, 2]=[d0,d2,d1,d3] - ((half2*)dst)[out_index] = tmp; + int input_index = x[0] * off0 + x[1] * i_off1 + x[2] * off2 + x[3]; // [batch, 0, 1, 2]=[d0,d1,d2,d3] + tmp = src[input_index]; + int out_index = x[0] * off0 + x[2] * o_off1 + x[1] * off2 + x[3]; // [batch, 1, 0, 2]=[d0,d2,d1,d3] + + dst[out_index] = tmp; } } +// Applied to half and bfloat16 template __global__ void addBias_layerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, float epsilon) { - int tid = threadIdx.x; - + using T2 = typename TypeConverter::Type; // half2 or bfloat162 + int tid = threadIdx.x; __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; + float mean = 0.0f; + float variance = 0.0f; + float2 local_out_fp2; + + T2* out_ptr = (T2*)out; + const T2* input_ptr = (const T2*)input; + const T2* bias_ptr = (const T2*)bias; // blockIdx.x * n + i + const T2* gamma_ptr = (const T2*)gamma; + const T2* beta_ptr = (const T2*)beta; float local_out = 0.0f; - int id = blockIdx.x * n + tid; - local_out += (float)(input[id]); - local_out += (float)(__ldg(&bias[id])); + int id = (blockIdx.x * n + tid * 2) >> 1; + local_out_fp2 = type22float2(hadd2(input_ptr[id], ldg(&bias_ptr[id]))); + + local_out += local_out_fp2.x; + local_out += local_out_fp2.y; mean = blockReduceSum(local_out); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); - variance = blockReduceSum((local_out - s_mean) * (local_out - s_mean)); - if (threadIdx.x == 0) { - s_variance = variance / n + epsilon; - } + + variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); + variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); + variance = blockReduceSum(variance); + if (threadIdx.x == 0) + s_variance = rsqrtf(variance / n + epsilon); __syncthreads(); - out[id] = - (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); + float2 gamma_val = type22float2(ldg(&gamma_ptr[tid])); + float2 beta_val = type22float2(ldg(&beta_ptr[tid])); + local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; + local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; + out_ptr[id] = float22type2(local_out_fp2); } template<> -__global__ void addBias_layerNorm(__half* out, - const __half* input, - const __half* bias, - const __half* gamma, - const __half* beta, - int m, - int n, - float epsilon) +__global__ void addBias_layerNorm(float* out, + const float* input, + const float* bias, + const float* gamma, + const float* beta, + int m, + int n, + float epsilon) { - int tid = threadIdx.x; + __shared__ float s_mean; __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float2 local_out_fp2; - - half2* out_ptr = (half2*)out; - const half2* input_ptr = (const half2*)input; - const half2* bias_ptr = (const half2*)bias; // blockIdx.x * n + i - const half2* gamma_ptr = (const half2*)gamma; - const half2* beta_ptr = (const half2*)beta; + float mean = 0.0f; + float variance = 0.0f; float local_out = 0.0f; - int id = (blockIdx.x * n + tid * 2) >> 1; - local_out_fp2 = __half22float2(__hadd2(input_ptr[id], __ldg(&bias_ptr[id]))); - - local_out += local_out_fp2.x; - local_out += local_out_fp2.y; + int id = blockIdx.x * n + tid; + local_out += (float)(input[id]); + local_out += (float)(__ldg(&bias[id])); mean = blockReduceSum(local_out); if (threadIdx.x == 0) { @@ -815,19 +807,14 @@ __global__ void addBias_layerNorm(__half* out, } __syncthreads(); - variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); - variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); - variance = blockReduceSum(variance); + float diff = local_out - s_mean; + variance = blockReduceSum(diff * diff); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / n + epsilon); } __syncthreads(); - float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); - float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); - local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; - local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; - out_ptr[id] = __float22half2_rn(local_out_fp2); + out[id] = (float)(((local_out - s_mean) * s_variance) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); } /*width=hidden_dim_ff; @@ -839,18 +826,18 @@ template __global__ void gelu_bias_loop(T* src, const T* bias, int width, int height) { int batch = blockIdx.x; - int x = blockIdx.y; - int y = threadIdx.x; + int x = blockIdx.y; + int y = threadIdx.x; if (x < height) { - int index = batch * width * height + x * width; + int index = batch * width * height + x * width; float v_src; float v_bias; float v; for (; y < width; y = y + blockDim.x) { v_bias = bias[y]; - v_src = src[index + y]; - v = v_src + v_bias; + v_src = src[index + y]; + v = v_src + v_bias; src[index + y] = (T)(0.5f * v * (1.0f + tanhf(0.79788456f * (v + 0.044715f * v * v * v)))); } @@ -860,22 +847,22 @@ template<> __global__ void gelu_bias_loop(__half* src, const __half* bias, int width, int height) { int batch = blockIdx.x; - int x = blockIdx.y; - int y = threadIdx.x * 2; + int x = blockIdx.y; + int y = threadIdx.x * 2; if (x < height) { - int index = batch * width * height + x * width; - half2 v_src; - half2 v_bias; - half2 v; + int index = batch * width * height + x * width; + half2 v_src; + half2 v_bias; + half2 v; float2 t; for (; y < width; y = y + blockDim.x * 2) { v_bias = ((half2*)bias)[y >> 1]; - v_src = ((half2*)src)[(index + y) >> 1]; - v = __hadd2(v_src, v_bias); - t = __half22float2(v); - t.x = (0.5f * t.x * (1.0f + tanhf(0.79788456f * (t.x + 0.044715f * t.x * t.x * t.x)))); - t.y = (0.5f * t.y * (1.0f + tanhf(0.79788456f * (t.y + 0.044715f * t.y * t.y * t.y)))); + v_src = ((half2*)src)[(index + y) >> 1]; + v = __hadd2(v_src, v_bias); + t = __half22float2(v); + t.x = (0.5f * t.x * (1.0f + tanhf(0.79788456f * (t.x + 0.044715f * t.x * t.x * t.x)))); + t.y = (0.5f * t.y * (1.0f + tanhf(0.79788456f * (t.y + 0.044715f * t.y * t.y * t.y)))); ((half2*)src)[(index + y) >> 1] = __float22half2_rn(t); } @@ -885,29 +872,29 @@ __global__ void gelu_bias_loop(__half* src, const __half* bias, int width, int h /*********************Invoke Functions***********************/ template -void invokePrepareMatrixes(int batch_size, - int seq_len, - int hidden_dim, - int size_per_head, - T* q_buf, - T* q_buf_bd, - T* q_buf_ef, - T* k_buf, - T* k_buf_bd, - T* k_buf_ef, - T* query_buf, - T* key_buf, - T* k_head_r, - T* attr_seg_embed, - T* attr_bias_Q_w, - T* attr_bias_Q_r, - T* attr_bias_Q_s, +void invokePrepareMatrixes(int batch_size, + int seq_len, + int hidden_dim, + int size_per_head, + T* q_buf, + T* q_buf_bd, + T* q_buf_ef, + T* k_buf, + T* k_buf_bd, + T* k_buf_ef, + T* query_buf, + T* key_buf, + T* k_head_r, + T* attr_seg_embed, + T* attr_bias_Q_w, + T* attr_bias_Q_r, + T* attr_bias_Q_s, cudaStream_t stream) { - int off0 = seq_len * hidden_dim; // seq_len*head_num*size_per_head - int i_off1 = hidden_dim; // head_num*size_per_head + int off0 = seq_len * hidden_dim; // seq_len*head_num*size_per_head + int i_off1 = hidden_dim; // head_num*size_per_head int o_off1 = seq_len * size_per_head; - int off2 = size_per_head; + int off2 = size_per_head; dim3 grid(seq_len, batch_size); dim3 block(next_pow2(hidden_dim) / numPerThread()); @@ -940,10 +927,10 @@ void invokeTranspose102( // dim3 block_trans_v(seq_len);//__half dim3 block_trans_v(seq_len, 2 / (numPerThread())); - int toff0 = head_num * seq_len * 2; + int toff0 = head_num * seq_len * 2; int ti_off1 = seq_len * 2; int to_off1 = head_num * 2; - int toff2 = 2; + int toff2 = 2; transpose102<<>>( qk_buf_ef_trans, qk_buf_ef, toff0, ti_off1, to_off1, toff2); @@ -955,29 +942,21 @@ void invokeTranspose201( { dim3 grid_trans2(batch_size, seq_len); dim3 block_trans2(seq_len); - int t2_off0 = seq_len * seq_len * head_num; - int t2_i_off1 = seq_len * head_num; - int t2_o_off1 = seq_len * seq_len; - int t2_i_off2 = head_num; - int t2_o_off2 = seq_len; + int t2_off0 = seq_len * seq_len * head_num; + int t2_i_off1 = seq_len * head_num; + int t2_o_off1 = seq_len * seq_len; + int t2_i_off2 = head_num; + int t2_o_off2 = seq_len; transpose201<<>>( qk_buf_ef_seg_trans, qk_buf_ef_seg, t2_off0, t2_i_off1, t2_i_off2, t2_o_off1, t2_o_off2); } +// applied to half and bfloat16 template void blockRelShiftBd(dim3& grid, dim3& block, int batch_size, int head_num, int seq_len) { - grid.x = batch_size; - grid.y = head_num; - grid.z = seq_len; - - block.x = seq_len * 2; -} -template<> -void blockRelShiftBd(dim3& grid, dim3& block, int batch_size, int head_num, int seq_len) -{ - int threads = 512; + int threads = 512; int seq_dim1 = threads / seq_len; int seq_dim2 = seq_len / seq_dim1; @@ -989,6 +968,16 @@ void blockRelShiftBd(dim3& grid, dim3& block, int batch_size, int head_num block.y = seq_len; } +template<> +void blockRelShiftBd(dim3& grid, dim3& block, int batch_size, int head_num, int seq_len) +{ + grid.x = batch_size; + grid.y = head_num; + grid.z = seq_len; + + block.x = seq_len * 2; +} + template void invokeRelShiftBd(int batch_size, int head_num, int seq_len, T* qk_buf_bd_shift, T* qk_buf_bd, cudaStream_t stream) { @@ -1003,70 +992,27 @@ void invokeRelShiftBd(int batch_size, int head_num, int seq_len, T* qk_buf_bd_sh } template -void invokeCalAttnScore(int batch_size, - int head_num, - int seq_len, - int size_per_head, - float q_scaling, - T* attn_score, - T* qk_buf, - T* qk_buf_bd_shift, - T* qk_buf_ef_seg_trans, - T* attn_mask, - T* value_buf_trans, - T* value_buf, +void invokeCalAttnScore(int batch_size, + int head_num, + int seq_len, + int size_per_head, + float q_scaling, + T* attn_score, + T* qk_buf, + T* qk_buf_bd_shift, + T* qk_buf_ef_seg_trans, + T* attn_mask, + T* value_buf_trans, + T* value_buf, cudaStream_t stream) { - int off0 = head_num * seq_len * seq_len; - int off1 = seq_len * seq_len; - float p = 1 / ((pow(size_per_head, 0.5)) * q_scaling); - - int voff0 = head_num * seq_len * size_per_head; - int v_o_off1 = seq_len * size_per_head; - int voff2 = size_per_head; - int v_i_off1 = head_num * size_per_head; - - dim3 grid_score(batch_size, head_num, seq_len); - dim3 block_score(next_pow2(seq_len)); - calAttnScore_valueBuf<<>>(attn_score, - qk_buf, - qk_buf_bd_shift, - qk_buf_ef_seg_trans, - attn_mask, - off0, - off1, - seq_len, - p, - value_buf_trans, - value_buf, - voff0, - v_i_off1, - v_o_off1, - voff2); -} + int off0 = head_num * seq_len * seq_len; + int off1 = seq_len * seq_len; + float p = 1 / ((pow(size_per_head, 0.5)) * q_scaling); -template<> -void invokeCalAttnScore(int batch_size, - int head_num, - int seq_len, - int size_per_head, - float q_scaling, - half* attn_score, - half* qk_buf, - half* qk_buf_bd_shift, - half* qk_buf_ef_seg_trans, - half* attn_mask, - half* value_buf_trans, - half* value_buf, - cudaStream_t stream) -{ - int off0 = head_num * seq_len * seq_len; - int off1 = seq_len * seq_len; - float p = 1 / ((pow(size_per_head, 0.5)) * q_scaling); - - int voff0 = head_num * seq_len * size_per_head; + int voff0 = head_num * seq_len * size_per_head; int v_o_off1 = seq_len * size_per_head; - int voff2 = size_per_head; + int voff2 = size_per_head; int v_i_off1 = head_num * size_per_head; if (seq_len <= 32) { dim3 grid_score(batch_size, head_num, 2); @@ -1132,6 +1078,49 @@ void invokeCalAttnScore(int batch_size, } } +template<> +void invokeCalAttnScore(int batch_size, + int head_num, + int seq_len, + int size_per_head, + float q_scaling, + float* attn_score, + float* qk_buf, + float* qk_buf_bd_shift, + float* qk_buf_ef_seg_trans, + float* attn_mask, + float* value_buf_trans, + float* value_buf, + cudaStream_t stream) +{ + int off0 = head_num * seq_len * seq_len; + int off1 = seq_len * seq_len; + float p = 1 / ((pow(size_per_head, 0.5)) * q_scaling); + + int voff0 = head_num * seq_len * size_per_head; + int v_o_off1 = seq_len * size_per_head; + int voff2 = size_per_head; + int v_i_off1 = head_num * size_per_head; + + dim3 grid_score(batch_size, head_num, seq_len); + dim3 block_score(next_pow2(seq_len)); + calAttnScore_valueBuf<<>>(attn_score, + qk_buf, + qk_buf_bd_shift, + qk_buf_ef_seg_trans, + attn_mask, + off0, + off1, + seq_len, + p, + value_buf_trans, + value_buf, + voff0, + v_i_off1, + v_o_off1, + voff2); +} + template void invokeTranspose102v2( int batch_size, int seq_len, int head_num, int size_per_head, T* attn_vec_trans, T* attn_vec, cudaStream_t stream) @@ -1139,23 +1128,23 @@ void invokeTranspose102v2( dim3 grid_trans_v(batch_size, seq_len); dim3 block_trans_v(head_num * size_per_head / numPerThread()); - int off0 = head_num * seq_len * size_per_head; + int off0 = head_num * seq_len * size_per_head; int i_off1 = seq_len * size_per_head; int o_off1 = head_num * size_per_head; - int off2 = size_per_head; + int off2 = size_per_head; transpose102_v2<<>>(attn_vec_trans, attn_vec, off0, i_off1, o_off1, off2); } template -void invokeAddResidualLayerNorm(int batch_size, - int seq_len, - int hidden_dim, - T* attn_layernorm, - T* attn_out, - const T* in_tensor, - const T* attr_layernorm_gamma, - const T* attr_layernorm_beta, +void invokeAddResidualLayerNorm(int batch_size, + int seq_len, + int hidden_dim, + T* attn_layernorm, + T* attn_out, + const T* in_tensor, + const T* attr_layernorm_gamma, + const T* attr_layernorm_beta, cudaStream_t stream) { dim3 grid(batch_size * seq_len); @@ -1180,124 +1169,208 @@ void invokeGelu( } /*********************The explicit instantiation part***********************/ -template void invokePrepareMatrixes(int batch_size, - int seq_len, - int hidden_dim, - int size_per_head, - float* q_buf, - float* q_buf_bd, - float* q_buf_ef, - float* k_buf, - float* k_buf_bd, - float* k_buf_ef, - float* query_buf, - float* key_buf, - float* k_head_r, - float* attr_seg_embed, - float* attr_bias_Q_w, - float* attr_bias_Q_r, - float* attr_bias_Q_s, +template void invokePrepareMatrixes(int batch_size, + int seq_len, + int hidden_dim, + int size_per_head, + float* q_buf, + float* q_buf_bd, + float* q_buf_ef, + float* k_buf, + float* k_buf_bd, + float* k_buf_ef, + float* query_buf, + float* key_buf, + float* k_head_r, + float* attr_seg_embed, + float* attr_bias_Q_w, + float* attr_bias_Q_r, + float* attr_bias_Q_s, cudaStream_t stream); -template void invokePrepareMatrixes(int batch_size, - int seq_len, - int hidden_dim, - int size_per_head, - half* q_buf, - half* q_buf_bd, - half* q_buf_ef, - half* k_buf, - half* k_buf_bd, - half* k_buf_ef, - half* query_buf, - half* key_buf, - half* k_head_r, - half* attr_seg_embed, - half* attr_bias_Q_w, - half* attr_bias_Q_r, - half* attr_bias_Q_s, +template void invokePrepareMatrixes(int batch_size, + int seq_len, + int hidden_dim, + int size_per_head, + half* q_buf, + half* q_buf_bd, + half* q_buf_ef, + half* k_buf, + half* k_buf_bd, + half* k_buf_ef, + half* query_buf, + half* key_buf, + half* k_head_r, + half* attr_seg_embed, + half* attr_bias_Q_w, + half* attr_bias_Q_r, + half* attr_bias_Q_s, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokePrepareMatrixes<__nv_bfloat16>(int batch_size, + int seq_len, + int hidden_dim, + int size_per_head, + __nv_bfloat16* q_buf, + __nv_bfloat16* q_buf_bd, + __nv_bfloat16* q_buf_ef, + __nv_bfloat16* k_buf, + __nv_bfloat16* k_buf_bd, + __nv_bfloat16* k_buf_ef, + __nv_bfloat16* query_buf, + __nv_bfloat16* key_buf, + __nv_bfloat16* k_head_r, + __nv_bfloat16* attr_seg_embed, + __nv_bfloat16* attr_bias_Q_w, + __nv_bfloat16* attr_bias_Q_r, + __nv_bfloat16* attr_bias_Q_s, + cudaStream_t stream); +#endif + template void invokeTranspose102( int batch_size, int seq_len, int head_num, float* qk_buf_ef_trans, float* qk_buf_ef, cudaStream_t stream); template void invokeTranspose102( int batch_size, int seq_len, int head_num, half* qk_buf_ef_trans, half* qk_buf_ef, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeTranspose102<__nv_bfloat16>(int batch_size, + int seq_len, + int head_num, + __nv_bfloat16* qk_buf_ef_trans, + __nv_bfloat16* qk_buf_ef, + cudaStream_t stream); +#endif + template void invokeTranspose201( int batch_size, int seq_len, int head_num, float* qk_buf_ef_seg_trans, float* qk_buf_ef_seg, cudaStream_t stream); template void invokeTranspose201( int batch_size, int seq_len, int head_num, half* qk_buf_ef_seg_trans, half* qk_buf_ef_seg, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeTranspose201<__nv_bfloat16>(int batch_size, + int seq_len, + int head_num, + __nv_bfloat16* qk_buf_ef_seg_trans, + __nv_bfloat16* qk_buf_ef_seg, + cudaStream_t stream); +#endif + template void invokeRelShiftBd( int batch_size, int head_num, int seq_len, float* qk_buf_bd_shift, float* qk_buf_bd, cudaStream_t stream); template void invokeRelShiftBd( int batch_size, int head_num, int seq_len, half* qk_buf_bd_shift, half* qk_buf_bd, cudaStream_t stream); -template void invokeCalAttnScore(int batch_size, - int head_num, - int seq_len, - int size_per_head, - float q_scaling, - float* attn_score, - float* qk_buf, - float* qk_buf_bd_shift, - float* qk_buf_ef_seg_trans, - float* attn_mask, - float* value_buf_trans, - float* value_buf, +#ifdef ENABLE_BF16 +template void invokeRelShiftBd<__nv_bfloat16>(int batch_size, + int head_num, + int seq_len, + __nv_bfloat16* qk_buf_bd_shift, + __nv_bfloat16* qk_buf_bd, + cudaStream_t stream); +#endif + +template void invokeCalAttnScore(int batch_size, + int head_num, + int seq_len, + int size_per_head, + float q_scaling, + float* attn_score, + float* qk_buf, + float* qk_buf_bd_shift, + float* qk_buf_ef_seg_trans, + float* attn_mask, + float* value_buf_trans, + float* value_buf, cudaStream_t stream); -template void invokeCalAttnScore(int batch_size, - int head_num, - int seq_len, - int size_per_head, - float q_scaling, - half* attn_score, - half* qk_buf, - half* qk_buf_bd_shift, - half* qk_buf_ef_seg_trans, - half* attn_mask, - half* value_buf_trans, - half* value_buf, +template void invokeCalAttnScore(int batch_size, + int head_num, + int seq_len, + int size_per_head, + float q_scaling, + half* attn_score, + half* qk_buf, + half* qk_buf_bd_shift, + half* qk_buf_ef_seg_trans, + half* attn_mask, + half* value_buf_trans, + half* value_buf, cudaStream_t stream); -template void invokeTranspose102v2(int batch_size, - int seq_len, - int head_num, - int size_per_head, - float* attn_vec_trans, - float* attn_vec, +#ifdef ENABLE_BF16 +template void invokeCalAttnScore<__nv_bfloat16>(int batch_size, + int head_num, + int seq_len, + int size_per_head, + float q_scaling, + __nv_bfloat16* attn_score, + __nv_bfloat16* qk_buf, + __nv_bfloat16* qk_buf_bd_shift, + __nv_bfloat16* qk_buf_ef_seg_trans, + __nv_bfloat16* attn_mask, + __nv_bfloat16* value_buf_trans, + __nv_bfloat16* value_buf, + cudaStream_t stream); +#endif + +template void invokeTranspose102v2(int batch_size, + int seq_len, + int head_num, + int size_per_head, + float* attn_vec_trans, + float* attn_vec, cudaStream_t stream); -template void invokeTranspose102v2(int batch_size, - int seq_len, - int head_num, - int size_per_head, - half* attn_vec_trans, - half* attn_vec, +template void invokeTranspose102v2(int batch_size, + int seq_len, + int head_num, + int size_per_head, + half* attn_vec_trans, + half* attn_vec, cudaStream_t stream); - -template void invokeAddResidualLayerNorm(int batch_size, - int seq_len, - int hidden_dim, - float* attn_layernorm, - float* attn_out, +#ifdef ENABLE_BF16 +template void invokeTranspose102v2<__nv_bfloat16>(int batch_size, + int seq_len, + int head_num, + int size_per_head, + __nv_bfloat16* attn_vec_trans, + __nv_bfloat16* attn_vec, + cudaStream_t stream); +#endif + +template void invokeAddResidualLayerNorm(int batch_size, + int seq_len, + int hidden_dim, + float* attn_layernorm, + float* attn_out, const float* in_tensor, const float* attr_layernorm_gamma, const float* attr_layernorm_beta, cudaStream_t stream); -template void invokeAddResidualLayerNorm(int batch_size, - int seq_len, - int hidden_dim, - half* attn_layernorm, - half* attn_out, - const half* in_tensor, - const half* attr_layernorm_gamma, - const half* attr_layernorm_beta, +template void invokeAddResidualLayerNorm(int batch_size, + int seq_len, + int hidden_dim, + half* attn_layernorm, + half* attn_out, + const half* in_tensor, + const half* attr_layernorm_gamma, + const half* attr_layernorm_beta, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddResidualLayerNorm<__nv_bfloat16>(int batch_size, + int seq_len, + int hidden_dim, + __nv_bfloat16* attn_layernorm, + __nv_bfloat16* attn_out, + const __nv_bfloat16* in_tensor, + const __nv_bfloat16* attr_layernorm_gamma, + const __nv_bfloat16* attr_layernorm_beta, + cudaStream_t stream); +#endif template void invokeGelu( int batch_size, int seq_len, int hidden_dim_ff, float* output_fc1, const float* attr_fc1_bias, cudaStream_t stream); diff --git a/src/fastertransformer/kernels/xlnet_attention_kernels.h b/src/fastertransformer/kernels/xlnet_attention_kernels.h index 9579b86b0..df8092afb 100644 --- a/src/fastertransformer/kernels/xlnet_attention_kernels.h +++ b/src/fastertransformer/kernels/xlnet_attention_kernels.h @@ -24,23 +24,23 @@ namespace fastertransformer { #define FINAL_MASK 0xffffffff const float epsilon = 0.0f; template -void invokePrepareMatrixes(int batch_size, - int seq_len, - int hidden_dim, - int size_per_head, - T* q_buf, - T* q_buf_bd, - T* q_buf_ef, - T* k_buf, - T* k_buf_bd, - T* k_buf_ef, - T* query_buf, - T* key_buf, - T* k_head_r, - T* attr_seg_embed, - T* attr_bias_Q_w, - T* attr_bias_Q_r, - T* attr_bias_Q_s, +void invokePrepareMatrixes(int batch_size, + int seq_len, + int hidden_dim, + int size_per_head, + T* q_buf, + T* q_buf_bd, + T* q_buf_ef, + T* k_buf, + T* k_buf_bd, + T* k_buf_ef, + T* query_buf, + T* key_buf, + T* k_head_r, + T* attr_seg_embed, + T* attr_bias_Q_w, + T* attr_bias_Q_r, + T* attr_bias_Q_s, cudaStream_t stream); template @@ -55,18 +55,18 @@ template void invokeRelShiftBd(int batch_size, int head_num, int seq_len, T* qk_buf_bd_shift, T* qk_buf_bd, cudaStream_t stream); template -void invokeCalAttnScore(int batch_size, - int head_num, - int seq_len, - int size_per_head, - float q_scaling, - T* attn_score, - T* qk_buf, - T* qk_buf_bd_shift, - T* qk_buf_ef_seg_trans, - T* attn_mask, - T* value_buf_trans, - T* value_buf, +void invokeCalAttnScore(int batch_size, + int head_num, + int seq_len, + int size_per_head, + float q_scaling, + T* attn_score, + T* qk_buf, + T* qk_buf_bd_shift, + T* qk_buf_ef_seg_trans, + T* attn_mask, + T* value_buf_trans, + T* value_buf, cudaStream_t stream); template @@ -74,14 +74,14 @@ void invokeTranspose102v2( int batch_size, int seq_len, int head_num, int size_per_head, T* attn_vec_trans, T* attn_vec, cudaStream_t stream); template -void invokeAddResidualLayerNorm(int batch_size, - int seq_len, - int hidden_dim, - T* attn_layernorm, - T* attn_out, - const T* in_tensor, - const T* attr_layernorm_gamma, - const T* attr_layernorm_beta, +void invokeAddResidualLayerNorm(int batch_size, + int seq_len, + int hidden_dim, + T* attn_layernorm, + T* attn_out, + const T* in_tensor, + const T* attr_layernorm_gamma, + const T* attr_layernorm_beta, cudaStream_t stream); template diff --git a/src/fastertransformer/kernels/xlnet_preprocess_kernels.cu b/src/fastertransformer/kernels/xlnet_preprocess_kernels.cu index 9f10c6d6c..92d426f4e 100644 --- a/src/fastertransformer/kernels/xlnet_preprocess_kernels.cu +++ b/src/fastertransformer/kernels/xlnet_preprocess_kernels.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" #include "xlnet_preprocess_kernels.h" namespace fastertransformer { @@ -36,57 +37,47 @@ int numPerThread() return sizeof(float) / sizeof(T); } /********************** Kernels ************************/ + +// Applied to half or bfloat16 +// dim3 grid(batch_size, seq_len); +// getWordEmdK<<>>(word_emb_k, params_word_emb_k, inp_k, seq_len, hidden_dim); template void __global__ getWordEmdK(T* word_emb_k, T* params_word_emb_k, int* inp_k, int seq_len, int hidden_dim) { + using T2 = typename TypeConverter::Type; // half2 or bfloat162 + int col = threadIdx.x; // the index of column + int row = blockIdx.y; // the index of row + int batch = blockIdx.x; // the index of batch - int col = threadIdx.x; // the index of column - int row = blockIdx.y; // the index of row - int batch = blockIdx.x; // the index of batch + int index = ldg(inp_k + batch * seq_len + row); + T2 data = ((T2*)params_word_emb_k)[(index * hidden_dim + col * 2) >> 1]; - int index = inp_k[batch * seq_len + row]; - T data = params_word_emb_k[index * hidden_dim + col]; - - word_emb_k[batch * seq_len * hidden_dim + row * hidden_dim + col] = data; + ((T2*)word_emb_k)[(batch * seq_len * hidden_dim + row * hidden_dim + col * 2) >> 1] = data; } -// dim3 grid(batch_size, seq_len); -// getWordEmdK<<>>(word_emb_k, params_word_emb_k, inp_k, seq_len, hidden_dim); template<> -void __global__ getWordEmdK(__half* word_emb_k, __half* params_word_emb_k, int* inp_k, int seq_len, int hidden_dim) +void __global__ getWordEmdK(float* word_emb_k, float* params_word_emb_k, int* inp_k, int seq_len, int hidden_dim) { - int col = threadIdx.x; // the index of column - int row = blockIdx.y; // the index of row - int batch = blockIdx.x; // the index of batch - int index = __ldg(inp_k + batch * seq_len + row); - half2 data = ((half2*)params_word_emb_k)[(index * hidden_dim + col * 2) >> 1]; + int col = threadIdx.x; // the index of column + int row = blockIdx.y; // the index of row + int batch = blockIdx.x; // the index of batch + + int index = inp_k[batch * seq_len + row]; + float data = params_word_emb_k[index * hidden_dim + col]; - ((half2*)word_emb_k)[(batch * seq_len * hidden_dim + row * hidden_dim + col * 2) >> 1] = data; + word_emb_k[batch * seq_len * hidden_dim + row * hidden_dim + col] = data; } +// Applied to half or bfloat16 template void __global__ getAttnMask(T* attn_mask, float* input_mask, int seq_len) { - int col = threadIdx.x; - int row = blockIdx.y; - int batch = blockIdx.x; - - float data = 1; - if (col == row) { - data = 0; - } - float mask = input_mask[batch * seq_len + col]; - attn_mask[batch * seq_len * seq_len + row * seq_len + col] = cast(data * mask); -} - -template<> -void __global__ getAttnMask(__half* attn_mask, float* input_mask, int seq_len) -{ + using T2 = typename TypeConverter::Type; // half2 or bfloat162 int in_index = blockIdx.y * blockDim.x + threadIdx.x; - int col = in_index % (seq_len / 2) * 2; - int row = in_index / (seq_len / 2); - int batch = blockIdx.x; + int col = in_index % (seq_len / 2) * 2; + int row = in_index / (seq_len / 2); + int batch = blockIdx.x; float2 tmp; if (row < seq_len && col < seq_len - 1) { @@ -103,52 +94,67 @@ void __global__ getAttnMask(__half* attn_mask, float* input_mask, int seq_len) } tmp.y = input_mask[batch * seq_len + col] * data; - int out_index = (batch * seq_len * seq_len + row * seq_len + col) >> 1; - ((half2*)attn_mask)[out_index] = __float22half2_rn(tmp); + int out_index = (batch * seq_len * seq_len + row * seq_len + col) >> 1; + ((T2*)attn_mask)[out_index] = float22type2(tmp); + } +} + +template<> +void __global__ getAttnMask(float* attn_mask, float* input_mask, int seq_len) +{ + int col = threadIdx.x; + int row = blockIdx.y; + int batch = blockIdx.x; + + float data = 1; + if (col == row) { + data = 0; } + float mask = input_mask[batch * seq_len + col]; + attn_mask[batch * seq_len * seq_len + row * seq_len + col] = cast(data * mask); } +// Applied to half or bfloat16 template void __global__ getSegMat(T* seg_mat, int* seg_id, int seq_len) { - int col = threadIdx.x; - int row = blockIdx.y; + using T2 = typename TypeConverter::Type; // half2 or bfloat162 + int col = threadIdx.x; + int row = blockIdx.y; int batch = blockIdx.x; int w[4] = {0, 1, 1, 0}; - int d1 = seg_id[batch * seq_len + col]; - int d2 = seg_id[batch * seq_len + row]; - int d = 0; + int d1 = seg_id[batch * seq_len + col]; + int d2 = seg_id[batch * seq_len + row]; + int d = 0; d = int(floor(exp(-1 * abs(double(d1 - d2))))); - int index = batch * seq_len * seq_len + row * seq_len + col; - seg_mat[index * 2] = w[d * 2 + 0]; - seg_mat[index * 2 + 1] = w[d * 2 + 1]; + int index = batch * seq_len * seq_len + row * seq_len + col; + float2 tmp_w; + tmp_w.x = w[d * 2 + 0]; + tmp_w.y = w[d * 2 + 1]; + + ((T2*)seg_mat)[index] = float22type2(tmp_w); } template<> -void __global__ getSegMat(__half* seg_mat, int* seg_id, int seq_len) +void __global__ getSegMat(float* seg_mat, int* seg_id, int seq_len) { - int col = threadIdx.x; - int row = blockIdx.y; + int col = threadIdx.x; + int row = blockIdx.y; int batch = blockIdx.x; int w[4] = {0, 1, 1, 0}; - int d1 = seg_id[batch * seq_len + col]; - int d2 = seg_id[batch * seq_len + row]; - int d = 0; + int d1 = seg_id[batch * seq_len + col]; + int d2 = seg_id[batch * seq_len + row]; + int d = 0; d = int(floor(exp(-1 * abs(double(d1 - d2))))); - int index = batch * seq_len * seq_len + row * seq_len + col; - float2 tmp_w; - tmp_w.x = w[d * 2 + 0]; - tmp_w.y = w[d * 2 + 1]; - - ((half2*)seg_mat)[index] = __float22half2_rn(tmp_w); - // seg_mat[index*2]=__int2half_rn(w[d*2+0]); - // seg_mat[index*2+1]=__int2half_rn(w[d*2+1]); + int index = batch * seq_len * seq_len + row * seq_len + col; + seg_mat[index * 2] = w[d * 2 + 0]; + seg_mat[index * 2 + 1] = w[d * 2 + 1]; } template @@ -163,30 +169,31 @@ void __global__ relativePosition(T* attr_k_head_r, int hidden_dim, int seq_len) float fwd_pos_seq = seq_len - row; float pos_emd = inv_freq * fwd_pos_seq; - float s = sinf(pos_emd); - float c = cosf(pos_emd); + float s = sinf(pos_emd); + float c = cosf(pos_emd); - attr_k_head_r[row * hidden_dim + col] = cast(s); + attr_k_head_r[row * hidden_dim + col] = cast(s); attr_k_head_r[row * hidden_dim + hidden_dim / 2 + col] = cast(c); } /***********************Pre-Process************************/ -template<> -void blockAttnMask(dim3& grid, dim3& block, int batch_size, int seq_len) +// Applied to half or bfloat16 +template +void blockAttnMask(dim3& grid, dim3& block, int batch_size, int seq_len) { - grid.x = batch_size; - grid.y = seq_len; - block.x = seq_len; + int numThreads = 512; + int numBlocky = (seq_len * seq_len / 2 - 1) / numThreads + 1; + grid.x = batch_size; + grid.y = numBlocky; + block.x = numThreads; } template<> -void blockAttnMask(dim3& grid, dim3& block, int batch_size, int seq_len) +void blockAttnMask(dim3& grid, dim3& block, int batch_size, int seq_len) { - int numThreads = 512; - int numBlocky = (seq_len * seq_len / 2 - 1) / numThreads + 1; - grid.x = batch_size; - grid.y = numBlocky; - block.x = numThreads; + grid.x = batch_size; + grid.y = seq_len; + block.x = seq_len; } template @@ -201,14 +208,14 @@ void genWordEmdK( } template -void preProcess(int batch_size, - int seq_len, - int hidden_dim, - T* attn_mask, - float* input_mask, - T* seg_mat, - int* seg_id, - T* attr_k_head_r, +void preProcess(int batch_size, + int seq_len, + int hidden_dim, + T* attn_mask, + float* input_mask, + T* seg_mat, + int* seg_id, + T* attr_k_head_r, cudaStream_t stream) { dim3 grid_attn_mask; @@ -226,38 +233,59 @@ void preProcess(int batch_size, relativePosition<<>>(attr_k_head_r, hidden_dim, seq_len); } -template void preProcess(int batch_size, - int seq_len, - int hidden_dim, - float* attn_mask, - float* input_mask, - float* seg_mat, - int* seg_id, - float* attr_k_head_r, +template void preProcess(int batch_size, + int seq_len, + int hidden_dim, + float* attn_mask, + float* input_mask, + float* seg_mat, + int* seg_id, + float* attr_k_head_r, cudaStream_t stream); -template void preProcess(int batch_size, - int seq_len, - int hidden_dim, - half* attn_mask, - float* input_mask, - half* seg_mat, - int* seg_id, - half* attr_k_head_r, +template void preProcess(int batch_size, + int seq_len, + int hidden_dim, + half* attn_mask, + float* input_mask, + half* seg_mat, + int* seg_id, + half* attr_k_head_r, cudaStream_t stream); -template void genWordEmdK(int batch_size, - int seq_len, - int hidden_dim, - float* word_emb_k, - float* params_word_emb_k, - int* inp_k, +#ifdef ENABLE_BF16 +template void preProcess<__nv_bfloat16>(int batch_size, + int seq_len, + int hidden_dim, + __nv_bfloat16* attn_mask, + float* input_mask, + __nv_bfloat16* seg_mat, + int* seg_id, + __nv_bfloat16* attr_k_head_r, + cudaStream_t stream); +#endif + +template void genWordEmdK(int batch_size, + int seq_len, + int hidden_dim, + float* word_emb_k, + float* params_word_emb_k, + int* inp_k, cudaStream_t stream); -template void genWordEmdK(int batch_size, - int seq_len, - int hidden_dim, - half* word_emb_k, - half* params_word_emb_k, - int* inp_k, +template void genWordEmdK(int batch_size, + int seq_len, + int hidden_dim, + half* word_emb_k, + half* params_word_emb_k, + int* inp_k, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void genWordEmdK<__nv_bfloat16>(int batch_size, + int seq_len, + int hidden_dim, + __nv_bfloat16* word_emb_k, + __nv_bfloat16* params_word_emb_k, + int* inp_k, + cudaStream_t stream); +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/xlnet_preprocess_kernels.h b/src/fastertransformer/kernels/xlnet_preprocess_kernels.h index 7548aeb47..be46c2a22 100644 --- a/src/fastertransformer/kernels/xlnet_preprocess_kernels.h +++ b/src/fastertransformer/kernels/xlnet_preprocess_kernels.h @@ -29,13 +29,13 @@ void genWordEmdK( int batch_size, int seq_len, int hidden_dim, T* word_emb_k, T* params_word_emb_k, int* inp_k, cudaStream_t stream); template -void preProcess(int batch_size, - int seq_len, - int hidden_dim, - T* attn_mask, - float* input_mask, - T* seg_mat, - int* seg_id, - T* attr_k_head_r, +void preProcess(int batch_size, + int seq_len, + int hidden_dim, + T* attn_mask, + float* input_mask, + T* seg_mat, + int* seg_id, + T* attr_k_head_r, cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/layers/BaseLayer.h b/src/fastertransformer/layers/BaseLayer.h index ad3a38707..ded83e993 100644 --- a/src/fastertransformer/layers/BaseLayer.h +++ b/src/fastertransformer/layers/BaseLayer.h @@ -26,12 +26,12 @@ namespace fastertransformer { class BaseLayer { public: - BaseLayer(cudaStream_t stream, + BaseLayer(cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop = nullptr, - bool sparse = false): + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop = nullptr, + bool sparse = false): stream_(stream), cublas_wrapper_(cublas_wrapper), allocator_(allocator), @@ -47,13 +47,13 @@ class BaseLayer { protected: virtual void allocateBuffer() = 0; - virtual void freeBuffer() = 0; + virtual void freeBuffer() = 0; // device environments - cudaStream_t stream_; + cudaStream_t stream_; cublasMMWrapper* cublas_wrapper_; - IAllocator* allocator_; - cudaDeviceProp* cuda_device_prop_ = nullptr; + IAllocator* allocator_; + cudaDeviceProp* cuda_device_prop_ = nullptr; bool is_free_buffer_after_forward_; bool is_allocate_buffer_ = false; // TODO (bhsueh) to be deprecated diff --git a/src/fastertransformer/layers/CMakeLists.txt b/src/fastertransformer/layers/CMakeLists.txt index cbaf4fac4..c8d391910 100644 --- a/src/fastertransformer/layers/CMakeLists.txt +++ b/src/fastertransformer/layers/CMakeLists.txt @@ -23,26 +23,32 @@ add_subdirectory(sampling_layers) add_library(FfnLayer STATIC FfnLayer.cc) set_property(TARGET FfnLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET FfnLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(FfnLayer PUBLIC -lcublas -lcudart cublasMMWrapper activation_kernels memory_utils matrix_vector_multiplication) +target_link_libraries(FfnLayer PUBLIC -lcublas -lcudart cublasMMWrapper activation_kernels memory_utils matrix_vector_multiplication tensor) add_library(FfnLayerINT8 STATIC FfnLayerINT8.cc) set_property(TARGET FfnLayerINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET FfnLayerINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(FfnLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart cublasMMWrapper cublasINT8MMWrapper activation_int8_kernels memory_utils) +target_link_libraries(FfnLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart cublasMMWrapper cublasINT8MMWrapper activation_int8_kernels memory_utils tensor) add_library(TensorParallelGeluFfnLayer STATIC TensorParallelGeluFfnLayer.cc) set_property(TARGET TensorParallelGeluFfnLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET TensorParallelGeluFfnLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(TensorParallelGeluFfnLayer PUBLIC -lcudart FfnLayer nccl_utils) +target_link_libraries(TensorParallelGeluFfnLayer PUBLIC -lcudart FfnLayer nccl_utils tensor) add_library(TensorParallelReluFfnLayer STATIC TensorParallelReluFfnLayer.cc) set_property(TARGET TensorParallelReluFfnLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET TensorParallelReluFfnLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(TensorParallelReluFfnLayer PUBLIC -lcudart FfnLayer nccl_utils) +target_link_libraries(TensorParallelReluFfnLayer PUBLIC -lcudart FfnLayer nccl_utils tensor) add_library(DynamicDecodeLayer STATIC DynamicDecodeLayer.cc) set_property(TARGET DynamicDecodeLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET DynamicDecodeLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(DynamicDecodeLayer PUBLIC -lcudart - TopKSamplingLayer TopPSamplingLayer TopKTopPSamplingLayer - OnlineBeamSearchLayer BeamSearchLayer ban_bad_words stop_criteria) +target_link_libraries(DynamicDecodeLayer PUBLIC -lcudart + TopKSamplingLayer TopPSamplingLayer TopKTopPSamplingLayer + OnlineBeamSearchLayer BeamSearchLayer ban_bad_words stop_criteria + gpt_kernels tensor) + +add_library(TensorParallelSiluFfnLayer STATIC TensorParallelSiluFfnLayer.cc) +set_property(TARGET TensorParallelSiluFfnLayer PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET TensorParallelSiluFfnLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(TensorParallelSiluFfnLayer PUBLIC -lcudart FfnLayer nccl_utils) diff --git a/src/fastertransformer/layers/DenseWeight.h b/src/fastertransformer/layers/DenseWeight.h index 5a5eb6a60..d787baaa7 100644 --- a/src/fastertransformer/layers/DenseWeight.h +++ b/src/fastertransformer/layers/DenseWeight.h @@ -20,12 +20,12 @@ namespace fastertransformer { template struct DenseWeight { - const T* kernel = nullptr; - const T* bias = nullptr; + const T* kernel = nullptr; + const T* bias = nullptr; const T* sp_kernel = nullptr; // for int8 kernel const int8_t* int8_kernel = nullptr; - const float* scale = nullptr; + const float* scale = nullptr; }; } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/DynamicDecodeBaseLayer.h b/src/fastertransformer/layers/DynamicDecodeBaseLayer.h index c59479f43..6f5de13c1 100644 --- a/src/fastertransformer/layers/DynamicDecodeBaseLayer.h +++ b/src/fastertransformer/layers/DynamicDecodeBaseLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,21 +26,24 @@ namespace fastertransformer { class DynamicDecodeBaseLayer: public BaseLayer { protected: virtual void allocateBuffer() = 0; - virtual void freeBuffer() = 0; + virtual void freeBuffer() = 0; public: - DynamicDecodeBaseLayer(cudaStream_t stream, + DynamicDecodeBaseLayer(cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop): + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop){}; ~DynamicDecodeBaseLayer() = default; DynamicDecodeBaseLayer(DynamicDecodeBaseLayer const& dynamic_decode_layer): BaseLayer(dynamic_decode_layer){}; - virtual void forward(std::vector* output_tensors, - const std::vector* input_tensors) = 0; - virtual void forward(std::unordered_map* output_tensors, + virtual void setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) = 0; + virtual void forward(std::vector* output_tensors, + const std::vector* input_tensors) = 0; + virtual void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) = 0; }; diff --git a/src/fastertransformer/layers/DynamicDecodeLayer.cc b/src/fastertransformer/layers/DynamicDecodeLayer.cc index 8b834bc41..0caebb7b9 100644 --- a/src/fastertransformer/layers/DynamicDecodeLayer.cc +++ b/src/fastertransformer/layers/DynamicDecodeLayer.cc @@ -30,6 +30,7 @@ template void DynamicDecodeLayer::allocateBuffer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); + finished_sum_ = (int*)allocator_->reMalloc(finished_sum_, sizeof(int), true); return; } @@ -37,6 +38,7 @@ template void DynamicDecodeLayer::freeBuffer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); + allocator_->free((void**)(&finished_sum_)); return; } @@ -53,7 +55,7 @@ void DynamicDecodeLayer::initialize() 0, // end_id, deprecated 0.0f, // beam_search_diversity_rate_, deprecated 1.0f, // temperature_, deprecated - 1.0f, // len_penalty_, deprecated + 0.0f, // len_penalty_, deprecated 1.0f, // repetition_penalty_, deprecated stream_, cublas_wrapper_, @@ -69,7 +71,7 @@ void DynamicDecodeLayer::initialize() 0, // end_id, deprecated 0.0f, // beam_search_diversity_rate_, deprecated 1.0f, // temperature_, deprecated - 1.0f, // len_penalty_, deprecated + 0.0f, // len_penalty_, deprecated 1.0f, // repetition_penalty_, deprecated stream_, cublas_wrapper_, @@ -83,7 +85,7 @@ void DynamicDecodeLayer::initialize() 0, // top_k_, deprecated 0, // random_seed_, deprecated 1.0f, // temperature_, deprecated - 1.0f, // len_penalty_, deprecated + 0.0f, // len_penalty_, deprecated 1.0f, // repetition_penalty_, deprecated stream_, cublas_wrapper_, @@ -97,7 +99,7 @@ void DynamicDecodeLayer::initialize() 0.0f, // top_p_, deprecated 0, // random_seed_, deprecated 1.0f, // temperature_, deprecated - 1.0f, // len_penalty_, deprecated + 0.0f, // len_penalty_, deprecated 1.0f, // repetition_penalty_, deprecated stream_, cublas_wrapper_, @@ -105,31 +107,18 @@ void DynamicDecodeLayer::initialize() false, cuda_device_prop_); - topk_topp_decode_ = new TopKTopPSamplingLayer(0, - vocab_size_, - vocab_size_padded_, - 0, // end_id, deprecated - 0, // top_k_, deprecated - 0.0f, // top_p_, deprecated - 0, // random_seed_, deprecated - 1.0f, // temperature_, deprecated - 1.0f, // len_penalty_, deprecated - 1.0f, // repetition_penalty_, deprecated - stream_, - cublas_wrapper_, - allocator_, - false); + allocateBuffer(); } template -DynamicDecodeLayer::DynamicDecodeLayer(size_t vocab_size, - size_t vocab_size_padded, - int end_id, - cudaStream_t stream, +DynamicDecodeLayer::DynamicDecodeLayer(size_t vocab_size, + size_t vocab_size_padded, + int end_id, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop): + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), vocab_size_(vocab_size), vocab_size_padded_(vocab_size_padded), @@ -143,6 +132,11 @@ template DynamicDecodeLayer::~DynamicDecodeLayer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); + delete online_beamsearch_decode_; + delete beamsearch_decode_; + delete topk_decode_; + delete topp_decode_; + freeBuffer(); } template @@ -157,7 +151,30 @@ DynamicDecodeLayer::DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_deco } template -void DynamicDecodeLayer::forward(std::unordered_map* output_tensors, +void DynamicDecodeLayer::setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) +{ + // Set up the dynamic decode layer for given input runtime arguments. + // + // input_tensors: + // runtime_top_k [1] or [batch_size] on cpu, optional. + // runtime_top_p [1] or [batch_size] on cpu, optional + // beam_search_diversity_rate [1] or [batch_size] on cpu, optional + // temperature [1] or [batch_size] on cpu, optional + // len_penalty [1] or [batch_size] on cpu, optional + // repetition_penalty [1] or [batch_size] on cpu, optional + + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + has_diff_runtime_args_ = hasDiffRuntimeArgs(runtime_args); + if (beam_width == 1) { // sampling layers + topk_decode_->setup(batch_size, beam_width, runtime_args); + topp_decode_->setup(batch_size, beam_width, runtime_args); + } +} + +template +void DynamicDecodeLayer::forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) { // TODO(bhsueh) @@ -170,16 +187,16 @@ void DynamicDecodeLayer::forward(std::unordered_map* out * \param step [1] on cpu * \param max_input_length [1] on cpu * \param input_lengths [batch_size, beam_width] + * \param sequence_limit_length [batch_size] * \param ite [1] on cpu * \param local_batch_size [1] on cpu - * \param has_diff_runtime_args [1] on cpu * \param stop_words_list [batch_size, 2, stop_words_length], optional - * \param runtime_top_k [1] or [batch_size] on cpu, optional - * \param runtime_top_p [1] or [batch_size] on cpu, optional - * \param temperature [1] or [batch_size] on cpu, optional - * \param len_penalty [1] or [batch_size] on cpu, optional - * \param repetition_penalty [1] or [batch_size] on cpu, optional - * \param random_seed [1] or [batch_size] on cpu, optional + * \param runtime_top_k [1] or [batch_size] on cpu, optional, uint + * \param runtime_top_p [1] or [batch_size] on cpu, optional, float + * \param temperature [1] or [batch_size] on cpu, optional, float + * \param len_penalty [1] or [batch_size] on cpu, optional, float + * \param repetition_penalty [1] or [batch_size] on cpu, optional, float + * \param random_seed [1] or [batch_size] on cpu, optional, unsigned long long int * \param bad_words_list [2, bad_words_length] or [batch_size, 2, bad_words_length], optional * \param src_key_cache [layer, batch_size * beam_width, local_head_num, @@ -196,6 +213,7 @@ void DynamicDecodeLayer::forward(std::unordered_map* out * output_tensors: * \param output_ids [max_seq_len, batch_size] * \param finished [batch_size * beam_width] + * \param should_stop [1] on cpu * \param cum_log_probs [batch_size * beam_width], necessary in beam search * \param parent_ids [max_seq_len, batch_size * beam_width] * \param sequence_length [batch_size * beam_width] @@ -205,28 +223,22 @@ void DynamicDecodeLayer::forward(std::unordered_map* out the k/v cache index for beam search **/ - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - const int ite = input_tensors->at("ite").getVal(); + const int ite = input_tensors->at("ite").getVal(); const int step = input_tensors->at("step").getVal(); - const int has_diff_runtime_args = input_tensors->at("has_diff_runtime_args").getVal(); FT_CHECK(input_tensors->at("logits").shape.size() == 3); - const size_t batch_size = input_tensors->at("logits").shape[0]; - const size_t beam_width = input_tensors->at("logits").shape[1]; - const size_t local_batch_size = (size_t)input_tensors->at("local_batch_size").getVal(); - int* tmp_seed_ptr = new int(0); - Tensor tmp_seed_tensor = Tensor{MEMORY_CPU, TYPE_INT32, {1}, tmp_seed_ptr}; - int* tmp_k_ptr = new int(1); - Tensor tmp_k_tensor = Tensor{MEMORY_CPU, TYPE_UINT32, {1}, tmp_k_ptr}; + const size_t batch_size = input_tensors->at("logits").shape[0]; + const size_t beam_width = input_tensors->at("logits").shape[1]; + const size_t local_batch_size = (size_t)input_tensors->at("local_batch_size").getVal(); if (input_tensors->find("bad_words_list") != input_tensors->end()) { - const auto& bad_words = input_tensors->at("bad_words_list"); - const int* bad_words_ptr = reinterpret_cast(bad_words.data); - const bool shared_bad_words = bad_words.shape.size() == 2; - const size_t bad_words_len = bad_words.shape[shared_bad_words ? 1 : 2]; + const auto& bad_words = input_tensors->at("bad_words_list"); + const int* bad_words_ptr = reinterpret_cast(bad_words.data); + const bool shared_bad_words = bad_words.shape.size() == 2; + const size_t bad_words_len = bad_words.shape[shared_bad_words ? 1 : 2]; - const int id_offset = ite * local_batch_size; + const int id_offset = ite * local_batch_size; const int decode_vocab_size_units_offset = id_offset * vocab_size_padded_; invokeBanBadWords((T*)input_tensors->at("logits").getPtrWithOffset(decode_vocab_size_units_offset), @@ -247,84 +259,86 @@ void DynamicDecodeLayer::forward(std::unordered_map* out } // dynamic decode GPT - const size_t dynamic_decode_batch_size = has_diff_runtime_args ? 1 : local_batch_size; - const int dynamic_decode_total_iteration = local_batch_size / dynamic_decode_batch_size; - - for (int dynamic_ite = ite * dynamic_decode_total_iteration; - dynamic_ite < (ite + 1) * dynamic_decode_total_iteration; - ++dynamic_ite) { - const int dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beam_width; - const int dynamic_decode_vocab_size_units_offset = dynamic_id_offset * vocab_size_padded_; - - // common inputs - Tensor logits = input_tensors->at("logits"); - Tensor input_lengths = input_tensors->at("input_lengths"); - Tensor end_id = input_tensors->at("end_id"); - std::unordered_map dynamic_decode_input_tensors{ - {"logits", - Tensor{logits.where, - logits.type, - {dynamic_decode_batch_size, logits.shape[1], logits.shape[2]}, - logits.getPtrWithOffset(dynamic_decode_vocab_size_units_offset)}}, - {"embedding_bias", input_tensors->at("embedding_bias")}, - {"step", input_tensors->at("step")}, - {"max_input_length", input_tensors->at("max_input_length")}, - {"end_id", - Tensor{end_id.where, - end_id.type, - {dynamic_decode_batch_size}, - end_id.getPtrWithOffset(dynamic_ite * dynamic_decode_batch_size)}}, - {"input_lengths", - Tensor{input_lengths.where, - input_lengths.type, - {dynamic_decode_batch_size, input_lengths.shape[1]}, - input_tensors->at("input_lengths").getPtrWithOffset(dynamic_id_offset)}}, - {"ite", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &dynamic_ite}}}; - - for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { - if (t->first.find("random_seed") == std::string::npos) { - dynamic_decode_input_tensors.insert(*t); + if (beam_width > 1) { + // Because we still not support batch beam search now, so we need to compute one by one if there are different + // runtime arguments. + const size_t dynamic_decode_batch_size = has_diff_runtime_args_ ? 1 : local_batch_size; + const int dynamic_decode_total_iteration = local_batch_size / dynamic_decode_batch_size; + + for (int dynamic_ite = ite * dynamic_decode_total_iteration; + dynamic_ite < (ite + 1) * dynamic_decode_total_iteration; + ++dynamic_ite) { + const int dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beam_width; + const int dynamic_decode_vocab_size_units_offset = dynamic_id_offset * vocab_size_padded_; + + // common inputs + Tensor logits = input_tensors->at("logits"); + Tensor input_lengths = input_tensors->at("input_lengths"); + Tensor end_id = input_tensors->at("end_id"); + + std::unordered_map dynamic_decode_input_tensors{ + {"logits", + Tensor{logits.where, + logits.type, + {dynamic_decode_batch_size, logits.shape[1], logits.shape[2]}, + logits.getPtrWithOffset(dynamic_decode_vocab_size_units_offset)}}, + {"embedding_bias", input_tensors->at("embedding_bias")}, + {"step", input_tensors->at("step")}, + {"max_input_length", input_tensors->at("max_input_length")}, + {"end_id", + Tensor{end_id.where, + end_id.type, + {dynamic_decode_batch_size}, + end_id.getPtrWithOffset(dynamic_ite * dynamic_decode_batch_size)}}, + {"input_lengths", + Tensor{input_lengths.where, + input_lengths.type, + {dynamic_decode_batch_size, input_lengths.shape[1]}, + input_tensors->at("input_lengths").getPtrWithOffset(dynamic_id_offset)}}, + {"ite", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &dynamic_ite}}}; + + for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { + if (t->first.find("random_seed") == std::string::npos) { + dynamic_decode_input_tensors.insert(*t); + } } - } - Tensor finished = output_tensors->at("finished"); - Tensor sequence_length = output_tensors->at("sequence_length"); - // common outputs - std::unordered_map dynamic_decode_output_tensors{ - {"output_ids", output_tensors->at("output_ids")}, - {"finished", - Tensor{finished.where, - finished.type, - {dynamic_decode_batch_size * beam_width}, - finished.getPtrWithOffset(dynamic_id_offset)}}, - {"sequence_length", - Tensor{sequence_length.where, - sequence_length.type, - {dynamic_decode_batch_size * beam_width}, - sequence_length.getPtrWithOffset(dynamic_id_offset)}}}; - - if (output_tensors->count("cum_log_probs") > 0) { - Tensor cum_log_probs = output_tensors->at("cum_log_probs"); - dynamic_decode_output_tensors.insert({"cum_log_probs", - Tensor{cum_log_probs.where, - cum_log_probs.type, - {dynamic_decode_batch_size * beam_width}, - cum_log_probs.getPtrWithOffset(dynamic_id_offset)}}); - } - - if (output_tensors->count("output_log_probs")) { - dynamic_decode_output_tensors.insert( - {"output_log_probs", - Tensor{MEMORY_GPU, - TYPE_FP32, + Tensor finished = output_tensors->at("finished"); + Tensor sequence_length = output_tensors->at("sequence_length"); + // common outputs + std::unordered_map dynamic_decode_output_tensors{ + {"output_ids", output_tensors->at("output_ids")}, + {"finished", + Tensor{finished.where, + finished.type, {dynamic_decode_batch_size * beam_width}, - output_tensors->at("output_log_probs") - .getPtrWithOffset((step - input_tensors->at("max_input_length").getVal()) * batch_size - * beam_width - + dynamic_id_offset)}}); - } + finished.getPtrWithOffset(dynamic_id_offset)}}, + {"sequence_length", + Tensor{sequence_length.where, + sequence_length.type, + {dynamic_decode_batch_size * beam_width}, + sequence_length.getPtrWithOffset(dynamic_id_offset)}}}; + + if (output_tensors->count("cum_log_probs") > 0) { + Tensor cum_log_probs = output_tensors->at("cum_log_probs"); + dynamic_decode_output_tensors.insert({"cum_log_probs", + Tensor{cum_log_probs.where, + cum_log_probs.type, + {dynamic_decode_batch_size * beam_width}, + cum_log_probs.getPtrWithOffset(dynamic_id_offset)}}); + } + + if (output_tensors->count("output_log_probs")) { + size_t step_offset = + (step - input_tensors->at("max_input_length").getVal()) * batch_size * beam_width; + dynamic_decode_output_tensors.insert( + {"output_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + {dynamic_decode_batch_size * beam_width}, + output_tensors->at("output_log_probs").getPtrWithOffset(step_offset + dynamic_id_offset)}}); + } - if (beam_width > 1) { dynamic_decode_input_tensors.insert({"src_cache_indirection", input_tensors->at("src_cache_indirection")}); dynamic_decode_output_tensors.insert({"parent_ids", output_tensors->at("parent_ids")}); @@ -340,45 +354,61 @@ void DynamicDecodeLayer::forward(std::unordered_map* out else { beamsearch_decode_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); } - } - else { // beam_width = 1 - if (input_tensors->at("is_initialize_random_table").getVal()) { - // only insert random seed for first generation to initialize the random table - if (input_tensors->count("random_seed")) { - dynamic_decode_input_tensors.insert({"random_seed", input_tensors->at("random_seed")}); - } - else { - dynamic_decode_input_tensors.insert({"random_seed", tmp_seed_tensor}); - } - } + } // end of dynamic_ite + } + else { // beam_width=1 + // In sampling, we have supported batch sampling. So, we always compute all sentences once. + const size_t local_batch_offset = ite * local_batch_size * beam_width; - if (input_tensors->count("runtime_top_p") == 0 - || input_tensors->at("runtime_top_p").getVal(has_diff_runtime_args ? dynamic_ite : 0) == 0.0f) { + Tensor logits = input_tensors->at("logits"); + Tensor input_lengths = input_tensors->at("input_lengths"); + Tensor end_id = input_tensors->at("end_id"); - if (input_tensors->count("runtime_top_k") == 0 - || input_tensors->at("runtime_top_k").getVal() == 0) { - FT_LOG_WARNING("beam_width = 1 and top_k = 0 and top_p == 0.0f at the same time is invalid." - "Using Greedy search by default."); + std::unordered_map decode_input_tensors{ + {"logits", + logits.slice({local_batch_size, beam_width, logits.shape[2]}, local_batch_offset * logits.shape[2])}, + {"embedding_bias", input_tensors->at("embedding_bias")}, + {"step", input_tensors->at("step")}, + {"max_input_length", input_tensors->at("max_input_length")}, + {"end_id", end_id.slice({local_batch_size}, ite * local_batch_size)}, + {"input_lengths", + input_lengths.slice({local_batch_size * beam_width, input_lengths.shape[1]}, local_batch_offset)}, + {"ite", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &ite}}}; - if (dynamic_decode_input_tensors.count("dynamic_decode_input_tensors")) { - dynamic_decode_input_tensors.erase("runtime_top_k"); - } - dynamic_decode_input_tensors.insert({"runtime_top_k", tmp_k_tensor}); - } - topk_decode_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); - } - else if (input_tensors->count("runtime_top_k") == 0 - || input_tensors->at("runtime_top_k").getVal(has_diff_runtime_args ? dynamic_ite : 0) == 0) { - topp_decode_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); - } - else { - topk_topp_decode_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); - } + Tensor finished = output_tensors->at("finished"); + Tensor sequence_length = output_tensors->at("sequence_length"); + + std::unordered_map decode_output_tensors{ + {"output_ids", output_tensors->at("output_ids")}, + {"finished", finished.slice({local_batch_size * beam_width}, local_batch_offset)}, + {"sequence_length", sequence_length.slice({local_batch_size * beam_width}, local_batch_offset)}}; + if (output_tensors->count("cum_log_probs") > 0) { + Tensor cum_log_probs = output_tensors->at("cum_log_probs"); + decode_output_tensors.insert( + {"cum_log_probs", cum_log_probs.slice({local_batch_size * beam_width}, local_batch_offset)}); } - } // dynamic_ite + if (output_tensors->count("output_log_probs")) { + Tensor output_log_probs = output_tensors->at("output_log_probs"); + int max_input_length = input_tensors->at("max_input_length").getVal(); + size_t step_offset = (step - max_input_length) * batch_size * beam_width; + decode_output_tensors.insert({"output_log_probs", + output_log_probs.slice({output_log_probs.shape[0] - (step - max_input_length), + local_batch_size * beam_width}, + step_offset + local_batch_offset)}); + } + + // Run topk / topp decode layers. + // Currently, we support batch sampling. If the runtime arguments are like + // topk = [4, 0, 4]. topp = [0.0, 0.5, 0.5] + // then topk_decode handles [4, x, 4 + 0.5] + // topp_decode handles [x, 0.5, x] + // where "x" are skipped. + topk_decode_->forward(&decode_output_tensors, &decode_input_tensors); + topp_decode_->forward(&decode_output_tensors, &decode_input_tensors); + } if (input_tensors->find("stop_words_list") != input_tensors->end()) { - const size_t id_offset = ite * local_batch_size * beam_width; + const size_t id_offset = ite * local_batch_size * beam_width; const size_t stop_words_length = input_tensors->at("stop_words_list").shape[2]; invokeStopWordsCriterion((const int*)output_tensors->at("output_ids").data, @@ -394,8 +424,56 @@ void DynamicDecodeLayer::forward(std::unordered_map* out stream_); } - delete tmp_seed_ptr; - delete tmp_k_ptr; + if (input_tensors->find("sequence_limit_length") != input_tensors->end()) { + invokeLengthCriterion(output_tensors->at("finished").getPtr(), + output_tensors->at("should_stop").getPtr(), + finished_sum_, + input_tensors->at("sequence_limit_length").getPtr(), + batch_size, + beam_width, + step, + stream_); + } +} + +template +bool DynamicDecodeLayer::hasDiffRuntimeArgs(const std::unordered_map* input_tensors) +{ + for (int i = 0; i < (int)runtime_arg_names_.size(); i++) { + if (input_tensors->count(runtime_arg_names_[i])) { + auto tensor = input_tensors->at(runtime_arg_names_[i]); + FT_CHECK(tensor.shape.size() == 1); + for (int j = 1; j < (int)tensor.shape[0]; j++) { + const void* data = tensor.data; + switch (tensor.type) { + case TYPE_FP32: + if (((const float*)data)[0] != ((const float*)data)[j]) { + return true; + } + break; + case TYPE_INT32: + if (((const int*)data)[0] != ((const int*)data)[j]) { + return true; + } + break; + case TYPE_UINT32: + if (((const uint*)data)[0] != ((const uint*)data)[j]) { + return true; + } + break; + case TYPE_UINT64: + if (((const unsigned long long int*)data)[0] != ((const unsigned long long int*)data)[j]) { + return true; + } + break; + default: + FT_CHECK_WITH_INFO(false, runtime_arg_names_[i] + ": " + tensor.toString() + " is invalid."); + break; + } + } + } + } + return false; } template class DynamicDecodeLayer; diff --git a/src/fastertransformer/layers/DynamicDecodeLayer.h b/src/fastertransformer/layers/DynamicDecodeLayer.h index 1f64db4e8..695ec56a2 100644 --- a/src/fastertransformer/layers/DynamicDecodeLayer.h +++ b/src/fastertransformer/layers/DynamicDecodeLayer.h @@ -31,31 +31,41 @@ class DynamicDecodeLayer: public BaseLayer { void allocateBuffer() override; void freeBuffer() override; void initialize(); + bool hasDiffRuntimeArgs(const std::unordered_map* input_tensors); DynamicDecodeBaseLayer* online_beamsearch_decode_; DynamicDecodeBaseLayer* beamsearch_decode_; DynamicDecodeBaseLayer* topk_decode_; DynamicDecodeBaseLayer* topp_decode_; - DynamicDecodeBaseLayer* topk_topp_decode_; - size_t vocab_size_; - size_t vocab_size_padded_; + size_t vocab_size_; + size_t vocab_size_padded_; cudaDeviceProp* cuda_device_prop_; + // List of argument names which can have different values in runtime + // and does not support a batched version of kernel in beam search. + const std::vector runtime_arg_names_ = { + "beam_search_diversity_rate", "temperature", "len_penalty", "repetition_penalty"}; + bool has_diff_runtime_args_ = false; + int* finished_sum_ = nullptr; + public: - DynamicDecodeLayer(size_t vocab_size, - size_t vocab_size_padded, - int end_id, - cudaStream_t stream, + DynamicDecodeLayer(size_t vocab_size, + size_t vocab_size_padded, + int end_id, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop); + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop); ~DynamicDecodeLayer(); DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_decode_layer); - void forward(std::unordered_map* output_tensors, + void setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args); + void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors); }; diff --git a/src/fastertransformer/layers/FfnLayer.cc b/src/fastertransformer/layers/FfnLayer.cc index e05bea949..509b183c9 100644 --- a/src/fastertransformer/layers/FfnLayer.cc +++ b/src/fastertransformer/layers/FfnLayer.cc @@ -19,9 +19,9 @@ namespace fastertransformer { template -void FfnLayer::forward(std::vector* output_tensors, +void FfnLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const FfnWeight* ffn_weights) + const FfnWeight* ffn_weights) { // input tensors: // ffn_input [token_num, hidden_dimension], @@ -35,9 +35,12 @@ void FfnLayer::forward(std::vector* output_tensors // FT_CHECK(isValidTokenNum(input_tensors->at(0).shape[0])); allocateBuffer(input_tensors->at(0).shape[0]); - const int m = input_tensors->at(0).shape[0]; - T* output_tensor = (T*)output_tensors->at(0).data; - const T* input_tensor = (const T*)input_tensors->at(0).data; + const int m = input_tensors->at(0).shape[0]; + T* output_tensor = (T*)output_tensors->at(0).data; + const T* input_tensor = (const T*)input_tensors->at(0).data; + + // TODO: INT8 and Sparsity are currently not implemented (geglu or reglu) + const bool use_gated_activation = use_gated_activation_ && ffn_weights->intermediate_weight2.kernel != nullptr; #ifdef SPARSITY_ENABLED int m_tmp = input_tensors->at(0).shape[0]; @@ -46,6 +49,7 @@ void FfnLayer::forward(std::vector* output_tensors } const int m_padded = m_tmp; if (sparse_ && cublas_wrapper_->isUseSparse(1, inter_size_, m, hidden_units_)) { + FT_CHECK(!use_gated_activation); cublas_wrapper_->SpGemm(CUBLAS_OP_N, CUBLAS_OP_N, inter_size_, @@ -58,6 +62,7 @@ void FfnLayer::forward(std::vector* output_tensors else { #endif if (int8_mode_ == 1 && m <= 2) { + FT_CHECK(!use_gated_activation); FT_CHECK(ffn_weights->intermediate_weight.int8_kernel != NULL && ffn_weights->intermediate_weight.scale != NULL); int8WeightPerChannelLdkMultiplicationLauncher(ffn_weights->intermediate_weight.int8_kernel, @@ -84,12 +89,30 @@ void FfnLayer::forward(std::vector* output_tensors hidden_units_, inter_buf_, inter_size_); + if (use_gated_activation) { + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + inter_size_, + m, + hidden_units_, + ffn_weights->intermediate_weight2.kernel, + inter_size_, + input_tensor, + hidden_units_, + inter_buf_2_, + inter_size_); + } } #ifdef SPARSITY_ENABLED } #endif - invokeAddBiasActivation(m, ffn_weights->intermediate_weight.bias); + if (use_gated_activation) { + invokeAddBiasGatedActivation(m, ffn_weights->intermediate_weight.bias, ffn_weights->intermediate_weight2.bias); + } + else { + invokeAddBiasActivation(m, ffn_weights->intermediate_weight.bias); + } sync_check_cuda_error(); #ifdef SPARSITY_ENABLED @@ -140,24 +163,27 @@ void FfnLayer::forward(std::vector* output_tensors } template -FfnLayer::FfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - cudaStream_t stream, +FfnLayer::FfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse, - int int8_mode): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode, + bool use_gated_activation): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse), max_token_num_(max_batch_size * max_seq_len), head_num_(head_num), size_per_head_(size_per_head), hidden_units_(head_num * size_per_head), + max_inter_size_(inter_size), inter_size_(inter_size), - int8_mode_(int8_mode) + int8_mode_(int8_mode), + use_gated_activation_(use_gated_activation) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); } @@ -174,8 +200,10 @@ FfnLayer::FfnLayer(FfnLayer const& ffn_layer): head_num_(ffn_layer.head_num_), size_per_head_(ffn_layer.size_per_head_), hidden_units_(ffn_layer.hidden_units_), + max_inter_size_(ffn_layer.max_inter_size_), inter_size_(ffn_layer.inter_size_), - int8_mode_(ffn_layer.int8_mode_) + int8_mode_(ffn_layer.int8_mode_), + use_gated_activation_(ffn_layer.use_gated_activation_) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); } @@ -193,7 +221,10 @@ void FfnLayer::allocateBuffer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_ == false) { - inter_buf_ = (T*)allocator_->malloc(sizeof(T) * max_token_num_ * inter_size_, false); + inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sizeof(T) * max_token_num_ * max_inter_size_, false); + if (use_gated_activation_) { + inter_buf_2_ = (T*)allocator_->reMalloc(inter_buf_2_, sizeof(T) * max_token_num_ * max_inter_size_, false); + } is_allocate_buffer_ = true; } } @@ -202,7 +233,10 @@ template void FfnLayer::allocateBuffer(size_t token_num) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sizeof(T) * token_num * inter_size_, false); + inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sizeof(T) * token_num * max_inter_size_, false); + if (use_gated_activation_) { + inter_buf_2_ = (T*)allocator_->reMalloc(inter_buf_2_, sizeof(T) * token_num * max_inter_size_, false); + } is_allocate_buffer_ = true; } @@ -211,7 +245,10 @@ void FfnLayer::freeBuffer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_) { - allocator_->free(inter_buf_); + allocator_->free((void**)(&inter_buf_)); + if (use_gated_activation_) { + allocator_->free((void**)(&inter_buf_2_)); + } is_allocate_buffer_ = false; } } @@ -232,17 +269,18 @@ template class FfnLayer<__nv_bfloat16>; #endif template -GeluFfnLayer::GeluFfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - cudaStream_t stream, +GeluFfnLayer::GeluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse, - int int8_mode): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode, + bool use_gated_activation): FfnLayer(max_batch_size, max_seq_len, head_num, @@ -253,7 +291,8 @@ GeluFfnLayer::GeluFfnLayer(size_t max_batch_size, allocator, is_free_buffer_after_forward, sparse, - int8_mode) + int8_mode, + use_gated_activation) { } @@ -265,7 +304,13 @@ GeluFfnLayer::GeluFfnLayer(GeluFfnLayer const& gelu_ffn_layer): FfnLayer void GeluFfnLayer::invokeAddBiasActivation(const int m, const T* bias) { - invokeAddBiasGelu(inter_buf_, bias, m, inter_size_, stream_); + invokeAddBiasGeluV2(inter_buf_, bias, m, inter_size_, stream_); +} + +template +void GeluFfnLayer::invokeAddBiasGatedActivation(const int m, const T* bias1, const T* bias2) +{ + invokeAddBiasGatedGelu(inter_buf_, inter_buf_2_, bias1, bias2, m, inter_size_, stream_); } template class GeluFfnLayer; @@ -275,16 +320,17 @@ template class GeluFfnLayer<__nv_bfloat16>; #endif template -ReluFfnLayer::ReluFfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - cudaStream_t stream, +ReluFfnLayer::ReluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + bool use_gated_activation): FfnLayer(max_batch_size, max_seq_len, head_num, @@ -294,7 +340,9 @@ ReluFfnLayer::ReluFfnLayer(size_t max_batch_size, cublas_wrapper, allocator, is_free_buffer_after_forward, - sparse) + sparse, + 0, + use_gated_activation) { } @@ -309,10 +357,66 @@ void ReluFfnLayer::invokeAddBiasActivation(const int m, const T* bias) invokeAddBiasRelu(inter_buf_, bias, m, inter_size_, stream_); } +template +void ReluFfnLayer::invokeAddBiasGatedActivation(const int m, const T* bias1, const T* bias2) +{ + invokeAddBiasGatedRelu(inter_buf_, inter_buf_2_, bias1, bias2, m, inter_size_, stream_); +} + template class ReluFfnLayer; template class ReluFfnLayer; #ifdef ENABLE_BF16 template class ReluFfnLayer<__nv_bfloat16>; #endif +template +SiluFfnLayer::SiluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + bool use_gated_activation): + FfnLayer(max_batch_size, + max_seq_len, + head_num, + size_per_head, + inter_size, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + sparse, + 0, + use_gated_activation) +{ +} + +template +SiluFfnLayer::SiluFfnLayer(SiluFfnLayer const& gelu_ffn_layer): FfnLayer(gelu_ffn_layer) +{ +} + +template +void SiluFfnLayer::invokeAddBiasActivation(const int m, const T* bias) +{ + invokeAddBiasSilu(inter_buf_, bias, m, inter_size_, stream_); +} + +template +void SiluFfnLayer::invokeAddBiasGatedActivation(const int m, const T* bias1, const T* bias2) +{ + invokeAddBiasGatedSilu(inter_buf_, inter_buf_2_, bias1, bias2, m, inter_size_, stream_); +} + +template class SiluFfnLayer; +template class SiluFfnLayer; +#ifdef ENABLE_BF16 +template class SiluFfnLayer<__nv_bfloat16>; +#endif + } // namespace fastertransformer diff --git a/src/fastertransformer/layers/FfnLayer.h b/src/fastertransformer/layers/FfnLayer.h index f5d1bc87d..e28fce570 100644 --- a/src/fastertransformer/layers/FfnLayer.h +++ b/src/fastertransformer/layers/FfnLayer.h @@ -20,16 +20,54 @@ #include "src/fastertransformer/kernels/matrix_vector_multiplication.h" #include "src/fastertransformer/layers/BaseLayer.h" #include "src/fastertransformer/layers/FfnWeight.h" +#include "src/fastertransformer/utils/cuda_utils.h" #include "src/fastertransformer/utils/memory_utils.h" #include namespace fastertransformer { -enum ActivationType { +enum class ActivationType { Gelu, - Relu + Relu, + Silu, + GeGLU, + ReGLU, + SiGLU, + InvalidType }; +inline ActivationType getActivationType(std::string activation_type_str) +{ + if (activation_type_str == "Gelu" || activation_type_str == "gelu") { + return ActivationType::Gelu; + } + else if (activation_type_str == "Relu" || activation_type_str == "relu") { + return ActivationType::Relu; + } + else if (activation_type_str == "Silu" || activation_type_str == "silu") { + return ActivationType::Silu; + } + else if (activation_type_str == "GeGLU" || activation_type_str == "geglu" || activation_type_str == "gated-gelu") { + return ActivationType::GeGLU; + } + else if (activation_type_str == "ReGLU" || activation_type_str == "reglu" || activation_type_str == "gated-relu") { + return ActivationType::ReGLU; + } + else if (activation_type_str == "SiGLU" || activation_type_str == "gated-silu") { + return ActivationType::SiGLU; + } + else { + FT_CHECK_WITH_INFO(false, "Activation Type: " + activation_type_str + " not supported !"); + } + return ActivationType::InvalidType; +} + +inline bool isGatedActivation(ActivationType activaiton_type) +{ + return activaiton_type == ActivationType::GeGLU || activaiton_type == ActivationType::ReGLU + || activaiton_type == ActivationType::SiGLU; +} + template class FfnLayer: public BaseLayer { private: @@ -46,52 +84,71 @@ class FfnLayer: public BaseLayer { // calculated data size_t hidden_units_; + // gated activation + bool use_gated_activation_; + void allocateBuffer() override; void freeBuffer() override; bool isValidTokenNum(size_t token_num); void allocateBuffer(size_t token_num); protected: - T* inter_buf_ = nullptr; + T* inter_buf_ = nullptr; + T* inter_buf_2_ = nullptr; // for gated activation + // the inter size for runtime ffn layer size_t inter_size_; - virtual void invokeAddBiasActivation(const int m, const T* bias) = 0; + /* used to allocater memory buffers + different ffn layers (inter_size) will + reuse the same ffn layer with the max inter size. + max_inter_size will be passed as inter_size when initializing the ffn layer + */ + size_t max_inter_size_; + virtual void invokeAddBiasActivation(const int m, const T* bias) = 0; + virtual void invokeAddBiasGatedActivation(const int m, const T* bias1, const T* bias2) = 0; public: - FfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - cudaStream_t stream, + FfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0, + bool use_gated_activation = false); FfnLayer(FfnLayer const& ffn_layer); virtual ~FfnLayer(); - virtual void forward(std::vector* output_tensors, + void resetInterSize(size_t runtime_inter_size) + { + inter_size_ = runtime_inter_size; + } + + virtual void forward(std::vector* output_tensors, const std::vector* input_tensors, - const FfnWeight* ffn_weights); + const FfnWeight* ffn_weights); }; template class GeluFfnLayer: public FfnLayer { public: - GeluFfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - cudaStream_t stream, + GeluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0, + bool use_gated_activation = false); GeluFfnLayer(GeluFfnLayer const& ffn_layer); @@ -102,23 +159,26 @@ class GeluFfnLayer: public FfnLayer { private: using FfnLayer::inter_buf_; + using FfnLayer::inter_buf_2_; using FfnLayer::inter_size_; void invokeAddBiasActivation(const int m, const T* bias) override; + void invokeAddBiasGatedActivation(const int m, const T* bias1, const T* bias2) override; }; template class ReluFfnLayer: public FfnLayer { public: - ReluFfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - cudaStream_t stream, + ReluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + bool use_gated_activation = false); ReluFfnLayer(ReluFfnLayer const& ffn_layer); @@ -129,8 +189,40 @@ class ReluFfnLayer: public FfnLayer { private: using FfnLayer::inter_buf_; + using FfnLayer::inter_buf_2_; + using FfnLayer::inter_size_; + void invokeAddBiasActivation(const int m, const T* bias) override; + void invokeAddBiasGatedActivation(const int m, const T* bias1, const T* bias2) override; +}; + +template +class SiluFfnLayer: public FfnLayer { +public: + SiluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + bool use_gated_activation = false); + + SiluFfnLayer(SiluFfnLayer const& ffn_layer); + + virtual ~SiluFfnLayer() = default; + +protected: + using FfnLayer::stream_; + +private: + using FfnLayer::inter_buf_; + using FfnLayer::inter_buf_2_; using FfnLayer::inter_size_; void invokeAddBiasActivation(const int m, const T* bias) override; + void invokeAddBiasGatedActivation(const int m, const T* bias1, const T* bias2) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/FfnLayerINT8.cc b/src/fastertransformer/layers/FfnLayerINT8.cc index 46a597fc9..dc9cb7f66 100644 --- a/src/fastertransformer/layers/FfnLayerINT8.cc +++ b/src/fastertransformer/layers/FfnLayerINT8.cc @@ -15,13 +15,14 @@ */ #include "FfnLayerINT8.h" +#include "src/fastertransformer/utils/nvtx_utils.h" namespace fastertransformer { template -void FfnLayerINT8::forward(std::vector* output_tensors, +void FfnLayerINT8::forward(std::vector* output_tensors, const std::vector* input_tensors, - const FfnWeight* ffn_weights) + const FfnWeight* ffn_weights) { // input_tensors: [input (token_num, hidden_dimension)] // output_tensors: [output (token_num, hidden_dimension)] @@ -41,9 +42,10 @@ void FfnLayerINT8::forward(std::vector* output_ten const int m_padded = m_tmp; #endif - int32_t* output_tensor = (int32_t*)output_tensors->at(0).data; - const int8_t* input_tensor = (const int8_t*)input_tensors->at(0).data; + int32_t* output_tensor = (int32_t*)output_tensors->at(0).data; + const int8_t* input_tensor = (const int8_t*)input_tensors->at(0).data; + PUSH_RANGE("FFN gemm 1"); if (int8_mode_ == 1) { cublas_wrapper->Gemm(inter_int_buf_, 1, @@ -84,10 +86,14 @@ void FfnLayerINT8::forward(std::vector* output_ten } #endif } + POP_RANGE; + PUSH_RANGE("add bias act"); invokeAddBiasActivation(m, ffn_weights->intermediate_weight.bias, scale_list); + POP_RANGE; sync_check_cuda_error(); + PUSH_RANGE("FFN gemm 2"); if (int8_mode_ == 1) { cublas_wrapper->Gemm(output_tensor, 1, @@ -128,6 +134,7 @@ void FfnLayerINT8::forward(std::vector* output_ten } #endif } + POP_RANGE; sync_check_cuda_error(); if (is_free_buffer_after_forward_ == true) { @@ -137,17 +144,17 @@ void FfnLayerINT8::forward(std::vector* output_ten } template -FfnLayerINT8::FfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, +FfnLayerINT8::FfnLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_token_num_(max_batch_size * max_seq_len), head_num_(head_num), @@ -184,8 +191,9 @@ template void FfnLayerINT8::allocateBuffer() { if (is_allocate_buffer_ == false) { - inter_int_buf_ = (int32_t*)allocator_->malloc(sizeof(int32_t) * max_token_num_ * inter_size_, false); - inter_buf_ = (int8_t*)allocator_->malloc(sizeof(int8_t) * max_token_num_ * inter_size_, false); + inter_int_buf_ = + (int32_t*)allocator_->reMalloc(inter_int_buf_, sizeof(int32_t) * max_token_num_ * inter_size_, false); + inter_buf_ = (int8_t*)allocator_->reMalloc(inter_buf_, sizeof(int8_t) * max_token_num_ * inter_size_, false); is_allocate_buffer_ = true; } } @@ -194,8 +202,8 @@ template void FfnLayerINT8::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(inter_int_buf_); - allocator_->free(inter_buf_); + allocator_->free((void**)(&inter_int_buf_)); + allocator_->free((void**)(&inter_buf_)); is_allocate_buffer_ = false; } } @@ -216,17 +224,17 @@ template class FfnLayerINT8; template class FfnLayerINT8; template -GeluFfnLayerINT8::GeluFfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, +GeluFfnLayerINT8::GeluFfnLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse): FfnLayerINT8(max_batch_size, max_seq_len, head_num, @@ -292,16 +300,16 @@ template class GeluFfnLayerINT8; template class GeluFfnLayerINT8; template -ReluFfnLayerINT8::ReluFfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, +ReluFfnLayerINT8::ReluFfnLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): FfnLayerINT8(max_batch_size, max_seq_len, head_num, diff --git a/src/fastertransformer/layers/FfnLayerINT8.h b/src/fastertransformer/layers/FfnLayerINT8.h index 50c541307..f2b4d46ae 100644 --- a/src/fastertransformer/layers/FfnLayerINT8.h +++ b/src/fastertransformer/layers/FfnLayerINT8.h @@ -53,33 +53,33 @@ class FfnLayerINT8: public BaseLayer { protected: size_t inter_size_; - int int8_mode_; - bool sparse_; + int int8_mode_; + bool sparse_; - int* inter_int_buf_; - int8_t* inter_buf_; + int* inter_int_buf_; + int8_t* inter_buf_; virtual void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) = 0; public: - FfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, + FfnLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false); FfnLayerINT8(FfnLayerINT8 const& ffn_layer); ~FfnLayerINT8(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const FfnWeight* ffn_weights); + const FfnWeight* ffn_weights); friend GeluFfnLayerINT8; friend ReluFfnLayerINT8; @@ -88,17 +88,17 @@ class FfnLayerINT8: public BaseLayer { template class GeluFfnLayerINT8: public FfnLayerINT8 { public: - GeluFfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, + GeluFfnLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false); GeluFfnLayerINT8(GeluFfnLayerINT8 const& ffn_layer); @@ -118,16 +118,16 @@ class GeluFfnLayerINT8: public FfnLayerINT8 { template class ReluFfnLayerINT8: public FfnLayerINT8 { public: - ReluFfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, + ReluFfnLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); ReluFfnLayerINT8(ReluFfnLayerINT8 const& ffn_layer); diff --git a/src/fastertransformer/layers/FfnWeight.h b/src/fastertransformer/layers/FfnWeight.h index a079aa47f..0e74698f4 100644 --- a/src/fastertransformer/layers/FfnWeight.h +++ b/src/fastertransformer/layers/FfnWeight.h @@ -23,6 +23,7 @@ namespace fastertransformer { template struct FfnWeight { DenseWeight intermediate_weight; + DenseWeight intermediate_weight2; // for gated activation DenseWeight output_weight; }; diff --git a/src/fastertransformer/layers/TensorParallelGeluFfnLayer.cc b/src/fastertransformer/layers/TensorParallelGeluFfnLayer.cc index a1f3e7e79..509a94229 100644 --- a/src/fastertransformer/layers/TensorParallelGeluFfnLayer.cc +++ b/src/fastertransformer/layers/TensorParallelGeluFfnLayer.cc @@ -19,15 +19,16 @@ namespace fastertransformer { template -void TensorParallelGeluFfnLayer::forward(std::vector* output_tensors, +void TensorParallelGeluFfnLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const FfnWeight* ffn_weights) + const FfnWeight* ffn_weights) { - const size_t token_num = output_tensors->at(0).shape[0]; + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + const size_t token_num = output_tensors->at(0).shape[0]; const size_t hidden_units = output_tensors->at(0).shape[1]; bool use_custom_all_reduce_kernel = false; - if (enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr) { + if (do_all_reduce_ && enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr) { use_custom_all_reduce_kernel = custom_all_reduce_comm_->swapInternalBuffer(output_tensors, token_num * hidden_units); } @@ -35,7 +36,7 @@ void TensorParallelGeluFfnLayer::forward(std::vector::forward(output_tensors, input_tensors, ffn_weights); T* ffn_out = (T*)(output_tensors->at(0).data); - if (tensor_para_.world_size_ > 1) { + if (do_all_reduce_ && tensor_para_.world_size_ > 1) { if (!use_custom_all_reduce_kernel) { ftNcclAllReduceSum(ffn_out, ffn_out, token_num * hidden_units, tensor_para_, GeluFfnLayer::stream_); } @@ -47,20 +48,22 @@ void TensorParallelGeluFfnLayer::forward(std::vector -TensorParallelGeluFfnLayer::TensorParallelGeluFfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - NcclParam tensor_para, - cudaStream_t stream, +TensorParallelGeluFfnLayer::TensorParallelGeluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + NcclParam tensor_para, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse, - int int8_mode, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + int int8_mode, + bool use_gated_activation, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): GeluFfnLayer(max_batch_size, max_seq_len, head_num, @@ -71,11 +74,14 @@ TensorParallelGeluFfnLayer::TensorParallelGeluFfnLayer(size_t max_batch_size, allocator, is_free_buffer_after_forward, is_sparse, - int8_mode), + int8_mode, + use_gated_activation), tensor_para_(tensor_para), custom_all_reduce_comm_(custom_all_reduce_comm), - enable_custom_all_reduce_(enable_custom_all_reduce) + enable_custom_all_reduce_(enable_custom_all_reduce), + do_all_reduce_(do_all_reduce) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_CHECK(inter_size % tensor_para_.world_size_ == 0); } @@ -84,7 +90,8 @@ TensorParallelGeluFfnLayer::TensorParallelGeluFfnLayer(TensorParallelGeluFfnL GeluFfnLayer(ffn_layer), tensor_para_(ffn_layer.tensor_para_), custom_all_reduce_comm_(ffn_layer.custom_all_reduce_comm_), - enable_custom_all_reduce_(ffn_layer.enable_custom_all_reduce_) + enable_custom_all_reduce_(ffn_layer.enable_custom_all_reduce_), + do_all_reduce_(ffn_layer.do_all_reduce_) { } diff --git a/src/fastertransformer/layers/TensorParallelGeluFfnLayer.h b/src/fastertransformer/layers/TensorParallelGeluFfnLayer.h index 4a9330886..6965fa0b6 100644 --- a/src/fastertransformer/layers/TensorParallelGeluFfnLayer.h +++ b/src/fastertransformer/layers/TensorParallelGeluFfnLayer.h @@ -25,34 +25,37 @@ namespace fastertransformer { template class TensorParallelGeluFfnLayer: public GeluFfnLayer { private: - NcclParam tensor_para_; + NcclParam tensor_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; + bool do_all_reduce_; protected: public: - TensorParallelGeluFfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse = false, - int int8_mode = 0, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + TensorParallelGeluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse = false, + int int8_mode = 0, + bool use_gated_activation = false, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); TensorParallelGeluFfnLayer(TensorParallelGeluFfnLayer const& ffn_layer); virtual ~TensorParallelGeluFfnLayer() = default; - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const FfnWeight* ffn_weights) override; + const FfnWeight* ffn_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/TensorParallelReluFfnLayer.cc b/src/fastertransformer/layers/TensorParallelReluFfnLayer.cc index 1b798b9ca..af9df99b1 100644 --- a/src/fastertransformer/layers/TensorParallelReluFfnLayer.cc +++ b/src/fastertransformer/layers/TensorParallelReluFfnLayer.cc @@ -19,11 +19,11 @@ namespace fastertransformer { template -void TensorParallelReluFfnLayer::forward(std::vector* output_tensors, +void TensorParallelReluFfnLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const FfnWeight* ffn_weights) + const FfnWeight* ffn_weights) { - const size_t token_num = output_tensors->at(0).shape[0]; + const size_t token_num = output_tensors->at(0).shape[0]; const size_t hidden_units = output_tensors->at(0).shape[1]; bool use_custom_all_reduce_kernel = false; @@ -35,7 +35,7 @@ void TensorParallelReluFfnLayer::forward(std::vector::forward(output_tensors, input_tensors, ffn_weights); T* ffn_out = (T*)(output_tensors->at(0).data); - if (tensor_para_.world_size_ > 1) { + if (do_all_reduce_ && tensor_para_.world_size_ > 1) { if (!use_custom_all_reduce_kernel) { ftNcclAllReduceSum(ffn_out, ffn_out, token_num * hidden_units, tensor_para_, ReluFfnLayer::stream_); } @@ -47,19 +47,21 @@ void TensorParallelReluFfnLayer::forward(std::vector -TensorParallelReluFfnLayer::TensorParallelReluFfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - NcclParam tensor_para, - cudaStream_t stream, +TensorParallelReluFfnLayer::TensorParallelReluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + NcclParam tensor_para, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + bool use_gated_activation, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): ReluFfnLayer(max_batch_size, max_seq_len, head_num, @@ -69,21 +71,26 @@ TensorParallelReluFfnLayer::TensorParallelReluFfnLayer(size_t max_batch_size, cublas_wrapper, allocator, is_free_buffer_after_forward, - is_sparse), + is_sparse, + use_gated_activation), tensor_para_(tensor_para), custom_all_reduce_comm_(custom_all_reduce_comm), - enable_custom_all_reduce_(enable_custom_all_reduce) + enable_custom_all_reduce_(enable_custom_all_reduce), + do_all_reduce_(do_all_reduce) { FT_CHECK(inter_size % tensor_para_.world_size_ == 0); } template TensorParallelReluFfnLayer::TensorParallelReluFfnLayer(TensorParallelReluFfnLayer const& ffn_layer): - ReluFfnLayer(ffn_layer), tensor_para_(ffn_layer.tensor_para_) + ReluFfnLayer(ffn_layer), tensor_para_(ffn_layer.tensor_para_), do_all_reduce_(ffn_layer.do_all_reduce_) { } template class TensorParallelReluFfnLayer; template class TensorParallelReluFfnLayer; +#ifdef ENABLE_BF16 +template class TensorParallelReluFfnLayer<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/TensorParallelReluFfnLayer.h b/src/fastertransformer/layers/TensorParallelReluFfnLayer.h index 8d6adec42..0c307ca94 100644 --- a/src/fastertransformer/layers/TensorParallelReluFfnLayer.h +++ b/src/fastertransformer/layers/TensorParallelReluFfnLayer.h @@ -25,33 +25,36 @@ namespace fastertransformer { template class TensorParallelReluFfnLayer: public ReluFfnLayer { private: - NcclParam tensor_para_; + NcclParam tensor_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; + bool do_all_reduce_; protected: public: - TensorParallelReluFfnLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + TensorParallelReluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + bool use_gated_activation = false, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); TensorParallelReluFfnLayer(TensorParallelReluFfnLayer const& ffn_layer); virtual ~TensorParallelReluFfnLayer() = default; - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const FfnWeight* ffn_weights) override; + const FfnWeight* ffn_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/TensorParallelSiluFfnLayer.cc b/src/fastertransformer/layers/TensorParallelSiluFfnLayer.cc new file mode 100644 index 000000000..63f962cba --- /dev/null +++ b/src/fastertransformer/layers/TensorParallelSiluFfnLayer.cc @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/layers/TensorParallelSiluFfnLayer.h" + +namespace fastertransformer { + +template +void TensorParallelSiluFfnLayer::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const FfnWeight* ffn_weights) +{ + const size_t token_num = output_tensors->at(0).shape[0]; + const size_t hidden_units = output_tensors->at(0).shape[1]; + + bool use_custom_all_reduce_kernel = false; + if (enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr) { + use_custom_all_reduce_kernel = + custom_all_reduce_comm_->swapInternalBuffer(output_tensors, token_num * hidden_units); + } + + SiluFfnLayer::forward(output_tensors, input_tensors, ffn_weights); + + T* ffn_out = (T*)(output_tensors->at(0).data); + if (do_all_reduce_ && tensor_para_.world_size_ > 1) { + if (!use_custom_all_reduce_kernel) { + ftNcclAllReduceSum(ffn_out, ffn_out, token_num * hidden_units, tensor_para_, SiluFfnLayer::stream_); + } + else { + custom_all_reduce_comm_->customAllReduce(token_num * hidden_units, SiluFfnLayer::stream_); + } + sync_check_cuda_error(); + } +} + +template +TensorParallelSiluFfnLayer::TensorParallelSiluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + bool use_gated_activation, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + SiluFfnLayer(max_batch_size, + max_seq_len, + head_num, + size_per_head, + inter_size / tensor_para.world_size_, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + is_sparse, + use_gated_activation), + tensor_para_(tensor_para), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce), + do_all_reduce_(do_all_reduce) +{ + FT_CHECK(inter_size % tensor_para_.world_size_ == 0); +} + +template +TensorParallelSiluFfnLayer::TensorParallelSiluFfnLayer(TensorParallelSiluFfnLayer const& ffn_layer): + SiluFfnLayer(ffn_layer), tensor_para_(ffn_layer.tensor_para_), do_all_reduce_(ffn_layer.do_all_reduce_) +{ +} + +template class TensorParallelSiluFfnLayer; +template class TensorParallelSiluFfnLayer; +#ifdef ENABLE_BF16 +template class TensorParallelSiluFfnLayer<__nv_bfloat16>; +#endif + +} // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/TensorParallelSiluFfnLayer.h b/src/fastertransformer/layers/TensorParallelSiluFfnLayer.h new file mode 100644 index 000000000..89e0ef7eb --- /dev/null +++ b/src/fastertransformer/layers/TensorParallelSiluFfnLayer.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/layers/FfnLayer.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class TensorParallelSiluFfnLayer: public SiluFfnLayer { +private: + NcclParam tensor_para_; + std::shared_ptr custom_all_reduce_comm_; + int enable_custom_all_reduce_; + bool do_all_reduce_; + +protected: +public: + TensorParallelSiluFfnLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + bool use_gated_activation = false, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + TensorParallelSiluFfnLayer(TensorParallelSiluFfnLayer const& ffn_layer); + + virtual ~TensorParallelSiluFfnLayer() = default; + + void forward(std::vector* output_tensors, + const std::vector* input_tensors, + const FfnWeight* ffn_weights) override; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h b/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h index b21e3a7c2..ee3c64763 100644 --- a/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,18 +66,22 @@ template class BaseAttentionLayer: public BaseLayer { public: - virtual void forward(std::vector* output_tensors, + virtual void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) = 0; - BaseAttentionLayer(cudaStream_t stream, + const AttentionWeight* attention_weights) = 0; + BaseAttentionLayer(cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse) { } virtual ~BaseAttentionLayer() = default; + virtual bool isValidSeqLen(const size_t seq_len) + { + return true; + } }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/CMakeLists.txt b/src/fastertransformer/layers/attention_layers/CMakeLists.txt index 9cef31540..3af2bd9fd 100644 --- a/src/fastertransformer/layers/attention_layers/CMakeLists.txt +++ b/src/fastertransformer/layers/attention_layers/CMakeLists.txt @@ -37,7 +37,7 @@ target_link_libraries(LongformerAttentionLayer PUBLIC -lcublas -lcudart cublasMM add_library(DecoderSelfAttentionLayer STATIC DecoderSelfAttentionLayer.cc) set_property(TARGET DecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET DecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(DecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_multihead_attention matrix_vector_multiplication) +target_link_libraries(DecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_multihead_attention matrix_vector_multiplication tensor) add_library(GptContextAttentionLayer STATIC GptContextAttentionLayer.cc) set_property(TARGET GptContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.cu b/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.cu index fae0f4f52..2b86dfd90 100644 --- a/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.cu +++ b/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.cu @@ -22,6 +22,7 @@ #include "3rdparty/cub/cub.cuh" #endif +#include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" @@ -29,9 +30,9 @@ namespace fastertransformer { -const int WARP_SIZE = 32; -const bool ATTENION_OPT = true; -const int ATTENTION_BLOCK_SIZE = 256; +const int WARP_SIZE = 32; +const bool ATTENION_OPT = true; +const int ATTENTION_BLOCK_SIZE = 256; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -50,40 +51,40 @@ using Copy_t = Copy_half_t; /////////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void cross_attention_kernel(T* query_buf, - const T* Q_bias, - T* key_cache, - const T* K_bias, - T* value_cache, - const T* V_bias, - const int* length_per_sample, - T* context_buf, +__global__ void cross_attention_kernel(T* query_buf, + const T* Q_bias, + T* key_cache, + const T* K_bias, + T* value_cache, + const T* V_bias, + const int* length_per_sample, + T* context_buf, const bool* finished, - int batch_size, - int head_num, - int size_per_head, - int step, - const int seq_len, - const T scalar) + int batch_size, + int head_num, + int size_per_head, + int step, + const int seq_len, + const T scalar) { if (finished != nullptr && finished[blockIdx.x / head_num] == true) { return; } - int tid = threadIdx.x; - int bid = blockIdx.x / head_num; + int tid = threadIdx.x; + int bid = blockIdx.x / head_num; int head_id = blockIdx.x % head_num; - extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; - T* sq = reinterpret_cast(s_buf); - T* logits = reinterpret_cast(&sq[size_per_head]); + extern __shared__ __align__(sizeof(float)) unsigned s_buf[]; // align on largest type + T* sq = reinterpret_cast(s_buf); + T* logits = reinterpret_cast(&sq[size_per_head]); int length = __ldg(&length_per_sample[bid]); - int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; + int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; int qkv_bias_id = head_id * size_per_head + tid; if (tid < size_per_head) { - sq[tid] = query_buf[qkv_id] + Q_bias[qkv_bias_id]; + sq[tid] = add(query_buf[qkv_id], Q_bias[qkv_bias_id]); } __syncthreads(); @@ -96,12 +97,12 @@ __global__ void cross_attention_kernel(T* query_buf, // For the first step, we should add bias to key memory cache. // The KV memory cache only need to be updated at the first step. if (step == 1 && tid < size_per_head) { - key += K_bias[head_id * size_per_head + tid]; + key = add(key, K_bias[head_id * size_per_head + tid]); key_cache[key_id] = key; } - T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); - T qk = blockReduceSum(val); + T val = (tid < size_per_head) ? mul(key, sq[tid], scalar) : (T)(0.0f); + T qk = blockReduceSum(val); if (threadIdx.x == 0) { logits[ite] = qk; } @@ -120,7 +121,7 @@ __global__ void cross_attention_kernel(T* query_buf, local_i -= s_max_val; float local_o = tid < length ? __expf(local_i) : 0.0f; - float val = blockReduceSum(local_o); + float val = blockReduceSum(local_o); if (tid == 0) { s_sum = val + 1e-6; @@ -141,10 +142,10 @@ __global__ void cross_attention_kernel(T* query_buf, // for the first step, we should add bias to key memory cache if (step == 1) { - value += V_bias[head_id * size_per_head + tid]; + value = add(value, V_bias[head_id * size_per_head + tid]); value_cache[value_id] = value; } - sum += value * logits[ite]; + sum = fma(value, logits[ite], sum); } context_buf[bid * head_num * size_per_head + head_id * size_per_head + tid] = sum; } @@ -160,56 +161,56 @@ __global__ void cross_attention_kernel_opt(T* __restrict query_buf, const int* length_per_sample, T* __restrict context_buf, const bool* finished, - int batch_size, - int head_num, - const int step, - const int seq_len, + int batch_size, + int head_num, + const int step, + const int seq_len, const float scalar) { if (finished != nullptr && finished[blockIdx.x / head_num] == true) { return; } typedef Copy_t copy_t; - const int elems_per_thread = size_per_head / WARP_SIZE; + const int elems_per_thread = size_per_head / WARP_SIZE; union Access_t { copy_t v; - T x[elems_per_thread]; // supported size 1,2,4 + T x[elems_per_thread]; // supported size 1,2,4 }; typedef struct Float_n_t { float x[elems_per_thread]; // supported size 1,2,4 } float_n_t; - __shared__ float_n_t sq[block_sz]; + __shared__ float_n_t sq[block_sz]; extern __shared__ float logits[]; // use to store the logits from [0~step] - const int warp_id = threadIdx.x / WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; const int warp_num = block_sz / WARP_SIZE; - typedef cub::BlockReduce MaxValBlockReduce; - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce MaxValBlockReduce; + typedef cub::BlockReduce BlockReduce; __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; - __shared__ typename BlockReduce::TempStorage block_temp_storage; + __shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; - const int tid = threadIdx.x; - const int bid = blockIdx.x / head_num; + const int tid = threadIdx.x; + const int bid = blockIdx.x / head_num; const int head_id = blockIdx.x % head_num; int length = __ldg(&length_per_sample[bid]); const int lane_id = tid % WARP_SIZE; - int qkv_id = bid * head_num * size_per_head + head_id * size_per_head; + int qkv_id = bid * head_num * size_per_head + head_id * size_per_head; int qkv_bias_id = head_id * size_per_head; int key_value_id = bid * (seq_len * head_num * size_per_head) + +head_id * size_per_head; - query_buf = &query_buf[qkv_id]; - K_bias = &K_bias[qkv_bias_id]; - key_cache = &key_cache[key_value_id]; - Q_bias = &Q_bias[qkv_bias_id]; - V_bias = &V_bias[qkv_bias_id]; + query_buf = &query_buf[qkv_id]; + K_bias = &K_bias[qkv_bias_id]; + key_cache = &key_cache[key_value_id]; + Q_bias = &Q_bias[qkv_bias_id]; + V_bias = &V_bias[qkv_bias_id]; value_cache = &value_cache[key_value_id]; context_buf = &context_buf[qkv_id]; @@ -217,7 +218,7 @@ __global__ void cross_attention_kernel_opt(T* __restrict query_buf, // each warp will have its own copy of sq query_buf_r.v = *((copy_t*)query_buf + lane_id); - bias_r.v = *((copy_t*)Q_bias + lane_id); + bias_r.v = *((copy_t*)Q_bias + lane_id); float qb_r[elems_per_thread]; for (int i = 0; i < elems_per_thread; ++i) { qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; @@ -250,7 +251,7 @@ __global__ void cross_attention_kernel_opt(T* __restrict query_buf, __syncthreads(); __shared__ float s_max_val, s_sum; - float local_i = -1e20f; + float local_i = -1e20f; for (int i = tid; i < length; i += blockDim.x) { local_i = max(local_i, logits[i]); } @@ -281,7 +282,7 @@ __global__ void cross_attention_kernel_opt(T* __restrict query_buf, // This optimization introduces discrepancy because of different order in FP32 summation float sum_r[elems_per_thread] = {0.f}; - bias_r.v = *((copy_t*)V_bias + lane_id); + bias_r.v = *((copy_t*)V_bias + lane_id); for (int ite = warp_id; ite < length; ite += warp_num) { key_val_r.v = *((copy_t*)&value_cache[ite * offset] + lane_id); @@ -319,28 +320,39 @@ __global__ void cross_attention_kernel_opt(T* __restrict query_buf, } template -void cross_attention_dispatch(T* query_buf, - const T* Q_bias, - T* key_cache, - const T* K_bias, - T* value_cache, - const T* V_bias, - const int* length, - T* context_buf, - const bool* finished, - const int max_batch_size, - const int inference_batch_size, - const int head_num, - const int size_per_head, - const int step, - const int seq_len, - const bool batch_major_cache, - const float q_scaling, - cudaStream_t stream) +struct CATypeConverter { + using Type = T; +}; + +template<> +struct CATypeConverter { + using Type = uint16_t; +}; + +template +void cross_attention_dispatch(T* query_buf, + const T* Q_bias, + T* key_cache, + const T* K_bias, + T* value_cache, + const T* V_bias, + const int* length, + T* context_buf, + const bool* finished, + const int max_batch_size, + const int inference_batch_size, + const int head_num, + const int size_per_head, + const int step, + const int memory_max_len, + const bool batch_major_cache, + const float q_scaling, + outputCrossAttentionParam output_cross_attention_params, + cudaStream_t stream) { if (!batch_major_cache) { const int block_sz = ATTENTION_BLOCK_SIZE; - float scalar = 1.f / (sqrtf(size_per_head * 1.0f) * q_scaling); + float scalar = 1.f / (sqrtf(size_per_head * 1.0f) * q_scaling); dim3 grid(inference_batch_size * head_num); @@ -348,70 +360,70 @@ void cross_attention_dispatch(T* query_buf, switch (cond) { case 32: cross_attention_kernel_opt - <<>>(query_buf, - Q_bias, - key_cache, - K_bias, - value_cache, - V_bias, - length, - context_buf, - finished, - max_batch_size, - head_num, - step, - seq_len, - scalar); + <<>>(query_buf, + Q_bias, + key_cache, + K_bias, + value_cache, + V_bias, + length, + context_buf, + finished, + max_batch_size, + head_num, + step, + memory_max_len, + scalar); break; case 64: cross_attention_kernel_opt - <<>>(query_buf, - Q_bias, - key_cache, - K_bias, - value_cache, - V_bias, - length, - context_buf, - finished, - max_batch_size, - head_num, - step, - seq_len, - scalar); + <<>>(query_buf, + Q_bias, + key_cache, + K_bias, + value_cache, + V_bias, + length, + context_buf, + finished, + max_batch_size, + head_num, + step, + memory_max_len, + scalar); break; case 128: cross_attention_kernel_opt - <<>>(query_buf, - Q_bias, - key_cache, - K_bias, - value_cache, - V_bias, - length, - context_buf, - finished, - max_batch_size, - head_num, - step, - seq_len, - scalar); + <<>>(query_buf, + Q_bias, + key_cache, + K_bias, + value_cache, + V_bias, + length, + context_buf, + finished, + max_batch_size, + head_num, + step, + memory_max_len, + scalar); break; default: // default path int block_size = 128; - if (seq_len <= 64) { + if (memory_max_len <= 64) { block_size = 64; } - else if (seq_len <= 128 && seq_len > size_per_head) { + else if (memory_max_len <= 128 && memory_max_len > size_per_head) { block_size = 128; } - else if (seq_len > 128 && seq_len <= 256) { + else if (memory_max_len > 128 && memory_max_len <= 256) { block_size = 256; } - else if (seq_len > 256 && seq_len <= 512) { + else if (memory_max_len > 256 && memory_max_len <= 512) { block_size = 512; } else { @@ -425,7 +437,7 @@ void cross_attention_dispatch(T* query_buf, assert(block_size <= 1024); dim3 block(block_size); - int shared_size = sizeof(T) * (size_per_head + seq_len); + int shared_size = sizeof(T) * (size_per_head + memory_max_len); cross_attention_kernel<<>>(query_buf, Q_bias, key_cache, @@ -439,17 +451,17 @@ void cross_attention_dispatch(T* query_buf, head_num, size_per_head, step, - seq_len, + memory_max_len, scalar); } } else { assert(step > 0); // assert(size_per_head == 32 || size_per_head == 64 || size_per_head == 128); - using DataType = typename std::conditional::type; + // using DataType = typename std::conditional::type; + using DataType = typename CATypeConverter::Type; // Prepare the parameters. Cross_multihead_attention_params params; - memset(¶ms, 0, sizeof(params)); params.q_bias = reinterpret_cast(Q_bias); params.k_bias = reinterpret_cast(K_bias); params.v_bias = reinterpret_cast(V_bias); @@ -458,67 +470,96 @@ void cross_attention_dispatch(T* query_buf, params.out = reinterpret_cast(context_buf); // Set the input buffers. - params.q = reinterpret_cast(query_buf); - params.k = nullptr; - params.v = nullptr; - params.stride = 0; + params.q = reinterpret_cast(query_buf); + params.k = nullptr; + params.v = nullptr; + params.stride = 0; params.finished = const_cast(finished); params.memory_length_per_sample = const_cast(length); - params.k_cache = reinterpret_cast(key_cache); - params.v_cache = reinterpret_cast(value_cache); + params.k_cache = reinterpret_cast(key_cache); + params.v_cache = reinterpret_cast(value_cache); params.batch_size = inference_batch_size; // TODO(bhsueh) We can use batch but not batch * beam_width in k/v cache in cross attention // because they are same for all beams. - params.beam_width = 1; // We don't care the beam_width in cross attention, set to 1 is enough. - params.seq_length = seq_len; - params.timestep = step - 1; - params.num_heads = head_num; + params.beam_width = 1; // We don't care the beam_width in cross attention, set to 1 is enough. + params.memory_max_len = memory_max_len; + params.timestep = step - 1; + params.num_heads = head_num; params.hidden_size_per_head = size_per_head; - params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling); + params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling); + + // output cross attentions + params.max_decoder_seq_len = output_cross_attention_params.max_decoder_seq_len; + params.cross_attention_out = output_cross_attention_params.cross_attention_out; + params.is_return_cross_attentions = output_cross_attention_params.is_return_cross_attentions; cross_multihead_attention(params, stream); } } -template void cross_attention_dispatch(float* query_buf, - const float* Q_bias, - float* key_cache, - const float* K_bias, - float* value_cache, - const float* V_bias, - const int* length, - float* context_buf, - const bool* finished, - const int max_batch_size, - const int inference_batch_size, - const int head_num, - const int size_per_head, - const int step, - const int seq_len, - const bool batch_major_cache, - const float q_scaling, - cudaStream_t stream); - -template void cross_attention_dispatch(half* query_buf, - const half* Q_bias, - half* key_cache, - const half* K_bias, - half* value_cache, - const half* V_bias, - const int* length, - half* context_buf, - const bool* finished, - const int max_batch_size, - const int inference_batch_size, - const int head_num, - const int size_per_head, - const int step, - const int seq_len, - const bool batch_major_cache, - const float q_scaling, - cudaStream_t stream); +template void cross_attention_dispatch(float* query_buf, + const float* Q_bias, + float* key_cache, + const float* K_bias, + float* value_cache, + const float* V_bias, + const int* length, + float* context_buf, + const bool* finished, + const int max_batch_size, + const int inference_batch_size, + const int head_num, + const int size_per_head, + const int step, + const int memory_max_len, + const bool batch_major_cache, + const float q_scaling, + outputCrossAttentionParam output_cross_attention_params, + cudaStream_t stream); + +template void cross_attention_dispatch(half* query_buf, + const half* Q_bias, + half* key_cache, + const half* K_bias, + half* value_cache, + const half* V_bias, + const int* length, + half* context_buf, + const bool* finished, + const int max_batch_size, + const int inference_batch_size, + const int head_num, + const int size_per_head, + const int step, + const int memory_max_len, + const bool batch_major_cache, + const float q_scaling, + outputCrossAttentionParam output_cross_attention_params, + cudaStream_t stream); + +#ifdef ENABLE_BF16 +template void cross_attention_dispatch(__nv_bfloat16* query_buf, + const __nv_bfloat16* Q_bias, + __nv_bfloat16* key_cache, + const __nv_bfloat16* K_bias, + __nv_bfloat16* value_cache, + const __nv_bfloat16* V_bias, + const int* length, + __nv_bfloat16* context_buf, + const bool* finished, + const int max_batch_size, + const int inference_batch_size, + const int head_num, + const int size_per_head, + const int step, + const int memory_max_len, + const bool batch_major_cache, + const float q_scaling, + outputCrossAttentionParam output_cross_attention_params, + cudaStream_t stream); +#endif // Currently need to transpose at the first step in Cross attention template @@ -526,24 +567,24 @@ __global__ void transpose_4d_batch_major_mem_k_cache( T* k_dst, const T* k_src, const int head_num, const int size_per_head, const int max_seq_len) { // B, L, H, Dh -> B, H, Dh/x, L, x - const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; auto key_src = reinterpret_cast(k_src + batch_id * head_num * size_per_head * max_seq_len + head_id * size_per_head); auto key_dst = reinterpret_cast(k_dst + batch_id * head_num * size_per_head * max_seq_len + head_id * size_per_head * max_seq_len); - const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; - int size_per_head_div_x = size_per_head / X_ELEMS; + const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; if (out_idx >= size_per_head_div_x * max_seq_len) { return; } - int idx = out_idx; - const int k_seq_len_id = idx % max_seq_len; - idx = (idx - k_seq_len_id) / max_seq_len; + int idx = out_idx; + const int k_seq_len_id = idx % max_seq_len; + idx = (idx - k_seq_len_id) / max_seq_len; const int k_head_size_id = idx % size_per_head_div_x; key_dst[out_idx] = key_src[k_seq_len_id * head_num * size_per_head_div_x + k_head_size_id]; @@ -555,7 +596,7 @@ __global__ void transpose_4d_batch_major_mem_v_cache( { // B, L, H, Dh -> B, H, L, Dh const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; + const int head_id = blockIdx.z; // 16 byte loads will handle "x" dimension auto val_src = reinterpret_cast(v_src + batch_id * head_num * size_per_head * max_seq_len @@ -566,35 +607,35 @@ __global__ void transpose_4d_batch_major_mem_v_cache( // idx is over output dimension L * size_per_head / x for values const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - const int size_per_head_div_x = size_per_head / X_ELEMS; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + const int size_per_head_div_x = size_per_head / X_ELEMS; if (out_idx >= size_per_head_div_x * max_seq_len) { return; } - int idx = out_idx; + int idx = out_idx; const int v_head_size_id = idx % size_per_head_div_x; - idx = (idx - v_head_size_id) / size_per_head_div_x; - const int v_seq_len_id = idx % max_seq_len; + idx = (idx - v_head_size_id) / size_per_head_div_x; + const int v_seq_len_id = idx % max_seq_len; val_dst[out_idx] = val_src[v_seq_len_id * head_num * size_per_head_div_x + v_head_size_id]; } template -void transpose_4d_batch_major_memory_kernelLauncher(T* dst, - const T* src, - const int local_batch_size, - const int max_seq_len, - const int size_per_head, - const int local_head_num, - const bool k_cache, +void transpose_4d_batch_major_memory_kernelLauncher(T* dst, + const T* src, + const int local_batch_size, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const bool k_cache, cudaStream_t stream) { constexpr int block_sz = 128; - constexpr int x = (sizeof(T) == 4) ? 4 : 8; - int size = max_seq_len * size_per_head / x; - dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_head_num); + constexpr int x = (sizeof(T) == 4) ? 4 : 8; + int size = max_seq_len * size_per_head / x; + dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_head_num); if (k_cache) { transpose_4d_batch_major_mem_k_cache<<>>( @@ -606,34 +647,46 @@ void transpose_4d_batch_major_memory_kernelLauncher(T* dst, } } -template void transpose_4d_batch_major_memory_kernelLauncher(float* dst, +template void transpose_4d_batch_major_memory_kernelLauncher(float* dst, const float* src, - const int local_batch_size, - const int max_seq_len, - const int size_per_head, - const int local_head_num, - const bool k_cache, + const int local_batch_size, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const bool k_cache, cudaStream_t stream); -template void transpose_4d_batch_major_memory_kernelLauncher(half* dst, - const half* src, - const int local_batch_size, - const int max_seq_len, - const int size_per_head, - const int local_head_num, - const bool k_cache, +template void transpose_4d_batch_major_memory_kernelLauncher(half* dst, + const half* src, + const int local_batch_size, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const bool k_cache, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void transpose_4d_batch_major_memory_kernelLauncher(__nv_bfloat16* dst, + const __nv_bfloat16* src, + const int local_batch_size, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const bool k_cache, + cudaStream_t stream); +#endif + template void DecoderCrossAttentionLayer::allocateBuffer() { if (is_allocate_buffer_ == false) { - q_buf_ = reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - context_buf_ = reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); + q_buf_ = reinterpret_cast(allocator_->reMalloc(q_buf_, sizeof(T) * max_batch_size_ * hidden_units_, false)); + context_buf_ = reinterpret_cast( + allocator_->reMalloc(context_buf_, sizeof(T) * max_batch_size_ * hidden_units_, false)); if (is_batch_major_cache_) { - mem_cache_buf_ = reinterpret_cast( - allocator_->malloc(sizeof(T) * max_batch_size_ * max_mem_seq_len_ * hidden_units_, false)); + mem_cache_buf_ = reinterpret_cast(allocator_->reMalloc( + mem_cache_buf_, sizeof(T) * max_batch_size_ * max_mem_seq_len_ * hidden_units_, false)); } is_allocate_buffer_ = true; } @@ -658,10 +711,10 @@ template void DecoderCrossAttentionLayer::freeBuffer() { if (is_allocate_buffer_) { - allocator_->free(q_buf_); - allocator_->free(context_buf_); + allocator_->free((void**)(&q_buf_)); + allocator_->free((void**)(&context_buf_)); if (is_batch_major_cache_) { - allocator_->free(mem_cache_buf_); + allocator_->free((void**)(&mem_cache_buf_)); } is_allocate_buffer_ = false; } @@ -694,15 +747,15 @@ bool DecoderCrossAttentionLayer::isValidSeqLen(size_t seq_len) } template -DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t d_model, - const float q_scaling, - cudaStream_t stream, +DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t d_model, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), head_num_(head_num), @@ -711,19 +764,19 @@ DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_batch_size, d_model_(d_model), q_scaling_(q_scaling) { - FT_CHECK(size_per_head_ == 32 || size_per_head_ == 64 || size_per_head_ == 96 || size_per_head_ == 80 + FT_CHECK(size_per_head_ == 32 || size_per_head_ == 64 || size_per_head_ == 80 || size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256); } template -DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - cudaStream_t stream, +DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): DecoderCrossAttentionLayer(max_batch_size, head_num, size_per_head, @@ -737,14 +790,14 @@ DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_batch_size, } template -DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - const float q_scaling, - cudaStream_t stream, +DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): DecoderCrossAttentionLayer(max_batch_size, head_num, size_per_head, @@ -759,19 +812,16 @@ DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_batch_size, template DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(DecoderCrossAttentionLayer const& attention_layer): - BaseAttentionLayer(attention_layer.stream_, - attention_layer.cublas_wrapper_, - attention_layer.allocator_, - attention_layer.is_free_buffer_after_forward_), - max_batch_size_(attention_layer.max_batch_size_), - head_num_(attention_layer.head_num_), - size_per_head_(attention_layer.size_per_head_), - hidden_units_(attention_layer.hidden_units_), - d_model_(attention_layer.d_model_), - q_scaling_(attention_layer.q_scaling_) + DecoderCrossAttentionLayer(attention_layer.max_batch_size_, + attention_layer.head_num_, + attention_layer.size_per_head_, + attention_layer.d_model_, + attention_layer.q_scaling_, + attention_layer.stream_, + attention_layer.cublas_wrapper_, + attention_layer.allocator_, + attention_layer.is_free_buffer_after_forward_) { - FT_CHECK(size_per_head_ == 32 || size_per_head_ == 64 || size_per_head_ == 96 || size_per_head_ == 128 - || size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256); } template @@ -782,9 +832,9 @@ DecoderCrossAttentionLayer::~DecoderCrossAttentionLayer() } template -void DecoderCrossAttentionLayer::forward(std::vector* output_tensors, +void DecoderCrossAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input tensors: // attention_input [batch_size, d_model], @@ -797,24 +847,28 @@ void DecoderCrossAttentionLayer::forward(std::vectorsize() == 5); - FT_CHECK(output_tensors->size() == 3); + FT_CHECK(output_tensors->size() == 3 || output_tensors->size() == 4); FT_CHECK(isValidBatchSize(input_tensors->at(0).shape[0])); FT_CHECK(isValidSeqLen(input_tensors->at(1).shape[1])); allocateBuffer(input_tensors->at(0).shape[0], input_tensors->at(1).shape[1]); - const T* attention_input = reinterpret_cast(input_tensors->at(0).data); - Tensor encoder_output_tensor = input_tensors->at(1); - const int* memory_sequence_length = reinterpret_cast(input_tensors->at(2).data); - const bool* finished = reinterpret_cast(input_tensors->at(3).data); - const int step = *reinterpret_cast(input_tensors->at(4).data); + const T* attention_input = reinterpret_cast(input_tensors->at(0).data); + Tensor encoder_output_tensor = input_tensors->at(1); + const int* memory_sequence_length = reinterpret_cast(input_tensors->at(2).data); + const bool* finished = reinterpret_cast(input_tensors->at(3).data); + const int step = *reinterpret_cast(input_tensors->at(4).data); - T* attention_out = (T*)(output_tensors->at(0).data); - T* key_mem_cache = (T*)(output_tensors->at(1).data); + T* attention_out = (T*)(output_tensors->at(0).data); + T* key_mem_cache = (T*)(output_tensors->at(1).data); T* value_mem_cache = (T*)(output_tensors->at(2).data); - const int batch_size = input_tensors->at(0).shape[0]; + const bool output_cross_attentions = output_tensors->size() == 4; + const int max_decoder_seq_len = output_cross_attentions ? output_tensors->at(3).shape[2] : 0; + + const int batch_size = input_tensors->at(0).shape[0]; const int mem_max_seq_len = encoder_output_tensor.shape[1]; cublas_wrapper_->Gemm(CUBLAS_OP_N, CUBLAS_OP_N, @@ -894,6 +948,14 @@ void DecoderCrossAttentionLayer::forward(std::vector output_attention_param{}; + // output cross attentions + if (output_cross_attentions) { + output_attention_param.max_decoder_seq_len = max_decoder_seq_len; + output_attention_param.cross_attention_out = output_tensors->at(3).getPtr(); + output_attention_param.is_return_cross_attentions = true; + } + cross_attention_dispatch(q_buf_, attention_weights->query_weight.bias, key_mem_cache, @@ -911,6 +973,7 @@ void DecoderCrossAttentionLayer::forward(std::vectorGemm(CUBLAS_OP_N, @@ -931,5 +994,8 @@ void DecoderCrossAttentionLayer::forward(std::vector; template class DecoderCrossAttentionLayer; +#ifdef ENABLE_BF16 +template class DecoderCrossAttentionLayer<__nv_bfloat16>; +#endif -} // namespace fastertransformer +} // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.h b/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.h index b92dba84b..234a08f5c 100644 --- a/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.h @@ -27,14 +27,14 @@ class DecoderCrossAttentionLayer: public BaseAttentionLayer { const size_t head_num_; const size_t size_per_head_; const size_t d_model_; - bool is_batch_major_cache_ = true; + bool is_batch_major_cache_ = true; // calculated params const size_t hidden_units_; - const float q_scaling_; + const float q_scaling_; // buffer handling - size_t max_batch_size_ = 0; + size_t max_batch_size_ = 0; size_t max_mem_seq_len_ = 0; void allocateBuffer() override; @@ -50,45 +50,45 @@ class DecoderCrossAttentionLayer: public BaseAttentionLayer { using BaseAttentionLayer::cublas_wrapper_; using BaseAttentionLayer::allocator_; - T* q_buf_ = nullptr; - T* context_buf_ = nullptr; + T* q_buf_ = nullptr; + T* context_buf_ = nullptr; T* mem_cache_buf_ = nullptr; public: - DecoderCrossAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - cudaStream_t stream, + DecoderCrossAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); - - DecoderCrossAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - const float q_scaling, - cudaStream_t stream, + IAllocator* allocator, + bool is_free_buffer_after_forward); + + DecoderCrossAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); - - DecoderCrossAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t d_model, - const float q_scaling, - cudaStream_t stream, + IAllocator* allocator, + bool is_free_buffer_after_forward); + + DecoderCrossAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t d_model, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); DecoderCrossAttentionLayer(DecoderCrossAttentionLayer const& attention_layer); ~DecoderCrossAttentionLayer(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc index e6d8be22f..7d825f739 100644 --- a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc @@ -21,40 +21,44 @@ namespace fastertransformer { template -struct TypeConverter { +struct SATypeConverter { using Type = T; }; template<> -struct TypeConverter { +struct SATypeConverter { using Type = uint16_t; }; template -void fusedQKV_masked_attention_dispatch(const T* qkv_buf, - const T* qkv_bias, - const T* relative_attention_bias, - T* key_cache, - T* value_cache, - const int* cache_indir, - T* context_buf, - const bool* finished, - const int* sequence_lengths, - const int max_batch_size, - const int inference_batch_size, - const int beam_width, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - const int max_seq_len, - const int max_input_len, - const int* input_lengths, - const int step, - const float q_scaling, - const int relative_attention_bias_stride, +void fusedQKV_masked_attention_dispatch(const T* qkv_buf, + const T* qkv_bias, + const T* relative_attention_bias, + T* key_cache, + T* value_cache, + const int* cache_indir, + T* context_buf, + const bool* finished, + const int* sequence_lengths, + const int max_batch_size, + const int inference_batch_size, + const int beam_width, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const bool neox_rotary_style, + const int memory_max_len, + const int* prefix_prompt_lengths, + const int max_prefix_prompt_length, + const int max_input_len, + const int* total_padding_tokens, + const int step, + const float q_scaling, + const int relative_attention_bias_stride, + const bool* masked_tokens, cudaStream_t stream) { - using DataType = typename TypeConverter::Type; + using DataType = typename SATypeConverter::Type; // Prepare the parameters. Masked_multihead_attention_params params; memset(¶ms, 0, sizeof(params)); @@ -74,38 +78,36 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, params.out = reinterpret_cast(context_buf); // Set the input buffers. - params.q = reinterpret_cast(qkv_buf); - params.k = reinterpret_cast(qkv_buf) + hidden_units; - params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units; - params.stride = 3 * hidden_units; + params.q = reinterpret_cast(qkv_buf); + params.k = reinterpret_cast(qkv_buf) + hidden_units; + params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units; + params.stride = 3 * hidden_units; params.finished = const_cast(finished); - params.k_cache = reinterpret_cast(key_cache); - params.v_cache = reinterpret_cast(value_cache); - params.cache_indir = cache_indir; - params.batch_size = inference_batch_size; - params.beam_width = beam_width; - params.seq_length = max_seq_len; - params.length_per_sample = sequence_lengths; - params.timestep = step - 1; - params.num_heads = head_num; + params.k_cache = reinterpret_cast(key_cache); + params.v_cache = reinterpret_cast(value_cache); + params.cache_indir = cache_indir; + params.batch_size = inference_batch_size; + params.beam_width = beam_width; + params.memory_max_len = memory_max_len; + params.prefix_prompt_lengths = prefix_prompt_lengths; + params.max_prefix_prompt_length = max_prefix_prompt_length; + params.length_per_sample = sequence_lengths; // max_input_length + current output length + // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation + params.timestep = step + max_prefix_prompt_length - 1; + params.num_heads = head_num; params.hidden_size_per_head = size_per_head; params.rotary_embedding_dim = rotary_embedding_dim; + params.neox_rotary_style = neox_rotary_style; // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust) params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling); - params.input_lengths = input_lengths; - params.max_input_len = max_input_len; - // TODO(bhsueh) Need better implementation + params.total_padding_tokens = total_padding_tokens; if (relative_attention_bias != nullptr) { - if (sizeof(T) == 4) { - params.relative_attention_bias_float = reinterpret_cast(relative_attention_bias); - } - else { - params.relative_attention_bias_half = reinterpret_cast(relative_attention_bias); - } + params.relative_attention_bias = reinterpret_cast(relative_attention_bias); } params.relative_attention_bias_stride = relative_attention_bias_stride; + params.masked_tokens = masked_tokens; masked_multihead_attention(params, stream); } @@ -113,57 +115,65 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, template void fusedQKV_masked_attention_dispatch(const float* qkv_buf, const float* qkv_bias, const float* relative_attention_bias, - float* key_cache, - float* value_cache, - const int* cache_indir, - float* context_buf, - const bool* finished, - const int* sequence_lengths, - const int max_batch_size, - const int inference_batch_size, - const int beam_width, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - const int max_seq_len, - const int max_input_len, - const int* input_lengths, - const int step, - const float q_scaling, - const int relative_attention_bias_stride, + float* key_cache, + float* value_cache, + const int* cache_indir, + float* context_buf, + const bool* finished, + const int* sequence_lengths, + const int max_batch_size, + const int inference_batch_size, + const int beam_width, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const bool neox_rotary_style, + const int memory_max_len, + const int* prefix_prompt_lengths, + const int max_prefix_prompt_length, + const int max_input_len, + const int* total_padding_tokens, + const int step, + const float q_scaling, + const int relative_attention_bias_stride, + const bool* masked_tokens, cudaStream_t stream); -template void fusedQKV_masked_attention_dispatch(const half* qkv_buf, - const half* qkv_bias, - const half* relative_attention_bias, - half* key_cache, - half* value_cache, - const int* cache_indir, - half* context_buf, - const bool* finished, - const int* sequence_lengths, - const int max_batch_size, - const int inference_batch_size, - const int beam_width, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - const int max_seq_len, - const int max_input_len, - const int* input_lengths, - const int step, - const float q_scaling, - const int relative_attention_bias_stride, +template void fusedQKV_masked_attention_dispatch(const half* qkv_buf, + const half* qkv_bias, + const half* relative_attention_bias, + half* key_cache, + half* value_cache, + const int* cache_indir, + half* context_buf, + const bool* finished, + const int* sequence_lengths, + const int max_batch_size, + const int inference_batch_size, + const int beam_width, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const bool neox_rotary_style, + const int memory_max_len, + const int* prefix_prompt_lengths, + const int max_prefix_prompt_length, + const int max_input_len, + const int* total_padding_tokens, + const int step, + const float q_scaling, + const int relative_attention_bias_stride, + const bool* masked_tokens, cudaStream_t stream); template void DecoderSelfAttentionLayer::allocateBuffer() { if (is_allocate_buffer_ == false) { - qkv_buf_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * 3 * local_hidden_units_, false)); - context_buf_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * local_hidden_units_, false)); + qkv_buf_ = reinterpret_cast( + allocator_->reMalloc(qkv_buf_, sizeof(T) * max_batch_size_ * 3 * local_hidden_units_, false)); + context_buf_ = reinterpret_cast( + allocator_->reMalloc(context_buf_, sizeof(T) * max_batch_size_ * local_hidden_units_, false)); is_allocate_buffer_ = true; } } @@ -183,8 +193,8 @@ template void DecoderSelfAttentionLayer::freeBuffer() { if (is_allocate_buffer_) { - allocator_->free(qkv_buf_); - allocator_->free(context_buf_); + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&context_buf_)); is_allocate_buffer_ = false; } } @@ -203,19 +213,20 @@ bool DecoderSelfAttentionLayer::isValidBatchSize(size_t batch_size) } template -DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - size_t rotary_embedding_dim, - size_t d_model, - const float q_scaling, - cudaStream_t stream, +DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + size_t d_model, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse, - int int8_mode): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), max_batch_size_(max_batch_size), head_num_(head_num), @@ -224,30 +235,32 @@ DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, local_head_num_(local_head_num), local_hidden_units_(local_head_num_ * size_per_head_), rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), d_model_(d_model), q_scaling_(q_scaling), int8_mode_(int8_mode) { - FT_CHECK(size_per_head_ == 32 || size_per_head_ == 64 || size_per_head_ == 80 || size_per_head_ == 96 + FT_CHECK(size_per_head_ == 32 || size_per_head_ == 64 || size_per_head_ == 80 || size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256); } template -DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - cudaStream_t stream, +DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse, - int int8_mode): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): DecoderSelfAttentionLayer(max_batch_size, head_num, size_per_head, head_num, 0, + false, head_num * size_per_head, 1.0f, stream, @@ -260,21 +273,22 @@ DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, } template -DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - const float q_scaling, - cudaStream_t stream, +DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse, - int int8_mode): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): DecoderSelfAttentionLayer(max_batch_size, head_num, size_per_head, head_num, 0, + false, head_num * size_per_head, q_scaling, stream, @@ -287,21 +301,22 @@ DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, } template -DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - cudaStream_t stream, +DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse, - int int8_mode): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): DecoderSelfAttentionLayer(max_batch_size, head_num, size_per_head, local_head_num, 0, + false, head_num * size_per_head, 1.0f, stream, @@ -314,23 +329,24 @@ DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, } template -DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - size_t d_model, - const float q_scaling, - cudaStream_t stream, +DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + size_t d_model, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse, - int int8_mode): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): DecoderSelfAttentionLayer(max_batch_size, head_num, size_per_head, local_head_num, 0, + false, d_model, q_scaling, stream, @@ -343,22 +359,24 @@ DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, } template -DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - size_t rotary_embedding_dim, - cudaStream_t stream, +DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse, - int int8_mode): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): DecoderSelfAttentionLayer(max_batch_size, head_num, size_per_head, local_head_num, rotary_embedding_dim, + neox_rotary_style, head_num * size_per_head, 1.0f, stream, @@ -377,6 +395,7 @@ DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(DecoderSelfAttentionLaye attention_layer.size_per_head_, attention_layer.local_head_num_, attention_layer.rotary_embedding_dim_, + attention_layer.neox_rotary_style_, attention_layer.d_model_, attention_layer.q_scaling_, attention_layer.stream_, @@ -396,48 +415,53 @@ DecoderSelfAttentionLayer::~DecoderSelfAttentionLayer() } template -void DecoderSelfAttentionLayer::forward(std::vector* output_tensors, +void DecoderSelfAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input tensors: // attention_input [batch_size, d_model_], // finished [batch_size], // sequence_lengths [batch_size] - // input_lengths [batch_size] + // total_padding_tokens [batch_size] + // d_prefix_prompt_lengths [batch_size] on gpu + // max_prefix_prompt_length [1] on cpu // max_input_length [1] on cpu // step [1] on cpu - // cache_indirection [batch_size / beam_width, beam_width, max_seq_len] + // cache_indirection [batch_size / beam_width, beam_width, memory_max_len] + // masked_tokens [batch_size, memory_len] // relative_attention_bias [1, head_num, step, step] or [1, head_num, max_seq_len, max_seq_len] (option) // output tensors: // attention_output [batch_size, d_model_], - // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] - // value_cache [batch, local_head_num, max_seq_len, size_per_head] + // key_cache [batch, local_head_num, size_per_head // x, memory_max_len, x] + // value_cache [batch, local_head_num, memory_max_len, size_per_head] FT_LOG_DEBUG(__PRETTY_FUNCTION__); - FT_CHECK(input_tensors->size() == 7 || input_tensors->size() == 8); + FT_CHECK(input_tensors->size() == 10 || input_tensors->size() == 11); FT_CHECK(output_tensors->size() == 3); FT_CHECK(output_tensors->at(1).shape.size() == 5 || output_tensors->at(1).shape.size() == 3); FT_CHECK(output_tensors->at(2).shape.size() == 4 || output_tensors->at(2).shape.size() == 3); - // FT_CHECK(isValidBatchSize(input_tensors->at(0).shape[0])); allocateBuffer(input_tensors->at(0).shape[0]); - const T* attention_input = reinterpret_cast(input_tensors->at(0).data); - const bool* finished = reinterpret_cast(input_tensors->at(1).data); - const int* sequence_lengths = reinterpret_cast(input_tensors->at(2).data); - const int* cache_indir = reinterpret_cast(input_tensors->at(6).data); - const T* relative_attention_bias = - reinterpret_cast(input_tensors->size() == 8 ? input_tensors->at(7).data : nullptr); - const int relative_attention_bias_stride = input_tensors->size() == 8 ? input_tensors->at(7).shape[3] : 0; + const T* attention_input = input_tensors->at(0).getPtr(); + const bool* finished = input_tensors->at(1).getPtr(); + const int* sequence_lengths = input_tensors->at(2).getPtr(); + const int* cache_indir = input_tensors->at(8).getPtr(); + const bool* masked_tokens = input_tensors->at(9).getPtr(); + const T* relative_attention_bias = input_tensors->size() == 11 ? input_tensors->at(10).getPtr() : nullptr; + const int relative_attention_bias_stride = input_tensors->size() == 11 ? input_tensors->at(10).shape[3] : 0; T* attention_out = (T*)(output_tensors->at(0).data); - T* key_cache = (T*)(output_tensors->at(1).data); - T* value_cache = (T*)(output_tensors->at(2).data); + T* key_cache = (T*)(output_tensors->at(1).data); + T* value_cache = (T*)(output_tensors->at(2).data); + + const int batch_size = input_tensors->at(0).shape[0]; + const int beam_width = input_tensors->at(8).shape[1]; + const int memory_max_len = output_tensors->at(1).shape[3]; - const int batch_size = input_tensors->at(0).shape[0]; - const int beam_width = input_tensors->at(6).shape[1]; - const int max_seq_len = output_tensors->at(1).shape[3]; + const int* d_prefix_prompt_lengths = input_tensors->at(4).getPtr(); + const int max_prefix_prompt_length = input_tensors->at(5).getVal(); #ifdef SPARSITY_ENABLED const int m_padded = 8 * div_up(batch_size, 8); @@ -486,28 +510,33 @@ void DecoderSelfAttentionLayer::forward(std::vector(qkv_buf_, - attention_weights->query_weight.bias, - relative_attention_bias, - key_cache, - value_cache, - cache_indir, - context_buf_, - finished, - sequence_lengths, - batch_size, - batch_size, - beam_width, - local_head_num_, - size_per_head_, - rotary_embedding_dim_, - max_seq_len, - *(int*)(input_tensors->at(4).data), - (int*)(input_tensors->at(3).data), - *(int*)(input_tensors->at(5).data), - q_scaling_, - relative_attention_bias_stride, - stream_); + fusedQKV_masked_attention_dispatch( + qkv_buf_, + attention_weights->query_weight.bias, + relative_attention_bias, + key_cache, + value_cache, + cache_indir, + context_buf_, + finished, + sequence_lengths, // NOTE: current seq len including padding (fixed after meeting the finished id) + batch_size, + batch_size, + beam_width, + local_head_num_, + size_per_head_, + rotary_embedding_dim_, + neox_rotary_style_, + memory_max_len, + d_prefix_prompt_lengths, + max_prefix_prompt_length, + input_tensors->at(6).getVal(), + input_tensors->at(3).getPtr(), + input_tensors->at(7).getVal(), + q_scaling_, + relative_attention_bias_stride, + masked_tokens, + stream_); sync_check_cuda_error(); #ifdef SPARSITY_ENABLED diff --git a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.h b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.h index ab3cda929..ceb688d67 100644 --- a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.h @@ -34,8 +34,9 @@ class DecoderSelfAttentionLayer: public BaseAttentionLayer { const size_t local_head_num_; const size_t local_hidden_units_; const size_t d_model_; - const float q_scaling_; + const float q_scaling_; const size_t rotary_embedding_dim_; + const bool neox_rotary_style_; const int int8_mode_ = 0; @@ -50,146 +51,120 @@ class DecoderSelfAttentionLayer: public BaseAttentionLayer { using BaseAttentionLayer::allocator_; protected: - T* qkv_buf_ = nullptr; + T* qkv_buf_ = nullptr; T* context_buf_ = nullptr; using BaseAttentionLayer::stream_; using BaseAttentionLayer::sparse_; public: - DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - size_t rotary_embedding_dim, - size_t d_model, - const float q_scaling, - cudaStream_t stream, + DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + size_t d_model, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0); - - DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - cudaStream_t stream, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0); - - DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - const float q_scaling, - cudaStream_t stream, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0); - - DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - cudaStream_t stream, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0); - - DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - size_t d_model, - const float q_scaling, - cudaStream_t stream, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + size_t d_model, + const float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0); - - DecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - size_t rotary_embedding_dim, - cudaStream_t stream, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + DecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); DecoderSelfAttentionLayer(DecoderSelfAttentionLayer const& attention_layer); ~DecoderSelfAttentionLayer(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; }; template -void fusedQKV_masked_attention_dispatch(const T* qkv_buf, - const T* qkv_bias, - T* key_cache, - T* value_cache, - T* context_buf, - const bool* finished, - const int* sequence_lengths, - const int max_batch_size, - const int inference_batch_size, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - const int max_seq_len, - const int max_input_len, - const int* input_lengths, - const int step, +void fusedQKV_masked_attention_dispatch(const T* qkv_buf, + const T* qkv_bias, + const T* relative_attention_bias, + T* key_cache, + T* value_cache, + const int* cache_indir, + T* context_buf, + const bool* finished, + const int* sequence_lengths, + const int max_batch_size, + const int inference_batch_size, + const int beam_width, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const bool neox_rotary_style, + const int memory_max_len, + const int* prefix_prompt_lengths, + const int max_prefix_prompt_length, + const int max_input_len, + const int* total_padding_tokens, + const int step, + const float q_scaling, + const int relative_attention_bias_stride, + const bool* masked_tokens, cudaStream_t stream); -template -void fusedQKV_masked_attention_dispatch(const T* qkv_buf, - const T* qkv_bias, - T* key_cache, - T* value_cache, - T* context_buf, - const bool* finished, - const int* sequence_lengths, - const int max_batch_size, - const int inference_batch_size, - const int head_num, - const int size_per_head, - const int max_seq_len, - const int max_input_len, - const int* input_lengths, - const int step, - cudaStream_t stream) -{ - fusedQKV_masked_attention_dispatch(qkv_buf, - qkv_bias, - key_cache, - value_cache, - context_buf, - finished, - sequence_lengths, - max_batch_size, - inference_batch_size, - head_num, - size_per_head, - /*rotary_embedding_dim */ 0, - max_seq_len, - max_input_len, - input_lengths, - step, - stream); -} - } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/FusedAttentionLayer.cu b/src/fastertransformer/layers/attention_layers/FusedAttentionLayer.cu index d72fcaee8..b4a36e96f 100644 --- a/src/fastertransformer/layers/attention_layers/FusedAttentionLayer.cu +++ b/src/fastertransformer/layers/attention_layers/FusedAttentionLayer.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,16 +18,16 @@ namespace fastertransformer { -__global__ void trt_add_QKV_bias(half2* qkv_buf, +__global__ void trt_add_QKV_bias(half2* qkv_buf, const half2* Q, const half2* bias_Q, const half2* K, const half2* bias_K, const half2* V, const half2* bias_V, - const int valid_word_num, - const int head_num, - const int size_per_head) + const int valid_word_num, + const int head_num, + const int size_per_head) { // Add bias, and then transpose from // [3, valid_word_num, head, size] -> [valid_word_num, head, 3, size] @@ -41,7 +41,7 @@ __global__ void trt_add_QKV_bias(half2* qkv_buf, const int head_id = (index - size_id) / size_per_head; const int target_offset = blockIdx.x * head_num * 3 * size_per_head + head_id * 3 * size_per_head; - const int src_id = seq_id * head_num * size_per_head + index; + const int src_id = seq_id * head_num * size_per_head + index; qkv_buf[target_offset + 0 * size_per_head + size_id] = Q[src_id] + bias_Q[index]; qkv_buf[target_offset + 1 * size_per_head + size_id] = K[src_id] + bias_K[index]; @@ -52,6 +52,7 @@ __global__ void trt_add_QKV_bias(half2* qkv_buf, template void FusedAttentionLayer::invokeTrtAddQkvBias(size_t token_num, const AttentionWeight* attention_weights) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); dim3 grid(token_num); dim3 block(min((int)(head_num_ * size_per_head_ / 2), 512)); @@ -68,25 +69,24 @@ void FusedAttentionLayer::invokeTrtAddQkvBias(size_t token_num, const Attenti } template -void FusedAttentionLayer::forward(std::vector* output_tensors, +void FusedAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { - // input_tensors: [input_query (token_num, hidden_dimension), + // input_tensors: [input_query (h_token_num, d_model), // attention_mask (batch, 1, seqlen, seqlen), // padding_offset (batch + 1 or batch * 2 + 1))] // If padding_offset.data is nullptr, then not remove padding - FT_CHECK(isValidBatchSize(input_tensors->at(1).shape[0])); - FT_CHECK(isValidSeqLen(input_tensors->at(1).shape[2])); + FT_LOG_DEBUG(__PRETTY_FUNCTION__); const int request_batch_size = input_tensors->at(1).shape[0]; - const int request_seq_len = input_tensors->at(1).shape[2]; + const int request_seq_len = input_tensors->at(1).shape[2]; allocateBuffer(request_batch_size, request_seq_len); - T* attention_out = (T*)output_tensors->at(0).data; - const T* from_tensor = (const T*)input_tensors->at(0).data; - const T* attention_mask = (const T*)input_tensors->at(1).data; + T* attention_out = (T*)output_tensors->at(0).data; + const T* from_tensor = (const T*)input_tensors->at(0).data; + const T* attention_mask = (const T*)input_tensors->at(1).data; const int* padding_offset = (const int*)input_tensors->at(2).data; size_t m_tmp = input_tensors->at(0).shape[0]; @@ -94,8 +94,8 @@ void FusedAttentionLayer::forward(std::vector* out m_tmp = (m_tmp / 8 + 1) * 8; } const size_t m = input_tensors->at(0).shape[0]; - const int k = hidden_units_; - const int n = hidden_units_; + int k = d_model_; + int n = hidden_units_; #ifdef SPARSITY_ENABLED const size_t m_padded = m_tmp; @@ -181,6 +181,8 @@ void FusedAttentionLayer::forward(std::vector* out dispatcher_fp16->run(qkv_buf_, nullptr, (int*)input_tensors->at(2).data, attn_workspace_, qkv_buf_2_, stream_); sync_check_cuda_error(); + k = hidden_units_; + n = d_model_; #ifdef SPARSITY_ENABLED if (sparse_ && cublas_wrapper_->isUseSparse(1, n, m, k)) { cublas_wrapper_->SpGemm(CUBLAS_OP_N, @@ -215,26 +217,27 @@ void FusedAttentionLayer::forward(std::vector* out } template -FusedAttentionLayer::FusedAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - int sm, - float q_scaling, - cudaStream_t stream, +FusedAttentionLayer::FusedAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t d_model, + int sm, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), - max_batch_size_(max_batch_size), - max_seq_len_(max_seq_len), head_num_(head_num), size_per_head_(size_per_head), + d_model_(d_model), sm_(sm), q_scaling_(q_scaling), sparse_(sparse) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); if ((sm_ == kSM_70 || sm_ == kSM_86 || sm_ == kSM_80 || sm_ == kSM_75 || sm_ == kSM_72) && size_per_head_ == 64) { dispatcher_fp16.reset(new FusedMHARunnerFP16v2(head_num_, size_per_head_, sm_, q_scaling_)); } @@ -246,30 +249,25 @@ FusedAttentionLayer::FusedAttentionLayer(size_t max_batch_size, template FusedAttentionLayer::FusedAttentionLayer(FusedAttentionLayer const& attention_layer): - BaseAttentionLayer(attention_layer.stream_, - attention_layer.cublas_wrapper_, - attention_layer.allocator_, - attention_layer.is_free_buffer_after_forward_), - max_batch_size_(attention_layer.max_batch_size_), - max_seq_len_(attention_layer.max_seq_len_), - head_num_(attention_layer.head_num_), - size_per_head_(attention_layer.size_per_head_), - hidden_units_(attention_layer.hidden_units_), - sm_(attention_layer.sm_), - q_scaling_(attention_layer.q_scaling_), - sparse_(attention_layer.sparse_) + FusedAttentionLayer(0, + 0, + attention_layer.head_num_, + attention_layer.size_per_head_, + attention_layer.d_model_, + attention_layer.sm_, + attention_layer.q_scaling_, + attention_layer.stream_, + attention_layer.cublas_wrapper_, + attention_layer.allocator_, + attention_layer.is_free_buffer_after_forward_, + attention_layer.sparse_) { - if ((sm_ == kSM_70 || sm_ == kSM_86 || sm_ == kSM_80 || sm_ == kSM_75 || sm_ == kSM_72) && size_per_head_ == 64) { - dispatcher_fp16.reset(new FusedMHARunnerFP16v2(head_num_, size_per_head_, sm_, q_scaling_)); - } - else { - throw std::runtime_error(std::string("[FT][ERROR] FusedAttentionLayer not support \n")); - } } template FusedAttentionLayer::~FusedAttentionLayer() { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); cublas_wrapper_ = nullptr; freeBuffer(); } @@ -277,72 +275,53 @@ FusedAttentionLayer::~FusedAttentionLayer() template void FusedAttentionLayer::allocateBuffer() { - if (is_allocate_buffer_ == false) { - q_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - k_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - v_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - qkv_buf_ = (T*)allocator_->malloc(sizeof(T) * 3 * max_batch_size_ * max_seq_len_ * hidden_units_, false); - qkv_buf_2_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - attn_workspace_ = (T*)allocator_->malloc(dispatcher_fp16->getWorkspaceSize(), false); - - batch_qkv_kernel_ptr_ = (T**)allocator_->malloc(sizeof(T*) * 12, false); - batch_qkv_input_ptr_ = batch_qkv_kernel_ptr_ + 4; - batch_qkv_buf_ptr_ = batch_qkv_input_ptr_ + 4; - is_allocate_buffer_ = true; - } + FT_CHECK(false); } template void FusedAttentionLayer::allocateBuffer(size_t batch_size, size_t seq_len) { - q_buf_ = (T*)allocator_->reMalloc(q_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); - k_buf_ = (T*)allocator_->reMalloc(k_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); - v_buf_ = (T*)allocator_->reMalloc(v_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); - qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * 3 * batch_size * seq_len * hidden_units_, false); - qkv_buf_2_ = (T*)allocator_->reMalloc(qkv_buf_2_, sizeof(T) * batch_size * seq_len * hidden_units_, false); + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + q_buf_ = (T*)allocator_->reMalloc(q_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); + k_buf_ = (T*)allocator_->reMalloc(k_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); + v_buf_ = (T*)allocator_->reMalloc(v_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); + qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * 3 * batch_size * seq_len * hidden_units_, false); + qkv_buf_2_ = (T*)allocator_->reMalloc(qkv_buf_2_, sizeof(T) * batch_size * seq_len * hidden_units_, false); attn_workspace_ = (T*)allocator_->reMalloc(attn_workspace_, dispatcher_fp16->getWorkspaceSize(), false); batch_qkv_kernel_ptr_ = (T**)allocator_->reMalloc(batch_qkv_kernel_ptr_, sizeof(T*) * 12, false); - batch_qkv_input_ptr_ = batch_qkv_kernel_ptr_ + 4; - batch_qkv_buf_ptr_ = batch_qkv_input_ptr_ + 4; - is_allocate_buffer_ = true; + batch_qkv_input_ptr_ = batch_qkv_kernel_ptr_ + 4; + batch_qkv_buf_ptr_ = batch_qkv_input_ptr_ + 4; + is_allocate_buffer_ = true; } template void FusedAttentionLayer::freeBuffer() { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_) { - allocator_->free(q_buf_); - allocator_->free(k_buf_); - allocator_->free(v_buf_); - allocator_->free(qkv_buf_); - allocator_->free(qkv_buf_2_); - allocator_->free(attn_workspace_); - allocator_->free(batch_qkv_kernel_ptr_); + allocator_->free((void**)(&q_buf_)); + allocator_->free((void**)(&k_buf_)); + allocator_->free((void**)(&v_buf_)); + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&qkv_buf_2_)); + allocator_->free((void**)(&attn_workspace_)); + allocator_->free((void**)(&batch_qkv_kernel_ptr_)); sync_check_cuda_error(); is_allocate_buffer_ = false; } } template -bool FusedAttentionLayer::isValidBatchSize(size_t batch_size) +bool FusedAttentionLayer::isValidSeqLen(const size_t seq_len) { - if (max_batch_size_ < batch_size) { - max_batch_size_ = batch_size; - } - return true; -} - -template -bool FusedAttentionLayer::isValidSeqLen(size_t seq_len) -{ - if (max_seq_len_ < seq_len) { - max_seq_len_ = seq_len; - } return seq_len <= 384; } template class FusedAttentionLayer; template class FusedAttentionLayer; +#ifdef ENABLE_BF16 +template class FusedAttentionLayer<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/attention_layers/FusedAttentionLayer.h b/src/fastertransformer/layers/attention_layers/FusedAttentionLayer.h index 685b021e1..ef7d12468 100644 --- a/src/fastertransformer/layers/attention_layers/FusedAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/FusedAttentionLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,23 +32,18 @@ class FusedAttentionLayer: public BaseAttentionLayer { // metadata size_t head_num_; size_t size_per_head_; - bool sparse_; + size_t d_model_; + bool sparse_; // calculated params size_t hidden_units_; - // buffer handling - size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; - void allocateBuffer() override; void freeBuffer() override; - bool isValidBatchSize(size_t batch_size); - bool isValidSeqLen(size_t seq_len); void allocateBuffer(size_t batch_size, size_t seq_len); - int sm_; - float q_scaling_; + int sm_; + float q_scaling_; std::unique_ptr dispatcher_fp16; using BaseAttentionLayer::stream_; @@ -58,43 +53,45 @@ class FusedAttentionLayer: public BaseAttentionLayer { using BaseAttentionLayer::allocator_; protected: - T* q_buf_ = nullptr; - T* k_buf_ = nullptr; - T* v_buf_ = nullptr; - T* q_buf_2_ = nullptr; - T* k_buf_2_ = nullptr; - T* v_buf_2_ = nullptr; - T* qk_buf_ = nullptr; - T* qkv_buf_ = nullptr; - T* qkv_buf_2_ = nullptr; + T* q_buf_ = nullptr; + T* k_buf_ = nullptr; + T* v_buf_ = nullptr; + T* q_buf_2_ = nullptr; + T* k_buf_2_ = nullptr; + T* v_buf_2_ = nullptr; + T* qk_buf_ = nullptr; + T* qkv_buf_ = nullptr; + T* qkv_buf_2_ = nullptr; T* attn_workspace_ = nullptr; T** batch_qkv_kernel_ptr_ = nullptr; - T** batch_qkv_input_ptr_ = nullptr; - T** batch_qkv_buf_ptr_ = nullptr; + T** batch_qkv_input_ptr_ = nullptr; + T** batch_qkv_buf_ptr_ = nullptr; public: - FusedAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - int sm, - float q_scaling, - cudaStream_t stream, + FusedAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t d_model, + int sm, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false); FusedAttentionLayer(FusedAttentionLayer const& attention_layer); ~FusedAttentionLayer(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; void invokeTrtAddQkvBias(size_t token_num, const AttentionWeight* attention_weights); + bool isValidSeqLen(const size_t seq_len) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc index bada6408b..c628bd64c 100644 --- a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc @@ -21,35 +21,44 @@ namespace fastertransformer { template -void GptContextAttentionLayer::forward(std::vector* output_tensors, +void GptContextAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input_tensors: - // input_query [batch_size * seq_len, hidden_dimension] - // attention_mask [batch_size, 1, seq_len, seq_len] + // input_query [token_num, hidden_dimension] + // attention_mask [batch_size, 1, seq_len, seq_len + max_prompt_length] // is_final_layer [1], bool on cpu + // d_prefix_prompt_batch [global_batch_size], + // each element contains ptr with buffer shape[2, local_head_num_, prompt_length, size_per_head] + // d_prefix_prompt_lengths [batch_size], int + // layer_id [1], int on cpu + // padding_offset, int, [token_num] // output_tensors: - // attention_out [batch_size * seq_len, hidden_dimension] + // attention_out [token_num, hidden_dimension] // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] // value_cache [batch, local_head_num, max_seq_len, size_per_head] - FT_CHECK(input_tensors->size() == 3); + FT_CHECK(input_tensors->size() >= 6); FT_CHECK(output_tensors->size() == 3); FT_CHECK(output_tensors->at(1).shape.size() == 5); FT_CHECK(output_tensors->at(2).shape.size() == 4 || output_tensors->at(2).shape.size() == 3); - // FT_CHECK(isValidBatchSize(input_tensors->at(1).shape[0])); - // FT_CHECK(isValidSeqLen(input_tensors->at(1).shape[2])); - const int request_batch_size = input_tensors->at(1).shape[0]; - const int request_seq_len = input_tensors->at(1).shape[2]; - allocateBuffer(request_batch_size, request_seq_len); + const int request_batch_size = input_tensors->at(1).shape[0]; + const int request_seq_len = input_tensors->at(1).shape[2]; + const int max_prompt_length = input_tensors->at(1).shape[3] - input_tensors->at(1).shape[2]; + const int layer_id = *(int*)input_tensors->at(5).data; + const T** d_prefix_prompt_batch = (const T**)input_tensors->at(3).data; + const int* d_prefix_prompt_lengths = (const int*)input_tensors->at(4).data; + const int* padding_offset = input_tensors->size() == 7 ? input_tensors->at(6).getPtr() : nullptr; + + allocateBuffer(request_batch_size, request_seq_len + max_prompt_length); sync_check_cuda_error(); - T* attention_out = (T*)output_tensors->at(0).data; - const T* attention_input = (const T*)input_tensors->at(0).data; - const T* attention_mask = (const T*)input_tensors->at(1).data; - const bool is_final = *((bool*)(input_tensors->at(2).data)); + T* attention_out = (T*)output_tensors->at(0).data; + const T* attention_input = (const T*)input_tensors->at(0).data; + const T* attention_mask = (const T*)input_tensors->at(1).data; + const bool is_final = *((bool*)(input_tensors->at(2).data)); const int m = input_tensors->at(0).shape[0]; @@ -81,67 +90,92 @@ void GptContextAttentionLayer::forward(std::vector #ifdef SPARSITY_ENABLED } #endif + sync_check_cuda_error(); + + // IDEA: append prefix prompt key value here + PrefixPromptBatchWeightsParam param{d_prefix_prompt_batch, + d_prefix_prompt_lengths, + max_prompt_length, + (size_t)layer_id * 2 * local_head_num_ * size_per_head_}; + + if (padding_offset != nullptr) { + // q_buf_2_, k_buf_2_ and v_buf_2_ are continuous + cudaMemsetAsync( + q_buf_2_, 0, request_batch_size * request_seq_len * 3 * local_hidden_units_ * sizeof(T), stream_); + } invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, + param, // prefix prompt qkv_buf_, attention_weights->query_weight.bias, + padding_offset, request_batch_size, request_seq_len, + m, local_head_num_, size_per_head_, rotary_embedding_dim_, + neox_rotary_style_, stream_); sync_check_cuda_error(); - const int max_seq_len = (int)(output_tensors->at(1).shape[3]); + const int max_seq_len = (int)(output_tensors->at(1).shape[3]); // max output seq length // Use batch major - // put k/v_buf from shape [B, H, L, Dh] - // to cache [B, H, Dh/x, L, x] and [B, H, L, Dh/x, x] + // put k/v_buf from shape [B, H, PL + L, Dh] + // to cache [B, H, Dh/x, PL + L, x] and [B, H, PL + L, Dh/x, x], PL denotes prompt length invokeTranspose4dBatchMajor((T*)output_tensors->at(1).data, (T*)output_tensors->at(2).data, k_buf_2_, v_buf_2_, request_batch_size, - request_seq_len, + max_prompt_length + request_seq_len, // max input length + prefix prompt length max_seq_len, size_per_head_, local_head_num_, stream_); + // IDEA : after this, k_cache = (batch_size, num_heads, Dh/x, prefix_prompt_len + L, x) + // k_cache = (batch_size, num_heads, prefix_prompt_len + L, Dh) sync_check_cuda_error(); + // NOTE: qkv buffer shape (batch_size, num_heads,L or prompt_len + L, Dh) + if (is_final == false) { - const cudaDataType_t gemm_data_type = getCudaDataType(); + const cudaDataType_t gemm_data_type = getCudaDataType(); + const int attention_seq_len_1 = request_seq_len; // q length + const int attention_seq_len_2 = max_prompt_length + request_seq_len; // kv length if (is_qk_buf_float_ == true && gemm_data_type != CUDA_R_32F) { cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N, - request_seq_len, - request_seq_len, - size_per_head_, + attention_seq_len_2, // n + attention_seq_len_1, // m + size_per_head_, // k 1.0f, k_buf_2_, gemm_data_type, - size_per_head_, - request_seq_len * size_per_head_, + size_per_head_, // k + attention_seq_len_2 * size_per_head_, // n * k q_buf_2_, gemm_data_type, - size_per_head_, - request_seq_len * size_per_head_, + size_per_head_, // k + attention_seq_len_1 * size_per_head_, // m * k 0.0f, qk_buf_float_, CUDA_R_32F, - request_seq_len, - request_seq_len * request_seq_len, - request_batch_size * local_head_num_, + attention_seq_len_2, // n + attention_seq_len_2 * attention_seq_len_1, + request_batch_size * local_head_num_, // global batch size CUDA_R_32F); + sync_check_cuda_error(); T scalar = 1 / sqrtf(size_per_head_ * 1.0f); invokeMaskedSoftMax(qk_buf_, qk_buf_float_, attention_mask, request_batch_size, - request_seq_len, + attention_seq_len_1, // seq_len_1 + attention_seq_len_2, // seq_len_2 local_head_num_, scalar, stream_); @@ -150,18 +184,18 @@ void GptContextAttentionLayer::forward(std::vector else { cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N, - request_seq_len, - request_seq_len, + attention_seq_len_2, + attention_seq_len_1, size_per_head_, k_buf_2_, size_per_head_, - request_seq_len * size_per_head_, + attention_seq_len_2 * size_per_head_, q_buf_2_, size_per_head_, - request_seq_len * size_per_head_, + attention_seq_len_1 * size_per_head_, qk_buf_, - request_seq_len, - request_seq_len * request_seq_len, + attention_seq_len_2, + attention_seq_len_2 * attention_seq_len_1, request_batch_size * local_head_num_); T scalar = 1 / sqrtf(size_per_head_ * 1.0f); @@ -169,7 +203,8 @@ void GptContextAttentionLayer::forward(std::vector qk_buf_, attention_mask, request_batch_size, - request_seq_len, + attention_seq_len_1, + attention_seq_len_2, local_head_num_, scalar, stream_); @@ -179,21 +214,43 @@ void GptContextAttentionLayer::forward(std::vector cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, CUBLAS_OP_N, size_per_head_, - request_seq_len, - request_seq_len, + attention_seq_len_1, + attention_seq_len_2, v_buf_2_, size_per_head_, - request_seq_len * size_per_head_, + attention_seq_len_2 * size_per_head_, qk_buf_, - request_seq_len, - request_seq_len * request_seq_len, + attention_seq_len_2, + attention_seq_len_1 * attention_seq_len_2, qkv_buf_2_, size_per_head_, - request_seq_len * size_per_head_, + attention_seq_len_1 * size_per_head_, request_batch_size * local_head_num_); + // transpose (batch_size, num_heads, L, Dh) to (batch_size, L, num_heads * Dh) invokeTransposeQKV( - qkv_buf_3_, qkv_buf_2_, request_batch_size, request_seq_len, local_head_num_, size_per_head_, stream_); + qkv_buf_3_, qkv_buf_2_, request_batch_size, attention_seq_len_1, local_head_num_, size_per_head_, stream_); + if (padding_offset == nullptr) { + invokeTransposeQKV(qkv_buf_3_, + qkv_buf_2_, + request_batch_size, + attention_seq_len_1, + local_head_num_, + size_per_head_, + stream_); + sync_check_cuda_error(); + } + else { + invokeTransposeAttentionOutRemovePadding(qkv_buf_2_, + qkv_buf_3_, + m, + request_batch_size, + attention_seq_len_1, + local_head_num_, + size_per_head_, + padding_offset, + stream_); + } sync_check_cuda_error(); #ifdef SPARSITY_ENABLED @@ -220,6 +277,7 @@ void GptContextAttentionLayer::forward(std::vector local_hidden_units_, attention_out, hidden_units_); + #ifdef SPARSITY_ENABLED } #endif @@ -232,16 +290,16 @@ void GptContextAttentionLayer::forward(std::vector } template -GptContextAttentionLayer::GptContextAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - cudaStream_t stream, +GptContextAttentionLayer::GptContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -251,22 +309,23 @@ GptContextAttentionLayer::GptContextAttentionLayer(size_t max_batch_size, local_head_num_(head_num), local_hidden_units_(local_head_num_ * size_per_head), rotary_embedding_dim_(0), + neox_rotary_style_(false), is_qk_buf_float_(is_qk_buf_float) { } template -GptContextAttentionLayer::GptContextAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - cudaStream_t stream, +GptContextAttentionLayer::GptContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -276,23 +335,25 @@ GptContextAttentionLayer::GptContextAttentionLayer(size_t max_batch_size, local_head_num_(local_head_num), local_hidden_units_(local_head_num_ * size_per_head), rotary_embedding_dim_(0), + neox_rotary_style_(false), is_qk_buf_float_(is_qk_buf_float) { } template -GptContextAttentionLayer::GptContextAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - size_t rotary_embedding_dim, - cudaStream_t stream, +GptContextAttentionLayer::GptContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -302,6 +363,7 @@ GptContextAttentionLayer::GptContextAttentionLayer(size_t max_batch_size, local_head_num_(local_head_num), local_hidden_units_(local_head_num_ * size_per_head), rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), is_qk_buf_float_(is_qk_buf_float) { } @@ -321,6 +383,7 @@ GptContextAttentionLayer::GptContextAttentionLayer(GptContextAttentionLayer void GptContextAttentionLayer::allocateBuffer() { FT_CHECK(false); - if (is_allocate_buffer_ == false) { - qkv_buf_ = (T*)allocator_->malloc(sizeof(T) * 3 * max_batch_size_ * max_seq_len_ * local_hidden_units_, true); - q_buf_2_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * local_hidden_units_, true); - k_buf_2_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * local_hidden_units_, true); - v_buf_2_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * local_hidden_units_, true); - - qk_buf_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * local_head_num_ * max_seq_len_ * max_seq_len_, true); - qkv_buf_2_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * local_hidden_units_, true); - qkv_buf_3_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * local_hidden_units_, true); - - if (is_qk_buf_float_ == true) { - qk_buf_float_ = (float*)allocator_->malloc( - sizeof(float) * max_batch_size_ * local_head_num_ * max_seq_len_ * max_seq_len_, true); - } - - is_allocate_buffer_ = true; - } } template @@ -361,11 +406,11 @@ void GptContextAttentionLayer::allocateBuffer(size_t batch_size, size_t seq_l { FT_LOG_DEBUG(__PRETTY_FUNCTION__); qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * 3 * batch_size * seq_len * local_hidden_units_, true); - q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * local_hidden_units_, true); - k_buf_2_ = (T*)allocator_->reMalloc(k_buf_2_, sizeof(T) * batch_size * seq_len * local_hidden_units_, true); - v_buf_2_ = (T*)allocator_->reMalloc(v_buf_2_, sizeof(T) * batch_size * seq_len * local_hidden_units_, true); + q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * 3 * local_hidden_units_, true); + k_buf_2_ = q_buf_2_ + batch_size * seq_len * local_hidden_units_; + v_buf_2_ = k_buf_2_ + batch_size * seq_len * local_hidden_units_; - qk_buf_ = (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * seq_len * seq_len, true); + qk_buf_ = (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * seq_len * seq_len, true); qkv_buf_2_ = (T*)allocator_->reMalloc(qkv_buf_2_, sizeof(T) * batch_size * seq_len * local_hidden_units_, true); qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * batch_size * seq_len * local_hidden_units_, true); @@ -381,47 +426,19 @@ void GptContextAttentionLayer::freeBuffer() { if (is_allocate_buffer_) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - allocator_->free(qkv_buf_); - allocator_->free(q_buf_2_); - allocator_->free(k_buf_2_); - allocator_->free(v_buf_2_); - allocator_->free(qk_buf_); - allocator_->free(qkv_buf_2_); - allocator_->free(qkv_buf_3_); + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&q_buf_2_)); + allocator_->free((void**)(&qk_buf_)); + allocator_->free((void**)(&qkv_buf_2_)); + allocator_->free((void**)(&qkv_buf_3_)); if (is_qk_buf_float_ == true) { - allocator_->free(qk_buf_float_); + allocator_->free((void**)(&qk_buf_float_)); } is_allocate_buffer_ = false; } } -template -bool GptContextAttentionLayer::isValidBatchSize(size_t batch_size) -{ - if (batch_size <= max_batch_size_) { - return true; - } - else { - freeBuffer(); - max_batch_size_ = batch_size * 1.2; - return true; - } -} - -template -bool GptContextAttentionLayer::isValidSeqLen(size_t seq_len) -{ - if (seq_len <= max_seq_len_) { - return true; - } - else { - freeBuffer(); - max_seq_len_ = seq_len * 1.2; - return true; - } -} - template class GptContextAttentionLayer; template class GptContextAttentionLayer; #ifdef ENABLE_BF16 diff --git a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h index 92e2175ef..e0ff3d30c 100644 --- a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h @@ -26,7 +26,7 @@ class GptContextAttentionLayer: public BaseAttentionLayer { private: // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; // metadata const size_t head_num_; @@ -35,12 +35,11 @@ class GptContextAttentionLayer: public BaseAttentionLayer { const size_t local_head_num_; const size_t local_hidden_units_; const size_t rotary_embedding_dim_; + const bool neox_rotary_style_; void allocateBuffer() override; void allocateBuffer(size_t batch_size, size_t seq_len); void freeBuffer() override; - bool isValidBatchSize(size_t batch_size); - bool isValidSeqLen(size_t seq_len); using BaseAttentionLayer::is_free_buffer_after_forward_; using BaseAttentionLayer::is_allocate_buffer_; @@ -52,59 +51,60 @@ class GptContextAttentionLayer: public BaseAttentionLayer { protected: using BaseAttentionLayer::stream_; using BaseAttentionLayer::sparse_; - T* qkv_buf_ = nullptr; - T* q_buf_2_ = nullptr; - T* k_buf_2_ = nullptr; - T* v_buf_2_ = nullptr; - T* qk_buf_ = nullptr; + T* qkv_buf_ = nullptr; + T* q_buf_2_ = nullptr; + T* k_buf_2_ = nullptr; + T* v_buf_2_ = nullptr; + T* qk_buf_ = nullptr; float* qk_buf_float_ = nullptr; - T* qkv_buf_2_ = nullptr; - T* qkv_buf_3_ = nullptr; + T* qkv_buf_2_ = nullptr; + T* qkv_buf_3_ = nullptr; public: - GptContextAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - cudaStream_t stream, + GptContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse = false); - - GptContextAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - cudaStream_t stream, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false); + + GptContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse = false); - - GptContextAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t local_head_num, - size_t rotary_embedding_dim, - cudaStream_t stream, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false); + + GptContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t local_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style_, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false); GptContextAttentionLayer(GptContextAttentionLayer const& attention_layer); virtual ~GptContextAttentionLayer(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/LongformerAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LongformerAttentionLayer.cc index 22f9ba2e0..8a19da659 100644 --- a/src/fastertransformer/layers/attention_layers/LongformerAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LongformerAttentionLayer.cc @@ -27,17 +27,17 @@ namespace fastertransformer { template -LongformerAttentionLayer::LongformerAttentionLayer(size_t head_num, - size_t size_per_head, - size_t local_attn_window_size, - size_t max_global_token_num, - size_t max_batch_size, - size_t max_seq_len, - float attn_scaler, - cudaStream_t stream, +LongformerAttentionLayer::LongformerAttentionLayer(size_t head_num, + size_t size_per_head, + size_t local_attn_window_size, + size_t max_global_token_num, + size_t max_batch_size, + size_t max_seq_len, + float attn_scaler, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), head_num_(head_num), size_per_head_(size_per_head), @@ -64,27 +64,27 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c allocateBuffer(); const int batch_size = input_tensors->at(0).shape[0]; - const int seq_len = input_tensors->at(0).shape[2]; + const int seq_len = input_tensors->at(0).shape[2]; FT_CHECK(seq_len % local_attn_window_size_ == 0); FT_CHECK(size_per_head_ == 64); - const int batch_stride = head_num_ * seq_len * size_per_head_; - const int global_batch_stride = head_num_ * max_global_token_num_ * size_per_head_; - const int attn_head_stride = seq_len * size_per_head_; - const int attn_window_stride = local_attn_window_size_ * size_per_head_; + const int batch_stride = head_num_ * seq_len * size_per_head_; + const int global_batch_stride = head_num_ * max_global_token_num_ * size_per_head_; + const int attn_head_stride = seq_len * size_per_head_; + const int attn_window_stride = local_attn_window_size_ * size_per_head_; const int local_attn_head_tail_gemm_strides_count = batch_size * head_num_; - const int local_attn_middle_gemm_strides_count = (seq_len / local_attn_window_size_) - 2; - - const void* q = input_tensors->at(0).data; - const void* k = input_tensors->at(1).data; - const void* v = input_tensors->at(2).data; - const void* qg = input_tensors->at(3).data; - const void* kg = input_tensors->at(4).data; - const void* vg = input_tensors->at(5).data; - const void* local_attn_mask = (const T*)input_tensors->at(6).data; - const void* global_attn_mask = (const T*)input_tensors->at(7).data; - const int* global_idx = (const int*)input_tensors->at(8).data; - const int* global_token_nums = (const int*)input_tensors->at(9).data; + const int local_attn_middle_gemm_strides_count = (seq_len / local_attn_window_size_) - 2; + + const void* q = input_tensors->at(0).data; + const void* k = input_tensors->at(1).data; + const void* v = input_tensors->at(2).data; + const void* qg = input_tensors->at(3).data; + const void* kg = input_tensors->at(4).data; + const void* vg = input_tensors->at(5).data; + const void* local_attn_mask = (const T*)input_tensors->at(6).data; + const void* global_attn_mask = (const T*)input_tensors->at(7).data; + const int* global_idx = (const int*)input_tensors->at(8).data; + const int* global_token_nums = (const int*)input_tensors->at(9).data; void* output = (void*)output_tensors->at(0).data; @@ -101,8 +101,8 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c 10 - 14: store buf_strides per head for each buffer. size: 5 * size_t */ size_t* internal_var[15]; - void** buf_ptrs = (void**)&internal_var[0]; - size_t* buf_sizes = (size_t*)(&internal_var[5]); + void** buf_ptrs = (void**)&internal_var[0]; + size_t* buf_sizes = (size_t*)(&internal_var[5]); size_t* buf_strides = (size_t*)(&internal_var[10]); buf_sizes[0] = (size_t)local_attn_window_size_ * 2 * local_attn_window_size_; @@ -162,7 +162,7 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c (char*)q + (i * batch_stride + j * size_per_head_ * seq_len + local_attn_window_size_ * size_per_head_) * sizeof(T); - void* k_middle = (char*)k + (i * batch_stride + j * size_per_head_ * seq_len) * sizeof(T); + void* k_middle = (char*)k + (i * batch_stride + j * size_per_head_ * seq_len) * sizeof(T); void* qk_middle = (char*)buf_ptrs[1] + (i * head_num_ + j) * buf_sizes[1] * sizeof(T); cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, @@ -184,10 +184,10 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c } // local attn per head - tail - int tail_blk_q = (seq_len / local_attn_window_size_) - 1; - int tail_blk_k = (seq_len / local_attn_window_size_) - 2; - void* q_tail = (char*)q + tail_blk_q * local_attn_window_size_ * size_per_head_ * sizeof(T); - void* k_tail = (char*)k + tail_blk_k * local_attn_window_size_ * size_per_head_ * sizeof(T); + int tail_blk_q = (seq_len / local_attn_window_size_) - 1; + int tail_blk_k = (seq_len / local_attn_window_size_) - 2; + void* q_tail = (char*)q + tail_blk_q * local_attn_window_size_ * size_per_head_ * sizeof(T); + void* k_tail = (char*)k + tail_blk_k * local_attn_window_size_ * size_per_head_ * sizeof(T); cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N, @@ -214,8 +214,8 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c } if (global_token_nums_cpu[i] > 0) { // local tokens attending global tokens - void* q_local = (char*)q + (i * batch_stride + local_attn_window_size_ * size_per_head_) * sizeof(T); - void* k_local = (char*)k + i * batch_stride * sizeof(T); + void* q_local = (char*)q + (i * batch_stride + local_attn_window_size_ * size_per_head_) * sizeof(T); + void* k_local = (char*)k + i * batch_stride * sizeof(T); void* qk_local = (char*)buf_ptrs[3] + i * buf_sizes[3] * head_num_ * sizeof(T); cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, @@ -235,8 +235,8 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c head_num_); // global token attending everything - void* q_global = (char*)qg + (i * global_batch_stride) * sizeof(T); - void* k_global = (char*)kg + (i * batch_stride) * sizeof(T); + void* q_global = (char*)qg + (i * global_batch_stride) * sizeof(T); + void* k_global = (char*)kg + (i * batch_stride) * sizeof(T); void* qk_global = (char*)buf_ptrs[4] + (i * buf_sizes[4] * head_num_) * sizeof(T); cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N, @@ -291,8 +291,8 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c for (int i = 0; i < batch_size; ++i) { for (int j = 0; j < (int)head_num_; ++j) { void* v_local = (char*)v + (i * batch_stride + j * size_per_head_ * seq_len) * sizeof(T); - void* prob = (char*)buf_ptrs[1] + (i * head_num_ + j) * buf_sizes[1] * sizeof(T); - void* out = (char*)output + void* prob = (char*)buf_ptrs[1] + (i * head_num_ + j) * buf_sizes[1] * sizeof(T); + void* out = (char*)output + (i * batch_stride + j * size_per_head_ * seq_len + local_attn_window_size_ * size_per_head_) * sizeof(T); cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, @@ -314,10 +314,10 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c } // local attn per head - tail - int tail_blk_v = (seq_len / local_attn_window_size_) - 2; - int tail_blk_out = (seq_len / local_attn_window_size_) - 1; - void* tail_v = (char*)v + tail_blk_v * local_attn_window_size_ * size_per_head_ * sizeof(T); - void* tail_out = (char*)output + tail_blk_out * local_attn_window_size_ * size_per_head_ * sizeof(T); + int tail_blk_v = (seq_len / local_attn_window_size_) - 2; + int tail_blk_out = (seq_len / local_attn_window_size_) - 1; + void* tail_v = (char*)v + tail_blk_v * local_attn_window_size_ * size_per_head_ * sizeof(T); + void* tail_out = (char*)output + tail_blk_out * local_attn_window_size_ * size_per_head_ * sizeof(T); cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, CUBLAS_OP_N, size_per_head_, @@ -340,7 +340,7 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c int glob_longdim_mm = seq_len - 2 * local_attn_window_size_; void* v_local = (char*)v + (i * batch_stride) * sizeof(T); - void* prob = (char*)buf_ptrs[3] + void* prob = (char*)buf_ptrs[3] + (i * buf_sizes[3] * head_num_ + local_attn_window_size_ * buf_strides[3]) * sizeof(T); void* out = (char*)output + (i * batch_stride + 2 * local_attn_window_size_ * size_per_head_) * sizeof(T); @@ -364,8 +364,8 @@ void LongformerAttentionLayer::forward(std::vector* output_tensors, c // global tokens void* v_global = (char*)vg + (i * batch_stride) * sizeof(T); - prob = (char*)buf_ptrs[4] + (i * buf_sizes[4] * head_num_) * sizeof(T); - out = (char*)output + (i * batch_stride) * sizeof(T); + prob = (char*)buf_ptrs[4] + (i * buf_sizes[4] * head_num_) * sizeof(T); + out = (char*)output + (i * batch_stride) * sizeof(T); cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, CUBLAS_OP_N, @@ -396,17 +396,18 @@ void LongformerAttentionLayer::allocateBuffer() { if (!is_allocate_buffer_) { - internal_vars_device_ = (void*)allocator_->malloc(sizeof(size_t) * 15); + internal_vars_device_ = (void*)allocator_->reMalloc(internal_vars_device_, sizeof(size_t) * 15); attn_buffer_ = - (T*)allocator_->malloc(sizeof(T) * head_num_ * max_batch_size_ - * (2 * local_attn_window_size_ * local_attn_window_size_ - + 3 * local_attn_window_size_ * local_attn_window_size_ - * (max_seq_len_ / local_attn_window_size_ - 2) - + 2 * local_attn_window_size_ * local_attn_window_size_ - + local_attn_window_size_ * (max_seq_len_ - local_attn_window_size_) - + local_attn_window_size_ * max_seq_len_), - false); + (T*)allocator_->reMalloc(attn_buffer_, + sizeof(T) * head_num_ * max_batch_size_ + * (2 * local_attn_window_size_ * local_attn_window_size_ + + 3 * local_attn_window_size_ * local_attn_window_size_ + * (max_seq_len_ / local_attn_window_size_ - 2) + + 2 * local_attn_window_size_ * local_attn_window_size_ + + local_attn_window_size_ * (max_seq_len_ - local_attn_window_size_) + + local_attn_window_size_ * max_seq_len_), + false); is_allocate_buffer_ = true; } @@ -416,8 +417,8 @@ template void LongformerAttentionLayer::freeBuffer() { if (is_allocate_buffer_) { - allocator_->free(internal_vars_device_); - allocator_->free(attn_buffer_); + allocator_->free((void**)(&internal_vars_device_)); + allocator_->free((void**)(&attn_buffer_)); is_allocate_buffer_ = false; } @@ -425,5 +426,8 @@ void LongformerAttentionLayer::freeBuffer() template class LongformerAttentionLayer; template class LongformerAttentionLayer; +#ifdef ENABLE_BF16 +template class LongformerAttentionLayer<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/attention_layers/LongformerAttentionLayer.h b/src/fastertransformer/layers/attention_layers/LongformerAttentionLayer.h index 5e5dc3e3a..f3f3397b7 100644 --- a/src/fastertransformer/layers/attention_layers/LongformerAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/LongformerAttentionLayer.h @@ -30,11 +30,11 @@ class LongformerAttentionLayer: public BaseLayer { size_t max_global_token_num_; size_t max_batch_size_; size_t max_seq_len_; - float attn_scaler_; + float attn_scaler_; - // interal buffers + // internal buffers void* internal_vars_device_; - T* attn_buffer_; + T* attn_buffer_; cudaStream_t memcpy_stream_; @@ -42,19 +42,19 @@ class LongformerAttentionLayer: public BaseLayer { void freeBuffer() override; public: - LongformerAttentionLayer(size_t head_num, - size_t size_per_head, - size_t local_attn_window_size, - size_t max_global_token_num, - size_t max_batch_size, - size_t max_seq_len, - float attn_scaler, - cudaStream_t stream, + LongformerAttentionLayer(size_t head_num, + size_t size_per_head, + size_t local_attn_window_size, + size_t max_global_token_num, + size_t max_batch_size, + size_t max_seq_len, + float attn_scaler, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward = false); + IAllocator* allocator, + bool is_free_buffer_after_forward = false); ~LongformerAttentionLayer(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors); }; diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelDecoderCrossAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/TensorParallelDecoderCrossAttentionLayer.cc index fc704570d..1365a1a34 100644 --- a/src/fastertransformer/layers/attention_layers/TensorParallelDecoderCrossAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/TensorParallelDecoderCrossAttentionLayer.cc @@ -20,18 +20,18 @@ namespace fastertransformer { template TensorParallelDecoderCrossAttentionLayer::TensorParallelDecoderCrossAttentionLayer( - size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t d_model, - float q_scaling, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, + size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): DecoderCrossAttentionLayer(max_batch_size, head_num / tensor_para.world_size_, size_per_head, @@ -50,16 +50,16 @@ TensorParallelDecoderCrossAttentionLayer::TensorParallelDecoderCrossAttention template TensorParallelDecoderCrossAttentionLayer::TensorParallelDecoderCrossAttentionLayer( - size_t max_batch_size, - size_t head_num, - size_t size_per_head, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, + size_t max_batch_size, + size_t head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): TensorParallelDecoderCrossAttentionLayer(max_batch_size, head_num, size_per_head, @@ -86,7 +86,7 @@ TensorParallelDecoderCrossAttentionLayer::TensorParallelDecoderCrossAttention } template -void TensorParallelDecoderCrossAttentionLayer::forward(std::vector* output_tensors, +void TensorParallelDecoderCrossAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, const AttentionWeight* attention_weights) { @@ -103,7 +103,7 @@ void TensorParallelDecoderCrossAttentionLayer::forward(std::vectorat(0).shape[0]; + const size_t batch_size = output_tensors->at(0).shape[0]; const size_t hidden_units = output_tensors->at(0).shape[1]; bool use_custom_all_reduce_kernel = false; @@ -132,5 +132,8 @@ void TensorParallelDecoderCrossAttentionLayer::forward(std::vector; template class TensorParallelDecoderCrossAttentionLayer; +#ifdef ENABLE_BF16 +template class TensorParallelDecoderCrossAttentionLayer<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelDecoderCrossAttentionLayer.h b/src/fastertransformer/layers/attention_layers/TensorParallelDecoderCrossAttentionLayer.h index e7e2d2b68..167e1c9d1 100644 --- a/src/fastertransformer/layers/attention_layers/TensorParallelDecoderCrossAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/TensorParallelDecoderCrossAttentionLayer.h @@ -25,43 +25,43 @@ namespace fastertransformer { template class TensorParallelDecoderCrossAttentionLayer: public DecoderCrossAttentionLayer { private: - NcclParam tensor_para_; + NcclParam tensor_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; protected: public: - TensorParallelDecoderCrossAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - std::shared_ptr custom_all_reduce_comm_ = nullptr, - int enable_custom_all_reduce_ = 0); + TensorParallelDecoderCrossAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + std::shared_ptr custom_all_reduce_comm_ = nullptr, + int enable_custom_all_reduce_ = 0); - TensorParallelDecoderCrossAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t d_model, - float q_scaling, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - std::shared_ptr custom_all_reduce_comm_ = nullptr, - int enable_custom_all_reduce_ = 0); + TensorParallelDecoderCrossAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + std::shared_ptr custom_all_reduce_comm_ = nullptr, + int enable_custom_all_reduce_ = 0); TensorParallelDecoderCrossAttentionLayer(TensorParallelDecoderCrossAttentionLayer const& attention_layer); ~TensorParallelDecoderCrossAttentionLayer() = default; - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.cc index e48bc4462..25dd59ec7 100644 --- a/src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.cc @@ -20,26 +20,29 @@ namespace fastertransformer { template TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLayer( - size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t rotary_embedding_dim, - size_t d_model, - float q_scaling, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse, - int int8_mode, + size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + int int8_mode, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): DecoderSelfAttentionLayer(max_batch_size, head_num, size_per_head, head_num / tensor_para.world_size_, rotary_embedding_dim, + neox_rotary_style, d_model, q_scaling, // NOTE stream, @@ -48,6 +51,7 @@ TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLa is_free_buffer_after_forward, is_sparse, int8_mode), + do_all_reduce_(do_all_reduce), tensor_para_(tensor_para), custom_all_reduce_comm_(custom_all_reduce_comm), enable_custom_all_reduce_(enable_custom_all_reduce) @@ -57,28 +61,31 @@ TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLa template TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLayer( - size_t max_batch_size, - size_t head_num, - size_t size_per_head, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse, - int int8_mode, + size_t max_batch_size, + size_t head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + int int8_mode, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): TensorParallelDecoderSelfAttentionLayer(max_batch_size, head_num, size_per_head, 0, + false, head_num * size_per_head, 1.0f, tensor_para, stream, cublas_wrapper, allocator, + do_all_reduce, is_free_buffer_after_forward, is_sparse, int8_mode, @@ -90,30 +97,33 @@ TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLa template TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLayer( - size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t d_model, - float q_scaling, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse, - int int8_mode, + size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + int int8_mode, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): TensorParallelDecoderSelfAttentionLayer(max_batch_size, head_num, size_per_head, 0, + false, d_model, q_scaling, tensor_para, stream, cublas_wrapper, allocator, + do_all_reduce, is_free_buffer_after_forward, is_sparse, int8_mode, @@ -124,29 +134,33 @@ TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLa template TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLayer( - size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t rotary_embedding_dim, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse, - int int8_mode, + size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + int int8_mode, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): TensorParallelDecoderSelfAttentionLayer(max_batch_size, head_num, size_per_head, rotary_embedding_dim, + neox_rotary_style, head_num * size_per_head, 1.0f, tensor_para, stream, cublas_wrapper, allocator, + do_all_reduce, is_free_buffer_after_forward, is_sparse, int8_mode, @@ -159,6 +173,7 @@ template TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLayer( TensorParallelDecoderSelfAttentionLayer const& attention_layer): DecoderSelfAttentionLayer(attention_layer), + do_all_reduce_(attention_layer.do_all_reduce_), tensor_para_(attention_layer.tensor_para_), custom_all_reduce_comm_(attention_layer.custom_all_reduce_comm_), enable_custom_all_reduce_(attention_layer.enable_custom_all_reduce_) @@ -166,7 +181,7 @@ TensorParallelDecoderSelfAttentionLayer::TensorParallelDecoderSelfAttentionLa } template -void TensorParallelDecoderSelfAttentionLayer::forward(std::vector* output_tensors, +void TensorParallelDecoderSelfAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, const AttentionWeight* attention_weights) { @@ -183,11 +198,11 @@ void TensorParallelDecoderSelfAttentionLayer::forward(std::vectorat(0).shape[0]; + const size_t batch_size = output_tensors->at(0).shape[0]; const size_t hidden_units = output_tensors->at(0).shape[1]; bool use_custom_all_reduce_kernel = false; - if (enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr) { + if (enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr && do_all_reduce_) { use_custom_all_reduce_kernel = custom_all_reduce_comm_->swapInternalBuffer(output_tensors, batch_size * hidden_units); } @@ -195,7 +210,7 @@ void TensorParallelDecoderSelfAttentionLayer::forward(std::vector::forward(output_tensors, input_tensors, attention_weights); T* attention_out = (T*)(output_tensors->at(0).data); - if (tensor_para_.world_size_ > 1) { + if (tensor_para_.world_size_ > 1 && do_all_reduce_) { if (!use_custom_all_reduce_kernel) { ftNcclAllReduceSum(attention_out, attention_out, diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.h b/src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.h index 9bdc595a8..1befb6990 100644 --- a/src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.h @@ -25,77 +25,84 @@ namespace fastertransformer { template class TensorParallelDecoderSelfAttentionLayer: public DecoderSelfAttentionLayer { private: - NcclParam tensor_para_; + NcclParam tensor_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; + bool do_all_reduce_; protected: public: - TensorParallelDecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t rotary_embedding_dim, - size_t d_model, - float q_scaling, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse = false, - int int8_mode = 0, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + TensorParallelDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); - TensorParallelDecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse = false, - int int8_mode = 0, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + TensorParallelDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); - TensorParallelDecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t d_model, - float q_scaling, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + TensorParallelDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); - TensorParallelDecoderSelfAttentionLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t rotary_embedding_dim, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + TensorParallelDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); TensorParallelDecoderSelfAttentionLayer(TensorParallelDecoderSelfAttentionLayer const& attention_layer); ~TensorParallelDecoderSelfAttentionLayer() = default; - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.cc index 89507d0d2..f12b1aa5d 100644 --- a/src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.cc @@ -19,9 +19,9 @@ namespace fastertransformer { template -void TensorParallelGptContextAttentionLayer::forward(std::vector* output_tensors, +void TensorParallelGptContextAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input_tensors: // input_query [batch_size * seq_len, hidden_dimension] @@ -33,18 +33,18 @@ void TensorParallelGptContextAttentionLayer::forward(std::vectorat(0).shape[0]; + const size_t m = output_tensors->at(0).shape[0]; const size_t hidden_units = output_tensors->at(0).shape[1]; bool use_custom_all_reduce_kernel = false; - if (enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr) { + if (do_all_reduce_ && enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr) { use_custom_all_reduce_kernel = custom_all_reduce_comm_->swapInternalBuffer(output_tensors, m * hidden_units); } GptContextAttentionLayer::forward(output_tensors, input_tensors, attention_weights); T* attention_out = (T*)(output_tensors->at(0).data); - if (tensor_para_.world_size_ > 1) { + if (do_all_reduce_ && tensor_para_.world_size_ > 1) { if (!use_custom_all_reduce_kernel) { ftNcclAllReduceSum( attention_out, attention_out, m * hidden_units, tensor_para_, GptContextAttentionLayer::stream_); @@ -58,19 +58,20 @@ void TensorParallelGptContextAttentionLayer::forward(std::vector TensorParallelGptContextAttentionLayer::TensorParallelGptContextAttentionLayer( - size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse, + size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): GptContextAttentionLayer(max_batch_size, max_seq_len, head_num, @@ -84,33 +85,37 @@ TensorParallelGptContextAttentionLayer::TensorParallelGptContextAttentionLaye sparse), tensor_para_(tensor_para), custom_all_reduce_comm_(custom_all_reduce_comm), - enable_custom_all_reduce_(enable_custom_all_reduce) + enable_custom_all_reduce_(enable_custom_all_reduce), + do_all_reduce_(do_all_reduce) { FT_CHECK(head_num % tensor_para_.world_size_ == 0); } template TensorParallelGptContextAttentionLayer::TensorParallelGptContextAttentionLayer( - size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t rotary_embedding_dim, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse, + size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): GptContextAttentionLayer(max_batch_size, max_seq_len, head_num, size_per_head, head_num / tensor_para.world_size_, rotary_embedding_dim, + neox_rotary_style, stream, cublas_wrapper, allocator, @@ -119,8 +124,10 @@ TensorParallelGptContextAttentionLayer::TensorParallelGptContextAttentionLaye sparse), tensor_para_(tensor_para), custom_all_reduce_comm_(custom_all_reduce_comm), - enable_custom_all_reduce_(enable_custom_all_reduce) + enable_custom_all_reduce_(enable_custom_all_reduce), + do_all_reduce_(do_all_reduce) { + FT_CHECK(head_num % tensor_para_.world_size_ == 0); } template @@ -129,7 +136,8 @@ TensorParallelGptContextAttentionLayer::TensorParallelGptContextAttentionLaye GptContextAttentionLayer(attention_layer), tensor_para_(attention_layer.tensor_para_), custom_all_reduce_comm_(attention_layer.custom_all_reduce_comm_), - enable_custom_all_reduce_(attention_layer.enable_custom_all_reduce_) + enable_custom_all_reduce_(attention_layer.enable_custom_all_reduce_), + do_all_reduce_(attention_layer.do_all_reduce_) { } diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.h b/src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.h index c71f7cbcc..271aa16ae 100644 --- a/src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.h @@ -25,47 +25,51 @@ namespace fastertransformer { template class TensorParallelGptContextAttentionLayer: public GptContextAttentionLayer { private: - NcclParam tensor_para_; + NcclParam tensor_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; + bool do_all_reduce_; public: - TensorParallelGptContextAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse = false, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + TensorParallelGptContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); - TensorParallelGptContextAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t rotary_embedding_dim, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse = false, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + TensorParallelGptContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); TensorParallelGptContextAttentionLayer(TensorParallelGptContextAttentionLayer const& attention_layer); ~TensorParallelGptContextAttentionLayer() = default; - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelUnfusedAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/TensorParallelUnfusedAttentionLayer.cc index 9e0278c3c..6ed606b85 100644 --- a/src/fastertransformer/layers/attention_layers/TensorParallelUnfusedAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/TensorParallelUnfusedAttentionLayer.cc @@ -19,9 +19,9 @@ namespace fastertransformer { template -void TensorParallelUnfusedAttentionLayer::forward(std::vector* output_tensors, +void TensorParallelUnfusedAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input_tensors: // input_query [token_num, d_model], @@ -56,20 +56,20 @@ void TensorParallelUnfusedAttentionLayer::forward(std::vector TensorParallelUnfusedAttentionLayer::TensorParallelUnfusedAttentionLayer( - size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t d_model, - float q_scaling, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse, + size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_sparse, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): UnfusedAttentionLayer(max_batch_size, max_seq_len, head_num / tensor_para.world_size_, @@ -97,5 +97,8 @@ TensorParallelUnfusedAttentionLayer::TensorParallelUnfusedAttentionLayer( template class TensorParallelUnfusedAttentionLayer; template class TensorParallelUnfusedAttentionLayer; +#ifdef ENABLE_BF16 +template class TensorParallelUnfusedAttentionLayer<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelUnfusedAttentionLayer.h b/src/fastertransformer/layers/attention_layers/TensorParallelUnfusedAttentionLayer.h index b0b6ee53d..d5040564e 100644 --- a/src/fastertransformer/layers/attention_layers/TensorParallelUnfusedAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/TensorParallelUnfusedAttentionLayer.h @@ -25,33 +25,33 @@ namespace fastertransformer { template class TensorParallelUnfusedAttentionLayer: public UnfusedAttentionLayer { private: - NcclParam tensor_para_; + NcclParam tensor_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; public: - TensorParallelUnfusedAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t d_model, - float q_scaling, - NcclParam tensor_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_sparse, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + TensorParallelUnfusedAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_sparse, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); TensorParallelUnfusedAttentionLayer(TensorParallelUnfusedAttentionLayer const& attention_layer); ~TensorParallelUnfusedAttentionLayer() = default; - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/UnfusedAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/UnfusedAttentionLayer.cc index 1edd52f87..6dafdab9d 100644 --- a/src/fastertransformer/layers/attention_layers/UnfusedAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/UnfusedAttentionLayer.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,9 +20,9 @@ namespace fastertransformer { template -void UnfusedAttentionLayer::forward(std::vector* output_tensors, +void UnfusedAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input_tensors: // input_query (token_num, d_model), @@ -32,26 +32,23 @@ void UnfusedAttentionLayer::forward(std::vector* o // If padding_offset.data is nullptr, then not remove padding FT_LOG_DEBUG(__PRETTY_FUNCTION__); - FT_CHECK(isValidBatchSize(input_tensors->at(1).shape[0])); - FT_CHECK(isValidSeqLen(input_tensors->at(1).shape[2])); FT_CHECK(input_tensors->size() == 3 || input_tensors->size() == 4); - // allocateBuffer(); const int request_batch_size = input_tensors->at(1).shape[0]; - const int request_seq_len = input_tensors->at(1).shape[2]; + const int request_seq_len = input_tensors->at(1).shape[2]; allocateBuffer(request_batch_size, request_seq_len); - T* attention_out = (T*)output_tensors->at(0).data; - const T* from_tensor = (const T*)input_tensors->at(0).data; - const T* attention_mask = (const T*)input_tensors->at(1).data; - const int* padding_offset = (const int*)input_tensors->at(2).data; - const T* relative_attention_bias = input_tensors->size() == 4 ? (const T*)input_tensors->at(3).data : nullptr; + T* attention_out = (T*)output_tensors->at(0).data; + const T* from_tensor = (const T*)input_tensors->at(0).data; + const T* attention_mask = (const T*)input_tensors->at(1).data; + const int* padding_offset = (const int*)input_tensors->at(2).data; + const T* relative_attention_bias = input_tensors->size() == 4 ? (const T*)input_tensors->at(3).data : nullptr; - bool with_bias = attention_weights->query_weight.bias != nullptr ? true : false; + bool with_bias = attention_weights->query_weight.bias != nullptr ? true : false; bool use_relative_position_bias = relative_attention_bias != nullptr ? true : false; const int m = input_tensors->at(0).shape[0]; - int k = d_model_; - int n = hidden_units_; + int k = d_model_; + int n = hidden_units_; #ifdef SPARSITY_ENABLED int m_tmp = m; if (m_tmp % 8 != 0) { @@ -195,8 +192,15 @@ void UnfusedAttentionLayer::forward(std::vector* o qk_buf_, relative_attention_bias, request_batch_size, head_num_, request_seq_len, stream_); } - invokeMaskedSoftMax( - qk_buf_, qk_buf_, attention_mask, request_batch_size, request_seq_len, head_num_, (T)1.0f, stream_); + invokeMaskedSoftMax(qk_buf_, + qk_buf_, + attention_mask, + request_batch_size, + request_seq_len, + request_seq_len, + head_num_, + (T)1.0f, + stream_); sync_check_cuda_error(); cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, @@ -270,16 +274,16 @@ void UnfusedAttentionLayer::forward(std::vector* o } template -UnfusedAttentionLayer::UnfusedAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - float q_scaling, - cudaStream_t stream, +UnfusedAttentionLayer::UnfusedAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse): UnfusedAttentionLayer(max_batch_size, max_seq_len, head_num, @@ -295,20 +299,18 @@ UnfusedAttentionLayer::UnfusedAttentionLayer(size_t max_batch_size, } template -UnfusedAttentionLayer::UnfusedAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t d_model, - float q_scaling, - cudaStream_t stream, +UnfusedAttentionLayer::UnfusedAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), - max_batch_size_(max_batch_size), - max_seq_len_(max_seq_len), head_num_(head_num), size_per_head_(size_per_head), d_model_(d_model), @@ -324,8 +326,6 @@ UnfusedAttentionLayer::UnfusedAttentionLayer(UnfusedAttentionLayer const& attention_layer.cublas_wrapper_, attention_layer.allocator_, attention_layer.is_free_buffer_after_forward_), - max_batch_size_(attention_layer.max_batch_size_), - max_seq_len_(attention_layer.max_seq_len_), head_num_(attention_layer.head_num_), size_per_head_(attention_layer.size_per_head_), d_model_(attention_layer.d_model_), @@ -346,39 +346,26 @@ UnfusedAttentionLayer::~UnfusedAttentionLayer() template void UnfusedAttentionLayer::allocateBuffer() { - if (is_allocate_buffer_ == false) { - q_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - k_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - v_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - q_buf_2_ = (T*)allocator_->malloc(sizeof(T) * 3 * max_batch_size_ * max_seq_len_ * hidden_units_, false); - k_buf_2_ = q_buf_2_ + max_batch_size_ * max_seq_len_ * hidden_units_; - v_buf_2_ = k_buf_2_ + max_batch_size_ * max_seq_len_ * hidden_units_; - qk_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * head_num_ * max_seq_len_ * max_seq_len_, false); - qkv_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - qkv_buf_2_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - batch_qkv_kernel_ptr_ = (T**)allocator_->malloc(sizeof(T*) * 12, false); - batch_qkv_input_ptr_ = batch_qkv_kernel_ptr_ + 4; - batch_qkv_buf_ptr_ = batch_qkv_input_ptr_ + 4; - } + FT_CHECK(false); } template void UnfusedAttentionLayer::allocateBuffer(size_t batch_size, size_t seq_len) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - q_buf_ = (T*)allocator_->reMalloc(q_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); - k_buf_ = (T*)allocator_->reMalloc(k_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); - v_buf_ = (T*)allocator_->reMalloc(v_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); - q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * 3 * batch_size * seq_len * hidden_units_, false); - k_buf_2_ = q_buf_2_ + batch_size * seq_len * hidden_units_; - v_buf_2_ = k_buf_2_ + batch_size * seq_len * hidden_units_; - qk_buf_ = (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * head_num_ * seq_len * seq_len, false); - qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); + q_buf_ = (T*)allocator_->reMalloc(q_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); + k_buf_ = (T*)allocator_->reMalloc(k_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); + v_buf_ = (T*)allocator_->reMalloc(v_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); + q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * 3 * batch_size * seq_len * hidden_units_, false); + k_buf_2_ = q_buf_2_ + batch_size * seq_len * hidden_units_; + v_buf_2_ = k_buf_2_ + batch_size * seq_len * hidden_units_; + qk_buf_ = (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * head_num_ * seq_len * seq_len, false); + qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); qkv_buf_2_ = (T*)allocator_->reMalloc(qkv_buf_2_, sizeof(T) * batch_size * seq_len * hidden_units_, false); batch_qkv_kernel_ptr_ = (T**)allocator_->reMalloc(batch_qkv_kernel_ptr_, sizeof(T*) * 12, false); - batch_qkv_input_ptr_ = batch_qkv_kernel_ptr_ + 4; - batch_qkv_buf_ptr_ = batch_qkv_input_ptr_ + 4; - is_allocate_buffer_ = true; + batch_qkv_input_ptr_ = batch_qkv_kernel_ptr_ + 4; + batch_qkv_buf_ptr_ = batch_qkv_input_ptr_ + 4; + is_allocate_buffer_ = true; } template @@ -386,38 +373,23 @@ void UnfusedAttentionLayer::freeBuffer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_) { - allocator_->free(q_buf_); - allocator_->free(k_buf_); - allocator_->free(v_buf_); - allocator_->free(q_buf_2_); - allocator_->free(qk_buf_); - allocator_->free(qkv_buf_); - allocator_->free(qkv_buf_2_); - allocator_->free(batch_qkv_kernel_ptr_); + allocator_->free((void**)(&q_buf_)); + allocator_->free((void**)(&k_buf_)); + allocator_->free((void**)(&v_buf_)); + allocator_->free((void**)(&q_buf_2_)); + allocator_->free((void**)(&qk_buf_)); + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&qkv_buf_2_)); + allocator_->free((void**)(&batch_qkv_kernel_ptr_)); sync_check_cuda_error(); is_allocate_buffer_ = false; } } -template -bool UnfusedAttentionLayer::isValidBatchSize(size_t batch_size) -{ - if (max_batch_size_ < batch_size) { - max_batch_size_ = batch_size; - } - return true; -} - -template -bool UnfusedAttentionLayer::isValidSeqLen(size_t seq_len) -{ - if (max_seq_len_ < seq_len) { - max_seq_len_ = seq_len; - } - return true; -} - template class UnfusedAttentionLayer; template class UnfusedAttentionLayer; +#ifdef ENABLE_BF16 +template class UnfusedAttentionLayer<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/attention_layers/UnfusedAttentionLayer.h b/src/fastertransformer/layers/attention_layers/UnfusedAttentionLayer.h index a6e1374d7..c5fc63424 100644 --- a/src/fastertransformer/layers/attention_layers/UnfusedAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/UnfusedAttentionLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,22 +23,16 @@ namespace fastertransformer { template class UnfusedAttentionLayer: public BaseAttentionLayer { private: - // buffer handling - size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; - // metadata size_t head_num_; size_t size_per_head_; size_t hidden_units_; size_t d_model_; - bool sparse_; - float q_scaling_; + bool sparse_; + float q_scaling_; void allocateBuffer() override; void freeBuffer() override; - bool isValidBatchSize(size_t batch_size); - bool isValidSeqLen(size_t seq_len); void allocateBuffer(size_t batch_size, size_t seq_len); protected: @@ -48,51 +42,51 @@ class UnfusedAttentionLayer: public BaseAttentionLayer { using BaseAttentionLayer::cublas_wrapper_; using BaseAttentionLayer::allocator_; - T* q_buf_ = nullptr; - T* k_buf_ = nullptr; - T* v_buf_ = nullptr; - T* q_buf_2_ = nullptr; - T* k_buf_2_ = nullptr; - T* v_buf_2_ = nullptr; - T* qk_buf_ = nullptr; - T* qkv_buf_ = nullptr; + T* q_buf_ = nullptr; + T* k_buf_ = nullptr; + T* v_buf_ = nullptr; + T* q_buf_2_ = nullptr; + T* k_buf_2_ = nullptr; + T* v_buf_2_ = nullptr; + T* qk_buf_ = nullptr; + T* qkv_buf_ = nullptr; T* qkv_buf_2_ = nullptr; T** batch_qkv_kernel_ptr_ = nullptr; - T** batch_qkv_input_ptr_ = nullptr; - T** batch_qkv_buf_ptr_ = nullptr; + T** batch_qkv_input_ptr_ = nullptr; + T** batch_qkv_buf_ptr_ = nullptr; public: - UnfusedAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - float q_scaling, - cudaStream_t stream, + UnfusedAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false); - UnfusedAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t d_model, - float q_scaling, - cudaStream_t stream, + UnfusedAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false); UnfusedAttentionLayer(UnfusedAttentionLayer const& attention_layer); ~UnfusedAttentionLayer(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/WindowAttention.cc b/src/fastertransformer/layers/attention_layers/WindowAttention.cc index d74d2769f..ea714819a 100644 --- a/src/fastertransformer/layers/attention_layers/WindowAttention.cc +++ b/src/fastertransformer/layers/attention_layers/WindowAttention.cc @@ -85,27 +85,30 @@ void WindowAttention::allocateBuffer() } if (is_allocate_buffer_ == false) { if (use_trt_) { - int S = trt_getS(window_len_); - trt_attention_mask_ = (half*)allocator_->malloc(roundByteSize(window_num_ * S * S * sizeof(T), 4), false); - trt_relative_position_bias_ = - (half*)allocator_->malloc(roundByteSize(num_head * S * S * sizeof(T), 4), false); - qkv_buf_ = - (T*)allocator_->malloc(3 * max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); - q_buf_ = (T*)allocator_->malloc(3 * max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); - k_buf_ = q_buf_ + max_batch_ * window_num_ * window_len_ * embed_dim_; - v_buf_ = k_buf_ + max_batch_ * window_num_ * window_len_ * embed_dim_; + int S = trt_getS(window_len_); + trt_attention_mask_ = (half*)allocator_->reMalloc( + trt_attention_mask_, roundByteSize(window_num_ * S * S * sizeof(T), 4), false); + trt_relative_position_bias_ = (half*)allocator_->reMalloc( + trt_relative_position_bias_, roundByteSize(num_head * S * S * sizeof(T), 4), false); + qkv_buf_ = (T*)allocator_->reMalloc( + qkv_buf_, 3 * max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); + q_buf_ = (T*)allocator_->reMalloc( + q_buf_, 3 * max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); + k_buf_ = q_buf_ + max_batch_ * window_num_ * window_len_ * embed_dim_; + v_buf_ = k_buf_ + max_batch_ * window_num_ * window_len_ * embed_dim_; qk_buf_ = nullptr; } else { - trt_attention_mask_ = nullptr; + trt_attention_mask_ = nullptr; trt_relative_position_bias_ = nullptr; - qkv_buf_ = - (T*)allocator_->malloc(3 * max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); - q_buf_ = (T*)allocator_->malloc(3 * max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); - k_buf_ = q_buf_ + max_batch_ * window_num_ * window_len_ * embed_dim_; - v_buf_ = k_buf_ + max_batch_ * window_num_ * window_len_ * embed_dim_; - qk_buf_ = (T*)allocator_->malloc( - 3 * max_batch_ * window_num_ * num_head * window_len_ * window_len_ * sizeof(T), false); + qkv_buf_ = (T*)allocator_->reMalloc( + qkv_buf_, 3 * max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); + q_buf_ = (T*)allocator_->reMalloc( + q_buf_, 3 * max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); + k_buf_ = q_buf_ + max_batch_ * window_num_ * window_len_ * embed_dim_; + v_buf_ = k_buf_ + max_batch_ * window_num_ * window_len_ * embed_dim_; + qk_buf_ = (T*)allocator_->reMalloc( + qk_buf_, 3 * max_batch_ * window_num_ * num_head * window_len_ * window_len_ * sizeof(T), false); } is_allocate_buffer_ = true; } @@ -116,29 +119,29 @@ void WindowAttention::freeBuffer() { if (is_allocate_buffer_ == true) { if (use_trt_) { - allocator_->free(trt_attention_mask_); - allocator_->free(trt_relative_position_bias_); - allocator_->free(qkv_buf_); - allocator_->free(q_buf_); + allocator_->free((void**)(&trt_attention_mask_)); + allocator_->free((void**)(&trt_relative_position_bias_)); + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&q_buf_)); } else { - allocator_->free(qkv_buf_); - allocator_->free(q_buf_); - allocator_->free(qk_buf_); + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&q_buf_)); + allocator_->free((void**)(&qk_buf_)); } is_allocate_buffer_ = false; } } template -WindowAttention::WindowAttention(int max_batch, - int window_size, - cudaStream_t stream, +WindowAttention::WindowAttention(int max_batch, + int window_size, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias, - float qk_scale): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias, + float qk_scale): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_(max_batch), window_size_(window_size), @@ -154,9 +157,9 @@ WindowAttention::~WindowAttention() } template -void WindowAttention::forward(std::vector* output_tensors, +void WindowAttention::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input_tensors: // input [batch * window_num * window_len, dim] @@ -166,28 +169,28 @@ void WindowAttention::forward(std::vector* output_ // output_tensors: // output [batch * window_num * window_len, dim] - T* output = (T*)output_tensors->at(0).data; - const T* input = (const T*)input_tensors->at(0).data; - const T* attention_mask = (const T*)input_tensors->at(1).data; - const T* attention_relative_pos_bias = (const T*)input_tensors->at(2).data; - const int* additional_params = (const int*)input_tensors->at(3).data; - const int batch = additional_params[0]; - const int dim = additional_params[1]; - const int input_resolution = additional_params[2]; - const int num_head = additional_params[3]; - const int shift_size = additional_params[4]; - const int sm = additional_params[5]; + T* output = (T*)output_tensors->at(0).data; + const T* input = (const T*)input_tensors->at(0).data; + const T* attention_mask = (const T*)input_tensors->at(1).data; + const T* attention_relative_pos_bias = (const T*)input_tensors->at(2).data; + const int* additional_params = (const int*)input_tensors->at(3).data; + const int batch = additional_params[0]; + const int dim = additional_params[1]; + const int input_resolution = additional_params[2]; + const int num_head = additional_params[3]; + const int shift_size = additional_params[4]; + const int sm = additional_params[5]; int size_per_head = dim / num_head; - int trt_S = 1024; + int trt_S = 1024; if ((sm == 75 || sm == 80 || sm == 86) && size_per_head == 32 && window_len_ <= TRT_MAX_LEN && std::is_same::value) { - trt_S = trt_getS(window_len_); + trt_S = trt_getS(window_len_); use_trt_ = true; } - num_head_ = num_head; + num_head_ = num_head; window_num_ = (input_resolution / window_size_) * (input_resolution / window_size_); - embed_dim_ = dim; + embed_dim_ = dim; allocateBuffer(); float scale = 1.0f / sqrt(size_per_head); @@ -387,5 +390,8 @@ void WindowAttention::forward(std::vector* output_ template class WindowAttention; template class WindowAttention; +#ifdef ENABLE_BF16 +template class WindowAttention<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/layers/attention_layers/WindowAttention.h b/src/fastertransformer/layers/attention_layers/WindowAttention.h index 13f066e84..ee82e1798 100644 --- a/src/fastertransformer/layers/attention_layers/WindowAttention.h +++ b/src/fastertransformer/layers/attention_layers/WindowAttention.h @@ -37,24 +37,24 @@ template class WindowAttention: public BaseAttentionLayer { private: - int max_batch_ = 1; - int dim_ = 96; - int num_head_ = 2; - int window_size_ = 7; - int head_dim_ = 48; - int input_resolution_ = 56; - int window_len_ = 49; - int embed_dim_ = 96; - int window_num_ = 64; - int size_per_head_; - bool qkv_bias_ = true; + int max_batch_ = 1; + int dim_ = 96; + int num_head_ = 2; + int window_size_ = 7; + int head_dim_ = 48; + int input_resolution_ = 56; + int window_len_ = 49; + int embed_dim_ = 96; + int window_num_ = 64; + int size_per_head_; + bool qkv_bias_ = true; float qk_scale_ = 1.0f; - bool use_trt_ = false; + bool use_trt_ = false; void allocateBuffer() override; void freeBuffer() override; - int dispatcher_fp16_num_head_ = -1; + int dispatcher_fp16_num_head_ = -1; std::unique_ptr dispatcher_fp16_; using BaseAttentionLayer::stream_; @@ -76,20 +76,20 @@ class WindowAttention: public BaseAttentionLayer { static size_t getBufSize(const int batch, const int num_head, const int window_num, const int window_len, const int dim); - WindowAttention(int max_batch, - int window_size, - cudaStream_t stream, + WindowAttention(int max_batch, + int window_size, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward = false, - bool qkv_bias = true, - float qk_scale = 1.0f); + IAllocator* allocator, + bool is_free_buffer_after_forward = false, + bool qkv_bias = true, + float qk_scale = 1.0f); ~WindowAttention(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights); + const AttentionWeight* attention_weights); }; // class WindowAttention } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers_int8/CMakeLists.txt b/src/fastertransformer/layers/attention_layers_int8/CMakeLists.txt index 03f6b240c..44a79636e 100644 --- a/src/fastertransformer/layers/attention_layers_int8/CMakeLists.txt +++ b/src/fastertransformer/layers/attention_layers_int8/CMakeLists.txt @@ -17,12 +17,14 @@ cmake_minimum_required(VERSION 3.8) add_library(UnfusedAttentionLayerINT8 STATIC UnfusedAttentionLayerINT8.cc) set_property(TARGET UnfusedAttentionLayerINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET UnfusedAttentionLayerINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(UnfusedAttentionLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart cublasMMWrapper cublasINT8MMWrapper softmax_int8_kernels transpose_int8_kernels unfused_attention_int8_kernels) +target_link_libraries(UnfusedAttentionLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart nvtx_utils cublasMMWrapper + cublasINT8MMWrapper softmax_int8_kernels transpose_int8_kernels unfused_attention_int8_kernels) add_library(FusedAttentionLayerINT8 STATIC FusedAttentionLayerINT8.cu) set_property(TARGET FusedAttentionLayerINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET FusedAttentionLayerINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(FusedAttentionLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart cublasMMWrapper cublasINT8MMWrapper trt_fused_multi_head_attention) +target_link_libraries(FusedAttentionLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart nvtx_utils cublasMMWrapper + cublasINT8MMWrapper trt_fused_multi_head_attention) add_library(WindowAttentionINT8 STATIC WindowAttentionINT8.cu) set_property(TARGET WindowAttentionINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.cu b/src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.cu index 527db49f5..2c73faec0 100644 --- a/src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.cu +++ b/src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.cu @@ -17,6 +17,7 @@ #include "src/fastertransformer/kernels/int8_utils.cuh" #include "src/fastertransformer/kernels/layout_transformer_int8_kernels.h" #include "src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.h" +#include "src/fastertransformer/utils/nvtx_utils.h" namespace fastertransformer { @@ -29,32 +30,32 @@ namespace fastertransformer { // size should be a multiple of 4 // using char4 as output, int4 as input template -__global__ void trt_add_QKV_bias_COL32_int32IInt8O(char4* output, - const int4* QKV, - const T* bias_Q, - const T* bias_K, - const T* bias_V, +__global__ void trt_add_QKV_bias_COL32_int32IInt8O(char4* output, + const int4* QKV, + const T* bias_Q, + const T* bias_K, + const T* bias_V, const float* input_deQFactor_div127_ptr, const float* q_weight_amax, const float* k_weight_amax, const float* v_weight_amax, - const float qkv_output_scale, - const int valid_word_num, - const int head_num, - const int size_per_head, - const int head_num_x_size_per_head) + const float qkv_output_scale, + const int valid_word_num, + const int head_num, + const int size_per_head, + const int head_num_x_size_per_head) { - const int qkv_id = blockIdx.z; - const int seq_id = (blockIdx.y << 5) + threadIdx.y; + const int qkv_id = blockIdx.z; + const int seq_id = (blockIdx.y << 5) + threadIdx.y; const int threadIdx4 = threadIdx.x << 2; - int hidden_id = (blockIdx.x << 5) + threadIdx4; - const int size_id = hidden_id % size_per_head; - const int head_id = hidden_id / size_per_head; + int hidden_id = (blockIdx.x << 5) + threadIdx4; + const int size_id = hidden_id % size_per_head; + const int head_id = hidden_id / size_per_head; const bool qual = (seq_id < valid_word_num) && (hidden_id < head_num_x_size_per_head); if (qual) { const float* weight_amax = qkv_id == 0 ? q_weight_amax : (qkv_id == 1 ? k_weight_amax : v_weight_amax); - const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); + const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); const T* bias_ptr = (qkv_id == 0) ? bias_Q : ((qkv_id == 1) ? bias_K : bias_V); @@ -62,7 +63,7 @@ __global__ void trt_add_QKV_bias_COL32_int32IInt8O(char4* output, + ((hidden_id & 0xffffffe0) * valid_word_num + (seq_id << 5) + (hidden_id & 31))) >> 2; - char4 tmp; + char4 tmp; const int4 tmp_int4 = __ldg(QKV + input_id); tmp.x = @@ -98,23 +99,23 @@ __global__ void trt_add_QKV_bias_COL32_int32IInt8O(char4* output, } template -void invokeTrtAddQkvBiasInt32Iint8O(int8_t* output, - const int32_t* Q, - const T* bias_Q, - const T* bias_K, - const T* bias_V, - const size_t token_num, - const size_t head_num, - const size_t size_per_head, - const float* input_deQFactor_div127_ptr, - const float* q_weight_amax, - const float* k_weight_amax, - const float* v_weight_amax, - const float mScaleQkv, +void invokeTrtAddQkvBiasInt32Iint8O(int8_t* output, + const int32_t* Q, + const T* bias_Q, + const T* bias_K, + const T* bias_V, + const size_t token_num, + const size_t head_num, + const size_t size_per_head, + const float* input_deQFactor_div127_ptr, + const float* q_weight_amax, + const float* k_weight_amax, + const float* v_weight_amax, + const float mScaleQkv, const cudaStream_t stream) { - int head_num_x_size_per_head = head_num * size_per_head; + int head_num_x_size_per_head = head_num * size_per_head; dim3 grid((head_num_x_size_per_head + 31) / 32, (token_num + 31) / 32, 3); dim3 block(8, 32); @@ -142,26 +143,26 @@ void invokeTrtAddQkvBiasInt32Iint8O(int8_t* output, // block(8, 32) // size should be a multiple of 4 template -__global__ void trt_add_QKV_bias_COL32_int8IO(char4* output, +__global__ void trt_add_QKV_bias_COL32_int8IO(char4* output, const char4* QKV, - const T* bias_Q, - const T* bias_K, - const T* bias_V, + const T* bias_Q, + const T* bias_K, + const T* bias_V, const float* q_input_deQFactor_ptr, const float* k_input_deQFactor_ptr, const float* v_input_deQFactor_ptr, - const float qkv_output_scale, - const int valid_word_num, - const int head_num, - const int size_per_head, - const int head_num_x_size_per_head) + const float qkv_output_scale, + const int valid_word_num, + const int head_num, + const int size_per_head, + const int head_num_x_size_per_head) { - const int qkv_id = blockIdx.z; - const int seq_id = (blockIdx.y << 5) + threadIdx.y; + const int qkv_id = blockIdx.z; + const int seq_id = (blockIdx.y << 5) + threadIdx.y; const int threadIdx4 = threadIdx.x << 2; - const int hidden_id = (blockIdx.x << 5) + threadIdx4; - const int size_id = hidden_id % size_per_head; - const int head_id = hidden_id / size_per_head; + const int hidden_id = (blockIdx.x << 5) + threadIdx4; + const int size_id = hidden_id % size_per_head; + const int head_id = hidden_id / size_per_head; const bool qual = (seq_id < valid_word_num) && (hidden_id < head_num_x_size_per_head); if (qual) { @@ -203,22 +204,22 @@ __global__ void trt_add_QKV_bias_COL32_int8IO(char4* output, } template -void invokeTrtAddQkvBiasInt8IO(int8_t* output, - const int8_t* Q, - const T* bias_Q, - const T* bias_K, - const T* bias_V, - const size_t token_num, - const size_t head_num, - const size_t size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* v_input_deQFactor_ptr, - const float mScaleQkv, +void invokeTrtAddQkvBiasInt8IO(int8_t* output, + const int8_t* Q, + const T* bias_Q, + const T* bias_K, + const T* bias_V, + const size_t token_num, + const size_t head_num, + const size_t size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* v_input_deQFactor_ptr, + const float mScaleQkv, const cudaStream_t stream) { - int head_num_x_size_per_head = head_num * size_per_head; + int head_num_x_size_per_head = head_num * size_per_head; dim3 grid((head_num_x_size_per_head + 31) / 32, (token_num + 31) / 32, 3); dim3 block(8, 32); @@ -245,26 +246,26 @@ void invokeTrtAddQkvBiasInt8IO(int8_t* output, // block(8, 32) // size should be a multiple of 4 template -__global__ void trt_add_QKV_bias_ROW_int8IO(char4* output, +__global__ void trt_add_QKV_bias_ROW_int8IO(char4* output, const char4* QKV, - const T* bias_Q, - const T* bias_K, - const T* bias_V, + const T* bias_Q, + const T* bias_K, + const T* bias_V, const float* q_input_deQFactor_ptr, const float* k_input_deQFactor_ptr, const float* v_input_deQFactor_ptr, - const float qkv_output_scale, - const int valid_word_num, - const int head_num, - const int size_per_head, - const int head_num_x_size_per_head) + const float qkv_output_scale, + const int valid_word_num, + const int head_num, + const int size_per_head, + const int head_num_x_size_per_head) { - const int qkv_id = blockIdx.z; - const int seq_id = (blockIdx.y << 5) + threadIdx.y; + const int qkv_id = blockIdx.z; + const int seq_id = (blockIdx.y << 5) + threadIdx.y; const int threadIdx4 = threadIdx.x << 2; - const int hidden_id = (blockIdx.x << 5) + threadIdx4; - const int size_id = hidden_id % size_per_head; - const int head_id = hidden_id / size_per_head; + const int hidden_id = (blockIdx.x << 5) + threadIdx4; + const int size_id = hidden_id % size_per_head; + const int head_id = hidden_id / size_per_head; const bool qual = (seq_id < valid_word_num) && (hidden_id < head_num_x_size_per_head); if (qual) { @@ -303,22 +304,22 @@ __global__ void trt_add_QKV_bias_ROW_int8IO(char4* output, } template -void invokeTrtAddQkvBiasInt8IORow(int8_t* output, - const int8_t* Q, - const T* bias_Q, - const T* bias_K, - const T* bias_V, - const size_t token_num, - const size_t head_num, - const size_t size_per_head, - const float* q_input_deQFactor_ptr, - const float* k_input_deQFactor_ptr, - const float* v_input_deQFactor_ptr, - const float mScaleQkv, +void invokeTrtAddQkvBiasInt8IORow(int8_t* output, + const int8_t* Q, + const T* bias_Q, + const T* bias_K, + const T* bias_V, + const size_t token_num, + const size_t head_num, + const size_t size_per_head, + const float* q_input_deQFactor_ptr, + const float* k_input_deQFactor_ptr, + const float* v_input_deQFactor_ptr, + const float mScaleQkv, const cudaStream_t stream) { - int head_num_x_size_per_head = head_num * size_per_head; + int head_num_x_size_per_head = head_num * size_per_head; dim3 grid((head_num_x_size_per_head + 31) / 32, (token_num + 31) / 32, 3); dim3 block(8, 32); @@ -340,9 +341,9 @@ void invokeTrtAddQkvBiasInt8IORow(int8_t* output, } template -void FusedAttentionLayerINT8::forward(std::vector* output_tensors, +void FusedAttentionLayerINT8::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input_tensors: [input (token_num, hidden_dimension), // attention_mask (batch, 1, seqlen, seqlen), @@ -350,7 +351,7 @@ void FusedAttentionLayerINT8::forward(std::vector* // output_tensors: [output (token_num, hidden_dimension)] // If padding_offset.data is nullptr, then not remove padding - const ScaleList* scale_list = ((const AttentionINT8Weight*)attention_weights)->scale_list_ptr; + const ScaleList* scale_list = ((const AttentionINT8Weight*)attention_weights)->scale_list_ptr; cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_; // input_tensors: [input_query (token_num, hidden_dimension), @@ -362,16 +363,16 @@ void FusedAttentionLayerINT8::forward(std::vector* FT_CHECK(isValidSeqLen(input_tensors->at(1).shape[2])); allocateBuffer(); - int32_t* attention_out = (int32_t*)output_tensors->at(0).data; - const int8_t* from_tensor = (const int8_t*)input_tensors->at(0).data; - const T* attention_mask = (const T*)input_tensors->at(1).data; - const int* padding_offset = (const int*)input_tensors->at(2).data; + int32_t* attention_out = (int32_t*)output_tensors->at(0).data; + const int8_t* from_tensor = (const int8_t*)input_tensors->at(0).data; + const T* attention_mask = (const T*)input_tensors->at(1).data; + const int* padding_offset = (const int*)input_tensors->at(2).data; const int request_batch_size = input_tensors->at(1).shape[0]; - const int request_seq_len = input_tensors->at(1).shape[2]; - const int m = input_tensors->at(0).shape[0]; - const int k = hidden_units_; - const int n = hidden_units_; + const int request_seq_len = input_tensors->at(1).shape[2]; + const int m = input_tensors->at(0).shape[0]; + const int k = hidden_units_; + const int n = hidden_units_; #ifdef SPARSITY_ENABLED int m_tmp = m; if (m_tmp % 16 != 0) { @@ -379,12 +380,12 @@ void FusedAttentionLayerINT8::forward(std::vector* } const int m_padded = m_tmp; #endif - const int fusedINT8QKV_type = cublas_wrapper->getFusedINT8QKVType(k, n, attention_weights); if (int8_mode_ == 1) { // K_int_buf_ V_int_buf_ should point to correct buffer according to m K_int_buf_ = (int*)Q_int_buf_ + m * head_num_ * size_per_head_; V_int_buf_ = (int*)K_int_buf_ + m * head_num_ * size_per_head_; + PUSH_RANGE("qkv_gemm"); if (fusedINT8QKV_type == 0) { cublas_wrapper->Gemm( Q_int_buf_, 1, m, n, k, 0, 0, 0, from_tensor, (int8_t*)(attention_weights->query_weight.kernel)); @@ -406,6 +407,9 @@ void FusedAttentionLayerINT8::forward(std::vector* from_tensor, (int8_t*)(attention_weights->query_weight.kernel)); } + POP_RANGE; + + PUSH_RANGE("invokeTrtAddQkvBiasInt32Iint8O"); invokeTrtAddQkvBiasInt32Iint8O(qkv_buf_, Q_int_buf_, attention_weights->query_weight.bias, @@ -420,6 +424,7 @@ void FusedAttentionLayerINT8::forward(std::vector* &(scale_list->d_scale_list_[scale_list->p2_offset_ + 2 * hidden_units_]), scale_list->h_scale_list_[scale_list->p4_offset_] / 127.0f, stream_); + POP_RANGE; } else if (int8_mode_ == 2 || int8_mode_ == 3) { // K_int_buf_ V_int_buf_ should point to correct buffer according to m @@ -428,6 +433,7 @@ void FusedAttentionLayerINT8::forward(std::vector* #ifdef SPARSITY_ENABLED if (sparse_) { + PUSH_RANGE("qkv_gemm"); cublas_wrapper->SpGemm(n, m_padded, k, @@ -449,9 +455,11 @@ void FusedAttentionLayerINT8::forward(std::vector* (int8_t*)(attention_weights->value_weight.sp_kernel), from_tensor, (int8_t*)V_int_buf_); + POP_RANGE; } else { #endif + PUSH_RANGE("qkv_gemm"); if (fusedINT8QKV_type == 0) { cublas_wrapper->Gemm((int8_t*)Q_int_buf_, 1, @@ -501,9 +509,11 @@ void FusedAttentionLayerINT8::forward(std::vector* from_tensor, (int8_t*)(attention_weights->query_weight.kernel)); } + POP_RANGE; #ifdef SPARSITY_ENABLED } if (sparse_) { + PUSH_RANGE("invokeTrtAddQkvBiasInt8IO"); invokeTrtAddQkvBiasInt8IORow(qkv_buf_, (int8_t*)Q_int_buf_, attention_weights->query_weight.bias, @@ -517,9 +527,11 @@ void FusedAttentionLayerINT8::forward(std::vector* &(scale_list->d_scale_list_[20 + 1]), scale_list->h_scale_list_[scale_list->p4_offset_] / 127.0f, stream_); + POP_RANGE; } else { #endif + PUSH_RANGE("invokeTrtAddQkvBiasInt8IO"); invokeTrtAddQkvBiasInt8IO(qkv_buf_, (int8_t*)Q_int_buf_, attention_weights->query_weight.bias, @@ -533,6 +545,7 @@ void FusedAttentionLayerINT8::forward(std::vector* &(scale_list->d_scale_list_[20 + 1]), scale_list->h_scale_list_[scale_list->p4_offset_] / 127.0f, stream_); + POP_RANGE; #ifdef SPARSITY_ENABLED } #endif @@ -546,7 +559,9 @@ void FusedAttentionLayerINT8::forward(std::vector* scale_list->h_scale_list_[scale_list->p4_offset_ + 1] / 127.0f, scale_list->h_scale_list_[scale_list->p4_offset_ + 2] / 127.0f); dispatcher_int8_->setup(S, B); + PUSH_RANGE("fused mha"); dispatcher_int8_->run(qkv_buf_, nullptr, (int*)input_tensors->at(2).data, attn_workspace_, qkv_buf_2_, stream_); + POP_RANGE; sync_check_cuda_error(); #ifdef SPARSITY_ENABLED @@ -558,6 +573,7 @@ void FusedAttentionLayerINT8::forward(std::vector* } #endif + PUSH_RANGE("proj gemm"); if (int8_mode_ == 1) { cublas_wrapper->Gemm( attention_out, 1, m, n, k, 0, 0, 0, qkv_buf_, (int8_t*)(attention_weights->attention_output_weight.kernel)); @@ -590,24 +606,25 @@ void FusedAttentionLayerINT8::forward(std::vector* } #endif } + POP_RANGE; if (is_free_buffer_after_forward_ == true) { freeBuffer(); } } template -FusedAttentionLayerINT8::FusedAttentionLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - int sm, - float q_scaling, - int int8_mode, - cudaStream_t stream, +FusedAttentionLayerINT8::FusedAttentionLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + int sm, + float q_scaling, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -662,15 +679,15 @@ template void FusedAttentionLayerINT8::allocateBuffer() { if (is_allocate_buffer_ == false) { - Q_int_buf_ = - (int32_t*)allocator_->malloc(sizeof(int32_t) * max_batch_size_ * max_seq_len_ * hidden_units_ * 3, false); + Q_int_buf_ = (int32_t*)allocator_->reMalloc( + Q_int_buf_, sizeof(int32_t) * max_batch_size_ * max_seq_len_ * hidden_units_ * 3, false); K_int_buf_ = Q_int_buf_ + max_batch_size_ * max_seq_len_ * hidden_units_; V_int_buf_ = K_int_buf_ + max_batch_size_ * max_seq_len_ * hidden_units_; - qkv_buf_ = (int8_t*)allocator_->malloc( - (sizeof(int8_t) * 3 * max_batch_size_ * max_seq_len_ * hidden_units_ + 3) / 4 * 4, false); - qkv_buf_2_ = (int8_t*)allocator_->malloc( - (sizeof(int8_t) * max_batch_size_ * max_seq_len_ * hidden_units_ + 3) / 4 * 4, false); - attn_workspace_ = (T*)allocator_->malloc(dispatcher_int8_->getWorkspaceSize(), false); + qkv_buf_ = (int8_t*)allocator_->reMalloc( + qkv_buf_, (sizeof(int8_t) * 3 * max_batch_size_ * max_seq_len_ * hidden_units_ + 3) / 4 * 4, false); + qkv_buf_2_ = (int8_t*)allocator_->reMalloc( + qkv_buf_2_, (sizeof(int8_t) * max_batch_size_ * max_seq_len_ * hidden_units_ + 3) / 4 * 4, false); + attn_workspace_ = (T*)allocator_->reMalloc(attn_workspace_, dispatcher_int8_->getWorkspaceSize(), false); is_allocate_buffer_ = true; } @@ -680,10 +697,10 @@ template void FusedAttentionLayerINT8::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(Q_int_buf_); - allocator_->free(qkv_buf_); - allocator_->free(qkv_buf_2_); - allocator_->free(attn_workspace_); + allocator_->free((void**)(&Q_int_buf_)); + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&qkv_buf_2_)); + allocator_->free((void**)(&attn_workspace_)); is_allocate_buffer_ = false; sync_check_cuda_error(); diff --git a/src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.h b/src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.h index accb33c82..c09242892 100644 --- a/src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.h +++ b/src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.h @@ -41,18 +41,18 @@ class FusedAttentionLayerINT8: public BaseAttentionLayer { // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; void allocateBuffer() override; void freeBuffer() override; bool isValidBatchSize(size_t batch_size); bool isValidSeqLen(size_t seq_len); - float q_scaling_; - int sm_; - int int8_mode_; + float q_scaling_; + int sm_; + int int8_mode_; std::unique_ptr dispatcher_int8_; - bool sparse_; + bool sparse_; using BaseAttentionLayer::stream_; using BaseAttentionLayer::is_free_buffer_after_forward_; @@ -64,31 +64,31 @@ class FusedAttentionLayerINT8: public BaseAttentionLayer { int32_t* Q_int_buf_; int32_t* K_int_buf_; int32_t* V_int_buf_; - int8_t* qkv_buf_; - int8_t* qkv_buf_2_; - T* attn_workspace_; + int8_t* qkv_buf_; + int8_t* qkv_buf_2_; + T* attn_workspace_; public: - FusedAttentionLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - int sm, - float q_scaling, - int int8_mode, - cudaStream_t stream, + FusedAttentionLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + int sm, + float q_scaling, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false); FusedAttentionLayerINT8(FusedAttentionLayerINT8 const& attention_layer); ~FusedAttentionLayerINT8(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; void invokeTrtAddQkvBias(size_t token_num, const AttentionWeight* attention_weights); }; diff --git a/src/fastertransformer/layers/attention_layers_int8/UnfusedAttentionLayerINT8.cc b/src/fastertransformer/layers/attention_layers_int8/UnfusedAttentionLayerINT8.cc index 9e2ef77c8..9c785e1d9 100644 --- a/src/fastertransformer/layers/attention_layers_int8/UnfusedAttentionLayerINT8.cc +++ b/src/fastertransformer/layers/attention_layers_int8/UnfusedAttentionLayerINT8.cc @@ -18,13 +18,14 @@ #include "src/fastertransformer/kernels/softmax_int8_kernels.h" #include "src/fastertransformer/kernels/transpose_int8_kernels.h" #include "src/fastertransformer/kernels/unfused_attention_int8_kernels.h" +#include "src/fastertransformer/utils/nvtx_utils.h" namespace fastertransformer { template -void UnfusedAttentionLayerINT8::forward(std::vector* output_tensors, +void UnfusedAttentionLayerINT8::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input_tensors: [input (token_num, hidden_dimension), @@ -33,24 +34,24 @@ void UnfusedAttentionLayerINT8::forward(std::vector*)attention_weights)->scale_list_ptr; + const ScaleList* scale_list = ((const AttentionINT8Weight*)attention_weights)->scale_list_ptr; cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_; FT_CHECK(isValidBatchSize(input_tensors->at(1).shape[0])); FT_CHECK(isValidSeqLen(input_tensors->at(1).shape[2])); allocateBuffer(); - int32_t* attention_out = (int32_t*)output_tensors->at(0).data; - const int8_t* from_tensor = (const int8_t*)input_tensors->at(0).data; - const T* attention_mask = (const T*)input_tensors->at(1).data; - const int* padding_offset = (const int*)input_tensors->at(2).data; + int32_t* attention_out = (int32_t*)output_tensors->at(0).data; + const int8_t* from_tensor = (const int8_t*)input_tensors->at(0).data; + const T* attention_mask = (const T*)input_tensors->at(1).data; + const int* padding_offset = (const int*)input_tensors->at(2).data; const int request_batch_size = input_tensors->at(1).shape[0]; - const int request_seq_len = input_tensors->at(1).shape[2]; - const int m = input_tensors->at(0).shape[0]; - const int k = hidden_units_; - const int n = hidden_units_; - int m_tmp = m; + const int request_seq_len = input_tensors->at(1).shape[2]; + const int m = input_tensors->at(0).shape[0]; + const int k = hidden_units_; + const int n = hidden_units_; + int m_tmp = m; if (m_tmp % 16 != 0) { m_tmp = (m_tmp / 16 + 1) * 16; } @@ -70,6 +71,7 @@ void UnfusedAttentionLayerINT8::forward(std::vectorGemm( Q_int_buf_, 1, m, n, k, 0, 0, 0, from_tensor, (int8_t*)(attention_weights->query_weight.kernel)); @@ -91,6 +93,7 @@ void UnfusedAttentionLayerINT8::forward(std::vectorquery_weight.kernel)); } + POP_RANGE; } else if (int8_mode_ == 2 || int8_mode_ == 3) { // K_int_buf_ V_int_buf_ should point to correct buffer according to m @@ -99,6 +102,7 @@ void UnfusedAttentionLayerINT8::forward(std::vectorSpGemm(n, m_padded, k, @@ -120,10 +124,12 @@ void UnfusedAttentionLayerINT8::forward(std::vectorvalue_weight.sp_kernel), from_tensor, (int8_t*)V_int_buf_); + POP_RANGE; } else { #endif if (fusedINT8QKV_type == 0) { + PUSH_RANGE("qkv_gemm"); cublas_wrapper->Gemm((int8_t*)Q_int_buf_, 1, m, @@ -157,9 +163,11 @@ void UnfusedAttentionLayerINT8::forward(std::vectorh_scale_list_[scale_list->p3_offset_ + 2], from_tensor, (int8_t*)(attention_weights->value_weight.kernel)); + POP_RANGE; } else { int strideFactor = (fusedINT8QKV_type == 1) ? (sizeof(T) / sizeof(int8_t)) : 1; + PUSH_RANGE("qkv_gemm"); cublas_wrapper->Gemm((int8_t*)Q_int_buf_, 3, m, @@ -171,6 +179,7 @@ void UnfusedAttentionLayerINT8::forward(std::vectorh_scale_list_[scale_list->p3_offset_ + 0], from_tensor, (int8_t*)(attention_weights->query_weight.kernel)); + POP_RANGE; } #ifdef SPARSITY_ENABLED } @@ -180,6 +189,7 @@ void UnfusedAttentionLayerINT8::forward(std::vector::forward(std::vectord_scale_list_[24 + 3]), cublas_wrapper->getUseOrderCol322R4R4(), stream_); + POP_RANGE; } else if (int8_mode_ == 2 || int8_mode_ == 3) { #ifdef SPARSITY_ENABLED if (sparse_) { + PUSH_RANGE("addQKVBiasTransformer"); invokeAddQKBiasTransformRow(q_buf_, k_buf_, (const int8_t*)Q_int_buf_, @@ -241,9 +253,11 @@ void UnfusedAttentionLayerINT8::forward(std::vectord_scale_list_[24 + 3]), cublas_wrapper->getUseOrderCol322R4R4(), stream_); + POP_RANGE; } else { #endif + PUSH_RANGE("addQKVBiasTransformer"); invokeAddQKBiasTransform(q_buf_, k_buf_, (const int8_t*)Q_int_buf_, @@ -271,6 +285,7 @@ void UnfusedAttentionLayerINT8::forward(std::vectord_scale_list_[24 + 3]), cublas_wrapper->getUseOrderCol322R4R4(), stream_); + POP_RANGE; #ifdef SPARSITY_ENABLED } #endif @@ -285,7 +300,7 @@ void UnfusedAttentionLayerINT8::forward(std::vector::forward(std::vectord_scale_list_[24 + 3]), cublas_wrapper->getUseOrderCol322R4R4(), stream_); + POP_RANGE; } else if (int8_mode_ == 2 || int8_mode_ == 3) { #ifdef SPARSITY_ENABLED if (sparse_) { + PUSH_RANGE("addQKVBiasTransformerRebuildPadding"); invokeAddQKBiasTransformRebuildPaddingRow(q_buf_, k_buf_, (const int8_t*)Q_int_buf_, @@ -357,9 +374,11 @@ void UnfusedAttentionLayerINT8::forward(std::vectord_scale_list_[24 + 3]), cublas_wrapper->getUseOrderCol322R4R4(), stream_); + POP_RANGE; } else { #endif + PUSH_RANGE("addQKVBiasTransformerRebuildPadding"); invokeAddQKBiasTransformRebuildPadding(q_buf_, k_buf_, (const int8_t*)Q_int_buf_, @@ -392,6 +411,7 @@ void UnfusedAttentionLayerINT8::forward(std::vectord_scale_list_[24 + 3]), cublas_wrapper->getUseOrderCol322R4R4(), stream_); + POP_RANGE; #ifdef SPARSITY_ENABLED } #endif @@ -399,9 +419,10 @@ void UnfusedAttentionLayerINT8::forward(std::vectorGemm(qk_int_buf_, batchCount, request_seq_len, @@ -412,7 +433,8 @@ void UnfusedAttentionLayerINT8::forward(std::vector::forward(std::vectord_scale_list_[16 + 1]), &(scale_list->d_scale_list_[32]), stream_); - + POP_RANGE; + PUSH_RANGE("QK*V batch gemm"); cublas_wrapper->Gemm(transpose_dst_int_buf_, batchCount, request_seq_len, @@ -435,7 +458,9 @@ void UnfusedAttentionLayerINT8::forward(std::vector::forward(std::vectord_scale_list_[36 + 3]), stream_); } + POP_RANGE; } else if (int8_mode_ == 2 || int8_mode_ == 3) { + PUSH_RANGE("Q*K batch gemm"); cublas_wrapper->Gemm((int8_t*)qk_int_buf_, batchCount, request_seq_len, @@ -476,7 +503,8 @@ void UnfusedAttentionLayerINT8::forward(std::vectorh_scale_list_[scale_list->p3_offset_ + 3], q_buf_, k_buf_); - + POP_RANGE; + PUSH_RANGE("softmax"); invokeSoftmaxCOL32(qk_buf_, (int8_t*)qk_int_buf_, attention_mask, @@ -487,7 +515,8 @@ void UnfusedAttentionLayerINT8::forward(std::vectord_scale_list_[28 + 1]), &(scale_list->d_scale_list_[32]), stream_); - + POP_RANGE; + PUSH_RANGE("QK*V batch gemm"); cublas_wrapper->Gemm((int8_t*)transpose_dst_int_buf_, batchCount, request_seq_len, @@ -499,8 +528,11 @@ void UnfusedAttentionLayerINT8::forward(std::vectorh_scale_list_[scale_list->p3_offset_ + 4], qk_buf_, v_buf_); + POP_RANGE; + #ifdef SPARSITY_ENABLED if (sparse_) { + PUSH_RANGE("transposeRebuildPadding"); if (padding_offset == nullptr) { invokeTransposeCOL32ToRow(dst_, (const int8_t*)transpose_dst_int_buf_, @@ -525,9 +557,11 @@ void UnfusedAttentionLayerINT8::forward(std::vectord_scale_list_[36 + 3]), stream_); } + POP_RANGE; } else { #endif + PUSH_RANGE("transposeRebuildPadding"); if (padding_offset == nullptr) { invokeTransposeCOL32(dst_, (const int8_t*)transpose_dst_int_buf_, @@ -552,11 +586,13 @@ void UnfusedAttentionLayerINT8::forward(std::vectord_scale_list_[36 + 3]), stream_); } + POP_RANGE; #ifdef SPARSITY_ENABLED } #endif } + PUSH_RANGE("proj gemm"); if (int8_mode_ == 1) { cublas_wrapper->Gemm( attention_out, 1, m, n, k, 0, 0, 0, dst_, (int8_t*)(attention_weights->attention_output_weight.kernel)); @@ -589,6 +625,7 @@ void UnfusedAttentionLayerINT8::forward(std::vector::forward(std::vector -UnfusedAttentionLayerINT8::UnfusedAttentionLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - float q_scaling, - int int8_mode, - cudaStream_t stream, +UnfusedAttentionLayerINT8::UnfusedAttentionLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + float q_scaling, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -649,24 +686,28 @@ void UnfusedAttentionLayerINT8::allocateBuffer() { if (is_allocate_buffer_ == false) { int padded_max_seq_len = (max_seq_len_ + 31) / 32 * 32; - q_buf_ = (int8_t*)allocator_->malloc(sizeof(int8_t) * max_batch_size_ * padded_max_seq_len * hidden_units_ * 3, - false); - k_buf_ = q_buf_ + max_batch_size_ * padded_max_seq_len * hidden_units_; - v_buf_ = k_buf_ + max_batch_size_ * padded_max_seq_len * hidden_units_; - qk_buf_ = (int8_t*)allocator_->malloc( - sizeof(int8_t) * max_batch_size_ * head_num_ * padded_max_seq_len * padded_max_seq_len, false); - dst_ = (int8_t*)allocator_->malloc(sizeof(int8_t) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - - Q_int_buf_ = - (int32_t*)allocator_->malloc(sizeof(int32_t) * max_batch_size_ * max_seq_len_ * hidden_units_ * 3, false); - V_int_buf_ = Q_int_buf_ + max_batch_size_ * max_seq_len_ * hidden_units_; - K_int_buf_ = V_int_buf_ + max_batch_size_ * max_seq_len_ * hidden_units_; - qk_int_buf_ = (int32_t*)allocator_->malloc( - sizeof(int32_t) * max_batch_size_ * head_num_ * padded_max_seq_len * padded_max_seq_len, false); - transpose_dst_int_buf_ = - (int32_t*)allocator_->malloc(sizeof(int32_t) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - - sequence_id_map_ = (int32_t*)allocator_->malloc(sizeof(int32_t) * max_batch_size_ * max_seq_len_, false); + q_buf_ = (int8_t*)allocator_->reMalloc( + q_buf_, sizeof(int8_t) * max_batch_size_ * padded_max_seq_len * hidden_units_ * 3, false); + k_buf_ = q_buf_ + max_batch_size_ * padded_max_seq_len * hidden_units_; + v_buf_ = k_buf_ + max_batch_size_ * padded_max_seq_len * hidden_units_; + qk_buf_ = (int8_t*)allocator_->reMalloc( + qk_buf_, sizeof(int8_t) * max_batch_size_ * head_num_ * padded_max_seq_len * padded_max_seq_len, false); + dst_ = + (int8_t*)allocator_->reMalloc(dst_, sizeof(int8_t) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + + Q_int_buf_ = (int32_t*)allocator_->reMalloc( + Q_int_buf_, sizeof(int32_t) * max_batch_size_ * max_seq_len_ * hidden_units_ * 3, false); + V_int_buf_ = Q_int_buf_ + max_batch_size_ * max_seq_len_ * hidden_units_; + K_int_buf_ = V_int_buf_ + max_batch_size_ * max_seq_len_ * hidden_units_; + qk_int_buf_ = (int32_t*)allocator_->reMalloc(qk_int_buf_, + sizeof(int32_t) * max_batch_size_ * head_num_ * padded_max_seq_len + * padded_max_seq_len, + false); + transpose_dst_int_buf_ = (int32_t*)allocator_->reMalloc( + transpose_dst_int_buf_, sizeof(int32_t) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + + sequence_id_map_ = + (int32_t*)allocator_->reMalloc(sequence_id_map_, sizeof(int32_t) * max_batch_size_ * max_seq_len_, false); is_allocate_buffer_ = true; } @@ -676,13 +717,13 @@ template void UnfusedAttentionLayerINT8::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(q_buf_); - allocator_->free(Q_int_buf_); - allocator_->free(qk_buf_); - allocator_->free(qk_int_buf_); - allocator_->free(dst_); - allocator_->free(transpose_dst_int_buf_); - allocator_->free(sequence_id_map_); + allocator_->free((void**)(&q_buf_)); + allocator_->free((void**)(&Q_int_buf_)); + allocator_->free((void**)(&qk_buf_)); + allocator_->free((void**)(&qk_int_buf_)); + allocator_->free((void**)(&dst_)); + allocator_->free((void**)(&transpose_dst_int_buf_)); + allocator_->free((void**)(&sequence_id_map_)); is_allocate_buffer_ = false; } } diff --git a/src/fastertransformer/layers/attention_layers_int8/UnfusedAttentionLayerINT8.h b/src/fastertransformer/layers/attention_layers_int8/UnfusedAttentionLayerINT8.h index c95884859..9169c57ed 100644 --- a/src/fastertransformer/layers/attention_layers_int8/UnfusedAttentionLayerINT8.h +++ b/src/fastertransformer/layers/attention_layers_int8/UnfusedAttentionLayerINT8.h @@ -29,7 +29,7 @@ class UnfusedAttentionLayerINT8: public BaseAttentionLayer { private: // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; // metadata size_t head_num_; @@ -37,8 +37,8 @@ class UnfusedAttentionLayerINT8: public BaseAttentionLayer { size_t hidden_units_; float q_scaling_; - int int8_mode_; - bool sparse_; + int int8_mode_; + bool sparse_; void allocateBuffer() override; void freeBuffer() override; @@ -52,29 +52,29 @@ class UnfusedAttentionLayerINT8: public BaseAttentionLayer { using BaseAttentionLayer::allocator_; protected: - int8_t *q_buf_, *k_buf_, *v_buf_, *qk_buf_, *dst_; + int8_t * q_buf_, *k_buf_, *v_buf_, *qk_buf_, *dst_; int32_t *Q_int_buf_, *V_int_buf_, *K_int_buf_, *qk_int_buf_, *transpose_dst_int_buf_, *sequence_id_map_; public: - UnfusedAttentionLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - float q_scaling, - int int8_mode, - cudaStream_t stream, + UnfusedAttentionLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + float q_scaling, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false); UnfusedAttentionLayerINT8(UnfusedAttentionLayerINT8 const& attention_layer); ~UnfusedAttentionLayerINT8(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) override; + const AttentionWeight* attention_weights) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers_int8/WindowAttentionINT8.cu b/src/fastertransformer/layers/attention_layers_int8/WindowAttentionINT8.cu index 74037aaf8..eb29e459f 100644 --- a/src/fastertransformer/layers/attention_layers_int8/WindowAttentionINT8.cu +++ b/src/fastertransformer/layers/attention_layers_int8/WindowAttentionINT8.cu @@ -24,27 +24,27 @@ namespace fastertransformer { // block(8, 32) // size should be a multiple of 4 template -__global__ void swin_trt_add_QKV_bias_COL32_int8IO(char4* output, +__global__ void swin_trt_add_QKV_bias_COL32_int8IO(char4* output, const char4* QKV, - const T* bias_Q, - const T* bias_K, - const T* bias_V, + const T* bias_Q, + const T* bias_K, + const T* bias_V, const float* q_bias_QFactor_ptr, const float* k_bias_QFactor_ptr, const float* v_bias_QFactor_ptr, - const float qkv_deQFactor, - const int valid_word_num, - const int head_num, - const int size_per_head, - const int head_num_x_size_per_head) + const float qkv_deQFactor, + const int valid_word_num, + const int head_num, + const int size_per_head, + const int head_num_x_size_per_head) { - const int qkv_id = blockIdx.z; - const int seq_id = (blockIdx.y << 5) + threadIdx.y; + const int qkv_id = blockIdx.z; + const int seq_id = (blockIdx.y << 5) + threadIdx.y; const int threadIdx4 = threadIdx.x << 2; - const int hidden_id = (blockIdx.x << 5) + threadIdx4; - const int size_id = hidden_id % size_per_head; - const int head_id = hidden_id / size_per_head; - const int col_id = qkv_id * head_num_x_size_per_head + hidden_id; + const int hidden_id = (blockIdx.x << 5) + threadIdx4; + const int size_id = hidden_id % size_per_head; + const int head_id = hidden_id / size_per_head; + const int col_id = qkv_id * head_num_x_size_per_head + hidden_id; const bool qual = (seq_id < valid_word_num) && (hidden_id < head_num_x_size_per_head); if (qual) { @@ -82,22 +82,22 @@ __global__ void swin_trt_add_QKV_bias_COL32_int8IO(char4* output, } template -void invokeSwinTrtAddQkvBiasInt8IO(int8_t* output, - const int8_t* Q, - const T* bias_Q, - const T* bias_K, - const T* bias_V, - const size_t token_num, - const size_t head_num, - const size_t size_per_head, - const float* q_bias_QFactor_ptr, - const float* k_bias_QFactor_ptr, - const float* v_bias_QFactor_ptr, - const float qkv_deQFactor, +void invokeSwinTrtAddQkvBiasInt8IO(int8_t* output, + const int8_t* Q, + const T* bias_Q, + const T* bias_K, + const T* bias_V, + const size_t token_num, + const size_t head_num, + const size_t size_per_head, + const float* q_bias_QFactor_ptr, + const float* k_bias_QFactor_ptr, + const float* v_bias_QFactor_ptr, + const float qkv_deQFactor, const cudaStream_t stream) { - int head_num_x_size_per_head = head_num * size_per_head; + int head_num_x_size_per_head = head_num * size_per_head; dim3 grid((head_num_x_size_per_head + 31) / 32, (token_num + 31) / 32, 3); dim3 block(8, 32); @@ -152,37 +152,44 @@ void WindowAttentionINT8::allocateBuffer() } if (is_allocate_buffer_ == false) { if (use_trt_) { - int S = trt_getS(window_len_); - trt_attention_mask_ = (half*)allocator_->malloc(roundByteSize(window_num_ * S * S * sizeof(T), 4), false); - trt_relative_position_bias_ = - (half*)allocator_->malloc(roundByteSize(num_head * S * S * sizeof(T), 4), false); - Q_buf_ = (int8_t*)allocator_->malloc( - 3 * max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int8_t), false); - K_buf_ = Q_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; - V_buf_ = K_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; - q_buf_ = (int8_t*)allocator_->malloc( - 3 * max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int8_t), false); - k_buf_ = q_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; - v_buf_ = k_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; - dst_buf_ = (int8_t*)allocator_->malloc( - max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int8_t), false); + int S = trt_getS(window_len_); + trt_attention_mask_ = (half*)allocator_->reMalloc( + trt_attention_mask_, roundByteSize(window_num_ * S * S * sizeof(T), 4), false); + trt_relative_position_bias_ = (half*)allocator_->reMalloc( + trt_relative_position_bias_, roundByteSize(num_head * S * S * sizeof(T), 4), false); + Q_buf_ = (int8_t*)allocator_->reMalloc(Q_buf_, + 3 * max_batch_ * patches_resolution_ * patches_resolution_ + * embed_dim_ * sizeof(int8_t), + false); + K_buf_ = Q_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; + V_buf_ = K_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; + q_buf_ = (int8_t*)allocator_->reMalloc(q_buf_, + 3 * max_batch_ * patches_resolution_ * patches_resolution_ + * embed_dim_ * sizeof(int8_t), + false); + k_buf_ = q_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; + v_buf_ = k_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; + dst_buf_ = (int8_t*)allocator_->reMalloc( + dst_buf_, max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int8_t), false); } else { int padded_winlen = (window_len_ + 31) / 32 * 32; - Q_buf_ = (int8_t*)allocator_->malloc( - 3 * max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int8_t), false); - K_buf_ = Q_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; - V_buf_ = K_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; - q_buf_ = (int8_t*)allocator_->malloc(max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(int8_t), - false); - k_buf_ = (int8_t*)allocator_->malloc(max_batch_ * window_num_ * padded_winlen * embed_dim_ * sizeof(int8_t), - false); - v_buf_ = (int8_t*)allocator_->malloc(max_batch_ * window_num_ * padded_winlen * embed_dim_ * sizeof(int8_t), - false); - qk_buf_ = - (int8_t*)allocator_->malloc(max_batch_ * window_num_ * num_head * window_len_ * padded_winlen, false); - dst_buf_ = (int8_t*)allocator_->malloc(max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(int8_t), + Q_buf_ = (int8_t*)allocator_->reMalloc(Q_buf_, + 3 * max_batch_ * patches_resolution_ * patches_resolution_ + * embed_dim_ * sizeof(int8_t), false); + K_buf_ = Q_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; + V_buf_ = K_buf_ + max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_; + q_buf_ = (int8_t*)allocator_->reMalloc( + q_buf_, max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(int8_t), false); + k_buf_ = (int8_t*)allocator_->reMalloc( + k_buf_, max_batch_ * window_num_ * padded_winlen * embed_dim_ * sizeof(int8_t), false); + v_buf_ = (int8_t*)allocator_->reMalloc( + v_buf_, max_batch_ * window_num_ * padded_winlen * embed_dim_ * sizeof(int8_t), false); + qk_buf_ = (int8_t*)allocator_->reMalloc( + qk_buf_, max_batch_ * window_num_ * num_head * window_len_ * padded_winlen, false); + dst_buf_ = (int8_t*)allocator_->reMalloc( + dst_buf_, max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(int8_t), false); } is_allocate_buffer_ = true; } @@ -193,34 +200,34 @@ void WindowAttentionINT8::freeBuffer() { if (is_allocate_buffer_ == true) { if (use_trt_) { - allocator_->free(trt_attention_mask_); - allocator_->free(trt_relative_position_bias_); - allocator_->free(Q_buf_); - allocator_->free(q_buf_); - allocator_->free(dst_buf_); + allocator_->free((void**)(&trt_attention_mask_)); + allocator_->free((void**)(&trt_relative_position_bias_)); + allocator_->free((void**)(&Q_buf_)); + allocator_->free((void**)(&q_buf_)); + allocator_->free((void**)(&dst_buf_)); } else { - allocator_->free(Q_buf_); - allocator_->free(q_buf_); - allocator_->free(k_buf_); - allocator_->free(v_buf_); - allocator_->free(dst_buf_); + allocator_->free((void**)(&Q_buf_)); + allocator_->free((void**)(&q_buf_)); + allocator_->free((void**)(&k_buf_)); + allocator_->free((void**)(&v_buf_)); + allocator_->free((void**)(&dst_buf_)); } is_allocate_buffer_ = false; } } template -WindowAttentionINT8::WindowAttentionINT8(int max_batch, - int window_size, - int patches_resolution, - int embed_dim, - cudaStream_t stream, +WindowAttentionINT8::WindowAttentionINT8(int max_batch, + int window_size, + int patches_resolution, + int embed_dim, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias, - float qk_scale): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias, + float qk_scale): BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), patches_resolution_(patches_resolution), embed_dim_(embed_dim), @@ -238,9 +245,9 @@ WindowAttentionINT8::~WindowAttentionINT8() } template -void WindowAttentionINT8::forward(std::vector* output_tensors, +void WindowAttentionINT8::forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights) + const AttentionWeight* attention_weights) { // input_tensors: // input [batch * window_num * window_len, dim] @@ -250,32 +257,32 @@ void WindowAttentionINT8::forward(std::vector* out // output_tensors: // output [batch * window_num * window_len, dim] - cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_; - int8_t* attention_out = (int8_t*)output_tensors->at(0).data; - const int8_t* from_tensor = (const int8_t*)input_tensors->at(0).data; - const T* attention_mask = (const T*)input_tensors->at(1).data; - const T* attention_relative_pos_bias = (const T*)input_tensors->at(2).data; - const int* input_parameters = (const int*)input_tensors->at(3).data; - const int batch = input_parameters[0]; - const int dim = input_parameters[1]; - const int input_resolution = input_parameters[2]; - const int num_head = input_parameters[3]; - const int shift_size = input_parameters[4]; - const int sm = input_parameters[5]; + cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_; + int8_t* attention_out = (int8_t*)output_tensors->at(0).data; + const int8_t* from_tensor = (const int8_t*)input_tensors->at(0).data; + const T* attention_mask = (const T*)input_tensors->at(1).data; + const T* attention_relative_pos_bias = (const T*)input_tensors->at(2).data; + const int* input_parameters = (const int*)input_tensors->at(3).data; + const int batch = input_parameters[0]; + const int dim = input_parameters[1]; + const int input_resolution = input_parameters[2]; + const int num_head = input_parameters[3]; + const int shift_size = input_parameters[4]; + const int sm = input_parameters[5]; int use_ORDER_COL32_2R_4R4 = (sm >= 80 ? 1 : 0); int size_per_head = dim / num_head; - int trt_S = 1024; + int trt_S = 1024; if ((sm == 75 || sm == 80 || sm == 86) && size_per_head == 32 && window_len_ <= TRT_MAX_LEN && std::is_same::value) { - trt_S = trt_getS(window_len_); + trt_S = trt_getS(window_len_); use_trt_ = true; } - num_head_ = num_head; + num_head_ = num_head; patches_resolution_ = input_resolution; - window_num_ = (input_resolution / window_size_) * (input_resolution / window_size_); - embed_dim_ = dim; + window_num_ = (input_resolution / window_size_) * (input_resolution / window_size_); + embed_dim_ = dim; allocateBuffer(); float scale = 1.0f / sqrt(size_per_head); diff --git a/src/fastertransformer/layers/attention_layers_int8/WindowAttentionINT8.h b/src/fastertransformer/layers/attention_layers_int8/WindowAttentionINT8.h index ffb291009..7df57c111 100644 --- a/src/fastertransformer/layers/attention_layers_int8/WindowAttentionINT8.h +++ b/src/fastertransformer/layers/attention_layers_int8/WindowAttentionINT8.h @@ -39,26 +39,26 @@ namespace fastertransformer { template class WindowAttentionINT8: public BaseAttentionLayer { private: - int max_batch_ = 1; - int dim_ = 96; - int num_head_ = 2; - int window_size_ = 7; - int head_dim_ = 48; - int input_resolution_ = 56; - int window_len_ = 49; - int patches_resolution_ = 56; - int embed_dim_ = 96; - int window_num_ = 64; - int size_per_head_; - bool qkv_bias_ = true; - float qk_scale_ = 1.0f; + int max_batch_ = 1; + int dim_ = 96; + int num_head_ = 2; + int window_size_ = 7; + int head_dim_ = 48; + int input_resolution_ = 56; + int window_len_ = 49; + int patches_resolution_ = 56; + int embed_dim_ = 96; + int window_num_ = 64; + int size_per_head_; + bool qkv_bias_ = true; + float qk_scale_ = 1.0f; size_t max_buf_size_ = 0; - bool use_trt_ = false; + bool use_trt_ = false; void allocateBuffer() override; void freeBuffer() override; - int dispatcher_int8_num_head_ = -1; + int dispatcher_int8_num_head_ = -1; std::unique_ptr dispatcher_int8_; using BaseAttentionLayer::stream_; @@ -77,22 +77,22 @@ class WindowAttentionINT8: public BaseAttentionLayer { static size_t roundByteSize(const size_t size, const int factor); public: - WindowAttentionINT8(int max_batch, - int window_size, - int patches_resolution, - int embed_dim, - cudaStream_t stream, + WindowAttentionINT8(int max_batch, + int window_size, + int patches_resolution, + int embed_dim, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward = false, - bool qkv_bias = true, - float qk_scale = 1.0f); + IAllocator* allocator, + bool is_free_buffer_after_forward = false, + bool qkv_bias = true, + float qk_scale = 1.0f); ~WindowAttentionINT8(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const AttentionWeight* attention_weights); + const AttentionWeight* attention_weights); }; // class WindowAttentionINT8 } // namespace fastertransformer diff --git a/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.cu b/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.cu index b6b844cc0..e6e011d2f 100644 --- a/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.cu +++ b/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.cu @@ -20,51 +20,57 @@ namespace fastertransformer { -__global__ void update_indir_cache_kernel(int* tgt_indir_cache, - const int* src_indir_cache, - const int* beam_ids, +__global__ void update_indir_cache_kernel(int* tgt_indir_cache, + const int* src_indir_cache, + const int* beam_ids, const bool* finished, - int batch_dim, - int local_batch_size, - int beam_width, - int max_seq_len, - int step) + int start_step, + int batch_dim, + int local_batch_size, + int beam_width, + int max_seq_len, + int step) { - int time_step = threadIdx.x + blockIdx.x * blockDim.x; - int bb_id = threadIdx.y + blockIdx.y * blockDim.y; - const int batch_id = bb_id / beam_width; - const int beam_id = bb_id % beam_width; + int time_step = threadIdx.x + blockIdx.x * blockDim.x; + int bb_id = threadIdx.y + blockIdx.y * blockDim.y; + const int batch_id = bb_id / beam_width; + const int beam_id = bb_id % beam_width; if (bb_id >= beam_width * local_batch_size || time_step >= min(step + 1, max_seq_len) || finished[bb_id]) { return; } + time_step += start_step; + const int time_step_circ = time_step % max_seq_len; const int src_beam = beam_ids[batch_id * beam_width + beam_id]; - const uint tgt_offset = batch_id * beam_width * max_seq_len + beam_id * max_seq_len + time_step; - const uint src_offset = batch_id * beam_width * max_seq_len + src_beam * max_seq_len + time_step; + const uint tgt_offset = batch_id * beam_width * max_seq_len + beam_id * max_seq_len + time_step_circ; + const uint src_offset = batch_id * beam_width * max_seq_len + src_beam * max_seq_len + time_step_circ; tgt_indir_cache[tgt_offset] = (time_step == step) ? beam_id : src_indir_cache[src_offset]; } -void update_indir_cache_kernelLauncher(int* tgt_indir_cache, - const int* src_indir_cache, - const int* beam_ids, - const bool* finished, - int batch_dim, - int local_batch_size, - int beam_width, - int max_seq_len, - int step, +void update_indir_cache_kernelLauncher(int* tgt_indir_cache, + const int* src_indir_cache, + const int* beam_ids, + const bool* finished, + int batch_dim, + int local_batch_size, + int beam_width, + int max_seq_len, + int step, cudaStream_t stream) { const dim3 block(32); - // Update indirections steps [0, step], included - const dim3 grid((step + 1 + block.x - 1) / block.x, local_batch_size * beam_width); + const int start_step = max(0, step + 1 - max_seq_len); + const int num_steps = min(step + 1, max_seq_len); + // Update indirections steps [start_step, step], included + const dim3 grid((num_steps + block.x - 1) / block.x, local_batch_size * beam_width); update_indir_cache_kernel<<>>(tgt_indir_cache, src_indir_cache, beam_ids, finished, + start_step, batch_dim, local_batch_size, beam_width, @@ -73,21 +79,21 @@ void update_indir_cache_kernelLauncher(int* tgt_indir_cache, } template -BaseBeamSearchLayer::BaseBeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, +BaseBeamSearchLayer::BaseBeamSearchLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t beam_width, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + float diversity_rate, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): DynamicDecodeBaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr), vocab_size_(vocab_size), vocab_size_padded_(vocab_size_padded) @@ -106,6 +112,7 @@ BaseBeamSearchLayer::BaseBeamSearchLayer(BaseBeamSearchLayer const& beam_s template BaseBeamSearchLayer::~BaseBeamSearchLayer() { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); freeBuffer(); } @@ -113,11 +120,19 @@ template void BaseBeamSearchLayer::freeBuffer() { if (is_allocate_buffer_) { - allocator_->free(topk_softmax_workspace_); + allocator_->free((void**)(&topk_softmax_workspace_)); is_allocate_buffer_ = false; } } +template +void BaseBeamSearchLayer::setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) +{ + // do nothing. +} + template void BaseBeamSearchLayer::forward(std::vector* output_tensors, const std::vector* input_tensors) { @@ -156,7 +171,7 @@ void BaseBeamSearchLayer::forward(std::vector* output_tensors, const } template -void BaseBeamSearchLayer::forward(std::unordered_map* output_tensors, +void BaseBeamSearchLayer::forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) { // input_tensors: @@ -188,16 +203,12 @@ void BaseBeamSearchLayer::forward(std::unordered_map* ou const int beam_width = output_tensors->at("output_ids").shape[2]; allocateBuffer(batch_size, beam_width); - const int step = *((int*)input_tensors->at("step").data); - const int ite = *((int*)input_tensors->at("ite").data); + const int step = *((int*)input_tensors->at("step").data); + const int ite = *((int*)input_tensors->at("ite").data); const int local_batch_size = input_tensors->at("logits").shape[0]; - const int head_num = input_tensors->at("src_value_cache").shape[2]; - const int size_per_head = input_tensors->at("src_value_cache").shape[4]; const float temperature = input_tensors->count("temperature") ? input_tensors->at("temperature").getVal() : 1.0f; - const float len_penalty = - input_tensors->count("len_penalty") ? input_tensors->at("len_penalty").getVal() : 1.0f; const float repetition_penalty = input_tensors->count("repetition_penalty") ? input_tensors->at("repetition_penalty").getVal() : 1.0f; @@ -218,7 +229,6 @@ void BaseBeamSearchLayer::forward(std::unordered_map* ou vocab_size_padded_, (const int*)input_tensors->at("end_id").data, temperature, - len_penalty, repetition_penalty, stream_); sync_check_cuda_error(); diff --git a/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.h b/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.h index 73d4065d6..6a7fa1c0c 100644 --- a/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.h +++ b/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.h @@ -31,50 +31,53 @@ class BaseBeamSearchLayer: public DynamicDecodeBaseLayer { size_t vocab_size_padded_; size_t topk_softmax_workspace_size_; - void* topk_softmax_workspace_ = nullptr; + void* topk_softmax_workspace_ = nullptr; - virtual void allocateBuffer() = 0; - virtual void allocateBuffer(size_t batch_size, size_t beam_width) = 0; - virtual void invokeSoftMax(std::vector* output_tensors, - const std::vector* input_tensors) = 0; - virtual void invokeSoftMax(std::unordered_map* output_tensors, + virtual void allocateBuffer() = 0; + virtual void allocateBuffer(size_t batch_size, size_t beam_width) = 0; + virtual void invokeSoftMax(std::vector* output_tensors, + const std::vector* input_tensors) = 0; + virtual void invokeSoftMax(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) = 0; public: - BaseBeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, + BaseBeamSearchLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t beam_width, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + float diversity_rate, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); BaseBeamSearchLayer(BaseBeamSearchLayer const& beam_search_layer); ~BaseBeamSearchLayer(); - void forward(std::vector* output_tensors, + void setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) override; + void forward(std::vector* output_tensors, const std::vector* input_tensors) override; - void forward(std::unordered_map* output_tensors, + void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) override; }; -void update_indir_cache_kernelLauncher(int* tgt_indir_cache, - const int* src_indir_cache, - const int* beam_ids, - const bool* finished, - int batch_dim, - int beam_width, - int max_seq_len, - int ite, +void update_indir_cache_kernelLauncher(int* tgt_indir_cache, + const int* src_indir_cache, + const int* beam_ids, + const bool* finished, + int batch_dim, + int beam_width, + int max_seq_len, + int ite, cudaStream_t stream); } // namespace fastertransformer diff --git a/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.cu b/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.cu index bc0f5543a..e2aa9d001 100644 --- a/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.cu +++ b/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.cu @@ -20,19 +20,19 @@ namespace fastertransformer { template -__global__ void logProbAddCumLogProb(float* log_probs, - const T* logits, +__global__ void logProbAddCumLogProb(float* log_probs, + const T* logits, const float* cum_log_probs, - const int* end_ids, - const bool* finished, - const int beam_width, - const int n) + const int* end_ids, + const bool* finished, + const int beam_width, + const int n) { - int bid = blockIdx.x; - bool finish = finished[bid]; - int offset = bid * n; + int bid = blockIdx.x; + bool finish = finished != nullptr ? finished[bid] : false; + int offset = bid * n; - float max_val = -1 * FLT_MAX; + float max_val = -1 * FLT_MAX; __shared__ float s_max_val; __shared__ float s_sum_val; @@ -44,7 +44,7 @@ __global__ void logProbAddCumLogProb(float* log_probs, else { for (int tid = threadIdx.x; tid < n; tid += blockDim.x) { log_probs[offset + tid] = (float)(logits[offset + tid]); - max_val = max(max_val, log_probs[offset + tid]); + max_val = max(max_val, log_probs[offset + tid]); } max_val = blockReduceMax(max_val); @@ -72,14 +72,14 @@ __global__ void logProbAddCumLogProb(float* log_probs, } template -void invokeLogProbAddCumLogProb(float* log_probs, - const T* logits, +void invokeLogProbAddCumLogProb(float* log_probs, + const T* logits, const float* cum_log_probs, - const int* end_ids, - const bool* finished, - const int m, - const int beam_width, - const int n, + const int* end_ids, + const bool* finished, + const int m, + const int beam_width, + const int n, cudaStream_t stream) { dim3 grid(m); @@ -90,23 +90,23 @@ void invokeLogProbAddCumLogProb(float* log_probs, } template -__global__ void updateStatesKernel(T* log_probs, - T* cum_log_probs, - float* output_log_probs, - bool* finished, - int* parent_ids, - int* sequence_length, - int* word_ids, - int* output_ids, - const int local_batch_size, - const int beam_width, - const int vocab_size, +__global__ void updateStatesKernel(T* log_probs, + T* cum_log_probs, + float* output_log_probs, + bool* finished, + int* parent_ids, + int* sequence_length, + int* word_ids, + int* output_ids, + const int local_batch_size, + const int beam_width, + const int vocab_size, const int* end_ids) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < local_batch_size * beam_width; index += blockDim.x * gridDim.x) { - int batch_id = index / beam_width; + int batch_id = index / beam_width; sequence_length[index] = finished[index] ? sequence_length[index] : sequence_length[index] + 1; int beam_id = (word_ids[index] / vocab_size) % beam_width; @@ -117,27 +117,27 @@ __global__ void updateStatesKernel(T* log_probs, output_log_probs[index] = log_probs[batch_id * beam_width * vocab_size + beam_id * vocab_size + word_id] - cum_log_probs[batch_id * beam_width + beam_id]; } - cum_log_probs[index] = log_probs[batch_id * beam_width * vocab_size + beam_id * vocab_size + word_id]; + cum_log_probs[index] = log_probs[batch_id * beam_width * vocab_size + beam_id * vocab_size + word_id]; sequence_length[index] = sequence_length[batch_id * beam_width + beam_id]; - finished[index] = word_id == end_ids[batch_id] ? 1 : 0; - parent_ids[index] = beam_id; - word_ids[index] = word_id; - output_ids[index] = word_id; + finished[index] = word_id == end_ids[batch_id] ? 1 : 0; + parent_ids[index] = beam_id; + word_ids[index] = word_id; + output_ids[index] = word_id; } } -void invokeUpdateStates(float* log_probs, - float* cum_log_probs, - float* output_log_probs, - bool* finished, - int* parent_ids, - int* sequence_length, - int* word_ids, - int* output_ids, - const int local_batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, +void invokeUpdateStates(float* log_probs, + float* cum_log_probs, + float* output_log_probs, + bool* finished, + int* parent_ids, + int* sequence_length, + int* word_ids, + int* output_ids, + const int local_batch_size, + const int beam_width, + const int vocab_size, + const int* end_ids, cudaStream_t stream) { dim3 grid((int)ceil(local_batch_size * beam_width * 1.0 / 256)); @@ -195,7 +195,7 @@ void BeamSearchLayer::invokeSoftMax(std::vector* output_tensors, cons } template -void BeamSearchLayer::invokeSoftMax(std::unordered_map* output_tensors, +void BeamSearchLayer::invokeSoftMax(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) { // input_tensors: @@ -223,14 +223,16 @@ void BeamSearchLayer::invokeSoftMax(std::unordered_map* FT_CHECK(input_tensors->size() >= 7); FT_CHECK(output_tensors->size() >= 6); - const int batch_size = output_tensors->at("output_ids").shape[1]; - const int beam_width = output_tensors->at("output_ids").shape[2]; - const int step = *((int*)input_tensors->at("step").data); - const int ite = *((int*)input_tensors->at("ite").data); - const int local_batch_size = input_tensors->at("logits").shape[0]; - const float diversity_rate = input_tensors->count("beam_search_diversity_rate") ? - input_tensors->at("beam_search_diversity_rate").getVal() : - 0.0f; + const int batch_size = output_tensors->at("output_ids").shape[1]; + const int beam_width = output_tensors->at("output_ids").shape[2]; + const int step = *((int*)input_tensors->at("step").data); + const int ite = *((int*)input_tensors->at("ite").data); + const int local_batch_size = input_tensors->at("logits").shape[0]; + const float diversity_rate = input_tensors->count("beam_search_diversity_rate") ? + input_tensors->at("beam_search_diversity_rate").getVal() : + 0.0f; + const float length_penalty = + input_tensors->count("len_penalty") ? input_tensors->at("len_penalty").getVal() : 0.0f; const int id_offset = step * batch_size * beam_width + ite * local_batch_size * beam_width; invokeLogProbAddCumLogProb(float_log_prob_buf_, @@ -244,17 +246,20 @@ void BeamSearchLayer::invokeSoftMax(std::unordered_map* stream_); sync_check_cuda_error(); - invokeTopkBeamSearch(topk_softmax_workspace_, - topk_softmax_workspace_size_, - float_log_prob_buf_, - ((int*)output_tensors->at("output_ids").data) + id_offset, - (bool*)output_tensors->at("finished").data, - local_batch_size, - beam_width, - vocab_size_padded_, - diversity_rate, - (const int*)input_tensors->at("end_id").data, - stream_); + invokeTopkBeamSearch( + topk_softmax_workspace_, + topk_softmax_workspace_size_, + float_log_prob_buf_, + output_tensors->at("output_ids").getPtrWithOffset(id_offset), + output_tensors->at("finished").getPtr(), + output_tensors->count("sequence_length") ? output_tensors->at("sequence_length").getPtr() : (int*)nullptr, + local_batch_size, + beam_width, + vocab_size_padded_, + diversity_rate, + length_penalty, + (const int*)input_tensors->at("end_id").data, + stream_); sync_check_cuda_error(); float* output_log_probs = @@ -291,36 +296,38 @@ void BeamSearchLayer::allocateBuffer(size_t batch_size, size_t beam_width) nullptr, nullptr, nullptr, + nullptr, batch_size, beam_width, vocab_size_padded_, - 0.0f, + 0.0f, // diversity rate + 0.0f, // length penalty nullptr, stream_); topk_softmax_workspace_ = reinterpret_cast(allocator_->reMalloc( topk_softmax_workspace_, topk_softmax_workspace_size_ + sizeof(float) * batch_size * beam_width * vocab_size_padded_, false)); - float_log_prob_buf_ = (float*)((char*)topk_softmax_workspace_ + topk_softmax_workspace_size_); - is_allocate_buffer_ = true; + float_log_prob_buf_ = (float*)((char*)topk_softmax_workspace_ + topk_softmax_workspace_size_); + is_allocate_buffer_ = true; } template -BeamSearchLayer::BeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, +BeamSearchLayer::BeamSearchLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t beam_width, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + float diversity_rate, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseBeamSearchLayer(max_batch_size, head_num, size_per_head, @@ -348,9 +355,10 @@ BeamSearchLayer::BeamSearchLayer(BeamSearchLayer const& beam_search_layer) template BeamSearchLayer::~BeamSearchLayer() { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); } template class BeamSearchLayer; template class BeamSearchLayer; -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.h b/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.h index 25d97861a..293f67ee4 100644 --- a/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.h +++ b/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.h @@ -34,9 +34,9 @@ class BeamSearchLayer: public BaseBeamSearchLayer { void allocateBuffer() override; void allocateBuffer(size_t batch_size, size_t beam_width) override; - void invokeSoftMax(std::vector* output_tensors, + void invokeSoftMax(std::vector* output_tensors, const std::vector* input_tensors) override; - void invokeSoftMax(std::unordered_map* output_tensors, + void invokeSoftMax(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) override; using BaseBeamSearchLayer::stream_; @@ -47,21 +47,21 @@ class BeamSearchLayer: public BaseBeamSearchLayer { protected: public: - BeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, + BeamSearchLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t beam_width, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + float diversity_rate, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); BeamSearchLayer(BeamSearchLayer const& beam_search_layer); diff --git a/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu b/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu index 5e237f221..4d71104e6 100644 --- a/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu +++ b/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu @@ -19,45 +19,45 @@ namespace fastertransformer { static const int SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS = 128; -static const int MAX_K = 4; +static const int MAX_K = 4; template -__global__ void update_kernel(bool* finished, - int* parent_ids, - int* sequence_length, - int* word_ids, - int* output_ids, - const int vocab_size, +__global__ void update_kernel(bool* finished, + int* parent_ids, + int* sequence_length, + int* word_ids, + int* output_ids, + const int vocab_size, const int* end_ids, - const int local_batch_size, - const int beam_width) + const int local_batch_size, + const int beam_width) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < local_batch_size * beam_width; index += blockDim.x * gridDim.x) { - int batch_id = index / beam_width; + int batch_id = index / beam_width; sequence_length[index] = finished[index] ? sequence_length[index] : sequence_length[index] + 1; int beam_id = (word_ids[index] / vocab_size) % beam_width; int word_id = word_ids[index] % vocab_size; sequence_length[index] = sequence_length[batch_id * beam_width + beam_id]; - finished[index] = word_id == end_ids[index / beam_width] ? 1 : 0; - parent_ids[index] = beam_id; - word_ids[index] = word_id; - output_ids[index] = word_id; + finished[index] = word_id == end_ids[index / beam_width] ? 1 : 0; + parent_ids[index] = beam_id; + word_ids[index] = word_id; + output_ids[index] = word_id; } } -void invokeUpdate(bool* finished, - int* parent_ids, - int* sequence_length, - int* word_ids, - int* output_ids, - const int local_batch_size, - const int beam_width, - const int vocab_size_padded, - const int* end_ids, +void invokeUpdate(bool* finished, + int* parent_ids, + int* sequence_length, + int* word_ids, + int* output_ids, + const int local_batch_size, + const int beam_width, + const int vocab_size_padded, + const int* end_ids, cudaStream_t stream) { dim3 grid((int)ceil(local_batch_size * beam_width * 1.0 / 256)); @@ -75,7 +75,7 @@ void invokeUpdate(bool* finished, } template -void OnlineBeamSearchLayer::invokeSoftMax(std::vector* output_tensors, +void OnlineBeamSearchLayer::invokeSoftMax(std::vector* output_tensors, const std::vector* input_tensors) { // input_tensors: @@ -113,7 +113,7 @@ void OnlineBeamSearchLayer::invokeSoftMax(std::vector* output_tensors } template -void OnlineBeamSearchLayer::invokeSoftMax(std::unordered_map* output_tensors, +void OnlineBeamSearchLayer::invokeSoftMax(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) { // input_tensors: @@ -141,30 +141,35 @@ void OnlineBeamSearchLayer::invokeSoftMax(std::unordered_mapsize() >= 7); FT_CHECK(output_tensors->size() >= 6); - const int batch_size = output_tensors->at("output_ids").shape[1]; - const int beam_width = output_tensors->at("output_ids").shape[2]; - const int step = *((int*)input_tensors->at("step").data); - const int ite = *((int*)input_tensors->at("ite").data); - const int local_batch_size = input_tensors->at("logits").shape[0]; - const float diversity_rate = input_tensors->count("beam_search_diversity_rate") ? - input_tensors->at("beam_search_diversity_rate").getVal() : - 0.0f; + const int batch_size = output_tensors->at("output_ids").shape[1]; + const int beam_width = output_tensors->at("output_ids").shape[2]; + const int step = *((int*)input_tensors->at("step").data); + const int ite = *((int*)input_tensors->at("ite").data); + const int local_batch_size = input_tensors->at("logits").shape[0]; + const float diversity_rate = input_tensors->count("beam_search_diversity_rate") ? + input_tensors->at("beam_search_diversity_rate").getVal() : + 0.0f; + const float length_penalty = + input_tensors->count("len_penalty") ? input_tensors->at("len_penalty").getVal() : 0.0f; + float* output_log_probs = output_tensors->count("output_log_probs") ? (float*)output_tensors->at("output_log_probs").data : nullptr; const int id_offset = step * batch_size * beam_width + local_batch_size * ite * beam_width; - invokeTopkSoftMax((const T*)input_tensors->at("logits").data, + invokeTopkSoftMax(input_tensors->at("logits").getPtr(), (const T*)(nullptr), - (const bool*)output_tensors->at("finished").data, - (float*)output_tensors->at("cum_log_probs").data, + output_tensors->at("finished").getPtr(), + output_tensors->at("sequence_length").getPtr(), + output_tensors->at("cum_log_probs").getPtr(), output_log_probs, - ((int*)output_tensors->at("output_ids").data) + id_offset, + output_tensors->at("output_ids").getPtrWithOffset(id_offset), topk_softmax_workspace_, topk_softmax_workspace_size_, local_batch_size, beam_width, vocab_size_padded_, - (const int*)input_tensors->at("end_id").data, + input_tensors->at("end_id").getPtr(), diversity_rate, + length_penalty, stream_); sync_check_cuda_error(); @@ -201,21 +206,21 @@ void OnlineBeamSearchLayer::allocateBuffer(size_t batch_size, size_t beam_wid } template -OnlineBeamSearchLayer::OnlineBeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, +OnlineBeamSearchLayer::OnlineBeamSearchLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t beam_width, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + float diversity_rate, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseBeamSearchLayer(max_batch_size, head_num, size_per_head, @@ -238,14 +243,16 @@ template OnlineBeamSearchLayer::OnlineBeamSearchLayer(OnlineBeamSearchLayer const& beam_search_layer): BaseBeamSearchLayer(beam_search_layer) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); } template OnlineBeamSearchLayer::~OnlineBeamSearchLayer() { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); } template class OnlineBeamSearchLayer; template class OnlineBeamSearchLayer; -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.h b/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.h index 506c03f8b..a753b57f0 100644 --- a/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.h +++ b/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.h @@ -33,9 +33,9 @@ class OnlineBeamSearchLayer: public BaseBeamSearchLayer { void allocateBuffer() override; void allocateBuffer(size_t batch_size, size_t beam_width) override; - void invokeSoftMax(std::vector* output_tensors, + void invokeSoftMax(std::vector* output_tensors, const std::vector* input_tensors) override; - void invokeSoftMax(std::unordered_map* output_tensors, + void invokeSoftMax(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) override; using BaseBeamSearchLayer::stream_; @@ -44,21 +44,21 @@ class OnlineBeamSearchLayer: public BaseBeamSearchLayer { protected: public: - OnlineBeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, + OnlineBeamSearchLayer(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t beam_width, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + float diversity_rate, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); OnlineBeamSearchLayer(OnlineBeamSearchLayer const& beam_search_layer); diff --git a/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.cc b/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.cc index 6f7df8849..323800916 100644 --- a/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.cc +++ b/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.cc @@ -17,26 +17,73 @@ #include "src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h" #include "src/fastertransformer/kernels/sampling_penalty_kernels.h" +#include "src/fastertransformer/kernels/sampling_topk_kernels.h" #include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/memory_utils.h" + +#include namespace fastertransformer { template -BaseSamplingLayer::BaseSamplingLayer(size_t max_batch_size, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - size_t top_k, - float top_p, +void BaseSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + curandstate_buf_ = reinterpret_cast( + allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, false)); + random_seeds_buf_ = reinterpret_cast( + allocator_->reMalloc(random_seeds_buf_, sizeof(unsigned long long) * batch_size, false)); + temperature_buf_ = + reinterpret_cast(allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false)); + repetition_penalty_buf_ = + reinterpret_cast(allocator_->reMalloc(repetition_penalty_buf_, sizeof(float) * batch_size, false)); + runtime_logits_buf_ = reinterpret_cast( + allocator_->reMalloc(runtime_logits_buf_, sizeof(T) * batch_size * vocab_size_padded_, false)); + skip_decode_buf_ = + reinterpret_cast(allocator_->reMalloc(skip_decode_buf_, sizeof(bool) * batch_size, false)); + + // host buffers. + temperature_ = new float[batch_size]; + repetition_penalty_ = new float[batch_size]; + skip_decode_ = new bool[batch_size]; + + is_allocate_buffer_ = true; +} + +template +void BaseSamplingLayer::freeBuffer() +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (is_allocate_buffer_) { + allocator_->free((void**)(&curandstate_buf_)); + allocator_->free((void**)(&random_seeds_buf_)); + allocator_->free((void**)(&temperature_buf_)); + allocator_->free((void**)(&repetition_penalty_buf_)); + allocator_->free((void**)(&runtime_logits_buf_)); + allocator_->free((void**)(&skip_decode_buf_)); + delete[] temperature_; + delete[] repetition_penalty_; + delete[] skip_decode_; + is_allocate_buffer_ = false; + } +} + +template +BaseSamplingLayer::BaseSamplingLayer(size_t max_batch_size, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + size_t top_k, + float top_p, unsigned long long random_seed, - float temperature_, - float len_penalty_, - float repetition_penalty_, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop): + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop): DynamicDecodeBaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop), vocab_size_(vocab_size), vocab_size_padded_(vocab_size_padded) @@ -57,6 +104,79 @@ BaseSamplingLayer::~BaseSamplingLayer() { } +template +void BaseSamplingLayer::setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) +{ + // Set up the sampling layer for given runtime arguments. + // + // runtime_args: + // runtime_top_k [1] or [batch_size] on cpu, optional. + // runtime_top_p [1] or [batch_size] on cpu, optional + // temperature [1] or [batch_size] on cpu, optional + // repetition_penalty [1] or [batch_size] on cpu, optional + + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + Tensor runtime_top_k = runtime_args->count("runtime_top_k") ? runtime_args->at("runtime_top_k") : Tensor(); + Tensor runtime_top_p = runtime_args->count("runtime_top_p") ? runtime_args->at("runtime_top_p") : Tensor(); + allocateBuffer(batch_size, runtime_top_k, runtime_top_p); + + // If runtime argument has single random seed, using this random seed to initialize the random table of all + // sentences. If the argument has [batch_size] random seeds, initializing the random table by different random seeds + // respectively. If no random seed, initialize the random table of all sentences by 0 directly. + if (runtime_args->count("random_seed")) { + Tensor random_seeds = runtime_args->at("random_seed"); + FT_CHECK_WITH_INFO(random_seeds.shape.size() == 1 + && (random_seeds.size() == 1 || random_seeds.size() == batch_size), + fmtstr("random_seeds must be of shape [1] or [batch_size(%ld)], got random_seeds.shape=%s", + batch_size, + vec2str(random_seeds.shape).c_str())); + if (random_seeds.size() == 1) { + invokeCurandInitialize(curandstate_buf_, batch_size, random_seeds.getVal(), stream_); + sync_check_cuda_error(); + } + else { + unsigned long long* random_seed_ptr = random_seeds.getPtr(); + cudaAutoCpy(random_seeds_buf_, random_seed_ptr, batch_size, stream_); + invokeCurandBatchInitialize(curandstate_buf_, batch_size, random_seeds_buf_, stream_); + sync_check_cuda_error(); + } + } + else { + // Initialize curand states using the default seed 0. + invokeCurandInitialize(curandstate_buf_, batch_size, 0, stream_); + } + + // Setup penalties. + const float default_temperature = 1.0f; + Tensor temperature = runtime_args->count("temperature") ? runtime_args->at("temperature") : + Tensor(MEMORY_CPU, TYPE_FP32, {1}, &default_temperature); + if (temperature.size() == 1) { + float tp = temperature.getVal(); + deviceFill(temperature_buf_, batch_size, tp, stream_); + std::fill_n(temperature_, batch_size, tp); + } + else { + cudaAutoCpy(temperature_buf_, temperature.getPtr(), batch_size, stream_); + std::copy_n(temperature.getPtr(), batch_size, temperature_); + } + + const float default_repetition_penalty = 1.0f; + Tensor repetition_penalty = runtime_args->count("repetition_penalty") ? + runtime_args->at("repetition_penalty") : + Tensor(MEMORY_CPU, TYPE_FP32, {1}, &default_repetition_penalty); + if (repetition_penalty.size() == 1) { + float rp = repetition_penalty.getVal(); + deviceFill(repetition_penalty_buf_, batch_size, rp, stream_); + std::fill_n(repetition_penalty_, batch_size, rp); + } + else { + cudaAutoCpy(repetition_penalty_buf_, repetition_penalty.getPtr(), batch_size, stream_); + std::copy_n(repetition_penalty.getPtr(), batch_size, repetition_penalty_); + } +} + template void BaseSamplingLayer::forward(std::vector* output_tensors, const std::vector* input_tensors) { @@ -94,7 +214,7 @@ void BaseSamplingLayer::forward(std::vector* output_tensors, const st } template -void BaseSamplingLayer::forward(std::unordered_map* output_tensors, +void BaseSamplingLayer::forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) { // input_tensors: @@ -104,12 +224,6 @@ void BaseSamplingLayer::forward(std::unordered_map* outp // max_input_length [1] on cpu // input_lengths [local_batch_size] // ite [1] on cpu - // runtime_top_k [1] on cpu, optional. - // runtime_top_p [1] on cpu, optional - // temperature [1] on cpu, optional - // len_penalty [1] on cpu, optional - // repetition_penalty [1] on cpu, optional - // random_seed [1] on cpu, unsigned long long, optional // output_tensors: // output_ids [max_seq_len, batch_size] @@ -123,43 +237,56 @@ void BaseSamplingLayer::forward(std::unordered_map* outp FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_CHECK(input_tensors->size() >= 6); FT_CHECK(output_tensors->size() >= 3); - const int batch_size = output_tensors->at("output_ids").shape[1]; - + const int batch_size = output_tensors->at("output_ids").shape[1]; const int local_batch_size = input_tensors->at("logits").shape[0]; - const int step = *((int*)input_tensors->at("step").data); - const int ite = *((int*)input_tensors->at("ite").data); - - float temperature = input_tensors->count("temperature") ? input_tensors->at("temperature").getVal() : 1.0f; - - if ((const T*)(input_tensors->at("embedding_bias").data) != nullptr || temperature != 1.0f) { - invokeApplyTemperaturePenalty((T*)input_tensors->at("logits").data, - (const T*)(input_tensors->at("embedding_bias").data), - temperature, - local_batch_size, - vocab_size_, - vocab_size_padded_, - stream_); + const int step = input_tensors->at("step").getVal(); + const int ite = input_tensors->at("ite").getVal(); + T* logits = input_tensors->at("logits").getPtr(); + +#define ALL_OF(p_, sz_, dt_, v_) (std::all_of(p_, p_ + sz_, [](dt_ b) { return b == v_; })) + + bool* skip_decode = skip_decode_ + ite * local_batch_size; + if (ALL_OF(skip_decode, local_batch_size, bool, true)) { + // No sample in the current batch to do TopX sampling. + return; + } + skip_any_ = std::any_of(skip_decode, skip_decode + local_batch_size, [](bool b) { return b; }); + if (skip_any_) { + // A TopX Sampling layer directly changes the logit values. In case of skip_any==true, + // meaning topk and topp layers will run simultaneously for a batch in the same step. + // We copy the logits to an internal buffer, not affecting the other sampling layers. + FT_CHECK(input_tensors->at("logits").size() == local_batch_size * vocab_size_padded_); + cudaD2Dcpy(runtime_logits_buf_, logits, input_tensors->at("logits").size()); + logits = runtime_logits_buf_; + } + + T* embedding_bias = input_tensors->at("embedding_bias").getPtr(); + if (embedding_bias != nullptr || !ALL_OF(temperature_ + ite * local_batch_size, local_batch_size, float, 1.0f)) { + invokeBatchApplyTemperaturePenalty(logits, + embedding_bias, + temperature_buf_ + ite * local_batch_size, + local_batch_size, + vocab_size_, + vocab_size_padded_, + stream_); } sync_check_cuda_error(); - if (step > 1 - && (input_tensors->count("repetition_penalty") - && input_tensors->at("repetition_penalty").getVal() != 1.0f)) { - invokeApplyRepetitionPenalty((T*)input_tensors->at("logits").data, - input_tensors->at("repetition_penalty").getVal(), - nullptr, - (int*)output_tensors->at("output_ids").data, - batch_size, - local_batch_size, - vocab_size_, - vocab_size_padded_, - (int*)input_tensors->at("input_lengths").data, - *((int*)input_tensors->at("max_input_length").data), - step, - ite, - stream_); + if (step > 1 && !ALL_OF(repetition_penalty_ + ite * local_batch_size, local_batch_size, float, 1.0f)) { + invokeBatchApplyRepetitionPenalty( + logits, + repetition_penalty_buf_ + ite * local_batch_size, + output_tensors->at("output_ids").getPtrWithOffset(ite * local_batch_size), + batch_size, + local_batch_size, + vocab_size_padded_, + input_tensors->at("input_lengths").getPtr(), + input_tensors->at("max_input_length").getVal(), + step, + stream_); sync_check_cuda_error(); } +#undef ALL_OF runSampling(output_tensors, input_tensors); @@ -172,4 +299,4 @@ void BaseSamplingLayer::forward(std::unordered_map* outp template class BaseSamplingLayer; template class BaseSamplingLayer; -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h b/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h index 5224154d6..12502751e 100644 --- a/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h +++ b/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h @@ -32,45 +32,57 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer { size_t vocab_size_; size_t vocab_size_padded_; - size_t sampling_workspace_size_; - void* sampling_workspace_ = nullptr; - curandState_t* curandstate_buf_ = nullptr; + size_t sampling_workspace_size_; + void* sampling_workspace_ = nullptr; + curandState_t* curandstate_buf_ = nullptr; + unsigned long long* random_seeds_buf_ = nullptr; - virtual void runSampling(std::vector* output_tensors, - const std::vector* input_tensors) = 0; - virtual void runSampling(std::unordered_map* output_tensors, + float* temperature_buf_ = nullptr; + float* repetition_penalty_buf_ = nullptr; + bool* skip_decode_buf_ = nullptr; + T* runtime_logits_buf_ = nullptr; + + float* temperature_ = nullptr; + float* repetition_penalty_ = nullptr; + bool* skip_decode_ = nullptr; + bool skip_any_ = false; + + virtual void runSampling(std::vector* output_tensors, + const std::vector* input_tensors) = 0; + virtual void runSampling(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) = 0; - virtual void freeBuffer() = 0; + virtual void freeBuffer(); virtual void allocateBuffer() = 0; - virtual void allocateBuffer(size_t batch_size, size_t top_k, float top_p) = 0; - virtual void - invokeInitialize(size_t batch_size, unsigned long long random_seed, curandState_t* curandstate_buf) = 0; + virtual void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p); public: - BaseSamplingLayer(size_t max_batch_size, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - size_t top_k, - float top_p, + BaseSamplingLayer(size_t max_batch_size, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + size_t top_k, + float top_p, unsigned long long random_seed, // TODO(bhsueh) delete - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop); + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop); BaseSamplingLayer(BaseSamplingLayer const& sampling_layer); ~BaseSamplingLayer(); - void forward(std::vector* output_tensors, + void setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) override; + void forward(std::vector* output_tensors, const std::vector* input_tensors) override; - void forward(std::unordered_map* output_tensors, + void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) override; }; diff --git a/src/fastertransformer/layers/sampling_layers/CMakeLists.txt b/src/fastertransformer/layers/sampling_layers/CMakeLists.txt index 88e42648a..215127a21 100644 --- a/src/fastertransformer/layers/sampling_layers/CMakeLists.txt +++ b/src/fastertransformer/layers/sampling_layers/CMakeLists.txt @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.8) add_library(BaseSamplingLayer STATIC BaseSamplingLayer.cc) set_property(TARGET BaseSamplingLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET BaseSamplingLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(BaseSamplingLayer PUBLIC -lcudart sampling_penalty_kernels) +target_link_libraries(BaseSamplingLayer PUBLIC -lcudart sampling_penalty_kernels memory_utils) add_library(TopKSamplingLayer STATIC TopKSamplingLayer.cu) set_property(TARGET TopKSamplingLayer PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.cu b/src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.cu index b7d14a885..3c261041d 100644 --- a/src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.cu +++ b/src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.cu @@ -21,9 +21,65 @@ #include "src/fastertransformer/kernels/sampling_topp_kernels.h" #include "src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.h" #include "src/fastertransformer/utils/logger.h" +#include "src/fastertransformer/utils/memory_utils.h" namespace fastertransformer { +template +__global__ void setup_topk_runtime_args(int batch_size, + uint top_k, + uint* top_ks, + int top_ks_size, + float top_p, + float* top_ps, + int top_ps_size, + bool* skip_decode) +{ + int index = blockIdx.x * gridDim.x + threadIdx.x; + for (int i = index; i < batch_size; i += gridDim.x * blockDim.x) { + uint k = top_ks_size > 1 ? top_ks[i] : top_k; + float p = top_ps_size > 1 ? top_ps[i] : top_p; + if (k == 0 && p == 0.0f) { + // Invalid runtime topk/topp combination. Use a greedy decoding as default. + printf("[WARNING] Invalid runtime topk/topp combination for token %d (topk: %d, topp: %f). " + "Use a greedy decoding as default.\n", + i, + k, + p); + k = 1; + } + if (k > 0 && p == 0.0f) { + // for compatibility <= FT5.0. + // This case corresponds to the old topk sampling, which is equivalent to + // the old topk_topp sampling with topp=1.0f. TopKSamplingLayer and + // TopKTopPSamplingLayer are now merged by TopKSamplingLayer. Thus, we + // replace the case topk>0 and topp=0.0f by topk>0 and topp=1.0f for the + // compatibility. + p = 1.0f; + } + // Clip k value. A topk sampling kernel supports up to TOP_K_MAX=64. + top_ks[i] = k > TOP_K_MAX ? TOP_K_MAX : k; + if (k > TOP_K_MAX) { + printf("[WARNING] topk (%d) is larger than max supported number (%d) for token %d" + " clip to max supported number %d. \n", + k, + TOP_K_MAX, + i, + top_ks[i]); + } + // Clip p value if it is out of range. range = [0.0, 1.0]. + top_ps[i] = p < 0.0f ? 0.0f : (p > 1.0f ? 1.0f : p); + if (p < 0.0f || p > 1.0f) { + printf("[WARNING] topp (%f) is out of range ([0.0, 1.0f]) for token %d" + " clip to closest number %f.\n", + p, + i, + top_ps[i]); + } + skip_decode[i] = k == 0; + } +} + template void TopKSamplingLayer::allocateBuffer() { @@ -31,9 +87,16 @@ void TopKSamplingLayer::allocateBuffer() } template -void TopKSamplingLayer::allocateBuffer(size_t batch_size, size_t top_k, float top_p) +void TopKSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); + BaseSamplingLayer::allocateBuffer(batch_size, top_k, top_p); + uint max_top_k = top_k.size() > 0 ? top_k.max() : 1; + if (max_top_k == 0) { + // for safety. TopKSamplingLayer handles a case of top_k=0 and top_p=0 as + // a greedy decode, i.e. top_k=1, although such case has max_top_k=0. + max_top_k = 1; + } invokeTopKSampling(nullptr, sampling_workspace_size_, nullptr, @@ -43,14 +106,18 @@ void TopKSamplingLayer::allocateBuffer(size_t batch_size, size_t top_k, float nullptr, nullptr, nullptr, - top_k, + max_top_k, + 1.0f, vocab_size_padded_, nullptr, stream_, - batch_size); + batch_size, + skip_decode_buf_); sampling_workspace_ = allocator_->reMalloc(sampling_workspace_, sampling_workspace_size_, false); - curandstate_buf_ = reinterpret_cast( - allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, false)); + runtime_top_k_buf_ = + reinterpret_cast(allocator_->reMalloc(runtime_top_k_buf_, sizeof(uint) * batch_size, false)); + runtime_top_p_buf_ = + reinterpret_cast(allocator_->reMalloc(runtime_top_p_buf_, sizeof(float) * batch_size, false)); is_allocate_buffer_ = true; } @@ -59,24 +126,73 @@ void TopKSamplingLayer::freeBuffer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_) { - allocator_->free(sampling_workspace_); - allocator_->free(curandstate_buf_); - is_allocate_buffer_ = false; + allocator_->free((void**)(&sampling_workspace_)); + allocator_->free((void**)(&runtime_top_k_buf_)); + allocator_->free((void**)(&runtime_top_p_buf_)); } + BaseSamplingLayer::freeBuffer(); + is_allocate_buffer_ = false; } template -void TopKSamplingLayer::invokeInitialize(size_t batch_size, - unsigned long long random_seed, - curandState_t* curandstate_buf) +void TopKSamplingLayer::setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) { + // Setup runtime topk and topp arguments. + // + // runtime_args: + // runtime_top_k [1] or [batch_size] on cpu, optional, uint. + // runtime_top_p [1] or [batch_size] on cpu, optional, float. + // temperature [1] or [batch_size] on cpu, optional + // repetition_penalty [1] or [batch_size] on cpu, optional FT_LOG_DEBUG(__PRETTY_FUNCTION__); - invokeCurandInitialize(curandstate_buf, batch_size, random_seed, stream_); - sync_check_cuda_error(); + BaseSamplingLayer::setup(batch_size, beam_width, runtime_args); + + uint tmp_top_k = 0; + const Tensor runtime_top_k = runtime_args->count("runtime_top_k") ? + runtime_args->at("runtime_top_k") : + Tensor(MEMORY_CPU, TYPE_UINT32, {1}, &tmp_top_k); + const Tensor runtime_top_p = runtime_args->count("runtime_top_p") ? runtime_args->at("runtime_top_p") : Tensor(); + const size_t runtime_top_k_size = runtime_top_k.size(); + const size_t runtime_top_p_size = runtime_top_p.size(); + + uint top_k = runtime_top_k.max(); + float top_p = runtime_top_p_size == 0 ? 0.0f : runtime_top_p.getVal(); + + if (runtime_top_k_size > 1) { + FT_CHECK_WITH_INFO( + runtime_top_k.size() == batch_size, + fmtstr("runtime_top_k.size() (%d) == batch_size (%d) is not satisfied!", runtime_top_k.size(), batch_size)); + cudaAutoCpy(runtime_top_k_buf_, runtime_top_k.getPtr(), batch_size, stream_); + } + if (runtime_top_p_size > 1) { + FT_CHECK_WITH_INFO( + runtime_top_p.size() == batch_size, + fmtstr("runtime_top_p.size() (%d) == batch_size (%d) is not satisfied!", runtime_top_p.size(), batch_size)); + cudaAutoCpy(runtime_top_p_buf_, runtime_top_p.getPtr(), batch_size, stream_); + } + + dim3 block(std::min((int)batch_size, 1024)); + dim3 grid(div_up((int)batch_size, (int)block.x)); + // support top_k up to 1024. + setup_topk_runtime_args<1024><<>>(batch_size, + top_k, + runtime_top_k_buf_, + runtime_top_k_size, + top_p, + runtime_top_p_buf_, + runtime_top_p_size, + skip_decode_buf_); + cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_); + uint* runtime_top_ks = new uint[batch_size]; + cudaAutoCpy(runtime_top_ks, runtime_top_k_buf_, batch_size, stream_); + runtime_max_top_k_ = static_cast(*std::max_element(runtime_top_ks, runtime_top_ks + batch_size)); + delete[] runtime_top_ks; } template -void TopKSamplingLayer::runSampling(std::vector* output_tensors, +void TopKSamplingLayer::runSampling(std::vector* output_tensors, const std::vector* input_tensors) { // input_tensors: @@ -115,7 +231,7 @@ void TopKSamplingLayer::runSampling(std::vector* o } template -void TopKSamplingLayer::runSampling(std::unordered_map* output_tensors, +void TopKSamplingLayer::runSampling(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) { // input_tensors: @@ -125,11 +241,6 @@ void TopKSamplingLayer::runSampling(std::unordered_map* // max_input_length [1] on cpu // input_lengths [local_batch_size] // ite [1] on cpu - // runtime_top_k [1] or [batch_size] on cpu, optional - // temperature [1] or [batch_size] on cpu, optional - // len_penalty [1] or [batch_size] on cpu, optional - // repetition_penalty [1] or [batch_size] on cpu, optional - // random_seed [1] or [batch_size] on cpu, optional // output_tensors: // output_ids [max_seq_len, batch_size] @@ -144,24 +255,15 @@ void TopKSamplingLayer::runSampling(std::unordered_map* FT_CHECK(input_tensors->size() >= 6); FT_CHECK(output_tensors->size() >= 3); - const int batch_size = output_tensors->at("output_ids").shape[1]; + const int batch_size = output_tensors->at("output_ids").shape[1]; const int local_batch_size = input_tensors->at("logits").shape[0]; - const int step = *((int*)input_tensors->at("step").data); - const int ite = *((int*)input_tensors->at("ite").data); - - const int runtime_top_k = input_tensors->at("runtime_top_k").shape[0] == 1 ? - input_tensors->at("runtime_top_k").getVal(0) : - input_tensors->at("runtime_top_k").getVal(ite * local_batch_size); - allocateBuffer(batch_size, runtime_top_k, 0.0f); - if (input_tensors->count("random_seed")) { - unsigned long long int random_seed = - input_tensors->at("random_seed").shape[0] == 1 ? - (unsigned long long int)input_tensors->at("random_seed").getVal(0) : - (unsigned long long int)input_tensors->at("random_seed").getVal(ite * local_batch_size); - invokeInitialize(local_batch_size, random_seed, curandstate_buf_ + ite * local_batch_size); - } + const int ite = *((int*)input_tensors->at("ite").data); + const int step = *((int*)input_tensors->at("step").data); + + // in case of skip any, the logit value is already copied and processed. + T* logits = !skip_any_ ? input_tensors->at("logits").getPtr() : runtime_logits_buf_; - invokeAddBiasEndMask((T*)(input_tensors->at("logits").data), + invokeAddBiasEndMask(logits, (T*)(nullptr), (const int*)input_tensors->at("end_id").data, (bool*)output_tensors->at("finished").data, @@ -176,48 +278,53 @@ void TopKSamplingLayer::runSampling(std::unordered_map* output_tensors->count("output_log_probs") ? output_tensors->at("output_log_probs").getPtr() : nullptr; if (cum_log_probs != nullptr || output_log_probs != nullptr) { - invokeAddBiasSoftMax((T*)(input_tensors->at("logits").data), + invokeAddBiasSoftMax(logits, (T*)(nullptr), - (const int*)input_tensors->at("end_id").data, - (bool*)output_tensors->at("finished").data, + input_tensors->at("end_id").getPtr(), + output_tensors->at("finished").getPtr(), local_batch_size, vocab_size_padded_, vocab_size_, stream_); + sync_check_cuda_error(); } - invokeTopKSampling( + invokeBatchTopKSampling( sampling_workspace_, sampling_workspace_size_, - input_tensors->at("logits").getPtr(), + logits, output_tensors->at("output_ids").getPtrWithOffset(step * batch_size + ite * local_batch_size), output_tensors->at("sequence_length").getPtr(), output_tensors->at("finished").getPtr(), cum_log_probs, output_log_probs, curandstate_buf_ + ite * local_batch_size, - runtime_top_k, + (int)runtime_max_top_k_, // useless because runtime_top_k_buf_ is never nullptr. Keep for legacy. + (int*)(runtime_top_k_buf_ + ite * local_batch_size), + 1.0f, // useless because runtime_top_p_buf_ is never nullptr. Keep for legacy. + runtime_top_p_buf_ + ite * local_batch_size, vocab_size_padded_, - (const int*)input_tensors->at("end_id").data, + input_tensors->at("end_id").getPtr(), stream_, - local_batch_size); + local_batch_size, + skip_decode_buf_ + ite * local_batch_size); sync_check_cuda_error(); } template -TopKSamplingLayer::TopKSamplingLayer(size_t max_batch_size, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - size_t top_k, +TopKSamplingLayer::TopKSamplingLayer(size_t max_batch_size, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + size_t top_k, unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseSamplingLayer(max_batch_size, vocab_size, vocab_size_padded, @@ -245,6 +352,7 @@ TopKSamplingLayer::TopKSamplingLayer(TopKSamplingLayer const& top_k_sampli template TopKSamplingLayer::~TopKSamplingLayer() { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); freeBuffer(); } diff --git a/src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.h b/src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.h index 7983be0b8..f4380e0da 100644 --- a/src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.h +++ b/src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.h @@ -18,29 +18,36 @@ #pragma once #include "src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h" +#include "src/fastertransformer/utils/memory_utils.h" namespace fastertransformer { template class TopKSamplingLayer: public BaseSamplingLayer { private: - void runSampling(std::vector* output_tensors, + void runSampling(std::vector* output_tensors, const std::vector* input_tensors) override; - void runSampling(std::unordered_map* output_tensors, + void runSampling(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) override; void freeBuffer() override; void allocateBuffer() override; - void allocateBuffer(size_t batch_size, size_t top_k, float top_p) override; - void invokeInitialize(size_t batch_size, unsigned long long random_seed, curandState_t* curandstate_buf) override; - bool isValidTopK(size_t runtime_top_k); + void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p) override; + uint runtime_max_top_k_ = 1; + uint* runtime_top_k_buf_ = nullptr; + float* runtime_top_p_buf_ = nullptr; using BaseSamplingLayer::vocab_size_; using BaseSamplingLayer::vocab_size_padded_; using BaseSamplingLayer::sampling_workspace_size_; using BaseSamplingLayer::sampling_workspace_; using BaseSamplingLayer::curandstate_buf_; + using BaseSamplingLayer::random_seeds_buf_; + using BaseSamplingLayer::skip_decode_buf_; + using BaseSamplingLayer::skip_decode_; + using BaseSamplingLayer::skip_any_; + using BaseSamplingLayer::runtime_logits_buf_; using BaseSamplingLayer::stream_; using BaseSamplingLayer::allocator_; @@ -48,23 +55,25 @@ class TopKSamplingLayer: public BaseSamplingLayer { protected: public: - TopKSamplingLayer(size_t max_batch_size, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - size_t top_k, + TopKSamplingLayer(size_t max_batch_size, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + size_t top_k, unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); - + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward); TopKSamplingLayer(TopKSamplingLayer const& top_k_sampling_layer); - ~TopKSamplingLayer(); + + void setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/sampling_layers/TopKTopPSamplingLayer.cu b/src/fastertransformer/layers/sampling_layers/TopKTopPSamplingLayer.cu index 3413db103..282aabaeb 100644 --- a/src/fastertransformer/layers/sampling_layers/TopKTopPSamplingLayer.cu +++ b/src/fastertransformer/layers/sampling_layers/TopKTopPSamplingLayer.cu @@ -31,8 +31,10 @@ void TopKTopPSamplingLayer::allocateBuffer() } template -void TopKTopPSamplingLayer::allocateBuffer(size_t batch_size, size_t top_k, float top_p) +void TopKTopPSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + BaseSamplingLayer::allocateBuffer(batch_size, top_k, top_p); invokeTopKTopPSampling(nullptr, // workspace sampling_workspace_size_, nullptr, // output_ids @@ -43,14 +45,12 @@ void TopKTopPSamplingLayer::allocateBuffer(size_t batch_size, size_t top_k, f nullptr, // output_log_probs nullptr, // curandstate batch_size, - top_k, - (T)top_p, + static_cast(top_k.max()), + top_p.max(), vocab_size_padded_, nullptr, stream_); sampling_workspace_ = allocator_->reMalloc(sampling_workspace_, sampling_workspace_size_, true); - curandstate_buf_ = reinterpret_cast( - allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, false)); is_allocate_buffer_ = true; } @@ -59,24 +59,14 @@ void TopKTopPSamplingLayer::freeBuffer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_) { - allocator_->free(sampling_workspace_); - allocator_->free(curandstate_buf_); - is_allocate_buffer_ = false; + allocator_->free((void**)(&sampling_workspace_)); } + BaseSamplingLayer::freeBuffer(); + is_allocate_buffer_ = false; } template -void TopKTopPSamplingLayer::invokeInitialize(size_t batch_size, - unsigned long long random_seed, - curandState_t* curandstate_buf) -{ - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - invokeCurandInitialize(curandstate_buf, batch_size, random_seed, stream_); - sync_check_cuda_error(); -} - -template -void TopKTopPSamplingLayer::runSampling(std::vector* output_tensors, +void TopKTopPSamplingLayer::runSampling(std::vector* output_tensors, const std::vector* input_tensors) { // input_tensors: @@ -113,7 +103,7 @@ void TopKTopPSamplingLayer::runSampling(std::vector -void TopKTopPSamplingLayer::runSampling(std::unordered_map* output_tensors, +void TopKTopPSamplingLayer::runSampling(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) { // input_tensors: @@ -123,12 +113,12 @@ void TopKTopPSamplingLayer::runSampling(std::unordered_map::runSampling(std::unordered_mapsize() >= 6); FT_CHECK(output_tensors->size() >= 3); - const int batch_size = output_tensors->at("output_ids").shape[1]; + const int batch_size = output_tensors->at("output_ids").shape[1]; const int local_batch_size = input_tensors->at("logits").shape[0]; - const int step = *((int*)input_tensors->at("step").data); - const int ite = *((int*)input_tensors->at("ite").data); - - const int runtime_top_k = input_tensors->at("runtime_top_k").shape[0] == 1 ? - input_tensors->at("runtime_top_k").getVal(0) : - input_tensors->at("runtime_top_k").getVal(ite * local_batch_size); - - const float runtime_top_p = input_tensors->at("runtime_top_p").shape[0] == 1 ? - input_tensors->at("runtime_top_p").getVal(0) : - input_tensors->at("runtime_top_p").getVal(ite * local_batch_size); - allocateBuffer(batch_size, runtime_top_k, runtime_top_p); - if (input_tensors->find("random_seed") != input_tensors->end()) { - unsigned long long int random_seed = - input_tensors->at("random_seed").shape[0] == 1 ? - (unsigned long long int)input_tensors->at("random_seed").getVal(0) : - (unsigned long long int)input_tensors->at("random_seed").getVal(ite * local_batch_size); - invokeInitialize(local_batch_size, random_seed, curandstate_buf_ + ite * local_batch_size); + const int step = *((int*)input_tensors->at("step").data); + const int ite = *((int*)input_tensors->at("ite").data); + + int runtime_top_k = (int)(input_tensors->at("runtime_top_k").shape[0] == 1 ? + input_tensors->at("runtime_top_k").getVal(0) : + input_tensors->at("runtime_top_k").getVal(ite * local_batch_size)); + + float runtime_top_p = input_tensors->at("runtime_top_p").shape[0] == 1 ? + input_tensors->at("runtime_top_p").getVal(0) : + input_tensors->at("runtime_top_p").getVal(ite * local_batch_size); + if ((runtime_top_k <= 0 || runtime_top_k > 64)) { + FT_LOG_ERROR("Invalid runtime topk: %d, using default topk value %d to do topk sampling.", runtime_top_k, 1); + runtime_top_k = 1; + } + if ((runtime_top_p <= 0.0f || runtime_top_p > 1.0f)) { + FT_LOG_ERROR( + "Invalid runtime topp: %f, using default topp value %f to do topp sampling.", runtime_top_p, 0.0001f); + runtime_top_p = 0.0001f; } invokeAddBiasEndMask((T*)(input_tensors->at("logits").data), @@ -200,7 +191,7 @@ void TopKTopPSamplingLayer::runSampling(std::unordered_mapat("end_id").data, stream_); @@ -208,20 +199,20 @@ void TopKTopPSamplingLayer::runSampling(std::unordered_map -TopKTopPSamplingLayer::TopKTopPSamplingLayer(size_t max_batch_size, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - int top_k, - float top_p, +TopKTopPSamplingLayer::TopKTopPSamplingLayer(size_t max_batch_size, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + int top_k, + float top_p, unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseSamplingLayer(max_batch_size, vocab_size, vocab_size_padded, @@ -238,6 +229,7 @@ TopKTopPSamplingLayer::TopKTopPSamplingLayer(size_t max_batch_size, is_free_buffer_after_forward, nullptr) { + FT_LOG_WARNING("TopKTopPSamplingLayer is deprecated. Please use TopKSamplingLayer instead."); } template diff --git a/src/fastertransformer/layers/sampling_layers/TopKTopPSamplingLayer.h b/src/fastertransformer/layers/sampling_layers/TopKTopPSamplingLayer.h index afffe015d..202013faf 100644 --- a/src/fastertransformer/layers/sampling_layers/TopKTopPSamplingLayer.h +++ b/src/fastertransformer/layers/sampling_layers/TopKTopPSamplingLayer.h @@ -24,15 +24,14 @@ namespace fastertransformer { template class TopKTopPSamplingLayer: public BaseSamplingLayer { private: - void runSampling(std::vector* output_tensors, + void runSampling(std::vector* output_tensors, const std::vector* input_tensors) override; - void runSampling(std::unordered_map* output_tensors, + void runSampling(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) override; void allocateBuffer() override; - void allocateBuffer(size_t batch_size, size_t top_k, float top_p) override; + void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p) override; void freeBuffer() override; - void invokeInitialize(size_t batch_size, unsigned long long random_seed, curandState_t* curandstate_buf) override; bool isValidTopK(size_t runtime_top_k); using BaseSamplingLayer::vocab_size_; @@ -41,6 +40,7 @@ class TopKTopPSamplingLayer: public BaseSamplingLayer { using BaseSamplingLayer::sampling_workspace_size_; using BaseSamplingLayer::sampling_workspace_; using BaseSamplingLayer::curandstate_buf_; + using BaseSamplingLayer::random_seeds_buf_; using BaseSamplingLayer::stream_; using BaseSamplingLayer::allocator_; @@ -48,20 +48,20 @@ class TopKTopPSamplingLayer: public BaseSamplingLayer { protected: public: - TopKTopPSamplingLayer(size_t max_batch_size, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - int top_k, - float top_p, + TopKTopPSamplingLayer(size_t max_batch_size, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + int top_k, + float top_p, unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward); TopKTopPSamplingLayer(TopKTopPSamplingLayer const& topk_topp_sampling_layer); diff --git a/src/fastertransformer/layers/sampling_layers/TopPSamplingLayer.cu b/src/fastertransformer/layers/sampling_layers/TopPSamplingLayer.cu index caf178d89..858df7384 100644 --- a/src/fastertransformer/layers/sampling_layers/TopPSamplingLayer.cu +++ b/src/fastertransformer/layers/sampling_layers/TopPSamplingLayer.cu @@ -15,6 +15,7 @@ * limitations under the License. */ +#include #include #include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" @@ -22,9 +23,47 @@ #include "src/fastertransformer/kernels/sampling_topp_kernels.h" #include "src/fastertransformer/layers/sampling_layers/TopPSamplingLayer.h" #include "src/fastertransformer/utils/logger.h" +#include "src/fastertransformer/utils/memory_utils.h" namespace fastertransformer { +static __global__ void set_topp_runtime_args(int batch_size, + uint top_k, + uint* top_ks, + int top_ks_size, + float top_p, + float* top_ps, + int top_ps_size, + bool* skip_decode) +{ + int index = blockIdx.x * gridDim.x + threadIdx.x; + for (int i = index; i < batch_size; i += gridDim.x * blockDim.x) { + uint k = top_ks_size > 1 ? top_ks[i] : top_k; + float p = top_ps_size > 1 ? top_ps[i] : top_p; + if (k == 0 && p == 0.0f) { + // Invalid runtime topk/topp combination. Use a greedy + // decoding as default and topk sampling will handle. + printf("[WARNING] Invalid runtime topk/topp combination for token %d (topk: %d, topp: %f). " + "Use a greedy decoding as default.\n", + i, + k, + p); + k = 1; + } + top_ks[i] = k; + // Clip p value if it is out of range. range = [0.0, 1.0]. + top_ps[i] = p < 0.0f ? 0.0f : (p > 1.0f ? 1.0f : p); + if (p < 0.0f || p > 1.0f) { + printf("[WARNING] topp (%f) is out of range ([0.0, 1.0f]) for token %d" + " clip to closest number %f.\n", + p, + i, + top_ps[i]); + } + skip_decode[i] = k > 0; + } +} + template void TopPSamplingLayer::allocateBuffer() { @@ -32,9 +71,10 @@ void TopPSamplingLayer::allocateBuffer() } template -void TopPSamplingLayer::allocateBuffer(size_t batch_size, size_t top_k, float top_p) +void TopPSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); + BaseSamplingLayer::allocateBuffer(batch_size, top_k, top_p); invokeTopPSampling(nullptr, // workspace sampling_workspace_size_, cub_temp_storage_size_, @@ -51,14 +91,15 @@ void TopPSamplingLayer::allocateBuffer(size_t batch_size, size_t top_k, float batch_size, vocab_size_padded_, nullptr, - top_p, + top_p.size() > 0 ? top_p.max() : 0.0f, stream_, - cuda_device_prop_); - + cuda_device_prop_, + skip_decode_buf_); sampling_workspace_ = allocator_->reMalloc(sampling_workspace_, sampling_workspace_size_, true); - curandstate_buf_ = reinterpret_cast( - allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, true)); - + runtime_top_k_buf_ = + reinterpret_cast(allocator_->reMalloc(runtime_top_k_buf_, sizeof(uint) * batch_size, false)); + runtime_top_p_buf_ = + reinterpret_cast(allocator_->reMalloc(runtime_top_p_buf_, sizeof(float) * batch_size, false)); topp_id_vals_buf_ = reinterpret_cast( allocator_->reMalloc(topp_id_vals_buf_, sizeof(int) * batch_size * vocab_size_padded_, false)); topp_offset_buf_ = @@ -72,28 +113,77 @@ template void TopPSamplingLayer::freeBuffer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - if (is_allocate_buffer_ == true) { - allocator_->free(sampling_workspace_); - allocator_->free(curandstate_buf_); - allocator_->free(topp_id_vals_buf_); - allocator_->free(topp_offset_buf_); - allocator_->free(begin_topp_offset_buf_); - is_allocate_buffer_ = false; + if (is_allocate_buffer_) { + allocator_->free((void**)(&sampling_workspace_)); + allocator_->free((void**)(&topp_id_vals_buf_)); + allocator_->free((void**)(&topp_offset_buf_)); + allocator_->free((void**)(&begin_topp_offset_buf_)); + allocator_->free((void**)(&runtime_top_k_buf_)); + allocator_->free((void**)(&runtime_top_p_buf_)); } + BaseSamplingLayer::freeBuffer(); + is_allocate_buffer_ = false; } template -void TopPSamplingLayer::invokeInitialize(size_t batch_size, - unsigned long long random_seed, - curandState_t* curandstate_buf) +void TopPSamplingLayer::setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) { + // Set up the sampling layer for given runtime arguments. + // + // runtime_args: + // runtime_top_k [1] or [batch_size] on cpu, optional. + // runtime_top_p [1] or [batch_size] on cpu, optional + // temperature [1] or [batch_size] on cpu, optional + // repetition_penalty [1] or [batch_size] on cpu, optional + FT_LOG_DEBUG(__PRETTY_FUNCTION__); - invokeCurandInitialize(curandstate_buf, batch_size, random_seed, stream_); - sync_check_cuda_error(); + BaseSamplingLayer::setup(batch_size, beam_width, runtime_args); + const Tensor runtime_top_p = runtime_args->count("runtime_top_p") ? runtime_args->at("runtime_top_p") : Tensor(); + const size_t runtime_top_p_size = runtime_top_p.size(); + if (runtime_top_p_size == 0) { + std::fill_n(skip_decode_, batch_size, true); + return; + } + + uint tmp_top_k = 0; + const Tensor runtime_top_k = runtime_args->count("runtime_top_k") ? + runtime_args->at("runtime_top_k") : + Tensor(MEMORY_CPU, TYPE_UINT32, {1}, &tmp_top_k); + const size_t runtime_top_k_size = runtime_top_k.size(); + + uint top_k = runtime_top_k.getVal(); + float top_p = runtime_top_p.getVal(); + + if (runtime_top_k_size > 1) { + FT_CHECK(runtime_top_k.size() == batch_size); + cudaH2Dcpy(runtime_top_k_buf_, runtime_top_k.getPtr(), batch_size); + } + if (runtime_top_p_size > 1) { + FT_CHECK(runtime_top_p.size() == batch_size); + cudaH2Dcpy(runtime_top_p_buf_, runtime_top_p.getPtr(), batch_size); + } + + dim3 block(std::min((int)batch_size, 1024)); + dim3 grid(div_up((int)batch_size, (int)block.x)); + set_topp_runtime_args<<>>(batch_size, + top_k, + runtime_top_k_buf_, + runtime_top_k_size, + top_p, + runtime_top_p_buf_, + runtime_top_p_size, + skip_decode_buf_); + cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_); + float* runtime_top_ps = new float[batch_size]; + cudaAutoCpy(runtime_top_ps, runtime_top_p_buf_, batch_size, stream_); + runtime_max_top_p_ = *std::max_element(runtime_top_ps, runtime_top_ps + batch_size); + delete[] runtime_top_ps; } template -void TopPSamplingLayer::runSampling(std::vector* output_tensors, +void TopPSamplingLayer::runSampling(std::vector* output_tensors, const std::vector* input_tensors) { // input_tensors: @@ -134,7 +224,7 @@ void TopPSamplingLayer::runSampling(std::vector* o } template -void TopPSamplingLayer::runSampling(std::unordered_map* output_tensors, +void TopPSamplingLayer::runSampling(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) { // input_tensors: @@ -144,11 +234,6 @@ void TopPSamplingLayer::runSampling(std::unordered_map* // max_input_length [1] on cpu // input_lengths [local_batch_size] // ite [1] on cpu - // runtime_top_p [1] or [batch_size] on cpu, optional - // temperature [1] or [batch_size] on cpu, optional - // len_penalty [1] or [batch_size] on cpu, optional - // repetition_penalty [1] or [batch_size] on cpu, optional - // random_seed [1] or [batch_size] on cpu, optional // output_tensors: // output_ids [max_seq_len, batch_size] @@ -162,32 +247,22 @@ void TopPSamplingLayer::runSampling(std::unordered_map* FT_CHECK(input_tensors->size() >= 6); FT_CHECK(output_tensors->size() >= 3); - const int batch_size = output_tensors->at("output_ids").shape[1]; + const int batch_size = output_tensors->at("output_ids").shape[1]; const int local_batch_size = input_tensors->at("logits").shape[0]; - const int step = *((int*)input_tensors->at("step").data); - const int ite = *((int*)input_tensors->at("ite").data); + const int step = *((int*)input_tensors->at("step").data); + const int ite = *((int*)input_tensors->at("ite").data); - const float runtime_top_p = input_tensors->at("runtime_top_p").shape[0] == 1 ? - input_tensors->at("runtime_top_p").getVal(0) : - input_tensors->at("runtime_top_p").getVal(ite * local_batch_size); - allocateBuffer(batch_size, 0, runtime_top_p); + // in case of skip any, the logit value is already copied and processed. + T* logits = !skip_any_ ? input_tensors->at("logits").getPtr() : runtime_logits_buf_; invokeTopPInitialize( topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, local_batch_size, vocab_size_padded_, stream_); sync_check_cuda_error(); - if (input_tensors->find("random_seed") != input_tensors->end()) { - unsigned long long int random_seed = - input_tensors->at("random_seed").shape[0] == 1 ? - (unsigned long long int)input_tensors->at("random_seed").getVal(0) : - (unsigned long long int)input_tensors->at("random_seed").getVal(ite * local_batch_size); - invokeInitialize(local_batch_size, random_seed, curandstate_buf_ + ite * local_batch_size); - } - - invokeAddBiasSoftMax((T*)(input_tensors->at("logits").data), + invokeAddBiasSoftMax(logits, (T*)(nullptr), - (const int*)input_tensors->at("end_id").data, - (bool*)output_tensors->at("finished").data, + input_tensors->at("end_id").getPtr(), + output_tensors->at("finished").getPtr(), local_batch_size, vocab_size_padded_, vocab_size_, @@ -198,43 +273,47 @@ void TopPSamplingLayer::runSampling(std::unordered_map* output_tensors->count("cum_log_probs") ? output_tensors->at("cum_log_probs").getPtr() : nullptr; float* output_log_probs = output_tensors->count("output_log_probs") ? output_tensors->at("output_log_probs").getPtr() : nullptr; - invokeTopPSampling(sampling_workspace_, - sampling_workspace_size_, - cub_temp_storage_size_, - ((int*)output_tensors->at("output_ids").data) + step * batch_size + ite * local_batch_size, - (int*)output_tensors->at("sequence_length").data, - (bool*)output_tensors->at("finished").data, - cum_log_probs, - output_log_probs, - (T*)(input_tensors->at("logits").data), - topp_id_vals_buf_, - topp_offset_buf_, - begin_topp_offset_buf_, - curandstate_buf_ + ite * local_batch_size, - local_batch_size, - vocab_size_padded_, - (const int*)input_tensors->at("end_id").data, - runtime_top_p, - stream_, - cuda_device_prop_); + + invokeBatchTopPSampling( + sampling_workspace_, + sampling_workspace_size_, + cub_temp_storage_size_, + output_tensors->at("output_ids").getPtrWithOffset(step * batch_size + ite * local_batch_size), + output_tensors->at("sequence_length").getPtr(), + output_tensors->at("finished").getPtr(), + cum_log_probs, + output_log_probs, + logits, + topp_id_vals_buf_, + topp_offset_buf_, + begin_topp_offset_buf_, + curandstate_buf_ + ite * local_batch_size, + local_batch_size, + vocab_size_padded_, + input_tensors->at("end_id").getPtr(), + runtime_max_top_p_, + runtime_top_p_buf_ + ite * local_batch_size, + stream_, + cuda_device_prop_, + skip_decode_buf_ + ite * local_batch_size); sync_check_cuda_error(); } template -TopPSamplingLayer::TopPSamplingLayer(size_t max_batch_size, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float top_p, +TopPSamplingLayer::TopPSamplingLayer(size_t max_batch_size, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + float top_p, unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop): + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop): BaseSamplingLayer(max_batch_size, vocab_size, vocab_size_padded, diff --git a/src/fastertransformer/layers/sampling_layers/TopPSamplingLayer.h b/src/fastertransformer/layers/sampling_layers/TopPSamplingLayer.h index 2895d39fb..9adcb0577 100644 --- a/src/fastertransformer/layers/sampling_layers/TopPSamplingLayer.h +++ b/src/fastertransformer/layers/sampling_layers/TopPSamplingLayer.h @@ -24,19 +24,22 @@ namespace fastertransformer { template class TopPSamplingLayer: public BaseSamplingLayer { private: - void runSampling(std::vector* output_tensors, + void runSampling(std::vector* output_tensors, const std::vector* input_tensors) override; - void runSampling(std::unordered_map* output_tensors, + void runSampling(std::unordered_map* output_tensors, const std::unordered_map* input_tensors) override; void allocateBuffer() override; - void allocateBuffer(size_t batch_size, size_t top_k, float top_p) override; + void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p) override; void freeBuffer() override; - void invokeInitialize(size_t batch_size, unsigned long long random_seed, curandState_t* curandstate_buf) override; - int* topp_id_vals_buf_; - int* topp_offset_buf_; - int* begin_topp_offset_buf_; + uint* runtime_top_k_buf_ = nullptr; + float* runtime_top_p_buf_ = nullptr; + float runtime_max_top_p_; + + int* topp_id_vals_buf_; + int* topp_offset_buf_; + int* begin_topp_offset_buf_; size_t cub_temp_storage_size_; using BaseSamplingLayer::vocab_size_; @@ -45,6 +48,11 @@ class TopPSamplingLayer: public BaseSamplingLayer { using BaseSamplingLayer::sampling_workspace_size_; using BaseSamplingLayer::sampling_workspace_; using BaseSamplingLayer::curandstate_buf_; + using BaseSamplingLayer::random_seeds_buf_; + using BaseSamplingLayer::skip_decode_buf_; + using BaseSamplingLayer::skip_decode_; + using BaseSamplingLayer::skip_any_; + using BaseSamplingLayer::runtime_logits_buf_; using BaseSamplingLayer::stream_; using BaseSamplingLayer::allocator_; @@ -53,24 +61,26 @@ class TopPSamplingLayer: public BaseSamplingLayer { protected: public: - TopPSamplingLayer(size_t max_batch_size, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float top_p, + TopPSamplingLayer(size_t max_batch_size, + size_t vocab_size, + size_t vocab_size_padded, + int end_id, + float top_p, unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop); - + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop); TopPSamplingLayer(TopPSamplingLayer const& top_p_sampling_layer); - ~TopPSamplingLayer(); + + void setup(const size_t batch_size, + const size_t beam_width, + const std::unordered_map* runtime_args) override; }; } // namespace fastertransformer diff --git a/src/fastertransformer/layers/xlnet_attention_layers/XlnetAttentionLayer.cc b/src/fastertransformer/layers/xlnet_attention_layers/XlnetAttentionLayer.cc index 62e1a4698..81803e987 100644 --- a/src/fastertransformer/layers/xlnet_attention_layers/XlnetAttentionLayer.cc +++ b/src/fastertransformer/layers/xlnet_attention_layers/XlnetAttentionLayer.cc @@ -19,12 +19,12 @@ namespace fastertransformer { template -void XlnetAttentionLayer::forward(std::vector* output_tensors, +void XlnetAttentionLayer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const XlnetAttentionWeight* attention_weights) + const XlnetAttentionWeight* attention_weights) { const size_t request_batch_size = input_tensors->at(0).shape[0]; - const size_t request_seq_len = input_tensors->at(0).shape[1]; + const size_t request_seq_len = input_tensors->at(0).shape[1]; FT_CHECK(isValidBatchSize(input_tensors->at(1).shape[0])); FT_CHECK(isValidSeqLen(input_tensors->at(1).shape[2])); @@ -40,11 +40,11 @@ void XlnetAttentionLayer::forward(std::vector* out FT_CHECK(input_tensors->at(1).shape[1] == request_seq_len); FT_CHECK(input_tensors->at(2).shape[1] == request_seq_len); - T* out_tensor = (T*)output_tensors->at(0).data; - T* in_tensor = (T*)input_tensors->at(0).data; + T* out_tensor = (T*)output_tensors->at(0).data; + T* in_tensor = (T*)input_tensors->at(0).data; T* attention_mask = (T*)input_tensors->at(1).data; - T* seg_mat = (T*)input_tensors->at(2).data; - T* attr_k_head_r = (T*)input_tensors->at(3).data; + T* seg_mat = (T*)input_tensors->at(2).data; + T* attr_k_head_r = (T*)input_tensors->at(3).data; cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, CUBLAS_OP_N, @@ -161,7 +161,6 @@ void XlnetAttentionLayer::forward(std::vector* out invokeTranspose201(request_batch_size, request_seq_len, head_num_, qk_buf_ef_seg_trans_, qk_buf_ef_seg_, stream_); invokeRelShiftBd(request_batch_size, head_num_, request_seq_len, qk_buf_bd_shift_, qk_buf_bd_, stream_); - invokeCalAttnScore(request_batch_size, head_num_, request_seq_len, @@ -209,15 +208,15 @@ void XlnetAttentionLayer::forward(std::vector* out } template -XlnetAttentionLayer::XlnetAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - float q_scaling, - cudaStream_t stream, +XlnetAttentionLayer::XlnetAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -233,33 +232,44 @@ template void XlnetAttentionLayer::allocateBuffer() { if (is_allocate_buffer_ == false) { - k_head_r_ = (T*)allocator_->malloc(sizeof(T) * max_seq_len_ * 2 * hidden_units_, false); - query_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_ * 3, false); - key_buf_ = query_buf_ + max_batch_size_ * max_seq_len_ * hidden_units_; + k_head_r_ = (T*)allocator_->reMalloc(k_head_r_, sizeof(T) * max_seq_len_ * 2 * hidden_units_, false); + query_buf_ = + (T*)allocator_->reMalloc(query_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_ * 3, false); + key_buf_ = query_buf_ + max_batch_size_ * max_seq_len_ * hidden_units_; value_buf_ = query_buf_ + 2 * max_batch_size_ * max_seq_len_ * hidden_units_; - q_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - k_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - qk_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * head_num_, false); - q_buf_bd_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - k_buf_bd_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * 2 * hidden_units_, false); - qk_buf_bd_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * max_seq_len_ * 2, false); - qk_buf_bd_shift_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * max_seq_len_, false); - q_buf_ef_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - k_buf_ef_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_ * 2, false); - qk_buf_ef_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * 2, false); - qk_buf_ef_trans_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * 2, false); - qk_buf_ef_seg_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * head_num_, false); - qk_buf_ef_seg_trans_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * head_num_, false); - attn_score_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * head_num_, false); - value_buf_trans_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - attn_vec_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - attn_vec_trans_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - attn_out_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + q_buf_ = (T*)allocator_->reMalloc(q_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + k_buf_ = (T*)allocator_->reMalloc(k_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + qk_buf_ = (T*)allocator_->reMalloc( + qk_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * head_num_, false); + q_buf_bd_ = + (T*)allocator_->reMalloc(q_buf_bd_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + k_buf_bd_ = + (T*)allocator_->reMalloc(k_buf_bd_, sizeof(T) * max_batch_size_ * max_seq_len_ * 2 * hidden_units_, false); + qk_buf_bd_ = (T*)allocator_->reMalloc( + qk_buf_bd_, sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * max_seq_len_ * 2, false); + qk_buf_bd_shift_ = (T*)allocator_->reMalloc( + qk_buf_bd_shift_, sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * max_seq_len_, false); + q_buf_ef_ = + (T*)allocator_->reMalloc(q_buf_ef_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + k_buf_ef_ = (T*)allocator_->reMalloc(k_buf_ef_, sizeof(T) * max_batch_size_ * hidden_units_ * 2, false); + qk_buf_ef_ = + (T*)allocator_->reMalloc(qk_buf_ef_, sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * 2, false); + qk_buf_ef_trans_ = (T*)allocator_->reMalloc( + qk_buf_ef_trans_, sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * 2, false); + qk_buf_ef_seg_ = (T*)allocator_->reMalloc( + qk_buf_ef_seg_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * head_num_, false); + qk_buf_ef_seg_trans_ = (T*)allocator_->reMalloc( + qk_buf_ef_seg_trans_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * head_num_, false); + attn_score_ = (T*)allocator_->reMalloc( + attn_score_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * head_num_, false); + value_buf_trans_ = (T*)allocator_->reMalloc( + value_buf_trans_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + attn_vec_ = + (T*)allocator_->reMalloc(attn_vec_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + attn_vec_trans_ = (T*)allocator_->reMalloc( + attn_vec_trans_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + attn_out_ = + (T*)allocator_->reMalloc(attn_out_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); is_allocate_buffer_ = true; } @@ -293,26 +303,26 @@ template void XlnetAttentionLayer::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(k_head_r_); - allocator_->free(query_buf_); - allocator_->free(q_buf_); - allocator_->free(k_buf_); - allocator_->free(qk_buf_); - allocator_->free(q_buf_bd_); - allocator_->free(k_buf_bd_); - allocator_->free(qk_buf_bd_); - allocator_->free(qk_buf_bd_shift_); - allocator_->free(q_buf_ef_); - allocator_->free(k_buf_ef_); - allocator_->free(qk_buf_ef_); - allocator_->free(qk_buf_ef_trans_); - allocator_->free(qk_buf_ef_seg_); - allocator_->free(qk_buf_ef_seg_trans_); - allocator_->free(attn_score_); - allocator_->free(value_buf_trans_); - allocator_->free(attn_vec_); - allocator_->free(attn_vec_trans_); - allocator_->free(attn_out_); + allocator_->free((void**)(&k_head_r_)); + allocator_->free((void**)(&query_buf_)); + allocator_->free((void**)(&q_buf_)); + allocator_->free((void**)(&k_buf_)); + allocator_->free((void**)(&qk_buf_)); + allocator_->free((void**)(&q_buf_bd_)); + allocator_->free((void**)(&k_buf_bd_)); + allocator_->free((void**)(&qk_buf_bd_)); + allocator_->free((void**)(&qk_buf_bd_shift_)); + allocator_->free((void**)(&q_buf_ef_)); + allocator_->free((void**)(&k_buf_ef_)); + allocator_->free((void**)(&qk_buf_ef_)); + allocator_->free((void**)(&qk_buf_ef_trans_)); + allocator_->free((void**)(&qk_buf_ef_seg_)); + allocator_->free((void**)(&qk_buf_ef_seg_trans_)); + allocator_->free((void**)(&attn_score_)); + allocator_->free((void**)(&value_buf_trans_)); + allocator_->free((void**)(&attn_vec_)); + allocator_->free((void**)(&attn_vec_trans_)); + allocator_->free((void**)(&attn_out_)); is_allocate_buffer_ = false; } @@ -327,5 +337,8 @@ XlnetAttentionLayer::~XlnetAttentionLayer() template class XlnetAttentionLayer; template class XlnetAttentionLayer; +#ifdef ENABLE_BF16 +template class XlnetAttentionLayer<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/layers/xlnet_attention_layers/XlnetAttentionLayer.h b/src/fastertransformer/layers/xlnet_attention_layers/XlnetAttentionLayer.h index 7b855b84d..1c1b0fb5f 100644 --- a/src/fastertransformer/layers/xlnet_attention_layers/XlnetAttentionLayer.h +++ b/src/fastertransformer/layers/xlnet_attention_layers/XlnetAttentionLayer.h @@ -27,11 +27,11 @@ class XlnetAttentionLayer: public BaseLayer { private: // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; // metadata - int head_num_; - int size_per_head_; + int head_num_; + int size_per_head_; float q_scaling_; // calculated params @@ -73,39 +73,39 @@ class XlnetAttentionLayer: public BaseLayer { T* attn_out_; public: - XlnetAttentionLayer(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - float q_scaling, - cudaStream_t stream, + XlnetAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); XlnetAttentionLayer(XlnetAttentionLayer const& attention_layer); ~XlnetAttentionLayer(); - void oneToManyCublasGemm(T* d_A, - T* d_B, - T* d_C, + void oneToManyCublasGemm(T* d_A, + T* d_B, + T* d_C, cublasOperation_t transa, cublasOperation_t transb, - int v_m, - int v_n, - int v_k, - int lda, - int strideA, - int ldb, - int strideB, - int ldc, - int strideC, - int batch); + int v_m, + int v_n, + int v_k, + int lda, + int strideA, + int ldb, + int strideB, + int ldc, + int strideC, + int batch); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const XlnetAttentionWeight* xlnet_attention_weights); + const XlnetAttentionWeight* xlnet_attention_weights); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/BaseWeight.h b/src/fastertransformer/models/BaseWeight.h new file mode 100644 index 000000000..f610baa07 --- /dev/null +++ b/src/fastertransformer/models/BaseWeight.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#pragma once + +namespace fastertransformer { + +template +struct FtWeight { + +public: + std::string name_; + std::vector shape_; + size_t size_ = 0; + T* ptr_ = nullptr; + + FtWeight() {} + FtWeight(const std::string name, const std::vector shape, T* ptr): name_(name), shape_(shape), ptr_(ptr) + { + size_ = 1; + for (uint i = 0; i < shape_.size(); i++) { + size_ *= shape_[i]; + } + } + + ~FtWeight() + { + size_ = 0; + ptr_ = nullptr; + } +}; + +} // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/models/CMakeLists.txt b/src/fastertransformer/models/CMakeLists.txt index af33e764c..7cce05b88 100644 --- a/src/fastertransformer/models/CMakeLists.txt +++ b/src/fastertransformer/models/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(xlnet) add_subdirectory(t5) add_subdirectory(gptj) +add_subdirectory(gptneox) add_subdirectory(multi_gpu_gpt) add_subdirectory(swin) add_subdirectory(swin_int8) diff --git a/src/fastertransformer/models/bert/Bert.cc b/src/fastertransformer/models/bert/Bert.cc index ac727df8f..7acbb97f1 100644 --- a/src/fastertransformer/models/bert/Bert.cc +++ b/src/fastertransformer/models/bert/Bert.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,82 +22,93 @@ namespace fastertransformer { template void Bert::initialize() { - if ((attention_type_ == AttentionType::FUSED_MHA || attention_type_ == AttentionType::FUSED_PADDED_MHA) - && std::is_same::value == true && max_seq_len_ <= 384) { - attention_layer_ = new FusedAttentionLayer(max_batch_size_, - max_seq_len_, - head_num_, - size_per_head_, - sm_, - q_scaling_, - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_, - sparse_); + if (std::is_same::value + && (attention_type_ == AttentionType::FUSED_MHA || attention_type_ == AttentionType::FUSED_PADDED_MHA)) { + fused_attention_layer_ = new FusedAttentionLayer(0, + 0, + head_num_ / tensor_para_.world_size_, + size_per_head_, + head_num_ * size_per_head_, + sm_, + q_scaling_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + sparse_); } - else if (attention_type_ == AttentionType::UNFUSED_MHA || attention_type_ == AttentionType::UNFUSED_PADDED_MHA) { - attention_layer_ = new UnfusedAttentionLayer(max_batch_size_, - max_seq_len_, - head_num_, - size_per_head_, - q_scaling_, - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_, - sparse_); - } - else { - throw std::runtime_error(std::string("[FT][ERROR] Invalid attention type \n")); - } - + unfused_attention_layer_ = new UnfusedAttentionLayer(0, + 0, + head_num_ / tensor_para_.world_size_, + size_per_head_, + head_num_ * size_per_head_, + q_scaling_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + sparse_); + + bool use_gated_activation = activation_type_ == ActivationType::GeGLU || activation_type_ == ActivationType::ReGLU; if (activation_type_ == ActivationType::Gelu) { - ffn_layer_ = new GeluFfnLayer(max_batch_size_, - max_seq_len_, - head_num_, - size_per_head_, - inter_size_, - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_, - sparse_); + ffn_layer_ = new TensorParallelGeluFfnLayer(0, + 0, + head_num_, + size_per_head_, + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + true, + is_free_buffer_after_forward_, + sparse_, + 0, + use_gated_activation, + custom_all_reduce_comm_, + enable_custom_all_reduce_); } else if (activation_type_ == ActivationType::Relu) { - ffn_layer_ = new ReluFfnLayer(max_batch_size_, - max_seq_len_, - head_num_, - size_per_head_, - inter_size_, - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_, - sparse_); + ffn_layer_ = new TensorParallelReluFfnLayer(0, + 0, + head_num_, + size_per_head_, + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + true, + is_free_buffer_after_forward_, + sparse_, + use_gated_activation, + custom_all_reduce_comm_, + enable_custom_all_reduce_); } } template -Bert::Bert(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - int sm, - float q_scaling, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type, - bool sparse, - ActivationType activation_type, - LayerNormType layernorm_type): +Bert::Bert(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + int sm, + float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type, + bool sparse, + ActivationType activation_type, + LayerNormType layernorm_type, + NcclParam tensor_para, + NcclParam pipeline_para, + std::shared_ptr custom_all_reduce_comm, + bool enable_custom_all_reduce): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), - max_batch_size_(max_batch_size), - max_seq_len_(max_seq_len), head_num_(head_num), size_per_head_(size_per_head), inter_size_(inter_size), @@ -108,35 +119,87 @@ Bert::Bert(size_t max_batch_size, attention_type_(attention_type), sparse_(sparse), activation_type_(activation_type), - layernorm_type_(layernorm_type) + layernorm_type_(layernorm_type), + tensor_para_(tensor_para), + pipeline_para_(pipeline_para), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce) { initialize(); } +template +Bert::Bert(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + int sm, + float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type, + bool sparse, + ActivationType activation_type, + LayerNormType layernorm_type): + Bert(max_batch_size, + max_seq_len, + head_num, + size_per_head, + inter_size, + num_layer, + sm, + q_scaling, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + attention_type, + sparse, + activation_type, + layernorm_type, + NcclParam(0, 1), + NcclParam(0, 1), + nullptr, + false) +{ +} + template Bert::Bert(Bert const& bert): - BaseLayer(bert), - max_batch_size_(bert.max_batch_size_), - max_seq_len_(bert.max_seq_len_), - head_num_(bert.head_num_), - size_per_head_(bert.size_per_head_), - inter_size_(bert.inter_size_), - hidden_units_(bert.hidden_units_), - num_layer_(bert.num_layer_), - sm_(bert.sm_), - q_scaling_(bert.q_scaling_), - attention_type_(bert.attention_type_), - sparse_(bert.sparse_), - activation_type_(bert.activation_type_), - layernorm_type_(bert.layernorm_type_) + Bert(0, + 0, + bert.head_num_, + bert.size_per_head_, + bert.inter_size_, + bert.num_layer_, + bert.sm_, + bert.q_scaling_, + bert.stream_, + bert.cublas_wrapper_, + bert.allocator_, + bert.is_free_buffer_after_forward_, + bert.attention_type_, + bert.sparse_, + bert.activation_type_, + bert.layernorm_type_, + bert.tensor_para_, + bert.pipeline_para_, + bert.custom_all_reduce_comm_, + bert.enable_custom_all_reduce_) { - initialize(); } template Bert::~Bert() { - delete attention_layer_; + if (fused_attention_layer_ != nullptr) { + delete fused_attention_layer_; + } + delete unfused_attention_layer_; delete ffn_layer_; freeBuffer(); } @@ -144,39 +207,14 @@ Bert::~Bert() template void Bert::allocateBuffer() { - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - if (is_allocate_buffer_ == false) { - token_num_ = (size_t*)allocator_->malloc(sizeof(size_t) * 1, false); - padding_offset_ = (int*)allocator_->malloc(sizeof(int) * max_batch_size_ * max_seq_len_, false); - trt_mha_padding_offset_ = (int*)allocator_->malloc(sizeof(int) * (2 * max_batch_size_ + 1), false); - - attention_mask_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); - - bert_in_buffer_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * size_per_head_, false); - attn_out_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - bert_out_buffer_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * size_per_head_, false); - - if (layernorm_type_ == LayerNormType::post_layernorm) { - normed_from_tensor_ = nullptr; - normed_attn_out_buf_ = nullptr; - } - else { - normed_from_tensor_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - normed_attn_out_buf_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - } - is_allocate_buffer_ = true; - } + FT_CHECK(false); } template void Bert::allocateBuffer(size_t batch_size, size_t seq_len) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - token_num_ = (size_t*)allocator_->reMalloc(token_num_, sizeof(size_t) * 1, false); + token_num_ = (size_t*)allocator_->reMalloc(token_num_, sizeof(size_t) * 1, false); padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false); trt_mha_padding_offset_ = (int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * batch_size + 1), false); @@ -185,12 +223,12 @@ void Bert::allocateBuffer(size_t batch_size, size_t seq_len) bert_in_buffer_ = (T*)allocator_->reMalloc(bert_in_buffer_, sizeof(T) * batch_size * seq_len * head_num_ * size_per_head_, false); - attn_out_buf_ = (T*)allocator_->reMalloc(attn_out_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); + attn_out_buf_ = (T*)allocator_->reMalloc(attn_out_buf_, sizeof(T) * batch_size * seq_len * hidden_units_, false); bert_out_buffer_ = (T*)allocator_->reMalloc( bert_out_buffer_, sizeof(T) * batch_size * seq_len * head_num_ * size_per_head_, false); if (layernorm_type_ == LayerNormType::post_layernorm) { - normed_from_tensor_ = nullptr; + normed_from_tensor_ = nullptr; normed_attn_out_buf_ = nullptr; } else { @@ -204,261 +242,413 @@ void Bert::allocateBuffer(size_t batch_size, size_t seq_len) template void Bert::freeBuffer() { - allocator_->free(token_num_); - allocator_->free(padding_offset_); - allocator_->free(trt_mha_padding_offset_); + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + allocator_->free((void**)(&token_num_)); + allocator_->free((void**)(&padding_offset_)); + allocator_->free((void**)(&trt_mha_padding_offset_)); - allocator_->free(attention_mask_); - allocator_->free(bert_in_buffer_); - allocator_->free(attn_out_buf_); - allocator_->free(bert_out_buffer_); + allocator_->free((void**)(&attention_mask_)); + allocator_->free((void**)(&bert_in_buffer_)); + allocator_->free((void**)(&attn_out_buf_)); + allocator_->free((void**)(&bert_out_buffer_)); if (layernorm_type_ == LayerNormType::post_layernorm) { - normed_from_tensor_ = nullptr; + normed_from_tensor_ = nullptr; normed_attn_out_buf_ = nullptr; } else { - allocator_->free(normed_from_tensor_); - allocator_->free(normed_attn_out_buf_); + allocator_->free((void**)(&normed_from_tensor_)); + allocator_->free((void**)(&normed_attn_out_buf_)); } } template -void Bert::forward(std::vector* output_tensors, - const std::vector* input_tensors, - const BertWeight* bert_weights) +bool Bert::isValidLayerParallelId(uint l) { - // input_tensors: - // input_query [batch, seqlen, hidden] - // sequence_length [batch] - // output tensors: - // output hidden state [batch, seqlen, hidden] - - const size_t request_batch_size = input_tensors->at(0).shape[0]; - const size_t request_seq_len = input_tensors->at(0).shape[1]; - FT_CHECK(input_tensors->size() == 2); - FT_CHECK(isValidBatchSize(request_batch_size)); - FT_CHECK(isValidSeqLen(request_seq_len)); - FT_CHECK(request_batch_size == input_tensors->at(1).shape[0]); - FT_CHECK(input_tensors->at(0).shape.size() == 3); - FT_CHECK(input_tensors->at(1).shape.size() == 1); - allocateBuffer(request_batch_size, request_seq_len); + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l >= local_num_layer * pipeline_para_.rank_) + && (l < local_num_layer * (pipeline_para_.rank_ + 1)); +} - const int* sequence_lengths = reinterpret_cast(input_tensors->at(1).data); +template +bool Bert::isFirstLayerParallelId(uint l) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * pipeline_para_.rank_); +} - size_t h_token_num; - T* bert_input_ptr; - T* bert_output_ptr; - Tensor* padding_offset_tensor_ptr; +template +bool Bert::isLastLayerParallelId(uint l) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * (pipeline_para_.rank_ + 1) - 1); +} - // preprocess (remove padding and build mask) - switch (attention_type_) { - case AttentionType::UNFUSED_MHA: { - invokeBuildEncoderAttentionMask( - attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); - sync_check_cuda_error(); - invokeGetPaddingOffset(&h_token_num, - token_num_, - padding_offset_, - sequence_lengths, - request_batch_size, - request_seq_len, - stream_); - - invokeRemovePadding(bert_in_buffer_, - (const T*)input_tensors->at(0).data, - padding_offset_, - h_token_num, - head_num_ * size_per_head_, - stream_); - sync_check_cuda_error(); - bert_input_ptr = bert_in_buffer_; - bert_output_ptr = bert_out_buffer_; +template +int Bert::getFirstLayerParallelId() +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return local_num_layer * pipeline_para_.rank_; +} - padding_offset_tensor_ptr = - new Tensor(MEMORY_GPU, TYPE_INT32, std::vector{h_token_num}, padding_offset_); - break; - } - case AttentionType::UNFUSED_PADDED_MHA: { - invokeBuildEncoderAttentionMask( - attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); - sync_check_cuda_error(); - h_token_num = request_batch_size * request_seq_len; - bert_input_ptr = (T*)input_tensors->at(0).data; - bert_output_ptr = (T*)output_tensors->at(0).data; - padding_offset_tensor_ptr = new Tensor(MEMORY_GPU, TYPE_INT32, std::vector{0}, nullptr); - break; - } - case AttentionType::FUSED_MHA: { - invokeGetPaddingOffset(&h_token_num, - token_num_, - padding_offset_, - sequence_lengths, - request_batch_size, - request_seq_len, - stream_); - - invokeRemovePadding(bert_in_buffer_, - (const T*)input_tensors->at(0).data, - padding_offset_, - h_token_num, - head_num_ * size_per_head_, - stream_); - sync_check_cuda_error(); - bert_input_ptr = bert_in_buffer_; - bert_output_ptr = bert_out_buffer_; +template +void Bert::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const BertWeight* bert_weights) +{ + TensorMap input_tensors_map = + TensorMap({{"input_hidden_state", input_tensors->at(0)}, {"sequence_lengths", input_tensors->at(1)}}); + TensorMap output_tensors_map = TensorMap({{"output_hidden_state", output_tensors->at(0)}}); + forward(&output_tensors_map, &input_tensors_map, bert_weights); +} - invokeGetTrtPaddingOffset(trt_mha_padding_offset_, sequence_lengths, request_batch_size, stream_); +template +void Bert::forward(TensorMap* output_tensors, TensorMap* input_tensors, const BertWeight* bert_weights) +{ + // input_tensors: + // input_hidden_state [batch, seqlen, hidden] + // sequence_lengths [batch] + // output tensors: + // output_hidden_state [batch, seqlen, hidden] - padding_offset_tensor_ptr = new Tensor( - MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size + 1}, trt_mha_padding_offset_); - break; - } - case AttentionType::FUSED_PADDED_MHA: { - h_token_num = request_batch_size * request_seq_len; - invokeGetTrtPaddingOffset( - trt_mha_padding_offset_, sequence_lengths, request_batch_size, request_seq_len, stream_); - padding_offset_tensor_ptr = new Tensor( - MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size * 2 + 1}, trt_mha_padding_offset_); - bert_input_ptr = (T*)input_tensors->at(0).data; - bert_output_ptr = (T*)output_tensors->at(0).data; - break; + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + const size_t request_batch_size = input_tensors->at("input_hidden_state").shape[0]; + const size_t request_seq_len = input_tensors->at("input_hidden_state").shape[1]; + FT_CHECK(input_tensors->size() >= 2); + FT_CHECK(request_batch_size == input_tensors->at("sequence_lengths").shape[0]); + FT_CHECK(input_tensors->at("input_hidden_state").shape.size() == 3); + FT_CHECK(input_tensors->at("sequence_lengths").shape.size() == 1); + allocateBuffer(request_batch_size, request_seq_len); + + const int* sequence_lengths = input_tensors->at("sequence_lengths").getPtr(); + + DataType data_type = getTensorType(); + const size_t local_batch_size = getLocalBatchSize(request_batch_size, request_seq_len, pipeline_para_.world_size_); + FT_CHECK(request_batch_size % local_batch_size == 0); + const size_t iteration_num = request_batch_size / local_batch_size; + AttentionType attention_type = attention_type_; + if (fused_attention_layer_ == nullptr || fused_attention_layer_->isValidSeqLen(request_seq_len) == false) { + if (attention_type == AttentionType::FUSED_MHA) { + FT_LOG_WARNING("Because the input is invalid for fused mha, switch to unfused mha."); + attention_type = AttentionType::UNFUSED_MHA; } - default: { - throw std::runtime_error(std::string("[FT][ERROR] Invalid attention type \n")); + else if (attention_type == AttentionType::FUSED_PADDED_MHA) { + FT_LOG_WARNING("Because the input is invalid for fused mha, switch to unfused mha."); + attention_type = AttentionType::UNFUSED_PADDED_MHA; } } - DataType data_type = getTensorType(); - for (uint i = 0; i < num_layer_; i++) { - const T* from_tensor = (const T*)(i == 0 ? bert_input_ptr : bert_output_ptr); - T* out_tensor = bert_output_ptr; - - if (layernorm_type_ == LayerNormType::pre_layernorm) { - invokeGeneralLayerNorm(normed_from_tensor_, - from_tensor, - bert_weights->bert_layer_weights[i].attn_layernorm_weights.gamma, - bert_weights->bert_layer_weights[i].attn_layernorm_weights.beta, - h_token_num, - hidden_units_, - stream_); + for (uint ite = 0; ite < iteration_num; ite++) { + Tensor* padding_offset_tensor_ptr = nullptr; + const size_t hidden_offset = ite * local_batch_size * request_seq_len * hidden_units_; + size_t h_token_num = 0; + + T* bert_input_ptr; + T* bert_output_ptr; + + switch (attention_type) { + case AttentionType::UNFUSED_MHA: { + invokeBuildEncoderAttentionMask( + attention_mask_, + input_tensors->at("sequence_lengths").getPtrWithOffset(ite * local_batch_size), + local_batch_size, + request_seq_len, + stream_); + sync_check_cuda_error(); + invokeGetPaddingOffset( + &h_token_num, + token_num_, + padding_offset_, + input_tensors->at("sequence_lengths").getPtrWithOffset(ite * local_batch_size), + local_batch_size, + request_seq_len, + stream_); + + invokeRemovePadding(bert_in_buffer_, + input_tensors->at("input_hidden_state").getPtrWithOffset(hidden_offset), + padding_offset_, + h_token_num, + head_num_ * size_per_head_, + stream_); + sync_check_cuda_error(); + + padding_offset_tensor_ptr = + new Tensor(MEMORY_GPU, TYPE_INT32, std::vector{h_token_num}, padding_offset_); + bert_input_ptr = bert_in_buffer_; + bert_output_ptr = bert_out_buffer_; + sync_check_cuda_error(); + break; + } + case AttentionType::UNFUSED_PADDED_MHA: { + invokeBuildEncoderAttentionMask( + attention_mask_, + input_tensors->at("sequence_lengths").getPtrWithOffset(ite * local_batch_size), + local_batch_size, + request_seq_len, + stream_); + sync_check_cuda_error(); + h_token_num = local_batch_size * request_seq_len; + bert_input_ptr = input_tensors->at("input_hidden_state").getPtrWithOffset(hidden_offset); + bert_output_ptr = output_tensors->at("output_hidden_state").getPtrWithOffset(hidden_offset); + padding_offset_tensor_ptr = new Tensor(MEMORY_GPU, TYPE_INT32, std::vector{0}, nullptr); + sync_check_cuda_error(); + break; + } + case AttentionType::FUSED_MHA: { + invokeGetPaddingOffset( + &h_token_num, + token_num_, + padding_offset_, + input_tensors->at("sequence_lengths").getPtrWithOffset(ite * local_batch_size), + local_batch_size, + request_seq_len, + stream_); + + invokeRemovePadding(bert_in_buffer_, + input_tensors->at("input_hidden_state").getPtrWithOffset(hidden_offset), + padding_offset_, + h_token_num, + head_num_ * size_per_head_, + stream_); + sync_check_cuda_error(); + + invokeGetTrtPaddingOffset( + trt_mha_padding_offset_, + input_tensors->at("sequence_lengths").getPtrWithOffset(ite * local_batch_size), + local_batch_size, + stream_); + + padding_offset_tensor_ptr = new Tensor( + MEMORY_GPU, TYPE_INT32, std::vector{local_batch_size + 1}, trt_mha_padding_offset_); + bert_input_ptr = bert_in_buffer_; + bert_output_ptr = bert_out_buffer_; + sync_check_cuda_error(); + break; + } + case AttentionType::FUSED_PADDED_MHA: { + h_token_num = local_batch_size * request_seq_len; + invokeGetTrtPaddingOffset( + trt_mha_padding_offset_, + input_tensors->at("sequence_lengths").getPtrWithOffset(ite * local_batch_size), + local_batch_size, + request_seq_len, + stream_); + sync_check_cuda_error(); + padding_offset_tensor_ptr = new Tensor( + MEMORY_GPU, TYPE_INT32, std::vector{local_batch_size * 2 + 1}, trt_mha_padding_offset_); + bert_input_ptr = input_tensors->at("input_hidden_state").getPtrWithOffset(hidden_offset); + bert_output_ptr = output_tensors->at("output_hidden_state").getPtrWithOffset(hidden_offset); + break; + } + default: { + throw std::runtime_error(std::string("[FT][ERROR] Invalid attention type \n")); + } } - // Attention - { - std::vector attn_input_tensors{ - Tensor{MEMORY_GPU, - data_type, - std::vector{h_token_num, hidden_units_}, - layernorm_type_ == LayerNormType::pre_layernorm ? normed_from_tensor_ : from_tensor}, - Tensor{MEMORY_GPU, - data_type, - std::vector{request_batch_size, 1, request_seq_len, request_seq_len}, - attention_mask_}, - *padding_offset_tensor_ptr}; - std::vector attn_output_tensors{ - Tensor{MEMORY_GPU, data_type, std::vector{h_token_num, hidden_units_}, attn_out_buf_}}; - - attention_layer_->forward( - &attn_output_tensors, &attn_input_tensors, &bert_weights->bert_layer_weights[i].attention_weights); - } + for (uint l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l) == false) { + continue; + } + T* from_tensor = l == 0 ? bert_input_ptr : bert_output_ptr; + T* out_tensor = bert_output_ptr; + + if (isFirstLayerParallelId(l) && pipeline_para_.rank_ != 0 && pipeline_para_.world_size_ > 1) { + ftNcclRecv(from_tensor + h_token_num * hidden_units_ / tensor_para_.world_size_ * tensor_para_.rank_, + h_token_num * hidden_units_ / tensor_para_.world_size_, + pipeline_para_.rank_ - 1, + pipeline_para_, + stream_); + if (tensor_para_.world_size_ > 1) { + ftNcclAllGather(from_tensor, + from_tensor, + h_token_num * hidden_units_ / tensor_para_.world_size_, + tensor_para_.rank_, + tensor_para_, + stream_); + } + } + + if (layernorm_type_ == LayerNormType::pre_layernorm) { + invokeGeneralLayerNorm(normed_from_tensor_, + from_tensor, + bert_weights->bert_layer_weights[l].attn_layernorm_weights.gamma, + bert_weights->bert_layer_weights[l].attn_layernorm_weights.beta, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + sync_check_cuda_error(); + } + // Attention + { + std::vector attn_input_tensors{ + Tensor{MEMORY_GPU, + data_type, + std::vector{h_token_num, hidden_units_}, + layernorm_type_ == LayerNormType::pre_layernorm ? normed_from_tensor_ : from_tensor}, + Tensor{MEMORY_GPU, + data_type, + std::vector{local_batch_size, 1, request_seq_len, request_seq_len}, + attention_mask_}, + *padding_offset_tensor_ptr}; + std::vector attn_output_tensors{ + Tensor{MEMORY_GPU, data_type, std::vector{h_token_num, hidden_units_}, attn_out_buf_}}; + + bool use_custom_all_reduce_kernel = false; + if (enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr) { + use_custom_all_reduce_kernel = + custom_all_reduce_comm_->swapInternalBuffer(&attn_output_tensors, h_token_num * hidden_units_); + } + + if (attention_type == AttentionType::FUSED_MHA || attention_type == AttentionType::FUSED_PADDED_MHA) { + fused_attention_layer_->forward(&attn_output_tensors, + &attn_input_tensors, + &bert_weights->bert_layer_weights[l].attention_weights); + } + else if (attention_type == AttentionType::UNFUSED_MHA + || attention_type == AttentionType::UNFUSED_PADDED_MHA) { + unfused_attention_layer_->forward(&attn_output_tensors, + &attn_input_tensors, + &bert_weights->bert_layer_weights[l].attention_weights); + } + + if (tensor_para_.world_size_ > 1) { + if (!use_custom_all_reduce_kernel) { + ftNcclAllReduceSum( + attn_out_buf_, attn_out_buf_, h_token_num * hidden_units_, tensor_para_, stream_); + } + else { + custom_all_reduce_comm_->customAllReduce(h_token_num * hidden_units_, stream_); + } + sync_check_cuda_error(); + } + } + + if (layernorm_type_ == LayerNormType::post_layernorm) { + invokeAddBiasResidualLayerNorm( + attn_out_buf_, + from_tensor, + bert_weights->bert_layer_weights[l].attention_weights.attention_output_weight.bias, + bert_weights->bert_layer_weights[l].attn_layernorm_weights.gamma, + bert_weights->bert_layer_weights[l].attn_layernorm_weights.beta, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + } + else if (layernorm_type_ == LayerNormType::pre_layernorm) { + invokeGeneralAddBiasResidualPreLayerNorm( + attn_out_buf_, + normed_attn_out_buf_, + from_tensor, + bert_weights->bert_layer_weights[l].ffn_layernorm_weights.gamma, + bert_weights->bert_layer_weights[l].ffn_layernorm_weights.beta, + bert_weights->bert_layer_weights[l].attention_weights.attention_output_weight.bias, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + } + sync_check_cuda_error(); - if (layernorm_type_ == LayerNormType::post_layernorm) { - invokeAddBiasResidualLayerNorm( - attn_out_buf_, - from_tensor, - bert_weights->bert_layer_weights[i].attention_weights.attention_output_weight.bias, - bert_weights->bert_layer_weights[i].attn_layernorm_weights.gamma, - bert_weights->bert_layer_weights[i].attn_layernorm_weights.beta, - h_token_num, - hidden_units_, - stream_); - } - else if (layernorm_type_ == LayerNormType::pre_layernorm) { - invokeGeneralAddBiasResidualPreLayerNorm( - attn_out_buf_, - normed_attn_out_buf_, - from_tensor, - bert_weights->bert_layer_weights[i].ffn_layernorm_weights.gamma, - bert_weights->bert_layer_weights[i].ffn_layernorm_weights.beta, - bert_weights->bert_layer_weights[i].attention_weights.attention_output_weight.bias, - h_token_num, - hidden_units_, - stream_); - } + // FFN + { + std::vector ffn_input_tensors{ + Tensor{MEMORY_GPU, + data_type, + std::vector{h_token_num, hidden_units_}, + layernorm_type_ == LayerNormType::pre_layernorm ? normed_attn_out_buf_ : attn_out_buf_}}; + std::vector ffn_output_tensors{ + Tensor{MEMORY_GPU, data_type, std::vector{h_token_num, hidden_units_}, out_tensor}}; + ffn_layer_->forward( + &ffn_output_tensors, &ffn_input_tensors, &bert_weights->bert_layer_weights[l].ffn_weights); + } + + if (layernorm_type_ == LayerNormType::post_layernorm) { + invokeAddBiasResidualLayerNorm(out_tensor, + attn_out_buf_, + bert_weights->bert_layer_weights[l].ffn_weights.output_weight.bias, + bert_weights->bert_layer_weights[l].ffn_layernorm_weights.gamma, + bert_weights->bert_layer_weights[l].ffn_layernorm_weights.beta, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + } + else if (layernorm_type_ == LayerNormType::pre_layernorm) { + invokeAddBiasResidual(out_tensor, + attn_out_buf_, + bert_weights->bert_layer_weights[l].ffn_weights.output_weight.bias, + h_token_num, + hidden_units_, + stream_); + } + sync_check_cuda_error(); - // FFN - { - std::vector ffn_input_tensors{ - Tensor{MEMORY_GPU, - data_type, - std::vector{h_token_num, hidden_units_}, - layernorm_type_ == LayerNormType::pre_layernorm ? normed_attn_out_buf_ : attn_out_buf_}}; - std::vector ffn_output_tensors{ - Tensor{MEMORY_GPU, data_type, std::vector{h_token_num, hidden_units_}, out_tensor}}; - ffn_layer_->forward( - &ffn_output_tensors, &ffn_input_tensors, &bert_weights->bert_layer_weights[i].ffn_weights); + if (isLastLayerParallelId(l) && pipeline_para_.rank_ != pipeline_para_.world_size_ - 1 + && pipeline_para_.world_size_ > 1) { + + ftNcclSend(out_tensor + h_token_num * hidden_units_ / tensor_para_.world_size_ * tensor_para_.rank_, + h_token_num * hidden_units_ / tensor_para_.world_size_, + pipeline_para_.rank_ + 1, + pipeline_para_, + stream_); + } + } // transformer layers + + if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { + if (layernorm_type_ == LayerNormType::pre_layernorm) { + invokeGeneralLayerNorm(bert_output_ptr, + bert_output_ptr, + bert_weights->post_transformer_layernorm_weights.gamma, + bert_weights->post_transformer_layernorm_weights.beta, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + sync_check_cuda_error(); + } + + // post process (rebuild padding) + switch (attention_type) { + case AttentionType::UNFUSED_MHA: { + invokeRebuildPadding(output_tensors->at("output_hidden_state").getPtrWithOffset(hidden_offset), + bert_out_buffer_, + padding_offset_, + h_token_num, + head_num_ * size_per_head_, + stream_); + sync_check_cuda_error(); + break; + } + case AttentionType::UNFUSED_PADDED_MHA: { + break; + } + case AttentionType::FUSED_MHA: { + invokeRebuildPadding(output_tensors->at("output_hidden_state").getPtrWithOffset(hidden_offset), + bert_out_buffer_, + padding_offset_, + h_token_num, + head_num_ * size_per_head_, + stream_); + sync_check_cuda_error(); + break; + } + case AttentionType::FUSED_PADDED_MHA: { + break; + } + default: { + throw std::runtime_error(std::string("[FT][ERROR] Invalid attention type \n")); + } + } } - if (layernorm_type_ == LayerNormType::post_layernorm) { - invokeAddBiasResidualLayerNorm(out_tensor, - attn_out_buf_, - bert_weights->bert_layer_weights[i].ffn_weights.output_weight.bias, - bert_weights->bert_layer_weights[i].ffn_layernorm_weights.gamma, - bert_weights->bert_layer_weights[i].ffn_layernorm_weights.beta, - h_token_num, - hidden_units_, - stream_); - } - else if (layernorm_type_ == LayerNormType::pre_layernorm) { - invokeAddBiasResidual(out_tensor, - attn_out_buf_, - bert_weights->bert_layer_weights[i].ffn_weights.output_weight.bias, - h_token_num, - hidden_units_, - stream_); - } - sync_check_cuda_error(); - } - - if (layernorm_type_ == LayerNormType::pre_layernorm) { - invokeGeneralLayerNorm(bert_output_ptr, - bert_output_ptr, - bert_weights->post_transformer_layernorm_weights.gamma, - bert_weights->post_transformer_layernorm_weights.beta, - h_token_num, - hidden_units_, - stream_); - } - - // post process (rebuild padding) - switch (attention_type_) { - case AttentionType::UNFUSED_MHA: { - invokeRebuildPadding((T*)output_tensors->at(0).data, - bert_out_buffer_, - padding_offset_, - h_token_num, - head_num_ * size_per_head_, - stream_); - break; - } - case AttentionType::UNFUSED_PADDED_MHA: { - break; - } - case AttentionType::FUSED_MHA: { - invokeRebuildPadding((T*)output_tensors->at(0).data, - bert_out_buffer_, - padding_offset_, - h_token_num, - head_num_ * size_per_head_, - stream_); - break; - } - case AttentionType::FUSED_PADDED_MHA: { - break; - } - default: { - throw std::runtime_error(std::string("[FT][ERROR] Invalid attention type \n")); + if (padding_offset_tensor_ptr != nullptr) { + delete padding_offset_tensor_ptr; } } @@ -467,28 +657,34 @@ void Bert::forward(std::vector* output_tensors, } sync_check_cuda_error(); - delete padding_offset_tensor_ptr; -} + if (pipeline_para_.world_size_ > 1) { + ftNcclGroupStart(); + const int data_size = request_batch_size * request_seq_len * hidden_units_ / tensor_para_.world_size_; + ftNcclBroadCast(output_tensors->at("output_hidden_state").getPtr() + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.world_size_ - 1, + pipeline_para_, + stream_); + ftNcclGroupEnd(); -template -bool Bert::isValidBatchSize(size_t batch_size) -{ - if (max_batch_size_ < batch_size) { - max_batch_size_ = batch_size; - } - return true; -} - -template -bool Bert::isValidSeqLen(size_t seq_len) -{ - if (max_seq_len_ < seq_len) { - max_seq_len_ = seq_len; + sync_check_cuda_error(); + if (tensor_para_.world_size_ > 1) { + ftNcclAllGather(output_tensors->at("output_hidden_state").getPtr(), + output_tensors->at("output_hidden_state").getPtr(), + data_size, + tensor_para_.rank_, + tensor_para_, + stream_); + } + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); } - return true; } template class Bert; template class Bert; +#ifdef ENABLE_BF16 +template class Bert<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/bert/Bert.h b/src/fastertransformer/models/bert/Bert.h index f18edaff7..41c0a83d1 100644 --- a/src/fastertransformer/models/bert/Bert.h +++ b/src/fastertransformer/models/bert/Bert.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,85 +20,114 @@ #include "src/fastertransformer/kernels/bert_preprocess_kernels.h" #include "src/fastertransformer/kernels/layernorm_kernels.h" -#include "src/fastertransformer/layers/FfnLayer.h" +#include "src/fastertransformer/layers/TensorParallelGeluFfnLayer.h" +#include "src/fastertransformer/layers/TensorParallelReluFfnLayer.h" #include "src/fastertransformer/layers/attention_layers/FusedAttentionLayer.h" #include "src/fastertransformer/layers/attention_layers/UnfusedAttentionLayer.h" #include "src/fastertransformer/models/bert/BertWeight.h" +#include "src/fastertransformer/utils/nccl_utils.h" namespace fastertransformer { template class Bert: public BaseLayer { private: - // buffer handling - size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; - // meta data - size_t head_num_; - size_t size_per_head_; - size_t inter_size_; - size_t hidden_units_; - size_t num_layer_; - int sm_; - float q_scaling_; - AttentionType attention_type_; - bool sparse_; - - BaseAttentionLayer* attention_layer_; - FfnLayer* ffn_layer_; + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t hidden_units_; + size_t num_layer_; + int sm_; + static constexpr float layernorm_eps_ = 1e-6f; + float q_scaling_; + AttentionType attention_type_; + bool sparse_; + + BaseAttentionLayer* unfused_attention_layer_ = nullptr; + BaseAttentionLayer* fused_attention_layer_ = nullptr; + FfnLayer* ffn_layer_; bool is_allocate_buffer_ = false; + NcclParam tensor_para_; + NcclParam pipeline_para_; + std::shared_ptr custom_all_reduce_comm_; + bool enable_custom_all_reduce_; + void allocateBuffer(); void freeBuffer(); - bool isValidBatchSize(size_t batch_size); - bool isValidSeqLen(size_t seq_len); void initialize(); const ActivationType activation_type_; - const LayerNormType layernorm_type_; + const LayerNormType layernorm_type_; void allocateBuffer(size_t batch_size, size_t seq_len); + bool isValidLayerParallelId(uint l); + bool isFirstLayerParallelId(uint l); + bool isLastLayerParallelId(uint l); + int getFirstLayerParallelId(); protected: // model params - size_t* token_num_ = nullptr; - int* padding_offset_ = nullptr; - int* trt_mha_padding_offset_ = nullptr; - T* attention_mask_ = nullptr; - T* bert_in_buffer_ = nullptr; - T* attn_out_buf_ = nullptr; - T* bert_out_buffer_ = nullptr; - - T* normed_from_tensor_ = nullptr; + size_t* token_num_ = nullptr; + int* padding_offset_ = nullptr; + int* trt_mha_padding_offset_ = nullptr; + T* attention_mask_ = nullptr; + T* bert_in_buffer_ = nullptr; + T* attn_out_buf_ = nullptr; + T* bert_out_buffer_ = nullptr; + + T* normed_from_tensor_ = nullptr; T* normed_attn_out_buf_ = nullptr; public: - Bert(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - int sm, - float q_scaling, - cudaStream_t stream, + Bert(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + int sm, + float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type, + bool sparse, + ActivationType activation_type, + LayerNormType layernorm_type, + NcclParam tensor_para, + NcclParam pipeline_para, + std::shared_ptr custom_all_reduce_comm, + bool enable_custom_all_reduce); + + Bert(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + int sm, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type, - bool sparse, - ActivationType activation_type, - LayerNormType layernorm_type); + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type, + bool sparse, + ActivationType activation_type, + LayerNormType layernorm_type); Bert(Bert const& bert_layer); ~Bert(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const BertWeight* bert_weights); + const BertWeight* bert_weights); + void forward(TensorMap* output_tensors, TensorMap* input_tensors, const BertWeight* bert_weights); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/bert/BertLayerWeight.h b/src/fastertransformer/models/bert/BertLayerWeight.h index 9c05c2223..b6f14b3d1 100644 --- a/src/fastertransformer/models/bert/BertLayerWeight.h +++ b/src/fastertransformer/models/bert/BertLayerWeight.h @@ -19,8 +19,10 @@ #include "src/fastertransformer/kernels/layernorm_kernels.h" #include "src/fastertransformer/layers/FfnWeight.h" #include "src/fastertransformer/layers/attention_layers/AttentionWeight.h" +#include "src/fastertransformer/models/BaseWeight.h" #include "src/fastertransformer/utils/cublasMMWrapper.h" #include "src/fastertransformer/utils/memory_utils.h" +#include namespace fastertransformer { @@ -28,143 +30,128 @@ template struct BertLayerWeight { BertLayerWeight() = default; - BertLayerWeight(const int hidden_units, const int inter_size): hidden_units_(hidden_units), inter_size_(inter_size) + BertLayerWeight(const size_t hidden_units, + const size_t inter_size, + const size_t tensor_para_size, + const size_t tensor_para_rank): + hidden_units_(hidden_units), + inter_size_(inter_size), + tensor_para_size_(tensor_para_size), + tensor_para_rank_(tensor_para_rank) { - deviceMalloc(&weights_ptr[0], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[1], hidden_units_); - deviceMalloc(&weights_ptr[2], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[3], hidden_units_); - deviceMalloc(&weights_ptr[4], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[5], hidden_units_); - deviceMalloc(&weights_ptr[6], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[7], hidden_units_); - deviceMalloc(&weights_ptr[8], hidden_units_); - deviceMalloc(&weights_ptr[9], hidden_units_); - deviceMalloc(&weights_ptr[10], hidden_units_ * inter_size_); - deviceMalloc(&weights_ptr[11], inter_size_); - deviceMalloc(&weights_ptr[12], inter_size_ * hidden_units_); - deviceMalloc(&weights_ptr[13], hidden_units_); - deviceMalloc(&weights_ptr[14], hidden_units_); - deviceMalloc(&weights_ptr[15], hidden_units_); - + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + std::string name; + + name = "attention.self.query.weight." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_, hidden_units_ / tensor_para_size_}, nullptr)}); + name = "attention.self.query.bias." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_ / tensor_para_size_}, nullptr)}); + name = "attention.self.key.weight." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_, hidden_units_ / tensor_para_size_}, nullptr)}); + name = "attention.self.key.bias." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_ / tensor_para_size_}, nullptr)}); + name = "attention.self.value.weight." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_, hidden_units_ / tensor_para_size_}, nullptr)}); + name = "attention.self.value.bias." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_ / tensor_para_size_}, nullptr)}); + name = "attention.output.dense.weight." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_, hidden_units_ / tensor_para_size_}, nullptr)}); + name = "attention.output.dense.bias.bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_}, nullptr)}); + name = "attention.output.LayerNorm.weight.bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_}, nullptr)}); + name = "attention.output.LayerNorm.bias.bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_}, nullptr)}); + name = "intermediate.dense.weight." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_, inter_size_ / tensor_para_size_}, nullptr)}); + name = "intermediate.dense.bias." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {inter_size_ / tensor_para_size_}, nullptr)}); + name = "output.dense.weight." + std::to_string(tensor_para_rank_) + ".bin"; + weights_ptr.insert({name, FtWeight(name, {inter_size_ / tensor_para_size_, hidden_units_}, nullptr)}); + name = "output.dense.bias.bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_}, nullptr)}); + name = "output.LayerNorm.weight.bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_}, nullptr)}); + name = "output.LayerNorm.bias.bin"; + weights_ptr.insert({name, FtWeight(name, {hidden_units_}, nullptr)}); + + for (auto it = weights_ptr.begin(); it != weights_ptr.end(); ++it) { + deviceMalloc(&it->second.ptr_, it->second.size_); + } setWeightPtr(); + FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } + BertLayerWeight(const int hidden_units, const int inter_size): BertLayerWeight(hidden_units, inter_size, 1, 0) {} + ~BertLayerWeight() { if (is_maintain_buffer == true) { - for (int i = 0; i < 16; i++) { - deviceFree(weights_ptr[i]); + for (auto it = weights_ptr.begin(); it != weights_ptr.end(); ++it) { + deviceFree(it->second.ptr_); } - - attention_weights.query_weight.kernel = nullptr; - attention_weights.query_weight.bias = nullptr; - attention_weights.key_weight.kernel = nullptr; - attention_weights.key_weight.bias = nullptr; - attention_weights.value_weight.kernel = nullptr; - attention_weights.value_weight.bias = nullptr; + weights_ptr.clear(); + + attention_weights.query_weight.kernel = nullptr; + attention_weights.query_weight.bias = nullptr; + attention_weights.key_weight.kernel = nullptr; + attention_weights.key_weight.bias = nullptr; + attention_weights.value_weight.kernel = nullptr; + attention_weights.value_weight.bias = nullptr; attention_weights.attention_output_weight.kernel = nullptr; - attention_weights.attention_output_weight.bias = nullptr; - attn_layernorm_weights.gamma = nullptr; - attn_layernorm_weights.beta = nullptr; - ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.kernel = nullptr; - ffn_weights.output_weight.bias = nullptr; - ffn_layernorm_weights.gamma = nullptr; - ffn_layernorm_weights.beta = nullptr; - is_maintain_buffer = false; + attention_weights.attention_output_weight.bias = nullptr; + attn_layernorm_weights.gamma = nullptr; + attn_layernorm_weights.beta = nullptr; + ffn_weights.intermediate_weight.kernel = nullptr; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + ffn_layernorm_weights.gamma = nullptr; + ffn_layernorm_weights.beta = nullptr; + is_maintain_buffer = false; } if (is_maintain_sp_buffer == true) { for (int i = 0; i < 6; i++) { deviceFree(sp_weights_ptr[i]); } - attention_weights.query_weight.sp_kernel = nullptr; - attention_weights.key_weight.sp_kernel = nullptr; - attention_weights.value_weight.sp_kernel = nullptr; + attention_weights.query_weight.sp_kernel = nullptr; + attention_weights.key_weight.sp_kernel = nullptr; + attention_weights.value_weight.sp_kernel = nullptr; attention_weights.attention_output_weight.sp_kernel = nullptr; - ffn_weights.intermediate_weight.sp_kernel = nullptr; - ffn_weights.output_weight.sp_kernel = nullptr; - is_maintain_sp_buffer = false; + ffn_weights.intermediate_weight.sp_kernel = nullptr; + ffn_weights.output_weight.sp_kernel = nullptr; + is_maintain_sp_buffer = false; } } - BertLayerWeight(const BertLayerWeight& other): hidden_units_(other.hidden_units_), inter_size_(other.inter_size_) + BertLayerWeight(const BertLayerWeight& other): + BertLayerWeight(other.hidden_units_, other.inter_size_, other.tensor_para_size_, other.tensor_para_rank_) { - deviceMalloc(&weights_ptr[0], hidden_units_ * hidden_units_); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[1], hidden_units_); - cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); - deviceMalloc(&weights_ptr[2], hidden_units_ * hidden_units_); - cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[3], hidden_units_); - cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_); - deviceMalloc(&weights_ptr[4], hidden_units_ * hidden_units_); - cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[5], hidden_units_); - cudaD2Dcpy(weights_ptr[5], other.weights_ptr[5], hidden_units_); - deviceMalloc(&weights_ptr[6], hidden_units_ * hidden_units_); - cudaD2Dcpy(weights_ptr[6], other.weights_ptr[6], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[7], hidden_units_); - cudaD2Dcpy(weights_ptr[7], other.weights_ptr[7], hidden_units_); - deviceMalloc(&weights_ptr[8], hidden_units_); - cudaD2Dcpy(weights_ptr[8], other.weights_ptr[8], hidden_units_); - deviceMalloc(&weights_ptr[9], hidden_units_); - cudaD2Dcpy(weights_ptr[9], other.weights_ptr[9], hidden_units_); - deviceMalloc(&weights_ptr[10], hidden_units_ * inter_size_); - cudaD2Dcpy(weights_ptr[10], other.weights_ptr[10], hidden_units_ * inter_size_); - deviceMalloc(&weights_ptr[11], inter_size_); - cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], inter_size_); - deviceMalloc(&weights_ptr[12], inter_size_ * hidden_units_); - cudaD2Dcpy(weights_ptr[12], other.weights_ptr[12], inter_size_ * hidden_units_); - deviceMalloc(&weights_ptr[13], hidden_units_); - cudaD2Dcpy(weights_ptr[13], other.weights_ptr[13], hidden_units_); - deviceMalloc(&weights_ptr[14], hidden_units_); - cudaD2Dcpy(weights_ptr[14], other.weights_ptr[14], hidden_units_); - deviceMalloc(&weights_ptr[15], hidden_units_); - cudaD2Dcpy(weights_ptr[15], other.weights_ptr[15], hidden_units_); - - setWeightPtr(); + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + for (auto it = other.weights_ptr.begin(); it != other.weights_ptr.end(); ++it) { + cudaD2Dcpy(weights_ptr.at(it->first).ptr_, it->second.ptr_, it->second.size_); + } + FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } BertLayerWeight& operator=(const BertLayerWeight& other) { - hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; - deviceMalloc(&weights_ptr[0], hidden_units_ * hidden_units_); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[1], hidden_units_); - cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); - deviceMalloc(&weights_ptr[2], hidden_units_ * hidden_units_); - cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[3], hidden_units_); - cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_); - deviceMalloc(&weights_ptr[4], hidden_units_ * hidden_units_); - cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[5], hidden_units_); - cudaD2Dcpy(weights_ptr[5], other.weights_ptr[5], hidden_units_); - deviceMalloc(&weights_ptr[6], hidden_units_ * hidden_units_); - cudaD2Dcpy(weights_ptr[6], other.weights_ptr[6], hidden_units_ * hidden_units_); - deviceMalloc(&weights_ptr[7], hidden_units_); - cudaD2Dcpy(weights_ptr[7], other.weights_ptr[7], hidden_units_); - deviceMalloc(&weights_ptr[8], hidden_units_); - cudaD2Dcpy(weights_ptr[8], other.weights_ptr[8], hidden_units_); - deviceMalloc(&weights_ptr[9], hidden_units_); - cudaD2Dcpy(weights_ptr[9], other.weights_ptr[9], hidden_units_); - deviceMalloc(&weights_ptr[10], hidden_units_ * inter_size_); - cudaD2Dcpy(weights_ptr[10], other.weights_ptr[10], hidden_units_ * inter_size_); - deviceMalloc(&weights_ptr[11], inter_size_); - cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], inter_size_); - deviceMalloc(&weights_ptr[12], inter_size_ * hidden_units_); - cudaD2Dcpy(weights_ptr[12], other.weights_ptr[12], inter_size_ * hidden_units_); - deviceMalloc(&weights_ptr[13], hidden_units_); - cudaD2Dcpy(weights_ptr[13], other.weights_ptr[13], hidden_units_); - deviceMalloc(&weights_ptr[14], hidden_units_); - cudaD2Dcpy(weights_ptr[14], other.weights_ptr[14], hidden_units_); - deviceMalloc(&weights_ptr[15], hidden_units_); - cudaD2Dcpy(weights_ptr[15], other.weights_ptr[15], hidden_units_); - + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + + for (auto it = other.weights_ptr.begin(); it != other.weights_ptr.end(); ++it) { + weights_ptr.insert({it->first, it->second}); + weights_ptr.at(it->first).ptr_ = nullptr; + deviceMalloc(weights_ptr.at(it->first).ptr_, it->second.size_); + cudaD2Dcpy(weights_ptr.at(it->first).ptr_, it->second.ptr_, it->second.size_); + } setWeightPtr(); + FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); + + return *this; } #ifdef SPARSITY_ENABLED @@ -185,49 +172,72 @@ struct BertLayerWeight { cublas_wrapper.compressMatrix( ffn_weights.intermediate_weight.kernel, sp_weights_ptr[4], inter_size, hidden_dim); cublas_wrapper.compressMatrix(ffn_weights.output_weight.kernel, sp_weights_ptr[5], hidden_dim, inter_size); - attention_weights.query_weight.sp_kernel = sp_weights_ptr[0]; - attention_weights.key_weight.sp_kernel = sp_weights_ptr[1]; - attention_weights.value_weight.sp_kernel = sp_weights_ptr[2]; + attention_weights.query_weight.sp_kernel = sp_weights_ptr[0]; + attention_weights.key_weight.sp_kernel = sp_weights_ptr[1]; + attention_weights.value_weight.sp_kernel = sp_weights_ptr[2]; attention_weights.attention_output_weight.sp_kernel = sp_weights_ptr[3]; - ffn_weights.intermediate_weight.sp_kernel = sp_weights_ptr[4]; - ffn_weights.output_weight.sp_kernel = sp_weights_ptr[5]; - is_maintain_sp_buffer = true; + ffn_weights.intermediate_weight.sp_kernel = sp_weights_ptr[4]; + ffn_weights.output_weight.sp_kernel = sp_weights_ptr[5]; + is_maintain_sp_buffer = true; } #endif AttentionWeight attention_weights; LayerNormWeight attn_layernorm_weights; - FfnWeight ffn_weights; + FfnWeight ffn_weights; LayerNormWeight ffn_layernorm_weights; + void loadModel(std::string dir_path, FtCudaDataType model_file_type) + { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + for (auto it = weights_ptr.begin(); it != weights_ptr.end(); ++it) { + loadWeightFromBin(it->second.ptr_, it->second.shape_, dir_path + it->first, model_file_type); + } + FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); + } + private: void setWeightPtr() { - attention_weights.query_weight.kernel = weights_ptr[0]; - attention_weights.query_weight.bias = weights_ptr[1]; - attention_weights.key_weight.kernel = weights_ptr[2]; - attention_weights.key_weight.bias = weights_ptr[3]; - attention_weights.value_weight.kernel = weights_ptr[4]; - attention_weights.value_weight.bias = weights_ptr[5]; - attention_weights.attention_output_weight.kernel = weights_ptr[6]; - attention_weights.attention_output_weight.bias = weights_ptr[7]; - attn_layernorm_weights.gamma = weights_ptr[8]; - attn_layernorm_weights.beta = weights_ptr[9]; - ffn_weights.intermediate_weight.kernel = weights_ptr[10]; - ffn_weights.intermediate_weight.bias = weights_ptr[11]; - ffn_weights.output_weight.kernel = weights_ptr[12]; - ffn_weights.output_weight.bias = weights_ptr[13]; - ffn_layernorm_weights.gamma = weights_ptr[14]; - ffn_layernorm_weights.beta = weights_ptr[15]; + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + attention_weights.query_weight.kernel = + weights_ptr.at("attention.self.query.weight." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + attention_weights.query_weight.bias = + weights_ptr.at("attention.self.query.bias." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + attention_weights.key_weight.kernel = + weights_ptr.at("attention.self.key.weight." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + attention_weights.key_weight.bias = + weights_ptr.at("attention.self.key.bias." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + attention_weights.value_weight.kernel = + weights_ptr.at("attention.self.value.weight." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + attention_weights.value_weight.bias = + weights_ptr.at("attention.self.value.bias." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + attention_weights.attention_output_weight.kernel = + weights_ptr.at("attention.output.dense.weight." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + attention_weights.attention_output_weight.bias = weights_ptr.at("attention.output.dense.bias.bin").ptr_; + attn_layernorm_weights.gamma = weights_ptr.at("attention.output.LayerNorm.weight.bin").ptr_; + attn_layernorm_weights.beta = weights_ptr.at("attention.output.LayerNorm.bias.bin").ptr_; + ffn_weights.intermediate_weight.kernel = + weights_ptr.at("intermediate.dense.weight." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + ffn_weights.intermediate_weight.bias = + weights_ptr.at("intermediate.dense.bias." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + ffn_weights.output_weight.kernel = + weights_ptr.at("output.dense.weight." + std::to_string(tensor_para_rank_) + ".bin").ptr_; + ffn_weights.output_weight.bias = weights_ptr.at("output.dense.bias.bin").ptr_; + ffn_layernorm_weights.gamma = weights_ptr.at("output.LayerNorm.weight.bin").ptr_; + ffn_layernorm_weights.beta = weights_ptr.at("output.LayerNorm.bias.bin").ptr_; is_maintain_buffer = true; + FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } - int hidden_units_; - int inter_size_; - bool is_maintain_buffer = false; - T* weights_ptr[16]; - T* sp_weights_ptr[6]; - bool is_maintain_sp_buffer = false; + size_t hidden_units_; + size_t inter_size_; + size_t tensor_para_size_; + size_t tensor_para_rank_; + bool is_maintain_buffer = false; + std::unordered_map> weights_ptr; + T* sp_weights_ptr[6]; + bool is_maintain_sp_buffer = false; }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/bert/BertWeight.h b/src/fastertransformer/models/bert/BertWeight.h index ec9341361..d53fe581c 100644 --- a/src/fastertransformer/models/bert/BertWeight.h +++ b/src/fastertransformer/models/bert/BertWeight.h @@ -24,17 +24,38 @@ template struct BertWeight { BertWeight() = default; - BertWeight(const int hidden_units, const int inter_size, const int num_layer): - hidden_units_(hidden_units), inter_size_(inter_size), num_layer_(num_layer) + BertWeight(const size_t hidden_units, + const size_t inter_size, + const size_t num_layer, + const size_t tensor_para_size, + const size_t tensor_para_rank, + const size_t pipeline_para_size, + const size_t pipeline_para_rank): + hidden_units_(hidden_units), + inter_size_(inter_size), + num_layer_(num_layer), + tensor_para_size_(tensor_para_size), + tensor_para_rank_(tensor_para_rank), + pipeline_para_size_(pipeline_para_size), + pipeline_para_rank_(pipeline_para_rank) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + deviceMalloc(&weights_ptr[0], hidden_units_); deviceMalloc(&weights_ptr[1], hidden_units_); setWeightPtr(); bert_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { - bert_layer_weights.push_back(BertLayerWeight(hidden_units_, inter_size_)); + bert_layer_weights.push_back( + BertLayerWeight(hidden_units_, inter_size_, tensor_para_size_, tensor_para_rank_)); } + FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); + } + + BertWeight(const int hidden_units, const int inter_size, const int num_layer): + BertWeight(hidden_units, inter_size, num_layer, 1, 0, 1, 0) + { } ~BertWeight() @@ -46,14 +67,21 @@ struct BertWeight { } post_transformer_layernorm_weights.gamma = nullptr; - post_transformer_layernorm_weights.beta = nullptr; - is_maintain_buffer = false; + post_transformer_layernorm_weights.beta = nullptr; + is_maintain_buffer = false; } } BertWeight(const BertWeight& other): - hidden_units_(other.hidden_units_), inter_size_(other.inter_size_), num_layer_(other.num_layer_) + BertWeight(other.hidden_units_, + other.inter_size_, + other.num_layer_, + other.tensor_para_size_, + other.tensor_para_rank_, + other.pipeline_para_size_, + other.pipeline_para_rank_) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); bert_layer_weights.clear(); bert_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { @@ -65,13 +93,20 @@ struct BertWeight { cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); setWeightPtr(); + FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } BertWeight& operator=(const BertWeight& other) { - hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; - num_layer_ = other.num_layer_; + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + num_layer_ = other.num_layer_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + pipeline_para_size_ = other.pipeline_para_size_; + pipeline_para_rank_ = other.pipeline_para_rank_; + bert_layer_weights.clear(); bert_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { @@ -83,24 +118,51 @@ struct BertWeight { cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); setWeightPtr(); + FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); + + return *this; } std::vector> bert_layer_weights; - LayerNormWeight post_transformer_layernorm_weights; + LayerNormWeight post_transformer_layernorm_weights; + + bool isValidLayerParallelId(int l) + { + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_size_)); + return l < num_layer_ && (l >= local_num_layer * pipeline_para_rank_) + && (l < local_num_layer * (pipeline_para_rank_ + 1)); + } + + void loadModel(std::string dir_path) + { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + FtCudaDataType model_file_type = getModelFileType(dir_path + "/config.ini", "bert"); + for (uint l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l)) { + bert_layer_weights[l].loadModel(dir_path + "model.encoder.layer." + std::to_string(l) + ".", + model_file_type); + } + } + FT_LOG_DEBUG(__PRETTY_FUNCTION__, " stop"); + } private: void setWeightPtr() { post_transformer_layernorm_weights.gamma = weights_ptr[0]; - post_transformer_layernorm_weights.beta = weights_ptr[1]; + post_transformer_layernorm_weights.beta = weights_ptr[1]; is_maintain_buffer = true; } - int hidden_units_; - int inter_size_; - int num_layer_; - bool is_maintain_buffer = false; - T* weights_ptr[2]; + size_t hidden_units_; + size_t inter_size_; + size_t num_layer_; + size_t tensor_para_size_; + size_t tensor_para_rank_; + size_t pipeline_para_size_; + size_t pipeline_para_rank_; + bool is_maintain_buffer = false; + T* weights_ptr[2]; }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/bert/CMakeLists.txt b/src/fastertransformer/models/bert/CMakeLists.txt index 57611b5fd..0d5c9961d 100644 --- a/src/fastertransformer/models/bert/CMakeLists.txt +++ b/src/fastertransformer/models/bert/CMakeLists.txt @@ -18,8 +18,8 @@ add_library(Bert STATIC Bert.cc) set_property(TARGET Bert PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Bert PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Bert PUBLIC -lcudart bert_preprocess_kernels cublasMMWrapper - UnfusedAttentionLayer FusedAttentionLayer FfnLayer layernorm_kernels - add_residual_kernels) + UnfusedAttentionLayer FusedAttentionLayer TensorParallelGeluFfnLayer TensorParallelReluFfnLayer + layernorm_kernels add_residual_kernels nccl_utils custom_ar_comm tensor) add_executable(bert_gemm bert_gemm.cc) -target_link_libraries(bert_gemm PUBLIC -lcublas -lcublasLt -lcudart encoder_gemm_func encoder_igemm_func memory_utils) \ No newline at end of file +target_link_libraries(bert_gemm PUBLIC -lcublas -lcublasLt -lcudart encoder_gemm_func encoder_igemm_func memory_utils tensor) diff --git a/src/fastertransformer/models/bert/bert_gemm.cc b/src/fastertransformer/models/bert/bert_gemm.cc index a2f4a9f0a..6abf81506 100644 --- a/src/fastertransformer/models/bert/bert_gemm.cc +++ b/src/fastertransformer/models/bert/bert_gemm.cc @@ -22,37 +22,39 @@ namespace ft = fastertransformer; int main(int argc, char* argv[]) { - if (argc != 7) { - printf("[ERROR] bert_gemm batch_size seq_len head_number size_per_head data_type int8_mode \n"); - printf("e.g. ./bin/bert_gemm 1 32 12 64 0 0\n"); + if (argc != 7 && argc != 8) { + FT_LOG_ERROR("bert_gemm batch_size seq_len head_number size_per_head data_type int8_mode tensor_para_size"); + FT_LOG_ERROR("e.g. ./bin/bert_gemm 1 32 12 64 0 0 1 "); return 0; } - const int batch_size = atoi(argv[1]); - const int seq_len = atoi(argv[2]); - const int head_num = atoi(argv[3]); - const int size_per_head = atoi(argv[4]); - const ft::CublasDataType data_type = static_cast(atoi(argv[5])); // 0 FP32, 1 FP16, 2 BF 16 - const int int8_mode = atoi(argv[6]); + const int batch_size = atoi(argv[1]); + const int seq_len = atoi(argv[2]); + const int head_num = atoi(argv[3]); + const int size_per_head = atoi(argv[4]); + const ft::CublasDataType data_type = static_cast(atoi(argv[5])); // 0 FP32, 1 FP16, 2 BF 16 + const int int8_mode = atoi(argv[6]); + const int tensor_para_size = argc < 8 ? 1 : atoi(argv[7]); const int inter_size = 4 * head_num * size_per_head; - - printf("[INFO] arguments: \n"); - printf(" batch_size: %d \n", batch_size); - printf(" head_num: %d \n", head_num); - printf(" size_per_head: %d \n", size_per_head); - printf(" data_type: %d \n", data_type); - printf(" int8_mode: %d \n", int8_mode); + ft::FT_CHECK_WITH_INFO(head_num % tensor_para_size == 0, fmtstr("[ERROR] head_num (%d) %% tensor_para_size (%d) != 0", head_num, tensor_para_size)); + FT_LOG_INFO("arguments:"); + FT_LOG_INFO(" batch_size: %d", batch_size); + FT_LOG_INFO(" head_num: %d", head_num); + FT_LOG_INFO(" size_per_head: %d", size_per_head); + FT_LOG_INFO(" data_type: %d", data_type); + FT_LOG_INFO(" int8_mode: %d", int8_mode); + FT_LOG_INFO(" tensor_para_size: %d", tensor_para_size); std::cout << std::endl; - void* gemm_test_buf; + void* gemm_test_buf; size_t buf_size_in_byte = ft::calGemmTestBufSizeInByte(batch_size, seq_len, head_num, size_per_head, inter_size, 0, int8_mode, data_type); size_t total, free; ft::check_cuda_error(cudaMemGetInfo(&free, &total)); if (free < buf_size_in_byte + 10 * 1024 * 1024) { - printf("[ERROR] There is no enough device memory for gemm test!\n" - " %ld Bytes is needed, but only %ld Bytes is free.\n", + FT_LOG_ERROR(" There is no enough device memory for gemm test!\n" + " %ld Bytes is needed, but only %ld Bytes is free.", buf_size_in_byte, free); gemm_test_buf = NULL; diff --git a/src/fastertransformer/models/bert_int8/BertINT8.cc b/src/fastertransformer/models/bert_int8/BertINT8.cc index 7c6347bcf..db31b9740 100644 --- a/src/fastertransformer/models/bert_int8/BertINT8.cc +++ b/src/fastertransformer/models/bert_int8/BertINT8.cc @@ -19,21 +19,21 @@ namespace fastertransformer { template -BertINT8::BertINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - int sm, - float q_scaling, - int int8_mode, - cudaStream_t stream, +BertINT8::BertINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + int sm, + float q_scaling, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type, + bool sparse): max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), head_num_(head_num), @@ -124,16 +124,19 @@ template void BertINT8::allocateBuffer() { if (is_allocate_buffer_ == false) { - token_num_ = (size_t*)allocator_->malloc(sizeof(size_t) * 1, false); - padding_offset_ = (int*)allocator_->malloc(sizeof(int) * max_batch_size_ * max_seq_len_, false); - trt_mha_padding_offset_ = (int*)allocator_->malloc(sizeof(int) * (2 * max_batch_size_ + 1), false); + token_num_ = (size_t*)allocator_->reMalloc(token_num_, sizeof(size_t) * 1, false); + padding_offset_ = + (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * max_batch_size_ * max_seq_len_, false); + trt_mha_padding_offset_ = + (int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * max_batch_size_ + 1), false); - attention_mask_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); + attention_mask_ = + (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); - bert_in_buffer_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * size_per_head_, false); - bert_out_buffer_ = - (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * size_per_head_, false); + bert_in_buffer_ = (T*)allocator_->reMalloc( + bert_in_buffer_, sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * size_per_head_, false); + bert_out_buffer_ = (T*)allocator_->reMalloc( + bert_out_buffer_, sizeof(T) * max_batch_size_ * max_seq_len_ * head_num_ * size_per_head_, false); is_allocate_buffer_ = true; } } @@ -142,26 +145,26 @@ template void BertINT8::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(token_num_); - allocator_->free(padding_offset_); - allocator_->free(trt_mha_padding_offset_); + allocator_->free((void**)(&token_num_)); + allocator_->free((void**)(&padding_offset_)); + allocator_->free((void**)(&trt_mha_padding_offset_)); - allocator_->free(attention_mask_); - allocator_->free(bert_in_buffer_); - allocator_->free(bert_out_buffer_); + allocator_->free((void**)(&attention_mask_)); + allocator_->free((void**)(&bert_in_buffer_)); + allocator_->free((void**)(&bert_out_buffer_)); is_allocate_buffer_ = false; } } template -void BertINT8::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void BertINT8::forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* bert_layer_weights) { // input_tensors: [input_query (batch, seqlen, hidden), sequence_length (batch)] // output_tensors: [output (batch, seqlen, size_per_head*head_num)] const size_t request_batch_size = input_tensors->at(0).shape[0]; - const size_t request_seq_len = input_tensors->at(0).shape[1]; + const size_t request_seq_len = input_tensors->at(0).shape[1]; FT_CHECK(input_tensors->size() == 2); FT_CHECK(isValidBatchSize(request_batch_size)); FT_CHECK(isValidSeqLen(request_seq_len)); @@ -173,8 +176,8 @@ void BertINT8::forward(std::vector* output_tensors, const int* sequence_lengths = reinterpret_cast(input_tensors->at(1).data); size_t h_token_num; - T* bert_input_ptr; - T* bert_output_ptr; + T* bert_input_ptr; + T* bert_output_ptr; Tensor* padding_offset_tensor_ptr; switch (attention_type_) { @@ -197,7 +200,7 @@ void BertINT8::forward(std::vector* output_tensors, head_num_ * size_per_head_, stream_); sync_check_cuda_error(); - bert_input_ptr = bert_in_buffer_; + bert_input_ptr = bert_in_buffer_; bert_output_ptr = bert_out_buffer_; padding_offset_tensor_ptr = @@ -208,9 +211,9 @@ void BertINT8::forward(std::vector* output_tensors, invokeBuildEncoderAttentionMask( attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); sync_check_cuda_error(); - h_token_num = request_batch_size * request_seq_len; - bert_input_ptr = (T*)input_tensors->at(0).data; - bert_output_ptr = (T*)output_tensors->at(0).data; + h_token_num = request_batch_size * request_seq_len; + bert_input_ptr = (T*)input_tensors->at(0).data; + bert_output_ptr = (T*)output_tensors->at(0).data; padding_offset_tensor_ptr = new Tensor(MEMORY_GPU, TYPE_INT32, std::vector{0}, nullptr); break; } @@ -230,7 +233,7 @@ void BertINT8::forward(std::vector* output_tensors, head_num_ * size_per_head_, stream_); sync_check_cuda_error(); - bert_input_ptr = bert_in_buffer_; + bert_input_ptr = bert_in_buffer_; bert_output_ptr = bert_out_buffer_; invokeGetTrtPaddingOffset(trt_mha_padding_offset_, sequence_lengths, request_batch_size, stream_); @@ -245,7 +248,7 @@ void BertINT8::forward(std::vector* output_tensors, trt_mha_padding_offset_, sequence_lengths, request_batch_size, request_seq_len, stream_); padding_offset_tensor_ptr = new Tensor( MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size * 2 + 1}, trt_mha_padding_offset_); - bert_input_ptr = (T*)input_tensors->at(0).data; + bert_input_ptr = (T*)input_tensors->at(0).data; bert_output_ptr = (T*)output_tensors->at(0).data; break; } @@ -254,7 +257,7 @@ void BertINT8::forward(std::vector* output_tensors, } } - DataType data_type = getTensorType(); + DataType data_type = getTensorType(); std::vector tmp_output_tensors = { Tensor{MEMORY_GPU, data_type, std::vector{h_token_num, hidden_units_}, bert_output_ptr}, }; diff --git a/src/fastertransformer/models/bert_int8/BertINT8.h b/src/fastertransformer/models/bert_int8/BertINT8.h index 2bdd73376..859dda064 100644 --- a/src/fastertransformer/models/bert_int8/BertINT8.h +++ b/src/fastertransformer/models/bert_int8/BertINT8.h @@ -28,22 +28,22 @@ class BertINT8 { private: // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; // meta data - size_t head_num_; - size_t size_per_head_; - size_t inter_size_; - size_t hidden_units_; - size_t num_layer_; - int sm_; - float q_scaling_; - int int8_mode_; - cudaStream_t stream_; + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t hidden_units_; + size_t num_layer_; + int sm_; + float q_scaling_; + int int8_mode_; + cudaStream_t stream_; cublasMMWrapper* cublas_wrapper_; - IAllocator* allocator_; - bool is_free_buffer_after_forward_; - AttentionType attention_type_; - bool sparse_; + IAllocator* allocator_; + bool is_free_buffer_after_forward_; + AttentionType attention_type_; + bool sparse_; bool is_allocate_buffer_ = false; @@ -56,35 +56,35 @@ class BertINT8 { protected: size_t* token_num_; - int* padding_offset_; - int* trt_mha_padding_offset_; - T* attention_mask_; - T* bert_in_buffer_; - T* bert_out_buffer_; + int* padding_offset_; + int* trt_mha_padding_offset_; + T* attention_mask_; + T* bert_in_buffer_; + T* bert_out_buffer_; public: - BertINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - int sm, - float q_scaling, - int int8_mode, - cudaStream_t stream, + BertINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + int sm, + float q_scaling, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type, + bool sparse = false); BertINT8(BertINT8 const& bert_layer); ~BertINT8(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* bert_layer_weights); // friend class BertDebug; }; diff --git a/src/fastertransformer/models/bert_int8/BertLayerINT8.cc b/src/fastertransformer/models/bert_int8/BertLayerINT8.cc index b831cdf04..31a7292f3 100644 --- a/src/fastertransformer/models/bert_int8/BertLayerINT8.cc +++ b/src/fastertransformer/models/bert_int8/BertLayerINT8.cc @@ -15,6 +15,7 @@ */ #include "BertLayerINT8.h" +#include "src/fastertransformer/utils/nvtx_utils.h" namespace fastertransformer { @@ -66,20 +67,20 @@ void BertLayerINT8::initialize() } template -BertLayerINT8::BertLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int sm, - float q_scaling, - int int8_mode, - cudaStream_t stream, +BertLayerINT8::BertLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + int sm, + float q_scaling, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type, - bool sparse): + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type, + bool sparse): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -129,20 +130,20 @@ template void BertLayerINT8::allocateBuffer() { if (is_allocate_buffer_ == false) { - attn_out_buf_ = reinterpret_cast( - allocator_->malloc(sizeof(int32_t) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); + attn_out_buf_ = reinterpret_cast(allocator_->reMalloc( + attn_out_buf_, sizeof(int32_t) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); int8_buf_ = reinterpret_cast( - allocator_->malloc(sizeof(int8_t) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); + allocator_->reMalloc(int8_buf_, sizeof(int8_t) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); - layer_norm_tmp_buf_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); + layer_norm_tmp_buf_ = reinterpret_cast(allocator_->reMalloc( + layer_norm_tmp_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); - transformer_out_tmp_DataType_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); + transformer_out_tmp_DataType_ = reinterpret_cast(allocator_->reMalloc( + transformer_out_tmp_DataType_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); // re-use transformer_out_tmp_DataType_ as col32_from_tensor_ - col32_from_tensor_ = transformer_out_tmp_DataType_; + col32_from_tensor_ = transformer_out_tmp_DataType_; is_allocate_buffer_ = true; } } @@ -151,10 +152,10 @@ template void BertLayerINT8::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(attn_out_buf_); - allocator_->free(int8_buf_); - allocator_->free(layer_norm_tmp_buf_); - allocator_->free(transformer_out_tmp_DataType_); + allocator_->free((void**)(&attn_out_buf_)); + allocator_->free((void**)(&int8_buf_)); + allocator_->free((void**)(&layer_norm_tmp_buf_)); + allocator_->free((void**)(&transformer_out_tmp_DataType_)); is_allocate_buffer_ = false; } } @@ -165,12 +166,12 @@ void BertLayerINT8::freeBuffer() // quantize for int8_mode=1); for layer_idx != 0, the layout of input should be COL32. template -void BertLayerINT8::forward(std::vector* output_tensors, +void BertLayerINT8::forward(std::vector* output_tensors, const std::vector* input_tensors, - const BertLayerWeight* bert_layer_weight) + const BertLayerWeight* bert_layer_weight) { const BertLayerINT8Weight* bert_layer_int8_weight = (const BertLayerINT8Weight*)bert_layer_weight; - const ScaleList* scale_list = &(bert_layer_int8_weight->scale_list_); + const ScaleList* scale_list = &(bert_layer_int8_weight->scale_list_); // input_tensors: [input_query (token_num, hidden_dimension), // attention_mask (batch, 1, seqlen, seqlen), @@ -189,7 +190,7 @@ void BertLayerINT8::forward(std::vector* output_tensors, allocateBuffer(); T* from_tensor = (T*)input_tensors->at(0).data; - T* out_tensor = (T*)(output_tensors->at(0).data); + T* out_tensor = (T*)(output_tensors->at(0).data); const size_t m = input_tensors->at(0).shape[0]; const size_t n = hidden_units_; @@ -215,6 +216,7 @@ void BertLayerINT8::forward(std::vector* output_tensors, &attn_output_tensors, &int8_input_tensors, &bert_layer_int8_weight->attention_weights); // int32 I ; DataType O + PUSH_RANGE("post layernorm 1"); invokeAddBiasResidualLayerNormCol32(layer_norm_tmp_buf_, attn_out_buf_, from_tensor, @@ -226,10 +228,12 @@ void BertLayerINT8::forward(std::vector* output_tensors, stream_, &(scale_list->d_scale_list_[scale_list->p2_offset_ + 3 * hidden_units_]), &(scale_list->d_scale_list_[36])); + POP_RANGE; invokeQuantization(int8_buf_, layer_norm_tmp_buf_, m * n, &(scale_list->d_scale_list_[44 + 3]), stream_); std::vector ffn_input_tensors{Tensor{MEMORY_GPU, TYPE_INT8, std::vector{m, n}, int8_buf_}}; // reuse attn_output_tensors as ffn_output_tensors ffn_layer_->forward(&attn_output_tensors, &ffn_input_tensors, &bert_layer_int8_weight->ffn_weights); + PUSH_RANGE("post layernorm 2"); if (layer_idx != num_layer - 1) { // int32 I ; DataType O invokeAddBiasResidualLayerNormCol32( @@ -262,6 +266,7 @@ void BertLayerINT8::forward(std::vector* output_tensors, invokeTransposeMatrixCOL32ToColMajor(out_tensor, transformer_out_tmp_DataType_, m, n, stream_); } + POP_RANGE; } else if (int8_mode_ == 2 || int8_mode_ == 3) { @@ -286,7 +291,7 @@ void BertLayerINT8::forward(std::vector* output_tensors, else { attention_layer_->forward(&attn_output_tensors, input_tensors, &bert_layer_int8_weight->attention_weights); } - + PUSH_RANGE("post layernorm 1"); const int8_t* residual = layer_idx == 0 ? int8_buf_ : (const int8_t*)from_tensor; // int8 IO #ifdef SPARSITY_ENABLED @@ -321,6 +326,7 @@ void BertLayerINT8::forward(std::vector* output_tensors, #ifdef SPARSITY_ENABLED } #endif + POP_RANGE; std::vector ffn_input_tensors{ Tensor{MEMORY_GPU, TYPE_INT8, std::vector{m, n}, layer_norm_tmp_buf_}}; // reuse attn_output_tensors as ffn_output_tensors @@ -329,6 +335,7 @@ void BertLayerINT8::forward(std::vector* output_tensors, // int8 IO #ifdef SPARSITY_ENABLED if (sparse_) { + PUSH_RANGE("post layernorm 2"); invokeAddBiasResidualLayerNormRow((int8_t*)out_tensor, (int8_t*)attn_out_buf_, (int8_t*)layer_norm_tmp_buf_, @@ -341,9 +348,11 @@ void BertLayerINT8::forward(std::vector* output_tensors, &(scale_list->d_scale_list_[56 + 1]), &(scale_list->d_scale_list_[44 + 1]), &(scale_list->d_scale_list_[60 + 3])); + POP_RANGE; } else { #endif + PUSH_RANGE("post layernorm 2"); invokeAddBiasResidualLayerNormCol32((int8_t*)out_tensor, (int8_t*)attn_out_buf_, (int8_t*)layer_norm_tmp_buf_, @@ -356,6 +365,7 @@ void BertLayerINT8::forward(std::vector* output_tensors, &(scale_list->d_scale_list_[56 + 1]), &(scale_list->d_scale_list_[44 + 1]), &(scale_list->d_scale_list_[60 + 3])); + POP_RANGE; #ifdef SPARSITY_ENABLED } #endif @@ -363,6 +373,7 @@ void BertLayerINT8::forward(std::vector* output_tensors, else { #ifdef SPARSITY_ENABLED if (sparse_) { + PUSH_RANGE("post layernorm 2"); invokeAddBiasResidualLayerNormRow(out_tensor, (int8_t*)attn_out_buf_, (int8_t*)layer_norm_tmp_buf_, @@ -374,10 +385,12 @@ void BertLayerINT8::forward(std::vector* output_tensors, stream_, &(scale_list->d_scale_list_[56 + 1]), &(scale_list->d_scale_list_[44 + 1])); + POP_RANGE; } else { #endif // int8 I ; DataType O + PUSH_RANGE("post layernorm 2"); invokeAddBiasResidualLayerNormCol32(transformer_out_tmp_DataType_, (int8_t*)attn_out_buf_, (int8_t*)layer_norm_tmp_buf_, @@ -391,6 +404,7 @@ void BertLayerINT8::forward(std::vector* output_tensors, &(scale_list->d_scale_list_[44 + 1])); invokeTransposeMatrixCOL32ToColMajor(out_tensor, transformer_out_tmp_DataType_, m, n, stream_); + POP_RANGE; #ifdef SPARSITY_ENABLED } #endif diff --git a/src/fastertransformer/models/bert_int8/BertLayerINT8.h b/src/fastertransformer/models/bert_int8/BertLayerINT8.h index ed90f13b7..e97a012cf 100644 --- a/src/fastertransformer/models/bert_int8/BertLayerINT8.h +++ b/src/fastertransformer/models/bert_int8/BertLayerINT8.h @@ -36,21 +36,21 @@ class BertLayerINT8: public BaseLayer { private: // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; // meta data - size_t head_num_; - size_t size_per_head_; - size_t inter_size_; - int sm_; - float q_scaling_; - size_t hidden_units_; + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + int sm_; + float q_scaling_; + size_t hidden_units_; AttentionType attention_type_; - int int8_mode_; - bool sparse_; + int int8_mode_; + bool sparse_; BaseAttentionLayer* attention_layer_; - FfnLayerINT8* ffn_layer_; + FfnLayerINT8* ffn_layer_; void allocateBuffer() override; void freeBuffer() override; @@ -61,34 +61,34 @@ class BertLayerINT8: public BaseLayer { protected: int32_t* attn_out_buf_; - int8_t* int8_buf_; - T* layer_norm_tmp_buf_; - T* transformer_out_tmp_DataType_; - T* col32_from_tensor_; + int8_t* int8_buf_; + T* layer_norm_tmp_buf_; + T* transformer_out_tmp_DataType_; + T* col32_from_tensor_; public: - BertLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int sm, - float q_scaling, - int int8_mode, - cudaStream_t stream, + BertLayerINT8(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + int sm, + float q_scaling, + int int8_mode, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type = AttentionType::UNFUSED_PADDED_MHA, - bool sparse = false); + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type = AttentionType::UNFUSED_PADDED_MHA, + bool sparse = false); BertLayerINT8(BertLayerINT8 const& bert_layer); ~BertLayerINT8(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const BertLayerWeight* bert_layer_weight); + const BertLayerWeight* bert_layer_weight); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/bert_int8/BertLayerINT8Weight.h b/src/fastertransformer/models/bert_int8/BertLayerINT8Weight.h index aeef861a0..f9bb84409 100644 --- a/src/fastertransformer/models/bert_int8/BertLayerINT8Weight.h +++ b/src/fastertransformer/models/bert_int8/BertLayerINT8Weight.h @@ -63,37 +63,37 @@ struct BertLayerINT8Weight: BertLayerWeight { deviceFree(scale_list_ptr[0]); free(scale_list_ptr[1]); - attention_weights.query_weight.kernel = nullptr; - attention_weights.query_weight.bias = nullptr; - attention_weights.key_weight.kernel = nullptr; - attention_weights.key_weight.bias = nullptr; - attention_weights.value_weight.kernel = nullptr; - attention_weights.value_weight.bias = nullptr; + attention_weights.query_weight.kernel = nullptr; + attention_weights.query_weight.bias = nullptr; + attention_weights.key_weight.kernel = nullptr; + attention_weights.key_weight.bias = nullptr; + attention_weights.value_weight.kernel = nullptr; + attention_weights.value_weight.bias = nullptr; attention_weights.attention_output_weight.kernel = nullptr; - attention_weights.attention_output_weight.bias = nullptr; - attention_weights.scale_list_ptr = nullptr; - attn_layernorm_weights.gamma = nullptr; - attn_layernorm_weights.beta = nullptr; - ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.kernel = nullptr; - ffn_weights.output_weight.bias = nullptr; - ffn_weights.scale_list_ptr = nullptr; - ffn_layernorm_weights.gamma = nullptr; - ffn_layernorm_weights.beta = nullptr; - is_maintain_buffer = false; + attention_weights.attention_output_weight.bias = nullptr; + attention_weights.scale_list_ptr = nullptr; + attn_layernorm_weights.gamma = nullptr; + attn_layernorm_weights.beta = nullptr; + ffn_weights.intermediate_weight.kernel = nullptr; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + ffn_weights.scale_list_ptr = nullptr; + ffn_layernorm_weights.gamma = nullptr; + ffn_layernorm_weights.beta = nullptr; + is_maintain_buffer = false; } if (is_maintain_sp_buffer == true) { for (int i = 0; i < 6; i++) { deviceFree(sp_weights_ptr[i]); } - attention_weights.query_weight.sp_kernel = nullptr; - attention_weights.key_weight.sp_kernel = nullptr; - attention_weights.value_weight.sp_kernel = nullptr; + attention_weights.query_weight.sp_kernel = nullptr; + attention_weights.key_weight.sp_kernel = nullptr; + attention_weights.value_weight.sp_kernel = nullptr; attention_weights.attention_output_weight.sp_kernel = nullptr; - ffn_weights.intermediate_weight.sp_kernel = nullptr; - ffn_weights.output_weight.sp_kernel = nullptr; - is_maintain_sp_buffer = false; + ffn_weights.intermediate_weight.sp_kernel = nullptr; + ffn_weights.output_weight.sp_kernel = nullptr; + is_maintain_sp_buffer = false; } } @@ -125,7 +125,7 @@ struct BertLayerINT8Weight: BertLayerWeight { deviceMalloc(&weights_ptr[11], hidden_units_); cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], hidden_units_); - scale_list_.size_ = other.scale_list_.size_; + scale_list_.size_ = other.scale_list_.size_; scale_list_.p3_offset_ = other.scale_list_.p3_offset_; scale_list_.p4_offset_ = other.scale_list_.p4_offset_; deviceMalloc(&scale_list_ptr[0], scale_list_.size_); @@ -151,7 +151,7 @@ struct BertLayerINT8Weight: BertLayerWeight { */ hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; + inter_size_ = other.inter_size_; deviceMalloc(&weights_ptr[0], hidden_units_ * hidden_units_ * 3); cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], hidden_units_ * hidden_units_ * 3); deviceMalloc(&weights_ptr[1], hidden_units_ * 3); @@ -177,7 +177,7 @@ struct BertLayerINT8Weight: BertLayerWeight { deviceMalloc(&weights_ptr[11], hidden_units_); cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], hidden_units_); - scale_list_.size_ = other.scale_list_.size_; + scale_list_.size_ = other.scale_list_.size_; scale_list_.p3_offset_ = other.scale_list_.p3_offset_; scale_list_.p4_offset_ = other.scale_list_.p4_offset_; deviceMalloc(&scale_list_ptr[0], scale_list_.size_); @@ -188,47 +188,47 @@ struct BertLayerINT8Weight: BertLayerWeight { setWeightPtr(); } - LayerNormWeight attn_layernorm_weights; - LayerNormWeight ffn_layernorm_weights; + LayerNormWeight attn_layernorm_weights; + LayerNormWeight ffn_layernorm_weights; AttentionINT8Weight attention_weights; - FfnINT8Weight ffn_weights; - ScaleList scale_list_; + FfnINT8Weight ffn_weights; + ScaleList scale_list_; private: void setWeightPtr() { - attention_weights.query_weight.kernel = weights_ptr[0]; - attention_weights.query_weight.bias = weights_ptr[1]; - attention_weights.key_weight.kernel = weights_ptr[0] + hidden_units_ * hidden_units_; - attention_weights.key_weight.bias = weights_ptr[1] + hidden_units_; - attention_weights.value_weight.kernel = weights_ptr[0] + hidden_units_ * hidden_units_ * 2; - attention_weights.value_weight.bias = weights_ptr[1] + hidden_units_ * 2; + attention_weights.query_weight.kernel = weights_ptr[0]; + attention_weights.query_weight.bias = weights_ptr[1]; + attention_weights.key_weight.kernel = weights_ptr[0] + hidden_units_ * hidden_units_; + attention_weights.key_weight.bias = weights_ptr[1] + hidden_units_; + attention_weights.value_weight.kernel = weights_ptr[0] + hidden_units_ * hidden_units_ * 2; + attention_weights.value_weight.bias = weights_ptr[1] + hidden_units_ * 2; attention_weights.attention_output_weight.kernel = weights_ptr[2]; - attention_weights.attention_output_weight.bias = weights_ptr[3]; - attn_layernorm_weights.gamma = weights_ptr[4]; - attn_layernorm_weights.beta = weights_ptr[5]; - ffn_weights.intermediate_weight.kernel = weights_ptr[6]; - ffn_weights.intermediate_weight.bias = weights_ptr[7]; - ffn_weights.output_weight.kernel = weights_ptr[8]; - ffn_weights.output_weight.bias = weights_ptr[9]; - ffn_layernorm_weights.gamma = weights_ptr[10]; - ffn_layernorm_weights.beta = weights_ptr[11]; - - scale_list_.d_scale_list_ = scale_list_ptr[0]; - scale_list_.h_scale_list_ = scale_list_ptr[1]; + attention_weights.attention_output_weight.bias = weights_ptr[3]; + attn_layernorm_weights.gamma = weights_ptr[4]; + attn_layernorm_weights.beta = weights_ptr[5]; + ffn_weights.intermediate_weight.kernel = weights_ptr[6]; + ffn_weights.intermediate_weight.bias = weights_ptr[7]; + ffn_weights.output_weight.kernel = weights_ptr[8]; + ffn_weights.output_weight.bias = weights_ptr[9]; + ffn_layernorm_weights.gamma = weights_ptr[10]; + ffn_layernorm_weights.beta = weights_ptr[11]; + + scale_list_.d_scale_list_ = scale_list_ptr[0]; + scale_list_.h_scale_list_ = scale_list_ptr[1]; attention_weights.scale_list_ptr = &scale_list_; - ffn_weights.scale_list_ptr = &scale_list_; + ffn_weights.scale_list_ptr = &scale_list_; is_maintain_buffer = true; } - int hidden_units_; - int inter_size_; - bool is_maintain_buffer = false; - T* weights_ptr[12]; + int hidden_units_; + int inter_size_; + bool is_maintain_buffer = false; + T* weights_ptr[12]; float* scale_list_ptr[2]; - T* sp_weights_ptr[6]; - bool is_maintain_sp_buffer = false; + T* sp_weights_ptr[6]; + bool is_maintain_sp_buffer = false; }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/bert_int8/CMakeLists.txt b/src/fastertransformer/models/bert_int8/CMakeLists.txt index 47568e619..e965d63d7 100644 --- a/src/fastertransformer/models/bert_int8/CMakeLists.txt +++ b/src/fastertransformer/models/bert_int8/CMakeLists.txt @@ -20,9 +20,9 @@ set_property(TARGET BertLayerINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(BertLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart -lcurand cublasMMWrapper cublasINT8MMWrapper UnfusedAttentionLayerINT8 FusedAttentionLayerINT8 FfnLayerINT8 layernorm_int8_kernels - layout_transformer_int8_kernels quantization_int8_kernels) + layout_transformer_int8_kernels quantization_int8_kernels nvtx_utils tensor) add_library(BertINT8 STATIC BertINT8.cc) set_property(TARGET BertINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET BertINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(BertINT8 PUBLIC -lcublasLt -lcublas -lcudart -lcurand BertLayerINT8 bert_preprocess_kernels) +target_link_libraries(BertINT8 PUBLIC -lcublasLt -lcublas -lcudart -lcurand BertLayerINT8 bert_preprocess_kernels nvtx_utils tensor) diff --git a/src/fastertransformer/models/decoder/CMakeLists.txt b/src/fastertransformer/models/decoder/CMakeLists.txt index 95a4d5b2c..b04b2d939 100644 --- a/src/fastertransformer/models/decoder/CMakeLists.txt +++ b/src/fastertransformer/models/decoder/CMakeLists.txt @@ -18,4 +18,4 @@ add_library(Decoder STATIC Decoder.cc) set_property(TARGET Decoder PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Decoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Decoder PUBLIC -lcudart cublasMMWrapper DecoderSelfAttentionLayer - DecoderCrossAttentionLayer FfnLayer layernorm_kernels add_residual_kernels) \ No newline at end of file + DecoderCrossAttentionLayer FfnLayer layernorm_kernels add_residual_kernels tensor) diff --git a/src/fastertransformer/models/decoder/Decoder.cc b/src/fastertransformer/models/decoder/Decoder.cc index 2826de178..17f0c6a46 100644 --- a/src/fastertransformer/models/decoder/Decoder.cc +++ b/src/fastertransformer/models/decoder/Decoder.cc @@ -52,18 +52,18 @@ template void Decoder::allocateBuffer() { if (is_allocate_buffer_ == false) { - decoder_normed_input_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - self_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - normed_self_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - cross_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - normed_cross_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - decoder_layer_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); + decoder_normed_input_ = reinterpret_cast( + allocator_->reMalloc(decoder_normed_input_, sizeof(T) * max_batch_size_ * hidden_units_, false)); + self_attn_output_ = reinterpret_cast( + allocator_->reMalloc(self_attn_output_, sizeof(T) * max_batch_size_ * hidden_units_, false)); + normed_self_attn_output_ = reinterpret_cast( + allocator_->reMalloc(normed_self_attn_output_, sizeof(T) * max_batch_size_ * hidden_units_, false)); + cross_attn_output_ = reinterpret_cast( + allocator_->reMalloc(cross_attn_output_, sizeof(T) * max_batch_size_ * hidden_units_, false)); + normed_cross_attn_output_ = reinterpret_cast( + allocator_->reMalloc(normed_cross_attn_output_, sizeof(T) * max_batch_size_ * hidden_units_, false)); + decoder_layer_output_ = reinterpret_cast( + allocator_->reMalloc(decoder_layer_output_, sizeof(T) * max_batch_size_ * hidden_units_, false)); is_allocate_buffer_ = true; } } @@ -71,29 +71,32 @@ void Decoder::allocateBuffer() template void Decoder::allocateBuffer(size_t batch_size) { - decoder_normed_input_ = - reinterpret_cast(decoder_normed_input_, allocator_->malloc(sizeof(T) * batch_size * hidden_units_, false)); - self_attn_output_ = - reinterpret_cast(self_attn_output_, allocator_->malloc(sizeof(T) * batch_size * hidden_units_, false)); - normed_self_attn_output_ = reinterpret_cast(normed_self_attn_output_, - allocator_->malloc(sizeof(T) * batch_size * hidden_units_, false)); - cross_attn_output_ = - reinterpret_cast(cross_attn_output_, allocator_->malloc(sizeof(T) * batch_size * hidden_units_, false)); - normed_cross_attn_output_ = reinterpret_cast(normed_cross_attn_output_, - allocator_->malloc(sizeof(T) * batch_size * hidden_units_, false)); - decoder_layer_output_ = - reinterpret_cast(decoder_layer_output_, allocator_->malloc(sizeof(T) * batch_size * hidden_units_, false)); + if (is_allocate_buffer_ == false) { + decoder_normed_input_ = reinterpret_cast( + allocator_->reMalloc(decoder_normed_input_, sizeof(T) * batch_size * hidden_units_, false)); + self_attn_output_ = reinterpret_cast( + allocator_->reMalloc(self_attn_output_, sizeof(T) * batch_size * hidden_units_, false)); + normed_self_attn_output_ = reinterpret_cast( + allocator_->reMalloc(normed_self_attn_output_, sizeof(T) * batch_size * hidden_units_, false)); + cross_attn_output_ = reinterpret_cast( + allocator_->reMalloc(cross_attn_output_, sizeof(T) * batch_size * hidden_units_, false)); + normed_cross_attn_output_ = reinterpret_cast( + allocator_->reMalloc(normed_cross_attn_output_, sizeof(T) * batch_size * hidden_units_, false)); + decoder_layer_output_ = reinterpret_cast( + allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * hidden_units_, false)); + is_allocate_buffer_ = true; + } } template void Decoder::freeBuffer() { - allocator_->free(decoder_normed_input_); - allocator_->free(self_attn_output_); - allocator_->free(normed_self_attn_output_); - allocator_->free(cross_attn_output_); - allocator_->free(normed_cross_attn_output_); - allocator_->free(decoder_layer_output_); + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + allocator_->free((void**)(&normed_self_attn_output_)); + allocator_->free((void**)(&cross_attn_output_)); + allocator_->free((void**)(&normed_cross_attn_output_)); + allocator_->free((void**)(&decoder_layer_output_)); } template @@ -109,15 +112,15 @@ bool Decoder::isValidBatchSize(size_t batch_size) } template -Decoder::Decoder(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - cudaStream_t stream, +Decoder::Decoder(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), head_num_(head_num), @@ -153,8 +156,8 @@ Decoder::~Decoder() } template -void Decoder::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void Decoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* decoder_layer_weight) { // input tensors: @@ -180,13 +183,13 @@ void Decoder::forward(std::vector* output_tensors, isValidBatchSize(input_tensors->at(0).shape[0]); allocateBuffer(input_tensors->at(0).shape[0]); - const size_t batch_size = (size_t)input_tensors->at(0).shape[0]; - const size_t mem_max_seq_len = (size_t)input_tensors->at(1).shape[1]; - const DataType data_type = getTensorType(); + const size_t batch_size = (size_t)input_tensors->at(0).shape[0]; + const size_t mem_max_seq_len = (size_t)input_tensors->at(1).shape[1]; + const DataType data_type = getTensorType(); for (uint l = 0; l < num_layer_; l++) { - const T* decoder_input = (const T*)((l == 0) ? input_tensors->at(0).data : decoder_layer_output_); - T* decoder_output = (T*)((l == num_layer_ - 1) ? output_tensors->at(0).data : decoder_layer_output_); + const T* decoder_input = (const T*)((l == 0) ? input_tensors->at(0).data : decoder_layer_output_); + T* decoder_output = (T*)((l == num_layer_ - 1) ? output_tensors->at(0).data : decoder_layer_output_); size_t self_key_cache_offset = l; for (auto t = output_tensors->at(1).shape.begin() + 1; t != output_tensors->at(1).shape.end(); ++t) { @@ -202,20 +205,24 @@ void Decoder::forward(std::vector* output_tensors, decoder_input, decoder_layer_weight->at(l).pre_layernorm_weights.gamma, decoder_layer_weight->at(l).pre_layernorm_weights.beta, + layernorm_eps_, batch_size, hidden_units_, stream_); sync_check_cuda_error(); - int tmp_0 = 0; + int tmp_0 = 0; std::vector self_attention_input_tensors{ Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, decoder_normed_input_}, input_tensors->at(3), input_tensors->at(5), Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, nullptr}, + Tensor{MEMORY_GPU, data_type, {batch_size}, nullptr}, + Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_0}, Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_0}, input_tensors->at(4), - input_tensors->at(6)}; + input_tensors->at(6), + Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size}, (const bool*)nullptr}}; std::vector self_attention_output_tensors{ Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, self_attn_output_}, Tensor{MEMORY_GPU, @@ -237,6 +244,7 @@ void Decoder::forward(std::vector* output_tensors, decoder_layer_weight->at(l).self_attn_layernorm_weights.gamma, decoder_layer_weight->at(l).self_attn_layernorm_weights.beta, decoder_layer_weight->at(l).self_attention_weights.attention_output_weight.bias, + layernorm_eps_, batch_size, hidden_units_, stream_); @@ -269,6 +277,7 @@ void Decoder::forward(std::vector* output_tensors, decoder_layer_weight->at(l).cross_attn_layernorm_weights.gamma, decoder_layer_weight->at(l).cross_attn_layernorm_weights.beta, decoder_layer_weight->at(l).cross_attention_weights.attention_output_weight.bias, + layernorm_eps_, batch_size, hidden_units_, stream_); @@ -296,5 +305,8 @@ void Decoder::forward(std::vector* output_tensors, template class Decoder; template class Decoder; +#ifdef ENABLE_BF16 +template class Decoder<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/models/decoder/Decoder.h b/src/fastertransformer/models/decoder/Decoder.h index 1f806d642..ac512a459 100644 --- a/src/fastertransformer/models/decoder/Decoder.h +++ b/src/fastertransformer/models/decoder/Decoder.h @@ -37,15 +37,16 @@ class Decoder: public BaseLayer { // buffer handling size_t max_batch_size_ = 0; // meta data - size_t head_num_; - size_t size_per_head_; - size_t inter_size_; - size_t num_layer_; - size_t hidden_units_; + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t hidden_units_; + static constexpr float layernorm_eps_ = 1e-6f; BaseAttentionLayer* self_attention_layer_; BaseAttentionLayer* cross_attention_layer_; - FfnLayer* ffn_layer_; + FfnLayer* ffn_layer_; void allocateBuffer() override; void allocateBuffer(size_t batch_size); @@ -55,30 +56,30 @@ class Decoder: public BaseLayer { void initialize(); protected: - T* decoder_normed_input_ = nullptr; - T* self_attn_output_ = nullptr; - T* normed_self_attn_output_ = nullptr; - T* cross_attn_output_ = nullptr; + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; + T* normed_self_attn_output_ = nullptr; + T* cross_attn_output_ = nullptr; T* normed_cross_attn_output_ = nullptr; - T* decoder_layer_output_ = nullptr; + T* decoder_layer_output_ = nullptr; public: - Decoder(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - cudaStream_t stream, + Decoder(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); Decoder(Decoder const& decoder); ~Decoder(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* decoder_layer_weights); }; diff --git a/src/fastertransformer/models/decoder/DecoderLayerWeight.h b/src/fastertransformer/models/decoder/DecoderLayerWeight.h index 51dc2342d..26b4af226 100644 --- a/src/fastertransformer/models/decoder/DecoderLayerWeight.h +++ b/src/fastertransformer/models/decoder/DecoderLayerWeight.h @@ -41,31 +41,31 @@ struct DecoderLayerWeight { deviceFree(weights_ptr[i]); } - pre_layernorm_weights.beta = nullptr; - pre_layernorm_weights.gamma = nullptr; - self_attention_weights.query_weight.kernel = nullptr; - self_attention_weights.query_weight.bias = nullptr; + pre_layernorm_weights.beta = nullptr; + pre_layernorm_weights.gamma = nullptr; + self_attention_weights.query_weight.kernel = nullptr; + self_attention_weights.query_weight.bias = nullptr; self_attention_weights.attention_output_weight.kernel = nullptr; - self_attention_weights.attention_output_weight.bias = nullptr; - self_attn_layernorm_weights.beta = nullptr; - self_attn_layernorm_weights.gamma = nullptr; - - cross_attention_weights.query_weight.kernel = nullptr; - cross_attention_weights.query_weight.bias = nullptr; - cross_attention_weights.key_weight.kernel = nullptr; - cross_attention_weights.key_weight.bias = nullptr; - cross_attention_weights.value_weight.kernel = nullptr; - cross_attention_weights.value_weight.bias = nullptr; + self_attention_weights.attention_output_weight.bias = nullptr; + self_attn_layernorm_weights.beta = nullptr; + self_attn_layernorm_weights.gamma = nullptr; + + cross_attention_weights.query_weight.kernel = nullptr; + cross_attention_weights.query_weight.bias = nullptr; + cross_attention_weights.key_weight.kernel = nullptr; + cross_attention_weights.key_weight.bias = nullptr; + cross_attention_weights.value_weight.kernel = nullptr; + cross_attention_weights.value_weight.bias = nullptr; cross_attention_weights.attention_output_weight.kernel = nullptr; - cross_attention_weights.attention_output_weight.bias = nullptr; - cross_attn_layernorm_weights.beta = nullptr; - cross_attn_layernorm_weights.gamma = nullptr; + cross_attention_weights.attention_output_weight.bias = nullptr; + cross_attn_layernorm_weights.beta = nullptr; + cross_attn_layernorm_weights.gamma = nullptr; ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.kernel = nullptr; - ffn_weights.output_weight.bias = nullptr; - is_maintain_buffer = false; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + is_maintain_buffer = false; } } @@ -103,8 +103,8 @@ struct DecoderLayerWeight { DecoderLayerWeight& operator=(const DecoderLayerWeight& other) { - hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; mem_hidden_units_ = other.mem_hidden_units_; mallocWeights(); @@ -141,35 +141,35 @@ struct DecoderLayerWeight { LayerNormWeight self_attn_layernorm_weights; AttentionWeight cross_attention_weights; LayerNormWeight cross_attn_layernorm_weights; - FfnWeight ffn_weights; + FfnWeight ffn_weights; private: void setWeightPtr() { - pre_layernorm_weights.beta = weights_ptr[0]; - pre_layernorm_weights.gamma = weights_ptr[1]; - self_attention_weights.query_weight.kernel = weights_ptr[2]; - self_attention_weights.query_weight.bias = weights_ptr[3]; + pre_layernorm_weights.beta = weights_ptr[0]; + pre_layernorm_weights.gamma = weights_ptr[1]; + self_attention_weights.query_weight.kernel = weights_ptr[2]; + self_attention_weights.query_weight.bias = weights_ptr[3]; self_attention_weights.attention_output_weight.kernel = weights_ptr[4]; - self_attention_weights.attention_output_weight.bias = weights_ptr[5]; - self_attn_layernorm_weights.beta = weights_ptr[6]; - self_attn_layernorm_weights.gamma = weights_ptr[7]; - - cross_attention_weights.query_weight.kernel = weights_ptr[8]; - cross_attention_weights.query_weight.bias = weights_ptr[9]; - cross_attention_weights.key_weight.kernel = weights_ptr[10]; - cross_attention_weights.key_weight.bias = weights_ptr[11]; - cross_attention_weights.value_weight.kernel = weights_ptr[12]; - cross_attention_weights.value_weight.bias = weights_ptr[13]; + self_attention_weights.attention_output_weight.bias = weights_ptr[5]; + self_attn_layernorm_weights.beta = weights_ptr[6]; + self_attn_layernorm_weights.gamma = weights_ptr[7]; + + cross_attention_weights.query_weight.kernel = weights_ptr[8]; + cross_attention_weights.query_weight.bias = weights_ptr[9]; + cross_attention_weights.key_weight.kernel = weights_ptr[10]; + cross_attention_weights.key_weight.bias = weights_ptr[11]; + cross_attention_weights.value_weight.kernel = weights_ptr[12]; + cross_attention_weights.value_weight.bias = weights_ptr[13]; cross_attention_weights.attention_output_weight.kernel = weights_ptr[14]; - cross_attention_weights.attention_output_weight.bias = weights_ptr[15]; - cross_attn_layernorm_weights.beta = weights_ptr[16]; - cross_attn_layernorm_weights.gamma = weights_ptr[17]; + cross_attention_weights.attention_output_weight.bias = weights_ptr[15]; + cross_attn_layernorm_weights.beta = weights_ptr[16]; + cross_attn_layernorm_weights.gamma = weights_ptr[17]; ffn_weights.intermediate_weight.kernel = weights_ptr[18]; - ffn_weights.intermediate_weight.bias = weights_ptr[19]; - ffn_weights.output_weight.kernel = weights_ptr[20]; - ffn_weights.output_weight.bias = weights_ptr[21]; + ffn_weights.intermediate_weight.bias = weights_ptr[19]; + ffn_weights.output_weight.kernel = weights_ptr[20]; + ffn_weights.output_weight.bias = weights_ptr[21]; } void mallocWeights() @@ -201,11 +201,11 @@ struct DecoderLayerWeight { is_maintain_buffer = true; } - int hidden_units_; - int inter_size_; - int mem_hidden_units_; + int hidden_units_; + int inter_size_; + int mem_hidden_units_; bool is_maintain_buffer = false; - T* weights_ptr[22]; + T* weights_ptr[22]; }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/decoding/CMakeLists.txt b/src/fastertransformer/models/decoding/CMakeLists.txt index 001ef48ff..9dfe7b028 100644 --- a/src/fastertransformer/models/decoding/CMakeLists.txt +++ b/src/fastertransformer/models/decoding/CMakeLists.txt @@ -18,7 +18,7 @@ add_library(Decoding STATIC Decoding.cc) set_property(TARGET Decoding PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Decoding PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Decoding PUBLIC -lcublas -lcudart -lcurand Decoder decoding_kernels - BeamSearchLayer DynamicDecodeLayer) + BeamSearchLayer DynamicDecodeLayer tensor) add_executable(decoding_gemm decoding_gemm.cc) target_link_libraries(decoding_gemm PUBLIC -lcublas -lcublasLt -lcudart decoding_gemm_func memory_utils) diff --git a/src/fastertransformer/models/decoding/Decoding.cc b/src/fastertransformer/models/decoding/Decoding.cc index cf7cd79cd..89587e8cf 100644 --- a/src/fastertransformer/models/decoding/Decoding.cc +++ b/src/fastertransformer/models/decoding/Decoding.cc @@ -37,14 +37,14 @@ void Decoding::initialize() allocator_, is_free_buffer_after_forward_); - dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, - vocab_size_padded_, - end_id_, - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_, - cuda_device_prop_); + dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, + vocab_size_padded_, + end_id_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + cuda_device_prop_); } template @@ -52,40 +52,48 @@ void Decoding::allocateBuffer() { FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_ == false) { - const size_t batchxbeam = max_batch_size_ * beam_width_; + const size_t batchxbeam = max_batch_size_ * beam_width_; const size_t self_cache_size = num_layer_ * batchxbeam * max_seq_len_ * hidden_units_; - const size_t mem_cache_size = num_layer_ * batchxbeam * mem_max_seq_len_ * hidden_units_; + const size_t mem_cache_size = num_layer_ * batchxbeam * mem_max_seq_len_ * hidden_units_; if (vocab_size_ != vocab_size_padded_) { - padded_embedding_kernel_ = (T*)(allocator_->malloc(sizeof(T) * hidden_units_ * vocab_size_padded_, true)); - padded_embedding_bias_ = (T*)(allocator_->malloc(sizeof(T) * vocab_size_padded_, true)); + padded_embedding_kernel_ = (T*)(allocator_->reMalloc( + padded_embedding_kernel_, sizeof(T) * hidden_units_ * vocab_size_padded_, true)); + padded_embedding_bias_ = + (T*)(allocator_->reMalloc(padded_embedding_bias_, sizeof(T) * vocab_size_padded_, true)); padded_embedding_kernel_ptr_ = padded_embedding_kernel_; - padded_embedding_bias_ptr_ = padded_embedding_bias_; + padded_embedding_bias_ptr_ = padded_embedding_bias_; } - decoder_input_buf_ = (T*)(allocator_->malloc(sizeof(T) * batchxbeam * hidden_units_, false)); - decoder_output_buf_ = (T*)(allocator_->malloc(sizeof(T) * batchxbeam * hidden_units_, false)); - normed_decoder_output_buf_ = (T*)(allocator_->malloc(sizeof(T) * batchxbeam * hidden_units_, false)); - logits_buf_ = (T*)(allocator_->malloc(sizeof(T) * batchxbeam * vocab_size_padded_, false)); - cum_log_probs_ = (float*)(allocator_->malloc(sizeof(float) * batchxbeam, false)); - finished_buf_ = (bool*)(allocator_->malloc(sizeof(bool) * batchxbeam, false)); + decoder_input_buf_ = + (T*)(allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); + decoder_output_buf_ = + (T*)(allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); + normed_decoder_output_buf_ = + (T*)(allocator_->reMalloc(normed_decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); + logits_buf_ = (DynamicDecodeType*)(allocator_->reMalloc( + logits_buf_, sizeof(DynamicDecodeType) * batchxbeam * vocab_size_padded_, false)); + cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); + finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); h_finished_buf_ = new bool[batchxbeam]; - start_ids_buf_ = (int*)(allocator_->malloc(sizeof(int) * max_batch_size_, false)); - end_ids_buf_ = (int*)(allocator_->malloc(sizeof(int) * max_batch_size_, false)); + start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * max_batch_size_, false)); + end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * max_batch_size_, false)); - key_cache_ = (T*)(allocator_->malloc(sizeof(T) * self_cache_size, false)); - value_cache_ = (T*)(allocator_->malloc(sizeof(T) * self_cache_size, false)); + key_cache_ = (T*)(allocator_->reMalloc(key_cache_, sizeof(T) * self_cache_size, false)); + value_cache_ = (T*)(allocator_->reMalloc(value_cache_, sizeof(T) * self_cache_size, false)); if (beam_width_ > 1) { - cache_indirections_[0] = (int*)(allocator_->malloc(sizeof(int) * batchxbeam * max_seq_len_ * 2, true)); + cache_indirections_[0] = + (int*)(allocator_->reMalloc(cache_indirections_, sizeof(int) * batchxbeam * max_seq_len_ * 2, true)); cache_indirections_[1] = cache_indirections_[0] + batchxbeam * max_seq_len_; } - key_mem_cache_ = (T*)(allocator_->malloc(sizeof(T) * mem_cache_size, false)); - value_mem_cache_ = (T*)(allocator_->malloc(sizeof(T) * mem_cache_size, false)); + key_mem_cache_ = (T*)(allocator_->reMalloc(key_mem_cache_, sizeof(T) * mem_cache_size, false)); + value_mem_cache_ = (T*)(allocator_->reMalloc(value_mem_cache_, sizeof(T) * mem_cache_size, false)); - padded_pos_embedding_bias_ = (T*)(allocator_->malloc(sizeof(T) * vocab_size_padded_, false)); - output_ids_buf_ = (int*)(allocator_->malloc(sizeof(int) * batchxbeam * max_seq_len_, false)); - parent_ids_buf_ = (int*)(allocator_->malloc(sizeof(int) * batchxbeam * max_seq_len_, false)); + padded_pos_embedding_bias_ = + (T*)(allocator_->reMalloc(padded_pos_embedding_bias_, sizeof(T) * vocab_size_padded_, false)); + output_ids_buf_ = (int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * max_seq_len_, false)); + parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_seq_len_, false)); is_allocate_buffer_ = true; } @@ -98,33 +106,33 @@ void Decoding::freeBuffer() if (is_allocate_buffer_ == true) { if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ptr_ = nullptr; - padded_embedding_bias_ptr_ = nullptr; - allocator_->free(padded_embedding_kernel_); - allocator_->free(padded_embedding_bias_); + padded_embedding_bias_ptr_ = nullptr; + allocator_->free((void**)(&padded_embedding_kernel_)); + allocator_->free((void**)(&padded_embedding_bias_)); } - allocator_->free(start_ids_buf_); - allocator_->free(end_ids_buf_); - - allocator_->free(decoder_input_buf_); - allocator_->free(decoder_output_buf_); - allocator_->free(normed_decoder_output_buf_); - allocator_->free(logits_buf_); - allocator_->free(cum_log_probs_); - allocator_->free(finished_buf_); + allocator_->free((void**)(&start_ids_buf_)); + allocator_->free((void**)(&end_ids_buf_)); + + allocator_->free((void**)(&decoder_input_buf_)); + allocator_->free((void**)(&decoder_output_buf_)); + allocator_->free((void**)(&normed_decoder_output_buf_)); + allocator_->free((void**)(&logits_buf_)); + allocator_->free((void**)(&cum_log_probs_)); + allocator_->free((void**)(&finished_buf_)); delete[] h_finished_buf_; - allocator_->free(key_cache_); - allocator_->free(value_cache_); + allocator_->free((void**)(&key_cache_)); + allocator_->free((void**)(&value_cache_)); if (cache_indirections_[0] != nullptr) { - allocator_->free(cache_indirections_[0]); + allocator_->free((void**)(&cache_indirections_)[0]); } - allocator_->free(key_mem_cache_); - allocator_->free(value_mem_cache_); + allocator_->free((void**)(&key_mem_cache_)); + allocator_->free((void**)(&value_mem_cache_)); - allocator_->free(padded_pos_embedding_bias_); + allocator_->free((void**)(&padded_pos_embedding_bias_)); - allocator_->free(output_ids_buf_); - allocator_->free(parent_ids_buf_); + allocator_->free((void**)(&output_ids_buf_)); + allocator_->free((void**)(&parent_ids_buf_)); is_allocate_buffer_ = false; } @@ -168,28 +176,28 @@ bool Decoding::isValidMemSeqLen(size_t seq_len) } template -Decoding::Decoding(size_t max_batch_size, - size_t max_seq_len, - size_t mem_max_seq_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, +Decoding::Decoding(size_t max_batch_size, + size_t max_seq_len, + size_t mem_max_seq_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + int start_id, + int end_id, + float beam_search_diversity_rate, + uint top_k, + float top_p, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop): + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len + 1), // allocater additional one to put the start token @@ -254,9 +262,9 @@ Decoding::~Decoding() } template -void Decoding::forward(std::vector* output_tensors, +void Decoding::forward(std::vector* output_tensors, const std::vector* input_tensors, - const DecodingWeight* decoding_weights) + const DecodingWeight* decoding_weights) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); // input_tensors: @@ -279,14 +287,24 @@ void Decoding::forward(std::vector* output_tensors, isValidMemSeqLen(input_tensors->at(0).shape[1]); allocateBuffer(); - const size_t batch_size = output_tensors->at(0).shape[1]; - const int max_input_length = 0; - const DataType data_type = getTensorType(); - const size_t mem_max_seq_len = input_tensors->at(0).shape[1]; + const size_t batch_size = output_tensors->at(0).shape[1]; + const int max_input_length = 0; + const DataType data_type = getTensorType(); + const size_t mem_max_seq_len = input_tensors->at(0).shape[1]; deviceFill(start_ids_buf_, batch_size, start_id_); deviceFill(end_ids_buf_, batch_size, end_id_); + const unsigned long long int random_seed = 0; + std::unordered_map runtime_args{ + {"random_seed", Tensor{MEMORY_CPU, TYPE_UINT64, {1}, &random_seed}}, + {"beam_search_diversity_rate", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &beam_search_diversity_rate_}}, + {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature_}}, + {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &len_penalty_}}, + {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty_}}, + {"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &top_k_}}, + {"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &top_p_}}}; + dynamic_decode_layer_->setup(batch_size, beam_width_, &runtime_args); if (beam_width_ > 1) { cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width_ * max_seq_len_, stream_); } @@ -304,7 +322,7 @@ void Decoding::forward(std::vector* output_tensors, if (vocab_size_ == vocab_size_padded_) { padded_embedding_kernel_ptr_ = decoding_weights->post_decoder_embedding.kernel; - padded_embedding_bias_ptr_ = decoding_weights->post_decoder_embedding.bias; + padded_embedding_bias_ptr_ = decoding_weights->post_decoder_embedding.bias; } else { invokePaddingEmbedding(padded_embedding_kernel_, @@ -327,10 +345,10 @@ void Decoding::forward(std::vector* output_tensors, const std::vector self_v_cache_size = { num_layer_, batch_size * beam_width_, head_num_, (size_t)(max_seq_len_), size_per_head_}; - for (int step = 1; step <= (int)max_seq_len_; step++) { - const int ite = 0; + for (int step = 1; step < (int)max_seq_len_; step++) { + const int ite = 0; const int local_batch_size = batch_size; - const int id_offset = ite * local_batch_size * beam_width_; + const int id_offset = ite * local_batch_size * beam_width_; cudaD2Hcpy(h_finished_buf_, finished_buf_, batch_size * beam_width_); uint sum = 0; @@ -344,19 +362,18 @@ void Decoding::forward(std::vector* output_tensors, const int src_indir_idx = (step - 1) % 2; const int tgt_indir_idx = 1 - src_indir_idx; - invokeEmbeddingLookupPosEncoding(decoder_input_buf_, - decoding_weights->pre_decoder_embedding_table, - decoding_weights->position_encoding_table, - output_ids_buf_, - nullptr, - batch_size * beam_width_, - hidden_units_, - (T)sqrtf(float(hidden_units_)), - step - 1, - 0, - batch_size * beam_width_, - 0, - stream_); + invokeEmbeddingLookupPosEncodingPadCount(decoder_input_buf_, + decoding_weights->pre_decoder_embedding_table, + decoding_weights->position_encoding_table, + output_ids_buf_, + nullptr, + batch_size * beam_width_, + hidden_units_, + (T)sqrtf(float(hidden_units_)), + step - 1, + batch_size * beam_width_, + 0, + stream_); sync_check_cuda_error(); std::vector decoder_input_tensors{ @@ -389,54 +406,81 @@ void Decoding::forward(std::vector* output_tensors, decoder_output_buf_, decoding_weights->post_decoder_layernorm.gamma, decoding_weights->post_decoder_layernorm.beta, + layernorm_eps_, batch_size * beam_width_, hidden_units_, stream_); sync_check_cuda_error(); - cublas_wrapper_->Gemm(CUBLAS_OP_N, - CUBLAS_OP_N, - vocab_size_padded_, // n - batch_size * beam_width_, - hidden_units_, // k - padded_embedding_kernel_ptr_, - vocab_size_padded_, // n - normed_decoder_output_buf_, - hidden_units_, // k - logits_buf_, - vocab_size_padded_ /* n */); - - const int tmp_ite = 0; +#ifdef ENABLE_BF16 + bool is_bf16 = std::is_same::value; +#else + bool is_bf16 = false; +#endif + + if (is_bf16) { + float alpha = 1.0f; + float beta = 0.0f; +#ifdef ENABLE_BF16 + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + vocab_size_padded_, // n + batch_size * beam_width_, + hidden_units_, // k + &alpha, + padded_embedding_kernel_ptr_, + CUDA_R_16BF, + vocab_size_padded_, // k + normed_decoder_output_buf_, + CUDA_R_16BF, + hidden_units_, // k + &beta, + logits_buf_, + CUDA_R_32F, + vocab_size_padded_, /* n */ + CUDA_R_32F, + cublasGemmAlgo_t(-1)); + + invokeAddBias( + logits_buf_, padded_embedding_bias_ptr_, batch_size * beam_width_, vocab_size_padded_, stream_); +#endif + } + else { + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + vocab_size_padded_, // n + batch_size * beam_width_, + hidden_units_, // k + padded_embedding_kernel_ptr_, + vocab_size_padded_, // n + normed_decoder_output_buf_, + hidden_units_, // k + logits_buf_, + vocab_size_padded_ /* n */); + } + + const int tmp_ite = 0; const int tmp_local_batch_size = batch_size; - const bool has_diff_runtime_args = false; - const int runtime_top_k = (int)top_k_; - const float runtime_top_p = (float)top_p_; - const unsigned long long int random_seed = 0; - const bool is_initialize_random_table = step == 1; std::unordered_map dynamic_decode_input_tensors{ {"logits", Tensor{MEMORY_GPU, data_type, {batch_size, beam_width_, vocab_size_padded_}, logits_buf_}}, - {"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size_padded_}, padded_embedding_bias_ptr_}}, + {"embedding_bias", + Tensor{MEMORY_GPU, data_type, {vocab_size_padded_}, is_bf16 ? nullptr : padded_embedding_bias_ptr_}}, {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids_buf_}}, {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}, {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width_}, nullptr}}, {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &tmp_ite}}, - {"has_diff_runtime_args", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &has_diff_runtime_args}}, {"src_key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_size, key_cache_}}, {"src_value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_size, value_cache_}}, {"src_cache_indirection", Tensor{ MEMORY_GPU, TYPE_INT32, {batch_size, beam_width_, max_seq_len_}, cache_indirections_[src_indir_idx]}}, {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_local_batch_size}}, - {"random_seed", Tensor{MEMORY_CPU, TYPE_UINT64, {1}, &random_seed}}, {"beam_search_diversity_rate", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &beam_search_diversity_rate_}}, {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature_}}, {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &len_penalty_}}, - {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty_}}, - {"runtime_top_k", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &runtime_top_k}}, - {"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &runtime_top_p}}, - {"is_initialize_random_table", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_initialize_random_table}}}; + {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty_}}}; // TODO(bhsueh) Need to modify the forward function to use unordered_map // for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { @@ -494,5 +538,8 @@ void Decoding::forward(std::vector* output_tensors, template class Decoding; template class Decoding; +#ifdef ENABLE_BF16 +template class Decoding<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/decoding/Decoding.h b/src/fastertransformer/models/decoding/Decoding.h index 73a2945ad..a96f3326a 100644 --- a/src/fastertransformer/models/decoding/Decoding.h +++ b/src/fastertransformer/models/decoding/Decoding.h @@ -26,36 +26,50 @@ namespace fastertransformer { +// fallback to fp32 dynamic decoder when bf16 specified +template +struct fallBackType { + using Type = float; +}; + +template<> +struct fallBackType { + using Type = half; +}; + template class Decoding: public BaseLayer { private: // buffer handling - size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_batch_size_ = 0; + size_t max_seq_len_ = 0; size_t mem_max_seq_len_ = 0; // meta data - size_t beam_width_; - size_t head_num_; - size_t size_per_head_; - size_t inter_size_; - size_t num_layer_; - size_t vocab_size_; - - int start_id_; - int end_id_; - float beam_search_diversity_rate_; + size_t beam_width_; + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + static constexpr float layernorm_eps_ = 1e-6f; + + int start_id_; + int end_id_; + float beam_search_diversity_rate_; size_t hidden_units_; - size_t top_k_; - float top_p_; - float temperature_; - float len_penalty_; - float repetition_penalty_; + uint top_k_; + float top_p_; + float temperature_; + float len_penalty_; + float repetition_penalty_; // calculated data size_t vocab_size_padded_; - Decoder* decoder_; - DynamicDecodeLayer* dynamic_decode_layer_; + using DynamicDecodeType = typename fallBackType::Type; + + Decoder* decoder_; + DynamicDecodeLayer* dynamic_decode_layer_; void allocateBuffer() override; void freeBuffer() override; @@ -66,26 +80,26 @@ class Decoding: public BaseLayer { void initialize(); protected: - T* padded_embedding_kernel_ = nullptr; - T* padded_embedding_bias_ = nullptr; + T* padded_embedding_kernel_ = nullptr; + T* padded_embedding_bias_ = nullptr; const T* padded_embedding_kernel_ptr_ = nullptr; - const T* padded_embedding_bias_ptr_ = nullptr; + const T* padded_embedding_bias_ptr_ = nullptr; - T* decoder_input_buf_ = nullptr; - T* decoder_output_buf_ = nullptr; - T* normed_decoder_output_buf_ = nullptr; - T* logits_buf_ = nullptr; - float* cum_log_probs_ = nullptr; - bool* finished_buf_ = nullptr; - bool* h_finished_buf_ = nullptr; + T* decoder_input_buf_ = nullptr; + T* decoder_output_buf_ = nullptr; + T* normed_decoder_output_buf_ = nullptr; + DynamicDecodeType* logits_buf_ = nullptr; + float* cum_log_probs_ = nullptr; + bool* finished_buf_ = nullptr; + bool* h_finished_buf_ = nullptr; int* start_ids_buf_; int* end_ids_buf_; - T* key_cache_ = nullptr; - T* value_cache_ = nullptr; - T* key_mem_cache_ = nullptr; - T* value_mem_cache_ = nullptr; + T* key_cache_ = nullptr; + T* value_cache_ = nullptr; + T* key_mem_cache_ = nullptr; + T* value_mem_cache_ = nullptr; int* cache_indirections_[2] = {nullptr, nullptr}; T* padded_pos_embedding_bias_ = nullptr; @@ -94,36 +108,36 @@ class Decoding: public BaseLayer { int* parent_ids_buf_ = nullptr; public: - Decoding(size_t max_batch_size, - size_t max_seq_len, - size_t mem_max_seq_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, + Decoding(size_t max_batch_size, + size_t max_seq_len, + size_t mem_max_seq_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + int start_id, + int end_id, + float beam_search_diversity_rate, + uint top_k, + float top_p, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop); + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop); Decoding(Decoding const& Decoding); ~Decoding(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const DecodingWeight* Decoding_weights); + const DecodingWeight* Decoding_weights); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/decoding/DecodingWeight.h b/src/fastertransformer/models/decoding/DecodingWeight.h index 0cce2bd69..9f30603c3 100644 --- a/src/fastertransformer/models/decoding/DecodingWeight.h +++ b/src/fastertransformer/models/decoding/DecodingWeight.h @@ -55,13 +55,13 @@ struct DecodingWeight { deviceFree(weights_ptr[i]); } - position_encoding_table = nullptr; - pre_decoder_embedding_table = nullptr; - post_decoder_layernorm.beta = nullptr; - post_decoder_layernorm.gamma = nullptr; + position_encoding_table = nullptr; + pre_decoder_embedding_table = nullptr; + post_decoder_layernorm.beta = nullptr; + post_decoder_layernorm.gamma = nullptr; post_decoder_embedding.kernel = nullptr; - post_decoder_embedding.bias = nullptr; - is_maintain_buffer = false; + post_decoder_embedding.bias = nullptr; + is_maintain_buffer = false; } } @@ -90,11 +90,11 @@ struct DecodingWeight { DecodingWeight& operator=(const DecodingWeight& other) { - hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; - num_layer_ = other.num_layer_; - vocab_size_ = other.vocab_size_; - max_seq_len_ = other.max_seq_len_; + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + num_layer_ = other.num_layer_; + vocab_size_ = other.vocab_size_; + max_seq_len_ = other.max_seq_len_; mem_hidden_units_ = other.mem_hidden_units_; mallocWeights(); @@ -125,30 +125,30 @@ struct DecodingWeight { } std::vector> decoder_layer_weights; - const T* position_encoding_table = nullptr; - const T* pre_decoder_embedding_table = nullptr; - LayerNormWeight post_decoder_layernorm; - DenseWeight post_decoder_embedding; + const T* position_encoding_table = nullptr; + const T* pre_decoder_embedding_table = nullptr; + LayerNormWeight post_decoder_layernorm; + DenseWeight post_decoder_embedding; private: void setWeightPtr() { - position_encoding_table = weights_ptr[0]; - pre_decoder_embedding_table = weights_ptr[1]; - post_decoder_layernorm.beta = weights_ptr[2]; - post_decoder_layernorm.gamma = weights_ptr[3]; + position_encoding_table = weights_ptr[0]; + pre_decoder_embedding_table = weights_ptr[1]; + post_decoder_layernorm.beta = weights_ptr[2]; + post_decoder_layernorm.gamma = weights_ptr[3]; post_decoder_embedding.kernel = weights_ptr[4]; - post_decoder_embedding.bias = weights_ptr[5]; + post_decoder_embedding.bias = weights_ptr[5]; } - int hidden_units_; - int inter_size_; - int vocab_size_; - int num_layer_; - int max_seq_len_; - int mem_hidden_units_; + int hidden_units_; + int inter_size_; + int vocab_size_; + int num_layer_; + int max_seq_len_; + int mem_hidden_units_; bool is_maintain_buffer = false; - T* weights_ptr[6]; + T* weights_ptr[6]; }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/decoding/decoding_gemm.cc b/src/fastertransformer/models/decoding/decoding_gemm.cc index e5f60d3f4..3eb8ab358 100644 --- a/src/fastertransformer/models/decoding/decoding_gemm.cc +++ b/src/fastertransformer/models/decoding/decoding_gemm.cc @@ -28,16 +28,16 @@ int main(int argc, char* argv[]) return 0; } - const int batch_size = atoi(argv[1]); - const int beam_width = atoi(argv[2]); - const int head_num = atoi(argv[3]); - const int size_per_head = atoi(argv[4]); - const int inter_size = atoi(argv[5]); - const int vocab_size = atoi(argv[6]); - const int max_mem_seq_len = atoi(argv[7]); - const int memory_hidden_units = atoi(argv[8]); + const int batch_size = atoi(argv[1]); + const int beam_width = atoi(argv[2]); + const int head_num = atoi(argv[3]); + const int size_per_head = atoi(argv[4]); + const int inter_size = atoi(argv[5]); + const int vocab_size = atoi(argv[6]); + const int max_mem_seq_len = atoi(argv[7]); + const int memory_hidden_units = atoi(argv[8]); const ft::CublasDataType data_type = static_cast(atoi(argv[9])); // 0 FP32, 1 FP16, 2 BF 16 - const bool is_append = argc == 11 ? ((bool)atoi(argv[10])) : false; + const bool is_append = argc == 11 ? ((bool)atoi(argv[10])) : false; printf("[INFO] arguments: \n"); printf(" batch_size: %d \n", batch_size); @@ -51,7 +51,7 @@ int main(int argc, char* argv[]) printf(" data_type: %d \n", data_type); std::cout << std::endl; - void* gemm_test_buf; + void* gemm_test_buf; size_t buf_size_in_byte = ft::calDecodingGemmTestBufSizeInByte(batch_size, beam_width, max_mem_seq_len, diff --git a/src/fastertransformer/models/gptj/CMakeLists.txt b/src/fastertransformer/models/gptj/CMakeLists.txt index d7d9d3e3e..759bae513 100644 --- a/src/fastertransformer/models/gptj/CMakeLists.txt +++ b/src/fastertransformer/models/gptj/CMakeLists.txt @@ -28,6 +28,7 @@ target_link_libraries(GptJDecoder PUBLIC -lcudart cublasMMWrapper layernorm_kernels add_residual_kernels GptJDecoderLayerWeight + tensor nccl_utils) add_library(GptJContextDecoder STATIC GptJContextDecoder.cc) @@ -39,6 +40,7 @@ target_link_libraries(GptJContextDecoder PUBLIC -lcudart cublasMMWrapper layernorm_kernels add_residual_kernels gpt_kernels + tensor nccl_utils) add_library(GptJWeight STATIC GptJWeight.cc) @@ -57,4 +59,5 @@ target_link_libraries(GptJ PUBLIC -lcudart DynamicDecodeLayer BaseBeamSearchLayer bert_preprocess_kernels + tensor GptJWeight) diff --git a/src/fastertransformer/models/gptj/GptJ.cc b/src/fastertransformer/models/gptj/GptJ.cc index cfd72b8a0..e2684e13b 100644 --- a/src/fastertransformer/models/gptj/GptJ.cc +++ b/src/fastertransformer/models/gptj/GptJ.cc @@ -33,6 +33,8 @@ void GptJ::initialize() inter_size_, num_layer_, rotary_embedding_dim_, + neox_rotary_style_, + layernorm_eps_, tensor_para_, pipeline_para_, stream_, @@ -49,6 +51,8 @@ void GptJ::initialize() inter_size_, num_layer_, rotary_embedding_dim_, + neox_rotary_style_, + layernorm_eps_, tensor_para_, pipeline_para_, stream_, @@ -75,12 +79,13 @@ void GptJ::allocateBuffer() } template -void GptJ::allocateBuffer(size_t batch_size, size_t beam_width, size_t max_seq_len, size_t max_input_len) +void GptJ::allocateBuffer( + size_t batch_size, size_t beam_width, size_t max_seq_len, size_t max_cache_seq_len, size_t max_input_len) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - const size_t batchxbeam = batch_size * beam_width; - const size_t self_cache_size = - (num_layer_ / pipeline_para_.world_size_) * batchxbeam * max_seq_len * hidden_units_ / tensor_para_.world_size_; + const size_t batchxbeam = batch_size * beam_width; + const size_t self_cache_size = (num_layer_ / pipeline_para_.world_size_) * batchxbeam * max_cache_seq_len + * hidden_units_ / tensor_para_.world_size_; if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ = @@ -92,8 +97,9 @@ void GptJ::allocateBuffer(size_t batch_size, size_t beam_width, size_t max_se padded_embedding_bias_ptr_ = padded_embedding_bias_; } - input_attention_mask_ = - (T*)(allocator_->reMalloc(input_attention_mask_, sizeof(T) * batchxbeam * max_seq_len * max_seq_len, false)); + // TODO : memory allocation optimization --> [max_input_len, max_input_len + max_prefix_prompt_len] + input_attention_mask_ = (T*)(allocator_->reMalloc( + input_attention_mask_, sizeof(T) * batchxbeam * max_seq_len * max_cache_seq_len, false)); decoder_input_buf_ = (T*)(allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); decoder_output_buf_ = (T*)(allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); @@ -102,17 +108,28 @@ void GptJ::allocateBuffer(size_t batch_size, size_t beam_width, size_t max_se logits_buf_ = (float*)(allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size_padded_, false)); nccl_logits_buf_ = (float*)(allocator_->reMalloc(nccl_logits_buf_, sizeof(float) * batchxbeam * vocab_size_padded_, false)); - cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); - finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); + cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); + finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); + sequence_lengths_ = (int*)(allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false)); + masked_tokens_ = (bool*)(allocator_->reMalloc(masked_tokens_, sizeof(bool) * batchxbeam * max_cache_seq_len, true)); + h_finished_buf_ = new bool[batchxbeam]; - key_cache_ = (T*)(allocator_->reMalloc(key_cache_, sizeof(T) * self_cache_size * 2, true)); + key_cache_ = (T*)(allocator_->reMalloc(key_cache_, sizeof(T) * self_cache_size * 2, true)); value_cache_ = key_cache_ + self_cache_size; if (beam_width > 1) { - cache_indirections_[0] = - (int*)(allocator_->reMalloc(cache_indirections_[0], sizeof(int) * batchxbeam * max_seq_len * 2, true)); - cache_indirections_[1] = cache_indirections_[0] + batchxbeam * max_seq_len; + cache_indirections_[0] = (int*)(allocator_->reMalloc( + cache_indirections_[0], sizeof(int) * batchxbeam * max_cache_seq_len * 2, true)); + cache_indirections_[1] = cache_indirections_[0] + batchxbeam * max_cache_seq_len; } + tiled_total_padding_count_ = + (int*)allocator_->reMalloc(tiled_total_padding_count_, batchxbeam * sizeof(int), false); + + // prompt_learning weight batch ptrs + prompt_learning_weight_batch_ = + (const T**)(allocator_->reMalloc(prompt_learning_weight_batch_, sizeof(T*) * batchxbeam, false)); + tiled_prompt_lengths_buf_ = + (int*)(allocator_->reMalloc(tiled_prompt_lengths_buf_, sizeof(int) * batchxbeam, false)); tiled_input_ids_buf_ = (int*)(allocator_->reMalloc(tiled_input_ids_buf_, sizeof(int) * batchxbeam * max_input_len, true)); @@ -124,14 +141,20 @@ void GptJ::allocateBuffer(size_t batch_size, size_t beam_width, size_t max_se parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * batch_size, false)); - end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false)); + end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false)); + seq_limit_len_ = (uint32_t*)(allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false)); - context_decoder_input_buf_ = (T*)(allocator_->reMalloc( + context_decoder_input_buf_ = (T*)(allocator_->reMalloc( context_decoder_input_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); context_decoder_output_buf_ = (T*)(allocator_->reMalloc( context_decoder_output_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); output_log_probs_buf_ = (float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_seq_len, false)); + + if (generation_should_stop_ == nullptr) { + cudaMallocHost(&generation_should_stop_, 1 * sizeof(bool)); + } + is_allocate_buffer_ = true; } @@ -141,70 +164,82 @@ void GptJ::freeBuffer() if (is_allocate_buffer_) { if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ptr_ = nullptr; - padded_embedding_bias_ptr_ = nullptr; - allocator_->free(padded_embedding_kernel_); - allocator_->free(padded_embedding_bias_); + padded_embedding_bias_ptr_ = nullptr; + allocator_->free((void**)(&padded_embedding_kernel_)); + allocator_->free((void**)(&padded_embedding_bias_)); } - allocator_->free(input_attention_mask_); - allocator_->free(decoder_input_buf_); - allocator_->free(decoder_output_buf_); - allocator_->free(normed_decoder_output_buf_); - allocator_->free(logits_buf_); - allocator_->free(nccl_logits_buf_); - allocator_->free(cum_log_probs_); - allocator_->free(finished_buf_); + allocator_->free((void**)(&input_attention_mask_)); + allocator_->free((void**)(&decoder_input_buf_)); + allocator_->free((void**)(&decoder_output_buf_)); + allocator_->free((void**)(&normed_decoder_output_buf_)); + allocator_->free((void**)(&logits_buf_)); + allocator_->free((void**)(&nccl_logits_buf_)); + allocator_->free((void**)(&cum_log_probs_)); + allocator_->free((void**)(&sequence_lengths_)); + allocator_->free((void**)(&finished_buf_)); delete[] h_finished_buf_; - allocator_->free(key_cache_); + allocator_->free((void**)(&key_cache_)); if (cache_indirections_[0] != nullptr) { - allocator_->free(cache_indirections_[0]); + allocator_->free((void**)(&cache_indirections_)[0]); } - allocator_->free(tiled_input_ids_buf_); - allocator_->free(tiled_input_lengths_buf_); + allocator_->free((void**)(&prompt_learning_weight_batch_)); + allocator_->free((void**)(&tiled_prompt_lengths_buf_)); + allocator_->free((void**)(&tiled_total_padding_count_)); + + allocator_->free((void**)(&tiled_input_ids_buf_)); + allocator_->free((void**)(&tiled_input_lengths_buf_)); + + allocator_->free((void**)(&transposed_output_ids_buf_)); + allocator_->free((void**)(&output_ids_buf_)); + allocator_->free((void**)(&parent_ids_buf_)); + allocator_->free((void**)(&masked_tokens_)); + + allocator_->free((void**)(&start_ids_buf_)); + allocator_->free((void**)(&end_ids_buf_)); + allocator_->free((void**)(&seq_limit_len_)); - allocator_->free(transposed_output_ids_buf_); - allocator_->free(output_ids_buf_); - allocator_->free(parent_ids_buf_); + allocator_->free((void**)(&context_decoder_input_buf_)); + allocator_->free((void**)(&context_decoder_output_buf_)); + allocator_->free((void**)(&output_log_probs_buf_)); - allocator_->free(start_ids_buf_); - allocator_->free(end_ids_buf_); + cudaFreeHost(generation_should_stop_); - allocator_->free(context_decoder_input_buf_); - allocator_->free(context_decoder_output_buf_); - allocator_->free(output_log_probs_buf_); is_allocate_buffer_ = false; } } template -GptJ::GptJ(size_t max_batch_size, - size_t max_seq_len, - size_t max_input_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - size_t rotary_embedding_dim, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop, +GptJ::GptJ(size_t max_batch_size, + size_t max_seq_len, + size_t max_input_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop), head_num_(head_num), size_per_head_(size_per_head), @@ -214,13 +249,15 @@ GptJ::GptJ(size_t max_batch_size, rotary_embedding_dim_(rotary_embedding_dim), start_id_(start_id), end_id_(end_id), + prompt_learning_start_id_(prompt_learning_start_id), + prompt_learning_type_(prompt_learning_type), hidden_units_(head_num * size_per_head), local_head_num_(head_num / 1) { - tensor_para_.world_size_ = 1; - tensor_para_.rank_ = 0; + tensor_para_.world_size_ = 1; + tensor_para_.rank_ = 0; pipeline_para_.world_size_ = 1; - pipeline_para_.rank_ = 0; + pipeline_para_.rank_ = 0; int local_vacab_size = ceil(vocab_size_ / 1.f / tensor_para_.world_size_); if (std::is_same::value) { @@ -231,34 +268,36 @@ GptJ::GptJ(size_t max_batch_size, } template -GptJ::GptJ(size_t max_batch_size, - size_t max_seq_len, - size_t max_input_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - size_t rotary_embedding_dim, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop, +GptJ::GptJ(size_t max_batch_size, + size_t max_seq_len, + size_t max_input_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop), head_num_(head_num), size_per_head_(size_per_head), @@ -268,6 +307,8 @@ GptJ::GptJ(size_t max_batch_size, rotary_embedding_dim_(rotary_embedding_dim), start_id_(start_id), end_id_(end_id), + prompt_learning_start_id_(prompt_learning_start_id), + prompt_learning_type_(prompt_learning_type), hidden_units_(head_num * size_per_head), tensor_para_(tensor_para), pipeline_para_(pipeline_para), @@ -294,6 +335,8 @@ GptJ::GptJ(GptJ const& gpt): rotary_embedding_dim_(gpt.rotary_embedding_dim_), start_id_(gpt.start_id_), end_id_(gpt.end_id_), + prompt_learning_start_id_(gpt.prompt_learning_start_id_), + prompt_learning_type_(gpt.prompt_learning_type_), hidden_units_(gpt.hidden_units_), tensor_para_(gpt.tensor_para_), pipeline_para_(gpt.pipeline_para_), @@ -315,35 +358,52 @@ GptJ::~GptJ() } template -void GptJ::forward(std::vector* output_tensors, +void GptJ::registerCallback(callback_sig* fn, void* ctx) +{ + token_generated_cb_ = fn; + token_generated_ctx_ = ctx; +} + +template +void GptJ::unRegisterCallback() +{ + token_generated_cb_ = nullptr; + token_generated_ctx_ = nullptr; +} + +template +void GptJ::forward(std::vector* output_tensors, const std::vector* input_tensors, - const GptJWeight* gpt_weights) + const GptJWeight* gpt_weights) { FT_CHECK(false); } template -void GptJ::forward(std::unordered_map* output_tensors, +void GptJ::forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const GptJWeight* gpt_weights) + const GptJWeight* gpt_weights) { // input_tensors: // input_ids [batch_size, max_input_length] // input_lengths [batch_size] - // max_output_seq_len [1] or [batch_size] on cpu + // prompt_learning_task_name_ids [batch_size] on cpu, optional + // output_seq_len [batch_size] on cpu // start_id [batch_size] on cpu, optional // end_id [batch_size] on cpu, optional // stop_words_list [batch_size, 2, stop_words_length], optional // bad_words_list [2, bad_words_length] or [batch_size, 2, bad_words_length], optional - // runtime_top_k [1] or [batch_size] on cpu, optional - // runtime_top_p [1] or [batch_size] on cpu, optional - // beam_search_diversity_rate [1] or [batch_size] on cpu, optional - // temperature [1] or [batch_size] on cpu, optional - // len_penalty [1] or [batch_size] on cpu, optional - // repetition_penalty [1] or [batch_size] on cpu, optional - // random_seed [1] or [batch_size] on cpu, optional - // prefix_soft_prompt_lengths [batch_size], optional - // prefix_soft_prompt_embedding [batch_size, max_prefix_soft_prompt_length, hidden_units], float, optional + // runtime_top_k [1] or [batch_size] on cpu, optional, uint. + // runtime_top_p [1] or [batch_size] on cpu, optional, float. + // beam_search_diversity_rate [1] or [batch_size] on cpu, optional, float. + // temperature [1] or [batch_size] on cpu, optional, float. + // len_penalty [1] or [batch_size] on cpu, optional, float. + // repetition_penalty [1] or [batch_size] on cpu, optional, float. + // random_seed [1] or [batch_size] on cpu, optional, unsigned long long int. + // request_prompt_lengths [batch_size], optional + // request_prompt_embedding [batch_size, max_prompt_length, hidden_units], float, optional + // requst_prompt_type [batch_size], int, optional + // memory_len [1] on cpu, uint32, optional // output_tensors: // output_ids [batch_size, beam_width, max_output_seq_len] @@ -363,8 +423,8 @@ void GptJ::forward(std::unordered_map* output_tensors, FT_CHECK_WITH_INFO(output_tensors->size() >= 2, "output_tensors->size() >= 2"); FT_CHECK(input_tensors->at("input_ids").shape.size() == 2); FT_CHECK(input_tensors->at("input_lengths").shape.size() == 1); - FT_CHECK(input_tensors->at("max_output_seq_len").shape.size() == 1 - || input_tensors->at("max_output_seq_len").shape.size() == 2); + FT_CHECK(input_tensors->find("output_seq_len") != input_tensors->end() + && input_tensors->at("output_seq_len").shape.size() == 1); FT_CHECK(output_tensors->at("output_ids").shape.size() == 3); FT_CHECK(output_tensors->at("sequence_length").shape.size() == 2); FT_CHECK_WITH_INFO(input_tensors->at("input_ids").shape[0] == output_tensors->at("output_ids").shape[0], @@ -372,25 +432,121 @@ void GptJ::forward(std::unordered_map* output_tensors, const size_t batch_size = output_tensors->at("output_ids").shape[0]; const size_t beam_width = output_tensors->at("output_ids").shape[1]; + + PromptLearningType request_prompt_type = PromptLearningType::no_prompt; + int valid_prompt_inputs = input_tensors->count("request_prompt_type") + + input_tensors->count("request_prompt_lengths") + + input_tensors->count("request_prompt_embedding"); + + if (valid_prompt_inputs == 3) { + request_prompt_type = static_cast(input_tensors->at("request_prompt_type").getVal()); + FT_LOG_INFO("Apply prompt embedding from input, will ignore task name ids"); + } + else if (valid_prompt_inputs > 0) { + FT_LOG_WARNING( + "Prompts not applied: request_prompt_embedding, request_prompt_lengths, request_prompt_type are all needed!"); + } + if (request_prompt_type == PromptLearningType::prefix_prompt) { + FT_LOG_WARNING("Request prompt doesn't support prefix prompt currently!"); + } + + // Prefix Prompt Inputs + // Padding works as follows: p p x x i i i x x --> p p i i i x x x x (p denotes prompt, i denotes input, x denotes + // pad) + // TODO (perkzz): move unnecessary paddings + const int* prompt_learning_task_name_ids = + input_tensors->count("prompt_learning_task_name_ids") ? + (const int*)(input_tensors->at("prompt_learning_task_name_ids").data) : + nullptr; + has_prefix_prompt_ = + (prompt_learning_task_name_ids != nullptr) && (prompt_learning_type_ == PromptLearningType::prefix_prompt); + int max_prefix_prompt_length = 0; + + FT_CHECK_WITH_INFO( + !(prompt_learning_task_name_ids != nullptr + && (prompt_learning_type_ == PromptLearningType::no_prompt + || prompt_learning_type_ == PromptLearningType::soft_prompt)), + "prompt_learning_type is prefix_prompt either p_prompt_tuning when prompt_learning_task_name_ids are provided."); + + // NOTE: Prefix Prompt PreProcessing + // get prefix_prompt_weight for each batch --> shape [batch, beam_width] + // --> ptrs with shape [num_layers, 2, num_heads, perfix_seq_len, size_per_head] + std::vector prefix_prompt_weight_batch_ptrs; + std::vector prefix_prompt_lengths; + if (has_prefix_prompt_) { + for (int bs_id = 0; bs_id < batch_size; ++bs_id) { + int task_id = prompt_learning_task_name_ids[bs_id]; + // throw errors when prompt task_name_ids are not found + std::pair prefix_prompt_weight_length_pair; + try { + prefix_prompt_weight_length_pair = gpt_weights->prompt_learning_table.at(task_id); + } + catch (const std::out_of_range& oor) { + FT_LOG_ERROR("prefix_prompt_weights_lengths not found for prompt task id: " + task_id); + throw oor; + } + for (int bw_id = 0; bw_id < beam_width; ++bw_id) { + prefix_prompt_weight_batch_ptrs.push_back(prefix_prompt_weight_length_pair.first); + prefix_prompt_lengths.push_back(prefix_prompt_weight_length_pair.second); + } + } + + max_prefix_prompt_length = *max_element(prefix_prompt_lengths.begin(), prefix_prompt_lengths.end()); + + FT_LOG_DEBUG("max_prefix_prompt_length: %d", max_prefix_prompt_length); + + if (max_prefix_prompt_length == 0) { + has_prefix_prompt_ = false; + FT_LOG_DEBUG("prompts are not applied !"); + } + } + int max_input_length = input_tensors->at("input_ids").shape[1]; - const size_t max_prefix_soft_prompt_length = input_tensors->count("prefix_soft_prompt_embedding") ? - input_tensors->at("prefix_soft_prompt_embedding").shape[1] : - 0; - const size_t max_output_seq_len = *std::max_element(input_tensors->at("max_output_seq_len").getPtr(), - input_tensors->at("max_output_seq_len").getPtr() - + input_tensors->at("max_output_seq_len").shape[0]) - + (max_input_length == 0 ? 1 : 0) // additional 1 to put start token - + max_prefix_soft_prompt_length; - const size_t max_seq_len = max_output_seq_len; - allocateBuffer(batch_size, beam_width, max_seq_len, max_input_length + max_prefix_soft_prompt_length); + FT_CHECK_WITH_INFO(!(max_input_length == 0 && max_prefix_prompt_length > 0), + "Prefix Prompt should come with inputs!"); + + // Prefix Soft Prompt (only support request prompt embedding currently) + has_prefix_soft_prompt_ = request_prompt_type == PromptLearningType::soft_prompt; + const size_t max_prefix_soft_prompt_length = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_embedding").shape[1] : 0; + const size_t limit_len_offset = max_prefix_soft_prompt_length + (max_input_length == 0 ? 1 : 0); + const size_t max_output_seq_len = input_tensors->at("output_seq_len").max() + limit_len_offset; + const size_t max_seq_len = max_output_seq_len; + + size_t memory_len = max_output_seq_len; + if (input_tensors->find("memory_len") != input_tensors->end()) { + memory_len = input_tensors->at("memory_len").getVal(); + } + /* TODO: could remove this constraint by changing how context decoder operates */ + FT_CHECK_WITH_INFO(max_input_length <= memory_len, + fmtstr("Memory size too low (%d) vs. input length (%d)", memory_len, max_input_length)); + + // max cache seq len should include max prefix prompt length as it has k/v states + const size_t max_cache_seq_len = memory_len + max_prefix_prompt_length; + + if (max_cache_seq_len < max_seq_len) { + FT_LOG_WARNING("max_cache_seq_len (%d) is less than max_seq_len (%d). " + "Note that this reduces the memory cost of k/v cache, but may hurt the accuracy.", + max_cache_seq_len, + max_seq_len); + } + else if (max_cache_seq_len > max_seq_len) { + FT_LOG_WARNING("max_cache_seq_len (%d) is larger than max_seq_len (%d). " + "This may lead to additional memory cost. Suggest to use smaller max_cache_seq_len.", + max_cache_seq_len, + max_seq_len); + } + + allocateBuffer( + batch_size, beam_width, max_seq_len, max_cache_seq_len, max_input_length + max_prefix_soft_prompt_length); sync_check_cuda_error(); - bool has_diff_runtime_args = hasDiffRuntimeArgs(input_tensors); - const bool has_per_item_requested_length = input_tensors->at("max_output_seq_len").shape.size() > 1; + setSeqLimitLen(seq_limit_len_, input_tensors->at("output_seq_len"), limit_len_offset, batch_size); - int* sequence_lengths = (int*)(output_tensors->at("sequence_length").data); - const DataType data_type = getTensorType(); + const DataType data_type = getTensorType(); + const cudaDataType_t gemm_data_type = getCudaDataType(); + dynamic_decode_layer_->setup(batch_size, beam_width, input_tensors); handleOptArg(input_tensors, "start_id", start_ids_buf_, start_id_, batch_size); handleOptArg(input_tensors, "end_id", end_ids_buf_, end_id_, batch_size); @@ -398,24 +554,41 @@ void GptJ::forward(std::unordered_map* output_tensors, batch_size * beam_width, local_head_num_, size_per_head_ / (16 / sizeof(T)), - max_output_seq_len, + max_cache_seq_len, 16 / sizeof(T)}; const std::vector self_v_cache_shape = {num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, local_head_num_, - max_output_seq_len, + max_cache_seq_len, size_per_head_}; // initialize the output ids and parent ids cudaMemsetAsync(output_ids_buf_, 0, sizeof(int) * batch_size * beam_width * max_seq_len, stream_); cudaMemsetAsync(parent_ids_buf_, 0, sizeof(int) * batch_size * beam_width * max_seq_len, stream_); + cudaMemsetAsync(masked_tokens_, false, sizeof(bool) * batch_size * beam_width * max_cache_seq_len, stream_); + cudaMemsetAsync(tiled_total_padding_count_, 0, sizeof(int) * batch_size * beam_width, stream_); if (beam_width > 1) { cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * max_seq_len, stream_); } + + // Prefix prompts + if (has_prefix_prompt_) { + cudaMemcpyAsync(prompt_learning_weight_batch_, + prefix_prompt_weight_batch_ptrs.data(), + sizeof(T*) * batch_size * beam_width, + cudaMemcpyDefault, + stream_); + cudaMemcpyAsync(tiled_prompt_lengths_buf_, + prefix_prompt_lengths.data(), + sizeof(int) * batch_size * beam_width, + cudaMemcpyDefault, + stream_); + } + sync_check_cuda_error(); // handle first step - if (input_tensors->count("prefix_soft_prompt_embedding") || max_input_length >= 1) { + if (has_prefix_prompt_ || has_prefix_soft_prompt_ || max_input_length > 1) { invokeTileGptInputs(tiled_input_ids_buf_, tiled_input_lengths_buf_, (int*)input_tensors->at("input_ids").data, @@ -426,23 +599,23 @@ void GptJ::forward(std::unordered_map* output_tensors, stream_); sync_check_cuda_error(); - if (input_tensors->count("prefix_soft_prompt_embedding")) { + if (has_prefix_soft_prompt_) { inputIdsEmbeddingLookupPosEncodingSoftPromptParam param; - param.from_tensor = context_decoder_input_buf_; - param.output_ids = output_ids_buf_; - param.input_lengths = tiled_input_lengths_buf_; - param.embedding_table = gpt_weights->pre_decoder_embedding_table; - param.pos_table = gpt_weights->position_encoding_table; - param.prefix_soft_prompt_embedding = input_tensors->at("prefix_soft_prompt_embedding").getPtr(); - param.prefix_soft_prompt_lengths = input_tensors->at("prefix_soft_prompt_lengths").getPtr(); - param.input_ids = tiled_input_ids_buf_; - param.start_step = 1; - param.max_input_length = max_input_length; + param.from_tensor = context_decoder_input_buf_; + param.output_ids = output_ids_buf_; + param.input_lengths = tiled_input_lengths_buf_; + param.embedding_table = gpt_weights->pre_decoder_embedding_table; + param.pos_table = gpt_weights->position_encoding_table; + param.prefix_soft_prompt_embedding = input_tensors->at("request_prompt_embedding").getPtr(); + param.prefix_soft_prompt_lengths = input_tensors->at("request_prompt_lengths").getPtr(); + param.input_ids = tiled_input_ids_buf_; + param.start_step = 1; + param.max_input_length = max_input_length; param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; - param.batch_size = batch_size; - param.beam_width = beam_width; - param.hidden_units = hidden_units_; - param.stream = stream_; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.hidden_units = hidden_units_; + param.stream = stream_; invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(param); sync_check_cuda_error(); @@ -453,6 +626,7 @@ void GptJ::forward(std::unordered_map* output_tensors, output_ids_buf_, gpt_weights->pre_decoder_embedding_table, gpt_weights->position_encoding_table, + pPromptTuningParam{}, // p/prompt tuning tiled_input_ids_buf_, 1, max_input_length, @@ -463,8 +637,13 @@ void GptJ::forward(std::unordered_map* output_tensors, sync_check_cuda_error(); } - invokeBuildDecoderAttentionMask( - input_attention_mask_, tiled_input_lengths_buf_, batch_size * beam_width, max_input_length, stream_); + invokeBuildDecoderAttentionMask(input_attention_mask_, + tiled_input_lengths_buf_, + tiled_prompt_lengths_buf_, + batch_size * beam_width, + max_input_length, + max_prefix_prompt_length, + stream_); sync_check_cuda_error(); std::unordered_map decoder_input_tensors{ @@ -476,9 +655,22 @@ void GptJ::forward(std::unordered_map* output_tensors, {"attention_mask", Tensor{MEMORY_GPU, data_type, - {batch_size * beam_width, 1, (size_t)max_input_length, (size_t)max_input_length}, + {batch_size * beam_width, + 1, + (size_t)max_input_length, + (size_t)(max_input_length + max_prefix_prompt_length)}, input_attention_mask_}}, - {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, tiled_input_lengths_buf_}}}; + {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, tiled_input_lengths_buf_}}, + {"d_prefix_prompt_batch", + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width}, + has_prefix_prompt_ ? prompt_learning_weight_batch_ : nullptr}}, + {"d_prefix_prompt_lengths", + Tensor{MEMORY_GPU, + TYPE_INT32, + {batch_size * beam_width}, + has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : nullptr}}}; std::unordered_map decoder_output_tensors{ {"decoder_output", @@ -493,9 +685,9 @@ void GptJ::forward(std::unordered_map* output_tensors, gpt_context_decoder_->forward( &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); - + sync_check_cuda_error(); invokeDecodingInitialize(finished_buf_, - sequence_lengths, + sequence_lengths_, nullptr, cum_log_probs_, start_ids_buf_, @@ -506,9 +698,11 @@ void GptJ::forward(std::unordered_map* output_tensors, sync_check_cuda_error(); } else if (max_input_length == 0) { + FT_CHECK(prompt_learning_type_ == PromptLearningType::no_prompt + && request_prompt_type == PromptLearningType::no_prompt); // Not support prompts in this case max_input_length++; invokeDecodingInitialize(finished_buf_, - sequence_lengths, + sequence_lengths_, output_ids_buf_, cum_log_probs_, start_ids_buf_, @@ -525,8 +719,10 @@ void GptJ::forward(std::unordered_map* output_tensors, sync_check_cuda_error(); } else if (max_input_length == 1) { + FT_CHECK(prompt_learning_type_ == PromptLearningType::no_prompt + && request_prompt_type == PromptLearningType::no_prompt); // Not support prompts in this case invokeDecodingInitialize(finished_buf_, - sequence_lengths, + sequence_lengths_, nullptr, cum_log_probs_, start_ids_buf_, @@ -554,7 +750,7 @@ void GptJ::forward(std::unordered_map* output_tensors, if (vocab_size_ == vocab_size_padded_) { padded_embedding_kernel_ptr_ = gpt_weights->post_decoder_embedding.kernel; - padded_embedding_bias_ptr_ = gpt_weights->post_decoder_embedding.bias; + padded_embedding_bias_ptr_ = gpt_weights->post_decoder_embedding.bias; } else { cudaMemcpyAsync(padded_embedding_kernel_, @@ -570,6 +766,16 @@ void GptJ::forward(std::unordered_map* output_tensors, sync_check_cuda_error(); } + invokeMaskPaddingTokens(masked_tokens_, + (const int*)(input_tensors->at("input_lengths").data), // not_tiled + tiled_prompt_lengths_buf_, + max_cache_seq_len, + max_input_length + max_prefix_prompt_length, + 0, + batch_size, + beam_width, + stream_); + for (int step = max_input_length; step < (int)max_output_seq_len; step++) { const int src_indir_idx = (step - max_input_length) % 2; const int tgt_indir_idx = 1 - src_indir_idx; @@ -577,29 +783,30 @@ void GptJ::forward(std::unordered_map* output_tensors, const size_t local_batch_size = getLocalBatchSize(batch_size, 1, pipeline_para_.world_size_); FT_CHECK(batch_size % local_batch_size == 0); const size_t iteration_num = batch_size / local_batch_size; + *generation_should_stop_ = true; for (uint ite = 0; ite < iteration_num; ++ite) { - const int id_offset = ite * local_batch_size * beam_width; - const int hidden_units_offset = id_offset * hidden_units_; + const int id_offset = ite * local_batch_size * beam_width; + const int hidden_units_offset = id_offset * hidden_units_; const int vocab_size_units_offset = id_offset * vocab_size_padded_; if (!(max_input_length > 1 && step == max_input_length)) { if (pipeline_para_.rank_ == 0) { - invokeEmbeddingLookupPosEncoding(decoder_input_buf_ + hidden_units_offset, - gpt_weights->pre_decoder_embedding_table, - gpt_weights->position_encoding_table, - output_ids_buf_ + id_offset, - tiled_input_lengths_buf_ + id_offset, - local_batch_size * beam_width, - hidden_units_, - (T)(1.0f), - step - 1, - max_input_length, - batch_size * beam_width, - 0, - stream_); + invokeEmbeddingLookupPosEncodingPadCount(decoder_input_buf_ + hidden_units_offset, + gpt_weights->pre_decoder_embedding_table, + gpt_weights->position_encoding_table, + output_ids_buf_ + id_offset, + tiled_total_padding_count_ + id_offset, + local_batch_size * beam_width, + hidden_units_, + (T)(1.0f), + step - 1, + batch_size * beam_width, + 0, + stream_); sync_check_cuda_error(); } + std::unordered_map decoder_input_tensors{ {"decoder_input", Tensor{MEMORY_GPU, @@ -609,12 +816,18 @@ void GptJ::forward(std::unordered_map* output_tensors, {"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {local_batch_size * beam_width}, finished_buf_ + id_offset}}, {"sequence_lengths", - Tensor{MEMORY_GPU, TYPE_INT32, {local_batch_size * beam_width}, sequence_lengths + id_offset}}, - {"input_lengths", + Tensor{MEMORY_GPU, TYPE_INT32, {local_batch_size * beam_width}, sequence_lengths_ + id_offset}}, + {"total_padding_tokens", Tensor{MEMORY_GPU, TYPE_INT32, {local_batch_size * beam_width}, - tiled_input_lengths_buf_ + id_offset}}, + tiled_total_padding_count_ + id_offset}}, + {"d_prefix_prompt_lengths", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size}, + has_prefix_prompt_ ? (tiled_prompt_lengths_buf_ + id_offset) : nullptr}}, + {"max_prefix_prompt_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_prefix_prompt_length}}, {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}, {"ite", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &ite}}, @@ -623,7 +836,12 @@ void GptJ::forward(std::unordered_map* output_tensors, TYPE_INT32, {local_batch_size, beam_width, max_output_seq_len}, beam_width > 1 ? cache_indirections_[src_indir_idx] + id_offset * max_output_seq_len : - nullptr}}}; + nullptr}}, + {"masked_tokens", + Tensor{MEMORY_GPU, + TYPE_BOOL, + {local_batch_size * beam_width, max_cache_seq_len}, + masked_tokens_ + id_offset * max_cache_seq_len}}}; std::unordered_map decoder_output_tensors{ {"decoder_output", Tensor{MEMORY_GPU, @@ -641,6 +859,7 @@ void GptJ::forward(std::unordered_map* output_tensors, decoder_output_buf_ + hidden_units_offset, gpt_weights->post_decoder_layernorm.gamma, gpt_weights->post_decoder_layernorm.beta, + layernorm_eps_, local_batch_size * beam_width, hidden_units_, stream_); @@ -648,7 +867,7 @@ void GptJ::forward(std::unordered_map* output_tensors, if (tensor_para_.world_size_ == 1) { float alpha = 1.0f; - float beta = 0.0f; + float beta = 0.0f; cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, vocab_size_padded_, // n @@ -656,10 +875,10 @@ void GptJ::forward(std::unordered_map* output_tensors, hidden_units_, // k &alpha, padded_embedding_kernel_ptr_, - sizeof(T) == 2 ? CUDA_R_16F : CUDA_R_32F, + gemm_data_type, hidden_units_, // k normed_decoder_output_buf_ + hidden_units_offset, - sizeof(T) == 2 ? CUDA_R_16F : CUDA_R_32F, + gemm_data_type, hidden_units_, // k &beta, logits_buf_ + vocab_size_units_offset, @@ -671,8 +890,8 @@ void GptJ::forward(std::unordered_map* output_tensors, else { FT_CHECK(vocab_size_padded_ % tensor_para_.world_size_ == 0); const int local_vocab_size = vocab_size_padded_ / tensor_para_.world_size_; - float alpha = 1.0f; - float beta = 0.0f; + float alpha = 1.0f; + float beta = 0.0f; cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, local_vocab_size, // n @@ -681,10 +900,10 @@ void GptJ::forward(std::unordered_map* output_tensors, &alpha, padded_embedding_kernel_ptr_ + tensor_para_.rank_ * local_vocab_size * hidden_units_, - sizeof(T) == 2 ? CUDA_R_16F : CUDA_R_32F, + gemm_data_type, hidden_units_, // k normed_decoder_output_buf_ + hidden_units_offset, - sizeof(T) == 2 ? CUDA_R_16F : CUDA_R_32F, + gemm_data_type, hidden_units_, // k &beta, nccl_logits_buf_ + vocab_size_units_offset @@ -699,7 +918,6 @@ void GptJ::forward(std::unordered_map* output_tensors, tensor_para_.rank_, tensor_para_, stream_); - check_cuda_error(cudaStreamSynchronize(stream_)); invokeTransposeAxis01(logits_buf_ + vocab_size_units_offset, nccl_logits_buf_ + vocab_size_units_offset, tensor_para_.world_size_, @@ -714,8 +932,8 @@ void GptJ::forward(std::unordered_map* output_tensors, vocab_size_padded_, stream_); - int tmp_local_batch_size = local_batch_size; - bool is_initialize_random_table = step == max_input_length; + int tmp_local_batch_size = local_batch_size; + bool is_initialize_random_table = step == max_input_length; std::unordered_map dynamic_decode_input_tensors{ {"logits", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size, beam_width, vocab_size_padded_}, logits_buf_}}, @@ -724,8 +942,8 @@ void GptJ::forward(std::unordered_map* output_tensors, {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, tiled_input_lengths_buf_}}, + {"sequence_limit_length", Tensor{MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len_}}, {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}}, - {"has_diff_runtime_args", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &has_diff_runtime_args}}, {"src_key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}}, {"src_value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}}, {"src_cache_indirection", @@ -744,6 +962,7 @@ void GptJ::forward(std::unordered_map* output_tensors, } // common outputs + bool subbatch_should_stop = false; std::unordered_map dynamic_decode_output_tensors{ {"output_ids", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, output_ids_buf_}}, @@ -765,12 +984,13 @@ void GptJ::forward(std::unordered_map* output_tensors, nullptr}}, {"parent_ids", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, parent_ids_buf_}}, - {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, sequence_lengths}}, + {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, sequence_lengths_}}, {"tgt_cache_indirection", Tensor{MEMORY_GPU, TYPE_INT32, {local_batch_size, beam_width, max_output_seq_len}, - cache_indirections_[tgt_indir_idx] + id_offset * max_output_seq_len}}}; + cache_indirections_[tgt_indir_idx] + id_offset * max_output_seq_len}}, + {"should_stop", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &subbatch_should_stop}}}; for (auto t = output_tensors->begin(); t != output_tensors->end(); ++t) { // Handle exceptions. @@ -781,11 +1001,12 @@ void GptJ::forward(std::unordered_map* output_tensors, } dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); + *generation_should_stop_ &= subbatch_should_stop; } } if (pipeline_para_.world_size_ > 1) { - NCCLCHECK(ncclGroupStart()); + ftNcclGroupStart(); ftNcclBroadCast(output_ids_buf_ + step * batch_size * beam_width, batch_size * beam_width, pipeline_para_.world_size_ - 1, @@ -793,10 +1014,9 @@ void GptJ::forward(std::unordered_map* output_tensors, stream_); ftNcclBroadCast( - sequence_lengths, batch_size * beam_width, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + sequence_lengths_, batch_size * beam_width, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); - ftNcclBroadCast( - finished_buf_, batch_size * beam_width, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + ftNcclBroadCast(generation_should_stop_, 1, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); if (beam_width > 1) { ftNcclBroadCast(cache_indirections_[tgt_indir_idx], @@ -805,167 +1025,164 @@ void GptJ::forward(std::unordered_map* output_tensors, pipeline_para_, stream_); } - NCCLCHECK(ncclGroupEnd()); - check_cuda_error(cudaStreamSynchronize(stream_)); + ftNcclGroupEnd(); + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); sync_check_cuda_error(); } - cudaD2Hcpy(h_finished_buf_, finished_buf_, batch_size * beam_width); - uint sum = 0; - for (uint i = 0; i < batch_size * beam_width; i++) { - if (has_per_item_requested_length) { - h_finished_buf_[i] |= - step >= reinterpret_cast(input_tensors->at("max_output_seq_len").data)[i / beam_width]; - } - sum += (int)h_finished_buf_[i]; - } - if (has_per_item_requested_length) { - cudaH2Dcpy(finished_buf_, h_finished_buf_, batch_size * beam_width); - } - if (sum == batch_size * beam_width) { + if (*generation_should_stop_) { break; } - } + if (token_generated_cb_ && step + 1 < (int)max_output_seq_len) { + setOutputTensors(output_tensors, input_tensors, max_output_seq_len); + sendTensorsToFirstPipelineNode(output_tensors, input_tensors); - if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { - if (input_tensors->at("input_ids").shape[1] == 0) { - if (beam_width > 1) { - // For beam search, do gather_tree - // take output_parent_ids as inter buffer - invokeGatherTree(transposed_output_ids_buf_, - sequence_lengths, - max_output_seq_len, - batch_size, - beam_width, - output_ids_buf_ + batch_size * beam_width, - parent_ids_buf_ + batch_size * beam_width, - end_ids_buf_, - stream_); - - // transpose and take output_parent_ids as inter buffer - invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, - transposed_output_ids_buf_, - max_output_seq_len - 1, - batch_size * beam_width, - 1, - stream_); - } - else { - // For sampling, only copy the results to output_tensor - invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, - output_ids_buf_ + batch_size * beam_width, - max_output_seq_len - 1, - batch_size * beam_width, - 1, - stream_); + if (pipeline_para_.rank_ == 0 && tensor_para_.rank_ == 0) { + token_generated_cb_(output_tensors, token_generated_ctx_); } } - else { - // add sequence_length 1 here because the sequence_length of time step t is t - 1 - invokePlusScalar(sequence_lengths, 1, batch_size * beam_width, stream_); - - // For sampling, it is equivalent to all parent ids are 0. - gatherTreeParam param; - param.beams = transposed_output_ids_buf_; - param.max_sequence_lengths = sequence_lengths; - param.max_time = max_output_seq_len; - param.batch_size = batch_size; - param.beam_width = beam_width; - param.step_ids = output_ids_buf_; - param.parent_ids = beam_width == 1 ? nullptr : parent_ids_buf_; - param.end_tokens = end_ids_buf_; - param.max_input_length = max_input_length; - param.prefix_soft_prompt_lengths = input_tensors->count("prefix_soft_prompt_lengths") ? - input_tensors->at("prefix_soft_prompt_lengths").getPtr() : - nullptr; - param.input_lengths = tiled_input_lengths_buf_; - param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; - param.stream = stream_; - param.output_ids = (int*)output_tensors->at("output_ids").data; - invokeGatherTree(param); - sync_check_cuda_error(); - } - if ((output_tensors->count("output_log_probs") > 0 && output_tensors->at("output_log_probs").data != nullptr)) { - invokeTransposeAxis01(output_tensors->at("output_log_probs").getPtr(), - output_log_probs_buf_, - input_tensors->at("max_output_seq_len").getVal() - max_input_length, - batch_size * beam_width, - 1, - stream_); - } - // Return the cumulative log probability if requested. - if (output_tensors->count("cum_log_probs") > 0) { - Tensor cum_log_probs = output_tensors->at("cum_log_probs"); - FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, - "The shape of cum_log_probs does not match with batch_size x beam_width."); - cudaD2Dcpy(cum_log_probs.getPtr(), cum_log_probs_, batch_size * beam_width); + if (step == max_input_length) { + /* We have just finished processing input: update the padding count: + * total_padding_count += (max_input_length - input_lengths) + * if has prefix prompts, += (max_prefix_prompt_length - prompt_length) + */ + invokeUpdatePaddingCount(tiled_total_padding_count_, + (const int*)(input_tensors->at("input_lengths").data), // not_tiled + has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : (const int*)nullptr, + max_input_length, + has_prefix_prompt_ ? max_prefix_prompt_length : 0, + batch_size, + beam_width, + stream_); } } + setOutputTensors(output_tensors, input_tensors, max_output_seq_len); + sendTensorsToFirstPipelineNode(output_tensors, input_tensors); +} - if (pipeline_para_.world_size_ > 1) { - NCCLCHECK(ncclGroupStart()); - if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { - ftNcclSend(output_tensors->at("output_ids").getPtr(), - batch_size * beam_width * max_output_seq_len, - 0, - pipeline_para_, - stream_); - - ftNcclSend(output_tensors->at("sequence_length").getPtr(), - batch_size * beam_width, - 0, - pipeline_para_, - stream_); +template +void GptJ::sendTensorsToFirstPipelineNode(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors) +{ + if (pipeline_para_.world_size_ == 1) { + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); + return; + } - if (output_tensors->count("cum_log_probs") > 0 && output_tensors->at("cum_log_probs").data != nullptr) { - ftNcclSend(output_tensors->at("cum_log_probs").getPtr(), - batch_size * beam_width, - 0, - pipeline_para_, - stream_); - } + const auto pp_rank = pipeline_para_.rank_; - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclSend(output_tensors->at("output_log_probs").getPtr(), - output_tensors->at("output_log_probs").size(), - 0, - pipeline_para_, - stream_); - } + ftNcclGroupStart(); + for (auto const& it : *output_tensors) { + if (it.second.data == nullptr) { + continue; } - else if (pipeline_para_.rank_ == 0) { - ftNcclRecv(output_tensors->at("output_ids").getPtr(), - batch_size * beam_width * max_output_seq_len, - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); - ftNcclRecv(output_tensors->at("sequence_length").getPtr(), - batch_size * beam_width, + if (pp_rank == pipeline_para_.world_size_ - 1) { + ftNcclSend(it.second.getPtr(), it.second.sizeBytes(), 0, pipeline_para_, stream_); + } + else if (pp_rank == 0) { + ftNcclRecv(it.second.getPtr(), + it.second.sizeBytes(), pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + } + } + ftNcclGroupEnd(); + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); +} - if (output_tensors->count("cum_log_probs") > 0 && output_tensors->at("cum_log_probs").data != nullptr) { - ftNcclRecv(output_tensors->at("cum_log_probs").getPtr(), - batch_size * beam_width, - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); - } +template +void GptJ::setOutputTensors(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const size_t max_output_seq_len) +{ + if (pipeline_para_.rank_ != pipeline_para_.world_size_ - 1) { + return; + } - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclRecv(output_tensors->at("output_log_probs").getPtr(), - output_tensors->at("output_log_probs").size(), - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); - } + const size_t batch_size = output_tensors->at("output_ids").shape[0]; + const size_t beam_width = output_tensors->at("output_ids").shape[1]; + int* sequence_lengths = output_tensors->at("sequence_length").getPtr(); + const int max_input_length = input_tensors->at("input_ids").shape[1]; + const size_t max_prefix_soft_prompt_length = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_embedding").shape[1] : 0; + + cudaAutoCpy(sequence_lengths, sequence_lengths_, output_tensors->at("sequence_length").size(), stream_); + if (input_tensors->at("input_ids").shape[1] == 0) { + // TODO: D2D sequence_lenghts + if (beam_width > 1) { + // For beam search, do gather_tree + // take output_parent_ids as inter buffer + invokeGatherTree(transposed_output_ids_buf_, + sequence_lengths_, + max_output_seq_len, + batch_size, + beam_width, + output_ids_buf_ + batch_size * beam_width, + parent_ids_buf_ + batch_size * beam_width, + end_ids_buf_, + stream_); + + // transpose and take output_parent_ids as inter buffer + invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, + transposed_output_ids_buf_, + max_output_seq_len - 1, + batch_size * beam_width, + 1, + stream_); } - NCCLCHECK(ncclGroupEnd()); - check_cuda_error(cudaStreamSynchronize(stream_)); + else { + // For sampling, only copy the results to output_tensor + invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, + output_ids_buf_ + batch_size * beam_width, + max_output_seq_len - 1, + batch_size * beam_width, + 1, + stream_); + } + } + else { + // add sequence_length 1 here because the sequence_length of time step t is t - 1 + invokePlusScalar(sequence_lengths, 1, batch_size * beam_width, stream_); + + // For sampling, it is equivalent to all parent ids are 0. + gatherTreeParam param; + param.beams = transposed_output_ids_buf_; + param.max_sequence_lengths = sequence_lengths; + param.max_time = max_output_seq_len; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.step_ids = output_ids_buf_; + param.parent_ids = beam_width == 1 ? nullptr : parent_ids_buf_; + param.end_tokens = end_ids_buf_; + param.max_input_length = max_input_length; + param.prefix_soft_prompt_lengths = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_lengths").getPtr() : nullptr; + param.input_lengths = tiled_input_lengths_buf_; + param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; + param.stream = stream_; + param.output_ids = (int*)output_tensors->at("output_ids").data; + invokeGatherTree(param); + sync_check_cuda_error(); + } + if ((output_tensors->count("output_log_probs") > 0 && output_tensors->at("output_log_probs").data != nullptr)) { + invokeTransposeAxis01(output_tensors->at("output_log_probs").getPtr(), + output_log_probs_buf_, + input_tensors->at("output_seq_len").getVal() - max_input_length, + batch_size * beam_width, + 1, + stream_); + } + // Return the cumulative log probability if requested. + if (output_tensors->count("cum_log_probs") > 0) { + Tensor cum_log_probs = output_tensors->at("cum_log_probs"); + FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, + "The shape of cum_log_probs does not match with batch_size x beam_width."); + cudaAutoCpy(cum_log_probs.getPtr(), cum_log_probs_, cum_log_probs.size(), stream_); } } @@ -1001,5 +1218,8 @@ bool* GptJ::getFinishBuffer() template class GptJ; template class GptJ; +#ifdef ENABLE_BF16 +template class GptJ<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/gptj/GptJ.h b/src/fastertransformer/models/gptj/GptJ.h index 0b8c0eb9f..f8f720d2d 100644 --- a/src/fastertransformer/models/gptj/GptJ.h +++ b/src/fastertransformer/models/gptj/GptJ.h @@ -31,40 +31,50 @@ template class GptJ: public BaseLayer { private: // meta data - size_t head_num_; - size_t size_per_head_; - size_t inter_size_; - size_t num_layer_; - size_t vocab_size_; - size_t rotary_embedding_dim_; - - int start_id_; - int end_id_; + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + size_t rotary_embedding_dim_; + const bool neox_rotary_style_ = false; // A unify way for GPT-NeoX in the future, not used now. + + static constexpr float layernorm_eps_ = 1e-6f; + + int start_id_; + int end_id_; size_t hidden_units_; - size_t local_head_num_; + size_t local_head_num_; NcclParam tensor_para_; NcclParam pipeline_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; - size_t vocab_size_padded_; + size_t vocab_size_padded_; const bool is_context_qk_buf_float_ = true; - GptJDecoder* gpt_decoder_; - GptJContextDecoder* gpt_context_decoder_; + // Prompt Learning Parameters + PromptLearningType prompt_learning_type_; + int prompt_learning_start_id_; // start_id for prompt_learning (only needed by prefix prompts) + bool has_prefix_soft_prompt_; + bool has_prefix_prompt_; + + GptJDecoder* gpt_decoder_; + GptJContextDecoder* gpt_context_decoder_; DynamicDecodeLayer* dynamic_decode_layer_; void allocateBuffer() override; - void allocateBuffer(size_t batch_size, size_t beam_width, size_t max_seq_len, size_t max_input_len); + void allocateBuffer( + size_t batch_size, size_t beam_width, size_t max_seq_len, size_t max_cache_seq_len, size_t max_input_len); void freeBuffer() override; void initialize(); protected: - T* padded_embedding_kernel_; - T* padded_embedding_bias_; + T* padded_embedding_kernel_; + T* padded_embedding_bias_; const T* padded_embedding_kernel_ptr_; const T* padded_embedding_bias_ptr_; @@ -80,97 +90,125 @@ class GptJ: public BaseLayer { bool* finished_buf_; bool* h_finished_buf_; + int* sequence_lengths_ = nullptr; - T* key_cache_; - T* value_cache_; + T* key_cache_; + T* value_cache_; int* cache_indirections_[2] = {nullptr, nullptr}; - int* tiled_input_ids_buf_; - int* tiled_input_lengths_buf_; - int* transposed_output_ids_buf_; - int* output_ids_buf_; - int* parent_ids_buf_; - int* start_ids_buf_; - int* end_ids_buf_; - - T* context_decoder_input_buf_; - T* context_decoder_output_buf_; + // prompt_learning weight_batch ptrs + const T** prompt_learning_weight_batch_; + int* tiled_prompt_lengths_buf_; // only needed by prefix prompts + + int* tiled_input_ids_buf_; + int* tiled_input_lengths_buf_; + int* transposed_output_ids_buf_; + int* output_ids_buf_; + int* parent_ids_buf_; + int* start_ids_buf_; + int* end_ids_buf_; + bool* masked_tokens_ = nullptr; + uint32_t* seq_limit_len_ = nullptr; + int* tiled_total_padding_count_ = nullptr; + + bool* generation_should_stop_ = nullptr; + + T* context_decoder_input_buf_; + T* context_decoder_output_buf_; float* output_log_probs_buf_; + // function pointer callback + using callback_sig = void(std::unordered_map*, void*); + callback_sig* token_generated_cb_ = nullptr; + void* token_generated_ctx_ = nullptr; + + void setOutputTensors(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + size_t max_seq_len); + void sendTensorsToFirstPipelineNode(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors); + public: - GptJ(size_t max_batch_size, - size_t max_seq_len, - size_t max_input_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - size_t rotary_embedding_dim, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); - - GptJ(size_t max_batch_size, - size_t max_seq_len, - size_t max_input_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - size_t rotary_embedding_dim, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop = nullptr, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + GptJ(size_t max_batch_size, + size_t max_seq_len, + size_t max_input_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + GptJ(size_t max_batch_size, + size_t max_seq_len, + size_t max_input_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop = nullptr, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); GptJ(GptJ const& GptJ); ~GptJ(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const GptJWeight* gpt_weights); + const GptJWeight* gpt_weights); - void forward(std::unordered_map* output_tensors, + void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const GptJWeight* gpt_weights); + const GptJWeight* gpt_weights); size_t getPipelineParallelRank(); size_t getPipelineParallelSize(); size_t getTensorParallelRank(); size_t getTensorParallelSize(); - bool* getFinishBuffer(); + bool* getFinishBuffer(); + + void registerCallback(callback_sig* fn, void* ctx); + void unRegisterCallback(); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/gptj/GptJContextDecoder.cc b/src/fastertransformer/models/gptj/GptJContextDecoder.cc index b5e226861..6cbc61635 100644 --- a/src/fastertransformer/models/gptj/GptJContextDecoder.cc +++ b/src/fastertransformer/models/gptj/GptJContextDecoder.cc @@ -30,10 +30,12 @@ void GptJContextDecoder::initialize() head_num_, size_per_head_, rotary_embedding_dim_, + neox_rotary_style_, tensor_para_, stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, is_qk_buf_float_, false, @@ -49,9 +51,11 @@ void GptJContextDecoder::initialize() stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, false, 0, + false, // use_gated_activation = false; custom_all_reduce_comm_, enable_custom_all_reduce_); } @@ -59,54 +63,32 @@ void GptJContextDecoder::initialize() template void GptJContextDecoder::allocateBuffer() { - if (is_allocate_buffer_ == false) { - decoder_normed_input_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); - self_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); - ffn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); - decoder_layer_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); - is_allocate_buffer_ = true; - } -} - -template -void GptJContextDecoder::freeBuffer() -{ - if (is_allocate_buffer_ == true) { - allocator_->free(decoder_normed_input_); - allocator_->free(self_attn_output_); - allocator_->free(ffn_output_); - allocator_->free(decoder_layer_output_); - is_allocate_buffer_ = false; - } + FT_CHECK(false); } template -bool GptJContextDecoder::isValidBatchSize(size_t batch_size) +void GptJContextDecoder::allocateBuffer(size_t batch_size, size_t seq_len) { - if (batch_size <= max_batch_size_) { - return true; - } - else { - freeBuffer(); - max_batch_size_ = batch_size * 1.2; - return true; - } + decoder_normed_input_ = reinterpret_cast( + allocator_->reMalloc(decoder_normed_input_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + self_attn_output_ = reinterpret_cast( + allocator_->reMalloc(self_attn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + ffn_output_ = reinterpret_cast( + allocator_->reMalloc(ffn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + decoder_layer_output_ = reinterpret_cast( + allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + is_allocate_buffer_ = true; } template -bool GptJContextDecoder::isValidSeqLen(size_t seq_len) +void GptJContextDecoder::freeBuffer() { - if (seq_len <= max_seq_len_) { - return true; - } - else { - freeBuffer(); - max_seq_len_ = seq_len * 1.2; - return true; + if (is_allocate_buffer_ == true) { + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + allocator_->free((void**)(&ffn_output_)); + allocator_->free((void**)(&decoder_layer_output_)); + is_allocate_buffer_ = false; } } @@ -140,22 +122,24 @@ int GptJContextDecoder::getFirstLayerParallelId() } template -GptJContextDecoder::GptJContextDecoder(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t rotary_embedding_dim, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, +GptJContextDecoder::GptJContextDecoder(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -164,6 +148,8 @@ GptJContextDecoder::GptJContextDecoder(size_t max_batch_size, inter_size_(inter_size), num_layer_(num_layer), rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + layernorm_eps_(layernorm_eps), hidden_units_(head_num * size_per_head), tensor_para_(tensor_para), pipeline_para_(pipeline_para), @@ -184,6 +170,7 @@ GptJContextDecoder::GptJContextDecoder(GptJContextDecoder const& decoder): inter_size_(decoder.inter_size_), num_layer_(decoder.num_layer_), rotary_embedding_dim_(decoder.rotary_embedding_dim_), + layernorm_eps_(decoder.layernorm_eps_), hidden_units_(decoder.hidden_units_), tensor_para_(decoder.tensor_para_), pipeline_para_(decoder.pipeline_para_), @@ -203,8 +190,8 @@ GptJContextDecoder::~GptJContextDecoder() } template -void GptJContextDecoder::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void GptJContextDecoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* gpt_decoder_layer_weight) { std::unordered_map input_tensors_map{{"decoder_input", input_tensors->at(0)}, @@ -219,14 +206,17 @@ void GptJContextDecoder::forward(std::vector* output_tensors, } template -void GptJContextDecoder::forward(std::unordered_map* output_tensors, +void GptJContextDecoder::forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const std::vector>* gpt_decoder_layer_weight) + const std::vector>* gpt_decoder_layer_weight) { // input tensors: // decoder_input [batch_size, seq_len, hidden_dimension], - // attention_mask [batch_size, 1, seq_len, seq_len] + // attention_mask [batch_size, 1, seq_len, seq_len + max_prompt_length] // input_lengths [batch_size] + // d_prefix_prompt_batch [batch_size], + // each element contains ptr with buffer shape[2, local_head_num_, prompt_length, size_per_head] + // prefix_prompt_lengths [batch size] // output tensors: // decoder_output [batch_size, seq_len, hidden_dimension], @@ -238,26 +228,28 @@ void GptJContextDecoder::forward(std::unordered_map* out // For example, the shape of decoder_input becomes [ite, batch_size, seq_len, hidden_dimension] during // computing. - FT_CHECK(input_tensors->size() == 3); + FT_CHECK(input_tensors->size() == 5); FT_CHECK(output_tensors->size() == 4); - isValidBatchSize(input_tensors->at("decoder_input").shape[0]); - isValidSeqLen(input_tensors->at("decoder_input").shape[1]); - allocateBuffer(); const int batch_size = input_tensors->at("decoder_input").shape[0]; - const int seq_len = input_tensors->at("decoder_input").shape[1]; + const int seq_len = input_tensors->at("decoder_input").shape[1]; // max_input_len + const int max_prompt_length = + input_tensors->at("attention_mask").shape[3] - input_tensors->at("attention_mask").shape[2]; const DataType data_type = getTensorType(); + allocateBuffer(batch_size, seq_len); - T* decoder_input = (T*)input_tensors->at("decoder_input").data; - T* decoder_output = (T*)output_tensors->at("decoder_output").data; - const T* attention_mask = (const T*)input_tensors->at("attention_mask").data; + T* decoder_input = (T*)input_tensors->at("decoder_input").data; + T* decoder_output = (T*)output_tensors->at("decoder_output").data; + const T* attention_mask = (const T*)input_tensors->at("attention_mask").data; + const T** d_prefix_prompt_batch = (const T**)input_tensors->at("d_prefix_prompt_batch").data; + const int* d_prefix_prompt_lengths = (const int*)input_tensors->at("d_prefix_prompt_lengths").data; const int local_batch_size = getLocalBatchSize(batch_size, seq_len, pipeline_para_.world_size_); FT_CHECK(batch_size % local_batch_size == 0); const int iteration_num = batch_size / local_batch_size; - Tensor& k_cache = output_tensors->at("key_cache"); - Tensor& v_cache = output_tensors->at("value_cache"); + Tensor& k_cache = output_tensors->at("key_cache"); + Tensor& v_cache = output_tensors->at("value_cache"); std::vector self_k_cache_size; self_k_cache_size.push_back(local_batch_size); for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { @@ -276,7 +268,7 @@ void GptJContextDecoder::forward(std::unordered_map* out } const bool is_final = false; // TODO(bhsueh) remove this flag - T* layer_input = + T* layer_input = ((l == 0) ? decoder_input : decoder_layer_output_) + ite * local_batch_size * seq_len * hidden_units_; T* layer_output = ((l == num_layer_ - 1) ? decoder_output : decoder_layer_output_) + ite * local_batch_size * seq_len * hidden_units_; @@ -297,6 +289,7 @@ void GptJContextDecoder::forward(std::unordered_map* out layer_input, gpt_decoder_layer_weight->at(l).pre_layernorm_weights.gamma, gpt_decoder_layer_weight->at(l).pre_layernorm_weights.beta, + layernorm_eps_, local_batch_size * seq_len, hidden_units_, stream_); @@ -309,10 +302,20 @@ void GptJContextDecoder::forward(std::unordered_map* out decoder_normed_input_}, Tensor{MEMORY_GPU, data_type, - {(size_t)local_batch_size, (size_t)1, (size_t)seq_len, (size_t)seq_len}, - attention_mask + local_batch_size * ite * seq_len * seq_len}, - Tensor{MEMORY_CPU, TYPE_BOOL, {(size_t)1}, &is_final}}; + {(size_t)local_batch_size, (size_t)1, (size_t)seq_len, (size_t)(seq_len + max_prompt_length)}, + attention_mask + local_batch_size * ite * seq_len * (seq_len + max_prompt_length)}, + Tensor{MEMORY_CPU, TYPE_BOOL, {(size_t)1}, &is_final}, // NOTE: assume batch size = 1 + Tensor{MEMORY_GPU, + data_type, + {(size_t)local_batch_size}, + d_prefix_prompt_batch != nullptr ? d_prefix_prompt_batch + ite * local_batch_size : nullptr}, + Tensor{MEMORY_GPU, + TYPE_INT32, + {(size_t)local_batch_size}, + d_prefix_prompt_lengths != nullptr ? d_prefix_prompt_lengths + ite * local_batch_size : nullptr}, + Tensor{MEMORY_CPU, TYPE_INT32, {(size_t)1}, &l}}; // layer_id + // NOTE: cache offer for specific layer size_t cache_offset = l - getFirstLayerParallelId(); for (auto t = k_cache.shape.begin() + 1; t != k_cache.shape.end(); ++t) { cache_offset *= *t; @@ -377,7 +380,7 @@ void GptJContextDecoder::forward(std::unordered_map* out batch_size, hidden_units_, stream_); - + sync_check_cuda_error(); if (is_free_buffer_after_forward_ == true) { freeBuffer(); } @@ -385,5 +388,7 @@ void GptJContextDecoder::forward(std::unordered_map* out template class GptJContextDecoder; template class GptJContextDecoder; - +#ifdef ENABLE_BF16 +template class GptJContextDecoder<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/gptj/GptJContextDecoder.h b/src/fastertransformer/models/gptj/GptJContextDecoder.h index 61a825792..e8a579586 100644 --- a/src/fastertransformer/models/gptj/GptJContextDecoder.h +++ b/src/fastertransformer/models/gptj/GptJContextDecoder.h @@ -37,7 +37,7 @@ class GptJContextDecoder: public BaseLayer { private: // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; // meta data size_t head_num_; @@ -45,6 +45,9 @@ class GptJContextDecoder: public BaseLayer { size_t inter_size_; size_t num_layer_; size_t rotary_embedding_dim_; + bool neox_rotary_style_; // A unify way for GPT-NeoX in the future, not used now. + + float layernorm_eps_; // calculated data size_t hidden_units_; @@ -53,61 +56,61 @@ class GptJContextDecoder: public BaseLayer { NcclParam pipeline_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; bool is_qk_buf_float_; BaseAttentionLayer* self_attention_layer_; - FfnLayer* ffn_layer_; + FfnLayer* ffn_layer_; void allocateBuffer() override; + void allocateBuffer(size_t batch_size, size_t seq_len); void freeBuffer() override; - bool isValidBatchSize(size_t batch_size); - bool isValidSeqLen(size_t seq_len); - bool isValidLayerParallelId(uint l); bool isFirstLayerParallelId(uint l); bool isLastLayerParallelId(uint l); - int getFirstLayerParallelId(); + int getFirstLayerParallelId(); void initialize(); protected: - T* decoder_normed_input_; - T* self_attn_output_; - T* ffn_output_; - T* decoder_layer_output_; + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; + T* ffn_output_ = nullptr; + T* decoder_layer_output_ = nullptr; public: - GptJContextDecoder(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t rotary_embedding_dim, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce_ = 0); + GptJContextDecoder(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce_ = 0); GptJContextDecoder(GptJContextDecoder const& decoder); ~GptJContextDecoder(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* decoder_layer_weights); - void forward(std::unordered_map* output_tensors, + void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const std::vector>* gpt_decoder_layer_weight); + const std::vector>* gpt_decoder_layer_weight); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/gptj/GptJDecoder.cc b/src/fastertransformer/models/gptj/GptJDecoder.cc index 268d664dd..2aaf8945d 100644 --- a/src/fastertransformer/models/gptj/GptJDecoder.cc +++ b/src/fastertransformer/models/gptj/GptJDecoder.cc @@ -27,10 +27,12 @@ void GptJDecoder::initialize() head_num_, size_per_head_, rotary_embedding_dim_, + neox_rotary_style_, tensor_para_, stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, false, 0, @@ -46,51 +48,44 @@ void GptJDecoder::initialize() stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, false, 0, + false, // use_gated_activation = false; custom_all_reduce_comm_, enable_custom_all_reduce_); - allocateBuffer(); } template void GptJDecoder::allocateBuffer() { - if (is_allocate_buffer_ == false) { - decoder_normed_input_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - self_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - ffn_output_ = reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - decoder_layer_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - is_allocate_buffer_ = true; - } + FT_CHECK(false); } template -void GptJDecoder::freeBuffer() +void GptJDecoder::allocateBuffer(size_t batch_size) { - if (is_allocate_buffer_ == true) { - allocator_->free(decoder_normed_input_); - allocator_->free(self_attn_output_); - allocator_->free(ffn_output_); - allocator_->free(decoder_layer_output_); - is_allocate_buffer_ = false; - } + decoder_normed_input_ = reinterpret_cast( + allocator_->reMalloc(decoder_normed_input_, sizeof(T) * batch_size * hidden_units_, false)); + self_attn_output_ = + reinterpret_cast(allocator_->reMalloc(self_attn_output_, sizeof(T) * batch_size * hidden_units_, false)); + ffn_output_ = + reinterpret_cast(allocator_->reMalloc(ffn_output_, sizeof(T) * batch_size * hidden_units_, false)); + decoder_layer_output_ = reinterpret_cast( + allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * hidden_units_, false)); + is_allocate_buffer_ = true; } template -bool GptJDecoder::isValidBatchSize(size_t batch_size) +void GptJDecoder::freeBuffer() { - if (batch_size <= max_batch_size_) { - return true; - } - else { - freeBuffer(); - max_batch_size_ = batch_size * 1.2; - return true; + if (is_allocate_buffer_ == true) { + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + allocator_->free((void**)(&ffn_output_)); + allocator_->free((void**)(&decoder_layer_output_)); + is_allocate_buffer_ = false; } } @@ -124,20 +119,22 @@ int GptJDecoder::getFirstLayerParallelId() } template -GptJDecoder::GptJDecoder(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t rotary_embedding_dim, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, +GptJDecoder::GptJDecoder(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), head_num_(head_num), @@ -145,6 +142,8 @@ GptJDecoder::GptJDecoder(size_t max_batch_size, inter_size_(inter_size), num_layer_(num_layer), rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + layernorm_eps_(layernorm_eps), hidden_units_(head_num_ * size_per_head), tensor_para_(tensor_para), pipeline_para_(pipeline_para), @@ -163,6 +162,7 @@ GptJDecoder::GptJDecoder(GptJDecoder const& decoder): inter_size_(decoder.inter_size_), num_layer_(decoder.num_layer_), rotary_embedding_dim_(decoder.rotary_embedding_dim_), + layernorm_eps_(decoder.layernorm_eps_), hidden_units_(decoder.hidden_units_), tensor_para_(decoder.tensor_para_), pipeline_para_(decoder.pipeline_para_), @@ -181,61 +181,51 @@ GptJDecoder::~GptJDecoder() } template -void GptJDecoder::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void GptJDecoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* gpt_decoder_layer_weight) { - std::unordered_map input_tensors_map{{"decoder_input", input_tensors->at(0)}, - {"finished", input_tensors->at(1)}, - {"sequence_lengths", input_tensors->at(2)}, - {"input_lengths", input_tensors->at(3)}, - {"max_input_length", input_tensors->at(4)}, - {"step", input_tensors->at(5)}, - {"ite", input_tensors->at(6)}}; - std::unordered_map output_tensors_map{ - {"decoder_output", output_tensors->at(0)}, - {"key_cache", output_tensors->at(1)}, - {"value_cache", output_tensors->at(2)}, - }; - - forward(&output_tensors_map, &input_tensors_map, gpt_decoder_layer_weight); + FT_CHECK(false); } + template -void GptJDecoder::forward(std::unordered_map* output_tensors, +void GptJDecoder::forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const std::vector>* gpt_decoder_layer_weight) + const std::vector>* gpt_decoder_layer_weight) { // input tensors: // decoder_input [local_batch_size, hidden_dimension], // finished [local_batch_size], // sequence_lengths [local_batch_size] - // input_lengths [local_batch_size], + // total_padding_tokens [local_batch_size], + // d_prefix_prompt_lengths [local_batch_size],on GPU + // max_prefix_prompt_length [1] on cpu // max_input_length [1] on cpu // step [1] on cpu // ite [1] on cpu - // cache_indirection [local_batch_size / beam_width, beam_width, max_seq_len] + // cache_indirection [local_batch_size / beam_width, beam_width, memory_len] // Here, local_batch_size contains the beam_width, so local_batch_size / beam_width // is real local_batch_size. + // masked_tokens[local_batch_size, memory_len] // output tensors: // decoder_output [local_batch_size, hidden_dimension], - // key_cache [num_layer, batch_size, head_num, size_per_head // x, max_seq_len, x] - // value_cache [num_layer, batch_size, head_num, max_seq_len, size_per_head] + // key_cache [num_layer, batch_size, head_num, size_per_head // x, memory_len, x] + // value_cache [num_layer, batch_size, head_num, memory_len, size_per_head] - FT_CHECK(input_tensors->size() == 8); + FT_CHECK(input_tensors->size() == 11); FT_CHECK(output_tensors->size() == 3); - isValidBatchSize(input_tensors->at("decoder_input").shape[0]); - allocateBuffer(); - const DataType data_type = getTensorType(); - const size_t local_batch_size = input_tensors->at("decoder_input").shape[0]; - const int ite = *((int*)(input_tensors->at("ite").data)); + const DataType data_type = getTensorType(); + const size_t local_batch_size = input_tensors->at("decoder_input").shape[0]; + const int ite = *((int*)(input_tensors->at("ite").data)); + allocateBuffer(local_batch_size); - T* decoder_input = (T*)input_tensors->at("decoder_input").data; + T* decoder_input = (T*)input_tensors->at("decoder_input").data; T* decoder_output = (T*)output_tensors->at("decoder_output").data; - Tensor& k_cache = output_tensors->at("key_cache"); - Tensor& v_cache = output_tensors->at("value_cache"); + Tensor& k_cache = output_tensors->at("key_cache"); + Tensor& v_cache = output_tensors->at("value_cache"); std::vector self_k_cache_size; self_k_cache_size.push_back(local_batch_size); for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { @@ -251,7 +241,7 @@ void GptJDecoder::forward(std::unordered_map* output_ten if (isValidLayerParallelId(l) == false) { continue; } - T* layer_input = (l == 0) ? decoder_input : decoder_layer_output_; + T* layer_input = (l == 0) ? decoder_input : decoder_layer_output_; T* layer_output = (l == num_layer_ - 1) ? decoder_output : decoder_layer_output_; if (isFirstLayerParallelId(l) == true && pipeline_para_.rank_ != 0 && pipeline_para_.world_size_ > 1) { @@ -273,6 +263,7 @@ void GptJDecoder::forward(std::unordered_map* output_ten layer_input, gpt_decoder_layer_weight->at(l).pre_layernorm_weights.gamma, gpt_decoder_layer_weight->at(l).pre_layernorm_weights.beta, + layernorm_eps_, local_batch_size, hidden_units_, stream_); @@ -282,10 +273,13 @@ void GptJDecoder::forward(std::unordered_map* output_ten Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, decoder_normed_input_}, input_tensors->at("finished"), input_tensors->at("sequence_lengths"), - input_tensors->at("input_lengths"), + input_tensors->at("total_padding_tokens"), + input_tensors->at("d_prefix_prompt_lengths"), + input_tensors->at("max_prefix_prompt_length"), input_tensors->at("max_input_length"), input_tensors->at("step"), - input_tensors->at("cache_indirection")}; + input_tensors->at("cache_indirection"), + input_tensors->at("masked_tokens")}; size_t cache_offset = l - getFirstLayerParallelId(); for (auto t = k_cache.shape.begin() + 1; t != k_cache.shape.end(); ++t) { @@ -343,5 +337,8 @@ void GptJDecoder::forward(std::unordered_map* output_ten template class GptJDecoder; template class GptJDecoder; +#ifdef ENABLE_BF16 +template class GptJDecoder<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/gptj/GptJDecoder.h b/src/fastertransformer/models/gptj/GptJDecoder.h index 01c551606..63c5889b3 100644 --- a/src/fastertransformer/models/gptj/GptJDecoder.h +++ b/src/fastertransformer/models/gptj/GptJDecoder.h @@ -36,13 +36,13 @@ template class GptJDecoder: public BaseLayer { private: protected: - void allocateBuffer() override; - void freeBuffer() override; - bool isValidBatchSize(size_t batch_size); - bool isValidLayerParallelId(uint l); - bool isFirstLayerParallelId(uint l); - bool isLastLayerParallelId(uint l); - int getFirstLayerParallelId(); + void allocateBuffer() override; + void allocateBuffer(size_t batch_size); + void freeBuffer() override; + bool isValidLayerParallelId(uint l); + bool isFirstLayerParallelId(uint l); + bool isLastLayerParallelId(uint l); + int getFirstLayerParallelId(); virtual void initialize(); // buffer handling size_t max_batch_size_ = 0; @@ -53,48 +53,53 @@ class GptJDecoder: public BaseLayer { size_t inter_size_; size_t num_layer_; size_t rotary_embedding_dim_; + bool neox_rotary_style_; // A unify way for GPT-NeoX in the future, not used now. size_t hidden_units_; + float layernorm_eps_; + NcclParam tensor_para_; NcclParam pipeline_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; - T* decoder_normed_input_; - T* self_attn_output_; - T* ffn_output_; - T* decoder_layer_output_; + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; + T* ffn_output_ = nullptr; + T* decoder_layer_output_ = nullptr; BaseAttentionLayer* self_attention_layer_; - FfnLayer* ffn_layer_; + FfnLayer* ffn_layer_; public: - GptJDecoder(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t rotary_embedding_dim, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce_ = 0); + GptJDecoder(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce_ = 0); GptJDecoder(GptJDecoder const& decoder); virtual ~GptJDecoder(); - virtual void forward(std::unordered_map* output_tensors, + virtual void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const std::vector>* decoder_layer_weights); + const std::vector>* decoder_layer_weights); - virtual void forward(std::vector* output_tensors, - const std::vector* input_tensors, + virtual void forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* decoder_layer_weights); }; diff --git a/src/fastertransformer/models/gptj/GptJDecoderLayerWeight.cc b/src/fastertransformer/models/gptj/GptJDecoderLayerWeight.cc index cb3a3aad7..409f24df6 100644 --- a/src/fastertransformer/models/gptj/GptJDecoderLayerWeight.cc +++ b/src/fastertransformer/models/gptj/GptJDecoderLayerWeight.cc @@ -41,17 +41,17 @@ GptJDecoderLayerWeight::~GptJDecoderLayerWeight() deviceFree(weights_ptr[i]); } - pre_layernorm_weights.beta = nullptr; - pre_layernorm_weights.gamma = nullptr; - self_attention_weights.query_weight.kernel = nullptr; - self_attention_weights.query_weight.bias = nullptr; + pre_layernorm_weights.beta = nullptr; + pre_layernorm_weights.gamma = nullptr; + self_attention_weights.query_weight.kernel = nullptr; + self_attention_weights.query_weight.bias = nullptr; self_attention_weights.attention_output_weight.kernel = nullptr; ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.kernel = nullptr; - ffn_weights.output_weight.bias = nullptr; - is_maintain_buffer = false; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + is_maintain_buffer = false; } } @@ -81,8 +81,8 @@ GptJDecoderLayerWeight::GptJDecoderLayerWeight(const GptJDecoderLayerWeight& template GptJDecoderLayerWeight& GptJDecoderLayerWeight::operator=(const GptJDecoderLayerWeight& other) { - hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; tensor_para_size_ = other.tensor_para_size_; tensor_para_rank_ = other.tensor_para_rank_; @@ -109,48 +109,51 @@ void GptJDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType m FT_CHECK(is_maintain_buffer == true); const std::string rank_spec = std::to_string(tensor_para_rank_); - loadWeightFromBin(weights_ptr[0], {hidden_units_}, dir_path + ".input_layernorm.bias.bin", model_file_type); - loadWeightFromBin(weights_ptr[1], {hidden_units_}, dir_path + ".input_layernorm.weight.bin", model_file_type); + loadWeightFromBin( + weights_ptr[0], {(size_t)hidden_units_}, dir_path + ".input_layernorm.bias.bin", model_file_type); + loadWeightFromBin( + weights_ptr[1], {(size_t)hidden_units_}, dir_path + ".input_layernorm.weight.bin", model_file_type); loadWeightFromBin(weights_ptr[2], - {hidden_units_, 3 * hidden_units_ / tensor_para_size_}, + {(size_t)hidden_units_, (size_t)(3 * hidden_units_ / tensor_para_size_)}, dir_path + ".attention.query_key_value.weight." + rank_spec + ".bin", model_file_type); // GPT-J does not have bias for QKV cudaMemset(weights_ptr[3], 0, sizeof(T) * 3 * hidden_units_ / tensor_para_size_); loadWeightFromBin(weights_ptr[4], - {hidden_units_ / tensor_para_size_, hidden_units_}, + {(size_t)(hidden_units_ / tensor_para_size_), (size_t)hidden_units_}, dir_path + ".attention.dense.weight." + rank_spec + ".bin", model_file_type); loadWeightFromBin(weights_ptr[5], - {hidden_units_, inter_size_ / tensor_para_size_}, + {(size_t)hidden_units_, (size_t)(inter_size_ / tensor_para_size_)}, dir_path + ".mlp.dense_h_to_4h.weight." + rank_spec + ".bin", model_file_type); loadWeightFromBin(weights_ptr[6], - {inter_size_ / tensor_para_size_}, + {(size_t)(inter_size_ / tensor_para_size_)}, dir_path + ".mlp.dense_h_to_4h.bias." + rank_spec + ".bin", model_file_type); loadWeightFromBin(weights_ptr[7], - {inter_size_ / tensor_para_size_, hidden_units_}, + {(size_t)(inter_size_ / tensor_para_size_), (size_t)hidden_units_}, dir_path + ".mlp.dense_4h_to_h.weight." + rank_spec + ".bin", model_file_type); - loadWeightFromBin(weights_ptr[8], {hidden_units_}, dir_path + ".mlp.dense_4h_to_h.bias.bin", model_file_type); + loadWeightFromBin( + weights_ptr[8], {(size_t)hidden_units_}, dir_path + ".mlp.dense_4h_to_h.bias.bin", model_file_type); } template void GptJDecoderLayerWeight::setWeightPtr() { - pre_layernorm_weights.beta = weights_ptr[0]; - pre_layernorm_weights.gamma = weights_ptr[1]; - self_attention_weights.query_weight.kernel = weights_ptr[2]; - self_attention_weights.query_weight.bias = weights_ptr[3]; + pre_layernorm_weights.beta = weights_ptr[0]; + pre_layernorm_weights.gamma = weights_ptr[1]; + self_attention_weights.query_weight.kernel = weights_ptr[2]; + self_attention_weights.query_weight.bias = weights_ptr[3]; self_attention_weights.attention_output_weight.kernel = weights_ptr[4]; ffn_weights.intermediate_weight.kernel = weights_ptr[5]; - ffn_weights.intermediate_weight.bias = weights_ptr[6]; - ffn_weights.output_weight.kernel = weights_ptr[7]; - ffn_weights.output_weight.bias = weights_ptr[8]; + ffn_weights.intermediate_weight.bias = weights_ptr[6]; + ffn_weights.output_weight.kernel = weights_ptr[7]; + ffn_weights.output_weight.bias = weights_ptr[8]; is_maintain_buffer = true; } @@ -172,5 +175,8 @@ void GptJDecoderLayerWeight::mallocWeights() template struct GptJDecoderLayerWeight; template struct GptJDecoderLayerWeight; +#ifdef ENABLE_BF16 +template struct GptJDecoderLayerWeight<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/gptj/GptJDecoderLayerWeight.h b/src/fastertransformer/models/gptj/GptJDecoderLayerWeight.h index e20d2bdaa..713436dd1 100644 --- a/src/fastertransformer/models/gptj/GptJDecoderLayerWeight.h +++ b/src/fastertransformer/models/gptj/GptJDecoderLayerWeight.h @@ -41,15 +41,15 @@ struct GptJDecoderLayerWeight { LayerNormWeight pre_layernorm_weights; AttentionWeight self_attention_weights; - FfnWeight ffn_weights; + FfnWeight ffn_weights; private: - int hidden_units_; - int inter_size_; - int tensor_para_size_; - int tensor_para_rank_; + int hidden_units_; + int inter_size_; + int tensor_para_size_; + int tensor_para_rank_; bool is_maintain_buffer = false; - T* weights_ptr[9]; + T* weights_ptr[9]; void setWeightPtr(); void mallocWeights(); diff --git a/src/fastertransformer/models/gptj/GptJWeight.cc b/src/fastertransformer/models/gptj/GptJWeight.cc index ce8441db2..c2fb6898f 100644 --- a/src/fastertransformer/models/gptj/GptJWeight.cc +++ b/src/fastertransformer/models/gptj/GptJWeight.cc @@ -19,15 +19,17 @@ namespace fastertransformer { template -GptJWeight::GptJWeight(const int hidden_units, - const int inter_size, - const int vocab_size, - const int num_layer, - const int max_seq_len, - const int tensor_para_size, - const int tensor_para_rank, - const int layer_para_size, - const int layer_para_rank): +GptJWeight::GptJWeight(const int hidden_units, + const int inter_size, + const int vocab_size, + const int num_layer, + const int max_seq_len, + const int tensor_para_size, + const int tensor_para_rank, + const int layer_para_size, + const int layer_para_rank, + PromptLearningType prompt_learning_type, + std::map> prompt_learning_pair): hidden_units_(hidden_units), inter_size_(inter_size), vocab_size_(vocab_size), @@ -36,8 +38,24 @@ GptJWeight::GptJWeight(const int hidden_units, tensor_para_size_(tensor_para_size), tensor_para_rank_(tensor_para_rank), layer_para_size_(layer_para_size), - layer_para_rank_(layer_para_rank) + layer_para_rank_(layer_para_rank), + prompt_learning_type_(prompt_learning_type), + prompt_learning_pair_(prompt_learning_pair) { + FT_CHECK(num_layer_ % layer_para_size_ == 0); + // set prompt weight size + if (prompt_learning_type_ == PromptLearningType::prefix_prompt) { + prompt_token_weight_size_ = 2 * num_layer_ * hidden_units_ / tensor_para_size_; + } + else if (prompt_learning_type_ == PromptLearningType::p_prompt_tuning) { + prompt_token_weight_size_ = hidden_units_; + } + + // set if load and malloc prompt weights + malloc_load_prompt_weights_ = !prompt_learning_pair_.empty() + && (prompt_learning_type_ == PromptLearningType::p_prompt_tuning + || prompt_learning_type_ == PromptLearningType::prefix_prompt); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { @@ -59,16 +77,16 @@ template GptJWeight::~GptJWeight() { if (is_maintain_buffer == true) { - for (int i = 0; i < 5; i++) { + for (int i = 0; i < weights_ptr.size(); i++) { deviceFree(weights_ptr[i]); } - pre_decoder_embedding_table = nullptr; - post_decoder_layernorm.beta = nullptr; - post_decoder_layernorm.gamma = nullptr; + pre_decoder_embedding_table = nullptr; + post_decoder_layernorm.beta = nullptr; + post_decoder_layernorm.gamma = nullptr; post_decoder_embedding.kernel = nullptr; - post_decoder_embedding.bias = nullptr; - is_maintain_buffer = false; + post_decoder_embedding.bias = nullptr; + is_maintain_buffer = false; } } @@ -82,14 +100,33 @@ GptJWeight::GptJWeight(const GptJWeight& other): tensor_para_size_(other.tensor_para_size_), tensor_para_rank_(other.tensor_para_rank_), layer_para_size_(other.layer_para_size_), - layer_para_rank_(other.layer_para_rank_) + layer_para_rank_(other.layer_para_rank_), + prompt_token_weight_size_(other.prompt_token_weight_size_), + malloc_load_prompt_weights_(other.malloc_load_prompt_weights_), + prompt_learning_type_(other.prompt_learning_type_), + prompt_learning_pair_(other.prompt_learning_pair_) { mallocWeights(); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], vocab_size_ * hidden_units_); cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_ * vocab_size_); cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], vocab_size_); + + // prompt learning table: malloc weights and set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t prompt_id = num_base_weights + (size_t)task_name_id; + + // cuda device to device memcpy prompt table weights buffer memory + cudaD2Dcpy(weights_ptr[prompt_id], other.weights_ptr[prompt_id], prompt_length * prompt_token_weight_size_); + } + } + setWeightPtr(); decoder_layer_weights.clear(); @@ -102,15 +139,19 @@ GptJWeight::GptJWeight(const GptJWeight& other): template GptJWeight& GptJWeight::operator=(const GptJWeight& other) { - hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; - vocab_size_ = other.vocab_size_; - num_layer_ = other.num_layer_; - max_seq_len_ = other.max_seq_len_; - tensor_para_size_ = other.tensor_para_size_; - tensor_para_rank_ = other.tensor_para_rank_; - layer_para_size_ = other.layer_para_size_; - layer_para_rank_ = other.layer_para_rank_; + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + vocab_size_ = other.vocab_size_; + num_layer_ = other.num_layer_; + max_seq_len_ = other.max_seq_len_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + layer_para_size_ = other.layer_para_size_; + layer_para_rank_ = other.layer_para_rank_; + prompt_token_weight_size_ = other.prompt_token_weight_size_; + malloc_load_prompt_weights_ = other.malloc_load_prompt_weights_; + prompt_learning_type_ = other.prompt_learning_type_; + prompt_learning_pair_ = other.prompt_learning_pair_; mallocWeights(); cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], vocab_size_ * hidden_units_); @@ -118,9 +159,23 @@ GptJWeight& GptJWeight::operator=(const GptJWeight& other) cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_ * vocab_size_); cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], vocab_size_); + + // prompt learning tables: malloc weights and set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t prompt_id = num_base_weights + (size_t)task_name_id; + + // cuda device to device memcpy prompt weights buffer memory + cudaD2Dcpy(weights_ptr[prompt_id], other.weights_ptr[prompt_id], prompt_length * prompt_token_weight_size_); + } + } + setWeightPtr(); decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(other.decoder_layer_weights[l]); } @@ -130,39 +185,93 @@ GptJWeight& GptJWeight::operator=(const GptJWeight& other) template void GptJWeight::setWeightPtr() { - pre_decoder_embedding_table = weights_ptr[0]; - post_decoder_layernorm.beta = weights_ptr[1]; - post_decoder_layernorm.gamma = weights_ptr[2]; + prompt_learning_table.resize(prompt_learning_pair_.size()); + + pre_decoder_embedding_table = weights_ptr[0]; + post_decoder_layernorm.beta = weights_ptr[1]; + post_decoder_layernorm.gamma = weights_ptr[2]; post_decoder_embedding.kernel = weights_ptr[3]; - post_decoder_embedding.bias = weights_ptr[4]; + post_decoder_embedding.bias = weights_ptr[4]; + + // prompt learning tables: set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + // set weight ptr + prompt_learning_table[task_name_id] = {weights_ptr[task_weight_id], prompt_length}; + } + } } template void GptJWeight::mallocWeights() { + weights_ptr.resize(num_base_weights + prompt_learning_pair_.size()); + deviceMalloc(&weights_ptr[0], vocab_size_ * hidden_units_); deviceMalloc(&weights_ptr[1], hidden_units_); deviceMalloc(&weights_ptr[2], hidden_units_); deviceMalloc(&weights_ptr[3], hidden_units_ * vocab_size_); deviceMalloc(&weights_ptr[4], vocab_size_); + + // prompt learning tables: malloc weights + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + // malloc weights + T* prompt_weights_ptr = nullptr; + deviceMalloc(&prompt_weights_ptr, prompt_length * prompt_token_weight_size_); + weights_ptr[task_weight_id] = prompt_weights_ptr; + } + } is_maintain_buffer = true; } template void GptJWeight::loadModel(std::string dir_path) { - // FtCudaDataType model_file_type = getModelFileType(dir_path + "/config.ini"); - FtCudaDataType model_file_type = FtCudaDataType::FP32; // only support FP32 now + FtCudaDataType model_file_type = getModelFileType(dir_path + "/config.ini", "gptj"); FT_CHECK(is_maintain_buffer == true); - loadWeightFromBin(weights_ptr[0], {vocab_size_ * hidden_units_}, dir_path + "/model.wte.bin", model_file_type); loadWeightFromBin( - weights_ptr[1], {hidden_units_}, dir_path + "/model.final_layernorm.bias.bin", model_file_type); + weights_ptr[0], {(size_t)(vocab_size_ * hidden_units_)}, dir_path + "/model.wte.bin", model_file_type); loadWeightFromBin( - weights_ptr[2], {hidden_units_}, dir_path + "/model.final_layernorm.weight.bin", model_file_type); + weights_ptr[1], {(size_t)hidden_units_}, dir_path + "/model.final_layernorm.bias.bin", model_file_type); loadWeightFromBin( - weights_ptr[3], {vocab_size_ * hidden_units_}, dir_path + "/model.lm_head.weight.bin", model_file_type); - loadWeightFromBin(weights_ptr[4], {vocab_size_}, dir_path + "/model.lm_head.bias.bin", model_file_type); + weights_ptr[2], {(size_t)hidden_units_}, dir_path + "/model.final_layernorm.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[3], + {(size_t)(vocab_size_ * hidden_units_)}, + dir_path + "/model.lm_head.weight.bin", + model_file_type); + loadWeightFromBin(weights_ptr[4], {(size_t)vocab_size_}, dir_path + "/model.lm_head.bias.bin", model_file_type); + + // prompt table: load weights from bin + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + std::string prompt_weight_path_name = (prompt_learning_type_ == PromptLearningType::p_prompt_tuning) ? + (dir_path + "/model.prompt_table." + task_name + ".weight.bin") : + (dir_path + "/model.prefix_prompt." + task_name + ".weight." + + std::to_string(tensor_para_rank_) + ".bin"); + + if (prompt_length > 0) { + loadWeightFromBin(weights_ptr[task_weight_id], + {(size_t)(prompt_length * (int)prompt_token_weight_size_)}, + prompt_weight_path_name, + model_file_type); + } + } + } for (int l = 0; l < num_layer_; l++) { decoder_layer_weights[l].loadModel(dir_path + "/model.layers." + std::to_string(l), model_file_type); @@ -179,5 +288,8 @@ bool GptJWeight::isValidLayerParallelId(int l) template struct GptJWeight; template struct GptJWeight; +#ifdef ENABLE_BF16 +template struct GptJWeight<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/gptj/GptJWeight.h b/src/fastertransformer/models/gptj/GptJWeight.h index 04d9ca760..90a638703 100644 --- a/src/fastertransformer/models/gptj/GptJWeight.h +++ b/src/fastertransformer/models/gptj/GptJWeight.h @@ -19,6 +19,7 @@ #include "src/fastertransformer/kernels/layernorm_kernels.h" #include "src/fastertransformer/models/gptj/GptJDecoderLayerWeight.h" #include "src/fastertransformer/utils/memory_utils.h" +#include "src/fastertransformer/utils/prompt_learning.h" namespace fastertransformer { @@ -26,15 +27,18 @@ template struct GptJWeight { GptJWeight() = default; - GptJWeight(const int hidden_units, - const int inter_size, - const int vocab_size, - const int num_layer, - const int max_seq_len, - const int tensor_para_size = 1, - const int tensor_para_rank = 0, - const int layer_para_size = 1, - const int layer_para_rank = 0); + GptJWeight( + const int hidden_units, + const int inter_size, + const int vocab_size, + const int num_layer, + const int max_seq_len, + const int tensor_para_size = 1, + const int tensor_para_rank = 0, + const int layer_para_size = 1, + const int layer_para_rank = 0, + PromptLearningType prompt_learning_type = PromptLearningType::no_prompt, + std::map> prompt_learning_pair = std::map>{}); ~GptJWeight(); GptJWeight(const GptJWeight& other); @@ -43,13 +47,22 @@ struct GptJWeight { void loadModel(std::string dir_path); std::vector> decoder_layer_weights; - const T* pre_decoder_embedding_table = nullptr; + const T* pre_decoder_embedding_table = nullptr; // GPT-J does not use embedding table, but we leave the ptr such that // GptJ::forward and Gpt::forward become identical const T* position_encoding_table = nullptr; + /* + prompt_learning_pair = vectors of [weight ptr, prompt length] pair + prompt_length is stored here for compatible prompt learning table + prefix_prompt weights store as shape [num_layers, 2, num_heads, perfix_seq_len, size_per_head] + p/prompt tuning weights store as shape [prompt_len, hidden_units] + idx is the task_name_id of the prompt tables + */ + std::vector> prompt_learning_table = {}; + LayerNormWeight post_decoder_layernorm; - DenseWeight post_decoder_embedding; + DenseWeight post_decoder_embedding; private: void setWeightPtr(); @@ -67,8 +80,16 @@ struct GptJWeight { int layer_para_size_; int layer_para_rank_; - bool is_maintain_buffer = false; - T* weights_ptr[5]; + // prompt learning pair (task_name, (task_name_id, prompt_len)) + PromptLearningType prompt_learning_type_; + std::map> prompt_learning_pair_; + bool malloc_load_prompt_weights_ = false; + // each prompt token's weight size + size_t prompt_token_weight_size_ = 0; + + bool is_maintain_buffer = false; + size_t num_base_weights = 5; + std::vector weights_ptr = std::vector(num_base_weights); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/CMakeLists.txt b/src/fastertransformer/models/gptneox/CMakeLists.txt new file mode 100644 index 000000000..f84dba40f --- /dev/null +++ b/src/fastertransformer/models/gptneox/CMakeLists.txt @@ -0,0 +1,63 @@ +# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +add_library(GptNeoXDecoderLayerWeight STATIC GptNeoXDecoderLayerWeight.cc) +set_property(TARGET GptNeoXDecoderLayerWeight PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET GptNeoXDecoderLayerWeight PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(GptNeoXDecoderLayerWeight PUBLIC memory_utils) + +add_library(GptNeoXDecoder STATIC GptNeoXDecoder.cc) +set_property(TARGET GptNeoXDecoder PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET GptNeoXDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(GptNeoXDecoder PUBLIC -lcudart cublasMMWrapper + TensorParallelDecoderSelfAttentionLayer + TensorParallelGeluFfnLayer + layernorm_kernels + add_residual_kernels + GptNeoXDecoderLayerWeight + tensor + nccl_utils) + +add_library(GptNeoXContextDecoder STATIC GptNeoXContextDecoder.cc) +set_property(TARGET GptNeoXContextDecoder PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET GptNeoXContextDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(GptNeoXContextDecoder PUBLIC -lcudart cublasMMWrapper + TensorParallelGptContextAttentionLayer + TensorParallelGeluFfnLayer + layernorm_kernels + add_residual_kernels + gpt_kernels + tensor + nccl_utils) + +add_library(GptNeoXWeight STATIC GptNeoXWeight.cc) +set_property(TARGET GptNeoXWeight PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET GptNeoXWeight PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(GptNeoXWeight PUBLIC GptNeoXDecoderLayerWeight) + +add_library(GptNeoX STATIC GptNeoX.cc) +set_property(TARGET GptNeoX PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET GptNeoX PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(GptNeoX PUBLIC -lcudart + GptNeoXDecoder + GptNeoXContextDecoder + decoding_kernels + gpt_kernels + DynamicDecodeLayer + BaseBeamSearchLayer + bert_preprocess_kernels + tensor + GptNeoXWeight) diff --git a/src/fastertransformer/models/gptneox/GptNeoX.cc b/src/fastertransformer/models/gptneox/GptNeoX.cc new file mode 100644 index 000000000..4f8492dee --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoX.cc @@ -0,0 +1,1197 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/gptneox/GptNeoX.h" +#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" +#include "src/fastertransformer/kernels/decoding_kernels.h" +#include "src/fastertransformer/kernels/gpt_kernels.h" +#include "src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.h" +#include + +namespace fastertransformer { + +template +void GptNeoX::initialize() +{ + gpt_context_decoder_ = new GptNeoXContextDecoder(head_num_, + size_per_head_, + inter_size_, + num_layer_, + rotary_embedding_dim_, + neox_rotary_style_, + use_gptj_residual_, + layernorm_eps_, + tensor_para_, + pipeline_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + is_context_qk_buf_float_, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + + gpt_decoder_ = new GptNeoXDecoder(head_num_, + size_per_head_, + inter_size_, + num_layer_, + rotary_embedding_dim_, + neox_rotary_style_, + use_gptj_residual_, + layernorm_eps_, + tensor_para_, + pipeline_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + + dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, + vocab_size_padded_, + 0, // end_id, deprecated + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + cuda_device_prop_); +} + +template +void GptNeoX::allocateBuffer() +{ + FT_CHECK(false); +} + +template +void GptNeoX::allocateBuffer( + size_t batch_size, size_t beam_width, size_t max_seq_len, size_t max_cache_seq_len, size_t max_input_len) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + const size_t batchxbeam = batch_size * beam_width; + const size_t self_cache_size = (num_layer_ / pipeline_para_.world_size_) * batchxbeam * max_cache_seq_len + * hidden_units_ / tensor_para_.world_size_; + + if (vocab_size_ != vocab_size_padded_) { + padded_embedding_kernel_ = + (T*)(allocator_->reMalloc(padded_embedding_kernel_, sizeof(T) * hidden_units_ * vocab_size_padded_, true)); + padded_embedding_kernel_ptr_ = padded_embedding_kernel_; + + padded_embedding_bias_ = + (T*)(allocator_->reMalloc(padded_embedding_bias_, sizeof(T) * vocab_size_padded_, true)); + } + + input_attention_mask_ = (T*)(allocator_->reMalloc( + input_attention_mask_, sizeof(T) * batchxbeam * max_seq_len * max_cache_seq_len, false)); + decoder_input_buf_ = (T*)(allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); + decoder_output_buf_ = + (T*)(allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); + normed_decoder_output_buf_ = + (T*)(allocator_->reMalloc(normed_decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); + logits_buf_ = (float*)(allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size_padded_, false)); + nccl_logits_buf_ = + (float*)(allocator_->reMalloc(nccl_logits_buf_, sizeof(float) * batchxbeam * vocab_size_padded_, false)); + cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); + finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); + h_finished_buf_ = new bool[batchxbeam]; + sequence_lengths_ = (int*)(allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false)); + + key_cache_ = (T*)(allocator_->reMalloc(key_cache_, sizeof(T) * self_cache_size * 2, true)); + value_cache_ = key_cache_ + self_cache_size; + if (beam_width > 1) { + cache_indirections_[0] = + (int*)(allocator_->reMalloc(cache_indirections_[0], sizeof(int) * batchxbeam * max_seq_len * 2, true)); + cache_indirections_[1] = cache_indirections_[0] + batchxbeam * max_seq_len; + } + + // prompt_learning weight batch ptrs + prompt_learning_weight_batch_ = + (const T**)(allocator_->reMalloc(prompt_learning_weight_batch_, sizeof(T*) * batchxbeam, false)); + tiled_prompt_lengths_buf_ = + (int*)(allocator_->reMalloc(tiled_prompt_lengths_buf_, sizeof(int) * batchxbeam, false)); + + tiled_input_ids_buf_ = + (int*)(allocator_->reMalloc(tiled_input_ids_buf_, sizeof(int) * batchxbeam * max_input_len, true)); + tiled_input_lengths_buf_ = (int*)(allocator_->reMalloc(tiled_input_lengths_buf_, sizeof(int) * batchxbeam, true)); + tiled_total_padding_count_ = + (int*)allocator_->reMalloc(tiled_total_padding_count_, batchxbeam * sizeof(int), false); + + transposed_output_ids_buf_ = + (int*)(allocator_->reMalloc(transposed_output_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); + output_ids_buf_ = (int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); + parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); + seq_limit_len_ = (uint32_t*)(allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false)); + masked_tokens_ = (bool*)(allocator_->reMalloc(masked_tokens_, sizeof(bool) * batchxbeam * max_cache_seq_len, true)); + + start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * batch_size, false)); + end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false)); + + context_decoder_input_buf_ = (T*)(allocator_->reMalloc( + context_decoder_input_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); + context_decoder_output_buf_ = (T*)(allocator_->reMalloc( + context_decoder_output_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); + output_log_probs_buf_ = + (float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_seq_len, false)); + + if (generation_should_stop_ == nullptr) { + cudaMallocHost(&generation_should_stop_, 1 * sizeof(bool)); + } + + is_allocate_buffer_ = true; +} + +template +void GptNeoX::freeBuffer() +{ + if (is_allocate_buffer_) { + if (vocab_size_ != vocab_size_padded_) { + padded_embedding_kernel_ptr_ = nullptr; + allocator_->free((void**)(&padded_embedding_kernel_)); + allocator_->free((void**)(&padded_embedding_bias_)); + } + + allocator_->free((void**)(&input_attention_mask_)); + allocator_->free((void**)(&decoder_input_buf_)); + allocator_->free((void**)(&decoder_output_buf_)); + allocator_->free((void**)(&normed_decoder_output_buf_)); + allocator_->free((void**)(&logits_buf_)); + allocator_->free((void**)(&nccl_logits_buf_)); + allocator_->free((void**)(&cum_log_probs_)); + allocator_->free((void**)(&finished_buf_)); + delete[] h_finished_buf_; + allocator_->free((void**)(&sequence_lengths_)); + + allocator_->free((void**)(&key_cache_)); + if (cache_indirections_[0] != nullptr) { + allocator_->free((void**)(&cache_indirections_)[0]); + } + + allocator_->free((void**)(&prompt_learning_weight_batch_)); + allocator_->free((void**)(&tiled_prompt_lengths_buf_)); + + allocator_->free((void**)(&tiled_input_ids_buf_)); + allocator_->free((void**)(&tiled_input_lengths_buf_)); + allocator_->free((void**)(&tiled_total_padding_count_)); + + allocator_->free((void**)(&transposed_output_ids_buf_)); + allocator_->free((void**)(&output_ids_buf_)); + allocator_->free((void**)(&parent_ids_buf_)); + allocator_->free((void**)(&seq_limit_len_)); + allocator_->free((void**)(&masked_tokens_)); + + allocator_->free((void**)(&start_ids_buf_)); + allocator_->free((void**)(&end_ids_buf_)); + + allocator_->free((void**)(&context_decoder_input_buf_)); + allocator_->free((void**)(&context_decoder_output_buf_)); + allocator_->free((void**)(&output_log_probs_buf_)); + + cudaFreeHost(generation_should_stop_); + + is_allocate_buffer_ = false; + } +} + +template +GptNeoX::GptNeoX(size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + bool use_gptj_residual, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop), + head_num_(head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + vocab_size_(vocab_size), + rotary_embedding_dim_(rotary_embedding_dim), + start_id_(start_id), + end_id_(end_id), + prompt_learning_start_id_(prompt_learning_start_id), + prompt_learning_type_(prompt_learning_type), + use_gptj_residual_(use_gptj_residual), + hidden_units_(head_num * size_per_head), + local_head_num_(head_num / 1) +{ + tensor_para_.world_size_ = 1; + tensor_para_.rank_ = 0; + pipeline_para_.world_size_ = 1; + pipeline_para_.rank_ = 0; + + int local_vacab_size = ceil(vocab_size_ / 1.f / tensor_para_.world_size_); + if (std::is_same::value) { + local_vacab_size = ceil(local_vacab_size / 8.f) * 8; + } + vocab_size_padded_ = (size_t)local_vacab_size * tensor_para_.world_size_; + initialize(); +} + +template +GptNeoX::GptNeoX(size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + bool use_gptj_residual, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop), + head_num_(head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + vocab_size_(vocab_size), + rotary_embedding_dim_(rotary_embedding_dim), + start_id_(start_id), + end_id_(end_id), + prompt_learning_start_id_(prompt_learning_start_id), + prompt_learning_type_(prompt_learning_type), + use_gptj_residual_(use_gptj_residual), + hidden_units_(head_num * size_per_head), + tensor_para_(tensor_para), + pipeline_para_(pipeline_para), + local_head_num_(head_num / tensor_para.world_size_), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce) +{ + int local_vacab_size = ceil(vocab_size_ / 1.f / tensor_para_.world_size_); + if (std::is_same::value) { + local_vacab_size = ceil(local_vacab_size / 8.f) * 8; + } + vocab_size_padded_ = (size_t)local_vacab_size * tensor_para_.world_size_; + initialize(); +} + +template +GptNeoX::GptNeoX(GptNeoX const& gpt): + BaseLayer(gpt), + head_num_(gpt.head_num_), + size_per_head_(gpt.size_per_head_), + inter_size_(gpt.inter_size_), + num_layer_(gpt.num_layer_), + vocab_size_(gpt.vocab_size_), + rotary_embedding_dim_(gpt.rotary_embedding_dim_), + start_id_(gpt.start_id_), + end_id_(gpt.end_id_), + prompt_learning_start_id_(gpt.prompt_learning_start_id_), + prompt_learning_type_(gpt.prompt_learning_type_), + use_gptj_residual_(gpt.use_gptj_residual_), + hidden_units_(gpt.hidden_units_), + tensor_para_(gpt.tensor_para_), + pipeline_para_(gpt.pipeline_para_), + local_head_num_(gpt.local_head_num_), + vocab_size_padded_(gpt.vocab_size_padded_), + custom_all_reduce_comm_(gpt.custom_all_reduce_comm_), + enable_custom_all_reduce_(gpt.enable_custom_all_reduce_) +{ + initialize(); +} + +template +GptNeoX::~GptNeoX() +{ + delete gpt_decoder_; + delete dynamic_decode_layer_; + delete gpt_context_decoder_; + freeBuffer(); +} + +template +void GptNeoX::registerCallback(callback_sig* fn, void* ctx) +{ + token_generated_cb_ = fn; + token_generated_ctx_ = ctx; +} + +template +void GptNeoX::unRegisterCallback() +{ + token_generated_cb_ = nullptr; + token_generated_ctx_ = nullptr; +} + +template +void GptNeoX::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const GptNeoXWeight* gpt_weights) +{ + FT_CHECK(false); +} + +template +void GptNeoX::forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const GptNeoXWeight* gpt_weights) +{ + // input_tensors: + // input_ids [batch_size, max_input_length] + // input_lengths [batch_size] + // prompt_learning_task_name_ids [batch_size] on cpu, optional + // output_seq_len [batch_size] on cpu + // start_id [batch_size] on cpu, optional + // end_id [batch_size] on cpu, optional + // stop_words_list [batch_size, 2, stop_words_length], optional + // bad_words_list [2, bad_words_length] or [batch_size, 2, bad_words_length], optional + // runtime_top_k [1] or [batch_size] on cpu, optional, uint. + // runtime_top_p [1] or [batch_size] on cpu, optional, float. + // beam_search_diversity_rate [1] or [batch_size] on cpu, optional, float. + // temperature [1] or [batch_size] on cpu, optional, float. + // len_penalty [1] or [batch_size] on cpu, optional, float. + // repetition_penalty [1] or [batch_size] on cpu, optional, float. + // random_seed [1] or [batch_size] on cpu, optional, unsigned long long int. + // request_prompt_lengths [batch_size], optional + // request_prompt_embedding [batch_size, max_prompt_length, hidden_units], float, optional + // requst_prompt_type [batch_size], int, optional + + // output_tensors: + // output_ids [batch_size, beam_width, max_output_seq_len] + // sequence_length [batch_size, beam_width] + // output_log_probs [batch_size, beam_width, request_output_seq_len], must be float*. + // optional. It leads to additional computing cost. If we don't need this result, don't put it. + // cum_log_probs [batch_size, beam], optional, must be float*. + // optional. It leads to additional computing cost. If we don't need this result, don't put it. + + // Step is from max_input_length ~ max_output_seq_len, + // When step = k, we put output ids and caches at step k, and the sequence_length would be k - 1 before + // complete this step. + // When there is no input_ids, put the start token at step 0 of output_ids_buf_. After forward, only copy + // the step 1 ~ max_output_seq_len of output_ids_buf_ to output_tensors->at(0).data + + FT_CHECK_WITH_INFO(input_tensors->size() >= 3, "input_tensors->size() >= 3"); + FT_CHECK_WITH_INFO(output_tensors->size() >= 2, "output_tensors->size() >= 2"); + FT_CHECK(input_tensors->at("input_ids").shape.size() == 2); + FT_CHECK(input_tensors->at("input_lengths").shape.size() == 1); + FT_CHECK(input_tensors->find("output_seq_len") != input_tensors->end() + && input_tensors->at("output_seq_len").shape.size() == 1); + FT_CHECK(output_tensors->at("output_ids").shape.size() == 3); + FT_CHECK(output_tensors->at("sequence_length").shape.size() == 2); + FT_CHECK_WITH_INFO(input_tensors->at("input_ids").shape[0] == output_tensors->at("output_ids").shape[0], + "input_tensors->at(\"input_ids\").shape[0] == output_tensors->at(\"output_ids\").shape[0]"); + + const size_t batch_size = output_tensors->at("output_ids").shape[0]; + const size_t beam_width = output_tensors->at("output_ids").shape[1]; + + PromptLearningType request_prompt_type = PromptLearningType::no_prompt; + int valid_prompt_inputs = input_tensors->count("request_prompt_type") + + input_tensors->count("request_prompt_lengths") + + input_tensors->count("request_prompt_embedding"); + + if (valid_prompt_inputs == 3) { + request_prompt_type = static_cast(input_tensors->at("request_prompt_type").getVal()); + FT_LOG_INFO("Apply prompt embedding from input, will ignore task name ids"); + } + else if (valid_prompt_inputs > 0) { + FT_LOG_WARNING( + "Prompts not applied: request_prompt_embedding, request_prompt_lengths, request_prompt_type are all needed!"); + } + if (request_prompt_type == PromptLearningType::prefix_prompt) { + FT_LOG_WARNING("Request prompt doesn't support prefix prompt currently!"); + } + + // Prefix Prompt Inputs + // Padding works as follows: p p x x i i i x x --> p p i i i x x x x (p denotes prompt, i denotes input, x denotes + // pad) + // TODO (perkzz): move unnecessary paddings + const int* prompt_learning_task_name_ids = + input_tensors->count("prompt_learning_task_name_ids") ? + (const int*)(input_tensors->at("prompt_learning_task_name_ids").data) : + nullptr; + has_prefix_prompt_ = + (prompt_learning_task_name_ids != nullptr) && (prompt_learning_type_ == PromptLearningType::prefix_prompt); + int max_prefix_prompt_length = 0; + + FT_CHECK_WITH_INFO( + !(prompt_learning_task_name_ids != nullptr + && (prompt_learning_type_ == PromptLearningType::no_prompt + || prompt_learning_type_ == PromptLearningType::soft_prompt)), + "prompt_learning_type is prefix_prompt either p_prompt_tuning when prompt_learning_task_name_ids are provided."); + + // NOTE: Prefix Prompt PreProcessing + // get prefix_prompt_weight for each batch --> shape [batch, beam_width] + // --> ptrs with shape [num_layers, 2, num_heads, perfix_seq_len, size_per_head] + std::vector prefix_prompt_weight_batch_ptrs; + std::vector prefix_prompt_lengths; + if (has_prefix_prompt_) { + for (int bs_id = 0; bs_id < batch_size; ++bs_id) { + int task_id = prompt_learning_task_name_ids[bs_id]; + // throw errors when prompt task_name_ids are not found + std::pair prefix_prompt_weight_length_pair; + try { + prefix_prompt_weight_length_pair = gpt_weights->prompt_learning_table.at(task_id); + } + catch (const std::out_of_range& oor) { + FT_LOG_ERROR("prefix_prompt_weights_lengths not found for prompt task id: " + task_id); + throw oor; + } + for (int bw_id = 0; bw_id < beam_width; ++bw_id) { + prefix_prompt_weight_batch_ptrs.push_back(prefix_prompt_weight_length_pair.first); + prefix_prompt_lengths.push_back(prefix_prompt_weight_length_pair.second); + } + } + + max_prefix_prompt_length = *max_element(prefix_prompt_lengths.begin(), prefix_prompt_lengths.end()); + + FT_LOG_DEBUG("max_prefix_prompt_length: %d", max_prefix_prompt_length); + + if (max_prefix_prompt_length == 0) { + has_prefix_prompt_ = false; + FT_LOG_DEBUG("prompts are not applied !"); + } + } + + int max_input_length = input_tensors->at("input_ids").shape[1]; + FT_CHECK_WITH_INFO(!(max_input_length == 0 && max_prefix_prompt_length > 0), + "Prefix Prompt should come with inputs!"); + + // Prefix Soft Prompt + has_prefix_soft_prompt_ = request_prompt_type == PromptLearningType::soft_prompt; + const size_t max_prefix_soft_prompt_length = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_embedding").shape[1] : 0; + const size_t limit_len_offset = max_prefix_soft_prompt_length + (max_input_length == 0 ? 1 : 0); + const size_t max_output_seq_len = input_tensors->at("output_seq_len").max() + limit_len_offset; + const size_t max_seq_len = max_output_seq_len; + // max cache seq len should include max prefix prompt length as it has k/v states + const size_t max_cache_seq_len = max_output_seq_len + max_prefix_prompt_length; + if (max_cache_seq_len < max_seq_len) { + FT_LOG_WARNING("max_cache_seq_len (%d) is less than max_seq_len (%d). " + "Note that this reduces the memory cost of k/v cache, but may hurt the accuracy.", + max_cache_seq_len, + max_seq_len); + } + else if (max_cache_seq_len > max_seq_len) { + FT_LOG_WARNING("max_cache_seq_len (%d) is larger than max_seq_len (%d). " + "This may lead to additional memory cost. Suggest to use smaller max_cache_seq_len.", + max_cache_seq_len, + max_seq_len); + } + const cudaDataType_t gemm_data_type = getCudaDataType(); + allocateBuffer( + batch_size, beam_width, max_seq_len, max_cache_seq_len, max_input_length + max_prefix_soft_prompt_length); + setSeqLimitLen(seq_limit_len_, input_tensors->at("output_seq_len"), limit_len_offset, batch_size); + + sync_check_cuda_error(); + dynamic_decode_layer_->setup(batch_size, beam_width, input_tensors); + + const DataType data_type = getTensorType(); + + handleOptArg(input_tensors, "start_id", start_ids_buf_, start_id_, batch_size); + handleOptArg(input_tensors, "end_id", end_ids_buf_, end_id_, batch_size); + + const std::vector self_k_cache_shape = {num_layer_ / pipeline_para_.world_size_, + batch_size * beam_width, + local_head_num_, + size_per_head_ / (16 / sizeof(T)), + max_cache_seq_len, + 16 / sizeof(T)}; + const std::vector self_v_cache_shape = {num_layer_ / pipeline_para_.world_size_, + batch_size * beam_width, + local_head_num_, + max_cache_seq_len, + size_per_head_}; + + // initialize the output ids and parent ids + cudaMemsetAsync(output_ids_buf_, 0, sizeof(int) * batch_size * beam_width * max_seq_len, stream_); + cudaMemsetAsync(parent_ids_buf_, 0, sizeof(int) * batch_size * beam_width * max_seq_len, stream_); + cudaMemsetAsync(masked_tokens_, false, sizeof(bool) * batch_size * beam_width * max_cache_seq_len, stream_); + cudaMemsetAsync(tiled_total_padding_count_, 0, sizeof(int) * batch_size * beam_width, stream_); + if (beam_width > 1) { + cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * max_seq_len, stream_); + } + + // Prefix prompts + if (has_prefix_prompt_) { + cudaMemcpyAsync(prompt_learning_weight_batch_, + prefix_prompt_weight_batch_ptrs.data(), + sizeof(T*) * batch_size * beam_width, + cudaMemcpyDefault, + stream_); + cudaMemcpyAsync(tiled_prompt_lengths_buf_, + prefix_prompt_lengths.data(), + sizeof(int) * batch_size * beam_width, + cudaMemcpyDefault, + stream_); + } + + sync_check_cuda_error(); + + // handle first step + if (has_prefix_prompt_ || has_prefix_soft_prompt_ || max_input_length > 1) { + invokeTileGptInputs(tiled_input_ids_buf_, + tiled_input_lengths_buf_, + (int*)input_tensors->at("input_ids").data, + (const int*)(input_tensors->at("input_lengths").data), + batch_size, + beam_width, + max_input_length, + stream_); + sync_check_cuda_error(); + + if (has_prefix_soft_prompt_) { + inputIdsEmbeddingLookupPosEncodingSoftPromptParam param; + param.from_tensor = context_decoder_input_buf_; + param.output_ids = output_ids_buf_; + param.input_lengths = tiled_input_lengths_buf_; + param.embedding_table = gpt_weights->pre_decoder_embedding_table; + param.pos_table = gpt_weights->position_encoding_table; + param.prefix_soft_prompt_embedding = input_tensors->at("request_prompt_embedding").getPtr(); + param.prefix_soft_prompt_lengths = input_tensors->at("request_prompt_lengths").getPtr(); + param.input_ids = tiled_input_ids_buf_; + param.start_step = 1; + param.max_input_length = max_input_length; + param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.hidden_units = hidden_units_; + param.stream = stream_; + + invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(param); + sync_check_cuda_error(); + max_input_length += max_prefix_soft_prompt_length; // view soft_prompt as input + } + else { + invokeInputIdsEmbeddingLookupPosEncoding(context_decoder_input_buf_, + output_ids_buf_, + gpt_weights->pre_decoder_embedding_table, + gpt_weights->position_encoding_table, + pPromptTuningParam{}, // no p/prompt tuning + tiled_input_ids_buf_, + 1, + max_input_length, + max_input_length, + batch_size * beam_width, + hidden_units_, + stream_); + sync_check_cuda_error(); + } + + invokeBuildDecoderAttentionMask(input_attention_mask_, + tiled_input_lengths_buf_, + tiled_prompt_lengths_buf_, + batch_size * beam_width, + max_input_length, + max_prefix_prompt_length, + stream_); + sync_check_cuda_error(); + + std::unordered_map decoder_input_tensors{ + {"decoder_input", + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width, (size_t)max_input_length, hidden_units_}, + context_decoder_input_buf_}}, + {"attention_mask", + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width, + 1, + (size_t)max_input_length, + (size_t)(max_input_length + max_prefix_prompt_length)}, + input_attention_mask_}}, + {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, tiled_input_lengths_buf_}}, + {"d_prefix_prompt_batch", + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width}, + has_prefix_prompt_ ? prompt_learning_weight_batch_ : nullptr}}, + {"d_prefix_prompt_lengths", + Tensor{MEMORY_GPU, + TYPE_INT32, + {batch_size * beam_width}, + has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : nullptr}}}; + + std::unordered_map decoder_output_tensors{ + {"decoder_output", + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width, (size_t)max_input_length, hidden_units_}, + context_decoder_output_buf_}}, + {"key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}}, + {"value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}}, + {"last_token_hidden_units", + Tensor{MEMORY_GPU, data_type, {batch_size * beam_width, hidden_units_}, decoder_output_buf_}}}; + + gpt_context_decoder_->forward( + &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); + sync_check_cuda_error(); + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + nullptr, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + max_input_length - 1, + stream_); + sync_check_cuda_error(); + } + else if (max_input_length == 0) { + FT_CHECK(prompt_learning_type_ == PromptLearningType::no_prompt + && request_prompt_type == PromptLearningType::no_prompt); // Not support prompts in this case + max_input_length++; + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + output_ids_buf_, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + max_input_length - 1, + stream_); + std::vector h_input_lengths(batch_size * beam_width, 1); + cudaMemcpyAsync(tiled_input_lengths_buf_, + h_input_lengths.data(), + sizeof(int) * batch_size * beam_width, + cudaMemcpyHostToDevice, + stream_); + sync_check_cuda_error(); + } + else if (max_input_length == 1) { + FT_CHECK(prompt_learning_type_ == PromptLearningType::no_prompt + && request_prompt_type == PromptLearningType::no_prompt); // Not support prompts in this case + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + nullptr, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + max_input_length - 1, + stream_); + sync_check_cuda_error(); + invokeTileGptInputs(tiled_input_ids_buf_, + tiled_input_lengths_buf_, + (int*)input_tensors->at("input_ids").data, + (const int*)(input_tensors->at("input_lengths").data), + batch_size, + beam_width, + max_input_length, + stream_); + sync_check_cuda_error(); + + cudaMemcpyAsync(output_ids_buf_, + tiled_input_ids_buf_, + sizeof(int) * batch_size * beam_width, + cudaMemcpyDeviceToDevice, + stream_); + } + + if (vocab_size_ == vocab_size_padded_) { + padded_embedding_kernel_ptr_ = gpt_weights->post_decoder_embedding.kernel; + } + else { + cudaMemcpyAsync(padded_embedding_kernel_, + gpt_weights->post_decoder_embedding.kernel, + sizeof(T) * vocab_size_ * hidden_units_, + cudaMemcpyDeviceToDevice, + stream_); + cudaMemcpyAsync(padded_embedding_bias_, + gpt_weights->post_decoder_embedding.bias, + sizeof(T) * vocab_size_, + cudaMemcpyDeviceToDevice, + stream_); + sync_check_cuda_error(); + } + + invokeMaskPaddingTokens(masked_tokens_, + (const int*)(input_tensors->at("input_lengths").data), // not_tiled + tiled_prompt_lengths_buf_, + max_cache_seq_len, + max_input_length + max_prefix_prompt_length, + 0, + batch_size, + beam_width, + stream_); + + for (int step = max_input_length; step < (int)max_output_seq_len; step++) { + const int src_indir_idx = (step - max_input_length) % 2; + const int tgt_indir_idx = 1 - src_indir_idx; + + const size_t local_batch_size = getLocalBatchSize(batch_size, 1, pipeline_para_.world_size_); + FT_CHECK(batch_size % local_batch_size == 0); + const size_t iteration_num = batch_size / local_batch_size; + *generation_should_stop_ = true; + + for (uint ite = 0; ite < iteration_num; ++ite) { + const int id_offset = ite * local_batch_size * beam_width; + const int hidden_units_offset = id_offset * hidden_units_; + const int vocab_size_units_offset = id_offset * vocab_size_padded_; + + if (!(max_input_length > 1 && step == max_input_length)) { + if (pipeline_para_.rank_ == 0) { + invokeEmbeddingLookupPosEncodingPadCount(decoder_input_buf_ + hidden_units_offset, + gpt_weights->pre_decoder_embedding_table, + gpt_weights->position_encoding_table, + output_ids_buf_ + id_offset, + tiled_total_padding_count_ + id_offset, + local_batch_size * beam_width, + hidden_units_, + (T)(1.0f), + step - 1, + batch_size * beam_width, + 0, + stream_); + sync_check_cuda_error(); + } + std::unordered_map decoder_input_tensors{ + {"decoder_input", + Tensor{MEMORY_GPU, + data_type, + {local_batch_size * beam_width, hidden_units_}, + decoder_input_buf_ + hidden_units_offset}}, + {"finished", + Tensor{MEMORY_GPU, TYPE_BOOL, {local_batch_size * beam_width}, finished_buf_ + id_offset}}, + {"sequence_lengths", + Tensor{MEMORY_GPU, TYPE_INT32, {local_batch_size * beam_width}, sequence_lengths_ + id_offset}}, + {"total_padding_tokens", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size * beam_width}, + tiled_total_padding_count_ + id_offset}}, + {"d_prefix_prompt_lengths", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size}, + has_prefix_prompt_ ? (tiled_prompt_lengths_buf_ + id_offset) : nullptr}}, + {"max_prefix_prompt_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_prefix_prompt_length}}, + {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, + {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}, + {"ite", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &ite}}, + {"cache_indirection", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size, beam_width, max_output_seq_len}, + beam_width > 1 ? cache_indirections_[src_indir_idx] + id_offset * max_output_seq_len : + nullptr}}, + {"masked_tokens", + Tensor{MEMORY_GPU, + TYPE_BOOL, + {local_batch_size * beam_width, max_cache_seq_len}, + masked_tokens_ + id_offset * max_cache_seq_len}}}; + std::unordered_map decoder_output_tensors{ + {"decoder_output", + Tensor{MEMORY_GPU, + data_type, + {local_batch_size * beam_width, hidden_units_}, + decoder_output_buf_ + hidden_units_offset}}, + {"key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}}, + {"value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}}}; + gpt_decoder_->forward( + &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); + } + + if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { + invokeGeneralLayerNorm(normed_decoder_output_buf_ + hidden_units_offset, + decoder_output_buf_ + hidden_units_offset, + gpt_weights->post_decoder_layernorm.gamma, + gpt_weights->post_decoder_layernorm.beta, + layernorm_eps_, + local_batch_size * beam_width, + hidden_units_, + stream_); + sync_check_cuda_error(); + + if (tensor_para_.world_size_ == 1) { + float alpha = 1.0f; + float beta = 0.0f; + cublas_wrapper_->Gemm(CUBLAS_OP_T, + CUBLAS_OP_N, + vocab_size_padded_, // n + local_batch_size * beam_width, + hidden_units_, // k + &alpha, + padded_embedding_kernel_ptr_, + gemm_data_type, + hidden_units_, // k + normed_decoder_output_buf_ + hidden_units_offset, + gemm_data_type, + hidden_units_, // k + &beta, + logits_buf_ + vocab_size_units_offset, + CUDA_R_32F, + vocab_size_padded_, /* n */ + CUDA_R_32F, + cublasGemmAlgo_t(-1)); + } + else { + FT_CHECK(vocab_size_padded_ % tensor_para_.world_size_ == 0); + const int local_vocab_size = vocab_size_padded_ / tensor_para_.world_size_; + float alpha = 1.0f; + float beta = 0.0f; + cublas_wrapper_->Gemm(CUBLAS_OP_T, + CUBLAS_OP_N, + local_vocab_size, // n + local_batch_size * beam_width, + hidden_units_, // k + &alpha, + padded_embedding_kernel_ptr_ + + tensor_para_.rank_ * local_vocab_size * hidden_units_, + gemm_data_type, + hidden_units_, // k + normed_decoder_output_buf_ + hidden_units_offset, + gemm_data_type, + hidden_units_, // k + &beta, + nccl_logits_buf_ + vocab_size_units_offset + + tensor_para_.rank_ * local_batch_size * beam_width * local_vocab_size, + CUDA_R_32F, + local_vocab_size, /* n */ + CUDA_R_32F, + cublasGemmAlgo_t(-1)); + ftNcclAllGather(nccl_logits_buf_ + vocab_size_units_offset, + nccl_logits_buf_ + vocab_size_units_offset, + local_batch_size * beam_width * local_vocab_size, + tensor_para_.rank_, + tensor_para_, + stream_); + invokeTransposeAxis01(logits_buf_ + vocab_size_units_offset, + nccl_logits_buf_ + vocab_size_units_offset, + tensor_para_.world_size_, + local_batch_size * beam_width, + local_vocab_size, + stream_); + } + + int tmp_local_batch_size = local_batch_size; + bool is_initialize_random_table = step == max_input_length; + std::unordered_map dynamic_decode_input_tensors{ + {"logits", + Tensor{MEMORY_GPU, TYPE_FP32, {batch_size, beam_width, vocab_size_padded_}, logits_buf_}}, + {"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size_padded_}, nullptr}}, + {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}, + {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, + {"input_lengths", + Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, tiled_input_lengths_buf_}}, + {"sequence_limit_length", Tensor{MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len_}}, + {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}}, + {"src_key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}}, + {"src_value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}}, + {"src_cache_indirection", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size, beam_width, max_output_seq_len}, + cache_indirections_[src_indir_idx] + id_offset * max_output_seq_len}}, + {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_local_batch_size}}, + {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids_buf_}}, + {"is_initialize_random_table", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_initialize_random_table}}}; + + for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { + if (dynamic_decode_input_tensors.find(t->first) == dynamic_decode_input_tensors.end()) { + dynamic_decode_input_tensors.insert(*t); + } + } + + // common outputs + bool subbatch_should_stop = false; + std::unordered_map dynamic_decode_output_tensors{ + {"output_ids", + Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, output_ids_buf_}}, + {"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, finished_buf_}}, + // cum_log_probs is necessary for beam search, while it is optional for sampling. + {"cum_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + {batch_size * beam_width}, + ((beam_width > 1) || (output_tensors->count("cum_log_probs") > 0)) ? cum_log_probs_ : + nullptr}}, + {"output_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + {max_seq_len, batch_size, beam_width}, + output_tensors->count("output_log_probs") > 0 + && output_tensors->at("output_log_probs").data != nullptr ? + output_log_probs_buf_ : + nullptr}}, + {"parent_ids", + Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, parent_ids_buf_}}, + {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, sequence_lengths_}}, + {"tgt_cache_indirection", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size, beam_width, max_output_seq_len}, + cache_indirections_[tgt_indir_idx] + id_offset * max_output_seq_len}}, + {"should_stop", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &subbatch_should_stop}}}; + + for (auto t = output_tensors->begin(); t != output_tensors->end(); ++t) { + // Handle exceptions. + if (t->first == "cum_log_probs" || t->first == "output_log_probs") { + continue; + } + dynamic_decode_output_tensors.insert(*t); + } + + dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); + *generation_should_stop_ &= subbatch_should_stop; + } + } + + if (pipeline_para_.world_size_ > 1) { + ftNcclGroupStart(); + ftNcclBroadCast(output_ids_buf_ + step * batch_size * beam_width, + batch_size * beam_width, + pipeline_para_.world_size_ - 1, + pipeline_para_, + stream_); + + ftNcclBroadCast( + sequence_lengths_, batch_size * beam_width, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + + ftNcclBroadCast(generation_should_stop_, 1, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + + if (beam_width > 1) { + ftNcclBroadCast(cache_indirections_[tgt_indir_idx], + batch_size * beam_width * max_output_seq_len, + pipeline_para_.world_size_ - 1, + pipeline_para_, + stream_); + } + ftNcclGroupEnd(); + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); + sync_check_cuda_error(); + } + + if (*generation_should_stop_) { + break; + } + if (token_generated_cb_ && step + 1 < (int)max_output_seq_len) { + setOutputTensors(output_tensors, input_tensors, max_output_seq_len); + sendTensorsToFirstPipelineNode(output_tensors, input_tensors); + + if (pipeline_para_.rank_ == 0 && tensor_para_.rank_ == 0) { + token_generated_cb_(output_tensors, token_generated_ctx_); + } + } + if (step == max_input_length) { + /* We have just finished processing input: update the padding count: + * total_padding_count += (max_input_length - input_lengths) + * if has prefix prompts, += (max_prefix_prompt_length - prompt_length) + */ + invokeUpdatePaddingCount(tiled_total_padding_count_, + (const int*)(input_tensors->at("input_lengths").data), // not_tiled + has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : (const int*)nullptr, + max_input_length, + has_prefix_prompt_ ? max_prefix_prompt_length : 0, + batch_size, + beam_width, + stream_); + } + } + + setOutputTensors(output_tensors, input_tensors, max_output_seq_len); + sendTensorsToFirstPipelineNode(output_tensors, input_tensors); +} + +template +void GptNeoX::sendTensorsToFirstPipelineNode(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (pipeline_para_.world_size_ == 1) { + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); + return; + } + + const auto pp_rank = pipeline_para_.rank_; + + ftNcclGroupStart(); + for (auto const& it : *output_tensors) { + if (it.second.data == nullptr) { + continue; + } + + if (pp_rank == pipeline_para_.world_size_ - 1) { + ftNcclSend(it.second.getPtr(), it.second.sizeBytes(), 0, pipeline_para_, stream_); + } + else if (pp_rank == 0) { + ftNcclRecv(it.second.getPtr(), + it.second.sizeBytes(), + pipeline_para_.world_size_ - 1, + pipeline_para_, + stream_); + } + } + ftNcclGroupEnd(); + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); +} + +template +void GptNeoX::setOutputTensors(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const size_t max_output_seq_len) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (pipeline_para_.rank_ != pipeline_para_.world_size_ - 1) { + return; + } + + const size_t batch_size = output_tensors->at("output_ids").shape[0]; + const size_t beam_width = output_tensors->at("output_ids").shape[1]; + int* sequence_lengths = output_tensors->at("sequence_length").getPtr(); + const int max_input_length = input_tensors->at("input_ids").shape[1]; + const size_t max_prefix_soft_prompt_length = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_embedding").shape[1] : 0; + + cudaAutoCpy(sequence_lengths, sequence_lengths_, output_tensors->at("sequence_length").size(), stream_); + if (input_tensors->at("input_ids").shape[1] == 0) { + // TODO: D2D sequence_lenghts + if (beam_width > 1) { + // For beam search, do gather_tree + // take output_parent_ids as inter buffer + invokeGatherTree(transposed_output_ids_buf_, + sequence_lengths_, + max_output_seq_len, + batch_size, + beam_width, + output_ids_buf_ + batch_size * beam_width, + parent_ids_buf_ + batch_size * beam_width, + end_ids_buf_, + stream_); + + // transpose and take output_parent_ids as inter buffer + invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, + transposed_output_ids_buf_, + max_output_seq_len - 1, + batch_size * beam_width, + 1, + stream_); + } + else { + // For sampling, only copy the results to output_tensor + invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, + output_ids_buf_ + batch_size * beam_width, + max_output_seq_len - 1, + batch_size * beam_width, + 1, + stream_); + } + } + else { + // add sequence_length 1 here because the sequence_length of time step t is t - 1 + invokePlusScalar(sequence_lengths, 1, batch_size * beam_width, stream_); + + // For sampling, it is equivalent to all parent ids are 0. + gatherTreeParam param; + param.beams = transposed_output_ids_buf_; + param.max_sequence_lengths = sequence_lengths; + param.max_time = max_output_seq_len; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.step_ids = output_ids_buf_; + param.parent_ids = beam_width == 1 ? nullptr : parent_ids_buf_; + param.end_tokens = end_ids_buf_; + param.max_input_length = max_input_length; + param.prefix_soft_prompt_lengths = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_lengths").getPtr() : nullptr; + param.input_lengths = tiled_input_lengths_buf_; + param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; + param.stream = stream_; + param.output_ids = (int*)output_tensors->at("output_ids").data; + invokeGatherTree(param); + sync_check_cuda_error(); + } + if ((output_tensors->count("output_log_probs") > 0 && output_tensors->at("output_log_probs").data != nullptr)) { + invokeTransposeAxis01(output_tensors->at("output_log_probs").getPtr(), + output_log_probs_buf_, + input_tensors->at("output_seq_len").max() - max_input_length, + batch_size * beam_width, + 1, + stream_); + } + // Return the cumulative log probability if requested. + if (output_tensors->count("cum_log_probs") > 0) { + Tensor cum_log_probs = output_tensors->at("cum_log_probs"); + FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, + "The shape of cum_log_probs does not match with batch_size x beam_width."); + cudaAutoCpy(cum_log_probs.getPtr(), cum_log_probs_, cum_log_probs.size(), stream_); + } +} + +template +size_t GptNeoX::getPipelineParallelRank() +{ + return pipeline_para_.rank_; +} + +template +size_t GptNeoX::getPipelineParallelSize() +{ + return pipeline_para_.world_size_; +} + +template +size_t GptNeoX::getTensorParallelRank() +{ + return tensor_para_.rank_; +} + +template +size_t GptNeoX::getTensorParallelSize() +{ + return tensor_para_.world_size_; +} + +template +bool* GptNeoX::getFinishBuffer() +{ + return finished_buf_; +} + +template class GptNeoX; +template class GptNeoX; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/GptNeoX.h b/src/fastertransformer/models/gptneox/GptNeoX.h new file mode 100644 index 000000000..60517939a --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoX.h @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "src/fastertransformer/layers/DynamicDecodeLayer.h" +#include "src/fastertransformer/models/gptneox/GptNeoXContextDecoder.h" +#include "src/fastertransformer/models/gptneox/GptNeoXDecoder.h" +#include "src/fastertransformer/models/gptneox/GptNeoXWeight.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/prompt_learning.h" + +namespace fastertransformer { + +template +class GptNeoX: public BaseLayer { +private: + // meta data + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + size_t rotary_embedding_dim_; + + static constexpr bool neox_rotary_style_ = true; + static constexpr float layernorm_eps_ = 1e-5f; + + int start_id_; + int end_id_; + size_t hidden_units_; + + size_t local_head_num_; + NcclParam tensor_para_; + NcclParam pipeline_para_; + + std::shared_ptr custom_all_reduce_comm_; + int enable_custom_all_reduce_; + + size_t vocab_size_padded_; + const bool is_context_qk_buf_float_ = true; + + // Residual Type + const bool use_gptj_residual_ = true; + + // Prompt Learning Parameters + PromptLearningType prompt_learning_type_; + int prompt_learning_start_id_; // start_id for prompt_learning (only needed by prefix prompts) + bool has_prefix_prompt_; + bool has_prefix_soft_prompt_; + + GptNeoXDecoder* gpt_decoder_; + GptNeoXContextDecoder* gpt_context_decoder_; + DynamicDecodeLayer* dynamic_decode_layer_; + + void allocateBuffer() override; + void allocateBuffer( + size_t batch_size, size_t beam_width, size_t max_seq_len, size_t max_cache_seq_len, size_t max_input_len); + void freeBuffer() override; + + void initialize(); + +protected: + T* padded_embedding_kernel_; + T* padded_embedding_bias_; + const T* padded_embedding_kernel_ptr_; + + T* input_attention_mask_; + + T* decoder_input_buf_; + T* decoder_output_buf_; + T* normed_decoder_output_buf_; + + float* logits_buf_; + float* nccl_logits_buf_; + float* cum_log_probs_; + + bool* finished_buf_; + bool* h_finished_buf_; + int* sequence_lengths_ = nullptr; + int* tiled_total_padding_count_ = nullptr; + uint32_t* seq_limit_len_ = nullptr; + + T* key_cache_; + T* value_cache_; + int* cache_indirections_[2] = {nullptr, nullptr}; + + // prompt_learning weight_batch ptrs + const T** prompt_learning_weight_batch_; + int* tiled_prompt_lengths_buf_; // only needed by prefix prompts + + int* tiled_input_ids_buf_; + int* tiled_input_lengths_buf_; + int* transposed_output_ids_buf_; + int* output_ids_buf_; + int* parent_ids_buf_; + int* start_ids_buf_; + int* end_ids_buf_; + bool* masked_tokens_ = nullptr; + + bool* generation_should_stop_ = nullptr; + + T* context_decoder_input_buf_; + T* context_decoder_output_buf_; + float* output_log_probs_buf_; + + // function pointer callback + using callback_sig = void(std::unordered_map*, void*); + callback_sig* token_generated_cb_ = nullptr; + void* token_generated_ctx_ = nullptr; + + void setOutputTensors(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + size_t max_seq_len); + void sendTensorsToFirstPipelineNode(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors); + +public: + GptNeoX(size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + bool use_gptj_residual, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + GptNeoX(size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + bool use_gptj_residual, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop = nullptr, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + GptNeoX(GptNeoX const& GptNeoX); + + ~GptNeoX(); + + void forward(std::vector* output_tensors, + const std::vector* input_tensors, + const GptNeoXWeight* gpt_weights); + + void forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const GptNeoXWeight* gpt_weights); + + size_t getPipelineParallelRank(); + size_t getPipelineParallelSize(); + size_t getTensorParallelRank(); + size_t getTensorParallelSize(); + bool* getFinishBuffer(); + + void registerCallback(callback_sig* fn, void* ctx); + void unRegisterCallback(); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/GptNeoXContextDecoder.cc b/src/fastertransformer/models/gptneox/GptNeoXContextDecoder.cc new file mode 100644 index 000000000..802c9db4f --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoXContextDecoder.cc @@ -0,0 +1,442 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/gptneox/GptNeoXContextDecoder.h" +#include "src/fastertransformer/kernels/gpt_kernels.h" + +#include "src/fastertransformer/layers/TensorParallelGeluFfnLayer.h" +#include "src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.h" + +namespace fastertransformer { + +template +void GptNeoXContextDecoder::initialize() +{ + self_attention_layer_ = new TensorParallelGptContextAttentionLayer(0, // max_batch_size + 0, // max_seq_len + head_num_, + size_per_head_, + rotary_embedding_dim_, + neox_rotary_style_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + !use_gptj_residual_, + is_free_buffer_after_forward_, + is_qk_buf_float_, + false, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + + ffn_layer_ = new TensorParallelGeluFfnLayer(0, // max_batch_size + 0, // max_seq_len + head_num_, + size_per_head_, + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + !use_gptj_residual_, + is_free_buffer_after_forward_, + false, + 0, + false, // use_gated_activation = false; + custom_all_reduce_comm_, + enable_custom_all_reduce_); +} + +template +void GptNeoXContextDecoder::allocateBuffer() +{ + FT_CHECK(false); +} + +template +void GptNeoXContextDecoder::allocateBuffer(size_t batch_size, size_t seq_len) +{ + decoder_normed_input_ = reinterpret_cast( + allocator_->reMalloc(decoder_normed_input_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + self_attn_output_ = reinterpret_cast( + allocator_->reMalloc(self_attn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + ffn_output_ = reinterpret_cast( + allocator_->reMalloc(ffn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + decoder_layer_output_ = reinterpret_cast( + allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + is_allocate_buffer_ = true; +} + +template +void GptNeoXContextDecoder::freeBuffer() +{ + if (is_allocate_buffer_ == true) { + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + allocator_->free((void**)(&ffn_output_)); + allocator_->free((void**)(&decoder_layer_output_)); + is_allocate_buffer_ = false; + } +} + +template +bool GptNeoXContextDecoder::isValidLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l >= local_num_layer * pipeline_para_.rank_) + && (l < local_num_layer * (pipeline_para_.rank_ + 1)); +} + +template +bool GptNeoXContextDecoder::isFirstLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * pipeline_para_.rank_); +} + +template +bool GptNeoXContextDecoder::isLastLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * (pipeline_para_.rank_ + 1) - 1); +} + +template +int GptNeoXContextDecoder::getFirstLayerParallelId() +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return local_num_layer * pipeline_para_.rank_; +} + +template +GptNeoXContextDecoder::GptNeoXContextDecoder(size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + bool use_gptj_residual, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), + head_num_(head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + use_gptj_residual_(use_gptj_residual), + layernorm_eps_(layernorm_eps), + hidden_units_(head_num * size_per_head), + tensor_para_(tensor_para), + pipeline_para_(pipeline_para), + is_qk_buf_float_(is_qk_buf_float), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce) +{ + initialize(); +} + +template +GptNeoXContextDecoder::GptNeoXContextDecoder(GptNeoXContextDecoder const& decoder): + BaseLayer(decoder.stream_, decoder.cublas_wrapper_, decoder.allocator_, decoder.is_free_buffer_after_forward_), + head_num_(decoder.head_num_), + size_per_head_(decoder.size_per_head_), + inter_size_(decoder.inter_size_), + num_layer_(decoder.num_layer_), + rotary_embedding_dim_(decoder.rotary_embedding_dim_), + neox_rotary_style_(decoder.neox_rotary_style_), + use_gptj_residual_(decoder.use_gptj_residual_), + layernorm_eps_(decoder.layernorm_eps_), + hidden_units_(decoder.hidden_units_), + tensor_para_(decoder.tensor_para_), + pipeline_para_(decoder.pipeline_para_), + is_qk_buf_float_(decoder.is_qk_buf_float_), + custom_all_reduce_comm_(decoder.custom_all_reduce_comm_), + enable_custom_all_reduce_(decoder.enable_custom_all_reduce_) +{ + initialize(); +} + +template +GptNeoXContextDecoder::~GptNeoXContextDecoder() +{ + delete self_attention_layer_; + delete ffn_layer_; + freeBuffer(); +} + +template +void GptNeoXContextDecoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* gpt_decoder_layer_weight) +{ + std::unordered_map input_tensors_map{{"decoder_input", input_tensors->at(0)}, + {"attention_mask", input_tensors->at(1)}, + {"input_lengths", input_tensors->at(2)}}; + std::unordered_map output_tensors_map{{"decoder_output", output_tensors->at(0)}, + {"key_cache", output_tensors->at(1)}, + {"value_cache", output_tensors->at(2)}, + {"last_token_hidden_units", output_tensors->at(3)}}; + + forward(&output_tensors_map, &input_tensors_map, gpt_decoder_layer_weight); +} + +template +void GptNeoXContextDecoder::forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* gpt_decoder_layer_weight) +{ + // input tensors: + // decoder_input [batch_size, seq_len, hidden_dimension], + // attention_mask [batch_size, 1, seq_len, seq_len + max_prompt_length] + // input_lengths [batch_size] + // d_prefix_prompt_batch [batch_size], + // each element contains ptr with buffer shape[2, local_head_num_, prompt_length, size_per_head] + // prefix_prompt_lengths [batch size] + + // output tensors: + // decoder_output [batch_size, seq_len, hidden_dimension], + // key_cache [num_layer, batch, local_head_num, size_per_head // x, max_seq_len, x] + // value_cache [num_layer, batch, local_head_num, max_seq_len, size_per_head] + // last_token_hidden_units [batch_size, hidden_dimension] + + // To use layer/pipeline parallelism, we view the shape of 'batch_size' to 'ite * local_batch_size'. + // For example, the shape of decoder_input becomes [ite, batch_size, seq_len, hidden_dimension] during + // computing. + + FT_CHECK(input_tensors->size() == 5); + FT_CHECK(output_tensors->size() == 4); + + const int batch_size = input_tensors->at("decoder_input").shape[0]; + const int seq_len = input_tensors->at("decoder_input").shape[1]; + const int max_prompt_length = + input_tensors->at("attention_mask").shape[3] - input_tensors->at("attention_mask").shape[2]; + const DataType data_type = getTensorType(); + allocateBuffer(batch_size, seq_len); + + T* decoder_input = (T*)input_tensors->at("decoder_input").data; + T* decoder_output = (T*)output_tensors->at("decoder_output").data; + const T* attention_mask = (const T*)input_tensors->at("attention_mask").data; + const T** d_prefix_prompt_batch = (const T**)input_tensors->at("d_prefix_prompt_batch").data; + const int* d_prefix_prompt_lengths = (const int*)input_tensors->at("d_prefix_prompt_lengths").data; + + const int local_batch_size = getLocalBatchSize(batch_size, seq_len, pipeline_para_.world_size_); + FT_CHECK(batch_size % local_batch_size == 0); + const int iteration_num = batch_size / local_batch_size; + + Tensor& k_cache = output_tensors->at("key_cache"); + Tensor& v_cache = output_tensors->at("value_cache"); + std::vector self_k_cache_size; + self_k_cache_size.push_back(local_batch_size); + for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { + self_k_cache_size.push_back(*t); + } + std::vector self_v_cache_size; + self_v_cache_size.push_back(local_batch_size); + for (auto t = v_cache.shape.begin() + 2; t != v_cache.shape.end(); ++t) { + self_v_cache_size.push_back(*t); + } + + for (int ite = 0; ite < iteration_num; ite++) { + for (int l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l) == false) { + continue; + } + + const bool is_final = false; // TODO(bhsueh) remove this flag + T* layer_input = + ((l == 0) ? decoder_input : decoder_layer_output_) + ite * local_batch_size * seq_len * hidden_units_; + T* layer_output = ((l == num_layer_ - 1) ? decoder_output : decoder_layer_output_) + + ite * local_batch_size * seq_len * hidden_units_; + + if (isFirstLayerParallelId(l) && pipeline_para_.rank_ != 0 && pipeline_para_.world_size_ > 1) { + int data_size = local_batch_size * seq_len * hidden_units_ / tensor_para_.world_size_; + ftNcclRecv(layer_input + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.rank_ - 1, + pipeline_para_, + stream_); + if (tensor_para_.world_size_ > 1) { + ftNcclAllGather(layer_input, layer_input, data_size, tensor_para_.rank_, tensor_para_, stream_); + } + } + + invokeGeneralLayerNorm(decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.beta, + layernorm_eps_, + local_batch_size * seq_len, + hidden_units_, + stream_); + sync_check_cuda_error(); + + std::vector self_attention_input_tensors{ + Tensor{MEMORY_GPU, + data_type, + {(size_t)(local_batch_size * seq_len), (size_t)hidden_units_}, + decoder_normed_input_}, + Tensor{MEMORY_GPU, + data_type, + {(size_t)local_batch_size, (size_t)1, (size_t)seq_len, (size_t)(seq_len + max_prompt_length)}, + attention_mask + local_batch_size * ite * seq_len * (seq_len + max_prompt_length)}, + Tensor{MEMORY_CPU, TYPE_BOOL, {(size_t)1}, &is_final}, + Tensor{MEMORY_GPU, + data_type, + {(size_t)local_batch_size}, + d_prefix_prompt_batch != nullptr ? d_prefix_prompt_batch + ite * local_batch_size : nullptr}, + Tensor{MEMORY_GPU, + TYPE_INT32, + {(size_t)local_batch_size}, + d_prefix_prompt_lengths != nullptr ? d_prefix_prompt_lengths + ite * local_batch_size : nullptr}, + Tensor{MEMORY_CPU, TYPE_INT32, {(size_t)1}, &l}}; + + size_t cache_offset = l - getFirstLayerParallelId(); + for (auto t = k_cache.shape.begin() + 1; t != k_cache.shape.end(); ++t) { + cache_offset *= *t; + }; + size_t ite_cache_offset = ite * local_batch_size; + for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { + ite_cache_offset *= *t; + } + cache_offset += ite_cache_offset; + + std::vector self_attention_output_tensors{ + Tensor{MEMORY_GPU, + data_type, + {(size_t)(local_batch_size * seq_len), (size_t)hidden_units_}, + self_attn_output_}, + Tensor{MEMORY_GPU, data_type, self_k_cache_size, ((const T*)k_cache.data) + cache_offset}, + Tensor{MEMORY_GPU, data_type, self_v_cache_size, ((const T*)v_cache.data) + cache_offset}}; + + self_attention_layer_->forward(&self_attention_output_tensors, + &self_attention_input_tensors, + &gpt_decoder_layer_weight->at(l)->self_attention_weights); + + if (is_final == false) { + if (use_gptj_residual_) { + invokeGeneralLayerNorm(decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.beta, + layernorm_eps_, + local_batch_size * seq_len, + hidden_units_, + stream_); + } + else { + invokeGeneralAddBiasResidualPreLayerNorm( + self_attn_output_, + decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.beta, + gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, + layernorm_eps_, + local_batch_size * seq_len, + hidden_units_, + stream_); + } + + std::vector ffn_input_tensors{ + Tensor{MEMORY_GPU, + data_type, + {(size_t)(local_batch_size * seq_len), (size_t)hidden_units_}, + decoder_normed_input_}}; + std::vector ffn_output_tensors{ + Tensor{MEMORY_GPU, + data_type, + {(size_t)(local_batch_size * seq_len), (size_t)hidden_units_}, + use_gptj_residual_ ? ffn_output_ : layer_output}}; + ffn_layer_->forward( + &ffn_output_tensors, &ffn_input_tensors, &gpt_decoder_layer_weight->at(l)->ffn_weights); + + if (use_gptj_residual_) { + // Original workflow: + // layer_output = layer_input + reduceSum(ffn_output + self_attn_output + ffn_output_bias) + // Our workflow: + // layer_output = reduceSum(ffn_output + self_attn_output + ffn_output_bias + layer_input / + // TP_size) + // They are equivalent on math, but we can use same buffer for layer_input and layer_output + + invokeAddBiasAttentionFfnResidual(layer_output, + ffn_output_, + self_attn_output_, + layer_input, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + local_batch_size * seq_len, + hidden_units_, + tensor_para_.world_size_, + stream_); + if (tensor_para_.world_size_ > 1) { + ftNcclAllReduceSum(layer_output, + layer_output, + local_batch_size * seq_len * hidden_units_, + tensor_para_, + stream_); + } + } + else { + invokeAddBiasResidual(layer_output, + self_attn_output_, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + local_batch_size * seq_len, + hidden_units_, + stream_); + } + + sync_check_cuda_error(); + + if (isLastLayerParallelId(l) && pipeline_para_.rank_ != pipeline_para_.world_size_ - 1 + && pipeline_para_.world_size_ > 1) { + int data_size = local_batch_size * seq_len * hidden_units_ / tensor_para_.world_size_; + ftNcclSend(layer_output + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.rank_ + 1, + pipeline_para_, + stream_); + } + } + } + } + + // TODO(bhsueh) We could optimize this point by only computing the last token for the last layer + invokeLookupHiddenStateOfLastToken((T*)output_tensors->at("last_token_hidden_units").data, + (T*)output_tensors->at("decoder_output").data, + (int*)input_tensors->at("input_lengths").data, + seq_len, + batch_size, + hidden_units_, + stream_); + sync_check_cuda_error(); + if (is_free_buffer_after_forward_ == true) { + freeBuffer(); + } +} + +template class GptNeoXContextDecoder; +template class GptNeoXContextDecoder; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/GptNeoXContextDecoder.h b/src/fastertransformer/models/gptneox/GptNeoXContextDecoder.h new file mode 100644 index 000000000..48e81331d --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoXContextDecoder.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "src/fastertransformer/kernels/add_residual_kernels.h" +#include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/layers/BaseLayer.h" +#include "src/fastertransformer/layers/FfnLayer.h" +#include "src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h" +#include "src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/allocator.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class GptNeoXContextDecoder: public BaseLayer { +private: + // meta data + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t rotary_embedding_dim_; + bool neox_rotary_style_; + bool use_gptj_residual_; + float layernorm_eps_; + + // calculated data + size_t hidden_units_; + + NcclParam tensor_para_; + NcclParam pipeline_para_; + + std::shared_ptr custom_all_reduce_comm_; + int enable_custom_all_reduce_; + + bool is_qk_buf_float_; + + BaseAttentionLayer* self_attention_layer_; + FfnLayer* ffn_layer_; + + void allocateBuffer() override; + void allocateBuffer(size_t batch_size, size_t seq_len); + void freeBuffer() override; + + bool isValidLayerParallelId(uint l); + bool isFirstLayerParallelId(uint l); + bool isLastLayerParallelId(uint l); + int getFirstLayerParallelId(); + + void initialize(); + +protected: + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; + T* ffn_output_ = nullptr; + T* decoder_layer_output_ = nullptr; + +public: + GptNeoXContextDecoder(size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + bool use_gptj_residual, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce_ = 0); + + GptNeoXContextDecoder(GptNeoXContextDecoder const& decoder); + + ~GptNeoXContextDecoder(); + + void forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* decoder_layer_weights); + + void forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* gpt_decoder_layer_weight); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/GptNeoXDecoder.cc b/src/fastertransformer/models/gptneox/GptNeoXDecoder.cc new file mode 100644 index 000000000..fd8fe3bd2 --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoXDecoder.cc @@ -0,0 +1,385 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/gptneox/GptNeoXDecoder.h" +#include "src/fastertransformer/layers/TensorParallelGeluFfnLayer.h" +#include "src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.h" + +namespace fastertransformer { + +template +void GptNeoXDecoder::initialize() +{ + self_attention_layer_ = new TensorParallelDecoderSelfAttentionLayer(0, // max_batch_size + head_num_, + size_per_head_, + rotary_embedding_dim_, + neox_rotary_style_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + !use_gptj_residual_, + is_free_buffer_after_forward_, + false, + 0, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + + ffn_layer_ = new TensorParallelGeluFfnLayer(0, // max_batch_size + 1, + head_num_, + size_per_head_, + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + !use_gptj_residual_, + is_free_buffer_after_forward_, + false, + 0, + false, // use_gated_activation = false; + custom_all_reduce_comm_, + enable_custom_all_reduce_); +} + +template +void GptNeoXDecoder::allocateBuffer() +{ + FT_CHECK(false); +} + +template +void GptNeoXDecoder::allocateBuffer(size_t batch_size) +{ + decoder_normed_input_ = reinterpret_cast( + allocator_->reMalloc(decoder_normed_input_, sizeof(T) * batch_size * hidden_units_, false)); + self_attn_output_ = + reinterpret_cast(allocator_->reMalloc(self_attn_output_, sizeof(T) * batch_size * hidden_units_, false)); + ffn_output_ = + reinterpret_cast(allocator_->reMalloc(ffn_output_, sizeof(T) * batch_size * hidden_units_, false)); + decoder_layer_output_ = reinterpret_cast( + allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * hidden_units_, false)); + is_allocate_buffer_ = true; +} + +template +void GptNeoXDecoder::freeBuffer() +{ + if (is_allocate_buffer_ == true) { + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + allocator_->free((void**)(&ffn_output_)); + allocator_->free((void**)(&decoder_layer_output_)); + is_allocate_buffer_ = false; + } +} + +template +bool GptNeoXDecoder::isValidLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l >= local_num_layer * pipeline_para_.rank_) + && (l < local_num_layer * (pipeline_para_.rank_ + 1)); +} + +template +bool GptNeoXDecoder::isFirstLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * pipeline_para_.rank_); +} + +template +bool GptNeoXDecoder::isLastLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * (pipeline_para_.rank_ + 1) - 1); +} + +template +int GptNeoXDecoder::getFirstLayerParallelId() +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return local_num_layer * pipeline_para_.rank_; +} + +template +GptNeoXDecoder::GptNeoXDecoder(size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + bool use_gptj_residual, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), + head_num_(head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + use_gptj_residual_(use_gptj_residual), + layernorm_eps_(layernorm_eps), + hidden_units_(head_num_ * size_per_head), + tensor_para_(tensor_para), + pipeline_para_(pipeline_para), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce) +{ + initialize(); +} + +template +GptNeoXDecoder::GptNeoXDecoder(GptNeoXDecoder const& decoder): + BaseLayer(decoder.stream_, decoder.cublas_wrapper_, decoder.allocator_, decoder.is_free_buffer_after_forward_), + head_num_(decoder.head_num_), + size_per_head_(decoder.size_per_head_), + inter_size_(decoder.inter_size_), + num_layer_(decoder.num_layer_), + rotary_embedding_dim_(decoder.rotary_embedding_dim_), + neox_rotary_style_(decoder.neox_rotary_style_), + use_gptj_residual_(decoder.use_gptj_residual_), + layernorm_eps_(decoder.layernorm_eps_), + hidden_units_(decoder.hidden_units_), + tensor_para_(decoder.tensor_para_), + pipeline_para_(decoder.pipeline_para_), + custom_all_reduce_comm_(decoder.custom_all_reduce_comm_), + enable_custom_all_reduce_(decoder.enable_custom_all_reduce_) +{ + initialize(); +} + +template +GptNeoXDecoder::~GptNeoXDecoder() +{ + delete self_attention_layer_; + delete ffn_layer_; + freeBuffer(); +} + +template +void GptNeoXDecoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* gpt_decoder_layer_weight) +{ + FT_CHECK(false); +} + +template +void GptNeoXDecoder::forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* gpt_decoder_layer_weight) +{ + // input tensors: + // decoder_input [local_batch_size, hidden_dimension], + // finished [local_batch_size], + // sequence_lengths [local_batch_size] + // total_padding_tokens [local_batch_size], + // max_input_length [1] on cpu + // d_prefix_prompt_lengths [local_batch_size],on GPU + // max_prefix_prompt_length [1] on cpu + // step [1] on cpu + // ite [1] on cpu + // cache_indirection [local_batch_size / beam_width, beam_width, memory_len] + // Here, local_batch_size contains the beam_width, so local_batch_size / beam_width + // is real local_batch_size. + // masked_tokens[local_batch_size, memory_len] + + // output tensors: + // decoder_output [local_batch_size, hidden_dimension], + // key_cache [num_layer, batch_size, head_num, size_per_head // x, memory_len, x] + // value_cache [num_layer, batch_size, head_num, memory_len, size_per_head] + + FT_CHECK(input_tensors->size() == 11); + FT_CHECK(output_tensors->size() == 3); + + const DataType data_type = getTensorType(); + const size_t local_batch_size = input_tensors->at("decoder_input").shape[0]; + allocateBuffer(local_batch_size); + const int ite = *((int*)(input_tensors->at("ite").data)); + + T* decoder_input = (T*)input_tensors->at("decoder_input").data; + T* decoder_output = (T*)output_tensors->at("decoder_output").data; + + Tensor& k_cache = output_tensors->at("key_cache"); + Tensor& v_cache = output_tensors->at("value_cache"); + std::vector self_k_cache_size; + self_k_cache_size.push_back(local_batch_size); + for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { + self_k_cache_size.push_back(*t); + } + std::vector self_v_cache_size; + self_v_cache_size.push_back(local_batch_size); + for (auto t = v_cache.shape.begin() + 2; t != v_cache.shape.end(); ++t) { + self_v_cache_size.push_back(*t); + } + + for (uint l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l) == false) { + continue; + } + T* layer_input = (l == 0) ? decoder_input : decoder_layer_output_; + T* layer_output = (l == num_layer_ - 1) ? decoder_output : decoder_layer_output_; + + if (isFirstLayerParallelId(l) == true && pipeline_para_.rank_ != 0 && pipeline_para_.world_size_ > 1) { + int data_size = local_batch_size * hidden_units_ / tensor_para_.world_size_; + // ftNcclRecv(layer_input, local_batch_size * hidden_units_, pipeline_para_.rank_ - 1, pipeline_para_, + // stream_); + + ftNcclRecv(layer_input + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.rank_ - 1, + pipeline_para_, + stream_); + if (tensor_para_.world_size_ > 1) { + ftNcclAllGather(layer_input, layer_input, data_size, tensor_para_.rank_, tensor_para_, stream_); + } + } + + invokeGeneralLayerNorm(decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.beta, + layernorm_eps_, + local_batch_size, + hidden_units_, + stream_); + sync_check_cuda_error(); + + std::vector self_attention_input_tensors{ + Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, decoder_normed_input_}, + input_tensors->at("finished"), + input_tensors->at("sequence_lengths"), + input_tensors->at("total_padding_tokens"), + input_tensors->at("d_prefix_prompt_lengths"), + input_tensors->at("max_prefix_prompt_length"), + input_tensors->at("max_input_length"), + input_tensors->at("step"), + input_tensors->at("cache_indirection"), + input_tensors->at("masked_tokens")}; + + size_t cache_offset = l - getFirstLayerParallelId(); + for (auto t = k_cache.shape.begin() + 1; t != k_cache.shape.end(); ++t) { + cache_offset *= *t; + }; + size_t ite_cache_offset = ite * local_batch_size; + for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { + ite_cache_offset *= *t; + } + cache_offset += ite_cache_offset; + + std::vector self_attention_output_tensors{ + Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, self_attn_output_}, + Tensor{MEMORY_GPU, data_type, self_k_cache_size, ((const T*)k_cache.data) + cache_offset}, + Tensor{MEMORY_GPU, data_type, self_v_cache_size, ((const T*)v_cache.data) + cache_offset}}; + + self_attention_layer_->forward(&self_attention_output_tensors, + &self_attention_input_tensors, + &gpt_decoder_layer_weight->at(l)->self_attention_weights); + if (use_gptj_residual_) { + invokeGeneralLayerNorm(decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.beta, + layernorm_eps_, + local_batch_size, + hidden_units_, + stream_); + } + else { + invokeGeneralAddBiasResidualPreLayerNorm( + self_attn_output_, + decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.beta, + gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, + layernorm_eps_, + local_batch_size, + hidden_units_, + stream_); + } + + std::vector ffn_input_tensors{ + Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, decoder_normed_input_}}; + std::vector ffn_output_tensors{Tensor{ + MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, use_gptj_residual_ ? ffn_output_ : layer_output}}; + ffn_layer_->forward(&ffn_output_tensors, &ffn_input_tensors, &gpt_decoder_layer_weight->at(l)->ffn_weights); + + if (use_gptj_residual_) { + // Original workflow: + // layer_output = layer_input + reduceSum(ffn_output + self_attn_output + ffn_output_bias) + // Our workflow: + // layer_output = reduceSum(ffn_output + self_attn_output + ffn_output_bias + layer_input / TP_size) + // They are equivalent on math, but we can use same buffer for layer_input and layer_output + invokeAddBiasAttentionFfnResidual(layer_output, + ffn_output_, + self_attn_output_, + layer_input, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + local_batch_size, + hidden_units_, + tensor_para_.world_size_, + stream_); + if (tensor_para_.world_size_ > 1) { + ftNcclAllReduceSum(layer_output, layer_output, local_batch_size * hidden_units_, tensor_para_, stream_); + } + } + else { + invokeAddBiasResidual(layer_output, + self_attn_output_, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + local_batch_size, + hidden_units_, + stream_); + } + + sync_check_cuda_error(); + + if (isLastLayerParallelId(l) == true && pipeline_para_.rank_ != pipeline_para_.world_size_ - 1 + && pipeline_para_.world_size_ > 1) { + int data_size = local_batch_size * hidden_units_ / tensor_para_.world_size_; + // ftNcclSend(layer_output, local_batch_size * hidden_units_, pipeline_para_.rank_ + 1, pipeline_para_, + // stream_); + + ftNcclSend(layer_output + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.rank_ + 1, + pipeline_para_, + stream_); + } + } + + if (is_free_buffer_after_forward_ == true) { + freeBuffer(); + } +} + +template class GptNeoXDecoder; +template class GptNeoXDecoder; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/GptNeoXDecoder.h b/src/fastertransformer/models/gptneox/GptNeoXDecoder.h new file mode 100644 index 000000000..c9d7fa7a6 --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoXDecoder.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "src/fastertransformer/kernels/add_residual_kernels.h" +#include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/layers/BaseLayer.h" +#include "src/fastertransformer/layers/FfnLayer.h" +#include "src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h" +#include "src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/allocator.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class GptNeoXDecoder: public BaseLayer { +private: +protected: + void allocateBuffer() override; + void allocateBuffer(size_t batch_size); + void freeBuffer() override; + bool isValidLayerParallelId(uint l); + bool isFirstLayerParallelId(uint l); + bool isLastLayerParallelId(uint l); + int getFirstLayerParallelId(); + virtual void initialize(); + + // meta data + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t rotary_embedding_dim_; + bool neox_rotary_style_; + bool use_gptj_residual_; + size_t hidden_units_; + float layernorm_eps_; + + NcclParam tensor_para_; + NcclParam pipeline_para_; + + std::shared_ptr custom_all_reduce_comm_; + int enable_custom_all_reduce_; + + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; + T* ffn_output_ = nullptr; + T* decoder_layer_output_ = nullptr; + + BaseAttentionLayer* self_attention_layer_; + FfnLayer* ffn_layer_; + +public: + GptNeoXDecoder(size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + bool use_gptj_residual, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce_ = 0); + + GptNeoXDecoder(GptNeoXDecoder const& decoder); + + virtual ~GptNeoXDecoder(); + + virtual void forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* decoder_layer_weights); + + virtual void forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* decoder_layer_weights); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.cc b/src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.cc new file mode 100644 index 000000000..148329706 --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.cc @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.h" +#include "src/fastertransformer/utils/memory_utils.h" + +namespace fastertransformer { + +template +GptNeoXDecoderLayerWeight::GptNeoXDecoderLayerWeight(const int hidden_units, + const int inter_size, + const int tensor_para_size, + const int tensor_para_rank, + const bool use_gptj_residual): + hidden_units_(hidden_units), + inter_size_(inter_size), + tensor_para_size_(tensor_para_size), + tensor_para_rank_(tensor_para_rank), + use_gptj_residual_(use_gptj_residual) +{ + mallocWeights(); + setWeightPtr(); +} + +template +GptNeoXDecoderLayerWeight::~GptNeoXDecoderLayerWeight() +{ + if (is_maintain_buffer == true) { + for (int i = 0; i < 12; i++) { + if (!use_gptj_residual_ && i != attention_dense_bias_weight_id) { + cudaFree(weights_ptr[i]); + } + } + + pre_layernorm_weights.beta = nullptr; + pre_layernorm_weights.gamma = nullptr; + self_attention_weights.query_weight.kernel = nullptr; + self_attention_weights.query_weight.bias = nullptr; + self_attention_weights.attention_output_weight.kernel = nullptr; + self_attention_weights.attention_output_weight.bias = nullptr; + post_attention_layernorm_weights.beta = nullptr; + post_attention_layernorm_weights.gamma = nullptr; + + ffn_weights.intermediate_weight.kernel = nullptr; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + is_maintain_buffer = false; + } +} + +template +GptNeoXDecoderLayerWeight::GptNeoXDecoderLayerWeight(const GptNeoXDecoderLayerWeight& other): + hidden_units_(other.hidden_units_), + inter_size_(other.inter_size_), + tensor_para_size_(other.tensor_para_size_), + tensor_para_rank_(other.tensor_para_rank_), + use_gptj_residual_(other.use_gptj_residual_) +{ + mallocWeights(); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], hidden_units_); + cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); + cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_ * 3 * hidden_units_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], 3 * hidden_units_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], hidden_units_ / tensor_para_size_ * hidden_units_); + if (!use_gptj_residual_) { + cudaD2Dcpy(weights_ptr[5], other.weights_ptr[5], hidden_units_); + } + + cudaD2Dcpy(weights_ptr[6], other.weights_ptr[6], hidden_units_ * inter_size_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[7], other.weights_ptr[7], inter_size_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[8], other.weights_ptr[8], inter_size_ / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[9], other.weights_ptr[9], hidden_units_); + cudaD2Dcpy(weights_ptr[10], other.weights_ptr[10], hidden_units_); + cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], hidden_units_); + setWeightPtr(); +} + +template +GptNeoXDecoderLayerWeight& GptNeoXDecoderLayerWeight::operator=(const GptNeoXDecoderLayerWeight& other) +{ + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + use_gptj_residual_ = other.use_gptj_residual_; + + mallocWeights(); + + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], hidden_units_); + cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); + cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_ * 3 * hidden_units_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], 3 * hidden_units_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], hidden_units_ / tensor_para_size_ * hidden_units_); + if (!use_gptj_residual_) { + cudaD2Dcpy(weights_ptr[5], other.weights_ptr[5], hidden_units_); + } + cudaD2Dcpy(weights_ptr[6], other.weights_ptr[6], hidden_units_ * inter_size_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[7], other.weights_ptr[7], inter_size_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[8], other.weights_ptr[8], inter_size_ / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[9], other.weights_ptr[9], hidden_units_); + cudaD2Dcpy(weights_ptr[10], other.weights_ptr[10], hidden_units_); + cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], hidden_units_); + setWeightPtr(); + return *this; +} + +template +void GptNeoXDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType model_file_type) +{ + FT_CHECK(is_maintain_buffer == true); + const std::string rank_spec = std::to_string(tensor_para_rank_); + + loadWeightFromBin( + weights_ptr[0], {(size_t)hidden_units_}, dir_path + ".input_layernorm.bias.bin", model_file_type); + loadWeightFromBin( + weights_ptr[1], {(size_t)hidden_units_}, dir_path + ".input_layernorm.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[2], + {(size_t)hidden_units_, (size_t)(3 * hidden_units_ / tensor_para_size_)}, + dir_path + ".attention.query_key_value.weight." + rank_spec + ".bin", + model_file_type); + + loadWeightFromBin(weights_ptr[3], + {(size_t)(3 * hidden_units_ / tensor_para_size_)}, + dir_path + ".attention.query_key_value.bias." + rank_spec + ".bin", + model_file_type); + + loadWeightFromBin(weights_ptr[4], + {(size_t)(hidden_units_ / tensor_para_size_), (size_t)hidden_units_}, + dir_path + ".attention.dense.weight." + rank_spec + ".bin", + model_file_type); + + if (!use_gptj_residual_) { + loadWeightFromBin( + weights_ptr[5], {(size_t)hidden_units_}, dir_path + ".attention.dense.bias.bin", model_file_type); + } + + loadWeightFromBin(weights_ptr[6], + {(size_t)hidden_units_, (size_t)(inter_size_ / tensor_para_size_)}, + dir_path + ".mlp.dense_h_to_4h.weight." + rank_spec + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[7], + {(size_t)(inter_size_ / tensor_para_size_)}, + dir_path + ".mlp.dense_h_to_4h.bias." + rank_spec + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[8], + {(size_t)(inter_size_ / tensor_para_size_), (size_t)hidden_units_}, + dir_path + ".mlp.dense_4h_to_h.weight." + rank_spec + ".bin", + model_file_type); + if (use_gptj_residual_) { + loadWeightFromBin( + weights_ptr[9], {(size_t)hidden_units_}, dir_path + ".mlp.attention.bias.sum.bin", model_file_type); + } + else { + loadWeightFromBin( + weights_ptr[9], {(size_t)hidden_units_}, dir_path + ".mlp.dense_4h_to_h.bias.bin", model_file_type); + } + loadWeightFromBin( + weights_ptr[10], {(size_t)hidden_units_}, dir_path + ".post_attention_layernorm.bias.bin", model_file_type); + loadWeightFromBin( + weights_ptr[11], {(size_t)hidden_units_}, dir_path + ".post_attention_layernorm.weight.bin", model_file_type); +} + +template +void GptNeoXDecoderLayerWeight::setWeightPtr() +{ + pre_layernorm_weights.beta = weights_ptr[0]; + pre_layernorm_weights.gamma = weights_ptr[1]; + self_attention_weights.query_weight.kernel = weights_ptr[2]; + self_attention_weights.query_weight.bias = weights_ptr[3]; + self_attention_weights.attention_output_weight.kernel = weights_ptr[4]; + self_attention_weights.attention_output_weight.bias = use_gptj_residual_ ? nullptr : weights_ptr[5]; + + ffn_weights.intermediate_weight.kernel = weights_ptr[6]; + ffn_weights.intermediate_weight.bias = weights_ptr[7]; + ffn_weights.output_weight.kernel = weights_ptr[8]; + ffn_weights.output_weight.bias = weights_ptr[9]; + + post_attention_layernorm_weights.beta = weights_ptr[10]; + post_attention_layernorm_weights.gamma = weights_ptr[11]; + is_maintain_buffer = true; +} + +template +void GptNeoXDecoderLayerWeight::mallocWeights() +{ + deviceMalloc(&weights_ptr[0], hidden_units_); + deviceMalloc(&weights_ptr[1], hidden_units_); + deviceMalloc(&weights_ptr[2], hidden_units_ * 3 * hidden_units_ / tensor_para_size_); + deviceMalloc(&weights_ptr[3], 3 * hidden_units_ / tensor_para_size_); + deviceMalloc(&weights_ptr[4], hidden_units_ / tensor_para_size_ * hidden_units_); + if (!use_gptj_residual_) { + deviceMalloc(&weights_ptr[5], hidden_units_); + } + + deviceMalloc(&weights_ptr[6], hidden_units_ * inter_size_ / tensor_para_size_); + deviceMalloc(&weights_ptr[7], inter_size_ / tensor_para_size_); + deviceMalloc(&weights_ptr[8], inter_size_ / tensor_para_size_ * hidden_units_); + deviceMalloc(&weights_ptr[9], hidden_units_); + deviceMalloc(&weights_ptr[10], hidden_units_); + deviceMalloc(&weights_ptr[11], hidden_units_); +} + +template struct GptNeoXDecoderLayerWeight; +template struct GptNeoXDecoderLayerWeight; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.h b/src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.h new file mode 100644 index 000000000..31e22224e --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/layers/FfnWeight.h" +#include "src/fastertransformer/layers/attention_layers/AttentionWeight.h" +#include "src/fastertransformer/utils/cuda_utils.h" + +namespace fastertransformer { + +template +struct GptNeoXDecoderLayerWeight { +public: + GptNeoXDecoderLayerWeight() = delete; + GptNeoXDecoderLayerWeight(const int hidden_units, + const int inter_size, + const int tensor_para_size = 1, + const int tensor_para_rank = 0, + const bool use_gptj_residual = true); + ~GptNeoXDecoderLayerWeight(); + GptNeoXDecoderLayerWeight(const GptNeoXDecoderLayerWeight& other); + GptNeoXDecoderLayerWeight& operator=(const GptNeoXDecoderLayerWeight& other); + + void loadModel(std::string dir_path, FtCudaDataType model_file_type); + + LayerNormWeight pre_layernorm_weights; + AttentionWeight self_attention_weights; + LayerNormWeight post_attention_layernorm_weights; + FfnWeight ffn_weights; + +private: + int hidden_units_; + int inter_size_; + int tensor_para_size_; + int tensor_para_rank_; + bool use_gptj_residual_; + const int attention_dense_bias_weight_id = 5; + bool is_maintain_buffer = false; + T* weights_ptr[12]; + + void setWeightPtr(); + void mallocWeights(); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/GptNeoXWeight.cc b/src/fastertransformer/models/gptneox/GptNeoXWeight.cc new file mode 100644 index 000000000..94be0948a --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoXWeight.cc @@ -0,0 +1,290 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/gptneox/GptNeoXWeight.h" + +namespace fastertransformer { + +template +GptNeoXWeight::GptNeoXWeight(const int hidden_units, + const int inter_size, + const int vocab_size, + const int num_layer, + const int max_seq_len, + const int tensor_para_size, + const int tensor_para_rank, + const int layer_para_size, + const int layer_para_rank, + const bool use_gptj_residual, + PromptLearningType prompt_learning_type, + std::map> prompt_learning_pair): + hidden_units_(hidden_units), + inter_size_(inter_size), + vocab_size_(vocab_size), + num_layer_(num_layer), + max_seq_len_(max_seq_len), + tensor_para_size_(tensor_para_size), + tensor_para_rank_(tensor_para_rank), + layer_para_size_(layer_para_size), + layer_para_rank_(layer_para_rank), + use_gptj_residual_(use_gptj_residual), + prompt_learning_type_(prompt_learning_type), + prompt_learning_pair_(prompt_learning_pair) +{ + FT_CHECK(num_layer_ % layer_para_size_ == 0); + // set prompt weight size + if (prompt_learning_type_ == PromptLearningType::prefix_prompt) { + prompt_token_weight_size_ = 2 * num_layer_ * hidden_units_ / tensor_para_size_; + } + else if (prompt_learning_type_ == PromptLearningType::p_prompt_tuning) { + prompt_token_weight_size_ = hidden_units_; + } + + // set if load and malloc prompt weights + malloc_load_prompt_weights_ = !prompt_learning_pair_.empty() + && (prompt_learning_type_ == PromptLearningType::p_prompt_tuning + || prompt_learning_type_ == PromptLearningType::prefix_prompt); + + decoder_layer_weights.reserve(num_layer_); + for (int l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l)) { + decoder_layer_weights.push_back(new GptNeoXDecoderLayerWeight( + hidden_units_, inter_size_, tensor_para_size_, tensor_para_rank_, use_gptj_residual_)); + } + else { + // Layer-parallelism: allocate empty layer because + // this rank does not compute it: + decoder_layer_weights.push_back(new GptNeoXDecoderLayerWeight(0, 0)); + } + } + + mallocWeights(); + setWeightPtr(); +} + +template +GptNeoXWeight::~GptNeoXWeight() +{ + if (is_maintain_buffer == true) { + for (int i = 0; i < weights_ptr.size(); i++) { + deviceFree(weights_ptr[i]); + } + + pre_decoder_embedding_table = nullptr; + post_decoder_layernorm.beta = nullptr; + post_decoder_layernorm.gamma = nullptr; + post_decoder_embedding.kernel = nullptr; + is_maintain_buffer = false; + } +} + +template +GptNeoXWeight::GptNeoXWeight(const GptNeoXWeight& other): + hidden_units_(other.hidden_units_), + inter_size_(other.inter_size_), + vocab_size_(other.vocab_size_), + num_layer_(other.num_layer_), + max_seq_len_(other.max_seq_len_), + tensor_para_size_(other.tensor_para_size_), + tensor_para_rank_(other.tensor_para_rank_), + layer_para_size_(other.layer_para_size_), + layer_para_rank_(other.layer_para_rank_), + use_gptj_residual_(other.use_gptj_residual_), + prompt_token_weight_size_(other.prompt_token_weight_size_), + malloc_load_prompt_weights_(other.malloc_load_prompt_weights_), + prompt_learning_type_(other.prompt_learning_type_), + prompt_learning_pair_(other.prompt_learning_pair_) +{ + mallocWeights(); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], vocab_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); + cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); + cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_ * vocab_size_); + + // prompt learning table: malloc weights and set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t prompt_id = num_base_weights + (size_t)task_name_id; + + // cuda device to device memcpy prompt table weights buffer memory + cudaD2Dcpy(weights_ptr[prompt_id], other.weights_ptr[prompt_id], prompt_length * prompt_token_weight_size_); + } + } + + setWeightPtr(); + + decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); + for (int l = 0; l < num_layer_; l++) { + decoder_layer_weights.push_back(other.decoder_layer_weights[l]); + } +} + +template +GptNeoXWeight& GptNeoXWeight::operator=(const GptNeoXWeight& other) +{ + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + vocab_size_ = other.vocab_size_; + num_layer_ = other.num_layer_; + max_seq_len_ = other.max_seq_len_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + layer_para_size_ = other.layer_para_size_; + layer_para_rank_ = other.layer_para_rank_; + use_gptj_residual_ = other.use_gptj_residual_; + prompt_token_weight_size_ = other.prompt_token_weight_size_; + malloc_load_prompt_weights_ = other.malloc_load_prompt_weights_; + prompt_learning_type_ = other.prompt_learning_type_; + prompt_learning_pair_ = other.prompt_learning_pair_; + + mallocWeights(); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], vocab_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); + cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); + cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_ * vocab_size_); + + // prompt learning table: malloc weights and set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t prompt_id = num_base_weights + (size_t)task_name_id; + + // cuda device to device memcpy prompt table weights buffer memory + cudaD2Dcpy(weights_ptr[prompt_id], other.weights_ptr[prompt_id], prompt_length * prompt_token_weight_size_); + } + } + + setWeightPtr(); + + decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); + for (int l = 0; l < num_layer_; l++) { + decoder_layer_weights.push_back(other.decoder_layer_weights[l]); + } + return *this; +} + +template +void GptNeoXWeight::setWeightPtr() +{ + prompt_learning_table.resize(prompt_learning_pair_.size()); + + pre_decoder_embedding_table = weights_ptr[0]; + post_decoder_layernorm.beta = weights_ptr[1]; + post_decoder_layernorm.gamma = weights_ptr[2]; + post_decoder_embedding.kernel = weights_ptr[3]; + + // prompt learning tables: set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + // set weight ptr + prompt_learning_table[task_name_id] = {weights_ptr[task_weight_id], prompt_length}; + } + } +} + +template +void GptNeoXWeight::mallocWeights() +{ + weights_ptr.resize(num_base_weights + prompt_learning_pair_.size()); + + deviceMalloc(&weights_ptr[0], vocab_size_ * hidden_units_); + deviceMalloc(&weights_ptr[1], hidden_units_); + deviceMalloc(&weights_ptr[2], hidden_units_); + deviceMalloc(&weights_ptr[3], hidden_units_ * vocab_size_); + + // prompt learning tables: malloc weights + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + // malloc weights + T* prompt_weights_ptr = nullptr; + deviceMalloc(&prompt_weights_ptr, prompt_length * prompt_token_weight_size_); + weights_ptr[task_weight_id] = prompt_weights_ptr; + } + } + is_maintain_buffer = true; +} + +template +void GptNeoXWeight::loadModel(std::string dir_path) +{ + FtCudaDataType model_file_type = getModelFileType(dir_path + "/config.ini", "gptneox"); + FT_CHECK(is_maintain_buffer == true); + + loadWeightFromBin( + weights_ptr[0], {(size_t)(vocab_size_ * hidden_units_)}, dir_path + "/model.wte.bin", model_file_type); + loadWeightFromBin( + weights_ptr[1], {(size_t)hidden_units_}, dir_path + "/model.final_layernorm.bias.bin", model_file_type); + loadWeightFromBin( + weights_ptr[2], {(size_t)hidden_units_}, dir_path + "/model.final_layernorm.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[3], + {(size_t)(vocab_size_ * hidden_units_)}, + dir_path + "/model.lm_head.weight.bin", + model_file_type); + + // prompt table: load weights from bin + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + std::string prompt_weight_path_name = (prompt_learning_type_ == PromptLearningType::p_prompt_tuning) ? + (dir_path + "/model.prompt_table." + task_name + ".weight.bin") : + (dir_path + "/model.prefix_prompt." + task_name + ".weight." + + std::to_string(tensor_para_rank_) + ".bin"); + + if (prompt_length > 0) { + loadWeightFromBin(weights_ptr[task_weight_id], + {(size_t)(prompt_length * (int)prompt_token_weight_size_)}, + prompt_weight_path_name, + model_file_type); + } + } + } + + for (int l = 0; l < num_layer_; l++) { + decoder_layer_weights[l]->loadModel(dir_path + "/model.layers." + std::to_string(l), model_file_type); + } +} + +template +bool GptNeoXWeight::isValidLayerParallelId(int l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / layer_para_size_)); + return l < num_layer_ && (l >= local_num_layer * layer_para_rank_) + && (l < local_num_layer * (layer_para_rank_ + 1)); +} + +template struct GptNeoXWeight; +template struct GptNeoXWeight; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/gptneox/GptNeoXWeight.h b/src/fastertransformer/models/gptneox/GptNeoXWeight.h new file mode 100644 index 000000000..4cdb06163 --- /dev/null +++ b/src/fastertransformer/models/gptneox/GptNeoXWeight.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/models/gptneox/GptNeoXDecoderLayerWeight.h" +#include "src/fastertransformer/utils/memory_utils.h" +#include "src/fastertransformer/utils/prompt_learning.h" + +namespace fastertransformer { + +template +struct GptNeoXWeight { + + GptNeoXWeight() = default; + GptNeoXWeight( + const int hidden_units, + const int inter_size, + const int vocab_size, + const int num_layer, + const int max_seq_len, + const int tensor_para_size = 1, + const int tensor_para_rank = 0, + const int layer_para_size = 1, + const int layer_para_rank = 0, + const bool use_gptj_residual_ = true, + PromptLearningType prompt_learning_type = PromptLearningType::no_prompt, + std::map> prompt_learning_pair = std::map>{}); + + ~GptNeoXWeight(); + GptNeoXWeight(const GptNeoXWeight& other); + GptNeoXWeight& operator=(const GptNeoXWeight& other); + + void loadModel(std::string dir_path); + + std::vector*> decoder_layer_weights; + const T* pre_decoder_embedding_table = nullptr; + // GPT-J does not use embedding table, but we leave the ptr such that + // GptNeoX::forward and Gpt::forward become identical + const T* position_encoding_table = nullptr; + + /* + prompt_learning_pair = vectors of [weight ptr, prompt length] pair + prompt_length is stored here for compatible prompt learning table + prefix_prompt weights store as shape [num_layers, 2, num_heads, perfix_seq_len, size_per_head] + p/prompt tuning weights store as shape [prompt_len, hidden_units] + idx is the task_name_id of the prompt tables + */ + std::vector> prompt_learning_table = {}; + + LayerNormWeight post_decoder_layernorm; + DenseWeight post_decoder_embedding; + +private: + void setWeightPtr(); + void mallocWeights(); + bool isValidLayerParallelId(int l); + + int hidden_units_; + int inter_size_; + int vocab_size_; + int num_layer_; + int max_seq_len_; + + int tensor_para_size_; + int tensor_para_rank_; + int layer_para_size_; + int layer_para_rank_; + + // residual type + bool use_gptj_residual_; + + // prompt learning pair (task_name, (task_name_id, prompt_len)) + PromptLearningType prompt_learning_type_; + std::map> prompt_learning_pair_; + bool malloc_load_prompt_weights_ = false; + // each prompt token's weight size + size_t prompt_token_weight_size_ = 0; + + bool is_maintain_buffer = false; + const size_t num_base_weights = 4; + std::vector weights_ptr = std::vector(num_base_weights); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/longformer/CMakeLists.txt b/src/fastertransformer/models/longformer/CMakeLists.txt index 46c3c8d62..495f762ca 100644 --- a/src/fastertransformer/models/longformer/CMakeLists.txt +++ b/src/fastertransformer/models/longformer/CMakeLists.txt @@ -19,4 +19,4 @@ set_property(TARGET LongformerEncoder PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET LongformerEncoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(LongformerEncoder PUBLIC -lcublas -lcudart -lcurand cublasMMWrapper LongformerAttentionLayer longformer_kernels add_bias_transpose_kernels - activation_kernels layernorm_kernels FfnLayer) \ No newline at end of file + activation_kernels layernorm_kernels FfnLayer tensor) diff --git a/src/fastertransformer/models/longformer/LongformerEncoder.cc b/src/fastertransformer/models/longformer/LongformerEncoder.cc index baab93c08..4b35fed74 100644 --- a/src/fastertransformer/models/longformer/LongformerEncoder.cc +++ b/src/fastertransformer/models/longformer/LongformerEncoder.cc @@ -27,20 +27,20 @@ namespace fastertransformer { template -LongformerEncoder::LongformerEncoder(size_t layers_num, - size_t in_dim, - size_t head_num, - size_t size_per_head, - size_t intermediate_size, - size_t local_attn_window_size, - size_t max_global_token_num, - size_t max_batch_size, - size_t max_seq_len, - float attn_scaler, - cudaStream_t stream, +LongformerEncoder::LongformerEncoder(size_t layers_num, + size_t in_dim, + size_t head_num, + size_t size_per_head, + size_t intermediate_size, + size_t local_attn_window_size, + size_t max_global_token_num, + size_t max_batch_size, + size_t max_seq_len, + float attn_scaler, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): layers_num_(layers_num), in_dim_(in_dim), head_num_(head_num), @@ -67,7 +67,7 @@ LongformerEncoder::LongformerEncoder(size_t layers_num, cublas_wrapper, allocator, is_free_buffer_after_forward); - inter_gelu_out_ffn_ = new GeluFfnLayer(max_batch_size, + inter_gelu_out_ffn_ = new GeluFfnLayer(max_batch_size, max_seq_len, head_num, size_per_head, @@ -91,21 +91,22 @@ template void LongformerEncoder::allocateBuffer() { if (!is_allocate_buffer_) { - cub_storage_ = (void*)allocator_->malloc(getInitLongformerCubStorage(max_seq_len_), false); - global_idx_ = (int*)allocator_->malloc(sizeof(int) * max_seq_len_ * max_batch_size_, false); - global_token_nums_ = (int*)allocator_->malloc(sizeof(int) * max_batch_size_, false); - seq_idx_ = (int*)allocator_->malloc(sizeof(int) * max_seq_len_, false); + cub_storage_ = (void*)allocator_->reMalloc(cub_storage_, getInitLongformerCubStorage(max_seq_len_), false); + global_idx_ = (int*)allocator_->reMalloc(global_idx_, sizeof(int) * max_seq_len_ * max_batch_size_, false); + global_token_nums_ = (int*)allocator_->reMalloc(global_token_nums_, sizeof(int) * max_batch_size_, false); + seq_idx_ = (int*)allocator_->reMalloc(seq_idx_, sizeof(int) * max_seq_len_, false); - local_attn_mask_shifted_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_, false); + local_attn_mask_shifted_ = + (T*)allocator_->reMalloc(local_attn_mask_shifted_, sizeof(T) * max_batch_size_ * max_seq_len_, false); size_t qkv_buffer_size = sizeof(T) * max_batch_size_ * hidden_units_ * 6 * max_seq_len_; size_t input_output_buffer_size = sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_; - qkv_buffer_ = (T*)allocator_->malloc(qkv_buffer_size, false); - mha_qkv_buffer_ = (T*)allocator_->malloc(qkv_buffer_size, false); - mha_out_buffer_ = (T*)allocator_->malloc(input_output_buffer_size, false); - attn_out_buffer_ = (T*)allocator_->malloc(input_output_buffer_size, false); - attn_output_buffer_ = (T*)allocator_->malloc(input_output_buffer_size, false); + qkv_buffer_ = (T*)allocator_->reMalloc(qkv_buffer_, qkv_buffer_size, false); + mha_qkv_buffer_ = (T*)allocator_->reMalloc(mha_qkv_buffer_, qkv_buffer_size, false); + mha_out_buffer_ = (T*)allocator_->reMalloc(mha_out_buffer_, input_output_buffer_size, false); + attn_out_buffer_ = (T*)allocator_->reMalloc(attn_out_buffer_, input_output_buffer_size, false); + attn_output_buffer_ = (T*)allocator_->reMalloc(attn_output_buffer_, input_output_buffer_size, false); is_allocate_buffer_ = true; } } @@ -114,16 +115,16 @@ template void LongformerEncoder::freeBuffer() { if (is_allocate_buffer_) { - allocator_->free(cub_storage_); - allocator_->free(global_idx_); - allocator_->free(global_token_nums_); - allocator_->free(seq_idx_); - allocator_->free(local_attn_mask_shifted_); - allocator_->free(qkv_buffer_); - allocator_->free(mha_qkv_buffer_); - allocator_->free(mha_out_buffer_); - allocator_->free(attn_out_buffer_); - allocator_->free(attn_output_buffer_); + allocator_->free((void**)(&cub_storage_)); + allocator_->free((void**)(&global_idx_)); + allocator_->free((void**)(&global_token_nums_)); + allocator_->free((void**)(&seq_idx_)); + allocator_->free((void**)(&local_attn_mask_shifted_)); + allocator_->free((void**)(&qkv_buffer_)); + allocator_->free((void**)(&mha_qkv_buffer_)); + allocator_->free((void**)(&mha_out_buffer_)); + allocator_->free((void**)(&attn_out_buffer_)); + allocator_->free((void**)(&attn_output_buffer_)); is_allocate_buffer_ = false; } } @@ -147,7 +148,7 @@ void LongformerEncoder::forward(std::vector* output_tensors, std::vec allocateBuffer(); const size_t batch_size = input_tensors->at(0).shape[0]; - const size_t seq_len = input_tensors->at(0).shape[1]; + const size_t seq_len = input_tensors->at(0).shape[1]; invokeInitLongformerIdx((T*)input_tensors->at(2).data, seq_idx_, @@ -179,20 +180,20 @@ void LongformerEncoder::forward(std::vector* output_tensors, std::vec } template -void LongformerEncoder::forwardLayer(T* input, - T* output, - const T* local_attn_mask, - const T* global_attn_mask, - const int* global_idx, - const int* global_token_nums, +void LongformerEncoder::forwardLayer(T* input, + T* output, + const T* local_attn_mask, + const T* global_attn_mask, + const int* global_idx, + const int* global_token_nums, const LongformerLayerWeight* weight, - const size_t batch_size, - const size_t seq_len, - const size_t in_dim_) + const size_t batch_size, + const size_t seq_len, + const size_t in_dim_) { - T* const q = qkv_buffer_; - T* const k = qkv_buffer_ + batch_size * hidden_units_ * seq_len; - T* const v = qkv_buffer_ + batch_size * hidden_units_ * 2 * seq_len; + T* const q = qkv_buffer_; + T* const k = qkv_buffer_ + batch_size * hidden_units_ * seq_len; + T* const v = qkv_buffer_ + batch_size * hidden_units_ * 2 * seq_len; T* const kg = qkv_buffer_ + batch_size * hidden_units_ * 3 * seq_len; T* const vg = qkv_buffer_ + batch_size * hidden_units_ * 4 * seq_len; T* const qg = qkv_buffer_ + batch_size * hidden_units_ * 5 * seq_len; @@ -276,9 +277,9 @@ void LongformerEncoder::forwardLayer(T* input, max_global_token_num_ * hidden_units_, batch_size); // reset all qkv pointer to transposed - T* const q_mha = mha_qkv_buffer_; - T* const k_mha = mha_qkv_buffer_ + batch_size * head_num_ * size_per_head_ * seq_len; - T* const v_mha = mha_qkv_buffer_ + batch_size * head_num_ * size_per_head_ * 2 * seq_len; + T* const q_mha = mha_qkv_buffer_; + T* const k_mha = mha_qkv_buffer_ + batch_size * head_num_ * size_per_head_ * seq_len; + T* const v_mha = mha_qkv_buffer_ + batch_size * head_num_ * size_per_head_ * 2 * seq_len; T* const kg_mha = mha_qkv_buffer_ + batch_size * head_num_ * size_per_head_ * 3 * seq_len; T* const vg_mha = mha_qkv_buffer_ + batch_size * head_num_ * size_per_head_ * 4 * seq_len; T* const qg_mha = mha_qkv_buffer_ + batch_size * head_num_ * size_per_head_ * 5 * seq_len; @@ -289,13 +290,13 @@ void LongformerEncoder::forwardLayer(T* input, qkv_buffer_, qkv_5_bias, mha_qkv_buffer_, batch_size, head_num_, size_per_head_, seq_len, 5, stream_); sync_check_cuda_error(); - // calculate qg seperately cause the dimension is not the same with others. + // calculate qg separately cause the dimension is not the same with others. const T* qg_bias = weight->global_query_weights.bias; invokeAddBiasTransposeToMultiHead( qg, qg_bias, qg_mha, batch_size, head_num_, size_per_head_, max_global_token_num_, 1, stream_); sync_check_cuda_error(); - DataType data_type = getTensorType(); + DataType data_type = getTensorType(); std::vector attn_inputs{ Tensor{MEMORY_GPU, data_type, std::vector{batch_size, head_num_, seq_len, size_per_head_}, q_mha}, Tensor{MEMORY_GPU, data_type, std::vector{batch_size, head_num_, seq_len, size_per_head_}, k_mha}, @@ -337,6 +338,7 @@ void LongformerEncoder::forwardLayer(T* input, weight->attention_output_weights.bias, weight->attention_output_layernorm_weights.gamma, weight->attention_output_layernorm_weights.beta, + layernorm_eps_, batch_size * seq_len, hidden_units_, stream_); @@ -357,6 +359,7 @@ void LongformerEncoder::forwardLayer(T* input, weight->ffn_weights.output_weight.bias, weight->output_layernorm_weights.gamma, weight->output_layernorm_weights.beta, + layernorm_eps_, batch_size * seq_len, hidden_units_, stream_); @@ -365,5 +368,8 @@ void LongformerEncoder::forwardLayer(T* input, template class LongformerEncoder; template class LongformerEncoder; +#ifdef ENABLE_BF16 +template class LongformerEncoder<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/models/longformer/LongformerEncoder.h b/src/fastertransformer/models/longformer/LongformerEncoder.h index bb70c7875..edf731054 100644 --- a/src/fastertransformer/models/longformer/LongformerEncoder.h +++ b/src/fastertransformer/models/longformer/LongformerEncoder.h @@ -28,14 +28,14 @@ namespace fastertransformer { template struct LongformerLayerWeight { - DenseWeight query_weights; - DenseWeight key_weights; - DenseWeight value_weights; - DenseWeight global_query_weights; - DenseWeight global_key_weights; - DenseWeight global_value_weights; - DenseWeight attention_output_weights; - FfnWeight ffn_weights; + DenseWeight query_weights; + DenseWeight key_weights; + DenseWeight value_weights; + DenseWeight global_query_weights; + DenseWeight global_key_weights; + DenseWeight global_value_weights; + DenseWeight attention_output_weights; + FfnWeight ffn_weights; LayerNormWeight attention_output_layernorm_weights; LayerNormWeight output_layernorm_weights; }; @@ -43,74 +43,75 @@ struct LongformerLayerWeight { template class LongformerEncoder { private: - size_t layers_num_; - size_t in_dim_; - size_t head_num_; - size_t size_per_head_; - size_t hidden_units_; - size_t intermediate_size_; - size_t max_batch_size_; - size_t max_global_token_num_; - size_t max_seq_len_; + size_t layers_num_; + size_t in_dim_; + size_t head_num_; + size_t size_per_head_; + size_t hidden_units_; + size_t intermediate_size_; + size_t max_batch_size_; + size_t max_global_token_num_; + size_t max_seq_len_; + static constexpr float layernorm_eps_ = 1e-6f; // internal buffers void* cub_storage_; - int* global_idx_; - int* global_token_nums_; - int* seq_idx_; - T* local_attn_mask_shifted_; - T* qkv_buffer_; - T* mha_qkv_buffer_; - T* mha_out_buffer_; - T* attn_out_buffer_; - T* attn_output_buffer_; - T* intermediate_buffer_; + int* global_idx_; + int* global_token_nums_; + int* seq_idx_; + T* local_attn_mask_shifted_; + T* qkv_buffer_; + T* mha_qkv_buffer_; + T* mha_out_buffer_; + T* attn_out_buffer_; + T* attn_output_buffer_; + T* intermediate_buffer_; - GeluFfnLayer* inter_gelu_out_ffn_; + GeluFfnLayer* inter_gelu_out_ffn_; LongformerAttentionLayer* longformer_attn_layer_; std::vector> weights_; cublasMMWrapper* cublas_wrapper_; - IAllocator* allocator_; - cudaStream_t stream_; - bool is_free_buffer_after_forward_; - bool is_allocate_buffer_ = false; + IAllocator* allocator_; + cudaStream_t stream_; + bool is_free_buffer_after_forward_; + bool is_allocate_buffer_ = false; public: - LongformerEncoder(size_t layers_num, - size_t in_dim, - size_t head_num, - size_t size_per_head, - size_t intermediate_size, - size_t local_attn_window_size, - size_t max_global_token_num, - size_t max_batch_size, - size_t max_seq_len, - float attn_scaler, - cudaStream_t stream, + LongformerEncoder(size_t layers_num, + size_t in_dim, + size_t head_num, + size_t size_per_head, + size_t intermediate_size, + size_t local_attn_window_size, + size_t max_global_token_num, + size_t max_batch_size, + size_t max_seq_len, + float attn_scaler, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); ~LongformerEncoder(); void forward(std::vector* output_tensors, std::vector* input_tensors); std::vector>* getWeightsPtr(); private: size_t getInitCubStorage(const int seq_len); - void initLongformerIdx(const T* global_attn_mask, const int seq_len, const int batch_size); - void allocateBuffer(); - void freeBuffer(); - void forwardLayer(T* input, - T* output, - const T* local_attn_mask, - const T* global_attn_mask, - const int* global_idx, - const int* global_token_nums, - const LongformerLayerWeight* weight, - const size_t batch_size, - const size_t seq_len, - const size_t in_dim); + void initLongformerIdx(const T* global_attn_mask, const int seq_len, const int batch_size); + void allocateBuffer(); + void freeBuffer(); + void forwardLayer(T* input, + T* output, + const T* local_attn_mask, + const T* global_attn_mask, + const int* global_idx, + const int* global_token_nums, + const LongformerLayerWeight* weight, + const size_t batch_size, + const size_t seq_len, + const size_t in_dim); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt b/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt index 10b9e0b80..202cd6fc4 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt +++ b/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt @@ -27,16 +27,16 @@ target_link_libraries(ParallelGptWeight PUBLIC ParallelGptDecoderLayerWeight) add_library(ParallelGptContextDecoder STATIC ParallelGptContextDecoder.cc) set_property(TARGET ParallelGptContextDecoder PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET ParallelGptContextDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(ParallelGptContextDecoder PUBLIC -lcudart TensorParallelGeluFfnLayer +target_link_libraries(ParallelGptContextDecoder PUBLIC -lcudart TensorParallelGeluFfnLayer TensorParallelReluFfnLayer TensorParallelGptContextAttentionLayer layernorm_kernels - add_residual_kernels nccl_utils gpt_kernels) + add_residual_kernels bert_preprocess_kernels nccl_utils gpt_kernels tensor) add_library(ParallelGptDecoder STATIC ParallelGptDecoder.cc) set_property(TARGET ParallelGptDecoder PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET ParallelGptDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(ParallelGptDecoder PUBLIC -lcudart TensorParallelGeluFfnLayer +target_link_libraries(ParallelGptDecoder PUBLIC -lcudart TensorParallelGeluFfnLayer TensorParallelReluFfnLayer TensorParallelDecoderSelfAttentionLayer layernorm_kernels - add_residual_kernels nccl_utils) + add_residual_kernels nccl_utils tensor) add_library(ParallelGpt STATIC ParallelGpt.cc) set_property(TARGET ParallelGpt PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc index d74dc6c45..0fa16dbc7 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc @@ -1,6 +1,7 @@ /* * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * Copyright (c) 2022, SK Telecom Authored by A. Dialog * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +35,8 @@ void ParallelGpt::initialize() size_per_head_, inter_size_, num_layer_, + layernorm_eps_, + gpt_variant_params_, tensor_para_, pipeline_para_, stream_, @@ -43,13 +46,16 @@ void ParallelGpt::initialize() is_context_qk_buf_float_, sparse_, custom_all_reduce_comm_, - enable_custom_all_reduce_); + enable_custom_all_reduce_, + remove_padding_); gpt_decoder_ = new ParallelGptDecoder(0, head_num_, size_per_head_, inter_size_, num_layer_, + layernorm_eps_, + gpt_variant_params_, tensor_para_, pipeline_para_, stream_, @@ -80,14 +86,15 @@ void ParallelGpt::allocateBuffer() template void ParallelGpt::allocateBuffer(size_t batch_size, size_t beam_width, - size_t max_seq_len, + size_t max_session_len, + size_t memory_len, size_t max_input_len, - bool is_return_context_cum_log_probs) + bool is_return_context_cum_log_probs) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); const size_t batchxbeam = batch_size * beam_width; const size_t self_cache_size = - (num_layer_ / pipeline_para_.world_size_) * batchxbeam * max_seq_len * hidden_units_ / tensor_para_.world_size_; + (num_layer_ / pipeline_para_.world_size_) * batchxbeam * memory_len * hidden_units_ / tensor_para_.world_size_; if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ = @@ -95,8 +102,8 @@ void ParallelGpt::allocateBuffer(size_t batch_size, padded_embedding_kernel_ptr_ = padded_embedding_kernel_; } - input_attention_mask_ = - (T*)(allocator_->reMalloc(input_attention_mask_, sizeof(T) * batchxbeam * max_seq_len * max_seq_len, false)); + input_attention_mask_ = (T*)(allocator_->reMalloc( + input_attention_mask_, sizeof(T) * batchxbeam * max_input_len * max_input_len, false)); decoder_input_buf_ = (T*)(allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); decoder_output_buf_ = (T*)(allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); @@ -105,46 +112,67 @@ void ParallelGpt::allocateBuffer(size_t batch_size, logits_buf_ = (float*)(allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size_padded_, false)); nccl_logits_buf_ = (float*)(allocator_->reMalloc(nccl_logits_buf_, sizeof(float) * batchxbeam * vocab_size_padded_, false)); - cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); - finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); - h_finished_buf_ = new bool[batchxbeam]; + cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); + finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); + h_finished_buf_ = new bool[batchxbeam]; + sequence_lengths_ = (int*)(allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false)); - key_cache_ = (T*)(allocator_->reMalloc(key_cache_, sizeof(T) * self_cache_size * 2, true)); + key_cache_ = (T*)(allocator_->reMalloc(key_cache_, sizeof(T) * self_cache_size * 2, true)); value_cache_ = key_cache_ + self_cache_size; if (beam_width > 1) { cache_indirections_[0] = - (int*)(allocator_->reMalloc(cache_indirections_[0], sizeof(int) * batchxbeam * max_seq_len * 2, true)); - cache_indirections_[1] = cache_indirections_[0] + batchxbeam * max_seq_len; + (int*)(allocator_->reMalloc(cache_indirections_[0], sizeof(int) * batchxbeam * memory_len * 2, true)); + cache_indirections_[1] = cache_indirections_[0] + batchxbeam * memory_len; } tiled_input_ids_buf_ = - (int*)(allocator_->reMalloc(tiled_input_ids_buf_, sizeof(int) * batchxbeam * max_input_len, true)); + (int*)(allocator_->reMalloc(tiled_input_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true)); tiled_input_lengths_buf_ = (int*)(allocator_->reMalloc(tiled_input_lengths_buf_, sizeof(int) * batchxbeam, true)); + // prompt_learning weight batch ptrs + prompt_learning_weight_batch_ = + (const T**)(allocator_->reMalloc(prompt_learning_weight_batch_, sizeof(T*) * batchxbeam, false)); + tiled_prompt_lengths_buf_ = + (int*)(allocator_->reMalloc(tiled_prompt_lengths_buf_, sizeof(int) * batchxbeam, false)); + start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * batch_size, false)); - end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false)); + end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false)); transposed_output_ids_buf_ = - (int*)(allocator_->reMalloc(transposed_output_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); - output_ids_buf_ = (int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); - parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); + (int*)(allocator_->reMalloc(transposed_output_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true)); + output_ids_buf_ = (int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true)); + parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true)); + seq_limit_len_ = (uint32_t*)(allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false)); + masked_tokens_ = (bool*)(allocator_->reMalloc(masked_tokens_, sizeof(bool) * batchxbeam * memory_len, true)); - context_decoder_input_buf_ = (T*)(allocator_->reMalloc( + context_decoder_input_buf_ = (T*)(allocator_->reMalloc( context_decoder_input_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); context_decoder_output_buf_ = (T*)(allocator_->reMalloc( context_decoder_output_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); output_log_probs_buf_ = - (float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_seq_len, false)); + (float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_session_len, false)); if (is_return_context_cum_log_probs) { lp_normed_decoder_output_buf_ = (T*)allocator_->reMalloc( lp_normed_decoder_output_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_); - lp_logits_buf_ = (float*)allocator_->reMalloc(lp_logits_buf_, + lp_logits_buf_ = (float*)allocator_->reMalloc(lp_logits_buf_, sizeof(float) * batchxbeam * max_input_len * vocab_size_padded_); lp_nccl_logits_buf_ = (float*)allocator_->reMalloc( lp_nccl_logits_buf_, sizeof(float) * batchxbeam * max_input_len * vocab_size_padded_); lp_logprob_buf_ = (float*)allocator_->reMalloc(lp_logprob_buf_, sizeof(float) * batchxbeam * max_input_len); } + if (shared_contexts_ratio_ > 0.0f) { + shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, 3 * batch_size * sizeof(int), false); + batch_to_compact_idx_ = shared_contexts_idx_ + batch_size; + compact_idx_ = shared_contexts_idx_ + 2 * batch_size; + compact_size_ = (int*)allocator_->reMalloc(compact_size_, sizeof(int), false); + } + + if (generation_should_stop_ == nullptr) { + cudaMallocHost(&generation_should_stop_, 1 * sizeof(bool)); + } + tiled_total_padding_count_ = + (int*)allocator_->reMalloc(tiled_total_padding_count_, batchxbeam * sizeof(int), false); is_allocate_buffer_ = true; } @@ -155,77 +183,97 @@ void ParallelGpt::freeBuffer() if (is_allocate_buffer_) { if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ptr_ = nullptr; - allocator_->free(padded_embedding_kernel_); + allocator_->free((void**)(&padded_embedding_kernel_)); } - allocator_->free(input_attention_mask_); - allocator_->free(decoder_input_buf_); - allocator_->free(decoder_output_buf_); - allocator_->free(normed_decoder_output_buf_); - allocator_->free(logits_buf_); - allocator_->free(nccl_logits_buf_); - allocator_->free(cum_log_probs_); - allocator_->free(finished_buf_); + allocator_->free((void**)(&input_attention_mask_)); + allocator_->free((void**)(&decoder_input_buf_)); + allocator_->free((void**)(&decoder_output_buf_)); + allocator_->free((void**)(&normed_decoder_output_buf_)); + allocator_->free((void**)(&logits_buf_)); + allocator_->free((void**)(&nccl_logits_buf_)); + allocator_->free((void**)(&cum_log_probs_)); + allocator_->free((void**)(&finished_buf_)); delete[] h_finished_buf_; + allocator_->free((void**)(&sequence_lengths_)); - allocator_->free(key_cache_); + allocator_->free((void**)(&key_cache_)); if (cache_indirections_[0] != nullptr) { - allocator_->free(cache_indirections_[0]); + allocator_->free((void**)(&cache_indirections_)[0]); } - allocator_->free(tiled_input_ids_buf_); - allocator_->free(tiled_input_lengths_buf_); + allocator_->free((void**)(&tiled_input_ids_buf_)); + allocator_->free((void**)(&tiled_input_lengths_buf_)); + + allocator_->free((void**)(&prompt_learning_weight_batch_)); + allocator_->free((void**)(&tiled_prompt_lengths_buf_)); + + allocator_->free((void**)(&transposed_output_ids_buf_)); + allocator_->free((void**)(&output_ids_buf_)); + allocator_->free((void**)(&parent_ids_buf_)); + allocator_->free((void**)(&masked_tokens_)); + + allocator_->free((void**)(&seq_limit_len_)); - allocator_->free(transposed_output_ids_buf_); - allocator_->free(output_ids_buf_); - allocator_->free(parent_ids_buf_); + allocator_->free((void**)(&start_ids_buf_)); + allocator_->free((void**)(&end_ids_buf_)); - allocator_->free(start_ids_buf_); - allocator_->free(end_ids_buf_); + allocator_->free((void**)(&context_decoder_input_buf_)); + allocator_->free((void**)(&context_decoder_output_buf_)); + allocator_->free((void**)(&output_log_probs_buf_)); - allocator_->free(context_decoder_input_buf_); - allocator_->free(context_decoder_output_buf_); - allocator_->free(output_log_probs_buf_); + allocator_->free((void**)(&lp_normed_decoder_output_buf_)); + allocator_->free((void**)(&lp_logits_buf_)); + allocator_->free((void**)(&lp_nccl_logits_buf_)); + allocator_->free((void**)(&lp_logprob_buf_)); - allocator_->free(lp_normed_decoder_output_buf_); - allocator_->free(lp_logits_buf_); - allocator_->free(lp_nccl_logits_buf_); - allocator_->free(lp_logprob_buf_); + cudaFreeHost(generation_should_stop_); + + if (shared_contexts_ratio_ > 0.0f) { + allocator_->free((void**)(&shared_contexts_idx_)); + allocator_->free((void**)(&compact_size_)); + } + allocator_->free((void**)(&tiled_total_padding_count_)); is_allocate_buffer_ = false; } } template -ParallelGpt::ParallelGpt(size_t max_batch_size, - size_t max_seq_len, - size_t max_input_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop, - bool sparse, - int int8_mode, +ParallelGpt::ParallelGpt(size_t max_batch_size, + size_t max_seq_len, + size_t max_input_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + int start_id, + int end_id, + int prompt_learning_start_id, + PromptLearningType prompt_learning_type, + gptVariantParams gpt_variant_params, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, + bool sparse, + int int8_mode, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce, + bool remove_padding, + float shared_contexts_ratio): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop, sparse), head_num_(head_num), size_per_head_(size_per_head), @@ -234,6 +282,10 @@ ParallelGpt::ParallelGpt(size_t max_batch_size, vocab_size_(vocab_size), start_id_(start_id), end_id_(end_id), + prompt_learning_start_id_(prompt_learning_start_id), + prompt_learning_type_(prompt_learning_type), + layernorm_eps_(gpt_variant_params.layernorm_eps), + gpt_variant_params_(gpt_variant_params), beam_search_diversity_rate_(beam_search_diversity_rate), hidden_units_(head_num_ * size_per_head), top_k_(top_k), @@ -247,7 +299,9 @@ ParallelGpt::ParallelGpt(size_t max_batch_size, local_head_num_(head_num / tensor_para.world_size_), int8_mode_(int8_mode), custom_all_reduce_comm_(custom_all_reduce_comm), - enable_custom_all_reduce_(enable_custom_all_reduce) + enable_custom_all_reduce_(enable_custom_all_reduce), + remove_padding_(remove_padding), + shared_contexts_ratio_(shared_contexts_ratio) { int local_vacab_size = ceil(vocab_size_ / 1.f / tensor_para_.world_size_); if (std::is_same::value) { @@ -267,7 +321,11 @@ ParallelGpt::ParallelGpt(ParallelGpt const& gpt): vocab_size_(gpt.vocab_size_), start_id_(gpt.start_id_), end_id_(gpt.end_id_), + prompt_learning_start_id_(gpt.prompt_learning_start_id_), + prompt_learning_type_(gpt.prompt_learning_type_), beam_search_diversity_rate_(gpt.beam_search_diversity_rate_), + layernorm_eps_(gpt.gpt_variant_params_.layernorm_eps), + gpt_variant_params_(gpt.gpt_variant_params_), hidden_units_(gpt.hidden_units_), top_k_(gpt.top_k_), top_p_(gpt.top_p_), @@ -281,7 +339,9 @@ ParallelGpt::ParallelGpt(ParallelGpt const& gpt): vocab_size_padded_(gpt.vocab_size_padded_), int8_mode_(gpt.int8_mode_), custom_all_reduce_comm_(gpt.custom_all_reduce_comm_), - enable_custom_all_reduce_(gpt.enable_custom_all_reduce_) + enable_custom_all_reduce_(gpt.enable_custom_all_reduce_), + remove_padding_(gpt.remove_padding_), + shared_contexts_ratio_(gpt.shared_contexts_ratio_) { initialize(); } @@ -296,13 +356,13 @@ ParallelGpt::~ParallelGpt() } template -void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, - const T* context_decoder_outputs, - const int* input_ids, - const int* input_lengths, - const size_t batch_size, - const size_t beam_width, - const size_t max_input_length, +void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, + const T* context_decoder_outputs, + const int* input_ids, + const int* input_lengths, + const size_t batch_size, + const size_t beam_width, + const size_t max_input_length, const ParallelGptWeight* gpt_weights) { // Compute the log probabilties of prompt inputs. @@ -313,7 +373,7 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, // input_lengths [batch_size, beam_width]; input lengths. FT_LOG_DEBUG(__PRETTY_FUNCTION__); - const size_t batchxbeam = batch_size * beam_width; + const size_t batchxbeam = batch_size * beam_width; const size_t n_hidden_states = batchxbeam * max_input_length; if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { @@ -322,13 +382,14 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, context_decoder_outputs, gpt_weights->post_decoder_layernorm.gamma, gpt_weights->post_decoder_layernorm.beta, + layernorm_eps_, n_hidden_states, hidden_units_, stream_); sync_check_cuda_error(); if (tensor_para_.world_size_ == 1) { float alpha = 1.0f; - float beta = 0.0f; + float beta = 0.0f; cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, vocab_size_padded_, // n @@ -352,8 +413,8 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, else { FT_CHECK(vocab_size_padded_ % tensor_para_.world_size_ == 0); const int local_vocab_size = vocab_size_padded_ / tensor_para_.world_size_; - float alpha = 1.0f; - float beta = 0.0f; + float alpha = 1.0f; + float beta = 0.0f; cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, local_vocab_size, // n @@ -408,8 +469,22 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, } template -void ParallelGpt::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void ParallelGpt::registerCallback(callback_sig* fn, void* ctx) +{ + token_generated_cb_ = fn; + token_generated_ctx_ = ctx; +} + +template +void ParallelGpt::unRegisterCallback() +{ + token_generated_cb_ = nullptr; + token_generated_ctx_ = nullptr; +} + +template +void ParallelGpt::forward(std::vector* output_tensors, + const std::vector* input_tensors, const ParallelGptWeight* gpt_weights) { // input_tensors: @@ -419,7 +494,6 @@ void ParallelGpt::forward(std::vector* output_tensors, // output_tensors: // output_ids [batch_size, beam, max_output_seq_len] - // parent_ids [max_output_seq_len, batch_size, beam] // sequence_length [batch_size, beam] // output_log_probs [batch_size, beam, request_output_seq_len], must be float*. // It leads to additional computing cost. If we don't need this result, please put nullptr @@ -436,13 +510,12 @@ void ParallelGpt::forward(std::vector* output_tensors, {"input_lengths", input_tensors->at(1)}, {"max_output_seq_len", input_tensors->at(2)}}; input_tensors_map.insert({"random_seed", {MEMORY_CPU, TYPE_INT32, {1}, &random_seed_}}); - input_tensors_map.insert({"runtime_top_k", {MEMORY_CPU, TYPE_INT32, {1}, &top_k_}}); + input_tensors_map.insert({"runtime_top_k", {MEMORY_CPU, TYPE_UINT32, {1}, &top_k_}}); input_tensors_map.insert({"runtime_top_p", {MEMORY_CPU, TYPE_FP32, {1}, &top_p_}}); std::unordered_map output_tensors_map{{"output_ids", output_tensors->at(0)}, - {"parent_ids", output_tensors->at(1)}, - {"sequence_length", output_tensors->at(2)}, - {"output_log_probs", output_tensors->at(3)}}; + {"sequence_length", output_tensors->at(1)}, + {"output_log_probs", output_tensors->at(2)}}; if (output_tensors->size() > 3) { output_tensors_map.insert({"cum_log_probs", output_tensors->at(4)}); } @@ -450,31 +523,36 @@ void ParallelGpt::forward(std::vector* output_tensors, } template -void ParallelGpt::forward(std::unordered_map* output_tensors, +void ParallelGpt::forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const ParallelGptWeight* gpt_weights) + const ParallelGptWeight* gpt_weights) { // input_tensors: // input_ids [batch_size, max_input_length] // input_lengths [batch_size] - // max_output_seq_len [1] on cpu + // prompt_learning_task_name_ids [batch_size] on cpu + // output_seq_len [batch_size] on cpu // stop_words_list [batch_size, 2, stop_words_length], optional + // bad_words_list [2, bad_words_length] or [batch_size, 2, bad_words_length], optional // start_id [batch_size] on cpu, optional // end_id [batch_size] on cpu, optional - // runtime_top_k [1] or [batch_size] on cpu, optional - // runtime_top_p [1] or [batch_size] on cpu, optional - // beam_search_diversity_rate [1] or [batch_size] on cpu, optional - // temperature [1] or [batch_size] on cpu, optional - // len_penalty [1] or [batch_size] on cpu, optional - // repetition_penalty [1] or [batch_size] on cpu, optional - // random_seed [1] or [batch_size] on cpu, optional - // prefix_soft_prompt_lengths [batch_size], optional - // prefix_soft_prompt_embedding [batch_size, max_prefix_soft_prompt_length, hidden_units], float, optional + // runtime_top_k [1] or [batch_size] on cpu, optional, uint. + // runtime_top_p [1] or [batch_size] on cpu, optional, float. + // beam_search_diversity_rate [1] or [batch_size] on cpu, optional, float. + // temperature [1] or [batch_size] on cpu, optional, float. + // len_penalty [1] or [batch_size] on cpu, optional, float. + // repetition_penalty [1] or [batch_size] on cpu, optional, float. + // random_seed [1] or [batch_size] on cpu, optional, unsigned long long int. + // request_prompt_lengths [batch_size], optional + // request_prompt_embedding [batch_size, max_prompt_length, hidden_units], float, optional + // request_prompt_type [batch_size], int, optional // is_return_context_cum_log_probs [1] on cpu, bool, optional + // session_len [1] on cpu, uint32, optional + // memory_len [1] on cpu, uint32, optional + // continue_gen [1] on cpu, bool, optional // output_tensors: // output_ids [batch_size, beam_width, max_output_seq_len] - // parent_ids [max_output_seq_len, batch_size, beam_width] // sequence_length [batch_size, beam_width] // output_log_probs [batch_size, beam_width, request_output_seq_len], must be float*. // optional. It leads to additional computing cost. If we don't need this result, don't put it. @@ -488,12 +566,12 @@ void ParallelGpt::forward(std::unordered_map* output_ten // the step 1 ~ max_output_seq_len of output_ids_buf_ to output_tensors->at(0).data FT_CHECK_WITH_INFO(input_tensors->size() >= 3, "input_tensors->size() >= 3"); - FT_CHECK_WITH_INFO(output_tensors->size() >= 3, "output_tensors->size() >= 3"); + FT_CHECK_WITH_INFO(output_tensors->size() >= 2, "output_tensors->size() >= 2"); FT_CHECK(input_tensors->at("input_ids").shape.size() == 2); FT_CHECK(input_tensors->at("input_lengths").shape.size() == 1); - FT_CHECK(input_tensors->at("max_output_seq_len").shape.size() == 1); + FT_CHECK(input_tensors->find("output_seq_len") != input_tensors->end() + && input_tensors->at("output_seq_len").shape.size() == 1); FT_CHECK(output_tensors->at("output_ids").shape.size() == 3); - FT_CHECK(output_tensors->at("parent_ids").shape.size() == 3); FT_CHECK(output_tensors->at("sequence_length").shape.size() == 2); FT_CHECK_WITH_INFO(input_tensors->at("input_ids").shape[0] == output_tensors->at("output_ids").shape[0], "input_tensors->at(\"input_ids\").shape[0] == output_tensors->at(\"output_ids\").shape[0]"); @@ -505,13 +583,120 @@ void ParallelGpt::forward(std::unordered_map* output_ten || output_tensors->at("cum_log_probs").size() == batch_size * beam_width, "The shape of cum_log_probs should match with batch_size x beam_width if provided."); int max_input_length = input_tensors->at("input_ids").shape[1]; - const size_t max_prefix_soft_prompt_length = input_tensors->count("prefix_soft_prompt_embedding") ? - input_tensors->at("prefix_soft_prompt_embedding").shape[1] : - 0; - const size_t max_output_seq_len = (*((int*)input_tensors->at("max_output_seq_len").data)) - + (max_input_length == 0 ? 1 : 0) // additional 1 to put start token - + max_prefix_soft_prompt_length; - const size_t max_seq_len = max_output_seq_len; + + bool continue_gen = input_tensors->find("continue_gen") != input_tensors->end() ? + input_tensors->at("continue_gen").getVal() : + false; + // triton backend will send START flag + if (input_tensors->find("START") != input_tensors->end()) { + continue_gen = !((bool)input_tensors->at("START").getVal()); + } + + const int initial_step = continue_gen ? step_ : 0; + int max_context_len = max_input_length + initial_step; + + // NOTE: the input already contains the p/prompt-tunning tokens ids for p/prompt tuning task + // prompt_learning_task_name_ids are used by both p/prompt-tunning and prefix_prompt task + const int* prompt_learning_task_name_ids = input_tensors->count("prompt_learning_task_name_ids") ? + (const int*)input_tensors->at("prompt_learning_task_name_ids").data : + nullptr; + + FT_CHECK_WITH_INFO( + !(prompt_learning_task_name_ids != nullptr + && (prompt_learning_type_ == PromptLearningType::no_prompt + || prompt_learning_type_ == PromptLearningType::soft_prompt)), + "prompt_learning_type is prefix_prompt either p_prompt_tuning when prompt_learning_task_name_ids are provided."); + + PromptLearningType request_prompt_type = PromptLearningType::no_prompt; + int valid_prompt_inputs = input_tensors->count("request_prompt_type") + + input_tensors->count("request_prompt_lengths") + + input_tensors->count("request_prompt_embedding"); + + if (valid_prompt_inputs == 3) { + request_prompt_type = static_cast(input_tensors->at("request_prompt_type").getVal()); + if (prompt_learning_task_name_ids != nullptr) { + FT_LOG_INFO("Apply prompt embedding from input, will ignore task name ids"); + } + } + else if (valid_prompt_inputs > 0) { + FT_LOG_WARNING( + "Prompts not applied: request_prompt_embedding, request_prompt_lengths, request_prompt_type are all needed!"); + } + if (request_prompt_type == PromptLearningType::prefix_prompt) { + FT_LOG_WARNING("Request prompt doesn't support prefix prompt currently!"); + } + + // whether or not use prompt embeddings from the request. + // If true, staticlly loaded prompts weights during model loading and task name ids will be ignored + bool use_request_p_prompt_embedding = request_prompt_type == PromptLearningType::p_prompt_tuning; + int max_request_p_prompt_length = + use_request_p_prompt_embedding ? input_tensors->at("request_prompt_embedding").shape[1] : 0; + + has_prefix_prompt_ = + (prompt_learning_task_name_ids != nullptr && prompt_learning_type_ == PromptLearningType::prefix_prompt); + has_p_prompt_tuning_ = + prompt_learning_task_name_ids != nullptr && prompt_learning_type_ == PromptLearningType::p_prompt_tuning + || use_request_p_prompt_embedding; + bool use_loaded_p_prompt_embedding = has_p_prompt_tuning_ && !use_request_p_prompt_embedding; + has_prefix_soft_prompt_ = request_prompt_type == PromptLearningType::soft_prompt; + + // NOTE: soft prompt + const size_t max_prefix_soft_prompt_length = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_embedding").shape[1] : 0; + const size_t limit_len_offset = max_prefix_soft_prompt_length + (max_input_length == 0 ? 1 : 0); + const size_t gen_len = input_tensors->at("output_seq_len").max() + limit_len_offset; + + size_t session_len = 0; + if (continue_gen) { + session_len = session_len_; // Record the size of allocated buffer in previous round. + } + else if (input_tensors->find("session_len") != input_tensors->end()) { + session_len = input_tensors->at("session_len").getVal(); // Use for allocate buffer in first round. + } + else { + session_len = gen_len; // When the interactive generation mode is disabled. + } + session_len_ = session_len; + FT_CHECK_WITH_INFO( + gen_len + initial_step <= session_len, + fmtstr("Session size too low (%d) vs. total output size (%d)", session_len, gen_len + initial_step)); + size_t memory_len = 0; + if (continue_gen) { + memory_len = memory_len_; // Record the size of allocated buffer in previous round. + } + else if (input_tensors->find("memory_len") != input_tensors->end()) { + memory_len = input_tensors->at("memory_len").getVal(); // Use for allocate buffer in first round. + } + else { + memory_len = session_len; // When the interactive generation mode is disabled. + } + memory_len_ = memory_len; + /* TODO: could remove this constraint by changing how context decoder operates */ + FT_CHECK_WITH_INFO(max_input_length <= memory_len, + fmtstr("Memory size too low (%d) vs. input length (%d)", memory_len, max_input_length)); + + if (memory_len < session_len) { + FT_LOG_WARNING("memory_len (%d) is less than session_len (%d). " + "Note that this reduces the memory cost of k/v cache, but may hurt the accuracy.", + memory_len, + session_len); + } + else if (memory_len > session_len) { + FT_LOG_WARNING("memory_len (%d) is larger than session_len (%d). " + "This may lead to additional memory cost. Suggest to use smaller memory_len.", + memory_len, + session_len); + } + + if (session_len_ > gpt_weights->getMaxSeqLen()) { + FT_LOG_ERROR("The session_len_ (%d) of request is longer than max_seq_len (%d) of embedding table." + " This is a invalid input. Setting the session_len_ to %d.", + session_len_, + gpt_weights->getMaxSeqLen(), + gpt_weights->getMaxSeqLen()); + session_len_ = gpt_weights->getMaxSeqLen(); + } + const bool is_return_context_cum_log_probs = input_tensors->count("is_return_context_cum_log_probs") > 0 && input_tensors->at("is_return_context_cum_log_probs").getVal(); if (is_return_context_cum_log_probs) { @@ -521,56 +706,35 @@ void ParallelGpt::forward(std::unordered_map* output_ten "the cumulative log probability computation of input contexts."); } - allocateBuffer(batch_size, - beam_width, - max_seq_len, - max_input_length + max_prefix_soft_prompt_length, - is_return_context_cum_log_probs); - sync_check_cuda_error(); - bool has_diff_runtime_args = hasDiffRuntimeArgs(input_tensors); + if (!continue_gen) { + allocateBuffer(batch_size, + beam_width, + session_len, + memory_len, + max_input_length + max_prefix_soft_prompt_length, + is_return_context_cum_log_probs); + sync_check_cuda_error(); + } - int* sequence_lengths = (int*)(output_tensors->at("sequence_length").data); - const DataType data_type = getTensorType(); + setSeqLimitLen(seq_limit_len_, input_tensors->at("output_seq_len"), limit_len_offset, batch_size); + + const DataType data_type = getTensorType(); const cudaDataType_t gemm_data_type = getCudaDataType(); const std::vector self_k_cache_shape = {num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, local_head_num_, size_per_head_ / (16 / sizeof(T)), - max_output_seq_len, + memory_len, 16 / sizeof(T)}; - const std::vector self_v_cache_shape = {num_layer_ / pipeline_para_.world_size_, - batch_size * beam_width, - local_head_num_, - max_output_seq_len, - size_per_head_}; + const std::vector self_v_cache_shape = { + num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, local_head_num_, memory_len, size_per_head_}; + dynamic_decode_layer_->setup(batch_size, beam_width, input_tensors); handleOptArg(input_tensors, "start_id", start_ids_buf_, start_id_, batch_size); handleOptArg(input_tensors, "end_id", end_ids_buf_, end_id_, batch_size); - // TODO(bhsueh) Initilaize them in one kernel - // initialize the output ids and parent ids - cudaMemsetAsync(output_ids_buf_, 0, sizeof(int) * batch_size * beam_width * max_seq_len, stream_); - cudaMemsetAsync(parent_ids_buf_, 0, sizeof(int) * batch_size * beam_width * max_seq_len, stream_); - if (beam_width > 1) { - cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * max_seq_len, stream_); - } - sync_check_cuda_error(); - - if (vocab_size_ == vocab_size_padded_) { - padded_embedding_kernel_ptr_ = gpt_weights->post_decoder_embedding.kernel; - } - else { - cudaMemcpyAsync(padded_embedding_kernel_, - gpt_weights->post_decoder_embedding.kernel, - sizeof(T) * vocab_size_ * hidden_units_, - cudaMemcpyDeviceToDevice, - stream_); - sync_check_cuda_error(); - } - - // handle first step - if (input_tensors->count("prefix_soft_prompt_embedding") || max_input_length >= 1) { + if (continue_gen) { invokeTileGptInputs(tiled_input_ids_buf_, tiled_input_lengths_buf_, (int*)input_tensors->at("input_ids").data, @@ -579,169 +743,304 @@ void ParallelGpt::forward(std::unordered_map* output_ten beam_width, max_input_length, stream_); + invokePlusScalar(tiled_input_lengths_buf_, initial_step, batch_size * beam_width, stream_); + sync_check_cuda_error(); + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + nullptr, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + initial_step - 1, + stream_); + invokeTransposeAxis01(output_ids_buf_ + initial_step * batch_size * beam_width, + tiled_input_ids_buf_, + batch_size * beam_width, + max_input_length, + 1, + stream_); + } + else { + // TODO(bhsueh) Initilaize them in one kernel + // initialize the output ids and parent ids + cudaMemsetAsync(output_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_); + cudaMemsetAsync(parent_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_); + cudaMemsetAsync(masked_tokens_, false, sizeof(bool) * batch_size * beam_width * memory_len, stream_); + cudaMemsetAsync(tiled_total_padding_count_, 0, sizeof(int) * batch_size * beam_width, stream_); + if (beam_width > 1) { + cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * memory_len, stream_); + } sync_check_cuda_error(); - if (input_tensors->count("prefix_soft_prompt_embedding")) { - inputIdsEmbeddingLookupPosEncodingSoftPromptParam param; - param.from_tensor = context_decoder_input_buf_; - param.output_ids = output_ids_buf_; - param.input_lengths = tiled_input_lengths_buf_; - param.embedding_table = gpt_weights->pre_decoder_embedding_table; - param.pos_table = gpt_weights->position_encoding_table; - param.prefix_soft_prompt_embedding = input_tensors->at("prefix_soft_prompt_embedding").getPtr(); - param.prefix_soft_prompt_lengths = input_tensors->at("prefix_soft_prompt_lengths").getPtr(); - param.input_ids = tiled_input_ids_buf_; - param.start_step = 1; - param.max_input_length = max_input_length; - param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; - param.batch_size = batch_size; - param.beam_width = beam_width; - param.hidden_units = hidden_units_; - param.stream = stream_; - - invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(param); - sync_check_cuda_error(); - max_input_length += max_prefix_soft_prompt_length; // view soft_prompt as input + if (vocab_size_ == vocab_size_padded_) { + padded_embedding_kernel_ptr_ = gpt_weights->post_decoder_embedding.kernel; } else { - invokeInputIdsEmbeddingLookupPosEncoding(context_decoder_input_buf_, - output_ids_buf_, - gpt_weights->pre_decoder_embedding_table, - gpt_weights->position_encoding_table, - tiled_input_ids_buf_, - 1, - max_input_length, - max_input_length, - batch_size * beam_width, - hidden_units_, - stream_); + cudaAutoCpy(padded_embedding_kernel_, + gpt_weights->post_decoder_embedding.kernel, + vocab_size_ * hidden_units_, + stream_); sync_check_cuda_error(); } - invokeBuildDecoderAttentionMask( - input_attention_mask_, tiled_input_lengths_buf_, batch_size * beam_width, max_input_length, stream_); - sync_check_cuda_error(); + int compact_size; + bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1); + if (use_shared_contexts) { + invokeFindContextDups(shared_contexts_idx_, + batch_to_compact_idx_, + compact_idx_, + compact_size_, + input_tensors->at("input_ids").getPtr(), + batch_size, + max_input_length, + stream_); + cudaD2Hcpy(&compact_size, compact_size_, 1); + use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size; + sync_check_cuda_error(); + } - std::vector decoder_input_tensors{ - Tensor{MEMORY_GPU, - data_type, - {batch_size * beam_width, (size_t)max_input_length, hidden_units_}, - context_decoder_input_buf_}, - Tensor{MEMORY_GPU, - data_type, - {batch_size * beam_width, 1, (size_t)max_input_length, (size_t)max_input_length}, - input_attention_mask_}, - Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, tiled_input_lengths_buf_}}; - - std::vector decoder_output_tensors{ - Tensor{MEMORY_GPU, - data_type, - {batch_size * beam_width, (size_t)max_input_length, hidden_units_}, - context_decoder_output_buf_}, - Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}, - Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}, - Tensor{MEMORY_GPU, data_type, {batch_size * beam_width, hidden_units_}, decoder_output_buf_}}; - - gpt_context_decoder_->forward( - &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); + // NOTE: p/prompt-tuning process here (lookup prompt embedding tables by task name ids) + // get p/prompt-tuning weight for each batch --> shape [batch, beam_width] + // --> ptrs with shape [prompt_len, hidden_size] + std::vector p_prompt_tuning_batch_ptrs; + std::vector p_prompt_tuning_lengths; + if (use_loaded_p_prompt_embedding) { + for (int bs_id = 0; bs_id < batch_size; ++bs_id) { + int task_id = prompt_learning_task_name_ids[bs_id]; + std::pair p_prompt_tuning_pair = {}; + bool valid_task_name_id = task_id < gpt_weights->prompt_learning_table.size(); + if (valid_task_name_id) { + p_prompt_tuning_pair = gpt_weights->prompt_learning_table.at(task_id); + } + else { + // don't throw oor in case of model server failing + FT_LOG_ERROR("p_prompt_tuning_weights not found for task id: " + std::to_string(task_id) + + "\n return with invalid output tensors"); + return; + } + for (int bw_id = 0; bw_id < beam_width; ++bw_id) { + // only weight ptrs needed here + p_prompt_tuning_batch_ptrs.push_back(p_prompt_tuning_pair.first); + p_prompt_tuning_lengths.push_back(p_prompt_tuning_pair.second); + } + } - invokeDecodingInitialize(finished_buf_, - sequence_lengths, - nullptr, - cum_log_probs_, - start_ids_buf_, - batch_size, - beam_width, - max_input_length - 1, - stream_); + cudaAutoCpy( + prompt_learning_weight_batch_, p_prompt_tuning_batch_ptrs.data(), batch_size * beam_width, stream_); + + cudaAutoCpy(tiled_prompt_lengths_buf_, p_prompt_tuning_lengths.data(), batch_size * beam_width, stream_); + + sync_check_cuda_error(); + } - if (is_return_context_cum_log_probs) { - computeContextCumLogProbs(cum_log_probs_, - context_decoder_output_buf_, - tiled_input_ids_buf_, + // handle first step + if (has_p_prompt_tuning_ || has_prefix_prompt_ || has_prefix_soft_prompt_ || max_input_length > 1) { + invokeTileGptPromptInputs(tiled_input_ids_buf_, tiled_input_lengths_buf_, + use_request_p_prompt_embedding ? tiled_prompt_lengths_buf_ : nullptr, + (int*)input_tensors->at("input_ids").data, + (const int*)(input_tensors->at("input_lengths").data), + use_request_p_prompt_embedding ? + (const int*)(input_tensors->at("request_prompt_lengths").data) : + nullptr, batch_size, beam_width, - (size_t)max_input_length, - gpt_weights); + max_input_length, + stream_); + sync_check_cuda_error(); + + if (has_prefix_soft_prompt_) { + inputIdsEmbeddingLookupPosEncodingSoftPromptParam param; + param.from_tensor = context_decoder_input_buf_; + param.output_ids = output_ids_buf_; + param.input_lengths = tiled_input_lengths_buf_; + param.embedding_table = gpt_weights->pre_decoder_embedding_table; + param.pos_table = gpt_weights->position_encoding_table; + param.prefix_soft_prompt_embedding = input_tensors->at("request_prompt_embedding").getPtr(); + param.prefix_soft_prompt_lengths = input_tensors->at("request_prompt_lengths").getPtr(); + param.input_ids = tiled_input_ids_buf_; + param.start_step = 1; + param.max_input_length = max_input_length; + param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.hidden_units = hidden_units_; + param.stream = stream_; + + invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(param); + sync_check_cuda_error(); + max_input_length += max_prefix_soft_prompt_length; // view soft_prompt as input + } + else { + // NOTE: add prompt embeddings here (for p/prompt tuning) + pPromptTuningParam prompt_param{ + use_loaded_p_prompt_embedding ? prompt_learning_weight_batch_ : (const T**)nullptr, + prompt_learning_start_id_, + max_request_p_prompt_length, + use_request_p_prompt_embedding, + use_request_p_prompt_embedding ? input_tensors->at("request_prompt_embedding").getPtr() : + nullptr}; + invokeInputIdsEmbeddingLookupPosEncoding(context_decoder_input_buf_, + output_ids_buf_, + gpt_weights->pre_decoder_embedding_table, + gpt_weights->position_encoding_table, + prompt_param, + tiled_input_ids_buf_, + 1, + max_input_length, + max_input_length, + batch_size * beam_width, + hidden_units_, + stream_); + sync_check_cuda_error(); + } + + invokeBuildDecoderAttentionMask(input_attention_mask_, + tiled_input_lengths_buf_, + nullptr, + batch_size * beam_width, + max_input_length, + 0, + stream_); + sync_check_cuda_error(); + + std::vector decoder_input_tensors{ + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width, (size_t)max_input_length, hidden_units_}, + context_decoder_input_buf_}, + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width, 1, (size_t)max_input_length, (size_t)max_input_length}, + input_attention_mask_}, + Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, tiled_input_lengths_buf_}}; + + if (use_shared_contexts) { + decoder_input_tensors.push_back({MEMORY_GPU, TYPE_INT32, {(size_t)compact_size}, compact_idx_}); + decoder_input_tensors.push_back({MEMORY_GPU, TYPE_INT32, {batch_size}, batch_to_compact_idx_}); + } + + std::vector decoder_output_tensors{ + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width, (size_t)max_input_length, hidden_units_}, + context_decoder_output_buf_}, + Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}, + Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}, + Tensor{MEMORY_GPU, data_type, {batch_size * beam_width, hidden_units_}, decoder_output_buf_}}; + + gpt_context_decoder_->forward( + &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); + + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + nullptr, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + max_input_length - 1, + stream_); + + if (is_return_context_cum_log_probs) { + computeContextCumLogProbs(cum_log_probs_, + context_decoder_output_buf_, + tiled_input_ids_buf_, + tiled_input_lengths_buf_, + batch_size, + beam_width, + (size_t)max_input_length, + gpt_weights); + } + sync_check_cuda_error(); + } + else if (max_input_length == 0) { + FT_CHECK(prompt_learning_type_ == PromptLearningType::no_prompt + && request_prompt_type == PromptLearningType::no_prompt); + max_input_length++; + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + output_ids_buf_, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + max_input_length - 1, + stream_); + std::vector h_input_lengths(batch_size * beam_width, 1); + cudaAutoCpy(tiled_input_lengths_buf_, h_input_lengths.data(), batch_size * beam_width, stream_); + sync_check_cuda_error(); + } + else if (max_input_length == 1) { + FT_CHECK(prompt_learning_type_ == PromptLearningType::no_prompt + && request_prompt_type == PromptLearningType::no_prompt); + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + nullptr, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + max_input_length - 1, + stream_); + sync_check_cuda_error(); + invokeTileGptInputs(tiled_input_ids_buf_, + tiled_input_lengths_buf_, + (int*)input_tensors->at("input_ids").data, + (const int*)(input_tensors->at("input_lengths").data), + batch_size, + beam_width, + max_input_length, + stream_); + sync_check_cuda_error(); + + cudaAutoCpy(output_ids_buf_, tiled_input_ids_buf_, batch_size * beam_width, stream_); } - sync_check_cuda_error(); - } - else if (max_input_length == 0) { - max_input_length++; - invokeDecodingInitialize(finished_buf_, - sequence_lengths, - output_ids_buf_, - cum_log_probs_, - start_ids_buf_, - batch_size, - beam_width, - max_input_length - 1, - stream_); - std::vector h_input_lengths(batch_size * beam_width, 1); - cudaMemcpyAsync(tiled_input_lengths_buf_, - h_input_lengths.data(), - sizeof(int) * batch_size * beam_width, - cudaMemcpyHostToDevice, - stream_); - sync_check_cuda_error(); } - else if (max_input_length == 1) { - invokeDecodingInitialize(finished_buf_, - sequence_lengths, - nullptr, - cum_log_probs_, - start_ids_buf_, - batch_size, - beam_width, - max_input_length - 1, - stream_); - sync_check_cuda_error(); - invokeTileGptInputs(tiled_input_ids_buf_, - tiled_input_lengths_buf_, - (int*)input_tensors->at("input_ids").data, + + invokeMaskPaddingTokens(masked_tokens_, (const int*)(input_tensors->at("input_lengths").data), + memory_len, + max_input_length, + initial_step, batch_size, beam_width, - max_input_length, stream_); - sync_check_cuda_error(); - - cudaMemcpyAsync(output_ids_buf_, - tiled_input_ids_buf_, - sizeof(int) * batch_size * beam_width, - cudaMemcpyDeviceToDevice, - stream_); - } - for (int step = max_input_length; step < (int)max_output_seq_len; step++) { - const int src_indir_idx = (step - max_input_length) % 2; - const int tgt_indir_idx = 1 - src_indir_idx; + // If continue, we restart from initial_step because last token hasn't been processed in decoder + const int step_start = continue_gen ? initial_step : max_input_length; + for (step_ = step_start; step_ < (int)gen_len; step_++) { + // Loop body produces Nth token by embedding && encoding token (N-1) + // if necessary. + const bool fill_caches_only = continue_gen && (step_ < max_context_len); + const int src_indir_idx = (step_ - step_start) % 2; + const int tgt_indir_idx = 1 - src_indir_idx; const size_t local_batch_size = getLocalBatchSize(batch_size, 1, pipeline_para_.world_size_); FT_CHECK(batch_size % local_batch_size == 0); const size_t iteration_num = batch_size / local_batch_size; + *generation_should_stop_ = !fill_caches_only; for (uint ite = 0; ite < iteration_num; ++ite) { - const int id_offset = ite * local_batch_size * beam_width; - const int hidden_units_offset = id_offset * hidden_units_; + const int id_offset = ite * local_batch_size * beam_width; + const int hidden_units_offset = id_offset * hidden_units_; const int vocab_size_units_offset = id_offset * vocab_size_padded_; - if (!(max_input_length > 1 && step == max_input_length)) { + if ((max_input_length <= 1) || (step_ > step_start) || continue_gen) { if (pipeline_para_.rank_ == 0) { - invokeEmbeddingLookupPosEncoding(decoder_input_buf_ + hidden_units_offset, - gpt_weights->pre_decoder_embedding_table, - gpt_weights->position_encoding_table, - output_ids_buf_ + id_offset, - tiled_input_lengths_buf_ + id_offset, - local_batch_size * beam_width, - hidden_units_, - (T)(1.0f), - step - 1, - max_input_length, - batch_size * beam_width, - 0, - stream_); + invokeEmbeddingLookupPosEncodingPadCount(decoder_input_buf_ + hidden_units_offset, + gpt_weights->pre_decoder_embedding_table, + gpt_weights->position_encoding_table, + output_ids_buf_ + id_offset, + tiled_total_padding_count_ + id_offset, + local_batch_size * beam_width, + hidden_units_, + (T)(1.0f), + step_ - 1, + batch_size * beam_width, + 0, + stream_); sync_check_cuda_error(); } @@ -751,17 +1050,22 @@ void ParallelGpt::forward(std::unordered_map* output_ten {local_batch_size * beam_width, hidden_units_}, decoder_input_buf_ + hidden_units_offset}, Tensor{MEMORY_GPU, TYPE_BOOL, {local_batch_size * beam_width}, finished_buf_ + id_offset}, - Tensor{MEMORY_GPU, TYPE_INT32, {local_batch_size * beam_width}, sequence_lengths + id_offset}, - Tensor{ - MEMORY_GPU, TYPE_INT32, {local_batch_size * beam_width}, tiled_input_lengths_buf_ + id_offset}, - Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}, - Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}, + Tensor{MEMORY_GPU, TYPE_INT32, {local_batch_size * beam_width}, sequence_lengths_ + id_offset}, + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size * beam_width}, + tiled_total_padding_count_ + id_offset}, + Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_context_len}, + Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step_}, Tensor{MEMORY_CPU, TYPE_INT32, {1}, &ite}, Tensor{MEMORY_GPU, TYPE_INT32, - {local_batch_size, beam_width, max_output_seq_len}, - beam_width > 1 ? cache_indirections_[src_indir_idx] + id_offset * max_output_seq_len : - nullptr}}; + {local_batch_size, beam_width, memory_len}, + beam_width > 1 ? cache_indirections_[src_indir_idx] + id_offset * memory_len : nullptr}, + Tensor{MEMORY_GPU, + TYPE_BOOL, + {local_batch_size * beam_width, memory_len}, + masked_tokens_ + id_offset * memory_len}}; std::vector decoder_output_tensors{ Tensor{MEMORY_GPU, @@ -774,19 +1078,25 @@ void ParallelGpt::forward(std::unordered_map* output_ten &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); } - if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { - invokeGeneralLayerNorm(normed_decoder_output_buf_ + hidden_units_offset, - decoder_output_buf_ + hidden_units_offset, - gpt_weights->post_decoder_layernorm.gamma, - gpt_weights->post_decoder_layernorm.beta, - local_batch_size * beam_width, - hidden_units_, - stream_); + if (!fill_caches_only && pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { + // OPT + T* decoder_output_final_buf = + gpt_variant_params_.has_post_decoder_layernorm ? normed_decoder_output_buf_ : decoder_output_buf_; + if (gpt_variant_params_.has_post_decoder_layernorm) { + invokeGeneralLayerNorm(normed_decoder_output_buf_ + hidden_units_offset, + decoder_output_buf_ + hidden_units_offset, + gpt_weights->post_decoder_layernorm.gamma, + gpt_weights->post_decoder_layernorm.beta, + layernorm_eps_, + local_batch_size * beam_width, + hidden_units_, + stream_); + } sync_check_cuda_error(); if (tensor_para_.world_size_ == 1) { float alpha = 1.0f; - float beta = 0.0f; + float beta = 0.0f; cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, vocab_size_padded_, // n @@ -795,8 +1105,8 @@ void ParallelGpt::forward(std::unordered_map* output_ten &alpha, padded_embedding_kernel_ptr_, gemm_data_type, - hidden_units_, // k - normed_decoder_output_buf_ + hidden_units_offset, + hidden_units_, // k + decoder_output_final_buf + hidden_units_offset, // OPT: no final layer norm gemm_data_type, hidden_units_, // k &beta, @@ -809,8 +1119,8 @@ void ParallelGpt::forward(std::unordered_map* output_ten else { FT_CHECK(vocab_size_padded_ % tensor_para_.world_size_ == 0); const int local_vocab_size = vocab_size_padded_ / tensor_para_.world_size_; - float alpha = 1.0f; - float beta = 0.0f; + float alpha = 1.0f; + float beta = 0.0f; cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, local_vocab_size, // n @@ -820,8 +1130,8 @@ void ParallelGpt::forward(std::unordered_map* output_ten padded_embedding_kernel_ptr_ + tensor_para_.rank_ * local_vocab_size * hidden_units_, gemm_data_type, - hidden_units_, // k - normed_decoder_output_buf_ + hidden_units_offset, + hidden_units_, // k + decoder_output_final_buf + hidden_units_offset, // OPT: no final layer norm gemm_data_type, hidden_units_, // k &beta, @@ -837,7 +1147,6 @@ void ParallelGpt::forward(std::unordered_map* output_ten tensor_para_.rank_, tensor_para_, stream_); - check_cuda_error(cudaStreamSynchronize(stream_)); invokeTransposeAxis01(logits_buf_ + vocab_size_units_offset, nccl_logits_buf_ + vocab_size_units_offset, tensor_para_.world_size_, @@ -846,26 +1155,26 @@ void ParallelGpt::forward(std::unordered_map* output_ten stream_); } - int tmp_local_batch_size = local_batch_size; - bool is_initialize_random_table = step == max_input_length; + int tmp_local_batch_size = local_batch_size; + bool is_initialize_random_table = step_ == max_context_len; std::unordered_map dynamic_decode_input_tensors{ {"logits", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size, beam_width, vocab_size_padded_}, logits_buf_}}, {"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size_padded_}, nullptr}}, - {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}, - {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, + {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step_}}, + {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_context_len}}, + {"sequence_limit_length", Tensor{MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len_}}, {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids_buf_}}, {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, tiled_input_lengths_buf_}}, {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}}, - {"has_diff_runtime_args", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &has_diff_runtime_args}}, {"src_key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}}, {"src_value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}}, {"src_cache_indirection", Tensor{MEMORY_GPU, TYPE_INT32, - {local_batch_size, beam_width, max_output_seq_len}, - cache_indirections_[src_indir_idx] + id_offset * max_output_seq_len}}, + {local_batch_size, beam_width, memory_len}, + cache_indirections_[src_indir_idx] + id_offset * memory_len}}, {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_local_batch_size}}, {"is_initialize_random_table", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_initialize_random_table}}}; @@ -876,9 +1185,9 @@ void ParallelGpt::forward(std::unordered_map* output_ten } // common outputs + bool subbatch_should_stop = false; std::unordered_map dynamic_decode_output_tensors{ - {"output_ids", - Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, output_ids_buf_}}, + {"output_ids", Tensor{MEMORY_GPU, TYPE_INT32, {gen_len, batch_size, beam_width}, output_ids_buf_}}, {"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, finished_buf_}}, // cum_log_probs is necessary for beam search, while it is optional for sampling. {"cum_log_probs", @@ -890,19 +1199,19 @@ void ParallelGpt::forward(std::unordered_map* output_ten {"output_log_probs", Tensor{MEMORY_GPU, TYPE_FP32, - {max_seq_len, batch_size, beam_width}, + {gen_len, batch_size, beam_width}, output_tensors->count("output_log_probs") > 0 && output_tensors->at("output_log_probs").data != nullptr ? output_log_probs_buf_ : nullptr}}, - {"parent_ids", - Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, parent_ids_buf_}}, - {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, sequence_lengths}}, + {"parent_ids", Tensor{MEMORY_GPU, TYPE_INT32, {gen_len, batch_size, beam_width}, parent_ids_buf_}}, + {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, sequence_lengths_}}, {"tgt_cache_indirection", Tensor{MEMORY_GPU, TYPE_INT32, - {local_batch_size, beam_width, max_output_seq_len}, - cache_indirections_[tgt_indir_idx] + id_offset * max_output_seq_len}}}; + {local_batch_size, beam_width, memory_len}, + cache_indirections_[tgt_indir_idx] + id_offset * memory_len}}, + {"should_stop", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &subbatch_should_stop}}}; for (auto t = output_tensors->begin(); t != output_tensors->end(); ++t) { // Handle exceptions. if (t->first == "cum_log_probs" || t->first == "output_log_probs") { @@ -912,189 +1221,193 @@ void ParallelGpt::forward(std::unordered_map* output_ten } dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); + *generation_should_stop_ &= subbatch_should_stop; } } - if (pipeline_para_.world_size_ > 1) { - NCCLCHECK(ncclGroupStart()); - ftNcclBroadCast(output_ids_buf_ + step * batch_size * beam_width, + if (fill_caches_only) { + invokePlusScalar(sequence_lengths_, 1, batch_size * beam_width, stream_); + } + else if (pipeline_para_.world_size_ > 1) { + ftNcclGroupStart(); + ftNcclBroadCast(output_ids_buf_ + step_ * batch_size * beam_width, batch_size * beam_width, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); ftNcclBroadCast( - sequence_lengths, batch_size * beam_width, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + sequence_lengths_, batch_size * beam_width, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); - ftNcclBroadCast( - finished_buf_, batch_size * beam_width, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + ftNcclBroadCast(generation_should_stop_, 1, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); if (beam_width > 1) { - ftNcclBroadCast(cache_indirections_[tgt_indir_idx], - batch_size * beam_width * max_output_seq_len, + batch_size * beam_width * memory_len, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); } - NCCLCHECK(ncclGroupEnd()); - check_cuda_error(cudaStreamSynchronize(stream_)); + ftNcclGroupEnd(); + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); sync_check_cuda_error(); } - cudaD2Hcpy(h_finished_buf_, finished_buf_, batch_size * beam_width); - uint sum = 0; - for (uint i = 0; i < batch_size * beam_width; i++) { - sum += (int)h_finished_buf_[i]; - } - if (sum == batch_size * beam_width) { + if (*generation_should_stop_) { break; } - } - - if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { - if (input_tensors->at("input_ids").shape[1] == 0) { - if (beam_width > 1) { - // For beam search, do gather_tree - invokeGatherTree(transposed_output_ids_buf_, - sequence_lengths, - max_output_seq_len, - batch_size, - beam_width, - output_ids_buf_ + batch_size * beam_width, - parent_ids_buf_ + batch_size * beam_width, - end_ids_buf_, - stream_); - // transpose and take output_parent_ids as inter buffer - invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, - transposed_output_ids_buf_, - max_output_seq_len - 1, - batch_size * beam_width, - 1, - stream_); + if (token_generated_cb_ && step_ + 1 < (int)gen_len) { + setOutputTensors(output_tensors, input_tensors, gen_len, session_len, max_context_len); + sendTensorsToFirstPipelineNode(output_tensors, input_tensors); - cudaD2Dcpy((int*)output_tensors->at("parent_ids").data, - parent_ids_buf_ + batch_size * beam_width, - batch_size * beam_width * (max_output_seq_len - 1)); - } - else { - // For sampling, only transpose the results to output_tensor - invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, - output_ids_buf_ + batch_size * beam_width, - max_output_seq_len - 1, - batch_size * beam_width, - 1, - stream_); + if (pipeline_para_.rank_ == 0 && tensor_para_.rank_ == 0) { + token_generated_cb_(output_tensors, token_generated_ctx_); } } - else { - // add sequence_length 1 here because the sequence_length of time step t is t - 1 - invokePlusScalar(sequence_lengths, 1, batch_size * beam_width, stream_); - - // For sampling, it is equivalent to all parent ids are 0. - gatherTreeParam param; - param.beams = transposed_output_ids_buf_; - param.max_sequence_lengths = sequence_lengths; - param.max_time = max_output_seq_len; - param.batch_size = batch_size; - param.beam_width = beam_width; - param.step_ids = output_ids_buf_; - param.parent_ids = beam_width == 1 ? nullptr : parent_ids_buf_; - param.end_tokens = end_ids_buf_; - param.max_input_length = max_input_length; - param.prefix_soft_prompt_lengths = input_tensors->count("prefix_soft_prompt_lengths") ? - input_tensors->at("prefix_soft_prompt_lengths").getPtr() : - nullptr; - param.input_lengths = tiled_input_lengths_buf_; - param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; - param.stream = stream_; - param.output_ids = (int*)output_tensors->at("output_ids").data; - invokeGatherTree(param); - sync_check_cuda_error(); - } - if (output_tensors->count("output_log_probs")) { - invokeTransposeAxis01(output_tensors->at("output_log_probs").getPtr(), - output_log_probs_buf_, - input_tensors->at("max_output_seq_len").getVal() - max_input_length, - batch_size * beam_width, - 1, - stream_); - } - // Return the cumulative log probability if requested. - if (output_tensors->count("cum_log_probs") > 0 && output_tensors->at("cum_log_probs").data != nullptr) { - Tensor cum_log_probs = output_tensors->at("cum_log_probs"); - FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, - "The shape of cum_log_probs does not match with batch_size x beam_width."); - cudaD2Dcpy(cum_log_probs.getPtr(), cum_log_probs_, batch_size * beam_width); + if (step_ == initial_step + max_input_length) { + /* We have just finished processing input: update the padding count: + * total_padding_count += (max_input_length - input_lengths) */ + invokeUpdatePaddingCount(tiled_total_padding_count_, + (const int*)(input_tensors->at("input_lengths").data), + max_input_length, + batch_size, + beam_width, + stream_); } } + setOutputTensors(output_tensors, input_tensors, gen_len, session_len, max_context_len); + sendTensorsToFirstPipelineNode(output_tensors, input_tensors); +} - if (pipeline_para_.world_size_ > 1) { - NCCLCHECK(ncclGroupStart()); - if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { - ftNcclSend(output_tensors->at("output_ids").getPtr(), - batch_size * beam_width * max_output_seq_len, - 0, - pipeline_para_, - stream_); - - ftNcclSend(output_tensors->at("sequence_length").getPtr(), - batch_size * beam_width, - 0, - pipeline_para_, - stream_); +template +void ParallelGpt::sendTensorsToFirstPipelineNode(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors) +{ + if (pipeline_para_.world_size_ == 1) { + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); + return; + } - if (output_tensors->count("cum_log_probs") > 0 && output_tensors->at("cum_log_probs").data != nullptr) { - ftNcclSend(output_tensors->at("cum_log_probs").getPtr(), - batch_size * beam_width, - 0, - pipeline_para_, - stream_); - } + const auto pp_rank = pipeline_para_.rank_; - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclSend(output_tensors->at("output_log_probs").getPtr(), - output_tensors->at("output_log_probs").size(), - 0, - pipeline_para_, - stream_); - } + ftNcclGroupStart(); + for (auto const& it : *output_tensors) { + if (it.second.data == nullptr) { + continue; } - else if (pipeline_para_.rank_ == 0) { - ftNcclRecv(output_tensors->at("output_ids").getPtr(), - batch_size * beam_width * max_output_seq_len, - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); - ftNcclRecv(output_tensors->at("sequence_length").getPtr(), - batch_size * beam_width, + if (pp_rank == pipeline_para_.world_size_ - 1) { + ftNcclSend(it.second.getPtr(), it.second.sizeBytes(), 0, pipeline_para_, stream_); + } + else if (pp_rank == 0) { + ftNcclRecv(it.second.getPtr(), + it.second.sizeBytes(), pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + } + } + ftNcclGroupEnd(); + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); +} - if (output_tensors->count("cum_log_probs") > 0 && output_tensors->at("cum_log_probs").data != nullptr) { - ftNcclRecv(output_tensors->at("cum_log_probs").getPtr(), - batch_size * beam_width, - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); - } +template +void ParallelGpt::setOutputTensors(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const size_t gen_len, + const size_t session_len, + const size_t max_context_len) +{ + if (pipeline_para_.rank_ != pipeline_para_.world_size_ - 1) { + return; + } - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclRecv(output_tensors->at("output_log_probs").getPtr(), - output_tensors->at("output_log_probs").size(), - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); - } + const size_t batch_size = output_tensors->at("output_ids").shape[0]; + const size_t beam_width = output_tensors->at("output_ids").shape[1]; + int* sequence_lengths = output_tensors->at("sequence_length").getPtr(); + const int max_input_length = input_tensors->at("input_ids").shape[1]; + const bool has_prefix_soft_prompt_ = input_tensors->count("prefix_soft_prompt_embedding"); + const size_t max_prefix_soft_prompt_length = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_embedding").shape[1] : 0; + + cudaAutoCpy(sequence_lengths, sequence_lengths_, output_tensors->at("sequence_length").size(), stream_); + if (input_tensors->at("input_ids").shape[1] == 0) { + // TODO: D2D sequence_lenghts + if (beam_width > 1) { + // For beam search, do gather_tree + // take output_parent_ids as inter buffer + invokeGatherTree(transposed_output_ids_buf_, + sequence_lengths_, + session_len, + batch_size, + beam_width, + output_ids_buf_ + batch_size * beam_width, + parent_ids_buf_ + batch_size * beam_width, + end_ids_buf_, + stream_); + + // transpose and take output_parent_ids as inter buffer + invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, + transposed_output_ids_buf_, + gen_len - 1, + batch_size * beam_width, + 1, + stream_); } - NCCLCHECK(ncclGroupEnd()); - check_cuda_error(cudaStreamSynchronize(stream_)); + else { + // For sampling, only copy the results to output_tensor + invokeTransposeAxis01((int*)output_tensors->at("output_ids").data, + output_ids_buf_ + batch_size * beam_width, + gen_len - 1, + batch_size * beam_width, + 1, + stream_); + } + } + else { + // add sequence_length 1 here because the sequence_length of time step t is t - 1 + invokePlusScalar(sequence_lengths, 1, batch_size * beam_width, stream_); + + // For sampling, it is equivalent to all parent ids are 0. + gatherTreeParam param; + param.beams = transposed_output_ids_buf_; + param.max_sequence_lengths = sequence_lengths; + param.max_time = gen_len; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.step_ids = output_ids_buf_; + param.parent_ids = beam_width == 1 ? nullptr : parent_ids_buf_; + param.end_tokens = end_ids_buf_; + param.max_input_length = max_context_len; + param.prefix_soft_prompt_lengths = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_lengths").getPtr() : nullptr; + param.input_lengths = tiled_input_lengths_buf_; + param.p_prompt_tuning_prompt_lengths = has_p_prompt_tuning_ ? tiled_prompt_lengths_buf_ : nullptr; + param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; + param.stream = stream_; + param.output_ids = (int*)output_tensors->at("output_ids").data; + invokeGatherTree(param); + sync_check_cuda_error(); + } + if ((output_tensors->count("output_log_probs") > 0 && output_tensors->at("output_log_probs").data != nullptr)) { + invokeTransposeAxis01(output_tensors->at("output_log_probs").getPtr(), + output_log_probs_buf_, + input_tensors->at("output_seq_len").max() - max_input_length, + batch_size * beam_width, + 1, + stream_); + } + // Return the cumulative log probability if requested. + if (output_tensors->count("cum_log_probs") > 0) { + Tensor cum_log_probs = output_tensors->at("cum_log_probs"); + FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, + "The shape of cum_log_probs does not match with batch_size x beam_width."); + cudaAutoCpy(cum_log_probs.getPtr(), cum_log_probs_, cum_log_probs.size(), stream_); } } diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h index 6533f45b6..660225557 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h @@ -38,71 +38,101 @@ class ParallelGpt: public BaseLayer { size_t num_layer_; size_t vocab_size_; - int start_id_; - int end_id_; - float beam_search_diversity_rate_; + int start_id_; + int end_id_; + float beam_search_diversity_rate_; size_t hidden_units_; + const float layernorm_eps_; // OPT + float shared_contexts_ratio_; + // TODO(bhsueh) remove these member because they are runtime parameters - size_t top_k_; - float top_p_; + size_t top_k_; + float top_p_; unsigned long long random_seed_; float temperature_; float len_penalty_; float repetition_penalty_; - size_t local_head_num_; + size_t local_head_num_; NcclParam tensor_para_; NcclParam pipeline_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; const bool is_context_qk_buf_float_ = true; - size_t vocab_size_padded_; - const int int8_mode_ = 0; + size_t vocab_size_padded_; + const int int8_mode_ = 0; + bool remove_padding_ = true; + + // Prompt Learning Parameters + PromptLearningType prompt_learning_type_; + int prompt_learning_start_id_; // start_id for prompt_learning (only needed by prefix prompts) + bool has_p_prompt_tuning_; + bool has_prefix_prompt_; + bool has_prefix_soft_prompt_; - ParallelGptDecoder* gpt_decoder_; + // GPT Variants parameters: e.g. Meta OPT + gptVariantParams gpt_variant_params_; + + ParallelGptDecoder* gpt_decoder_; ParallelGptContextDecoder* gpt_context_decoder_; - DynamicDecodeLayer* dynamic_decode_layer_; + DynamicDecodeLayer* dynamic_decode_layer_; void allocateBuffer() override; void allocateBuffer(size_t batch_size, size_t beam_width, size_t max_seq_len, + size_t memory_len, size_t max_input_len, - bool is_return_context_cum_log_probs); + bool is_return_context_cum_log_probs); void freeBuffer() override; void initialize(); - void computeContextCumLogProbs(float* cum_log_probs, - const T* context_decoder_outputs, - const int* input_ids, - const int* input_lengths, - const size_t batch_size, - const size_t beam_width, - const size_t max_input_length, + void computeContextCumLogProbs(float* cum_log_probs, + const T* context_decoder_outputs, + const int* input_ids, + const int* input_lengths, + const size_t batch_size, + const size_t beam_width, + const size_t max_input_length, const ParallelGptWeight* gpt_weights); protected: - T* padded_embedding_kernel_; + // For stateful processing (interactive generation) + int step_; + size_t session_len_; + size_t memory_len_; + int* tiled_total_padding_count_ = nullptr; + + T* padded_embedding_kernel_; const T* padded_embedding_kernel_ptr_; T* input_attention_mask_; - T* decoder_input_buf_; - T* decoder_output_buf_; - T* normed_decoder_output_buf_; - float* logits_buf_; - float* nccl_logits_buf_; - float* cum_log_probs_; - bool* finished_buf_; - bool* h_finished_buf_; - - T* key_cache_; - T* value_cache_; + T* decoder_input_buf_; + T* decoder_output_buf_; + T* normed_decoder_output_buf_; + float* logits_buf_; + float* nccl_logits_buf_; + float* cum_log_probs_; + bool* finished_buf_; + bool* h_finished_buf_; + int* sequence_lengths_ = nullptr; + uint32_t* seq_limit_len_ = nullptr; + bool* generation_should_stop_ = nullptr; + + int* shared_contexts_idx_ = nullptr; + T* compact_decoder_features_ = nullptr; + int* compact_idx_ = nullptr; + int* batch_to_compact_idx_ = nullptr; + int* compact_size_ = nullptr; + + T* key_cache_; + T* value_cache_; int* cache_indirections_[2] = {nullptr, nullptr}; int* start_ids_buf_; @@ -111,68 +141,94 @@ class ParallelGpt: public BaseLayer { int* tiled_input_ids_buf_; int* tiled_input_lengths_buf_; - int* transposed_output_ids_buf_; - int* output_ids_buf_; - int* parent_ids_buf_; + // prompt_learning weight_batch ptrs + const T** prompt_learning_weight_batch_; + int* tiled_prompt_lengths_buf_; // only needed by prefix prompts - T* context_decoder_input_buf_; - T* context_decoder_output_buf_; + int* transposed_output_ids_buf_; + int* output_ids_buf_; + int* parent_ids_buf_; + bool* masked_tokens_ = nullptr; + + T* context_decoder_input_buf_; + T* context_decoder_output_buf_; float* output_log_probs_buf_; // buffers dedicated to log prob computation - T* lp_normed_decoder_output_buf_ = nullptr; - float* lp_logits_buf_ = nullptr; - float* lp_nccl_logits_buf_ = nullptr; - float* lp_logprob_buf_ = nullptr; + T* lp_normed_decoder_output_buf_ = nullptr; + float* lp_logits_buf_ = nullptr; + float* lp_nccl_logits_buf_ = nullptr; + float* lp_logprob_buf_ = nullptr; + + // function pointer callback + using callback_sig = void(std::unordered_map*, void*); + callback_sig* token_generated_cb_ = nullptr; + void* token_generated_ctx_ = nullptr; + + void setOutputTensors(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const size_t gen_len, + const size_t session_len, + const size_t max_context_len); + void sendTensorsToFirstPipelineNode(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors); public: - ParallelGpt(size_t max_batch_size, - size_t max_seq_len, - size_t max_input_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - unsigned long long random_seed, - float temperature, - float len_penalty, - float repetition_penalty, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop = nullptr, - bool sparse = false, - int int8_mode = 0, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + ParallelGpt(size_t max_batch_size, + size_t max_seq_len, + size_t max_input_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + gptVariantParams gpt_variant_params, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop = nullptr, + bool sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0, + bool remove_padding = true, + float shared_contexts_ratio = 1.0f); ParallelGpt(ParallelGpt const& gpt); ~ParallelGpt(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, const ParallelGptWeight* gpt_weights); - void forward(std::unordered_map* output_tensors, + void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const ParallelGptWeight* gpt_weights); + const ParallelGptWeight* gpt_weights); size_t getPipelineParallelRank(); size_t getPipelineParallelSize(); size_t getTensorParallelRank(); size_t getTensorParallelSize(); - bool* getFinishBuffer(); + bool* getFinishBuffer(); + + void registerCallback(callback_sig* fn, void* ctx); + void unRegisterCallback(); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.cc index 0af476996..80361587d 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.cc @@ -1,5 +1,6 @@ /* * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022, SK Telecom Authored by A. Dialog * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +16,7 @@ */ #include "src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.h" +#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" #include "src/fastertransformer/kernels/gpt_kernels.h" namespace fastertransformer { @@ -31,56 +33,91 @@ void ParallelGptContextDecoder::initialize() stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, is_qk_buf_float_, sparse_, custom_all_reduce_comm_, enable_custom_all_reduce_); - ffn_layer_ = new TensorParallelGeluFfnLayer(max_batch_size_, - max_seq_len_, - head_num_, - size_per_head_, - inter_size_, - tensor_para_, - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_, - sparse_, - 0, - custom_all_reduce_comm_, - enable_custom_all_reduce_); + bool use_gated_activation = activation_type_ == ActivationType::GeGLU || activation_type_ == ActivationType::ReGLU; + size_t max_inter_size = has_adapters_ ? std::max(inter_size_, adapter_inter_size_) : inter_size_; + if (activation_type_ == ActivationType::Gelu || activation_type_ == ActivationType::GeGLU) { + ffn_layer_ = new TensorParallelGeluFfnLayer(max_batch_size_, + max_seq_len_, + head_num_, + size_per_head_, + max_inter_size, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + true, + is_free_buffer_after_forward_, + sparse_, + 0, + use_gated_activation, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + } + else if (activation_type_ == ActivationType::Relu || activation_type_ == ActivationType::ReGLU) { + ffn_layer_ = new TensorParallelReluFfnLayer(max_batch_size_, + max_seq_len_, + head_num_, + size_per_head_, + max_inter_size, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + true, + is_free_buffer_after_forward_, + sparse_, + use_gated_activation, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + } } template void ParallelGptContextDecoder::allocateBuffer() { FT_CHECK(false); - if (is_allocate_buffer_ == false) { - decoder_normed_input_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); - self_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); - normed_self_attn_output_ = decoder_normed_input_; // reuse the buffer - decoder_layer_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); - is_allocate_buffer_ = true; - } } template -void ParallelGptContextDecoder::allocateBuffer(size_t batch_size, size_t seq_len) +void ParallelGptContextDecoder::allocateBuffer(size_t batch_size, size_t seq_len, bool use_shared_contexts) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); decoder_normed_input_ = reinterpret_cast( allocator_->reMalloc(decoder_normed_input_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); self_attn_output_ = reinterpret_cast( - allocator_->reMalloc(self_attn_output_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false)); + allocator_->reMalloc(self_attn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); normed_self_attn_output_ = decoder_normed_input_; // reuse the buffer + // only allocate additionl buffers when has adapters + after_adapter_attn_output_ = + has_adapters_ ? reinterpret_cast( + allocator_->reMalloc(after_adapter_attn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)) : + self_attn_output_; decoder_layer_output_ = reinterpret_cast( allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + token_num_ = reinterpret_cast(allocator_->reMalloc(token_num_, sizeof(size_t) * 1, false)); + padding_offset_ = + reinterpret_cast(allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false)); is_allocate_buffer_ = true; + + if (use_shared_contexts) { + compact_decoder_features_ = reinterpret_cast( + allocator_->reMalloc(compact_decoder_features_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + compact_attention_mask_ = reinterpret_cast( + allocator_->reMalloc(compact_attention_mask_, sizeof(T) * batch_size * seq_len * seq_len, false)); + compact_input_lengths_ = + reinterpret_cast(allocator_->reMalloc(compact_input_lengths_, sizeof(int) * batch_size, false)); + k_cache_layer_ = reinterpret_cast( + allocator_->reMalloc(k_cache_layer_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + v_cache_layer_ = reinterpret_cast( + allocator_->reMalloc(v_cache_layer_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + } } template @@ -88,39 +125,25 @@ void ParallelGptContextDecoder::freeBuffer() { if (is_allocate_buffer_) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - allocator_->free(decoder_normed_input_); - allocator_->free(self_attn_output_); - allocator_->free(decoder_layer_output_); + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + if (has_adapters_) { + allocator_->free((void**)(&after_adapter_attn_output_)); + } + allocator_->free((void**)(&decoder_layer_output_)); + allocator_->free((void**)(&token_num_)); + allocator_->free((void**)(&padding_offset_)); + if (compact_attention_mask_ != nullptr) { + allocator_->free((void**)(&compact_decoder_features_)); + allocator_->free((void**)(&compact_attention_mask_)); + allocator_->free((void**)(&compact_input_lengths_)); + allocator_->free((void**)(&k_cache_layer_)); + allocator_->free((void**)(&v_cache_layer_)); + } is_allocate_buffer_ = false; } } -template -bool ParallelGptContextDecoder::isValidBatchSize(size_t batch_size) -{ - if (batch_size <= max_batch_size_) { - return true; - } - else { - freeBuffer(); - max_batch_size_ = batch_size * 1.2; - return true; - } -} - -template -bool ParallelGptContextDecoder::isValidSeqLen(size_t seq_len) -{ - if (seq_len <= max_seq_len_) { - return true; - } - else { - freeBuffer(); - max_seq_len_ = seq_len * 1.2; - return true; - } -} - template bool ParallelGptContextDecoder::isValidLayerParallelId(uint l) { @@ -151,22 +174,25 @@ int ParallelGptContextDecoder::getFirstLayerParallelId() } template -ParallelGptContextDecoder::ParallelGptContextDecoder(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, +ParallelGptContextDecoder::ParallelGptContextDecoder(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + float layernorm_eps, + gptVariantParams gpt_variant_params, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce, + bool remove_padding): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -174,12 +200,18 @@ ParallelGptContextDecoder::ParallelGptContextDecoder(size_t max_batch_size, size_per_head_(size_per_head), inter_size_(inter_size), num_layer_(num_layer), + layernorm_eps_(layernorm_eps), + layernorm_type_(gpt_variant_params.layernorm_type), + activation_type_(gpt_variant_params.activation_type), + adapter_inter_size_(gpt_variant_params.adapter_inter_size), + has_adapters_(gpt_variant_params.has_adapters), hidden_units_(head_num_ * size_per_head), tensor_para_(tensor_para), pipeline_para_(pipeline_para), is_qk_buf_float_(is_qk_buf_float), custom_all_reduce_comm_(custom_all_reduce_comm), - enable_custom_all_reduce_(enable_custom_all_reduce) + enable_custom_all_reduce_(enable_custom_all_reduce), + remove_padding_(remove_padding) { initialize(); } @@ -193,12 +225,18 @@ ParallelGptContextDecoder::ParallelGptContextDecoder(ParallelGptContextDecode size_per_head_(decoder.size_per_head_), inter_size_(decoder.inter_size_), num_layer_(decoder.num_layer_), + layernorm_eps_(decoder.layernorm_eps_), + layernorm_type_(decoder.layernorm_type_), + activation_type_(decoder.activation_type_), + adapter_inter_size_(decoder.adapter_inter_size_), + has_adapters_(decoder.has_adapters_), hidden_units_(decoder.hidden_units_), tensor_para_(decoder.tensor_para_), pipeline_para_(decoder.pipeline_para_), is_qk_buf_float_(decoder.is_qk_buf_float_), custom_all_reduce_comm_(decoder.custom_all_reduce_comm_), - enable_custom_all_reduce_(decoder.enable_custom_all_reduce_) + enable_custom_all_reduce_(decoder.enable_custom_all_reduce_), + remove_padding_(decoder.remove_padding_) { initialize(); } @@ -213,14 +251,16 @@ ParallelGptContextDecoder::~ParallelGptContextDecoder() template void ParallelGptContextDecoder::forward( - std::vector* output_tensors, - const std::vector* input_tensors, + std::vector* output_tensors, + const std::vector* input_tensors, const std::vector*>* gpt_decoder_layer_weight) { // input tensors: // decoder_input [batch_size, seq_len, hidden_dimension], // attention_mask [batch_size, 1, seq_len, seq_len] // input_lengths [batch_size] + // compact_idx [compact_size] // optional + // batch_to_compact_idx [batch_size] // optional // output tensors: // decoder_output [batch_size, seq_len, hidden_dimension], @@ -233,16 +273,32 @@ void ParallelGptContextDecoder::forward( // computing. FT_LOG_DEBUG(__PRETTY_FUNCTION__); - FT_CHECK(input_tensors->size() == 3); FT_CHECK(output_tensors->size() == 4); - isValidBatchSize(input_tensors->at(0).shape[0]); - isValidSeqLen(input_tensors->at(0).shape[1]); - // allocateBuffer(); - allocateBuffer(max_batch_size_, max_seq_len_); - const size_t batch_size = (size_t)input_tensors->at(0).shape[0]; - const size_t seq_len = (size_t)input_tensors->at(0).shape[1]; - const DataType data_type = getTensorType(); + FT_CHECK(input_tensors->size() == 3 || input_tensors->size() == 5); + const bool use_shared_contexts = input_tensors->size() == 5; + + const size_t batch_size = + use_shared_contexts ? input_tensors->at(3).shape[0] : (size_t)input_tensors->at(0).shape[0]; + const size_t seq_len = input_tensors->at(0).shape[1]; + const size_t hidden_dimension = input_tensors->at(0).shape[2]; + const size_t max_seq_len = output_tensors->at(2).shape[3]; + const DataType data_type = getTensorType(); + allocateBuffer(batch_size, seq_len, use_shared_contexts); + + if (use_shared_contexts) { + invokeCompactInputs(compact_decoder_features_, + compact_attention_mask_, + compact_input_lengths_, + input_tensors->at(0).getPtr(), + input_tensors->at(1).getPtr(), + input_tensors->at(2).getPtr(), + input_tensors->at(3).getPtr(), + batch_size, + seq_len, + hidden_dimension, + stream_); + } const size_t local_batch_size = getLocalBatchSize(batch_size, seq_len, pipeline_para_.world_size_); FT_CHECK(batch_size % local_batch_size == 0); @@ -259,24 +315,59 @@ void ParallelGptContextDecoder::forward( self_v_cache_size.push_back(*t); } + if (use_shared_contexts) { + // we use k_cache_layer_ and v_cache_layer_ + self_k_cache_size[3] = seq_len; + self_v_cache_size[2] = seq_len; + } + for (uint ite = 0; ite < iteration_num; ite++) { + size_t h_token_num = local_batch_size * seq_len; + if (remove_padding_) { + const int* base_input_lengths = + (use_shared_contexts ? compact_input_lengths_ : input_tensors->at(2).getPtr()); + invokeGetPaddingOffset(&h_token_num, + token_num_, + padding_offset_, + base_input_lengths + ite * local_batch_size, + local_batch_size, + seq_len, + stream_); + } + for (uint l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l) == false) { continue; } - // const bool is_final = l == (num_layer_ - 1); - const bool is_final = false; // TODO(bhsueh) remove this flag - T* decoder_input = - (l == 0) ? (T*)(input_tensors->at(0).data) + (int)(ite * local_batch_size * seq_len * hidden_units_) : - decoder_layer_output_; - T* decoder_output = - (l == (num_layer_ - 1)) ? - (T*)(output_tensors->at(0).data) + (int)(ite * local_batch_size * seq_len * hidden_units_) : - decoder_layer_output_; + if (l == 0 && remove_padding_) { + const T* base_input = + (use_shared_contexts ? compact_decoder_features_ : input_tensors->at(0).getPtr()); + invokeRemovePadding(decoder_layer_output_, + base_input + ite * local_batch_size * seq_len * hidden_units_, + padding_offset_, + h_token_num, + hidden_units_, + stream_); + } + + const bool is_final = false; // TODO(bhsueh) remove this flag + T* decoder_input = decoder_layer_output_; + T* decoder_output = decoder_layer_output_; + if (!remove_padding_) { + if (l == 0) { + decoder_input = use_shared_contexts ? compact_decoder_features_ : input_tensors->at(0).getPtr(); + decoder_input += ite * local_batch_size * seq_len * hidden_units_; + } + if (l == num_layer_ - 1) { + decoder_output = + use_shared_contexts ? compact_decoder_features_ : output_tensors->at(0).getPtr(); + decoder_output += ite * local_batch_size * seq_len * hidden_units_; + } + } if (isFirstLayerParallelId(l) && pipeline_para_.rank_ != 0) { - const int data_size = local_batch_size * seq_len * hidden_units_ / tensor_para_.world_size_; + const int data_size = h_token_num * hidden_units_ / tensor_para_.world_size_; ftNcclRecv(decoder_input + data_size * tensor_para_.rank_, data_size, pipeline_para_.rank_ - 1, @@ -287,89 +378,225 @@ void ParallelGptContextDecoder::forward( } } - invokeGeneralLayerNorm(decoder_normed_input_, - decoder_input, - gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, - gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.beta, - local_batch_size * seq_len, - hidden_units_, - stream_); + if (layernorm_type_ == LayerNormType::pre_layernorm) { + invokeGeneralLayerNorm(decoder_normed_input_, + decoder_input, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.beta, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + } sync_check_cuda_error(); std::vector self_attention_input_tensors{ - Tensor{MEMORY_GPU, data_type, {local_batch_size * seq_len, hidden_units_}, decoder_normed_input_}, + Tensor{MEMORY_GPU, + data_type, + {h_token_num, hidden_units_}, + layernorm_type_ == LayerNormType::pre_layernorm ? decoder_normed_input_ : decoder_input}, Tensor{MEMORY_GPU, data_type, {local_batch_size, 1, seq_len, seq_len}, (const T*)input_tensors->at(1).data + local_batch_size * ite * seq_len * seq_len}, - Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_final}}; - - size_t cache_offset = l - getFirstLayerParallelId(); - for (auto t = output_tensors->at(1).shape.begin() + 1; t != output_tensors->at(1).shape.end(); ++t) { - cache_offset *= *t; - }; - size_t ite_cache_offset = ite * local_batch_size; - for (auto t = output_tensors->at(1).shape.begin() + 2; t != output_tensors->at(1).shape.end(); ++t) { - ite_cache_offset *= *t; + Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_final}, + Tensor{MEMORY_GPU, + data_type, + {(size_t)local_batch_size, (size_t)l}, + nullptr}, // prefix prompt weight batch + Tensor{MEMORY_GPU, TYPE_INT32, {(size_t)local_batch_size}, nullptr}, // prefix prompt lengths + Tensor{MEMORY_CPU, TYPE_INT32, {(size_t)1}, &l}, // layer_id + Tensor{MEMORY_GPU, TYPE_INT32, {h_token_num}, (remove_padding_ ? padding_offset_ : nullptr)}}; + + size_t cache_stride_batch = 1; + for (auto it = output_tensors->at(1).shape.begin() + 2; it != output_tensors->at(1).shape.end(); ++it) { + cache_stride_batch *= *it; } - cache_offset += ite_cache_offset; + + const size_t cache_layer_offset = + (l - getFirstLayerParallelId()) * output_tensors->at(1).shape[1] * cache_stride_batch; + const size_t ite_cache_offset = ite * local_batch_size * cache_stride_batch; + const size_t cache_offset = cache_layer_offset + ite_cache_offset; + + T* k_cache_ptr = use_shared_contexts ? k_cache_layer_ : output_tensors->at(1).getPtr() + cache_offset; + T* v_cache_ptr = use_shared_contexts ? v_cache_layer_ : output_tensors->at(2).getPtr() + cache_offset; std::vector self_attention_output_tensors{ - Tensor{MEMORY_GPU, data_type, {local_batch_size * seq_len, hidden_units_}, self_attn_output_}, - Tensor{MEMORY_GPU, data_type, self_k_cache_size, ((const T*)output_tensors->at(1).data) + cache_offset}, - Tensor{ - MEMORY_GPU, data_type, self_v_cache_size, ((const T*)output_tensors->at(2).data) + cache_offset}}; + Tensor{MEMORY_GPU, data_type, {h_token_num, hidden_units_}, self_attn_output_}, + Tensor{MEMORY_GPU, data_type, self_k_cache_size, k_cache_ptr}, + Tensor{MEMORY_GPU, data_type, self_v_cache_size, v_cache_ptr}}; self_attention_layer_->forward(&self_attention_output_tensors, &self_attention_input_tensors, &gpt_decoder_layer_weight->at(l)->self_attention_weights); - if (is_final == false) { + if (use_shared_contexts) { + // Even with local batches, we must process the whole K/V caches as any + // element in batch_idx_to_compact_idx may reference the local batch + // we're processing. We also need to discard references that aren't in + // that particular local batch. + invokeUnCompactCaches(output_tensors->at(1).getPtr() + cache_layer_offset, + output_tensors->at(2).getPtr() + cache_layer_offset, + k_cache_layer_, + v_cache_layer_, + input_tensors->at(4).getPtr(), + output_tensors->at(2).shape[1], // batch_size (uncompact) + output_tensors->at(2).shape[2], // local_head_num + max_seq_len, + seq_len, + size_per_head_, + local_batch_size, + ite, + stream_); + sync_check_cuda_error(); + } + + // the adapter after attention (only pre layernorm currently) + if (has_adapters_) { + invokeAddBias(self_attn_output_, + gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, + h_token_num, + hidden_units_, + stream_); + + std::vector ffn_input_tensors{ + Tensor{MEMORY_GPU, data_type, {h_token_num, hidden_units_}, self_attn_output_}}; + std::vector ffn_output_tensors{ + Tensor{MEMORY_GPU, data_type, {h_token_num, hidden_units_}, after_adapter_attn_output_}}; + + ffn_layer_->resetInterSize(adapter_inter_size_ / tensor_para_.world_size_); + ffn_layer_->forward(&ffn_output_tensors, + &ffn_input_tensors, + &gpt_decoder_layer_weight->at(l)->after_attention_adapter_weights); + } + + if (layernorm_type_ == LayerNormType::pre_layernorm) { invokeGeneralAddBiasResidualPreLayerNorm( - self_attn_output_, + after_adapter_attn_output_, normed_self_attn_output_, decoder_input, + has_adapters_ ? self_attn_output_ : nullptr, gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.gamma, gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.beta, - gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, - local_batch_size * seq_len, + has_adapters_ ? + gpt_decoder_layer_weight->at(l)->after_attention_adapter_weights.output_weight.bias : + gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, + layernorm_eps_, + h_token_num, hidden_units_, stream_); - sync_check_cuda_error(); + } + else if (layernorm_type_ == LayerNormType::post_layernorm) { + invokeAddBiasResidualLayerNorm( + after_adapter_attn_output_, + decoder_input, + has_adapters_ ? + gpt_decoder_layer_weight->at(l)->after_attention_adapter_weights.output_weight.bias : + gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.beta, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + } + sync_check_cuda_error(); - std::vector ffn_input_tensors{Tensor{ - MEMORY_GPU, data_type, {local_batch_size * seq_len, hidden_units_}, normed_self_attn_output_}}; + T* ffn_output_ptr = has_adapters_ ? self_attn_output_ : decoder_output; + + std::vector ffn_input_tensors{Tensor{MEMORY_GPU, + data_type, + {h_token_num, hidden_units_}, + layernorm_type_ == LayerNormType::pre_layernorm ? + normed_self_attn_output_ : + after_adapter_attn_output_}}; + std::vector ffn_output_tensors{ + Tensor{MEMORY_GPU, data_type, {h_token_num, hidden_units_}, ffn_output_ptr}}; + + ffn_layer_->resetInterSize(inter_size_ / tensor_para_.world_size_); + ffn_layer_->forward(&ffn_output_tensors, &ffn_input_tensors, &gpt_decoder_layer_weight->at(l)->ffn_weights); + + // the adapter after ffn (only pre layernorm currently) + if (has_adapters_) { + invokeAddBias(ffn_output_ptr, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + h_token_num, + hidden_units_, + stream_); + + std::vector ffn_input_tensors{ + Tensor{MEMORY_GPU, data_type, {h_token_num, hidden_units_}, ffn_output_ptr}}; std::vector ffn_output_tensors{ - Tensor{MEMORY_GPU, data_type, {local_batch_size * seq_len, hidden_units_}, decoder_output}}; - - ffn_layer_->forward( - &ffn_output_tensors, &ffn_input_tensors, &gpt_decoder_layer_weight->at(l)->ffn_weights); - invokeAddBiasResidual(decoder_output, - self_attn_output_, - gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, - local_batch_size * seq_len, - hidden_units_, - stream_); - sync_check_cuda_error(); + Tensor{MEMORY_GPU, data_type, {h_token_num, hidden_units_}, decoder_output}}; - if (isLastLayerParallelId(l) == true && pipeline_para_.rank_ != pipeline_para_.world_size_ - 1) { - const int data_size = local_batch_size * seq_len * hidden_units_ / tensor_para_.world_size_; - ftNcclSend(decoder_output + data_size * tensor_para_.rank_, - data_size, - pipeline_para_.rank_ + 1, - pipeline_para_, - stream_); - } + ffn_layer_->resetInterSize(adapter_inter_size_ / tensor_para_.world_size_); + ffn_layer_->forward(&ffn_output_tensors, + &ffn_input_tensors, + &gpt_decoder_layer_weight->at(l)->after_ffn_adapter_weights); + } + + if (layernorm_type_ == LayerNormType::pre_layernorm) { + invokeAddBiasResidual( + decoder_output, + after_adapter_attn_output_, + has_adapters_ ? ffn_output_ptr : nullptr, + has_adapters_ ? gpt_decoder_layer_weight->at(l)->after_ffn_adapter_weights.output_weight.bias : + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + h_token_num, + hidden_units_, + stream_); + } + else if (layernorm_type_ == LayerNormType::post_layernorm) { + invokeAddBiasResidualLayerNorm( + decoder_output, + after_adapter_attn_output_, + has_adapters_ ? gpt_decoder_layer_weight->at(l)->after_ffn_adapter_weights.output_weight.bias : + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.beta, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + } + sync_check_cuda_error(); + + if (isLastLayerParallelId(l) == true && (pipeline_para_.rank_ != pipeline_para_.world_size_ - 1)) { + const int data_size = h_token_num * hidden_units_ / tensor_para_.world_size_; + ftNcclSend(decoder_output + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.rank_ + 1, + pipeline_para_, + stream_); + } + + if ((l == num_layer_ - 1) && remove_padding_) { + T* base_ptr = use_shared_contexts ? compact_decoder_features_ : output_tensors->at(0).getPtr(); + invokeRebuildPadding(base_ptr + ite * local_batch_size * seq_len * hidden_units_, + decoder_layer_output_, + padding_offset_, + h_token_num, + head_num_ * size_per_head_, + stream_); } } } + if (use_shared_contexts) { + invokeUnCompactOutputs(output_tensors->at(0).getPtr(), + compact_decoder_features_, + input_tensors->at(4).getPtr(), + output_tensors->at(2).shape[1], // batch + seq_len * hidden_units_, + stream_); + } + // TODO(bhsueh) We could optimize this point by only computing the last token for the last layer invokeLookupHiddenStateOfLastToken((T*)output_tensors->at(3).data, (T*)output_tensors->at(0).data, (int*)input_tensors->at(2).data, seq_len, - batch_size, + input_tensors->at(0).shape[0], hidden_units_, stream_); diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.h b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.h index d53c5b374..c6c305007 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.h +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.h @@ -22,6 +22,7 @@ #include "src/fastertransformer/kernels/layernorm_kernels.h" #include "src/fastertransformer/layers/BaseLayer.h" #include "src/fastertransformer/layers/TensorParallelGeluFfnLayer.h" +#include "src/fastertransformer/layers/TensorParallelReluFfnLayer.h" #include "src/fastertransformer/layers/attention_layers/TensorParallelGptContextAttentionLayer.h" #include "src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h" #include "src/fastertransformer/utils/Tensor.h" @@ -36,13 +37,21 @@ class ParallelGptContextDecoder: public BaseLayer { private: // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; // meta data - size_t head_num_; - size_t size_per_head_; - size_t inter_size_; - size_t num_layer_; + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + float layernorm_eps_; + LayerNormType layernorm_type_; + ActivationType activation_type_; + + // adapter + bool has_adapters_; + size_t adapter_inter_size_; + T* after_adapter_attn_output_; // calculated data size_t hidden_units_; @@ -51,54 +60,64 @@ class ParallelGptContextDecoder: public BaseLayer { NcclParam pipeline_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; + bool remove_padding_; bool is_qk_buf_float_; BaseAttentionLayer* self_attention_layer_; - FfnLayer* ffn_layer_; + FfnLayer* ffn_layer_; void allocateBuffer() override; - void allocateBuffer(size_t batch_size, size_t seq_len); + void allocateBuffer(size_t batch_size, size_t seq_len, bool use_shared_contexts); void freeBuffer() override; - bool isValidBatchSize(size_t batch_size); - bool isValidSeqLen(size_t seq_len); bool isValidLayerParallelId(uint l); void initialize(); bool isFirstLayerParallelId(uint l); bool isLastLayerParallelId(uint l); - int getFirstLayerParallelId(); + int getFirstLayerParallelId(); - T* decoder_normed_input_ = nullptr; - T* self_attn_output_ = nullptr; - T* normed_self_attn_output_ = nullptr; - T* decoder_layer_output_ = nullptr; + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; + T* normed_self_attn_output_ = nullptr; + T* decoder_layer_output_ = nullptr; + size_t* token_num_ = nullptr; + int* padding_offset_ = nullptr; + + T* compact_decoder_features_ = nullptr; + T* compact_attention_mask_ = nullptr; + int* compact_input_lengths_ = nullptr; + T* k_cache_layer_ = nullptr; + T* v_cache_layer_ = nullptr; protected: public: - ParallelGptContextDecoder(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool is_qk_buf_float, - bool sparse = false, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce_ = 0); + ParallelGptContextDecoder(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + float layernorm_eps, + gptVariantParams gpt_variant_params, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0, + bool remove_padding = true); ParallelGptContextDecoder(ParallelGptContextDecoder const& decoder); ~ParallelGptContextDecoder(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector*>* decoder_layer_weights); }; diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc index 4822a4c1d..321a944c1 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc @@ -1,5 +1,6 @@ /* * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022, SK Telecom Authored by A. Dialog * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,50 +30,81 @@ void ParallelGptDecoder::initialize() stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, sparse_, int8_mode_, custom_all_reduce_comm_, enable_custom_all_reduce_); - ffn_layer_ = new TensorParallelGeluFfnLayer(max_batch_size_, - 1, - head_num_, - size_per_head_, - inter_size_, - tensor_para_, - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_, - sparse_, - int8_mode_, - custom_all_reduce_comm_, - enable_custom_all_reduce_); + bool use_gated_activation = activation_type_ == ActivationType::GeGLU || activation_type_ == ActivationType::ReGLU; + size_t max_inter_size = has_adapters_ ? std::max(inter_size_, adapter_inter_size_) : inter_size_; + if (activation_type_ == ActivationType::Gelu || activation_type_ == ActivationType::GeGLU) { + ffn_layer_ = new TensorParallelGeluFfnLayer(max_batch_size_, + 1, + head_num_, + size_per_head_, + max_inter_size, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + true, + is_free_buffer_after_forward_, + sparse_, + int8_mode_, + use_gated_activation, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + } + else if (activation_type_ == ActivationType::Relu || activation_type_ == ActivationType::ReGLU) { + ffn_layer_ = new TensorParallelReluFfnLayer(max_batch_size_, + 1, + head_num_, + size_per_head_, + max_inter_size, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + true, + is_free_buffer_after_forward_, + sparse_, + use_gated_activation, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + } } template -ParallelGptDecoder::ParallelGptDecoder(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse, - int int8_mode, +ParallelGptDecoder::ParallelGptDecoder(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + float layernorm_eps, + gptVariantParams gpt_variant_params, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse), max_batch_size_(max_batch_size), head_num_(head_num), size_per_head_(size_per_head), inter_size_(inter_size), num_layer_(num_layer), + layernorm_eps_(layernorm_eps), + layernorm_type_(gpt_variant_params.layernorm_type), + activation_type_(gpt_variant_params.activation_type), + adapter_inter_size_(gpt_variant_params.adapter_inter_size), + has_adapters_(gpt_variant_params.has_adapters), hidden_units_(head_num_ * size_per_head_), tensor_para_(tensor_para), pipeline_para_(pipeline_para), @@ -96,6 +128,11 @@ ParallelGptDecoder::ParallelGptDecoder(ParallelGptDecoder const& decoder): size_per_head_(decoder.size_per_head_), inter_size_(decoder.inter_size_), num_layer_(decoder.num_layer_), + layernorm_eps_(decoder.layernorm_eps_), + layernorm_type_(decoder.layernorm_type_), + activation_type_(decoder.activation_type_), + adapter_inter_size_(decoder.adapter_inter_size_), + has_adapters_(decoder.has_adapters_), hidden_units_(decoder.hidden_units_), tensor_para_(decoder.tensor_para_), pipeline_para_(decoder.pipeline_para_), @@ -110,17 +147,6 @@ template void ParallelGptDecoder::allocateBuffer() { FT_CHECK(false); - if (is_allocate_buffer_ == false) { - decoder_layer_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - decoder_normed_input_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - self_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - normed_self_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * hidden_units_, false)); - is_allocate_buffer_ = true; - } } template @@ -135,7 +161,11 @@ void ParallelGptDecoder::allocateBuffer(size_t batch_size) reinterpret_cast(allocator_->reMalloc(self_attn_output_, sizeof(T) * batch_size * hidden_units_, false)); normed_self_attn_output_ = reinterpret_cast( allocator_->reMalloc(normed_self_attn_output_, sizeof(T) * batch_size * hidden_units_, false)); - is_allocate_buffer_ = true; + // only allocate additionl buffers when has adapters + after_adapter_attn_output_ = has_adapters_ ? reinterpret_cast(allocator_->reMalloc( + after_adapter_attn_output_, sizeof(T) * batch_size * hidden_units_, false)) : + self_attn_output_; + is_allocate_buffer_ = true; } template @@ -143,27 +173,17 @@ void ParallelGptDecoder::freeBuffer() { if (is_allocate_buffer_) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - allocator_->free(decoder_layer_output_); - allocator_->free(decoder_normed_input_); - allocator_->free(self_attn_output_); - allocator_->free(normed_self_attn_output_); + allocator_->free((void**)(&decoder_layer_output_)); + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + allocator_->free((void**)(&normed_self_attn_output_)); + if (has_adapters_) { + allocator_->free((void**)(&after_adapter_attn_output_)); + } is_allocate_buffer_ = false; } } -template -bool ParallelGptDecoder::isValidBatchSize(size_t batch_size) -{ - if (batch_size <= max_batch_size_) { - return true; - } - else { - freeBuffer(); - max_batch_size_ = batch_size * 1.2; - return true; - } -} - template bool ParallelGptDecoder::isValidLayerParallelId(uint l) { @@ -202,35 +222,36 @@ ParallelGptDecoder::~ParallelGptDecoder() } template -void ParallelGptDecoder::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void ParallelGptDecoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector*>* gpt_decoder_layer_weight) { // input tensors: // decoder_input [local_batch_size, hidden_dimension], // finished [local_batch_size], - // sequence_lengths [local_batch_size] // input_lengths [local_batch_size], + // total_padding_tokens [local_batch_size] // max_input_length [1] on cpu // step [1] on cpu // ite [1] on cpu - // cache_indirection [local_batch_size / beam_width, beam_width, max_seq_len] + // cache_indirection [local_batch_size / beam_width, beam_width, memory_len] // Here, local_batch_size contains the beam_width, so local_batch_size / beam_width // is real local_batch_size. + // masked_tokens [local_batch_size, memory_len] // output tensors: // decoder_output [local_batch_size, hidden_dimension], - // key_cache [num_layer, batch_size, head_num, size_per_head // x, max_seq_len, x] - // value_cache [num_layer, batch_size, head_num, max_seq_len, size_per_head] + // key_cache [num_layer, batch_size, head_num, size_per_head // x, memory_len, x] + // value_cache [num_layer, batch_size, head_num, memory_len, size_per_head] FT_LOG_DEBUG(__PRETTY_FUNCTION__); - FT_CHECK(input_tensors->size() == 8); + FT_CHECK(input_tensors->size() == 9); FT_CHECK(output_tensors->size() == 3); const size_t local_batch_size = input_tensors->at(0).shape[0]; allocateBuffer(local_batch_size); const DataType data_type = getTensorType(); - const int ite = *((int*)(input_tensors->at(6).data)); + const int ite = *((int*)(input_tensors->at(6).data)); std::vector self_k_cache_size; self_k_cache_size.push_back(local_batch_size); @@ -247,7 +268,7 @@ void ParallelGptDecoder::forward(std::vector* output_tensors, if (isValidLayerParallelId(l) == false) { continue; } - T* decoder_input = (T*)((l == 0) ? input_tensors->at(0).data : decoder_layer_output_); + T* decoder_input = (T*)((l == 0) ? input_tensors->at(0).data : decoder_layer_output_); T* decoder_output = (T*)((l == num_layer_ - 1) ? output_tensors->at(0).data : decoder_layer_output_); if (isFirstLayerParallelId(l) == true && pipeline_para_.rank_ != 0 && pipeline_para_.world_size_ > 1) { @@ -279,23 +300,34 @@ void ParallelGptDecoder::forward(std::vector* output_tensors, } cache_offset += ite_cache_offset; - invokeGeneralLayerNorm(decoder_normed_input_, - decoder_input, - gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, - gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.beta, - local_batch_size, - hidden_units_, - stream_); + if (layernorm_type_ == LayerNormType::pre_layernorm) { + invokeGeneralLayerNorm(decoder_normed_input_, + decoder_input, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.beta, + layernorm_eps_, + local_batch_size, + hidden_units_, + stream_); + } sync_check_cuda_error(); + const int max_prefix_prompt_length = 0; std::vector self_attention_input_tensors{ - Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, decoder_normed_input_}, + Tensor{MEMORY_GPU, + data_type, + {local_batch_size, hidden_units_}, + layernorm_type_ == LayerNormType::pre_layernorm ? decoder_normed_input_ : decoder_input}, input_tensors->at(1), input_tensors->at(2), input_tensors->at(3), + Tensor{ + MEMORY_GPU, data_type, {(size_t)local_batch_size, (size_t)l}, nullptr}, // prefix prompt weight batch + Tensor{MEMORY_CPU, TYPE_INT32, {(size_t)1}, &max_prefix_prompt_length}, // max prefix prompt length input_tensors->at(4), input_tensors->at(5), - input_tensors->at(7)}; + input_tensors->at(7), + input_tensors->at(8)}; std::vector self_attention_output_tensors{ Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, self_attn_output_}, @@ -306,30 +338,111 @@ void ParallelGptDecoder::forward(std::vector* output_tensors, &self_attention_input_tensors, &gpt_decoder_layer_weight->at(l)->self_attention_weights); - invokeGeneralAddBiasResidualPreLayerNorm( - self_attn_output_, - normed_self_attn_output_, - decoder_input, - gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.gamma, - gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.beta, - gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, - local_batch_size, - hidden_units_, - stream_); + // the adapter after attention + if (has_adapters_) { + invokeAddBias(self_attn_output_, + gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, + local_batch_size, + hidden_units_, + stream_); + + std::vector ffn_input_tensors{ + Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, self_attn_output_}}; + std::vector ffn_output_tensors{ + Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, after_adapter_attn_output_}}; + + ffn_layer_->resetInterSize(adapter_inter_size_ / tensor_para_.world_size_); + ffn_layer_->forward(&ffn_output_tensors, + &ffn_input_tensors, + &gpt_decoder_layer_weight->at(l)->after_attention_adapter_weights); + } + + if (layernorm_type_ == LayerNormType::pre_layernorm) { + invokeGeneralAddBiasResidualPreLayerNorm( + after_adapter_attn_output_, + normed_self_attn_output_, + decoder_input, + has_adapters_ ? self_attn_output_ : nullptr, + gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.beta, + has_adapters_ ? gpt_decoder_layer_weight->at(l)->after_attention_adapter_weights.output_weight.bias : + gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, + layernorm_eps_, + local_batch_size, + hidden_units_, + stream_); + } + else if (layernorm_type_ == LayerNormType::post_layernorm) { + invokeAddBiasResidualLayerNorm( + after_adapter_attn_output_, + decoder_input, + has_adapters_ ? gpt_decoder_layer_weight->at(l)->after_attention_adapter_weights.output_weight.bias : + gpt_decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.beta, + layernorm_eps_, + local_batch_size, + hidden_units_, + stream_); + } sync_check_cuda_error(); - std::vector ffn_input_tensors{ - Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, normed_self_attn_output_}}; + T* ffn_output_ptr = has_adapters_ ? self_attn_output_ : decoder_output; + + std::vector ffn_input_tensors{Tensor{ + MEMORY_GPU, + data_type, + {local_batch_size, hidden_units_}, + layernorm_type_ == LayerNormType::pre_layernorm ? normed_self_attn_output_ : after_adapter_attn_output_}}; std::vector ffn_output_tensors{ - Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, decoder_output}}; + Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, ffn_output_ptr}}; + + ffn_layer_->resetInterSize(inter_size_ / tensor_para_.world_size_); ffn_layer_->forward(&ffn_output_tensors, &ffn_input_tensors, &gpt_decoder_layer_weight->at(l)->ffn_weights); - invokeAddBiasResidual(decoder_output, - self_attn_output_, - gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, - local_batch_size, - hidden_units_, - stream_); + + // the adapter after ffn + if (has_adapters_) { + invokeAddBias(ffn_output_ptr, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + local_batch_size, + hidden_units_, + stream_); + + std::vector ffn_input_tensors{ + Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, ffn_output_ptr}}; + std::vector ffn_output_tensors{ + Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, decoder_output}}; + + ffn_layer_->resetInterSize(adapter_inter_size_ / tensor_para_.world_size_); + ffn_layer_->forward( + &ffn_output_tensors, &ffn_input_tensors, &gpt_decoder_layer_weight->at(l)->after_ffn_adapter_weights); + } + + if (layernorm_type_ == LayerNormType::pre_layernorm) { + invokeAddBiasResidual(decoder_output, + after_adapter_attn_output_, + has_adapters_ ? ffn_output_ptr : nullptr, + has_adapters_ ? + gpt_decoder_layer_weight->at(l)->after_ffn_adapter_weights.output_weight.bias : + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + local_batch_size, + hidden_units_, + stream_); + } + else if (layernorm_type_ == LayerNormType::post_layernorm) { + invokeAddBiasResidualLayerNorm( + decoder_output, + after_adapter_attn_output_, + has_adapters_ ? gpt_decoder_layer_weight->at(l)->after_ffn_adapter_weights.output_weight.bias : + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->self_attn_layernorm_weights.beta, + layernorm_eps_, + local_batch_size, + hidden_units_, + stream_); + } sync_check_cuda_error(); if (isLastLayerParallelId(l) == true && pipeline_para_.rank_ != pipeline_para_.world_size_ - 1 diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.h b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.h index 43b665cfb..a11075437 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.h +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.h @@ -20,6 +20,7 @@ #include "src/fastertransformer/kernels/layernorm_kernels.h" #include "src/fastertransformer/layers/BaseLayer.h" #include "src/fastertransformer/layers/TensorParallelGeluFfnLayer.h" +#include "src/fastertransformer/layers/TensorParallelReluFfnLayer.h" #include "src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.h" #include "src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h" #include "src/fastertransformer/utils/Tensor.h" @@ -35,10 +36,18 @@ class ParallelGptDecoder: public BaseLayer { // buffer handling size_t max_batch_size_ = 0; // meta data - size_t head_num_; - size_t size_per_head_; - size_t inter_size_; - size_t num_layer_; + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + float layernorm_eps_; + LayerNormType layernorm_type_; + ActivationType activation_type_; + + // adapter + bool has_adapters_; + size_t adapter_inter_size_; + T* after_adapter_attn_output_; int int8_mode_ = 0; @@ -49,51 +58,52 @@ class ParallelGptDecoder: public BaseLayer { NcclParam pipeline_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; // buffers - T* decoder_normed_input_ = nullptr; - T* self_attn_output_ = nullptr; + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; T* normed_self_attn_output_ = nullptr; - T* decoder_layer_output_ = nullptr; + T* decoder_layer_output_ = nullptr; BaseAttentionLayer* self_attention_layer_; - FfnLayer* ffn_layer_; + FfnLayer* ffn_layer_; void initialize(); void allocateBuffer() override; void allocateBuffer(size_t batch_size); void freeBuffer() override; - bool isValidBatchSize(size_t batch_size); bool isValidLayerParallelId(uint l); bool isFirstLayerParallelId(uint l); bool isLastLayerParallelId(uint l); - int getFirstLayerParallelId(); + int getFirstLayerParallelId(); protected: public: - ParallelGptDecoder(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - NcclParam tensor_para, - NcclParam pipeline_para, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false, - int int8_mode = 0, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce_ = 0); + ParallelGptDecoder(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + float layernorm_eps, + gptVariantParams gpt_variant_params, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce_ = 0); ParallelGptDecoder(ParallelGptDecoder const& decoder); ~ParallelGptDecoder(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector*>* decoder_layer_weights); }; diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc index 5c20fa91f..bc98383b4 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc @@ -20,16 +20,18 @@ namespace fastertransformer { template -ParallelGptDecoderLayerWeight::ParallelGptDecoderLayerWeight(const int hidden_units, - const int inter_size, - const int tensor_para_size, - const int tensor_para_rank, - const int int8_mode): +ParallelGptDecoderLayerWeight::ParallelGptDecoderLayerWeight(const int hidden_units, + const int inter_size, + const int tensor_para_size, + const int tensor_para_rank, + const int int8_mode, + gptVariantParams gpt_variant_params): hidden_units_(hidden_units), inter_size_(inter_size), tensor_para_size_(tensor_para_size), tensor_para_rank_(tensor_para_rank), - int8_mode_(int8_mode) + int8_mode_(int8_mode), + gpt_variant_params_(gpt_variant_params) { mallocWeights(); setWeightPtr(); @@ -47,37 +49,63 @@ template ParallelGptDecoderLayerWeight::~ParallelGptDecoderLayerWeight() { if (is_maintain_buffer == true) { - for (int i = 0; i < 12; i++) { - deviceFree(weights_ptr[i]); + for (int i = 0; i < weights_ptr.size(); i++) { + if (weights_ptr[i] != nullptr) { + deviceFree(weights_ptr[i]); + } } - pre_layernorm_weights.beta = nullptr; - pre_layernorm_weights.gamma = nullptr; - self_attention_weights.query_weight.kernel = nullptr; - self_attention_weights.query_weight.bias = nullptr; + pre_layernorm_weights.beta = nullptr; + pre_layernorm_weights.gamma = nullptr; + self_attention_weights.query_weight.kernel = nullptr; + self_attention_weights.query_weight.bias = nullptr; self_attention_weights.attention_output_weight.kernel = nullptr; - self_attention_weights.attention_output_weight.bias = nullptr; - self_attn_layernorm_weights.beta = nullptr; - self_attn_layernorm_weights.gamma = nullptr; + self_attention_weights.attention_output_weight.bias = nullptr; + self_attn_layernorm_weights.beta = nullptr; + self_attn_layernorm_weights.gamma = nullptr; ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.kernel = nullptr; - ffn_weights.output_weight.bias = nullptr; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + + after_attention_adapter_weights.intermediate_weight.kernel = nullptr; + after_attention_adapter_weights.intermediate_weight.bias = nullptr; + after_attention_adapter_weights.output_weight.kernel = nullptr; + after_attention_adapter_weights.output_weight.bias = nullptr; + + after_ffn_adapter_weights.intermediate_weight.kernel = nullptr; + after_ffn_adapter_weights.intermediate_weight.bias = nullptr; + after_ffn_adapter_weights.output_weight.kernel = nullptr; + after_ffn_adapter_weights.output_weight.bias = nullptr; if (int8_mode_ != 0) { - for (int i = 0; i < 4; i++) { - deviceFree(int8_weights_ptr[i]); - deviceFree(scale_ptr[i]); + for (int i = 0; i < int8_weights_ptr.size(); i++) { + if (int8_weights_ptr[i] != nullptr) { + deviceFree(int8_weights_ptr[i]); + } } - self_attention_weights.query_weight.int8_kernel = nullptr; - self_attention_weights.query_weight.scale = nullptr; - self_attention_weights.attention_output_weight.int8_kernel = nullptr; - self_attention_weights.attention_output_weight.scale = nullptr; - ffn_weights.intermediate_weight.int8_kernel = nullptr; - ffn_weights.intermediate_weight.scale = nullptr; - ffn_weights.output_weight.int8_kernel = nullptr; - ffn_weights.output_weight.scale = nullptr; + for (int i = 0; i < scale_ptr.size(); i++) { + if (scale_ptr[i] != nullptr) { + deviceFree(scale_ptr[i]); + } + } + self_attention_weights.query_weight.int8_kernel = nullptr; + self_attention_weights.query_weight.scale = nullptr; + self_attention_weights.attention_output_weight.int8_kernel = nullptr; + self_attention_weights.attention_output_weight.scale = nullptr; + ffn_weights.intermediate_weight.int8_kernel = nullptr; + ffn_weights.intermediate_weight.scale = nullptr; + ffn_weights.output_weight.int8_kernel = nullptr; + ffn_weights.output_weight.scale = nullptr; + after_attention_adapter_weights.intermediate_weight.int8_kernel = nullptr; + after_attention_adapter_weights.intermediate_weight.scale = nullptr; + after_attention_adapter_weights.output_weight.int8_kernel = nullptr; + after_attention_adapter_weights.output_weight.scale = nullptr; + after_ffn_adapter_weights.intermediate_weight.int8_kernel = nullptr; + after_ffn_adapter_weights.intermediate_weight.scale = nullptr; + after_ffn_adapter_weights.output_weight.int8_kernel = nullptr; + after_ffn_adapter_weights.output_weight.scale = nullptr; } is_maintain_buffer = false; @@ -90,7 +118,8 @@ ParallelGptDecoderLayerWeight::ParallelGptDecoderLayerWeight(const ParallelGp inter_size_(other.inter_size_), tensor_para_size_(other.tensor_para_size_), tensor_para_rank_(other.tensor_para_rank_), - int8_mode_(other.int8_mode_) + int8_mode_(other.int8_mode_), + gpt_variant_params_(other.gpt_variant_params_) { mallocWeights(); cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], hidden_units_); @@ -107,6 +136,25 @@ ParallelGptDecoderLayerWeight::ParallelGptDecoderLayerWeight(const ParallelGp cudaD2Dcpy(weights_ptr[10], other.weights_ptr[10], inter_size_ / tensor_para_size_ * hidden_units_); cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], hidden_units_); + if (gpt_variant_params_.has_adapters) { + cudaD2Dcpy(weights_ptr[12], + other.weights_ptr[12], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(weights_ptr[13], other.weights_ptr[13], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(weights_ptr[14], + other.weights_ptr[14], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[15], other.weights_ptr[15], hidden_units_); + cudaD2Dcpy(weights_ptr[16], + other.weights_ptr[16], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(weights_ptr[17], other.weights_ptr[17], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(weights_ptr[18], + other.weights_ptr[18], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[19], other.weights_ptr[19], hidden_units_); + } + if (int8_mode_ != 0) { cudaD2Dcpy( int8_weights_ptr[0], other.int8_weights_ptr[0], hidden_units_ * 3 * hidden_units_ / tensor_para_size_); @@ -117,6 +165,24 @@ ParallelGptDecoderLayerWeight::ParallelGptDecoderLayerWeight(const ParallelGp cudaD2Dcpy(scale_ptr[1], other.scale_ptr[1], hidden_units_); cudaD2Dcpy(scale_ptr[2], other.scale_ptr[2], inter_size_ / tensor_para_size_); cudaD2Dcpy(scale_ptr[3], other.scale_ptr[3], hidden_units_); + if (gpt_variant_params_.has_adapters) { + cudaD2Dcpy(int8_weights_ptr[4], + other.int8_weights_ptr[4], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(int8_weights_ptr[5], + other.int8_weights_ptr[5], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(int8_weights_ptr[6], + other.int8_weights_ptr[6], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(int8_weights_ptr[7], + other.int8_weights_ptr[7], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(scale_ptr[4], other.scale_ptr[4], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(scale_ptr[5], other.scale_ptr[5], hidden_units_); + cudaD2Dcpy(scale_ptr[6], other.scale_ptr[6], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(scale_ptr[7], other.scale_ptr[7], hidden_units_); + } } setWeightPtr(); @@ -126,11 +192,12 @@ template ParallelGptDecoderLayerWeight& ParallelGptDecoderLayerWeight::operator=(const ParallelGptDecoderLayerWeight& other) { - hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; - tensor_para_size_ = other.tensor_para_size_; - tensor_para_rank_ = other.tensor_para_rank_; - int8_mode_ = other.int8_mode_; + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + int8_mode_ = other.int8_mode_; + gpt_variant_params_ = other.gpt_variant_params_; mallocWeights(); cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], hidden_units_); @@ -147,6 +214,25 @@ ParallelGptDecoderLayerWeight::operator=(const ParallelGptDecoderLayerWeight& cudaD2Dcpy(weights_ptr[10], other.weights_ptr[10], inter_size_ / tensor_para_size_ * hidden_units_); cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], hidden_units_); + if (gpt_variant_params_.has_adapters) { + cudaD2Dcpy(weights_ptr[12], + other.weights_ptr[12], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(weights_ptr[13], other.weights_ptr[13], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(weights_ptr[14], + other.weights_ptr[14], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[15], other.weights_ptr[15], hidden_units_); + cudaD2Dcpy(weights_ptr[16], + other.weights_ptr[16], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(weights_ptr[17], other.weights_ptr[17], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(weights_ptr[18], + other.weights_ptr[18], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[19], other.weights_ptr[19], hidden_units_); + } + if (int8_mode_ != 0) { cudaD2Dcpy( int8_weights_ptr[0], other.int8_weights_ptr[0], hidden_units_ * 3 * hidden_units_ / tensor_para_size_); @@ -157,6 +243,24 @@ ParallelGptDecoderLayerWeight::operator=(const ParallelGptDecoderLayerWeight& cudaD2Dcpy(scale_ptr[1], other.scale_ptr[1], hidden_units_); cudaD2Dcpy(scale_ptr[2], other.scale_ptr[2], inter_size_ / tensor_para_size_); cudaD2Dcpy(scale_ptr[3], other.scale_ptr[3], hidden_units_); + if (gpt_variant_params_.has_adapters) { + cudaD2Dcpy(int8_weights_ptr[4], + other.int8_weights_ptr[4], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(int8_weights_ptr[5], + other.int8_weights_ptr[5], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(int8_weights_ptr[6], + other.int8_weights_ptr[6], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(int8_weights_ptr[7], + other.int8_weights_ptr[7], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(scale_ptr[4], other.scale_ptr[4], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(scale_ptr[5], other.scale_ptr[5], hidden_units_); + cudaD2Dcpy(scale_ptr[6], other.scale_ptr[6], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + cudaD2Dcpy(scale_ptr[7], other.scale_ptr[7], hidden_units_); + } } setWeightPtr(); @@ -168,41 +272,78 @@ void ParallelGptDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDat { FT_CHECK(is_maintain_buffer == true); - loadWeightFromBin(weights_ptr[0], {(int)hidden_units_}, dir_path + ".input_layernorm.bias.bin", model_file_type); - loadWeightFromBin( - weights_ptr[1], {(int)hidden_units_}, dir_path + ".input_layernorm.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[0], {hidden_units_}, dir_path + ".input_layernorm.bias.bin", model_file_type); + loadWeightFromBin(weights_ptr[1], {hidden_units_}, dir_path + ".input_layernorm.weight.bin", model_file_type); loadWeightFromBin(weights_ptr[2], - {(int)hidden_units_, (int)(3 * hidden_units_ / tensor_para_size_)}, + {hidden_units_, 3 * hidden_units_ / tensor_para_size_}, dir_path + ".attention.query_key_value.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[3], - {3, (int)(hidden_units_ / tensor_para_size_)}, + {3, hidden_units_ / tensor_para_size_}, dir_path + ".attention.query_key_value.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[4], - {(int)(hidden_units_ / tensor_para_size_), (int)hidden_units_}, + {hidden_units_ / tensor_para_size_, hidden_units_}, dir_path + ".attention.dense.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin(weights_ptr[5], {(int)hidden_units_}, dir_path + ".attention.dense.bias.bin", model_file_type); + loadWeightFromBin(weights_ptr[5], {hidden_units_}, dir_path + ".attention.dense.bias.bin", model_file_type); loadWeightFromBin( - weights_ptr[6], {(int)hidden_units_}, dir_path + ".post_attention_layernorm.bias.bin", model_file_type); + weights_ptr[6], {hidden_units_}, dir_path + ".post_attention_layernorm.bias.bin", model_file_type); loadWeightFromBin( - weights_ptr[7], {(int)hidden_units_}, dir_path + ".post_attention_layernorm.weight.bin", model_file_type); + weights_ptr[7], {hidden_units_}, dir_path + ".post_attention_layernorm.weight.bin", model_file_type); loadWeightFromBin(weights_ptr[8], - {(int)hidden_units_, (int)(inter_size_ / tensor_para_size_)}, + {hidden_units_, inter_size_ / tensor_para_size_}, dir_path + ".mlp.dense_h_to_4h.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[9], - {(int)(inter_size_ / tensor_para_size_)}, + {inter_size_ / tensor_para_size_}, dir_path + ".mlp.dense_h_to_4h.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[10], - {(int)(inter_size_ / tensor_para_size_), (int)hidden_units_}, + {inter_size_ / tensor_para_size_, hidden_units_}, dir_path + ".mlp.dense_4h_to_h.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin( - weights_ptr[11], {(int)hidden_units_}, dir_path + ".mlp.dense_4h_to_h.bias.bin", model_file_type); + loadWeightFromBin(weights_ptr[11], {hidden_units_}, dir_path + ".mlp.dense_4h_to_h.bias.bin", model_file_type); + + if (gpt_variant_params_.has_adapters) { + loadWeightFromBin(weights_ptr[12], + {hidden_units_, gpt_variant_params_.adapter_inter_size / tensor_para_size_}, + dir_path + ".after_attention_adapter.dense_h_to_4h.weight." + + std::to_string(tensor_para_rank_) + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[13], + {gpt_variant_params_.adapter_inter_size / tensor_para_size_}, + dir_path + ".after_attention_adapter.dense_h_to_4h.bias." + + std::to_string(tensor_para_rank_) + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[14], + {gpt_variant_params_.adapter_inter_size / tensor_para_size_, hidden_units_}, + dir_path + ".after_attention_adapter.dense_4h_to_h.weight." + + std::to_string(tensor_para_rank_) + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[15], + {hidden_units_}, + dir_path + ".after_attention_adapter.dense_4h_to_h.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[16], + {hidden_units_, gpt_variant_params_.adapter_inter_size / tensor_para_size_}, + dir_path + ".after_ffn_adapter.dense_h_to_4h.weight." + std::to_string(tensor_para_rank_) + + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[17], + {gpt_variant_params_.adapter_inter_size / tensor_para_size_}, + dir_path + ".after_ffn_adapter.dense_h_to_4h.bias." + std::to_string(tensor_para_rank_) + + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[18], + {gpt_variant_params_.adapter_inter_size / tensor_para_size_, hidden_units_}, + dir_path + ".after_ffn_adapter.dense_4h_to_h.weight." + std::to_string(tensor_para_rank_) + + ".bin", + model_file_type); + loadWeightFromBin( + weights_ptr[19], {hidden_units_}, dir_path + ".after_ffn_adapter.dense_4h_to_h.bias.bin", model_file_type); + } if (int8_mode_ != 0) { transposeCalibrateQuantizeWeight(); @@ -212,29 +353,47 @@ void ParallelGptDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDat template void ParallelGptDecoderLayerWeight::setWeightPtr() { - pre_layernorm_weights.beta = weights_ptr[0]; - pre_layernorm_weights.gamma = weights_ptr[1]; - self_attention_weights.query_weight.kernel = weights_ptr[2]; - self_attention_weights.query_weight.bias = weights_ptr[3]; + pre_layernorm_weights.beta = weights_ptr[0]; + pre_layernorm_weights.gamma = weights_ptr[1]; + self_attention_weights.query_weight.kernel = weights_ptr[2]; + self_attention_weights.query_weight.bias = weights_ptr[3]; self_attention_weights.attention_output_weight.kernel = weights_ptr[4]; - self_attention_weights.attention_output_weight.bias = weights_ptr[5]; - self_attn_layernorm_weights.beta = weights_ptr[6]; - self_attn_layernorm_weights.gamma = weights_ptr[7]; + self_attention_weights.attention_output_weight.bias = weights_ptr[5]; + self_attn_layernorm_weights.beta = weights_ptr[6]; + self_attn_layernorm_weights.gamma = weights_ptr[7]; ffn_weights.intermediate_weight.kernel = weights_ptr[8]; - ffn_weights.intermediate_weight.bias = weights_ptr[9]; - ffn_weights.output_weight.kernel = weights_ptr[10]; - ffn_weights.output_weight.bias = weights_ptr[11]; + ffn_weights.intermediate_weight.bias = weights_ptr[9]; + ffn_weights.output_weight.kernel = weights_ptr[10]; + ffn_weights.output_weight.bias = weights_ptr[11]; + + after_attention_adapter_weights.intermediate_weight.kernel = weights_ptr[12]; + after_attention_adapter_weights.intermediate_weight.bias = weights_ptr[13]; + after_attention_adapter_weights.output_weight.kernel = weights_ptr[14]; + after_attention_adapter_weights.output_weight.bias = weights_ptr[15]; + + after_ffn_adapter_weights.intermediate_weight.kernel = weights_ptr[16]; + after_ffn_adapter_weights.intermediate_weight.bias = weights_ptr[17]; + after_ffn_adapter_weights.output_weight.kernel = weights_ptr[18]; + after_ffn_adapter_weights.output_weight.bias = weights_ptr[19]; if (int8_mode_ != 0) { - self_attention_weights.query_weight.int8_kernel = int8_weights_ptr[0]; - self_attention_weights.query_weight.scale = scale_ptr[0]; - self_attention_weights.attention_output_weight.int8_kernel = int8_weights_ptr[1]; - self_attention_weights.attention_output_weight.scale = scale_ptr[1]; - ffn_weights.intermediate_weight.int8_kernel = int8_weights_ptr[2]; - ffn_weights.intermediate_weight.scale = scale_ptr[2]; - ffn_weights.output_weight.int8_kernel = int8_weights_ptr[3]; - ffn_weights.output_weight.scale = scale_ptr[3]; + self_attention_weights.query_weight.int8_kernel = int8_weights_ptr[0]; + self_attention_weights.query_weight.scale = scale_ptr[0]; + self_attention_weights.attention_output_weight.int8_kernel = int8_weights_ptr[1]; + self_attention_weights.attention_output_weight.scale = scale_ptr[1]; + ffn_weights.intermediate_weight.int8_kernel = int8_weights_ptr[2]; + ffn_weights.intermediate_weight.scale = scale_ptr[2]; + ffn_weights.output_weight.int8_kernel = int8_weights_ptr[3]; + ffn_weights.output_weight.scale = scale_ptr[3]; + after_attention_adapter_weights.intermediate_weight.int8_kernel = int8_weights_ptr[4]; + after_attention_adapter_weights.intermediate_weight.scale = scale_ptr[4]; + after_attention_adapter_weights.output_weight.int8_kernel = int8_weights_ptr[5]; + after_attention_adapter_weights.output_weight.scale = scale_ptr[5]; + after_ffn_adapter_weights.intermediate_weight.int8_kernel = int8_weights_ptr[6]; + after_ffn_adapter_weights.intermediate_weight.scale = scale_ptr[6]; + after_ffn_adapter_weights.output_weight.int8_kernel = int8_weights_ptr[7]; + after_ffn_adapter_weights.output_weight.scale = scale_ptr[7]; } is_maintain_buffer = true; @@ -257,6 +416,17 @@ void ParallelGptDecoderLayerWeight::mallocWeights() deviceMalloc(&weights_ptr[10], inter_size_ / tensor_para_size_ * hidden_units_); deviceMalloc(&weights_ptr[11], hidden_units_); + if (gpt_variant_params_.has_adapters) { + deviceMalloc(&weights_ptr[12], hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + deviceMalloc(&weights_ptr[13], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + deviceMalloc(&weights_ptr[14], gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + deviceMalloc(&weights_ptr[15], hidden_units_); + deviceMalloc(&weights_ptr[16], hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + deviceMalloc(&weights_ptr[17], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + deviceMalloc(&weights_ptr[18], gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + deviceMalloc(&weights_ptr[19], hidden_units_); + } + if (int8_mode_ != 0) { deviceMalloc(&int8_weights_ptr[0], hidden_units_ * 3 * hidden_units_ / tensor_para_size_); deviceMalloc(&int8_weights_ptr[1], hidden_units_ / tensor_para_size_ * hidden_units_); @@ -267,6 +437,21 @@ void ParallelGptDecoderLayerWeight::mallocWeights() deviceMalloc(&scale_ptr[1], hidden_units_); deviceMalloc(&scale_ptr[2], inter_size_ / tensor_para_size_); deviceMalloc(&scale_ptr[3], hidden_units_); + + if (gpt_variant_params_.has_adapters) { + deviceMalloc(&int8_weights_ptr[4], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + deviceMalloc(&int8_weights_ptr[5], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + deviceMalloc(&int8_weights_ptr[6], + hidden_units_ * gpt_variant_params_.adapter_inter_size / tensor_para_size_); + deviceMalloc(&int8_weights_ptr[7], + gpt_variant_params_.adapter_inter_size / tensor_para_size_ * hidden_units_); + deviceMalloc(&scale_ptr[4], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + deviceMalloc(&scale_ptr[5], hidden_units_); + deviceMalloc(&scale_ptr[6], gpt_variant_params_.adapter_inter_size / tensor_para_size_); + deviceMalloc(&scale_ptr[7], hidden_units_); + } } } @@ -275,32 +460,46 @@ template void ParallelGptDecoderLayerWeight::compress_weights(cublasMMWrapper& cublas_wrapper, int hidden_dim) { hidden_units_ = hidden_dim; - inter_size_ = 4 * hidden_units_; - - const size_t num_sparse_weights = 4; - size_t shapes[num_sparse_weights][2] = {{hidden_units_, 3 * hidden_units_ / tensor_para_size_}, - {hidden_units_ / tensor_para_size_, hidden_units_}, - {hidden_units_, inter_size_ / tensor_para_size_}, - {inter_size_ / tensor_para_size_, hidden_units_}}; + inter_size_ = 4 * hidden_units_; + + const size_t num_sparse_weights = 8; + size_t shapes[num_sparse_weights][2] = { + {hidden_units_, 3 * hidden_units_ / tensor_para_size_}, + {hidden_units_ / tensor_para_size_, hidden_units_}, + {hidden_units_, inter_size_ / tensor_para_size_}, + {inter_size_ / tensor_para_size_, hidden_units_}, + {hidden_units_, gpt_variant_params_.adapter_inter_size / tensor_para_size_}, + {gpt_variant_params_.adapter_inter_size / tensor_para_size_, hidden_units_}, + {hidden_units_, gpt_variant_params_.adapter_inter_size / tensor_para_size_}, + {gpt_variant_params_.adapter_inter_size / tensor_para_size_, hidden_units_}}; const T* dense_weights[num_sparse_weights] = {self_attention_weights.query_weight.kernel, self_attention_weights.attention_output_weight.kernel, ffn_weights.intermediate_weight.kernel, - ffn_weights.output_weight.kernel}; - - for (size_t i = 0; i < num_sparse_weights; ++i) { - int m = shapes[i][1]; - int k = shapes[i][0]; + ffn_weights.output_weight.kernel, + after_attention_adapter_weights.intermediate_weight.kernel, + after_attention_adapter_weights.output_weight.kernel, + after_ffn_adapter_weights.intermediate_weight.kernel, + after_ffn_adapter_weights.output_weight.kernel}; + + size_t real_num_sparse_weights = gpt_variant_params_.has_adapters ? num_sparse_weights : (num_sparse_weights - 4); + for (size_t i = 0; i < real_num_sparse_weights; ++i) { + int m = shapes[i][1]; + int k = shapes[i][0]; size_t compressed_size = cublas_wrapper.getSparseMatrixSize(m, k); deviceMalloc(&sp_weights_ptr[i], static_cast(compressed_size), false); cublas_wrapper.compressMatrix(dense_weights[i], sp_weights_ptr[i], m, k); } - self_attention_weights.query_weight.sp_kernel = sp_weights_ptr[0]; - self_attention_weights.attention_output_weight.sp_kernel = sp_weights_ptr[1]; - ffn_weights.intermediate_weight.sp_kernel = sp_weights_ptr[2]; - ffn_weights.output_weight.sp_kernel = sp_weights_ptr[3]; - is_maintain_sp_buffer = true; + self_attention_weights.query_weight.sp_kernel = sp_weights_ptr[0]; + self_attention_weights.attention_output_weight.sp_kernel = sp_weights_ptr[1]; + ffn_weights.intermediate_weight.sp_kernel = sp_weights_ptr[2]; + ffn_weights.output_weight.sp_kernel = sp_weights_ptr[3]; + after_attention_adapter_weights.intermediate_weight.sp_kernel = sp_weights_ptr[4]; + after_attention_adapter_weights.output_weight.sp_kernel = sp_weights_ptr[5]; + after_ffn_adapter_weights.intermediate_weight.sp_kernel = sp_weights_ptr[6]; + after_ffn_adapter_weights.output_weight.sp_kernel = sp_weights_ptr[7]; + is_maintain_sp_buffer = true; } #endif @@ -330,6 +529,54 @@ void ParallelGptDecoderLayerWeight::transposeCalibrateQuantizeWeight() scale_ptr[3], weights_ptr[10], inter_size_ / tensor_para_size_, hidden_units_, stream_); invokeLdnTransposeQuantizeWeightPerChannel( int8_weights_ptr[3], scale_ptr[3], weights_ptr[10], inter_size_ / tensor_para_size_, hidden_units_, stream_); + + invokeLdnCalibrateWeightPerChannel(scale_ptr[4], + weights_ptr[12], + hidden_units_, + gpt_variant_params_.adapter_inter_size / tensor_para_size_, + stream_); + invokeLdnTransposeQuantizeWeightPerChannel(int8_weights_ptr[4], + scale_ptr[4], + weights_ptr[12], + hidden_units_, + gpt_variant_params_.adapter_inter_size / tensor_para_size_, + stream_); + + invokeLdnCalibrateWeightPerChannel(scale_ptr[5], + weights_ptr[14], + gpt_variant_params_.adapter_inter_size / tensor_para_size_, + hidden_units_, + stream_); + invokeLdnTransposeQuantizeWeightPerChannel(int8_weights_ptr[5], + scale_ptr[5], + weights_ptr[14], + gpt_variant_params_.adapter_inter_size / tensor_para_size_, + hidden_units_, + stream_); + + invokeLdnCalibrateWeightPerChannel(scale_ptr[6], + weights_ptr[16], + hidden_units_, + gpt_variant_params_.adapter_inter_size / tensor_para_size_, + stream_); + invokeLdnTransposeQuantizeWeightPerChannel(int8_weights_ptr[6], + scale_ptr[6], + weights_ptr[16], + hidden_units_, + gpt_variant_params_.adapter_inter_size / tensor_para_size_, + stream_); + + invokeLdnCalibrateWeightPerChannel(scale_ptr[7], + weights_ptr[18], + gpt_variant_params_.adapter_inter_size / tensor_para_size_, + hidden_units_, + stream_); + invokeLdnTransposeQuantizeWeightPerChannel(int8_weights_ptr[7], + scale_ptr[7], + weights_ptr[18], + gpt_variant_params_.adapter_inter_size / tensor_para_size_, + hidden_units_, + stream_); } template struct ParallelGptDecoderLayerWeight; diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h index 7927819db..2480ef64f 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h @@ -20,26 +20,39 @@ #include "src/fastertransformer/kernels/calibrate_quantize_weight_kernels.h" #include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/layers/FfnLayer.h" #include "src/fastertransformer/layers/FfnWeight.h" #include "src/fastertransformer/layers/attention_layers/AttentionWeight.h" #include "src/fastertransformer/utils/cublasMMWrapper.h" namespace fastertransformer { +struct gptVariantParams { + // GPT default params + float layernorm_eps = 1e-6f; + LayerNormType layernorm_type = LayerNormType::pre_layernorm; + ActivationType activation_type = ActivationType::Gelu; + bool has_post_decoder_layernorm = true; + // detoxification adapters. refer to + bool has_adapters = false; + size_t adapter_inter_size = 0; +}; + template struct ParallelGptDecoderLayerWeight { public: ParallelGptDecoderLayerWeight() = default; ParallelGptDecoderLayerWeight(const int int8_mode); - ParallelGptDecoderLayerWeight(const int hidden_units, - const int inter_size, - const int tensor_para_size, - const int tensor_para_rank, - const int int8_mode = 0); + ParallelGptDecoderLayerWeight(const int hidden_units, + const int inter_size, + const int tensor_para_size, + const int tensor_para_rank, + const int int8_mode = 0, + gptVariantParams gpt_variant_params = {}); ~ParallelGptDecoderLayerWeight(); ParallelGptDecoderLayerWeight(const ParallelGptDecoderLayerWeight& other); ParallelGptDecoderLayerWeight& operator=(const ParallelGptDecoderLayerWeight& other); - void loadModel(std::string dir_path, FtCudaDataType model_file_type); + void loadModel(std::string dir_path, FtCudaDataType model_file_type); #ifdef SPARSITY_ENABLED void compress_weights(cublasMMWrapper& cublas_wrapper, int hidden_dim); #endif @@ -48,7 +61,9 @@ struct ParallelGptDecoderLayerWeight { LayerNormWeight pre_layernorm_weights; AttentionWeight self_attention_weights; LayerNormWeight self_attn_layernorm_weights; - FfnWeight ffn_weights; + FfnWeight ffn_weights; + FfnWeight after_attention_adapter_weights; + FfnWeight after_ffn_adapter_weights; private: void setWeightPtr(); @@ -57,19 +72,24 @@ struct ParallelGptDecoderLayerWeight { protected: size_t hidden_units_; size_t inter_size_; - size_t tensor_para_size_ = 1; - size_t tensor_para_rank_ = 0; - bool is_maintain_buffer = false; - int int8_mode_ = 0; - T* weights_ptr[12]; + size_t tensor_para_size_ = 1; + size_t tensor_para_rank_ = 0; + bool is_maintain_buffer = false; + int int8_mode_ = 0; + + // gpt varians params. e.g. detoxification adapters + gptVariantParams gpt_variant_params_; + + std::vector weights_ptr = std::vector(20, nullptr); + + std::vector int8_weights_ptr = std::vector(8, nullptr); - int8_t* int8_weights_ptr[4]; - float* scale_ptr[4]; - cudaStream_t stream_ = 0; + std::vector scale_ptr = std::vector(8, nullptr); + cudaStream_t stream_ = 0; #ifdef SPARSITY_ENABLED - T* sp_weights_ptr[4]; - bool is_maintain_sp_buffer = false; + std::vector sp_weights_ptr = std::vector(8, nullptr); + bool is_maintain_sp_buffer = false; #endif }; diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc index ee3a82b57..6d5a0c99f 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc @@ -19,16 +19,19 @@ namespace fastertransformer { template -ParallelGptWeight::ParallelGptWeight(const int hidden_units, - const int inter_size, - const int vocab_size, - const int num_layer, - const int max_seq_len, - const int tensor_para_size, - const int tensor_para_rank, - const int layer_para_size, - const int layer_para_rank, - const int int8_mode): +ParallelGptWeight::ParallelGptWeight(const int hidden_units, + const int inter_size, + const int vocab_size, + const int num_layer, + const int max_seq_len, + const int tensor_para_size, + const int tensor_para_rank, + const int layer_para_size, + const int layer_para_rank, + const int int8_mode, + PromptLearningType prompt_learning_type, + std::map> prompt_learning_pair, + gptVariantParams gpt_variant_params): hidden_units_(hidden_units), inter_size_(inter_size), vocab_size_(vocab_size), @@ -38,14 +41,30 @@ ParallelGptWeight::ParallelGptWeight(const int hidden_units, tensor_para_rank_(tensor_para_rank), layer_para_size_(layer_para_size), layer_para_rank_(layer_para_rank), - int8_mode_(int8_mode) + int8_mode_(int8_mode), + prompt_learning_type_(prompt_learning_type), + prompt_learning_pair_(prompt_learning_pair), + gpt_variant_params_(gpt_variant_params) { + FT_CHECK(num_layer_ % layer_para_size_ == 0); + // set prompt weight size + if (prompt_learning_type_ == PromptLearningType::prefix_prompt) { + prompt_token_weight_size_ = 2 * num_layer_ * hidden_units_ / tensor_para_size_; + } + else if (prompt_learning_type_ == PromptLearningType::p_prompt_tuning) { + prompt_token_weight_size_ = hidden_units_; + } + + // set if load and malloc prompt weights + malloc_load_prompt_weights_ = !prompt_learning_pair_.empty() + && (prompt_learning_type_ == PromptLearningType::p_prompt_tuning + || prompt_learning_type_ == PromptLearningType::prefix_prompt); decoder_layer_weights.clear(); decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { decoder_layer_weights.push_back(new ParallelGptDecoderLayerWeight( - hidden_units_, inter_size_, tensor_para_size_, tensor_para_rank_, int8_mode_)); + hidden_units_, inter_size_, tensor_para_size_, tensor_para_rank_, int8_mode_, gpt_variant_params_)); } else { // Don't malloc and load these layers since we don't use them. @@ -61,17 +80,17 @@ template ParallelGptWeight::~ParallelGptWeight() { if (is_maintain_buffer == true) { - for (int i = 0; i < 5; i++) { + for (int i = 0; i < weights_ptr.size(); i++) { deviceFree(weights_ptr[i]); } - position_encoding_table = nullptr; - pre_decoder_embedding_table = nullptr; - post_decoder_layernorm.beta = nullptr; - post_decoder_layernorm.gamma = nullptr; + position_encoding_table = nullptr; + pre_decoder_embedding_table = nullptr; + post_decoder_layernorm.beta = nullptr; + post_decoder_layernorm.gamma = nullptr; post_decoder_embedding.kernel = nullptr; - post_decoder_embedding.bias = nullptr; - is_maintain_buffer = false; + post_decoder_embedding.bias = nullptr; + is_maintain_buffer = false; } for (int i = 0; i < num_layer_; i++) { @@ -89,7 +108,13 @@ ParallelGptWeight::ParallelGptWeight(const ParallelGptWeight& other): tensor_para_size_(other.tensor_para_size_), tensor_para_rank_(other.tensor_para_rank_), layer_para_size_(other.layer_para_size_), - layer_para_rank_(other.layer_para_rank_) + layer_para_rank_(other.layer_para_rank_), + int8_mode_(other.int8_mode_), + prompt_token_weight_size_(other.prompt_token_weight_size_), + malloc_load_prompt_weights_(other.malloc_load_prompt_weights_), + prompt_learning_type_(other.prompt_learning_type_), + prompt_learning_pair_(other.prompt_learning_pair_), + gpt_variant_params_(other.gpt_variant_params_) { mallocWeights(); cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * vocab_size_); @@ -97,6 +122,20 @@ ParallelGptWeight::ParallelGptWeight(const ParallelGptWeight& other): cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_); cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], hidden_units_ * vocab_size_); + + // prompt learning table: malloc weights and set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t prompt_id = num_base_weights + (size_t)task_name_id; + + // cuda device to device memcpy prompt table weights buffer memory + cudaD2Dcpy(weights_ptr[prompt_id], other.weights_ptr[prompt_id], prompt_length * prompt_token_weight_size_); + } + } + setWeightPtr(); decoder_layer_weights.clear(); @@ -109,15 +148,21 @@ ParallelGptWeight::ParallelGptWeight(const ParallelGptWeight& other): template ParallelGptWeight& ParallelGptWeight::operator=(const ParallelGptWeight& other) { - hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; - num_layer_ = other.num_layer_; - vocab_size_ = other.vocab_size_; - max_seq_len_ = other.max_seq_len_; - tensor_para_size_ = other.tensor_para_size_; - tensor_para_rank_ = other.tensor_para_rank_; - layer_para_size_ = other.layer_para_size_; - layer_para_rank_ = other.layer_para_rank_; + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + num_layer_ = other.num_layer_; + vocab_size_ = other.vocab_size_; + max_seq_len_ = other.max_seq_len_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + layer_para_size_ = other.layer_para_size_; + layer_para_rank_ = other.layer_para_rank_; + int8_mode_ = other.int8_mode_; + prompt_token_weight_size_ = other.prompt_token_weight_size_; + malloc_load_prompt_weights_ = other.malloc_load_prompt_weights_; + prompt_learning_type_ = other.prompt_learning_type_; + prompt_learning_pair_ = other.prompt_learning_pair_; + gpt_variant_params_ = other.gpt_variant_params_; mallocWeights(); cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * vocab_size_); @@ -125,9 +170,23 @@ ParallelGptWeight& ParallelGptWeight::operator=(const ParallelGptWeight& o cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_); cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], hidden_units_ * vocab_size_); + + // prompt learning tables: malloc weights and set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t prompt_id = num_base_weights + (size_t)task_name_id; + + // cuda device to device memcpy prompt weights buffer memory + cudaD2Dcpy(weights_ptr[prompt_id], other.weights_ptr[prompt_id], prompt_length * prompt_token_weight_size_); + } + } + setWeightPtr(); decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(other.decoder_layer_weights[l]); } @@ -137,22 +196,52 @@ ParallelGptWeight& ParallelGptWeight::operator=(const ParallelGptWeight& o template void ParallelGptWeight::setWeightPtr() { - position_encoding_table = weights_ptr[0]; - pre_decoder_embedding_table = weights_ptr[1]; - post_decoder_layernorm.beta = weights_ptr[2]; - post_decoder_layernorm.gamma = weights_ptr[3]; + prompt_learning_table.resize(prompt_learning_pair_.size()); + + position_encoding_table = weights_ptr[0]; + pre_decoder_embedding_table = weights_ptr[1]; + post_decoder_layernorm.beta = weights_ptr[2]; + post_decoder_layernorm.gamma = weights_ptr[3]; post_decoder_embedding.kernel = weights_ptr[4]; - post_decoder_embedding.bias = nullptr; + post_decoder_embedding.bias = nullptr; + + // prompt learning tables: set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + // set weight ptr + prompt_learning_table[task_name_id] = {weights_ptr[task_weight_id], prompt_length}; + } + } } template void ParallelGptWeight::mallocWeights() { + weights_ptr.resize(num_base_weights + prompt_learning_pair_.size()); + deviceMalloc(&weights_ptr[0], max_seq_len_ * vocab_size_); deviceMalloc(&weights_ptr[1], vocab_size_ * hidden_units_); deviceMalloc(&weights_ptr[2], hidden_units_); deviceMalloc(&weights_ptr[3], hidden_units_); deviceMalloc(&weights_ptr[4], hidden_units_ * vocab_size_); + + // prompt learning tables: malloc weights + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + // malloc weights + T* prompt_weights_ptr = nullptr; + deviceMalloc(&prompt_weights_ptr, prompt_length * prompt_token_weight_size_); + weights_ptr[task_weight_id] = prompt_weights_ptr; + } + } is_maintain_buffer = true; } @@ -163,11 +252,42 @@ void ParallelGptWeight::loadModel(std::string dir_path) FT_CHECK(is_maintain_buffer == true); loadWeightFromBin(weights_ptr[0], {max_seq_len_, hidden_units_}, dir_path + "/model.wpe.bin", model_file_type); loadWeightFromBin(weights_ptr[1], {vocab_size_ * hidden_units_}, dir_path + "/model.wte.bin", model_file_type); - loadWeightFromBin( - weights_ptr[2], {hidden_units_}, dir_path + "/model.final_layernorm.bias.bin", model_file_type); - loadWeightFromBin( - weights_ptr[3], {hidden_units_}, dir_path + "/model.final_layernorm.weight.bin", model_file_type); - loadWeightFromBin(weights_ptr[4], {vocab_size_ * hidden_units_}, dir_path + "/model.wte.bin", model_file_type); + if (gpt_variant_params_.has_post_decoder_layernorm) { + loadWeightFromBin( + weights_ptr[2], {hidden_units_}, dir_path + "/model.final_layernorm.bias.bin", model_file_type); + loadWeightFromBin( + weights_ptr[3], {hidden_units_}, dir_path + "/model.final_layernorm.weight.bin", model_file_type); + } + if (checkIfFileExist(dir_path + "/model.lm_head.weight.bin")) { + loadWeightFromBin( + weights_ptr[4], {vocab_size_ * hidden_units_}, dir_path + "/model.lm_head.weight.bin", model_file_type); + } + else { + loadWeightFromBin( + weights_ptr[4], {vocab_size_ * hidden_units_}, dir_path + "/model.wte.bin", model_file_type); + } + + // prompt table: load weights from bin + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + std::string prompt_weight_path_name = (prompt_learning_type_ == PromptLearningType::p_prompt_tuning) ? + (dir_path + "/model.prompt_table." + task_name + ".weight.bin") : + (dir_path + "/model.prefix_prompt." + task_name + ".weight." + + std::to_string(tensor_para_rank_) + ".bin"); + + if (prompt_length > 0) { + loadWeightFromBin(weights_ptr[task_weight_id], + {prompt_length * prompt_token_weight_size_}, + prompt_weight_path_name, + model_file_type); + } + } + } for (int l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { @@ -189,6 +309,7 @@ void ParallelGptWeight::resizeLayer(const int num_layer, const int int8_mode) { int8_mode_ = int8_mode; num_layer_ = num_layer; + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(new ParallelGptDecoderLayerWeight(int8_mode_)); } diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h index eb9d97556..38458b186 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h @@ -17,8 +17,11 @@ #pragma once #include "src/fastertransformer/kernels/layernorm_kernels.h" + +#include "src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.h" #include "src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h" #include "src/fastertransformer/utils/memory_utils.h" +#include "src/fastertransformer/utils/prompt_learning.h" namespace fastertransformer { @@ -26,48 +29,80 @@ template struct ParallelGptWeight { ParallelGptWeight() = default; - ParallelGptWeight(const int hidden_units, - const int inter_size, - const int vocab_size, - const int num_layer, - const int max_seq_len, - const int tensor_para_size, - const int tensor_para_rank, - const int layer_para_size, - const int layer_para_rank, - const int int8_mode = 0); + ParallelGptWeight(const int hidden_units, + const int inter_size, + const int vocab_size, + const int num_layer, + const int max_seq_len, + const int tensor_para_size, + const int tensor_para_rank, + const int layer_para_size, + const int layer_para_rank, + const int int8_mode = 0, + PromptLearningType prompt_learning_type = PromptLearningType::no_prompt, + std::map> prompt_learning_pair = {}, + gptVariantParams gpt_variant_params = {}); ~ParallelGptWeight(); ParallelGptWeight(const ParallelGptWeight& other); ParallelGptWeight& operator=(const ParallelGptWeight& other); - void loadModel(std::string dir_path); - void resizeLayer(const int num_layer, const int int8_mode = 0); + void loadModel(std::string dir_path); + void resizeLayer(const int num_layer, const int int8_mode = 0); #ifdef SPARSITY_ENABLED void compress_weights(cublasMMWrapper& cublas_wrapper); #endif std::vector*> decoder_layer_weights; - const T* position_encoding_table = nullptr; - const T* pre_decoder_embedding_table = nullptr; - LayerNormWeight post_decoder_layernorm; - DenseWeight post_decoder_embedding; + const T* position_encoding_table = nullptr; + const T* pre_decoder_embedding_table = nullptr; + LayerNormWeight post_decoder_layernorm; + DenseWeight post_decoder_embedding; + + /* + prompt_learning_pair = vectors of [weight ptr, prompt length] pair + prompt_length is stored here for compatible prompt learning table + prefix_prompt weights store as shape [num_layers, 2, num_heads, perfix_seq_len, size_per_head] + p/prompt tuning weights store as shape [prompt_len, hidden_units] + idx is the task_name_id of the prompt tables + */ + std::vector> prompt_learning_table = {}; + inline size_t getMaxSeqLen() const + { + return max_seq_len_; + } + inline void setMaxSeqLen(size_t max_seq_len) + { + max_seq_len_ = max_seq_len; + } private: void setWeightPtr(); void mallocWeights(); bool isValidLayerParallelId(int l); - int hidden_units_; - int inter_size_; - int vocab_size_; - int num_layer_; - int max_seq_len_; - int tensor_para_size_; - int tensor_para_rank_; - int layer_para_size_; - int layer_para_rank_; - int int8_mode_ = 0; - bool is_maintain_buffer = false; - T* weights_ptr[5]; + size_t hidden_units_; + size_t inter_size_; + size_t vocab_size_; + size_t num_layer_; + size_t max_seq_len_; + size_t tensor_para_size_; + size_t tensor_para_rank_; + size_t layer_para_size_; + size_t layer_para_rank_; + size_t int8_mode_ = 0; + + // gpt variants: e.g. meta opt + gptVariantParams gpt_variant_params_; + + // prompt learning pair (task_name, (task_name_id, prompt_len)) + PromptLearningType prompt_learning_type_; + std::map> prompt_learning_pair_; + bool malloc_load_prompt_weights_ = false; + // each prompt token's weight size + size_t prompt_token_weight_size_ = 0; + + bool is_maintain_buffer = false; + size_t num_base_weights = 5; + std::vector weights_ptr = std::vector(num_base_weights); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/multi_gpu_gpt/gpt_gemm.cc b/src/fastertransformer/models/multi_gpu_gpt/gpt_gemm.cc index 7e4590e33..149c0668b 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/gpt_gemm.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/gpt_gemm.cc @@ -21,23 +21,23 @@ namespace ft = fastertransformer; int main(int argc, char* argv[]) { - if (argc != 9 && argc != 10) { + if (argc != 9 && argc != 10 && argc != 11) { printf( "[ERROR] gpt_gemm batch_size beam_width max_input_len head_number size_per_head inter_size vocab_size data_type tensor_para_size\n"); printf("e.g. ./bin/gpt_gemm 8 4 32 96 128 49152 51200 1 8\n"); return 0; } - const int batch_size = atoi(argv[1]); - const int beam_width = atoi(argv[2]); - const int max_input_len = atoi(argv[3]); - const int head_num = atoi(argv[4]); - const int size_per_head = atoi(argv[5]); - const int inter_size = atoi(argv[6]); - const int vocab_size = atoi(argv[7]); - const ft::CublasDataType data_type = static_cast(atoi(argv[8])); // 0 FP32, 1 FP16, 2 BF 16 - const int tensor_para_size = argc <= 9 ? 1 : atoi(argv[9]); - int is_fp16_compute_type = argc <= 10 ? 0 : atoi(argv[10]); + const int batch_size = atoi(argv[1]); + const int beam_width = atoi(argv[2]); + const int max_input_len = atoi(argv[3]); + const int head_num = atoi(argv[4]); + const int size_per_head = atoi(argv[5]); + const int inter_size = atoi(argv[6]); + const int vocab_size = atoi(argv[7]); + const ft::CublasDataType data_type = static_cast(atoi(argv[8])); // 0 FP32, 1 FP16, 2 BF 16 + const int tensor_para_size = argc < 10 ? 1 : atoi(argv[9]); + int is_fp16_compute_type = argc < 11 ? 0 : atoi(argv[10]); if (data_type == ft::BFLOAT16_DATATYPE && is_fp16_compute_type != 0) { printf("[ERROR] BFLOAT16_DATATYPE does not support is_fp16_compute_type = True\n"); return 0; @@ -55,7 +55,7 @@ int main(int argc, char* argv[]) printf(" tensor_para_size: %d \n", tensor_para_size); std::cout << std::endl; - void* gemm_test_buf; + void* gemm_test_buf; size_t buf_size_in_byte = ft::calGptGemmTestBufSizeInByte(batch_size, beam_width, max_input_len, diff --git a/src/fastertransformer/models/swin/CMakeLists.txt b/src/fastertransformer/models/swin/CMakeLists.txt index bba6bd4d8..77990d0ce 100644 --- a/src/fastertransformer/models/swin/CMakeLists.txt +++ b/src/fastertransformer/models/swin/CMakeLists.txt @@ -31,4 +31,4 @@ set_property(TARGET Swin PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Swin PUBLIC -lcudart SwinBasicLayer memory_utils) add_executable(swin_gemm swin_gemm.cc) -target_link_libraries(swin_gemm PUBLIC -lcublas -lcublasLt -lcudart swin_igemm_func swin_gemm_func memory_utils) \ No newline at end of file +target_link_libraries(swin_gemm PUBLIC -lcublas -lcublasLt -lcudart swin_igemm_func swin_gemm_func memory_utils tensor) diff --git a/src/fastertransformer/models/swin/Swin.cc b/src/fastertransformer/models/swin/Swin.cc index c12564c84..790ba70c9 100644 --- a/src/fastertransformer/models/swin/Swin.cc +++ b/src/fastertransformer/models/swin/Swin.cc @@ -37,25 +37,25 @@ SwinTransformer::getBufSize(const int batch, const int patches_resolution, co } template -SwinTransformer::SwinTransformer(int max_batch, - int img_size, - int patch_size, - int in_chans, - int embed_dim, - int window_size, - int* depths, - int* num_heads, - bool ape, - bool patch_norm, - int layer_num, - float mlp_ratio, - cudnnHandle_t cudnn_handle, - cudaStream_t stream, +SwinTransformer::SwinTransformer(int max_batch, + int img_size, + int patch_size, + int in_chans, + int embed_dim, + int window_size, + int* depths, + int* num_heads, + bool ape, + bool patch_norm, + int layer_num, + float mlp_ratio, + cudnnHandle_t cudnn_handle, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias, - float qk_scale): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias, + float qk_scale): max_batch_(max_batch), img_size_(img_size), patch_size_(patch_size), @@ -78,11 +78,12 @@ SwinTransformer::SwinTransformer(int max_batch, { patches_resolution_ = img_size / patch_size; - max_buf_size_ = getBufSize(max_batch_, patches_resolution_, layer_num_, embed_dim_); + max_buf_size_ = getBufSize(max_batch_, patches_resolution_, layer_num_, embed_dim_); basic_layer_ = new SwinTransformerBasicLayer(max_batch_, window_size_, mlp_ratio_, + layernorm_eps_, stream, cublas_wrapper, allocator, @@ -106,7 +107,7 @@ void SwinTransformer::allocateBuffer() { if (is_allocate_buffer_ == false) { - buf_ = reinterpret_cast(allocator_->malloc(max_buf_size_, false)); + buf_ = reinterpret_cast(allocator_->reMalloc(buf_, max_buf_size_, false)); x_patch_embed_ = buf_; basic_layer_output_ = x_patch_embed_ + max_batch_ * embed_dim_ * patches_resolution_ * patches_resolution_; @@ -133,34 +134,40 @@ void SwinTransformer::freeBuffer() printf("[ERROR][SwinTransformer][freeBuffer] allocator_ is NULL!\n"); exit(-1); } - allocator_->free(buf_); - allocator_ = nullptr; - buf_ = nullptr; + allocator_->free((void**)(&buf_)); + allocator_ = nullptr; + buf_ = nullptr; is_allocate_buffer_ = false; } } template -void SwinTransformer::patchEmbed(T* output, - const T* input, - const T* kernel, - const T* bias, - const T* gamma, - const T* beta, - const int batch, - const int img_size, - const int patch_size, - const int patches_resolution, - const int in_chans, - const int embed_dim, +void SwinTransformer::patchEmbed(T* output, + const T* input, + const T* kernel, + const T* bias, + const T* gamma, + const T* beta, + const int batch, + const int img_size, + const int patch_size, + const int patches_resolution, + const int in_chans, + const int embed_dim, const bool patch_norm) { conv2d( output, input, kernel, batch, img_size, img_size, in_chans, embed_dim, patch_size, patch_size, cudnn_handle_); if (patch_norm) { - invokeAddBiasLayernorm( - output, bias, gamma, beta, batch * patches_resolution * patches_resolution, embed_dim, stream_); + invokeAddBiasLayernorm(output, + bias, + gamma, + beta, + layernorm_eps_, + batch * patches_resolution * patches_resolution, + embed_dim, + stream_); } else { invokeAddBias(output, bias, batch * patches_resolution * patches_resolution, embed_dim, stream_); @@ -168,9 +175,9 @@ void SwinTransformer::patchEmbed(T* output, } template -void SwinTransformer::forward(std::vector* output_tensors, +void SwinTransformer::forward(std::vector* output_tensors, const std::vector* input_tensors, - SwinTransformerWeight& swin_weights) + SwinTransformerWeight& swin_weights) { // input_tensors: // input_images [batch, in_channels, input_resolution, input_resolution] @@ -178,10 +185,10 @@ void SwinTransformer::forward(std::vector* output_tensors, // output_tensors: // output_embedding [batch, final_len] - T* output = (T*)output_tensors->at(0).data; - const T* input = (const T*)input_tensors->at(0).data; - const size_t batch = input_tensors->at(0).shape[0]; - const int sm = *(const int*)input_tensors->at(1).data; + T* output = (T*)output_tensors->at(0).data; + const T* input = (const T*)input_tensors->at(0).data; + const size_t batch = input_tensors->at(0).shape[0]; + const int sm = *(const int*)input_tensors->at(1).data; allocateBuffer(); patchEmbed(x_patch_embed_, input, @@ -197,12 +204,12 @@ void SwinTransformer::forward(std::vector* output_tensors, embed_dim_, patch_norm_); - size_t basic_layer_dim = embed_dim_; - size_t basic_layer_input_resolution = patches_resolution_; - int basic_layer_output_size = batch * patches_resolution_ * patches_resolution_ * embed_dim_ / 2; - size_t m = batch * patches_resolution_ * patches_resolution_; - size_t n = embed_dim_; - DataType data_type = getTensorType(); + size_t basic_layer_dim = embed_dim_; + size_t basic_layer_input_resolution = patches_resolution_; + int basic_layer_output_size = batch * patches_resolution_ * patches_resolution_ * embed_dim_ / 2; + size_t m = batch * patches_resolution_ * patches_resolution_; + size_t n = embed_dim_; + DataType data_type = getTensorType(); bool do_patch_merge = true; if (layer_num_ == 1) { @@ -219,7 +226,7 @@ void SwinTransformer::forward(std::vector* output_tensors, data_type, std::vector{batch, basic_layer_input_resolution, basic_layer_input_resolution, basic_layer_dim}, basic_layer_output_ + (i % 2) * basic_layer_output_size}}; - int additional_params[4] = {depths_[i], num_heads_[i], do_patch_merge ? 1 : 0, sm}; + int additional_params[4] = {depths_[i], num_heads_[i], do_patch_merge ? 1 : 0, sm}; std::vector tmp_input_tensors{ Tensor{ MEMORY_GPU, @@ -239,6 +246,7 @@ void SwinTransformer::forward(std::vector* output_tensors, basic_layer_output_ + ((layer_num_ - 1) % 2) * basic_layer_output_size, swin_weights.norm_weights.gamma, swin_weights.norm_weights.beta, + layernorm_eps_, batch * basic_layer_input_resolution * basic_layer_input_resolution, basic_layer_dim, stream_); @@ -269,5 +277,8 @@ void SwinTransformer::forward(std::vector* output_tensors, template class SwinTransformer; template class SwinTransformer; +#ifdef ENABLE_BF16 +template class SwinTransformer<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/swin/Swin.h b/src/fastertransformer/models/swin/Swin.h index d1221b28e..244a755f9 100644 --- a/src/fastertransformer/models/swin/Swin.h +++ b/src/fastertransformer/models/swin/Swin.h @@ -26,31 +26,32 @@ template class SwinTransformer { private: - int max_batch_ = 1; - int img_size_ = 224; - int patch_size_ = 4; - int in_chans_ = 3; - int embed_dim_ = 96; - int window_size_ = 7; - int* depths_; - int* num_heads_; - bool ape_ = false; - bool patch_norm_ = true; - float mlp_ratio_ = 4.0f; - bool qkv_bias_ = true; - int patches_resolution_ = 56; - int layer_num_ = 4; - int qk_scale_ = 1.0f; - size_t max_buf_size_ = 0; - IAllocator* allocator_ = nullptr; - cudnnHandle_t cudnn_handle_; - cudaStream_t stream_; - cublasMMWrapper* cublas_wrapper_; - bool is_free_buffer_after_forward_; - bool is_allocate_buffer_ = false; + int max_batch_ = 1; + int img_size_ = 224; + int patch_size_ = 4; + int in_chans_ = 3; + int embed_dim_ = 96; + int window_size_ = 7; + int* depths_; + int* num_heads_; + bool ape_ = false; + bool patch_norm_ = true; + float mlp_ratio_ = 4.0f; + bool qkv_bias_ = true; + int patches_resolution_ = 56; + int layer_num_ = 4; + int qk_scale_ = 1.0f; + static constexpr float layernorm_eps_ = 1e-6f; + size_t max_buf_size_ = 0; + IAllocator* allocator_ = nullptr; + cudnnHandle_t cudnn_handle_; + cudaStream_t stream_; + cublasMMWrapper* cublas_wrapper_; + bool is_free_buffer_after_forward_; + bool is_allocate_buffer_ = false; - T* buf_ = nullptr; - T* x_patch_embed_ = nullptr; + T* buf_ = nullptr; + T* x_patch_embed_ = nullptr; T* basic_layer_output_ = nullptr; // for avgPool_ones T* avg_pool_ones_ = nullptr; @@ -62,25 +63,25 @@ class SwinTransformer { void allocateBuffer(); - SwinTransformer(int max_batch, - int img_size, - int patch_size, - int in_chans, - int embed_dim, - int window_size, - int* depths, - int* num_heads, - bool ape, - bool patch_norm, - int layer_num, - float mlp_ratio, - cudnnHandle_t cudnn_handle, - cudaStream_t stream, + SwinTransformer(int max_batch, + int img_size, + int patch_size, + int in_chans, + int embed_dim, + int window_size, + int* depths, + int* num_heads, + bool ape, + bool patch_norm, + int layer_num, + float mlp_ratio, + cudnnHandle_t cudnn_handle, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias = true, - float qk_scale = 1.0f); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias = true, + float qk_scale = 1.0f); void freeBuffer(); @@ -88,23 +89,23 @@ class SwinTransformer { // input is [B, C_in, H, W] // output is [B, H, W, C_out] - void patchEmbed(T* output, - const T* input, - const T* kernel, - const T* bias, - const T* gamma, - const T* beta, - const int batch, - const int img_size, - const int patch_size, - const int patches_resolution, - const int in_chans, - const int embed_dim, + void patchEmbed(T* output, + const T* input, + const T* kernel, + const T* bias, + const T* gamma, + const T* beta, + const int batch, + const int img_size, + const int patch_size, + const int patches_resolution, + const int in_chans, + const int embed_dim, const bool patch_norm); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - SwinTransformerWeight& swin_weights); + SwinTransformerWeight& swin_weights); }; // class SwinTransformer } // namespace fastertransformer diff --git a/src/fastertransformer/models/swin/SwinBasicLayer.cc b/src/fastertransformer/models/swin/SwinBasicLayer.cc index 433281ce4..34e4b7493 100644 --- a/src/fastertransformer/models/swin/SwinBasicLayer.cc +++ b/src/fastertransformer/models/swin/SwinBasicLayer.cc @@ -22,8 +22,8 @@ template void SwinTransformerBasicLayer::allocateBuffer() { if (is_allocate_buffer_ == false) { - block_output_ = (T*)allocator_->malloc( - 2 * max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(T), false); + block_output_ = (T*)allocator_->reMalloc( + block_output_, 2 * max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(T), false); is_allocate_buffer_ = true; } @@ -33,31 +33,34 @@ template void SwinTransformerBasicLayer::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(block_output_); + allocator_->free((void**)(&block_output_)); is_allocate_buffer_ = false; } } template -SwinTransformerBasicLayer::SwinTransformerBasicLayer(int max_batch, - int window_size, - float mlp_ratio, - cudaStream_t stream, +SwinTransformerBasicLayer::SwinTransformerBasicLayer(int max_batch, + int window_size, + float mlp_ratio, + float layernorm_eps, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias, - float qk_scale): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias, + float qk_scale): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_(max_batch), window_size_(window_size), mlp_ratio_(mlp_ratio), + layernorm_eps_(layernorm_eps), qkv_bias_(qkv_bias), qk_scale_(qk_scale) { block_ = new SwinTransformerBlock(max_batch_, window_size_, mlp_ratio_, + layernorm_eps_, stream, cublas_wrapper, allocator, @@ -77,18 +80,26 @@ SwinTransformerBasicLayer::~SwinTransformerBasicLayer() } template -void SwinTransformerBasicLayer::patchMerge(T* output, - T* merge_layernorm_buf, +void SwinTransformerBasicLayer::patchMerge(T* output, + T* merge_layernorm_buf, const T* input, const T* gamma, const T* beta, const T* weight, - int batch, - int input_resolution, - int dim) + int batch, + int input_resolution, + int dim) { - invokeMergeLayernorm( - merge_layernorm_buf, input, gamma, beta, batch, input_resolution, input_resolution, dim, stream_); + invokeMergeLayernorm(merge_layernorm_buf, + input, + gamma, + beta, + layernorm_eps_, + batch, + input_resolution, + input_resolution, + dim, + stream_); cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, @@ -104,8 +115,8 @@ void SwinTransformerBasicLayer::patchMerge(T* output, } template -void SwinTransformerBasicLayer::forward(std::vector* output_tensors, - std::vector* input_tensors, +void SwinTransformerBasicLayer::forward(std::vector* output_tensors, + std::vector* input_tensors, SwinTransformerBasicLayerWeight& swin_basic_layer_weights) { // input_tensors: @@ -114,29 +125,29 @@ void SwinTransformerBasicLayer::forward(std::vector* output_tensors, // output_tensors: // output [batch, output_resolution, output_resolution, output_dim] - T* from_tensor = (T*)input_tensors->at(0).data; - T* output_tensor = (T*)(output_tensors->at(0).data); - size_t batch = input_tensors->at(0).shape[0]; + T* from_tensor = (T*)input_tensors->at(0).data; + T* output_tensor = (T*)(output_tensors->at(0).data); + size_t batch = input_tensors->at(0).shape[0]; size_t input_resolution = input_tensors->at(0).shape[1]; assert(input_resolution == input_tensors->at(0).shape[2]); - size_t dim = input_tensors->at(0).shape[3]; - int* input_paramters = (int*)input_tensors->at(1).data; - const int depth = input_paramters[0]; - const int num_head = input_paramters[1]; - bool do_patch_merge = (input_paramters[2] == 1) ? true : false; - const int sm = input_paramters[3]; + size_t dim = input_tensors->at(0).shape[3]; + int* input_paramters = (int*)input_tensors->at(1).data; + const int depth = input_paramters[0]; + const int num_head = input_paramters[1]; + bool do_patch_merge = (input_paramters[2] == 1) ? true : false; + const int sm = input_paramters[3]; patches_resolution_ = input_resolution; - embed_dim_ = dim; + embed_dim_ = dim; allocateBuffer(); - int block_output_size = batch * input_resolution * input_resolution * dim; - size_t m = batch * input_resolution * input_resolution; - size_t n = dim; - size_t window_num = (input_resolution / window_size_) * (input_resolution / window_size_); - size_t window_len = window_size_ * window_size_; - int shift_size = 0; - DataType data_type = getTensorType(); + int block_output_size = batch * input_resolution * input_resolution * dim; + size_t m = batch * input_resolution * input_resolution; + size_t n = dim; + size_t window_num = (input_resolution / window_size_) * (input_resolution / window_size_); + size_t window_len = window_size_ * window_size_; + int shift_size = 0; + DataType data_type = getTensorType(); if (do_patch_merge) { for (int i = 0; i < depth; i++) { @@ -145,8 +156,8 @@ void SwinTransformerBasicLayer::forward(std::vector* output_tensors, data_type, std::vector{batch, input_resolution, input_resolution, dim}, block_output_ + (i % 2) * block_output_size}}; - shift_size = (i % 2 == 0) ? 0 : (window_size_ / 2); - int additional_parameters[3] = {num_head, shift_size, sm}; + shift_size = (i % 2 == 0) ? 0 : (window_size_ / 2); + int additional_parameters[3] = {num_head, shift_size, sm}; std::vector tmp_input_tensors{ Tensor{MEMORY_GPU, data_type, @@ -177,8 +188,8 @@ void SwinTransformerBasicLayer::forward(std::vector* output_tensors, data_type, std::vector{batch, input_resolution, input_resolution, dim}, i == depth - 1 ? output_tensor : block_output_ + (i % 2) * block_output_size}}; - shift_size = (i % 2 == 0) ? 0 : (window_size_ / 2); - int additional_parameters[3] = {num_head, shift_size, sm}; + shift_size = (i % 2 == 0) ? 0 : (window_size_ / 2); + int additional_parameters[3] = {num_head, shift_size, sm}; std::vector tmp_input_tensors{ Tensor{MEMORY_GPU, data_type, @@ -200,4 +211,7 @@ void SwinTransformerBasicLayer::forward(std::vector* output_tensors, template class SwinTransformerBasicLayer; template class SwinTransformerBasicLayer; +#ifdef ENABLE_BF16 +template class SwinTransformerBasicLayer<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/models/swin/SwinBasicLayer.h b/src/fastertransformer/models/swin/SwinBasicLayer.h index b523f5499..0b4b83496 100644 --- a/src/fastertransformer/models/swin/SwinBasicLayer.h +++ b/src/fastertransformer/models/swin/SwinBasicLayer.h @@ -23,16 +23,17 @@ template class SwinTransformerBasicLayer: public BaseLayer { private: - int max_batch_ = 1; - int patches_resolution_ = 64; - int embed_dim_ = 96; - int window_size_ = 7; - float mlp_ratio_ = 4.0f; - bool qkv_bias_ = true; - float qk_scale_ = 1.0f; + int max_batch_ = 1; + int patches_resolution_ = 64; + int embed_dim_ = 96; + int window_size_ = 7; + float mlp_ratio_ = 4.0f; + bool qkv_bias_ = true; + float qk_scale_ = 1.0f; + float layernorm_eps_; - T* buf_ = nullptr; - T *block_output_ = nullptr, *merge_layernorm_buf_ = nullptr; + T* buf_ = nullptr; + T * block_output_ = nullptr, *merge_layernorm_buf_ = nullptr; SwinTransformerBlock* block_ = nullptr; public: @@ -43,15 +44,16 @@ class SwinTransformerBasicLayer: public BaseLayer { } // dim & input_resolution will be used to malloc the max buf size - SwinTransformerBasicLayer(int max_batch, - int window_size, - float mlp_ratio, - cudaStream_t stream, + SwinTransformerBasicLayer(int max_batch, + int window_size, + float mlp_ratio, + float layernorm_eps, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias, - float qk_scale); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias, + float qk_scale); void allocateBuffer(); @@ -62,18 +64,18 @@ class SwinTransformerBasicLayer: public BaseLayer { // input is [B, H, W, C] // merge_layernorm_buf is [B, H/2, W/2, 4*C] // output is [B, H/2, W/2, 2*C] - void patchMerge(T* output, - T* merge_layernorm_buf, + void patchMerge(T* output, + T* merge_layernorm_buf, const T* input, const T* gamma, const T* beta, const T* weight, - int batch, - int input_resolution, - int dim); + int batch, + int input_resolution, + int dim); - void forward(std::vector* output_tensors, - std::vector* input_tensors, + void forward(std::vector* output_tensors, + std::vector* input_tensors, SwinTransformerBasicLayerWeight& swin_basic_layer_weights); }; // class SwinTransformerBasicLayer diff --git a/src/fastertransformer/models/swin/SwinBlock.cc b/src/fastertransformer/models/swin/SwinBlock.cc index 34998e4c4..a3c3d0378 100644 --- a/src/fastertransformer/models/swin/SwinBlock.cc +++ b/src/fastertransformer/models/swin/SwinBlock.cc @@ -22,15 +22,15 @@ template void SwinTransformerBlock::allocateBuffer() { if (is_allocate_buffer_ == false) { - attention_output_ = - (T*)allocator_->malloc(max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); - normed_attn_out_buf_ = - (T*)allocator_->malloc(max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); - mlp_buf_ = (T*)allocator_->malloc( - max_batch_ * window_num_ * window_len_ * int(embed_dim_ * mlp_ratio_) * sizeof(T), false); + attention_output_ = (T*)allocator_->reMalloc( + attention_output_, max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); + normed_attn_out_buf_ = (T*)allocator_->reMalloc( + normed_attn_out_buf_, max_batch_ * window_num_ * window_len_ * embed_dim_ * sizeof(T), false); + mlp_buf_ = (T*)allocator_->reMalloc( + mlp_buf_, max_batch_ * window_num_ * window_len_ * int(embed_dim_ * mlp_ratio_) * sizeof(T), false); normed_shifted_input_ = mlp_buf_; - is_allocate_buffer_ = true; + is_allocate_buffer_ = true; } } @@ -38,32 +38,34 @@ template void SwinTransformerBlock::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(attention_output_); - allocator_->free(normed_attn_out_buf_); - allocator_->free(mlp_buf_); + allocator_->free((void**)(&attention_output_)); + allocator_->free((void**)(&normed_attn_out_buf_)); + allocator_->free((void**)(&mlp_buf_)); is_allocate_buffer_ = false; } } template -SwinTransformerBlock::SwinTransformerBlock(int max_batch, - int window_size, - float mlp_ratio, - cudaStream_t stream, +SwinTransformerBlock::SwinTransformerBlock(int max_batch, + int window_size, + float mlp_ratio, + float layernorm_eps, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias, - float qk_scale): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias, + float qk_scale): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_(max_batch), window_size_(window_size), mlp_ratio_(mlp_ratio), + layernorm_eps_(layernorm_eps), qkv_bias_(qkv_bias), qk_scale_(qk_scale) { window_len_ = window_size_ * window_size_; - atten_ = new WindowAttention(max_batch_, + atten_ = new WindowAttention(max_batch_, window_size_, stream, cublas_wrapper, @@ -83,8 +85,8 @@ SwinTransformerBlock::~SwinTransformerBlock() } template -void SwinTransformerBlock::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void SwinTransformerBlock::forward(std::vector* output_tensors, + const std::vector* input_tensors, SwinTransformerBlockWeight& swin_block_weights) { // input_tensors: @@ -94,28 +96,29 @@ void SwinTransformerBlock::forward(std::vector* output_tensors, // output_tensors: // output [batch, input_resolution, input_resolution, dim] - T* input = (T*)input_tensors->at(0).data; - T* output = (T*)(output_tensors->at(0).data); - const int batch = input_tensors->at(0).shape[0]; + T* input = (T*)input_tensors->at(0).data; + T* output = (T*)(output_tensors->at(0).data); + const int batch = input_tensors->at(0).shape[0]; const int input_resolution = input_tensors->at(0).shape[1]; assert(input_resolution == input_tensors->at(0).shape[2]); - const int dim = input_tensors->at(0).shape[3]; + const int dim = input_tensors->at(0).shape[3]; const int* input_parameters = (const int*)input_tensors->at(2).data; - const int num_head = input_parameters[0]; - int shift_size = input_parameters[1]; - const int sm = input_parameters[2]; + const int num_head = input_parameters[0]; + int shift_size = input_parameters[1]; + const int sm = input_parameters[2]; - shift_size = (input_resolution <= window_size_) ? 0 : shift_size; + shift_size = (input_resolution <= window_size_) ? 0 : shift_size; int window_num = (input_resolution / window_size_) * (input_resolution / window_size_); - window_num_ = window_num; - embed_dim_ = dim; - int mlp_dim = int(mlp_ratio_ * dim); + window_num_ = window_num; + embed_dim_ = dim; + int mlp_dim = int(mlp_ratio_ * dim); allocateBuffer(); invokeLayernormShiftPartition(normed_shifted_input_, input, swin_block_weights.attn_layernorm_weights.gamma, swin_block_weights.attn_layernorm_weights.beta, + layernorm_eps_, batch, input_resolution, input_resolution, @@ -124,10 +127,10 @@ void SwinTransformerBlock::forward(std::vector* output_tensors, window_size_, stream_); - const size_t m = batch * input_resolution * input_resolution; - const size_t n = dim; - DataType data_type = getTensorType(); - int additional_parameters[6] = {batch, dim, input_resolution, num_head, shift_size, sm}; + const size_t m = batch * input_resolution * input_resolution; + const size_t n = dim; + DataType data_type = getTensorType(); + int additional_parameters[6] = {batch, dim, input_resolution, num_head, shift_size, sm}; std::vector attn_output_tensors{ Tensor{MEMORY_GPU, data_type, std::vector{m, n}, attention_output_}}; std::vector attn_input_tensors{ @@ -147,6 +150,7 @@ void SwinTransformerBlock::forward(std::vector* output_tensors, swin_block_weights.ffn_layernorm_weights.gamma, swin_block_weights.ffn_layernorm_weights.beta, swin_block_weights.attention_weights.attention_output_weight.bias, + layernorm_eps_, batch * input_resolution * input_resolution, dim, stream_); @@ -164,11 +168,11 @@ void SwinTransformerBlock::forward(std::vector* output_tensors, mlp_buf_, mlp_dim); - invokeAddBiasGelu(mlp_buf_, - swin_block_weights.ffn_weights.intermediate_weight.bias, - batch * input_resolution * input_resolution, - mlp_dim, - stream_); + invokeAddBiasGeluV2(mlp_buf_, + swin_block_weights.ffn_weights.intermediate_weight.bias, + batch * input_resolution * input_resolution, + mlp_dim, + stream_); cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, @@ -196,5 +200,8 @@ void SwinTransformerBlock::forward(std::vector* output_tensors, template class SwinTransformerBlock; template class SwinTransformerBlock; +#ifdef ENABLE_BF16 +template class SwinTransformerBlock<__nv_bfloat16>; +#endif } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/models/swin/SwinBlock.h b/src/fastertransformer/models/swin/SwinBlock.h index 4344e1a6f..c25bd3850 100644 --- a/src/fastertransformer/models/swin/SwinBlock.h +++ b/src/fastertransformer/models/swin/SwinBlock.h @@ -26,14 +26,15 @@ namespace fastertransformer { template class SwinTransformerBlock: public BaseLayer { private: - int max_batch_ = 1; - int window_size_ = 7; - int window_len_ = 49; - int window_num_ = 64; - int embed_dim_ = 96; - float mlp_ratio_ = 4.0f; - bool qkv_bias_ = true; - float qk_scale_ = 1.0f; + int max_batch_ = 1; + int window_size_ = 7; + int window_len_ = 49; + int window_num_ = 64; + int embed_dim_ = 96; + float mlp_ratio_ = 4.0f; + bool qkv_bias_ = true; + float qk_scale_ = 1.0f; + float layernorm_eps_; T* buf_ = nullptr; @@ -63,20 +64,21 @@ class SwinTransformerBlock: public BaseLayer { void freeBuffer(); - SwinTransformerBlock(int max_batch, - int window_size, - float mlp_ratio, - cudaStream_t stream, + SwinTransformerBlock(int max_batch, + int window_size, + float mlp_ratio, + float layernorm_eps_, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias, - float qk_scale = 1.0f); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias, + float qk_scale = 1.0f); ~SwinTransformerBlock(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, SwinTransformerBlockWeight& swin_block_weights); }; // class SwinTransformerBlock diff --git a/src/fastertransformer/models/swin/SwinWeight.h b/src/fastertransformer/models/swin/SwinWeight.h index 68ff690b1..02df08530 100644 --- a/src/fastertransformer/models/swin/SwinWeight.h +++ b/src/fastertransformer/models/swin/SwinWeight.h @@ -26,27 +26,27 @@ template class SwinTransformerBlockWeight { public: AttentionWeight attention_weights; - FfnWeight ffn_weights; + FfnWeight ffn_weights; LayerNormWeight attn_layernorm_weights; LayerNormWeight ffn_layernorm_weights; - const T* attention_relative_pos_bias = nullptr; + const T* attention_relative_pos_bias = nullptr; }; // SwinTransformerBlockWeight template class SwinTransformerBasicLayerWeight { public: - LayerNormWeight merge_layernorm_weights; - DenseWeight merge_linear_weights; - const T* attn_mask = nullptr; + LayerNormWeight merge_layernorm_weights; + DenseWeight merge_linear_weights; + const T* attn_mask = nullptr; std::vector> block_weight_list; }; // SwinTransformerBasicLayerWeight template class SwinTransformerWeight { public: - DenseWeight patchEmbed_linear_weights; - LayerNormWeight patchEmbed_norm_weights; - LayerNormWeight norm_weights; + DenseWeight patchEmbed_linear_weights; + LayerNormWeight patchEmbed_norm_weights; + LayerNormWeight norm_weights; std::vector> basic_layer_weight_list; }; // class SwinTransformerWeight diff --git a/src/fastertransformer/models/swin/swin_gemm.cc b/src/fastertransformer/models/swin/swin_gemm.cc index 1620ab840..b9ab286a8 100644 --- a/src/fastertransformer/models/swin/swin_gemm.cc +++ b/src/fastertransformer/models/swin/swin_gemm.cc @@ -29,13 +29,13 @@ int main(int argc, char* argv[]) return 0; } - const int batch_img = atoi(argv[1]); - const int image_width = atoi(argv[2]); - const int window_width = atoi(argv[3]); - const int head_num = atoi(argv[4]); - const int size_per_head = atoi(argv[5]); - const ft::CublasDataType data_type = static_cast(atoi(argv[6])); // 0 FP32, 1 FP16, 2 BF 16 - const int is_int8 = atoi(argv[7]); + const int batch_img = atoi(argv[1]); + const int image_width = atoi(argv[2]); + const int window_width = atoi(argv[3]); + const int head_num = atoi(argv[4]); + const int size_per_head = atoi(argv[5]); + const ft::CublasDataType data_type = static_cast(atoi(argv[6])); // 0 FP32, 1 FP16, 2 BF 16 + const int is_int8 = atoi(argv[7]); printf("[INFO] arguments: \n"); printf(" batch_img: %d \n", batch_img); @@ -50,13 +50,13 @@ int main(int argc, char* argv[]) const int patch_width = 4; const int batch_size = batch_img * (image_width / (patch_width * window_width)) * (image_width / (patch_width * window_width)); - const int seq_len = window_width * window_width; + const int seq_len = window_width * window_width; const int inter_size = 4 * head_num * size_per_head; - void* gemm_test_buf; + void* gemm_test_buf; size_t buf_size_in_byte = ft::calGemmTestBufSizeInByte( batch_size, seq_len, head_num, size_per_head, inter_size, 0, 1, ft::FLOAT_DATATYPE); - int batch_tmp = batch_size; + int batch_tmp = batch_size; int head_num_tmp = head_num; for (int i = 1; i < 4; i++) { batch_tmp /= 4; diff --git a/src/fastertransformer/models/swin_int8/CMakeLists.txt b/src/fastertransformer/models/swin_int8/CMakeLists.txt index 9847d449d..dd8edfd05 100644 --- a/src/fastertransformer/models/swin_int8/CMakeLists.txt +++ b/src/fastertransformer/models/swin_int8/CMakeLists.txt @@ -18,14 +18,14 @@ add_library(SwinBlockINT8 STATIC SwinBlockINT8.cc) set_property(TARGET SwinBlockINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET SwinBlockINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(SwinBlockINT8 PUBLIC -lcublasLt -lcublas -lcudart - WindowAttentionINT8 activation_int8_kernels add_residual_kernels) + WindowAttentionINT8 activation_int8_kernels add_residual_kernels tensor) add_library(SwinBasicLayerINT8 STATIC SwinBasicLayerINT8.cc) set_property(TARGET SwinBasicLayerINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET SwinBasicLayerINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(SwinBasicLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart SwinBlockINT8 dequantize_kernels) +target_link_libraries(SwinBasicLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart SwinBlockINT8 dequantize_kernels tensor) add_library(SwinINT8 STATIC SwinINT8.cc) set_property(TARGET SwinINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET SwinINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(SwinINT8 PUBLIC -lcudart SwinBasicLayerINT8 activation_kernels memory_utils) +target_link_libraries(SwinINT8 PUBLIC -lcudart SwinBasicLayerINT8 activation_kernels memory_utils tensor) diff --git a/src/fastertransformer/models/swin_int8/SwinBasicLayerINT8.cc b/src/fastertransformer/models/swin_int8/SwinBasicLayerINT8.cc index 84288ef45..e656e1a2f 100644 --- a/src/fastertransformer/models/swin_int8/SwinBasicLayerINT8.cc +++ b/src/fastertransformer/models/swin_int8/SwinBasicLayerINT8.cc @@ -22,11 +22,13 @@ template void SwinTransformerINT8BasicLayer::allocateBuffer() { if (is_allocate_buffer_ == false) { - block_output_ = (T*)allocator_->malloc( - 2 * max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(T), false); + block_output_ = (T*)allocator_->reMalloc( + block_output_, 2 * max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(T), false); - gemm_out_buf_ = (int8_t*)allocator_->malloc( - max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ / 2 * sizeof(int32_t), false); + gemm_out_buf_ = (int8_t*)allocator_->reMalloc(gemm_out_buf_, + max_batch_ * patches_resolution_ * patches_resolution_ + * embed_dim_ / 2 * sizeof(int32_t), + false); is_allocate_buffer_ = true; } @@ -36,8 +38,8 @@ template void SwinTransformerINT8BasicLayer::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(block_output_); - allocator_->free(gemm_out_buf_); + allocator_->free((void**)(&block_output_)); + allocator_->free((void**)(&gemm_out_buf_)); is_allocate_buffer_ = false; } } @@ -46,18 +48,18 @@ void SwinTransformerINT8BasicLayer::freeBuffer() // merge_layernorm_buf is [B, H/2, W/2, 4*C] // output is [B, H/2, W/2, 2*C] template -void SwinTransformerINT8BasicLayer::patchMerge(T* output, - int8_t* gemm_out_buf, - int8_t* merge_layernorm_buf, - const T* input, - const T* gamma, - const T* beta, - const int8_t* weight, - int batch, +void SwinTransformerINT8BasicLayer::patchMerge(T* output, + int8_t* gemm_out_buf, + int8_t* merge_layernorm_buf, + const T* input, + const T* gamma, + const T* beta, + const int8_t* weight, + int batch, const ScaleList* scalePtr, - int input_resolution, - int dim, - int sm) + int input_resolution, + int dim, + int sm) { cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_; invokeMergeLayerNormCol32(merge_layernorm_buf, @@ -112,18 +114,19 @@ void SwinTransformerINT8BasicLayer::patchMerge(T* output, } template -SwinTransformerINT8BasicLayer::SwinTransformerINT8BasicLayer(int int8_mode, - int max_batch, - int window_size, - int patches_resolution, - int embed_dim, - float mlp_ratio, - bool qkv_bias, - float qk_scale, - cudaStream_t stream, +SwinTransformerINT8BasicLayer::SwinTransformerINT8BasicLayer(int int8_mode, + int max_batch, + int window_size, + int patches_resolution, + int embed_dim, + float mlp_ratio, + float layernorm_eps, + bool qkv_bias, + float qk_scale, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), int8_mode(int8_mode), max_batch_(max_batch), @@ -131,6 +134,7 @@ SwinTransformerINT8BasicLayer::SwinTransformerINT8BasicLayer(int int8_mode, embed_dim_(embed_dim), window_size_(window_size), mlp_ratio_(mlp_ratio), + layernorm_eps_(layernorm_eps), qkv_bias_(qkv_bias), qk_scale_(qk_scale) { @@ -140,6 +144,7 @@ SwinTransformerINT8BasicLayer::SwinTransformerINT8BasicLayer(int int8_mode, patches_resolution_, embed_dim_, mlp_ratio_, + layernorm_eps_, qkv_bias_, stream, cublas_wrapper, @@ -158,8 +163,8 @@ SwinTransformerINT8BasicLayer::~SwinTransformerINT8BasicLayer() } template -void SwinTransformerINT8BasicLayer::forward(std::vector* output_tensors, - std::vector* input_tensors, +void SwinTransformerINT8BasicLayer::forward(std::vector* output_tensors, + std::vector* input_tensors, SwinTransformerINT8BasicLayerWeight& swin_basic_layer_weights) { // input_tensors: @@ -168,26 +173,26 @@ void SwinTransformerINT8BasicLayer::forward(std::vector* output_tenso // output_tensors: // output [batch, output_resolution, output_resolution, output_dim] - T* from_tensor = (T*)input_tensors->at(0).data; - T* out_tensor = (T*)(output_tensors->at(0).data); - size_t batch = input_tensors->at(0).shape[0]; + T* from_tensor = (T*)input_tensors->at(0).data; + T* out_tensor = (T*)(output_tensors->at(0).data); + size_t batch = input_tensors->at(0).shape[0]; size_t input_resolution = input_tensors->at(0).shape[1]; assert(input_resolution == input_tensors->at(0).shape[2]); - size_t dim = input_tensors->at(0).shape[3]; - int* input_paramters = (int*)input_tensors->at(1).data; - const int depth = input_paramters[0]; - const int num_head = input_paramters[1]; - bool do_patch_merge = (input_paramters[2] == 1) ? true : false; - const int sm = input_paramters[3]; + size_t dim = input_tensors->at(0).shape[3]; + int* input_paramters = (int*)input_tensors->at(1).data; + const int depth = input_paramters[0]; + const int num_head = input_paramters[1]; + bool do_patch_merge = (input_paramters[2] == 1) ? true : false; + const int sm = input_paramters[3]; allocateBuffer(); - int block_output_size = batch * input_resolution * input_resolution * dim; - size_t m = batch * input_resolution * input_resolution; - size_t n = dim; - size_t window_num = (input_resolution / window_size_) * (input_resolution / window_size_); - size_t window_len = window_size_ * window_size_; - int shift_size = 0; - DataType data_type = getTensorType(); + int block_output_size = batch * input_resolution * input_resolution * dim; + size_t m = batch * input_resolution * input_resolution; + size_t n = dim; + size_t window_num = (input_resolution / window_size_) * (input_resolution / window_size_); + size_t window_len = window_size_ * window_size_; + int shift_size = 0; + DataType data_type = getTensorType(); if (do_patch_merge) { for (int i = 0; i < depth; i++) { @@ -196,8 +201,8 @@ void SwinTransformerINT8BasicLayer::forward(std::vector* output_tenso data_type, std::vector{batch, input_resolution, input_resolution, dim}, block_output_ + (i % 2) * block_output_size}}; - shift_size = (i % 2 == 0) ? 0 : (window_size_ / 2); - int additional_parameters[3] = {num_head, shift_size, sm}; + shift_size = (i % 2 == 0) ? 0 : (window_size_ / 2); + int additional_parameters[3] = {num_head, shift_size, sm}; std::vector tmp_input_tensors{ Tensor{MEMORY_GPU, data_type, @@ -233,8 +238,8 @@ void SwinTransformerINT8BasicLayer::forward(std::vector* output_tenso data_type, std::vector{batch, input_resolution, input_resolution, dim}, i == depth - 1 ? out_tensor : block_output_ + (i % 2) * block_output_size}}; - shift_size = (i % 2 == 0) ? 0 : (window_size_ / 2); - int additional_parameters[3] = {num_head, shift_size, sm}; + shift_size = (i % 2 == 0) ? 0 : (window_size_ / 2); + int additional_parameters[3] = {num_head, shift_size, sm}; std::vector tmp_input_tensors{ Tensor{MEMORY_GPU, data_type, diff --git a/src/fastertransformer/models/swin_int8/SwinBasicLayerINT8.h b/src/fastertransformer/models/swin_int8/SwinBasicLayerINT8.h index fc6c197ae..4e9f67b31 100644 --- a/src/fastertransformer/models/swin_int8/SwinBasicLayerINT8.h +++ b/src/fastertransformer/models/swin_int8/SwinBasicLayerINT8.h @@ -22,20 +22,21 @@ namespace fastertransformer { template class SwinTransformerINT8BasicLayer: public BaseLayer { private: - int int8_mode = 0; - int max_batch_ = 1; - int patches_resolution_ = 64; - int embed_dim_ = 96; - int window_size_ = 7; - float mlp_ratio_ = 4.0f; - bool qkv_bias_ = true; - float qk_scale_ = 1.0f; + int int8_mode = 0; + int max_batch_ = 1; + int patches_resolution_ = 64; + int embed_dim_ = 96; + int window_size_ = 7; + float mlp_ratio_ = 4.0f; + bool qkv_bias_ = true; + float qk_scale_ = 1.0f; + float layernorm_eps_; const size_t max_buf_size_ = 0; // T* buf_ = nullptr; - T* block_output_ = nullptr; - int8_t* gemm_out_buf_ = nullptr; - SwinTransformerINT8Block* block_ = nullptr; + T* block_output_ = nullptr; + int8_t* gemm_out_buf_ = nullptr; + SwinTransformerINT8Block* block_ = nullptr; void allocateBuffer(); @@ -44,38 +45,39 @@ class SwinTransformerINT8BasicLayer: public BaseLayer { // input is [B, H, W, C] // merge_layernorm_buf is [B, H/2, W/2, 4*C] // output is [B, H/2, W/2, 2*C] - void patchMerge(T* output, - int8_t* gemm_out_buf, - int8_t* merge_layernorm_buf, - const T* input, - const T* gamma, - const T* beta, - const int8_t* weight, - int batch, + void patchMerge(T* output, + int8_t* gemm_out_buf, + int8_t* merge_layernorm_buf, + const T* input, + const T* gamma, + const T* beta, + const int8_t* weight, + int batch, const ScaleList* scalePtr, - int input_resolution, - int dim, - int sm); + int input_resolution, + int dim, + int sm); public: // dim & input_resolution will be used to malloc the max buf size - SwinTransformerINT8BasicLayer(int int8_mode, - int max_batch, - int window_size, - int patches_resolution, - int embed_dim, - float mlp_ratio, - bool qkv_bias, - float qk_scale, - cudaStream_t stream, + SwinTransformerINT8BasicLayer(int int8_mode, + int max_batch, + int window_size, + int patches_resolution, + int embed_dim, + float mlp_ratio, + float layernorm_eps, + bool qkv_bias, + float qk_scale, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); ~SwinTransformerINT8BasicLayer(); - void forward(std::vector* output_tensors, - std::vector* input_tensors, + void forward(std::vector* output_tensors, + std::vector* input_tensors, SwinTransformerINT8BasicLayerWeight& swin_basic_layer_weights); }; // class SwinTransformerINT8BasicLayer diff --git a/src/fastertransformer/models/swin_int8/SwinBlockINT8.cc b/src/fastertransformer/models/swin_int8/SwinBlockINT8.cc index 1914a2e52..c36ca78bc 100644 --- a/src/fastertransformer/models/swin_int8/SwinBlockINT8.cc +++ b/src/fastertransformer/models/swin_int8/SwinBlockINT8.cc @@ -22,18 +22,21 @@ template void SwinTransformerINT8Block::allocateBuffer() { if (is_allocate_buffer_ == false) { - attention_output_ = (int8_t*)allocator_->malloc( - max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int8_t), false); - skip_buf_ = (int8_t*)allocator_->malloc( - max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int8_t), false); - mlp_buf_ = (int8_t*)allocator_->malloc(max_batch_ * patches_resolution_ * patches_resolution_ - * int(embed_dim_ * mlp_ratio_) * sizeof(int8_t), - false); - mlp_output_ = (int8_t*)allocator_->malloc( - max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int32_t), false); + attention_output_ = (int8_t*)allocator_->reMalloc(attention_output_, + max_batch_ * patches_resolution_ * patches_resolution_ + * embed_dim_ * sizeof(int8_t), + false); + skip_buf_ = (int8_t*)allocator_->reMalloc( + skip_buf_, max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int8_t), false); + mlp_buf_ = (int8_t*)allocator_->reMalloc(mlp_buf_, + max_batch_ * patches_resolution_ * patches_resolution_ + * int(embed_dim_ * mlp_ratio_) * sizeof(int8_t), + false); + mlp_output_ = (int8_t*)allocator_->reMalloc( + mlp_output_, max_batch_ * patches_resolution_ * patches_resolution_ * embed_dim_ * sizeof(int32_t), false); normed_shifted_input_ = mlp_buf_; - is_allocate_buffer_ = true; + is_allocate_buffer_ = true; } } @@ -41,27 +44,28 @@ template void SwinTransformerINT8Block::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(attention_output_); - allocator_->free(skip_buf_); - allocator_->free(mlp_buf_); - allocator_->free(mlp_output_); + allocator_->free((void**)(&attention_output_)); + allocator_->free((void**)(&skip_buf_)); + allocator_->free((void**)(&mlp_buf_)); + allocator_->free((void**)(&mlp_output_)); is_allocate_buffer_ = false; } } template -SwinTransformerINT8Block::SwinTransformerINT8Block(int int8_mode, - int max_batch, - int window_size, - int patches_resolution, - int embed_dim, - float mlp_ratio, - bool qkv_bias, - cudaStream_t stream, +SwinTransformerINT8Block::SwinTransformerINT8Block(int int8_mode, + int max_batch, + int window_size, + int patches_resolution, + int embed_dim, + float mlp_ratio, + float layernorm_eps, + bool qkv_bias, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - float qk_scale): + IAllocator* allocator, + bool is_free_buffer_after_forward, + float qk_scale): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), int8_mode(int8_mode), max_batch_(max_batch), @@ -69,11 +73,12 @@ SwinTransformerINT8Block::SwinTransformerINT8Block(int int8_mode, patches_resolution_(patches_resolution), embed_dim_(embed_dim), mlp_ratio_(mlp_ratio), + layernorm_eps_(layernorm_eps), qkv_bias_(qkv_bias), qk_scale_(qk_scale) { window_len_ = window_size_ * window_size_; - atten_ = new WindowAttentionINT8(max_batch_, + atten_ = new WindowAttentionINT8(max_batch_, window_size_, patches_resolution, embed_dim, @@ -95,8 +100,8 @@ SwinTransformerINT8Block::~SwinTransformerINT8Block() } template -void SwinTransformerINT8Block::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void SwinTransformerINT8Block::forward(std::vector* output_tensors, + const std::vector* input_tensors, SwinTransformerINT8BlockWeight& swin_block_weights) { // input_tensors: @@ -106,20 +111,20 @@ void SwinTransformerINT8Block::forward(std::vector* output_tensors, // output_tensors: // output [batch, input_resolution, input_resolution, dim] - cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_; - T* from_tensor = (T*)input_tensors->at(0).data; - T* out_tensor = (T*)(output_tensors->at(0).data); - const int batch = input_tensors->at(0).shape[0]; - const int input_resolution = input_tensors->at(0).shape[1]; - const int dim = input_tensors->at(0).shape[3]; - const int* input_parameters = (const int*)input_tensors->at(2).data; - const int num_head = input_parameters[0]; - int shift_size = input_parameters[1]; - const int sm = input_parameters[2]; - - shift_size = (input_resolution <= window_size_) ? 0 : shift_size; + cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_; + T* from_tensor = (T*)input_tensors->at(0).data; + T* out_tensor = (T*)(output_tensors->at(0).data); + const int batch = input_tensors->at(0).shape[0]; + const int input_resolution = input_tensors->at(0).shape[1]; + const int dim = input_tensors->at(0).shape[3]; + const int* input_parameters = (const int*)input_tensors->at(2).data; + const int num_head = input_parameters[0]; + int shift_size = input_parameters[1]; + const int sm = input_parameters[2]; + + shift_size = (input_resolution <= window_size_) ? 0 : shift_size; size_t window_num = (input_resolution / window_size_) * (input_resolution / window_size_); - int mlp_dim = int(mlp_ratio_ * dim); + int mlp_dim = int(mlp_ratio_ * dim); allocateBuffer(); const size_t m = batch * input_resolution * input_resolution; @@ -142,7 +147,7 @@ void SwinTransformerINT8Block::forward(std::vector* output_tensors, stream_); sync_check_cuda_error(); - int additional_parameters[6] = {batch, dim, input_resolution, num_head, shift_size, sm}; + int additional_parameters[6] = {batch, dim, input_resolution, num_head, shift_size, sm}; std::vector attn_output_tensors{ Tensor{MEMORY_GPU, TYPE_INT8, std: vector{m, n}, attention_output_}}; std::vector int8_input_tensors{ diff --git a/src/fastertransformer/models/swin_int8/SwinBlockINT8.h b/src/fastertransformer/models/swin_int8/SwinBlockINT8.h index f4e477025..f7b1ce80f 100644 --- a/src/fastertransformer/models/swin_int8/SwinBlockINT8.h +++ b/src/fastertransformer/models/swin_int8/SwinBlockINT8.h @@ -27,15 +27,16 @@ namespace fastertransformer { template class SwinTransformerINT8Block: public BaseLayer { private: - int int8_mode = 0; - int max_batch_ = 1; - int window_size_ = 7; - int window_len_ = 49; - int patches_resolution_ = 56; - int embed_dim_ = 96; - float mlp_ratio_ = 4.0f; - bool qkv_bias_ = true; - float qk_scale_ = 1.0f; + int int8_mode = 0; + int max_batch_ = 1; + int window_size_ = 7; + int window_len_ = 49; + int patches_resolution_ = 56; + int embed_dim_ = 96; + float mlp_ratio_ = 4.0f; + float layernorm_eps_; + bool qkv_bias_ = true; + float qk_scale_ = 1.0f; size_t max_buf_size_ = 0; int8_t* buf_ = nullptr; @@ -49,23 +50,24 @@ class SwinTransformerINT8Block: public BaseLayer { void freeBuffer(); public: - SwinTransformerINT8Block(int int8_mode, - int max_batch, - int window_size, - int patches_resolution, - int embed_dim, - float mlp_ratio, - bool qkv_bias, - cudaStream_t stream, + SwinTransformerINT8Block(int int8_mode, + int max_batch, + int window_size, + int patches_resolution, + int embed_dim, + float mlp_ratio, + float layernorm_eps_, + bool qkv_bias, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - float qk_scale = 1.0f); + IAllocator* allocator, + bool is_free_buffer_after_forward, + float qk_scale = 1.0f); ~SwinTransformerINT8Block(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, SwinTransformerINT8BlockWeight& swin_block_weights); }; // class SwinTransformerINT8Block diff --git a/src/fastertransformer/models/swin_int8/SwinINT8.cc b/src/fastertransformer/models/swin_int8/SwinINT8.cc index 056996ede..c32368a7c 100644 --- a/src/fastertransformer/models/swin_int8/SwinINT8.cc +++ b/src/fastertransformer/models/swin_int8/SwinINT8.cc @@ -45,7 +45,7 @@ template void SwinTransformerINT8::allocateBuffer() { if (is_allocate_buffer_ == false) { - buf_ = (T*)(allocator_->malloc(max_buf_size_, false)); + buf_ = (T*)(allocator_->reMalloc(buf_, max_buf_size_, false)); x_patch_embed_ = buf_; @@ -71,7 +71,7 @@ template void SwinTransformerINT8::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(buf_); + allocator_->free((void**)(&buf_)); is_allocate_buffer_ = false; } } @@ -79,26 +79,32 @@ void SwinTransformerINT8::freeBuffer() // input is [B, C_in, H, W] // output is [B, H, W, C_out] template -void SwinTransformerINT8::patchEmbed(T* output, - const T* input, - const T* kernel, - const T* bias, - const T* gamma, - const T* beta, - const int batch, - const int img_size, - const int patch_size, - const int patches_resolution, - const int in_chans, - const int embed_dim, +void SwinTransformerINT8::patchEmbed(T* output, + const T* input, + const T* kernel, + const T* bias, + const T* gamma, + const T* beta, + const int batch, + const int img_size, + const int patch_size, + const int patches_resolution, + const int in_chans, + const int embed_dim, const bool patch_norm) { conv2d( output, input, kernel, batch, img_size, img_size, in_chans, embed_dim, patch_size, patch_size, cudnn_handle_); if (patch_norm) { - invokeAddBiasLayernorm( - output, bias, gamma, beta, batch * patches_resolution * patches_resolution, embed_dim, stream_); + invokeAddBiasLayernorm(output, + bias, + gamma, + beta, + layernorm_eps_, + batch * patches_resolution * patches_resolution, + embed_dim, + stream_); } else { invokeAddBias(output, bias, batch * patches_resolution * patches_resolution, embed_dim, stream_); @@ -106,26 +112,26 @@ void SwinTransformerINT8::patchEmbed(T* output, } template -SwinTransformerINT8::SwinTransformerINT8(int int8_mode, - int max_batch, - int img_size, - int patch_size, - int in_chans, - int embed_dim, - int window_size, - int* depths, - int* num_heads, - bool ape, - bool patch_norm, - int layer_num, - float mlp_ratio, - cudnnHandle_t cudnn_handle, - cudaStream_t stream, +SwinTransformerINT8::SwinTransformerINT8(int int8_mode, + int max_batch, + int img_size, + int patch_size, + int in_chans, + int embed_dim, + int window_size, + int* depths, + int* num_heads, + bool ape, + bool patch_norm, + int layer_num, + float mlp_ratio, + cudnnHandle_t cudnn_handle, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias, - float qk_scale): + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias, + float qk_scale): int8_mode(int8_mode), max_batch_(max_batch), img_size_(img_size), @@ -149,7 +155,7 @@ SwinTransformerINT8::SwinTransformerINT8(int int8_mode, { patches_resolution_ = img_size / patch_size; - max_buf_size_ = getBufSize(max_batch_, patches_resolution_, layer_num_, embed_dim_); + max_buf_size_ = getBufSize(max_batch_, patches_resolution_, layer_num_, embed_dim_); basic_layer_ = new SwinTransformerINT8BasicLayer(int8_mode, max_batch_, @@ -157,6 +163,7 @@ SwinTransformerINT8::SwinTransformerINT8(int int8_mode, patches_resolution_, embed_dim_, mlp_ratio_, + layernorm_eps_, qkv_bias_, qk_scale_, stream_, @@ -175,8 +182,8 @@ SwinTransformerINT8::~SwinTransformerINT8() } template -void SwinTransformerINT8::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void SwinTransformerINT8::forward(std::vector* output_tensors, + const std::vector* input_tensors, SwinTransformerINT8Weight& swin_weights) { // input_tensors: @@ -185,10 +192,10 @@ void SwinTransformerINT8::forward(std::vector* output_tensors, // output_tensors: // output_embedding [batch, final_len] - T* from_tensor = (T*)input_tensors->at(0).data; - T* output = (T*)(output_tensors->at(0).data); - const size_t batch = input_tensors->at(0).shape[0]; - const int sm = *(const int*)input_tensors->at(1).data; + T* from_tensor = (T*)input_tensors->at(0).data; + T* output = (T*)(output_tensors->at(0).data); + const size_t batch = input_tensors->at(0).shape[0]; + const int sm = *(const int*)input_tensors->at(1).data; allocateBuffer(); patchEmbed(x_patch_embed_, from_tensor, @@ -204,12 +211,12 @@ void SwinTransformerINT8::forward(std::vector* output_tensors, embed_dim_, patch_norm_); - size_t basic_layer_dim = embed_dim_; - size_t basic_layer_input_resolution = patches_resolution_; - int basic_layer_output_size = batch * patches_resolution_ * patches_resolution_ * embed_dim_ / 2; - size_t m = batch * patches_resolution_ * patches_resolution_; - size_t n = embed_dim_; - DataType data_type = getTensorType(); + size_t basic_layer_dim = embed_dim_; + size_t basic_layer_input_resolution = patches_resolution_; + int basic_layer_output_size = batch * patches_resolution_ * patches_resolution_ * embed_dim_ / 2; + size_t m = batch * patches_resolution_ * patches_resolution_; + size_t n = embed_dim_; + DataType data_type = getTensorType(); bool do_patch_merge = true; if (layer_num_ == 1) { @@ -229,7 +236,7 @@ void SwinTransformerINT8::forward(std::vector* output_tensors, data_type, std::vector{batch, basic_layer_input_resolution, basic_layer_input_resolution, basic_layer_dim}, basic_layer_output_ + (i % 2) * basic_layer_output_size}}; - int additional_parameters[4] = {depths_[i], num_heads_[i], do_patch_merge ? 1 : 0, sm}; + int additional_parameters[4] = {depths_[i], num_heads_[i], do_patch_merge ? 1 : 0, sm}; std::vector tmp_input_tensors{ Tensor{ MEMORY_GPU, @@ -254,6 +261,7 @@ void SwinTransformerINT8::forward(std::vector* output_tensors, buffer_COL32, swin_weights.norm_weights.gamma, swin_weights.norm_weights.beta, + layernorm_eps_, batch * basic_layer_input_resolution * basic_layer_input_resolution, basic_layer_dim, stream_); diff --git a/src/fastertransformer/models/swin_int8/SwinINT8.h b/src/fastertransformer/models/swin_int8/SwinINT8.h index de0c52a55..a9e41d937 100644 --- a/src/fastertransformer/models/swin_int8/SwinINT8.h +++ b/src/fastertransformer/models/swin_int8/SwinINT8.h @@ -26,39 +26,40 @@ namespace fastertransformer { template class SwinTransformerINT8 { private: - int int8_mode = 0; - int max_batch_ = 1; - int img_size_ = 224; - int patch_size_ = 4; - int in_chans_ = 3; - int embed_dim_ = 96; - int window_size_ = 7; - int* depths_; - int* num_heads_; - bool ape_ = false; - bool patch_norm_ = true; - float mlp_ratio_ = 4.0f; - bool qkv_bias_ = true; - int patches_resolution_ = 56; - int layer_num_ = 4; - int qk_scale_ = 1.0f; - size_t max_buf_size_ = 0; - size_t max_basic_layer_buf_size_ = 0; - size_t max_block_buf_size_ = 0; - size_t max_window_attention_buf_size_ = 0; - IAllocator* allocator_ = nullptr; - cudnnHandle_t cudnn_handle_; - cudaStream_t stream_; - cublasMMWrapper* cublas_wrapper_; - bool is_free_buffer_after_forward_; - bool is_allocate_buffer_ = false; + int int8_mode = 0; + int max_batch_ = 1; + int img_size_ = 224; + int patch_size_ = 4; + int in_chans_ = 3; + int embed_dim_ = 96; + int window_size_ = 7; + int* depths_; + int* num_heads_; + bool ape_ = false; + bool patch_norm_ = true; + constexpr static float layernorm_eps_ = 1e-6f; + float mlp_ratio_ = 4.0f; + bool qkv_bias_ = true; + int patches_resolution_ = 56; + int layer_num_ = 4; + int qk_scale_ = 1.0f; + size_t max_buf_size_ = 0; + size_t max_basic_layer_buf_size_ = 0; + size_t max_block_buf_size_ = 0; + size_t max_window_attention_buf_size_ = 0; + IAllocator* allocator_ = nullptr; + cudnnHandle_t cudnn_handle_; + cudaStream_t stream_; + cublasMMWrapper* cublas_wrapper_; + bool is_free_buffer_after_forward_; + bool is_allocate_buffer_ = false; - T* buf_ = nullptr; - T* x_patch_embed_ = nullptr; + T* buf_ = nullptr; + T* x_patch_embed_ = nullptr; T* basic_layer_output_ = nullptr; // for avgPool_ones T* avg_pool_ones_ = nullptr; - T* buffer_COL32 = nullptr; + T* buffer_COL32 = nullptr; SwinTransformerINT8BasicLayer* basic_layer_ = nullptr; @@ -70,46 +71,46 @@ class SwinTransformerINT8 { // input is [B, C_in, H, W] // output is [B, H, W, C_out] - void patchEmbed(T* output, - const T* input, - const T* kernel, - const T* bias, - const T* gamma, - const T* beta, - const int batch, - const int img_size, - const int patch_size, - const int patches_resolution, - const int in_chans, - const int embed_dim, + void patchEmbed(T* output, + const T* input, + const T* kernel, + const T* bias, + const T* gamma, + const T* beta, + const int batch, + const int img_size, + const int patch_size, + const int patches_resolution, + const int in_chans, + const int embed_dim, const bool patch_norm); public: - SwinTransformerINT8(int int8_mode, - int max_batch, - int img_size, - int patch_size, - int in_chans, - int embed_dim, - int window_size, - int* depths, - int* num_heads, - bool ape, - bool patch_norm, - int layer_num, - float mlp_ratio, - cudnnHandle_t cudnn_handle, - cudaStream_t stream, + SwinTransformerINT8(int int8_mode, + int max_batch, + int img_size, + int patch_size, + int in_chans, + int embed_dim, + int window_size, + int* depths, + int* num_heads, + bool ape, + bool patch_norm, + int layer_num, + float mlp_ratio, + cudnnHandle_t cudnn_handle, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool qkv_bias = true, - float qk_scale = 1.0f); + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool qkv_bias = true, + float qk_scale = 1.0f); ~SwinTransformerINT8(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, SwinTransformerINT8Weight& swin_weights); }; // class SwinTransformerINT8 diff --git a/src/fastertransformer/models/swin_int8/SwinINT8Weight.h b/src/fastertransformer/models/swin_int8/SwinINT8Weight.h index 4429c7fe6..989cb4062 100644 --- a/src/fastertransformer/models/swin_int8/SwinINT8Weight.h +++ b/src/fastertransformer/models/swin_int8/SwinINT8Weight.h @@ -29,28 +29,28 @@ template class SwinTransformerINT8BlockWeight { public: AttentionINT8Weight attention_weights; - FfnINT8Weight ffn_weights; - LayerNormWeight attn_layernorm_weights; - LayerNormWeight ffn_layernorm_weights; - const T* attention_relative_pos_bias = nullptr; - ScaleList scalelist; + FfnINT8Weight ffn_weights; + LayerNormWeight attn_layernorm_weights; + LayerNormWeight ffn_layernorm_weights; + const T* attention_relative_pos_bias = nullptr; + ScaleList scalelist; }; // SwinTransformerINT8BlockWeight template class SwinTransformerINT8BasicLayerWeight { public: - LayerNormWeight merge_layernorm_weights; - DenseWeight merge_linear_weights; - const T* attn_mask = nullptr; + LayerNormWeight merge_layernorm_weights; + DenseWeight merge_linear_weights; + const T* attn_mask = nullptr; vector> block_weight_list; }; // SwinTransformerINT8BasicLayerWeight template class SwinTransformerINT8Weight { public: - DenseWeight patchEmbed_linear_weights; - LayerNormWeight patchEmbed_norm_weights; - LayerNormWeight norm_weights; + DenseWeight patchEmbed_linear_weights; + LayerNormWeight patchEmbed_norm_weights; + LayerNormWeight norm_weights; vector> basic_layer_weight_list; }; // class SwinTransformerINT8Weight diff --git a/src/fastertransformer/models/t5/CMakeLists.txt b/src/fastertransformer/models/t5/CMakeLists.txt index 9f3455d3f..21804554b 100644 --- a/src/fastertransformer/models/t5/CMakeLists.txt +++ b/src/fastertransformer/models/t5/CMakeLists.txt @@ -18,22 +18,22 @@ add_library(T5Decoder STATIC T5Decoder.cc T5DecoderLayerWeight.cc) set_property(TARGET T5Decoder PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET T5Decoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(T5Decoder PUBLIC -lcudart cublasMMWrapper TensorParallelDecoderSelfAttentionLayer - TensorParallelDecoderCrossAttentionLayer TensorParallelReluFfnLayer - layernorm_kernels add_residual_kernels nccl_utils memory_utils) + TensorParallelDecoderCrossAttentionLayer TensorParallelReluFfnLayer TensorParallelSiluFfnLayer + layernorm_kernels add_residual_kernels nccl_utils memory_utils tensor) add_library(T5Decoding STATIC T5Decoding.cc T5DecodingWeight.cc) set_property(TARGET T5Decoding PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET T5Decoding PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(T5Decoding PUBLIC -lcudart cublasMMWrapper T5Decoder bert_preprocess_kernels decoding_kernels DynamicDecodeLayer BaseBeamSearchLayer - beam_search_topk_kernels gpt_kernels) + beam_search_topk_kernels gpt_kernels tensor) add_library(T5Encoder STATIC T5Encoder.cc T5EncoderWeight.cc T5EncoderLayerWeight.cc) set_property(TARGET T5Encoder PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET T5Encoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(T5Encoder PUBLIC -lcudart bert_preprocess_kernels cublasMMWrapper TensorParallelUnfusedAttentionLayer FusedAttentionLayer TensorParallelReluFfnLayer - TensorParallelGeluFfnLayer layernorm_kernels add_residual_kernels nccl_utils) + TensorParallelGeluFfnLayer TensorParallelSiluFfnLayer layernorm_kernels add_residual_kernels nccl_utils tensor) add_executable(t5_gemm t5_gemm.cc) -target_link_libraries(t5_gemm PUBLIC -lcudart t5_gemm_func memory_utils) \ No newline at end of file +target_link_libraries(t5_gemm PUBLIC -lcudart t5_gemm_func memory_utils) diff --git a/src/fastertransformer/models/t5/T5Decoder.cc b/src/fastertransformer/models/t5/T5Decoder.cc index 73c22310c..c482fa7e9 100644 --- a/src/fastertransformer/models/t5/T5Decoder.cc +++ b/src/fastertransformer/models/t5/T5Decoder.cc @@ -31,6 +31,7 @@ void T5Decoder::initialize() stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, false, 0, @@ -51,7 +52,9 @@ void T5Decoder::initialize() custom_all_reduce_comm_, enable_custom_all_reduce_); - if (activation_type_ == ActivationType::Gelu) { + bool use_gated_activation = activation_type_ == ActivationType::GeGLU || activation_type_ == ActivationType::ReGLU + || activation_type_ == ActivationType::SiGLU; + if (activation_type_ == ActivationType::Gelu || activation_type_ == ActivationType::GeGLU) { ffn_layer_ = new TensorParallelGeluFfnLayer(max_batch_size_, 1, 1, @@ -61,13 +64,15 @@ void T5Decoder::initialize() stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, false, 0, + use_gated_activation, custom_all_reduce_comm_, enable_custom_all_reduce_); } - else { + else if (activation_type_ == ActivationType::Relu || activation_type_ == ActivationType::ReGLU) { ffn_layer_ = new TensorParallelReluFfnLayer(max_batch_size_, 1, 1, @@ -77,8 +82,27 @@ void T5Decoder::initialize() stream_, cublas_wrapper_, allocator_, + true, + is_free_buffer_after_forward_, + false, + use_gated_activation, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + } + else if (activation_type_ == ActivationType::Silu || activation_type_ == ActivationType::SiGLU) { + ffn_layer_ = new TensorParallelSiluFfnLayer(max_batch_size_, + 1, + 1, + d_model_, + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + true, is_free_buffer_after_forward_, false, + use_gated_activation, custom_all_reduce_comm_, enable_custom_all_reduce_); } @@ -88,14 +112,18 @@ template void T5Decoder::allocateBuffer() { if (is_allocate_buffer_ == false) { - decoder_normed_input_ = reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * d_model_, false)); - self_attn_output_ = reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * d_model_, false)); - normed_self_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * d_model_, false)); - cross_attn_output_ = reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * d_model_, false)); - normed_cross_attn_output_ = - reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * d_model_, false)); - decoder_layer_output_ = reinterpret_cast(allocator_->malloc(sizeof(T) * max_batch_size_ * d_model_, false)); + decoder_normed_input_ = reinterpret_cast( + allocator_->reMalloc(decoder_normed_input_, sizeof(T) * max_batch_size_ * d_model_, false)); + self_attn_output_ = reinterpret_cast( + allocator_->reMalloc(self_attn_output_, sizeof(T) * max_batch_size_ * d_model_, false)); + normed_self_attn_output_ = reinterpret_cast( + allocator_->reMalloc(normed_self_attn_output_, sizeof(T) * max_batch_size_ * d_model_, false)); + cross_attn_output_ = reinterpret_cast( + allocator_->reMalloc(cross_attn_output_, sizeof(T) * max_batch_size_ * d_model_, false)); + normed_cross_attn_output_ = reinterpret_cast( + allocator_->reMalloc(normed_cross_attn_output_, sizeof(T) * max_batch_size_ * d_model_, false)); + decoder_layer_output_ = reinterpret_cast( + allocator_->reMalloc(decoder_layer_output_, sizeof(T) * max_batch_size_ * d_model_, false)); is_allocate_buffer_ = true; } } @@ -122,13 +150,14 @@ void T5Decoder::allocateBuffer(size_t batch_size) template void T5Decoder::freeBuffer() { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_ == true) { - allocator_->free(decoder_normed_input_); - allocator_->free(self_attn_output_); - allocator_->free(normed_self_attn_output_); - allocator_->free(cross_attn_output_); - allocator_->free(normed_cross_attn_output_); - allocator_->free(decoder_layer_output_); + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + allocator_->free((void**)(&normed_self_attn_output_)); + allocator_->free((void**)(&cross_attn_output_)); + allocator_->free((void**)(&normed_cross_attn_output_)); + allocator_->free((void**)(&decoder_layer_output_)); is_allocate_buffer_ = false; } } @@ -147,22 +176,23 @@ bool T5Decoder::isValidBatchSize(size_t batch_size) } template -T5Decoder::T5Decoder(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - NcclParam tensor_para, - NcclParam pipeline_para, - ActivationType activation_type, - float q_scaling, +T5Decoder::T5Decoder(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t d_model, + size_t num_layer, + float layernorm_eps, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + NcclParam tensor_para, + NcclParam pipeline_para, + ActivationType activation_type, + float q_scaling, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), head_num_(head_num), @@ -170,6 +200,7 @@ T5Decoder::T5Decoder(size_t max_batch_size, inter_size_(inter_size), d_model_(d_model), num_layer_(num_layer), + layernorm_eps_(layernorm_eps), hidden_units_(head_num_ * size_per_head), tensor_para_(tensor_para), pipeline_para_(pipeline_para), @@ -190,6 +221,7 @@ T5Decoder::T5Decoder(T5Decoder const& decoder): inter_size_(decoder.inter_size_), d_model_(decoder.d_model_), num_layer_(decoder.num_layer_), + layernorm_eps_(decoder.layernorm_eps_), hidden_units_(decoder.hidden_units_), tensor_para_(decoder.tensor_para_), pipeline_para_(decoder.pipeline_para_), @@ -249,8 +281,8 @@ int T5Decoder::getFirstLayerParallelId() } template -void T5Decoder::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void T5Decoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector*>* decoder_layer_weight) { // input tensors: @@ -272,16 +304,19 @@ void T5Decoder::forward(std::vector* output_tensors, // value_cache [num_layer / pipeline_para_.world_size_, batch, head_num, max_seq_len, size_per_head] // key_mem_cache [num_layer / pipeline_para_.world_size_, batch_size, mem_max_seq_len, hidden_dimension], // value_mem_cache [num_layer / pipeline_para_.world_size_, batch_size, mem_max_seq_len, hidden_dimension] + // attention_output: shape = [num_layer / pipeline_para_.world_size_, batch_size, beam, + // head_num / tensor_para_.world_size_, max_seq_len, mem_max_seq_len] + // offset = [batch_offset, layer_offset_base] optional, float* FT_CHECK(input_tensors->size() == 9); - FT_CHECK(output_tensors->size() == 5); + FT_CHECK(output_tensors->size() == 5 || output_tensors->size() == 6); isValidBatchSize(input_tensors->at(0).shape[0]); const size_t local_batch_size = input_tensors->at(0).shape[0]; allocateBuffer(local_batch_size); - const size_t mem_max_seq_len = (size_t)input_tensors->at(1).shape[1]; - const uint ite = *((uint*)(input_tensors->at(7).data)); - const DataType data_type = getTensorType(); + const size_t mem_max_seq_len = (size_t)input_tensors->at(1).shape[1]; + const uint ite = *((uint*)(input_tensors->at(7).data)); + const DataType data_type = getTensorType(); std::vector self_k_cache_shape; self_k_cache_shape.push_back(local_batch_size); @@ -297,12 +332,15 @@ void T5Decoder::forward(std::vector* output_tensors, const std::vector mem_cache_shape = { local_batch_size, output_tensors->at(3).shape[2], output_tensors->at(3).shape[3]}; + const bool output_cross_attention = output_tensors->size() == 6; + const uint max_seq_len = output_cross_attention ? output_tensors->at(5).shape[4] : 0; + for (uint l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l) == false) { continue; } - T* decoder_input = (T*)((l == 0) ? input_tensors->at(0).data : decoder_layer_output_); + T* decoder_input = (T*)((l == 0) ? input_tensors->at(0).data : decoder_layer_output_); T* decoder_output = (T*)((l == num_layer_ - 1) ? output_tensors->at(0).data : decoder_layer_output_); if (isFirstLayerParallelId(l) == true && pipeline_para_.rank_ != 0 && pipeline_para_.world_size_ > 1) { @@ -348,20 +386,24 @@ void T5Decoder::forward(std::vector* output_tensors, decoder_input, decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, decoder_layer_weight->at(l)->pre_layernorm_weights.beta, + layernorm_eps_, local_batch_size, d_model_, stream_); sync_check_cuda_error(); - int tmp_0 = 0; + int tmp_0 = 0; std::vector self_attention_input_tensors{ Tensor{MEMORY_GPU, data_type, {local_batch_size, d_model_}, decoder_normed_input_}, input_tensors->at(3), input_tensors->at(5), Tensor{MEMORY_GPU, TYPE_INT32, {local_batch_size}, (T*)nullptr}, + Tensor{MEMORY_GPU, data_type, {local_batch_size}, (T*)nullptr}, + Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_0}, Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_0}, input_tensors->at(4), input_tensors->at(8), + Tensor{MEMORY_GPU, TYPE_BOOL, {local_batch_size}, (bool*)nullptr}, input_tensors->at(6)}; std::vector self_attention_output_tensors{ Tensor{MEMORY_GPU, data_type, {local_batch_size, d_model_}, self_attn_output_}, @@ -378,6 +420,7 @@ void T5Decoder::forward(std::vector* output_tensors, decoder_layer_weight->at(l)->self_attn_layernorm_weights.gamma, decoder_layer_weight->at(l)->self_attn_layernorm_weights.beta, decoder_layer_weight->at(l)->self_attention_weights.attention_output_weight.bias, + layernorm_eps_, local_batch_size, d_model_, stream_); @@ -393,6 +436,17 @@ void T5Decoder::forward(std::vector* output_tensors, Tensor{MEMORY_GPU, data_type, {local_batch_size, d_model_}, cross_attn_output_}, Tensor{MEMORY_GPU, data_type, mem_cache_shape, ((const T*)output_tensors->at(3).data) + mem_cache_offset}, Tensor{MEMORY_GPU, data_type, mem_cache_shape, ((const T*)output_tensors->at(4).data) + mem_cache_offset}}; + if (output_cross_attention) { + int local_layer_id = l - getFirstLayerParallelId(); + const size_t cross_attentions_offset = local_layer_id * output_tensors->at(5).offsets[1] + + output_tensors->at(5).offsets[0] * head_num_ + / tensor_para_.world_size_ * max_seq_len * mem_max_seq_len; + cross_attention_output_tensors.push_back( + Tensor{MEMORY_GPU, + TYPE_FP32, + {local_batch_size, head_num_ / tensor_para_.world_size_, max_seq_len, mem_max_seq_len}, + output_tensors->at(5).getPtrWithOffset(cross_attentions_offset)}); + } cross_attention_layer_->forward(&cross_attention_output_tensors, &cross_attention_input_tensors, &decoder_layer_weight->at(l)->cross_attention_weights); @@ -404,6 +458,7 @@ void T5Decoder::forward(std::vector* output_tensors, decoder_layer_weight->at(l)->cross_attn_layernorm_weights.gamma, decoder_layer_weight->at(l)->cross_attn_layernorm_weights.beta, decoder_layer_weight->at(l)->cross_attention_weights.attention_output_weight.bias, + layernorm_eps_, local_batch_size, d_model_, stream_); @@ -443,5 +498,8 @@ void T5Decoder::forward(std::vector* output_tensors, template class T5Decoder; template class T5Decoder; +#ifdef ENABLE_BF16 +template class T5Decoder<__nv_bfloat16>; +#endif -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/models/t5/T5Decoder.h b/src/fastertransformer/models/t5/T5Decoder.h index f86612cd1..57947e8d0 100644 --- a/src/fastertransformer/models/t5/T5Decoder.h +++ b/src/fastertransformer/models/t5/T5Decoder.h @@ -23,6 +23,7 @@ #include "src/fastertransformer/layers/BaseLayer.h" #include "src/fastertransformer/layers/TensorParallelGeluFfnLayer.h" #include "src/fastertransformer/layers/TensorParallelReluFfnLayer.h" +#include "src/fastertransformer/layers/TensorParallelSiluFfnLayer.h" #include "src/fastertransformer/layers/attention_layers/TensorParallelDecoderCrossAttentionLayer.h" #include "src/fastertransformer/layers/attention_layers/TensorParallelDecoderSelfAttentionLayer.h" #include "src/fastertransformer/models/t5/T5DecoderLayerWeight.h" @@ -39,18 +40,19 @@ class T5Decoder: public BaseLayer { // buffer handling size_t max_batch_size_ = 0; // meta data - const size_t head_num_; - const size_t size_per_head_; - const size_t inter_size_; - const size_t d_model_; - const size_t num_layer_; - const size_t hidden_units_; + const size_t head_num_; + const size_t size_per_head_; + const size_t inter_size_; + const size_t d_model_; + const size_t num_layer_; + const size_t hidden_units_; const ActivationType activation_type_; - float q_scaling_; + const float layernorm_eps_; + float q_scaling_; BaseAttentionLayer* self_attention_layer_; BaseAttentionLayer* cross_attention_layer_; - FfnLayer* ffn_layer_; + FfnLayer* ffn_layer_; void allocateBuffer() override; void freeBuffer() override; @@ -63,45 +65,46 @@ class T5Decoder: public BaseLayer { NcclParam pipeline_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; bool isValidLayerParallelId(uint l); bool isFirstLayerParallelId(uint l); bool isLastLayerParallelId(uint l); - int getFirstLayerParallelId(); + int getFirstLayerParallelId(); protected: - T* decoder_normed_input_ = nullptr; - T* self_attn_output_ = nullptr; - T* normed_self_attn_output_ = nullptr; - T* cross_attn_output_ = nullptr; + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; + T* normed_self_attn_output_ = nullptr; + T* cross_attn_output_ = nullptr; T* normed_cross_attn_output_ = nullptr; - T* decoder_layer_output_ = nullptr; + T* decoder_layer_output_ = nullptr; public: - T5Decoder(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - NcclParam tensor_para, - NcclParam pipeline_para, - ActivationType activation_type, - float q_scaling = 1.0f, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + T5Decoder(size_t max_batch_size, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t d_model, + size_t num_layer, + float layernorm_eps_, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + NcclParam tensor_para, + NcclParam pipeline_para, + ActivationType activation_type, + float q_scaling = 1.0f, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); T5Decoder(T5Decoder const& decoder); ~T5Decoder(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector*>* decoder_layer_weights); void setStream(cudaStream_t stream) override; }; diff --git a/src/fastertransformer/models/t5/T5DecoderLayerWeight.cc b/src/fastertransformer/models/t5/T5DecoderLayerWeight.cc index 85125ba1a..c3c9610c4 100644 --- a/src/fastertransformer/models/t5/T5DecoderLayerWeight.cc +++ b/src/fastertransformer/models/t5/T5DecoderLayerWeight.cc @@ -29,7 +29,8 @@ T5DecoderLayerWeight::T5DecoderLayerWeight(const size_t head_num, const size_t mem_d_model, const size_t tensor_para_size, const size_t tensor_para_rank, - const bool t5_with_bias): + const bool t5_with_bias, + const bool use_gated_activation): head_num_(head_num), size_per_head_(size_per_head), d_model_(d_model), @@ -38,8 +39,10 @@ T5DecoderLayerWeight::T5DecoderLayerWeight(const size_t head_num, tensor_para_size_(tensor_para_size), tensor_para_rank_(tensor_para_rank), t5_with_bias_(t5_with_bias), - real_weights_num_(t5_with_bias ? 22 : 11) + use_gated_activation_(use_gated_activation) { + real_weights_num_ = (11 + (use_gated_activation ? 1 : 0)) * (t5_with_bias ? 2 : 1); + FT_LOG_DEBUG("T5DecoderLayerWeight " + std::string(__func__) + " start"); initialize(); @@ -63,21 +66,44 @@ void T5DecoderLayerWeight::initialize() weights_size[6] = mem_d_model_ * (head_num_ / tensor_para_size_) * size_per_head_; weights_size[7] = (head_num_ / tensor_para_size_) * size_per_head_ * d_model_; weights_size[8] = d_model_; - weights_size[9] = d_model_ * (inter_size_ / tensor_para_size_); - weights_size[10] = (inter_size_ / tensor_para_size_) * d_model_; + if (use_gated_activation_) { + weights_size[9] = d_model_ * (inter_size_ / tensor_para_size_); + weights_size[10] = d_model_ * (inter_size_ / tensor_para_size_); // for gated activation + weights_size[11] = (inter_size_ / tensor_para_size_) * d_model_; + } + else { + weights_size[9] = d_model_ * (inter_size_ / tensor_para_size_); + weights_size[10] = (inter_size_ / tensor_para_size_) * d_model_; + } if (t5_with_bias_) { - weights_size[11] = d_model_; - weights_size[12] = 3 * (head_num_ / tensor_para_size_) * size_per_head_; - weights_size[13] = d_model_; - weights_size[14] = d_model_; - weights_size[15] = (head_num_ / tensor_para_size_) * size_per_head_; - weights_size[16] = (head_num_ / tensor_para_size_) * size_per_head_; - weights_size[17] = (head_num_ / tensor_para_size_) * size_per_head_; - weights_size[18] = d_model_; - weights_size[19] = d_model_; - weights_size[20] = (inter_size_ / tensor_para_size_); - weights_size[21] = d_model_; + if (use_gated_activation_) { + weights_size[12] = d_model_; + weights_size[13] = 3 * (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[14] = d_model_; + weights_size[15] = d_model_; + weights_size[16] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[17] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[18] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[19] = d_model_; + weights_size[20] = d_model_; + weights_size[21] = (inter_size_ / tensor_para_size_); + weights_size[22] = (inter_size_ / tensor_para_size_); // for gated activation + weights_size[23] = d_model_; + } + else { + weights_size[11] = d_model_; + weights_size[12] = 3 * (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[13] = d_model_; + weights_size[14] = d_model_; + weights_size[15] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[16] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[17] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[18] = d_model_; + weights_size[19] = d_model_; + weights_size[20] = (inter_size_ / tensor_para_size_); + weights_size[21] = d_model_; + } } FT_LOG_DEBUG("T5DecoderLayerWeight " + std::string(__func__) + " end"); @@ -93,36 +119,36 @@ T5DecoderLayerWeight::~T5DecoderLayerWeight() deviceFree(weights_ptr[i]); } - pre_layernorm_weights.gamma = nullptr; - self_attention_weights.query_weight.kernel = nullptr; + pre_layernorm_weights.gamma = nullptr; + self_attention_weights.query_weight.kernel = nullptr; self_attention_weights.attention_output_weight.kernel = nullptr; - self_attn_layernorm_weights.gamma = nullptr; + self_attn_layernorm_weights.gamma = nullptr; - cross_attention_weights.query_weight.kernel = nullptr; - cross_attention_weights.key_weight.kernel = nullptr; - cross_attention_weights.value_weight.kernel = nullptr; + cross_attention_weights.query_weight.kernel = nullptr; + cross_attention_weights.key_weight.kernel = nullptr; + cross_attention_weights.value_weight.kernel = nullptr; cross_attention_weights.attention_output_weight.kernel = nullptr; - cross_attn_layernorm_weights.gamma = nullptr; - - ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.output_weight.kernel = nullptr; - - if (t5_with_bias_) { - pre_layernorm_weights.beta = nullptr; - self_attention_weights.query_weight.bias = nullptr; - self_attention_weights.attention_output_weight.bias = nullptr; - self_attn_layernorm_weights.beta = nullptr; - - cross_attention_weights.query_weight.bias = nullptr; - cross_attention_weights.key_weight.bias = nullptr; - cross_attention_weights.value_weight.bias = nullptr; - cross_attention_weights.attention_output_weight.bias = nullptr; - cross_attn_layernorm_weights.beta = nullptr; - - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.bias = nullptr; - } - is_maintain_buffer = false; + cross_attn_layernorm_weights.gamma = nullptr; + + ffn_weights.intermediate_weight.kernel = nullptr; + ffn_weights.intermediate_weight2.kernel = nullptr; + ffn_weights.output_weight.kernel = nullptr; + + pre_layernorm_weights.beta = nullptr; + self_attention_weights.query_weight.bias = nullptr; + self_attention_weights.attention_output_weight.bias = nullptr; + self_attn_layernorm_weights.beta = nullptr; + + cross_attention_weights.query_weight.bias = nullptr; + cross_attention_weights.key_weight.bias = nullptr; + cross_attention_weights.value_weight.bias = nullptr; + cross_attention_weights.attention_output_weight.bias = nullptr; + cross_attn_layernorm_weights.beta = nullptr; + + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.intermediate_weight2.bias = nullptr; + ffn_weights.output_weight.bias = nullptr; + is_maintain_buffer = false; } FT_LOG_DEBUG("T5DecoderLayerWeight " + std::string(__func__) + " end"); @@ -138,6 +164,7 @@ T5DecoderLayerWeight::T5DecoderLayerWeight(const T5DecoderLayerWeight& other) tensor_para_size_(other.tensor_para_size_), tensor_para_rank_(other.tensor_para_rank_), t5_with_bias_(other.t5_with_bias_), + use_gated_activation_(other.use_gated_activation_), real_weights_num_(other.real_weights_num_) { @@ -152,15 +179,16 @@ T5DecoderLayerWeight::T5DecoderLayerWeight(const T5DecoderLayerWeight& other) template T5DecoderLayerWeight& T5DecoderLayerWeight::operator=(const T5DecoderLayerWeight& other) { - head_num_ = other.head_num_; - size_per_head_ = other.size_per_head_; - d_model_ = other.d_model_; - inter_size_ = other.inter_size_; - mem_d_model_ = other.mem_d_model_; - tensor_para_size_ = other.tensor_para_size_; - tensor_para_rank_ = other.tensor_para_rank_; - t5_with_bias_ = other.t5_with_bias_; - real_weights_num_ = other.real_weights_num_; + head_num_ = other.head_num_; + size_per_head_ = other.size_per_head_; + d_model_ = other.d_model_; + inter_size_ = other.inter_size_; + mem_d_model_ = other.mem_d_model_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + t5_with_bias_ = other.t5_with_bias_; + use_gated_activation_ = other.use_gated_activation_; + real_weights_num_ = other.real_weights_num_; initialize(); mallocWeights(); @@ -175,34 +203,58 @@ T5DecoderLayerWeight& T5DecoderLayerWeight::operator=(const T5DecoderLayer template void T5DecoderLayerWeight::setWeightPtr() { - pre_layernorm_weights.gamma = weights_ptr[0]; - self_attention_weights.query_weight.kernel = weights_ptr[1]; + pre_layernorm_weights.gamma = weights_ptr[0]; + self_attention_weights.query_weight.kernel = weights_ptr[1]; self_attention_weights.attention_output_weight.kernel = weights_ptr[2]; - self_attn_layernorm_weights.gamma = weights_ptr[3]; + self_attn_layernorm_weights.gamma = weights_ptr[3]; - cross_attention_weights.query_weight.kernel = weights_ptr[4]; - cross_attention_weights.key_weight.kernel = weights_ptr[5]; - cross_attention_weights.value_weight.kernel = weights_ptr[6]; + cross_attention_weights.query_weight.kernel = weights_ptr[4]; + cross_attention_weights.key_weight.kernel = weights_ptr[5]; + cross_attention_weights.value_weight.kernel = weights_ptr[6]; cross_attention_weights.attention_output_weight.kernel = weights_ptr[7]; - cross_attn_layernorm_weights.gamma = weights_ptr[8]; - - ffn_weights.intermediate_weight.kernel = weights_ptr[9]; - ffn_weights.output_weight.kernel = weights_ptr[10]; + cross_attn_layernorm_weights.gamma = weights_ptr[8]; + if (use_gated_activation_) { + ffn_weights.intermediate_weight.kernel = weights_ptr[9]; + ffn_weights.intermediate_weight2.kernel = weights_ptr[10]; + ffn_weights.output_weight.kernel = weights_ptr[11]; + } + else { + ffn_weights.intermediate_weight.kernel = weights_ptr[9]; + ffn_weights.output_weight.kernel = weights_ptr[10]; + } if (t5_with_bias_) { - pre_layernorm_weights.beta = weights_ptr[11]; - self_attention_weights.query_weight.bias = weights_ptr[12]; - self_attention_weights.attention_output_weight.bias = weights_ptr[13]; - self_attn_layernorm_weights.beta = weights_ptr[14]; - - cross_attention_weights.query_weight.bias = weights_ptr[15]; - cross_attention_weights.key_weight.bias = weights_ptr[16]; - cross_attention_weights.value_weight.bias = weights_ptr[17]; - cross_attention_weights.attention_output_weight.bias = weights_ptr[18]; - cross_attn_layernorm_weights.beta = weights_ptr[19]; - - ffn_weights.intermediate_weight.bias = weights_ptr[20]; - ffn_weights.output_weight.bias = weights_ptr[21]; + if (use_gated_activation_) { + pre_layernorm_weights.beta = weights_ptr[12]; + self_attention_weights.query_weight.bias = weights_ptr[13]; + self_attention_weights.attention_output_weight.bias = weights_ptr[14]; + self_attn_layernorm_weights.beta = weights_ptr[15]; + + cross_attention_weights.query_weight.bias = weights_ptr[16]; + cross_attention_weights.key_weight.bias = weights_ptr[17]; + cross_attention_weights.value_weight.bias = weights_ptr[18]; + cross_attention_weights.attention_output_weight.bias = weights_ptr[19]; + cross_attn_layernorm_weights.beta = weights_ptr[20]; + + ffn_weights.intermediate_weight.bias = weights_ptr[21]; + ffn_weights.intermediate_weight2.bias = weights_ptr[22]; + ffn_weights.output_weight.bias = weights_ptr[23]; + } + else { + pre_layernorm_weights.beta = weights_ptr[11]; + self_attention_weights.query_weight.bias = weights_ptr[12]; + self_attention_weights.attention_output_weight.bias = weights_ptr[13]; + self_attn_layernorm_weights.beta = weights_ptr[14]; + + cross_attention_weights.query_weight.bias = weights_ptr[15]; + cross_attention_weights.key_weight.bias = weights_ptr[16]; + cross_attention_weights.value_weight.bias = weights_ptr[17]; + cross_attention_weights.attention_output_weight.bias = weights_ptr[18]; + cross_attn_layernorm_weights.beta = weights_ptr[19]; + + ffn_weights.intermediate_weight.bias = weights_ptr[20]; + ffn_weights.output_weight.bias = weights_ptr[21]; + } } } @@ -222,89 +274,123 @@ void T5DecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType mod FT_CHECK(is_maintain_buffer == true); loadWeightFromBin( - weights_ptr[0], {(int)weights_size[0]}, dir_path + "layer.0.layer_norm.weight.bin", model_file_type); + weights_ptr[0], {weights_size[0]}, dir_path + "layer.0.layer_norm.weight.bin", model_file_type); loadWeightFromBin(weights_ptr[1], - {(int)weights_size[1]}, + {weights_size[1]}, dir_path + "layer.0.SelfAttention.qkv.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[2], - {(int)weights_size[2]}, + {weights_size[2]}, dir_path + "layer.0.SelfAttention.o.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin( - weights_ptr[3], {(int)weights_size[3]}, dir_path + "layer.1.layer_norm.weight.bin", model_file_type); + weights_ptr[3], {weights_size[3]}, dir_path + "layer.1.layer_norm.weight.bin", model_file_type); loadWeightFromBin(weights_ptr[4], - {(int)weights_size[4]}, + {weights_size[4]}, dir_path + "layer.1.EncDecAttention.q.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[5], - {(int)weights_size[5]}, + {weights_size[5]}, dir_path + "layer.1.EncDecAttention.k.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[6], - {(int)weights_size[6]}, + {weights_size[6]}, dir_path + "layer.1.EncDecAttention.v.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[7], - {(int)weights_size[7]}, + {weights_size[7]}, dir_path + "layer.1.EncDecAttention.o.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin( - weights_ptr[8], {(int)weights_size[8]}, dir_path + "layer.2.layer_norm.weight.bin", model_file_type); + weights_ptr[8], {weights_size[8]}, dir_path + "layer.2.layer_norm.weight.bin", model_file_type); loadWeightFromBin(weights_ptr[9], - {(int)weights_size[9]}, + {weights_size[9]}, dir_path + "layer.2.DenseReluDense.wi.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin(weights_ptr[10], - {(int)weights_size[10]}, + + const int gated_activation_weight_offset = use_gated_activation_ ? 1 : 0; + if (use_gated_activation_) { + loadWeightFromBin(weights_ptr[10], + {weights_size[10]}, + dir_path + "layer.2.DenseReluDense.wi2.weight." + std::to_string(tensor_para_rank_) + + ".bin", + model_file_type); + } + loadWeightFromBin(weights_ptr[10 + gated_activation_weight_offset], + {weights_size[10 + gated_activation_weight_offset]}, dir_path + "layer.2.DenseReluDense.wo.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); if (t5_with_bias_) { - loadWeightFromBin( - weights_ptr[11], {(int)weights_size[11]}, dir_path + "layer.0.layer_norm.bias.bin", model_file_type); - loadWeightFromBin(weights_ptr[12], - {(int)weights_size[12]}, + loadWeightFromBin(weights_ptr[11 + gated_activation_weight_offset], + {weights_size[11 + gated_activation_weight_offset]}, + dir_path + "layer.0.layer_norm.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[12 + gated_activation_weight_offset], + {weights_size[12 + gated_activation_weight_offset]}, dir_path + "layer.0.SelfAttention.qkv.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin( - weights_ptr[13], {(int)weights_size[13]}, dir_path + "layer.0.SelfAttention.o.bias.bin", model_file_type); - loadWeightFromBin( - weights_ptr[14], {(int)weights_size[14]}, dir_path + "layer.1.layer_norm.bias.bin", model_file_type); - loadWeightFromBin(weights_ptr[15], - {(int)weights_size[15]}, + loadWeightFromBin(weights_ptr[13 + gated_activation_weight_offset], + {weights_size[13 + gated_activation_weight_offset]}, + dir_path + "layer.0.SelfAttention.o.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[14 + gated_activation_weight_offset], + {weights_size[14 + gated_activation_weight_offset]}, + dir_path + "layer.1.layer_norm.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[15 + gated_activation_weight_offset], + {weights_size[15 + gated_activation_weight_offset]}, dir_path + "layer.1.EncDecAttention.q.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin(weights_ptr[16], - {(int)weights_size[16]}, + loadWeightFromBin(weights_ptr[16 + gated_activation_weight_offset], + {weights_size[16 + gated_activation_weight_offset]}, dir_path + "layer.1.EncDecAttention.k.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin(weights_ptr[17], - {(int)weights_size[17]}, + loadWeightFromBin(weights_ptr[17 + gated_activation_weight_offset], + {weights_size[17 + gated_activation_weight_offset]}, dir_path + "layer.1.EncDecAttention.v.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin( - weights_ptr[18], {(int)weights_size[18]}, dir_path + "layer.1.EncDecAttention.o.bias.bin", model_file_type); - loadWeightFromBin( - weights_ptr[19], {(int)weights_size[19]}, dir_path + "layer.2.layer_norm.bias.bin", model_file_type); - loadWeightFromBin(weights_ptr[20], - {(int)weights_size[20]}, + loadWeightFromBin(weights_ptr[18 + gated_activation_weight_offset], + {weights_size[18 + gated_activation_weight_offset]}, + dir_path + "layer.1.EncDecAttention.o.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[19 + gated_activation_weight_offset], + {weights_size[19 + gated_activation_weight_offset]}, + dir_path + "layer.2.layer_norm.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[20 + gated_activation_weight_offset], + {weights_size[20 + gated_activation_weight_offset]}, dir_path + "layer.2.DenseReluDense.wi.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin( - weights_ptr[21], {(int)weights_size[21]}, dir_path + "layer.2.DenseReluDense.wo.bias.bin", model_file_type); + if (use_gated_activation_) { + loadWeightFromBin(weights_ptr[22], + {weights_size[22]}, + dir_path + "layer.2.DenseReluDense.wi2.bias." + std::to_string(tensor_para_rank_) + + ".bin", + model_file_type); + loadWeightFromBin( + weights_ptr[23], {weights_size[23]}, dir_path + "layer.2.DenseReluDense.wo.bias.bin", model_file_type); + } + else { + loadWeightFromBin( + weights_ptr[21], {weights_size[21]}, dir_path + "layer.2.DenseReluDense.wo.bias.bin", model_file_type); + } } FT_LOG_DEBUG("T5DecoderLayerWeight " + std::string(__func__) + " end"); } template -void T5DecoderLayerWeight::setT5WithBias(bool t5_with_bias_para) +void T5DecoderLayerWeight::setT5WithBias(bool t5_with_bias_para, bool use_gated_activation_para) { - t5_with_bias_ = t5_with_bias_para; + t5_with_bias_ = t5_with_bias_para; + use_gated_activation_ = use_gated_activation_para; } template struct T5DecoderLayerWeight; template struct T5DecoderLayerWeight; +#ifdef ENABLE_BF16 +template struct T5DecoderLayerWeight<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/t5/T5DecoderLayerWeight.h b/src/fastertransformer/models/t5/T5DecoderLayerWeight.h index 2dc0cda96..f32086ebc 100644 --- a/src/fastertransformer/models/t5/T5DecoderLayerWeight.h +++ b/src/fastertransformer/models/t5/T5DecoderLayerWeight.h @@ -35,7 +35,8 @@ struct T5DecoderLayerWeight { const size_t mem_d_model, const size_t tensor_para_size, const size_t tensor_para_rank, - const bool t5_with_bias); + const bool t5_with_bias, + const bool use_gated_activation); ~T5DecoderLayerWeight(); T5DecoderLayerWeight(const T5DecoderLayerWeight& other); T5DecoderLayerWeight& operator=(const T5DecoderLayerWeight& other); @@ -45,12 +46,13 @@ struct T5DecoderLayerWeight { LayerNormWeight self_attn_layernorm_weights; AttentionWeight cross_attention_weights; LayerNormWeight cross_attn_layernorm_weights; - FfnWeight ffn_weights; - bool t5_with_bias_; + FfnWeight ffn_weights; + bool t5_with_bias_; + bool use_gated_activation_; void loadModel(std::string dir_path, FtCudaDataType model_file_type); - void setT5WithBias(bool t5_with_bias_para); + void setT5WithBias(bool t5_with_bias_para, bool use_gated_activation_para); private: void setWeightPtr(); @@ -64,13 +66,13 @@ struct T5DecoderLayerWeight { size_t mem_d_model_; size_t tensor_para_size_; size_t tensor_para_rank_; - bool is_maintain_buffer = false; + bool is_maintain_buffer = false; int real_weights_num_; - const static int weights_num_ = 22; - T* weights_ptr[weights_num_]; - size_t weights_size[weights_num_]; + const static int weights_num_ = 24; + T* weights_ptr[weights_num_]; + size_t weights_size[weights_num_]; }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/t5/T5Decoding.cc b/src/fastertransformer/models/t5/T5Decoding.cc index c1ef9e721..dfd6248c1 100644 --- a/src/fastertransformer/models/t5/T5Decoding.cc +++ b/src/fastertransformer/models/t5/T5Decoding.cc @@ -31,6 +31,7 @@ void T5Decoding::initialize() inter_size_, d_model_, num_layer_, + layernorm_eps_, stream_, cublas_wrapper_, allocator_, @@ -42,14 +43,14 @@ void T5Decoding::initialize() custom_all_reduce_comm_, enable_custom_all_reduce_); - dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, - vocab_size_padded_, - 0, // end_id, deprecated - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_, - cuda_device_prop_); + dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, + vocab_size_padded_, + 0, // end_id, deprecated + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + cuda_device_prop_); } template @@ -68,7 +69,7 @@ void T5Decoding::allocateBuffer( // use max_seq_len + 1, but not max_seq_len. // This only affects the buffer size, not affect the performance. - const size_t batchxbeam = batch_size * beam_width; + const size_t batchxbeam = batch_size * beam_width; const size_t self_cache_size = (num_layer_ / pipeline_para_.world_size_) * batchxbeam * (max_seq_len + 1) * (hidden_units_ / tensor_para_.world_size_); const size_t mem_cache_size = (num_layer_ / pipeline_para_.world_size_) * batchxbeam * max_mem_seq_len @@ -86,19 +87,21 @@ void T5Decoding::allocateBuffer( relative_attention_bias_ = (T*)(allocator_->reMalloc( relative_attention_bias_, sizeof(T) * head_num_ * (max_seq_len + 1) * (max_seq_len + 1), false)); - decoder_input_buf_ = (T*)(allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * d_model_, false)); + decoder_input_buf_ = (T*)(allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * d_model_, false)); decoder_output_buf_ = (T*)(allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * d_model_, false)); normed_decoder_output_buf_ = (T*)(allocator_->reMalloc(normed_decoder_output_buf_, sizeof(T) * batchxbeam * d_model_, false)); - logits_buf_ = (T*)(allocator_->reMalloc(logits_buf_, sizeof(T) * batchxbeam * vocab_size_padded_, false)); - nccl_logits_buf_ = (T*)(allocator_->reMalloc(nccl_logits_buf_, sizeof(T) * batchxbeam * vocab_size_padded_, false)); - cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); - finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); - h_finished_buf_ = (bool*)realloc(h_finished_buf_, sizeof(bool) * batchxbeam); + logits_buf_ = (DynamicDecodeType*)(allocator_->reMalloc( + logits_buf_, sizeof(DynamicDecodeType) * batchxbeam * vocab_size_padded_, false)); + nccl_logits_buf_ = (DynamicDecodeType*)(allocator_->reMalloc( + nccl_logits_buf_, sizeof(DynamicDecodeType) * batchxbeam * vocab_size_padded_, false)); + cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); + finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); + h_finished_buf_ = (bool*)realloc(h_finished_buf_, sizeof(bool) * batchxbeam); key_cache_ = (T*)(allocator_->reMalloc(key_cache_, sizeof(T) * (2 * self_cache_size + 2 * mem_cache_size), false)); - value_cache_ = key_cache_ + self_cache_size; - key_mem_cache_ = value_cache_ + self_cache_size; + value_cache_ = key_cache_ + self_cache_size; + key_mem_cache_ = value_cache_ + self_cache_size; value_mem_cache_ = key_mem_cache_ + mem_cache_size; if (beam_width > 1) { cache_indirections_[0] = (int*)(allocator_->reMalloc( @@ -111,7 +114,7 @@ void T5Decoding::allocateBuffer( (int*)(allocator_->reMalloc(tiled_encoder_sequence_length_, sizeof(int) * batchxbeam, false)); start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * batch_size, false)); - end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false)); + end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false)); output_ids_buf_ = (int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * (max_seq_len + 1), false)); @@ -127,38 +130,42 @@ void T5Decoding::allocateBuffer( template void T5Decoding::freeBuffer() { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_) { if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ptr_ = nullptr; - allocator_->free(padded_embedding_kernel_); + allocator_->free((void**)(&padded_embedding_kernel_)); padded_post_decoder_embedding_bias_ptr_ = nullptr; - allocator_->free(padded_post_decoder_embedding_bias_); + allocator_->free((void**)(&padded_post_decoder_embedding_bias_)); } - allocator_->free(relative_attention_bias_); + allocator_->free((void**)(&relative_attention_bias_)); - allocator_->free(decoder_input_buf_); - allocator_->free(decoder_output_buf_); - allocator_->free(normed_decoder_output_buf_); - allocator_->free(logits_buf_); - allocator_->free(nccl_logits_buf_); - allocator_->free(cum_log_probs_); - allocator_->free(finished_buf_); + allocator_->free((void**)(&decoder_input_buf_)); + allocator_->free((void**)(&decoder_output_buf_)); + allocator_->free((void**)(&normed_decoder_output_buf_)); + allocator_->free((void**)(&logits_buf_)); + allocator_->free((void**)(&nccl_logits_buf_)); + allocator_->free((void**)(&cum_log_probs_)); + allocator_->free((void**)(&finished_buf_)); free(h_finished_buf_); - allocator_->free(key_cache_); + allocator_->free((void**)(&key_cache_)); if (cache_indirections_[0] != nullptr) { - allocator_->free(cache_indirections_[0]); + allocator_->free((void**)(&cache_indirections_)[0]); } - allocator_->free(start_ids_buf_); - allocator_->free(end_ids_buf_); + allocator_->free((void**)(&tiled_encoder_output_)); + allocator_->free((void**)(&tiled_encoder_sequence_length_)); + + allocator_->free((void**)(&start_ids_buf_)); + allocator_->free((void**)(&end_ids_buf_)); - allocator_->free(output_ids_buf_); - allocator_->free(parent_ids_buf_); - allocator_->free(output_ids_transpose_buf_); - allocator_->free(output_log_probs_buf_); + allocator_->free((void**)(&output_ids_buf_)); + allocator_->free((void**)(&parent_ids_buf_)); + allocator_->free((void**)(&output_ids_transpose_buf_)); + allocator_->free((void**)(&output_log_probs_buf_)); is_allocate_buffer_ = false; } } @@ -172,37 +179,38 @@ void T5Decoding::setStream(cudaStream_t stream) } template -T5Decoding::T5Decoding(size_t max_batch_size, - size_t max_seq_len, - size_t mem_max_seq_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - size_t vocab_size, - size_t num_bucket, - size_t max_distance, - float q_scaling, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop, - NcclParam tensor_para, - NcclParam pipeline_para, - ActivationType activation_type, +T5Decoding::T5Decoding(size_t max_batch_size, + size_t max_seq_len, + size_t mem_max_seq_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t d_model, + size_t num_layer, + size_t vocab_size, + size_t num_bucket, + size_t max_distance, + float q_scaling, + int start_id, + int end_id, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, + NcclParam tensor_para, + NcclParam pipeline_para, + ActivationType activation_type, + bool tie_word_embeddings, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop), head_num_(head_num), size_per_head_(size_per_head), @@ -225,6 +233,7 @@ T5Decoding::T5Decoding(size_t max_batch_size, tensor_para_(tensor_para), pipeline_para_(pipeline_para), activation_type_(activation_type), + tie_word_embeddings_(tie_word_embeddings), custom_all_reduce_comm_(custom_all_reduce_comm), enable_custom_all_reduce_(enable_custom_all_reduce) { @@ -261,6 +270,7 @@ T5Decoding::T5Decoding(T5Decoding const& decoding): tensor_para_(decoding.tensor_para_), pipeline_para_(decoding.pipeline_para_), activation_type_(decoding.activation_type_), + tie_word_embeddings_(decoding.tie_word_embeddings_), custom_all_reduce_comm_(decoding.custom_all_reduce_comm_), enable_custom_all_reduce_(decoding.enable_custom_all_reduce_) { @@ -276,7 +286,7 @@ T5Decoding::~T5Decoding() } template -void T5Decoding::forward(std::vector* output_tensors, +void T5Decoding::forward(std::vector* output_tensors, const std::vector* input_tensors, const T5DecodingWeight* decoding_weights) { @@ -301,87 +311,32 @@ void T5Decoding::forward(std::vector* output_tensors, } template -bool T5Decoding::hasDiffRuntimeArgs(const std::unordered_map* input_tensors) -{ - // runtime_top_k [1] or [batch_size] on cpu, optional. - // runtime_top_p [1] or [batch_size] on cpu, optional - // beam_search_diversity_rate [1] or [batch_size] on cpu, optional - // temperature [1] or [batch_size] on cpu, optional - // len_penalty [1] or [batch_size] on cpu, optional - // repetition_penalty [1] or [batch_size] on cpu, optional - // random_seed [1] or [batch_size] on cpu, optional - - std::vector check_list = {"runtime_top_k", - "runtime_top_p", - "beam_search_diversity_rate", - "temperature", - "len_penalty", - "repetition_penalty", - "random_seed"}; - - for (int i = 0; i < check_list.size(); i++) { - if (input_tensors->count(check_list[i])) { - auto tensor = input_tensors->at(check_list[i]); - if (tensor.shape.size() > 1) { - FT_CHECK(tensor.shape[1] == 1); - for (int i = 1; i < tensor.shape[0]; i++) { - const void* data = tensor.data; - switch (tensor.type) { - case TYPE_FP32: - if (((const float*)data)[0] != ((const float*)data)[i]) { - return true; - } - break; - case TYPE_INT32: - if (((const int*)data)[0] != ((const int*)data)[i]) { - return true; - } - break; - case TYPE_UINT32: - if (((const uint*)data)[0] != ((const uint*)data)[i]) { - return true; - } - break; - case TYPE_UINT64: - if (((const unsigned long long int*)data)[0] != ((const unsigned long long int*)data)[i]) { - return true; - } - break; - default: - FT_CHECK(false); - break; - } - } - } - } - } - return false; -} - -template -void T5Decoding::forward(std::unordered_map* output_tensors, +void T5Decoding::forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const T5DecodingWeight* decoding_weights) + const T5DecodingWeight* decoding_weights) { // input_tensors: // encoder_output [batch_size, mem_max_seq_len, memory_hidden_dimension] // encoder_sequence_length [batch_size] // stop_words_list [batch_size, 2, stop_words_length], optional + // bad_words_list [batch_size, 2, stop_words_length], optional // start_id [batch_size] on cpu, optional // end_id [batch_size] on cpu, optional - // runtime_top_k [1] or [batch_size] on cpu, optional. - // runtime_top_p [1] or [batch_size] on cpu, optional - // beam_search_diversity_rate [1] or [batch_size] on cpu, optional - // temperature [1] or [batch_size] on cpu, optional - // len_penalty [1] or [batch_size] on cpu, optional - // repetition_penalty [1] or [batch_size] on cpu, optional - // random_seed [1] or [batch_size] on cpu, optional + // runtime_top_k [1] or [batch_size] on cpu, optional, uint. + // runtime_top_p [1] or [batch_size] on cpu, optional, float. + // beam_search_diversity_rate [1] or [batch_size] on cpu, optional, float. + // temperature [1] or [batch_size] on cpu, optional, float. + // len_penalty [1] or [batch_size] on cpu, optional, float. + // repetition_penalty [1] or [batch_size] on cpu, optional, float. + // random_seed [1] or [batch_size] on cpu, optional, unsigned long long int. // output_tensors: // output_ids [batch_size, beam, max_seq_len] // sequence_length [batch_size, beam], record the number of generated token, except the start token // output_log_probs [batch_size, beam, max_seq_len], optional, must be float*. // cum_log_probs [batch_size, beam], optional, must be float*. + // cross_attentions [num_layer / pipeline_para_size, batch_size, beam, + // head_num / tensor_para_size, max_seq_len, mem_max_seq_len], optional, must be float*. // Step is from 1 ~ max_seq_len, // When step = k, we put output ids and caches at step k, and the sequence_length would be k - 1 before @@ -391,22 +346,25 @@ void T5Decoding::forward(std::unordered_map* output_tens FT_CHECK(input_tensors->size() >= 2); FT_CHECK(output_tensors->size() >= 2); FT_CHECK(input_tensors->at("encoder_output").shape.size() == 3); - const size_t batch_size = output_tensors->at("output_ids").shape[0]; - const size_t beam_width = output_tensors->at("output_ids").shape[1]; - const size_t max_seq_len = output_tensors->at("output_ids").shape[2]; + const size_t batch_size = output_tensors->at("output_ids").shape[0]; + const size_t beam_width = output_tensors->at("output_ids").shape[1]; + const size_t max_seq_len = output_tensors->at("output_ids").shape[2]; const size_t mem_max_seq_len = input_tensors->at("encoder_output").shape[1]; allocateBuffer(batch_size, beam_width, max_seq_len, mem_max_seq_len, input_tensors->at("encoder_output").shape[2]); - bool has_diff_runtime_args = hasDiffRuntimeArgs(input_tensors); - + dynamic_decode_layer_->setup(batch_size, beam_width, input_tensors); handleOptArg(input_tensors, "start_id", start_ids_buf_, start_id_, batch_size); handleOptArg(input_tensors, "end_id", end_ids_buf_, end_id_, batch_size); - FT_CHECK(input_tensors->at("encoder_output").shape[2] == d_model_); + FT_CHECK_WITH_INFO(input_tensors->at("encoder_output").shape[2] == d_model_, + fmtstr("expect input_tensors->at(\"encoder_output\").shape[2] == d_model_, " + "but get input_tensors->at(\"encoder_output\").shape[2] = %d, d_model_ = %d", + input_tensors->at("encoder_output").shape[2], + d_model_)); - const int max_input_length = 1; - const DataType data_type = getTensorType(); - int* sequence_lengths = (int*)output_tensors->at("sequence_length").data; + const int max_input_length = 1; + const DataType data_type = getTensorType(); + int* sequence_lengths = (int*)output_tensors->at("sequence_length").data; cudaMemset((int*)output_tensors->at("output_ids").data, 0, sizeof(int) * batch_size * beam_width * max_seq_len); if (beam_width > 1) { @@ -425,11 +383,11 @@ void T5Decoding::forward(std::unordered_map* output_tens d_model_, stream_); sync_check_cuda_error(); - encoder_output_ptr_ = tiled_encoder_output_; + encoder_output_ptr_ = tiled_encoder_output_; encoder_sequence_length_ptr_ = tiled_encoder_sequence_length_; } else { - encoder_output_ptr_ = (const T*)(input_tensors->at("encoder_output").data); + encoder_output_ptr_ = (const T*)(input_tensors->at("encoder_output").data); encoder_sequence_length_ptr_ = (const int*)(input_tensors->at("encoder_sequence_length").data); } @@ -456,7 +414,7 @@ void T5Decoding::forward(std::unordered_map* output_tens sync_check_cuda_error(); if (vocab_size_ == vocab_size_padded_) { - padded_embedding_kernel_ptr_ = decoding_weights->post_decoder_embedding.kernel; + padded_embedding_kernel_ptr_ = decoding_weights->post_decoder_embedding.kernel; padded_post_decoder_embedding_bias_ptr_ = decoding_weights->post_decoder_embedding.bias; } else { @@ -482,10 +440,10 @@ void T5Decoding::forward(std::unordered_map* output_tens head_num_ / tensor_para_.world_size_, (size_t)(max_seq_len + 1), size_per_head_}; - const std::vector mem_cache_shape = {num_layer_ / pipeline_para_.world_size_, - batch_size * beam_width, - mem_max_seq_len, - head_num_ / tensor_para_.world_size_ * size_per_head_}; + const std::vector mem_cache_shape = {num_layer_ / pipeline_para_.world_size_, + batch_size * beam_width, + mem_max_seq_len, + head_num_ / tensor_para_.world_size_ * size_per_head_}; const size_t local_batch_size = getLocalBatchSize(batch_size, 1, pipeline_para_.world_size_); FT_CHECK(batch_size % local_batch_size == 0); @@ -495,27 +453,26 @@ void T5Decoding::forward(std::unordered_map* output_tens const int tgt_indir_idx = 1 - src_indir_idx; for (uint ite = 0; ite < iteration_num; ++ite) { - const int id_offset = ite * local_batch_size * beam_width; - const int d_model_offset = id_offset * d_model_; + const int id_offset = ite * local_batch_size * beam_width; + const int d_model_offset = id_offset * d_model_; const int vocab_size_units_offset = id_offset * vocab_size_padded_; if (pipeline_para_.rank_ == 0) { - invokeEmbeddingLookupPosEncoding(decoder_input_buf_ + d_model_offset, - decoding_weights->pre_decoder_embedding_table, - decoding_weights->position_embedding_type - == PositionEmbeddingType::relative ? - (T*)nullptr : - decoding_weights->absolute_or_relative_position_embedding, - output_ids_buf_ + id_offset, - nullptr, - local_batch_size * beam_width, - d_model_, - (T)1.0f, - step - 1, - 0, - batch_size * beam_width, - 0, - stream_); + invokeEmbeddingLookupPosEncodingPadCount(decoder_input_buf_ + d_model_offset, + decoding_weights->pre_decoder_embedding_table, + decoding_weights->position_embedding_type + == PositionEmbeddingType::relative ? + (T*)nullptr : + decoding_weights->absolute_or_relative_position_embedding, + output_ids_buf_ + id_offset, + nullptr, + local_batch_size * beam_width, + d_model_, + (T)1.0f, + step - 1, + batch_size * beam_width, + 0, + stream_); sync_check_cuda_error(); } @@ -558,60 +515,136 @@ void T5Decoding::forward(std::unordered_map* output_tens Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}, Tensor{MEMORY_GPU, data_type, mem_cache_shape, key_mem_cache_}, Tensor{MEMORY_GPU, data_type, mem_cache_shape, value_mem_cache_}}; + + if (output_tensors->count("cross_attentions")) { + decoder_output_tensors.push_back(Tensor{ + MEMORY_GPU, + TYPE_FP32, + output_tensors->at("cross_attentions").shape, + output_tensors->at("cross_attentions").data, + {(size_t)id_offset, + batch_size * beam_width * head_num_ / tensor_para_.world_size_ * max_seq_len * mem_max_seq_len}}); + } + decoder_->forward( &decoder_output_tensors, &decoder_input_tensors, &decoding_weights->decoder_layer_weights); bool t5_with_bias = decoding_weights->t5_with_bias; + const cudaDataType_t gemm_data_type = getCudaDataType(); + if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { invokeGeneralT5LayerNorm(normed_decoder_output_buf_ + d_model_offset, decoder_output_buf_ + d_model_offset, decoding_weights->post_decoder_layernorm.gamma, decoding_weights->post_decoder_layernorm.beta, + layernorm_eps_, local_batch_size * beam_width, d_model_, stream_); sync_check_cuda_error(); + DataType logits_data_type = data_type; + + // bf16 logits computation fallback to fp32 if (tensor_para_.world_size_ == 1) { - cublas_wrapper_->Gemm(CUBLAS_OP_T, - CUBLAS_OP_N, - vocab_size_padded_, // n - local_batch_size * beam_width, - d_model_, // k - padded_embedding_kernel_ptr_, - d_model_, // k - normed_decoder_output_buf_ + d_model_offset, - d_model_, // k - logits_buf_ + vocab_size_units_offset, - vocab_size_padded_ /* n */, - t5_with_bias ? 1.0f : 1.0f / sqrt(d_model_), - 0.0f); +#ifdef ENABLE_BF16 + if (std::is_same::value) { + logits_data_type = TYPE_FP32; +#else + if (false) { +#endif + float alpha = t5_with_bias ? 1.0f : (tie_word_embeddings_ ? 1.0f / sqrt(d_model_) : 1.0f); + float beta = 0.0f; + cublas_wrapper_->Gemm(CUBLAS_OP_T, + CUBLAS_OP_N, + vocab_size_padded_, // n + local_batch_size * beam_width, + d_model_, // k + &alpha, + padded_embedding_kernel_ptr_, + gemm_data_type, + d_model_, // k + normed_decoder_output_buf_ + d_model_offset, + gemm_data_type, + d_model_, // k + &beta, + logits_buf_ + vocab_size_units_offset, + CUDA_R_32F, + vocab_size_padded_, /* n */ + CUDA_R_32F, + cublasGemmAlgo_t(-1)); + } + else { + cublas_wrapper_->Gemm(CUBLAS_OP_T, + CUBLAS_OP_N, + vocab_size_padded_, // n + local_batch_size * beam_width, + d_model_, // k + padded_embedding_kernel_ptr_, + d_model_, // k + normed_decoder_output_buf_ + d_model_offset, + d_model_, // k + logits_buf_ + vocab_size_units_offset, + vocab_size_padded_ /* n */, + t5_with_bias ? 1.0f : 1.0f / sqrt(d_model_), + 0.0f); + } } else { const int local_vocab_size = vocab_size_padded_ / tensor_para_.world_size_; - cublas_wrapper_->Gemm(CUBLAS_OP_T, - CUBLAS_OP_N, - local_vocab_size, // n - local_batch_size * beam_width, - d_model_, // k - padded_embedding_kernel_ptr_ - + tensor_para_.rank_ * local_vocab_size * d_model_, - d_model_, // k - normed_decoder_output_buf_ + d_model_offset, - d_model_, // k - nccl_logits_buf_ + vocab_size_units_offset - + tensor_para_.rank_ * local_batch_size * beam_width * local_vocab_size, - local_vocab_size /* n */, - t5_with_bias ? 1.0f : 1.0f / sqrt(d_model_), - 0.0f); +#ifdef ENABLE_BF16 + if (std::is_same::value) { + logits_data_type = TYPE_FP32; +#else + if (false) { +#endif + float alpha = t5_with_bias ? 1.0f : (tie_word_embeddings_ ? 1.0f / sqrt(d_model_) : 1.0f); + float beta = 0.0f; + cublas_wrapper_->Gemm( + CUBLAS_OP_T, + CUBLAS_OP_N, + local_vocab_size, // n + local_batch_size * beam_width, + d_model_, // k + &alpha, + padded_embedding_kernel_ptr_ + tensor_para_.rank_ * local_vocab_size * d_model_, + gemm_data_type, + d_model_, // k + normed_decoder_output_buf_ + d_model_offset, + gemm_data_type, + d_model_, // k + &beta, + nccl_logits_buf_ + vocab_size_units_offset + + tensor_para_.rank_ * local_batch_size * beam_width * local_vocab_size, + CUDA_R_32F, + local_vocab_size, /* n */ + CUDA_R_32F, + cublasGemmAlgo_t(-1)); + } + else { + cublas_wrapper_->Gemm( + CUBLAS_OP_T, + CUBLAS_OP_N, + local_vocab_size, // n + local_batch_size * beam_width, + d_model_, // k + padded_embedding_kernel_ptr_ + tensor_para_.rank_ * local_vocab_size * d_model_, + d_model_, // k + normed_decoder_output_buf_ + d_model_offset, + d_model_, // k + nccl_logits_buf_ + vocab_size_units_offset + + tensor_para_.rank_ * local_batch_size * beam_width * local_vocab_size, + local_vocab_size /* n */, + t5_with_bias ? 1.0f : 1.0f / sqrt(d_model_), + 0.0f); + } ftNcclAllGather(nccl_logits_buf_ + vocab_size_units_offset, nccl_logits_buf_ + vocab_size_units_offset, local_batch_size * beam_width * local_vocab_size, tensor_para_.rank_, tensor_para_, stream_); - check_cuda_error(cudaStreamSynchronize(stream_)); invokeTransposeAxis01(logits_buf_ + vocab_size_units_offset, nccl_logits_buf_ + vocab_size_units_offset, tensor_para_.world_size_, @@ -628,18 +661,17 @@ void T5Decoding::forward(std::unordered_map* output_tens stream_); } - int tmp_local_batch_size = local_batch_size; - bool is_initialize_random_table = step == 1; + int tmp_local_batch_size = local_batch_size; + bool is_initialize_random_table = step == 1; std::unordered_map dynamic_decode_input_tensors{ {"logits", - Tensor{MEMORY_GPU, data_type, {batch_size, beam_width, vocab_size_padded_}, logits_buf_}}, + Tensor{MEMORY_GPU, logits_data_type, {batch_size, beam_width, vocab_size_padded_}, logits_buf_}}, {"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size_padded_}, nullptr}}, {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}, {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, nullptr}}, {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}}, {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids_buf_}}, - {"has_diff_runtime_args", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &has_diff_runtime_args}}, {"src_key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}}, {"src_value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}}, {"src_cache_indirection", @@ -698,7 +730,7 @@ void T5Decoding::forward(std::unordered_map* output_tens } if (pipeline_para_.world_size_ > 1) { - NCCLCHECK(ncclGroupStart()); + ftNcclGroupStart(); ftNcclBroadCast(output_ids_buf_ + step * batch_size * beam_width, batch_size * beam_width, pipeline_para_.world_size_ - 1, @@ -718,8 +750,9 @@ void T5Decoding::forward(std::unordered_map* output_tens pipeline_para_, stream_); } - NCCLCHECK(ncclGroupEnd()); - check_cuda_error(cudaStreamSynchronize(stream_)); + ftNcclGroupEnd(); + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); sync_check_cuda_error(); } @@ -782,7 +815,7 @@ void T5Decoding::forward(std::unordered_map* output_tens } if (pipeline_para_.world_size_ > 1) { - NCCLCHECK(ncclGroupStart()); + ftNcclGroupStart(); if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { ftNcclSend(output_tensors->at("output_ids").getPtr(), batch_size * beam_width * max_seq_len, @@ -843,10 +876,12 @@ void T5Decoding::forward(std::unordered_map* output_tens stream_); } } - NCCLCHECK(ncclGroupEnd()); - check_cuda_error(cudaStreamSynchronize(stream_)); + ftNcclGroupEnd(); } + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); + if (is_free_buffer_after_forward_) { freeBuffer(); } @@ -854,5 +889,8 @@ void T5Decoding::forward(std::unordered_map* output_tens template class T5Decoding; template class T5Decoding; +#ifdef ENABLE_BF16 +template class T5Decoding<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/t5/T5Decoding.h b/src/fastertransformer/models/t5/T5Decoding.h index 99ebc22b8..0cf629354 100644 --- a/src/fastertransformer/models/t5/T5Decoding.h +++ b/src/fastertransformer/models/t5/T5Decoding.h @@ -27,38 +27,53 @@ namespace fastertransformer { +// fallback to fp32 dynamic decoder when bf16 specified +template +struct fallBackType { + using Type = float; +}; + +template<> +struct fallBackType { + using Type = half; +}; + template class T5Decoding: public BaseLayer { private: // meta data - const size_t head_num_; - const size_t size_per_head_; - const size_t inter_size_; - const size_t d_model_; - const size_t num_layer_; - const size_t vocab_size_; - const size_t num_bucket_; - const size_t max_distance_; + const size_t head_num_; + const size_t size_per_head_; + const size_t inter_size_; + const size_t d_model_; + const size_t num_layer_; + const size_t vocab_size_; + const size_t num_bucket_; + const size_t max_distance_; const ActivationType activation_type_; - float q_scaling_; + float q_scaling_; + const bool tie_word_embeddings_; const int start_id_; const int end_id_; + constexpr static float layernorm_eps_ = 1e-6f; + // TODO(bhsueh) remove - const float beam_search_diversity_rate_; + const float beam_search_diversity_rate_; const size_t hidden_units_; const size_t top_k_; - const float top_p_; - const float temperature_; - const float len_penalty_; - const float repetition_penalty_; + const float top_p_; + const float temperature_; + const float len_penalty_; + const float repetition_penalty_; // calculated data size_t vocab_size_padded_; T5Decoder* decoder_; - DynamicDecodeLayer* dynamic_decode_layer_; + using DynamicDecodeType = typename fallBackType::Type; + DynamicDecodeLayer* dynamic_decode_layer_; void allocateBuffer() override; void freeBuffer() override; @@ -66,94 +81,94 @@ class T5Decoding: public BaseLayer { size_t batch_size, size_t beam_width, size_t max_seq_len, size_t max_mem_seq_len, size_t encoder_d_model); void initialize(); - bool hasDiffRuntimeArgs(const std::unordered_map* input_tensors); NcclParam tensor_para_; NcclParam pipeline_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; protected: - T* padded_embedding_kernel_ = nullptr; - const T* padded_embedding_kernel_ptr_ = nullptr; - T* padded_post_decoder_embedding_bias_ = nullptr; + T* padded_embedding_kernel_ = nullptr; + const T* padded_embedding_kernel_ptr_ = nullptr; + T* padded_post_decoder_embedding_bias_ = nullptr; const T* padded_post_decoder_embedding_bias_ptr_ = nullptr; - T* relative_attention_bias_ = nullptr; + T* relative_attention_bias_ = nullptr; - T* decoder_input_buf_ = nullptr; - T* decoder_output_buf_ = nullptr; - T* normed_decoder_output_buf_ = nullptr; - T* logits_buf_ = nullptr; - T* nccl_logits_buf_ = nullptr; - float* cum_log_probs_ = nullptr; - bool* finished_buf_ = nullptr; - bool* h_finished_buf_ = nullptr; + T* decoder_input_buf_ = nullptr; + T* decoder_output_buf_ = nullptr; + T* normed_decoder_output_buf_ = nullptr; + DynamicDecodeType* logits_buf_ = nullptr; + DynamicDecodeType* nccl_logits_buf_ = nullptr; + float* cum_log_probs_ = nullptr; + bool* finished_buf_ = nullptr; + bool* h_finished_buf_ = nullptr; int* start_ids_buf_ = nullptr; - int* end_ids_buf_ = nullptr; + int* end_ids_buf_ = nullptr; - T* key_cache_ = nullptr; - T* value_cache_ = nullptr; - T* key_mem_cache_ = nullptr; - T* value_mem_cache_ = nullptr; + T* key_cache_ = nullptr; + T* value_cache_ = nullptr; + T* key_mem_cache_ = nullptr; + T* value_mem_cache_ = nullptr; int* cache_indirections_[2] = {nullptr, nullptr}; - int* output_ids_buf_ = nullptr; - int* parent_ids_buf_ = nullptr; - int* output_ids_transpose_buf_ = nullptr; - float* output_log_probs_buf_ = nullptr; + int* output_ids_buf_ = nullptr; + int* parent_ids_buf_ = nullptr; + int* output_ids_transpose_buf_ = nullptr; + float* output_log_probs_buf_ = nullptr; - T* tiled_encoder_output_ = nullptr; + T* tiled_encoder_output_ = nullptr; int* tiled_encoder_sequence_length_ = nullptr; - const T* encoder_output_ptr_ = nullptr; + const T* encoder_output_ptr_ = nullptr; const int* encoder_sequence_length_ptr_ = nullptr; public: - T5Decoding(size_t max_batch_size, - size_t max_seq_len, - size_t mem_max_seq_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - size_t vocab_size, - size_t num_bucket, - size_t max_distance, - float q_scaling, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop, - NcclParam tensor_para, - NcclParam pipeline_para, - ActivationType activation_type = ActivationType::Relu, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + T5Decoding(size_t max_batch_size, + size_t max_seq_len, + size_t mem_max_seq_len, + size_t beam_width, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t d_model, + size_t num_layer, + size_t vocab_size, + size_t num_bucket, + size_t max_distance, + float q_scaling, + int start_id, + int end_id, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, + NcclParam tensor_para, + NcclParam pipeline_para, + ActivationType activation_type = ActivationType::Relu, + bool tie_word_embeddings = true, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); T5Decoding(T5Decoding const& T5Decoding); ~T5Decoding(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, const T5DecodingWeight* Decoding_weights); - void forward(std::unordered_map* output_tensors, + void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const T5DecodingWeight* Decoding_weights); + const T5DecodingWeight* Decoding_weights); void setStream(cudaStream_t stream) override; }; diff --git a/src/fastertransformer/models/t5/T5DecodingWeight.cc b/src/fastertransformer/models/t5/T5DecodingWeight.cc index 1749583b7..5346a24fd 100644 --- a/src/fastertransformer/models/t5/T5DecodingWeight.cc +++ b/src/fastertransformer/models/t5/T5DecodingWeight.cc @@ -20,19 +20,20 @@ namespace fastertransformer { template -T5DecodingWeight::T5DecodingWeight(const size_t head_num, - const size_t size_per_head, - const size_t d_model, - const size_t inter_size, - const size_t vocab_size, - const size_t num_layer, - const size_t mem_d_model, - const size_t num_bucket_or_max_seq_len, - const size_t tensor_para_size, - const size_t tensor_para_rank, - const size_t pipeline_para_size, - const size_t pipeline_para_rank, - const bool t5_with_bias_para, +T5DecodingWeight::T5DecodingWeight(const size_t head_num, + const size_t size_per_head, + const size_t d_model, + const size_t inter_size, + const size_t vocab_size, + const size_t num_layer, + const size_t mem_d_model, + const size_t num_bucket_or_max_seq_len, + const size_t tensor_para_size, + const size_t tensor_para_rank, + const size_t pipeline_para_size, + const size_t pipeline_para_rank, + const bool t5_with_bias_para, + const bool use_gated_activation_para, const PositionEmbeddingType pe_type): head_num_(head_num), size_per_head_(size_per_head), @@ -47,15 +48,18 @@ T5DecodingWeight::T5DecodingWeight(const size_t head_num, pipeline_para_size_(pipeline_para_size), pipeline_para_rank_(pipeline_para_rank), t5_with_bias(t5_with_bias_para), + use_gated_activation(use_gated_activation_para), position_embedding_type(pe_type), real_weights_num_(t5_with_bias ? 6 : 4) { FT_LOG_DEBUG("T5DecodingWeight " + std::string(__func__) + " start"); + FT_CHECK(num_layer_ % pipeline_para_size_ == 0); initialize(); mallocWeights(); setWeightPtr(); decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { decoder_layer_weights.push_back(new T5DecoderLayerWeight(head_num_, @@ -65,7 +69,8 @@ T5DecodingWeight::T5DecodingWeight(const size_t head_num, mem_d_model_, tensor_para_size_, tensor_para_rank_, - t5_with_bias)); + t5_with_bias, + use_gated_activation)); } else { decoder_layer_weights.push_back(new T5DecoderLayerWeight()); @@ -102,13 +107,13 @@ T5DecodingWeight::~T5DecodingWeight() deviceFree(weights_ptr[i]); } - pre_decoder_embedding_table = nullptr; + pre_decoder_embedding_table = nullptr; absolute_or_relative_position_embedding = nullptr; - post_decoder_layernorm.gamma = nullptr; - post_decoder_embedding.kernel = nullptr; - post_decoder_embedding.bias = nullptr; - post_decoder_layernorm.beta = nullptr; - is_maintain_buffer = false; + post_decoder_layernorm.gamma = nullptr; + post_decoder_embedding.kernel = nullptr; + post_decoder_embedding.bias = nullptr; + post_decoder_layernorm.beta = nullptr; + is_maintain_buffer = false; } FT_LOG_DEBUG("T5DecodingWeight " + std::string(__func__) + " end"); } @@ -128,6 +133,7 @@ T5DecodingWeight::T5DecodingWeight(const T5DecodingWeight& other): pipeline_para_size_(other.pipeline_para_size_), pipeline_para_rank_(other.pipeline_para_rank_), t5_with_bias(other.t5_with_bias), + use_gated_activation(other.use_gated_activation), position_embedding_type(other.position_embedding_type), real_weights_num_(other.real_weights_num_) { @@ -140,6 +146,7 @@ T5DecodingWeight::T5DecodingWeight(const T5DecodingWeight& other): setWeightPtr(); decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(new T5DecoderLayerWeight(*other.decoder_layer_weights[l])); } @@ -149,21 +156,22 @@ T5DecodingWeight::T5DecodingWeight(const T5DecodingWeight& other): template T5DecodingWeight& T5DecodingWeight::operator=(const T5DecodingWeight& other) { - head_num_ = other.head_num_; - size_per_head_ = other.size_per_head_; - d_model_ = other.d_model_; - inter_size_ = other.inter_size_; - vocab_size_ = other.vocab_size_; - num_layer_ = other.num_layer_; - mem_d_model_ = other.mem_d_model_; + head_num_ = other.head_num_; + size_per_head_ = other.size_per_head_; + d_model_ = other.d_model_; + inter_size_ = other.inter_size_; + vocab_size_ = other.vocab_size_; + num_layer_ = other.num_layer_; + mem_d_model_ = other.mem_d_model_; num_bucket_or_max_seq_len_ = other.num_bucket_or_max_seq_len_; - tensor_para_size_ = other.tensor_para_size_; - tensor_para_rank_ = other.tensor_para_rank_; - pipeline_para_size_ = other.pipeline_para_size_; - pipeline_para_rank_ = other.pipeline_para_rank_; - t5_with_bias = other.t5_with_bias; - position_embedding_type = other.position_embedding_type; - real_weights_num_ = other.real_weights_num_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + pipeline_para_size_ = other.pipeline_para_size_; + pipeline_para_rank_ = other.pipeline_para_rank_; + t5_with_bias = other.t5_with_bias; + use_gated_activation = other.use_gated_activation; + position_embedding_type = other.position_embedding_type; + real_weights_num_ = other.real_weights_num_; FT_LOG_DEBUG("T5DecodingWeight " + std::string(__func__) + " start"); initialize(); @@ -174,6 +182,7 @@ T5DecodingWeight& T5DecodingWeight::operator=(const T5DecodingWeight& othe setWeightPtr(); decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(new T5DecoderLayerWeight(*other.decoder_layer_weights[l])); } @@ -195,10 +204,10 @@ void T5DecodingWeight::mallocWeights() template void T5DecodingWeight::setWeightPtr() { - pre_decoder_embedding_table = weights_ptr[0]; + pre_decoder_embedding_table = weights_ptr[0]; absolute_or_relative_position_embedding = weights_ptr[1]; - post_decoder_layernorm.gamma = weights_ptr[2]; - post_decoder_embedding.kernel = weights_ptr[3]; + post_decoder_layernorm.gamma = weights_ptr[2]; + post_decoder_embedding.kernel = weights_ptr[3]; if (t5_with_bias) { post_decoder_layernorm.beta = weights_ptr[4]; post_decoder_embedding.bias = weights_ptr[5]; @@ -212,24 +221,25 @@ void T5DecodingWeight::loadModel(std::string dir_path) FT_CHECK(is_maintain_buffer == true); FtCudaDataType model_file_type = getModelFileType(dir_path + "/config.ini", "decoder"); - loadWeightFromBin(weights_ptr[0], {(int)weights_size[0]}, dir_path + "/shared.weight_T.bin", model_file_type); + loadWeightFromBin(weights_ptr[0], {(size_t)weights_size[0]}, dir_path + "/shared.weight_T.bin", model_file_type); if (position_embedding_type == PositionEmbeddingType::absolute) { - loadWeightFromBin(weights_ptr[1], {(int)weights_size[1]}, dir_path + "/shared.ape.bin", model_file_type); + loadWeightFromBin(weights_ptr[1], {(size_t)weights_size[1]}, dir_path + "/shared.ape.bin", model_file_type); } else { loadWeightFromBin(weights_ptr[1], - {(int)weights_size[1]}, + {(size_t)weights_size[1]}, dir_path + "/decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); } loadWeightFromBin( - weights_ptr[2], {(int)weights_size[2]}, dir_path + "/decoder.final_layer_norm.weight.bin", model_file_type); - loadWeightFromBin(weights_ptr[3], {(int)weights_size[3]}, dir_path + "/shared.weight_T.bin", model_file_type); + weights_ptr[2], {(size_t)weights_size[2]}, dir_path + "/decoder.final_layer_norm.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[3], {(size_t)weights_size[3]}, dir_path + "/shared.weight_T.bin", model_file_type); if (t5_with_bias) { - loadWeightFromBin(weights_ptr[4], {weights_size[4]}, dir_path + "/decoder.final_layer_norm.bias.bin"); - loadWeightFromBin(weights_ptr[5], {weights_size[5]}, dir_path + "/shared.bias.bin"); + loadWeightFromBin( + weights_ptr[4], {(size_t)weights_size[4]}, dir_path + "/decoder.final_layer_norm.bias.bin"); + loadWeightFromBin(weights_ptr[5], {(size_t)weights_size[5]}, dir_path + "/shared.bias.bin"); } for (int l = 0; l < num_layer_; l++) { @@ -256,6 +266,7 @@ void T5DecodingWeight::resizeLayer(const int num_layer) FT_LOG_DEBUG("T5DecodingWeight " + std::string(__func__) + " start"); decoder_layer_weights.clear(); num_layer_ = num_layer; + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(new T5DecoderLayerWeight()); } @@ -263,16 +274,22 @@ void T5DecodingWeight::resizeLayer(const int num_layer) } template -void T5DecodingWeight::setT5StructureDiff(bool t5_with_bias_para, PositionEmbeddingType position_embedding_type_para) +void T5DecodingWeight::setT5StructureDiff(bool t5_with_bias_para, + bool use_gated_activation_para, + PositionEmbeddingType position_embedding_type_para) { - t5_with_bias = t5_with_bias_para; + t5_with_bias = t5_with_bias_para; + use_gated_activation = use_gated_activation_para; position_embedding_type = position_embedding_type_para; for (int i = 0; i < num_layer_; i++) { - decoder_layer_weights[i]->setT5WithBias(t5_with_bias_para); + decoder_layer_weights[i]->setT5WithBias(t5_with_bias_para, use_gated_activation_para); } } template struct T5DecodingWeight; template struct T5DecodingWeight; +#ifdef ENABLE_BF16 +template struct T5DecodingWeight<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/t5/T5DecodingWeight.h b/src/fastertransformer/models/t5/T5DecodingWeight.h index 4d237031e..9db7dddb4 100644 --- a/src/fastertransformer/models/t5/T5DecodingWeight.h +++ b/src/fastertransformer/models/t5/T5DecodingWeight.h @@ -28,37 +28,41 @@ template struct T5DecodingWeight { T5DecodingWeight() = default; - T5DecodingWeight(const size_t head_num, - const size_t size_per_head, - const size_t d_model, - const size_t inter_size, - const size_t vocab_size, - const size_t num_layer, - const size_t mem_d_model, - const size_t num_bucket_or_max_seq_len, - const size_t tensor_para_size, - const size_t tensor_para_rank, - const size_t pipeline_para_size, - const size_t pipeline_para_rank, - const bool t5_with_bias_para = false, + T5DecodingWeight(const size_t head_num, + const size_t size_per_head, + const size_t d_model, + const size_t inter_size, + const size_t vocab_size, + const size_t num_layer, + const size_t mem_d_model, + const size_t num_bucket_or_max_seq_len, + const size_t tensor_para_size, + const size_t tensor_para_rank, + const size_t pipeline_para_size, + const size_t pipeline_para_rank, + const bool t5_with_bias_para = false, + const bool use_gated_activation_para = false, const PositionEmbeddingType position_embedding_type_para = PositionEmbeddingType::relative); ~T5DecodingWeight(); T5DecodingWeight(const T5DecodingWeight& other); T5DecodingWeight& operator=(const T5DecodingWeight& other); std::vector*> decoder_layer_weights; - const T* pre_decoder_embedding_table = nullptr; - const T* absolute_or_relative_position_embedding = nullptr; - LayerNormWeight post_decoder_layernorm; - DenseWeight post_decoder_embedding; - bool t5_with_bias = false; + const T* pre_decoder_embedding_table = nullptr; + const T* absolute_or_relative_position_embedding = nullptr; + LayerNormWeight post_decoder_layernorm; + DenseWeight post_decoder_embedding; + bool t5_with_bias = false; + bool use_gated_activation = false; // 0 = relative_position_embedding, 1 = absolute_position_embedding PositionEmbeddingType position_embedding_type = PositionEmbeddingType::relative; void loadModel(std::string dir_path); void resizeLayer(const int num_layer); - void setT5StructureDiff(bool t5_with_bias_para, PositionEmbeddingType position_embedding_type_para); + void setT5StructureDiff(bool t5_with_bias_para, + bool use_gated_activation_para, + PositionEmbeddingType position_embedding_type_para); private: void setWeightPtr(); @@ -74,19 +78,19 @@ struct T5DecodingWeight { size_t num_layer_; size_t mem_d_model_; // refer to num_buckt if using relative position embedding - // refer to max_seq_len if using absoulte position embedding + // refer to max_seq_len if using absolute position embedding size_t num_bucket_or_max_seq_len_; size_t tensor_para_size_; size_t tensor_para_rank_; size_t pipeline_para_size_; size_t pipeline_para_rank_; - bool is_maintain_buffer = false; + bool is_maintain_buffer = false; int real_weights_num_; const static int weights_num_ = 6; - T* weights_ptr[weights_num_]; - size_t weights_size[weights_num_]; + T* weights_ptr[weights_num_]; + size_t weights_size[weights_num_]; }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/t5/T5Encoder.cc b/src/fastertransformer/models/t5/T5Encoder.cc index db989ffc9..a72ce3cb1 100644 --- a/src/fastertransformer/models/t5/T5Encoder.cc +++ b/src/fastertransformer/models/t5/T5Encoder.cc @@ -60,7 +60,9 @@ void T5Encoder::initialize() throw std::runtime_error(std::string("[FT][ERROR] Invalid attention type \n")); } - if (activation_type_ == ActivationType::Gelu) { + bool use_gated_activation = activation_type_ == ActivationType::GeGLU || activation_type_ == ActivationType::ReGLU + || activation_type_ == ActivationType::SiGLU; + if (activation_type_ == ActivationType::Gelu || activation_type_ == ActivationType::GeGLU) { ffn_layer_ = new TensorParallelGeluFfnLayer(max_batch_size_, max_seq_len_, 1, @@ -70,13 +72,15 @@ void T5Encoder::initialize() stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, sparse_, 0, + use_gated_activation, // don't use GeGLU custom_all_reduce_comm_, enable_custom_all_reduce_); } - else if (activation_type_ == ActivationType::Relu) { + else if (activation_type_ == ActivationType::Relu || activation_type_ == ActivationType::ReGLU) { ffn_layer_ = new TensorParallelReluFfnLayer(max_batch_size_, max_seq_len_, 1, @@ -86,37 +90,56 @@ void T5Encoder::initialize() stream_, cublas_wrapper_, allocator_, + true, is_free_buffer_after_forward_, sparse_, + use_gated_activation, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + } + else if (activation_type_ == ActivationType::Silu || activation_type_ == ActivationType::SiGLU) { + ffn_layer_ = new TensorParallelSiluFfnLayer(max_batch_size_, + max_seq_len_, + 1, + d_model_, + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + true, + is_free_buffer_after_forward_, + sparse_, + use_gated_activation, custom_all_reduce_comm_, enable_custom_all_reduce_); } } template -T5Encoder::T5Encoder(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - size_t num_bucket_or_max_seq_len, - size_t max_distance, - int sm, - float q_scaling, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type, - bool sparse, - ActivationType activation_type, - LayerNormType layernorm_type, - NcclParam tensor_para, - NcclParam pipeline_para, +T5Encoder::T5Encoder(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t d_model, + size_t num_layer, + size_t num_bucket_or_max_seq_len, + size_t max_distance, + int sm, + float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type, + bool sparse, + ActivationType activation_type, + LayerNormType layernorm_type, + NcclParam tensor_para, + NcclParam pipeline_para, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -190,25 +213,35 @@ template void T5Encoder::allocateBuffer() { if (is_allocate_buffer_ == false) { - token_num_ = (size_t*)allocator_->malloc(sizeof(size_t) * 1, false); - padding_offset_ = (int*)allocator_->malloc(sizeof(int) * max_batch_size_ * max_seq_len_, false); - trt_mha_padding_offset_ = (int*)allocator_->malloc(sizeof(int) * (2 * max_batch_size_ + 1), false); - - attention_mask_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); - relative_attention_bias_ = (T*)allocator_->malloc(sizeof(T) * head_num_ * max_seq_len_ * max_seq_len_, false); - - t5_encoder_emb_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); - t5_encoder_in_buffer_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); - attn_out_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); - t5_encoder_out_buffer_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); + token_num_ = (size_t*)allocator_->reMalloc(token_num_, sizeof(size_t) * 1, false); + padding_offset_ = + (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * max_batch_size_ * max_seq_len_, false); + trt_mha_padding_offset_ = + (int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * max_batch_size_ + 1), false); + + attention_mask_ = + (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); + relative_attention_bias_ = (T*)allocator_->reMalloc( + relative_attention_bias_, sizeof(T) * head_num_ * max_seq_len_ * max_seq_len_, false); + + t5_encoder_emb_buf_ = + (T*)allocator_->reMalloc(t5_encoder_emb_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); + t5_encoder_in_buffer_ = (T*)allocator_->reMalloc( + t5_encoder_in_buffer_, sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); + attn_out_buf_ = + (T*)allocator_->reMalloc(attn_out_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); + t5_encoder_out_buffer_ = (T*)allocator_->reMalloc( + t5_encoder_out_buffer_, sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); if (layernorm_type_ == LayerNormType::post_layernorm) { - normed_from_tensor_ = nullptr; + normed_from_tensor_ = nullptr; normed_attn_out_buf_ = nullptr; } else { - normed_from_tensor_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); - normed_attn_out_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); + normed_from_tensor_ = (T*)allocator_->reMalloc( + normed_from_tensor_, sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); + normed_attn_out_buf_ = (T*)allocator_->reMalloc( + normed_attn_out_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * d_model_, false); } is_allocate_buffer_ = true; } @@ -218,7 +251,7 @@ template void T5Encoder::allocateBuffer(size_t batch_size, size_t seq_len) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - token_num_ = (size_t*)allocator_->reMalloc(token_num_, sizeof(size_t) * 1, false); + token_num_ = (size_t*)allocator_->reMalloc(token_num_, sizeof(size_t) * 1, false); padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false); trt_mha_padding_offset_ = (int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * batch_size + 1), false); @@ -236,7 +269,7 @@ void T5Encoder::allocateBuffer(size_t batch_size, size_t seq_len) (T*)allocator_->reMalloc(t5_encoder_out_buffer_, sizeof(T) * batch_size * seq_len * d_model_, false); if (layernorm_type_ == LayerNormType::post_layernorm) { - normed_from_tensor_ = nullptr; + normed_from_tensor_ = nullptr; normed_attn_out_buf_ = nullptr; } else { @@ -252,24 +285,24 @@ template void T5Encoder::freeBuffer() { if (is_allocate_buffer_) { - allocator_->free(token_num_); - allocator_->free(padding_offset_); - allocator_->free(trt_mha_padding_offset_); + allocator_->free((void**)(&token_num_)); + allocator_->free((void**)(&padding_offset_)); + allocator_->free((void**)(&trt_mha_padding_offset_)); - allocator_->free(attention_mask_); - allocator_->free(relative_attention_bias_); - allocator_->free(t5_encoder_emb_buf_); - allocator_->free(t5_encoder_in_buffer_); - allocator_->free(attn_out_buf_); - allocator_->free(t5_encoder_out_buffer_); + allocator_->free((void**)(&attention_mask_)); + allocator_->free((void**)(&relative_attention_bias_)); + allocator_->free((void**)(&t5_encoder_emb_buf_)); + allocator_->free((void**)(&t5_encoder_in_buffer_)); + allocator_->free((void**)(&attn_out_buf_)); + allocator_->free((void**)(&t5_encoder_out_buffer_)); if (layernorm_type_ == LayerNormType::post_layernorm) { - normed_from_tensor_ = nullptr; + normed_from_tensor_ = nullptr; normed_attn_out_buf_ = nullptr; } else { - allocator_->free(normed_from_tensor_); - allocator_->free(normed_attn_out_buf_); + allocator_->free((void**)(&normed_from_tensor_)); + allocator_->free((void**)(&normed_attn_out_buf_)); } is_allocate_buffer_ = false; } @@ -305,9 +338,9 @@ int T5Encoder::getFirstLayerParallelId() } template -void T5Encoder::forward(std::vector* output_tensors, +void T5Encoder::forward(std::vector* output_tensors, const std::vector* input_tensors, - const T5EncoderWeight* t5_encoder_weights) + const T5EncoderWeight* t5_encoder_weights) { // input_tensors: // input_ids [batch, seqlen] @@ -323,29 +356,45 @@ void T5Encoder::forward(std::vector* output_tensors, } template -void T5Encoder::forward(std::unordered_map* output_tensors, +void T5Encoder::forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const T5EncoderWeight* t5_encoder_weights) + const T5EncoderWeight* t5_encoder_weights) { // input_tensors: // input_ids [batch, seqlen] // sequence_length [batch] + // inputs_embeds [batch, seqlen, d_model_] // output tensors: // output_hidden_state [batch, seqlen, d_model_] FT_LOG_DEBUG(__PRETTY_FUNCTION__); - const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; - const size_t request_seq_len = input_tensors->at("input_ids").shape[1]; - FT_CHECK(input_tensors->size() == 2); + const bool use_inputs_embeds = (bool)input_tensors->count("inputs_embeds"); + if (use_inputs_embeds) { + if (input_tensors->count("input_ids")) { + FT_LOG_WARNING("Pass input_ids and inputs_embeds at the same time, using inputs_embeds"); + } + FT_CHECK(input_tensors->at("inputs_embeds").shape.size() == 3); + FT_LOG_INFO("Using inputs embeds instead of input_ids !"); + } + else { + FT_CHECK(input_tensors->at("input_ids").shape.size() == 2); + } + std::string input_tensor_name = use_inputs_embeds ? "inputs_embeds" : "input_ids"; + const size_t request_batch_size = input_tensors->at(input_tensor_name).shape[0]; + const size_t request_seq_len = input_tensors->at(input_tensor_name).shape[1]; + FT_CHECK(input_tensors->size() == 2 || input_tensors->size() == 3); FT_CHECK(request_batch_size == input_tensors->at("sequence_length").shape[0]); - FT_CHECK(input_tensors->at("input_ids").shape.size() == 2); FT_CHECK(input_tensors->at("sequence_length").shape.size() == 1); + allocateBuffer(request_batch_size, request_seq_len); // T5 Structure Difference - bool t5_with_bias = t5_encoder_weights->t5_with_bias; + const bool t5_with_bias = t5_encoder_weights->t5_with_bias; PositionEmbeddingType position_embedding_type = t5_encoder_weights->position_embedding_type; + const bool use_inputs_embeds_buffer = + use_inputs_embeds && position_embedding_type == PositionEmbeddingType::relative; + invokeBuildRelativeAttentionBias(relative_attention_bias_, t5_encoder_weights->absolute_or_relative_position_embedding, head_num_, @@ -362,48 +411,55 @@ void T5Encoder::forward(std::unordered_map* output_tenso sizeof(T) * request_batch_size * request_seq_len * d_model_); } const size_t local_batch_size = getLocalBatchSize(request_batch_size, request_seq_len, pipeline_para_.world_size_); - const size_t iteration_num = request_batch_size / local_batch_size; + const size_t iteration_num = request_batch_size / local_batch_size; + for (uint ite = 0; ite < iteration_num; ite++) { - size_t id_offset = ite * local_batch_size; + size_t id_offset = ite * local_batch_size; size_t d_model_offset = id_offset * request_seq_len * d_model_; const int* sequence_lengths = input_tensors->at("sequence_length").getPtr() + id_offset; if (position_embedding_type == PositionEmbeddingType::absolute) { - invokeInputIdsEmbeddingLookupPosEncoding(t5_encoder_emb_buf_, - nullptr, - t5_encoder_weights->embedding_table, - t5_encoder_weights->absolute_or_relative_position_embedding, - input_tensors->at("input_ids").getPtr() - + id_offset * request_seq_len, - 1, - request_seq_len, - request_seq_len, - local_batch_size, - d_model_, - stream_); + const int prompt_token_start_id = 0; + invokeInputIdsEmbeddingLookupPosEncoding( + t5_encoder_emb_buf_, + nullptr, + use_inputs_embeds ? input_tensors->at("inputs_embeds").getPtr() : + t5_encoder_weights->embedding_table, + t5_encoder_weights->absolute_or_relative_position_embedding, + pPromptTuningParam{}, // p/prompt tuning + use_inputs_embeds ? nullptr : + input_tensors->at("input_ids").getPtrWithOffset(id_offset * request_seq_len), + 1, + request_seq_len, + request_seq_len, + local_batch_size, + d_model_, + stream_); } else { - invokeEmbeddingLookupPosEncoding(t5_encoder_emb_buf_, - t5_encoder_weights->embedding_table, - (const T*)nullptr, - input_tensors->at("input_ids").getPtr() + id_offset * request_seq_len, - nullptr, - local_batch_size * request_seq_len, - d_model_, - (T)1.0f, - 0, - 0, - local_batch_size * request_seq_len, - 0, - stream_); + if (!use_inputs_embeds) { + invokeEmbeddingLookupPosEncodingPadCount( + t5_encoder_emb_buf_, + t5_encoder_weights->embedding_table, + (const T*)nullptr, + input_tensors->at("input_ids").getPtrWithOffset(id_offset * request_seq_len), + nullptr, + local_batch_size * request_seq_len, + d_model_, + (T)1.0f, + 0, + local_batch_size * request_seq_len, + 0, + stream_); + } } sync_check_cuda_error(); - size_t h_token_num; - T* t5_encoder_input_ptr; - T* t5_encoder_output_ptr; + size_t h_token_num; + T* t5_encoder_input_ptr; + T* t5_encoder_output_ptr; Tensor* padding_offset_tensor_ptr; // preprocess (remove padding and build mask) switch (attention_type_) { @@ -422,11 +478,17 @@ void T5Encoder::forward(std::unordered_map* output_tenso sync_check_cuda_error(); if (pipeline_para_.rank_ == 0) { - invokeRemovePadding( - t5_encoder_in_buffer_, t5_encoder_emb_buf_, padding_offset_, h_token_num, d_model_, stream_); + invokeRemovePadding(t5_encoder_in_buffer_, + use_inputs_embeds_buffer ? + input_tensors->at("inputs_embeds").getPtrWithOffset(d_model_offset) : + t5_encoder_emb_buf_, + padding_offset_, + h_token_num, + d_model_, + stream_); sync_check_cuda_error(); } - t5_encoder_input_ptr = t5_encoder_in_buffer_; + t5_encoder_input_ptr = t5_encoder_in_buffer_; t5_encoder_output_ptr = t5_encoder_out_buffer_; padding_offset_tensor_ptr = @@ -437,10 +499,19 @@ void T5Encoder::forward(std::unordered_map* output_tenso invokeBuildEncoderAttentionMask( attention_mask_, sequence_lengths, local_batch_size, request_seq_len, stream_); - sync_check_cuda_error(); h_token_num = local_batch_size * request_seq_len; - t5_encoder_input_ptr = t5_encoder_emb_buf_; - t5_encoder_output_ptr = output_tensors->at("output_hidden_state").getPtr() + d_model_offset; + if (use_inputs_embeds_buffer) { + cudaMemcpyAsync(t5_encoder_emb_buf_, + input_tensors->at("inputs_embeds").getPtrWithOffset(d_model_offset), + sizeof(T) * h_token_num * d_model_, + cudaMemcpyDeviceToDevice, + stream_); + } + + sync_check_cuda_error(); + h_token_num = local_batch_size * request_seq_len; + t5_encoder_input_ptr = t5_encoder_emb_buf_; + t5_encoder_output_ptr = output_tensors->at("output_hidden_state").getPtr() + d_model_offset; padding_offset_tensor_ptr = new Tensor(MEMORY_GPU, TYPE_INT32, std::vector{0}, nullptr); break; } @@ -492,7 +563,7 @@ void T5Encoder::forward(std::unordered_map* output_tenso continue; } T* from_tensor = (i == 0 ? t5_encoder_input_ptr : t5_encoder_output_ptr); - T* out_tensor = t5_encoder_output_ptr; + T* out_tensor = t5_encoder_output_ptr; if (isFirstLayerParallelId(i) && pipeline_para_.rank_ != 0) { const int data_size = h_token_num * d_model_ / tensor_para_.world_size_; @@ -510,6 +581,7 @@ void T5Encoder::forward(std::unordered_map* output_tenso from_tensor, t5_encoder_weights->t5_encoder_layer_weights[i]->attn_layernorm_weights.gamma, t5_encoder_weights->t5_encoder_layer_weights[i]->attn_layernorm_weights.beta, + layernorm_eps_, h_token_num, d_model_, stream_); @@ -548,6 +620,7 @@ void T5Encoder::forward(std::unordered_map* output_tenso t5_encoder_weights->t5_encoder_layer_weights[i]->attn_layernorm_weights.gamma, t5_encoder_weights->t5_encoder_layer_weights[i]->attn_layernorm_weights.beta, t5_encoder_weights->t5_encoder_layer_weights[i]->attention_weights.attention_output_weight.bias, + layernorm_eps_, h_token_num, d_model_, stream_); @@ -560,6 +633,7 @@ void T5Encoder::forward(std::unordered_map* output_tenso t5_encoder_weights->t5_encoder_layer_weights[i]->ffn_layernorm_weights.gamma, t5_encoder_weights->t5_encoder_layer_weights[i]->ffn_layernorm_weights.beta, t5_encoder_weights->t5_encoder_layer_weights[i]->attention_weights.attention_output_weight.bias, + layernorm_eps_, h_token_num, d_model_, stream_); @@ -587,6 +661,7 @@ void T5Encoder::forward(std::unordered_map* output_tenso t5_encoder_weights->t5_encoder_layer_weights[i]->ffn_layernorm_weights.gamma, t5_encoder_weights->t5_encoder_layer_weights[i]->ffn_layernorm_weights.beta, t5_encoder_weights->t5_encoder_layer_weights[i]->ffn_weights.output_weight.bias, + layernorm_eps_, h_token_num, d_model_, stream_); @@ -611,12 +686,15 @@ void T5Encoder::forward(std::unordered_map* output_tenso } } + // exit(0); + if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { if (layernorm_type_ == LayerNormType::pre_layernorm) { invokeGeneralT5LayerNorm(t5_encoder_output_ptr, t5_encoder_output_ptr, t5_encoder_weights->post_transformer_layernorm_weights.gamma, t5_encoder_weights->post_transformer_layernorm_weights.beta, + layernorm_eps_, h_token_num, d_model_, stream_); @@ -666,7 +744,7 @@ void T5Encoder::forward(std::unordered_map* output_tenso sync_check_cuda_error(); if (pipeline_para_.world_size_ > 1) { - NCCLCHECK(ncclGroupStart()); + ftNcclGroupStart(); const int data_size = request_batch_size * request_seq_len * d_model_ / tensor_para_.world_size_; ftNcclBroadCast(output_tensors->at("output_hidden_state").getPtr() + data_size * tensor_para_.rank_, data_size, @@ -674,7 +752,7 @@ void T5Encoder::forward(std::unordered_map* output_tenso pipeline_para_, stream_); - NCCLCHECK(ncclGroupEnd()); + ftNcclGroupEnd(); check_cuda_error(cudaStreamSynchronize(stream_)); sync_check_cuda_error(); if (tensor_para_.world_size_ > 1) { @@ -690,5 +768,8 @@ void T5Encoder::forward(std::unordered_map* output_tenso template class T5Encoder; template class T5Encoder; +#ifdef ENABLE_BF16 +template class T5Encoder<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/t5/T5Encoder.h b/src/fastertransformer/models/t5/T5Encoder.h index 44039761a..b10ae0f6d 100644 --- a/src/fastertransformer/models/t5/T5Encoder.h +++ b/src/fastertransformer/models/t5/T5Encoder.h @@ -24,6 +24,7 @@ #include "src/fastertransformer/kernels/layernorm_kernels.h" #include "src/fastertransformer/layers/TensorParallelGeluFfnLayer.h" #include "src/fastertransformer/layers/TensorParallelReluFfnLayer.h" +#include "src/fastertransformer/layers/TensorParallelSiluFfnLayer.h" // #include "src/fastertransformer/layers/attention_layers/FusedAttentionLayer.h" #include "src/fastertransformer/layers/attention_layers/TensorParallelUnfusedAttentionLayer.h" #include "src/fastertransformer/models/t5/T5EncoderWeight.h" @@ -36,24 +37,25 @@ class T5Encoder: public BaseLayer { private: // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; // meta data - const size_t head_num_; - const size_t size_per_head_; - const size_t inter_size_; - const size_t hidden_units_; - const size_t d_model_; - const size_t num_layer_; - const size_t num_bucket_or_max_seq_len_; - const size_t max_distance_; - int sm_; - float q_scaling_; - AttentionType attention_type_; - bool sparse_; + const size_t head_num_; + const size_t size_per_head_; + const size_t inter_size_; + const size_t hidden_units_; + const size_t d_model_; + const size_t num_layer_; + const size_t num_bucket_or_max_seq_len_; + const size_t max_distance_; + int sm_; + constexpr static float layernorm_eps_ = 1e-6f; + float q_scaling_; + AttentionType attention_type_; + bool sparse_; BaseAttentionLayer* attention_layer_; - FfnLayer* ffn_layer_; + FfnLayer* ffn_layer_; bool is_allocate_buffer_ = false; @@ -64,68 +66,68 @@ class T5Encoder: public BaseLayer { bool isValidLayerParallelId(uint l); bool isFirstLayerParallelId(uint l); bool isLastLayerParallelId(uint l); - int getFirstLayerParallelId(); + int getFirstLayerParallelId(); const ActivationType activation_type_; - const LayerNormType layernorm_type_; + const LayerNormType layernorm_type_; const NcclParam tensor_para_; const NcclParam pipeline_para_; std::shared_ptr custom_all_reduce_comm_; - int enable_custom_all_reduce_; + int enable_custom_all_reduce_; protected: // model params - size_t* token_num_ = nullptr; - int* padding_offset_ = nullptr; - int* trt_mha_padding_offset_ = nullptr; - T* attention_mask_ = nullptr; - T* relative_attention_bias_ = nullptr; - T* t5_encoder_emb_buf_ = nullptr; - T* t5_encoder_in_buffer_ = nullptr; - T* attn_out_buf_ = nullptr; - T* t5_encoder_out_buffer_ = nullptr; - - T* normed_from_tensor_ = nullptr; + size_t* token_num_ = nullptr; + int* padding_offset_ = nullptr; + int* trt_mha_padding_offset_ = nullptr; + T* attention_mask_ = nullptr; + T* relative_attention_bias_ = nullptr; + T* t5_encoder_emb_buf_ = nullptr; + T* t5_encoder_in_buffer_ = nullptr; + T* attn_out_buf_ = nullptr; + T* t5_encoder_out_buffer_ = nullptr; + + T* normed_from_tensor_ = nullptr; T* normed_attn_out_buf_ = nullptr; public: - T5Encoder(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - size_t num_bucket_or_max_seq_len, - size_t max_distance, - int sm, - float q_scaling, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type, - bool sparse, - ActivationType activation_type, - LayerNormType layernorm_type, - NcclParam tensor_para, - NcclParam pipeline_para, - std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + T5Encoder(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t d_model, + size_t num_layer, + size_t num_bucket_or_max_seq_len, + size_t max_distance, + int sm, + float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type, + bool sparse, + ActivationType activation_type, + LayerNormType layernorm_type, + NcclParam tensor_para, + NcclParam pipeline_para, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); T5Encoder(T5Encoder const& t5_layer); ~T5Encoder(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const T5EncoderWeight* t5_weights); + const T5EncoderWeight* t5_weights); - void forward(std::unordered_map* output_tensors, + void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, - const T5EncoderWeight* t5_weights); + const T5EncoderWeight* t5_weights); inline size_t getDModel() { diff --git a/src/fastertransformer/models/t5/T5EncoderLayerWeight.cc b/src/fastertransformer/models/t5/T5EncoderLayerWeight.cc index 750b353cf..d53f23777 100644 --- a/src/fastertransformer/models/t5/T5EncoderLayerWeight.cc +++ b/src/fastertransformer/models/t5/T5EncoderLayerWeight.cc @@ -27,7 +27,8 @@ T5EncoderLayerWeight::T5EncoderLayerWeight(const size_t head_num, const size_t inter_size, const size_t tensor_para_size, const size_t tensor_para_rank, - const bool t5_with_bias): + const bool t5_with_bias, + const bool use_gated_activation): head_num_(head_num), size_per_head_(size_per_head), d_model_(d_model), @@ -35,8 +36,9 @@ T5EncoderLayerWeight::T5EncoderLayerWeight(const size_t head_num, tensor_para_size_(tensor_para_size), tensor_para_rank_(tensor_para_rank), t5_with_bias_(t5_with_bias), - real_weights_num_(t5_with_bias ? 16 : 8) + use_gated_activation_(use_gated_activation) { + real_weights_num_ = (8 + (use_gated_activation_ ? 1 : 0)) * (t5_with_bias_ ? 2 : 1); FT_LOG_DEBUG("T5EncoderLayerWeight " + std::string(__func__) + " start"); initialize(); mallocWeights(); @@ -53,18 +55,39 @@ void T5EncoderLayerWeight::initialize() weights_size[2] = d_model_ * (head_num_ / tensor_para_size_) * size_per_head_; weights_size[3] = (head_num_ / tensor_para_size_) * size_per_head_ * d_model_; weights_size[4] = d_model_; - weights_size[5] = d_model_ * (inter_size_ / tensor_para_size_); - weights_size[6] = (inter_size_ / tensor_para_size_) * d_model_; - weights_size[7] = d_model_; + if (use_gated_activation_) { + weights_size[5] = d_model_ * (inter_size_ / tensor_para_size_); + weights_size[6] = d_model_ * (inter_size_ / tensor_para_size_); // for gated activation + weights_size[7] = (inter_size_ / tensor_para_size_) * d_model_; + weights_size[8] = d_model_; + } + else { + weights_size[5] = d_model_ * (inter_size_ / tensor_para_size_); + weights_size[6] = (inter_size_ / tensor_para_size_) * d_model_; + weights_size[7] = d_model_; + } if (t5_with_bias_) { - weights_size[8] = (head_num_ / tensor_para_size_) * size_per_head_; - weights_size[9] = (head_num_ / tensor_para_size_) * size_per_head_; - weights_size[10] = (head_num_ / tensor_para_size_) * size_per_head_; - weights_size[11] = d_model_; - weights_size[12] = d_model_; - weights_size[13] = (inter_size_ / tensor_para_size_); - weights_size[14] = d_model_; - weights_size[15] = d_model_; + if (use_gated_activation_) { + weights_size[9] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[10] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[11] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[12] = d_model_; + weights_size[13] = d_model_; + weights_size[14] = (inter_size_ / tensor_para_size_); + weights_size[15] = (inter_size_ / tensor_para_size_); // for gated activation + weights_size[16] = d_model_; + weights_size[17] = d_model_; + } + else { + weights_size[8] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[9] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[10] = (head_num_ / tensor_para_size_) * size_per_head_; + weights_size[11] = d_model_; + weights_size[12] = d_model_; + weights_size[13] = (inter_size_ / tensor_para_size_); + weights_size[14] = d_model_; + weights_size[15] = d_model_; + } } FT_LOG_DEBUG("T5EncoderLayerWeight " + std::string(__func__) + " end"); @@ -79,38 +102,38 @@ T5EncoderLayerWeight::~T5EncoderLayerWeight() deviceFree(weights_ptr[i]); } - attention_weights.query_weight.kernel = nullptr; - attention_weights.key_weight.kernel = nullptr; - attention_weights.value_weight.kernel = nullptr; + attention_weights.query_weight.kernel = nullptr; + attention_weights.key_weight.kernel = nullptr; + attention_weights.value_weight.kernel = nullptr; attention_weights.attention_output_weight.kernel = nullptr; - attn_layernorm_weights.gamma = nullptr; - ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.output_weight.kernel = nullptr; - ffn_layernorm_weights.gamma = nullptr; - if (t5_with_bias_) { - attention_weights.query_weight.bias = nullptr; - attention_weights.key_weight.bias = nullptr; - attention_weights.value_weight.bias = nullptr; - attention_weights.attention_output_weight.bias = nullptr; - attn_layernorm_weights.beta = nullptr; - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.bias = nullptr; - ffn_layernorm_weights.beta = nullptr; - is_maintain_buffer = false; - } + attn_layernorm_weights.gamma = nullptr; + ffn_weights.intermediate_weight.kernel = nullptr; + ffn_weights.intermediate_weight2.kernel = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_layernorm_weights.gamma = nullptr; + attention_weights.query_weight.bias = nullptr; + attention_weights.key_weight.bias = nullptr; + attention_weights.value_weight.bias = nullptr; + attention_weights.attention_output_weight.bias = nullptr; + attn_layernorm_weights.beta = nullptr; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.intermediate_weight2.bias = nullptr; + ffn_weights.output_weight.bias = nullptr; + ffn_layernorm_weights.beta = nullptr; + is_maintain_buffer = false; } if (is_maintain_sp_buffer == true) { for (int i = 0; i < 6; i++) { deviceFree(sp_weights_ptr[i]); } - attention_weights.query_weight.sp_kernel = nullptr; - attention_weights.key_weight.sp_kernel = nullptr; - attention_weights.value_weight.sp_kernel = nullptr; + attention_weights.query_weight.sp_kernel = nullptr; + attention_weights.key_weight.sp_kernel = nullptr; + attention_weights.value_weight.sp_kernel = nullptr; attention_weights.attention_output_weight.sp_kernel = nullptr; - ffn_weights.intermediate_weight.sp_kernel = nullptr; - ffn_weights.output_weight.sp_kernel = nullptr; - is_maintain_sp_buffer = false; + ffn_weights.intermediate_weight.sp_kernel = nullptr; + ffn_weights.output_weight.sp_kernel = nullptr; + is_maintain_sp_buffer = false; } FT_LOG_DEBUG("T5EncoderLayerWeight " + std::string(__func__) + " end"); } @@ -141,13 +164,13 @@ T5EncoderLayerWeight& T5EncoderLayerWeight::operator=(const T5EncoderLayer { FT_LOG_DEBUG("T5EncoderLayerWeight " + std::string(__func__) + " start"); - head_num_ = other.head_num_; - size_per_head_ = other.size_per_head_; - d_model_ = other.d_model_; - inter_size_ = other.inter_size_; + head_num_ = other.head_num_; + size_per_head_ = other.size_per_head_; + d_model_ = other.d_model_; + inter_size_ = other.inter_size_; tensor_para_size_ = other.tensor_para_size_; tensor_para_rank_ = other.tensor_para_rank_; - t5_with_bias_ = other.t5_with_bias_; + t5_with_bias_ = other.t5_with_bias_; real_weights_num_ = other.real_weights_num_; initialize(); mallocWeights(); @@ -193,13 +216,13 @@ void T5EncoderLayerWeight::compress_weights(cublasMMWrapper& cublas_wrapper, ffn_weights.intermediate_weight.kernel, sp_weights_ptr[4], inter_size / tensor_para_size_, d_model_); cublas_wrapper.compressMatrix( ffn_weights.output_weight.kernel, sp_weights_ptr[5], d_model_, inter_size / tensor_para_size_); - attention_weights.query_weight.sp_kernel = sp_weights_ptr[0]; - attention_weights.key_weight.sp_kernel = sp_weights_ptr[1]; - attention_weights.value_weight.sp_kernel = sp_weights_ptr[2]; + attention_weights.query_weight.sp_kernel = sp_weights_ptr[0]; + attention_weights.key_weight.sp_kernel = sp_weights_ptr[1]; + attention_weights.value_weight.sp_kernel = sp_weights_ptr[2]; attention_weights.attention_output_weight.sp_kernel = sp_weights_ptr[3]; - ffn_weights.intermediate_weight.sp_kernel = sp_weights_ptr[4]; - ffn_weights.output_weight.sp_kernel = sp_weights_ptr[5]; - is_maintain_sp_buffer = true; + ffn_weights.intermediate_weight.sp_kernel = sp_weights_ptr[4]; + ffn_weights.output_weight.sp_kernel = sp_weights_ptr[5]; + is_maintain_sp_buffer = true; FT_LOG_DEBUG("T5EncoderLayerWeight " + std::string(__func__) + " end"); } #endif @@ -208,25 +231,45 @@ template void T5EncoderLayerWeight::setWeightPtr() { FT_LOG_DEBUG("T5EncoderLayerWeight " + std::string(__func__) + " start"); - attention_weights.query_weight.kernel = weights_ptr[0]; - attention_weights.key_weight.kernel = weights_ptr[1]; - attention_weights.value_weight.kernel = weights_ptr[2]; + attention_weights.query_weight.kernel = weights_ptr[0]; + attention_weights.key_weight.kernel = weights_ptr[1]; + attention_weights.value_weight.kernel = weights_ptr[2]; attention_weights.attention_output_weight.kernel = weights_ptr[3]; - attn_layernorm_weights.gamma = weights_ptr[4]; - ffn_weights.intermediate_weight.kernel = weights_ptr[5]; - ffn_weights.output_weight.kernel = weights_ptr[6]; - ffn_layernorm_weights.gamma = weights_ptr[7]; + attn_layernorm_weights.gamma = weights_ptr[4]; + if (use_gated_activation_) { + ffn_weights.intermediate_weight.kernel = weights_ptr[5]; + ffn_weights.intermediate_weight2.kernel = weights_ptr[6]; + ffn_weights.output_weight.kernel = weights_ptr[7]; + ffn_layernorm_weights.gamma = weights_ptr[8]; + } + else { + ffn_weights.intermediate_weight.kernel = weights_ptr[5]; + ffn_weights.output_weight.kernel = weights_ptr[6]; + ffn_layernorm_weights.gamma = weights_ptr[7]; + } if (t5_with_bias_) { - attention_weights.query_weight.bias = weights_ptr[8]; - attention_weights.key_weight.bias = weights_ptr[9]; - attention_weights.value_weight.bias = weights_ptr[10]; - attention_weights.attention_output_weight.bias = weights_ptr[11]; - attn_layernorm_weights.beta = weights_ptr[12]; - ffn_weights.intermediate_weight.bias = weights_ptr[13]; - ffn_weights.output_weight.bias = weights_ptr[14]; - ffn_layernorm_weights.beta = weights_ptr[15]; - is_maintain_buffer = false; + if (use_gated_activation_) { + attention_weights.query_weight.bias = weights_ptr[9]; + attention_weights.key_weight.bias = weights_ptr[10]; + attention_weights.value_weight.bias = weights_ptr[11]; + attention_weights.attention_output_weight.bias = weights_ptr[12]; + attn_layernorm_weights.beta = weights_ptr[13]; + ffn_weights.intermediate_weight.bias = weights_ptr[14]; + ffn_weights.intermediate_weight2.bias = weights_ptr[15]; + ffn_weights.output_weight.bias = weights_ptr[16]; + ffn_layernorm_weights.beta = weights_ptr[17]; + } + else { + attention_weights.query_weight.bias = weights_ptr[8]; + attention_weights.key_weight.bias = weights_ptr[9]; + attention_weights.value_weight.bias = weights_ptr[10]; + attention_weights.attention_output_weight.bias = weights_ptr[11]; + attn_layernorm_weights.beta = weights_ptr[12]; + ffn_weights.intermediate_weight.bias = weights_ptr[13]; + ffn_weights.output_weight.bias = weights_ptr[14]; + ffn_layernorm_weights.beta = weights_ptr[15]; + } } is_maintain_buffer = true; @@ -252,70 +295,103 @@ void T5EncoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType mod FT_CHECK(is_maintain_buffer == true); loadWeightFromBin(weights_ptr[0], - {(int)weights_size[0]}, + {weights_size[0]}, dir_path + "layer.0.SelfAttention.q.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[1], - {(int)weights_size[1]}, + {weights_size[1]}, dir_path + "layer.0.SelfAttention.k.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[2], - {(int)weights_size[2]}, + {weights_size[2]}, dir_path + "layer.0.SelfAttention.v.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin(weights_ptr[3], - {(int)weights_size[3]}, + {weights_size[3]}, dir_path + "layer.0.SelfAttention.o.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); loadWeightFromBin( - weights_ptr[4], {(int)weights_size[4]}, dir_path + "layer.0.layer_norm.weight.bin", model_file_type); + weights_ptr[4], {weights_size[4]}, dir_path + "layer.0.layer_norm.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[5], - {(int)weights_size[5]}, + {weights_size[5]}, dir_path + "layer.1.DenseReluDense.wi.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin(weights_ptr[6], - {(int)weights_size[6]}, + const int gated_activation_weight_offset = use_gated_activation_ ? 1 : 0; + if (use_gated_activation_) { + loadWeightFromBin(weights_ptr[6], + {weights_size[6]}, + dir_path + "layer.1.DenseReluDense.wi2.weight." + std::to_string(tensor_para_rank_) + + ".bin", + model_file_type); + } + + loadWeightFromBin(weights_ptr[6 + gated_activation_weight_offset], + {weights_size[6 + gated_activation_weight_offset]}, dir_path + "layer.1.DenseReluDense.wo.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin( - weights_ptr[7], {(int)weights_size[7]}, dir_path + "layer.1.layer_norm.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[7 + gated_activation_weight_offset], + {weights_size[7 + gated_activation_weight_offset]}, + dir_path + "layer.1.layer_norm.weight.bin", + model_file_type); if (t5_with_bias_) { - loadWeightFromBin(weights_ptr[8], - {(int)weights_size[8]}, + loadWeightFromBin(weights_ptr[8 + gated_activation_weight_offset], + {weights_size[8 + gated_activation_weight_offset]}, dir_path + "layer.0.SelfAttention.q.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin(weights_ptr[9], - {(int)weights_size[9]}, + loadWeightFromBin(weights_ptr[9 + gated_activation_weight_offset], + {weights_size[9 + gated_activation_weight_offset]}, dir_path + "layer.0.SelfAttention.k.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin(weights_ptr[10], - {(int)weights_size[10]}, + loadWeightFromBin(weights_ptr[10 + gated_activation_weight_offset], + {weights_size[10 + gated_activation_weight_offset]}, dir_path + "layer.0.SelfAttention.v.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin( - weights_ptr[11], {(int)weights_size[11]}, dir_path + "layer.0.SelfAttention.o.bias.bin", model_file_type); - loadWeightFromBin( - weights_ptr[12], {(int)weights_size[12]}, dir_path + "layer.0.layer_norm.bias.bin", model_file_type); - loadWeightFromBin(weights_ptr[13], - {(int)weights_size[13]}, + loadWeightFromBin(weights_ptr[11 + gated_activation_weight_offset], + {weights_size[11 + gated_activation_weight_offset]}, + dir_path + "layer.0.SelfAttention.o.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[12 + gated_activation_weight_offset], + {weights_size[12 + gated_activation_weight_offset]}, + dir_path + "layer.0.layer_norm.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[13 + gated_activation_weight_offset], + {weights_size[13 + gated_activation_weight_offset]}, dir_path + "layer.1.DenseReluDense.wi.bias." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); - loadWeightFromBin( - weights_ptr[14], {(int)weights_size[14]}, dir_path + "layer.1.DenseReluDense.wo.bias.bin", model_file_type); - loadWeightFromBin( - weights_ptr[15], {(int)weights_size[15]}, dir_path + "layer.1.layer_norm.bias.bin", model_file_type); + if (use_gated_activation_) { + loadWeightFromBin(weights_ptr[15], + {weights_size[15]}, + dir_path + "layer.1.DenseReluDense.wi2.bias." + std::to_string(tensor_para_rank_) + + ".bin", + model_file_type); + loadWeightFromBin( + weights_ptr[16], {weights_size[16]}, dir_path + "layer.1.DenseReluDense.wo.bias.bin", model_file_type); + loadWeightFromBin( + weights_ptr[17], {weights_size[17]}, dir_path + "layer.1.layer_norm.bias.bin", model_file_type); + } + else { + loadWeightFromBin( + weights_ptr[14], {weights_size[14]}, dir_path + "layer.1.DenseReluDense.wo.bias.bin", model_file_type); + loadWeightFromBin( + weights_ptr[15], {weights_size[15]}, dir_path + "layer.1.layer_norm.bias.bin", model_file_type); + } } FT_LOG_DEBUG("T5EncoderLayerWeight " + std::string(__func__) + " end"); } template -void T5EncoderLayerWeight::setT5WithBias(bool t5_with_bias_para) +void T5EncoderLayerWeight::setT5WithBias(bool t5_with_bias_para, bool use_gated_activation_para) { - t5_with_bias_ = t5_with_bias_para; + t5_with_bias_ = t5_with_bias_para; + use_gated_activation_ = use_gated_activation_para; } template struct T5EncoderLayerWeight; template struct T5EncoderLayerWeight; +#ifdef ENABLE_BF16 +template struct T5EncoderLayerWeight<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/t5/T5EncoderLayerWeight.h b/src/fastertransformer/models/t5/T5EncoderLayerWeight.h index 1afc71a84..62f5c987e 100644 --- a/src/fastertransformer/models/t5/T5EncoderLayerWeight.h +++ b/src/fastertransformer/models/t5/T5EncoderLayerWeight.h @@ -34,7 +34,8 @@ struct T5EncoderLayerWeight { const size_t inter_size, const size_t tensor_para_size, const size_t tensor_para_rank, - const bool t5_with_bias); + const bool t5_with_bias, + const bool use_gated_activation); ~T5EncoderLayerWeight(); T5EncoderLayerWeight(const T5EncoderLayerWeight& other); T5EncoderLayerWeight& operator=(const T5EncoderLayerWeight& other); @@ -45,13 +46,13 @@ struct T5EncoderLayerWeight { AttentionWeight attention_weights; LayerNormWeight attn_layernorm_weights; - FfnWeight ffn_weights; + FfnWeight ffn_weights; LayerNormWeight ffn_layernorm_weights; - bool t5_with_bias_; + bool t5_with_bias_; + bool use_gated_activation_; void loadModel(std::string dir_path, FtCudaDataType model_file_type); - - void setT5WithBias(bool t5_with_bias_para); + void setT5WithBias(bool t5_with_bias_para, bool use_gated_activation_para); private: void setWeightPtr(); @@ -69,12 +70,12 @@ struct T5EncoderLayerWeight { bool is_maintain_buffer = false; - // Assume bias added - const static int weights_num_ = 16; - T* weights_ptr[weights_num_]; - size_t weights_size[weights_num_]; + // Assume bias added, and gated activation used + const static int weights_num_ = 18; + T* weights_ptr[weights_num_]; + size_t weights_size[weights_num_]; - T* sp_weights_ptr[6]; + T* sp_weights_ptr[6]; bool is_maintain_sp_buffer = false; }; diff --git a/src/fastertransformer/models/t5/T5EncoderWeight.cc b/src/fastertransformer/models/t5/T5EncoderWeight.cc index 7995e0d7e..9665b4d16 100644 --- a/src/fastertransformer/models/t5/T5EncoderWeight.cc +++ b/src/fastertransformer/models/t5/T5EncoderWeight.cc @@ -20,18 +20,19 @@ namespace fastertransformer { template -T5EncoderWeight::T5EncoderWeight(const size_t head_num, - const size_t size_per_head, - const size_t d_model, - const size_t inter_size, - const size_t vocab_size, - const size_t num_layer, - const size_t num_bucket_or_max_seq_len, - const size_t tensor_para_size, - const size_t tensor_para_rank, - const size_t pipeline_para_size, - const size_t pipeline_para_rank, - const bool t5_with_bias_para, +T5EncoderWeight::T5EncoderWeight(const size_t head_num, + const size_t size_per_head, + const size_t d_model, + const size_t inter_size, + const size_t vocab_size, + const size_t num_layer, + const size_t num_bucket_or_max_seq_len, + const size_t tensor_para_size, + const size_t tensor_para_rank, + const size_t pipeline_para_size, + const size_t pipeline_para_rank, + const bool t5_with_bias_para, + const bool use_gated_activation_para, const PositionEmbeddingType pe_type): head_num_(head_num), size_per_head_(size_per_head), @@ -45,10 +46,12 @@ T5EncoderWeight::T5EncoderWeight(const size_t head_num, pipeline_para_size_(pipeline_para_size), pipeline_para_rank_(pipeline_para_rank), t5_with_bias(t5_with_bias_para), + use_gated_activation(use_gated_activation_para), position_embedding_type(pe_type), real_weights_num_(t5_with_bias ? 4 : 3) { FT_LOG_DEBUG("T5EncoderWeight " + std::string(__func__) + " start"); + FT_CHECK(num_layer_ % pipeline_para_size_ == 0); initialize(); mallocWeights(); setWeightPtr(); @@ -56,8 +59,14 @@ T5EncoderWeight::T5EncoderWeight(const size_t head_num, t5_encoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { - t5_encoder_layer_weights.push_back(new T5EncoderLayerWeight( - head_num_, size_per_head, d_model_, inter_size_, tensor_para_size_, tensor_para_rank_, t5_with_bias)); + t5_encoder_layer_weights.push_back(new T5EncoderLayerWeight(head_num_, + size_per_head, + d_model_, + inter_size_, + tensor_para_size_, + tensor_para_rank_, + t5_with_bias, + use_gated_activation)); } else { // Don't malloc and load these layers since we don't use them. @@ -94,10 +103,10 @@ T5EncoderWeight::~T5EncoderWeight() } post_transformer_layernorm_weights.gamma = nullptr; - absolute_or_relative_position_embedding = nullptr; - embedding_table = nullptr; - post_transformer_layernorm_weights.beta = nullptr; - is_maintain_buffer = false; + absolute_or_relative_position_embedding = nullptr; + embedding_table = nullptr; + post_transformer_layernorm_weights.beta = nullptr; + is_maintain_buffer = false; } for (int i = 0; i < num_layer_; i++) { delete t5_encoder_layer_weights[i]; @@ -119,6 +128,7 @@ T5EncoderWeight::T5EncoderWeight(const T5EncoderWeight& other): pipeline_para_size_(other.pipeline_para_size_), pipeline_para_rank_(other.pipeline_para_rank_), t5_with_bias(other.t5_with_bias), + use_gated_activation(other.use_gated_activation), position_embedding_type(other.position_embedding_type), real_weights_num_(other.real_weights_num_) { @@ -143,20 +153,21 @@ T5EncoderWeight& T5EncoderWeight::operator=(const T5EncoderWeight& other) { FT_LOG_DEBUG("T5EncoderWeight " + std::string(__func__) + " start"); - head_num_ = other.head_num_; - size_per_head_ = other.size_per_head_; - d_model_ = other.d_model_; - inter_size_ = other.inter_size_; - vocab_size_ = other.vocab_size_; - num_layer_ = other.num_layer_; + head_num_ = other.head_num_; + size_per_head_ = other.size_per_head_; + d_model_ = other.d_model_; + inter_size_ = other.inter_size_; + vocab_size_ = other.vocab_size_; + num_layer_ = other.num_layer_; num_bucket_or_max_seq_len_ = other.num_bucket_or_max_seq_len_; - tensor_para_size_ = other.tensor_para_size_; - tensor_para_rank_ = other.tensor_para_rank_; - pipeline_para_size_ = other.pipeline_para_size_; - pipeline_para_rank_ = other.pipeline_para_rank_; - t5_with_bias = other.t5_with_bias; - position_embedding_type = other.position_embedding_type; - real_weights_num_ = other.real_weights_num_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + pipeline_para_size_ = other.pipeline_para_size_; + pipeline_para_rank_ = other.pipeline_para_rank_; + t5_with_bias = other.t5_with_bias; + use_gated_activation = other.use_gated_activation; + position_embedding_type = other.position_embedding_type; + real_weights_num_ = other.real_weights_num_; initialize(); mallocWeights(); for (int i = 0; i < real_weights_num_; i++) { @@ -179,8 +190,8 @@ void T5EncoderWeight::setWeightPtr() { FT_LOG_DEBUG("T5EncoderWeight " + std::string(__func__) + " start"); post_transformer_layernorm_weights.gamma = weights_ptr[0]; - absolute_or_relative_position_embedding = weights_ptr[1]; - embedding_table = weights_ptr[2]; + absolute_or_relative_position_embedding = weights_ptr[1]; + embedding_table = weights_ptr[2]; if (t5_with_bias) { post_transformer_layernorm_weights.beta = weights_ptr[3]; } @@ -206,21 +217,23 @@ void T5EncoderWeight::loadModel(std::string dir_path) FT_CHECK(is_maintain_buffer == true); loadWeightFromBin( - weights_ptr[0], {(int)weights_size[0]}, dir_path + "/encoder.final_layer_norm.weight.bin", model_file_type); + weights_ptr[0], {(size_t)weights_size[0]}, dir_path + "/encoder.final_layer_norm.weight.bin", model_file_type); if (position_embedding_type == PositionEmbeddingType::absolute) { - loadWeightFromBin(weights_ptr[1], {(int)weights_size[1]}, dir_path + "/shared.ape.bin", model_file_type); + loadWeightFromBin(weights_ptr[1], {(size_t)weights_size[1]}, dir_path + "/shared.ape.bin", model_file_type); } else { loadWeightFromBin(weights_ptr[1], - {(int)weights_size[1]}, + {(size_t)weights_size[1]}, dir_path + "/encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight." + std::to_string(tensor_para_rank_) + ".bin", model_file_type); } - loadWeightFromBin(weights_ptr[2], {(int)weights_size[2]}, dir_path + "/shared.weight_T.bin", model_file_type); + loadWeightFromBin(weights_ptr[2], {(size_t)weights_size[2]}, dir_path + "/shared.weight_T.bin", model_file_type); if (t5_with_bias) { - loadWeightFromBin( - weights_ptr[3], {(int)weights_size[3]}, dir_path + "/encoder.final_layer_norm.bias.bin", model_file_type); + loadWeightFromBin(weights_ptr[3], + {(size_t)weights_size[3]}, + dir_path + "/encoder.final_layer_norm.bias.bin", + model_file_type); } for (int l = 0; l < num_layer_; l++) { @@ -247,6 +260,7 @@ void T5EncoderWeight::resizeLayer(const int num_layer) FT_LOG_DEBUG("T5EncoderWeight " + std::string(__func__) + " start"); t5_encoder_layer_weights.clear(); num_layer_ = num_layer; + t5_encoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { t5_encoder_layer_weights.push_back(new T5EncoderLayerWeight()); } @@ -254,16 +268,22 @@ void T5EncoderWeight::resizeLayer(const int num_layer) } template -void T5EncoderWeight::setT5StructureDiff(bool t5_with_bias_para, PositionEmbeddingType position_embedding_type_para) +void T5EncoderWeight::setT5StructureDiff(bool t5_with_bias_para, + bool use_gated_activation_para, + PositionEmbeddingType position_embedding_type_para) { - t5_with_bias = t5_with_bias_para; + t5_with_bias = t5_with_bias_para; position_embedding_type = position_embedding_type_para; + use_gated_activation = use_gated_activation_para; for (int i = 0; i < num_layer_; i++) { - t5_encoder_layer_weights[i]->setT5WithBias(t5_with_bias_para); + t5_encoder_layer_weights[i]->setT5WithBias(t5_with_bias_para, use_gated_activation); } } template struct T5EncoderWeight; template struct T5EncoderWeight; +#ifdef ENABLE_BF16 +template struct T5EncoderWeight<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/t5/T5EncoderWeight.h b/src/fastertransformer/models/t5/T5EncoderWeight.h index 0f6b8f394..70a5cf356 100644 --- a/src/fastertransformer/models/t5/T5EncoderWeight.h +++ b/src/fastertransformer/models/t5/T5EncoderWeight.h @@ -25,33 +25,37 @@ template struct T5EncoderWeight { T5EncoderWeight() = default; - T5EncoderWeight(const size_t head_num, - const size_t size_per_head, - const size_t d_model, - const size_t inter_size, - const size_t vocab_size, - const size_t num_layer, - const size_t num_bucket_or_max_seq_len, - const size_t tensor_para_size, - const size_t tensor_para_rank, - const size_t pipeline_para_size, - const size_t pipeline_para_rank, - const bool t5_with_bias_para = false, - const PositionEmbeddingType pe_type = PositionEmbeddingType::relative); + T5EncoderWeight(const size_t head_num, + const size_t size_per_head, + const size_t d_model, + const size_t inter_size, + const size_t vocab_size, + const size_t num_layer, + const size_t num_bucket_or_max_seq_len, + const size_t tensor_para_size, + const size_t tensor_para_rank, + const size_t pipeline_para_size, + const size_t pipeline_para_rank, + const bool t5_with_bias_para = false, + const bool use_gated_activation_para = false, + const PositionEmbeddingType pe_type = PositionEmbeddingType::relative); ~T5EncoderWeight(); T5EncoderWeight(const T5EncoderWeight& other); T5EncoderWeight& operator=(const T5EncoderWeight& other); std::vector*> t5_encoder_layer_weights; - LayerNormWeight post_transformer_layernorm_weights; - T* absolute_or_relative_position_embedding = nullptr; - T* embedding_table = nullptr; - bool t5_with_bias = false; - PositionEmbeddingType position_embedding_type = PositionEmbeddingType::relative; + LayerNormWeight post_transformer_layernorm_weights; + T* absolute_or_relative_position_embedding = nullptr; + T* embedding_table = nullptr; + bool t5_with_bias = false; + bool use_gated_activation = false; + PositionEmbeddingType position_embedding_type = PositionEmbeddingType::relative; void loadModel(std::string dir_path); void resizeLayer(const int num_layer); - void setT5StructureDiff(bool t5_with_bias_para, PositionEmbeddingType position_embedding_type_para); + void setT5StructureDiff(bool t5_with_bias_para, + bool use_gated_activation_para, + PositionEmbeddingType position_embedding_type_para); private: void setWeightPtr(); @@ -66,7 +70,7 @@ struct T5EncoderWeight { size_t vocab_size_; size_t num_layer_; // refer to num_buckt if using relative position embedding - // refer to max_seq_len if using absoulte position embedding + // refer to max_seq_len if using absolute position embedding size_t num_bucket_or_max_seq_len_; size_t tensor_para_size_; size_t tensor_para_rank_; @@ -78,8 +82,8 @@ struct T5EncoderWeight { int real_weights_num_; const static int weights_num_ = 4; - T* weights_ptr[weights_num_]; - size_t weights_size[weights_num_]; + T* weights_ptr[weights_num_]; + size_t weights_size[weights_num_]; }; } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/models/t5/t5_gemm.cc b/src/fastertransformer/models/t5/t5_gemm.cc index 094c1a73b..c718b9358 100644 --- a/src/fastertransformer/models/t5/t5_gemm.cc +++ b/src/fastertransformer/models/t5/t5_gemm.cc @@ -39,33 +39,33 @@ int main(int argc, char* argv[]) " tensor_para_size \\ \n" " is_fp16_compute_type \\ \n" " is_append"); - FT_LOG_ERROR("e.g. ./bin/t5_gemm 8 4 32 512 8 64 2048 512 8 64 2048 32100 1 2 1 0"); + FT_LOG_ERROR("e.g. ./bin/t5_gemm 8 4 32 512 8 64 2048 512 8 64 2048 32100 1 2 0 0"); return 0; } - const int batch_size = atoi(argv[1]); - const int beam_width = atoi(argv[2]); + const int batch_size = atoi(argv[1]); + const int beam_width = atoi(argv[2]); const int max_mem_seq_len = atoi(argv[3]); - const int encoder_d_model = atoi(argv[4]); - const int encoder_head_num = atoi(argv[5]); + const int encoder_d_model = atoi(argv[4]); + const int encoder_head_num = atoi(argv[5]); const int encoder_size_per_head = atoi(argv[6]); - const int encoder_inter_size = atoi(argv[7]); + const int encoder_inter_size = atoi(argv[7]); - const int decoder_d_model = atoi(argv[8]); - const int decoder_head_num = atoi(argv[9]); + const int decoder_d_model = atoi(argv[8]); + const int decoder_head_num = atoi(argv[9]); const int decoder_size_per_head = atoi(argv[10]); - const int decoder_inter_size = atoi(argv[11]); - const int decoder_vocab_size = atoi(argv[12]); + const int decoder_inter_size = atoi(argv[11]); + const int decoder_vocab_size = atoi(argv[12]); const ft::CublasDataType data_type = static_cast(atoi(argv[13])); // 0 FP32, 1 FP16, 2 BF 16 - const int tensor_para_size = argc <= 15 ? 1 : atoi(argv[14]); - int is_fp16_compute_type = argc <= 16 ? 0 : atoi(argv[15]); + const int tensor_para_size = argc < 15 ? 1 : atoi(argv[14]); + int is_fp16_compute_type = argc < 16 ? 0 : atoi(argv[15]); if (data_type == ft::BFLOAT16_DATATYPE && is_fp16_compute_type != 0) { printf("[ERROR] BFLOAT16_DATATYPE does not support is_fp16_compute_type = True\n"); return 0; } - const bool is_append = argc <= 17 ? false : (bool)(atoi(argv[16])); + const bool is_append = argc < 17 ? false : (bool)(atoi(argv[16])); std::cout << "[INFO] arguments: " << std::endl << " batch_size: " << batch_size << std::endl @@ -82,9 +82,9 @@ int main(int argc, char* argv[]) << " decoder_vocab_size: " << decoder_vocab_size << std::endl << " data_type: " << data_type << std::endl << " tensor_para_size: " << tensor_para_size << std::endl - << " is_fp16_compute_type: " << is_fp16_compute_type << std::endl; - - void* gemm_test_buf; + << " is_fp16_compute_type: " << is_fp16_compute_type << std::endl + << " is_append:" << is_append << std::endl; + void* gemm_test_buf; size_t buf_size_in_byte = ft::calT5GemmTestBufSizeInByte(batch_size, beam_width, max_mem_seq_len, @@ -170,8 +170,8 @@ int main(int argc, char* argv[]) } #endif else { - printf("[ERROR] data type only supports fp32(0), fp16(1), bf16(2). \n"); - return -1; + FT_LOG_ERROR("data type %d is invalid, only supports fp32(0), fp16(1), bf16(2).", (int)(data_type)); + ft::FT_CHECK(false); } ft::check_cuda_error(cudaFree(gemm_test_buf)); diff --git a/src/fastertransformer/models/vit/CMakeLists.txt b/src/fastertransformer/models/vit/CMakeLists.txt index 56b66af7c..295decf97 100644 --- a/src/fastertransformer/models/vit/CMakeLists.txt +++ b/src/fastertransformer/models/vit/CMakeLists.txt @@ -19,7 +19,7 @@ set_property(TARGET ViT PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET ViT PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(ViT PUBLIC -lcudart -lcublasLt -lcublas cublasMMWrapper UnfusedAttentionLayer FusedAttentionLayer FfnLayer layernorm_kernels - add_residual_kernels activation_kernels vit_kernels bert_preprocess_kernels) + add_residual_kernels activation_kernels vit_kernels bert_preprocess_kernels tensor) add_executable(vit_gemm vit_gemm.cc) -target_link_libraries(vit_gemm PUBLIC -lcublas -lcublasLt -lcudart encoder_gemm_func encoder_igemm_func memory_utils) \ No newline at end of file +target_link_libraries(vit_gemm PUBLIC -lcublas -lcublasLt -lcudart encoder_gemm_func encoder_igemm_func memory_utils tensor) diff --git a/src/fastertransformer/models/vit/ViT.cc b/src/fastertransformer/models/vit/ViT.cc index e785f2bc6..ce78204f3 100644 --- a/src/fastertransformer/models/vit/ViT.cc +++ b/src/fastertransformer/models/vit/ViT.cc @@ -49,7 +49,7 @@ void ViTTransformer::initialize() int(attention_type_)); if (img_size_ % patch_size_ != 0) { std::ostringstream buffer; - buffer << "[FT][ERROR] IMG size & PITCH size missmatch. " << img_size_ << " % " << patch_size_ << " !=0 \n"; + buffer << "[FT][ERROR] IMG size & PITCH size mismatch. " << img_size_ << " % " << patch_size_ << " !=0 \n"; throw std::runtime_error(buffer.str()); } @@ -68,6 +68,7 @@ void ViTTransformer::initialize() max_seq_len_, head_num_, head_dim_, + head_num_ * head_dim_, sm_, q_scaling_, stream_, @@ -111,23 +112,23 @@ void ViTTransformer::initialize() } template -ViTTransformer::ViTTransformer(size_t max_batch_size, - size_t img_size, - size_t chn_num, - size_t patch_size, - size_t embed_dim, - size_t head_num, - size_t inter_size, - size_t num_layer, - bool with_cls_token, - int sm, - float q_scaling, - cudaStream_t stream, - cudnnHandle_t cudnn_handle, +ViTTransformer::ViTTransformer(size_t max_batch_size, + size_t img_size, + size_t chn_num, + size_t patch_size, + size_t embed_dim, + size_t head_num, + size_t inter_size, + size_t num_layer, + bool with_cls_token, + int sm, + float q_scaling, + cudaStream_t stream, + cudnnHandle_t cudnn_handle, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type): + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), img_size_(img_size), @@ -184,15 +185,21 @@ template void ViTTransformer::allocateBuffer() { if (is_allocate_buffer_ == false) { - embed_buf_1_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); - embed_buf_2_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); - embed_buf_3_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); - mask_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); - padding_offset_ = (int*)allocator_->malloc(sizeof(int) * max_batch_size_ * max_seq_len_, false); - token_num_ = (size_t*)allocator_->malloc(sizeof(size_t) * 1, false); - - trt_mha_padding_offset_ = (int*)allocator_->malloc(sizeof(int) * (2 * max_batch_size_ + 1), false); - seq_len_vec_ = (int*)allocator_->malloc(sizeof(int) * max_batch_size_, false); + embed_buf_1_ = + (T*)allocator_->reMalloc(embed_buf_1_, sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); + embed_buf_2_ = + (T*)allocator_->reMalloc(embed_buf_2_, sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); + embed_buf_3_ = + (T*)allocator_->reMalloc(embed_buf_3_, sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); + mask_buf_ = + (T*)allocator_->reMalloc(mask_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); + padding_offset_ = + (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * max_batch_size_ * max_seq_len_, false); + token_num_ = (size_t*)allocator_->reMalloc(token_num_, sizeof(size_t) * 1, false); + + trt_mha_padding_offset_ = + (int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * max_batch_size_ + 1), false); + seq_len_vec_ = (int*)allocator_->reMalloc(seq_len_vec_, sizeof(int) * max_batch_size_, false); setSeqLenVec(max_batch_size_); setDefaultMask(max_batch_size_); @@ -235,22 +242,22 @@ void ViTTransformer::allocateBuffer(size_t batch_size) template void ViTTransformer::freeBuffer() { - allocator_->free(embed_buf_1_); - allocator_->free(embed_buf_2_); - allocator_->free(embed_buf_3_); - allocator_->free(mask_buf_); - allocator_->free(trt_mha_padding_offset_); - allocator_->free(seq_len_vec_); - allocator_->free(padding_offset_); - allocator_->free(token_num_); + allocator_->free((void**)(&embed_buf_1_)); + allocator_->free((void**)(&embed_buf_2_)); + allocator_->free((void**)(&embed_buf_3_)); + allocator_->free((void**)(&mask_buf_)); + allocator_->free((void**)(&trt_mha_padding_offset_)); + allocator_->free((void**)(&seq_len_vec_)); + allocator_->free((void**)(&padding_offset_)); + allocator_->free((void**)(&token_num_)); is_allocate_buffer_ = false; } template -void ViTTransformer::forward(std::vector* output_tensors, +void ViTTransformer::forward(std::vector* output_tensors, const std::vector* input_tensors, - const ViTWeight* weights) + const ViTWeight* weights) { // input_tensors: // input_img, BCHW [batch, chn_num, img_size, img_size] @@ -258,11 +265,11 @@ void ViTTransformer::forward(std::vector* output_tensors, // output feature_map [batch, seq_len, embed_dim] const size_t input_batch_size = input_tensors->at(0).shape[0]; - const size_t input_chn_num = input_tensors->at(0).shape[1]; - const size_t input_img_size = input_tensors->at(0).shape[2]; - const size_t patch_resol = input_img_size / patch_size_; - size_t seq_len = patch_resol * patch_resol + (with_cls_token_ ? 1 : 0); - const bool need_padding = + const size_t input_chn_num = input_tensors->at(0).shape[1]; + const size_t input_img_size = input_tensors->at(0).shape[2]; + const size_t patch_resol = input_img_size / patch_size_; + size_t seq_len = patch_resol * patch_resol + (with_cls_token_ ? 1 : 0); + const bool need_padding = (attention_type_ == AttentionType::UNFUSED_MHA && seq_len % 8 != 0 && std::is_same::value); FT_CHECK(input_img_size == img_size_); @@ -273,9 +280,9 @@ void ViTTransformer::forward(std::vector* output_tensors, FT_CHECK(output_tensors->at(0).shape.size() == 3); allocateBuffer(input_batch_size); - const T* input = (const T*)input_tensors->at(0).data; - T* output = (T*)output_tensors->at(0).data; - T* encoder_input_ptr = embed_buf_1_; + const T* input = (const T*)input_tensors->at(0).data; + T* output = (T*)output_tensors->at(0).data; + T* encoder_input_ptr = embed_buf_1_; // preprocess (patches embedding, concat class embed and add pos embed) patchEmbed(need_padding ? embed_buf_2_ : encoder_input_ptr, @@ -304,7 +311,7 @@ void ViTTransformer::forward(std::vector* output_tensors, else { offset_tensor_ptr = new Tensor(MEMORY_GPU, TYPE_INT32, std::vector{0}, nullptr); if (need_padding) { - seq_len = (seq_len + 7) / 8 * 8; + seq_len = (seq_len + 7) / 8 * 8; h_token_num = seq_len * input_batch_size; cudaMemsetAsync(encoder_input_ptr, 0, sizeof(T) * input_batch_size * seq_len * embed_dim_, stream_); invokeRebuildPadding( @@ -312,9 +319,9 @@ void ViTTransformer::forward(std::vector* output_tensors, } } - T* from_buf = encoder_input_ptr; - T* norm_out_buf = embed_buf_2_; - T* attn_out_buf = embed_buf_3_; + T* from_buf = encoder_input_ptr; + T* norm_out_buf = embed_buf_2_; + T* attn_out_buf = embed_buf_3_; T* encoder_out_buf = from_buf; for (uint i = 0; i < num_layer_; i++) { @@ -323,6 +330,7 @@ void ViTTransformer::forward(std::vector* output_tensors, from_buf, weights->vit_layer_weights[i].attn_layernorm_weights.gamma, weights->vit_layer_weights[i].attn_layernorm_weights.beta, + layernorm_eps_, h_token_num, embed_dim_, stream_); @@ -347,6 +355,7 @@ void ViTTransformer::forward(std::vector* output_tensors, weights->vit_layer_weights[i].ffn_layernorm_weights.gamma, weights->vit_layer_weights[i].ffn_layernorm_weights.beta, weights->vit_layer_weights[i].attention_weights.attention_output_weight.bias, + layernorm_eps_, h_token_num, embed_dim_, stream_); @@ -374,6 +383,7 @@ void ViTTransformer::forward(std::vector* output_tensors, from_buf, weights->post_transformer_layernorm_weights.gamma, weights->post_transformer_layernorm_weights.beta, + layernorm_eps_, h_token_num, embed_dim_, stream_); @@ -426,12 +436,12 @@ void ViTTransformer::setDefaultPaddingOffset(size_t batch_size) } template -void ViTTransformer::patchEmbed(T* output, - const T* input, - const T* kernel, - const T* bias, - const T* cls_embed, - const T* pos_embed, +void ViTTransformer::patchEmbed(T* output, + const T* input, + const T* kernel, + const T* bias, + const T* cls_embed, + const T* pos_embed, const int batch, const int img_size, const int patch_size, diff --git a/src/fastertransformer/models/vit/ViT.h b/src/fastertransformer/models/vit/ViT.h index 27fda8683..2faff9ed6 100644 --- a/src/fastertransformer/models/vit/ViT.h +++ b/src/fastertransformer/models/vit/ViT.h @@ -30,26 +30,27 @@ namespace fastertransformer { template class ViTTransformer: public BaseLayer { private: - size_t max_batch_size_ = 0; - size_t img_size_ = 224; - size_t chn_num_ = 3; - size_t patch_size_ = 16; // preproc patch size - size_t max_seq_len_; - size_t request_seq_len_; - size_t embed_dim_; // patch conv out units, size_per_head = embed_dim / head_num - size_t head_num_; // mha head num - size_t head_dim_; // mha head size - size_t inter_size_; // FF internal size - size_t num_layer_; - size_t nopad_token_num_; - bool with_cls_token_; - int sm_; - float q_scaling_; - AttentionType attention_type_; - cudnnHandle_t cudnn_handle_; + size_t max_batch_size_ = 0; + size_t img_size_ = 224; + size_t chn_num_ = 3; + size_t patch_size_ = 16; // preproc patch size + size_t max_seq_len_; + size_t request_seq_len_; + size_t embed_dim_; // patch conv out units, size_per_head = embed_dim / head_num + size_t head_num_; // mha head num + size_t head_dim_; // mha head size + size_t inter_size_; // FF internal size + size_t num_layer_; + size_t nopad_token_num_; + bool with_cls_token_; + int sm_; + static constexpr float layernorm_eps_ = 1e-6f; + float q_scaling_; + AttentionType attention_type_; + cudnnHandle_t cudnn_handle_; BaseAttentionLayer* attention_layer_; - FfnLayer* ffn_layer_; + FfnLayer* ffn_layer_; bool is_allocate_buffer_ = false; @@ -59,12 +60,12 @@ class ViTTransformer: public BaseLayer { bool setSeqLenVec(size_t batch_size); void setDefaultMask(size_t batch_size); void setDefaultPaddingOffset(size_t batch_size); - void patchEmbed(T* output, - const T* input, - const T* kernel, - const T* bias, - const T* cls_embed, - const T* pos_embed, + void patchEmbed(T* output, + const T* input, + const T* kernel, + const T* bias, + const T* cls_embed, + const T* pos_embed, const int batch, const int img_size, const int patch_size, @@ -77,33 +78,33 @@ class ViTTransformer: public BaseLayer { protected: // size_t* token_num_ = nullptr; - T* embed_buf_1_ = nullptr; - T* embed_buf_2_ = nullptr; - T* embed_buf_3_ = nullptr; - T* mask_buf_ = nullptr; - int* trt_mha_padding_offset_ = nullptr; - int* seq_len_vec_ = nullptr; - int* padding_offset_ = nullptr; - size_t* token_num_ = nullptr; + T* embed_buf_1_ = nullptr; + T* embed_buf_2_ = nullptr; + T* embed_buf_3_ = nullptr; + T* mask_buf_ = nullptr; + int* trt_mha_padding_offset_ = nullptr; + int* seq_len_vec_ = nullptr; + int* padding_offset_ = nullptr; + size_t* token_num_ = nullptr; public: - ViTTransformer(size_t max_batch_size, - size_t img_size, - size_t chn_num, - size_t patch_size, - size_t embed_dim, - size_t head_num, - size_t inter_size, - size_t num_layer, - bool with_cls_token, - int sm, - float q_scaling, - cudaStream_t stream, - cudnnHandle_t cudnn_handle, + ViTTransformer(size_t max_batch_size, + size_t img_size, + size_t chn_num, + size_t patch_size, + size_t embed_dim, + size_t head_num, + size_t inter_size, + size_t num_layer, + bool with_cls_token, + int sm, + float q_scaling, + cudaStream_t stream, + cudnnHandle_t cudnn_handle, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type); + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type); ViTTransformer(ViTTransformer const& vit_layer); diff --git a/src/fastertransformer/models/vit/ViTLayerWeight.h b/src/fastertransformer/models/vit/ViTLayerWeight.h index 69f5d824f..cb5d4c78a 100644 --- a/src/fastertransformer/models/vit/ViTLayerWeight.h +++ b/src/fastertransformer/models/vit/ViTLayerWeight.h @@ -33,16 +33,16 @@ struct ViTLayerWeight { ViTLayerWeight(const int embed_dim, const int inter_size, int layer_idx, const bool hold_buffer): embed_dim_(embed_dim), inter_size_(inter_size), layer_idx_(layer_idx) { - weights_size[0] = embed_dim_ * embed_dim_; - weights_size[1] = embed_dim_; - weights_size[2] = embed_dim_ * embed_dim_; - weights_size[3] = embed_dim_; - weights_size[4] = embed_dim_ * embed_dim_; - weights_size[5] = embed_dim_; - weights_size[6] = embed_dim_ * embed_dim_; - weights_size[7] = embed_dim_; - weights_size[8] = embed_dim_; - weights_size[9] = embed_dim_; + weights_size[0] = embed_dim_ * embed_dim_; + weights_size[1] = embed_dim_; + weights_size[2] = embed_dim_ * embed_dim_; + weights_size[3] = embed_dim_; + weights_size[4] = embed_dim_ * embed_dim_; + weights_size[5] = embed_dim_; + weights_size[6] = embed_dim_ * embed_dim_; + weights_size[7] = embed_dim_; + weights_size[8] = embed_dim_; + weights_size[9] = embed_dim_; weights_size[10] = embed_dim_ * inter_size_; weights_size[11] = inter_size_; weights_size[12] = inter_size_ * embed_dim_; @@ -65,23 +65,23 @@ struct ViTLayerWeight { deviceFree(weights_ptr[i]); } - attention_weights.query_weight.kernel = nullptr; - attention_weights.query_weight.bias = nullptr; - attention_weights.key_weight.kernel = nullptr; - attention_weights.key_weight.bias = nullptr; - attention_weights.value_weight.kernel = nullptr; - attention_weights.value_weight.bias = nullptr; + attention_weights.query_weight.kernel = nullptr; + attention_weights.query_weight.bias = nullptr; + attention_weights.key_weight.kernel = nullptr; + attention_weights.key_weight.bias = nullptr; + attention_weights.value_weight.kernel = nullptr; + attention_weights.value_weight.bias = nullptr; attention_weights.attention_output_weight.kernel = nullptr; - attention_weights.attention_output_weight.bias = nullptr; - attn_layernorm_weights.gamma = nullptr; - attn_layernorm_weights.beta = nullptr; - ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.kernel = nullptr; - ffn_weights.output_weight.bias = nullptr; - ffn_layernorm_weights.gamma = nullptr; - ffn_layernorm_weights.beta = nullptr; - is_maintain_buffer = false; + attention_weights.attention_output_weight.bias = nullptr; + attn_layernorm_weights.gamma = nullptr; + attn_layernorm_weights.beta = nullptr; + ffn_weights.intermediate_weight.kernel = nullptr; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + ffn_layernorm_weights.gamma = nullptr; + ffn_layernorm_weights.beta = nullptr; + is_maintain_buffer = false; } } @@ -103,9 +103,9 @@ struct ViTLayerWeight { ViTLayerWeight& operator=(const ViTLayerWeight& other) { - embed_dim_ = other.embed_dim_; + embed_dim_ = other.embed_dim_; inter_size_ = other.inter_size_; - layer_idx_ = other.layer_idx_; + layer_idx_ = other.layer_idx_; memcpy(weights_size, other.weights_size, sizeof(size_t) * WEIGHT_N); if (other.is_maintain_buffer) { for (int i = 0; i < WEIGHT_N; i++) { @@ -231,78 +231,78 @@ struct ViTLayerWeight { Tensor w1{ MEMORY_GPU, dtype, std::vector{embed_dim_, embed_dim_}, attention_weights.query_weight.kernel}; - w1.save(std::string(buffer.str()) + "_q_kern.npy"); + w1.saveNpy(std::string(buffer.str()) + "_q_kern.npy"); Tensor w2{MEMORY_GPU, dtype, std::vector{embed_dim_}, attention_weights.query_weight.bias}; - w2.save(std::string(buffer.str()) + "_q_bias.npy"); + w2.saveNpy(std::string(buffer.str()) + "_q_bias.npy"); Tensor w3{MEMORY_GPU, dtype, std::vector{embed_dim_, embed_dim_}, attention_weights.key_weight.kernel}; - w3.save(std::string(buffer.str()) + "_k_kern.npy"); + w3.saveNpy(std::string(buffer.str()) + "_k_kern.npy"); Tensor w4{MEMORY_GPU, dtype, std::vector{embed_dim_}, attention_weights.key_weight.bias}; - w4.save(std::string(buffer.str()) + "_k_bias.npy"); + w4.saveNpy(std::string(buffer.str()) + "_k_bias.npy"); Tensor w5{ MEMORY_GPU, dtype, std::vector{embed_dim_, embed_dim_}, attention_weights.value_weight.kernel}; - w5.save(std::string(buffer.str()) + "_v_kern.npy"); + w5.saveNpy(std::string(buffer.str()) + "_v_kern.npy"); Tensor w6{MEMORY_GPU, dtype, std::vector{embed_dim_}, attention_weights.value_weight.bias}; - w6.save(std::string(buffer.str()) + "_v_bias.npy"); + w6.saveNpy(std::string(buffer.str()) + "_v_bias.npy"); Tensor w7{MEMORY_GPU, dtype, std::vector{embed_dim_, embed_dim_}, attention_weights.attention_output_weight.kernel}; - w7.save(std::string(buffer.str()) + "_att_o_kern.npy"); + w7.saveNpy(std::string(buffer.str()) + "_att_o_kern.npy"); Tensor w8{MEMORY_GPU, dtype, std::vector{embed_dim_}, attention_weights.attention_output_weight.bias}; - w8.save(std::string(buffer.str()) + "_att_o_bias.npy"); + w8.saveNpy(std::string(buffer.str()) + "_att_o_bias.npy"); Tensor w9{MEMORY_GPU, dtype, std::vector{embed_dim_}, attn_layernorm_weights.gamma}; - w9.save(std::string(buffer.str()) + "_ln0_scale.npy"); + w9.saveNpy(std::string(buffer.str()) + "_ln0_scale.npy"); Tensor w10{MEMORY_GPU, dtype, std::vector{embed_dim_}, attn_layernorm_weights.beta}; - w10.save(std::string(buffer.str()) + "_ln0_bias.npy"); + w10.saveNpy(std::string(buffer.str()) + "_ln0_bias.npy"); Tensor w11{ MEMORY_GPU, dtype, std::vector{embed_dim_, inter_size_}, ffn_weights.intermediate_weight.kernel}; - w11.save(std::string(buffer.str()) + "_ffn_inter_kern.npy"); + w11.saveNpy(std::string(buffer.str()) + "_ffn_inter_kern.npy"); Tensor w12{MEMORY_GPU, dtype, std::vector{inter_size_}, ffn_weights.intermediate_weight.bias}; - w12.save(std::string(buffer.str()) + "_ffn_inter_bias.npy"); + w12.saveNpy(std::string(buffer.str()) + "_ffn_inter_bias.npy"); Tensor w13{MEMORY_GPU, dtype, std::vector{inter_size_, embed_dim_}, ffn_weights.output_weight.kernel}; - w13.save(std::string(buffer.str()) + "_ffn_o_kern.npy"); + w13.saveNpy(std::string(buffer.str()) + "_ffn_o_kern.npy"); Tensor w14{MEMORY_GPU, dtype, std::vector{embed_dim_}, ffn_weights.output_weight.bias}; - w14.save(std::string(buffer.str()) + "_ffn_o_bias.npy"); + w14.saveNpy(std::string(buffer.str()) + "_ffn_o_bias.npy"); Tensor w15{MEMORY_GPU, dtype, std::vector{embed_dim_}, ffn_layernorm_weights.gamma}; - w15.save(std::string(buffer.str()) + "_ln2_scale.npy"); + w15.saveNpy(std::string(buffer.str()) + "_ln2_scale.npy"); Tensor w16{MEMORY_GPU, dtype, std::vector{embed_dim_}, ffn_layernorm_weights.beta}; - w16.save(std::string(buffer.str()) + "_ln2_bias.npy"); + w16.saveNpy(std::string(buffer.str()) + "_ln2_bias.npy"); } AttentionWeight attention_weights; LayerNormWeight attn_layernorm_weights; - FfnWeight ffn_weights; + FfnWeight ffn_weights; LayerNormWeight ffn_layernorm_weights; private: void setWeightPtr() { - attention_weights.query_weight.kernel = weights_ptr[0]; - attention_weights.query_weight.bias = weights_ptr[1]; - attention_weights.key_weight.kernel = weights_ptr[2]; - attention_weights.key_weight.bias = weights_ptr[3]; - attention_weights.value_weight.kernel = weights_ptr[4]; - attention_weights.value_weight.bias = weights_ptr[5]; + attention_weights.query_weight.kernel = weights_ptr[0]; + attention_weights.query_weight.bias = weights_ptr[1]; + attention_weights.key_weight.kernel = weights_ptr[2]; + attention_weights.key_weight.bias = weights_ptr[3]; + attention_weights.value_weight.kernel = weights_ptr[4]; + attention_weights.value_weight.bias = weights_ptr[5]; attention_weights.attention_output_weight.kernel = weights_ptr[6]; - attention_weights.attention_output_weight.bias = weights_ptr[7]; - attn_layernorm_weights.gamma = weights_ptr[8]; - attn_layernorm_weights.beta = weights_ptr[9]; - ffn_weights.intermediate_weight.kernel = weights_ptr[10]; - ffn_weights.intermediate_weight.bias = weights_ptr[11]; - ffn_weights.output_weight.kernel = weights_ptr[12]; - ffn_weights.output_weight.bias = weights_ptr[13]; - ffn_layernorm_weights.gamma = weights_ptr[14]; - ffn_layernorm_weights.beta = weights_ptr[15]; + attention_weights.attention_output_weight.bias = weights_ptr[7]; + attn_layernorm_weights.gamma = weights_ptr[8]; + attn_layernorm_weights.beta = weights_ptr[9]; + ffn_weights.intermediate_weight.kernel = weights_ptr[10]; + ffn_weights.intermediate_weight.bias = weights_ptr[11]; + ffn_weights.output_weight.kernel = weights_ptr[12]; + ffn_weights.output_weight.bias = weights_ptr[13]; + ffn_layernorm_weights.gamma = weights_ptr[14]; + ffn_layernorm_weights.beta = weights_ptr[15]; is_maintain_buffer = true; } - int embed_dim_; - int inter_size_; - int layer_idx_; - bool is_maintain_buffer = false; - T* weights_ptr[WEIGHT_N]{nullptr}; + int embed_dim_; + int inter_size_; + int layer_idx_; + bool is_maintain_buffer = false; + T* weights_ptr[WEIGHT_N]{nullptr}; size_t weights_size[WEIGHT_N]; - bool is_maintain_sp_buffer = false; + bool is_maintain_sp_buffer = false; }; #undef WEIGHT_N diff --git a/src/fastertransformer/models/vit/ViTWeight.h b/src/fastertransformer/models/vit/ViTWeight.h index 6d3b7ea66..cdfb64ee3 100644 --- a/src/fastertransformer/models/vit/ViTWeight.h +++ b/src/fastertransformer/models/vit/ViTWeight.h @@ -32,12 +32,12 @@ template struct ViTWeight { ViTWeight() = delete; - ViTWeight(const int embed_dim, - const int inter_size, - const int num_layer, - const int img_size, - const int patch_size, - const int chn_num, + ViTWeight(const int embed_dim, + const int inter_size, + const int num_layer, + const int img_size, + const int patch_size, + const int chn_num, const bool with_cls_token, const bool hold_buffer = true): with_cls_token_(with_cls_token), @@ -85,12 +85,12 @@ struct ViTWeight { } post_transformer_layernorm_weights.gamma = nullptr; - post_transformer_layernorm_weights.beta = nullptr; - pre_transform_embeds.class_embed = nullptr; - pre_transform_embeds.position_embed = nullptr; - pre_encoder_conv_weights.kernel = nullptr; - pre_encoder_conv_weights.bias = nullptr; - is_maintain_buffer = false; + post_transformer_layernorm_weights.beta = nullptr; + pre_transform_embeds.class_embed = nullptr; + pre_transform_embeds.position_embed = nullptr; + pre_encoder_conv_weights.kernel = nullptr; + pre_encoder_conv_weights.bias = nullptr; + is_maintain_buffer = false; } } @@ -124,13 +124,13 @@ struct ViTWeight { ViTWeight& operator=(const ViTWeight& other) { - embed_dim_ = other.embed_dim_; - inter_size_ = other.inter_size_; - num_layer_ = other.num_layer_; - img_size_ = other.img_size_; - patch_size_ = other.patch_size_; - chn_num_ = other.chn_num_; - seq_len_ = other.seq_len_; + embed_dim_ = other.embed_dim_; + inter_size_ = other.inter_size_; + num_layer_ = other.num_layer_; + img_size_ = other.img_size_; + patch_size_ = other.patch_size_; + chn_num_ = other.chn_num_; + seq_len_ = other.seq_len_; with_cls_token_ = other.with_cls_token_; memcpy(weights_size, other.weights_size, sizeof(size_t) * WEIGHT_N); @@ -255,20 +255,20 @@ struct ViTWeight { dtype, std::vector{embed_dim_, chn_num_, patch_size_, patch_size_}, pre_encoder_conv_weights.kernel}; // OIHW - conv_kernel.save("./weights/conv_kernel.npy"); + conv_kernel.saveNpy("./weights/conv_kernel.npy"); Tensor conv_bias{MEMORY_GPU, dtype, std::vector{embed_dim_}, pre_encoder_conv_weights.bias}; // OIHW - conv_bias.save("./weights/conv_bias.npy"); + conv_bias.saveNpy("./weights/conv_bias.npy"); Tensor cls_token{MEMORY_GPU, dtype, std::vector{embed_dim_}, pre_transform_embeds.class_embed}; if (with_cls_token_) { - cls_token.save("./weights/cls_token.npy"); + cls_token.saveNpy("./weights/cls_token.npy"); } Tensor pos_embed{ MEMORY_GPU, dtype, std::vector{1, seq_len_, embed_dim_}, pre_transform_embeds.position_embed}; - pos_embed.save("./weights/pos_embed.npy"); + pos_embed.saveNpy("./weights/pos_embed.npy"); Tensor ln_gamma{MEMORY_GPU, dtype, std::vector{embed_dim_}, post_transformer_layernorm_weights.gamma}; - ln_gamma.save("./weights/enc_ln_scale.npy"); + ln_gamma.saveNpy("./weights/enc_ln_scale.npy"); Tensor ln_beta{MEMORY_GPU, dtype, std::vector{embed_dim_}, post_transformer_layernorm_weights.beta}; - ln_beta.save("./weights/enc_ln_bias.npy"); + ln_beta.saveNpy("./weights/enc_ln_bias.npy"); for (int i = 0; i < num_layer_; i++) { vit_layer_weights[i].ExportWeights(i); @@ -276,32 +276,32 @@ struct ViTWeight { } std::vector> vit_layer_weights; - LayerNormWeight post_transformer_layernorm_weights; - ViTEmbeds pre_transform_embeds; - DenseWeight pre_encoder_conv_weights; - bool with_cls_token_; + LayerNormWeight post_transformer_layernorm_weights; + ViTEmbeds pre_transform_embeds; + DenseWeight pre_encoder_conv_weights; + bool with_cls_token_; private: void setWeightPtr() { - pre_encoder_conv_weights.kernel = weights_ptr[0]; - pre_encoder_conv_weights.bias = weights_ptr[1]; - pre_transform_embeds.class_embed = weights_ptr[2]; - pre_transform_embeds.position_embed = weights_ptr[3]; + pre_encoder_conv_weights.kernel = weights_ptr[0]; + pre_encoder_conv_weights.bias = weights_ptr[1]; + pre_transform_embeds.class_embed = weights_ptr[2]; + pre_transform_embeds.position_embed = weights_ptr[3]; post_transformer_layernorm_weights.gamma = weights_ptr[4]; - post_transformer_layernorm_weights.beta = weights_ptr[5]; + post_transformer_layernorm_weights.beta = weights_ptr[5]; is_maintain_buffer = true; } - int embed_dim_; - int inter_size_; - int num_layer_; - int img_size_; - int patch_size_; - int chn_num_; - int seq_len_; - bool is_maintain_buffer = false; - T* weights_ptr[WEIGHT_N]{nullptr}; + int embed_dim_; + int inter_size_; + int num_layer_; + int img_size_; + int patch_size_; + int chn_num_; + int seq_len_; + bool is_maintain_buffer = false; + T* weights_ptr[WEIGHT_N]{nullptr}; size_t weights_size[WEIGHT_N]; }; diff --git a/src/fastertransformer/models/vit/vit_gemm.cc b/src/fastertransformer/models/vit/vit_gemm.cc index 00281561c..7cd74305d 100644 --- a/src/fastertransformer/models/vit/vit_gemm.cc +++ b/src/fastertransformer/models/vit/vit_gemm.cc @@ -29,14 +29,14 @@ int main(int argc, char* argv[]) return 0; } - const int batch_size = atoi(argv[1]); - const int img_size = atoi(argv[2]); - const int patch_size = atoi(argv[3]); - const int embed_dim = atoi(argv[4]); - const int head_num = atoi(argv[5]); - const int with_cls_token = atoi(argv[6]); + const int batch_size = atoi(argv[1]); + const int img_size = atoi(argv[2]); + const int patch_size = atoi(argv[3]); + const int embed_dim = atoi(argv[4]); + const int head_num = atoi(argv[5]); + const int with_cls_token = atoi(argv[6]); const ft::CublasDataType data_type = static_cast(atoi(argv[7])); // 0 FP32, 1 FP16, 2 BF 16 - const int int8_mode = atoi(argv[8]); + const int int8_mode = atoi(argv[8]); printf("[INFO] arguments: \n"); printf(" batch_size: %d \n", batch_size); @@ -58,7 +58,7 @@ int main(int argc, char* argv[]) } const int patch_resol = img_size / patch_size; - int seq_len = patch_resol * patch_resol + (with_cls_token != 0 ? 1 : 0); + int seq_len = patch_resol * patch_resol + (with_cls_token != 0 ? 1 : 0); if (atoi(argv[7]) == 1 && seq_len > 384 && seq_len % 8 != 0) { seq_len = (seq_len + 7) / 8 * 8; } @@ -68,7 +68,7 @@ int main(int argc, char* argv[]) std::cout << std::endl; - void* gemm_test_buf; + void* gemm_test_buf; size_t buf_size_in_byte = ft::calGemmTestBufSizeInByte(batch_size, seq_len, head_num, size_per_head, inter_size, 0, int8_mode, data_type); size_t total, free; diff --git a/src/fastertransformer/models/vit_int8/CMakeLists.txt b/src/fastertransformer/models/vit_int8/CMakeLists.txt index dffc48d55..02ac9640f 100644 --- a/src/fastertransformer/models/vit_int8/CMakeLists.txt +++ b/src/fastertransformer/models/vit_int8/CMakeLists.txt @@ -20,4 +20,4 @@ set_property(TARGET ViTINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(ViTINT8 PUBLIC -lcudart -lcublasLt -lcublas cublasINT8MMWrapper UnfusedAttentionLayerINT8 FusedAttentionLayerINT8 FfnLayerINT8 layernorm_kernels layernorm_int8_kernels add_residual_kernels activation_kernels layout_transformer_int8_kernels - vit_kernels bert_preprocess_kernels) + vit_kernels bert_preprocess_kernels tensor) diff --git a/src/fastertransformer/models/vit_int8/ViTINT8.cc b/src/fastertransformer/models/vit_int8/ViTINT8.cc index f61078511..575040c8f 100644 --- a/src/fastertransformer/models/vit_int8/ViTINT8.cc +++ b/src/fastertransformer/models/vit_int8/ViTINT8.cc @@ -34,7 +34,7 @@ void ViTTransformerINT8::initialize() { if (img_size_ % patch_size_ != 0) { std::ostringstream buffer; - buffer << "[FT][ERROR] IMG size & PITCH size missmatch. " << img_size_ << " % " << patch_size_ << " !=0 \n"; + buffer << "[FT][ERROR] IMG size & PITCH size mismatch. " << img_size_ << " % " << patch_size_ << " !=0 \n"; throw std::runtime_error(buffer.str()); } @@ -98,24 +98,24 @@ void ViTTransformerINT8::initialize() } template -ViTTransformerINT8::ViTTransformerINT8(size_t max_batch_size, - size_t img_size, - size_t chn_num, - size_t patch_size, - size_t embed_dim, - size_t head_num, - size_t inter_size, - size_t num_layer, - bool with_cls_token, - int sm, - float q_scaling, - int int8_mode, - cudaStream_t stream, - cudnnHandle_t cudnn_handle, +ViTTransformerINT8::ViTTransformerINT8(size_t max_batch_size, + size_t img_size, + size_t chn_num, + size_t patch_size, + size_t embed_dim, + size_t head_num, + size_t inter_size, + size_t num_layer, + bool with_cls_token, + int sm, + float q_scaling, + int int8_mode, + cudaStream_t stream, + cudnnHandle_t cudnn_handle, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type): + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), img_size_(img_size), @@ -174,16 +174,23 @@ template void ViTTransformerINT8::allocateBuffer() { if (is_allocate_buffer_ == false) { - embed_buf_1_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); - embed_buf_2_ = (T*)allocator_->malloc(sizeof(int32_t) * max_batch_size_ * max_seq_len_ * embed_dim_, false); - embed_buf_3_ = (T*)allocator_->malloc(sizeof(int32_t) * max_batch_size_ * max_seq_len_ * embed_dim_, false); - embed_buf_4_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); - mask_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); - padding_offset_ = (int*)allocator_->malloc(sizeof(int) * max_batch_size_ * max_seq_len_, false); - token_num_ = (size_t*)allocator_->malloc(sizeof(size_t) * 1, false); - - trt_mha_padding_offset_ = (int*)allocator_->malloc(sizeof(int) * (2 * max_batch_size_ + 1), false); - seq_len_vec_ = (int*)allocator_->malloc(sizeof(int) * max_batch_size_, false); + embed_buf_1_ = + (T*)allocator_->reMalloc(embed_buf_1_, sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); + embed_buf_2_ = (T*)allocator_->reMalloc( + embed_buf_2_, sizeof(int32_t) * max_batch_size_ * max_seq_len_ * embed_dim_, false); + embed_buf_3_ = (T*)allocator_->reMalloc( + embed_buf_3_, sizeof(int32_t) * max_batch_size_ * max_seq_len_ * embed_dim_, false); + embed_buf_4_ = + (T*)allocator_->reMalloc(embed_buf_4_, sizeof(T) * max_batch_size_ * max_seq_len_ * embed_dim_, false); + mask_buf_ = + (T*)allocator_->reMalloc(mask_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); + padding_offset_ = + (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * max_batch_size_ * max_seq_len_, false); + token_num_ = (size_t*)allocator_->reMalloc(token_num_, sizeof(size_t) * 1, false); + + trt_mha_padding_offset_ = + (int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * max_batch_size_ + 1), false); + seq_len_vec_ = (int*)allocator_->reMalloc(seq_len_vec_, sizeof(int) * max_batch_size_, false); setSeqLenVec(max_batch_size_); setDefaultMask(max_batch_size_); @@ -206,14 +213,14 @@ void ViTTransformerINT8::allocateBuffer(size_t batch_size) return; } - batch_size = batch_size > max_batch_size_ ? batch_size : max_batch_size_; + batch_size = batch_size > max_batch_size_ ? batch_size : max_batch_size_; embed_buf_1_ = (T*)allocator_->reMalloc(embed_buf_1_, sizeof(T) * batch_size * max_seq_len_ * embed_dim_, false); embed_buf_2_ = (T*)allocator_->reMalloc(embed_buf_2_, sizeof(int32_t) * batch_size * max_seq_len_ * embed_dim_, false); embed_buf_3_ = (T*)allocator_->reMalloc(embed_buf_3_, sizeof(int32_t) * batch_size * max_seq_len_ * embed_dim_, false); embed_buf_4_ = (T*)allocator_->reMalloc(embed_buf_4_, sizeof(T) * batch_size * max_seq_len_ * embed_dim_, false); - mask_buf_ = (T*)allocator_->reMalloc(mask_buf_, sizeof(T) * batch_size * max_seq_len_ * max_seq_len_, false); + mask_buf_ = (T*)allocator_->reMalloc(mask_buf_, sizeof(T) * batch_size * max_seq_len_ * max_seq_len_, false); REMALLOC(padding_offset_, sizeof(int) * batch_size * max_seq_len_); REMALLOC(token_num_, sizeof(size_t) * 1); trt_mha_padding_offset_ = @@ -231,23 +238,23 @@ void ViTTransformerINT8::allocateBuffer(size_t batch_size) template void ViTTransformerINT8::freeBuffer() { - allocator_->free(embed_buf_1_); - allocator_->free(embed_buf_2_); - allocator_->free(embed_buf_3_); - allocator_->free(embed_buf_4_); - allocator_->free(mask_buf_); - allocator_->free(trt_mha_padding_offset_); - allocator_->free(seq_len_vec_); - allocator_->free(padding_offset_); - allocator_->free(token_num_); + allocator_->free((void**)(&embed_buf_1_)); + allocator_->free((void**)(&embed_buf_2_)); + allocator_->free((void**)(&embed_buf_3_)); + allocator_->free((void**)(&embed_buf_4_)); + allocator_->free((void**)(&mask_buf_)); + allocator_->free((void**)(&trt_mha_padding_offset_)); + allocator_->free((void**)(&seq_len_vec_)); + allocator_->free((void**)(&padding_offset_)); + allocator_->free((void**)(&token_num_)); is_allocate_buffer_ = false; } template -void ViTTransformerINT8::forward(std::vector* output_tensors, +void ViTTransformerINT8::forward(std::vector* output_tensors, const std::vector* input_tensors, - const ViTINT8Weight* weights) + const ViTINT8Weight* weights) { // input_tensors: // input_img, BCHW [batch, chn_num, img_size, img_size] @@ -255,11 +262,11 @@ void ViTTransformerINT8::forward(std::vector* output_tensors, // output classification [batch, seq_len, embed_dim] const size_t input_batch_size = input_tensors->at(0).shape[0]; - const size_t input_chn_num = input_tensors->at(0).shape[1]; - const size_t input_img_size = input_tensors->at(0).shape[2]; - const size_t patch_resol = input_img_size / patch_size_; - size_t seq_len = patch_resol * patch_resol + (with_cls_token_ ? 1 : 0); - const bool need_padding = + const size_t input_chn_num = input_tensors->at(0).shape[1]; + const size_t input_img_size = input_tensors->at(0).shape[2]; + const size_t patch_resol = input_img_size / patch_size_; + size_t seq_len = patch_resol * patch_resol + (with_cls_token_ ? 1 : 0); + const bool need_padding = (attention_type_ == AttentionType::UNFUSED_MHA && seq_len % 32 != 0 && std::is_same::value); FT_CHECK(input_img_size == img_size_); @@ -270,9 +277,9 @@ void ViTTransformerINT8::forward(std::vector* output_tensors, FT_CHECK(output_tensors->at(0).shape.size() == 3); allocateBuffer(input_batch_size); - const T* input = (const T*)input_tensors->at(0).data; - T* output = (T*)output_tensors->at(0).data; - T* encoder_input_ptr = embed_buf_1_; + const T* input = (const T*)input_tensors->at(0).data; + T* output = (T*)output_tensors->at(0).data; + T* encoder_input_ptr = embed_buf_1_; // preprocess (patches embedding, concat class embed and add pos embed) patchEmbed(need_padding ? embed_buf_2_ : encoder_input_ptr, @@ -288,11 +295,11 @@ void ViTTransformerINT8::forward(std::vector* output_tensors, input_chn_num, embed_dim_); - DataType data_type = getTensorType(); - size_t h_token_num = input_batch_size * seq_len; - T* norm_out_buf = embed_buf_2_; - T* attn_out_buf = embed_buf_3_; - T* encoder_out_buf = embed_buf_1_; + DataType data_type = getTensorType(); + size_t h_token_num = input_batch_size * seq_len; + T* norm_out_buf = embed_buf_2_; + T* attn_out_buf = embed_buf_3_; + T* encoder_out_buf = embed_buf_1_; // get offsets Tensor* offset_tensor_ptr; @@ -304,7 +311,7 @@ void ViTTransformerINT8::forward(std::vector* output_tensors, else { offset_tensor_ptr = new Tensor(MEMORY_GPU, TYPE_INT32, std::vector{0}, nullptr); if (need_padding) { - seq_len = (seq_len + 31) / 32 * 32; + seq_len = (seq_len + 31) / 32 * 32; h_token_num = seq_len * input_batch_size; cudaMemsetAsync(encoder_input_ptr, 0, sizeof(T) * input_batch_size * seq_len * embed_dim_, stream_); invokeRebuildPadding( @@ -411,6 +418,7 @@ void ViTTransformerINT8::forward(std::vector* output_tensors, attn_out_buf, weights->post_transformer_layernorm_weights.gamma, weights->post_transformer_layernorm_weights.beta, + layernorm_eps_, h_token_num, embed_dim_, stream_); @@ -473,12 +481,12 @@ void ViTTransformerINT8::setDefaultPaddingOffset(size_t batch_size) } template -void ViTTransformerINT8::patchEmbed(T* output, - const T* input, - const T* kernel, - const T* bias, - const T* cls_embed, - const T* pos_embed, +void ViTTransformerINT8::patchEmbed(T* output, + const T* input, + const T* kernel, + const T* bias, + const T* cls_embed, + const T* pos_embed, const int batch, const int img_size, const int patch_size, diff --git a/src/fastertransformer/models/vit_int8/ViTINT8.h b/src/fastertransformer/models/vit_int8/ViTINT8.h index ef8b8840f..096587280 100644 --- a/src/fastertransformer/models/vit_int8/ViTINT8.h +++ b/src/fastertransformer/models/vit_int8/ViTINT8.h @@ -30,28 +30,29 @@ namespace fastertransformer { template class ViTTransformerINT8: public BaseLayer { private: - size_t max_batch_size_ = 0; - size_t img_size_ = 224; - size_t chn_num_ = 3; - size_t class_num_ = 1000; - size_t patch_size_ = 16; // preproc patch size - size_t max_seq_len_; - size_t request_seq_len_; - size_t embed_dim_; // patch conv out units, size_per_head = embed_dim / head_num - size_t head_num_; // mha head num - size_t head_dim_; // mha head size - size_t inter_size_; // FF internal size - size_t num_layer_; - size_t nopad_token_num_; - bool with_cls_token_; - int sm_; - float q_scaling_; - AttentionType attention_type_; - int int8_mode_; - cudnnHandle_t cudnn_handle_; + size_t max_batch_size_ = 0; + size_t img_size_ = 224; + size_t chn_num_ = 3; + size_t class_num_ = 1000; + size_t patch_size_ = 16; // preproc patch size + size_t max_seq_len_; + size_t request_seq_len_; + size_t embed_dim_; // patch conv out units, size_per_head = embed_dim / head_num + size_t head_num_; // mha head num + size_t head_dim_; // mha head size + size_t inter_size_; // FF internal size + size_t num_layer_; + size_t nopad_token_num_; + bool with_cls_token_; + int sm_; + float q_scaling_; + static constexpr float layernorm_eps_ = 1e-6f; + AttentionType attention_type_; + int int8_mode_; + cudnnHandle_t cudnn_handle_; BaseAttentionLayer* attention_layer_; - FfnLayerINT8* ffn_layer_; + FfnLayerINT8* ffn_layer_; bool is_allocate_buffer_ = false; @@ -62,12 +63,12 @@ class ViTTransformerINT8: public BaseLayer { bool setSeqLenVec(size_t batch_size); void setDefaultMask(size_t batch_size); void setDefaultPaddingOffset(size_t batch_size); - void patchEmbed(T* output, - const T* input, - const T* kernel, - const T* bias, - const T* cls_embed, - const T* pos_embed, + void patchEmbed(T* output, + const T* input, + const T* kernel, + const T* bias, + const T* cls_embed, + const T* pos_embed, const int batch, const int img_size, const int patch_size, @@ -80,43 +81,43 @@ class ViTTransformerINT8: public BaseLayer { protected: // size_t* token_num_ = nullptr; - T* embed_buf_1_ = nullptr; - T* embed_buf_2_ = nullptr; - T* embed_buf_3_ = nullptr; - T* embed_buf_4_ = nullptr; - T* mask_buf_ = nullptr; - int* trt_mha_padding_offset_ = nullptr; - int* seq_len_vec_ = nullptr; - int* padding_offset_ = nullptr; - size_t* token_num_ = nullptr; + T* embed_buf_1_ = nullptr; + T* embed_buf_2_ = nullptr; + T* embed_buf_3_ = nullptr; + T* embed_buf_4_ = nullptr; + T* mask_buf_ = nullptr; + int* trt_mha_padding_offset_ = nullptr; + int* seq_len_vec_ = nullptr; + int* padding_offset_ = nullptr; + size_t* token_num_ = nullptr; public: - ViTTransformerINT8(size_t max_batch_size, - size_t img_size, - size_t chn_num, - size_t patch_size, - size_t embed_dim, - size_t head_num, - size_t inter_size, - size_t num_layer, - bool with_cls_token, - int sm, - float q_scaling, - int int8_mode, - cudaStream_t stream, - cudnnHandle_t cudnn_handle, + ViTTransformerINT8(size_t max_batch_size, + size_t img_size, + size_t chn_num, + size_t patch_size, + size_t embed_dim, + size_t head_num, + size_t inter_size, + size_t num_layer, + bool with_cls_token, + int sm, + float q_scaling, + int int8_mode, + cudaStream_t stream, + cudnnHandle_t cudnn_handle, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - AttentionType attention_type); + IAllocator* allocator, + bool is_free_buffer_after_forward, + AttentionType attention_type); ViTTransformerINT8(ViTTransformerINT8 const& vit_layer); ~ViTTransformerINT8(); - void forward(std::vector* output_tensors, + void forward(std::vector* output_tensors, const std::vector* input_tensors, - const ViTINT8Weight* weights); + const ViTINT8Weight* weights); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/vit_int8/ViTINT8Weight.h b/src/fastertransformer/models/vit_int8/ViTINT8Weight.h index e5451ad4f..cdd9c0230 100644 --- a/src/fastertransformer/models/vit_int8/ViTINT8Weight.h +++ b/src/fastertransformer/models/vit_int8/ViTINT8Weight.h @@ -27,42 +27,50 @@ struct ViTEmbeds { const T* position_embed; }; +#define WEIGHT_N 6 template struct ViTINT8Weight { ViTINT8Weight() = default; - ViTINT8Weight(const int embed_dim, - const int inter_size, - const int num_layer, - const int img_size, - const int patch_size, - const int chn_num, - const bool with_cls_token): + ViTINT8Weight(const int embed_dim, + const int inter_size, + const int num_layer, + const int img_size, + const int patch_size, + const int chn_num, + const bool with_cls_token, + const bool hold_buffer = true): with_cls_token_(with_cls_token), embed_dim_(embed_dim), inter_size_(inter_size), num_layer_(num_layer), img_size_(img_size), patch_size_(patch_size), - chn_num_(chn_num) + chn_num_(chn_num), + seq_len_(img_size_ * img_size_ / (patch_size_ * patch_size_) + (with_cls_token_ ? 1 : 0)) { - deviceMalloc(&weights_ptr[0], embed_dim_); - deviceMalloc(&weights_ptr[1], embed_dim_); - if (with_cls_token) { - deviceMalloc(&weights_ptr[2], embed_dim_); // pre_transform_embeds.class_embed - } - deviceMalloc(&weights_ptr[3], - embed_dim_ - * (img_size_ * img_size_ / (patch_size_ * patch_size_) - + (with_cls_token ? 1 : 0))); // pre_transform_embeds.position_embed - deviceMalloc(&weights_ptr[4], - chn_num_ * patch_size_ * patch_size_ * embed_dim_); // pre_encoder_conv_weights.kernel - deviceMalloc(&weights_ptr[5], embed_dim_); // pre_encoder_conv_weights.bias - - setWeightPtr(); + weights_size[0] = chn_num_ * patch_size_ * patch_size_ * embed_dim_; + weights_size[1] = embed_dim_; + weights_size[2] = with_cls_token_ ? embed_dim_ : 0; + weights_size[3] = embed_dim_ * seq_len_; + weights_size[4] = embed_dim_; + weights_size[5] = embed_dim_; + + if (hold_buffer) { + for (int i = 0; i < WEIGHT_N; i++) { + if (weights_size[i] == 0) { + continue; + } + + deviceMalloc(&weights_ptr[i], weights_size[i]); + } + + setWeightPtr(); + } vit_layer_weights.reserve(num_layer_); + for (int i = 0; i < num_layer_; i++) { - vit_layer_weights.push_back(ViTLayerINT8Weight(embed_dim_, inter_size_)); + vit_layer_weights.push_back(ViTLayerINT8Weight(embed_dim_, inter_size_, i, hold_buffer)); } } @@ -70,17 +78,18 @@ struct ViTINT8Weight { { if (is_maintain_buffer == true) { vit_layer_weights.clear(); - for (int i = 0; i < 6; i++) { - deviceFree(weights_ptr[i]); + for (int i = 0; i < WEIGHT_N; i++) { + if (weights_ptr[i] != nullptr) + deviceFree(weights_ptr[i]); } post_transformer_layernorm_weights.gamma = nullptr; - post_transformer_layernorm_weights.beta = nullptr; - pre_transform_embeds.class_embed = nullptr; - pre_transform_embeds.position_embed = nullptr; - pre_encoder_conv_weights.kernel = nullptr; - pre_encoder_conv_weights.bias = nullptr; - is_maintain_buffer = false; + post_transformer_layernorm_weights.beta = nullptr; + pre_transform_embeds.class_embed = nullptr; + pre_transform_embeds.position_embed = nullptr; + pre_encoder_conv_weights.kernel = nullptr; + pre_encoder_conv_weights.bias = nullptr; + is_maintain_buffer = false; } } @@ -92,103 +101,208 @@ struct ViTINT8Weight { img_size_(other.img_size_), patch_size_(other.patch_size_), chn_num_(other.chn_num_), - cls_num_(other.cls_num_) + seq_len_(other.seq_len_) { + memcpy(weights_size, other.weights_size, sizeof(size_t) * WEIGHT_N); + if (other.is_maintain_buffer) { + for (int i = 0; i < WEIGHT_N; i++) { + if (!is_maintain_buffer) { + deviceMalloc(&weights_ptr[i], weights_size[i]); + } + cudaD2Dcpy(weights_ptr[i], other.weights_ptr[i], weights_size[i]); + } + setWeightPtr(); + } + vit_layer_weights.clear(); vit_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { vit_layer_weights.push_back(other.vit_layer_weights[i]); } - deviceMalloc(&weights_ptr[0], embed_dim_); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], embed_dim_); - deviceMalloc(&weights_ptr[1], embed_dim_); - cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], embed_dim_); - if (other.weights_ptr[2] != nullptr) { - deviceMalloc(&weights_ptr[2], embed_dim_); // pre_transform_embeds.class_embed - cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], embed_dim_); - } - deviceMalloc(&weights_ptr[3], - embed_dim_ - * (img_size_ * img_size_ / (patch_size_ * patch_size_) - + (with_cls_token_ ? 1 : 0))); // pre_transform_embeds.position_embed - cudaD2Dcpy(weights_ptr[3], - other.weights_ptr[3], - embed_dim_ * (img_size_ * img_size_ / (patch_size_ * patch_size_) + (with_cls_token_ ? 1 : 0))); - deviceMalloc(&weights_ptr[4], - chn_num_ * patch_size_ * patch_size_ * embed_dim_); // pre_encoder_conv_weights.kernel - cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], chn_num_ * patch_size_ * patch_size_ * embed_dim_); - deviceMalloc(&weights_ptr[5], embed_dim_); // pre_encoder_conv_weights.bias - cudaD2Dcpy(weights_ptr[5], other.weights_ptr[5], embed_dim_); - - setWeightPtr(); } ViTINT8Weight& operator=(const ViTINT8Weight& other) { + embed_dim_ = other.embed_dim_; + inter_size_ = other.inter_size_; + num_layer_ = other.num_layer_; + img_size_ = other.img_size_; + patch_size_ = other.patch_size_; + chn_num_ = other.chn_num_; + seq_len_ = other.seq_len_; with_cls_token_ = other.with_cls_token_; - embed_dim_ = other.embed_dim_; - inter_size_ = other.inter_size_; - num_layer_ = other.num_layer_; - img_size_ = other.img_size_; - patch_size_ = other.patch_size_; - chn_num_ = other.chn_num_; - cls_num_ = other.cls_num_; + memcpy(weights_size, other.weights_size, sizeof(size_t) * WEIGHT_N); + + if (other.is_maintain_buffer) { + for (int i = 0; i < WEIGHT_N; i++) { + if (!is_maintain_buffer) { + deviceMalloc(&weights_ptr[i], weights_size[i]); + } + cudaD2Dcpy(weights_ptr[i], other.weights_ptr[i], weights_size[i]); + } + setWeightPtr(); + } + vit_layer_weights.clear(); vit_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { vit_layer_weights.push_back(other.vit_layer_weights[i]); } - deviceMalloc(&weights_ptr[0], embed_dim_); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], embed_dim_); - deviceMalloc(&weights_ptr[1], embed_dim_); - cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], embed_dim_); - if (other.weights_ptr[2] != nullptr) { - deviceMalloc(&weights_ptr[2], embed_dim_); // pre_transform_embeds.class_embed - cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], embed_dim_); - } - deviceMalloc(&weights_ptr[3], - embed_dim_ - * (img_size_ * img_size_ / (patch_size_ * patch_size_) - + (with_cls_token_ ? 1 : 0))); // pre_transform_embeds.position_embed - cudaD2Dcpy(weights_ptr[3], - other.weights_ptr[3], - embed_dim_ * (img_size_ * img_size_ / (patch_size_ * patch_size_) + (with_cls_token_ ? 1 : 0))); - deviceMalloc(&weights_ptr[4], - chn_num_ * patch_size_ * patch_size_ * embed_dim_); // pre_encoder_conv_weights.kernel - cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], chn_num_ * patch_size_ * patch_size_ * embed_dim_); - deviceMalloc(&weights_ptr[5], embed_dim_); // pre_encoder_conv_weights.bias - cudaD2Dcpy(weights_ptr[5], other.weights_ptr[5], embed_dim_); - - setWeightPtr(); + + return *this; + } + + size_t GetSerializeSize() + { + size_t count; + for (int i = 0; i < WEIGHT_N; i++) { + count += weights_size[i]; + } + count *= sizeof(T); + + for (auto& lw : vit_layer_weights) { + count += lw.GetSerializeSize(); + } + + return count; + } + + void serialize(void* buffer) + { + char* tmp_buf = (char*)buffer; + for (int i = 0; i < WEIGHT_N; i++) { + cudaMemcpy(tmp_buf, weights_ptr[i], sizeof(T) * weights_size[i], cudaMemcpyDeviceToHost); + tmp_buf += sizeof(T) * weights_size[i]; + } + + for (auto& lw : vit_layer_weights) { + lw.serialize(tmp_buf); + tmp_buf += lw.GetSerializeSize(); + } + } + + void deserialize(const void* buffer) + { + if (!is_maintain_buffer) { + return; + } + + char* tmp_buf = (char*)buffer; + for (int i = 0; i < WEIGHT_N; i++) { + cudaMemcpy(weights_ptr[i], tmp_buf, sizeof(T) * weights_size[i], cudaMemcpyHostToDevice); + tmp_buf += sizeof(T) * weights_size[i]; + } + + for (auto& lw : vit_layer_weights) { + lw.deserialize(tmp_buf); + tmp_buf += lw.GetSerializeSize(); + } + } + + void CopyWeightsFromHostBuffers(const T* const*& w) + { + cudaMemcpy( + const_cast(pre_encoder_conv_weights.kernel), *w++, sizeof(T) * weights_size[0], cudaMemcpyHostToDevice); + cudaMemcpy( + const_cast(pre_encoder_conv_weights.bias), *w++, sizeof(T) * weights_size[1], cudaMemcpyHostToDevice); + if (with_cls_token_) { + cudaMemcpy(const_cast(pre_transform_embeds.class_embed), + *w++, + sizeof(T) * weights_size[2], + cudaMemcpyHostToDevice); + with_cls_token_ = true; + } + cudaMemcpy(const_cast(pre_transform_embeds.position_embed), + *w++, + sizeof(T) * weights_size[3], + cudaMemcpyHostToDevice); + + for (int i = 0; i < num_layer_; i++) { + auto& layer_weight = vit_layer_weights[i]; + layer_weight.CopyWeightsFromHostBuffers(w); + } + + cudaMemcpy(const_cast(post_transformer_layernorm_weights.gamma), + *w++, + sizeof(T) * weights_size[4], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(post_transformer_layernorm_weights.beta), + *w++, + sizeof(T) * weights_size[5], + cudaMemcpyHostToDevice); + } + + inline size_t GetWeightCount() + { + size_t weight_count = WEIGHT_N; + weight_count += num_layer_ * vit_layer_weights[0].GetWeightCount(); + + return weight_count; + } + + void ExportWeights() + { + DataType dtype = DataType::TYPE_INVALID; + if (std::is_same::value) { + dtype = DataType::TYPE_FP16; + } + else if (std::is_same::value) { + dtype = DataType::TYPE_FP32; + } + + Tensor conv_kernel{MEMORY_GPU, + dtype, + std::vector{embed_dim_, chn_num_, patch_size_, patch_size_}, + pre_encoder_conv_weights.kernel}; // OIHW + FT_LOG_INFO("exporting conv"); + conv_kernel.saveNpy("./weights/conv_kernel.npy"); + Tensor conv_bias{MEMORY_GPU, dtype, std::vector{embed_dim_}, pre_encoder_conv_weights.bias}; // OIHW + FT_LOG_INFO("exporting conv bias"); + conv_bias.saveNpy("./weights/conv_bias.npy"); + Tensor cls_token{MEMORY_GPU, dtype, std::vector{embed_dim_}, pre_transform_embeds.class_embed}; + if (with_cls_token_) + cls_token.saveNpy("./weights/cls_token.npy"); + Tensor pos_embed{ + MEMORY_GPU, dtype, std::vector{1, seq_len_, embed_dim_}, pre_transform_embeds.position_embed}; + pos_embed.saveNpy("./weights/pos_embed.npy"); + Tensor ln_gamma{MEMORY_GPU, dtype, std::vector{embed_dim_}, post_transformer_layernorm_weights.gamma}; + ln_gamma.saveNpy("./weights/enc_ln_scale.npy"); + Tensor ln_beta{MEMORY_GPU, dtype, std::vector{embed_dim_}, post_transformer_layernorm_weights.beta}; + ln_beta.saveNpy("./weights/enc_ln_bias.npy"); + + for (int i = 0; i < num_layer_; i++) { + vit_layer_weights[i].ExportWeights(i); + } } std::vector> vit_layer_weights; - LayerNormWeight post_transformer_layernorm_weights; - ViTEmbeds pre_transform_embeds; - DenseWeight pre_encoder_conv_weights; - bool with_cls_token_; + LayerNormWeight post_transformer_layernorm_weights; + ViTEmbeds pre_transform_embeds; + DenseWeight pre_encoder_conv_weights; + bool with_cls_token_; private: void setWeightPtr() { - post_transformer_layernorm_weights.gamma = weights_ptr[0]; - post_transformer_layernorm_weights.beta = weights_ptr[1]; - pre_transform_embeds.class_embed = weights_ptr[2]; - pre_transform_embeds.position_embed = weights_ptr[3]; - pre_encoder_conv_weights.kernel = weights_ptr[4]; - pre_encoder_conv_weights.bias = weights_ptr[5]; + pre_encoder_conv_weights.kernel = weights_ptr[0]; + pre_encoder_conv_weights.bias = weights_ptr[1]; + pre_transform_embeds.class_embed = weights_ptr[2]; + pre_transform_embeds.position_embed = weights_ptr[3]; + post_transformer_layernorm_weights.gamma = weights_ptr[4]; + post_transformer_layernorm_weights.beta = weights_ptr[5]; is_maintain_buffer = true; } - int embed_dim_; - int inter_size_; - int num_layer_; - int img_size_; - int patch_size_; - int chn_num_; - int cls_num_; - bool is_maintain_buffer = false; - T* weights_ptr[8]{nullptr}; + int embed_dim_; + int inter_size_; + int num_layer_; + int img_size_; + int patch_size_; + int chn_num_; + int seq_len_; + bool is_maintain_buffer = false; + T* weights_ptr[WEIGHT_N]{nullptr}; + size_t weights_size[WEIGHT_N]; }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/vit_int8/ViTLayerINT8Weight.h b/src/fastertransformer/models/vit_int8/ViTLayerINT8Weight.h index 59f394fc6..c3a89ccf7 100644 --- a/src/fastertransformer/models/vit_int8/ViTLayerINT8Weight.h +++ b/src/fastertransformer/models/vit_int8/ViTLayerINT8Weight.h @@ -24,201 +24,334 @@ namespace fastertransformer { +#define WEIGHT_N 16 + template struct ViTLayerINT8Weight { ViTLayerINT8Weight() = default; - ViTLayerINT8Weight(const int embed_dim, const int inter_size): embed_dim_(embed_dim), inter_size_(inter_size) + ViTLayerINT8Weight(const int embed_dim, const int inter_size, int layer_idx, const bool hold_buffer): + embed_dim_(embed_dim), inter_size_(inter_size), layer_idx_(layer_idx) { - deviceMalloc(&weights_ptr[0], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[1], embed_dim_); - deviceMalloc(&weights_ptr[2], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[3], embed_dim_); - deviceMalloc(&weights_ptr[4], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[5], embed_dim_); - deviceMalloc(&weights_ptr[6], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[7], embed_dim_); - deviceMalloc(&weights_ptr[8], embed_dim_); - deviceMalloc(&weights_ptr[9], embed_dim_); - deviceMalloc(&weights_ptr[10], embed_dim_ * inter_size_); - deviceMalloc(&weights_ptr[11], inter_size_); - deviceMalloc(&weights_ptr[12], inter_size_ * embed_dim_); - deviceMalloc(&weights_ptr[13], embed_dim_); - deviceMalloc(&weights_ptr[14], embed_dim_); - deviceMalloc(&weights_ptr[15], embed_dim_); - - scale_list_.size_ = ACTIVATION_AMAX_NUM + 9 * embed_dim + INT8O_GEMM_NUM + TRT_AMAX_NUM + SCALE_RESERVE_NUM; - scale_list_.p3_offset_ = ACTIVATION_AMAX_NUM + 9 * embed_dim; - scale_list_.p4_offset_ = ACTIVATION_AMAX_NUM + 9 * embed_dim + INT8O_GEMM_NUM; - deviceMalloc(&scale_list_ptr[0], scale_list_.size_); - scale_list_ptr[1] = (float*)malloc(sizeof(float) * scale_list_.size_); - - setWeightPtr(); + weights_size[0] = embed_dim_ * embed_dim_; + weights_size[1] = embed_dim_; + weights_size[2] = embed_dim_ * embed_dim_; + weights_size[3] = embed_dim_; + weights_size[4] = embed_dim_ * embed_dim_; + weights_size[5] = embed_dim_; + weights_size[6] = embed_dim_ * embed_dim_; + weights_size[7] = embed_dim_; + weights_size[8] = embed_dim_; + weights_size[9] = embed_dim_; + weights_size[10] = embed_dim_ * inter_size_; + weights_size[11] = inter_size_; + weights_size[12] = inter_size_ * embed_dim_; + weights_size[13] = embed_dim_; + weights_size[14] = embed_dim_; + weights_size[15] = embed_dim_; + if (hold_buffer) { + for (int i = 0; i < WEIGHT_N; i++) { + deviceMalloc(&weights_ptr[i], weights_size[i]); + } + + scale_list_.size_ = ACTIVATION_AMAX_NUM + 9 * embed_dim + INT8O_GEMM_NUM + TRT_AMAX_NUM + SCALE_RESERVE_NUM; + scale_list_.p3_offset_ = ACTIVATION_AMAX_NUM + 9 * embed_dim; + scale_list_.p4_offset_ = ACTIVATION_AMAX_NUM + 9 * embed_dim + INT8O_GEMM_NUM; + deviceMalloc(&scale_list_ptr[0], scale_list_.size_); + scale_list_ptr[1] = (float*)malloc(sizeof(float) * scale_list_.size_); + setWeightPtr(); + } } ~ViTLayerINT8Weight() { if (is_maintain_buffer == true) { - for (int i = 0; i < 16; i++) { + for (int i = 0; i < WEIGHT_N; i++) { deviceFree(weights_ptr[i]); } deviceFree(scale_list_ptr[0]); free(scale_list_ptr[1]); - - attention_weights.query_weight.kernel = nullptr; - attention_weights.query_weight.bias = nullptr; - attention_weights.key_weight.kernel = nullptr; - attention_weights.key_weight.bias = nullptr; - attention_weights.value_weight.kernel = nullptr; - attention_weights.value_weight.bias = nullptr; + attention_weights.query_weight.kernel = nullptr; + attention_weights.query_weight.bias = nullptr; + attention_weights.key_weight.kernel = nullptr; + attention_weights.key_weight.bias = nullptr; + attention_weights.value_weight.kernel = nullptr; + attention_weights.value_weight.bias = nullptr; attention_weights.attention_output_weight.kernel = nullptr; - attention_weights.attention_output_weight.bias = nullptr; - attn_layernorm_weights.gamma = nullptr; - attn_layernorm_weights.beta = nullptr; - ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.kernel = nullptr; - ffn_weights.output_weight.bias = nullptr; - ffn_layernorm_weights.gamma = nullptr; - ffn_layernorm_weights.beta = nullptr; - is_maintain_buffer = false; + attention_weights.attention_output_weight.bias = nullptr; + attn_layernorm_weights.gamma = nullptr; + attn_layernorm_weights.beta = nullptr; + ffn_weights.intermediate_weight.kernel = nullptr; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + ffn_layernorm_weights.gamma = nullptr; + ffn_layernorm_weights.beta = nullptr; + is_maintain_buffer = false; } } ViTLayerINT8Weight(const ViTLayerINT8Weight& other): embed_dim_(other.embed_dim_), inter_size_(other.inter_size_) { - deviceMalloc(&weights_ptr[0], embed_dim_ * embed_dim_); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[1], embed_dim_); - cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], embed_dim_); - deviceMalloc(&weights_ptr[2], embed_dim_ * embed_dim_); - cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[3], embed_dim_); - cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], embed_dim_); - deviceMalloc(&weights_ptr[4], embed_dim_ * embed_dim_); - cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[5], embed_dim_); - cudaD2Dcpy(weights_ptr[5], other.weights_ptr[5], embed_dim_); - deviceMalloc(&weights_ptr[6], embed_dim_ * embed_dim_); - cudaD2Dcpy(weights_ptr[6], other.weights_ptr[6], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[7], embed_dim_); - cudaD2Dcpy(weights_ptr[7], other.weights_ptr[7], embed_dim_); - deviceMalloc(&weights_ptr[8], embed_dim_); - cudaD2Dcpy(weights_ptr[8], other.weights_ptr[8], embed_dim_); - deviceMalloc(&weights_ptr[9], embed_dim_); - cudaD2Dcpy(weights_ptr[9], other.weights_ptr[9], embed_dim_); - deviceMalloc(&weights_ptr[10], embed_dim_ * inter_size_); - cudaD2Dcpy(weights_ptr[10], other.weights_ptr[10], embed_dim_ * inter_size_); - deviceMalloc(&weights_ptr[11], inter_size_); - cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], inter_size_); - deviceMalloc(&weights_ptr[12], inter_size_ * embed_dim_); - cudaD2Dcpy(weights_ptr[12], other.weights_ptr[12], inter_size_ * embed_dim_); - deviceMalloc(&weights_ptr[13], embed_dim_); - cudaD2Dcpy(weights_ptr[13], other.weights_ptr[13], embed_dim_); - deviceMalloc(&weights_ptr[14], embed_dim_); - cudaD2Dcpy(weights_ptr[14], other.weights_ptr[14], embed_dim_); - deviceMalloc(&weights_ptr[15], embed_dim_); - cudaD2Dcpy(weights_ptr[15], other.weights_ptr[15], embed_dim_); - - scale_list_.size_ = other.scale_list_.size_; - scale_list_.p3_offset_ = other.scale_list_.p3_offset_; - scale_list_.p4_offset_ = other.scale_list_.p4_offset_; - deviceMalloc(&scale_list_ptr[0], scale_list_.size_); - cudaD2Dcpy(scale_list_ptr[0], other.scale_list_ptr[0], scale_list_.size_); - scale_list_ptr[1] = (float*)malloc(sizeof(float) * scale_list_.size_); - memcpy(scale_list_ptr[1], other.scale_list_ptr[1], sizeof(float) * scale_list_.size_); - - setWeightPtr(); + memcpy(weights_size, other.weights_size, sizeof(size_t) * WEIGHT_N); + layer_idx_ = other.layer_idx_; + if (other.is_maintain_buffer) { + for (int i = 0; i < WEIGHT_N; i++) { + if (!is_maintain_buffer) { + deviceMalloc(&weights_ptr[i], weights_size[i]); + } + cudaD2Dcpy(weights_ptr[i], other.weights_ptr[i], weights_size[i]); + } + + scale_list_.size_ = other.scale_list_.size_; + scale_list_.p3_offset_ = other.scale_list_.p3_offset_; + scale_list_.p4_offset_ = other.scale_list_.p4_offset_; + deviceMalloc(&scale_list_ptr[0], scale_list_.size_); + cudaD2Dcpy(scale_list_ptr[0], other.scale_list_ptr[0], scale_list_.size_); + scale_list_ptr[1] = (float*)malloc(sizeof(float) * scale_list_.size_); + memcpy(scale_list_ptr[1], other.scale_list_ptr[1], sizeof(float) * scale_list_.size_); + setWeightPtr(); + } } ViTLayerINT8Weight& operator=(const ViTLayerINT8Weight& other) { - embed_dim_ = other.embed_dim_; + embed_dim_ = other.embed_dim_; inter_size_ = other.inter_size_; - deviceMalloc(&weights_ptr[0], embed_dim_ * embed_dim_); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[1], embed_dim_); - cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], embed_dim_); - deviceMalloc(&weights_ptr[2], embed_dim_ * embed_dim_); - cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[3], embed_dim_); - cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], embed_dim_); - deviceMalloc(&weights_ptr[4], embed_dim_ * embed_dim_); - cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[5], embed_dim_); - cudaD2Dcpy(weights_ptr[5], other.weights_ptr[5], embed_dim_); - deviceMalloc(&weights_ptr[6], embed_dim_ * embed_dim_); - cudaD2Dcpy(weights_ptr[6], other.weights_ptr[6], embed_dim_ * embed_dim_); - deviceMalloc(&weights_ptr[7], embed_dim_); - cudaD2Dcpy(weights_ptr[7], other.weights_ptr[7], embed_dim_); - deviceMalloc(&weights_ptr[8], embed_dim_); - cudaD2Dcpy(weights_ptr[8], other.weights_ptr[8], embed_dim_); - deviceMalloc(&weights_ptr[9], embed_dim_); - cudaD2Dcpy(weights_ptr[9], other.weights_ptr[9], embed_dim_); - deviceMalloc(&weights_ptr[10], embed_dim_ * inter_size_); - cudaD2Dcpy(weights_ptr[10], other.weights_ptr[10], embed_dim_ * inter_size_); - deviceMalloc(&weights_ptr[11], inter_size_); - cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], inter_size_); - deviceMalloc(&weights_ptr[12], inter_size_ * embed_dim_); - cudaD2Dcpy(weights_ptr[12], other.weights_ptr[12], inter_size_ * embed_dim_); - deviceMalloc(&weights_ptr[13], embed_dim_); - cudaD2Dcpy(weights_ptr[13], other.weights_ptr[13], embed_dim_); - deviceMalloc(&weights_ptr[14], embed_dim_); - cudaD2Dcpy(weights_ptr[14], other.weights_ptr[14], embed_dim_); - deviceMalloc(&weights_ptr[15], embed_dim_); - cudaD2Dcpy(weights_ptr[15], other.weights_ptr[15], embed_dim_); - - scale_list_.size_ = other.scale_list_.size_; - scale_list_.p3_offset_ = other.scale_list_.p3_offset_; - scale_list_.p4_offset_ = other.scale_list_.p4_offset_; - deviceMalloc(&scale_list_ptr[0], scale_list_.size_); - cudaD2Dcpy(scale_list_ptr[0], other.scale_list_ptr[0], scale_list_.size_); - scale_list_ptr[1] = (float*)malloc(sizeof(float) * scale_list_.size_); - memcpy(scale_list_ptr[1], other.scale_list_ptr[1], sizeof(float) * scale_list_.size_); - - setWeightPtr(); + layer_idx_ = other.layer_idx_; + memcpy(weights_size, other.weights_size, sizeof(size_t) * WEIGHT_N); + if (other.is_maintain_buffer) { + for (int i = 0; i < WEIGHT_N; i++) { + if (!is_maintain_buffer) { + deviceMalloc(&weights_ptr[i], weights_size[i]); + } + cudaD2Dcpy(weights_ptr[i], other.weights_ptr[i], weights_size[i]); + } + scale_list_.size_ = other.scale_list_.size_; + scale_list_.p3_offset_ = other.scale_list_.p3_offset_; + scale_list_.p4_offset_ = other.scale_list_.p4_offset_; + deviceMalloc(&scale_list_ptr[0], scale_list_.size_); + cudaD2Dcpy(scale_list_ptr[0], other.scale_list_ptr[0], scale_list_.size_); + scale_list_ptr[1] = (float*)malloc(sizeof(float) * scale_list_.size_); + memcpy(scale_list_ptr[1], other.scale_list_ptr[1], sizeof(float) * scale_list_.size_); + setWeightPtr(); + } + + return *this; + } + + inline size_t GetWeightCount() + { + return WEIGHT_N + 2; + } + + size_t GetSerializeSize() + { + size_t count; + for (int i = 0; i < WEIGHT_N; i++) { + count += weights_size[i]; + } + + return sizeof(T) * count + 2 * scale_list_.size_ * sizeof(float); + } + + void CopyWeightsFromHostBuffers(const T* const*& w) + { + cudaMemcpy( + const_cast(attn_layernorm_weights.gamma), *w++, sizeof(T) * weights_size[8], cudaMemcpyHostToDevice); + cudaMemcpy( + const_cast(attn_layernorm_weights.beta), *w++, sizeof(T) * weights_size[9], cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(attention_weights.query_weight.kernel), + *w++, + sizeof(T) * weights_size[0], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(attention_weights.query_weight.bias), + *w++, + sizeof(T) * weights_size[1], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(attention_weights.key_weight.kernel), + *w++, + sizeof(T) * weights_size[2], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(attention_weights.key_weight.bias), + *w++, + sizeof(T) * weights_size[3], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(attention_weights.value_weight.kernel), + *w++, + sizeof(T) * weights_size[4], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(attention_weights.value_weight.bias), + *w++, + sizeof(T) * weights_size[5], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(attention_weights.attention_output_weight.kernel), + *w++, + sizeof(T) * weights_size[6], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(attention_weights.attention_output_weight.bias), + *w++, + sizeof(T) * weights_size[7], + cudaMemcpyHostToDevice); + cudaMemcpy( + const_cast(ffn_layernorm_weights.gamma), *w++, sizeof(T) * weights_size[14], cudaMemcpyHostToDevice); + cudaMemcpy( + const_cast(ffn_layernorm_weights.beta), *w++, sizeof(T) * weights_size[15], cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(ffn_weights.intermediate_weight.kernel), + *w++, + sizeof(T) * weights_size[10], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(ffn_weights.intermediate_weight.bias), + *w++, + sizeof(T) * weights_size[11], + cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(ffn_weights.output_weight.kernel), + *w++, + sizeof(T) * weights_size[12], + cudaMemcpyHostToDevice); + cudaMemcpy( + const_cast(ffn_weights.output_weight.bias), *w++, sizeof(T) * weights_size[13], cudaMemcpyHostToDevice); + + cudaMemcpy( + const_cast(scale_list_ptr[0]), *w++, sizeof(float) * scale_list_.size_, cudaMemcpyHostToDevice); + cudaMemcpy( + const_cast(scale_list_ptr[1]), *w++, sizeof(float) * scale_list_.size_, cudaMemcpyHostToHost); + } + + void serialize(void* buffer) + { + char* tmp_buf = (char*)buffer; + for (int i = 0; i < WEIGHT_N; i++) { + cudaMemcpy(tmp_buf, weights_ptr[i], sizeof(T) * weights_size[i], cudaMemcpyDeviceToHost); + tmp_buf += sizeof(T) * weights_size[i]; + } + cudaMemcpy(tmp_buf, scale_list_ptr[0], sizeof(float) * scale_list_.size_, cudaMemcpyDeviceToHost); + tmp_buf += sizeof(float) * scale_list_.size_; + cudaMemcpy(tmp_buf, scale_list_ptr[1], sizeof(float) * scale_list_.size_, cudaMemcpyHostToHost); + tmp_buf += sizeof(float) * scale_list_.size_; + } + + void deserialize(const void* buffer) + { + if (!is_maintain_buffer) { + return; + } + + char* tmp_buf = (char*)buffer; + for (int i = 0; i < WEIGHT_N; i++) { + cudaMemcpy(weights_ptr[i], tmp_buf, sizeof(T) * weights_size[i], cudaMemcpyHostToDevice); + tmp_buf += sizeof(T) * weights_size[i]; + } + cudaMemcpy(scale_list_ptr[0], tmp_buf, sizeof(float) * scale_list_.size_, cudaMemcpyHostToDevice); + tmp_buf += sizeof(float) * scale_list_.size_; + cudaMemcpy(scale_list_ptr[1], tmp_buf, sizeof(float) * scale_list_.size_, cudaMemcpyHostToHost); + tmp_buf += sizeof(float) * scale_list_.size_; + } + + void ExportWeights(int layer_idx) + { + FT_LOG_INFO("Exporting layer %d...", layer_idx); + FT_LOG_INFO("embed_dim:%d, inter_size:%d", embed_dim_, inter_size_); + DataType dtype = DataType::TYPE_INVALID; + if (std::is_same::value) { + dtype = DataType::TYPE_FP16; + } + else if (std::is_same::value) { + dtype = DataType::TYPE_FP32; + } + + std::ostringstream buffer; + buffer << "./weights/l" << layer_idx; + DataType wtype = DataType::TYPE_INT8; + Tensor w1{ + MEMORY_GPU, wtype, std::vector{embed_dim_, embed_dim_}, attention_weights.query_weight.kernel}; + w1.saveNpy(std::string(buffer.str()) + "_q_kern.npy"); + Tensor w2{MEMORY_GPU, dtype, std::vector{embed_dim_}, attention_weights.query_weight.bias}; + w2.saveNpy(std::string(buffer.str()) + "_q_bias.npy"); + Tensor w3{MEMORY_GPU, wtype, std::vector{embed_dim_, embed_dim_}, attention_weights.key_weight.kernel}; + w3.saveNpy(std::string(buffer.str()) + "_k_kern.npy"); + Tensor w4{MEMORY_GPU, dtype, std::vector{embed_dim_}, attention_weights.key_weight.bias}; + w4.saveNpy(std::string(buffer.str()) + "_k_bias.npy"); + Tensor w5{ + MEMORY_GPU, wtype, std::vector{embed_dim_, embed_dim_}, attention_weights.value_weight.kernel}; + w5.saveNpy(std::string(buffer.str()) + "_v_kern.npy"); + Tensor w6{MEMORY_GPU, dtype, std::vector{embed_dim_}, attention_weights.value_weight.bias}; + w6.saveNpy(std::string(buffer.str()) + "_v_bias.npy"); + Tensor w7{MEMORY_GPU, + wtype, + std::vector{embed_dim_, embed_dim_}, + attention_weights.attention_output_weight.kernel}; + w7.saveNpy(std::string(buffer.str()) + "_att_o_kern.npy"); + Tensor w8{MEMORY_GPU, dtype, std::vector{embed_dim_}, attention_weights.attention_output_weight.bias}; + w8.saveNpy(std::string(buffer.str()) + "_att_o_bias.npy"); + Tensor w9{MEMORY_GPU, dtype, std::vector{embed_dim_}, attn_layernorm_weights.gamma}; + w9.saveNpy(std::string(buffer.str()) + "_ln0_scale.npy"); + Tensor w10{MEMORY_GPU, dtype, std::vector{embed_dim_}, attn_layernorm_weights.beta}; + w10.saveNpy(std::string(buffer.str()) + "_ln0_bias.npy"); + Tensor w11{ + MEMORY_GPU, wtype, std::vector{embed_dim_, inter_size_}, ffn_weights.intermediate_weight.kernel}; + w11.saveNpy(std::string(buffer.str()) + "_ffn_inter_kern.npy"); + Tensor w12{MEMORY_GPU, dtype, std::vector{inter_size_}, ffn_weights.intermediate_weight.bias}; + w12.saveNpy(std::string(buffer.str()) + "_ffn_inter_bias.npy"); + Tensor w13{MEMORY_GPU, wtype, std::vector{inter_size_, embed_dim_}, ffn_weights.output_weight.kernel}; + w13.saveNpy(std::string(buffer.str()) + "_ffn_o_kern.npy"); + Tensor w14{MEMORY_GPU, dtype, std::vector{embed_dim_}, ffn_weights.output_weight.bias}; + w14.saveNpy(std::string(buffer.str()) + "_ffn_o_bias.npy"); + Tensor w15{MEMORY_GPU, dtype, std::vector{embed_dim_}, ffn_layernorm_weights.gamma}; + w15.saveNpy(std::string(buffer.str()) + "_ln2_scale.npy"); + Tensor w16{MEMORY_GPU, dtype, std::vector{embed_dim_}, ffn_layernorm_weights.beta}; + w16.saveNpy(std::string(buffer.str()) + "_ln2_bias.npy"); + + Tensor w17{MEMORY_GPU, DataType::TYPE_FP32, std::vector{scale_list_.size_}, scale_list_.d_scale_list_}; + w17.saveNpy(std::string(buffer.str()) + "_d_scalelist.npy"); + Tensor w18{MEMORY_CPU, DataType::TYPE_FP32, std::vector{scale_list_.size_}, scale_list_.h_scale_list_}; + w18.saveNpy(std::string(buffer.str()) + "_h_scalelist.npy"); } AttentionINT8Weight attention_weights; - LayerNormWeight attn_layernorm_weights; - FfnINT8Weight ffn_weights; - LayerNormWeight ffn_layernorm_weights; - ScaleList scale_list_; + LayerNormWeight attn_layernorm_weights; + FfnINT8Weight ffn_weights; + LayerNormWeight ffn_layernorm_weights; + ScaleList scale_list_; private: void setWeightPtr() { - attention_weights.query_weight.kernel = weights_ptr[0]; - attention_weights.query_weight.bias = weights_ptr[1]; - attention_weights.key_weight.kernel = weights_ptr[2]; - attention_weights.key_weight.bias = weights_ptr[3]; - attention_weights.value_weight.kernel = weights_ptr[4]; - attention_weights.value_weight.bias = weights_ptr[5]; + attention_weights.query_weight.kernel = weights_ptr[0]; + attention_weights.query_weight.bias = weights_ptr[1]; + attention_weights.key_weight.kernel = weights_ptr[2]; + attention_weights.key_weight.bias = weights_ptr[3]; + attention_weights.value_weight.kernel = weights_ptr[4]; + attention_weights.value_weight.bias = weights_ptr[5]; attention_weights.attention_output_weight.kernel = weights_ptr[6]; - attention_weights.attention_output_weight.bias = weights_ptr[7]; - attn_layernorm_weights.gamma = weights_ptr[8]; - attn_layernorm_weights.beta = weights_ptr[9]; - ffn_weights.intermediate_weight.kernel = weights_ptr[10]; - ffn_weights.intermediate_weight.bias = weights_ptr[11]; - ffn_weights.output_weight.kernel = weights_ptr[12]; - ffn_weights.output_weight.bias = weights_ptr[13]; - ffn_layernorm_weights.gamma = weights_ptr[14]; - ffn_layernorm_weights.beta = weights_ptr[15]; - - scale_list_.d_scale_list_ = scale_list_ptr[0]; - scale_list_.h_scale_list_ = scale_list_ptr[1]; + attention_weights.attention_output_weight.bias = weights_ptr[7]; + attn_layernorm_weights.gamma = weights_ptr[8]; + attn_layernorm_weights.beta = weights_ptr[9]; + ffn_weights.intermediate_weight.kernel = weights_ptr[10]; + ffn_weights.intermediate_weight.bias = weights_ptr[11]; + ffn_weights.output_weight.kernel = weights_ptr[12]; + ffn_weights.output_weight.bias = weights_ptr[13]; + ffn_layernorm_weights.gamma = weights_ptr[14]; + ffn_layernorm_weights.beta = weights_ptr[15]; + + scale_list_.d_scale_list_ = scale_list_ptr[0]; + scale_list_.h_scale_list_ = scale_list_ptr[1]; attention_weights.scale_list_ptr = &scale_list_; - ffn_weights.scale_list_ptr = &scale_list_; + ffn_weights.scale_list_ptr = &scale_list_; is_maintain_buffer = true; } - int embed_dim_; - int inter_size_; - bool is_maintain_buffer = false; - T* weights_ptr[16]; + int embed_dim_; + int inter_size_; + int layer_idx_; + bool is_maintain_buffer = false; + T* weights_ptr[WEIGHT_N]{nullptr}; + size_t weights_size[WEIGHT_N]; + bool is_maintain_sp_buffer = false; float* scale_list_ptr[2]; - T* sp_weights_ptr[6]; }; +#undef WEIGHT_N + } // namespace fastertransformer diff --git a/src/fastertransformer/models/xlnet/CMakeLists.txt b/src/fastertransformer/models/xlnet/CMakeLists.txt index f259b3c0b..234e709d1 100644 --- a/src/fastertransformer/models/xlnet/CMakeLists.txt +++ b/src/fastertransformer/models/xlnet/CMakeLists.txt @@ -18,7 +18,7 @@ add_library(Xlnet STATIC Xlnet.cc) set_property(TARGET Xlnet PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Xlnet PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Xlnet PUBLIC -lcudart xlnet_preprocess_kernels cublasMMWrapper - XlnetAttentionLayer FfnLayer layernorm_kernels) + XlnetAttentionLayer FfnLayer layernorm_kernels tensor) add_executable(xlnet_gemm xlnet_gemm.cc) -target_link_libraries(xlnet_gemm PUBLIC -lcublas -lcublasLt -lcudart xlnet_gemm_func xlnet_gemm_func memory_utils) +target_link_libraries(xlnet_gemm PUBLIC -lcublas -lcublasLt -lcudart xlnet_gemm_func xlnet_gemm_func memory_utils tensor) diff --git a/src/fastertransformer/models/xlnet/Xlnet.cc b/src/fastertransformer/models/xlnet/Xlnet.cc index 958361f90..695362e9a 100644 --- a/src/fastertransformer/models/xlnet/Xlnet.cc +++ b/src/fastertransformer/models/xlnet/Xlnet.cc @@ -31,7 +31,7 @@ void Xlnet::initialize() cublas_wrapper_, allocator_, is_free_buffer_after_forward_); - ffn_layer_ = new GeluFfnLayer(max_batch_size_, + ffn_layer_ = new GeluFfnLayer(max_batch_size_, max_seq_len_, head_num_, size_per_head_, @@ -44,17 +44,17 @@ void Xlnet::initialize() } template -Xlnet::Xlnet(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - float q_scaling, - cudaStream_t stream, +Xlnet::Xlnet(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): + IAllocator* allocator, + bool is_free_buffer_after_forward): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), max_batch_size_(max_batch_size), max_seq_len_(max_seq_len), @@ -95,11 +95,15 @@ template void Xlnet::allocateBuffer() { if (is_allocate_buffer_ == false) { - attn_mask_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); - seg_mat_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * 2, false); - attr_k_head_r_ = (T*)allocator_->malloc(sizeof(T) * max_seq_len_ * hidden_units_ * 2, false); - attn_out_buf_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); - output_fc2_ = (T*)allocator_->malloc(sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + attn_mask_ = + (T*)allocator_->reMalloc(attn_mask_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false); + seg_mat_ = + (T*)allocator_->reMalloc(seg_mat_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_ * 2, false); + attr_k_head_r_ = (T*)allocator_->reMalloc(attr_k_head_r_, sizeof(T) * max_seq_len_ * hidden_units_ * 2, false); + attn_out_buf_ = + (T*)allocator_->reMalloc(attn_out_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); + output_fc2_ = + (T*)allocator_->reMalloc(output_fc2_, sizeof(T) * max_batch_size_ * max_seq_len_ * hidden_units_, false); is_allocate_buffer_ = true; } } @@ -108,18 +112,18 @@ template void Xlnet::freeBuffer() { if (is_allocate_buffer_ == true) { - allocator_->free(attn_mask_); - allocator_->free(seg_mat_); - allocator_->free(attr_k_head_r_); - allocator_->free(output_fc2_); + allocator_->free((void**)(&attn_mask_)); + allocator_->free((void**)(&seg_mat_)); + allocator_->free((void**)(&attr_k_head_r_)); + allocator_->free((void**)(&output_fc2_)); is_allocate_buffer_ = false; } } template -void Xlnet::forward(std::vector* output_tensors, - const std::vector* input_tensors, +void Xlnet::forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* xlnet_layer_weights) { // input_tensors: @@ -131,13 +135,13 @@ void Xlnet::forward(std::vector* output_tensors, // out_tensor [batch_size, seq_len, hidden_units] const size_t request_batch_size = input_tensors->at(0).shape[0]; - const size_t request_seq_len = input_tensors->at(0).shape[1]; + const size_t request_seq_len = input_tensors->at(0).shape[1]; - T* input_ptr = (T*)input_tensors->at(0).data; + T* input_ptr = (T*)input_tensors->at(0).data; T* output_ptr = (T*)output_tensors->at(0).data; float* input_mask = (float*)input_tensors->at(1).data; - int* seg_id = (int*)input_tensors->at(2).data; + int* seg_id = (int*)input_tensors->at(2).data; FT_CHECK(input_tensors->size() == 3); FT_CHECK(isValidBatchSize(request_batch_size)); @@ -163,8 +167,8 @@ void Xlnet::forward(std::vector* output_tensors, DataType data_type = getTensorType(); for (uint i = 0; i < num_layer_; i++) { - const T* in_tensor = (const T*)(i == 0 ? input_ptr : output_ptr); - T* out_tensor = output_ptr; + const T* in_tensor = (const T*)(i == 0 ? input_ptr : output_ptr); + T* out_tensor = output_ptr; std::vector attn_input_tensors{ Tensor{MEMORY_GPU, @@ -217,6 +221,7 @@ void Xlnet::forward(std::vector* output_tensors, xlnet_layer_weights->at(i).ffn_weights.output_weight.bias, xlnet_layer_weights->at(i).ffn_layernorm_weights.gamma, xlnet_layer_weights->at(i).ffn_layernorm_weights.beta, + layernorm_eps_, request_batch_size * request_seq_len, hidden_units_, stream_); @@ -249,5 +254,8 @@ bool Xlnet::isValidSeqLen(size_t seq_len) template class Xlnet; template class Xlnet; +#ifdef ENABLE_BF16 +template class Xlnet<__nv_bfloat16>; +#endif } // namespace fastertransformer diff --git a/src/fastertransformer/models/xlnet/Xlnet.h b/src/fastertransformer/models/xlnet/Xlnet.h index 047ec660c..fe70b8fe9 100644 --- a/src/fastertransformer/models/xlnet/Xlnet.h +++ b/src/fastertransformer/models/xlnet/Xlnet.h @@ -34,16 +34,17 @@ class Xlnet: public BaseLayer { private: // buffer handling size_t max_batch_size_ = 0; - size_t max_seq_len_ = 0; + size_t max_seq_len_ = 0; // meta data - size_t head_num_; - size_t size_per_head_; - size_t inter_size_; - size_t hidden_units_; - size_t num_layer_; - float q_scaling_; - - bool is_allocate_buffer_ = false; + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t hidden_units_; + size_t num_layer_; + float q_scaling_; + constexpr static float layernorm_eps_ = 1e-6f; + + bool is_allocate_buffer_ = false; FfnLayer* ffn_layer_; void allocateBuffer(); @@ -66,24 +67,24 @@ class Xlnet: public BaseLayer { XlnetAttentionLayer* attention_layer_; public: - Xlnet(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - float q_scaling, - cudaStream_t stream, + Xlnet(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + float q_scaling, + cudaStream_t stream, cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); + IAllocator* allocator, + bool is_free_buffer_after_forward); Xlnet(Xlnet const& xlnet_layer); ~Xlnet(); - void forward(std::vector* output_tensors, - const std::vector* input_tensors, + void forward(std::vector* output_tensors, + const std::vector* input_tensors, const std::vector>* bert_layer_weights); }; diff --git a/src/fastertransformer/models/xlnet/XlnetLayerWeight.h b/src/fastertransformer/models/xlnet/XlnetLayerWeight.h index fc9fdaf85..5b3518d14 100644 --- a/src/fastertransformer/models/xlnet/XlnetLayerWeight.h +++ b/src/fastertransformer/models/xlnet/XlnetLayerWeight.h @@ -27,9 +27,9 @@ namespace fastertransformer { template struct XlnetLayerWeight { XlnetAttentionWeight attention_weights; - LayerNormWeight attn_layernorm_weights; - FfnWeight ffn_weights; - LayerNormWeight ffn_layernorm_weights; + LayerNormWeight attn_layernorm_weights; + FfnWeight ffn_weights; + LayerNormWeight ffn_layernorm_weights; XlnetLayerWeight() = default; XlnetLayerWeight(const int hidden_units, const int inter_size): hidden_units_(hidden_units), inter_size_(inter_size) @@ -49,25 +49,25 @@ struct XlnetLayerWeight { deviceFree(weights_ptr[i]); } - attention_weights.attr_kernel_Q = nullptr; - attention_weights.attr_kernel_K = nullptr; - attention_weights.attr_kernel_V = nullptr; - attention_weights.attr_pos_emb = nullptr; - attention_weights.attr_bias_Q_w = nullptr; - attention_weights.attr_bias_Q_r = nullptr; - attention_weights.attr_bias_Q_s = nullptr; + attention_weights.attr_kernel_Q = nullptr; + attention_weights.attr_kernel_K = nullptr; + attention_weights.attr_kernel_V = nullptr; + attention_weights.attr_pos_emb = nullptr; + attention_weights.attr_bias_Q_w = nullptr; + attention_weights.attr_bias_Q_r = nullptr; + attention_weights.attr_bias_Q_s = nullptr; attention_weights.attr_seg_embed = nullptr; - attention_weights.attr_proj_o = nullptr; + attention_weights.attr_proj_o = nullptr; - attn_layernorm_weights.gamma = nullptr; - attn_layernorm_weights.beta = nullptr; + attn_layernorm_weights.gamma = nullptr; + attn_layernorm_weights.beta = nullptr; ffn_weights.intermediate_weight.kernel = nullptr; - ffn_weights.intermediate_weight.bias = nullptr; - ffn_weights.output_weight.kernel = nullptr; - ffn_weights.output_weight.bias = nullptr; - ffn_layernorm_weights.gamma = nullptr; - ffn_layernorm_weights.beta = nullptr; - is_maintain_buffer = false; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + ffn_layernorm_weights.gamma = nullptr; + ffn_layernorm_weights.beta = nullptr; + is_maintain_buffer = false; } } @@ -84,7 +84,7 @@ struct XlnetLayerWeight { XlnetLayerWeight& operator=(const XlnetLayerWeight& other) { hidden_units_ = other.hidden_units_; - inter_size_ = other.inter_size_; + inter_size_ = other.inter_size_; setWeightSize(); for (int i = 0; i < NUM_WEIGHTS; i++) { deviceMalloc(&weights_ptr[i], weights_size[i]); @@ -106,40 +106,40 @@ struct XlnetLayerWeight { private: void setWeightPtr() { - attention_weights.attr_kernel_Q = weights_ptr[0]; - attention_weights.attr_kernel_K = weights_ptr[0] + hidden_units_ * hidden_units_; - attention_weights.attr_kernel_V = weights_ptr[0] + hidden_units_ * hidden_units_ * 2; - attention_weights.attr_pos_emb = weights_ptr[1]; - attention_weights.attr_bias_Q_w = weights_ptr[2]; - attention_weights.attr_bias_Q_r = weights_ptr[3]; - attention_weights.attr_bias_Q_s = weights_ptr[4]; + attention_weights.attr_kernel_Q = weights_ptr[0]; + attention_weights.attr_kernel_K = weights_ptr[0] + hidden_units_ * hidden_units_; + attention_weights.attr_kernel_V = weights_ptr[0] + hidden_units_ * hidden_units_ * 2; + attention_weights.attr_pos_emb = weights_ptr[1]; + attention_weights.attr_bias_Q_w = weights_ptr[2]; + attention_weights.attr_bias_Q_r = weights_ptr[3]; + attention_weights.attr_bias_Q_s = weights_ptr[4]; attention_weights.attr_seg_embed = weights_ptr[5]; - attention_weights.attr_proj_o = weights_ptr[6]; + attention_weights.attr_proj_o = weights_ptr[6]; - attn_layernorm_weights.gamma = weights_ptr[7]; - attn_layernorm_weights.beta = weights_ptr[8]; + attn_layernorm_weights.gamma = weights_ptr[7]; + attn_layernorm_weights.beta = weights_ptr[8]; ffn_weights.intermediate_weight.kernel = weights_ptr[9]; - ffn_weights.intermediate_weight.bias = weights_ptr[10]; - ffn_weights.output_weight.kernel = weights_ptr[11]; - ffn_weights.output_weight.bias = weights_ptr[12]; - ffn_layernorm_weights.gamma = weights_ptr[13]; - ffn_layernorm_weights.beta = weights_ptr[14]; + ffn_weights.intermediate_weight.bias = weights_ptr[10]; + ffn_weights.output_weight.kernel = weights_ptr[11]; + ffn_weights.output_weight.bias = weights_ptr[12]; + ffn_layernorm_weights.gamma = weights_ptr[13]; + ffn_layernorm_weights.beta = weights_ptr[14]; is_maintain_buffer = true; } void setWeightSize() { - weights_size[0] = hidden_units_ * hidden_units_ * 3; - weights_size[1] = hidden_units_ * hidden_units_; - weights_size[2] = hidden_units_; - weights_size[3] = hidden_units_; - weights_size[4] = hidden_units_; - weights_size[5] = hidden_units_ * 2; - weights_size[6] = hidden_units_ * hidden_units_; - weights_size[7] = hidden_units_; - weights_size[8] = hidden_units_; - weights_size[9] = hidden_units_ * inter_size_; + weights_size[0] = hidden_units_ * hidden_units_ * 3; + weights_size[1] = hidden_units_ * hidden_units_; + weights_size[2] = hidden_units_; + weights_size[3] = hidden_units_; + weights_size[4] = hidden_units_; + weights_size[5] = hidden_units_ * 2; + weights_size[6] = hidden_units_ * hidden_units_; + weights_size[7] = hidden_units_; + weights_size[8] = hidden_units_; + weights_size[9] = hidden_units_ * inter_size_; weights_size[10] = inter_size_; weights_size[11] = hidden_units_ * inter_size_; weights_size[12] = hidden_units_; @@ -147,11 +147,11 @@ struct XlnetLayerWeight { weights_size[14] = hidden_units_; } - int hidden_units_; - int inter_size_; + int hidden_units_; + int inter_size_; bool is_maintain_buffer = false; - T* weights_ptr[NUM_WEIGHTS]; - int weights_size[NUM_WEIGHTS]; + T* weights_ptr[NUM_WEIGHTS]; + int weights_size[NUM_WEIGHTS]; }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/xlnet/xlnet_gemm.cc b/src/fastertransformer/models/xlnet/xlnet_gemm.cc index bedda97ec..92aa651ce 100644 --- a/src/fastertransformer/models/xlnet/xlnet_gemm.cc +++ b/src/fastertransformer/models/xlnet/xlnet_gemm.cc @@ -28,11 +28,11 @@ int main(int argc, char* argv[]) return 0; } - const int batch_size = atoi(argv[1]); - const int seq_len = atoi(argv[2]); - const int head_num = atoi(argv[3]); - const int size_per_head = atoi(argv[4]); - const ft::CublasDataType data_type = static_cast(atoi(argv[5])); // 0 FP32, 1 FP16, 2 BF 16 + const int batch_size = atoi(argv[1]); + const int seq_len = atoi(argv[2]); + const int head_num = atoi(argv[3]); + const int size_per_head = atoi(argv[4]); + const ft::CublasDataType data_type = static_cast(atoi(argv[5])); // 0 FP32, 1 FP16, 2 BF 16 printf("[INFO] arguments: \n"); printf(" batch_size: %d \n", batch_size); printf(" head_num: %d \n", head_num); @@ -41,9 +41,9 @@ int main(int argc, char* argv[]) std::cout << std::endl; int hidden_units_ = size_per_head * head_num; - int inter_size_ = 4 * hidden_units_; + int inter_size_ = 4 * hidden_units_; - void* gemm_test_buf; + void* gemm_test_buf; size_t buf_size_in_byte = ft::calGemmTestBufSizeInByteXlnet( batch_size, seq_len, head_num, size_per_head, inter_size_, hidden_units_, data_type); size_t total, free; diff --git a/src/fastertransformer/tensorrt_plugin/swin/serialize.hpp b/src/fastertransformer/tensorrt_plugin/swin/serialize.hpp index be49fc26f..9b79e9b3e 100644 --- a/src/fastertransformer/tensorrt_plugin/swin/serialize.hpp +++ b/src/fastertransformer/tensorrt_plugin/swin/serialize.hpp @@ -73,7 +73,7 @@ struct Serializer { } static void deserialize(void const** buffer, size_t* buffer_size, const char** value) { - *value = static_cast(*buffer); + *value = static_cast(*buffer); size_t data_size = strnlen(*value, *buffer_size) + 1; assert(*buffer_size >= data_size); reinterpret_cast(*buffer) += data_size; diff --git a/src/fastertransformer/tensorrt_plugin/swin/swinTransformerINT8Plugin.cpp b/src/fastertransformer/tensorrt_plugin/swin/swinTransformerINT8Plugin.cpp index 0a0fe068b..56ffb7d3c 100644 --- a/src/fastertransformer/tensorrt_plugin/swin/swinTransformerINT8Plugin.cpp +++ b/src/fastertransformer/tensorrt_plugin/swin/swinTransformerINT8Plugin.cpp @@ -32,29 +32,29 @@ using namespace std; namespace fastertransformer { // Static class fields initialization -PluginFieldCollection SwinTransformerINT8PluginCreator::mFC{}; +PluginFieldCollection SwinTransformerINT8PluginCreator::mFC{}; std::vector SwinTransformerINT8PluginCreator::mPluginAttributes; REGISTER_TENSORRT_PLUGIN(SwinTransformerINT8PluginCreator); template -SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, - const int int8_mode, - const int max_batch_size, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, - const std::vector& w, +SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, + const int int8_mode, + const int max_batch_size, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, + const std::vector& w, const std::vector& d_amax, const std::vector& h_amax): int8_mode_(int8_mode), @@ -75,7 +75,7 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, check_cuda_error(cublasCreate(&cublas_handle_)); check_cuda_error(cublasLtCreate(&cublaslt_handle_)); checkCUDNN(cudnnCreate(&cudnn_handle_)); - sm_ = getSMVersion(); + sm_ = getSMVersion(); bool _use_ORDER_COL32_2R_4R4 = false; #if (CUDART_VERSION >= 11000) if (sm_ >= 80) { @@ -91,7 +91,7 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, exit(-1); } - depths_ = (int*)malloc(layer_num * sizeof(int)); + depths_ = (int*)malloc(layer_num * sizeof(int)); num_heads_ = (int*)malloc(layer_num * sizeof(int)); memcpy(depths_, depths, layer_num * sizeof(int)); memcpy(num_heads_, num_heads, layer_num * sizeof(int)); @@ -103,7 +103,7 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, weight_size_, layer_num, embed_dim, mlp_ratio, window_size, img_size, patch_size, in_chans, depths, num_heads); int weight_idx = 0; - int amax_idx = 0; + int amax_idx = 0; int hidden_dim = embed_dim; for (int l = 0; l < layer_num; l++) { SwinTransformerINT8BasicLayerWeight bl; @@ -139,10 +139,10 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, weight_idx++; p.ffn_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx], weight_size_, weight_idx); weight_idx++; - p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; - p.scalelist.p2_offset_ = ACTIVATION_AMAX_NUM; - p.scalelist.p3_offset_ = ACTIVATION_AMAX_NUM + 5; - p.scalelist.p4_offset_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM; + p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; + p.scalelist.p2_offset_ = ACTIVATION_AMAX_NUM; + p.scalelist.p3_offset_ = ACTIVATION_AMAX_NUM + 5; + p.scalelist.p4_offset_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM; p.scalelist.d_scale_list_ = d_amaxCopy(d_amaxlist_, d_amax[amax_idx], 96, amax_idx); p.scalelist.h_scale_list_ = h_amaxCopy(h_amaxlist_, h_amax[amax_idx], 96, amax_idx); amax_idx++; @@ -177,9 +177,9 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, params_.norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx], weight_size_, weight_idx); weight_idx++; - cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); + cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); cublasWrapperMutex_ = new std::mutex(); - allocator_ = new Allocator(getDevice()); + allocator_ = new Allocator(getDevice()); cublasINT8MMWrapper* cublas_wrapper = new cublasINT8MMWrapper( cublas_handle_, cublaslt_handle_, nullptr, cublasAlgoMap_, cublasWrapperMutex_, _use_ORDER_COL32_2R_4R4); @@ -213,22 +213,22 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, } template -SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, - const int int8_mode, - const int max_batch_size, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, +SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, + const int int8_mode, + const int max_batch_size, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, const std::vector& w, const std::vector& d_amax, const std::vector& h_amax): @@ -252,7 +252,7 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, check_cuda_error(cublasLtCreate(&cublaslt_handle_)); checkCUDNN(cudnnCreate(&cudnn_handle_)); - sm_ = getSMVersion(); + sm_ = getSMVersion(); bool _use_ORDER_COL32_2R_4R4 = false; #if (CUDART_VERSION >= 11000) if (sm_ >= 80) { @@ -268,7 +268,7 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, exit(-1); } - depths_ = (int*)malloc(layer_num * sizeof(int)); + depths_ = (int*)malloc(layer_num * sizeof(int)); num_heads_ = (int*)malloc(layer_num * sizeof(int)); memcpy(depths_, depths, layer_num * sizeof(int)); memcpy(num_heads_, num_heads, layer_num * sizeof(int)); @@ -276,27 +276,27 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, output_dim_ = int(pow(2, layer_num - 1)) * embed_dim; int weight_idx = 0; - int amax_idx = 0; + int amax_idx = 0; for (int l = 0; l < layer_num; l++) { SwinTransformerINT8BasicLayerWeight bl; for (int di = 0; di < depths[l]; di++) { SwinTransformerINT8BlockWeight p; - p.attention_weights.query_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.attention_weights.query_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attention_weights.query_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attention_weights.query_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); p.attention_weights.attention_output_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.attention_weights.attention_output_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_weights.intermediate_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_weights.intermediate_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_weights.output_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_weights.output_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.attn_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.attn_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; - p.scalelist.p2_offset_ = ACTIVATION_AMAX_NUM; - p.scalelist.p3_offset_ = ACTIVATION_AMAX_NUM + 5; - p.scalelist.p4_offset_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM; + p.attention_weights.attention_output_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_weights.intermediate_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_weights.intermediate_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_weights.output_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_weights.output_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attn_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attn_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; + p.scalelist.p2_offset_ = ACTIVATION_AMAX_NUM; + p.scalelist.p3_offset_ = ACTIVATION_AMAX_NUM + 5; + p.scalelist.p4_offset_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM; p.scalelist.d_scale_list_ = d_amaxCopy(d_amaxlist_, d_amax[amax_idx]); p.scalelist.h_scale_list_ = h_amaxCopy(h_amaxlist_, h_amax[amax_idx]); amax_idx++; @@ -304,21 +304,21 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, bl.block_weight_list.push_back(p); } bl.merge_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - bl.merge_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - bl.merge_linear_weights.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - bl.attn_mask = cudaMallocAndCopy(weights_, w[weight_idx++]); + bl.merge_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + bl.merge_linear_weights.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); + bl.attn_mask = cudaMallocAndCopy(weights_, w[weight_idx++]); params_.basic_layer_weight_list.push_back(bl); } params_.patchEmbed_linear_weights.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.patchEmbed_linear_weights.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.patchEmbed_norm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.patchEmbed_norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.norm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.patchEmbed_linear_weights.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.patchEmbed_norm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.patchEmbed_norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.norm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); + cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); cublasWrapperMutex_ = new std::mutex(); - allocator_ = new Allocator(getDevice()); + allocator_ = new Allocator(getDevice()); cublasINT8MMWrapper* cublas_wrapper = new cublasINT8MMWrapper( cublas_handle_, cublaslt_handle_, nullptr, cublasAlgoMap_, cublasWrapperMutex_, _use_ORDER_COL32_2R_4R4); @@ -359,7 +359,7 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, check_cuda_error(cublasLtCreate(&cublaslt_handle_)); checkCUDNN(cudnnCreate(&cudnn_handle_)); - sm_ = getSMVersion(); + sm_ = getSMVersion(); bool _use_ORDER_COL32_2R_4R4 = false; #if (CUDART_VERSION >= 11000) if (sm_ >= 80) { @@ -388,8 +388,8 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, weight_size_.push_back(tmp); } - depths_ = (int*)malloc(layer_num_ * sizeof(int)); - num_heads_ = (int*)malloc(layer_num_ * sizeof(int)); + depths_ = (int*)malloc(layer_num_ * sizeof(int)); + num_heads_ = (int*)malloc(layer_num_ * sizeof(int)); const char* d = static_cast(data); memcpy(depths_, d, layer_num_ * sizeof(int)); d = d + layer_num_ * sizeof(int); @@ -428,27 +428,27 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, } int weight_idx = 0; - int amax_idx = 0; + int amax_idx = 0; for (int l = 0; l < layer_num_; l++) { SwinTransformerINT8BasicLayerWeight bl; for (int di = 0; di < depths_[l]; di++) { SwinTransformerINT8BlockWeight p; - p.attention_weights.query_weight.kernel = weights_[weight_idx++]; - p.attention_weights.query_weight.bias = weights_[weight_idx++]; + p.attention_weights.query_weight.kernel = weights_[weight_idx++]; + p.attention_weights.query_weight.bias = weights_[weight_idx++]; p.attention_weights.attention_output_weight.kernel = weights_[weight_idx++]; - p.attention_weights.attention_output_weight.bias = weights_[weight_idx++]; - p.ffn_weights.intermediate_weight.kernel = weights_[weight_idx++]; - p.ffn_weights.intermediate_weight.bias = weights_[weight_idx++]; - p.ffn_weights.output_weight.kernel = weights_[weight_idx++]; - p.ffn_weights.output_weight.bias = weights_[weight_idx++]; - p.attn_layernorm_weights.gamma = weights_[weight_idx++]; - p.attn_layernorm_weights.beta = weights_[weight_idx++]; - p.ffn_layernorm_weights.gamma = weights_[weight_idx++]; - p.ffn_layernorm_weights.beta = weights_[weight_idx++]; - p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; - p.scalelist.p2_offset_ = ACTIVATION_AMAX_NUM; - p.scalelist.p3_offset_ = ACTIVATION_AMAX_NUM + 5; - p.scalelist.p4_offset_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM; + p.attention_weights.attention_output_weight.bias = weights_[weight_idx++]; + p.ffn_weights.intermediate_weight.kernel = weights_[weight_idx++]; + p.ffn_weights.intermediate_weight.bias = weights_[weight_idx++]; + p.ffn_weights.output_weight.kernel = weights_[weight_idx++]; + p.ffn_weights.output_weight.bias = weights_[weight_idx++]; + p.attn_layernorm_weights.gamma = weights_[weight_idx++]; + p.attn_layernorm_weights.beta = weights_[weight_idx++]; + p.ffn_layernorm_weights.gamma = weights_[weight_idx++]; + p.ffn_layernorm_weights.beta = weights_[weight_idx++]; + p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; + p.scalelist.p2_offset_ = ACTIVATION_AMAX_NUM; + p.scalelist.p3_offset_ = ACTIVATION_AMAX_NUM + 5; + p.scalelist.p4_offset_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM; p.scalelist.d_scale_list_ = d_amaxlist_[amax_idx]; p.scalelist.h_scale_list_ = h_amaxlist_[amax_idx]; amax_idx++; @@ -456,21 +456,21 @@ SwinTransformerINT8Plugin::SwinTransformerINT8Plugin(const std::string& name, bl.block_weight_list.push_back(p); } bl.merge_layernorm_weights.gamma = weights_[weight_idx++]; - bl.merge_layernorm_weights.beta = weights_[weight_idx++]; - bl.merge_linear_weights.kernel = weights_[weight_idx++]; - bl.attn_mask = weights_[weight_idx++]; + bl.merge_layernorm_weights.beta = weights_[weight_idx++]; + bl.merge_linear_weights.kernel = weights_[weight_idx++]; + bl.attn_mask = weights_[weight_idx++]; params_.basic_layer_weight_list.push_back(bl); } params_.patchEmbed_linear_weights.kernel = weights_[weight_idx++]; - params_.patchEmbed_linear_weights.bias = weights_[weight_idx++]; - params_.patchEmbed_norm_weights.gamma = weights_[weight_idx++]; - params_.patchEmbed_norm_weights.beta = weights_[weight_idx++]; - params_.norm_weights.gamma = weights_[weight_idx++]; - params_.norm_weights.beta = weights_[weight_idx++]; + params_.patchEmbed_linear_weights.bias = weights_[weight_idx++]; + params_.patchEmbed_norm_weights.gamma = weights_[weight_idx++]; + params_.patchEmbed_norm_weights.beta = weights_[weight_idx++]; + params_.norm_weights.gamma = weights_[weight_idx++]; + params_.norm_weights.beta = weights_[weight_idx++]; - cublasAlgoMap_ = new cublasAlgoMap(IGEMM_CONFIG, ""); + cublasAlgoMap_ = new cublasAlgoMap(IGEMM_CONFIG, ""); cublasWrapperMutex_ = new std::mutex(); - allocator_ = new Allocator(getDevice()); + allocator_ = new Allocator(getDevice()); cublasINT8MMWrapper* cublas_wrapper = new cublasINT8MMWrapper( cublas_handle_, cublaslt_handle_, nullptr, cublasAlgoMap_, cublasWrapperMutex_, _use_ORDER_COL32_2R_4R4); @@ -516,6 +516,7 @@ SwinTransformerINT8Plugin::~SwinTransformerINT8Plugin() delete h_amaxlist_[i]; } check_cuda_error(cublasDestroy(cublas_handle_)); + check_cuda_error(cublasLtDestroy(cublaslt_handle_)); checkCUDNN(cudnnDestroy(cudnn_handle_)); delete cublasWrapperMutex_; delete cublasAlgoMap_; @@ -555,41 +556,51 @@ nvinfer1::IPluginV2DynamicExt* SwinTransformerINT8Plugin::clone() const noexc } template -DimsExprs SwinTransformerINT8Plugin::getOutputDimensions(int outputIndex, +DimsExprs SwinTransformerINT8Plugin::getOutputDimensions(int outputIndex, const DimsExprs* inputs, - int nbInputs, - IExprBuilder& exprBuilder) noexcept + int nbInputs, + IExprBuilder& exprBuilder) noexcept { // Input is B*in_chans*H*W, output should be B*dim*1*1 for fc layer assert(outputIndex == 0); // Copy over everything DimsExprs output; output.nbDims = 4; - output.d[0] = inputs[0].d[0]; - output.d[1] = exprBuilder.constant(output_dim_); - output.d[2] = exprBuilder.constant(1); - output.d[3] = exprBuilder.constant(1); + output.d[0] = inputs[0].d[0]; + output.d[1] = exprBuilder.constant(output_dim_); + output.d[2] = exprBuilder.constant(1); + output.d[3] = exprBuilder.constant(1); return output; } template -bool SwinTransformerINT8Plugin::supportsFormatCombination(int pos, +bool SwinTransformerINT8Plugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) noexcept + int nbInputs, + int nbOutputs) noexcept { - assert(pos >= 0); - assert(pos < 2); + bool res = false; + assert(pos >= 0 && pos < 2); assert(nbInputs == 1); + switch (pos) { + case 0: // input + case 1: // output + res = (inOut[pos].type + == (std::is_same::value ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT)) + && (inOut[pos].format == nvinfer1::TensorFormat::kLINEAR); + break; + default: + break; + } - return true; + return res; } template void SwinTransformerINT8Plugin::configurePlugin(const DynamicPluginTensorDesc* in, - int nbInputs, + int nbInputs, const DynamicPluginTensorDesc* out, - int nbOutputs) noexcept + int nbOutputs) noexcept { assert(nbInputs == 1); assert(nbOutputs == 1); @@ -597,18 +608,18 @@ void SwinTransformerINT8Plugin::configurePlugin(const DynamicPluginTensorDesc template size_t SwinTransformerINT8Plugin::getWorkspaceSize(const PluginTensorDesc* inputs, - int nbInputs, + int nbInputs, const PluginTensorDesc* outputs, - int nbOutputs) const noexcept + int nbOutputs) const noexcept { return 0; } // IPluginV2Ext Methods template -nvinfer1::DataType SwinTransformerINT8Plugin::getOutputDataType(int index, +nvinfer1::DataType SwinTransformerINT8Plugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, - int nbInputs) const noexcept + int nbInputs) const noexcept { assert(index == 0); assert(inputTypes[0] == nvinfer1::DataType::kFLOAT || inputTypes[0] == nvinfer1::DataType::kHALF); @@ -702,23 +713,11 @@ void SwinTransformerINT8Plugin::serialize(void* buffer) const noexcept d += weight_size_[i] * sizeof(T); } for (int i = 0; i < d_amaxlist_.size(); i++) { - // float* ptr = (float*)d; - // printf("[[%d]]: ", i); - // for(int j = 0; j < 96; j ++){ - // printf("%.6f ", ptr[j]); - // } - // printf("\n"); check_cuda_error(cudaMemcpy(d, d_amaxlist_[i], 96 * sizeof(float), cudaMemcpyDeviceToHost)); d += 96 * sizeof(float); } for (int i = 0; i < h_amaxlist_.size(); i++) { - // float* ptr = h_amaxlist_[i]; - // printf("[[%d]]: ", i); - // for(int j = 0; j < 96; j ++){ - // printf("%.6f ", ptr[j]); - // } - // printf("\n"); check_cuda_error(cudaMemcpy(d, h_amaxlist_[i], 96 * sizeof(float), cudaMemcpyHostToHost)); d += 96 * sizeof(float); } @@ -745,10 +744,10 @@ const char* SwinTransformerINT8Plugin::getPluginNamespace() const noexcept template int SwinTransformerINT8Plugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept { int batch_size = inputDesc->dims.d[0]; assert(batch_size <= max_batch_size_); @@ -756,7 +755,7 @@ int SwinTransformerINT8Plugin::enqueue(const PluginTensorDesc* inputDesc, assert(img_size_ == inputDesc->dims.d[2]); assert(img_size_ == inputDesc->dims.d[3]); - int sm_ptr[1] = {sm_}; + int sm_ptr[1] = {sm_}; std::vector input_tensors = std::vector{ Tensor{MEMORY_GPU, getTensorType(), @@ -777,7 +776,7 @@ int SwinTransformerINT8Plugin::enqueue(const PluginTensorDesc* inputDesc, SwinTransformerINT8PluginCreator::SwinTransformerINT8PluginCreator() { mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); + mFC.fields = mPluginAttributes.data(); } const char* SwinTransformerINT8PluginCreator::getPluginName() const noexcept @@ -797,21 +796,21 @@ const PluginFieldCollection* SwinTransformerINT8PluginCreator::getFieldNames() n IPluginV2* SwinTransformerINT8PluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept { - int int8_mode; - int max_batch_size; - int img_size; - int patch_size; - int in_chans; - int embed_dim; - int window_size; - int* depths = nullptr; - int* num_heads = nullptr; - bool ape; - bool patch_norm; - int layer_num; - float mlp_ratio; - bool qkv_bias; - float qk_scale; + int int8_mode; + int max_batch_size; + int img_size; + int patch_size; + int in_chans; + int embed_dim; + int window_size; + int* depths = nullptr; + int* num_heads = nullptr; + bool ape; + bool patch_norm; + int layer_num; + float mlp_ratio; + bool qkv_bias; + float qk_scale; std::vector w; std::vector d_amax; std::vector h_amax; @@ -995,16 +994,16 @@ IPluginV2* SwinTransformerINT8PluginCreator::createPlugin(const char* name, cons return p; } else { - printf("[ERROR][SwinTransformerINT8PluginCreator::createPlugin] unsupport datatype.\n"); + printf("[ERROR][SwinTransformerINT8PluginCreator::createPlugin] unsupported datatype.\n"); exit(-1); } } IPluginV2* SwinTransformerINT8PluginCreator::deserializePlugin(const char* name, const void* serialData, - size_t serialLength) noexcept + size_t serialLength) noexcept { - int type_id; + int type_id; size_t int_length = sizeof(int); deserialize_value(&serialData, &int_length, &type_id); // This object will be deleted when the network is destroyed, which will @@ -1014,7 +1013,7 @@ IPluginV2* SwinTransformerINT8PluginCreator::deserializePlugin(const char* name, else if (type_id == 1) return new SwinTransformerINT8Plugin(name, serialData, serialLength); else { - printf("[ERROR][SwinTransformerINT8PluginCreator::deserializePlugin] unsupport data type %d\n", type_id); + printf("[ERROR][SwinTransformerINT8PluginCreator::deserializePlugin] unsupported data type %d\n", type_id); exit(-1); } } diff --git a/src/fastertransformer/tensorrt_plugin/swin/swinTransformerINT8Plugin.h b/src/fastertransformer/tensorrt_plugin/swin/swinTransformerINT8Plugin.h index b999c1100..30db5db05 100644 --- a/src/fastertransformer/tensorrt_plugin/swin/swinTransformerINT8Plugin.h +++ b/src/fastertransformer/tensorrt_plugin/swin/swinTransformerINT8Plugin.h @@ -62,76 +62,76 @@ template class SwinTransformerINT8Plugin: public nvinfer1::IPluginV2DynamicExt { private: const std::string layer_name_; - std::string namespace_; - - std::vector weights_; // in device memory - std::vector d_amaxlist_; // in device memory - std::vector h_amaxlist_; // in host memory - std::vector weight_size_; - cublasHandle_t cublas_handle_ = nullptr; - cublasLtHandle_t cublaslt_handle_ = nullptr; - cudnnHandle_t cudnn_handle_ = nullptr; - SwinTransformerINT8Weight params_; - SwinTransformerINT8* swin_transformer_ = nullptr; - std::mutex* cublasWrapperMutex_ = nullptr; - cublasAlgoMap* cublasAlgoMap_ = nullptr; - fastertransformer::Allocator* allocator_ = nullptr; - int int8_mode_; - int output_dim_; - int weight_num_; - int max_batch_size_; - int img_size_; - int patch_size_; - int in_chans_; - int embed_dim_; - int window_size_; - bool ape_; - int patch_norm_; - int layer_num_; - float mlp_ratio_; - bool qkv_bias_; - float qk_scale_; - int* depths_; - int* num_heads_; + std::string namespace_; + + std::vector weights_; // in device memory + std::vector d_amaxlist_; // in device memory + std::vector h_amaxlist_; // in host memory + std::vector weight_size_; + cublasHandle_t cublas_handle_ = nullptr; + cublasLtHandle_t cublaslt_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; + SwinTransformerINT8Weight params_; + SwinTransformerINT8* swin_transformer_ = nullptr; + std::mutex* cublasWrapperMutex_ = nullptr; + cublasAlgoMap* cublasAlgoMap_ = nullptr; + fastertransformer::Allocator* allocator_ = nullptr; + int int8_mode_; + int output_dim_; + int weight_num_; + int max_batch_size_; + int img_size_; + int patch_size_; + int in_chans_; + int embed_dim_; + int window_size_; + bool ape_; + int patch_norm_; + int layer_num_; + float mlp_ratio_; + bool qkv_bias_; + float qk_scale_; + int* depths_; + int* num_heads_; public: int sm_; - SwinTransformerINT8Plugin(const std::string& name, - const int int8_mode, - const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, - const std::vector& w, + SwinTransformerINT8Plugin(const std::string& name, + const int int8_mode, + const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, + const std::vector& w, const std::vector& d_amax, const std::vector& h_amax); - SwinTransformerINT8Plugin(const std::string& name, - const int int8_mode, - const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, + SwinTransformerINT8Plugin(const std::string& name, + const int int8_mode, + const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, const std::vector& w, const std::vector& d_amax, const std::vector& h_amax); @@ -146,28 +146,28 @@ class SwinTransformerINT8Plugin: public nvinfer1::IPluginV2DynamicExt { // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, - const nvinfer1::DimsExprs* inputs, - int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) noexcept override; - bool supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) noexcept override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) noexcept override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const noexcept override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType @@ -176,19 +176,19 @@ class SwinTransformerINT8Plugin: public nvinfer1::IPluginV2DynamicExt { // IPluginV2 Methods const char* getPluginType() const noexcept override; const char* getPluginVersion() const noexcept override; - int getNbOutputs() const noexcept override; - int initialize() noexcept override; - void terminate() noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void* buffer) const noexcept override; - void destroy() noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override; + int getNbOutputs() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(const char* pluginNamespace) noexcept override; const char* getPluginNamespace() const noexcept override; // Host To Device static T* cudaMallocAndCopy(vector& weights, const nvinfer1::Weights& w) { - T* dpWeight; + T* dpWeight; size_t nValue = w.count; check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T))); check_cuda_error(cudaMemcpy(dpWeight, w.values, nValue * sizeof(T), cudaMemcpyHostToDevice)); @@ -198,13 +198,13 @@ class SwinTransformerINT8Plugin: public nvinfer1::IPluginV2DynamicExt { } // Device to Device - static T* cudaMallocAndCopy(vector& weights, - const T* dpWeightOld, + static T* cudaMallocAndCopy(vector& weights, + const T* dpWeightOld, const vector& weight_size, - const int weight_idx, - bool is_float = false) + const int weight_idx, + bool is_float = false) { - T* dpWeight; + T* dpWeight; size_t nValue = weight_size[weight_idx]; check_cuda_error(cudaMalloc((void**)&dpWeight, nValue * sizeof(T))); check_cuda_error(cudaMemcpy(dpWeight, dpWeightOld, nValue * sizeof(T), cudaMemcpyDeviceToDevice)); @@ -302,8 +302,8 @@ class SwinTransformerINT8PluginCreator: public nvinfer1::IPluginCreator { if (field_name.compare(name) == 0) { nvinfer1::Weights tmp; tmp.values = fc->fields[i].data; - tmp.count = fc->fields[i].length; - tmp.type = fieldTypeToDataType(fc->fields[i].type); + tmp.count = fc->fields[i].length; + tmp.type = fieldTypeToDataType(fc->fields[i].type); w.push_back(tmp); break; } @@ -311,9 +311,9 @@ class SwinTransformerINT8PluginCreator: public nvinfer1::IPluginCreator { } private: - static nvinfer1::PluginFieldCollection mFC; + static nvinfer1::PluginFieldCollection mFC; static std::vector mPluginAttributes; - std::string namespace_; + std::string namespace_; }; } // namespace fastertransformer diff --git a/src/fastertransformer/tensorrt_plugin/swin/swinTransformerPlugin.cpp b/src/fastertransformer/tensorrt_plugin/swin/swinTransformerPlugin.cpp index 3248c6b98..be13eb68d 100644 --- a/src/fastertransformer/tensorrt_plugin/swin/swinTransformerPlugin.cpp +++ b/src/fastertransformer/tensorrt_plugin/swin/swinTransformerPlugin.cpp @@ -32,27 +32,27 @@ using namespace std; namespace fastertransformer { // Static class fields initialization -PluginFieldCollection SwinTransformerPluginCreator::mFC{}; +PluginFieldCollection SwinTransformerPluginCreator::mFC{}; std::vector SwinTransformerPluginCreator::mPluginAttributes; REGISTER_TENSORRT_PLUGIN(SwinTransformerPluginCreator); template -SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, - const int max_batch_size, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, +SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, + const int max_batch_size, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, const std::vector& w): layer_name_(name), max_batch_size_(max_batch_size), @@ -81,7 +81,7 @@ SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, exit(-1); } - depths_ = (int*)malloc(layer_num * sizeof(int)); + depths_ = (int*)malloc(layer_num * sizeof(int)); num_heads_ = (int*)malloc(layer_num * sizeof(int)); memcpy(depths_, depths, layer_num * sizeof(int)); memcpy(num_heads_, num_heads, layer_num * sizeof(int)); @@ -154,9 +154,9 @@ SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, params_.norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx], weight_size_, weight_idx); weight_idx++; - cublasAlgoMap_ = new cublasAlgoMap(GEMM_CONFIG, ""); + cublasAlgoMap_ = new cublasAlgoMap(GEMM_CONFIG, ""); cublasWrapperMutex_ = new std::mutex(); - allocator_ = new Allocator(getDevice()); + allocator_ = new Allocator(getDevice()); cublasMMWrapper* cublas_wrapper = new cublasMMWrapper(cublas_handle_, cublaslt_handle_, nullptr, cublasAlgoMap_, cublasWrapperMutex_, nullptr); @@ -189,21 +189,21 @@ SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, } template -SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, - const int max_batch_size, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, +SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, + const int max_batch_size, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, const std::vector& w): layer_name_(name), max_batch_size_(max_batch_size), @@ -224,7 +224,7 @@ SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, check_cuda_error(cublasLtCreate(&cublaslt_handle_)); checkCUDNN(cudnnCreate(&cudnn_handle_)); - sm_ = getSMVersion(); + sm_ = getSMVersion(); weight_num_ = getWeightNum(layer_num, depths); if (weight_num_ != w.size()) { printf("[ERROR][SwinTransformerPlugin] weights number %lu does not match expected number %d!\n", @@ -233,7 +233,7 @@ SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, exit(-1); } - depths_ = (int*)malloc(layer_num * sizeof(int)); + depths_ = (int*)malloc(layer_num * sizeof(int)); num_heads_ = (int*)malloc(layer_num * sizeof(int)); memcpy(depths_, depths, layer_num * sizeof(int)); memcpy(num_heads_, num_heads, layer_num * sizeof(int)); @@ -245,37 +245,37 @@ SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, SwinTransformerBasicLayerWeight bl; for (int di = 0; di < depths[l]; di++) { SwinTransformerBlockWeight p; - p.attention_weights.query_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.attention_weights.query_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attention_weights.query_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attention_weights.query_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); p.attention_weights.attention_output_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.attention_weights.attention_output_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_weights.intermediate_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_weights.intermediate_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_weights.output_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_weights.output_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.attn_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.attn_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.ffn_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - p.attention_relative_pos_bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attention_weights.attention_output_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_weights.intermediate_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_weights.intermediate_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_weights.output_weight.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_weights.output_weight.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attn_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attn_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.ffn_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + p.attention_relative_pos_bias = cudaMallocAndCopy(weights_, w[weight_idx++]); bl.block_weight_list.push_back(p); } bl.merge_layernorm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - bl.merge_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - bl.merge_linear_weights.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - bl.attn_mask = cudaMallocAndCopy(weights_, w[weight_idx++]); + bl.merge_layernorm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + bl.merge_linear_weights.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); + bl.attn_mask = cudaMallocAndCopy(weights_, w[weight_idx++]); params_.basic_layer_weight_list.push_back(bl); } params_.patchEmbed_linear_weights.kernel = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.patchEmbed_linear_weights.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.patchEmbed_norm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.patchEmbed_norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.norm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); - params_.norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.patchEmbed_linear_weights.bias = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.patchEmbed_norm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.patchEmbed_norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.norm_weights.gamma = cudaMallocAndCopy(weights_, w[weight_idx++]); + params_.norm_weights.beta = cudaMallocAndCopy(weights_, w[weight_idx++]); - cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); + cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); cublasWrapperMutex_ = new std::mutex(); - allocator_ = new Allocator(getDevice()); + allocator_ = new Allocator(getDevice()); cublasMMWrapper* cublas_wrapper = new cublasMMWrapper(cublas_handle_, cublaslt_handle_, nullptr, cublasAlgoMap_, cublasWrapperMutex_, nullptr); @@ -336,8 +336,8 @@ SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, const v weight_size_.push_back(tmp); } - depths_ = (int*)malloc(layer_num_ * sizeof(int)); - num_heads_ = (int*)malloc(layer_num_ * sizeof(int)); + depths_ = (int*)malloc(layer_num_ * sizeof(int)); + num_heads_ = (int*)malloc(layer_num_ * sizeof(int)); const char* d = static_cast(data); memcpy(depths_, d, layer_num_ * sizeof(int)); d = d + layer_num_ * sizeof(int); @@ -356,37 +356,37 @@ SwinTransformerPlugin::SwinTransformerPlugin(const std::string& name, const v SwinTransformerBasicLayerWeight bl; for (int di = 0; di < depths_[l]; di++) { SwinTransformerBlockWeight p; - p.attention_weights.query_weight.kernel = weights_[weight_idx++]; - p.attention_weights.query_weight.bias = weights_[weight_idx++]; + p.attention_weights.query_weight.kernel = weights_[weight_idx++]; + p.attention_weights.query_weight.bias = weights_[weight_idx++]; p.attention_weights.attention_output_weight.kernel = weights_[weight_idx++]; - p.attention_weights.attention_output_weight.bias = weights_[weight_idx++]; - p.ffn_weights.intermediate_weight.kernel = weights_[weight_idx++]; - p.ffn_weights.intermediate_weight.bias = weights_[weight_idx++]; - p.ffn_weights.output_weight.kernel = weights_[weight_idx++]; - p.ffn_weights.output_weight.bias = weights_[weight_idx++]; - p.attn_layernorm_weights.gamma = weights_[weight_idx++]; - p.attn_layernorm_weights.beta = weights_[weight_idx++]; - p.ffn_layernorm_weights.gamma = weights_[weight_idx++]; - p.ffn_layernorm_weights.beta = weights_[weight_idx++]; - p.attention_relative_pos_bias = weights_[weight_idx++]; + p.attention_weights.attention_output_weight.bias = weights_[weight_idx++]; + p.ffn_weights.intermediate_weight.kernel = weights_[weight_idx++]; + p.ffn_weights.intermediate_weight.bias = weights_[weight_idx++]; + p.ffn_weights.output_weight.kernel = weights_[weight_idx++]; + p.ffn_weights.output_weight.bias = weights_[weight_idx++]; + p.attn_layernorm_weights.gamma = weights_[weight_idx++]; + p.attn_layernorm_weights.beta = weights_[weight_idx++]; + p.ffn_layernorm_weights.gamma = weights_[weight_idx++]; + p.ffn_layernorm_weights.beta = weights_[weight_idx++]; + p.attention_relative_pos_bias = weights_[weight_idx++]; bl.block_weight_list.push_back(p); } bl.merge_layernorm_weights.gamma = weights_[weight_idx++]; - bl.merge_layernorm_weights.beta = weights_[weight_idx++]; - bl.merge_linear_weights.kernel = weights_[weight_idx++]; - bl.attn_mask = weights_[weight_idx++]; + bl.merge_layernorm_weights.beta = weights_[weight_idx++]; + bl.merge_linear_weights.kernel = weights_[weight_idx++]; + bl.attn_mask = weights_[weight_idx++]; params_.basic_layer_weight_list.push_back(bl); } params_.patchEmbed_linear_weights.kernel = weights_[weight_idx++]; - params_.patchEmbed_linear_weights.bias = weights_[weight_idx++]; - params_.patchEmbed_norm_weights.gamma = weights_[weight_idx++]; - params_.patchEmbed_norm_weights.beta = weights_[weight_idx++]; - params_.norm_weights.gamma = weights_[weight_idx++]; - params_.norm_weights.beta = weights_[weight_idx++]; + params_.patchEmbed_linear_weights.bias = weights_[weight_idx++]; + params_.patchEmbed_norm_weights.gamma = weights_[weight_idx++]; + params_.patchEmbed_norm_weights.beta = weights_[weight_idx++]; + params_.norm_weights.gamma = weights_[weight_idx++]; + params_.norm_weights.beta = weights_[weight_idx++]; - cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); + cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); cublasWrapperMutex_ = new std::mutex(); - allocator_ = new Allocator(getDevice()); + allocator_ = new Allocator(getDevice()); cublasMMWrapper* cublas_wrapper = new cublasMMWrapper(cublas_handle_, cublaslt_handle_, nullptr, cublasAlgoMap_, cublasWrapperMutex_, nullptr); @@ -462,41 +462,51 @@ nvinfer1::IPluginV2DynamicExt* SwinTransformerPlugin::clone() const noexcept } template -DimsExprs SwinTransformerPlugin::getOutputDimensions(int outputIndex, +DimsExprs SwinTransformerPlugin::getOutputDimensions(int outputIndex, const DimsExprs* inputs, - int nbInputs, - IExprBuilder& exprBuilder) noexcept + int nbInputs, + IExprBuilder& exprBuilder) noexcept { // Input is B*in_chans*H*W, output should be B*dim*1*1 for fc layer assert(outputIndex == 0); // Copy over everything DimsExprs output; output.nbDims = 4; - output.d[0] = inputs[0].d[0]; - output.d[1] = exprBuilder.constant(output_dim_); - output.d[2] = exprBuilder.constant(1); - output.d[3] = exprBuilder.constant(1); + output.d[0] = inputs[0].d[0]; + output.d[1] = exprBuilder.constant(output_dim_); + output.d[2] = exprBuilder.constant(1); + output.d[3] = exprBuilder.constant(1); return output; } template -bool SwinTransformerPlugin::supportsFormatCombination(int pos, +bool SwinTransformerPlugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) noexcept + int nbInputs, + int nbOutputs) noexcept { - assert(pos >= 0); - assert(pos < 2); + bool res = false; + assert(pos >= 0 && pos < 2); assert(nbInputs == 1); + switch (pos) { + case 0: // input + case 1: // output + res = (inOut[pos].type + == (std::is_same::value ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT)) + && (inOut[pos].format == nvinfer1::TensorFormat::kLINEAR); + break; + default: + break; + } - return true; + return res; } template void SwinTransformerPlugin::configurePlugin(const DynamicPluginTensorDesc* in, - int nbInputs, + int nbInputs, const DynamicPluginTensorDesc* out, - int nbOutputs) noexcept + int nbOutputs) noexcept { assert(nbInputs == 1); assert(nbOutputs == 1); @@ -504,18 +514,18 @@ void SwinTransformerPlugin::configurePlugin(const DynamicPluginTensorDesc* in template size_t SwinTransformerPlugin::getWorkspaceSize(const PluginTensorDesc* inputs, - int nbInputs, + int nbInputs, const PluginTensorDesc* outputs, - int nbOutputs) const noexcept + int nbOutputs) const noexcept { return 0; } // IPluginV2Ext Methods template -nvinfer1::DataType SwinTransformerPlugin::getOutputDataType(int index, +nvinfer1::DataType SwinTransformerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, - int nbInputs) const noexcept + int nbInputs) const noexcept { assert(index == 0); assert(inputTypes[0] == nvinfer1::DataType::kFLOAT || inputTypes[0] == nvinfer1::DataType::kHALF); @@ -624,10 +634,10 @@ const char* SwinTransformerPlugin::getPluginNamespace() const noexcept template int SwinTransformerPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept { int batch_size = inputDesc->dims.d[0]; assert(batch_size <= max_batch_size_); @@ -635,7 +645,7 @@ int SwinTransformerPlugin::enqueue(const PluginTensorDesc* inputDesc, assert(img_size_ == inputDesc->dims.d[2]); assert(img_size_ == inputDesc->dims.d[3]); - int sm_ptr[1] = {sm_}; + int sm_ptr[1] = {sm_}; std::vector input_tensors = std::vector{ Tensor{MEMORY_GPU, getTensorType(), @@ -656,7 +666,7 @@ int SwinTransformerPlugin::enqueue(const PluginTensorDesc* inputDesc, SwinTransformerPluginCreator::SwinTransformerPluginCreator() { mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); + mFC.fields = mPluginAttributes.data(); } const char* SwinTransformerPluginCreator::getPluginName() const noexcept @@ -676,20 +686,20 @@ const PluginFieldCollection* SwinTransformerPluginCreator::getFieldNames() noexc IPluginV2* SwinTransformerPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept { - int max_batch_size; - int img_size; - int patch_size; - int in_chans; - int embed_dim; - int window_size; - int* depths = nullptr; - int* num_heads = nullptr; - bool ape; - bool patch_norm; - int layer_num; - float mlp_ratio; - bool qkv_bias; - float qk_scale; + int max_batch_size; + int img_size; + int patch_size; + int in_chans; + int embed_dim; + int window_size; + int* depths = nullptr; + int* num_heads = nullptr; + bool ape; + bool patch_norm; + int layer_num; + float mlp_ratio; + bool qkv_bias; + float qk_scale; std::vector w; for (int i = 0; i < fc->nbFields; i++) { @@ -857,7 +867,7 @@ IPluginV2* SwinTransformerPluginCreator::createPlugin(const char* name, const Pl return p; } else { - printf("[ERROR][SwinTransformerPluginCreator::createPlugin] unsupport datatype.\n"); + printf("[ERROR][SwinTransformerPluginCreator::createPlugin] unsupported datatype.\n"); exit(-1); } } @@ -865,7 +875,7 @@ IPluginV2* SwinTransformerPluginCreator::createPlugin(const char* name, const Pl IPluginV2* SwinTransformerPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept { - int type_id; + int type_id; size_t int_length = sizeof(int); deserialize_value(&serialData, &int_length, &type_id); // This object will be deleted when the network is destroyed, which will @@ -875,7 +885,7 @@ SwinTransformerPluginCreator::deserializePlugin(const char* name, const void* se else if (type_id == 1) return new SwinTransformerPlugin(name, serialData, serialLength); else { - printf("[ERROR][SwinTransformerPluginCreator::deserializePlugin] unsupport data type %d\n", type_id); + printf("[ERROR][SwinTransformerPluginCreator::deserializePlugin] unsupported data type %d\n", type_id); exit(-1); } } diff --git a/src/fastertransformer/tensorrt_plugin/swin/swinTransformerPlugin.h b/src/fastertransformer/tensorrt_plugin/swin/swinTransformerPlugin.h index 8f6fb2bd1..7727c4a5c 100644 --- a/src/fastertransformer/tensorrt_plugin/swin/swinTransformerPlugin.h +++ b/src/fastertransformer/tensorrt_plugin/swin/swinTransformerPlugin.h @@ -62,69 +62,69 @@ template class SwinTransformerPlugin: public nvinfer1::IPluginV2DynamicExt { private: const std::string layer_name_; - std::string namespace_; + std::string namespace_; - std::vector weights_; // in device memory - std::vector weight_size_; - cublasHandle_t cublas_handle_ = nullptr; - cublasLtHandle_t cublaslt_handle_ = nullptr; - cudnnHandle_t cudnn_handle_ = nullptr; - SwinTransformerWeight params_; - SwinTransformer* swin_transformer_ = nullptr; - std::mutex* cublasWrapperMutex_ = nullptr; - cublasAlgoMap* cublasAlgoMap_ = nullptr; - fastertransformer::Allocator* allocator_ = nullptr; - int output_dim_; - int weight_num_; - int max_batch_size_; - int img_size_; - int patch_size_; - int in_chans_; - int embed_dim_; - int window_size_; - bool ape_; - int patch_norm_; - int layer_num_; - float mlp_ratio_; - bool qkv_bias_; - float qk_scale_; - int* depths_; - int* num_heads_; + std::vector weights_; // in device memory + std::vector weight_size_; + cublasHandle_t cublas_handle_ = nullptr; + cublasLtHandle_t cublaslt_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; + SwinTransformerWeight params_; + SwinTransformer* swin_transformer_ = nullptr; + std::mutex* cublasWrapperMutex_ = nullptr; + cublasAlgoMap* cublasAlgoMap_ = nullptr; + fastertransformer::Allocator* allocator_ = nullptr; + int output_dim_; + int weight_num_; + int max_batch_size_; + int img_size_; + int patch_size_; + int in_chans_; + int embed_dim_; + int window_size_; + bool ape_; + int patch_norm_; + int layer_num_; + float mlp_ratio_; + bool qkv_bias_; + float qk_scale_; + int* depths_; + int* num_heads_; public: int sm_; - SwinTransformerPlugin(const std::string& name, - const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, + SwinTransformerPlugin(const std::string& name, + const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, const std::vector& w); - SwinTransformerPlugin(const std::string& name, - const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, + SwinTransformerPlugin(const std::string& name, + const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, const std::vector& w); SwinTransformerPlugin(const std::string& name, const void* data, size_t length); @@ -137,28 +137,28 @@ class SwinTransformerPlugin: public nvinfer1::IPluginV2DynamicExt { // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, - const nvinfer1::DimsExprs* inputs, - int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) noexcept override; - bool supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) noexcept override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) noexcept override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const noexcept override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType @@ -167,19 +167,19 @@ class SwinTransformerPlugin: public nvinfer1::IPluginV2DynamicExt { // IPluginV2 Methods const char* getPluginType() const noexcept override; const char* getPluginVersion() const noexcept override; - int getNbOutputs() const noexcept override; - int initialize() noexcept override; - void terminate() noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void* buffer) const noexcept override; - void destroy() noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override; + int getNbOutputs() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(const char* pluginNamespace) noexcept override; const char* getPluginNamespace() const noexcept override; // Host To Device static T* cudaMallocAndCopy(vector& weights, const nvinfer1::Weights& w) { - T* dpWeight; + T* dpWeight; size_t nValue = w.count; check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T))); check_cuda_error(cudaMemcpy(dpWeight, w.values, nValue * sizeof(T), cudaMemcpyHostToDevice)); @@ -188,12 +188,12 @@ class SwinTransformerPlugin: public nvinfer1::IPluginV2DynamicExt { } // Device to Device - static T* cudaMallocAndCopy(vector& weights, - const T* dpWeightOld, + static T* cudaMallocAndCopy(vector& weights, + const T* dpWeightOld, const vector& weight_size, - const int weight_idx) + const int weight_idx) { - T* dpWeight; + T* dpWeight; size_t nValue = weight_size[weight_idx]; check_cuda_error(cudaMalloc((void**)&dpWeight, nValue * sizeof(T))); check_cuda_error(cudaMemcpy(dpWeight, dpWeightOld, nValue * sizeof(T), cudaMemcpyDeviceToDevice)); @@ -239,8 +239,8 @@ class SwinTransformerPluginCreator: public nvinfer1::IPluginCreator { if (field_name.compare(name) == 0) { nvinfer1::Weights tmp; tmp.values = fc->fields[i].data; - tmp.count = fc->fields[i].length; - tmp.type = fieldTypeToDataType(fc->fields[i].type); + tmp.count = fc->fields[i].length; + tmp.type = fieldTypeToDataType(fc->fields[i].type); w.push_back(tmp); break; } @@ -248,9 +248,9 @@ class SwinTransformerPluginCreator: public nvinfer1::IPluginCreator { } private: - static nvinfer1::PluginFieldCollection mFC; + static nvinfer1::PluginFieldCollection mFC; static std::vector mPluginAttributes; - std::string namespace_; + std::string namespace_; }; } // namespace fastertransformer diff --git a/src/fastertransformer/tensorrt_plugin/t5/README.md b/src/fastertransformer/tensorrt_plugin/t5/README.md index ffccb16b9..0f82f0e83 100644 --- a/src/fastertransformer/tensorrt_plugin/t5/README.md +++ b/src/fastertransformer/tensorrt_plugin/t5/README.md @@ -1,8 +1,7 @@ # T5Plugin in FasterTransformer -+ Original Faster Transformer: [link](https://github.com/NVIDIA/FasterTransformer) -+ This project aims to wapper the encoder and decoding parts of the faster transformer as TensoRT plugins respectively. The most of teh code file are from orignial Faster TRansformer project. ++ This project aims to wapper the encoder and decoding parts of the faster transformer as TensoRT plugins respectively. The most of the code file are from original Faster TRansformer project. -## Envionment +## Environment + **nvcr.io/nvidia/pytorch:21.02-py3** (including CUDA 11.2.0, cudnn 8.1.0.77, 1.8.0a0+52ea372, TensorRTTensorRT 7.2.2.3+cuda11.1.0.024) + Now the code in the repository are compatible for TensorRT7, maybe need several edition before using in TensorRT8. diff --git a/src/fastertransformer/tensorrt_plugin/t5/T5Plugin.cu b/src/fastertransformer/tensorrt_plugin/t5/T5Plugin.cu index f7e873bb6..16b8f34a3 100644 --- a/src/fastertransformer/tensorrt_plugin/t5/T5Plugin.cu +++ b/src/fastertransformer/tensorrt_plugin/t5/T5Plugin.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "3rdparty/INIReader.h" #include "T5Plugin.h" using namespace fastertransformer; @@ -22,81 +23,86 @@ namespace nvinfer1 { // class T5EncoderPlugin --------------------------------------------------------------------------- T5EncoderPlugin::T5EncoderPlugin(const std::string& name, - size_t max_batch_size, - size_t max_seq_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - size_t num_bucket, - size_t max_distance, - int sm, - float q_scaling, - int useFP16): + size_t max_batch_size, + size_t max_seq_len, + size_t beam_width, + int sm, + int useFP16, + const std::string& ckpt_path, + bool own_weight): name_(name) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); WHERE_AM_I(); + + INIReader reader = INIReader(ckpt_path + "/config.ini"); + if (reader.ParseError() < 0) { + FT_LOG_ERROR("Can't load %s/config.ini", ckpt_path.c_str()); + FT_CHECK(false); + } + + bool t5_with_bias = reader.GetBoolean("structure", "t5_with_bias", false); m_.max_batch_size = max_batch_size; - m_.max_seq_len = max_seq_len; - m_.beam_width = beam_width; - m_.head_num = head_num; - m_.size_per_head = size_per_head; - m_.inter_size = inter_size; - m_.d_model = d_model; - m_.num_layer = num_layer; - m_.num_bucket = num_bucket; - m_.max_distance = max_distance; - m_.sm = sm; - m_.q_scaling = q_scaling; - m_.useFP16 = (bool)useFP16; - m_.batch_size = m_.max_batch_size; - m_.seq_len = m_.max_seq_len; + m_.max_seq_len = max_seq_len; + m_.beam_width = beam_width; + m_.head_num = reader.GetInteger("encoder", "num_heads"); + m_.size_per_head = reader.GetInteger("encoder", "d_kv"); + m_.inter_size = reader.GetInteger("encoder", "d_ff"); + m_.d_model = reader.GetInteger("encoder", "d_model"); + m_.num_layer = reader.GetInteger("encoder", "num_layers"); + m_.num_bucket = reader.GetInteger("encoder", "relative_attention_num_buckets_or_max_pos_seq_len"); + m_.max_distance = reader.GetInteger("encoder", "relative_attention_max_distance"); + m_.sm = sm; + m_.q_scaling = t5_with_bias ? 1.0f : (1.0f / (sqrt(m_.size_per_head) * 1.0f)); + m_.useFP16 = (bool)useFP16; + m_.batch_size = m_.max_batch_size; + m_.seq_len = m_.max_seq_len; + strcpy(m_.ckpt_path, ckpt_path.c_str()); cublasCreate(&cublasHandle_); cublasLtCreate(&cublasltHandle_); #ifdef SPARSITY_ENABLED - cusparseLtInit(&cusparseltHandle_)); + cusparseLtInit(&cusparseltHandle_); #endif - // T5 EncoderWeight - std::string paraFilePath = "./para"; - if (m_.useFP16) { - m_.attention_type = AttentionType::UNFUSED_MHA; // when use FP16, only this type works till v5.0-dev - pT5EncoderWeightHalf_ = new T5EncoderWeight(m_.head_num, - m_.size_per_head, - m_.d_model, - m_.inter_size, - m_.vocab_size, - m_.num_layer, - m_.num_bucket, - 1, // tensor_para_size - 0, // tensor_para_rank - 1, // pipeline_para_size - 0 // pipeline_para_rank - ); - pT5EncoderWeightHalf_->loadModel(paraFilePath); - } - else { - m_.attention_type = - getAttentionType(m_.size_per_head, getSMVersion(), m_.is_remove_padding, m_.max_seq_len); - pT5EncoderWeightFloat_ = new T5EncoderWeight(m_.head_num, - m_.size_per_head, - m_.d_model, - m_.inter_size, - m_.vocab_size, - m_.num_layer, - m_.num_bucket, - 1, // tensor_para_size - 0, // tensor_para_rank - 1, // pipeline_para_size - 0 // pipeline_para_rank - ); - pT5EncoderWeightFloat_->loadModel(paraFilePath); + is_own_weight = own_weight; + if (is_own_weight) { + // T5 EncoderWeight + if (m_.useFP16) { + m_.attention_type = + getAttentionType(m_.size_per_head, getSMVersion(), m_.is_remove_padding, m_.max_seq_len, false); + pT5EncoderWeightHalf_ = new T5EncoderWeight(m_.head_num, + m_.size_per_head, + m_.d_model, + m_.inter_size, + m_.vocab_size, + m_.num_layer, + m_.num_bucket, + 1, // tensor_para_size + 0, // tensor_para_rank + 1, // pipeline_para_size + 0 // pipeline_para_rank + ); + pT5EncoderWeightHalf_->loadModel(std::string(m_.ckpt_path)); + } + else { + m_.attention_type = + getAttentionType(m_.size_per_head, getSMVersion(), m_.is_remove_padding, m_.max_seq_len); + pT5EncoderWeightFloat_ = new T5EncoderWeight(m_.head_num, + m_.size_per_head, + m_.d_model, + m_.inter_size, + m_.vocab_size, + m_.num_layer, + m_.num_bucket, + 1, // tensor_para_size + 0, // tensor_para_rank + 1, // pipeline_para_size + 0 // pipeline_para_rank + ); + pT5EncoderWeightFloat_->loadModel(std::string(m_.ckpt_path)); + } } - // Gemm file selection std::string gemmFileName = std::string(GEMM_CONFIG).substr(0, 11) + std::string("-SM") + std::to_string(m_.sm) + std::string("-FP") + std::to_string(m_.useFP16 ? 16 : 32) + std::string("-BS") @@ -104,39 +110,39 @@ T5EncoderPlugin::T5EncoderPlugin(const std::string& name, + std::string("-BM") + std::to_string(m_.beam_width) + std::string(".in"); std::ifstream infile(gemmFileName); if (infile.good()) { -#if DEBUG_ENABLE == 1 +#ifdef T5_PLUGIN_DEBUG printf("Gemm file exist!\n"); #endif } else { -#if DEBUG_ENABLE == 1 +#ifdef T5_PLUGIN_DEBUG printf("Gemm file do not exist!\n"); #endif int argv[16] = { 0, - m_.max_batch_size, - m_.beam_width, // useless for encoder - (m_.batch_size == 128 && m_.seq_len == 384) ? 128 : m_.seq_len, // seq_len, in case of OOM - m_.d_model, - m_.head_num, - m_.size_per_head, - m_.inter_size, - m_.d_model, - m_.head_num, - m_.size_per_head, - m_.inter_size, - m_.vocab_size, - m_.useFP16, // is_fp16 - 1, // tensor_para_size - m_.useFP16 // is_fp16_compute_type + (int)m_.max_batch_size, + (int)m_.beam_width, // useless for encoder + (m_.batch_size == 128 && m_.seq_len == 384) ? 128 : (int)m_.seq_len, // seq_len, in case of OOM + (int)m_.d_model, + (int)m_.head_num, + (int)m_.size_per_head, + (int)m_.inter_size, + (int)m_.d_model, + (int)m_.head_num, + (int)m_.size_per_head, + (int)m_.inter_size, + (int)m_.vocab_size, + m_.useFP16 ? 1 : 0, // is_fp16 + 1, // tensor_para_size + false // always use fp32 compute type }; t5_gemm(argv); rename(std::string(GEMM_CONFIG).c_str(), gemmFileName.c_str()); } - pCublasAlgoMap_ = new cublasAlgoMap(gemmFileName, ""); + pCublasAlgoMap_ = new cublasAlgoMap(gemmFileName, ""); pCublasWrapperMutex_ = new std::mutex(); - pAllocator_ = new Allocator(getDevice()); + pAllocator_ = new Allocator(getDevice()); // cublas wrapper and T5Encoder #ifdef SPARSITY_ENABLED @@ -170,8 +176,8 @@ T5EncoderPlugin::T5EncoderPlugin(const std::string& name, m_.is_sparse, m_.activation_type, m_.layernorm_type, - {0, 1, nullptr}, // tensor_para - {0, 1, nullptr} // pipeline_para + NcclParam(0, 1), // tensor_para + NcclParam(0, 1) // pipeline_para ); } else { @@ -196,8 +202,8 @@ T5EncoderPlugin::T5EncoderPlugin(const std::string& name, m_.is_sparse, m_.activation_type, m_.layernorm_type, - {0, 1, nullptr}, // tensor_para - {0, 1, nullptr} // pipeline_para + NcclParam(0, 1), // tensor_para + NcclParam(0, 1) // pipeline_para ); } PRINT_ENCODER(m_.useFP16) @@ -212,85 +218,88 @@ T5EncoderPlugin::T5EncoderPlugin(const std::string& name, const void* buffer, si cublasCreate(&cublasHandle_); cublasLtCreate(&cublasltHandle_); #ifdef SPARSITY_ENABLED - cusparseLtInit(&cusparseltHandle_)); + cusparseLtInit(&cusparseltHandle_); #endif - // T5 EncoderWeight - std::string paraFilePath = "./para"; - if (m_.useFP16) { - m_.attention_type = AttentionType::UNFUSED_MHA; // when use FP16, only this type works till v5.0-dev - pT5EncoderWeightHalf_ = new T5EncoderWeight(m_.head_num, - m_.size_per_head, - m_.d_model, - m_.inter_size, - m_.vocab_size, - m_.num_layer, - m_.num_bucket, - 1, // tensor_para_size - 0, // tensor_para_rank - 1, // pipeline_para_size - 0 // pipeline_para_rank - ); - pT5EncoderWeightHalf_->loadModel(paraFilePath); - } - else { - m_.attention_type = - getAttentionType(m_.size_per_head, getSMVersion(), m_.is_remove_padding, m_.max_seq_len); - pT5EncoderWeightFloat_ = new T5EncoderWeight(m_.head_num, - m_.size_per_head, - m_.d_model, - m_.inter_size, - m_.vocab_size, - m_.num_layer, - m_.num_bucket, - 1, // tensor_para_size - 0, // tensor_para_rank - 1, // pipeline_para_size - 0 // pipeline_para_rank - ); - pT5EncoderWeightFloat_->loadModel(paraFilePath); + is_own_weight = true; + if (is_own_weight) { + // T5 EncoderWeight + if (m_.useFP16) { + m_.attention_type = + getAttentionType(m_.size_per_head, getSMVersion(), m_.is_remove_padding, m_.max_seq_len, false); + pT5EncoderWeightHalf_ = new T5EncoderWeight(m_.head_num, + m_.size_per_head, + m_.d_model, + m_.inter_size, + m_.vocab_size, + m_.num_layer, + m_.num_bucket, + 1, // tensor_para_size + 0, // tensor_para_rank + 1, // pipeline_para_size + 0 // pipeline_para_rank + ); + pT5EncoderWeightHalf_->loadModel(std::string(m_.ckpt_path)); + } + else { + m_.attention_type = + getAttentionType(m_.size_per_head, getSMVersion(), m_.is_remove_padding, m_.max_seq_len); + pT5EncoderWeightFloat_ = new T5EncoderWeight(m_.head_num, + m_.size_per_head, + m_.d_model, + m_.inter_size, + m_.vocab_size, + m_.num_layer, + m_.num_bucket, + 1, // tensor_para_size + 0, // tensor_para_rank + 1, // pipeline_para_size + 0 // pipeline_para_rank + ); + pT5EncoderWeightFloat_->loadModel(std::string(m_.ckpt_path)); + } } - - // Gemm file selection, in constructor, we use max_batch_szie and seq_len as data size + // Gemm file selection, in constructor, we use max_batch_szie and seq_len as + // data size std::string gemmFileName = std::string(GEMM_CONFIG).substr(0, 11) + std::string("-SM") + std::to_string(m_.sm) + std::string("-FP") + std::to_string(m_.useFP16 ? 16 : 32) + std::string("-BS") + std::to_string(m_.batch_size) + std::string("-SL") + std::to_string(m_.seq_len) + std::string("-BM") + std::to_string(m_.beam_width) + std::string(".in"); std::ifstream infile(gemmFileName); if (infile.good()) { -#if DEBUG_ENABLE == 1 +#ifdef T5_PLUGIN_DEBUG printf("Gemm file exist!\n"); #endif } else { -#if DEBUG_ENABLE == 1 +#ifdef T5_PLUGIN_DEBUG printf("Gemm file do not exist!\n"); #endif int argv[16] = { 0, - m_.max_batch_size, - m_.beam_width, // useless for encoder - (m_.batch_size == 128 && m_.seq_len == 384) ? 128 : m_.seq_len, // seq_len, in case of OOM - m_.d_model, - m_.head_num, - m_.size_per_head, - m_.inter_size, - m_.d_model, - m_.head_num, - m_.size_per_head, - m_.inter_size, - m_.vocab_size, - m_.useFP16, // is_fp16 - 1, // tensor_para_size - m_.useFP16 // is_fp16_compute_type + (int)m_.max_batch_size, + (int)m_.beam_width, // useless for encoder + (m_.batch_size == 128 && m_.seq_len == 384) ? 128 : (int)m_.seq_len, // seq_len, in case of OOM + (int)m_.d_model, + (int)m_.head_num, + (int)m_.size_per_head, + (int)m_.inter_size, + (int)m_.d_model, + (int)m_.head_num, + (int)m_.size_per_head, + (int)m_.inter_size, + (int)m_.vocab_size, + m_.useFP16 ? 1 : 0, // is_fp16 + 1, // tensor_para_size + false // always use fp32 compute type }; t5_gemm(argv); rename(std::string(GEMM_CONFIG).c_str(), gemmFileName.c_str()); } - pCublasAlgoMap_ = new cublasAlgoMap(gemmFileName, ""); + pCublasAlgoMap_ = new cublasAlgoMap(gemmFileName, ""); pCublasWrapperMutex_ = new std::mutex(); - pAllocator_ = new Allocator(getDevice()); + pAllocator_ = new Allocator(getDevice()); // cublas wrapper and T5Encoder #ifdef SPARSITY_ENABLED @@ -324,8 +333,8 @@ T5EncoderPlugin::T5EncoderPlugin(const std::string& name, const void* buffer, si m_.is_sparse, m_.activation_type, m_.layernorm_type, - {0, 1, nullptr}, // tensor_para - {0, 1, nullptr} // pipeline_para + NcclParam(0, 1), // tensor_para + NcclParam(0, 1) // pipeline_para ); } else { @@ -350,8 +359,8 @@ T5EncoderPlugin::T5EncoderPlugin(const std::string& name, const void* buffer, si m_.is_sparse, m_.activation_type, m_.layernorm_type, - {0, 1, nullptr}, // tensor_para - {0, 1, nullptr} // pipeline_para + NcclParam(0, 1), // tensor_para + NcclParam(0, 1) // pipeline_para ); } PRINT_ENCODER(m_.useFP16) @@ -360,10 +369,10 @@ T5EncoderPlugin::T5EncoderPlugin(const std::string& name, const void* buffer, si T5EncoderPlugin::~T5EncoderPlugin() { WHERE_AM_I(); - if (pT5EncoderWeightHalf_ != nullptr) { + if (is_own_weight && pT5EncoderWeightHalf_ != nullptr) { delete pT5EncoderWeightHalf_; } - if (pT5EncoderWeightFloat_ != nullptr) { + if (is_own_weight && pT5EncoderWeightFloat_ != nullptr) { delete pT5EncoderWeightFloat_; } if (pT5EncoderHalf_ != nullptr) { @@ -393,8 +402,11 @@ void T5EncoderPlugin::serialize(void* buffer) const noexcept IPluginV2DynamicExt* T5EncoderPlugin::clone() const noexcept { WHERE_AM_I(); - auto p = new T5EncoderPlugin(name_, &m_, sizeof(m_)); + auto p = new T5EncoderPlugin( + name_, m_.max_batch_size, m_.max_seq_len, m_.beam_width, m_.sm, m_.useFP16, std::string(m_.ckpt_path), false); p->setPluginNamespace(namespace_.c_str()); + p->pT5EncoderWeightHalf_ = this->pT5EncoderWeightHalf_; + p->pT5EncoderWeightFloat_ = this->pT5EncoderWeightFloat_; return p; } @@ -410,10 +422,10 @@ DataType T5EncoderPlugin::getOutputDataType(int index, const DataType* inputType return m_.useFP16 ? DataType::kHALF : DataType::kFLOAT; } -bool T5EncoderPlugin::supportsFormatCombination(int pos, +bool T5EncoderPlugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) noexcept + int nbInputs, + int nbOutputs) noexcept { WHERE_AM_I(); bool res = false; @@ -428,9 +440,9 @@ bool T5EncoderPlugin::supportsFormatCombination(int pos, && (inOut[2].format == TensorFormat::kLINEAR); break; default: // should NOT be here! - ; + res = false; } -#if DEBUG_ENABLE == 1 +#ifdef T5_PLUGIN_DEBUG printf("Dim("); for (int i = 0; i < 3; i++) { printf("%d,", inOut[i].dims.nbDims); @@ -456,33 +468,33 @@ bool T5EncoderPlugin::supportsFormatCombination(int pos, return res; } -DimsExprs T5EncoderPlugin::getOutputDimensions(int index, +DimsExprs T5EncoderPlugin::getOutputDimensions(int index, const DimsExprs* pInputDim, - int nInputDim, - IExprBuilder& exprBuilder) noexcept + int nInputDim, + IExprBuilder& exprBuilder) noexcept { WHERE_AM_I(); DimsExprs ret; ret.nbDims = 3; - ret.d[0] = pInputDim[0].d[0]; - ret.d[1] = pInputDim[0].d[1]; - ret.d[2] = exprBuilder.constant(512); + ret.d[0] = pInputDim[0].d[0]; + ret.d[1] = pInputDim[0].d[1]; + ret.d[2] = exprBuilder.constant(512); return ret; } void T5EncoderPlugin::configurePlugin(const DynamicPluginTensorDesc* in, - int nbInput, + int nbInput, const DynamicPluginTensorDesc* out, - int nbOutput) noexcept + int nbOutput) noexcept { WHERE_AM_I(); PRINT_ENCODER(int(out[0].desc.type)) } size_t T5EncoderPlugin::getWorkspaceSize(const PluginTensorDesc* inputs, - int32_t nbInputs, + int32_t nbInputs, const PluginTensorDesc* outputs, - int32_t nbOutputs) const noexcept + int32_t nbOutputs) const noexcept { WHERE_AM_I(); return 0; @@ -550,15 +562,15 @@ void T5EncoderPlugin::detachFromContext() noexcept int T5EncoderPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept { FT_LOG_DEBUG(__PRETTY_FUNCTION__); WHERE_AM_I(); m_.batch_size = inputDesc[0].dims.d[0]; - m_.seq_len = inputDesc[0].dims.d[1]; + m_.seq_len = inputDesc[0].dims.d[1]; PRINT_ENCODER(outputDesc[0].type) cublasSetStream(cublasHandle_, stream); @@ -591,14 +603,14 @@ int T5EncoderPlugin::enqueue(const PluginTensorDesc* inputDesc, } // class T5EncoderPluginCreator -------------------------------------------------------------------- -PluginFieldCollection T5EncoderPluginCreator::fc_{}; +PluginFieldCollection T5EncoderPluginCreator::fc_{}; std::vector T5EncoderPluginCreator::attr_; T5EncoderPluginCreator::T5EncoderPluginCreator() { WHERE_AM_I(); fc_.nbFields = attr_.size(); - fc_.fields = attr_.data(); + fc_.fields = attr_.data(); } T5EncoderPluginCreator::~T5EncoderPluginCreator() @@ -610,19 +622,12 @@ IPluginV2* T5EncoderPluginCreator::createPlugin(const char* name, const PluginFi { FT_LOG_DEBUG(__PRETTY_FUNCTION__); WHERE_AM_I(); - int max_batch_size = 128; - int max_seq_len = 384; - int beam_width = 1; - int head_num = 8; - int size_per_head = 512 / 8; - int inter_size = 512 * 4; - int d_model = 512; - int num_layer = 6; - int num_bucket = 32; - int max_distance = 128; - int sm = -1; - float q_scaling = 1.0f / (sqrt(size_per_head) * 1.0f); - int useFP16 = 0; + int max_batch_size = 128; + int max_seq_len = 384; + int beam_width = 1; + int sm = -1; + int useFP16 = 0; + std::string ckpt_path = std::string(""); struct cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); @@ -632,38 +637,18 @@ IPluginV2* T5EncoderPluginCreator::createPlugin(const char* name, const PluginFi {"max_batch_size", &max_batch_size}, {"max_seq_len", &max_seq_len}, {"beam_width", &beam_width}, - {"head_num", &head_num}, - {"size_per_head", &size_per_head}, - {"inter_size", &inter_size}, - {"d_model", &d_model}, - {"num_layer", &num_layer}, - {"num_bucket", &num_bucket}, - {"max_distance", &max_distance}, {"sm", &sm}, {"useFP16", &useFP16}, }; for (int i = 0; i < fc->nbFields; i++) { - if (!strcmp(fc->fields[i].name, "q_scaling")) { - q_scaling = *(float*)fc->fields[i].data; - } - else if (name2p.find(fc->fields[i].name) != name2p.end()) { + if (name2p.find(fc->fields[i].name) != name2p.end()) { *name2p[fc->fields[i].name] = *(int*)fc->fields[i].data; } + else if (!strcmp(fc->fields[i].name, "ckpt_path")) { + ckpt_path = std::string((char*)fc->fields[i].data); + } } - return new T5EncoderPlugin(name, - max_batch_size, - max_seq_len, - beam_width, - head_num, - size_per_head, - inter_size, - d_model, - num_layer, - num_bucket, - max_distance, - sm, - q_scaling, - useFP16); + return new T5EncoderPlugin(name, max_batch_size, max_seq_len, beam_width, sm, useFP16, ckpt_path, true); } IPluginV2* @@ -707,95 +692,86 @@ REGISTER_TENSORRT_PLUGIN(T5EncoderPluginCreator); // class T5DecodingPlugin -------------------------------------------------------------------------- T5DecodingPlugin::T5DecodingPlugin(const std::string& name, - size_t max_batch_size, - size_t max_seq_len, - size_t mem_max_seq_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - size_t vocab_size, - size_t num_bucket, - size_t max_distance, - float q_scaling, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - float temperature, - float len_penalty, - float repetition_penalty, - int useFP16): + size_t max_batch_size, + size_t max_seq_len, + size_t mem_max_seq_len, + size_t beam_width, + int useFP16, + const std::string& ckpt_path, + bool own_weight): name_(name) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); WHERE_AM_I(); - m_.max_batch_size = max_batch_size; - m_.max_seq_len = max_seq_len; - m_.mem_max_seq_len = mem_max_seq_len; - m_.beam_width = beam_width; - m_.head_num = head_num; - m_.size_per_head = size_per_head; - m_.inter_size = inter_size; - m_.d_model = d_model; - m_.num_layer = num_layer; - m_.vocab_size = vocab_size; - m_.num_bucket = num_bucket; - m_.max_distance = max_distance; - m_.q_scaling = q_scaling; - m_.start_id = start_id; - m_.end_id = end_id; - m_.beam_search_diversity_rate = beam_search_diversity_rate; - m_.top_k = top_k; - m_.top_p = top_p; - m_.temperature = temperature; - m_.len_penalty = len_penalty; - m_.repetition_penalty = repetition_penalty; - m_.useFP16 = useFP16; - m_.batch_size = m_.max_batch_size; - m_.seq_len = m_.max_seq_len; + + INIReader reader = INIReader(ckpt_path + "/config.ini"); + if (reader.ParseError() < 0) { + FT_LOG_ERROR("Can't load %s/config.ini", ckpt_path.c_str()); + FT_CHECK(false); + } + + bool t5_with_bias = reader.GetBoolean("structure", "t5_with_bias", false); + m_.max_batch_size = max_batch_size; + m_.max_seq_len = max_seq_len; + m_.mem_max_seq_len = mem_max_seq_len; + m_.beam_width = beam_width; + m_.head_num = reader.GetInteger("decoder", "num_heads"); + m_.size_per_head = reader.GetInteger("decoder", "d_kv"); + m_.inter_size = reader.GetInteger("decoder", "d_ff"); + m_.d_model = reader.GetInteger("decoder", "d_model"); + m_.num_layer = reader.GetInteger("decoder", "num_layers"); + m_.vocab_size = reader.GetInteger("decoder", "vocab_size"); + m_.num_bucket = reader.GetInteger("decoder", "relative_attention_num_buckets_or_max_pos_seq_len"); + m_.max_distance = reader.GetInteger("decoder", "relative_attention_max_distance"); + m_.q_scaling = t5_with_bias ? 1.0f : (1.0f / (sqrt(m_.size_per_head) * 1.0f)); + m_.start_id = reader.GetInteger("decoder", "decoder_start_token_id"); + m_.end_id = reader.GetInteger("decoder", "eos_token_id"); + m_.useFP16 = (bool)useFP16; + m_.batch_size = m_.max_batch_size; + m_.seq_len = m_.max_seq_len; + m_.mem_hidden_units = reader.GetInteger("encoder", "num_heads") * reader.GetInteger("encoder", "d_kv"); + m_.mem_d_model = reader.GetInteger("encoder", "d_model"); + strcpy(m_.ckpt_path, ckpt_path.c_str()); cublasCreate(&cublasHandle_); cublasLtCreate(&cublasltHandle_); - // T5DecodingWeight - std::string paraFilePath = "./para"; - if (m_.useFP16) { - pT5DecodingWeightHalf_ = new T5DecodingWeight(m_.head_num, - m_.size_per_head, - m_.d_model, - m_.inter_size, - m_.vocab_size, - m_.num_layer, - m_.mem_d_model, - m_.num_bucket, - 1, // tensor_para_size - 0, // tensor_para_rank - 1, // pipeline_para_size - 0 // pipeline_para_rank - ); - pT5DecodingWeightHalf_->loadModel(paraFilePath); - } - else { - pT5DecodingWeightFloat_ = new T5DecodingWeight(m_.head_num, - m_.size_per_head, - m_.d_model, - m_.inter_size, - m_.vocab_size, - m_.num_layer, - m_.mem_d_model, - m_.num_bucket, - 1, // tensor_para_size, - 0, // tensor_para_rank, - 1, // pipeline_para_size, - 0 // pipeline_para_rank - ); - pT5DecodingWeightFloat_->loadModel(paraFilePath); + is_own_weight = own_weight; + if (is_own_weight) { + // T5DecodingWeight + if (m_.useFP16) { + pT5DecodingWeightHalf_ = new T5DecodingWeight(m_.head_num, + m_.size_per_head, + m_.d_model, + m_.inter_size, + m_.vocab_size, + m_.num_layer, + m_.mem_d_model, + m_.num_bucket, + 1, // tensor_para_size + 0, // tensor_para_rank + 1, // pipeline_para_size + 0 // pipeline_para_rank + ); + pT5DecodingWeightHalf_->loadModel(std::string(m_.ckpt_path)); + } + else { + pT5DecodingWeightFloat_ = new T5DecodingWeight(m_.head_num, + m_.size_per_head, + m_.d_model, + m_.inter_size, + m_.vocab_size, + m_.num_layer, + m_.mem_d_model, + m_.num_bucket, + 1, // tensor_para_size, + 0, // tensor_para_rank, + 1, // pipeline_para_size, + 0 // pipeline_para_rank + ); + pT5DecodingWeightFloat_->loadModel(std::string(m_.ckpt_path)); + } } - // Gemm file selection check_cuda_error(cudaGetDeviceProperties(&cuda_device_prop_, 0)); std::string gemmFileName = std::string(GEMM_CONFIG).substr(0, 11) + std::string("-SM") @@ -805,39 +781,39 @@ T5DecodingPlugin::T5DecodingPlugin(const std::string& name, + std::string("-BM") + std::to_string(m_.beam_width) + std::string(".in"); std::ifstream infile(gemmFileName); if (infile.good()) { -#if DEBUG_ENABLE == 1 +#ifdef T5_PLUGIN_DEBUG printf("Gemm file exist!\n"); #endif } else { -#if DEBUG_ENABLE == 1 +#ifdef T5_PLUGIN_DEBUG printf("Gemm file do not exist!\n"); #endif int argv[16] = { 0, - m_.max_batch_size, - m_.beam_width, - (m_.batch_size == 128 && m_.seq_len == 384) ? 128 : m_.seq_len, // seq_len, in case of OOM - m_.d_model, - m_.head_num, - m_.size_per_head, - m_.inter_size, - m_.d_model, - m_.head_num, - m_.size_per_head, - m_.inter_size, - m_.vocab_size, - m_.useFP16, // is_fp16 - 1, // tensor_para_size - m_.useFP16 // is_fp16_compute_type + (int)m_.max_batch_size, + (int)m_.beam_width, + (m_.batch_size == 128 && m_.seq_len == 384) ? 128 : (int)m_.seq_len, // seq_len, in case of OOM + (int)m_.d_model, + (int)m_.head_num, + (int)m_.size_per_head, + (int)m_.inter_size, + (int)m_.d_model, + (int)m_.head_num, + (int)m_.size_per_head, + (int)m_.inter_size, + (int)m_.vocab_size, + m_.useFP16 ? 1 : 0, // is_fp16 + 1, // tensor_para_size + false // always use fp32 compute type }; t5_gemm(argv); rename(std::string(GEMM_CONFIG).c_str(), gemmFileName.c_str()); } - pCublasAlgoMap_ = new cublasAlgoMap(gemmFileName, ""); + pCublasAlgoMap_ = new cublasAlgoMap(gemmFileName, ""); pCublasWrapperMutex_ = new std::mutex(); - pAllocator_ = new Allocator(getDevice()); + pAllocator_ = new Allocator(getDevice()); // cublas wrapper and T5Decoding pCublasWrapper_ = @@ -861,19 +837,19 @@ T5DecodingPlugin::T5DecodingPlugin(const std::string& name, m_.q_scaling, m_.start_id, m_.end_id, - m_.beam_search_diversity_rate, - m_.top_k, - m_.top_p, - m_.temperature, - m_.len_penalty, - m_.repetition_penalty, - 0, // stream placeholder + 0.0f, // don't need to pass beam_search_diversity_rate in constructor + 0, // don't need to pass top_k in constructor + 0.0f, // don't need to pass top_p in constructor + 0.0f, // don't need to pass temperature in constructor + 0.0f, // don't need to pass len_penalty in constructor + 0.0f, // don't need to pass repetition_penalty in constructor + 0, // stream placeholder pCublasWrapper_, pAllocator_, m_.is_free_buffer_after_forward, &cuda_device_prop_, - {0, 1, nullptr}, // tensor_para - {0, 1, nullptr} // pipeline_para + NcclParam(0, 1), // tensor_para + NcclParam(0, 1) // pipeline_para ); } else { @@ -894,19 +870,19 @@ T5DecodingPlugin::T5DecodingPlugin(const std::string& name, m_.q_scaling, m_.start_id, m_.end_id, - m_.beam_search_diversity_rate, - m_.top_k, - m_.top_p, - m_.temperature, - m_.len_penalty, - m_.repetition_penalty, - 0, // stream placeholder + 0.0f, // don't need to pass beam_search_diversity_rate in constructor + 0, // don't need to pass top_k in constructor + 0.0f, // don't need to pass top_p in constructor + 0.0f, // don't need to pass temperature in constructor + 0.0f, // don't need to pass len_penalty in constructor + 0.0f, // don't need to pass repetition_penalty in constructor + 0, // stream placeholder pCublasWrapper_, pAllocator_, m_.is_free_buffer_after_forward, &cuda_device_prop_, - {0, 1, nullptr}, // tensor_para - {0, 1, nullptr} // pipeline_para + NcclParam(0, 1), // tensor_para + NcclParam(0, 1) // pipeline_para ); } } @@ -919,42 +895,42 @@ T5DecodingPlugin::T5DecodingPlugin(const std::string& name, const void* buffer, cublasCreate(&cublasHandle_); cublasLtCreate(&cublasltHandle_); - - // T5DecodingWeight - std::string paraFilePath = "./para"; - if (m_.useFP16) { - pT5DecodingWeightHalf_ = new T5DecodingWeight(m_.head_num, - m_.size_per_head, - m_.d_model, - m_.inter_size, - m_.vocab_size, - m_.num_layer, - m_.mem_d_model, - m_.num_bucket, - 1, // tensor_para_size - 0, // tensor_para_rank - 1, // pipeline_para_size - 0 // pipeline_para_rank - ); - pT5DecodingWeightHalf_->loadModel(paraFilePath); - } - else { - pT5DecodingWeightFloat_ = new T5DecodingWeight(m_.head_num, - m_.size_per_head, - m_.d_model, - m_.inter_size, - m_.vocab_size, - m_.num_layer, - m_.mem_d_model, - m_.num_bucket, - 1, // tensor_para_size, - 0, // tensor_para_rank, - 1, // pipeline_para_size, - 0 // pipeline_para_rank - ); - pT5DecodingWeightFloat_->loadModel(paraFilePath); + is_own_weight = true; + if (is_own_weight) { + // T5DecodingWeight + if (m_.useFP16) { + pT5DecodingWeightHalf_ = new T5DecodingWeight(m_.head_num, + m_.size_per_head, + m_.d_model, + m_.inter_size, + m_.vocab_size, + m_.num_layer, + m_.mem_d_model, + m_.num_bucket, + 1, // tensor_para_size + 0, // tensor_para_rank + 1, // pipeline_para_size + 0 // pipeline_para_rank + ); + pT5DecodingWeightHalf_->loadModel(std::string(m_.ckpt_path)); + } + else { + pT5DecodingWeightFloat_ = new T5DecodingWeight(m_.head_num, + m_.size_per_head, + m_.d_model, + m_.inter_size, + m_.vocab_size, + m_.num_layer, + m_.mem_d_model, + m_.num_bucket, + 1, // tensor_para_size, + 0, // tensor_para_rank, + 1, // pipeline_para_size, + 0 // pipeline_para_rank + ); + pT5DecodingWeightFloat_->loadModel(std::string(m_.ckpt_path)); + } } - // Gemm file selection check_cuda_error(cudaGetDeviceProperties(&cuda_device_prop_, 0)); std::string gemmFileName = std::string(GEMM_CONFIG).substr(0, 11) + std::string("-SM") @@ -964,39 +940,39 @@ T5DecodingPlugin::T5DecodingPlugin(const std::string& name, const void* buffer, + std::string("-BM") + std::to_string(m_.beam_width) + std::string(".in"); std::ifstream infile(gemmFileName); if (infile.good()) { -#if DEBUG_ENABLE == 1 - printf("Gemm file exist!\n"); +#ifdef T5_PLUGIN_DEBUG + FT_LOG_INFO("Gemm file exist!"); #endif } else { -#if DEBUG_ENABLE == 1 - printf("Gemm file do not exist!\n"); +#ifdef T5_PLUGIN_DEBUG + FT_LOG_INFO("Gemm file do not exist!"); #endif int argv[16] = { 0, - m_.max_batch_size, - m_.beam_width, - (m_.batch_size == 128 && m_.seq_len == 384) ? 128 : m_.seq_len, // seq_len, in case of OOM - m_.d_model, - m_.head_num, - m_.size_per_head, - m_.inter_size, - m_.d_model, - m_.head_num, - m_.size_per_head, - m_.inter_size, - m_.vocab_size, - m_.useFP16, // is_fp16 - 1, // tensor_para_size - m_.useFP16 // is_fp16_compute_type + (int)m_.max_batch_size, + (int)m_.beam_width, + (m_.batch_size == 128 && m_.seq_len == 384) ? 128 : (int)m_.seq_len, // seq_len, in case of OOM + (int)m_.d_model, + (int)m_.head_num, + (int)m_.size_per_head, + (int)m_.inter_size, + (int)m_.d_model, + (int)m_.head_num, + (int)m_.size_per_head, + (int)m_.inter_size, + (int)m_.vocab_size, + m_.useFP16 ? 1 : 0, // is_fp16 + 1, // tensor_para_size + false // always use fp32 compute type }; t5_gemm(argv); rename(std::string(GEMM_CONFIG).c_str(), gemmFileName.c_str()); } - pCublasAlgoMap_ = new cublasAlgoMap(gemmFileName, ""); + pCublasAlgoMap_ = new cublasAlgoMap(gemmFileName, ""); pCublasWrapperMutex_ = new std::mutex(); - pAllocator_ = new Allocator(getDevice()); + pAllocator_ = new Allocator(getDevice()); // cublas wrapper and T5Decoding pCublasWrapper_ = @@ -1031,8 +1007,8 @@ T5DecodingPlugin::T5DecodingPlugin(const std::string& name, const void* buffer, pAllocator_, m_.is_free_buffer_after_forward, &cuda_device_prop_, - {0, 1, nullptr}, // tensor_para - {0, 1, nullptr} // pipeline_para + NcclParam(0, 1), // tensor_para + NcclParam(0, 1) // pipeline_para ); } else { @@ -1064,8 +1040,8 @@ T5DecodingPlugin::T5DecodingPlugin(const std::string& name, const void* buffer, pAllocator_, m_.is_free_buffer_after_forward, &cuda_device_prop_, - {0, 1, nullptr}, // tensor_para - {0, 1, nullptr} // pipeline_para + NcclParam(0, 1), // tensor_para + NcclParam(0, 1) // pipeline_para ); } } @@ -1073,10 +1049,10 @@ T5DecodingPlugin::T5DecodingPlugin(const std::string& name, const void* buffer, T5DecodingPlugin::~T5DecodingPlugin() { WHERE_AM_I(); - if (pT5DecodingWeightHalf_ != nullptr) { + if (is_own_weight && pT5DecodingWeightHalf_ != nullptr) { delete pT5DecodingWeightHalf_; } - if (pT5DecodingWeightFloat_ != nullptr) { + if (is_own_weight && pT5DecodingWeightFloat_ != nullptr) { delete pT5DecodingWeightFloat_; } if (pT5DecodingHalf_ != nullptr) { @@ -1106,8 +1082,17 @@ void T5DecodingPlugin::serialize(void* buffer) const noexcept IPluginV2DynamicExt* T5DecodingPlugin::clone() const noexcept { WHERE_AM_I(); - auto p = new T5DecodingPlugin(name_, &m_, sizeof(m_)); + auto p = new T5DecodingPlugin(name_, + m_.max_batch_size, + m_.max_seq_len, + m_.mem_max_seq_len, + m_.beam_width, + m_.useFP16, + std::string(m_.ckpt_path), + false); p->setPluginNamespace(namespace_.c_str()); + p->pT5DecodingWeightHalf_ = this->pT5DecodingWeightHalf_; + p->pT5DecodingWeightFloat_ = this->pT5DecodingWeightFloat_; return p; } @@ -1123,76 +1108,96 @@ DataType T5DecodingPlugin::getOutputDataType(int index, const DataType* inputTyp return DataType::kINT32; } -bool T5DecodingPlugin::supportsFormatCombination(int pos, +bool T5DecodingPlugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) noexcept + int nbInputs, + int nbOutputs) noexcept { WHERE_AM_I(); bool res = false; switch (pos) { - case 0: - res = (inOut[0].type == (m_.useFP16 ? DataType::kHALF : DataType::kFLOAT)) - && (inOut[0].format == TensorFormat::kLINEAR); + case 0: // encoder_output + res = (inOut[pos].type == (m_.useFP16 ? DataType::kHALF : DataType::kFLOAT)) + && (inOut[pos].format == TensorFormat::kLINEAR); break; - case 1: - case 2: - case 3: - case 4: + case 1: // encoder_sequence_length + case 2: // runtime_top_k + case 3: // runtime_top_p + res = (inOut[pos].type == DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR); + break; + // res = (inOut[pos].type == DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR); break; + case 4: // beam_search_diversity_rate + case 5: // temperature + case 6: // len_penalty + case 7: // repetition_penalty + res = (inOut[pos].type == DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR); + break; + case 8: // output_ids + case 9: // sequence_length res = (inOut[pos].type == DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR); break; default: // should NOT be here! - ; + res = false; } -#if DEBUG_ENABLE == 1 - printf("Dim("); +#ifdef T5_PLUGIN_DEBUG + FT_LOG_INFO("Dim("); for (int i = 0; i < 5; i++) { - printf("%d,", inOut[i].dims.nbDims); + FT_LOG_INFO("%d,", inOut[i].dims.nbDims); } - printf("),"); - printf("pos=%d,res=%d,format(%d,%d,%d,%d,%d),type(%d,%d,%d,%d,%d),", - pos, - int(res), - int(inOut[0].format), - int(inOut[1].format), - int(inOut[2].format), - int(inOut[3].format), - int(inOut[4].format), - int(inOut[0].type), - int(inOut[1].type), - int(inOut[2].type), - int(inOut[3].type), - int(inOut[4].type)); - printf("kLINEAR=%d,float=%d,half=%d,int8=%d,int32=%d,bool=%d\n", - int(TensorFormat::kLINEAR), - int(DataType::kFLOAT), - int(DataType::kHALF), - int(DataType::kINT8), - int(DataType::kINT32), - int(DataType::kBOOL)); + FT_LOG_INFO("),"); + FT_LOG_INFO("pos=%d,res=%d,format(%d,%d,%d,%d,%d,%d,%d,%d,%d,%d),type(%d,%d,%d,%d,%d,%d,%d,%d,%d,%d),", + pos, + int(res), + int(inOut[0].format), + int(inOut[1].format), + int(inOut[2].format), + int(inOut[3].format), + int(inOut[4].format), + int(inOut[5].format), + int(inOut[6].format), + int(inOut[7].format), + int(inOut[8].format), + int(inOut[9].format), + int(inOut[0].type), + int(inOut[1].type), + int(inOut[2].type), + int(inOut[3].type), + int(inOut[4].type), + int(inOut[5].type), + int(inOut[6].type), + int(inOut[7].type), + int(inOut[8].type), + int(inOut[9].type)); + FT_LOG_INFO("kLINEAR=%d,float=%d,half=%d,int8=%d,int32=%d,bool=%d\n", + int(TensorFormat::kLINEAR), + int(DataType::kFLOAT), + int(DataType::kHALF), + int(DataType::kINT8), + int(DataType::kINT32), + int(DataType::kBOOL)); #endif return res; } -DimsExprs T5DecodingPlugin::getOutputDimensions(int index, +DimsExprs T5DecodingPlugin::getOutputDimensions(int index, const DimsExprs* pInputDim, - int nInputDim, - IExprBuilder& exprBuilder) noexcept + int nInputDim, + IExprBuilder& exprBuilder) noexcept { WHERE_AM_I(); DimsExprs ret; switch (index) { case 0: ret.nbDims = 3; - ret.d[0] = pInputDim[0].d[0]; - ret.d[1] = exprBuilder.constant(m_.beam_width); - ret.d[2] = exprBuilder.constant(m_.max_seq_len); + ret.d[0] = pInputDim[0].d[0]; + ret.d[1] = exprBuilder.constant(m_.beam_width); + ret.d[2] = exprBuilder.constant(m_.max_seq_len); break; case 1: ret.nbDims = 2; - ret.d[0] = pInputDim[0].d[0]; - ret.d[1] = exprBuilder.constant(m_.beam_width); + ret.d[0] = pInputDim[0].d[0]; + ret.d[1] = exprBuilder.constant(m_.beam_width); break; default: // should NOT be here! ; @@ -1201,9 +1206,9 @@ DimsExprs T5DecodingPlugin::getOutputDimensions(int index, } void T5DecodingPlugin::configurePlugin(const DynamicPluginTensorDesc* in, - int nbInput, + int nbInput, const DynamicPluginTensorDesc* out, - int nbOutput) noexcept + int nbOutput) noexcept { FT_LOG_DEBUG(__PRETTY_FUNCTION__); WHERE_AM_I(); @@ -1211,9 +1216,9 @@ void T5DecodingPlugin::configurePlugin(const DynamicPluginTensorDesc* in, } size_t T5DecodingPlugin::getWorkspaceSize(const PluginTensorDesc* inputs, - int32_t nbInputs, + int32_t nbInputs, const PluginTensorDesc* outputs, - int32_t nbOutputs) const noexcept + int32_t nbOutputs) const noexcept { WHERE_AM_I(); return 0; @@ -1276,20 +1281,45 @@ void T5DecodingPlugin::detachFromContext() noexcept int T5DecodingPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept { FT_LOG_DEBUG(__PRETTY_FUNCTION__); WHERE_AM_I(); m_.batch_size = inputDesc[0].dims.d[0]; - m_.seq_len = inputDesc[0].dims.d[1]; + m_.seq_len = inputDesc[0].dims.d[1]; PRINT_DECODING(inputDesc[0].type) cublasSetStream(cublasHandle_, stream); pCublasWrapper_->setStream(stream); + int nTopK = (inputDesc[2].dims.d[0] == 1) ? 1 : m_.batch_size; + int nTopP = (inputDesc[3].dims.d[0] == 1) ? 1 : m_.batch_size; + int nBeam_search_diversity_rate = (inputDesc[4].dims.d[0] == 1) ? 1 : m_.batch_size; + int nTemperature = (inputDesc[5].dims.d[0] == 1) ? 1 : m_.batch_size; + int nLen_penalty = (inputDesc[6].dims.d[0] == 1) ? 1 : m_.batch_size; + int nRepetition_penalty = (inputDesc[7].dims.d[0] == 1) ? 1 : m_.batch_size; + int* pTopK = new int[m_.batch_size]; + float* pTopP = new float[m_.batch_size]; + float* pBeam_search_diversity_rate = new float[m_.batch_size]; + float* pTemperature = new float[m_.batch_size]; + float* pLen_penalty = new float[m_.batch_size]; + float* pRepetition_penalty = new float[m_.batch_size]; + + cudaMemcpyAsync(pTopK, (int*)inputs[2], sizeof(int) * nTopK, cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(pTopP, (float*)inputs[3], sizeof(float) * nTopP, cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(pBeam_search_diversity_rate, + (float*)inputs[4], + sizeof(float) * nBeam_search_diversity_rate, + cudaMemcpyDeviceToHost, + stream); + cudaMemcpyAsync(pTemperature, (float*)inputs[5], sizeof(float) * nTemperature, cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(pLen_penalty, (float*)inputs[6], sizeof(float) * nLen_penalty, cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync( + pRepetition_penalty, (float*)inputs[7], sizeof(float) * nRepetition_penalty, cudaMemcpyDeviceToHost, stream); + std::unordered_map outputTensor{ {"output_ids", Tensor{MEMORY_GPU, @@ -1306,10 +1336,26 @@ int T5DecodingPlugin::enqueue(const PluginTensorDesc* inputDesc, {"encoder_output", Tensor{MEMORY_GPU, TYPE_FP16, - std::vector{(size_t)m_.batch_size, (size_t)m_.seq_len, (size_t)m_.mem_hidden_units}, + std::vector{(size_t)m_.batch_size, (size_t)m_.seq_len, (size_t)m_.mem_d_model}, (half*)inputs[0]}}, {"encoder_sequence_length", - Tensor{MEMORY_GPU, TYPE_INT32, std::vector{(size_t)m_.batch_size}, (int*)inputs[1]}}}; + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{(size_t)m_.batch_size}, (int*)inputs[1]}}, + {"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{(size_t)nTopK}, (int*)pTopK}}, + {"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{(size_t)nTopP}, (float*)pTopP}}, + {"beam_search_diversity_rate", + Tensor{MEMORY_CPU, + TYPE_FP32, + std::vector{(size_t)nBeam_search_diversity_rate}, + (float*)pBeam_search_diversity_rate}}, + {"temperature", + Tensor{MEMORY_CPU, TYPE_FP32, std::vector{(size_t)nTemperature}, (float*)pTemperature}}, + {"len_penalty", + Tensor{MEMORY_CPU, TYPE_FP32, std::vector{(size_t)nLen_penalty}, (float*)pLen_penalty}}, + {"repetition_penalty", + Tensor{MEMORY_CPU, + TYPE_FP32, + std::vector{(size_t)nRepetition_penalty}, + (float*)pRepetition_penalty}}}; pT5DecodingHalf_->setStream(stream); pT5DecodingHalf_->forward(&outputTensor, &inputTensor, pT5DecodingWeightHalf_); } @@ -1318,24 +1364,47 @@ int T5DecodingPlugin::enqueue(const PluginTensorDesc* inputDesc, {"encoder_output", Tensor{MEMORY_GPU, TYPE_FP32, - std::vector{(size_t)m_.batch_size, (size_t)m_.seq_len, (size_t)m_.mem_hidden_units}, + std::vector{(size_t)m_.batch_size, (size_t)m_.seq_len, (size_t)m_.mem_d_model}, (float*)inputs[0]}}, {"encoder_sequence_length", - Tensor{MEMORY_GPU, TYPE_INT32, std::vector{(size_t)m_.batch_size}, (int*)inputs[1]}}}; + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{(size_t)m_.batch_size}, (int*)inputs[1]}}, + {"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{(size_t)nTopK}, (int*)pTopK}}, + {"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{(size_t)nTopP}, (float*)pTopP}}, + {"beam_search_diversity_rate", + Tensor{MEMORY_CPU, + TYPE_FP32, + std::vector{(size_t)nBeam_search_diversity_rate}, + (float*)pBeam_search_diversity_rate}}, + {"temperature", + Tensor{MEMORY_CPU, TYPE_FP32, std::vector{(size_t)nTemperature}, (float*)pTemperature}}, + {"len_penalty", + Tensor{MEMORY_CPU, TYPE_FP32, std::vector{(size_t)nLen_penalty}, (float*)pLen_penalty}}, + {"repetition_penalty", + Tensor{MEMORY_CPU, + TYPE_FP32, + std::vector{(size_t)nRepetition_penalty}, + (float*)pRepetition_penalty}}}; pT5DecodingFloat_->setStream(stream); pT5DecodingFloat_->forward(&outputTensor, &inputTensor, pT5DecodingWeightFloat_); } + + delete[] pTopK; + delete[] pTopP; + delete[] pBeam_search_diversity_rate; + delete[] pTemperature; + delete[] pLen_penalty; + delete[] pRepetition_penalty; return 0; } // class T5DecodingPluginCreator ------------------------------------------------------------------- -PluginFieldCollection T5DecodingPluginCreator::fc_{}; +PluginFieldCollection T5DecodingPluginCreator::fc_{}; std::vector T5DecodingPluginCreator::attr_; T5DecodingPluginCreator::T5DecodingPluginCreator() { WHERE_AM_I(); fc_.nbFields = attr_.size(); - fc_.fields = attr_.data(); + fc_.fields = attr_.data(); } T5DecodingPluginCreator::~T5DecodingPluginCreator() @@ -1347,86 +1416,30 @@ IPluginV2* T5DecodingPluginCreator::createPlugin(const char* name, const PluginF { FT_LOG_DEBUG(__PRETTY_FUNCTION__); WHERE_AM_I(); - int max_batch_size = 128; - int max_seq_len = 384; - int mem_max_seq_len = max_seq_len; - int beam_width = 4; - int head_num = 8; - int size_per_head = 512 / 8; - int d_model = head_num * size_per_head; - int inter_size = d_model * 4; - int num_layer = 6; - int vocab_size = 32128; - int num_bucket = 32; - int max_distance = 128; - int start_id = 0; - int end_id = 1; - float beam_search_diversity_rate = 0.0f; - int top_k = beam_width; - float top_p = 0.0f; - float temperature = 1.0f; - float len_penalty = 2.0f; - float q_scaling = 1.0f / (sqrt(size_per_head) * 1.0f); - float repetition_penalty = 1.0f; - int useFP16 = 0; + int max_batch_size = 128; + int max_seq_len = 384; + int mem_max_seq_len = max_seq_len; + int beam_width = 4; + int useFP16 = 0; + std::string ckpt_path = std::string(""); std::map name2pint{ {"max_batch_size", &max_batch_size}, {"max_seq_len", &max_seq_len}, {"mem_max_seq_len", &mem_max_seq_len}, {"beam_width", &beam_width}, - {"head_num", &head_num}, - {"size_per_head", &size_per_head}, - {"inter_size", &inter_size}, - {"d_model", &d_model}, - {"num_layer", &num_layer}, - {"num_bucket", &num_bucket}, - {"max_distance", &max_distance}, - {"vocab_size", &vocab_size}, - {"start_id", &start_id}, - {"end_id", &end_id}, - {"top_k", &top_k}, {"useFP16", &useFP16}, }; - std::map name2pfloat{ - {"beam_search_diversity_rate", &beam_search_diversity_rate}, - {"top_p", &top_p}, - {"temperature", &temperature}, - {"len_penalty", &len_penalty}, - {"repetition_penalty", &repetition_penalty}, - }; for (int i = 0; i < fc->nbFields; i++) { if (name2pint.find(fc->fields[i].name) != name2pint.end()) { *name2pint[fc->fields[i].name] = *(int*)fc->fields[i].data; } - if (name2pfloat.find(fc->fields[i].name) != name2pfloat.end()) { - *name2pfloat[fc->fields[i].name] = *(float*)fc->fields[i].data; + else if (!strcmp(fc->fields[i].name, "ckpt_path")) { + ckpt_path = std::string((char*)fc->fields[i].data); } } - - return new T5DecodingPlugin(name, - max_batch_size, - max_seq_len, - mem_max_seq_len, - beam_width, - head_num, - size_per_head, - inter_size, - d_model, - num_layer, - vocab_size, - num_bucket, - max_distance, - q_scaling, - start_id, - end_id, - beam_search_diversity_rate, - top_k, - top_p, - temperature, - len_penalty, - repetition_penalty, - useFP16); + return new T5DecodingPlugin( + name, max_batch_size, max_seq_len, mem_max_seq_len, beam_width, useFP16, ckpt_path, true); } IPluginV2* diff --git a/src/fastertransformer/tensorrt_plugin/t5/T5Plugin.h b/src/fastertransformer/tensorrt_plugin/t5/T5Plugin.h index 92aa6baf9..8190a4f64 100644 --- a/src/fastertransformer/tensorrt_plugin/t5/T5Plugin.h +++ b/src/fastertransformer/tensorrt_plugin/t5/T5Plugin.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ #include "src/fastertransformer/models/t5/T5Encoder.h" #include "src/fastertransformer/models/t5/T5EncoderWeight.h" #include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/logger.h" + #include #include #include @@ -29,62 +31,59 @@ #include #include -#if DEBUG_ENABLE == 1 +// #define T5_PLUGIN_DEBUG +#ifdef T5_PLUGIN_DEBUG #define WHERE_AM_I() printf("[%s]: this->%p\n", __func__, this); #define PRINT_ENCODER(DATA_TYPE) \ - printf("[encoder::%s]Info:\n\tdatatype=%d\n", __func__, DATA_TYPE); \ - printf("\tmax_batch_size=%d\n", m_.max_batch_size); \ - printf("\tmax_seq_len=%d\n", m_.max_seq_len); \ - printf("\tbeam_width=%d\n", m_.beam_width); \ - printf("\thead_num=%d\n", m_.head_num); \ - printf("\tsize_per_head=%d\n", m_.size_per_head); \ - printf("\td_model=%d\n", m_.d_model); \ - printf("\tinter_size=%d\n", m_.inter_size); \ - printf("\tnum_layer=%d\n", m_.num_layer); \ - printf("\tnum_bucket=%d\n", m_.num_bucket); \ - printf("\tmax_distance=%d\n", m_.max_distance); \ + printf("[encoder::%s]Info:\n\tdatatype=%d\n", __func__, (int)DATA_TYPE); \ + printf("\tmax_batch_size=%ld\n", m_.max_batch_size); \ + printf("\tmax_seq_len=%ld\n", m_.max_seq_len); \ + printf("\tbeam_width=%ld\n", m_.beam_width); \ + printf("\thead_num=%ld\n", m_.head_num); \ + printf("\tsize_per_head=%ld\n", m_.size_per_head); \ + printf("\td_model=%ld\n", m_.d_model); \ + printf("\tinter_size=%ld\n", m_.inter_size); \ + printf("\tnum_layer=%ld\n", m_.num_layer); \ + printf("\tnum_bucket=%ld\n", m_.num_bucket); \ + printf("\tmax_distance=%ld\n", m_.max_distance); \ printf("\tsm=%d\n", m_.sm); \ printf("\tq_scaling=%f\n", m_.q_scaling); \ printf("\tuseFP16=%d\n", m_.useFP16); \ - printf("\tvocab_size=%d\n", m_.vocab_size); \ - printf("\tis_remove_padding=%d\n", m_.is_remove_padding); \ - printf("\tis_free_buffer_after_forward=%d\n", m_.is_free_buffer_after_forward); \ - printf("\tis_sparse=%d\n", m_.is_sparse); \ - printf("\tattention_type=%d\n", m_.attention_type); \ - printf("\tactivation_type=%d\n", m_.activation_type); \ - printf("\tlayernorm_type=%d\n", m_.layernorm_type); \ - printf("\tbatch_size=%d\n", m_.batch_size); \ - printf("\tseq_len=%d\n", m_.seq_len); + printf("\tvocab_size=%ld\n", m_.vocab_size); \ + printf("\tis_remove_padding=%d\n", (int)m_.is_remove_padding); \ + printf("\tis_free_buffer_after_forward=%d\n", (int)m_.is_free_buffer_after_forward); \ + printf("\tis_sparse=%d\n", (int)m_.is_sparse); \ + printf("\tattention_type=%d\n", (int)m_.attention_type); \ + printf("\tactivation_type=%d\n", (int)m_.activation_type); \ + printf("\tlayernorm_type=%d\n", (int)m_.layernorm_type); \ + printf("\tbatch_size=%ld\n", m_.batch_size); \ + printf("\tseq_len=%ld\n", m_.seq_len); \ + printf("\tckpt_path=%s\n", m_.ckpt_path); #define PRINT_DECODING(DATA_TYPE) \ - printf("[decoding::%s]Info:\n\tdatatype=%d\n", __func__, DATA_TYPE); \ - printf("\tmax_batch_size=%d\n", m_.max_batch_size); \ - printf("\tmax_seq_len=%d\n", m_.max_seq_len); \ - printf("\tmem_max_seq_len=%d\n", m_.mem_max_seq_len); \ - printf("\tbeam_width=%d\n", m_.beam_width); \ - printf("\thead_num=%d\n", m_.head_num); \ - printf("\tsize_per_head=%d\n", m_.size_per_head); \ - printf("\td_model=%d\n", m_.d_model); \ - printf("\tinter_size=%d\n", m_.inter_size); \ - printf("\tnum_layer=%d\n", m_.num_layer); \ - printf("\tvocab_size=%d\n", m_.vocab_size); \ - printf("\tnum_bucket=%d\n", m_.num_bucket); \ - printf("\tmax_distance=%d\n", m_.max_distance); \ + printf("[decoding::%s]Info:\n\tdatatype=%d\n", __func__, (int)DATA_TYPE); \ + printf("\tmax_batch_size=%ld\n", m_.max_batch_size); \ + printf("\tmax_seq_len=%ld\n", m_.max_seq_len); \ + printf("\tmem_max_seq_len=%ld\n", m_.mem_max_seq_len); \ + printf("\tbeam_width=%ld\n", m_.beam_width); \ + printf("\thead_num=%ld\n", m_.head_num); \ + printf("\tsize_per_head=%ld\n", m_.size_per_head); \ + printf("\td_model=%ld\n", m_.d_model); \ + printf("\tinter_size=%ld\n", m_.inter_size); \ + printf("\tnum_layer=%ld\n", m_.num_layer); \ + printf("\tvocab_size=%ld\n", m_.vocab_size); \ + printf("\tnum_bucket=%ld\n", m_.num_bucket); \ + printf("\tmax_distance=%ld\n", m_.max_distance); \ printf("\tq_scaling=%f\n", m_.q_scaling); \ printf("\tstart_id=%d\n", m_.start_id); \ printf("\tend_id=%d\n", m_.end_id); \ - printf("\tbeam_search_diversity_rate=%f\n", m_.beam_search_diversity_rate); \ - printf("\ttop_k=%d\n", m_.top_k); \ - printf("\ttop_p=%f\n", m_.top_p); \ - printf("\ttemperature=%f\n", m_.temperature); \ - printf("\tlen_penalty=%f\n", m_.len_penalty); \ - printf("\trepetition_penalty=%f\n", m_.repetition_penalty); \ printf("\tuseFP16=%d\n", m_.useFP16); \ - printf("\tmem_d_model=%d\n", m_.mem_d_model); \ - printf("\tmem_hidden_units=%d\n", m_.mem_hidden_units); \ + printf("\tmem_d_model=%ld\n", m_.mem_d_model); \ + printf("\tmem_hidden_units=%ld\n", m_.mem_hidden_units); \ printf("\tis_free_buffer_after_forward=%d\n", m_.is_free_buffer_after_forward); \ - printf("\tbatch_size=%d\n", m_.batch_size); \ - printf("\tseq_len=%d\n", m_.seq_len); + printf("\tbatch_size=%ld\n", m_.batch_size); \ + printf("\tseq_len=%ld\n", m_.seq_len); \ + printf("\tckpt_path=%s\n", m_.ckpt_path); #else #define WHERE_AM_I() @@ -102,109 +101,104 @@ static const char* DECODING_VERSION{"1"}; using namespace fastertransformer; namespace nvinfer1 { - // class T5EncoderPlugin --------------------------------------------------------------------------- class T5EncoderPlugin: public IPluginV2DynamicExt { private: - using IPluginV2Ext::configurePlugin; + using IPluginV2::enqueue; using IPluginV2::getOutputDimensions; using IPluginV2::getWorkspaceSize; - using IPluginV2::enqueue; + using IPluginV2Ext::configurePlugin; const std::string name_; - std::string namespace_; - cublasHandle_t cublasHandle_; - cublasLtHandle_t cublasltHandle_; + std::string namespace_; + bool is_own_weight = false; + cublasHandle_t cublasHandle_; + cublasLtHandle_t cublasltHandle_; #ifdef SPARSITY_ENABLED cusparseLtHandle_t cusparseltHandle_; #endif - cublasAlgoMap* pCublasAlgoMap_ = nullptr; - std::mutex* pCublasWrapperMutex_ = nullptr; - T5EncoderWeight* pT5EncoderWeightHalf_ = nullptr; - T5EncoderWeight* pT5EncoderWeightFloat_ = nullptr; - Allocator* pAllocator_ = nullptr; - cublasMMWrapper* pCublasWrapper_ = nullptr; - T5Encoder* pT5EncoderHalf_ = nullptr; - T5Encoder* pT5EncoderFloat_ = nullptr; + cublasAlgoMap* pCublasAlgoMap_ = nullptr; + std::mutex* pCublasWrapperMutex_ = nullptr; + T5EncoderWeight* pT5EncoderWeightHalf_ = nullptr; + T5EncoderWeight* pT5EncoderWeightFloat_ = nullptr; + Allocator* pAllocator_ = nullptr; + cublasMMWrapper* pCublasWrapper_ = nullptr; + T5Encoder* pT5EncoderHalf_ = nullptr; + T5Encoder* pT5EncoderFloat_ = nullptr; struct { // constructor parameter size_t max_batch_size = 128; - size_t max_seq_len = 384; - size_t beam_width = 1; - size_t head_num = 8; - size_t size_per_head = 512 / 8; - size_t d_model = head_num * size_per_head; - size_t inter_size = d_model * 4; - size_t num_layer = 6; - size_t num_bucket = 32; - size_t max_distance = 128; - int sm = -1; // assign later - float q_scaling = 1.0f / (1.0f * sqrt(size_per_head)); - bool useFP16 = false; + size_t max_seq_len = 384; + size_t beam_width = 1; + size_t head_num = 8; + size_t size_per_head = 512 / 8; + size_t d_model = head_num * size_per_head; + size_t inter_size = d_model * 4; + size_t num_layer = 6; + size_t num_bucket = 32; + size_t max_distance = 128; + int sm = -1; // assign later + float q_scaling = 1.0f / (1.0f * sqrt(size_per_head)); + bool useFP16 = false; // internal parameter - size_t vocab_size = 32128; - bool is_remove_padding = true; - bool is_free_buffer_after_forward = false; - bool is_sparse = false; - AttentionType attention_type = AttentionType::UNFUSED_MHA; - fastertransformer::ActivationType activation_type = fastertransformer::ActivationType::Relu; - LayerNormType layernorm_type = LayerNormType::pre_layernorm; + size_t vocab_size = 32128; + bool is_remove_padding = true; + bool is_free_buffer_after_forward = false; + bool is_sparse = false; + AttentionType attention_type = AttentionType::UNFUSED_MHA; + fastertransformer::ActivationType activation_type = fastertransformer::ActivationType::Relu; + LayerNormType layernorm_type = LayerNormType::pre_layernorm; // runtime parameter - size_t batch_size = 0; - size_t seq_len = 0; + size_t batch_size = 0; + size_t seq_len = 0; + char ckpt_path[256] = ""; } m_; public: T5EncoderPlugin() = delete; T5EncoderPlugin(const std::string& name, - size_t max_batch_size, - size_t max_seq_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - size_t num_bucket, - size_t max_distance, - int sm, - float q_scaling, - int useFP16); + size_t max_batch_size, + size_t max_seq_len, + size_t beam_width, + int sm, + int useFP16, + const std::string& ckpt_path, + bool own_weight); T5EncoderPlugin(const std::string& name, const void* buffer, size_t length); ~T5EncoderPlugin(); - virtual size_t getSerializationSize() const noexcept override; - virtual void serialize(void* buffer) const noexcept override; + virtual size_t getSerializationSize() const noexcept override; + virtual void serialize(void* buffer) const noexcept override; IPluginV2DynamicExt* clone() const noexcept override; - int getNbOutputs() const noexcept override; - DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const noexcept override; + int getNbOutputs() const noexcept override; + DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const noexcept override; bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - DimsExprs getOutputDimensions(int index, - const DimsExprs* pInputDim, - int nInputDim, - IExprBuilder& exprBuilder) noexcept override; + DimsExprs getOutputDimensions(int index, + const DimsExprs* pInputDim, + int nInputDim, + IExprBuilder& exprBuilder) noexcept override; virtual void configurePlugin(const DynamicPluginTensorDesc* in, - int nbInput, + int nbInput, const DynamicPluginTensorDesc* out, - int nbOutput) noexcept override; - size_t getWorkspaceSize(const PluginTensorDesc* inputs, - int32_t nbInputs, - const PluginTensorDesc* outputs, - int32_t nbOutputs) const noexcept override; - int enqueue(const PluginTensorDesc* inputDesc, - const PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept override; - void setPluginNamespace(const char* szNamespace) noexcept override; - const char* getPluginNamespace() const noexcept override; - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; - int initialize() noexcept override; - void terminate() noexcept override; - void destroy() noexcept override; + int nbOutput) noexcept override; + size_t getWorkspaceSize(const PluginTensorDesc* inputs, + int32_t nbInputs, + const PluginTensorDesc* outputs, + int32_t nbOutputs) const noexcept override; + int enqueue(const PluginTensorDesc* inputDesc, + const PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; + void setPluginNamespace(const char* szNamespace) noexcept override; + const char* getPluginNamespace() const noexcept override; + const char* getPluginType() const noexcept override; + const char* getPluginVersion() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + void destroy() noexcept override; void attachToContext(cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept; void detachFromContext() noexcept; }; @@ -212,16 +206,16 @@ class T5EncoderPlugin: public IPluginV2DynamicExt { // class T5EncoderPluginCreator -------------------------------------------------------------------- class T5EncoderPluginCreator: public IPluginCreator { private: - static PluginFieldCollection fc_; + static PluginFieldCollection fc_; static std::vector attr_; - std::string namespace_; + std::string namespace_; public: T5EncoderPluginCreator(); ~T5EncoderPluginCreator(); - IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override; - IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override; - void setPluginNamespace(const char* szNamespace) noexcept override; + IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override; + IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override; + void setPluginNamespace(const char* szNamespace) noexcept override; const char* getPluginNamespace() const noexcept override; const char* getPluginName() const noexcept override; const char* getPluginVersion() const noexcept override; @@ -230,117 +224,104 @@ class T5EncoderPluginCreator: public IPluginCreator { // class T5DecodingPlugin -------------------------------------------------------------------------- class T5DecodingPlugin: public IPluginV2DynamicExt { private: - using IPluginV2Ext::configurePlugin; + using IPluginV2::enqueue; using IPluginV2::getOutputDimensions; using IPluginV2::getWorkspaceSize; - using IPluginV2::enqueue; + using IPluginV2Ext::configurePlugin; - const std::string name_; - std::string namespace_; - cublasHandle_t cublasHandle_; - cublasLtHandle_t cublasltHandle_; - cudaDeviceProp cuda_device_prop_; - cublasAlgoMap* pCublasAlgoMap_ = nullptr; - Allocator* pAllocator_ = nullptr; - std::mutex* pCublasWrapperMutex_ = nullptr; - cublasMMWrapper* pCublasWrapper_ = nullptr; - T5DecodingWeight* pT5DecodingWeightHalf_ = nullptr; - T5DecodingWeight* pT5DecodingWeightFloat_ = nullptr; - T5Decoding* pT5DecodingHalf_ = nullptr; - T5Decoding* pT5DecodingFloat_ = nullptr; + const std::string name_; + std::string namespace_; + bool is_own_weight = false; + cublasHandle_t cublasHandle_; + cublasLtHandle_t cublasltHandle_; + cudaDeviceProp cuda_device_prop_; + cublasAlgoMap* pCublasAlgoMap_ = nullptr; + Allocator* pAllocator_ = nullptr; + std::mutex* pCublasWrapperMutex_ = nullptr; + cublasMMWrapper* pCublasWrapper_ = nullptr; + T5DecodingWeight* pT5DecodingWeightHalf_ = nullptr; + T5DecodingWeight* pT5DecodingWeightFloat_ = nullptr; + T5Decoding* pT5DecodingHalf_ = nullptr; + T5Decoding* pT5DecodingFloat_ = nullptr; struct { // constructor parameter - size_t max_batch_size = 128; - size_t max_seq_len = 384; - size_t mem_max_seq_len = max_seq_len; - size_t beam_width = 4; - size_t head_num = 8; - size_t size_per_head = 512 / 8; - size_t d_model = head_num * size_per_head; - size_t inter_size = d_model * 4; - size_t num_layer = 6; - size_t vocab_size = 32128; - size_t num_bucket = 32; - size_t max_distance = 128; - float q_scaling = 1.0f / (1.0f * sqrt(size_per_head)); - int start_id = 0; - int end_id = 1; - float beam_search_diversity_rate = 0.0f; - size_t top_k = beam_width; - float top_p = 0.0f; - float temperature = 1.0f; - float len_penalty = 2.0f; - float repetition_penalty = 1.0f; - bool useFP16 = false; + size_t max_batch_size = 128; + size_t max_seq_len = 384; + size_t mem_max_seq_len = max_seq_len; + size_t beam_width = 4; + size_t head_num = 8; + size_t size_per_head = 512 / 8; + size_t d_model = head_num * size_per_head; + size_t inter_size = d_model * 4; + size_t num_layer = 6; + size_t vocab_size = 32128; + size_t num_bucket = 32; + size_t max_distance = 128; + float q_scaling = 1.0f / (1.0f * sqrt(size_per_head)); + int start_id = 0; + int end_id = 1; + float beam_search_diversity_rate = 0.0f; + int top_k = 0; + float top_p = 0.0f; + float temperature = 1.0f; + float len_penalty = 2.0f; + float repetition_penalty = 1.0f; + bool useFP16 = false; // internal parameter - size_t mem_d_model = d_model; - size_t mem_hidden_units = d_model; - bool is_free_buffer_after_forward = false; + size_t mem_d_model = d_model; + size_t mem_hidden_units = d_model; + bool is_free_buffer_after_forward = false; // runtime parameter - size_t batch_size = 128; - size_t seq_len = 384; + size_t batch_size = 128; + size_t seq_len = 384; + char ckpt_path[256] = ""; } m_; public: T5DecodingPlugin() = delete; T5DecodingPlugin(const std::string& name, - size_t max_batch_size, - size_t max_seq_len, - size_t mem_max_seq_len, - size_t beam_width, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t d_model, - size_t num_layer, - size_t vocab_size, - size_t num_bucket, - size_t max_distance, - float q_scaling, - int start_id, - int end_id, - float beam_search_diversity_rate, - size_t top_k, - float top_p, - float temperature, - float len_penalty, - float repetition_penalty, - int useFP16); + size_t max_batch_size, + size_t max_seq_len, + size_t mem_max_seq_len, + size_t beam_width, + int useFP16, + const std::string& ckpt_path, + bool own_weight); T5DecodingPlugin(const std::string& name, const void* buffer, size_t length); ~T5DecodingPlugin(); - virtual size_t getSerializationSize() const noexcept override; - virtual void serialize(void* buffer) const noexcept override; + virtual size_t getSerializationSize() const noexcept override; + virtual void serialize(void* buffer) const noexcept override; IPluginV2DynamicExt* clone() const noexcept override; - int getNbOutputs() const noexcept override; - DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const noexcept override; + int getNbOutputs() const noexcept override; + DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const noexcept override; bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - DimsExprs getOutputDimensions(int index, - const DimsExprs* pInputDim, - int nInputDim, - IExprBuilder& exprBuilder) noexcept override; + DimsExprs getOutputDimensions(int index, + const DimsExprs* pInputDim, + int nInputDim, + IExprBuilder& exprBuilder) noexcept override; virtual void configurePlugin(const DynamicPluginTensorDesc* in, - int nbInput, + int nbInput, const DynamicPluginTensorDesc* out, - int nbOutput) noexcept override; - size_t getWorkspaceSize(const PluginTensorDesc* inputs, - int32_t nbInputs, - const PluginTensorDesc* outputs, - int32_t nbOutputs) const noexcept override; - int enqueue(const PluginTensorDesc* inputDesc, - const PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept override; - void setPluginNamespace(const char* szNamespace) noexcept override; - const char* getPluginNamespace() const noexcept override; - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; - int initialize() noexcept override; - void terminate() noexcept override; - void destroy() noexcept override; + int nbOutput) noexcept override; + size_t getWorkspaceSize(const PluginTensorDesc* inputs, + int32_t nbInputs, + const PluginTensorDesc* outputs, + int32_t nbOutputs) const noexcept override; + int enqueue(const PluginTensorDesc* inputDesc, + const PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; + void setPluginNamespace(const char* szNamespace) noexcept override; + const char* getPluginNamespace() const noexcept override; + const char* getPluginType() const noexcept override; + const char* getPluginVersion() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + void destroy() noexcept override; void attachToContext(cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept; void detachFromContext() noexcept; }; @@ -348,16 +329,16 @@ class T5DecodingPlugin: public IPluginV2DynamicExt { // class T5DecodingPluginCreator ------------------------------------------------------------------- class T5DecodingPluginCreator: public IPluginCreator { private: - static PluginFieldCollection fc_; + static PluginFieldCollection fc_; static std::vector attr_; - std::string namespace_; + std::string namespace_; public: T5DecodingPluginCreator(); ~T5DecodingPluginCreator(); - IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override; - IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override; - void setPluginNamespace(const char* szNamespace) noexcept override; + IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override; + IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override; + void setPluginNamespace(const char* szNamespace) noexcept override; const char* getPluginNamespace() const noexcept override; const char* getPluginName() const noexcept override; const char* getPluginVersion() const noexcept override; diff --git a/src/fastertransformer/tensorrt_plugin/t5/T5PluginGemm.cc b/src/fastertransformer/tensorrt_plugin/t5/T5PluginGemm.cc index dc967b125..088df65a3 100644 --- a/src/fastertransformer/tensorrt_plugin/t5/T5PluginGemm.cc +++ b/src/fastertransformer/tensorrt_plugin/t5/T5PluginGemm.cc @@ -44,24 +44,24 @@ int t5_gemm(int argv[16]) return 0; } - const int batch_size = argv[1]; - const int beam_width = argv[2]; + const int batch_size = argv[1]; + const int beam_width = argv[2]; const int max_mem_seq_len = argv[3]; - const int encoder_d_model = argv[4]; - const int encoder_head_num = argv[5]; + const int encoder_d_model = argv[4]; + const int encoder_head_num = argv[5]; const int encoder_size_per_head = argv[6]; - const int encoder_inter_size = argv[7]; + const int encoder_inter_size = argv[7]; - const int decoder_d_model = argv[8]; - const int decoder_head_num = argv[9]; + const int decoder_d_model = argv[8]; + const int decoder_head_num = argv[9]; const int decoder_size_per_head = argv[10]; - const int decoder_inter_size = argv[11]; - const int decoder_vocab_size = argv[12]; + const int decoder_inter_size = argv[11]; + const int decoder_vocab_size = argv[12]; - const ft::CublasDataType data_type = static_cast(argv[13]); // 0 FP32, 1 FP16, 2 BF 16 - const int tensor_para_size = argc <= 14 ? 1 : argv[14]; - const int is_fp16_compute_type = argc <= 15 ? 1 : argv[15]; + const ft::CublasDataType data_type = static_cast(argv[13]); // 0 FP32, 1 FP16, 2 BF 16 + const int tensor_para_size = argc <= 14 ? 1 : argv[14]; + const int is_fp16_compute_type = argc <= 15 ? 1 : argv[15]; std::cout << "[INFO] arguments: " << std::endl << " batch_size: " << batch_size << std::endl @@ -80,7 +80,7 @@ int t5_gemm(int argv[16]) << " tensor_para_size: " << tensor_para_size << std::endl << " is_fp16_compute_type: " << is_fp16_compute_type << std::endl; - void* gemm_test_buf; + void* gemm_test_buf; size_t buf_size_in_byte = ft::calT5GemmTestBufSizeInByte(batch_size, beam_width, max_mem_seq_len, @@ -127,7 +127,7 @@ int t5_gemm(int argv[16]) false, is_fp16_compute_type); } - if (data_type == ft::HALF_DATATYPE) { + else if (data_type == ft::HALF_DATATYPE) { ft::generate_t5_gemm_config(batch_size, beam_width, max_mem_seq_len, @@ -146,7 +146,7 @@ int t5_gemm(int argv[16]) is_fp16_compute_type); } #ifdef ENABLE_BF16 - if (data_type == ft::BFLOAT16_DATATYPE) { + else if (data_type == ft::BFLOAT16_DATATYPE) { ft::generate_t5_gemm_config<__nv_bfloat16>(batch_size, beam_width, max_mem_seq_len, @@ -166,8 +166,8 @@ int t5_gemm(int argv[16]) } #endif else { - printf("[ERROR] data type only supports fp32(0), fp16(1), bf16(2). \n"); - return -1; + FT_LOG_ERROR("data type %d is invalid, only supports fp32(0), fp16(1), bf16(2).", (int)(data_type)); + ft::FT_CHECK(false); } ft::check_cuda_error(cudaFree(gemm_test_buf)); diff --git a/src/fastertransformer/tensorrt_plugin/vit/CMakeLists.txt b/src/fastertransformer/tensorrt_plugin/vit/CMakeLists.txt index 8a2fe6d98..6f8a106a8 100644 --- a/src/fastertransformer/tensorrt_plugin/vit/CMakeLists.txt +++ b/src/fastertransformer/tensorrt_plugin/vit/CMakeLists.txt @@ -15,6 +15,7 @@ cmake_minimum_required(VERSION 3.8) set(vit_trt_files ViTPlugin.cpp + ViTINT8Plugin.cpp ) @@ -23,5 +24,5 @@ if(BUILD_TRT) add_library(${LIB_NAME} SHARED ${vit_trt_files}) set_target_properties(${LIB_NAME} PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(${LIB_NAME} trt_fused_multi_head_attention ViT -lcudnn -lcublas -lcudart -lnvinfer) + target_link_libraries(${LIB_NAME} trt_fused_multi_head_attention ViT ViTINT8 -lcudnn -lcublas -lcudart -lnvinfer) endif() diff --git a/src/fastertransformer/tensorrt_plugin/vit/ViTINT8Plugin.cpp b/src/fastertransformer/tensorrt_plugin/vit/ViTINT8Plugin.cpp new file mode 100644 index 000000000..a5d0fae88 --- /dev/null +++ b/src/fastertransformer/tensorrt_plugin/vit/ViTINT8Plugin.cpp @@ -0,0 +1,592 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ViTINT8Plugin.h" +#include "NvInfer.h" +#include +#include + +#include +#include +#include +#include +#include + +using namespace nvinfer1; +using namespace std; + +namespace fastertransformer { + +// Static class fields initialization +PluginFieldCollection VisionTransformerINT8PluginCreator::mFC{}; +std::vector VisionTransformerINT8PluginCreator::mPluginAttributes; + +REGISTER_TENSORRT_PLUGIN(VisionTransformerINT8PluginCreator); + +template +VisionTransformerINT8Plugin::VisionTransformerINT8Plugin(const std::string& name, + const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int num_heads, + const int inter_size, + const int layer_num, + const float q_scaling, + const bool with_cls_token, + const int int8_mode, + const std::vector& w): + layer_name_(name) +{ + + settings_.max_batch_size = max_batch; + settings_.img_size = img_size; + settings_.chn_num = in_chans; + settings_.patch_size = patch_size; + settings_.embed_dim = embed_dim; + settings_.head_num = num_heads; + settings_.inter_size = inter_size; + settings_.num_layer = layer_num; + settings_.with_cls_token = with_cls_token; + settings_.sm = getSMVersion(); + settings_.q_scaling = q_scaling; + settings_.int8_mode = int8_mode; + settings_.seq_len = (img_size / patch_size) * (img_size / patch_size) + (with_cls_token ? 1 : 0); + settings_.attention_type = getAttentionType(embed_dim / num_heads, settings_.sm, true, settings_.seq_len); + + Init(w); +} + +template +VisionTransformerINT8Plugin::VisionTransformerINT8Plugin(const std::string& name, const void* data, size_t length): + layer_name_(name) +{ + ::memcpy(&settings_, data, sizeof(settings_)); + const char* w_buffer = static_cast(data) + sizeof(settings_); + + std::vector dummy; + Init(dummy); + + params_->deserialize(w_buffer); +} + +template +VisionTransformerINT8Plugin::VisionTransformerINT8Plugin(const VisionTransformerINT8Plugin& plugin): + layer_name_(plugin.layer_name_), settings_(plugin.settings_) +{ + std::vector dummy; + Init(dummy); + *params_ = *plugin.params_; +} + +template +void VisionTransformerINT8Plugin::Init(const std::vector& w) +{ + params_ = new ViTINT8Weight(settings_.embed_dim, + settings_.inter_size, + settings_.num_layer, + settings_.img_size, + settings_.patch_size, + settings_.chn_num, + settings_.with_cls_token); + + if (w.size() > 0) { + size_t weight_num = params_->GetWeightCount(); + + if (weight_num != w.size()) { + printf("[ERROR][VisionTransformerINT8Plugin] weights number %lu does not match expected number %lu!\n", + w.size(), + weight_num); + exit(-1); + } + const T* const* pp_buf = &w[0]; + params_->CopyWeightsFromHostBuffers(pp_buf); + } + + check_cuda_error(cublasCreate(&cublas_handle_)); + check_cuda_error(cublasLtCreate(&cublaslt_handle_)); + checkCUDNN(cudnnCreate(&cudnn_handle_)); + + bool _use_ORDER_COL32_2R_4R4 = false; +#if (CUDART_VERSION >= 11000) + if (settings_.sm >= 80) { + _use_ORDER_COL32_2R_4R4 = true; + } +#endif + + cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); + cublasWrapperMutex_ = new std::mutex(); + allocator_ = new Allocator(getDevice()); + + cublas_wrapper_ = new cublasINT8MMWrapper( + cublas_handle_, cublaslt_handle_, nullptr, cublasAlgoMap_, cublasWrapperMutex_, _use_ORDER_COL32_2R_4R4); + if (std::is_same::value) { + cublas_wrapper_->setFP16GemmConfig(); + } + else if (std::is_same::value) { + cublas_wrapper_->setFP32GemmConfig(); + } + + vit_transformer_ = new ViTTransformerINT8(settings_.max_batch_size, + settings_.img_size, + settings_.chn_num, + settings_.patch_size, + settings_.embed_dim, + settings_.head_num, + settings_.inter_size, + settings_.num_layer, + settings_.with_cls_token, + settings_.sm, + settings_.q_scaling, + settings_.int8_mode, + 0, + cudnn_handle_, + cublas_wrapper_, + allocator_, + false, + settings_.attention_type); +} + +template +VisionTransformerINT8Plugin::~VisionTransformerINT8Plugin() +{ + check_cuda_error(cublasDestroy(cublas_handle_)); + check_cuda_error(cublasLtDestroy(cublaslt_handle_)); + checkCUDNN(cudnnDestroy(cudnn_handle_)); + delete vit_transformer_; + delete cublas_wrapper_; + delete allocator_; + delete cublasWrapperMutex_; + delete cublasAlgoMap_; + delete params_; +} + +// IPluginV2DynamicExt Methods +template +nvinfer1::IPluginV2DynamicExt* VisionTransformerINT8Plugin::clone() const noexcept +{ + + VisionTransformerINT8Plugin* ret = new VisionTransformerINT8Plugin(*this); + return ret; +} + +template +DimsExprs VisionTransformerINT8Plugin::getOutputDimensions(int outputIndex, + const DimsExprs* inputs, + int nbInputs, + IExprBuilder& exprBuilder) noexcept +{ + // Input is B*in_chans*H*W, output should be B*seq_len*embed_dim*1 + assert(outputIndex == 0); + DimsExprs output; + output.nbDims = 3; + output.d[0] = inputs[0].d[0]; + output.d[1] = exprBuilder.constant(settings_.seq_len); + output.d[2] = exprBuilder.constant(settings_.embed_dim); + return output; +} + +template +bool VisionTransformerINT8Plugin::supportsFormatCombination(int pos, + const PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) noexcept +{ + bool res = false; + assert(pos >= 0 && pos < 2); + assert(nbInputs == 1); + switch (pos) { + case 0: // input + case 1: // output + res = (inOut[pos].type + == (std::is_same::value ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT)) + && (inOut[pos].format == nvinfer1::TensorFormat::kLINEAR); + break; + default: + break; + } + + return res; +} + +template +void VisionTransformerINT8Plugin::configurePlugin(const DynamicPluginTensorDesc* in, + int nbInputs, + const DynamicPluginTensorDesc* out, + int nbOutputs) noexcept +{ + assert(nbInputs == 1); + assert(nbOutputs == 1); +} + +template +size_t VisionTransformerINT8Plugin::getWorkspaceSize(const PluginTensorDesc* inputs, + int nbInputs, + const PluginTensorDesc* outputs, + int nbOutputs) const noexcept +{ + return 0; +} + +// IPluginV2Ext Methods +template +nvinfer1::DataType VisionTransformerINT8Plugin::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const noexcept +{ + assert(index == 0); + assert(inputTypes[0] == nvinfer1::DataType::kFLOAT || inputTypes[0] == nvinfer1::DataType::kHALF); + return inputTypes[0]; +} + +// IPluginV2 Methods +template +const char* VisionTransformerINT8Plugin::getPluginType() const noexcept +{ + return VIT_PLUGIN_NAME; +} + +template +const char* VisionTransformerINT8Plugin::getPluginVersion() const noexcept +{ + return VIT_PLUGIN_VERSION; +} + +template +int VisionTransformerINT8Plugin::getNbOutputs() const noexcept +{ + return 1; +} + +template +int VisionTransformerINT8Plugin::initialize() noexcept +{ + return 0; +} + +template +void VisionTransformerINT8Plugin::terminate() noexcept +{ +} + +template +size_t VisionTransformerINT8Plugin::getSerializationSize() const noexcept +{ + + size_t size = sizeof(int) + sizeof(settings_); + size += params_->GetSerializeSize(); + return size; +} + +template +void VisionTransformerINT8Plugin::serialize(void* buffer) const noexcept +{ + FT_LOG_INFO("start serialize vit..."); + + int type_id = 0; + if (std::is_same::value) { + type_id = 1; + } + ::memcpy(buffer, &type_id, sizeof(type_id)); + char* serial_buffer = (char*)buffer + sizeof(type_id); + ::memcpy(serial_buffer, &settings_, sizeof(settings_)); + serial_buffer += sizeof(settings_); + params_->serialize(serial_buffer); +} + +template +void VisionTransformerINT8Plugin::destroy() noexcept +{ + delete this; +} + +template +void VisionTransformerINT8Plugin::setPluginNamespace(const char* libNamespace) noexcept +{ + namespace_ = libNamespace; +} + +template +const char* VisionTransformerINT8Plugin::getPluginNamespace() const noexcept +{ + return namespace_.c_str(); +} + +template +int VisionTransformerINT8Plugin::enqueue(const PluginTensorDesc* inputDesc, + const PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept +{ + int batch_size = inputDesc->dims.d[0]; + assert(batch_size <= settings_.max_batch_size); + assert(settings_.chn_num == inputDesc->dims.d[1]); + assert(settings_.img_size == inputDesc->dims.d[2]); + assert(settings_.img_size == inputDesc->dims.d[3]); + + std::vector input_tensors = std::vector{Tensor{ + MEMORY_GPU, + getTensorType(), + std::vector{ + (size_t)batch_size, (size_t)settings_.chn_num, (size_t)settings_.img_size, (size_t)settings_.img_size}, + (const T*)(inputs[0])}}; + + std::vector output_tensors = std::vector{ + Tensor{MEMORY_GPU, + getTensorType(), + std::vector{(size_t)batch_size, (size_t)settings_.seq_len, (size_t)settings_.embed_dim}, + (T*)(outputs[0])}}; + + vit_transformer_->forward(&output_tensors, &input_tensors, params_); + return 0; +} + +VisionTransformerINT8PluginCreator::VisionTransformerINT8PluginCreator() +{ + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char* VisionTransformerINT8PluginCreator::getPluginName() const noexcept +{ + return VIT_PLUGIN_NAME; +} + +const char* VisionTransformerINT8PluginCreator::getPluginVersion() const noexcept +{ + return VIT_PLUGIN_VERSION; +} + +const PluginFieldCollection* VisionTransformerINT8PluginCreator::getFieldNames() noexcept +{ + return &mFC; +} + +// Creator +#define L_ROOT "transformer.encoder.layer.%d" +#define ATT_Q "attn.query" +#define ATT_K "attn.key" +#define ATT_V "attn.value" +#define ATT_OUT "attn.out" +#define ATT_NORM "attention_norm" +#define FFN_NORM "ffn_norm" +#define FFN_IN "ffn.fc1" +#define FFN_OUT "ffn.fc2" + +const std::vector layer_weight_names = {L_ROOT "." ATT_NORM ".weight", + L_ROOT "." ATT_NORM ".bias", + L_ROOT "." ATT_Q ".weight", + L_ROOT "." ATT_Q ".bias", + L_ROOT "." ATT_K ".weight", + L_ROOT "." ATT_K ".bias", + L_ROOT "." ATT_V ".weight", + L_ROOT "." ATT_V ".bias", + L_ROOT "." ATT_OUT ".weight", + L_ROOT "." ATT_OUT ".bias", + L_ROOT "." FFN_NORM ".weight", + L_ROOT "." FFN_NORM ".bias", + L_ROOT "." FFN_IN ".weight", + L_ROOT "." FFN_IN ".bias", + L_ROOT "." FFN_OUT ".weight", + L_ROOT "." FFN_OUT ".bias", + L_ROOT ".amaxList", + L_ROOT ".h_amaxList"}; + +const std::vector pre_layer_weight_names = {"transformer.embeddings.patch_embeddings.weight", + "transformer.embeddings.patch_embeddings.bias", + "transformer.embeddings.cls_token", + "transformer.embeddings.position_embeddings"}; +const std::vector post_layer_weight_names = {"transformer.encoder.encoder_norm.weight", + "transformer.encoder.encoder_norm.bias"}; + +nvinfer1::PluginFieldType getFieldCollectionTypeINT8(std::string name, const nvinfer1::PluginFieldCollection* fc) +{ + for (int i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare(name) == 0) { + return fc->fields[i].type; + } + } + return nvinfer1::PluginFieldType::kUNKNOWN; +} + +template +void loadWeightsPtrINT8(std::vector& w, + const nvinfer1::PluginFieldCollection* fc, + int layer_num, + bool with_cls_token = true) +{ + int idx = 0; + for (auto& name : pre_layer_weight_names) { + if (!with_cls_token && name == "transformer.embeddings.cls_token") { + continue; + } + + for (int i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare(name) == 0) { + w[idx++] = (const T*)fc->fields[i].data; + } + } + } + + for (int i = 0; i < layer_num; i++) { + for (auto& name : layer_weight_names) { + char str_buf[1024]; + sprintf(str_buf, name, i); + + for (int j = 0; j < fc->nbFields; j++) { + std::string field_name(fc->fields[j].name); + if (field_name.compare(str_buf) == 0) { + w[idx++] = (const T*)fc->fields[j].data; + } + } + } + } + + for (auto& name : post_layer_weight_names) { + for (int i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare(name) == 0) { + w[idx++] = (const T*)fc->fields[i].data; + } + } + } + + FT_CHECK(idx == w.size()); +} + +IPluginV2* VisionTransformerINT8PluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +{ + int max_batch; + int img_size; + int patch_size; + int in_chans; + int embed_dim; + int num_heads; + int inter_size; + int layer_num; + int int8_mode; + int with_cls_token = true; + + std::map name2pint = {{"max_batch", &max_batch}, + {"img_size", &img_size}, + {"patch_size", &patch_size}, + {"in_chans", &in_chans}, + {"embed_dim", &embed_dim}, + {"num_heads", &num_heads}, + {"inter_size", &inter_size}, + {"layer_num", &layer_num}, + {"int8_mode", &int8_mode}, + {"with_cls_token", &with_cls_token}}; + + for (int i = 0; i < fc->nbFields; i++) { + auto iter = name2pint.find(fc->fields[i].name); + if (iter != name2pint.end()) { + *(iter->second) = *((int*)fc->fields[i].data); + printf("name=%s, value=%d\n", iter->first.c_str(), *((int*)fc->fields[i].data)); + continue; + } + } + + size_t weights_num = + pre_layer_weight_names.size() + post_layer_weight_names.size() + layer_num * layer_weight_names.size(); + + auto weights_type = getFieldCollectionTypeINT8(pre_layer_weight_names[0], fc); + + std::vector w_fp16; + std::vector w_fp32; + IPluginV2* p; + switch (weights_type) { + case nvinfer1::PluginFieldType::kFLOAT16: + w_fp16.resize(weights_num); + loadWeightsPtrINT8(w_fp16, fc, layer_num); + p = new VisionTransformerINT8Plugin(name, + max_batch, + img_size, + patch_size, + in_chans, + embed_dim, + num_heads, + inter_size, + layer_num, + 1.0, + with_cls_token, + int8_mode, + w_fp16); + + break; + case nvinfer1::PluginFieldType::kFLOAT32: + w_fp32.resize(weights_num); + loadWeightsPtrINT8(w_fp32, fc, layer_num); + p = new VisionTransformerINT8Plugin(name, + max_batch, + img_size, + patch_size, + in_chans, + embed_dim, + num_heads, + inter_size, + layer_num, + 1.0, + with_cls_token, + int8_mode, + w_fp32); + break; + default: + FT_CHECK_WITH_INFO(false, "[VisionTransformerINT8PluginCreator::createPlugin] unsupported datatype."); + } + + return p; +} + +IPluginV2* VisionTransformerINT8PluginCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) noexcept +{ + int type_id; + ::memcpy(&type_id, serialData, sizeof(int)); + char* modelData = (char*)serialData + sizeof(int); + + // This object will be deleted when the network is destroyed, which will + // call VisionTransformerINT8Plugin::destroy() + if (type_id == 0) + return new VisionTransformerINT8Plugin(name, modelData, serialLength); + else if (type_id == 1) + return new VisionTransformerINT8Plugin(name, modelData, serialLength); + else { + FT_LOG_ERROR("[VisionTransformerINT8PluginCreator::deserializePlugin] unsupported data type %d\n", type_id); + FT_CHECK(false); + } +} + +void VisionTransformerINT8PluginCreator::setPluginNamespace(const char* libNamespace) noexcept +{ + namespace_ = libNamespace; +} + +const char* VisionTransformerINT8PluginCreator::getPluginNamespace() const noexcept +{ + return namespace_.c_str(); +} + +template class VisionTransformerINT8Plugin; +template class VisionTransformerINT8Plugin; + +} // namespace fastertransformer diff --git a/src/fastertransformer/tensorrt_plugin/vit/ViTINT8Plugin.h b/src/fastertransformer/tensorrt_plugin/vit/ViTINT8Plugin.h new file mode 100644 index 000000000..60d822dc5 --- /dev/null +++ b/src/fastertransformer/tensorrt_plugin/vit/ViTINT8Plugin.h @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#ifndef SWIN_TRANSFORMER_INT8_PLUGIN_H +#define SWIN_TRANSFORMER_INT8_PLUGIN_H + +#include "NvInfer.h" +#include "NvInferPlugin.h" +#include "NvInferRuntime.h" +#include "cublas_v2.h" +#include +#include + +#include "src/fastertransformer/models/vit_int8/ViTINT8.h" +#include "src/fastertransformer/utils/allocator.h" + +namespace fastertransformer { + +namespace { +static const char* VIT_PLUGIN_VERSION{"1"}; +static const char* VIT_PLUGIN_NAME{"CustomVisionTransformerINT8Plugin"}; +} // namespace + +struct ViTINT8Settings { + size_t max_batch_size = 32; + size_t img_size = 224; + size_t chn_num = 3; + size_t patch_size = 16; + size_t embed_dim = 768; + size_t head_num = 12; + size_t inter_size = embed_dim * 4; + size_t num_layer = 12; + bool with_cls_token = true; + bool is_fp16 = false; + int sm = -1; + float q_scaling = 1.0f; + AttentionType attention_type = AttentionType::UNFUSED_MHA; + // runtime param + size_t seq_len = 0; + int int8_mode = 2; +}; + +template +class VisionTransformerINT8Plugin: public nvinfer1::IPluginV2DynamicExt { +private: + const std::string layer_name_; + std::string namespace_; + + cublasHandle_t cublas_handle_ = nullptr; + cublasLtHandle_t cublaslt_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; + ViTINT8Weight* params_ = nullptr; + ViTTransformerINT8* vit_transformer_ = nullptr; + std::mutex* cublasWrapperMutex_ = nullptr; + cublasAlgoMap* cublasAlgoMap_ = nullptr; + fastertransformer::Allocator* allocator_ = nullptr; + cublasINT8MMWrapper* cublas_wrapper_ = nullptr; + + ViTINT8Settings settings_; + +public: + int sm_; + VisionTransformerINT8Plugin(const std::string& name, + const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int num_heads, + const int inter_size, + const int layer_num, + const float q_scaling, + const bool with_cls_token, + const int int8_mode, + const std::vector& w); + + VisionTransformerINT8Plugin(const std::string& name, const void* data, size_t length); + VisionTransformerINT8Plugin(const VisionTransformerINT8Plugin& plugin); + VisionTransformerINT8Plugin() = delete; + + ~VisionTransformerINT8Plugin(); + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) noexcept override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) noexcept override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const noexcept override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType + getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + + // IPluginV2 Methods + const char* getPluginType() const noexcept override; + const char* getPluginVersion() const noexcept override; + int getNbOutputs() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(const char* pluginNamespace) noexcept override; + const char* getPluginNamespace() const noexcept override; + +private: + void Init(const std::vector& w); +}; + +class VisionTransformerINT8PluginCreator: public nvinfer1::IPluginCreator { +public: + VisionTransformerINT8PluginCreator(); + + const char* getPluginName() const noexcept override; + + const char* getPluginVersion() const noexcept override; + + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + + nvinfer1::IPluginV2* + deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override; + + void setPluginNamespace(const char* pluginNamespace) noexcept override; + + const char* getPluginNamespace() const noexcept override; + +private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string namespace_; +}; + +} // namespace fastertransformer +#endif // SWIN_TRANSFORMER_INT8_PLUGIN_H diff --git a/src/fastertransformer/tensorrt_plugin/vit/ViTPlugin.cpp b/src/fastertransformer/tensorrt_plugin/vit/ViTPlugin.cpp index 710b7c150..f42437731 100644 --- a/src/fastertransformer/tensorrt_plugin/vit/ViTPlugin.cpp +++ b/src/fastertransformer/tensorrt_plugin/vit/ViTPlugin.cpp @@ -31,39 +31,39 @@ using namespace std; namespace fastertransformer { // Static class fields initialization -PluginFieldCollection VisionTransformerPluginCreator::mFC{}; +PluginFieldCollection VisionTransformerPluginCreator::mFC{}; std::vector VisionTransformerPluginCreator::mPluginAttributes; REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator); template -VisionTransformerPlugin::VisionTransformerPlugin(const std::string& name, - const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int num_heads, - const int inter_size, - const int layer_num, - const float q_scaling, - const bool with_cls_token, +VisionTransformerPlugin::VisionTransformerPlugin(const std::string& name, + const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int num_heads, + const int inter_size, + const int layer_num, + const float q_scaling, + const bool with_cls_token, const std::vector& w): layer_name_(name) { settings_.max_batch_size = max_batch; - settings_.img_size = img_size; - settings_.chn_num = in_chans; - settings_.patch_size = patch_size; - settings_.embed_dim = embed_dim; - settings_.head_num = num_heads; - settings_.inter_size = inter_size; - settings_.num_layer = layer_num; + settings_.img_size = img_size; + settings_.chn_num = in_chans; + settings_.patch_size = patch_size; + settings_.embed_dim = embed_dim; + settings_.head_num = num_heads; + settings_.inter_size = inter_size; + settings_.num_layer = layer_num; settings_.with_cls_token = with_cls_token; - settings_.sm = getSMVersion(); - settings_.q_scaling = q_scaling; - settings_.seq_len = (img_size / patch_size) * (img_size / patch_size) + (with_cls_token ? 1 : 0); + settings_.sm = getSMVersion(); + settings_.q_scaling = q_scaling; + settings_.seq_len = (img_size / patch_size) * (img_size / patch_size) + (with_cls_token ? 1 : 0); settings_.attention_type = getAttentionType(embed_dim / num_heads, settings_.sm, true, settings_.seq_len); Init(w); @@ -119,9 +119,9 @@ void VisionTransformerPlugin::Init(const std::vector& w) check_cuda_error(cublasLtCreate(&cublaslt_handle_)); checkCUDNN(cudnnCreate(&cudnn_handle_)); - cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); + cublasAlgoMap_ = new cublasAlgoMap("igemm.config", ""); cublasWrapperMutex_ = new std::mutex(); - allocator_ = new Allocator(getDevice()); + allocator_ = new Allocator(getDevice()); cublas_wrapper_ = new cublasMMWrapper(cublas_handle_, cublaslt_handle_, nullptr, cublasAlgoMap_, cublasWrapperMutex_, allocator_); @@ -174,26 +174,26 @@ nvinfer1::IPluginV2DynamicExt* VisionTransformerPlugin::clone() const noexcep } template -DimsExprs VisionTransformerPlugin::getOutputDimensions(int outputIndex, +DimsExprs VisionTransformerPlugin::getOutputDimensions(int outputIndex, const DimsExprs* inputs, - int nbInputs, - IExprBuilder& exprBuilder) noexcept + int nbInputs, + IExprBuilder& exprBuilder) noexcept { // Input is B*in_chans*H*W, output should be B*seq_len*embed_dim*1 assert(outputIndex == 0); DimsExprs output; output.nbDims = 3; - output.d[0] = inputs[0].d[0]; - output.d[1] = exprBuilder.constant(settings_.seq_len); - output.d[2] = exprBuilder.constant(settings_.embed_dim); + output.d[0] = inputs[0].d[0]; + output.d[1] = exprBuilder.constant(settings_.seq_len); + output.d[2] = exprBuilder.constant(settings_.embed_dim); return output; } template -bool VisionTransformerPlugin::supportsFormatCombination(int pos, +bool VisionTransformerPlugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) noexcept + int nbInputs, + int nbOutputs) noexcept { bool res = false; assert(pos >= 0 && pos < 2); @@ -214,9 +214,9 @@ bool VisionTransformerPlugin::supportsFormatCombination(int pos, template void VisionTransformerPlugin::configurePlugin(const DynamicPluginTensorDesc* in, - int nbInputs, + int nbInputs, const DynamicPluginTensorDesc* out, - int nbOutputs) noexcept + int nbOutputs) noexcept { assert(nbInputs == 1); assert(nbOutputs == 1); @@ -224,18 +224,18 @@ void VisionTransformerPlugin::configurePlugin(const DynamicPluginTensorDesc* template size_t VisionTransformerPlugin::getWorkspaceSize(const PluginTensorDesc* inputs, - int nbInputs, + int nbInputs, const PluginTensorDesc* outputs, - int nbOutputs) const noexcept + int nbOutputs) const noexcept { return 0; } // IPluginV2Ext Methods template -nvinfer1::DataType VisionTransformerPlugin::getOutputDataType(int index, +nvinfer1::DataType VisionTransformerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, - int nbInputs) const noexcept + int nbInputs) const noexcept { assert(index == 0); assert(inputTypes[0] == nvinfer1::DataType::kFLOAT || inputTypes[0] == nvinfer1::DataType::kHALF); @@ -318,10 +318,10 @@ const char* VisionTransformerPlugin::getPluginNamespace() const noexcept template int VisionTransformerPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept { int batch_size = inputDesc->dims.d[0]; assert(batch_size <= settings_.max_batch_size); @@ -329,7 +329,7 @@ int VisionTransformerPlugin::enqueue(const PluginTensorDesc* inputDesc, assert(settings_.img_size == inputDesc->dims.d[2]); assert(settings_.img_size == inputDesc->dims.d[3]); - int sm_ptr[1] = {sm_}; + int sm_ptr[1] = {sm_}; std::vector input_tensors = std::vector{Tensor{ MEMORY_GPU, getTensorType(), @@ -350,7 +350,7 @@ int VisionTransformerPlugin::enqueue(const PluginTensorDesc* inputDesc, VisionTransformerPluginCreator::VisionTransformerPluginCreator() { mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); + mFC.fields = mPluginAttributes.data(); } const char* VisionTransformerPluginCreator::getPluginName() const noexcept @@ -413,10 +413,10 @@ nvinfer1::PluginFieldType getFieldCollectionType(std::string name, const nvinfer } template -void loadWeightsPtr(std::vector& w, +void loadWeightsPtr(std::vector& w, const nvinfer1::PluginFieldCollection* fc, - int layer_num, - bool with_cls_token = true) + int layer_num, + bool with_cls_token = true) { int idx = 0; for (auto& name : pre_layer_weight_names) { @@ -431,7 +431,7 @@ void loadWeightsPtr(std::vector& w, } } - for (int i = 0; i < layer_num; i++){ + for (int i = 0; i < layer_num; i++) { for (auto& name : layer_weight_names) { char str_buf[1024]; sprintf(str_buf, name, i); @@ -493,9 +493,9 @@ IPluginV2* VisionTransformerPluginCreator::createPlugin(const char* name, const auto weights_type = getFieldCollectionType(pre_layer_weight_names[0], fc); - std::vector w_fp16; + std::vector w_fp16; std::vector w_fp32; - IPluginV2* p; + IPluginV2* p; switch (weights_type) { case nvinfer1::PluginFieldType::kFLOAT16: w_fp16.resize(weights_num); @@ -531,7 +531,7 @@ IPluginV2* VisionTransformerPluginCreator::createPlugin(const char* name, const w_fp32); break; default: - printf("[ERROR][VisionTransformerPluginCreator::createPlugin] unsupport datatype.\n"); + printf("[ERROR][VisionTransformerPluginCreator::createPlugin] unsupported datatype.\n"); exit(-1); } @@ -540,7 +540,7 @@ IPluginV2* VisionTransformerPluginCreator::createPlugin(const char* name, const IPluginV2* VisionTransformerPluginCreator::deserializePlugin(const char* name, const void* serialData, - size_t serialLength) noexcept + size_t serialLength) noexcept { int type_id; ::memcpy(&type_id, serialData, sizeof(int)); @@ -553,7 +553,7 @@ IPluginV2* VisionTransformerPluginCreator::deserializePlugin(const char* name, else if (type_id == 1) return new VisionTransformerPlugin(name, modelData, serialLength); else { - printf("[ERROR][VisionTransformerPluginCreator::deserializePlugin] unsupport data type %d\n", type_id); + printf("[ERROR][VisionTransformerPluginCreator::deserializePlugin] unsupported data type %d\n", type_id); exit(-1); } } diff --git a/src/fastertransformer/tensorrt_plugin/vit/ViTPlugin.h b/src/fastertransformer/tensorrt_plugin/vit/ViTPlugin.h index d90fdf187..5ae70ccc8 100644 --- a/src/fastertransformer/tensorrt_plugin/vit/ViTPlugin.h +++ b/src/fastertransformer/tensorrt_plugin/vit/ViTPlugin.h @@ -37,18 +37,18 @@ static const char* VIT_PLUGIN_NAME{"CustomVisionTransformerPlugin"}; } // namespace struct ViTSettings { - size_t max_batch_size = 32; - size_t img_size = 224; - size_t chn_num = 3; - size_t patch_size = 16; - size_t embed_dim = 768; - size_t head_num = 12; - size_t inter_size = embed_dim * 4; - size_t num_layer = 12; - bool with_cls_token = true; - bool is_fp16 = false; - int sm = -1; - float q_scaling = 1.0f; + size_t max_batch_size = 32; + size_t img_size = 224; + size_t chn_num = 3; + size_t patch_size = 16; + size_t embed_dim = 768; + size_t head_num = 12; + size_t inter_size = embed_dim * 4; + size_t num_layer = 12; + bool with_cls_token = true; + bool is_fp16 = false; + int sm = -1; + float q_scaling = 1.0f; AttentionType attention_type = AttentionType::UNFUSED_MHA; // runtime param size_t seq_len = 0; @@ -58,33 +58,33 @@ template class VisionTransformerPlugin: public nvinfer1::IPluginV2DynamicExt { private: const std::string layer_name_; - std::string namespace_; - - cublasHandle_t cublas_handle_ = nullptr; - cublasLtHandle_t cublaslt_handle_ = nullptr; - cudnnHandle_t cudnn_handle_ = nullptr; - ViTWeight* params_ = nullptr; - ViTTransformer* vit_transformer_ = nullptr; - std::mutex* cublasWrapperMutex_ = nullptr; - cublasAlgoMap* cublasAlgoMap_ = nullptr; - fastertransformer::Allocator* allocator_ = nullptr; - cublasMMWrapper* cublas_wrapper_ = nullptr; + std::string namespace_; + + cublasHandle_t cublas_handle_ = nullptr; + cublasLtHandle_t cublaslt_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; + ViTWeight* params_ = nullptr; + ViTTransformer* vit_transformer_ = nullptr; + std::mutex* cublasWrapperMutex_ = nullptr; + cublasAlgoMap* cublasAlgoMap_ = nullptr; + fastertransformer::Allocator* allocator_ = nullptr; + cublasMMWrapper* cublas_wrapper_ = nullptr; ViTSettings settings_; public: int sm_; - VisionTransformerPlugin(const std::string& name, - const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int num_heads, - const int inter_size, - const int layer_num, - const float q_scaling, - const bool with_cls_token, + VisionTransformerPlugin(const std::string& name, + const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int num_heads, + const int inter_size, + const int layer_num, + const float q_scaling, + const bool with_cls_token, const std::vector& w); VisionTransformerPlugin(const std::string& name, const void* data, size_t length); @@ -95,28 +95,28 @@ class VisionTransformerPlugin: public nvinfer1::IPluginV2DynamicExt { // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, - const nvinfer1::DimsExprs* inputs, - int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) noexcept override; - bool supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) noexcept override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) noexcept override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const noexcept override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType @@ -125,13 +125,13 @@ class VisionTransformerPlugin: public nvinfer1::IPluginV2DynamicExt { // IPluginV2 Methods const char* getPluginType() const noexcept override; const char* getPluginVersion() const noexcept override; - int getNbOutputs() const noexcept override; - int initialize() noexcept override; - void terminate() noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void* buffer) const noexcept override; - void destroy() noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override; + int getNbOutputs() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(const char* pluginNamespace) noexcept override; const char* getPluginNamespace() const noexcept override; private: @@ -158,9 +158,9 @@ class VisionTransformerPluginCreator: public nvinfer1::IPluginCreator { const char* getPluginNamespace() const noexcept override; private: - static nvinfer1::PluginFieldCollection mFC; + static nvinfer1::PluginFieldCollection mFC; static std::vector mPluginAttributes; - std::string namespace_; + std::string namespace_; }; } // namespace fastertransformer diff --git a/src/fastertransformer/tf_op/BaseOp.h b/src/fastertransformer/tf_op/BaseOp.h index dfa42e412..b0ea6cdcc 100644 --- a/src/fastertransformer/tf_op/BaseOp.h +++ b/src/fastertransformer/tf_op/BaseOp.h @@ -37,8 +37,8 @@ using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -namespace tf = tensorflow; -namespace ft = fastertransformer; +namespace tf = tensorflow; +namespace ft = fastertransformer; template class BaseOp: public tf::OpKernel { @@ -102,7 +102,7 @@ class BaseOp: public tf::OpKernel { ft::MEMORY_GPU, ft::getTensorType(), convert_shape(tensor), (half*)(tensor.flat().data())}; } #ifdef ENABLE_BF16 - if (std::is_same::value == true) { + else if (std::is_same::value == true) { return ft::Tensor{ft::MEMORY_GPU, ft::getTensorType<__nv_bfloat16>(), convert_shape(tensor), @@ -140,9 +140,9 @@ class BaseOp: public tf::OpKernel { } private: - cublasHandle_t cublas_handle_; + cublasHandle_t cublas_handle_; cublasLtHandle_t cublaslt_handle_; - std::mutex* cublas_wrapper_mutex_; + std::mutex* cublas_wrapper_mutex_; }; #endif diff --git a/src/fastertransformer/tf_op/bert/BertINT8Op.cc b/src/fastertransformer/tf_op/bert/BertINT8Op.cc index e022fc0a0..f12d3391e 100644 --- a/src/fastertransformer/tf_op/bert/BertINT8Op.cc +++ b/src/fastertransformer/tf_op/bert/BertINT8Op.cc @@ -91,9 +91,9 @@ class BertINT8Op: public BaseOp { OP_REQUIRES_OK(context, context->GetAttr("int8_mode", &int8_mode_)); OP_REQUIRES_OK(context, context->GetAttr("remove_padding", &remove_padding_)); OP_REQUIRES_OK(context, context->GetAttr("q_scaling", &q_scaling_)); - sm_ = ft::getSMVersion(); - cublas_algo_map_ = new ft::cublasAlgoMap("igemm_config.in"); - set_weight_ = false; + sm_ = ft::getSMVersion(); + cublas_algo_map_ = new ft::cublasAlgoMap("igemm_config.in"); + set_weight_ = false; use_ORDER_COL32_2R_4R4_ = false; #if (CUDART_VERSION >= 11000) if (sm_ >= 80) { @@ -112,7 +112,7 @@ class BertINT8Op: public BaseOp { if (set_weight_ && h_scale_list_) { free(h_scale_list_); h_scale_list_ = nullptr; - set_weight_ = false; + set_weight_ = false; } } @@ -122,15 +122,15 @@ class BertINT8Op: public BaseOp { context->num_inputs() == (num_layer_ * 17) + 3, tf::errors::InvalidArgument("[ERROR] More or Less input arguments")); - const size_t batch_size_ = (size_t)context->input(0).dim_size(0); + const size_t batch_size_ = (size_t)context->input(0).dim_size(0); const size_t from_seq_len_ = (size_t)context->input(0).dim_size(1); OP_REQUIRES(context, batch_size_ == (size_t)context->input(2).dim_size(0), tf::errors::InvalidArgument("[ERROR] invalid shape")); - const cudaStream_t& stream = context->eigen_device().stream(); - cublasLtHandle_t cublaslt_handle = this->get_cublaslt_handler(); + const cudaStream_t& stream = context->eigen_device().stream(); + cublasLtHandle_t cublaslt_handle = this->get_cublaslt_handler(); ft::cublasINT8MMWrapper cublas_wrapper = ft::cublasINT8MMWrapper( cublaslt_handle, stream, cublas_algo_map_, this->get_cublas_wrapper_mutex(), use_ORDER_COL32_2R_4R4_); @@ -177,12 +177,12 @@ class BertINT8Op: public BaseOp { // deal with scale list bert_layer_weights_[i].scale_list_.d_scale_list_ = reinterpret_cast(context->input(3 + num_layer_ * 16 + i).flat().data()); - bert_layer_weights_[i].scale_list_.size_ = scale_list_size; + bert_layer_weights_[i].scale_list_.size_ = scale_list_size; bert_layer_weights_[i].scale_list_.p3_offset_ = ACTIVATION_AMAX_NUM + 9 * head_num_ * size_per_head_; bert_layer_weights_[i].scale_list_.p4_offset_ = ACTIVATION_AMAX_NUM + 9 * head_num_ * size_per_head_ + INT8O_GEMM_NUM; bert_layer_weights_[i].attention_weights.scale_list_ptr = &(bert_layer_weights_[i].scale_list_); - bert_layer_weights_[i].ffn_weights.scale_list_ptr = &(bert_layer_weights_[i].scale_list_); + bert_layer_weights_[i].ffn_weights.scale_list_ptr = &(bert_layer_weights_[i].scale_list_); // copy h_scale_list cudaMemcpy(h_scale_list_ + i * scale_list_size, bert_layer_weights_[i].scale_list_.d_scale_list_, @@ -214,7 +214,7 @@ class BertINT8Op: public BaseOp { OP_REQUIRES_OK(context, context->allocate_output(0, context->input(0).shape(), &output)); DataType* out_tensor = reinterpret_cast(output->flat().data()); - ft::DataType data_type = ft::getTensorType(); + ft::DataType data_type = ft::getTensorType(); std::vector input_tensors = std::vector{this->convert_tensor(context->input(0)), this->convert_int_tensor(context->input(2))}; @@ -238,16 +238,16 @@ class BertINT8Op: public BaseOp { } private: - int head_num_ = 0, size_per_head_ = 0, num_layer_ = 0, inter_size_ = 0, int8_mode_ = 1; - bool remove_padding_; - bool use_ORDER_COL32_2R_4R4_; - int sm_; - float q_scaling_ = 1.0f; - ft::cublasAlgoMap* cublas_algo_map_; - float* h_scale_list_ = nullptr; - bool set_weight_ = false; + int head_num_ = 0, size_per_head_ = 0, num_layer_ = 0, inter_size_ = 0, int8_mode_ = 1; + bool remove_padding_; + bool use_ORDER_COL32_2R_4R4_; + int sm_; + float q_scaling_ = 1.0f; + ft::cublasAlgoMap* cublas_algo_map_; + float* h_scale_list_ = nullptr; + bool set_weight_ = false; typedef TFTraits traits_; - typedef typename traits_::DataType DataType; + typedef typename traits_::DataType DataType; std::vector> bert_layer_weights_; }; diff --git a/src/fastertransformer/tf_op/bert/BertOp.cc b/src/fastertransformer/tf_op/bert/BertOp.cc index 12b41f329..9f96f33d1 100644 --- a/src/fastertransformer/tf_op/bert/BertOp.cc +++ b/src/fastertransformer/tf_op/bert/BertOp.cc @@ -72,6 +72,14 @@ class TFTraits { typedef half DataType; }; +#ifdef ENABLE_BF16 +template<> +class TFTraits { +public: + typedef __nv_bfloat16 DataType; +}; +#endif + template class BertOp: public BaseOp { public: @@ -84,7 +92,7 @@ class BertOp: public BaseOp { OP_REQUIRES_OK(context, context->GetAttr("num_layer", &num_layer_)); OP_REQUIRES_OK(context, context->GetAttr("remove_padding", &remove_padding_)); OP_REQUIRES_OK(context, context->GetAttr("q_scaling", &q_scaling_)); - sm_ = ft::getSMVersion(); + sm_ = ft::getSMVersion(); cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); } catch (std::runtime_error& error) { @@ -103,18 +111,18 @@ class BertOp: public BaseOp { context->num_inputs() == (num_layer_ * 16) + 3, tf::errors::InvalidArgument("[ERROR] More or Less input arguments")); - const size_t batch_size_ = (size_t)context->input(0).dim_size(0); + const size_t batch_size_ = (size_t)context->input(0).dim_size(0); const size_t from_seq_len_ = (size_t)context->input(0).dim_size(1); OP_REQUIRES(context, batch_size_ == (size_t)context->input(2).dim_size(0), tf::errors::InvalidArgument("[ERROR] invalid shape")); - const cudaStream_t& stream = context->eigen_device().stream(); - cublasHandle_t cublas_handle = this->get_cublas_handler(); + const cudaStream_t& stream = context->eigen_device().stream(); + cublasHandle_t cublas_handle = this->get_cublas_handler(); cublasSetStream(cublas_handle, stream); ft::Allocator allocator(context, stream); - ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, + ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, this->get_cublaslt_handler(), stream, cublas_algo_map_, @@ -124,6 +132,11 @@ class BertOp: public BaseOp { if (std::is_same::value) { cublas_wrapper.setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } @@ -173,7 +186,7 @@ class BertOp: public BaseOp { context, 3 + num_layer_ * 15 + i, &bert_weights.bert_layer_weights[i].ffn_layernorm_weights.gamma); } bert_weights.post_transformer_layernorm_weights.gamma = nullptr; - bert_weights.post_transformer_layernorm_weights.beta = nullptr; + bert_weights.post_transformer_layernorm_weights.beta = nullptr; ft::AttentionType attention_type = ft::getAttentionType(size_per_head_, sm_, remove_padding_, from_seq_len_); @@ -199,7 +212,7 @@ class BertOp: public BaseOp { OP_REQUIRES_OK(context, context->allocate_output(0, context->input(0).shape(), &output)); DataType* out_tensor = reinterpret_cast(output->flat().data()); - ft::DataType data_type = ft::getTensorType(); + ft::DataType data_type = ft::getTensorType(); std::vector input_tensors = std::vector{this->convert_tensor(context->input(0)), this->convert_int_tensor(context->input(2))}; @@ -223,12 +236,12 @@ class BertOp: public BaseOp { } private: - int head_num_ = 0, size_per_head_ = 0, inter_size_ = 0, num_layer_ = 0; - float q_scaling_ = 1.0f; - bool remove_padding_; - int sm_; - ft::cublasAlgoMap* cublas_algo_map_; - typedef TFTraits traits_; + int head_num_ = 0, size_per_head_ = 0, inter_size_ = 0, num_layer_ = 0; + float q_scaling_ = 1.0f; + bool remove_padding_; + int sm_; + ft::cublasAlgoMap* cublas_algo_map_; + typedef TFTraits traits_; typedef typename traits_::DataType DataType; }; diff --git a/src/fastertransformer/tf_op/bert/weight_quantize_op.cc b/src/fastertransformer/tf_op/bert/weight_quantize_op.cc index 93891a111..b6db0cee9 100644 --- a/src/fastertransformer/tf_op/bert/weight_quantize_op.cc +++ b/src/fastertransformer/tf_op/bert/weight_quantize_op.cc @@ -43,10 +43,10 @@ int index_CUBLASLT_ORDER_COL4_4R2_8C(int col_id, int row_id, int m_32) int index_CUBLASLT_ORDER_COL32_2R_4R4(int col_id, int row_id, int m_32) { - int new_col = col_id >> 5; + int new_col = col_id >> 5; int row_in_tile = row_id & 31; int col_in_tile = col_id & 31; - int new_row = // CUBLASLT_ORDER_COL32_2R_4R4 + int new_row = // CUBLASLT_ORDER_COL32_2R_4R4 (((row_id >> 5) << 10) + //(((row%8)/2*4+row/8)*2+row%2)*32+col (((((((row_in_tile & 7) >> 1) << 2) + (row_in_tile >> 3)) << 1) + (row_in_tile & 1)) << 5) + col_in_tile); @@ -54,20 +54,20 @@ int index_CUBLASLT_ORDER_COL32_2R_4R4(int col_id, int row_id, int m_32) } template -void quantization_CUBLASLT_ORDER_COL4_4R2_8C(T* dst, - float* amaxs, - const T* weight, +void quantization_CUBLASLT_ORDER_COL4_4R2_8C(T* dst, + float* amaxs, + const T* weight, const float* quant_max, const float* quant_min, - int n, - int k, - bool per_channel_quantization) + int n, + int k, + bool per_channel_quantization) { // quantization int8_t* int8_dst = (int8_t*)dst; - float element; - float amax; - float amax_in_all = fabs(quant_max[0]); + float element; + float amax; + float amax_in_all = fabs(quant_max[0]); if (per_channel_quantization) { for (int i = 0; i < n; i++) { amaxs[i] = fabs(quant_min[i]); @@ -89,29 +89,29 @@ void quantization_CUBLASLT_ORDER_COL4_4R2_8C(T* dst, for (int col = 0; col < k; col++) { tmp = col * n; for (int row = 0; row < n; row++) { - amax = amaxs[row]; - element = float(weight[tmp + row]); - idx_in_COL4 = index_CUBLASLT_ORDER_COL4_4R2_8C(col, row, 32 * n); + amax = amaxs[row]; + element = float(weight[tmp + row]); + idx_in_COL4 = index_CUBLASLT_ORDER_COL4_4R2_8C(col, row, 32 * n); int8_dst[idx_in_COL4] = float_to_int8_rn_host(element * 127.0 / amax); } } } template -void quantization_CUBLASLT_ORDER_COL32_2R_4R4(T* dst, - float* amaxs, - const T* weight, +void quantization_CUBLASLT_ORDER_COL32_2R_4R4(T* dst, + float* amaxs, + const T* weight, const float* quant_max, const float* quant_min, - int n, - int k, - bool per_channel_quantization) + int n, + int k, + bool per_channel_quantization) { // quantization int8_t* int8_dst = (int8_t*)dst; - float element; - float amax; - float amax_in_all = fabs(quant_max[0]); + float element; + float amax; + float amax_in_all = fabs(quant_max[0]); if (per_channel_quantization) { for (int i = 0; i < n; i++) { amaxs[i] = fabs(quant_min[i]); @@ -133,9 +133,9 @@ void quantization_CUBLASLT_ORDER_COL32_2R_4R4(T* dst, for (int col = 0; col < k; col++) { tmp = col * n; for (int row = 0; row < n; row++) { - amax = amaxs[row]; - element = float(weight[tmp + row]); - idx_in_COL32_2R_4R4 = index_CUBLASLT_ORDER_COL32_2R_4R4(col, row, 32 * n); + amax = amaxs[row]; + element = float(weight[tmp + row]); + idx_in_COL32_2R_4R4 = index_CUBLASLT_ORDER_COL32_2R_4R4(col, row, 32 * n); int8_dst[idx_in_COL32_2R_4R4] = float_to_int8_rn_host(element * 127.0 / amax); } } @@ -197,7 +197,7 @@ class WeightQuantizeOp: public BaseOp { Tensor* output2 = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, {n}, &output2)); - transform_out = reinterpret_cast(output->flat().data()); + transform_out = reinterpret_cast(output->flat().data()); transform_out2 = reinterpret_cast(output2->flat().data()); try { @@ -222,13 +222,13 @@ class WeightQuantizeOp: public BaseOp { } private: - int n, k; - const T* weight_; + int n, k; + const T* weight_; const float *quant_max_, *quant_min_; - T* transform_out; - float* transform_out2; - bool use_ORDER_COL32_2R_4R4; - bool per_channel_quantization_; + T* transform_out; + float* transform_out2; + bool use_ORDER_COL32_2R_4R4; + bool per_channel_quantization_; }; #define REGISTER_CPU(T) \ diff --git a/src/fastertransformer/tf_op/decoder/DecoderOp.cc b/src/fastertransformer/tf_op/decoder/DecoderOp.cc index a6e0f7f32..5f1445651 100644 --- a/src/fastertransformer/tf_op/decoder/DecoderOp.cc +++ b/src/fastertransformer/tf_op/decoder/DecoderOp.cc @@ -91,6 +91,14 @@ class TFTraits { typedef half DataType; }; +#ifdef ENABLE_BF16 +template<> +class TFTraits { +public: + typedef __nv_bfloat16 DataType; +}; +#endif + template class DecoderOp: public BaseOp { public: @@ -122,11 +130,11 @@ class DecoderOp: public BaseOp { const size_t batch_size = (size_t)(context->input(0).dim_size(0)); - const cudaStream_t& stream = context->eigen_device().stream(); - cublasHandle_t cublas_handle = this->get_cublas_handler(); + const cudaStream_t& stream = context->eigen_device().stream(); + cublasHandle_t cublas_handle = this->get_cublas_handler(); cublasSetStream(cublas_handle, stream); ft::Allocator allocator(context, stream); - ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, + ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, this->get_cublaslt_handler(), stream, cublas_algo_map_, @@ -136,6 +144,11 @@ class DecoderOp: public BaseOp { if (std::is_same::value) { cublas_wrapper.setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } @@ -195,11 +208,11 @@ class DecoderOp: public BaseOp { context, 7 + num_layer_ * 21 + i, &decoder_layer_weights[i].ffn_weights.output_weight.bias); } - tf::Tensor self_cache_keys_tensor = context->input(3); - tf::Tensor self_cache_values_tensor = context->input(4); - tf::Tensor memory_cache_keys_tensor = context->input(5); - tf::Tensor memory_cache_values_tensor = context->input(6); - tf::Tensor* output = nullptr; + tf::Tensor self_cache_keys_tensor = context->input(3); + tf::Tensor self_cache_values_tensor = context->input(4); + tf::Tensor memory_cache_keys_tensor = context->input(5); + tf::Tensor memory_cache_values_tensor = context->input(6); + tf::Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, context->input(0).shape(), &output)); DataType* out_tensor = (DataType*)(output->flat().data()); @@ -209,7 +222,7 @@ class DecoderOp: public BaseOp { context->set_output(4, memory_cache_values_tensor); const int* d_step = reinterpret_cast(context->input(7 + num_layer_ * 22).flat().data()); - int step; + int step; cudaMemcpyAsync(&step, d_step, sizeof(int), cudaMemcpyDeviceToHost, stream); step += 1; tf::Tensor sequence_length_tensor = context->input(8 + num_layer_ * 22); @@ -219,7 +232,7 @@ class DecoderOp: public BaseOp { size_t hidden_units = (size_t)(head_num_ * size_per_head_); - ft::DataType data_type = ft::getTensorType(); + ft::DataType data_type = ft::getTensorType(); std::vector input_tensors = std::vector{ this->convert_tensor(context->input(0)), this->convert_tensor(context->input(1)), @@ -229,7 +242,7 @@ class DecoderOp: public BaseOp { this->convert_int_tensor(sequence_length_tensor), ft::Tensor{ft::MEMORY_GPU, ft::TYPE_INT32, - {batch_size, 1, step}, + {batch_size, 1, (size_t)step}, nullptr}}; // Since we do gather in the Framework, we don't need id of indirection buffer std::vector output_tensors = @@ -253,9 +266,9 @@ class DecoderOp: public BaseOp { } private: - int head_num_ = 0, size_per_head_ = 0, inter_size_ = 0, num_layer_ = 0; - ft::cublasAlgoMap* cublas_algo_map_; - typedef TFTraits traits_; + int head_num_ = 0, size_per_head_ = 0, inter_size_ = 0, num_layer_ = 0; + ft::cublasAlgoMap* cublas_algo_map_; + typedef TFTraits traits_; typedef typename traits_::DataType DataType; }; diff --git a/src/fastertransformer/tf_op/decoder/FusedSelfAttentionOp.cc b/src/fastertransformer/tf_op/decoder/FusedSelfAttentionOp.cc index bbf73b113..055e93cd6 100644 --- a/src/fastertransformer/tf_op/decoder/FusedSelfAttentionOp.cc +++ b/src/fastertransformer/tf_op/decoder/FusedSelfAttentionOp.cc @@ -40,6 +40,14 @@ class TFTraits { typedef half DataType; }; +#ifdef ENABLE_BF16 +template<> +class TFTraits { +public: + typedef __nv_bfloat16 DataType; +}; +#endif + REGISTER_OP("FusedQkvMultiHeadAttention") .Input("qkv_tensor: T") .Input("qkv_bias: T") @@ -115,8 +123,8 @@ class FusedQkvMultiHeadAttentionOp: public BaseOp { "([seq_len, batch_size, head_num, size_per_head])")); } - const cudaStream_t& stream = context->eigen_device().stream(); - const DataType_* qkv_input = reinterpret_cast(context->input(0).flat().data()); + const cudaStream_t& stream = context->eigen_device().stream(); + const DataType_* qkv_input = reinterpret_cast(context->input(0).flat().data()); OP_REQUIRES(context, qkv_input != nullptr, tf::errors::InvalidArgument("qkv_input is null")); const DataType_* qkv_bias = reinterpret_cast(context->input(1).flat().data()); @@ -137,20 +145,31 @@ class FusedQkvMultiHeadAttentionOp: public BaseOp { try { fastertransformer::fusedQKV_masked_attention_dispatch(qkv_input, qkv_bias, - k_cache, - v_cache, - output_ptr, - nullptr, - nullptr, - batch_size_, - batch_size_, - head_num_, - size_per_head_, - decoder_max_seq_len, - 0, - nullptr, - seq_len_, - stream); + (const DataType_*)nullptr, // relative_attention_bias + k_cache, // key_cache + v_cache, // value_cache + (const int*)nullptr, // cache_indir + output_ptr, // context_buf + (const bool*)nullptr, // finished + (const int*)nullptr, // sequence_lengths + batch_size_, // max_batch_size + batch_size_, // inference_batch_size + 1, // beam_width + head_num_, // head_num + size_per_head_, // size_per_head + 0, // rotary_embedding_dim + false, // neox_rotary_style + decoder_max_seq_len, // max_seq_len + (const int*)nullptr, // prefix_prompt_lengths + 0, // max_prefix_prompt_length + 0, // max_input_len + (const int*)nullptr, // total_padding_tokens + seq_len_, // step + 1.0f, // q_scaling + 0, // relative_attention_bias_stride + (const bool*)nullptr, // masked_tokens + stream // stream + ); } catch (std::runtime_error& error) { std::cout << tf::errors::Internal(error.what()); @@ -163,8 +182,8 @@ class FusedQkvMultiHeadAttentionOp: public BaseOp { } private: - int batch_size_, head_num_, size_per_head_, seq_len_; - typedef TFTraits traits_; + int batch_size_, head_num_, size_per_head_, seq_len_; + typedef TFTraits traits_; typedef typename traits_::DataType DataType_; }; diff --git a/src/fastertransformer/tf_op/decoding/DecodingOp.cc b/src/fastertransformer/tf_op/decoding/DecodingOp.cc index 993f38f6e..09087f9ff 100644 --- a/src/fastertransformer/tf_op/decoding/DecodingOp.cc +++ b/src/fastertransformer/tf_op/decoding/DecodingOp.cc @@ -84,8 +84,8 @@ REGISTER_OP("Decoding") // calculate batch size tf::shape_inference::DimensionOrConstant max_seq_dim(max_seq_len); tf::shape_inference::DimensionOrConstant beam_width_dim(beam_width); - tf::shape_inference::DimensionHandle batchxbeam_dim = c->Dim(c->input(0), 0); - tf::shape_inference::DimensionHandle batch_dim; + tf::shape_inference::DimensionHandle batchxbeam_dim = c->Dim(c->input(0), 0); + tf::shape_inference::DimensionHandle batch_dim; TF_RETURN_IF_ERROR(c->Divide(batchxbeam_dim, beam_width_dim, true, &batch_dim)); if (beam_width > 1) { @@ -117,6 +117,14 @@ class TFTraits { typedef half DataType; }; +#ifdef ENABLE_BF16 +template<> +class TFTraits { +public: + typedef __nv_bfloat16 DataType; +}; +#endif + template class DecodingOp: public BaseOp { public: @@ -157,15 +165,15 @@ class DecodingOp: public BaseOp { context->num_inputs() == (num_layer_ * 22) + 8, tf::errors::InvalidArgument("[ERROR] More or Less input arguments")); - const size_t batch_size = (size_t)(context->input(0).dim_size(0) / beam_width_); + const size_t batch_size = (size_t)(context->input(0).dim_size(0) / beam_width_); const size_t mem_max_seq_len = (size_t)(context->input(0).dim_size(1)); - const size_t vocab_size = (size_t)(context->input(2 + num_layer_ * 22 + 3).dim_size(0)); + const size_t vocab_size = (size_t)(context->input(2 + num_layer_ * 22 + 3).dim_size(0)); - const cudaStream_t& stream = context->eigen_device().stream(); - cublasHandle_t cublas_handle = this->get_cublas_handler(); + const cudaStream_t& stream = context->eigen_device().stream(); + cublasHandle_t cublas_handle = this->get_cublas_handler(); cublasSetStream(cublas_handle, stream); ft::Allocator allocator(context, stream); - ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, + ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, this->get_cublaslt_handler(), stream, cublas_algo_map_, @@ -175,6 +183,11 @@ class DecodingOp: public BaseOp { if (std::is_same::value) { cublas_wrapper.setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } @@ -263,8 +276,8 @@ class DecodingOp: public BaseOp { this->get_tensor(context, 2 + num_layer_ * 22 + 4, &decoding_weights.post_decoder_embedding.kernel); this->get_tensor(context, 2 + num_layer_ * 22 + 5, &decoding_weights.post_decoder_embedding.bias); - tf::Tensor* output_id_tensor = nullptr; - tf::Tensor* parent_id_tensor = nullptr; + tf::Tensor* output_id_tensor = nullptr; + tf::Tensor* parent_id_tensor = nullptr; tf::Tensor* sequence_length_tensor = nullptr; if (beam_width_ > 1) { OP_REQUIRES_OK(context, @@ -290,8 +303,8 @@ class DecodingOp: public BaseOp { 1, {(long long int)max_seq_len_, (long long int)batch_size}, &parent_id_tensor)); OP_REQUIRES_OK(context, context->allocate_output(2, {(long long int)batch_size}, &sequence_length_tensor)); } - int* output_ids = (int*)(output_id_tensor->flat().data()); - int* parent_ids = (int*)(parent_id_tensor->flat().data()); + int* output_ids = (int*)(output_id_tensor->flat().data()); + int* parent_ids = (int*)(parent_id_tensor->flat().data()); int* sequence_length = (int*)(sequence_length_tensor->flat().data()); ft::Decoding decoding = ft::Decoding(batch_size, @@ -341,18 +354,18 @@ class DecodingOp: public BaseOp { } private: - int max_seq_len_ = 0, beam_width_ = 1; - int head_num_ = 0, size_per_head_ = 0, inter_size_; - int num_layer_ = 0, start_id_ = -1, end_id_ = -1; - float beam_search_diversity_rate_ = 1.0; - float temperature_; - float len_penalty_; - float repetition_penalty_; - int top_k_ = 0; - float top_p_ = 0.0f; - ft::cublasAlgoMap* cublas_algo_map_; - cudaDeviceProp prop_; - typedef TFTraits traits_; + int max_seq_len_ = 0, beam_width_ = 1; + int head_num_ = 0, size_per_head_ = 0, inter_size_; + int num_layer_ = 0, start_id_ = -1, end_id_ = -1; + float beam_search_diversity_rate_ = 1.0; + float temperature_; + float len_penalty_; + float repetition_penalty_; + int top_k_ = 0; + float top_p_ = 0.0f; + ft::cublasAlgoMap* cublas_algo_map_; + cudaDeviceProp prop_; + typedef TFTraits traits_; typedef typename traits_::DataType DataType; }; diff --git a/src/fastertransformer/tf_op/encoder/EncoderOp.cc b/src/fastertransformer/tf_op/encoder/EncoderOp.cc index e5786a39b..13296e539 100644 --- a/src/fastertransformer/tf_op/encoder/EncoderOp.cc +++ b/src/fastertransformer/tf_op/encoder/EncoderOp.cc @@ -86,7 +86,7 @@ class EncoderOp: public BaseOp { OP_REQUIRES_OK(context, context->GetAttr("num_layer", &num_layer_)); OP_REQUIRES_OK(context, context->GetAttr("remove_padding", &remove_padding_)); OP_REQUIRES_OK(context, context->GetAttr("q_scaling", &q_scaling_)); - sm_ = ft::getSMVersion(); + sm_ = ft::getSMVersion(); cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); } catch (std::runtime_error& error) { @@ -105,18 +105,18 @@ class EncoderOp: public BaseOp { context->num_inputs() == (num_layer_ * 16) + 5, tf::errors::InvalidArgument("[ERROR] More or Less input arguments")); - const size_t batch_size_ = (size_t)context->input(0).dim_size(0); + const size_t batch_size_ = (size_t)context->input(0).dim_size(0); const size_t from_seq_len_ = (size_t)context->input(0).dim_size(1); OP_REQUIRES(context, batch_size_ == (size_t)context->input(2).dim_size(0), tf::errors::InvalidArgument("[ERROR] invalid shape")); - const cudaStream_t& stream = context->eigen_device().stream(); - cublasHandle_t cublas_handle = this->get_cublas_handler(); + const cudaStream_t& stream = context->eigen_device().stream(); + cublasHandle_t cublas_handle = this->get_cublas_handler(); cublasSetStream(cublas_handle, stream); ft::Allocator allocator(context, stream); - ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, + ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, this->get_cublaslt_handler(), stream, cublas_algo_map_, @@ -208,7 +208,7 @@ class EncoderOp: public BaseOp { OP_REQUIRES_OK(context, context->allocate_output(0, context->input(0).shape(), &output)); DataType* out_tensor = reinterpret_cast(output->flat().data()); - ft::DataType data_type = ft::getTensorType(); + ft::DataType data_type = ft::getTensorType(); std::vector input_tensors = std::vector{this->convert_tensor(context->input(0)), this->convert_int_tensor(context->input(2))}; @@ -232,12 +232,12 @@ class EncoderOp: public BaseOp { } private: - int head_num_ = 0, size_per_head_ = 0, inter_size_ = 0, num_layer_ = 0; - float q_scaling_ = 1.0f; - bool remove_padding_; - int sm_; - ft::cublasAlgoMap* cublas_algo_map_; - typedef TFTraits traits_; + int head_num_ = 0, size_per_head_ = 0, inter_size_ = 0, num_layer_ = 0; + float q_scaling_ = 1.0f; + bool remove_padding_; + int sm_; + ft::cublasAlgoMap* cublas_algo_map_; + typedef TFTraits traits_; typedef typename traits_::DataType DataType; }; diff --git a/src/fastertransformer/tf_op/gpt/GptOp.cc b/src/fastertransformer/tf_op/gpt/GptOp.cc index 06df93845..019f11322 100644 --- a/src/fastertransformer/tf_op/gpt/GptOp.cc +++ b/src/fastertransformer/tf_op/gpt/GptOp.cc @@ -45,7 +45,6 @@ REGISTER_OP("Gpt") .Input("pre_decoder_embedding_table: T") // 17 .Input("post_decoder_embedding_kernel: T") // 18 .Output("output_ids: int32") - .Output("parent_ids: int32") .Output("sequence_length: int32") .Output("cum_log_probs: float") .Attr("N: int") @@ -79,19 +78,11 @@ REGISTER_OP("Gpt") // calculate batch size tf::shape_inference::DimensionOrConstant max_seq_dim(max_seq_len); tf::shape_inference::DimensionOrConstant beam_dim(beam_width); - tf::shape_inference::DimensionHandle batch_dim = c->Dim(c->input(0), 0); + tf::shape_inference::DimensionHandle batch_dim = c->Dim(c->input(0), 0); - if (beam_width > 1) { - c->set_output(0, c->MakeShape({batch_dim, beam_dim, max_seq_len})); - c->set_output(1, c->MakeShape({max_seq_len, batch_dim, beam_dim})); - c->set_output(2, c->MakeShape({batch_dim, beam_dim})); - } - else { - c->set_output(0, c->MakeShape({batch_dim, max_seq_len})); - c->set_output(1, c->MakeShape({max_seq_len, batch_dim, 1})); - c->set_output(2, c->MakeShape({batch_dim})); - c->set_output(3, c->MakeShape({request_output_length, batch_dim})); - } + c->set_output(0, c->MakeShape({batch_dim, beam_dim, max_seq_len})); + c->set_output(1, c->MakeShape({batch_dim, beam_dim})); + c->set_output(2, c->MakeShape({batch_dim, beam_dim, request_output_length})); return tf::Status::OK(); }); @@ -161,15 +152,15 @@ class GptOp: public BaseOp { context->num_inputs() == (num_layer_ * 12) + 7, tf::errors::InvalidArgument("[ERROR] More or Less input arguments")); - const size_t batch_size = (size_t)context->input(0).dim_size(0); - const size_t vocab_size = (size_t)(context->input(2 + num_layer_ * 12 + 3).dim_size(0)); + const size_t batch_size = (size_t)context->input(0).dim_size(0); + const size_t vocab_size = (size_t)(context->input(2 + num_layer_ * 12 + 3).dim_size(0)); const size_t max_input_length = (size_t)(context->input(0).dim_size(1)); - const cudaStream_t& stream = context->eigen_device().stream(); - cublasHandle_t cublas_handle = this->get_cublas_handler(); + const cudaStream_t& stream = context->eigen_device().stream(); + cublasHandle_t cublas_handle = this->get_cublas_handler(); cublasSetStream(cublas_handle, stream); ft::Allocator allocator(context, stream); - ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, + ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, this->get_cublaslt_handler(), stream, cublas_algo_map_, @@ -238,59 +229,29 @@ class GptOp: public BaseOp { this->get_tensor(context, 2 + num_layer_ * 12 + 4, &gpt_weidghts.post_decoder_embedding.kernel); int total_output_length = request_output_length_ + (int)max_input_length; - tf::Tensor* output_id_tensor = nullptr; - tf::Tensor* parent_id_tensor = nullptr; + tf::Tensor* output_id_tensor = nullptr; tf::Tensor* sequence_length_tensor = nullptr; - tf::Tensor* cum_log_probs = nullptr; - if (beam_width_ > 1) { + tf::Tensor* cum_log_probs = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, + {(long long int)batch_size, (long long int)beam_width_, (long long int)total_output_length}, + &output_id_tensor)); + OP_REQUIRES_OK(context, + context->allocate_output( + 1, {(long long int)batch_size, (long long int)beam_width_}, &sequence_length_tensor)); + if (this->output_log_probs_) { OP_REQUIRES_OK( context, context->allocate_output( - 0, - {(long long int)batch_size, (long long int)beam_width_, (long long int)total_output_length}, - &output_id_tensor)); - OP_REQUIRES_OK( - context, - context->allocate_output( - 1, - {(long long int)total_output_length, (long long int)batch_size, (long long int)beam_width_}, - &parent_id_tensor)); - OP_REQUIRES_OK(context, - context->allocate_output( - 2, {(long long int)batch_size, (long long int)beam_width_}, &sequence_length_tensor)); - - if (this->output_log_probs_) { - OP_REQUIRES_OK( - context, - context->allocate_output( - 3, - {(long long int)request_output_length_, (long long int)batch_size, (long long int)beam_width_}, - &cum_log_probs)); - } - else { - OP_REQUIRES_OK(context, context->allocate_output(3, {0}, &cum_log_probs)); - } + 2, + {(long long int)batch_size, (long long int)beam_width_, (long long int)request_output_length_}, + &cum_log_probs)); } else { - OP_REQUIRES_OK(context, - context->allocate_output( - 0, {(long long int)batch_size, (long long int)total_output_length}, &output_id_tensor)); - OP_REQUIRES_OK(context, - context->allocate_output( - 1, {(long long int)total_output_length, (long long int)batch_size}, &parent_id_tensor)); - OP_REQUIRES_OK(context, context->allocate_output(2, {(long long int)batch_size}, &sequence_length_tensor)); - if (this->output_log_probs_) { - OP_REQUIRES_OK( - context, - context->allocate_output( - 3, {(long long int)request_output_length_, (long long int)batch_size}, &cum_log_probs)); - } - else { - OP_REQUIRES_OK(context, context->allocate_output(3, {0}, &cum_log_probs)); - } + OP_REQUIRES_OK(context, context->allocate_output(3, {0}, &cum_log_probs)); } - int* output_ids = (int*)(output_id_tensor->flat().data()); - int* parent_ids = (int*)(parent_id_tensor->flat().data()); + int* output_ids = (int*)(output_id_tensor->flat().data()); int* sequence_length = (int*)(sequence_length_tensor->flat().data()); ft::NcclParam tensor_para; @@ -307,6 +268,9 @@ class GptOp: public BaseOp { vocab_size, start_id_, end_id_, + end_id_ + 1, // p_prompt_tuning token start id + ft::PromptLearningType::no_prompt, + ft::gptVariantParams{}, beam_search_diversity_rate_, top_k_, top_p_, @@ -342,7 +306,7 @@ class GptOp: public BaseOp { } if (top_k_ != 0) { input_tensors.insert( - {"runtime_top_k", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_INT32, std::vector{1}, &top_k_}}); + {"runtime_top_k", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_UINT32, std::vector{1}, &top_k_}}); } } input_tensors.insert( @@ -355,17 +319,18 @@ class GptOp: public BaseOp { input_tensors.insert( {"random_seed", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_UINT64, std::vector{1}, &random_seed}}); + std::vector total_output_length_vec(batch_size, total_output_length); + input_tensors.insert( + {"output_seq_len", + ft::Tensor{ + ft::MEMORY_CPU, ft::TYPE_UINT32, std::vector{batch_size}, total_output_length_vec.data()}}); + std::unordered_map output_tensors = std::unordered_map{ {"output_ids", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_INT32, std::vector{batch_size, (size_t)beam_width_, (size_t)total_output_length}, output_ids}}, - {"parent_ids", - ft::Tensor{ft::MEMORY_GPU, - ft::TYPE_INT32, - std::vector{(size_t)total_output_length, batch_size, (size_t)beam_width_}, - parent_ids}}, {"sequence_length", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_INT32, @@ -374,7 +339,7 @@ class GptOp: public BaseOp { {"output_log_probs", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_FP32, - {(size_t)request_output_length_, (size_t)batch_size, (size_t)beam_width_}, + {(size_t)batch_size, (size_t)beam_width_, (size_t)request_output_length_}, output_log_probs_ ? reinterpret_cast(cum_log_probs->flat().data()) : nullptr}}}; try { @@ -391,20 +356,20 @@ class GptOp: public BaseOp { } private: - int max_batch_size_ = 0, max_seq_len_ = 0, beam_width_ = 1; - int head_num_ = 0, size_per_head_ = 0, inter_size_ = 0; - int num_layer_ = 0, start_id_ = -1, end_id_ = -1; - float beam_search_diversity_rate_ = 1.0; - float temperature_; - float len_penalty_; - float repetition_penalty_; - int top_k_ = 0; - float top_p_ = 0.0f; - bool output_log_probs_; - int request_output_length_; - ft::cublasAlgoMap* cublas_algo_map_; - struct cudaDeviceProp prop_; - typedef TFTraits traits_; + int max_batch_size_ = 0, max_seq_len_ = 0, beam_width_ = 1; + int head_num_ = 0, size_per_head_ = 0, inter_size_ = 0; + int num_layer_ = 0, start_id_ = -1, end_id_ = -1; + float beam_search_diversity_rate_ = 1.0; + float temperature_; + float len_penalty_; + float repetition_penalty_; + int top_k_ = 0; + float top_p_ = 0.0f; + bool output_log_probs_; + int request_output_length_; + ft::cublasAlgoMap* cublas_algo_map_; + struct cudaDeviceProp prop_; + typedef TFTraits traits_; typedef typename traits_::DataType DataType; }; diff --git a/src/fastertransformer/th_op/CMakeLists.txt b/src/fastertransformer/th_op/CMakeLists.txt index 138b2d0a3..5aac28068 100644 --- a/src/fastertransformer/th_op/CMakeLists.txt +++ b/src/fastertransformer/th_op/CMakeLists.txt @@ -17,9 +17,8 @@ add_definitions(-DTORCH_CUDA=1) add_library(th_utils STATIC th_utils.cu) set_property(TARGET th_utils PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET th_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(th_utils PUBLIC "${TORCH_LIBRARIES}" -lcublas -lcudart -lcurand) +target_link_libraries(th_utils PUBLIC "${TORCH_LIBRARIES}" -lcublas -lcudart -lcurand tensor) -add_subdirectory(bert) add_subdirectory(encoder) add_subdirectory(decoder) add_subdirectory(decoding) @@ -28,7 +27,6 @@ add_subdirectory(longformer) add_subdirectory(swin) add_subdirectory(vit) -if(BUILD_MULTI_GPU) - add_subdirectory(multi_gpu_gpt) - add_subdirectory(t5) -endif() +add_subdirectory(multi_gpu_gpt) +add_subdirectory(t5) +add_subdirectory(bert) diff --git a/src/fastertransformer/th_op/bert/BertINT8Op.cc b/src/fastertransformer/th_op/bert/BertINT8Op.cc index cf6de4190..245566ab3 100644 --- a/src/fastertransformer/th_op/bert/BertINT8Op.cc +++ b/src/fastertransformer/th_op/bert/BertINT8Op.cc @@ -38,13 +38,13 @@ FasterTransformerINT8Bert::FasterTransformerINT8Bert(th::Tensor q_kernel, th::Tensor output_layernorm_beta, th::Tensor d_scale_list, th::Tensor h_scale_list, - int64_t head_num, - int64_t head_size, - bool remove_padding, - int64_t layer_num, - int64_t int8_mode, - bool sparse, - double q_scaling): + int64_t head_num, + int64_t head_size, + bool remove_padding, + int64_t layer_num, + int64_t int8_mode, + bool sparse, + double q_scaling): _st(q_kernel.scalar_type()), _remove_padding(remove_padding), weights{q_kernel, @@ -93,14 +93,14 @@ FasterTransformerINT8Bert::FasterTransformerINT8Bert(th::Tensor q_kernel, default: throw std::runtime_error("Wrong Tensor type."); } - head_info = torch::empty({6}, torch::dtype(torch::kInt64)); - head_info[0] = head_num; - head_info[1] = head_size; - head_info[2] = (int64_t)remove_padding; - head_info[3] = layer_num; - head_info[4] = int8_mode; - head_info[5] = (int64_t)sparse; - scaling_info = torch::empty({1}, torch::dtype(torch::kFloat64)); + head_info = torch::empty({6}, torch::dtype(torch::kInt64)); + head_info[0] = head_num; + head_info[1] = head_size; + head_info[2] = (int64_t)remove_padding; + head_info[3] = layer_num; + head_info[4] = int8_mode; + head_info[5] = (int64_t)sparse; + scaling_info = torch::empty({1}, torch::dtype(torch::kFloat64)); scaling_info[0] = (double)q_scaling; } @@ -116,7 +116,7 @@ th::Tensor FasterTransformerINT8Bert::forward(th::Tensor input, th::Tensor seque CHECK_CONTIGUOUS(sequence_lengths); TORCH_CHECK(sequence_lengths.dtype() == torch::kInt32, "sequence_lengths dtype should be int32"); int batch_size = input.size(0); - int seq_len = input.size(1); + int seq_len = input.size(1); auto output = torch::empty_like(input); ftbert->forward(batch_size, seq_len, input, sequence_lengths, output, _remove_padding); @@ -169,13 +169,13 @@ static auto fasterTransformerINT8BertTHS = return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int64_t head_num = state[18][0].item().to(); - int64_t head_size = state[18][1].item().to(); - bool remove_padding = (bool)(state[18][2].item().to()); - int64_t layer_num = state[18][3].item().to(); - int64_t int8_mode = state[18][4].item().to(); - bool sparse = (bool)(state[18][5].item().to()); - double q_scaling = state[19][0].item().to(); + int64_t head_num = state[18][0].item().to(); + int64_t head_size = state[18][1].item().to(); + bool remove_padding = (bool)(state[18][2].item().to()); + int64_t layer_num = state[18][3].item().to(); + int64_t int8_mode = state[18][4].item().to(); + bool sparse = (bool)(state[18][5].item().to()); + double q_scaling = state[19][0].item().to(); return c10::make_intrusive(state[0], state[1], state[2], diff --git a/src/fastertransformer/th_op/bert/BertINT8Op.h b/src/fastertransformer/th_op/bert/BertINT8Op.h index e440b2768..5d5ec5ad5 100644 --- a/src/fastertransformer/th_op/bert/BertINT8Op.h +++ b/src/fastertransformer/th_op/bert/BertINT8Op.h @@ -28,23 +28,23 @@ namespace torch_ext { class IFBertINT8 { public: virtual ~IFBertINT8() {} - virtual void forward(int batch_size, - int seq_len, + virtual void forward(int batch_size, + int seq_len, th::Tensor& input, th::Tensor& sequence_lengths, th::Tensor& output, - bool removing_padding) = 0; + bool removing_padding) = 0; }; template class FTBertINT8: public IFBertINT8 { public: - FTBertINT8(int head_num, - int head_size, - int layer_num, - const float q_scaling, - int int8_mode, - bool sparse, + FTBertINT8(int head_num, + int head_size, + int layer_num, + const float q_scaling, + int int8_mode, + bool sparse, const std::vector& w): _head_num(head_num), _head_size(head_size), @@ -74,8 +74,8 @@ class FTBertINT8: public IFBertINT8 { } #endif std::string sp_config_fname = sparse ? "spigemm_config.in" : ""; - cublas_algo_map_ = new ft::cublasAlgoMap("igemm_config.in", sp_config_fname); - cublas_wrapper_mutex_ = new std::mutex(); + cublas_algo_map_ = new ft::cublasAlgoMap("igemm_config.in", sp_config_fname); + cublas_wrapper_mutex_ = new std::mutex(); bert_layer_weights.clear(); bert_layer_weights.resize(_layer_num); @@ -94,15 +94,15 @@ class FTBertINT8: public IFBertINT8 { bert_layer_weights[i].attention_weights.attention_output_weight.bias = get_ptr(_weights[7]) + hidden_dim * i; bert_layer_weights[i].attn_layernorm_weights.gamma = get_ptr(_weights[8]) + hidden_dim * i; - bert_layer_weights[i].attn_layernorm_weights.beta = get_ptr(_weights[9]) + hidden_dim * i; + bert_layer_weights[i].attn_layernorm_weights.beta = get_ptr(_weights[9]) + hidden_dim * i; bert_layer_weights[i].ffn_weights.intermediate_weight.kernel = get_ptr(_weights[10]) + hidden_dim * hidden_dim * 4 * i; bert_layer_weights[i].ffn_weights.intermediate_weight.bias = get_ptr(_weights[11]) + hidden_dim * 4 * i; bert_layer_weights[i].ffn_weights.output_weight.kernel = get_ptr(_weights[12]) + hidden_dim * hidden_dim * 4 * i; bert_layer_weights[i].ffn_weights.output_weight.bias = get_ptr(_weights[13]) + hidden_dim * i; - bert_layer_weights[i].ffn_layernorm_weights.gamma = get_ptr(_weights[14]) + hidden_dim * i; - bert_layer_weights[i].ffn_layernorm_weights.beta = get_ptr(_weights[15]) + hidden_dim * i; + bert_layer_weights[i].ffn_layernorm_weights.gamma = get_ptr(_weights[14]) + hidden_dim * i; + bert_layer_weights[i].ffn_layernorm_weights.beta = get_ptr(_weights[15]) + hidden_dim * i; // for scale_list bert_layer_weights[i].scale_list_.size_ = @@ -114,7 +114,7 @@ class FTBertINT8: public IFBertINT8 { bert_layer_weights[i].scale_list_.h_scale_list_ = get_ptr(_weights[17]) + i * bert_layer_weights[i].scale_list_.size_; bert_layer_weights[i].attention_weights.scale_list_ptr = &(bert_layer_weights[i].scale_list_); - bert_layer_weights[i].ffn_weights.scale_list_ptr = &(bert_layer_weights[i].scale_list_); + bert_layer_weights[i].ffn_weights.scale_list_ptr = &(bert_layer_weights[i].scale_list_); } if (sparse) { for (int i = 0; i < _layer_num; ++i) { @@ -146,14 +146,14 @@ class FTBertINT8: public IFBertINT8 { delete cublas_wrapper_mutex_; } - void forward(int batch_size, - int seq_len, + void forward(int batch_size, + int seq_len, th::Tensor& input, th::Tensor& sequence_lengths, th::Tensor& output, - bool removing_padding) override + bool removing_padding) override { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); ft::cublasINT8MMWrapper cublas_wrapper = #ifdef SPARSITY_ENABLED ft::cublasINT8MMWrapper(_cublasltHandle, @@ -188,7 +188,7 @@ class FTBertINT8: public IFBertINT8 { attention_type, _sparse); - ft::DataType data_type = ft::getTensorType(); + ft::DataType data_type = ft::getTensorType(); std::vector input_tensors = std::vector{ ft::Tensor{ft::MEMORY_GPU, data_type, @@ -221,21 +221,21 @@ class FTBertINT8: public IFBertINT8 { } private: - const int _head_num; - const int _head_size; + const int _head_num; + const int _head_size; std::vector _weights; - const int _layer_num; - const float _q_scaling; - int _int8_mode; - bool _sparse; - int sm_; - bool _use_ORDER_COL32_2R_4R4; - cublasLtHandle_t _cublasltHandle; + const int _layer_num; + const float _q_scaling; + int _int8_mode; + bool _sparse; + int sm_; + bool _use_ORDER_COL32_2R_4R4; + cublasLtHandle_t _cublasltHandle; #ifdef SPARSITY_ENABLED cusparseLtHandle_t _cusparseLtHandle; #endif - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; std::vector> bert_layer_weights; }; @@ -259,13 +259,13 @@ class FasterTransformerINT8Bert: public th::jit::CustomClassHolder { th::Tensor output_layernorm_beta, th::Tensor d_scale_list, th::Tensor h_scale_list, - int64_t head_num, - int64_t head_size, - bool remove_padding, - int64_t layer_num, - int64_t int8_mode, - bool sparse, - double q_scaling); + int64_t head_num, + int64_t head_size, + bool remove_padding, + int64_t layer_num, + int64_t int8_mode, + bool sparse, + double q_scaling); ~FasterTransformerINT8Bert(); th::Tensor forward(th::Tensor input, th::Tensor sequence_lengths); @@ -273,11 +273,11 @@ class FasterTransformerINT8Bert: public th::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType _st; - bool _remove_padding; - IFBertINT8* ftbert; - th::Tensor head_info; - th::Tensor scaling_info; + const at::ScalarType _st; + bool _remove_padding; + IFBertINT8* ftbert; + th::Tensor head_info; + th::Tensor scaling_info; std::vector weights; }; diff --git a/src/fastertransformer/th_op/bert/BertOp.cc b/src/fastertransformer/th_op/bert/BertOp.cc index c88a2d5b2..24374c8f9 100644 --- a/src/fastertransformer/th_op/bert/BertOp.cc +++ b/src/fastertransformer/th_op/bert/BertOp.cc @@ -16,6 +16,10 @@ #include "src/fastertransformer/th_op/bert/BertOp.h" +#ifdef USE_NVTX +bool NVTX_ON = true; +#endif + namespace th = torch; namespace torch_ext { @@ -35,13 +39,15 @@ FasterTransformerBert::FasterTransformerBert(th::Tensor q_kernel, th::Tensor output_bias, th::Tensor output_layernorm_gamma, th::Tensor output_layernorm_beta, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - bool remove_padding, - int64_t layer_num, - bool sparse, - double q_scaling): + int64_t head_num, + int64_t head_size, + int64_t inter_size, + bool remove_padding, + int64_t layer_num, + bool sparse, + double q_scaling, + int64_t tensor_para_size, + int64_t pipeline_para_size): _st(q_kernel.scalar_type()), _remove_padding(remove_padding), weights{q_kernel, @@ -80,22 +86,53 @@ FasterTransformerBert::FasterTransformerBert(th::Tensor q_kernel, switch (_st) { case at::ScalarType::Float: - ftbert = new FTBert(head_num, head_size, inter_size, layer_num, sparse, q_scaling, weights); + ftbert = new FTBert(head_num, + head_size, + inter_size, + layer_num, + sparse, + q_scaling, + tensor_para_size, + pipeline_para_size, + weights); break; case at::ScalarType::Half: - ftbert = new FTBert(head_num, head_size, inter_size, layer_num, sparse, q_scaling, weights); + ftbert = new FTBert(head_num, + head_size, + inter_size, + layer_num, + sparse, + q_scaling, + tensor_para_size, + pipeline_para_size, + weights); + break; +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + ftbert = new FTBert<__nv_bfloat16>(head_num, + head_size, + inter_size, + layer_num, + sparse, + q_scaling, + tensor_para_size, + pipeline_para_size, + weights); break; +#endif default: throw std::runtime_error("Wrong Tensor type."); } - head_info = torch::empty({6}, torch::dtype(torch::kInt64)); - head_info[0] = head_num; - head_info[1] = head_size; - head_info[2] = (int64_t)remove_padding; - head_info[3] = layer_num; - head_info[4] = (int64_t)sparse; - head_info[5] = inter_size; - scaling_info = torch::empty({1}, torch::dtype(torch::kFloat64)); + head_info = torch::empty({8}, torch::dtype(torch::kInt64)); + head_info[0] = head_num; + head_info[1] = head_size; + head_info[2] = (int64_t)remove_padding; + head_info[3] = layer_num; + head_info[4] = (int64_t)sparse; + head_info[5] = inter_size; + head_info[6] = tensor_para_size; + head_info[7] = pipeline_para_size; + scaling_info = torch::empty({1}, torch::dtype(torch::kFloat64)); scaling_info[0] = (double)q_scaling; } @@ -111,7 +148,7 @@ th::Tensor FasterTransformerBert::forward(th::Tensor input, th::Tensor sequence_ CHECK_CONTIGUOUS(sequence_lengths); TORCH_CHECK(sequence_lengths.dtype() == torch::kInt32, "sequence_lengths dtype should be int32"); size_t batch_size = (size_t)input.size(0); - size_t seq_len = (size_t)input.size(1); + size_t seq_len = (size_t)input.size(1); auto output = torch::empty_like(input); ftbert->forward(batch_size, seq_len, input, sequence_lengths, output, _remove_padding); @@ -156,20 +193,24 @@ static auto fasterTransformerBertTHS = bool, int64_t, bool, - double>()) + double, + int64_t, + int64_t>()) .def("forward", &torch_ext::FasterTransformerBert::forward) .def_pickle( [](const c10::intrusive_ptr& self) -> std::vector { return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int64_t head_num = state[16][0].item().to(); - int64_t head_size = state[16][1].item().to(); - bool remove_padding = (bool)(state[16][2].item().to()); - int64_t layer_num = state[16][3].item().to(); - bool sparse = (bool)(state[16][4].item().to()); - int64_t inter_size = state[16][5].item().to(); - double q_scaling = state[17][0].item().to(); + int64_t head_num = state[16][0].item().to(); + int64_t head_size = state[16][1].item().to(); + bool remove_padding = (bool)(state[16][2].item().to()); + int64_t layer_num = state[16][3].item().to(); + bool sparse = (bool)(state[16][4].item().to()); + int64_t inter_size = state[16][5].item().to(); + int64_t tensor_para_size = state[16][6].item().to(); + int64_t pipeline_para_size = state[16][7].item().to(); + double q_scaling = state[17][0].item().to(); return c10::make_intrusive(state[0], state[1], state[2], @@ -192,5 +233,7 @@ static auto fasterTransformerBertTHS = remove_padding, layer_num, sparse, - q_scaling); + q_scaling, + tensor_para_size, + pipeline_para_size); }); \ No newline at end of file diff --git a/src/fastertransformer/th_op/bert/BertOp.h b/src/fastertransformer/th_op/bert/BertOp.h index 2c6aaca57..8c8c482fb 100644 --- a/src/fastertransformer/th_op/bert/BertOp.h +++ b/src/fastertransformer/th_op/bert/BertOp.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "src/fastertransformer/models/bert/Bert.h" #include "src/fastertransformer/th_op/th_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" namespace ft = fastertransformer; namespace th = torch; @@ -24,23 +25,25 @@ namespace torch_ext { class IFBert { public: virtual ~IFBert() {} - virtual void forward(size_t batch_size, - size_t seq_len, + virtual void forward(size_t batch_size, + size_t seq_len, th::Tensor& input, th::Tensor& sequence_lengths, th::Tensor& output, - bool removing_padding) = 0; + bool removing_padding) = 0; }; template class FTBert: public IFBert { public: - FTBert(int head_num, - int head_size, - int inter_size, - int layer_num, - bool sparse, - float q_scaling, + FTBert(size_t head_num, + size_t head_size, + size_t inter_size, + size_t layer_num, + bool sparse, + float q_scaling, + const size_t tensor_para_size, + const size_t pipeline_para_size, const std::vector& w): _head_num(head_num), _head_size(head_size), @@ -55,7 +58,9 @@ class FTBert: public IFBert { std::cout << "[WARNING] Sparsity support is not enabled. Will use dense GEMM instead.\n" << std::flush; } #endif - int hidden_dim = _head_num * _head_size; + ft::ftNcclInitialize(tensor_para_, pipeline_para_, tensor_para_size, pipeline_para_size); + const int hidden_dim = _head_num * _head_size; + const int local_hidden_dim = (_head_num / tensor_para_.world_size_) * _head_size; ft::check_cuda_error(cublasLtCreate(&_cublasltHandle)); sm_ = ft::getSMVersion(); #ifdef SPARSITY_ENABLED @@ -64,44 +69,72 @@ class FTBert: public IFBert { } #endif std::string sp_config_fname = sparse ? "spgemm_config.in" : ""; - cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", sp_config_fname); - cublas_wrapper_mutex_ = new std::mutex(); + cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", sp_config_fname); + cublas_wrapper_mutex_ = new std::mutex(); + + + // #define LOAD_MODEL // Used for debug + +#ifdef LOAD_MODEL + bert_weights = ft::BertWeight(_head_num * _head_size, + _inter_size, + _layer_num, + tensor_para_.world_size_, + tensor_para_.rank_, + pipeline_para_.world_size_, + pipeline_para_.rank_); + bert_weights.loadModel("tmp/" + std::to_string(tensor_para_.world_size_) + "-gpu/"); +#else bert_weights.bert_layer_weights.clear(); bert_weights.bert_layer_weights.resize(_layer_num); - for (int i = 0; i < _layer_num; i++) { + int local_num_layer = (int)(ceil(_layer_num * 1.0f / pipeline_para_.world_size_)); + if (!(i < _layer_num && (i >= local_num_layer * pipeline_para_.rank_) + && (i < local_num_layer * (pipeline_para_.rank_ + 1)))) { + continue; + } + const int first_layer_index = local_num_layer * pipeline_para_.rank_; + bert_weights.bert_layer_weights[i].attention_weights.query_weight.kernel = - get_ptr(_weights[0]) + hidden_dim * hidden_dim * i; + get_ptr(_weights[0]) + hidden_dim * local_hidden_dim * (i - first_layer_index); bert_weights.bert_layer_weights[i].attention_weights.query_weight.bias = - get_ptr(_weights[1]) + hidden_dim * i; + get_ptr(_weights[1]) + local_hidden_dim * (i - first_layer_index); bert_weights.bert_layer_weights[i].attention_weights.key_weight.kernel = - get_ptr(_weights[2]) + hidden_dim * hidden_dim * i; + get_ptr(_weights[2]) + hidden_dim * local_hidden_dim * (i - first_layer_index); bert_weights.bert_layer_weights[i].attention_weights.key_weight.bias = - get_ptr(_weights[3]) + hidden_dim * i; + get_ptr(_weights[3]) + local_hidden_dim * (i - first_layer_index); bert_weights.bert_layer_weights[i].attention_weights.value_weight.kernel = - get_ptr(_weights[4]) + hidden_dim * hidden_dim * i; + get_ptr(_weights[4]) + hidden_dim * local_hidden_dim * (i - first_layer_index); bert_weights.bert_layer_weights[i].attention_weights.value_weight.bias = - get_ptr(_weights[5]) + hidden_dim * i; + get_ptr(_weights[5]) + local_hidden_dim * (i - first_layer_index); bert_weights.bert_layer_weights[i].attention_weights.attention_output_weight.kernel = - get_ptr(_weights[6]) + hidden_dim * hidden_dim * i; + get_ptr(_weights[6]) + local_hidden_dim * hidden_dim * (i - first_layer_index); bert_weights.bert_layer_weights[i].attention_weights.attention_output_weight.bias = - get_ptr(_weights[7]) + hidden_dim * i; - bert_weights.bert_layer_weights[i].attn_layernorm_weights.gamma = get_ptr(_weights[8]) + hidden_dim * i; - bert_weights.bert_layer_weights[i].attn_layernorm_weights.beta = get_ptr(_weights[9]) + hidden_dim * i; + get_ptr(_weights[7]) + hidden_dim * (i - first_layer_index); + bert_weights.bert_layer_weights[i].attn_layernorm_weights.gamma = + get_ptr(_weights[8]) + hidden_dim * (i - first_layer_index); + bert_weights.bert_layer_weights[i].attn_layernorm_weights.beta = + get_ptr(_weights[9]) + hidden_dim * (i - first_layer_index); bert_weights.bert_layer_weights[i].ffn_weights.intermediate_weight.kernel = - get_ptr(_weights[10]) + hidden_dim * hidden_dim * 4 * i; + get_ptr(_weights[10]) + + hidden_dim * (_inter_size / tensor_para_.world_size_) * (i - first_layer_index); bert_weights.bert_layer_weights[i].ffn_weights.intermediate_weight.bias = - get_ptr(_weights[11]) + hidden_dim * 4 * i; + get_ptr(_weights[11]) + (_inter_size / tensor_para_.world_size_) * (i - first_layer_index); bert_weights.bert_layer_weights[i].ffn_weights.output_weight.kernel = - get_ptr(_weights[12]) + hidden_dim * hidden_dim * 4 * i; + get_ptr(_weights[12]) + + (_inter_size / tensor_para_.world_size_) * hidden_dim * (i - first_layer_index); bert_weights.bert_layer_weights[i].ffn_weights.output_weight.bias = - get_ptr(_weights[13]) + hidden_dim * i; - bert_weights.bert_layer_weights[i].ffn_layernorm_weights.gamma = get_ptr(_weights[14]) + hidden_dim * i; - bert_weights.bert_layer_weights[i].ffn_layernorm_weights.beta = get_ptr(_weights[15]) + hidden_dim * i; + get_ptr(_weights[13]) + hidden_dim * (i - first_layer_index); + bert_weights.bert_layer_weights[i].ffn_layernorm_weights.gamma = + get_ptr(_weights[14]) + hidden_dim * (i - first_layer_index); + bert_weights.bert_layer_weights[i].ffn_layernorm_weights.beta = + get_ptr(_weights[15]) + hidden_dim * (i - first_layer_index); } +#endif + #ifdef SPARSITY_ENABLED if (sparse) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t _cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(_cublasHandle, stream); ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(_cublasHandle, @@ -120,6 +153,8 @@ class FTBert: public IFBert { ~FTBert() override { + ft::ftNcclParamDestroy(tensor_para_); + ft::ftNcclParamDestroy(pipeline_para_); cublasLtDestroy(_cublasltHandle); #ifdef SPARSITY_ENABLED if (_sparse) { @@ -130,18 +165,18 @@ class FTBert: public IFBert { delete cublas_wrapper_mutex_; } - void forward(size_t batch_size, - size_t seq_len, + void forward(size_t batch_size, + size_t seq_len, th::Tensor& input, th::Tensor& sequence_lengths, th::Tensor& output, - bool removing_padding) override + bool removing_padding) override { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t _cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(_cublasHandle, stream); ft::Allocator* allocator = new ft::Allocator(); - ft::cublasMMWrapper* cublas_wrapper = + ft::cublasMMWrapper* cublas_wrapper = #ifdef SPARSITY_ENABLED new ft::cublasMMWrapper(_cublasHandle, _cublasltHandle, @@ -158,6 +193,11 @@ class FTBert: public IFBert { if (std::is_same::value) { cublas_wrapper->setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } @@ -179,9 +219,13 @@ class FTBert: public IFBert { attention_type, _sparse, ft::ActivationType::Gelu, - ft::LayerNormType::post_layernorm); + ft::LayerNormType::post_layernorm, + tensor_para_, + pipeline_para_, + nullptr, + false); - ft::DataType data_type = ft::getTensorType(); + ft::DataType data_type = ft::getTensorType(); std::vector input_tensors = std::vector{ ft::Tensor{ft::MEMORY_GPU, data_type, @@ -213,21 +257,23 @@ class FTBert: public IFBert { } private: - const int _head_num; - const int _head_size; - const int _inter_size; - const int _layer_num; + const size_t _head_num; + const size_t _head_size; + const size_t _inter_size; + const size_t _layer_num; std::vector _weights; - bool _sparse; - const float _q_scaling; - int sm_; - cublasLtHandle_t _cublasltHandle; + bool _sparse; + const float _q_scaling; + ft::NcclParam tensor_para_; + ft::NcclParam pipeline_para_; + int sm_; + cublasLtHandle_t _cublasltHandle; #ifdef SPARSITY_ENABLED cusparseLtHandle_t _cusparseLtHandle; #endif - std::mutex* cublas_wrapper_mutex_; + std::mutex* cublas_wrapper_mutex_; ft::cublasAlgoMap* cublas_algo_map_; - ft::BertWeight bert_weights; + ft::BertWeight bert_weights; }; class FasterTransformerBert: public th::jit::CustomClassHolder { @@ -248,13 +294,15 @@ class FasterTransformerBert: public th::jit::CustomClassHolder { th::Tensor output_bias, th::Tensor output_layernorm_gamma, th::Tensor output_layernorm_beta, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - bool remove_padding, - int64_t layer_num, - bool sparse, - double q_scaling); + int64_t head_num, + int64_t head_size, + int64_t inter_size, + bool remove_padding, + int64_t layer_num, + bool sparse, + double q_scaling, + int64_t tensor_para_size, + int64_t pipeline_para_size); ~FasterTransformerBert(); @@ -263,11 +311,11 @@ class FasterTransformerBert: public th::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType _st; - bool _remove_padding; - IFBert* ftbert; - th::Tensor head_info; - th::Tensor scaling_info; + const at::ScalarType _st; + bool _remove_padding; + IFBert* ftbert; + th::Tensor head_info; + th::Tensor scaling_info; std::vector weights; }; diff --git a/src/fastertransformer/th_op/bert/CMakeLists.txt b/src/fastertransformer/th_op/bert/CMakeLists.txt index a2068642d..de7ed5846 100644 --- a/src/fastertransformer/th_op/bert/CMakeLists.txt +++ b/src/fastertransformer/th_op/bert/CMakeLists.txt @@ -13,4 +13,4 @@ # limitations under the License. add_library(th_bert SHARED BertOp.cc BertINT8Op.cc WeightQuantizeOp.cc) -target_link_libraries(th_bert PRIVATE "${TORCH_LIBRARIES}" Bert BertINT8 th_utils quantize_weight) +target_link_libraries(th_bert PRIVATE "${TORCH_LIBRARIES}" Bert BertINT8 nccl_utils th_utils quantize_weight) diff --git a/src/fastertransformer/th_op/bert/WeightQuantizeOp.cc b/src/fastertransformer/th_op/bert/WeightQuantizeOp.cc index cd065a1fc..9fedd1fd2 100644 --- a/src/fastertransformer/th_op/bert/WeightQuantizeOp.cc +++ b/src/fastertransformer/th_op/bert/WeightQuantizeOp.cc @@ -28,7 +28,7 @@ void compressInt8Matrix(void* output, const void* input, const int m, const int cusparseLtHandle_t _cusparseLtHandle; CHECK_CUSPARSE(cusparseLtInit(&_cusparseLtHandle)); cusparseLtMatDescriptor_t matA; - unsigned alignment = 16; + unsigned alignment = 16; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( &_cusparseLtHandle, &matA, m, k, k, alignment, CUDA_R_8I, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT)) CHECK_CUSPARSE(cusparseLtSpMMACompress2( @@ -68,18 +68,18 @@ Tensor weight_quantize(Tensor weight, Tensor quant_max, bool sparse) TORCH_CHECK(quant_max.dtype() == torch::kFloat32, "quant_max dtype should be float32"); TORCH_CHECK(quant_max.numel() == n, "quant_max wrong shape"); - const float* weight_ = get_ptr(weight); + const float* weight_ = get_ptr(weight); const float* quant_max_ = get_ptr(quant_max); - auto output = torch::empty({k * n}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); + auto output = torch::empty({k * n}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); int8_t* transform_out = get_ptr(output); auto stream = at::cuda::getCurrentCUDAStream().stream(); #ifdef SPARSITY_ENABLED if (sparse) { - int format = 0; - auto tmp = torch::empty({k * n}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + int format = 0; + auto tmp = torch::empty({k * n}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); int8_t* tmp_out = get_ptr(tmp); fastertransformer::invokeQuantizeWeight(tmp_out, weight_, quant_max_, n, k, format, stream); compressInt8Matrix(transform_out, tmp_out, n, k, stream); diff --git a/src/fastertransformer/th_op/decoder/DecoderOp.cc b/src/fastertransformer/th_op/decoder/DecoderOp.cc index 07dedf460..565dd7ba6 100644 --- a/src/fastertransformer/th_op/decoder/DecoderOp.cc +++ b/src/fastertransformer/th_op/decoder/DecoderOp.cc @@ -41,11 +41,11 @@ FasterTransformerDecoder::FasterTransformerDecoder(th::Tensor self_layernorm_gam th::Tensor inter_bias, th::Tensor output_kernel, th::Tensor output_bias, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - int64_t layer_num, - int64_t mem_hidden_dim): + int64_t head_num, + int64_t head_size, + int64_t inter_size, + int64_t layer_num, + int64_t mem_hidden_dim): _st(self_kernel_q.scalar_type()), weights{self_layernorm_gamma, self_layernorm_beta, self_kernel_q, self_bias_q, self_output_kernel, self_output_bias, cross_layernorm_gamma, cross_layernorm_beta, @@ -65,10 +65,16 @@ FasterTransformerDecoder::FasterTransformerDecoder(th::Tensor self_layernorm_gam case at::ScalarType::Half: ftdecoder = new FTDecoder(head_num, head_size, inter_size, layer_num, mem_hidden_dim, weights); break; +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + ftdecoder = + new FTDecoder<__nv_bfloat16>(head_num, head_size, inter_size, layer_num, mem_hidden_dim, weights); + break; +#endif default: throw std::runtime_error("Wrong Tensor type."); } - head_info = torch::empty({4}, torch::dtype(torch::kInt64)); + head_info = torch::empty({4}, torch::dtype(torch::kInt64)); head_info[0] = head_num; head_info[1] = head_size; head_info[2] = layer_num; @@ -81,7 +87,7 @@ FasterTransformerDecoder::~FasterTransformerDecoder() delete ftdecoder; } -std::vector FasterTransformerDecoder::forward(int64_t step, +std::vector FasterTransformerDecoder::forward(int64_t step, th::Tensor from_tensor, th::Tensor memory_tensor, th::Tensor memory_sequence_length, @@ -176,10 +182,10 @@ static auto fasterTransformerDecoderTHS = return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int64_t head_num = state[22][0].item().to(); - int64_t head_size = state[22][1].item().to(); - int64_t layer_num = state[22][2].item().to(); - int64_t inter_size = state[22][3].item().to(); + int64_t head_num = state[22][0].item().to(); + int64_t head_size = state[22][1].item().to(); + int64_t layer_num = state[22][2].item().to(); + int64_t inter_size = state[22][3].item().to(); int64_t mem_hidden_dim = state[22][4].item().to(); return c10::make_intrusive(state[0], state[1], diff --git a/src/fastertransformer/th_op/decoder/DecoderOp.h b/src/fastertransformer/th_op/decoder/DecoderOp.h index 032992596..dd4ade97b 100644 --- a/src/fastertransformer/th_op/decoder/DecoderOp.h +++ b/src/fastertransformer/th_op/decoder/DecoderOp.h @@ -24,8 +24,8 @@ namespace torch_ext { class IFDecoder { public: virtual ~IFDecoder() {} - virtual void forward(size_t batch_size, - size_t step, + virtual void forward(size_t batch_size, + size_t step, th::Tensor& from_tensor, th::Tensor& memory_tensor, th::Tensor& memory_sequence_length, @@ -40,11 +40,11 @@ class IFDecoder { template class FTDecoder: public IFDecoder { public: - FTDecoder(int head_num, - int head_size, - int inter_size, - int layer_num, - int mem_hidden_dim, + FTDecoder(int head_num, + int head_size, + int inter_size, + int layer_num, + int mem_hidden_dim, const std::vector& w): _head_num(head_num), _head_size(head_size), @@ -63,7 +63,7 @@ class FTDecoder: public IFDecoder { for (int i = 0; i < _layer_num; ++i) { decoder_layer_weights[i].pre_layernorm_weights.gamma = get_ptr(_weights[0]) + i * hidden_dim; - decoder_layer_weights[i].pre_layernorm_weights.beta = get_ptr(_weights[1]) + i * hidden_dim; + decoder_layer_weights[i].pre_layernorm_weights.beta = get_ptr(_weights[1]) + i * hidden_dim; decoder_layer_weights[i].self_attention_weights.query_weight.kernel = get_ptr(_weights[2]) + i * hidden_dim * 3 * hidden_dim; decoder_layer_weights[i].self_attention_weights.query_weight.bias = @@ -73,7 +73,7 @@ class FTDecoder: public IFDecoder { decoder_layer_weights[i].self_attention_weights.attention_output_weight.bias = get_ptr(_weights[5]) + i * hidden_dim; decoder_layer_weights[i].self_attn_layernorm_weights.gamma = get_ptr(_weights[6]) + i * hidden_dim; - decoder_layer_weights[i].self_attn_layernorm_weights.beta = get_ptr(_weights[7]) + i * hidden_dim; + decoder_layer_weights[i].self_attn_layernorm_weights.beta = get_ptr(_weights[7]) + i * hidden_dim; decoder_layer_weights[i].cross_attention_weights.query_weight.kernel = get_ptr(_weights[8]) + i * hidden_dim * hidden_dim; decoder_layer_weights[i].cross_attention_weights.key_weight.kernel = @@ -91,7 +91,7 @@ class FTDecoder: public IFDecoder { decoder_layer_weights[i].cross_attention_weights.attention_output_weight.bias = get_ptr(_weights[15]) + i * hidden_dim; decoder_layer_weights[i].cross_attn_layernorm_weights.gamma = get_ptr(_weights[16]) + i * hidden_dim; - decoder_layer_weights[i].cross_attn_layernorm_weights.beta = get_ptr(_weights[17]) + i * hidden_dim; + decoder_layer_weights[i].cross_attn_layernorm_weights.beta = get_ptr(_weights[17]) + i * hidden_dim; decoder_layer_weights[i].ffn_weights.intermediate_weight.kernel = get_ptr(_weights[18]) + i * hidden_dim * _inter_size; decoder_layer_weights[i].ffn_weights.intermediate_weight.bias = get_ptr(_weights[19]) + i * _inter_size; @@ -108,8 +108,8 @@ class FTDecoder: public IFDecoder { delete cublas_wrapper_mutex_; } - void forward(size_t batch_size, - size_t step, + void forward(size_t batch_size, + size_t step, th::Tensor& from_tensor, th::Tensor& memory_tensor, th::Tensor& memory_sequence_length, @@ -120,16 +120,21 @@ class FTDecoder: public IFDecoder { th::Tensor& memory_cache_keys_tensor, th::Tensor& memory_cache_values_tensor) override { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t _cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(_cublasHandle, stream); - ft::Allocator* allocator = new ft::Allocator(); - ft::cublasMMWrapper* cublas_wrapper = new ft::cublasMMWrapper( + ft::Allocator* allocator = new ft::Allocator(); + ft::cublasMMWrapper* cublas_wrapper = new ft::cublasMMWrapper( _cublasHandle, _cublasltHandle, stream, cublas_algo_map_, cublas_wrapper_mutex_, allocator); if (std::is_same::value) { cublas_wrapper->setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } @@ -137,7 +142,7 @@ class FTDecoder: public IFDecoder { ft::Decoder* decoder = new ft::Decoder( batch_size, _head_num, _head_size, _inter_size, _layer_num, stream, cublas_wrapper, allocator, true); - int tmp_step = step + 1; + int tmp_step = step + 1; std::vector input_tensors = std::vector{ convert_tensor(from_tensor), convert_tensor(memory_tensor), @@ -173,15 +178,15 @@ class FTDecoder: public IFDecoder { } private: - const int _head_num; - const int _head_size; - const int _inter_size; - std::vector _weights; - const int _layer_num; - const int _mem_hidden_dim; - cublasLtHandle_t _cublasltHandle; - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; + const int _head_num; + const int _head_size; + const int _inter_size; + std::vector _weights; + const int _layer_num; + const int _mem_hidden_dim; + cublasLtHandle_t _cublasltHandle; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; std::vector> decoder_layer_weights; }; @@ -209,15 +214,15 @@ class FasterTransformerDecoder: public th::jit::CustomClassHolder { th::Tensor inter_bias, th::Tensor output_kernel, th::Tensor output_bias, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - int64_t layer_num, - int64_t mem_hidden_dim); + int64_t head_num, + int64_t head_size, + int64_t inter_size, + int64_t layer_num, + int64_t mem_hidden_dim); ~FasterTransformerDecoder(); - std::vector forward(int64_t step, + std::vector forward(int64_t step, th::Tensor from_tensor, th::Tensor memory_tensor, th::Tensor memory_sequence_length, @@ -230,9 +235,9 @@ class FasterTransformerDecoder: public th::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType _st; - IFDecoder* ftdecoder; - th::Tensor head_info; + const at::ScalarType _st; + IFDecoder* ftdecoder; + th::Tensor head_info; std::vector weights; }; diff --git a/src/fastertransformer/th_op/decoding/DecodingOp.cc b/src/fastertransformer/th_op/decoding/DecodingOp.cc index ba064c0ec..be2b49987 100644 --- a/src/fastertransformer/th_op/decoding/DecodingOp.cc +++ b/src/fastertransformer/th_op/decoding/DecodingOp.cc @@ -21,20 +21,20 @@ namespace th = torch; namespace torch_ext { using torch::Tensor; -FasterTransformerDecoding::FasterTransformerDecoding(int64_t head_num, - int64_t size_per_head, - int64_t inter_size, - int64_t mem_hidden_dim, - int64_t layer_num, - int64_t vocab_size, - int64_t start_id, - int64_t end_id, - double beam_search_diversity_rate, - int64_t top_k, - double top_p, - double temperature, - double len_penalty, - double repetition_penalty, +FasterTransformerDecoding::FasterTransformerDecoding(int64_t head_num, + int64_t size_per_head, + int64_t inter_size, + int64_t mem_hidden_dim, + int64_t layer_num, + int64_t vocab_size, + int64_t start_id, + int64_t end_id, + double beam_search_diversity_rate, + int64_t top_k, + double top_p, + double temperature, + double len_penalty, + double repetition_penalty, th::Tensor self_layernorm_gamma, th::Tensor self_layernorm_beta, th::Tensor self_kernel_q, @@ -134,11 +134,30 @@ FasterTransformerDecoding::FasterTransformerDecoding(int64_t head_num, (float)repetition_penalty, weights); break; +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + ftdecoding = new torch_ext::FTDecoding<__nv_bfloat16>(head_num, + size_per_head, + inter_size, + mem_hidden_dim, + layer_num, + vocab_size, + start_id, + end_id, + (float)beam_search_diversity_rate, + top_k, + (float)top_p, + (float)temperature, + (float)len_penalty, + (float)repetition_penalty, + weights); + break; +#endif default: throw std::runtime_error("Wrong Tensor type."); } - int_info_ = torch::empty({11}, torch::dtype(torch::kInt64)); - float_info_ = torch::empty({5}, torch::dtype(torch::kFloat64)); + int_info_ = torch::empty({11}, torch::dtype(torch::kInt64)); + float_info_ = torch::empty({5}, torch::dtype(torch::kFloat64)); int_info_[0] = head_num; int_info_[1] = size_per_head; int_info_[2] = inter_size; @@ -161,8 +180,8 @@ FasterTransformerDecoding::~FasterTransformerDecoding() delete ftdecoding; } -std::vector FasterTransformerDecoding::forward(int64_t beam_width, - int64_t max_seq_len, +std::vector FasterTransformerDecoding::forward(int64_t beam_width, + int64_t max_seq_len, th::Tensor memory, th::Tensor memory_seq_lens) { @@ -171,7 +190,7 @@ std::vector FasterTransformerDecoding::forward(int64_t beam_width, CHECK_CONTIGUOUS(memory_seq_lens); TORCH_CHECK(memory_seq_lens.dtype() == torch::kInt32, "mem_seq_lens dtype should be int32"); - int batch_size = (int)(memory.size(0) / beam_width); + int batch_size = (int)(memory.size(0) / beam_width); auto output_ids = torch::empty({batch_size * beam_width * max_seq_len}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto parent_ids = torch::empty({batch_size * beam_width * max_seq_len}, @@ -247,22 +266,22 @@ static auto fasterTransformerDecodingTHS = return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int head_num = state[28][0].item().to(); - int size_per_head = state[28][1].item().to(); - int inter_size = state[28][2].item().to(); + int head_num = state[28][0].item().to(); + int size_per_head = state[28][1].item().to(); + int inter_size = state[28][2].item().to(); int mem_hidden_dim = state[28][3].item().to(); - int layer_num = state[28][4].item().to(); - int vocab_size = state[28][5].item().to(); - int start_id = state[28][6].item().to(); - int end_id = state[28][7].item().to(); - int top_k = state[28][8].item().to(); + int layer_num = state[28][4].item().to(); + int vocab_size = state[28][5].item().to(); + int start_id = state[28][6].item().to(); + int end_id = state[28][7].item().to(); + int top_k = state[28][8].item().to(); // TODO(bhsueh) Here may have bugs double beam_search_diversity_rate = state[33][0].item().to(); - double top_p = state[33][1].item().to(); - double temperature = state[33][2].item().to(); - double len_penalty = state[33][3].item().to(); - double repetition_penalty = state[33][4].item().to(); + double top_p = state[33][1].item().to(); + double temperature = state[33][2].item().to(); + double len_penalty = state[33][3].item().to(); + double repetition_penalty = state[33][4].item().to(); return c10::make_intrusive(head_num, size_per_head, diff --git a/src/fastertransformer/th_op/decoding/DecodingOp.h b/src/fastertransformer/th_op/decoding/DecodingOp.h index 5b45581e0..152113721 100644 --- a/src/fastertransformer/th_op/decoding/DecodingOp.h +++ b/src/fastertransformer/th_op/decoding/DecodingOp.h @@ -25,8 +25,8 @@ namespace torch_ext { class IFTDecoding { public: virtual ~IFTDecoding() {} - virtual void forward(size_t beam_width, - size_t max_seq_len, + virtual void forward(size_t beam_width, + size_t max_seq_len, th::Tensor memory, th::Tensor memory_seq_lens, th::Tensor output_ids, @@ -37,20 +37,20 @@ class IFTDecoding { template class FTDecoding: public IFTDecoding { public: - FTDecoding(int head_num, - int size_per_head, - int inter_size, - int mem_hidden_dim, - int layer_num, - int vocab_size, - int start_id, - int end_id, - float beam_search_diversity_rate, - int top_k, - float top_p, - float temperature, - float len_penalty, - float repetition_penalty, + FTDecoding(int head_num, + int size_per_head, + int inter_size, + int mem_hidden_dim, + int layer_num, + int vocab_size, + int start_id, + int end_id, + float beam_search_diversity_rate, + int top_k, + float top_p, + float temperature, + float len_penalty, + float repetition_penalty, const std::vector& w): head_num_(head_num), size_per_head_(size_per_head), @@ -69,7 +69,7 @@ class FTDecoding: public IFTDecoding { _weights(w) { ft::check_cuda_error(cublasLtCreate(&cublasltHandle_)); - cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); + cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); cublas_wrapper_mutex_ = new std::mutex(); decoding_weights.decoder_layer_weights.resize(layer_num_); @@ -121,12 +121,12 @@ class FTDecoding: public IFTDecoding { decoding_weights.decoder_layer_weights[i].ffn_weights.output_weight.bias = get_ptr(_weights[21]) + i * hidden_dim; } - decoding_weights.post_decoder_layernorm.gamma = get_ptr(_weights[22]); - decoding_weights.post_decoder_layernorm.beta = get_ptr(_weights[23]); - decoding_weights.pre_decoder_embedding_table = get_ptr(_weights[24]); - decoding_weights.position_encoding_table = get_ptr(_weights[25]); + decoding_weights.post_decoder_layernorm.gamma = get_ptr(_weights[22]); + decoding_weights.post_decoder_layernorm.beta = get_ptr(_weights[23]); + decoding_weights.pre_decoder_embedding_table = get_ptr(_weights[24]); + decoding_weights.position_encoding_table = get_ptr(_weights[25]); decoding_weights.post_decoder_embedding.kernel = get_ptr(_weights[26]); - decoding_weights.post_decoder_embedding.bias = get_ptr(_weights[27]); + decoding_weights.post_decoder_embedding.bias = get_ptr(_weights[27]); ft::check_cuda_error(cudaGetDeviceProperties(&prop_, 0)); } @@ -138,32 +138,37 @@ class FTDecoding: public IFTDecoding { delete cublas_wrapper_mutex_; } - void forward(size_t beam_width, - size_t max_seq_len, + void forward(size_t beam_width, + size_t max_seq_len, th::Tensor memory, th::Tensor memory_seq_lens, th::Tensor output_ids, th::Tensor parent_ids, th::Tensor sequence_lengths) override { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(cublasHandle, stream); - ft::Allocator allocator = ft::Allocator(); - ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper( + ft::Allocator allocator = ft::Allocator(); + ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper( cublasHandle, cublasltHandle_, stream, cublas_algo_map_, cublas_wrapper_mutex_, &allocator); if (std::is_same::value) { cublas_wrapper.setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } - const size_t batch_size = (size_t)memory.size(0) / beam_width; + const size_t batch_size = (size_t)memory.size(0) / beam_width; const size_t mem_max_seq_len = (size_t)memory.size(1); - ft::Decoding decoding = ft::Decoding(batch_size, + ft::Decoding decoding = ft::Decoding(batch_size, max_seq_len, mem_max_seq_len, beam_width, @@ -185,7 +190,7 @@ class FTDecoding: public IFTDecoding { &allocator, false, &prop_); - ft::DataType data_type = ft::getTensorType(); + ft::DataType data_type = ft::getTensorType(); std::vector input_tensors = std::vector{ ft::Tensor{ft::MEMORY_GPU, data_type, @@ -213,45 +218,45 @@ class FTDecoding: public IFTDecoding { } private: - const int head_num_; - const int size_per_head_; - const int inter_size_; - const int mem_hidden_dim_; - const int layer_num_; - const int vocab_size_; - const int start_id_; - const int end_id_; + const int head_num_; + const int size_per_head_; + const int inter_size_; + const int mem_hidden_dim_; + const int layer_num_; + const int vocab_size_; + const int start_id_; + const int end_id_; const float beam_search_diversity_rate_; - const int top_k_; + const int top_k_; const float top_p_; const float temperature_; const float len_penalty_; const float repetition_penalty_; std::vector _weights; - cublasLtHandle_t cublasltHandle_; - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; - struct cudaDeviceProp prop_; - ft::DecodingWeight decoding_weights; + cublasLtHandle_t cublasltHandle_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; + struct cudaDeviceProp prop_; + ft::DecodingWeight decoding_weights; }; class FasterTransformerDecoding: public torch::jit::CustomClassHolder { public: - FasterTransformerDecoding(int64_t head_num, - int64_t size_per_head, - int64_t inter_size, - int64_t mem_hidden_dim, - int64_t layer_num, - int64_t vocab_size, - int64_t start_id, - int64_t end_id, - double beam_search_diversity_rate, - int64_t top_k, - double top_p, - double temperature, - double len_penalty, - double repetition_penalty, + FasterTransformerDecoding(int64_t head_num, + int64_t size_per_head, + int64_t inter_size, + int64_t mem_hidden_dim, + int64_t layer_num, + int64_t vocab_size, + int64_t start_id, + int64_t end_id, + double beam_search_diversity_rate, + int64_t top_k, + double top_p, + double temperature, + double len_penalty, + double repetition_penalty, th::Tensor self_layernorm_gamma, th::Tensor self_layernorm_beta, th::Tensor self_kernel_q, @@ -289,10 +294,10 @@ class FasterTransformerDecoding: public torch::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType _st; + const at::ScalarType _st; torch_ext::IFTDecoding* ftdecoding; - th::Tensor int_info_; - th::Tensor float_info_; + th::Tensor int_info_; + th::Tensor float_info_; std::vector weights; }; diff --git a/src/fastertransformer/th_op/decoding/GatherTreeOp.cc b/src/fastertransformer/th_op/decoding/GatherTreeOp.cc index 24925d48d..4cfe167ef 100644 --- a/src/fastertransformer/th_op/decoding/GatherTreeOp.cc +++ b/src/fastertransformer/th_op/decoding/GatherTreeOp.cc @@ -31,11 +31,11 @@ gather_tree(th::Tensor step_ids, th::Tensor parent_ids, th::Tensor max_sequence_ CHECK_TH_CUDA(max_sequence_lengths); CHECK_CONTIGUOUS(max_sequence_lengths); TORCH_CHECK(max_sequence_lengths.dtype() == th::kInt32, "max_sequence_lengths dtype should be int32"); - int max_step = step_ids.size(0); - int batch_size = step_ids.size(1); - int beam_width = step_ids.size(2); - auto beams = th::empty_like(step_ids); - auto stream = at::cuda::getCurrentCUDAStream().stream(); + int max_step = step_ids.size(0); + int batch_size = step_ids.size(1); + int beam_width = step_ids.size(2); + auto beams = th::empty_like(step_ids); + auto stream = at::cuda::getCurrentCUDAStream().stream(); fastertransformer::invokeGatherTree(torch_ext::get_ptr(beams), torch_ext::get_ptr(max_sequence_lengths), diff --git a/src/fastertransformer/th_op/encoder/EncoderOp.cc b/src/fastertransformer/th_op/encoder/EncoderOp.cc index 3734e2e50..501ec8e9f 100644 --- a/src/fastertransformer/th_op/encoder/EncoderOp.cc +++ b/src/fastertransformer/th_op/encoder/EncoderOp.cc @@ -37,14 +37,14 @@ FasterTransformerEncoder::FasterTransformerEncoder(th::Tensor pre_attr_layernorm th::Tensor output_bias, th::Tensor post_transformer_layernorm_gamma, th::Tensor post_transformer_layernorm_beta, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - bool remove_padding, - int64_t layer_num, - bool allow_gemm_test, - bool sparse, - double q_scaling): + int64_t head_num, + int64_t head_size, + int64_t inter_size, + bool remove_padding, + int64_t layer_num, + bool allow_gemm_test, + bool sparse, + double q_scaling): _st(q_kernel.scalar_type()), _remove_padding(remove_padding), weights{pre_attr_layernorm_gamma, @@ -94,18 +94,24 @@ FasterTransformerEncoder::FasterTransformerEncoder(th::Tensor pre_attr_layernorm ftencoder = new FTEncoder( head_num, head_size, inter_size, layer_num, allow_gemm_test, sparse, q_scaling, weights); break; +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + ftencoder = new FTEncoder<__nv_bfloat16>( + head_num, head_size, inter_size, layer_num, allow_gemm_test, sparse, q_scaling, weights); + break; +#endif default: throw std::runtime_error("Wrong Tensor type."); } - head_info = torch::empty({7}, torch::dtype(torch::kInt64)); - head_info[0] = head_num; - head_info[1] = head_size; - head_info[2] = (int64_t)remove_padding; - head_info[3] = layer_num; - head_info[4] = (int64_t)allow_gemm_test; - head_info[5] = (int64_t)sparse; - head_info[6] = inter_size; - scaling_info = torch::empty({1}, torch::dtype(torch::kFloat64)); + head_info = torch::empty({7}, torch::dtype(torch::kInt64)); + head_info[0] = head_num; + head_info[1] = head_size; + head_info[2] = (int64_t)remove_padding; + head_info[3] = layer_num; + head_info[4] = (int64_t)allow_gemm_test; + head_info[5] = (int64_t)sparse; + head_info[6] = inter_size; + scaling_info = torch::empty({1}, torch::dtype(torch::kFloat64)); scaling_info[0] = (double)q_scaling; } @@ -121,7 +127,7 @@ th::Tensor FasterTransformerEncoder::forward(th::Tensor input, th::Tensor sequen CHECK_CONTIGUOUS(sequence_lengths); TORCH_CHECK(sequence_lengths.dtype() == torch::kInt32, "sequence_lengths dtype should be int32"); size_t batch_size = (size_t)input.size(0); - size_t seq_len = (size_t)input.size(1); + size_t seq_len = (size_t)input.size(1); auto output = torch::empty_like(input); ftencoder->forward(batch_size, seq_len, input, sequence_lengths, output, _remove_padding); @@ -176,14 +182,14 @@ static auto fasterTransformerEncoderTHS = return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int64_t head_num = state[18][0].item().to(); - int64_t head_size = state[18][1].item().to(); - bool remove_padding = (bool)(state[18][2].item().to()); - int64_t layer_num = state[18][3].item().to(); - bool allow_gemm_test = (bool)(state[18][4].item().to()); - bool sparse = (bool)(state[18][5].item().to()); - int64_t inter_size = state[18][6].item().to(); - double q_scaling = state[19][0].item().to(); + int64_t head_num = state[18][0].item().to(); + int64_t head_size = state[18][1].item().to(); + bool remove_padding = (bool)(state[18][2].item().to()); + int64_t layer_num = state[18][3].item().to(); + bool allow_gemm_test = (bool)(state[18][4].item().to()); + bool sparse = (bool)(state[18][5].item().to()); + int64_t inter_size = state[18][6].item().to(); + double q_scaling = state[19][0].item().to(); return c10::make_intrusive(state[0], state[1], state[2], diff --git a/src/fastertransformer/th_op/encoder/EncoderOp.h b/src/fastertransformer/th_op/encoder/EncoderOp.h index 4331c36da..eb65e73ca 100644 --- a/src/fastertransformer/th_op/encoder/EncoderOp.h +++ b/src/fastertransformer/th_op/encoder/EncoderOp.h @@ -24,24 +24,24 @@ namespace torch_ext { class IFEncoder { public: virtual ~IFEncoder() {} - virtual void forward(size_t batch_size, - size_t seq_len, + virtual void forward(size_t batch_size, + size_t seq_len, th::Tensor& input, th::Tensor& sequence_lengths, th::Tensor& output, - bool removing_padding) = 0; + bool removing_padding) = 0; }; template class FTEncoder: public IFEncoder { public: - FTEncoder(int head_num, - int head_size, - int inter_size, - int layer_num, - bool allow_gemm_test, - bool sparse, - float q_scaling, + FTEncoder(int head_num, + int head_size, + int inter_size, + int layer_num, + bool allow_gemm_test, + bool sparse, + float q_scaling, const std::vector& w): _head_num(head_num), _head_size(head_size), @@ -65,8 +65,8 @@ class FTEncoder: public IFEncoder { } #endif std::string sp_config_fname = sparse ? "spgemm_config.in" : ""; - cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", sp_config_fname); - cublas_wrapper_mutex_ = new std::mutex(); + cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", sp_config_fname); + cublas_wrapper_mutex_ = new std::mutex(); encoder_weights.bert_layer_weights.clear(); encoder_weights.bert_layer_weights.resize(_layer_num); @@ -105,10 +105,10 @@ class FTEncoder: public IFEncoder { get_ptr(_weights[15]) + hidden_dim * i; } encoder_weights.post_transformer_layernorm_weights.gamma = get_ptr(_weights[16]); - encoder_weights.post_transformer_layernorm_weights.beta = get_ptr(_weights[17]); + encoder_weights.post_transformer_layernorm_weights.beta = get_ptr(_weights[17]); #ifdef SPARSITY_ENABLED if (sparse) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t _cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(_cublasHandle, stream); ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(_cublasHandle, @@ -137,18 +137,18 @@ class FTEncoder: public IFEncoder { delete cublas_wrapper_mutex_; } - void forward(size_t batch_size, - size_t seq_len, + void forward(size_t batch_size, + size_t seq_len, th::Tensor& input, th::Tensor& sequence_lengths, th::Tensor& output, - bool removing_padding) override + bool removing_padding) override { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t _cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(_cublasHandle, stream); ft::Allocator* allocator = new ft::Allocator(); - ft::cublasMMWrapper* cublas_wrapper = + ft::cublasMMWrapper* cublas_wrapper = #ifdef SPARSITY_ENABLED new ft::cublasMMWrapper(_cublasHandle, _cublasltHandle, @@ -165,6 +165,11 @@ class FTEncoder: public IFEncoder { if (std::is_same::value) { cublas_wrapper->setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } @@ -182,13 +187,13 @@ class FTEncoder: public IFEncoder { stream, cublas_wrapper, allocator, - true, + false, attention_type, _sparse, ft::ActivationType::Relu, ft::LayerNormType::pre_layernorm); - ft::DataType data_type = ft::getTensorType(); + ft::DataType data_type = ft::getTensorType(); std::vector input_tensors = std::vector{ ft::Tensor{ft::MEMORY_GPU, data_type, @@ -220,21 +225,21 @@ class FTEncoder: public IFEncoder { } private: - const int _head_num; - const int _head_size; - const int _inter_size; + const int _head_num; + const int _head_size; + const int _inter_size; std::vector _weights; - const int _layer_num; - bool _sparse; - const float _q_scaling; - int sm_; - cublasLtHandle_t _cublasltHandle; + const int _layer_num; + bool _sparse; + const float _q_scaling; + int sm_; + cublasLtHandle_t _cublasltHandle; #ifdef SPARSITY_ENABLED cusparseLtHandle_t _cusparseLtHandle; #endif - std::mutex* cublas_wrapper_mutex_; + std::mutex* cublas_wrapper_mutex_; ft::cublasAlgoMap* cublas_algo_map_; - ft::BertWeight encoder_weights; + ft::BertWeight encoder_weights; }; class FasterTransformerEncoder: public th::jit::CustomClassHolder { @@ -257,14 +262,14 @@ class FasterTransformerEncoder: public th::jit::CustomClassHolder { th::Tensor output_bias, th::Tensor post_transformer_layernorm_gamma, th::Tensor post_transformer_layernorm_beta, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - bool remove_padding, - int64_t layer_num, - bool allow_gemm_test, - bool sparse, - double q_scaling); + int64_t head_num, + int64_t head_size, + int64_t inter_size, + bool remove_padding, + int64_t layer_num, + bool allow_gemm_test, + bool sparse, + double q_scaling); ~FasterTransformerEncoder(); @@ -273,13 +278,13 @@ class FasterTransformerEncoder: public th::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType _st; - bool _remove_padding; - IFEncoder* ftencoder; - th::Tensor head_info; - th::Tensor scaling_info; + const at::ScalarType _st; + bool _remove_padding; + IFEncoder* ftencoder; + th::Tensor head_info; + th::Tensor scaling_info; std::vector weights; - bool _allow_gemm_test; + bool _allow_gemm_test; }; } // namespace torch_ext diff --git a/src/fastertransformer/th_op/gpt/GptOp.cc b/src/fastertransformer/th_op/gpt/GptOp.cc index 8690a8b9d..dedb87c2e 100644 --- a/src/fastertransformer/th_op/gpt/GptOp.cc +++ b/src/fastertransformer/th_op/gpt/GptOp.cc @@ -17,16 +17,23 @@ #include "src/fastertransformer/th_op/gpt/GptOp.h" namespace th = torch; +namespace ft = fastertransformer; namespace torch_ext { -GptOp::GptOp(const int64_t head_num, - const int64_t size_per_head, - const int64_t inter_size, - const int64_t layer_num, - const int64_t vocab_size, - const int64_t start_id, - const int64_t end_id, - const bool sparse, +GptOp::GptOp(const int64_t head_num, + const int64_t size_per_head, + const int64_t inter_size, + const int64_t layer_num, + const int64_t vocab_size, + const int64_t start_id, + const int64_t end_id, + const bool sparse, + const double layernorm_eps, + const std::string layernorm_type, + const std::string activation_type, + const bool has_post_decoder_layernorm, + const bool has_adapters, + const int64_t adapter_inter_size, const std::vector weights): st_(weights[0].scalar_type()) { @@ -34,13 +41,21 @@ GptOp::GptOp(const int64_t head_num, CHECK_INPUT(t, st_); } + ft::gptVariantParams gpt_variant_params{(float)layernorm_eps, + ft::getLayerNormType(layernorm_type), + ft::getActivationType(activation_type), + has_post_decoder_layernorm, + has_adapters, + (size_t)adapter_inter_size}; + switch (st_) { case at::ScalarType::Float: ftgpt = new FTGpt((size_t)head_num, (size_t)size_per_head, (size_t)inter_size, (size_t)layer_num, - vocab_size, + (size_t)vocab_size, + gpt_variant_params, start_id, end_id, sparse, @@ -52,6 +67,7 @@ GptOp::GptOp(const int64_t head_num, (size_t)inter_size, (size_t)layer_num, (size_t)vocab_size, + gpt_variant_params, start_id, end_id, sparse, @@ -64,6 +80,7 @@ GptOp::GptOp(const int64_t head_num, (size_t)inter_size, (size_t)layer_num, (size_t)vocab_size, + gpt_variant_params, start_id, end_id, sparse, @@ -80,18 +97,18 @@ GptOp::~GptOp() delete ftgpt; } -std::vector GptOp::forward(th::Tensor input_ids, - th::Tensor input_lengths, - const int64_t output_len, - const int64_t beam_width, - const int64_t top_k, - const double top_p, - const double beam_search_diversity_rate, - const double temperature, - const double len_penalty, - const double repetition_penalty, - const int64_t random_seed, - const int64_t return_cum_log_probs) +std::vector GptOp::forward(th::Tensor input_ids, + th::Tensor input_lengths, + const int64_t output_len, + th::optional beam_width_opt, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional return_cum_log_probs_opt) { CHECK_TH_CUDA(input_ids); CHECK_CONTIGUOUS(input_ids); @@ -99,18 +116,21 @@ std::vector GptOp::forward(th::Tensor input_ids, CHECK_TH_CUDA(input_lengths); CHECK_CONTIGUOUS(input_lengths); TORCH_CHECK(input_lengths.dtype() == torch::kInt32, "input_lengths dtype should be int32"); - TORCH_CHECK(return_cum_log_probs == 0 || return_cum_log_probs == 1 || return_cum_log_probs == 2, - "return_cum_log_probs should be" - " 0 (no return cum_log_probs), " - " 1 (the cumulative log probs of generated sequences), or" - " 2 (the cumulative log probs of sequences).") + int64_t return_cum_log_probs = return_cum_log_probs_opt.has_value() ? (int64_t)return_cum_log_probs_opt.value() : 0; + if (return_cum_log_probs_opt.has_value()) { + TORCH_CHECK(return_cum_log_probs == 0 || return_cum_log_probs == 1 || return_cum_log_probs == 2, + "return_cum_log_probs should be" + " 0 (no return cum_log_probs), " + " 1 (the cumulative log probs of generated sequences), or" + " 2 (the cumulative log probs of sequences).") + } - const int batch_size = input_ids.size(0); - const int max_input_length = input_ids.size(1); - const int total_request_output_len = max_input_length + output_len; - th::Tensor output_ids = torch::empty({batch_size, beam_width, total_request_output_len}, - torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); - th::Tensor parent_ids = torch::empty({total_request_output_len, batch_size, beam_width}, + const int beam_width = beam_width_opt.has_value() ? (int)beam_width_opt.value() : 1; + + const int batch_size = input_ids.size(0); + const int max_input_length = input_ids.size(1); + const int total_request_output_len = max_input_length + output_len; + th::Tensor output_ids = torch::empty({batch_size, beam_width, total_request_output_len}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); th::Tensor sequence_lengths = torch::empty({batch_size, beam_width}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); @@ -120,19 +140,18 @@ std::vector GptOp::forward(th::Tensor input_ids, ftgpt->forward(input_ids, input_lengths, output_ids, - parent_ids, sequence_lengths, cum_log_probs, (const size_t)output_len, (const size_t)beam_width, - (const size_t)top_k, - (const float)top_p, - (const float)beam_search_diversity_rate, - (const float)temperature, - (const float)len_penalty, - (const float)repetition_penalty, - (const unsigned long long int)random_seed, - return_cum_log_probs); + top_k_opt, + top_p_opt, + beam_search_diversity_rate_opt, + temperature_opt, + len_penalty_opt, + repetition_penalty_opt, + random_seed_opt, + return_cum_log_probs_opt); if (return_cum_log_probs > 0) { return std::vector{output_ids, sequence_lengths, cum_log_probs}; } @@ -142,7 +161,24 @@ std::vector GptOp::forward(th::Tensor input_ids, } // namespace torch_ext static auto fasterTransformerGptTHS = +#ifdef LEGACY_THS + torch::jit::class_("FasterTransformerGptOp") +#else torch::jit::class_("FasterTransformer", "GptOp") - .def(torch::jit:: - init>()) +#endif + .def(torch::jit::init>()) .def("forward", &torch_ext::GptOp::forward); diff --git a/src/fastertransformer/th_op/gpt/GptOp.h b/src/fastertransformer/th_op/gpt/GptOp.h index 2f931cdd9..89b39c734 100644 --- a/src/fastertransformer/th_op/gpt/GptOp.h +++ b/src/fastertransformer/th_op/gpt/GptOp.h @@ -28,41 +28,42 @@ using std::vector; class IFGpt { public: virtual ~IFGpt() {} - virtual void forward(th::Tensor& input_ids, - th::Tensor& input_lengths, - th::Tensor& output_ids, - th::Tensor& parent_ids, - th::Tensor& sequence_lengths, - th::Tensor& cum_log_probs, - const size_t request_output_len, - const size_t beam_width, - const size_t top_k, - const float top_p, - const float beam_search_diversity_rate, - const float temperature, - const float len_penalty, - const float repetition_penalty, - const unsigned long long int random_seed, - const int return_cum_log_probs = 0) = 0; + virtual void forward(th::Tensor& input_ids, + th::Tensor& input_lengths, + th::Tensor& output_ids, + th::Tensor& sequence_lengths, + th::Tensor& cum_log_probs, + const size_t request_output_len, + const size_t beam_width, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional return_cum_log_probs_opt) = 0; }; template class FTGpt: public IFGpt { public: - FTGpt(const size_t head_num, - const size_t size_per_head, - const size_t inter_size, - const size_t layer_num, - const size_t vocab_size, - const int start_id, - const int end_id, - const bool sparse, - const vector weights): + FTGpt(const size_t head_num, + const size_t size_per_head, + const size_t inter_size, + const size_t layer_num, + const size_t vocab_size, + const ft::gptVariantParams gpt_variant_params, + const int start_id, + const int end_id, + const bool sparse, + const vector weights): head_num_(head_num), size_per_head_(size_per_head), inter_size_(inter_size), layer_num_(layer_num), vocab_size_(vocab_size), + gpt_variant_params_(gpt_variant_params), start_id_(start_id), end_id_(end_id), #ifndef SPARSITY_ENABLED @@ -84,8 +85,8 @@ class FTGpt: public IFGpt { } std::string sp_config_fname = sparse ? SPGEMM_CONFIG : ""; - cublas_algo_map_ = new ft::cublasAlgoMap(GEMM_CONFIG, sp_config_fname); - cublas_wrapper_mutex_ = new std::mutex(); + cublas_algo_map_ = new ft::cublasAlgoMap(GEMM_CONFIG, sp_config_fname); + cublas_wrapper_mutex_ = new std::mutex(); gpt_weights_.resizeLayer(layer_num_); for (int i = 0; i < (int)layer_num_; i++) { @@ -115,14 +116,39 @@ class FTGpt: public IFGpt { get_ptr(weights_[i + 11 * layer_num_]); } - gpt_weights_.post_decoder_layernorm.gamma = get_ptr(weights_[12 * layer_num_ + 0]); - gpt_weights_.post_decoder_layernorm.beta = get_ptr(weights_[12 * layer_num_ + 1]); - gpt_weights_.position_encoding_table = get_ptr(weights_[12 * layer_num_ + 2]); - gpt_weights_.pre_decoder_embedding_table = get_ptr(weights_[12 * layer_num_ + 3]); - gpt_weights_.post_decoder_embedding.kernel = get_ptr(weights_[12 * layer_num_ + 4]); + size_t weight_offset = gpt_variant_params_.has_post_decoder_layernorm ? 0 : 2; + if (gpt_variant_params_.has_post_decoder_layernorm) { + gpt_weights_.post_decoder_layernorm.gamma = get_ptr(weights_[12 * layer_num_ + 0]); + gpt_weights_.post_decoder_layernorm.beta = get_ptr(weights_[12 * layer_num_ + 1]); + } + gpt_weights_.position_encoding_table = get_ptr(weights_[12 * layer_num_ + 2 - weight_offset]); + gpt_weights_.setMaxSeqLen(weights_[12 * layer_num_ + 2 - weight_offset].size(0)); + gpt_weights_.pre_decoder_embedding_table = get_ptr(weights_[12 * layer_num_ + 3 - weight_offset]); + gpt_weights_.post_decoder_embedding.kernel = get_ptr(weights_[12 * layer_num_ + 4 - weight_offset]); + + if (gpt_variant_params_.has_adapters) { + for (int i = 0; i < (int)layer_num_; i++) { + gpt_weights_.decoder_layer_weights[i]->after_attention_adapter_weights.intermediate_weight.kernel = + get_ptr(weights_[12 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_attention_adapter_weights.intermediate_weight.bias = + get_ptr(weights_[13 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_attention_adapter_weights.output_weight.kernel = + get_ptr(weights_[14 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_attention_adapter_weights.output_weight.bias = + get_ptr(weights_[15 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_ffn_adapter_weights.intermediate_weight.kernel = + get_ptr(weights_[16 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_ffn_adapter_weights.intermediate_weight.bias = + get_ptr(weights_[17 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_ffn_adapter_weights.output_weight.kernel = + get_ptr(weights_[18 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_ffn_adapter_weights.output_weight.bias = + get_ptr(weights_[19 * layer_num_ + 4 - weight_offset + i + 1]); + } + } #ifdef SPARSITY_ENABLED if (sparse_) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(cublas_handle, stream); ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublas_handle, @@ -134,7 +160,7 @@ class FTGpt: public IFGpt { nullptr); // Here we need to pass hidden_units to compress weights as sparse BERT did, // because GptWeights has no proper attribute value - like num_layer, dummy hidden_units, - // or inter_size. Let me udpate an initalization of GptWeights in future. + // or inter_size. Let me update an initialization of GptWeights in future. int hidden_units = head_num_ * size_per_head_; for (size_t i = 0; i < layer_num_; ++i) { gpt_weights_.decoder_layer_weights[i]->compress_weights(cublas_wrapper, hidden_units); @@ -153,28 +179,29 @@ class FTGpt: public IFGpt { delete cublas_wrapper_mutex_; } - void forward(th::Tensor& input_ids, - th::Tensor& input_lengths, - th::Tensor& output_ids, - th::Tensor& parent_ids, - th::Tensor& sequence_lengths, - th::Tensor& cum_log_probs, - const size_t request_output_len, - const size_t beam_width, - const size_t top_k, - const float top_p, - const float beam_search_diversity_rate, - const float temperature, - const float len_penalty, - const float repetition_penalty, - const unsigned long long int random_seed, - const int return_cum_log_probs = 0) override + void forward(th::Tensor& input_ids, + th::Tensor& input_lengths, + th::Tensor& output_ids, + th::Tensor& sequence_lengths, + th::Tensor& cum_log_probs, + const size_t request_output_len, + const size_t beam_width, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional return_cum_log_probs_opt) override { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + int return_cum_log_probs = return_cum_log_probs_opt.has_value() ? (int)return_cum_log_probs_opt.value() : 0; + + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(cublasHandle, stream); - ft::Allocator allocator = ft::Allocator(); - ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublasHandle, + ft::Allocator allocator = ft::Allocator(); + ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(cublasHandle, cublasltHandle_, #ifdef SPARSITY_ENABLED cusparseLtHandle_, @@ -197,13 +224,13 @@ class FTGpt: public IFGpt { } const size_t request_batch_size = (size_t)input_ids.size(0); - const size_t max_input_length = (size_t)input_ids.size(1); - const int total_output_len = (int)(max_input_length + request_output_len); + const size_t max_input_length = (size_t)input_ids.size(1); + const int total_output_len = (int)(max_input_length + request_output_len); ft::NcclParam tensor_para; ft::NcclParam pipeline_para; - ft::ParallelGpt gpt = ft::ParallelGpt(request_batch_size, + ft::ParallelGpt gpt = ft::ParallelGpt(request_batch_size, total_output_len, max_input_length, beam_width, @@ -214,13 +241,16 @@ class FTGpt: public IFGpt { vocab_size_, start_id_, end_id_, - beam_search_diversity_rate, - top_k, - top_p, - random_seed, - temperature, - len_penalty, - repetition_penalty, + end_id_ + 1, // p/prompt tuning virtual token start id + ft::PromptLearningType::no_prompt, + gpt_variant_params_, // gpt variant params --> meta opt + 0.0f, // beam_search_diversity_rate, + 1, // top_k, + 0.0, // top_p, + 0, // random_seed, + 1.0f, // temperature, + 1.0f, // len_penalty, + 1.0f, // repetition_penalty, tensor_para, pipeline_para, stream, @@ -230,6 +260,7 @@ class FTGpt: public IFGpt { &prop_, sparse_, 0); + std::vector output_seq_len(request_batch_size, total_output_len); std::unordered_map input_tensors = std::unordered_map{ {"input_ids", @@ -240,32 +271,39 @@ class FTGpt: public IFGpt { {"input_lengths", ft::Tensor{ ft::MEMORY_GPU, ft::TYPE_INT32, std::vector{request_batch_size}, get_ptr(input_lengths)}}, - {"max_output_seq_len", - ft::Tensor{ft::MEMORY_CPU, ft::TYPE_INT32, std::vector{1}, &total_output_len}}}; - if (top_k == 0 && top_p == 0.0f) { - ft::FT_CHECK(beam_width > 1); + {"output_seq_len", + ft::Tensor{ + ft::MEMORY_CPU, ft::TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}}; + if (beam_width > 1 && beam_search_diversity_rate_opt.has_value()) { input_tensors.insert( {"beam_search_diversity_rate", - ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &beam_search_diversity_rate}}); + convert_tensor(beam_search_diversity_rate_opt.value(), ft::MemoryType::MEMORY_CPU)}); } - else { - if (top_p != 0.0f) { - input_tensors.insert( - {"runtime_top_p", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &top_p}}); - } - if (top_k != 0) { - input_tensors.insert( - {"runtime_top_k", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_INT32, std::vector{1}, &top_k}}); - } + if (top_p_opt.has_value()) { + input_tensors.insert( + {"runtime_top_p", convert_tensor(top_p_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (top_k_opt.has_value()) { + input_tensors.insert( + {"runtime_top_k", convert_tensor(top_k_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (temperature_opt.has_value()) { + input_tensors.insert( + {"temperature", convert_tensor(temperature_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (len_penalty_opt.has_value()) { + input_tensors.insert( + {"len_penalty", convert_tensor(len_penalty_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (repetition_penalty_opt.has_value()) { + input_tensors.insert({"repetition_penalty", + convert_tensor(repetition_penalty_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (random_seed_opt.has_value()) { + input_tensors.insert( + {"random_seed", + convert_tensor(random_seed_opt.value(), ft::MemoryType::MEMORY_CPU)}); } - input_tensors.insert( - {"temperature", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &temperature}}); - input_tensors.insert( - {"len_penalty", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &len_penalty}}); - input_tensors.insert({"repetition_penalty", - ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &repetition_penalty}}); - input_tensors.insert( - {"random_seed", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_UINT64, std::vector{1}, &random_seed}}); bool return_context_cum_log_probs = false; if (return_cum_log_probs == 2) { @@ -281,11 +319,6 @@ class FTGpt: public IFGpt { ft::TYPE_INT32, std::vector{request_batch_size, beam_width, (size_t)total_output_len}, get_ptr(output_ids)}}, - {"parent_ids", - ft::Tensor{ft::MEMORY_GPU, - ft::TYPE_INT32, - std::vector{(size_t)total_output_len, request_batch_size, beam_width}, - get_ptr(parent_ids)}}, {"sequence_length", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_INT32, @@ -319,52 +352,60 @@ class FTGpt: public IFGpt { const size_t inter_size_; const size_t layer_num_; const size_t vocab_size_; - const int start_id_; - const int end_id_; - const bool sparse_; + const int start_id_; + const int end_id_; + const bool sparse_; + + const ft::gptVariantParams gpt_variant_params_; std::vector weights_; - cublasLtHandle_t cublasltHandle_; + cublasLtHandle_t cublasltHandle_; #ifdef SPARSITY_ENABLED cusparseLtHandle_t cusparseLtHandle_; - bool is_spmm_compressed = false; + bool is_spmm_compressed = false; #endif - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; - struct cudaDeviceProp prop_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; + struct cudaDeviceProp prop_; ft::ParallelGptWeight gpt_weights_; }; class GptOp: public th::jit::CustomClassHolder { public: - GptOp(const int64_t head_num, - const int64_t size_per_head, - const int64_t inter_size, - const int64_t layer_num, - const int64_t vocab_size, - const int64_t start_id, - const int64_t end_id, - const bool sparse, + GptOp(const int64_t head_num, + const int64_t size_per_head, + const int64_t inter_size, + const int64_t layer_num, + const int64_t vocab_size, + const int64_t start_id, + const int64_t end_id, + const bool sparse, + const double layernorm_eps, + const std::string layernorm_type, + const std::string activation_type, + const bool has_post_decoder_layernorm, + const bool has_adapters, + const int64_t adapter_inter_size, const vector weights); ~GptOp(); - vector forward(th::Tensor input_ids, - th::Tensor input_lengths, - const int64_t output_len, - const int64_t beam_width, - const int64_t top_k, - const double top_p, - const double beam_search_diversity_rate, - const double temperature, - const double len_penalty, - const double repetition_penalty, - const int64_t random_seed, - const int64_t return_cum_log_probs = 0); + vector forward(th::Tensor input_ids, + th::Tensor input_lengths, + const int64_t output_len, + th::optional beam_width_opt, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional return_cum_log_probs_opt); private: - const at::ScalarType st_; - IFGpt* ftgpt; + const at::ScalarType st_; + IFGpt* ftgpt; std::vector weights; }; diff --git a/src/fastertransformer/th_op/longformer/LongformerEncoderOp.cc b/src/fastertransformer/th_op/longformer/LongformerEncoderOp.cc index 05ade7e88..ccd19d345 100644 --- a/src/fastertransformer/th_op/longformer/LongformerEncoderOp.cc +++ b/src/fastertransformer/th_op/longformer/LongformerEncoderOp.cc @@ -30,7 +30,7 @@ FasterTransformerLongformerEncoder::FasterTransformerLongformerEncoder(int64_t l int64_t max_global_token_num, int64_t max_batch_size, int64_t max_seq_len, - double attn_scaler): + double attn_scaler): layer_num_(layer_num), in_dim_(in_dim), head_num_(head_num), @@ -44,7 +44,7 @@ FasterTransformerLongformerEncoder::FasterTransformerLongformerEncoder(int64_t l hidden_units_(head_num * size_per_head) { ft::check_cuda_error(cublasLtCreate(&_cublasltHandle)); - cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); + cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); cublas_wrapper_mutex_ = new std::mutex(); } @@ -66,13 +66,13 @@ th::Tensor FasterTransformerLongformerEncoder::forward( ft::check_cuda_error(cudaSetDevice(device_id)); int batch_size = input.size(0); - int seq_len = input.size(1); - int in_dim_ = input.size(2); + int seq_len = input.size(1); + int in_dim_ = input.size(2); auto options = th::TensorOptions().dtype(scalar_type).device(torch::kCUDA, device_id); - auto output = th::zeros({batch_size, seq_len, (int64_t)hidden_units_}, options); + auto output = th::zeros({batch_size, seq_len, (int64_t)hidden_units_}, options); - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t _cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(_cublasHandle, stream); @@ -88,6 +88,11 @@ th::Tensor FasterTransformerLongformerEncoder::forward( else if (scalar_type == at::ScalarType::Half) { data_type = ft::TYPE_FP16; } +#ifdef ENABLE_BF16 + else if (scalar_type == at::ScalarType::BFloat16) { + data_type = ft::TYPE_BF16; + } +#endif else { throw std::runtime_error("Wrong Tensor type."); } @@ -156,6 +161,30 @@ th::Tensor FasterTransformerLongformerEncoder::forward( ft::check_cuda_error(cudaStreamSynchronize(stream)); delete encoder; } +#ifdef ENABLE_BF16 + else if (scalar_type == at::ScalarType::BFloat16) { + cublas_wrapper->setBF16GemmConfig(); + auto encoder = new ft::LongformerEncoder<__nv_bfloat16>(layer_num_, + in_dim_, + head_num_, + size_per_head_, + intermediate_size_, + local_attn_window_size_, + max_global_token_num_, + max_batch_size_, + max_seq_len_, + attn_scaler_, + stream, + cublas_wrapper, + allocator, + false); + setWeight<__nv_bfloat16>( + layer_num_, in_dim_, hidden_units_, intermediate_size_, th_weights, encoder->getWeightsPtr()); + encoder->forward(&output_tensors, &input_tensors); + ft::check_cuda_error(cudaStreamSynchronize(stream)); + delete encoder; + } +#endif delete cublas_wrapper; delete allocator; diff --git a/src/fastertransformer/th_op/longformer/LongformerEncoderOp.h b/src/fastertransformer/th_op/longformer/LongformerEncoderOp.h index 68e388d4d..8e376ad26 100644 --- a/src/fastertransformer/th_op/longformer/LongformerEncoderOp.h +++ b/src/fastertransformer/th_op/longformer/LongformerEncoderOp.h @@ -32,12 +32,12 @@ class FasterTransformerLongformerEncoder: public th::jit::CustomClassHolder { size_t max_global_token_num_; size_t max_batch_size_; size_t max_seq_len_; - float attn_scaler_; + float attn_scaler_; size_t hidden_units_; - cublasLtHandle_t _cublasltHandle; + cublasLtHandle_t _cublasltHandle; ft::cublasAlgoMap* cublas_algo_map_; - std::mutex* cublas_wrapper_mutex_; + std::mutex* cublas_wrapper_mutex_; public: FasterTransformerLongformerEncoder(int64_t layer_num, @@ -49,7 +49,7 @@ class FasterTransformerLongformerEncoder: public th::jit::CustomClassHolder { int64_t max_global_token_num, int64_t max_batch_size, int64_t max_seq_len, - double attn_scaler); + double attn_scaler); ~FasterTransformerLongformerEncoder(); @@ -57,22 +57,22 @@ class FasterTransformerLongformerEncoder: public th::jit::CustomClassHolder { th::Tensor local_attn_mask, th::Tensor global_attn_mask, th::Tensor th_weights, - int64_t device_id = 0); + int64_t device_id = 0); template - void setWeight(int layer_num, - int in_dim, - int hidden_units, - int intermediate_size, - th::Tensor th_weights, + void setWeight(int layer_num, + int in_dim, + int hidden_units, + int intermediate_size, + th::Tensor th_weights, std::vector>* weights) { weights->clear(); weights->resize(layer_num); auto weights_ptr = get_ptr(th_weights); - int offside = 0; + int offside = 0; for (int i = 0; i < layer_num; i++) { - // q k v kg vg weights and bias should be continous, required by the ft longformer encoder. + // q k v kg vg weights and bias should be continuous, required by the ft longformer encoder. weights->at(i).query_weights.kernel = weights_ptr + offside; // q offside += (i == 0 ? in_dim : hidden_units) * hidden_units; weights->at(i).key_weights.kernel = weights_ptr + offside; // k diff --git a/src/fastertransformer/th_op/multi_gpu_gpt/CMakeLists.txt b/src/fastertransformer/th_op/multi_gpu_gpt/CMakeLists.txt index 88c4e58fe..d4cceee65 100644 --- a/src/fastertransformer/th_op/multi_gpu_gpt/CMakeLists.txt +++ b/src/fastertransformer/th_op/multi_gpu_gpt/CMakeLists.txt @@ -13,4 +13,5 @@ # limitations under the License. add_library(th_parallel_gpt SHARED ParallelGptOp.cc WeightTransposeCalibrateQuantizeOp.cc) -target_link_libraries(th_parallel_gpt PRIVATE "${TORCH_LIBRARIES}" ParallelGpt th_utils calibrate_quantize_weight_kernels) +target_link_libraries(th_parallel_gpt PRIVATE "${TORCH_LIBRARIES}" + ParallelGpt th_utils calibrate_quantize_weight_kernels nccl_utils) diff --git a/src/fastertransformer/th_op/multi_gpu_gpt/ParallelGptOp.cc b/src/fastertransformer/th_op/multi_gpu_gpt/ParallelGptOp.cc index 498959372..1fff8f2b2 100644 --- a/src/fastertransformer/th_op/multi_gpu_gpt/ParallelGptOp.cc +++ b/src/fastertransformer/th_op/multi_gpu_gpt/ParallelGptOp.cc @@ -18,27 +18,42 @@ #include "src/fastertransformer/th_op/multi_gpu_gpt/WeightTransposeCalibrateQuantizeOp.h" namespace th = torch; +namespace ft = fastertransformer; namespace torch_ext { -ParallelGptOp::ParallelGptOp(const int64_t head_num, - const int64_t size_per_head, - const int64_t inter_size, - const int64_t layer_num, - const int64_t vocab_size, - const int64_t start_id, - const int64_t end_id, - const int64_t tensor_para_size, - const int64_t pipeline_para_size, - const int64_t int8_mode, +ParallelGptOp::ParallelGptOp(const int64_t head_num, + const int64_t size_per_head, + const int64_t inter_size, + const int64_t layer_num, + const int64_t vocab_size, + const int64_t start_id, + const int64_t end_id, + const int64_t tensor_para_size, + const int64_t pipeline_para_size, + const int64_t int8_mode, + const double layernorm_eps, + const std::string layernorm_type, + const std::string activation_type, + const bool has_post_decoder_layernorm, + const bool has_adapters, + const int64_t adapter_inter_size, const std::vector weights, const std::vector int8_weights, - const std::vector scale): + const std::vector scale, + const double shared_contexts_ratio): st_(weights[0].scalar_type()) { for (auto t : weights) { CHECK_INPUT(t, st_); } + ft::gptVariantParams gpt_variant_params{(float)layernorm_eps, + ft::getLayerNormType(layernorm_type), + ft::getActivationType(activation_type), + has_post_decoder_layernorm, + has_adapters, + (size_t)adapter_inter_size}; + switch (st_) { case at::ScalarType::Float: ftgpt = new FTGpt((size_t)head_num, @@ -46,6 +61,7 @@ ParallelGptOp::ParallelGptOp(const int64_t head_num, (size_t)inter_size, (size_t)layer_num, (size_t)vocab_size, + gpt_variant_params, start_id, end_id, tensor_para_size, @@ -53,7 +69,8 @@ ParallelGptOp::ParallelGptOp(const int64_t head_num, int8_mode, weights, int8_weights, - scale); + scale, + shared_contexts_ratio); break; case at::ScalarType::Half: ftgpt = new FTGpt((size_t)head_num, @@ -61,6 +78,7 @@ ParallelGptOp::ParallelGptOp(const int64_t head_num, (size_t)inter_size, (size_t)layer_num, (size_t)vocab_size, + gpt_variant_params, start_id, end_id, tensor_para_size, @@ -68,7 +86,8 @@ ParallelGptOp::ParallelGptOp(const int64_t head_num, int8_mode, weights, int8_weights, - scale); + scale, + shared_contexts_ratio); break; #ifdef ENABLE_BF16 case at::ScalarType::BFloat16: @@ -77,6 +96,7 @@ ParallelGptOp::ParallelGptOp(const int64_t head_num, (size_t)inter_size, (size_t)layer_num, (size_t)vocab_size, + gpt_variant_params, start_id, end_id, tensor_para_size, @@ -84,7 +104,8 @@ ParallelGptOp::ParallelGptOp(const int64_t head_num, int8_mode, weights, int8_weights, - scale); + scale, + shared_contexts_ratio); break; #endif default: @@ -97,18 +118,18 @@ ParallelGptOp::~ParallelGptOp() delete ftgpt; } -std::vector ParallelGptOp::forward(th::Tensor input_ids, - th::Tensor input_lengths, - const int64_t output_len, - const int64_t beam_width, - const int64_t top_k, - const double top_p, - const double beam_search_diversity_rate, - const double temperature, - const double len_penalty, - const double repetition_penalty, - const int64_t random_seed, - const int64_t return_cum_log_probs) +std::vector ParallelGptOp::forward(th::Tensor input_ids, + th::Tensor input_lengths, + const int64_t output_len, + th::optional beam_width_opt, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional return_cum_log_probs_opt) { CHECK_TH_CUDA(input_ids); CHECK_CONTIGUOUS(input_ids); @@ -116,18 +137,21 @@ std::vector ParallelGptOp::forward(th::Tensor input_ids, CHECK_TH_CUDA(input_lengths); CHECK_CONTIGUOUS(input_lengths); TORCH_CHECK(input_lengths.dtype() == torch::kInt32, "input_lengths dtype should be int32"); - TORCH_CHECK(return_cum_log_probs == 0 || return_cum_log_probs == 1 || return_cum_log_probs == 2, - "return_cum_log_probs should be" - " 0 (no return cum_log_probs), " - " 1 (the cumulative log probs of generated sequences), or" - " 2 (the cumulative log probs of sequences).") + int64_t return_cum_log_probs = return_cum_log_probs_opt.has_value() ? (int64_t)return_cum_log_probs_opt.value() : 0; + if (return_cum_log_probs_opt.has_value()) { + TORCH_CHECK(return_cum_log_probs == 0 || return_cum_log_probs == 1 || return_cum_log_probs == 2, + "return_cum_log_probs should be" + " 0 (no return cum_log_probs), " + " 1 (the cumulative log probs of generated sequences), or" + " 2 (the cumulative log probs of sequences).") + } - const int batch_size = input_ids.size(0) / beam_width; - const int max_input_length = input_ids.size(1); - const int total_request_output_len = max_input_length + output_len; - th::Tensor output_ids = torch::empty({batch_size, beam_width, total_request_output_len}, - torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); - th::Tensor parent_ids = torch::empty({total_request_output_len, batch_size, beam_width}, + const int beam_width = beam_width_opt.has_value() ? (int)beam_width_opt.value() : 1; + + const int batch_size = input_ids.size(0); + const int max_input_length = input_ids.size(1); + const int total_request_output_len = max_input_length + output_len; + th::Tensor output_ids = torch::empty({batch_size, beam_width, total_request_output_len}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); th::Tensor sequence_lengths = torch::empty({batch_size, beam_width}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); @@ -137,19 +161,18 @@ std::vector ParallelGptOp::forward(th::Tensor input_ids, ftgpt->forward(input_ids, input_lengths, output_ids, - parent_ids, sequence_lengths, cum_log_probs, (const size_t)output_len, (const size_t)beam_width, - (const size_t)top_k, - (const float)top_p, - (const float)beam_search_diversity_rate, - (const float)temperature, - (const float)len_penalty, - (const float)repetition_penalty, - (const unsigned long long int)random_seed, - return_cum_log_probs); + top_k_opt, + top_p_opt, + beam_search_diversity_rate_opt, + temperature_opt, + len_penalty_opt, + repetition_penalty_opt, + random_seed_opt, + return_cum_log_probs_opt); if (return_cum_log_probs > 0) { return std::vector{output_ids, sequence_lengths, cum_log_probs}; } @@ -158,21 +181,33 @@ std::vector ParallelGptOp::forward(th::Tensor input_ids, } // namespace torch_ext -static auto fasterTransformerGptTHS = torch::jit::class_("FasterTransformer", "ParallelGptOp") - .def(torch::jit::init, - std::vector, - std::vector>()) - .def("forward", &torch_ext::ParallelGptOp::forward); +static auto fasterTransformerGptTHS = +#ifdef LEGACY_THS + torch::jit::class_("FasterTransformerParallelGptOp") +#else + torch::jit::class_("FasterTransformer", "ParallelGptOp") +#endif + .def(torch::jit::init, + std::vector, + std::vector, + double>()) + .def("forward", &torch_ext::ParallelGptOp::forward); static auto weight_transpose_calibrate_quantize = torch::RegisterOperators( "fastertransformer::weight_transpose_calibrate_quantize", &torch_ext::weight_transpose_calibrate_quantize); diff --git a/src/fastertransformer/th_op/multi_gpu_gpt/ParallelGptOp.h b/src/fastertransformer/th_op/multi_gpu_gpt/ParallelGptOp.h index 38a88388c..3bedd2807 100644 --- a/src/fastertransformer/th_op/multi_gpu_gpt/ParallelGptOp.h +++ b/src/fastertransformer/th_op/multi_gpu_gpt/ParallelGptOp.h @@ -18,7 +18,7 @@ #include "src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h" #include "src/fastertransformer/th_op/th_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" namespace ft = fastertransformer; namespace th = torch; @@ -29,44 +29,46 @@ using std::vector; class IFGpt { public: virtual ~IFGpt() {} - virtual void forward(th::Tensor& input_ids, - th::Tensor& input_lengths, - th::Tensor& output_ids, - th::Tensor& parent_ids, - th::Tensor& sequence_lengths, - th::Tensor& cum_log_probs, - const size_t request_output_len, - const size_t beam_width, - const size_t top_k, - const float top_p, - const float beam_search_diversity_rate, - const float temperature, - const float len_penalty, - const float repetition_penalty, - const unsigned long long int random_seed, - const int return_cum_log_probs = 0) = 0; + virtual void forward(th::Tensor& input_ids, + th::Tensor& input_lengths, + th::Tensor& output_ids, + th::Tensor& sequence_lengths, + th::Tensor& cum_log_probs, + const size_t request_output_len, + const size_t beam_width, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional return_cum_log_probs_opt) = 0; }; template class FTGpt: public IFGpt { public: - FTGpt(const size_t head_num, - const size_t size_per_head, - const size_t inter_size, - const size_t layer_num, - const size_t vocab_size, - const int start_id, - const int end_id, - const int tensor_para_size, - const int pipeline_para_size, - const int int8_mode, - const vector weights, - const vector int8_weights, - const vector scale): + FTGpt(const size_t head_num, + const size_t size_per_head, + const size_t inter_size, + const size_t layer_num, + const size_t vocab_size, + const ft::gptVariantParams gpt_variant_params, + const int start_id, + const int end_id, + const int tensor_para_size, + const int pipeline_para_size, + const int int8_mode, + const vector weights, + const vector int8_weights, + const vector scale, + const float shared_contexts_ratio): head_num_(head_num), size_per_head_(size_per_head), inter_size_(inter_size), layer_num_(layer_num), + gpt_variant_params_(gpt_variant_params), vocab_size_(vocab_size), start_id_(start_id), end_id_(end_id), @@ -75,13 +77,14 @@ class FTGpt: public IFGpt { int8_mode_(int8_mode), weights_(weights), int8_weights_(int8_weights), - scale_(scale) + scale_(scale), + shared_contexts_ratio_(shared_contexts_ratio) { ft::check_cuda_error(cublasLtCreate(&cublasltHandle_)); - cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); + cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); cublas_wrapper_mutex_ = new std::mutex(); - init_nccl_comm(); + ftNcclInitialize(tensor_para_, pipeline_para_, tensor_para_size, pipeline_para_size); gpt_weights_.resizeLayer(layer_num_); @@ -131,11 +134,36 @@ class FTGpt: public IFGpt { } } - gpt_weights_.post_decoder_layernorm.gamma = get_ptr(weights_[12 * layer_num_ + 0]); - gpt_weights_.post_decoder_layernorm.beta = get_ptr(weights_[12 * layer_num_ + 1]); - gpt_weights_.position_encoding_table = get_ptr(weights_[12 * layer_num_ + 2]); - gpt_weights_.pre_decoder_embedding_table = get_ptr(weights_[12 * layer_num_ + 3]); - gpt_weights_.post_decoder_embedding.kernel = get_ptr(weights_[12 * layer_num_ + 4]); + size_t weight_offset = gpt_variant_params_.has_post_decoder_layernorm ? 0 : 2; + if (gpt_variant_params_.has_post_decoder_layernorm) { + gpt_weights_.post_decoder_layernorm.gamma = get_ptr(weights_[12 * layer_num_ + 0]); + gpt_weights_.post_decoder_layernorm.beta = get_ptr(weights_[12 * layer_num_ + 1]); + } + gpt_weights_.position_encoding_table = get_ptr(weights_[12 * layer_num_ + 2 - weight_offset]); + gpt_weights_.setMaxSeqLen(weights_[12 * layer_num_ + 2 - weight_offset].size(0)); + gpt_weights_.pre_decoder_embedding_table = get_ptr(weights_[12 * layer_num_ + 3 - weight_offset]); + gpt_weights_.post_decoder_embedding.kernel = get_ptr(weights_[12 * layer_num_ + 4 - weight_offset]); + + if (gpt_variant_params_.has_adapters) { + for (int i = 0; i < (int)layer_num_; i++) { + gpt_weights_.decoder_layer_weights[i]->after_attention_adapter_weights.intermediate_weight.kernel = + get_ptr(weights_[12 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_attention_adapter_weights.intermediate_weight.bias = + get_ptr(weights_[13 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_attention_adapter_weights.output_weight.kernel = + get_ptr(weights_[14 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_attention_adapter_weights.output_weight.bias = + get_ptr(weights_[15 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_ffn_adapter_weights.intermediate_weight.kernel = + get_ptr(weights_[16 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_ffn_adapter_weights.intermediate_weight.bias = + get_ptr(weights_[17 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_ffn_adapter_weights.output_weight.kernel = + get_ptr(weights_[18 * layer_num_ + 4 - weight_offset + i + 1]); + gpt_weights_.decoder_layer_weights[i]->after_ffn_adapter_weights.output_weight.bias = + get_ptr(weights_[19 * layer_num_ + 4 - weight_offset + i + 1]); + } + } int device_id = 0; ft::check_cuda_error(cudaGetDevice(&device_id)); @@ -145,120 +173,35 @@ class FTGpt: public IFGpt { ~FTGpt() override { - ncclCommDestroy(tensor_para_comm_); - ncclCommDestroy(pipeline_para_comm_); + ft::ftNcclParamDestroy(tensor_para_); + ft::ftNcclParamDestroy(pipeline_para_); cublasLtDestroy(cublasltHandle_); delete cublas_algo_map_; delete cublas_wrapper_mutex_; } - void init_nccl_comm() + void forward(th::Tensor& input_ids, + th::Tensor& input_lengths, + th::Tensor& output_ids, + th::Tensor& sequence_lengths, + th::Tensor& cum_log_probs, + const size_t request_output_len, + const size_t beam_width, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional return_cum_log_probs_opt) override { - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank_)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size_)); - - int mpi_initialized; - MPICHECK(MPI_Initialized(&mpi_initialized)); - if (!mpi_initialized) { - FT_LOG_INFO("MPI is not initialized! Skipped the NCCL communication initialization.\n"); - if (tensor_para_size_ != 1) { - FT_LOG_ERROR("MPI initialization can only be skipped when tensor_para_size=1, but got %zu!\n", - tensor_para_size_); - } - if (pipeline_para_size_ != 1) { - FT_LOG_ERROR("MPI initialization can only be skipped when pipeline_para_size=1, but got %zu!\n", - pipeline_para_size_); - } - return; - } - - int rank = rank_; - tensor_para_rank_ = rank % tensor_para_size_; - pipeline_para_rank_ = rank / tensor_para_size_; - - ncclUniqueId tensor_para_nccl_uid; - ncclUniqueId pipeline_para_nccl_uid; - - // assume gpu_num = n * k, - // tensor parallelism group size is n - // pipeline parallelism group size is k - if (tensor_para_rank_ == 0) { - // get the uid of each tensor parallelism group - // here, 0, 1, ..., n-1 are in group 0, - // n, ..., 2n - 1 are in group 1. - NCCLCHECK(ncclGetUniqueId(&tensor_para_nccl_uid)); - for (int i = 1; i < (int)tensor_para_size_; i++) { - FT_LOG_INFO("rank %d sends tensor_para_nccl_uid to rank %d \n", rank, rank + i); - MPICHECK(MPI_Send( - &tensor_para_nccl_uid, sizeof(tensor_para_nccl_uid), MPI_BYTE, rank + i, 0, MPI_COMM_WORLD)); - } - } - else { - MPI_Status status; - FT_LOG_INFO("rank %d receives tensor_para_nccl_uid from rank %d \n", rank, rank - (int)tensor_para_rank_); - MPICHECK(MPI_Recv(&tensor_para_nccl_uid, - sizeof(tensor_para_nccl_uid), - MPI_BYTE, - rank - tensor_para_rank_, - 0, - MPI_COMM_WORLD, - &status)); - } - - if (pipeline_para_rank_ == 0) { - // get the uid of each pipeline parallelism group - // 0, k, 2k, are in group 0 - // 1, k+1, 2k+1 are in group 1 - NCCLCHECK(ncclGetUniqueId(&pipeline_para_nccl_uid)); - for (int i = 1; i < (int)pipeline_para_size_; i++) { - FT_LOG_INFO( - "rank %d sends pipeline_para_nccl_uid to rank %d \n", rank, rank + i * (int)tensor_para_size_); - MPICHECK(MPI_Send(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank + i * tensor_para_size_, - 0, - MPI_COMM_WORLD)); - } - } - else { - MPI_Status status; - FT_LOG_INFO("rank %d receives pipeline_para_nccl_uid from rank %d \n", rank, rank % (int)tensor_para_size_); - MPICHECK(MPI_Recv(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank % tensor_para_size_, - 0, - MPI_COMM_WORLD, - &status)); - } - NCCLCHECK(ncclCommInitRank(&tensor_para_comm_, tensor_para_size_, tensor_para_nccl_uid, tensor_para_rank_)); - NCCLCHECK( - ncclCommInitRank(&pipeline_para_comm_, pipeline_para_size_, pipeline_para_nccl_uid, pipeline_para_rank_)); - } - - void forward(th::Tensor& input_ids, - th::Tensor& input_lengths, - th::Tensor& output_ids, - th::Tensor& parent_ids, - th::Tensor& sequence_lengths, - th::Tensor& cum_log_probs, - const size_t request_output_len, - const size_t beam_width, - const size_t top_k, - const float top_p, - const float beam_search_diversity_rate, - const float temperature, - const float len_penalty, - const float repetition_penalty, - const unsigned long long int query_random_seed, - const int return_cum_log_probs = 0) override - { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + int return_cum_log_probs = return_cum_log_probs_opt.has_value() ? (int)return_cum_log_probs_opt.value() : 0; + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(cublasHandle, stream); - ft::Allocator allocator = ft::Allocator(); - ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper( + ft::Allocator allocator = ft::Allocator(); + ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper( cublasHandle, cublasltHandle_, stream, cublas_algo_map_, cublas_wrapper_mutex_, &allocator); if (std::is_same::value) { @@ -274,18 +217,10 @@ class FTGpt: public IFGpt { } const size_t request_batch_size = (size_t)input_ids.size(0) / beam_width; - const size_t max_input_length = (size_t)input_ids.size(1); - const int total_output_len = (int)(max_input_length + request_output_len); - - ft::NcclParam tensor_para(tensor_para_rank_, tensor_para_size_, tensor_para_comm_); - ft::NcclParam pipeline_para(pipeline_para_rank_, pipeline_para_size_, pipeline_para_comm_); - - unsigned long long int random_seed = query_random_seed; - if (world_size_ > 1) { - MPICHECK(MPI_Bcast(&random_seed, 1, MPI_UNSIGNED_LONG_LONG, 0, MPI_COMM_WORLD)); - } + const size_t max_input_length = (size_t)input_ids.size(1); + const int total_output_len = (int)(max_input_length + request_output_len); - ft::ParallelGpt gpt = ft::ParallelGpt(request_batch_size, + ft::ParallelGpt gpt = ft::ParallelGpt(request_batch_size, total_output_len, max_input_length, beam_width, @@ -296,22 +231,30 @@ class FTGpt: public IFGpt { vocab_size_, start_id_, end_id_, - beam_search_diversity_rate, - top_k, - top_p, - random_seed, - temperature, - len_penalty, - repetition_penalty, - tensor_para, - pipeline_para, + end_id_ + 1, // p/prompt tuning virtual token start id + ft::PromptLearningType::no_prompt, + gpt_variant_params_, + 0.0f, // beam_search_diversity_rate, + 1, // top_k, + 0.0, // top_p, + 0, // random_seed, + 1.0f, // temperature, + 1.0f, // len_penalty, + 1.0f, // repetition_penalty, + tensor_para_, + pipeline_para_, stream, &cublas_wrapper, &allocator, false, &prop_, false, - int8_mode_); + int8_mode_, + nullptr, + 0, + true, + shared_contexts_ratio_); + std::vector output_seq_len(request_batch_size, total_output_len); std::unordered_map input_tensors = std::unordered_map{ {"input_ids", @@ -322,32 +265,39 @@ class FTGpt: public IFGpt { {"input_lengths", ft::Tensor{ ft::MEMORY_GPU, ft::TYPE_INT32, std::vector{request_batch_size}, get_ptr(input_lengths)}}, - {"max_output_seq_len", - ft::Tensor{ft::MEMORY_CPU, ft::TYPE_INT32, std::vector{1}, &total_output_len}}}; - if (top_k == 0 && top_p == 0.0f) { - ft::FT_CHECK(beam_width > 1); + {"output_seq_len", + ft::Tensor{ + ft::MEMORY_CPU, ft::TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}}; + if (beam_width > 1 && beam_search_diversity_rate_opt.has_value()) { input_tensors.insert( {"beam_search_diversity_rate", - ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &beam_search_diversity_rate}}); + convert_tensor(beam_search_diversity_rate_opt.value(), ft::MemoryType::MEMORY_CPU)}); } - else { - if (top_p != 0.0f) { - input_tensors.insert( - {"runtime_top_p", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &top_p}}); - } - if (top_k != 0) { - input_tensors.insert( - {"runtime_top_k", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_INT32, std::vector{1}, &top_k}}); - } + if (top_p_opt.has_value()) { + input_tensors.insert( + {"runtime_top_p", convert_tensor(top_p_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (top_k_opt.has_value()) { + input_tensors.insert( + {"runtime_top_k", convert_tensor(top_k_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (temperature_opt.has_value()) { + input_tensors.insert( + {"temperature", convert_tensor(temperature_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (len_penalty_opt.has_value()) { + input_tensors.insert( + {"len_penalty", convert_tensor(len_penalty_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (repetition_penalty_opt.has_value()) { + input_tensors.insert({"repetition_penalty", + convert_tensor(repetition_penalty_opt.value(), ft::MemoryType::MEMORY_CPU)}); + } + if (random_seed_opt.has_value()) { + input_tensors.insert( + {"random_seed", + convert_tensor(random_seed_opt.value(), ft::MemoryType::MEMORY_CPU)}); } - input_tensors.insert( - {"temperature", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &temperature}}); - input_tensors.insert( - {"len_penalty", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &len_penalty}}); - input_tensors.insert({"repetition_penalty", - ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &repetition_penalty}}); - input_tensors.insert( - {"random_seed", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_UINT64, std::vector{1}, &random_seed}}); bool return_context_cum_log_probs = false; if (return_cum_log_probs == 2) { @@ -363,11 +313,6 @@ class FTGpt: public IFGpt { ft::TYPE_INT32, std::vector{request_batch_size, beam_width, (size_t)total_output_len}, get_ptr(output_ids)}}, - {"parent_ids", - ft::Tensor{ft::MEMORY_GPU, - ft::TYPE_INT32, - std::vector{(size_t)total_output_len, request_batch_size, beam_width}, - get_ptr(parent_ids)}}, {"sequence_length", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_INT32, @@ -385,12 +330,16 @@ class FTGpt: public IFGpt { try { gpt.forward(&output_tensors, &input_tensors, &gpt_weights_); } - catch (std::runtime_error& error) { - std::cout << error.what() << std::endl; + catch (const std::runtime_error& error) { + FT_LOG_ERROR(error.what()); + ft::FT_CHECK(false); + } + catch (const std::exception& error) { + FT_LOG_ERROR(error.what()); ft::FT_CHECK(false); } catch (...) { - std::cout << "Runtime error" << std::endl; + FT_LOG_ERROR("Unknown error"); ft::FT_CHECK(false); } } @@ -401,66 +350,74 @@ class FTGpt: public IFGpt { const size_t inter_size_; const size_t layer_num_; const size_t vocab_size_; - const int start_id_; - const int end_id_; + const int start_id_; + const int end_id_; + const float shared_contexts_ratio_; const int int8_mode_ = 0; size_t tensor_para_size_; size_t pipeline_para_size_; + ft::gptVariantParams gpt_variant_params_; + std::vector int8_weights_; std::vector scale_; std::vector weights_; - size_t tensor_para_rank_; - ncclComm_t tensor_para_comm_; - size_t pipeline_para_rank_; - ncclComm_t pipeline_para_comm_; + ft::NcclParam tensor_para_; + ft::NcclParam pipeline_para_; - cublasLtHandle_t cublasltHandle_; - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; - struct cudaDeviceProp prop_; + cublasLtHandle_t cublasltHandle_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; + struct cudaDeviceProp prop_; ft::ParallelGptWeight gpt_weights_; - int world_size_ = 1; - int rank_ = 0; + int world_size_ = 1; + int rank_ = 0; }; class ParallelGptOp: public th::jit::CustomClassHolder { public: - ParallelGptOp(const int64_t head_num, - const int64_t size_per_head, - const int64_t inter_size, - const int64_t layer_num, - const int64_t vocab_size, - const int64_t start_id, - const int64_t end_id, - const int64_t tensor_para_size, - const int64_t pipeline_para_size, - const int64_t int8_mode, + ParallelGptOp(const int64_t head_num, + const int64_t size_per_head, + const int64_t inter_size, + const int64_t layer_num, + const int64_t vocab_size, + const int64_t start_id, + const int64_t end_id, + const int64_t tensor_para_size, + const int64_t pipeline_para_size, + const int64_t int8_mode, + const double layernorm_eps, + const std::string layernorm_type, + const std::string activation_type, + const bool has_post_decoder_layernorm, + const bool has_adapters, + const int64_t adapter_inter_size, const vector weights, const vector int8_weights, - const vector scale); + const vector scale, + const double shared_contexts_ratio); ~ParallelGptOp(); - vector forward(th::Tensor input_ids, - th::Tensor input_lengths, - const int64_t output_len, - const int64_t beam_width, - const int64_t top_k, - const double top_p, - const double beam_search_diversity_rate, - const double temperature, - const double len_penalty, - const double repetition_penalty, - const int64_t random_seed, - const int64_t return_cum_log_probs); + vector forward(th::Tensor input_ids, + th::Tensor input_lengths, + const int64_t output_len, + th::optional beam_width_opt, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional return_cum_log_probs_opt); private: - const at::ScalarType st_; - IFGpt* ftgpt; + const at::ScalarType st_; + IFGpt* ftgpt; std::vector weights; }; diff --git a/src/fastertransformer/th_op/multi_gpu_gpt/WeightTransposeCalibrateQuantizeOp.cc b/src/fastertransformer/th_op/multi_gpu_gpt/WeightTransposeCalibrateQuantizeOp.cc index bf187c75b..7ee3e2abd 100644 --- a/src/fastertransformer/th_op/multi_gpu_gpt/WeightTransposeCalibrateQuantizeOp.cc +++ b/src/fastertransformer/th_op/multi_gpu_gpt/WeightTransposeCalibrateQuantizeOp.cc @@ -38,7 +38,7 @@ void ldnTransposeQuantizeWeightPerChannel(int8_t* output, const float* scale, co for (int n_i = 0; n_i < n; n_i++) { float scale_val = scale[n_i]; for (int k_i = 0; k_i < k; k_i++) { - float val = weight[k_i * n + n_i]; + float val = weight[k_i * n + n_i]; output[n_i * k + k_i] = float_to_int8_rn_host(val / scale_val); } } @@ -62,8 +62,8 @@ std::vector weight_transpose_calibrate_quantize(Tensor weight) if (weight.device() == torch::kCUDA) { auto int8_weight = torch::empty({k * n}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); int8_t* int8_weight_out = get_ptr(int8_weight); - auto scale = torch::empty({n}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); - float* scale_out = get_ptr(scale); + auto scale = torch::empty({n}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); + float* scale_out = get_ptr(scale); auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -76,8 +76,8 @@ std::vector weight_transpose_calibrate_quantize(Tensor weight) else { auto int8_weight = torch::empty({k * n}, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false)); int8_t* int8_weight_out = get_ptr(int8_weight); - auto scale = torch::empty({n}, torch::dtype(torch::kFloat32).device(torch::kCPU).requires_grad(false)); - float* scale_out = get_ptr(scale); + auto scale = torch::empty({n}, torch::dtype(torch::kFloat32).device(torch::kCPU).requires_grad(false)); + float* scale_out = get_ptr(scale); ldnCalibrateWeightPerChannel(scale_out, weight_, k, n); ldnTransposeQuantizeWeightPerChannel(int8_weight_out, scale_out, weight_, k, n); diff --git a/src/fastertransformer/th_op/swin/SwinINT8Op.cc b/src/fastertransformer/th_op/swin/SwinINT8Op.cc index e6b8a7131..94eee72dd 100644 --- a/src/fastertransformer/th_op/swin/SwinINT8Op.cc +++ b/src/fastertransformer/th_op/swin/SwinINT8Op.cc @@ -20,21 +20,21 @@ namespace th = torch; namespace torch_ext { SwinTransformerINT8Class::SwinTransformerINT8Class(std::vector w, - int64_t int8_mode, - th::Tensor depths, - th::Tensor num_heads, - int64_t max_batch, - int64_t img_size, - int64_t patch_size, - int64_t in_chans, - int64_t embed_dim, - int64_t window_size, - bool ape, - bool patch_norm, - int64_t layer_num, - double mlp_ratio, - bool qkv_bias, - double qk_scale): + int64_t int8_mode, + th::Tensor depths, + th::Tensor num_heads, + int64_t max_batch, + int64_t img_size, + int64_t patch_size, + int64_t in_chans, + int64_t embed_dim, + int64_t window_size, + bool ape, + bool patch_norm, + int64_t layer_num, + double mlp_ratio, + bool qkv_bias, + double qk_scale): st_(w[0].scalar_type()), depths_(depths), num_heads_(num_heads), weights_(w) { @@ -92,19 +92,19 @@ SwinTransformerINT8Class::SwinTransformerINT8Class(std::vector w, default: throw std::runtime_error("Wrong Tensor type."); } - info_int_ = torch::empty({11}, torch::dtype(torch::kInt64)); - info_int_[0] = max_batch; - info_int_[1] = img_size; - info_int_[2] = patch_size; - info_int_[3] = in_chans; - info_int_[4] = embed_dim; - info_int_[5] = window_size; - info_int_[6] = (int64_t)ape; - info_int_[7] = (int64_t)patch_norm; - info_int_[8] = layer_num; - info_int_[9] = (int64_t)qkv_bias; - info_int_[10] = int8_mode; - info_float_ = torch::empty({2}, torch::dtype(torch::kFloat64)); + info_int_ = torch::empty({11}, torch::dtype(torch::kInt64)); + info_int_[0] = max_batch; + info_int_[1] = img_size; + info_int_[2] = patch_size; + info_int_[3] = in_chans; + info_int_[4] = embed_dim; + info_int_[5] = window_size; + info_int_[6] = (int64_t)ape; + info_int_[7] = (int64_t)patch_norm; + info_int_[8] = layer_num; + info_int_[9] = (int64_t)qkv_bias; + info_int_[10] = int8_mode; + info_float_ = torch::empty({2}, torch::dtype(torch::kFloat64)); info_float_[0] = mlp_ratio; info_float_[1] = qk_scale; } @@ -127,7 +127,7 @@ SwinTransformerINT8Class::~SwinTransformerINT8Class() th::Tensor SwinTransformerINT8Class::forward(th::Tensor input) { CHECK_INPUT(input, st_); - int batch_size = input.size(0); + int batch_size = input.size(0); auto output = torch::empty({batch_size, output_dim_}, torch::dtype(input.dtype()).device(torch::kCUDA).requires_grad(false)); swin_transformer_func_->forward(batch_size, input, output); @@ -137,11 +137,11 @@ th::Tensor SwinTransformerINT8Class::forward(th::Tensor input) } // namespace torch_ext static auto swinTransformerINT8THS = - // #ifdef LEGACY_THS - // torch::jit::class_("SwinTransformerINT8Class") - // #else +#ifdef LEGACY_THS + torch::jit::class_("SwinTransformerINT8Class") +#else torch::jit::class_("SwinTransformerINT8", "Class") - // #endif +#endif .def(torch::jit::init, int64_t, th::Tensor, @@ -164,26 +164,26 @@ static auto swinTransformerINT8THS = return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int state_size = state.size(); - std::vector::const_iterator first = state.begin(); - std::vector::const_iterator last = state.begin() + (state_size - 4); - std::vector weights(first, last); - int idx = state.size() - 2; - int i = 0; - int64_t max_batch = state[idx][i++].item().to(); - int64_t img_size = state[idx][i++].item().to(); - int64_t patch_size = state[idx][i++].item().to(); - int64_t in_chans = state[idx][i++].item().to(); - int64_t embed_dim = state[idx][i++].item().to(); - int64_t window_size = state[idx][i++].item().to(); - bool ape = state[idx][i++].item().to(); - bool patch_norm = state[idx][i++].item().to(); - int64_t layer_num = state[idx][i++].item().to(); - bool qkv_bias = state[idx][i++].item().to(); - int64_t int8_mode = state[idx][i++].item().to(); - idx = state.size() - 1; - double mlp_ratio = state[idx][0].item().to(); - double qk_scale = state[idx][1].item().to(); + int state_size = state.size(); + std::vector::const_iterator first = state.begin(); + std::vector::const_iterator last = state.begin() + (state_size - 4); + std::vector weights(first, last); + int idx = state.size() - 2; + int i = 0; + int64_t max_batch = state[idx][i++].item().to(); + int64_t img_size = state[idx][i++].item().to(); + int64_t patch_size = state[idx][i++].item().to(); + int64_t in_chans = state[idx][i++].item().to(); + int64_t embed_dim = state[idx][i++].item().to(); + int64_t window_size = state[idx][i++].item().to(); + bool ape = state[idx][i++].item().to(); + bool patch_norm = state[idx][i++].item().to(); + int64_t layer_num = state[idx][i++].item().to(); + bool qkv_bias = state[idx][i++].item().to(); + int64_t int8_mode = state[idx][i++].item().to(); + idx = state.size() - 1; + double mlp_ratio = state[idx][0].item().to(); + double qk_scale = state[idx][1].item().to(); return c10::make_intrusive(weights, int8_mode, state[state_size - 4], diff --git a/src/fastertransformer/th_op/swin/SwinINT8Op.h b/src/fastertransformer/th_op/swin/SwinINT8Op.h index ea920fc0e..af5704afc 100644 --- a/src/fastertransformer/th_op/swin/SwinINT8Op.h +++ b/src/fastertransformer/th_op/swin/SwinINT8Op.h @@ -31,39 +31,39 @@ class ISwinTransformerINT8Func { template class SwinTransformerINT8Func: public ISwinTransformerINT8Func { public: - int sm_; - bool _use_ORDER_COL32_2R_4R4; - int int8_mode_; - int max_batch_; - int img_size_; - int patch_size_; - int in_chans_; - int embed_dim_; - int window_size_; - int* depths_; - int* num_heads_; - bool ape_; - bool patch_norm_; - int layer_num_; + int sm_; + bool _use_ORDER_COL32_2R_4R4; + int int8_mode_; + int max_batch_; + int img_size_; + int patch_size_; + int in_chans_; + int embed_dim_; + int window_size_; + int* depths_; + int* num_heads_; + bool ape_; + bool patch_norm_; + int layer_num_; float mlp_ratio_; - bool qkv_bias_; + bool qkv_bias_; float qk_scale_; - SwinTransformerINT8Func(const int int8_mode, - const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, + SwinTransformerINT8Func(const int int8_mode, + const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, const std::vector& w): weights_(w), int8_mode_(int8_mode), @@ -93,7 +93,7 @@ class SwinTransformerINT8Func: public ISwinTransformerINT8Func { _use_ORDER_COL32_2R_4R4 = true; } - cublas_algo_map_ = new ft::cublasAlgoMap(IGEMM_CONFIG, ""); + cublas_algo_map_ = new ft::cublasAlgoMap(IGEMM_CONFIG, ""); cublas_wrapper_mutex_ = new std::mutex(); // We arrange weights layer by layer and block by block inside each layer; @@ -121,41 +121,41 @@ class SwinTransformerINT8Func: public ISwinTransformerINT8Func { ft::SwinTransformerINT8BasicLayerWeight bl; for (int di = 0; di < depths[l]; di++) { ft::SwinTransformerINT8BlockWeight p; - p.attention_weights.query_weight.kernel = get_ptr(weights_[weight_idx++]); - p.attention_weights.query_weight.bias = get_ptr(weights_[weight_idx++]); + p.attention_weights.query_weight.kernel = get_ptr(weights_[weight_idx++]); + p.attention_weights.query_weight.bias = get_ptr(weights_[weight_idx++]); p.attention_weights.attention_output_weight.kernel = get_ptr(weights_[weight_idx++]); - p.attention_weights.attention_output_weight.bias = get_ptr(weights_[weight_idx++]); - p.ffn_weights.intermediate_weight.kernel = get_ptr(weights_[weight_idx++]); - p.ffn_weights.intermediate_weight.bias = get_ptr(weights_[weight_idx++]); - p.ffn_weights.output_weight.kernel = get_ptr(weights_[weight_idx++]); - p.ffn_weights.output_weight.bias = get_ptr(weights_[weight_idx++]); - p.attn_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); - p.attn_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); - p.ffn_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); - p.ffn_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); - p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; - p.scalelist.p2_offset_ = ACTIVATION_AMAX_NUM; - p.scalelist.p3_offset_ = ACTIVATION_AMAX_NUM + 5; - p.scalelist.p4_offset_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM; - p.scalelist.d_scale_list_ = get_ptr(weights_[weight_idx++]); - p.scalelist.h_scale_list_ = get_ptr(weights_[weight_idx++]); + p.attention_weights.attention_output_weight.bias = get_ptr(weights_[weight_idx++]); + p.ffn_weights.intermediate_weight.kernel = get_ptr(weights_[weight_idx++]); + p.ffn_weights.intermediate_weight.bias = get_ptr(weights_[weight_idx++]); + p.ffn_weights.output_weight.kernel = get_ptr(weights_[weight_idx++]); + p.ffn_weights.output_weight.bias = get_ptr(weights_[weight_idx++]); + p.attn_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); + p.attn_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); + p.ffn_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); + p.ffn_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); + p.scalelist.size_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM + TRT_AMAX_NUM; + p.scalelist.p2_offset_ = ACTIVATION_AMAX_NUM; + p.scalelist.p3_offset_ = ACTIVATION_AMAX_NUM + 5; + p.scalelist.p4_offset_ = ACTIVATION_AMAX_NUM + 5 + INT8O_GEMM_NUM; + p.scalelist.d_scale_list_ = get_ptr(weights_[weight_idx++]); + p.scalelist.h_scale_list_ = get_ptr(weights_[weight_idx++]); p.attention_relative_pos_bias = get_ptr(weights_[weight_idx++]); bl.block_weight_list.push_back(p); } bl.merge_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); - bl.merge_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); - bl.merge_linear_weights.kernel = get_ptr(weights_[weight_idx++]); - bl.attn_mask = get_ptr(weights_[weight_idx++]); + bl.merge_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); + bl.merge_linear_weights.kernel = get_ptr(weights_[weight_idx++]); + bl.attn_mask = get_ptr(weights_[weight_idx++]); params_.basic_layer_weight_list.push_back(bl); hidden_dim *= 2; } params_.patchEmbed_linear_weights.kernel = get_ptr(weights_[weight_idx++]); - params_.patchEmbed_linear_weights.bias = get_ptr(weights_[weight_idx++]); - params_.patchEmbed_norm_weights.gamma = get_ptr(weights_[weight_idx++]); - params_.patchEmbed_norm_weights.beta = get_ptr(weights_[weight_idx++]); - params_.norm_weights.gamma = get_ptr(weights_[weight_idx++]); - params_.norm_weights.beta = get_ptr(weights_[weight_idx++]); + params_.patchEmbed_linear_weights.bias = get_ptr(weights_[weight_idx++]); + params_.patchEmbed_norm_weights.gamma = get_ptr(weights_[weight_idx++]); + params_.patchEmbed_norm_weights.beta = get_ptr(weights_[weight_idx++]); + params_.norm_weights.gamma = get_ptr(weights_[weight_idx++]); + params_.norm_weights.beta = get_ptr(weights_[weight_idx++]); } ~SwinTransformerINT8Func() override @@ -204,8 +204,8 @@ class SwinTransformerINT8Func: public ISwinTransformerINT8Func { qkv_bias_, qk_scale_); - ft::DataType data_type = ft::getTensorType(); - int sm_ptr[1] = {sm_}; + ft::DataType data_type = ft::getTensorType(); + int sm_ptr[1] = {sm_}; std::vector input_tensors = std::vector{ ft::Tensor{ft::MEMORY_GPU, data_type, @@ -226,33 +226,33 @@ class SwinTransformerINT8Func: public ISwinTransformerINT8Func { } private: - std::vector weights_; - cublasHandle_t cublas_handle_ = nullptr; - cudnnHandle_t cudnn_handle_ = nullptr; - cublasLtHandle_t cublaslt_handle_ = nullptr; + std::vector weights_; + cublasHandle_t cublas_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; + cublasLtHandle_t cublaslt_handle_ = nullptr; ft::SwinTransformerINT8Weight params_; - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; }; class SwinTransformerINT8Class: public torch::jit::CustomClassHolder { public: SwinTransformerINT8Class(std::vector w, - int64_t int8_mode, - th::Tensor depths, - th::Tensor num_heads, - int64_t max_batch, - int64_t img_size, - int64_t patch_size, - int64_t in_chans, - int64_t embed_dim, - int64_t window_size, - bool ape, - bool patch_norm, - int64_t layer_num, - double mlp_ratio, - bool qkv_bias = true, - double qk_scale = 1.0); + int64_t int8_mode, + th::Tensor depths, + th::Tensor num_heads, + int64_t max_batch, + int64_t img_size, + int64_t patch_size, + int64_t in_chans, + int64_t embed_dim, + int64_t window_size, + bool ape, + bool patch_norm, + int64_t layer_num, + double mlp_ratio, + bool qkv_bias = true, + double qk_scale = 1.0); ~SwinTransformerINT8Class(); @@ -261,14 +261,14 @@ class SwinTransformerINT8Class: public torch::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType st_; + const at::ScalarType st_; ISwinTransformerINT8Func* swin_transformer_func_; - std::vector weights_; - th::Tensor depths_; - th::Tensor num_heads_; - th::Tensor info_int_; - th::Tensor info_float_; - int output_dim_; + std::vector weights_; + th::Tensor depths_; + th::Tensor num_heads_; + th::Tensor info_int_; + th::Tensor info_float_; + int output_dim_; }; } // namespace torch_ext diff --git a/src/fastertransformer/th_op/swin/SwinOp.cc b/src/fastertransformer/th_op/swin/SwinOp.cc index 7a9e18436..47699516e 100644 --- a/src/fastertransformer/th_op/swin/SwinOp.cc +++ b/src/fastertransformer/th_op/swin/SwinOp.cc @@ -23,20 +23,20 @@ template class SwinTransformerFunc; template class SwinTransformerFunc; SwinTransformerClass::SwinTransformerClass(std::vector w, - th::Tensor depths, - th::Tensor num_heads, - int64_t max_batch, - int64_t img_size, - int64_t patch_size, - int64_t in_chans, - int64_t embed_dim, - int64_t window_size, - bool ape, - bool patch_norm, - int64_t layer_num, - double mlp_ratio, - bool qkv_bias, - double qk_scale): + th::Tensor depths, + th::Tensor num_heads, + int64_t max_batch, + int64_t img_size, + int64_t patch_size, + int64_t in_chans, + int64_t embed_dim, + int64_t window_size, + bool ape, + bool patch_norm, + int64_t layer_num, + double mlp_ratio, + bool qkv_bias, + double qk_scale): st_(w[0].scalar_type()), depths_(depths), num_heads_(num_heads), weights_(w) { @@ -87,21 +87,40 @@ SwinTransformerClass::SwinTransformerClass(std::vector w, qk_scale, weights_); break; +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + swin_transformer_func_ = new SwinTransformerFunc<__nv_bfloat16>(max_batch, + img_size, + patch_size, + in_chans, + embed_dim, + window_size, + get_ptr(depths), + get_ptr(num_heads), + ape, + patch_norm, + layer_num, + mlp_ratio, + qkv_bias, + qk_scale, + weights_); + break; +#endif default: throw std::runtime_error("Wrong th::Tensor type."); } - info_int_ = torch::empty({10}, torch::dtype(torch::kInt64)); - info_int_[0] = max_batch; - info_int_[1] = img_size; - info_int_[2] = patch_size; - info_int_[3] = in_chans; - info_int_[4] = embed_dim; - info_int_[5] = window_size; - info_int_[6] = (int64_t)ape; - info_int_[7] = (int64_t)patch_norm; - info_int_[8] = layer_num; - info_int_[9] = (int64_t)qkv_bias; - info_float_ = torch::empty({2}, torch::dtype(torch::kFloat64)); + info_int_ = torch::empty({10}, torch::dtype(torch::kInt64)); + info_int_[0] = max_batch; + info_int_[1] = img_size; + info_int_[2] = patch_size; + info_int_[3] = in_chans; + info_int_[4] = embed_dim; + info_int_[5] = window_size; + info_int_[6] = (int64_t)ape; + info_int_[7] = (int64_t)patch_norm; + info_int_[8] = layer_num; + info_int_[9] = (int64_t)qkv_bias; + info_float_ = torch::empty({2}, torch::dtype(torch::kFloat64)); info_float_[0] = mlp_ratio; info_float_[1] = qk_scale; } @@ -124,15 +143,15 @@ SwinTransformerClass::~SwinTransformerClass() th::Tensor SwinTransformerClass::forward(th::Tensor input) { CHECK_INPUT(input, st_); - int batch_size = input.size(0); + int batch_size = input.size(0); auto output = torch::empty({batch_size, output_dim_}, torch::dtype(input.dtype()).device(torch::kCUDA).requires_grad(false)); swin_transformer_func_->forward(batch_size, input, output); return output; } -th::Tensor gen_relative_pos_bias(th::Tensor relative_position_bias_table, - th::Tensor relative_position_bias_index, +th::Tensor gen_relative_pos_bias(th::Tensor relative_position_bias_table, + th::Tensor relative_position_bias_index, const int64_t window_size, const int64_t head_num) { @@ -179,25 +198,25 @@ static auto swinTransformerTHS = return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int state_size = state.size(); - std::vector::const_iterator first = state.begin(); - std::vector::const_iterator last = state.begin() + (state_size - 4); - std::vector weights(first, last); - int idx = state.size() - 2; - int i = 0; - int64_t max_batch = state[idx][i++].item().to(); - int64_t img_size = state[idx][i++].item().to(); - int64_t patch_size = state[idx][i++].item().to(); - int64_t in_chans = state[idx][i++].item().to(); - int64_t embed_dim = state[idx][i++].item().to(); - int64_t window_size = state[idx][i++].item().to(); - bool ape = state[idx][i++].item().to(); - bool patch_norm = state[idx][i++].item().to(); - int64_t layer_num = state[idx][i++].item().to(); - bool qkv_bias = state[idx][i++].item().to(); - idx = state.size() - 1; - double mlp_ratio = state[idx][0].item().to(); - double qk_scale = state[idx][1].item().to(); + int state_size = state.size(); + std::vector::const_iterator first = state.begin(); + std::vector::const_iterator last = state.begin() + (state_size - 4); + std::vector weights(first, last); + int idx = state.size() - 2; + int i = 0; + int64_t max_batch = state[idx][i++].item().to(); + int64_t img_size = state[idx][i++].item().to(); + int64_t patch_size = state[idx][i++].item().to(); + int64_t in_chans = state[idx][i++].item().to(); + int64_t embed_dim = state[idx][i++].item().to(); + int64_t window_size = state[idx][i++].item().to(); + bool ape = state[idx][i++].item().to(); + bool patch_norm = state[idx][i++].item().to(); + int64_t layer_num = state[idx][i++].item().to(); + bool qkv_bias = state[idx][i++].item().to(); + idx = state.size() - 1; + double mlp_ratio = state[idx][0].item().to(); + double qk_scale = state[idx][1].item().to(); return c10::make_intrusive(weights, state[state_size - 4], state[state_size - 3], diff --git a/src/fastertransformer/th_op/swin/SwinOp.h b/src/fastertransformer/th_op/swin/SwinOp.h index 58e805e45..120bcbce9 100644 --- a/src/fastertransformer/th_op/swin/SwinOp.h +++ b/src/fastertransformer/th_op/swin/SwinOp.h @@ -32,36 +32,36 @@ class ISwinTransformerFunc { template class SwinTransformerFunc: public ISwinTransformerFunc { public: - int sm_; - int max_batch_; - int img_size_; - int patch_size_; - int in_chans_; - int embed_dim_; - int window_size_; - int* depths_; - int* num_heads_; - bool ape_; - bool patch_norm_; - int layer_num_; + int sm_; + int max_batch_; + int img_size_; + int patch_size_; + int in_chans_; + int embed_dim_; + int window_size_; + int* depths_; + int* num_heads_; + bool ape_; + bool patch_norm_; + int layer_num_; float mlp_ratio_; - bool qkv_bias_; + bool qkv_bias_; float qk_scale_; - SwinTransformerFunc(const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int window_size, - int* depths, - int* num_heads, - const bool ape, - const bool patch_norm, - const int layer_num, - const float mlp_ratio, - const bool qkv_bias, - const float qk_scale, + SwinTransformerFunc(const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int window_size, + int* depths, + int* num_heads, + const bool ape, + const bool patch_norm, + const int layer_num, + const float mlp_ratio, + const bool qkv_bias, + const float qk_scale, const std::vector& w): weights_(w), max_batch_(max_batch), @@ -85,7 +85,7 @@ class SwinTransformerFunc: public ISwinTransformerFunc { sm_ = ft::getSMVersion(); - cublas_algo_map_ = new ft::cublasAlgoMap(GEMM_CONFIG, ""); + cublas_algo_map_ = new ft::cublasAlgoMap(GEMM_CONFIG, ""); cublas_wrapper_mutex_ = new std::mutex(); // We arrange weights layer by layer and block by block inside each layer; @@ -112,33 +112,33 @@ class SwinTransformerFunc: public ISwinTransformerFunc { ft::SwinTransformerBasicLayerWeight bl; for (int di = 0; di < depths[l]; di++) { ft::SwinTransformerBlockWeight p; - p.attention_weights.query_weight.kernel = get_ptr(weights_[weight_idx++]); - p.attention_weights.query_weight.bias = get_ptr(weights_[weight_idx++]); + p.attention_weights.query_weight.kernel = get_ptr(weights_[weight_idx++]); + p.attention_weights.query_weight.bias = get_ptr(weights_[weight_idx++]); p.attention_weights.attention_output_weight.kernel = get_ptr(weights_[weight_idx++]); - p.attention_weights.attention_output_weight.bias = get_ptr(weights_[weight_idx++]); - p.ffn_weights.intermediate_weight.kernel = get_ptr(weights_[weight_idx++]); - p.ffn_weights.intermediate_weight.bias = get_ptr(weights_[weight_idx++]); - p.ffn_weights.output_weight.kernel = get_ptr(weights_[weight_idx++]); - p.ffn_weights.output_weight.bias = get_ptr(weights_[weight_idx++]); - p.attn_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); - p.attn_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); - p.ffn_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); - p.ffn_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); - p.attention_relative_pos_bias = get_ptr(weights_[weight_idx++]); + p.attention_weights.attention_output_weight.bias = get_ptr(weights_[weight_idx++]); + p.ffn_weights.intermediate_weight.kernel = get_ptr(weights_[weight_idx++]); + p.ffn_weights.intermediate_weight.bias = get_ptr(weights_[weight_idx++]); + p.ffn_weights.output_weight.kernel = get_ptr(weights_[weight_idx++]); + p.ffn_weights.output_weight.bias = get_ptr(weights_[weight_idx++]); + p.attn_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); + p.attn_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); + p.ffn_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); + p.ffn_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); + p.attention_relative_pos_bias = get_ptr(weights_[weight_idx++]); bl.block_weight_list.push_back(p); } bl.merge_layernorm_weights.gamma = get_ptr(weights_[weight_idx++]); - bl.merge_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); - bl.merge_linear_weights.kernel = get_ptr(weights_[weight_idx++]); - bl.attn_mask = get_ptr(weights_[weight_idx++]); + bl.merge_layernorm_weights.beta = get_ptr(weights_[weight_idx++]); + bl.merge_linear_weights.kernel = get_ptr(weights_[weight_idx++]); + bl.attn_mask = get_ptr(weights_[weight_idx++]); params_.basic_layer_weight_list.push_back(bl); } params_.patchEmbed_linear_weights.kernel = get_ptr(weights_[weight_idx++]); - params_.patchEmbed_linear_weights.bias = get_ptr(weights_[weight_idx++]); - params_.patchEmbed_norm_weights.gamma = get_ptr(weights_[weight_idx++]); - params_.patchEmbed_norm_weights.beta = get_ptr(weights_[weight_idx++]); - params_.norm_weights.gamma = get_ptr(weights_[weight_idx++]); - params_.norm_weights.beta = get_ptr(weights_[weight_idx++]); + params_.patchEmbed_linear_weights.bias = get_ptr(weights_[weight_idx++]); + params_.patchEmbed_norm_weights.gamma = get_ptr(weights_[weight_idx++]); + params_.patchEmbed_norm_weights.beta = get_ptr(weights_[weight_idx++]); + params_.norm_weights.gamma = get_ptr(weights_[weight_idx++]); + params_.norm_weights.beta = get_ptr(weights_[weight_idx++]); } ~SwinTransformerFunc() override @@ -162,6 +162,11 @@ class SwinTransformerFunc: public ISwinTransformerFunc { if (std::is_same::value) { cublas_wrapper->setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } @@ -186,20 +191,20 @@ class SwinTransformerFunc: public ISwinTransformerFunc { qkv_bias_, qk_scale_); - ft::DataType data_type = ft::getTensorType(); - int sm_ptr[1] = {sm_}; + ft::DataType data_type = ft::getTensorType(); + int sm_ptr[1] = {sm_}; std::vector input_tensors = std::vector{ ft::Tensor{ft::MEMORY_GPU, data_type, - std::vector{(size_t)batch_size, (size_t)img_size_ * img_size_, (size_t)in_chans_}, + std::vector{(size_t)batch_size, (size_t)in_chans_, (size_t)img_size_, (size_t)img_size_}, get_ptr(input)}, ft::Tensor{ft::MEMORY_CPU, ft::TYPE_INT8, std::vector{1}, sm_ptr}}; - std::vector output_tensors = std::vector{ - ft::Tensor{ft::MEMORY_GPU, - data_type, - std::vector{(size_t)batch_size, (size_t)img_size_ * img_size_, (size_t)in_chans_}, - get_ptr(output)}}; + std::vector output_tensors = + std::vector{ft::Tensor{ft::MEMORY_GPU, + data_type, + std::vector{(size_t)batch_size, (size_t)output.size(1)}, + get_ptr(output)}}; swin_transformer->forward(&output_tensors, &input_tensors, params_); delete swin_transformer; delete cublas_wrapper; @@ -207,32 +212,32 @@ class SwinTransformerFunc: public ISwinTransformerFunc { } private: - std::vector weights_; - cublasHandle_t cublas_handle_ = nullptr; - cublasLtHandle_t cublaslt_handle_ = nullptr; - cudnnHandle_t cudnn_handle_ = nullptr; + std::vector weights_; + cublasHandle_t cublas_handle_ = nullptr; + cublasLtHandle_t cublaslt_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; ft::SwinTransformerWeight params_; - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; }; class SwinTransformerClass: public torch::jit::CustomClassHolder { public: SwinTransformerClass(std::vector w, - th::Tensor depths, - th::Tensor num_heads, - int64_t max_batch, - int64_t img_size, - int64_t patch_size, - int64_t in_chans, - int64_t embed_dim, - int64_t window_size, - bool ape, - bool patch_norm, - int64_t layer_num, - double mlp_ratio, - bool qkv_bias = true, - double qk_scale = 1.0); + th::Tensor depths, + th::Tensor num_heads, + int64_t max_batch, + int64_t img_size, + int64_t patch_size, + int64_t in_chans, + int64_t embed_dim, + int64_t window_size, + bool ape, + bool patch_norm, + int64_t layer_num, + double mlp_ratio, + bool qkv_bias = true, + double qk_scale = 1.0); ~SwinTransformerClass(); @@ -241,28 +246,28 @@ class SwinTransformerClass: public torch::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType st_; - ISwinTransformerFunc* swin_transformer_func_; + const at::ScalarType st_; + ISwinTransformerFunc* swin_transformer_func_; std::vector weights_; - th::Tensor depths_; - th::Tensor num_heads_; - th::Tensor info_int_; - th::Tensor info_float_; - int output_dim_; + th::Tensor depths_; + th::Tensor num_heads_; + th::Tensor info_int_; + th::Tensor info_float_; + int output_dim_; }; template th::Tensor gen_relative_pos_bias_impl(th::Tensor relative_position_bias_table, th::Tensor relative_position_bias_index, - const int window_size, - const int head_num) + const int window_size, + const int head_num) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - int window_len = window_size * window_size; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + int window_len = window_size * window_size; const T* relative_position_bias_table_ptr = get_ptr(relative_position_bias_table); CHECK_INPUT(relative_position_bias_index, at::ScalarType::Long); const int64_t* relative_position_bias_index_ptr = get_ptr(relative_position_bias_index); - auto output = + auto output = torch::empty({head_num, window_len, window_len}, torch::dtype(relative_position_bias_table.dtype()).device(torch::kCUDA).requires_grad(false)); T* output_ptr = get_ptr(output); @@ -273,8 +278,8 @@ th::Tensor gen_relative_pos_bias_impl(th::Tensor relative_position_bias_table, return output; } -th::Tensor gen_relative_pos_bias(th::Tensor relative_position_bias_table, - th::Tensor relative_position_bias_index, +th::Tensor gen_relative_pos_bias(th::Tensor relative_position_bias_table, + th::Tensor relative_position_bias_index, const int64_t window_size, const int64_t head_num); diff --git a/src/fastertransformer/th_op/swin/WeightQuantizeOp.cc b/src/fastertransformer/th_op/swin/WeightQuantizeOp.cc index e92b109c3..c564691e7 100644 --- a/src/fastertransformer/th_op/swin/WeightQuantizeOp.cc +++ b/src/fastertransformer/th_op/swin/WeightQuantizeOp.cc @@ -46,15 +46,15 @@ Tensor swin_weight_quantize(Tensor weight, Tensor quant_max) TORCH_CHECK(quant_max.dtype() == torch::kFloat32, "quant_max dtype should be float32"); TORCH_CHECK(quant_max.numel() == 1, "quant_max wrong shape"); - const half* weight_ = get_ptr(weight); + const half* weight_ = get_ptr(weight); const float* quant_max_ = get_ptr(quant_max); - auto output = torch::empty({k * n}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); + auto output = torch::empty({k * n}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); int8_t* transform_out = get_ptr(output); auto stream = at::cuda::getCurrentCUDAStream().stream(); - int format = use_ORDER_COL32_2R_4R4 ? 1 : 2; + int format = use_ORDER_COL32_2R_4R4 ? 1 : 2; const int scale_is_vector = 0; invokeQuantizeWeight(transform_out, weight_, quant_max_, n, k, format, stream, scale_is_vector); return output; diff --git a/src/fastertransformer/th_op/t5/CMakeLists.txt b/src/fastertransformer/th_op/t5/CMakeLists.txt index ccf7a073d..94b6992eb 100644 --- a/src/fastertransformer/th_op/t5/CMakeLists.txt +++ b/src/fastertransformer/th_op/t5/CMakeLists.txt @@ -13,4 +13,4 @@ # limitations under the License. add_library(th_t5 SHARED T5EncoderOp.cc T5DecoderOp.cc T5DecodingOp.cc) -target_link_libraries(th_t5 PRIVATE "${TORCH_LIBRARIES}" T5Encoder T5Decoder T5Decoding th_utils) +target_link_libraries(th_t5 PRIVATE "${TORCH_LIBRARIES}" T5Encoder T5Decoder T5Decoding th_utils nccl_utils) diff --git a/src/fastertransformer/th_op/t5/T5DecoderOp.cc b/src/fastertransformer/th_op/t5/T5DecoderOp.cc index 9ca439c94..7a9f077c1 100644 --- a/src/fastertransformer/th_op/t5/T5DecoderOp.cc +++ b/src/fastertransformer/th_op/t5/T5DecoderOp.cc @@ -29,27 +29,38 @@ FasterTransformerT5Decoder::FasterTransformerT5Decoder(th::Tensor self_layernorm th::Tensor cross_output_kernel, th::Tensor ffn_layernorm_gamma, th::Tensor inter_kernel, + th::Tensor inter_kernel2, th::Tensor output_kernel, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - int64_t d_model, - int64_t layer_num, - int64_t mem_d_model, - int64_t tensor_para_size, - int64_t pipeline_para_size): + th::Tensor self_layernorm_beta, + th::Tensor self_bias_qkv, + th::Tensor self_output_bias, + th::Tensor cross_layernorm_beta, + th::Tensor cross_bias_q, + th::Tensor cross_bias_k, + th::Tensor cross_bias_v, + th::Tensor cross_output_bias, + th::Tensor ffn_layernorm_beta, + th::Tensor inter_bias, + th::Tensor inter_bias2, + th::Tensor output_bias, + int64_t head_num, + int64_t head_size, + int64_t inter_size, + int64_t d_model, + int64_t layer_num, + int64_t mem_d_model, + int64_t tensor_para_size, + int64_t pipeline_para_size, + bool t5_with_bias, + bool use_gated_activation, + int64_t position_embedding_type): _st(self_kernel_q.scalar_type()), - weights{self_layernorm_gamma, - self_kernel_q, - self_output_kernel, - cross_layernorm_gamma, - cross_kernel_q, - cross_kernel_k, - cross_kernel_v, - cross_output_kernel, - ffn_layernorm_gamma, - inter_kernel, - output_kernel} + weights{self_layernorm_gamma, self_kernel_q, self_output_kernel, cross_layernorm_gamma, + cross_kernel_q, cross_kernel_k, cross_kernel_v, cross_output_kernel, + ffn_layernorm_gamma, inter_kernel, inter_kernel2, output_kernel, + self_layernorm_beta, self_bias_qkv, self_output_bias, cross_layernorm_beta, + cross_bias_q, cross_bias_k, cross_bias_v, cross_output_bias, + ffn_layernorm_beta, inter_bias, inter_bias2, output_bias} { for (auto t : weights) { CHECK_INPUT(t, _st); @@ -65,6 +76,9 @@ FasterTransformerT5Decoder::FasterTransformerT5Decoder(th::Tensor self_layernorm mem_d_model, tensor_para_size, pipeline_para_size, + t5_with_bias, + use_gated_activation, + position_embedding_type, weights); break; case at::ScalarType::Half: @@ -76,20 +90,42 @@ FasterTransformerT5Decoder::FasterTransformerT5Decoder(th::Tensor self_layernorm mem_d_model, tensor_para_size, pipeline_para_size, + t5_with_bias, + use_gated_activation, + position_embedding_type, weights); break; +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + ftdecoder = new FTT5Decoder<__nv_bfloat16>(head_num, + head_size, + inter_size, + d_model, + layer_num, + mem_d_model, + tensor_para_size, + pipeline_para_size, + t5_with_bias, + use_gated_activation, + position_embedding_type, + weights); + break; +#endif default: throw std::runtime_error("Wrong Tensor type."); } - head_info = torch::empty({8}, torch::dtype(torch::kInt64)); - head_info[0] = head_num; - head_info[1] = head_size; - head_info[2] = layer_num; - head_info[3] = inter_size; - head_info[4] = d_model; - head_info[5] = mem_d_model; - head_info[6] = tensor_para_size; - head_info[7] = pipeline_para_size; + head_info = torch::empty({11}, torch::dtype(torch::kInt64)); + head_info[0] = head_num; + head_info[1] = head_size; + head_info[2] = layer_num; + head_info[3] = inter_size; + head_info[4] = d_model; + head_info[5] = mem_d_model; + head_info[6] = tensor_para_size; + head_info[7] = pipeline_para_size; + head_info[8] = t5_with_bias; + head_info[9] = use_gated_activation; + head_info[10] = position_embedding_type; } FasterTransformerT5Decoder::~FasterTransformerT5Decoder() @@ -97,7 +133,7 @@ FasterTransformerT5Decoder::~FasterTransformerT5Decoder() delete ftdecoder; } -std::vector FasterTransformerT5Decoder::forward(int64_t step, +std::vector FasterTransformerT5Decoder::forward(int64_t step, th::Tensor from_tensor, th::Tensor memory_tensor, th::Tensor memory_sequence_length, @@ -172,6 +208,20 @@ static auto fasterTransformerDecoderTHS = th::Tensor, th::Tensor, th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + th::Tensor, + int64_t, int64_t, int64_t, int64_t, @@ -179,6 +229,8 @@ static auto fasterTransformerDecoderTHS = int64_t, int64_t, int64_t, + bool, + bool, int64_t>()) .def("forward", &torch_ext::FasterTransformerT5Decoder::forward) .def_pickle( @@ -186,14 +238,17 @@ static auto fasterTransformerDecoderTHS = return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int64_t head_num = state[11][0].item().to(); - int64_t head_size = state[11][1].item().to(); - int64_t layer_num = state[11][2].item().to(); - int64_t inter_size = state[11][3].item().to(); - int64_t d_model = state[11][4].item().to(); - int64_t mem_d_model = state[11][5].item().to(); - int64_t tensor_para_size = state[11][6].item().to(); - int64_t pipeline_para_size = state[11][7].item().to(); + int64_t head_num = state[24][0].item().to(); + int64_t head_size = state[24][1].item().to(); + int64_t layer_num = state[24][2].item().to(); + int64_t inter_size = state[24][3].item().to(); + int64_t d_model = state[24][4].item().to(); + int64_t mem_d_model = state[24][5].item().to(); + int64_t tensor_para_size = state[24][6].item().to(); + int64_t pipeline_para_size = state[24][7].item().to(); + bool t5_with_bias = (bool)state[24][8].item().to(); + bool use_gated_activation = (bool)state[24][9].item().to(); + int64_t position_embedding_type = state[24][10].item().to(); return c10::make_intrusive(state[0], state[1], state[2], @@ -205,6 +260,19 @@ static auto fasterTransformerDecoderTHS = state[8], state[9], state[10], + state[11], + state[12], + state[13], + state[14], + state[15], + state[16], + state[17], + state[18], + state[19], + state[20], + state[21], + state[22], + state[23], head_num, head_size, inter_size, @@ -212,5 +280,8 @@ static auto fasterTransformerDecoderTHS = layer_num, mem_d_model, tensor_para_size, - pipeline_para_size); + pipeline_para_size, + t5_with_bias, + use_gated_activation, + position_embedding_type); }); \ No newline at end of file diff --git a/src/fastertransformer/th_op/t5/T5DecoderOp.h b/src/fastertransformer/th_op/t5/T5DecoderOp.h index 0baf864f6..2f69a0b9c 100644 --- a/src/fastertransformer/th_op/t5/T5DecoderOp.h +++ b/src/fastertransformer/th_op/t5/T5DecoderOp.h @@ -16,7 +16,7 @@ #include "src/fastertransformer/models/t5/T5Decoder.h" #include "src/fastertransformer/th_op/th_utils.h" -#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" namespace ft = fastertransformer; namespace th = torch; @@ -25,8 +25,8 @@ namespace torch_ext { class IFT5Decoder { public: virtual ~IFT5Decoder() {} - virtual void forward(size_t batch_size, - size_t step, + virtual void forward(size_t batch_size, + size_t step, th::Tensor& from_tensor, th::Tensor& memory_tensor, th::Tensor& memory_sequence_length, @@ -42,26 +42,30 @@ class IFT5Decoder { template class FTT5Decoder: public IFT5Decoder { public: - FTT5Decoder(int head_num, - int head_size, - int inter_size, - int d_model, - int layer_num, - int mem_d_model, - int tensor_para_size, - int pipeline_para_size, + FTT5Decoder(int head_num, + int head_size, + int inter_size, + int d_model, + int layer_num, + int mem_d_model, + int tensor_para_size, + int pipeline_para_size, + bool t5_with_bias, + bool use_gated_activation, + int position_embedding_type, const std::vector& w): _head_num(head_num), _head_size(head_size), _inter_size(inter_size), + _t5_with_bias(t5_with_bias), + _use_gated_activation(use_gated_activation), + _position_embedding_type(position_embedding_type), _d_model(d_model), _weights(w), _layer_num(layer_num), _mem_d_model(mem_d_model) { - tensor_para_.world_size_ = tensor_para_size; - pipeline_para_.world_size_ = pipeline_para_size; - init_nccl_comm(); + ft::ftNcclInitialize(tensor_para_, pipeline_para_, tensor_para_size, pipeline_para_size); int hidden_dim = _head_num * _head_size; ft::check_cuda_error(cublasLtCreate(&_cublasltHandle)); @@ -82,124 +86,76 @@ class FTT5Decoder: public IFT5Decoder { decoder_layer_weights[i]->pre_layernorm_weights.gamma = get_ptr(_weights[0]) + (i - first_layer_index) * _d_model; decoder_layer_weights[i]->self_attention_weights.query_weight.kernel = - get_ptr(_weights[1]) + (i - first_layer_index) * _d_model * 3 * hidden_dim; + get_ptr(_weights[1]) + + (i - first_layer_index) * _d_model * 3 * hidden_dim / tensor_para_.world_size_; decoder_layer_weights[i]->self_attention_weights.attention_output_weight.kernel = - get_ptr(_weights[2]) + (i - first_layer_index) * hidden_dim * _d_model; + get_ptr(_weights[2]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_ * _d_model; decoder_layer_weights[i]->self_attn_layernorm_weights.gamma = get_ptr(_weights[3]) + (i - first_layer_index) * _d_model; decoder_layer_weights[i]->cross_attention_weights.query_weight.kernel = - get_ptr(_weights[4]) + (i - first_layer_index) * _d_model * hidden_dim; + get_ptr(_weights[4]) + (i - first_layer_index) * _d_model * hidden_dim / tensor_para_.world_size_; decoder_layer_weights[i]->cross_attention_weights.key_weight.kernel = - get_ptr(_weights[5]) + (i - first_layer_index) * _mem_d_model * hidden_dim; + get_ptr(_weights[5]) + + (i - first_layer_index) * _mem_d_model * hidden_dim / tensor_para_.world_size_; decoder_layer_weights[i]->cross_attention_weights.value_weight.kernel = - get_ptr(_weights[6]) + (i - first_layer_index) * _mem_d_model * hidden_dim; + get_ptr(_weights[6]) + + (i - first_layer_index) * _mem_d_model * hidden_dim / tensor_para_.world_size_; decoder_layer_weights[i]->cross_attention_weights.attention_output_weight.kernel = - get_ptr(_weights[7]) + (i - first_layer_index) * hidden_dim * _d_model; + get_ptr(_weights[7]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_ * _d_model; decoder_layer_weights[i]->cross_attn_layernorm_weights.gamma = get_ptr(_weights[8]) + (i - first_layer_index) * _d_model; decoder_layer_weights[i]->ffn_weights.intermediate_weight.kernel = - get_ptr(_weights[9]) + (i - first_layer_index) * _d_model * _inter_size; - decoder_layer_weights[i]->ffn_weights.output_weight.kernel = - get_ptr(_weights[10]) + (i - first_layer_index) * _inter_size * _d_model; - } - } - - void init_nccl_comm() - { - int mpi_initialized; - MPICHECK(MPI_Initialized(&mpi_initialized)); - if (!mpi_initialized) { - printf("[INFO] MPI is not initialized! Skipped the NCCL communication initialization.\n"); - if (tensor_para_.world_size_ != 1) { - printf("[FATAL] MPI initialization can only be skipped when tensor_para_size=1, but got %d!\n", - tensor_para_.world_size_); - } - if (pipeline_para_.world_size_ != 1) { - printf("[FATAL] MPI initialization can only be skipped when pipeline_para_size=1, but got %d!\n", - pipeline_para_.world_size_); - } - return; - } - - int rank, world_size; - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size)); - tensor_para_.rank_ = rank % tensor_para_.world_size_; - pipeline_para_.rank_ = rank / tensor_para_.world_size_; - - ncclUniqueId tensor_para_nccl_uid; - ncclUniqueId pipeline_para_nccl_uid; - - // assume gpu_num = n * k, - // tensor parallelism group size is n - // pipeline parallelism group size is k - if (tensor_para_.rank_ == 0) { - // get the uid of each tensor parallelism group - // here, 0, 1, ..., n-1 are in group 0, - // n, ..., 2n - 1 are in group 1. - NCCLCHECK(ncclGetUniqueId(&tensor_para_nccl_uid)); - for (int i = 1; i < (int)tensor_para_.world_size_; i++) { - FT_LOG_INFO("rank %d sends tensor_para_nccl_uid to rank %d \n", rank, rank + i); - MPICHECK(MPI_Send( - &tensor_para_nccl_uid, sizeof(tensor_para_nccl_uid), MPI_BYTE, rank + i, 0, MPI_COMM_WORLD)); + get_ptr(_weights[9]) + (i - first_layer_index) * _d_model * _inter_size / tensor_para_.world_size_; + if (_use_gated_activation) { + decoder_layer_weights[i]->ffn_weights.intermediate_weight2.kernel = + get_ptr(_weights[10]) + + (i - first_layer_index) * _d_model * _inter_size / tensor_para_.world_size_; } - } - else { - MPI_Status status; - FT_LOG_INFO("rank %d receives tensor_para_nccl_uid from rank %d \n", rank, rank - (int)tensor_para_.rank_); - MPICHECK(MPI_Recv(&tensor_para_nccl_uid, - sizeof(tensor_para_nccl_uid), - MPI_BYTE, - rank - tensor_para_.rank_, - 0, - MPI_COMM_WORLD, - &status)); - } + decoder_layer_weights[i]->ffn_weights.output_weight.kernel = + get_ptr(_weights[11]) + (i - first_layer_index) * _inter_size / tensor_para_.world_size_ * _d_model; - if (pipeline_para_.rank_ == 0) { - // get the uid of each pipeline parallelism group - // 0, k, 2k, are in group 0 - // 1, k+1, 2k+1 are in group 1 - NCCLCHECK(ncclGetUniqueId(&pipeline_para_nccl_uid)); - for (int i = 1; i < (int)pipeline_para_.world_size_; i++) { - FT_LOG_INFO("rank %d sends pipeline_para_nccl_uid to rank %d \n", - rank, - rank + i * (int)tensor_para_.world_size_); - MPICHECK(MPI_Send(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank + i * tensor_para_.world_size_, - 0, - MPI_COMM_WORLD)); + if (_t5_with_bias) { + decoder_layer_weights[i]->pre_layernorm_weights.beta = + get_ptr(_weights[12]) + (i - first_layer_index) * _d_model; + decoder_layer_weights[i]->self_attention_weights.query_weight.bias = + get_ptr(_weights[13]) + (i - first_layer_index) * 3 * hidden_dim / tensor_para_.world_size_; + decoder_layer_weights[i]->self_attention_weights.attention_output_weight.bias = + get_ptr(_weights[14]) + (i - first_layer_index) * _d_model; + decoder_layer_weights[i]->self_attn_layernorm_weights.beta = + get_ptr(_weights[15]) + (i - first_layer_index) * _d_model; + decoder_layer_weights[i]->cross_attention_weights.query_weight.bias = + get_ptr(_weights[16]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_; + decoder_layer_weights[i]->cross_attention_weights.key_weight.bias = + get_ptr(_weights[17]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_; + decoder_layer_weights[i]->cross_attention_weights.value_weight.bias = + get_ptr(_weights[18]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_; + decoder_layer_weights[i]->cross_attention_weights.attention_output_weight.bias = + get_ptr(_weights[19]) + (i - first_layer_index) * _d_model; + decoder_layer_weights[i]->cross_attn_layernorm_weights.beta = + get_ptr(_weights[20]) + (i - first_layer_index) * _d_model; + decoder_layer_weights[i]->ffn_weights.intermediate_weight.bias = + get_ptr(_weights[21]) + (i - first_layer_index) * _inter_size / tensor_para_.world_size_; + if (_use_gated_activation) { + decoder_layer_weights[i]->ffn_weights.intermediate_weight2.bias = + get_ptr(_weights[22]) + (i - first_layer_index) * _inter_size / tensor_para_.world_size_; + } + decoder_layer_weights[i]->ffn_weights.output_weight.bias = + get_ptr(_weights[23]) + (i - first_layer_index) * _d_model; } } - else { - MPI_Status status; - FT_LOG_INFO( - "rank %d receives pipeline_para_nccl_uid from rank %d \n", rank, rank % (int)tensor_para_.world_size_); - MPICHECK(MPI_Recv(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank % tensor_para_.world_size_, - 0, - MPI_COMM_WORLD, - &status)); - } - NCCLCHECK(ncclCommInitRank( - &tensor_para_.nccl_comm_, tensor_para_.world_size_, tensor_para_nccl_uid, tensor_para_.rank_)); - NCCLCHECK(ncclCommInitRank( - &pipeline_para_.nccl_comm_, pipeline_para_.world_size_, pipeline_para_nccl_uid, pipeline_para_.rank_)); } ~FTT5Decoder() override { + ft::ftNcclParamDestroy(tensor_para_); + ft::ftNcclParamDestroy(pipeline_para_); cublasLtDestroy(_cublasltHandle); delete cublas_algo_map_; delete cublas_wrapper_mutex_; } - void forward(size_t batch_size, - size_t step, + void forward(size_t batch_size, + size_t step, th::Tensor& from_tensor, th::Tensor& memory_tensor, th::Tensor& memory_sequence_length, @@ -211,7 +167,7 @@ class FTT5Decoder: public IFT5Decoder { th::Tensor& memory_cache_values_tensor, th::Tensor& relative_attention_bias_tensor) override { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t _cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(_cublasHandle, stream); fastertransformer::Allocator* allocator = @@ -222,25 +178,36 @@ class FTT5Decoder: public IFT5Decoder { if (std::is_same::value) { cublas_wrapper->setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } + // NeMo Megatron T5 with bias and use gelu or geglu. + ft::ActivationType activation_type = + _t5_with_bias ? (_use_gated_activation ? ft::ActivationType::GeGLU : ft::ActivationType::Gelu) : + ft::ActivationType::Relu; + ft::T5Decoder decoder = ft::T5Decoder(batch_size, _head_num, _head_size, _inter_size, _d_model, _layer_num, + _layernorm_eps, stream, cublas_wrapper, allocator, true, tensor_para_, pipeline_para_, - ft::ActivationType::Relu); + activation_type); - int tmp_step = step + 1; + int tmp_step = step + 1; std::vector input_tensors = std::vector{convert_tensor(from_tensor), convert_tensor(memory_tensor), @@ -272,20 +239,25 @@ class FTT5Decoder: public IFT5Decoder { } private: - const int _head_num; - const int _head_size; - const int _inter_size; - const int _d_model; - std::vector _weights; - const int _layer_num; - const int _mem_d_model; - cublasLtHandle_t _cublasltHandle; - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; + const int _head_num; + const int _head_size; + const int _inter_size; + const int _d_model; + std::vector _weights; + const int _layer_num; + static constexpr float _layernorm_eps = 1e-6f; + const int _mem_d_model; + cublasLtHandle_t _cublasltHandle; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; std::vector*> decoder_layer_weights; ft::NcclParam tensor_para_; ft::NcclParam pipeline_para_; + + bool _t5_with_bias; + bool _use_gated_activation; + int _position_embedding_type; }; class FasterTransformerT5Decoder: public th::jit::CustomClassHolder { @@ -300,19 +272,35 @@ class FasterTransformerT5Decoder: public th::jit::CustomClassHolder { th::Tensor cross_output_kernel, th::Tensor ffn_layernorm_gamma, th::Tensor inter_kernel, + th::Tensor inter_kernel2, th::Tensor output_kernel, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - int64_t d_model, - int64_t layer_num, - int64_t mem_d_model, - int64_t tensor_para_size, - int64_t pipeline_para_size); + th::Tensor self_layernorm_beta, + th::Tensor self_bias_qkv, + th::Tensor self_output_bias, + th::Tensor cross_layernorm_beta, + th::Tensor cross_bias_q, + th::Tensor cross_bias_k, + th::Tensor cross_bias_v, + th::Tensor cross_output_bias, + th::Tensor ffn_layernorm_beta, + th::Tensor inter_bias, + th::Tensor inter_bias2, + th::Tensor output_bias, + int64_t head_num, + int64_t head_size, + int64_t inter_size, + int64_t d_model, + int64_t layer_num, + int64_t mem_d_model, + int64_t tensor_para_size, + int64_t pipeline_para_size, + bool t5_with_bias, + bool use_gated_activation, + int64_t position_embedding_type); ~FasterTransformerT5Decoder(); - std::vector forward(int64_t step, + std::vector forward(int64_t step, th::Tensor from_tensor, th::Tensor memory_tensor, th::Tensor memory_sequence_length, @@ -326,9 +314,9 @@ class FasterTransformerT5Decoder: public th::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType _st; - IFT5Decoder* ftdecoder; - th::Tensor head_info; + const at::ScalarType _st; + IFT5Decoder* ftdecoder; + th::Tensor head_info; std::vector weights; }; diff --git a/src/fastertransformer/th_op/t5/T5DecodingOp.cc b/src/fastertransformer/th_op/t5/T5DecodingOp.cc index 055d66a8b..a1c9b84d9 100644 --- a/src/fastertransformer/th_op/t5/T5DecodingOp.cc +++ b/src/fastertransformer/th_op/t5/T5DecodingOp.cc @@ -20,63 +20,85 @@ namespace th = torch; namespace torch_ext { -FasterTransformerT5Decoding::FasterTransformerT5Decoding(int64_t head_num, - int64_t size_per_head, - int64_t inter_size, - int64_t mem_d_model, - int64_t d_model, - int64_t layer_num, - int64_t vocab_size, - int64_t num_bucket, - int64_t max_distance, - double q_scaling, - int64_t start_id, - int64_t end_id, - int64_t tensor_para_size, - int64_t pipeline_para_size, - bool t5_with_bias, - int64_t position_embedding_type, - th::Tensor self_layernorm_gamma, - th::Tensor self_kernel_q, - th::Tensor self_output_kernel, - th::Tensor cross_layernorm_gamma, - th::Tensor cross_kernel_q, - th::Tensor cross_kernel_k, - th::Tensor cross_kernel_v, - th::Tensor cross_output_kernel, - th::Tensor ffn_layernorm_gamma, - th::Tensor inter_kernel, - th::Tensor output_kernel, - th::Tensor decoding_gamma, - th::Tensor embedding_table, - th::Tensor absolute_or_relative_position_embedding, - th::Tensor self_layernorm_beta, - th::Tensor self_bias_qkv, - th::Tensor self_output_bias, - th::Tensor cross_layernorm_beta, - th::Tensor cross_bias_q, - th::Tensor cross_bias_k, - th::Tensor cross_bias_v, - th::Tensor cross_output_bias, - th::Tensor ffn_layernorm_beta, - th::Tensor inter_bias, - th::Tensor output_bias, - th::Tensor decoding_beta, - th::Tensor embedding_bias): - _st(self_layernorm_gamma.scalar_type()), weights{self_layernorm_gamma, self_kernel_q, - self_output_kernel, cross_layernorm_gamma, - cross_kernel_q, cross_kernel_k, - cross_kernel_v, cross_output_kernel, - ffn_layernorm_gamma, inter_kernel, - output_kernel, decoding_gamma, - embedding_table, absolute_or_relative_position_embedding, - self_layernorm_beta, self_bias_qkv, - self_output_bias, cross_layernorm_beta, - cross_bias_q, cross_bias_k, - cross_bias_v, cross_output_bias, - ffn_layernorm_beta, inter_bias, - output_bias, decoding_beta, - embedding_bias} +FasterTransformerT5Decoding::FasterTransformerT5Decoding(int64_t head_num, + int64_t size_per_head, + int64_t inter_size, + int64_t mem_d_model, + int64_t d_model, + int64_t layer_num, + int64_t vocab_size, + int64_t num_bucket, + int64_t max_distance, + double q_scaling, + int64_t start_id, + int64_t end_id, + int64_t tensor_para_size, + int64_t pipeline_para_size, + bool t5_with_bias, + int64_t position_embedding_type, + std::string activation_type, + bool tie_word_embeddings, + th::Tensor self_layernorm_gamma, + th::Tensor self_kernel_q, + th::Tensor self_output_kernel, + th::Tensor cross_layernorm_gamma, + th::Tensor cross_kernel_q, + th::Tensor cross_kernel_k, + th::Tensor cross_kernel_v, + th::Tensor cross_output_kernel, + th::Tensor ffn_layernorm_gamma, + th::Tensor inter_kernel, + th::Tensor inter_kernel2, + th::Tensor output_kernel, + th::Tensor decoding_gamma, + th::Tensor embedding_table, + th::Tensor lm_head, + th::Tensor absolute_or_relative_position_embedding, + th::Tensor self_layernorm_beta, + th::Tensor self_bias_qkv, + th::Tensor self_output_bias, + th::Tensor cross_layernorm_beta, + th::Tensor cross_bias_q, + th::Tensor cross_bias_k, + th::Tensor cross_bias_v, + th::Tensor cross_output_bias, + th::Tensor ffn_layernorm_beta, + th::Tensor inter_bias, + th::Tensor inter_bias2, + th::Tensor output_bias, + th::Tensor decoding_beta, + th::Tensor embedding_bias): + _st(self_layernorm_gamma.scalar_type()), + weights{self_layernorm_gamma, + self_kernel_q, + self_output_kernel, + cross_layernorm_gamma, + cross_kernel_q, + cross_kernel_k, + cross_kernel_v, + cross_output_kernel, + ffn_layernorm_gamma, + inter_kernel, + inter_kernel2, + output_kernel, + decoding_gamma, + embedding_table, + lm_head, + absolute_or_relative_position_embedding, + self_layernorm_beta, + self_bias_qkv, + self_output_bias, + cross_layernorm_beta, + cross_bias_q, + cross_bias_k, + cross_bias_v, + cross_output_bias, + ffn_layernorm_beta, + inter_bias, + inter_bias2, + output_bias, + decoding_beta, + embedding_bias} { CHECK_INPUT(self_layernorm_gamma, _st); // layer_num, d_model CHECK_INPUT(self_kernel_q, _st); // layer_num, d_model, 3 * hidden_dim @@ -88,9 +110,11 @@ FasterTransformerT5Decoding::FasterTransformerT5Decoding(int64_t head_num, CHECK_INPUT(cross_output_kernel, _st); // layer_num, hidden_dim, d_model CHECK_INPUT(ffn_layernorm_gamma, _st); // layer_num, d_model CHECK_INPUT(inter_kernel, _st); // layer_num, d_model, inter_size + CHECK_INPUT(inter_kernel2, _st); // layer_num, d_model, inter_size CHECK_INPUT(output_kernel, _st); // layer_num, inter_size, d_model CHECK_INPUT(decoding_gamma, _st); // d_model CHECK_INPUT(embedding_table, _st); // vocab_size, d_model + CHECK_INPUT(lm_head, _st); // d_model, vocab_size CHECK_INPUT(absolute_or_relative_position_embedding, _st); // head_num, num_bucket or max_seq_len, d_model if (t5_with_bias) { CHECK_INPUT(self_layernorm_beta, _st); // layer_num, d_model @@ -103,6 +127,7 @@ FasterTransformerT5Decoding::FasterTransformerT5Decoding(int64_t head_num, CHECK_INPUT(cross_output_bias, _st); // layer_num, d_model CHECK_INPUT(ffn_layernorm_beta, _st); // layer_num, d_model CHECK_INPUT(inter_bias, _st); // layer_num, inter_size + CHECK_INPUT(inter_bias2, _st); // layer_num, inter_size CHECK_INPUT(output_bias, _st); // layer_num, d_model CHECK_INPUT(decoding_beta, _st); // d_model CHECK_INPUT(embedding_bias, _st); // vocab_size @@ -125,6 +150,8 @@ FasterTransformerT5Decoding::FasterTransformerT5Decoding(int64_t head_num, pipeline_para_size, t5_with_bias, ft::PositionEmbeddingType(position_embedding_type), + ft::getActivationType(activation_type), + tie_word_embeddings, weights); break; case at::ScalarType::Half: @@ -144,28 +171,36 @@ FasterTransformerT5Decoding::FasterTransformerT5Decoding(int64_t head_num, pipeline_para_size, t5_with_bias, ft::PositionEmbeddingType(position_embedding_type), + ft::getActivationType(activation_type), + tie_word_embeddings, weights); break; +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + ftdecoding = new torch_ext::FTT5Decoding<__nv_bfloat16>(head_num, + size_per_head, + inter_size, + mem_d_model, + d_model, + layer_num, + vocab_size, + num_bucket, + max_distance, + q_scaling, + start_id, + end_id, + tensor_para_size, + pipeline_para_size, + t5_with_bias, + ft::PositionEmbeddingType(position_embedding_type), + ft::getActivationType(activation_type), + tie_word_embeddings, + weights); + break; +#endif default: throw std::runtime_error("Wrong Tensor type."); } - int_info_ = torch::empty({16}, torch::dtype(torch::kInt64)); - int_info_[0] = head_num; - int_info_[1] = size_per_head; - int_info_[2] = inter_size; - int_info_[3] = mem_d_model; - int_info_[4] = d_model; - int_info_[5] = layer_num; - int_info_[6] = vocab_size; - int_info_[7] = num_bucket; - int_info_[8] = max_distance; - int_info_[9] = q_scaling; - int_info_[10] = start_id; - int_info_[11] = end_id; - int_info_[12] = tensor_para_size; - int_info_[13] = pipeline_para_size; - int_info_[14] = t5_with_bias; - int_info_[15] = position_embedding_type; } FasterTransformerT5Decoding::~FasterTransformerT5Decoding() @@ -173,36 +208,38 @@ FasterTransformerT5Decoding::~FasterTransformerT5Decoding() delete ftdecoding; } -std::vector FasterTransformerT5Decoding::forward(int64_t beam_width, - int64_t max_seq_len, - int64_t top_k, - double top_p, - double beam_search_diversity_rate, - double temperature, - double len_penalty, - double repetition_penalty, - int64_t random_seed, - bool is_return_output_log_probs, - bool is_return_cum_log_probs, - th::Tensor memory, - th::Tensor memory_seq_lens) +std::vector FasterTransformerT5Decoding::forward(th::optional beam_width, + int64_t max_seq_len, + th::optional top_k, + th::optional top_p, + th::optional beam_search_diversity_rate, + th::optional temperature, + th::optional len_penalty, + th::optional repetition_penalty, + th::optional random_seed, + th::optional is_return_output_log_probs, + th::optional is_return_cum_log_probs, + th::optional is_return_cross_attentions, + th::Tensor memory, + th::Tensor memory_seq_lens) { CHECK_INPUT(memory, _st); CHECK_TH_CUDA(memory_seq_lens); CHECK_CONTIGUOUS(memory_seq_lens); TORCH_CHECK(memory_seq_lens.dtype() == torch::kInt32, "mem_seq_lens dtype should be int32"); - auto results = ftdecoding->forward((size_t)beam_width, + auto results = ftdecoding->forward(beam_width, (size_t)max_seq_len, - (size_t)top_k, - (float)top_p, - (float)beam_search_diversity_rate, - (float)temperature, - (float)len_penalty, - (float)repetition_penalty, - (unsigned long long)random_seed, + top_k, + top_p, + beam_search_diversity_rate, + temperature, + len_penalty, + repetition_penalty, + random_seed, is_return_output_log_probs, is_return_cum_log_probs, + is_return_cross_attentions, memory, memory_seq_lens); return results; @@ -211,7 +248,6 @@ std::vector FasterTransformerT5Decoding::forward(int64_t beam_width, std::vector FasterTransformerT5Decoding::get_pickle_info() const { std::vector tmp(weights); - tmp.push_back(int_info_); return tmp; } @@ -237,8 +273,13 @@ static auto fasterTransformerT5DecodingTHS = int64_t, int64_t, int64_t, + bool, int64_t, - int64_t, + std::string, + bool, + th::Tensor, + th::Tensor, + th::Tensor, th::Tensor, th::Tensor, th::Tensor, @@ -266,69 +307,4 @@ static auto fasterTransformerT5DecodingTHS = th::Tensor, th::Tensor, th::Tensor>()) - .def("forward", &torch_ext::FasterTransformerT5Decoding::forward) - .def_pickle( - [](const c10::intrusive_ptr& self) -> std::vector { - return self->get_pickle_info(); - }, - [](std::vector state) -> c10::intrusive_ptr { - int head_num = state[27][0].item().to(); - int size_per_head = state[27][1].item().to(); - int inter_size = state[27][2].item().to(); - int mem_d_model = state[27][3].item().to(); - int d_model = state[27][4].item().to(); - int layer_num = state[27][5].item().to(); - int vocab_size = state[27][6].item().to(); - int num_bucket = state[27][7].item().to(); - int max_distance = state[27][8].item().to(); - int start_id = state[27][9].item().to(); - int end_id = state[27][10].item().to(); - int tensor_para_size = state[27][11].item().to(); - int pipeline_para_size = state[27][12].item().to(); - bool t5_with_bias = (bool)state[27][13].item().to(); - int position_embedding_type = state[27][14].item().to(); - double q_scaling = state[28][0].item().to(); - return c10::make_intrusive(head_num, - size_per_head, - inter_size, - mem_d_model, - d_model, - layer_num, - vocab_size, - num_bucket, - max_distance, - q_scaling, - start_id, - end_id, - tensor_para_size, - pipeline_para_size, - t5_with_bias, - position_embedding_type, - state[0], - state[1], - state[2], - state[3], - state[4], - state[5], - state[6], - state[7], - state[8], - state[9], - state[10], - state[11], - state[12], - state[13], - state[14], - state[15], - state[16], - state[17], - state[18], - state[19], - state[20], - state[21], - state[22], - state[23], - state[24], - state[25], - state[26]); - }); \ No newline at end of file + .def("forward", &torch_ext::FasterTransformerT5Decoding::forward); \ No newline at end of file diff --git a/src/fastertransformer/th_op/t5/T5DecodingOp.h b/src/fastertransformer/th_op/t5/T5DecodingOp.h index 91114b219..0fd466a7d 100644 --- a/src/fastertransformer/th_op/t5/T5DecodingOp.h +++ b/src/fastertransformer/th_op/t5/T5DecodingOp.h @@ -16,7 +16,7 @@ #include "src/fastertransformer/models/t5/T5Decoding.h" #include "src/fastertransformer/th_op/th_utils.h" -#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" namespace ft = fastertransformer; namespace th = torch; @@ -26,40 +26,43 @@ namespace torch_ext { class IFTT5Decoding { public: virtual ~IFTT5Decoding() {} - virtual std::vector forward(size_t beam_width, - size_t max_seq_len, - size_t top_k, - float top_p, - float beam_search_diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - unsigned long long random_seed, - bool is_return_output_log_probs, - bool is_return_cum_log_probs, - th::Tensor memory, - th::Tensor memory_seq_lens) = 0; + virtual std::vector forward(th::optional beam_width_opt, + size_t max_seq_len, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional is_return_output_log_probs_opt, + th::optional is_return_cum_log_probs_opt, + th::optional is_return_cross_attentions_opt, + th::Tensor memory, + th::Tensor memory_seq_lens) = 0; }; template class FTT5Decoding: public IFTT5Decoding { public: - FTT5Decoding(int head_num, - int size_per_head, - int inter_size, - int mem_d_model, - int d_model, - int layer_num, - int vocab_size, - int num_bucket, - int max_distance, - double q_scaling, - int start_id, - int end_id, - int tensor_para_size, - int pipeline_para_size, - bool t5_with_bias, - ft::PositionEmbeddingType position_embedding_type, + FTT5Decoding(int head_num, + int size_per_head, + int inter_size, + int mem_d_model, + int d_model, + int layer_num, + int vocab_size, + int num_bucket, + int max_distance, + double q_scaling, + int start_id, + int end_id, + int tensor_para_size, + int pipeline_para_size, + bool t5_with_bias, + ft::PositionEmbeddingType position_embedding_type, + ft::ActivationType activation_type, + bool tie_word_embeddings, const std::vector& w): head_num_(head_num), size_per_head_(size_per_head), @@ -75,18 +78,19 @@ class FTT5Decoding: public IFTT5Decoding { end_id_(end_id), t5_with_bias_(t5_with_bias), position_embedding_type_(position_embedding_type), + activation_type_(activation_type), + tie_word_embeddings_(tie_word_embeddings), _weights(w) { - tensor_para_.world_size_ = tensor_para_size; - pipeline_para_.world_size_ = pipeline_para_size; - init_nccl_comm(); + bool use_gated_activation = isGatedActivation(activation_type_); + ft::ftNcclInitialize(tensor_para_, pipeline_para_, tensor_para_size, pipeline_para_size); ft::check_cuda_error(cublasLtCreate(&cublasltHandle_)); - cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); + cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in"); cublas_wrapper_mutex_ = new std::mutex(); decoding_weights.resizeLayer(layer_num_); - decoding_weights.setT5StructureDiff(t5_with_bias, position_embedding_type); + decoding_weights.setT5StructureDiff(t5_with_bias, use_gated_activation, position_embedding_type); const int hidden_dim = head_num_ * size_per_head_; for (int i = 0; i < layer_num_; ++i) { @@ -117,205 +121,150 @@ class FTT5Decoding: public IFTT5Decoding { get_ptr(_weights[8]) + (i - first_layer_index) * d_model; decoding_weights.decoder_layer_weights[i]->ffn_weights.intermediate_weight.kernel = get_ptr(_weights[9]) + (i - first_layer_index) * d_model * inter_size_ / tensor_para_.world_size_; + if (use_gated_activation) { + decoding_weights.decoder_layer_weights[i]->ffn_weights.intermediate_weight2.kernel = + get_ptr(_weights[10]) + + (i - first_layer_index) * d_model * inter_size_ / tensor_para_.world_size_; + } decoding_weights.decoder_layer_weights[i]->ffn_weights.output_weight.kernel = - get_ptr(_weights[10]) + (i - first_layer_index) * inter_size_ / tensor_para_.world_size_ * d_model; + get_ptr(_weights[11]) + (i - first_layer_index) * inter_size_ / tensor_para_.world_size_ * d_model; if (t5_with_bias_) { decoding_weights.decoder_layer_weights[i]->pre_layernorm_weights.beta = - get_ptr(_weights[14]) + (i - first_layer_index) * d_model; + get_ptr(_weights[16]) + (i - first_layer_index) * d_model; decoding_weights.decoder_layer_weights[i]->self_attention_weights.query_weight.bias = - get_ptr(_weights[15]) + (i - first_layer_index) * 3 * hidden_dim / tensor_para_.world_size_; + get_ptr(_weights[17]) + (i - first_layer_index) * 3 * hidden_dim / tensor_para_.world_size_; decoding_weights.decoder_layer_weights[i]->self_attention_weights.attention_output_weight.bias = - get_ptr(_weights[16]) + (i - first_layer_index) * d_model; + get_ptr(_weights[18]) + (i - first_layer_index) * d_model; decoding_weights.decoder_layer_weights[i]->self_attn_layernorm_weights.beta = - get_ptr(_weights[17]) + (i - first_layer_index) * d_model; + get_ptr(_weights[19]) + (i - first_layer_index) * d_model; decoding_weights.decoder_layer_weights[i]->cross_attention_weights.query_weight.bias = - get_ptr(_weights[18]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_; + get_ptr(_weights[20]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_; decoding_weights.decoder_layer_weights[i]->cross_attention_weights.key_weight.bias = - get_ptr(_weights[19]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_; + get_ptr(_weights[21]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_; decoding_weights.decoder_layer_weights[i]->cross_attention_weights.value_weight.bias = - get_ptr(_weights[20]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_; + get_ptr(_weights[22]) + (i - first_layer_index) * hidden_dim / tensor_para_.world_size_; decoding_weights.decoder_layer_weights[i]->cross_attention_weights.attention_output_weight.bias = - get_ptr(_weights[21]) + (i - first_layer_index) * d_model; + get_ptr(_weights[23]) + (i - first_layer_index) * d_model; decoding_weights.decoder_layer_weights[i]->cross_attn_layernorm_weights.beta = - get_ptr(_weights[22]) + (i - first_layer_index) * d_model; + get_ptr(_weights[24]) + (i - first_layer_index) * d_model; decoding_weights.decoder_layer_weights[i]->ffn_weights.intermediate_weight.bias = - get_ptr(_weights[23]) + (i - first_layer_index) * inter_size_ / tensor_para_.world_size_; + get_ptr(_weights[25]) + (i - first_layer_index) * inter_size_ / tensor_para_.world_size_; + if (use_gated_activation) { + decoding_weights.decoder_layer_weights[i]->ffn_weights.intermediate_weight2.bias = + get_ptr(_weights[26]) + (i - first_layer_index) * inter_size_ / tensor_para_.world_size_; + } decoding_weights.decoder_layer_weights[i]->ffn_weights.output_weight.bias = - get_ptr(_weights[24]) + (i - first_layer_index) * d_model; + get_ptr(_weights[27]) + (i - first_layer_index) * d_model; } } - decoding_weights.post_decoder_layernorm.gamma = get_ptr(_weights[11]); - decoding_weights.pre_decoder_embedding_table = get_ptr(_weights[12]); - decoding_weights.post_decoder_embedding.kernel = get_ptr(_weights[12]); - decoding_weights.absolute_or_relative_position_embedding = get_ptr(_weights[13]); + decoding_weights.post_decoder_layernorm.gamma = get_ptr(_weights[12]); + decoding_weights.pre_decoder_embedding_table = get_ptr(_weights[13]); + decoding_weights.post_decoder_embedding.kernel = get_ptr(_weights[14]); + decoding_weights.absolute_or_relative_position_embedding = get_ptr(_weights[15]); if (t5_with_bias_) { - decoding_weights.post_decoder_layernorm.beta = get_ptr(_weights[25]); - decoding_weights.post_decoder_embedding.bias = get_ptr(_weights[26]); + decoding_weights.post_decoder_layernorm.beta = get_ptr(_weights[28]); + decoding_weights.post_decoder_embedding.bias = get_ptr(_weights[29]); } int device_id = 0; ft::check_cuda_error(cudaGetDevice(&device_id)); ft::check_cuda_error(cudaGetDeviceProperties(&prop_, device_id)); } - void init_nccl_comm() - { - int mpi_initialized; - MPICHECK(MPI_Initialized(&mpi_initialized)); - if (!mpi_initialized) { - FT_LOG_INFO("MPI is not initialized! Skipped the NCCL communication initialization.\n"); - if (tensor_para_.world_size_ != 1) { - printf("[FATAL] MPI initialization can only be skipped when tensor_para_size=1, but got %d!\n", - tensor_para_.world_size_); - } - if (pipeline_para_.world_size_ != 1) { - printf("[FATAL] MPI initialization can only be skipped when pipeline_para_size=1, but got %d!\n", - pipeline_para_.world_size_); - } - return; - } - - int rank, world_size; - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size)); - tensor_para_.rank_ = rank % tensor_para_.world_size_; - pipeline_para_.rank_ = rank / tensor_para_.world_size_; - - ncclUniqueId tensor_para_nccl_uid; - ncclUniqueId pipeline_para_nccl_uid; - - // assume gpu_num = n * k, - // tensor parallelism group size is n - // pipeline parallelism group size is k - if (tensor_para_.rank_ == 0) { - // get the uid of each tensor parallelism group - // here, 0, 1, ..., n-1 are in group 0, - // n, ..., 2n - 1 are in group 1. - NCCLCHECK(ncclGetUniqueId(&tensor_para_nccl_uid)); - for (int i = 1; i < (int)tensor_para_.world_size_; i++) { - printf("[INFO] rank %d sends tensor_para_nccl_uid to rank %d \n", rank, rank + i); - MPICHECK(MPI_Send( - &tensor_para_nccl_uid, sizeof(tensor_para_nccl_uid), MPI_BYTE, rank + i, 0, MPI_COMM_WORLD)); - } - } - else { - MPI_Status status; - printf( - "[INFO] rank %d receives tensor_para_nccl_uid from rank %d \n", rank, rank - (int)tensor_para_.rank_); - MPICHECK(MPI_Recv(&tensor_para_nccl_uid, - sizeof(tensor_para_nccl_uid), - MPI_BYTE, - rank - tensor_para_.rank_, - 0, - MPI_COMM_WORLD, - &status)); - } - - if (pipeline_para_.rank_ == 0) { - // get the uid of each pipeline parallelism group - // 0, k, 2k, are in group 0 - // 1, k+1, 2k+1 are in group 1 - NCCLCHECK(ncclGetUniqueId(&pipeline_para_nccl_uid)); - for (int i = 1; i < (int)pipeline_para_.world_size_; i++) { - printf("[INFO] rank %d sends pipeline_para_nccl_uid to rank %d \n", - rank, - rank + i * (int)tensor_para_.world_size_); - MPICHECK(MPI_Send(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank + i * tensor_para_.world_size_, - 0, - MPI_COMM_WORLD)); - } - } - else { - MPI_Status status; - printf("[INFO] rank %d receives pipeline_para_nccl_uid from rank %d \n", - rank, - rank % (int)tensor_para_.world_size_); - MPICHECK(MPI_Recv(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank % tensor_para_.world_size_, - 0, - MPI_COMM_WORLD, - &status)); - } - NCCLCHECK(ncclCommInitRank( - &tensor_para_.nccl_comm_, tensor_para_.world_size_, tensor_para_nccl_uid, tensor_para_.rank_)); - NCCLCHECK(ncclCommInitRank( - &pipeline_para_.nccl_comm_, pipeline_para_.world_size_, pipeline_para_nccl_uid, pipeline_para_.rank_)); - } - ~FTT5Decoding() override { + ft::ftNcclParamDestroy(tensor_para_); + ft::ftNcclParamDestroy(pipeline_para_); cublasLtDestroy(cublasltHandle_); delete cublas_algo_map_; delete cublas_wrapper_mutex_; } - std::vector forward(size_t beam_width, - size_t max_seq_len, - size_t top_k, - float top_p, - float beam_search_diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - unsigned long long random_seed, - bool is_return_output_log_probs, - bool is_return_cum_log_probs, - th::Tensor memory, - th::Tensor memory_seq_lens) override + std::vector forward(th::optional beam_width_opt, + size_t max_seq_len, + th::optional top_k_opt, + th::optional top_p_opt, + th::optional beam_search_diversity_rate_opt, + th::optional temperature_opt, + th::optional len_penalty_opt, + th::optional repetition_penalty_opt, + th::optional random_seed_opt, + th::optional is_return_output_log_probs_opt, + th::optional is_return_cum_log_probs_opt, + th::optional is_return_cross_attentions_opt, + th::Tensor memory, + th::Tensor memory_seq_lens) override { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + // input validation + size_t beam_width = beam_width_opt.has_value() ? (size_t)beam_width_opt.value() : 1; + uint top_k = top_k_opt.has_value() ? (uint)top_k_opt.value() : 1; + float top_p = top_p_opt.has_value() ? (float)top_p_opt.value() : 0.0f; + float beam_search_diversity_rate = + beam_search_diversity_rate_opt.has_value() ? (float)beam_search_diversity_rate_opt.value() : 0.0f; + float temperature = temperature_opt.has_value() ? (float)temperature_opt.value() : 1.0f; + float len_penalty = len_penalty_opt.has_value() ? (float)len_penalty_opt.value() : 0.0f; + float repetition_penalty = repetition_penalty_opt.has_value() ? (float)repetition_penalty_opt.value() : 1.0f; + unsigned long long random_seed = random_seed_opt.has_value() ? (unsigned long long)random_seed_opt.value() : 0; + bool is_return_output_log_probs = + is_return_output_log_probs_opt.has_value() ? (bool)is_return_output_log_probs_opt.value() : false; + bool is_return_cum_log_probs = + is_return_cum_log_probs_opt.has_value() ? (bool)is_return_cum_log_probs_opt.value() : false; + bool is_return_cross_attentions = + is_return_cross_attentions_opt.has_value() ? (bool)is_return_cross_attentions_opt.value() : false; + + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(cublasHandle, stream); - ft::Allocator allocator = ft::Allocator(); - ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper( + ft::Allocator allocator = ft::Allocator(); + ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper( cublasHandle, cublasltHandle_, stream, cublas_algo_map_, cublas_wrapper_mutex_, &allocator); if (std::is_same::value) { cublas_wrapper.setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper.setFP32GemmConfig(); } - const size_t batch_size = (size_t)memory.size(0); + const size_t batch_size = (size_t)memory.size(0); const size_t mem_max_seq_len = (size_t)memory.size(1); - ft::T5Decoding decoding = - ft::T5Decoding(batch_size, - max_seq_len, - mem_max_seq_len, - beam_width, - head_num_, - size_per_head_, - inter_size_, - d_model_, - layer_num_, - vocab_size_, - num_bucket_, - max_distance_, - q_scaling_, - start_id_, - end_id_, - beam_search_diversity_rate, - top_k, - top_p, - temperature, - len_penalty, - repetition_penalty, - stream, - &cublas_wrapper, - &allocator, - false, - &prop_, - tensor_para_, - pipeline_para_, - t5_with_bias_ ? ft::ActivationType::Gelu : ft::ActivationType::Relu); - ft::DataType data_type = ft::getTensorType(); + ft::T5Decoding decoding = ft::T5Decoding(batch_size, + max_seq_len, + mem_max_seq_len, + beam_width, + head_num_, + size_per_head_, + inter_size_, + d_model_, + layer_num_, + vocab_size_, + num_bucket_, + max_distance_, + q_scaling_, + start_id_, + end_id_, + beam_search_diversity_rate, + top_k, + top_p, + temperature, + len_penalty, + repetition_penalty, + stream, + &cublas_wrapper, + &allocator, + false, + &prop_, + tensor_para_, + pipeline_para_, + activation_type_, + tie_word_embeddings_); + ft::DataType data_type = ft::getTensorType(); std::unordered_map input_tensors = std::unordered_map{ {"encoder_output", @@ -329,34 +278,40 @@ class FTT5Decoding: public IFTT5Decoding { std::vector{(size_t)memory_seq_lens.size(0)}, get_ptr(memory_seq_lens)}}}; - if (top_k == 0 && top_p == 0.0f) { - ft::FT_CHECK(beam_width > 1); + if (beam_width > 1) { input_tensors.insert( {"beam_search_diversity_rate", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &beam_search_diversity_rate}}); } - else { - if (top_p != 0.0f) { - input_tensors.insert( - {"runtime_top_p", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &top_p}}); - } - if (top_k != 0) { - input_tensors.insert( - {"runtime_top_k", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_INT32, std::vector{1}, &top_k}}); - } + if (top_p_opt.has_value()) { + input_tensors.insert( + {"runtime_top_p", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &top_p}}); + } + if (top_k_opt.has_value()) { + input_tensors.insert( + {"runtime_top_k", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_UINT32, std::vector{1}, &top_k}}); + } + if (temperature_opt.has_value()) { + input_tensors.insert( + {"temperature", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &temperature}}); + } + if (len_penalty_opt.has_value()) { + input_tensors.insert( + {"len_penalty", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &len_penalty}}); + } + if (repetition_penalty_opt.has_value()) { + input_tensors.insert( + {"repetition_penalty", + ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &repetition_penalty}}); + } + if (random_seed_opt.has_value()) { + input_tensors.insert( + {"random_seed", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_UINT64, std::vector{1}, &random_seed}}); } - input_tensors.insert( - {"temperature", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &temperature}}); - input_tensors.insert( - {"len_penalty", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &len_penalty}}); - input_tensors.insert({"repetition_penalty", - ft::Tensor{ft::MEMORY_CPU, ft::TYPE_FP32, std::vector{1}, &repetition_penalty}}); - input_tensors.insert( - {"random_seed", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_UINT64, std::vector{1}, &random_seed}}); - auto output_ids = torch::empty({batch_size * beam_width * max_seq_len}, + auto output_ids = torch::empty({(long int)(batch_size * beam_width * max_seq_len)}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); - auto sequence_length = torch::empty({batch_size * beam_width}, + auto sequence_length = torch::empty({(long int)(batch_size * beam_width)}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); std::vector th_output_tensors = {output_ids, sequence_length}; @@ -373,7 +328,7 @@ class FTT5Decoding: public IFTT5Decoding { get_ptr(sequence_length)}}}; if (is_return_output_log_probs) { - auto output_log_probs = torch::empty({batch_size * beam_width * max_seq_len}, + auto output_log_probs = torch::empty({(long int)(batch_size * beam_width * max_seq_len)}, torch::dtype(torch::kFloat).device(torch::kCUDA).requires_grad(false)); output_tensors.insert({"output_log_probs", ft::Tensor{ft::MEMORY_GPU, @@ -383,39 +338,58 @@ class FTT5Decoding: public IFTT5Decoding { th_output_tensors.push_back(output_log_probs); } if (is_return_cum_log_probs) { - auto cum_log_probs = torch::empty({batch_size * beam_width}, + auto cum_log_probs = torch::empty({(long int)(batch_size * beam_width)}, torch::dtype(torch::kFloat).device(torch::kCUDA).requires_grad(false)); output_tensors.insert( {"cum_log_probs", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_FP32, {batch_size, beam_width}, get_ptr(cum_log_probs)}}); th_output_tensors.push_back(cum_log_probs); } + if (is_return_cross_attentions) { + auto cross_attentions = + torch::empty({(long int)(ceil(layer_num_ * 1.0 / pipeline_para_.world_size_) * batch_size * beam_width + * (head_num_ / tensor_para_.world_size_) * max_seq_len * mem_max_seq_len)}, + torch::dtype(torch::kFloat).device(torch::kCUDA).requires_grad(false)); + output_tensors.insert({"cross_attentions", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_FP32, + {(size_t)(layer_num_ / pipeline_para_.world_size_), + (size_t)batch_size, + (size_t)beam_width, + (size_t)(head_num_ / tensor_para_.world_size_), + (size_t)max_seq_len, + (size_t)mem_max_seq_len}, + get_ptr(cross_attentions)}}); + th_output_tensors.push_back(cross_attentions); + } decoding.forward(&output_tensors, &input_tensors, &decoding_weights); return th_output_tensors; } private: - const int head_num_; - const int size_per_head_; - const int inter_size_; - const int mem_d_model_; - const int d_model_; - const int layer_num_; - const int vocab_size_; - const int num_bucket_; - const int max_distance_; - double q_scaling_; - const int start_id_; - const int end_id_; - const bool t5_with_bias_; + const int head_num_; + const int size_per_head_; + const int inter_size_; + const int mem_d_model_; + const int d_model_; + const int layer_num_; + const int vocab_size_; + const int num_bucket_; + const int max_distance_; + double q_scaling_; + const int start_id_; + const int end_id_; + const bool t5_with_bias_; const ft::PositionEmbeddingType position_embedding_type_; + const ft::ActivationType activation_type_; + const bool tie_word_embeddings_; std::vector _weights; - cublasLtHandle_t cublasltHandle_; - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; - struct cudaDeviceProp prop_; + cublasLtHandle_t cublasltHandle_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; + struct cudaDeviceProp prop_; ft::T5DecodingWeight decoding_weights; ft::NcclParam tensor_para_; @@ -424,73 +398,78 @@ class FTT5Decoding: public IFTT5Decoding { class FasterTransformerT5Decoding: public torch::jit::CustomClassHolder { public: - FasterTransformerT5Decoding(int64_t head_num, - int64_t size_per_head, - int64_t inter_size, - int64_t mem_d_model, - int64_t d_model, - int64_t layer_num, - int64_t vocab_size, - int64_t num_bucket, - int64_t max_distance, - double q_scaling, - int64_t start_id, - int64_t end_id, - int64_t tensor_para_size, - int64_t pipeline_para_size, - bool t5_with_bias, - int64_t position_embedding_type, - th::Tensor self_layernorm_gamma, - th::Tensor self_kernel_qkv, - th::Tensor self_output_kernel, - th::Tensor cross_layernorm_gamma, - th::Tensor cross_kernel_q, - th::Tensor cross_kernel_k, - th::Tensor cross_kernel_v, - th::Tensor cross_output_kernel, - th::Tensor ffn_layernorm_gamma, - th::Tensor inter_kernel, - th::Tensor output_kernel, - th::Tensor decoding_gamma, - th::Tensor embedding_table, - th::Tensor absolute_or_relative_position_embedding, - th::Tensor self_layernorm_beta, - th::Tensor self_bias_qkv, - th::Tensor self_output_bias, - th::Tensor cross_layernorm_beta, - th::Tensor cross_bias_q, - th::Tensor cross_bias_k, - th::Tensor cross_bias_v, - th::Tensor cross_output_bias, - th::Tensor ffn_layernorm_beta, - th::Tensor inter_bias, - th::Tensor output_bias, - th::Tensor decoding_beta, - th::Tensor embedding_bias); + FasterTransformerT5Decoding(int64_t head_num, + int64_t size_per_head, + int64_t inter_size, + int64_t mem_d_model, + int64_t d_model, + int64_t layer_num, + int64_t vocab_size, + int64_t num_bucket, + int64_t max_distance, + double q_scaling, + int64_t start_id, + int64_t end_id, + int64_t tensor_para_size, + int64_t pipeline_para_size, + bool t5_with_bias, + int64_t position_embedding_type, + std::string activaiton_type, + bool tie_word_embeddings, + th::Tensor self_layernorm_gamma, + th::Tensor self_kernel_qkv, + th::Tensor self_output_kernel, + th::Tensor cross_layernorm_gamma, + th::Tensor cross_kernel_q, + th::Tensor cross_kernel_k, + th::Tensor cross_kernel_v, + th::Tensor cross_output_kernel, + th::Tensor ffn_layernorm_gamma, + th::Tensor inter_kernel, + th::Tensor inter_kernel2, + th::Tensor output_kernel, + th::Tensor decoding_gamma, + th::Tensor embedding_table, + th::Tensor lm_head, + th::Tensor absolute_or_relative_position_embedding, + th::Tensor self_layernorm_beta, + th::Tensor self_bias_qkv, + th::Tensor self_output_bias, + th::Tensor cross_layernorm_beta, + th::Tensor cross_bias_q, + th::Tensor cross_bias_k, + th::Tensor cross_bias_v, + th::Tensor cross_output_bias, + th::Tensor ffn_layernorm_beta, + th::Tensor inter_bias, + th::Tensor inter_bias2, + th::Tensor output_bias, + th::Tensor decoding_beta, + th::Tensor embedding_bias); ~FasterTransformerT5Decoding(); - std::vector forward(int64_t beam_width, - int64_t max_seq_len, - int64_t top_k, - double top_p, - double beam_search_diversity_rate, - double temperature, - double len_penalty, - double repetition_penalty, - int64_t random_seed, - bool is_return_output_log_probs, - bool is_return_cum_log_probs, - th::Tensor memory, - th::Tensor memory_seq_lens); + std::vector forward(th::optional beam_width, + int64_t max_seq_len, + th::optional top_k, + th::optional top_p, + th::optional beam_search_diversity_rate, + th::optional temperature, + th::optional len_penalty, + th::optional repetition_penalty, + th::optional random_seed, + th::optional is_return_output_log_probs, + th::optional is_return_cum_log_probs, + th::optional is_return_cross_attentions, + th::Tensor memory, + th::Tensor memory_seq_lens); std::vector get_pickle_info() const; private: - const at::ScalarType _st; + const at::ScalarType _st; torch_ext::IFTT5Decoding* ftdecoding; - th::Tensor int_info_; - std::vector weights; + std::vector weights; }; -} // namespace torch_ext \ No newline at end of file +} // namespace torch_ext diff --git a/src/fastertransformer/th_op/t5/T5EncoderOp.cc b/src/fastertransformer/th_op/t5/T5EncoderOp.cc index d5015878b..d3a882946 100644 --- a/src/fastertransformer/th_op/t5/T5EncoderOp.cc +++ b/src/fastertransformer/th_op/t5/T5EncoderOp.cc @@ -19,40 +19,44 @@ namespace th = torch; namespace torch_ext { -FasterTransformerT5Encoder::FasterTransformerT5Encoder(th::Tensor attr_output_layernorm_gamma, - th::Tensor q_kernel, - th::Tensor k_kernel, - th::Tensor v_kernel, - th::Tensor attr_output_kernel, - th::Tensor output_layernorm_gamma, - th::Tensor inter_kernel, - th::Tensor output_kernel, - th::Tensor post_transformer_layernorm_gamma, - th::Tensor absolute_or_relative_position_embedding, - th::Tensor embedding_table, - th::Tensor attr_output_layernorm_beta, - th::Tensor q_bias, - th::Tensor k_bias, - th::Tensor v_bias, - th::Tensor attr_output_bias, - th::Tensor output_layernorm_beta, - th::Tensor inter_bias, - th::Tensor output_bias, - th::Tensor post_transformer_layernorm_beta, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - int64_t d_model, - bool remove_padding, - int64_t layer_num, - int64_t num_bucket, - int64_t max_distance, - bool sparse, - double q_scaling, - int64_t tensor_para_size, - int64_t pipeline_para_size, - bool t5_with_bias, - int64_t position_embedding_type): +FasterTransformerT5Encoder::FasterTransformerT5Encoder(th::Tensor attr_output_layernorm_gamma, + th::Tensor q_kernel, + th::Tensor k_kernel, + th::Tensor v_kernel, + th::Tensor attr_output_kernel, + th::Tensor output_layernorm_gamma, + th::Tensor inter_kernel, + th::Tensor inter_kernel2, + th::Tensor output_kernel, + th::Tensor post_transformer_layernorm_gamma, + th::Tensor absolute_or_relative_position_embedding, + th::Tensor embedding_table, + th::Tensor attr_output_layernorm_beta, + th::Tensor q_bias, + th::Tensor k_bias, + th::Tensor v_bias, + th::Tensor attr_output_bias, + th::Tensor output_layernorm_beta, + th::Tensor inter_bias, + th::Tensor inter_bias2, + th::Tensor output_bias, + th::Tensor post_transformer_layernorm_beta, + int64_t head_num, + int64_t head_size, + int64_t inter_size, + int64_t d_model, + bool remove_padding, + int64_t layer_num, + int64_t num_bucket, + int64_t max_distance, + bool sparse, + double q_scaling, + int64_t tensor_para_size, + int64_t pipeline_para_size, + bool t5_with_bias, + int64_t position_embedding_type, + std::string activation_type): + d_model_(d_model), _st(q_kernel.scalar_type()), _remove_padding(remove_padding), weights{attr_output_layernorm_gamma, @@ -62,6 +66,7 @@ FasterTransformerT5Encoder::FasterTransformerT5Encoder(th::Tensor attr_output_la attr_output_kernel, output_layernorm_gamma, inter_kernel, + inter_kernel2, output_kernel, post_transformer_layernorm_gamma, absolute_or_relative_position_embedding, @@ -73,6 +78,7 @@ FasterTransformerT5Encoder::FasterTransformerT5Encoder(th::Tensor attr_output_la attr_output_bias, output_layernorm_beta, inter_bias, + inter_bias2, output_bias, post_transformer_layernorm_beta} { @@ -82,6 +88,7 @@ FasterTransformerT5Encoder::FasterTransformerT5Encoder(th::Tensor attr_output_la CHECK_INPUT(attr_output_kernel, _st); // hidden_dim, d_model CHECK_INPUT(attr_output_layernorm_gamma, _st); // d_model CHECK_INPUT(inter_kernel, _st); // d_model, inter_size + CHECK_INPUT(inter_kernel2, _st); // d_model, inter_size CHECK_INPUT(output_kernel, _st); // inter_size, d_model CHECK_INPUT(output_layernorm_gamma, _st); // d_model CHECK_INPUT(post_transformer_layernorm_gamma, _st); // d_model @@ -94,6 +101,7 @@ FasterTransformerT5Encoder::FasterTransformerT5Encoder(th::Tensor attr_output_la CHECK_INPUT(attr_output_bias, _st); // d_model CHECK_INPUT(attr_output_layernorm_beta, _st); // d_model CHECK_INPUT(inter_bias, _st); // inter_size + CHECK_INPUT(inter_bias2, _st); // inter_size CHECK_INPUT(output_bias, _st); // d_model CHECK_INPUT(output_layernorm_beta, _st); // d_model CHECK_INPUT(post_transformer_layernorm_beta, _st); // d_model @@ -114,6 +122,7 @@ FasterTransformerT5Encoder::FasterTransformerT5Encoder(th::Tensor attr_output_la pipeline_para_size, t5_with_bias, ft::PositionEmbeddingType(position_embedding_type), + ft::getActivationType(activation_type), weights); break; case at::ScalarType::Half: @@ -130,27 +139,31 @@ FasterTransformerT5Encoder::FasterTransformerT5Encoder(th::Tensor attr_output_la pipeline_para_size, t5_with_bias, ft::PositionEmbeddingType(position_embedding_type), + ft::getActivationType(activation_type), weights); break; +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + ft_t5_encoder = new FTT5Encoder<__nv_bfloat16>(head_num, + head_size, + inter_size, + d_model, + layer_num, + num_bucket, + max_distance, + sparse, + q_scaling, + tensor_para_size, + pipeline_para_size, + t5_with_bias, + ft::PositionEmbeddingType(position_embedding_type), + ft::getActivationType(activation_type), + weights); + break; +#endif default: throw std::runtime_error("Wrong Tensor type."); } - head_info = torch::empty({13}, torch::dtype(torch::kInt64)); - head_info[0] = head_num; - head_info[1] = head_size; - head_info[2] = (int64_t)remove_padding; - head_info[3] = layer_num; - head_info[4] = num_bucket; - head_info[5] = max_distance; - head_info[6] = (int64_t)sparse; - head_info[7] = inter_size; - head_info[8] = d_model; - head_info[9] = tensor_para_size; - head_info[10] = pipeline_para_size; - head_info[11] = t5_with_bias; - head_info[12] = position_embedding_type; - scaling_info = torch::empty({1}, torch::dtype(torch::kFloat64)); - scaling_info[0] = (double)q_scaling; } FasterTransformerT5Encoder::~FasterTransformerT5Encoder() @@ -158,29 +171,41 @@ FasterTransformerT5Encoder::~FasterTransformerT5Encoder() delete ft_t5_encoder; } -th::Tensor FasterTransformerT5Encoder::forward(th::Tensor input_ids, th::Tensor sequence_lengths) +th::Tensor FasterTransformerT5Encoder::forward(th::optional input_ids, + th::Tensor sequence_lengths, + th::optional inputs_embeds) { - CHECK_CONTIGUOUS(input_ids); - TORCH_CHECK(input_ids.dtype() == torch::kInt32, "input_ids dtype should be int32"); + if (input_ids.has_value()) { + CHECK_CONTIGUOUS(input_ids.value()); + TORCH_CHECK(input_ids.value().dtype() == torch::kInt32, "input_ids dtype should be int32"); + } CHECK_CONTIGUOUS(sequence_lengths); TORCH_CHECK(sequence_lengths.dtype() == torch::kInt32, "sequence_lengths dtype should be int32"); - size_t batch_size = (size_t)input_ids.size(0); - size_t seq_len = (size_t)input_ids.size(1); - int64_t d_model = head_info[8].item().to(); + if (inputs_embeds.has_value()) { + CHECK_CONTIGUOUS(inputs_embeds.value()); + TORCH_CHECK(inputs_embeds.value().dtype() == torch::kFloat32 + || inputs_embeds.value().dtype() == torch::kFloat16, + "inputs_embeds dtype should be float32 or float16"); + } + + TORCH_CHECK(input_ids.has_value() || inputs_embeds.has_value(), + "input_ids and inputs_embeds should not be empty at the same time."); + + size_t batch_size = inputs_embeds.has_value() ? inputs_embeds.value().size(0) : input_ids.value().size(0); + size_t seq_len = inputs_embeds.has_value() ? inputs_embeds.value().size(1) : input_ids.value().size(1); + int64_t d_model = d_model_; auto output = torch::empty({(long int)batch_size, (long int)seq_len, (long int)d_model}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - ft_t5_encoder->forward(batch_size, seq_len, input_ids, sequence_lengths, output, _remove_padding); + ft_t5_encoder->forward(batch_size, seq_len, input_ids, sequence_lengths, inputs_embeds, output, _remove_padding); return output; } std::vector FasterTransformerT5Encoder::get_pickle_info() const { std::vector tmp(weights); - tmp.push_back(head_info); - tmp.push_back(scaling_info); return tmp; } @@ -212,6 +237,8 @@ static auto fasterTransformerT5EncoderTHS = th::Tensor, th::Tensor, th::Tensor, + th::Tensor, + th::Tensor, int64_t, int64_t, int64_t, @@ -225,59 +252,6 @@ static auto fasterTransformerT5EncoderTHS = int64_t, int64_t, bool, - int64_t>()) - .def("forward", &torch_ext::FasterTransformerT5Encoder::forward) - .def_pickle( - [](const c10::intrusive_ptr& self) -> std::vector { - return self->get_pickle_info(); - }, - [](std::vector state) -> c10::intrusive_ptr { - int64_t head_num = state[20][0].item().to(); - int64_t head_size = state[20][1].item().to(); - bool remove_padding = (bool)(state[20][2].item().to()); - int64_t layer_num = state[20][3].item().to(); - int64_t num_bucket = state[20][4].item().to(); - int64_t max_distance = state[20][5].item().to(); - bool sparse = (bool)(state[20][6].item().to()); - int64_t inter_size = state[20][7].item().to(); - int64_t d_model = state[20][8].item().to(); - int64_t tensor_para_size = state[20][9].item().to(); - int64_t pipeline_para_size = state[20][10].item().to(); - bool t5_with_bias = (bool)(state[20][11].item().to()); - int64_t position_embedding_type = state[20][12].item().to(); - double q_scaling = state[21][0].item().to(); - return c10::make_intrusive(state[0], - state[1], - state[2], - state[3], - state[4], - state[5], - state[6], - state[7], - state[8], - state[9], - state[10], - state[11], - state[12], - state[13], - state[14], - state[15], - state[16], - state[17], - state[18], - state[19], - head_num, - head_size, - inter_size, - d_model, - remove_padding, - layer_num, - num_bucket, - max_distance, - sparse, - q_scaling, - tensor_para_size, - pipeline_para_size, - t5_with_bias, - position_embedding_type); - }); \ No newline at end of file + int64_t, + std::string>()) + .def("forward", &torch_ext::FasterTransformerT5Encoder::forward); diff --git a/src/fastertransformer/th_op/t5/T5EncoderOp.h b/src/fastertransformer/th_op/t5/T5EncoderOp.h index 7de9e2027..2cebba63a 100644 --- a/src/fastertransformer/th_op/t5/T5EncoderOp.h +++ b/src/fastertransformer/th_op/t5/T5EncoderOp.h @@ -16,7 +16,7 @@ #include "src/fastertransformer/models/t5/T5Encoder.h" #include "src/fastertransformer/th_op/th_utils.h" -#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" namespace ft = fastertransformer; namespace th = torch; @@ -25,30 +25,32 @@ namespace torch_ext { class IFT5Encoder { public: virtual ~IFT5Encoder() {} - virtual void forward(size_t batch_size, - size_t seq_len, - th::Tensor& input, - th::Tensor& sequence_lengths, - th::Tensor& output, - bool removing_padding) = 0; + virtual void forward(size_t batch_size, + size_t seq_len, + th::optional input, + th::Tensor& sequence_lengths, + th::optional inputs_embeds, + th::Tensor& output, + bool removing_padding) = 0; }; template class FTT5Encoder: public IFT5Encoder { public: - FTT5Encoder(int head_num, - int head_size, - int inter_size, - int d_model, - int layer_num, - int num_bucket, - int max_distance, - bool sparse, - float q_scaling, - int tensor_para_size, - int pipeline_para_size, - bool t5_with_bias, - ft::PositionEmbeddingType position_embedding_type, + FTT5Encoder(int head_num, + int head_size, + int inter_size, + int d_model, + int layer_num, + int num_bucket, + int max_distance, + bool sparse, + float q_scaling, + int tensor_para_size, + int pipeline_para_size, + bool t5_with_bias, + ft::PositionEmbeddingType position_embedding_type, + ft::ActivationType activation_type, const std::vector& w): _head_num(head_num), _head_size(head_size), @@ -61,11 +63,14 @@ class FTT5Encoder: public IFT5Encoder { _sparse(sparse), _q_scaling(q_scaling), _t5_with_bias(t5_with_bias), - _position_embedding_type(position_embedding_type) + _position_embedding_type(position_embedding_type), + _activation_type(activation_type) { - tensor_para_.world_size_ = tensor_para_size; - pipeline_para_.world_size_ = pipeline_para_size; - init_nccl_comm(); + bool use_gated_activation = isGatedActivation(_activation_type); + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + ft::ftNcclInitialize(tensor_para_, pipeline_para_, tensor_para_size, pipeline_para_size); + #ifndef SPARSITY_ENABLED if (sparse) { std::cout << "[WARNING] Sparsity support is not enabled. Will use dense GEMM instead.\n" << std::flush; @@ -80,11 +85,11 @@ class FTT5Encoder: public IFT5Encoder { } #endif std::string sp_config_fname = sparse ? "spgemm_config.in" : ""; - cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", sp_config_fname); - cublas_wrapper_mutex_ = new std::mutex(); + cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", sp_config_fname); + cublas_wrapper_mutex_ = new std::mutex(); t5_encoder_weights.resizeLayer(_layer_num); - t5_encoder_weights.setT5StructureDiff(t5_with_bias, position_embedding_type); + t5_encoder_weights.setT5StructureDiff(t5_with_bias, use_gated_activation, position_embedding_type); for (int i = 0; i < _layer_num; i++) { int local_num_layer = (int)(ceil(_layer_num * 1.0f / pipeline_para_.world_size_)); if (!(i < _layer_num && (i >= local_num_layer * pipeline_para_.rank_) @@ -107,37 +112,46 @@ class FTT5Encoder: public IFT5Encoder { get_ptr(_weights[5]) + _d_model * (i - first_layer_index); t5_encoder_weights.t5_encoder_layer_weights[i]->ffn_weights.intermediate_weight.kernel = get_ptr(_weights[6]) + _d_model * _inter_size / tensor_para_.world_size_ * (i - first_layer_index); + if (use_gated_activation) { + t5_encoder_weights.t5_encoder_layer_weights[i]->ffn_weights.intermediate_weight2.kernel = + get_ptr(_weights[7]) + + _d_model * _inter_size / tensor_para_.world_size_ * (i - first_layer_index); + } t5_encoder_weights.t5_encoder_layer_weights[i]->ffn_weights.output_weight.kernel = - get_ptr(_weights[7]) + _inter_size / tensor_para_.world_size_ * _d_model * (i - first_layer_index); + get_ptr(_weights[8]) + _inter_size / tensor_para_.world_size_ * _d_model * (i - first_layer_index); if (_t5_with_bias) { t5_encoder_weights.t5_encoder_layer_weights[i]->attn_layernorm_weights.beta = - get_ptr(_weights[11]) + _d_model * (i - first_layer_index); + get_ptr(_weights[12]) + _d_model * (i - first_layer_index); t5_encoder_weights.t5_encoder_layer_weights[i]->attention_weights.query_weight.bias = - get_ptr(_weights[12]) + hidden_dim / tensor_para_.world_size_ * (i - first_layer_index); - t5_encoder_weights.t5_encoder_layer_weights[i]->attention_weights.key_weight.bias = get_ptr(_weights[13]) + hidden_dim / tensor_para_.world_size_ * (i - first_layer_index); - t5_encoder_weights.t5_encoder_layer_weights[i]->attention_weights.value_weight.bias = + t5_encoder_weights.t5_encoder_layer_weights[i]->attention_weights.key_weight.bias = get_ptr(_weights[14]) + hidden_dim / tensor_para_.world_size_ * (i - first_layer_index); + t5_encoder_weights.t5_encoder_layer_weights[i]->attention_weights.value_weight.bias = + get_ptr(_weights[15]) + hidden_dim / tensor_para_.world_size_ * (i - first_layer_index); t5_encoder_weights.t5_encoder_layer_weights[i]->attention_weights.attention_output_weight.bias = - get_ptr(_weights[15]) + _d_model * (i - first_layer_index); - t5_encoder_weights.t5_encoder_layer_weights[i]->ffn_layernorm_weights.beta = get_ptr(_weights[16]) + _d_model * (i - first_layer_index); + t5_encoder_weights.t5_encoder_layer_weights[i]->ffn_layernorm_weights.beta = + get_ptr(_weights[17]) + _d_model * (i - first_layer_index); t5_encoder_weights.t5_encoder_layer_weights[i]->ffn_weights.intermediate_weight.bias = - get_ptr(_weights[17]) + _inter_size / tensor_para_.world_size_ * (i - first_layer_index); + get_ptr(_weights[18]) + _inter_size / tensor_para_.world_size_ * (i - first_layer_index); + if (use_gated_activation) { + t5_encoder_weights.t5_encoder_layer_weights[i]->ffn_weights.intermediate_weight2.bias = + get_ptr(_weights[19]) + _inter_size / tensor_para_.world_size_ * (i - first_layer_index); + } t5_encoder_weights.t5_encoder_layer_weights[i]->ffn_weights.output_weight.bias = - get_ptr(_weights[18]) + _d_model * (i - first_layer_index); + get_ptr(_weights[20]) + _d_model * (i - first_layer_index); } } - t5_encoder_weights.post_transformer_layernorm_weights.gamma = get_ptr(_weights[8]); - t5_encoder_weights.absolute_or_relative_position_embedding = get_ptr(_weights[9]); - t5_encoder_weights.embedding_table = get_ptr(_weights[10]); + t5_encoder_weights.post_transformer_layernorm_weights.gamma = get_ptr(_weights[9]); + t5_encoder_weights.absolute_or_relative_position_embedding = get_ptr(_weights[10]); + t5_encoder_weights.embedding_table = get_ptr(_weights[11]); if (_t5_with_bias) { - t5_encoder_weights.post_transformer_layernorm_weights.beta = get_ptr(_weights[19]); + t5_encoder_weights.post_transformer_layernorm_weights.beta = get_ptr(_weights[21]); } #ifdef SPARSITY_ENABLED if (sparse) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t _cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(_cublasHandle, stream); ft::cublasMMWrapper cublas_wrapper = ft::cublasMMWrapper(_cublasHandle, @@ -156,6 +170,9 @@ class FTT5Encoder: public IFT5Encoder { ~FTT5Encoder() override { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + ft::ftNcclParamDestroy(tensor_para_); + ft::ftNcclParamDestroy(pipeline_para_); cublasLtDestroy(_cublasltHandle); #ifdef SPARSITY_ENABLED if (_sparse) { @@ -166,107 +183,20 @@ class FTT5Encoder: public IFT5Encoder { delete cublas_wrapper_mutex_; } - void init_nccl_comm() + void forward(size_t batch_size, + size_t seq_len, + th::optional input_ids, + th::Tensor& sequence_lengths, + th::optional inputs_embeds, + th::Tensor& output, + bool removing_padding) override { - int mpi_initialized; - MPICHECK(MPI_Initialized(&mpi_initialized)); - if (!mpi_initialized) { - printf("[INFO] MPI is not initialized! Skipped the NCCL communication initialization.\n"); - if (tensor_para_.world_size_ != 1) { - printf("[FATAL] MPI initialization can only be skipped when tensor_para_size=1, but got %d!\n", - tensor_para_.world_size_); - } - if (pipeline_para_.world_size_ != 1) { - printf("[FATAL] MPI initialization can only be skipped when pipeline_para_size=1, but got %d!\n", - pipeline_para_.world_size_); - } - return; - } - - int rank, world_size; - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size)); - tensor_para_.rank_ = rank % tensor_para_.world_size_; - pipeline_para_.rank_ = rank / tensor_para_.world_size_; - - ncclUniqueId tensor_para_nccl_uid; - ncclUniqueId pipeline_para_nccl_uid; - - // assume gpu_num = n * k, - // tensor parallelism group size is n - // pipeline parallelism group size is k - if (tensor_para_.rank_ == 0) { - // get the uid of each tensor parallelism group - // here, 0, 1, ..., n-1 are in group 0, - // n, ..., 2n - 1 are in group 1. - NCCLCHECK(ncclGetUniqueId(&tensor_para_nccl_uid)); - for (int i = 1; i < (int)tensor_para_.world_size_; i++) { - printf("[INFO] rank %d sends tensor_para_nccl_uid to rank %d \n", rank, rank + i); - MPICHECK(MPI_Send( - &tensor_para_nccl_uid, sizeof(tensor_para_nccl_uid), MPI_BYTE, rank + i, 0, MPI_COMM_WORLD)); - } - } - else { - MPI_Status status; - printf( - "[INFO] rank %d receives tensor_para_nccl_uid from rank %d \n", rank, rank - (int)tensor_para_.rank_); - MPICHECK(MPI_Recv(&tensor_para_nccl_uid, - sizeof(tensor_para_nccl_uid), - MPI_BYTE, - rank - tensor_para_.rank_, - 0, - MPI_COMM_WORLD, - &status)); - } - - if (pipeline_para_.rank_ == 0) { - // get the uid of each pipeline parallelism group - // 0, k, 2k, are in group 0 - // 1, k+1, 2k+1 are in group 1 - NCCLCHECK(ncclGetUniqueId(&pipeline_para_nccl_uid)); - for (int i = 1; i < (int)pipeline_para_.world_size_; i++) { - printf("[INFO] rank %d sends pipeline_para_nccl_uid to rank %d \n", - rank, - rank + i * (int)tensor_para_.world_size_); - MPICHECK(MPI_Send(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank + i * tensor_para_.world_size_, - 0, - MPI_COMM_WORLD)); - } - } - else { - MPI_Status status; - printf("[INFO] rank %d receives pipeline_para_nccl_uid from rank %d \n", - rank, - rank % (int)tensor_para_.world_size_); - MPICHECK(MPI_Recv(&pipeline_para_nccl_uid, - sizeof(pipeline_para_nccl_uid), - MPI_BYTE, - rank % tensor_para_.world_size_, - 0, - MPI_COMM_WORLD, - &status)); - } - NCCLCHECK(ncclCommInitRank( - &tensor_para_.nccl_comm_, tensor_para_.world_size_, tensor_para_nccl_uid, tensor_para_.rank_)); - NCCLCHECK(ncclCommInitRank( - &pipeline_para_.nccl_comm_, pipeline_para_.world_size_, pipeline_para_nccl_uid, pipeline_para_.rank_)); - } - - void forward(size_t batch_size, - size_t seq_len, - th::Tensor& input_ids, - th::Tensor& sequence_lengths, - th::Tensor& output, - bool removing_padding) override - { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasHandle_t _cublasHandle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(_cublasHandle, stream); ft::Allocator* allocator = new ft::Allocator(); - ft::cublasMMWrapper* cublas_wrapper = + ft::cublasMMWrapper* cublas_wrapper = #ifdef SPARSITY_ENABLED new ft::cublasMMWrapper(_cublasHandle, _cublasltHandle, @@ -283,37 +213,55 @@ class FTT5Encoder: public IFT5Encoder { if (std::is_same::value) { cublas_wrapper->setFP16GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } ft::AttentionType attention_type = ft::getAttentionType(_head_size, sm_, removing_padding, seq_len, false); - ft::T5Encoder* t5_encoder = - new ft::T5Encoder(batch_size, - seq_len, - _head_num, - _head_size, - _inter_size, - _d_model, - _layer_num, - _num_bucket, - _max_distance, - sm_, - _q_scaling, - stream, - cublas_wrapper, - allocator, - false, - attention_type, - _sparse, - _t5_with_bias ? ft::ActivationType::Gelu : ft::ActivationType::Relu, - ft::LayerNormType::pre_layernorm, - tensor_para_, - pipeline_para_); + ft::T5Encoder* t5_encoder = new ft::T5Encoder(batch_size, + seq_len, + _head_num, + _head_size, + _inter_size, + _d_model, + _layer_num, + _num_bucket, + _max_distance, + sm_, + _q_scaling, + stream, + cublas_wrapper, + allocator, + false, + attention_type, + _sparse, + _activation_type, + ft::LayerNormType::pre_layernorm, + tensor_para_, + pipeline_para_); - std::unordered_map input_tensors = std::unordered_map{ - {"input_ids", convert_tensor(input_ids)}, {"sequence_length", convert_tensor(sequence_lengths)}}; + std::unordered_map input_tensors = + std::unordered_map{{"sequence_length", convert_tensor(sequence_lengths)}}; + + if (inputs_embeds.has_value()) { + if (std::is_same::value) { + TORCH_CHECK(inputs_embeds.value().dtype() == torch::kFloat32, "inputs_embeds dtype should be float32"); + } + else if (std::is_same::value) { + TORCH_CHECK(inputs_embeds.value().dtype() == torch::kFloat16, "inputs_embeds dtype should be float16"); + } + input_tensors.insert({"inputs_embeds", convert_tensor(inputs_embeds.value())}); + } + else { + // already check that input_ids and input_embeds cannot be empty at the same time + input_tensors.insert({"input_ids", convert_tensor(input_ids.value())}); + } std::unordered_map output_tensors = std::unordered_map{{"output_hidden_state", convert_tensor(output)}}; @@ -335,25 +283,26 @@ class FTT5Encoder: public IFT5Encoder { } private: - const int _head_num; - const int _head_size; - const int _inter_size; - const int _d_model; - const int _layer_num; - const int _num_bucket; - const int _max_distance; - std::vector _weights; - bool _t5_with_bias; + const int _head_num; + const int _head_size; + const int _inter_size; + const int _d_model; + const int _layer_num; + const int _num_bucket; + const int _max_distance; + std::vector _weights; + bool _t5_with_bias; ft::PositionEmbeddingType _position_embedding_type; - bool _sparse; - const float _q_scaling; - int sm_; - cublasLtHandle_t _cublasltHandle; + ft::ActivationType _activation_type; + bool _sparse; + const float _q_scaling; + int sm_; + cublasLtHandle_t _cublasltHandle; #ifdef SPARSITY_ENABLED cusparseLtHandle_t _cusparseLtHandle; #endif - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; ft::T5EncoderWeight t5_encoder_weights; ft::NcclParam tensor_para_; @@ -362,53 +311,56 @@ class FTT5Encoder: public IFT5Encoder { class FasterTransformerT5Encoder: public th::jit::CustomClassHolder { public: - FasterTransformerT5Encoder(th::Tensor attr_output_layernorm_gamma, - th::Tensor q_kernel, - th::Tensor k_kernel, - th::Tensor v_kernel, - th::Tensor attr_output_kernel, - th::Tensor output_layernorm_gamma, - th::Tensor inter_kernel, - th::Tensor output_kernel, - th::Tensor post_transformer_layernorm_gamma, - th::Tensor absolute_or_relative_position_embedding, - th::Tensor embedding_table, - th::Tensor attr_output_layernorm_beta, - th::Tensor q_bias, - th::Tensor k_bias, - th::Tensor v_bias, - th::Tensor attr_output_bias, - th::Tensor output_layernorm_beta, - th::Tensor inter_bias, - th::Tensor output_bias, - th::Tensor post_transformer_layernorm_beta, - int64_t head_num, - int64_t head_size, - int64_t inter_size, - int64_t d_model, - bool remove_padding, - int64_t layer_num, - int64_t num_bucket, - int64_t max_distance, - bool sparse, - double q_scaling, - int64_t tensor_para_size, - int64_t pipeline_para_size, - bool t5_with_bias, - int64_t position_embedding_type); + FasterTransformerT5Encoder(th::Tensor attr_output_layernorm_gamma, + th::Tensor q_kernel, + th::Tensor k_kernel, + th::Tensor v_kernel, + th::Tensor attr_output_kernel, + th::Tensor output_layernorm_gamma, + th::Tensor inter_kernel, + th::Tensor inter_kernel2, + th::Tensor output_kernel, + th::Tensor post_transformer_layernorm_gamma, + th::Tensor absolute_or_relative_position_embedding, + th::Tensor embedding_table, + th::Tensor attr_output_layernorm_beta, + th::Tensor q_bias, + th::Tensor k_bias, + th::Tensor v_bias, + th::Tensor attr_output_bias, + th::Tensor output_layernorm_beta, + th::Tensor inter_bias, + th::Tensor inter_bias2, + th::Tensor output_bias, + th::Tensor post_transformer_layernorm_beta, + int64_t head_num, + int64_t head_size, + int64_t inter_size, + int64_t d_model, + bool remove_padding, + int64_t layer_num, + int64_t num_bucket, + int64_t max_distance, + bool sparse, + double q_scaling, + int64_t tensor_para_size, + int64_t pipeline_para_size, + bool t5_with_bias, + int64_t position_embedding_type, + std::string activation_type); ~FasterTransformerT5Encoder(); - th::Tensor forward(th::Tensor input, th::Tensor sequence_lengths); + th::Tensor + forward(th::optional input, th::Tensor sequence_lengths, th::optional input_embeds); std::vector get_pickle_info() const; private: - const at::ScalarType _st; - bool _remove_padding; - IFT5Encoder* ft_t5_encoder; - th::Tensor head_info; - th::Tensor scaling_info; + const at::ScalarType _st; + bool _remove_padding; + int64_t d_model_; + IFT5Encoder* ft_t5_encoder; std::vector weights; }; diff --git a/src/fastertransformer/th_op/th_utils.cu b/src/fastertransformer/th_op/th_utils.cu index 3762e93fd..c70825bfc 100644 --- a/src/fastertransformer/th_op/th_utils.cu +++ b/src/fastertransformer/th_op/th_utils.cu @@ -30,14 +30,37 @@ std::vector convert_shape(torch::Tensor tensor) template fastertransformer::Tensor convert_tensor(torch::Tensor tensor) { - return fastertransformer::Tensor{fastertransformer::MEMORY_GPU, - fastertransformer::getTensorType(), - convert_shape(tensor), - get_ptr(tensor)}; + return convert_tensor(tensor, fastertransformer::MEMORY_GPU); } template fastertransformer::Tensor convert_tensor(torch::Tensor tensor); template fastertransformer::Tensor convert_tensor(torch::Tensor tensor); +#ifdef ENABLE_BF16 +template fastertransformer::Tensor convert_tensor<__nv_bfloat16>(torch::Tensor tensor); +#endif template fastertransformer::Tensor convert_tensor(torch::Tensor tensor); +template fastertransformer::Tensor convert_tensor(torch::Tensor tensor); +template fastertransformer::Tensor convert_tensor(torch::Tensor tensor); + +template +fastertransformer::Tensor convert_tensor(torch::Tensor tensor, fastertransformer::MemoryType memory_type) +{ + return fastertransformer::Tensor{ + memory_type, fastertransformer::getTensorType(), convert_shape(tensor), get_ptr(tensor)}; +} + +template fastertransformer::Tensor convert_tensor(torch::Tensor tensor, + fastertransformer::MemoryType memory_type); +template fastertransformer::Tensor convert_tensor(torch::Tensor tensor, + fastertransformer::MemoryType memory_type); +#ifdef ENABLE_BF16 +template fastertransformer::Tensor convert_tensor<__nv_bfloat16>(torch::Tensor tensor, + fastertransformer::MemoryType memory_type); +#endif +template fastertransformer::Tensor convert_tensor(torch::Tensor tensor, fastertransformer::MemoryType memory_type); +template fastertransformer::Tensor convert_tensor(torch::Tensor tensor, + fastertransformer::MemoryType memory_type); +template fastertransformer::Tensor convert_tensor(torch::Tensor tensor, + fastertransformer::MemoryType memory_type); } // namespace torch_ext diff --git a/src/fastertransformer/th_op/th_utils.h b/src/fastertransformer/th_op/th_utils.h index 7cb6830bc..727b2b211 100644 --- a/src/fastertransformer/th_op/th_utils.h +++ b/src/fastertransformer/th_op/th_utils.h @@ -52,4 +52,7 @@ std::vector convert_shape(torch::Tensor tensor); template fastertransformer::Tensor convert_tensor(torch::Tensor tensor); +template +fastertransformer::Tensor convert_tensor(torch::Tensor tensor, fastertransformer::MemoryType memory_type); + } // namespace torch_ext diff --git a/src/fastertransformer/th_op/vit/ViTINT8Op.cc b/src/fastertransformer/th_op/vit/ViTINT8Op.cc index 5db8925ef..e4490ff6c 100644 --- a/src/fastertransformer/th_op/vit/ViTINT8Op.cc +++ b/src/fastertransformer/th_op/vit/ViTINT8Op.cc @@ -23,16 +23,16 @@ template class VisionTransformerINT8Func; template class VisionTransformerINT8Func; VisionTransformerINT8Class::VisionTransformerINT8Class(std::vector w, - int64_t max_batch, - int64_t img_size, - int64_t patch_size, - int64_t in_chans, - int64_t embed_dim, - int64_t num_heads, - int64_t inter_size, - int64_t layer_num, - int64_t int8_mode, - int64_t with_cls_token): + int64_t max_batch, + int64_t img_size, + int64_t patch_size, + int64_t in_chans, + int64_t embed_dim, + int64_t num_heads, + int64_t inter_size, + int64_t layer_num, + int64_t int8_mode, + int64_t with_cls_token): st_(w[0].scalar_type()), weights_(w) { @@ -76,7 +76,7 @@ VisionTransformerINT8Class::VisionTransformerINT8Class(std::vector w default: throw std::runtime_error("Wrong th::Tensor type."); } - info_int_ = torch::empty({10}, torch::dtype(torch::kInt64)); + info_int_ = torch::empty({10}, torch::dtype(torch::kInt64)); info_int_[0] = max_batch; info_int_[1] = img_size; info_int_[2] = patch_size; @@ -104,8 +104,8 @@ VisionTransformerINT8Class::~VisionTransformerINT8Class() th::Tensor VisionTransformerINT8Class::forward(th::Tensor input) { CHECK_INPUT(input, st_); - int batch_size = input.size(0); - auto output = torch::empty({batch_size, output_seq_len_, output_emb_dim_}, + int batch_size = input.size(0); + auto output = torch::empty({batch_size, output_seq_len_, output_emb_dim_}, torch::dtype(input.dtype()).device(torch::kCUDA).requires_grad(false)); vit_func_->forward(batch_size, input, output); return output; @@ -136,22 +136,22 @@ static auto visionTransformerTHS = return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int state_size = state.size(); - std::vector::const_iterator first = state.begin(); - std::vector::const_iterator last = state.begin() + (state_size - 1); - std::vector weights(first, last); - int idx = state.size() - 1; - int i = 0; - int64_t max_batch = state[idx][i++].item().to(); - int64_t img_size = state[idx][i++].item().to(); - int64_t patch_size = state[idx][i++].item().to(); - int64_t in_chans = state[idx][i++].item().to(); - int64_t embed_dim = state[idx][i++].item().to(); - int64_t num_heads = state[idx][i++].item().to(); - int64_t inter_size = state[idx][i++].item().to(); - int64_t layer_num = state[idx][i++].item().to(); - int64_t int8_mode = state[idx][i++].item().to(); - int64_t with_cls_token = state[idx][i++].item().to(); + int state_size = state.size(); + std::vector::const_iterator first = state.begin(); + std::vector::const_iterator last = state.begin() + (state_size - 1); + std::vector weights(first, last); + int idx = state.size() - 1; + int i = 0; + int64_t max_batch = state[idx][i++].item().to(); + int64_t img_size = state[idx][i++].item().to(); + int64_t patch_size = state[idx][i++].item().to(); + int64_t in_chans = state[idx][i++].item().to(); + int64_t embed_dim = state[idx][i++].item().to(); + int64_t num_heads = state[idx][i++].item().to(); + int64_t inter_size = state[idx][i++].item().to(); + int64_t layer_num = state[idx][i++].item().to(); + int64_t int8_mode = state[idx][i++].item().to(); + int64_t with_cls_token = state[idx][i++].item().to(); return c10::make_intrusive(weights, max_batch, img_size, diff --git a/src/fastertransformer/th_op/vit/ViTINT8Op.h b/src/fastertransformer/th_op/vit/ViTINT8Op.h index 158af314b..3954a44d2 100644 --- a/src/fastertransformer/th_op/vit/ViTINT8Op.h +++ b/src/fastertransformer/th_op/vit/ViTINT8Op.h @@ -32,32 +32,32 @@ class IViTFunc { template class VisionTransformerINT8Func: public IViTFunc { public: - int sm_; - bool _use_ORDER_COL32_2R_4R4; - int max_batch_; - int img_size_; - int patch_size_; - int in_chans_; - int embed_dim_; - int num_heads_; - int head_dim_; - int inter_size_; - int layer_num_; + int sm_; + bool _use_ORDER_COL32_2R_4R4; + int max_batch_; + int img_size_; + int patch_size_; + int in_chans_; + int embed_dim_; + int num_heads_; + int head_dim_; + int inter_size_; + int layer_num_; float q_scaling_; - int int8_mode_; - bool with_cls_token_; - - VisionTransformerINT8Func(const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int num_heads, - const int inter_size, - const int layer_num, - const float q_scaling, - const int int8_mode, - const bool with_cls_token, + int int8_mode_; + bool with_cls_token_; + + VisionTransformerINT8Func(const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int num_heads, + const int inter_size, + const int layer_num, + const float q_scaling, + const int int8_mode, + const bool with_cls_token, const std::vector& w): weights_(w), max_batch_(max_batch), @@ -71,7 +71,9 @@ class VisionTransformerINT8Func: public IViTFunc { layer_num_(layer_num), q_scaling_(q_scaling), int8_mode_(int8_mode), - with_cls_token_(with_cls_token) + with_cls_token_(with_cls_token), + params_(ft::ViTINT8Weight( + embed_dim, inter_size, layer_num, img_size, patch_size, in_chans, with_cls_token, false)) { ft::check_cuda_error(cublasCreate(&cublas_handle_)); ft::check_cuda_error(cublasLtCreate(&cublaslt_handle_)); @@ -83,51 +85,51 @@ class VisionTransformerINT8Func: public IViTFunc { if (sm_ >= 80) { _use_ORDER_COL32_2R_4R4 = true; } - cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", std::string("")); + cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", std::string("")); cublas_wrapper_mutex_ = new std::mutex(); params_.vit_layer_weights.clear(); params_.vit_layer_weights.resize(layer_num_); - int idx_w = 0; + int idx_w = 0; params_.pre_transform_embeds.position_embed = get_ptr(weights_[idx_w++]); if (with_cls_token) { params_.pre_transform_embeds.class_embed = get_ptr(weights_[idx_w++]); } params_.pre_encoder_conv_weights.kernel = get_ptr(weights_[idx_w++]); - params_.pre_encoder_conv_weights.bias = get_ptr(weights_[idx_w++]); + params_.pre_encoder_conv_weights.bias = get_ptr(weights_[idx_w++]); for (int i = 0; i < layer_num_; i++) { - auto& layer_weight = params_.vit_layer_weights[i]; - layer_weight.attn_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); - layer_weight.attn_layernorm_weights.beta = get_ptr(weights_[idx_w++]); - layer_weight.ffn_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); - layer_weight.ffn_layernorm_weights.beta = get_ptr(weights_[idx_w++]); - layer_weight.ffn_weights.intermediate_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.ffn_weights.intermediate_weight.bias = get_ptr(weights_[idx_w++]); - layer_weight.ffn_weights.output_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.ffn_weights.output_weight.bias = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.query_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.query_weight.bias = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.key_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.key_weight.bias = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.value_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.value_weight.bias = get_ptr(weights_[idx_w++]); + auto& layer_weight = params_.vit_layer_weights[i]; + layer_weight.attn_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); + layer_weight.attn_layernorm_weights.beta = get_ptr(weights_[idx_w++]); + layer_weight.ffn_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); + layer_weight.ffn_layernorm_weights.beta = get_ptr(weights_[idx_w++]); + layer_weight.ffn_weights.intermediate_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.ffn_weights.intermediate_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.ffn_weights.output_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.ffn_weights.output_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.query_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.query_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.key_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.key_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.value_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.value_weight.bias = get_ptr(weights_[idx_w++]); layer_weight.attention_weights.attention_output_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.attention_output_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.attention_output_weight.bias = get_ptr(weights_[idx_w++]); } params_.post_transformer_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); - params_.post_transformer_layernorm_weights.beta = get_ptr(weights_[idx_w++]); + params_.post_transformer_layernorm_weights.beta = get_ptr(weights_[idx_w++]); for (int i = 0; i < layer_num_; i++) { - auto& layer_weight = params_.vit_layer_weights[i]; - layer_weight.scale_list_.size_ = ACTIVATION_AMAX_NUM + 9 * embed_dim + INT8O_GEMM_NUM + TRT_AMAX_NUM; + auto& layer_weight = params_.vit_layer_weights[i]; + layer_weight.scale_list_.size_ = ACTIVATION_AMAX_NUM + 9 * embed_dim + INT8O_GEMM_NUM + TRT_AMAX_NUM; layer_weight.scale_list_.p2_offset_ = ACTIVATION_AMAX_NUM; layer_weight.scale_list_.p3_offset_ = ACTIVATION_AMAX_NUM + 9 * embed_dim; layer_weight.scale_list_.p4_offset_ = ACTIVATION_AMAX_NUM + 9 * embed_dim + INT8O_GEMM_NUM; - layer_weight.scale_list_.d_scale_list_ = get_ptr(weights_[idx_w++]); - layer_weight.scale_list_.h_scale_list_ = get_ptr(weights_[idx_w++]); + layer_weight.scale_list_.d_scale_list_ = get_ptr(weights_[idx_w++]); + layer_weight.scale_list_.h_scale_list_ = get_ptr(weights_[idx_w++]); layer_weight.attention_weights.scale_list_ptr = &(layer_weight.scale_list_); - layer_weight.ffn_weights.scale_list_ptr = &(layer_weight.scale_list_); + layer_weight.ffn_weights.scale_list_ptr = &(layer_weight.scale_list_); } } @@ -157,7 +159,7 @@ class VisionTransformerINT8Func: public IViTFunc { cublas_wrapper->setFP32GemmConfig(); } - int seq_len = (img_size_ / patch_size_) * (img_size_ / patch_size_) + (with_cls_token_ ? 1 : 0); + int seq_len = (img_size_ / patch_size_) * (img_size_ / patch_size_) + (with_cls_token_ ? 1 : 0); ft::AttentionType attention_type = ft::getAttentionType(head_dim_, sm_, true, seq_len); auto vit = new ft::ViTTransformerINT8(max_batch_, @@ -179,8 +181,8 @@ class VisionTransformerINT8Func: public IViTFunc { true, attention_type); - ft::DataType data_type = ft::getTensorType(); - int sm_ptr[1] = {sm_}; + ft::DataType data_type = ft::getTensorType(); + int sm_ptr[1] = {sm_}; std::vector input_tensors = std::vector{ ft::Tensor{ft::MEMORY_GPU, data_type, @@ -202,27 +204,27 @@ class VisionTransformerINT8Func: public IViTFunc { private: std::vector weights_; - cublasHandle_t cublas_handle_ = nullptr; - cublasLtHandle_t cublaslt_handle_ = nullptr; - cudnnHandle_t cudnn_handle_ = nullptr; - ft::ViTINT8Weight params_; - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; + cublasHandle_t cublas_handle_ = nullptr; + cublasLtHandle_t cublaslt_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; + ft::ViTINT8Weight params_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; }; class VisionTransformerINT8Class: public torch::jit::CustomClassHolder { public: VisionTransformerINT8Class(std::vector w, - int64_t max_batch, - int64_t img_size, - int64_t patch_size, - int64_t in_chans, - int64_t embed_dim, - int64_t num_heads, - int64_t inter_size, - int64_t layer_num, - int64_t with_cls_token, - int64_t int8_mode); + int64_t max_batch, + int64_t img_size, + int64_t patch_size, + int64_t in_chans, + int64_t embed_dim, + int64_t num_heads, + int64_t inter_size, + int64_t layer_num, + int64_t with_cls_token, + int64_t int8_mode); ~VisionTransformerINT8Class(); @@ -231,12 +233,12 @@ class VisionTransformerINT8Class: public torch::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType st_; - IViTFunc* vit_func_; + const at::ScalarType st_; + IViTFunc* vit_func_; std::vector weights_; - th::Tensor info_int_; - int output_seq_len_; - int output_emb_dim_; + th::Tensor info_int_; + int output_seq_len_; + int output_emb_dim_; }; } // namespace torch_ext diff --git a/src/fastertransformer/th_op/vit/ViTOp.cc b/src/fastertransformer/th_op/vit/ViTOp.cc index ca51ef728..2d431decb 100644 --- a/src/fastertransformer/th_op/vit/ViTOp.cc +++ b/src/fastertransformer/th_op/vit/ViTOp.cc @@ -23,15 +23,15 @@ template class VisionTransformerFunc; template class VisionTransformerFunc; VisionTransformerClass::VisionTransformerClass(std::vector w, - int64_t max_batch, - int64_t img_size, - int64_t patch_size, - int64_t in_chans, - int64_t embed_dim, - int64_t num_heads, - int64_t inter_size, - int64_t layer_num, - int64_t with_cls_token): + int64_t max_batch, + int64_t img_size, + int64_t patch_size, + int64_t in_chans, + int64_t embed_dim, + int64_t num_heads, + int64_t inter_size, + int64_t layer_num, + int64_t with_cls_token): st_(w[0].scalar_type()), weights_(w) { @@ -72,7 +72,7 @@ VisionTransformerClass::VisionTransformerClass(std::vector w, default: throw std::runtime_error("Wrong th::Tensor type."); } - info_int_ = torch::empty({9}, torch::dtype(torch::kInt64)); + info_int_ = torch::empty({9}, torch::dtype(torch::kInt64)); info_int_[0] = max_batch; info_int_[1] = img_size; info_int_[2] = patch_size; @@ -99,8 +99,8 @@ VisionTransformerClass::~VisionTransformerClass() th::Tensor VisionTransformerClass::forward(th::Tensor input) { CHECK_INPUT(input, st_); - int batch_size = input.size(0); - auto output = torch::empty({batch_size, output_seq_len_, output_emb_dim_}, + int batch_size = input.size(0); + auto output = torch::empty({batch_size, output_seq_len_, output_emb_dim_}, torch::dtype(input.dtype()).device(torch::kCUDA).requires_grad(false)); vit_func_->forward(batch_size, input, output); return output; @@ -130,21 +130,21 @@ static auto visionTransformerTHS = return self->get_pickle_info(); }, [](std::vector state) -> c10::intrusive_ptr { - int state_size = state.size(); - std::vector::const_iterator first = state.begin(); - std::vector::const_iterator last = state.begin() + (state_size - 1); - std::vector weights(first, last); - int idx = state.size() - 1; - int i = 0; - int64_t max_batch = state[idx][i++].item().to(); - int64_t img_size = state[idx][i++].item().to(); - int64_t patch_size = state[idx][i++].item().to(); - int64_t in_chans = state[idx][i++].item().to(); - int64_t embed_dim = state[idx][i++].item().to(); - int64_t num_heads = state[idx][i++].item().to(); - int64_t inter_size = state[idx][i++].item().to(); - int64_t layer_num = state[idx][i++].item().to(); - int64_t with_cls_token = state[idx][i++].item().to(); + int state_size = state.size(); + std::vector::const_iterator first = state.begin(); + std::vector::const_iterator last = state.begin() + (state_size - 1); + std::vector weights(first, last); + int idx = state.size() - 1; + int i = 0; + int64_t max_batch = state[idx][i++].item().to(); + int64_t img_size = state[idx][i++].item().to(); + int64_t patch_size = state[idx][i++].item().to(); + int64_t in_chans = state[idx][i++].item().to(); + int64_t embed_dim = state[idx][i++].item().to(); + int64_t num_heads = state[idx][i++].item().to(); + int64_t inter_size = state[idx][i++].item().to(); + int64_t layer_num = state[idx][i++].item().to(); + int64_t with_cls_token = state[idx][i++].item().to(); return c10::make_intrusive(weights, max_batch, img_size, diff --git a/src/fastertransformer/th_op/vit/ViTOp.h b/src/fastertransformer/th_op/vit/ViTOp.h index 2a946e8f7..88e22936c 100644 --- a/src/fastertransformer/th_op/vit/ViTOp.h +++ b/src/fastertransformer/th_op/vit/ViTOp.h @@ -31,30 +31,30 @@ class IViTFunc { template class VisionTransformerFunc: public IViTFunc { public: - int sm_; - int max_batch_; - int img_size_; - int patch_size_; - int in_chans_; - int embed_dim_; - int num_heads_; - int head_dim_; - int inter_size_; - int layer_num_; - bool sparse_; + int sm_; + int max_batch_; + int img_size_; + int patch_size_; + int in_chans_; + int embed_dim_; + int num_heads_; + int head_dim_; + int inter_size_; + int layer_num_; + bool sparse_; float q_scaling_; - bool with_cls_token_; - - VisionTransformerFunc(const int max_batch, - const int img_size, - const int patch_size, - const int in_chans, - const int embed_dim, - const int num_heads, - const int inter_size, - const int layer_num, - const float q_scaling, - const bool with_cls_token, + bool with_cls_token_; + + VisionTransformerFunc(const int max_batch, + const int img_size, + const int patch_size, + const int in_chans, + const int embed_dim, + const int num_heads, + const int inter_size, + const int layer_num, + const float q_scaling, + const bool with_cls_token, const std::vector& w): weights_(w), max_batch_(max_batch), @@ -90,40 +90,40 @@ class VisionTransformerFunc: public IViTFunc { checkCUDNN(cudnnCreate(&cudnn_handle_)); sm_ = ft::getSMVersion(); - cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", std::string("")); + cublas_algo_map_ = new ft::cublasAlgoMap("gemm_config.in", std::string("")); cublas_wrapper_mutex_ = new std::mutex(); // params_.vit_layer_weights.clear(); // params_.vit_layer_weights.resize(layer_num_); - int idx_w = 0; + int idx_w = 0; params_.pre_encoder_conv_weights.kernel = get_ptr(weights_[idx_w++]); - params_.pre_encoder_conv_weights.bias = get_ptr(weights_[idx_w++]); + params_.pre_encoder_conv_weights.bias = get_ptr(weights_[idx_w++]); if (with_cls_token) { params_.pre_transform_embeds.class_embed = get_ptr(weights_[idx_w++]); - params_.with_cls_token_ = true; + params_.with_cls_token_ = true; } params_.pre_transform_embeds.position_embed = get_ptr(weights_[idx_w++]); for (int i = 0; i < layer_num_; i++) { - auto& layer_weight = params_.vit_layer_weights[i]; - layer_weight.attn_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); - layer_weight.attn_layernorm_weights.beta = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.query_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.query_weight.bias = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.key_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.key_weight.bias = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.value_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.value_weight.bias = get_ptr(weights_[idx_w++]); + auto& layer_weight = params_.vit_layer_weights[i]; + layer_weight.attn_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); + layer_weight.attn_layernorm_weights.beta = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.query_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.query_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.key_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.key_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.value_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.value_weight.bias = get_ptr(weights_[idx_w++]); layer_weight.attention_weights.attention_output_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.attention_weights.attention_output_weight.bias = get_ptr(weights_[idx_w++]); - layer_weight.ffn_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); - layer_weight.ffn_layernorm_weights.beta = get_ptr(weights_[idx_w++]); - layer_weight.ffn_weights.intermediate_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.ffn_weights.intermediate_weight.bias = get_ptr(weights_[idx_w++]); - layer_weight.ffn_weights.output_weight.kernel = get_ptr(weights_[idx_w++]); - layer_weight.ffn_weights.output_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.attention_weights.attention_output_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.ffn_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); + layer_weight.ffn_layernorm_weights.beta = get_ptr(weights_[idx_w++]); + layer_weight.ffn_weights.intermediate_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.ffn_weights.intermediate_weight.bias = get_ptr(weights_[idx_w++]); + layer_weight.ffn_weights.output_weight.kernel = get_ptr(weights_[idx_w++]); + layer_weight.ffn_weights.output_weight.bias = get_ptr(weights_[idx_w++]); } params_.post_transformer_layernorm_weights.gamma = get_ptr(weights_[idx_w++]); - params_.post_transformer_layernorm_weights.beta = get_ptr(weights_[idx_w++]); + params_.post_transformer_layernorm_weights.beta = get_ptr(weights_[idx_w++]); } ~VisionTransformerFunc() override @@ -136,7 +136,7 @@ class VisionTransformerFunc: public IViTFunc { void forward(int batch_size, th::Tensor& input, th::Tensor& output) override { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); auto cublas_handle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(cublas_handle, stream); @@ -151,7 +151,7 @@ class VisionTransformerFunc: public IViTFunc { else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } - int seq_len = (img_size_ / patch_size_) * (img_size_ / patch_size_) + (with_cls_token_ ? 1 : 0); + int seq_len = (img_size_ / patch_size_) * (img_size_ / patch_size_) + (with_cls_token_ ? 1 : 0); ft::AttentionType attention_type = ft::getAttentionType(head_dim_, sm_, true, seq_len); auto vit = new ft::ViTTransformer(max_batch_, @@ -172,8 +172,8 @@ class VisionTransformerFunc: public IViTFunc { true, attention_type); - ft::DataType data_type = ft::getTensorType(); - int sm_ptr[1] = {sm_}; + ft::DataType data_type = ft::getTensorType(); + int sm_ptr[1] = {sm_}; std::vector input_tensors = std::vector{ ft::Tensor{ft::MEMORY_GPU, data_type, @@ -195,25 +195,25 @@ class VisionTransformerFunc: public IViTFunc { private: std::vector weights_; - cublasLtHandle_t cublaslt_handle_ = nullptr; - cudnnHandle_t cudnn_handle_ = nullptr; - ft::ViTWeight params_; - std::mutex* cublas_wrapper_mutex_; - ft::cublasAlgoMap* cublas_algo_map_; + cublasLtHandle_t cublaslt_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; + ft::ViTWeight params_; + std::mutex* cublas_wrapper_mutex_; + ft::cublasAlgoMap* cublas_algo_map_; }; class VisionTransformerClass: public torch::jit::CustomClassHolder { public: VisionTransformerClass(std::vector w, - int64_t max_batch, - int64_t img_size, - int64_t patch_size, - int64_t in_chans, - int64_t embed_dim, - int64_t num_heads, - int64_t inter_size, - int64_t layer_num, - int64_t with_cls_token); + int64_t max_batch, + int64_t img_size, + int64_t patch_size, + int64_t in_chans, + int64_t embed_dim, + int64_t num_heads, + int64_t inter_size, + int64_t layer_num, + int64_t with_cls_token); ~VisionTransformerClass(); @@ -222,12 +222,12 @@ class VisionTransformerClass: public torch::jit::CustomClassHolder { std::vector get_pickle_info() const; private: - const at::ScalarType st_; - IViTFunc* vit_func_; + const at::ScalarType st_; + IViTFunc* vit_func_; std::vector weights_; - th::Tensor info_int_; - int output_seq_len_; - int output_emb_dim_; + th::Tensor info_int_; + int output_seq_len_; + int output_emb_dim_; }; } // namespace torch_ext diff --git a/src/fastertransformer/th_op/vit/WeightQuantizeOp.cc b/src/fastertransformer/th_op/vit/WeightQuantizeOp.cc index 453418426..0bc5f0c4d 100644 --- a/src/fastertransformer/th_op/vit/WeightQuantizeOp.cc +++ b/src/fastertransformer/th_op/vit/WeightQuantizeOp.cc @@ -50,10 +50,10 @@ Tensor vit_weight_quantize(Tensor weight, Tensor quant_max) TORCH_CHECK(quant_max.dtype() == torch::kFloat32, "quant_max dtype should be float32"); TORCH_CHECK(quant_max.numel() == n, "quant_max wrong shape"); - const float* weight_ = get_ptr(weight); + const float* weight_ = get_ptr(weight); const float* quant_max_ = get_ptr(quant_max); - auto output = torch::empty({k * n}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); + auto output = torch::empty({k * n}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); int8_t* transform_out = get_ptr(output); auto stream = at::cuda::getCurrentCUDAStream().stream(); diff --git a/src/fastertransformer/triton_backend/CMakeLists.txt b/src/fastertransformer/triton_backend/CMakeLists.txt index 42157aa15..34eb7bc55 100644 --- a/src/fastertransformer/triton_backend/CMakeLists.txt +++ b/src/fastertransformer/triton_backend/CMakeLists.txt @@ -14,8 +14,11 @@ cmake_minimum_required(VERSION 3.8) -if(BUILD_MULTI_GPU) - add_subdirectory(gptj) - add_subdirectory(t5) - add_subdirectory(multi_gpu_gpt) -endif() \ No newline at end of file +add_library(TransformerTritonBackend SHARED transformer_triton_backend.cpp) +target_link_libraries(TransformerTritonBackend PRIVATE nccl_utils mpi_utils) + +add_subdirectory(gptj) +add_subdirectory(gptneox) +add_subdirectory(t5) +add_subdirectory(multi_gpu_gpt) +add_subdirectory(bert) diff --git a/src/fastertransformer/triton_backend/bert/BertTritonModel.cc b/src/fastertransformer/triton_backend/bert/BertTritonModel.cc new file mode 100644 index 000000000..1bb89dc02 --- /dev/null +++ b/src/fastertransformer/triton_backend/bert/BertTritonModel.cc @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" + +#include "src/fastertransformer/triton_backend/bert/BertTritonModel.h" +#include "src/fastertransformer/triton_backend/bert/BertTritonModelInstance.h" + +namespace ft = fastertransformer; + +template +BertTritonModel::BertTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + bool enable_custom_all_reduce, + std::string model_dir, + int int8_mode, + bool is_sparse, + bool is_remove_padding): + tensor_para_size_(tensor_para_size), + pipeline_para_size_(pipeline_para_size), + shared_weights_(std::vector>>(ft::getDeviceCount())), + enable_custom_all_reduce_(enable_custom_all_reduce), + model_dir_(model_dir), + int8_mode_(int8_mode), + is_sparse_(is_sparse), + is_remove_padding_(is_remove_padding) +{ + ft::FT_CHECK_WITH_INFO(int8_mode_ == 0, "still not support int8 in bert backend"); + ft::FT_CHECK_WITH_INFO(is_sparse == false, "still not support sparse in bert backend"); + + INIReader reader = INIReader(model_dir + "/config.ini"); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << model_dir << "/config.ini" + << "'\n"; + ft::FT_CHECK(false); + } + + /* Bert base Configuration File Example + [bert] + model_name = bert + position_embedding_type = absolute + hidden_size = 768 + num_hidden_layers = 12 + head_num = 12 + size_per_head = 64 + activation_type = gelu + inter_size = 3072 + max_position_embeddings = 512 + layer_norm_eps = 1e-12 + weight_data_type = fp32 + */ + + model_name_ = reader.Get("bert", "model_name"); + head_num_ = reader.GetInteger("bert", "head_num"); + size_per_head_ = reader.GetInteger("bert", "size_per_head"); + inter_size_ = reader.GetInteger("bert", "inter_size"); + num_layer_ = reader.GetInteger("bert", "num_layer"); + layernorm_type_ = ft::getLayerNormType("post_layernorm"); + activation_type_ = ft::getActivationType(reader.Get("bert", "activation_type", "Gelu")); + q_scaling_ = reader.GetFloat("bert", "q_scaling", 1.0f); +} + +template +std::unique_ptr +BertTritonModel::createModelInstance(int device_id, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + + std::unique_ptr> allocator( + new ft::Allocator(device_id)); + + allocator->setStream(stream); + + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); + + std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); + std::unique_ptr cublas_wrapper_mutex(new std::mutex()); + std::unique_ptr cublas_wrapper(new ft::cublasMMWrapper( + cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get())); + + std::unique_ptr cuda_device_prop_ptr(new cudaDeviceProp); + ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); + + if (std::is_same::value) { + cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif + else if (std::is_same::value) { + cublas_wrapper->setFP32GemmConfig(); + } + + ft::NcclParam tensor_para = nccl_params.first[comms_rank]; + ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; + + const int max_seq_len = 384; + ft::AttentionType attention_type = + ft::getAttentionType(size_per_head_, ft::getSMVersion(), is_remove_padding_, max_seq_len); + + auto bert = + std::make_unique>(ft::Bert(0, // max_batch_size, FT will adjust the buffer automatically. + 0, // max_seq_len, FT will adjust the buffer automatically. + head_num_, + size_per_head_, + inter_size_, + num_layer_, + ft::getSMVersion(), + q_scaling_, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + attention_type, + is_sparse_, + activation_type_, + layernorm_type_, + tensor_para, + pipeline_para, + custom_all_reduce_comm, + enable_custom_all_reduce_)); + +#ifdef SPARSITY_ENABLED + if (is_sparse_) { + for (int i = 0; i < num_layer_; ++i) { + shared_weights_[device_id]->bert_layer_weights[i].compress_weights(*(cublas_wrapper.get()), + head_num_ * size_per_head_); + } + } +#endif + + return std::unique_ptr>(new BertTritonModelInstance(std::move(bert), + shared_weights_[device_id], + std::move(allocator), + std::move(cublas_algo_map), + std::move(cublas_wrapper_mutex), + std::move(cublas_wrapper), + std::move(cuda_device_prop_ptr))); +} + +template +void BertTritonModel::createSharedWeights(int device_id, int rank) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + shared_weights_[device_id] = std::make_shared>(head_num_ * size_per_head_, + inter_size_, + num_layer_, + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank); + + shared_weights_[device_id]->loadModel(model_dir_); + return; +} + +template +std::string BertTritonModel::toString() +{ + std::stringstream ss; + ss << "Model: " << model_name_ << "\nmodel_dir: " << model_dir_ << "\nhead_num: " << head_num_ + << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ + << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ + << "\nq_scaling: " << q_scaling_ << "\nis_remove_padding: " << is_remove_padding_ + << "\nis_sparse: " << is_sparse_ << "\nactivation_type: " << static_cast(activation_type_) + << "\nlayernorm_type: " << static_cast(layernorm_type_) << "\nint8_mode:" << int8_mode_ + << "\nenable_custom_all_reduce:" << enable_custom_all_reduce_ << "\nis_sparse: " << is_sparse << std::endl; + + return ss.str(); +} + +template +void BertTritonModel::createCustomComms( + std::vector>* custom_all_reduce_comms, int world_size) +{ + using commDataType = typename ft::CustomARCommTypeConverter::Type; + ft::initCustomAllReduceComm(custom_all_reduce_comms, enable_custom_all_reduce_, world_size); +} + +template +int BertTritonModel::getTensorParaSize() +{ + return tensor_para_size_; +} + +template +int BertTritonModel::getPipelineParaSize() +{ + return pipeline_para_size_; +} + +template struct BertTritonModel; +template struct BertTritonModel; +#ifdef ENABLE_BF16 +template struct BertTritonModel<__nv_bfloat16>; +#endif diff --git a/src/fastertransformer/triton_backend/bert/BertTritonModel.h b/src/fastertransformer/triton_backend/bert/BertTritonModel.h new file mode 100644 index 000000000..344ffbb9a --- /dev/null +++ b/src/fastertransformer/triton_backend/bert/BertTritonModel.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/models/bert/Bert.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" + +namespace ft = fastertransformer; + +template +struct BertTritonModel: public AbstractTransformerModel { + BertTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + bool enable_custom_all_reduce, + std::string model_dir, + int int8_mode, + bool is_sparse, + bool is_remove_padding); + + virtual std::unique_ptr + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) override; + + virtual void createSharedWeights(int deviceId, int rank) override; + + virtual void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) override; + + virtual std::string toString() override; + virtual int getTensorParaSize() override; + virtual int getPipelineParaSize() override; + +private: + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t tensor_para_size_; + size_t pipeline_para_size_; + + float q_scaling_; + bool is_remove_padding_; + bool is_sparse_; + ft::ActivationType activation_type_; + ft::LayerNormType layernorm_type_; + + std::string model_name_; + std::string model_dir_; + int int8_mode_ = 0; + bool enable_custom_all_reduce_ = 0; + bool is_sparse = false; + std::vector>> shared_weights_; +}; diff --git a/src/fastertransformer/triton_backend/bert/BertTritonModelInstance.cc b/src/fastertransformer/triton_backend/bert/BertTritonModelInstance.cc new file mode 100644 index 000000000..8fdd1b80a --- /dev/null +++ b/src/fastertransformer/triton_backend/bert/BertTritonModelInstance.cc @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/triton_backend/bert/BertTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/triton_utils.hpp" + +namespace ft = fastertransformer; + +template +void triton_stream_callback(std::unordered_map* output_tensors, void* ctx) +{ + BertTritonModelInstance* model = reinterpret_cast*>(ctx); + auto result = BertTritonModelInstance::convert_outputs(*output_tensors); + + model->stream_cb_(result, model->stream_ctx_); +} + +template +BertTritonModelInstance::BertTritonModelInstance(std::unique_ptr> bert, + std::shared_ptr> bert_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr): + bert_(std::move(bert)), + bert_weight_(bert_weight), + allocator_(std::move(allocator)), + cublas_algo_map_(std::move(cublas_algo_map)), + cublas_wrapper_mutex_(std::move(cublas_wrapper_mutex)), + cublas_wrapper_(std::move(cublas_wrapper)), + cuda_device_prop_ptr_(std::move(cuda_device_prop_ptr)) +{ +} + +template +std::shared_ptr> +BertTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + ft::FT_CHECK(false); + return nullptr; +} + +template +ft::TensorMap BertTritonModelInstance::convert_inputs( + std::shared_ptr> input_tensors) +{ + move_tensor_H2D(input_tensors->at("input_hidden_state"), d_input_hidden_state_, &allocator_); + move_tensor_H2D(input_tensors->at("sequence_lengths"), d_sequence_lengths_, &allocator_); + + ft::TensorMap ft_input_tensors( + {{"input_hidden_state", as_GPU_tensor(input_tensors->at("input_hidden_state"), d_input_hidden_state_)}, + {"sequence_lengths", as_GPU_tensor(input_tensors->at("sequence_lengths"), d_sequence_lengths_)}}); + + return ft_input_tensors; +} + +template +std::shared_ptr> +BertTritonModelInstance::convert_outputs(ft::TensorMap& output_tensors) +{ + std::unordered_map* outputs_mapping = + new std::unordered_map(); + + for (auto it = output_tensors.begin(); it != output_tensors.end(); it++) { + outputs_mapping->insert({it->first, triton::Tensor::convertFtTensorToTriton(it->second)}); + } + + return std::shared_ptr>(outputs_mapping); +} + +template +std::shared_ptr> +BertTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + const size_t batch_size = input_tensors->at("input_hidden_state").shape[0]; + const size_t seq_len = input_tensors->at("input_hidden_state").shape[1]; + const size_t hidden_units = input_tensors->at("input_hidden_state").shape[2]; + + allocateBuffer(batch_size, seq_len, hidden_units); + + ft::TensorMap ft_input_tensors = convert_inputs(input_tensors); + + ft::TensorMap output_tensors = ft::TensorMap({{"output_hidden_state", + ft::Tensor{ft::MEMORY_GPU, + ft::getTensorType(), + std::vector{batch_size, seq_len, hidden_units}, + d_output_hidden_state_}}}); + + bert_->forward(&output_tensors, &ft_input_tensors, bert_weight_.get()); + + if (d_input_hidden_state_ != nullptr) { + ft::check_cuda_error(cudaFree(d_input_hidden_state_)); + d_input_hidden_state_ = nullptr; + } + if (d_sequence_lengths_ != nullptr) { + ft::check_cuda_error(cudaFree(d_sequence_lengths_)); + d_sequence_lengths_ = nullptr; + } + + return convert_outputs(output_tensors); +} + +template +BertTritonModelInstance::~BertTritonModelInstance() +{ + freeBuffer(); +} + +template +void BertTritonModelInstance::allocateBuffer(const size_t batch_size, + const size_t seq_len, + const size_t hidden_units) +{ + d_output_hidden_state_ = + (T*)(allocator_->reMalloc(d_output_hidden_state_, sizeof(T) * batch_size * seq_len * hidden_units, false)); +} + +template +void BertTritonModelInstance::freeBuffer() +{ + if (d_output_hidden_state_ != nullptr) { + allocator_->free((void**)(&d_output_hidden_state_)); + } +} + +template struct BertTritonModelInstance; +template struct BertTritonModelInstance; +#ifdef ENABLE_BF16 +template struct BertTritonModelInstance<__nv_bfloat16>; +#endif \ No newline at end of file diff --git a/src/fastertransformer/triton_backend/bert/BertTritonModelInstance.h b/src/fastertransformer/triton_backend/bert/BertTritonModelInstance.h new file mode 100644 index 000000000..bb53a33f5 --- /dev/null +++ b/src/fastertransformer/triton_backend/bert/BertTritonModelInstance.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/models/bert/Bert.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" + +namespace ft = fastertransformer; + +template +struct BertTritonModelInstance: AbstractTransformerModelInstance { + + BertTritonModelInstance(std::unique_ptr> gpt, + std::shared_ptr> gpt_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr); + ~BertTritonModelInstance(); + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + static std::shared_ptr> + convert_outputs(ft::TensorMap& output_tensors); + +private: + const std::unique_ptr> bert_; + const std::shared_ptr> bert_weight_; + const std::unique_ptr> allocator_; + const std::unique_ptr cublas_algo_map_; + const std::unique_ptr cublas_wrapper_mutex_; + const std::unique_ptr cublas_wrapper_; + const std::unique_ptr cuda_device_prop_ptr_; + + ft::TensorMap convert_inputs(std::shared_ptr> input_tensors); + + void allocateBuffer(const size_t batch_size, const size_t seq_len, const size_t hidden_units); + void freeBuffer(); + + T* d_input_hidden_state_ = nullptr; + int* d_sequence_lengths_ = nullptr; + T* d_output_hidden_state_ = nullptr; +}; diff --git a/src/fastertransformer/triton_backend/bert/CMakeLists.txt b/src/fastertransformer/triton_backend/bert/CMakeLists.txt new file mode 100644 index 000000000..1fa9b6372 --- /dev/null +++ b/src/fastertransformer/triton_backend/bert/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +set(bert_triton_backend_files + BertTritonModel.cc + BertTritonModelInstance.cc +) + +add_library(BertTritonBackend SHARED ${bert_triton_backend_files}) +target_link_libraries(BertTritonBackend PRIVATE Bert TransformerTritonBackend) +target_compile_features(BertTritonBackend PRIVATE cxx_std_14) \ No newline at end of file diff --git a/src/fastertransformer/triton_backend/gptj/CMakeLists.txt b/src/fastertransformer/triton_backend/gptj/CMakeLists.txt index 0a4b91b25..def708439 100644 --- a/src/fastertransformer/triton_backend/gptj/CMakeLists.txt +++ b/src/fastertransformer/triton_backend/gptj/CMakeLists.txt @@ -20,5 +20,5 @@ set(parallel_gpt_triton_backend_files ) add_library(GptJTritonBackend SHARED ${parallel_gpt_triton_backend_files}) -target_link_libraries(GptJTritonBackend PRIVATE GptJ) +target_link_libraries(GptJTritonBackend PRIVATE TransformerTritonBackend GptJ) target_compile_features(GptJTritonBackend PRIVATE cxx_std_14) diff --git a/src/fastertransformer/triton_backend/gptj/GptJTritonModel.cc b/src/fastertransformer/triton_backend/gptj/GptJTritonModel.cc index 44bb4c323..7ea01d0b5 100644 --- a/src/fastertransformer/triton_backend/gptj/GptJTritonModel.cc +++ b/src/fastertransformer/triton_backend/gptj/GptJTritonModel.cc @@ -30,13 +30,30 @@ std::shared_ptr AbstractTransformerModel::createGptJMo return nullptr; } - const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); - const int is_half = reader.GetInteger("ft_instance_hyperparameter", "is_half"); - int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); - std::string model_dir = reader.Get("ft_instance_hyperparameter", "model_dir"); - model_dir = model_dir + "/" + std::to_string(tensor_para_size) + "-gpu/"; + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + std::string model_dir = reader.Get("ft_instance_hyperparameter", "model_dir"); + model_dir = model_dir + "/" + std::to_string(tensor_para_size) + "-gpu/"; - if (is_half) { + // Prompt Learning Configurations + int end_id = reader.GetInteger(model_name, "end_id"); + int prompt_learning_start_id = reader.GetInteger(model_name, "prompt_learning_start_id", end_id + 1); + ft::PromptLearningType prompt_learning_type = + static_cast(reader.GetInteger(model_name, "prompt_learning_type", 0)); + + std::map> prompt_learning_table_pair; + + // NOTE: get prompt from configuration files + const int num_tasks = reader.GetInteger(model_name, "num_tasks", 0); + for (int task_name_id = 0; task_name_id < num_tasks; task_name_id++) { + std::string config_task_name = model_name + "_task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prompt_learning_table_pair.insert({task_name, {task_name_id, prompt_length}}); + } + + if (data_type == "fp16") { return std::make_shared>( reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"), reader.GetInteger(model_name, "head_num"), @@ -46,14 +63,17 @@ std::shared_ptr AbstractTransformerModel::createGptJMo reader.GetInteger(model_name, "vocab_size"), reader.GetInteger(model_name, "rotary_embedding"), reader.GetInteger(model_name, "start_id"), - reader.GetInteger(model_name, "end_id"), + end_id, + prompt_learning_start_id, + prompt_learning_type, + prompt_learning_table_pair, reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), model_name, model_dir); } - else { + else if (data_type == "fp32") { return std::make_shared>( reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"), reader.GetInteger(model_name, "head_num"), @@ -63,30 +83,126 @@ std::shared_ptr AbstractTransformerModel::createGptJMo reader.GetInteger(model_name, "vocab_size"), reader.GetInteger(model_name, "rotary_embedding"), reader.GetInteger(model_name, "start_id"), - reader.GetInteger(model_name, "end_id"), + end_id, + prompt_learning_start_id, + prompt_learning_type, + prompt_learning_table_pair, + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + model_name, + model_dir); + } +#ifdef ENABLE_BF16 + else if (data_type == "bf16") { + return std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"), + reader.GetInteger(model_name, "head_num"), + reader.GetInteger(model_name, "size_per_head"), + reader.GetInteger(model_name, "inter_size"), + reader.GetInteger(model_name, "decoder_layers"), + reader.GetInteger(model_name, "vocab_size"), + reader.GetInteger(model_name, "rotary_embedding"), + reader.GetInteger(model_name, "start_id"), + end_id, + prompt_learning_start_id, + prompt_learning_type, + prompt_learning_table_pair, reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), model_name, model_dir); } +#endif + else { + FT_LOG_ERROR("Unsupported data type " + data_type); + exit(-1); + } } template -GptJTritonModel::GptJTritonModel(size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - size_t rotary_embedding_dim, - int start_id, - int end_id, - size_t tensor_para_size, - size_t pipeline_para_size, - int enable_custom_all_reduce, - std::string model_name, +GptJTritonModel::GptJTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, std::string model_dir): + tensor_para_size_(tensor_para_size), + pipeline_para_size_(pipeline_para_size), + shared_weights_(std::vector>>(ft::getDeviceCount())), + enable_custom_all_reduce_(enable_custom_all_reduce), + model_dir_(model_dir) +{ + INIReader reader = INIReader(model_dir + "/config.ini"); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << model_dir << "/config.ini" + << "'\n"; + ft::FT_CHECK(false); + } + + /* GPTJ Model Config Example + [gptj] + model_name=gpt-j-6B + head_num=16 + size_per_head=256 + inter_size=16384 + num_layer=28 + rotary_embedding_dim=64 + vocab_size=50400 + start_id=50256 + end_id=50256 + prompt_learning_start_id=50257 ; only required by p/prompt tuning + prompt_learning_type=3 + num_tasks=2 + + [task_0] + task_name=no_prompt + prompt_length=0 + + [task_1] + task_name=len1_seed100 + prompt_length=1 + */ + model_name_ = reader.Get("gptj", "model_name"); + head_num_ = reader.GetInteger("gptj", "head_num"); + size_per_head_ = reader.GetInteger("gptj", "size_per_head"); + inter_size_ = reader.GetInteger("gptj", "inter_size"); + num_layer_ = reader.GetInteger("gptj", "num_layer"); + vocab_size_ = reader.GetInteger("gptj", "vocab_size"); + rotary_embedding_dim_ = reader.GetInteger("gptj", "rotary_embedding"); + start_id_ = reader.GetInteger("gptj", "start_id"); + end_id_ = reader.GetInteger("gptj", "end_id"); + + num_tasks_ = reader.GetInteger("gptj", "num_tasks", 0); + + prompt_learning_start_id_ = reader.GetInteger("gptj", "prompt_learning_start_id", end_id_ + 1); + prompt_learning_type_ = static_cast(reader.GetInteger("gptj", "prompt_learning_type", 0)); + + for (int task_name_id = 0; task_name_id < num_tasks_; task_name_id++) { + std::string config_task_name = "task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prompt_learning_table_pair_.insert({task_name, {task_name_id, prompt_length}}); + } +} + +template +GptJTritonModel::GptJTritonModel(size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, + ft::PromptLearningType prompt_learning_type, + std::map> prompt_learning_table_pair, + size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_name, + std::string model_dir): max_seq_len_(max_seq_len), head_num_(head_num), size_per_head_(size_per_head), @@ -96,8 +212,12 @@ GptJTritonModel::GptJTritonModel(size_t max_seq_len, rotary_embedding_dim_(rotary_embedding_dim), start_id_(start_id), end_id_(end_id), + prompt_learning_start_id_(prompt_learning_start_id), + prompt_learning_type_(prompt_learning_type), + prompt_learning_table_pair_(prompt_learning_table_pair), tensor_para_size_(tensor_para_size), pipeline_para_size_(pipeline_para_size), + shared_weights_(std::vector>>(ft::getDeviceCount())), enable_custom_all_reduce_(enable_custom_all_reduce), model_name_(model_name), model_dir_(model_dir) @@ -106,30 +226,29 @@ GptJTritonModel::GptJTritonModel(size_t max_seq_len, template std::unique_ptr -GptJTritonModel::createModelInstance(int device_id, - int rank, - cudaStream_t stream, - std::pair, std::vector> nccl_comms, +GptJTritonModel::createModelInstance(int device_id, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, std::shared_ptr custom_all_reduce_comm) { ft::check_cuda_error(cudaSetDevice(device_id)); - const int tensor_para_rank = rank % tensor_para_size_; - const int pipeline_para_rank = rank / tensor_para_size_; + const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); std::unique_ptr> allocator( new ft::Allocator(device_id)); allocator->setStream(stream); - cublasHandle_t cublas_handle; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cublasCreate(&cublas_handle); cublasLtCreate(&cublaslt_handle); cublasSetStream(cublas_handle, stream); - std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); - std::unique_ptr cublas_wrapper_mutex(new std::mutex()); + std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); + std::unique_ptr cublas_wrapper_mutex(new std::mutex()); std::unique_ptr cublas_wrapper(new ft::cublasMMWrapper( cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get())); @@ -139,55 +258,52 @@ GptJTritonModel::createModelInstance(int device_id, if (std::is_same::value) { cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } - ft::NcclParam tensor_para(tensor_para_rank, tensor_para_size_, nccl_comms.first[device_id]); - ft::NcclParam pipeline_para(pipeline_para_rank, pipeline_para_size_, nccl_comms.second[device_id]); - - auto gpt = std::make_unique>(ft::GptJ(0, // max_batch_size, FT will adjust the buffer automatically. - 0, // max_seq_len, FT will adjust the buffer automatically. - 0, // max_input_len, FT will adjust the buffer automatically. - 0, - head_num_, - size_per_head_, - inter_size_, - num_layer_, - vocab_size_, - rotary_embedding_dim_, - start_id_, - end_id_, - 0.0f, // beam_search_diversity_rate_, - 0, // top_k_, - 0.0f, // top_p_, - 0, // random seed, note that all gpus should use same seed - 0.0f, // temperature_, - 0.0f, // len_penalty_, - 0.0f, // repetition_penalty_, - tensor_para, - pipeline_para, - stream, - cublas_wrapper.get(), - allocator.get(), - false, - cuda_device_prop_ptr.get(), - custom_all_reduce_comm, - enable_custom_all_reduce_)); - - auto weight = std::unique_ptr>(new ft::GptJWeight(head_num_ * size_per_head_, - inter_size_, - vocab_size_, - num_layer_, - max_seq_len_, - tensor_para_size_, - tensor_para_rank, - pipeline_para_size_, - pipeline_para_rank)); - - weight->loadModel(model_dir_); + ft::NcclParam tensor_para = nccl_params.first[comms_rank]; + ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; + + auto gpt = + std::make_unique>(ft::GptJ(0, // max_batch_size, FT will adjust the buffer automatically. + 0, // max_seq_len, FT will adjust the buffer automatically. + 0, // max_input_len, FT will adjust the buffer automatically. + 0, + head_num_, + size_per_head_, + inter_size_, + num_layer_, + vocab_size_, + rotary_embedding_dim_, + start_id_, + end_id_, + prompt_learning_start_id_, // p/prompt tuning virtual token start id + prompt_learning_type_, + 0.0f, // beam_search_diversity_rate_, + 0, // top_k_, + 0.0f, // top_p_, + 0, // random seed, note that all gpus should use same seed + 0.0f, // temperature_, + 0.0f, // len_penalty_, + 0.0f, // repetition_penalty_, + tensor_para, + pipeline_para, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + cuda_device_prop_ptr.get(), + custom_all_reduce_comm, + enable_custom_all_reduce_)); + return std::unique_ptr>(new GptJTritonModelInstance(std::move(gpt), - std::move(weight), + shared_weights_[device_id], std::move(allocator), std::move(cublas_algo_map), std::move(cublas_wrapper_mutex), @@ -195,87 +311,40 @@ GptJTritonModel::createModelInstance(int device_id, std::move(cuda_device_prop_ptr))); } +template +void GptJTritonModel::createSharedWeights(int device_id, int rank) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + shared_weights_[device_id] = std::make_shared>(head_num_ * size_per_head_, + inter_size_, + vocab_size_, + num_layer_, + max_seq_len_, // not needed + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank, + prompt_learning_type_, + prompt_learning_table_pair_); + shared_weights_[device_id]->loadModel(model_dir_); + return; +} + template std::string GptJTritonModel::toString() { std::stringstream ss; ss << "Model: " - << "\nmax_seq_len: " << max_seq_len_ << "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ - << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ - << "\nstart_id: " << start_id_ << "\nend_id: " << end_id_ << "\ntensor_para_size: " << tensor_para_size_ + << "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ + << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ << "\nstart_id: " << start_id_ + << "\nend_id: " << end_id_ << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_ << "\nmodel_dir: " << model_dir_ << std::endl; return ss.str(); } -template -std::vector GptJTritonModel::createNcclIds(const uint32_t world_size, bool multi_instances) -{ - std::vector nccl_ids(tensor_para_size_ + pipeline_para_size_); - if (multi_instances) { - if (tensor_para_size_ * pipeline_para_size_ != 1) { - printf( - "[ERROR] Multiple Instances currently only support tensor_para_size_ and pipeline_para_size_ both 1\n"); - ft::FT_CHECK(tensor_para_size_ == 1 && pipeline_para_size_ == 1); - } - nccl_ids.resize(2); - } - else { - if (world_size != tensor_para_size_ * pipeline_para_size_) { - printf( - "[ERROR] world_size (%d) should equal to tensor_para_size_ * pipeline_para_size_ (%ld * %ld here) \n", - world_size, - tensor_para_size_, - pipeline_para_size_); - ft::FT_CHECK(world_size == tensor_para_size_ * pipeline_para_size_); - } - } - - for (uint32_t i = 0; i < nccl_ids.size(); i++) { - NCCLCHECK(ncclGetUniqueId(&nccl_ids[i])); - } - return nccl_ids; -} - -template -std::pair, std::vector> GptJTritonModel::createNcclComms( - std::vector nccl_ids, const int node_id, bool multi_instances, int instance_id) -{ - const int gpu_count = ft::getDeviceCount(); - std::vector tensor_para_comms(gpu_count); - std::vector pipeline_para_comms(gpu_count); - if (multi_instances) { - ncclUniqueId tensor_para_nccl_uid = nccl_ids[0]; - ncclUniqueId pipeline_para_nccl_uid = nccl_ids[1]; - size_t tensor_para_rank = 0; - size_t pipeline_para_rank = 0; - - ft::check_cuda_error(cudaSetDevice(instance_id)); - NCCLCHECK(ncclCommInitRank( - &tensor_para_comms[instance_id], tensor_para_size_, tensor_para_nccl_uid, tensor_para_rank)); - NCCLCHECK(ncclCommInitRank( - &pipeline_para_comms[instance_id], pipeline_para_size_, pipeline_para_nccl_uid, pipeline_para_rank)); - } - else { - NCCLCHECK(ncclGroupStart()); - for (int gid = 0; gid < gpu_count; gid++) { - int rank = node_id * gpu_count + gid; - size_t tensor_para_rank = rank % tensor_para_size_; - size_t pipeline_para_rank = rank / tensor_para_size_; - ncclUniqueId tensor_para_nccl_uid = nccl_ids[pipeline_para_rank]; - ncclUniqueId pipeline_para_nccl_uid = nccl_ids[pipeline_para_size_ + tensor_para_rank]; - - ft::check_cuda_error(cudaSetDevice(gid)); - NCCLCHECK( - ncclCommInitRank(&tensor_para_comms[gid], tensor_para_size_, tensor_para_nccl_uid, tensor_para_rank)); - NCCLCHECK(ncclCommInitRank( - &pipeline_para_comms[gid], pipeline_para_size_, pipeline_para_nccl_uid, pipeline_para_rank)); - } - NCCLCHECK(ncclGroupEnd()); - } - return std::pair, std::vector>(tensor_para_comms, pipeline_para_comms); -} - template void GptJTritonModel::createCustomComms( std::vector>* custom_all_reduce_comms, int world_size) @@ -298,3 +367,6 @@ int GptJTritonModel::getPipelineParaSize() template struct GptJTritonModel; template struct GptJTritonModel; +#ifdef ENABLE_BF16 +template struct GptJTritonModel<__nv_bfloat16>; +#endif diff --git a/src/fastertransformer/triton_backend/gptj/GptJTritonModel.h b/src/fastertransformer/triton_backend/gptj/GptJTritonModel.h index 131f1e95d..2a532e023 100644 --- a/src/fastertransformer/triton_backend/gptj/GptJTritonModel.h +++ b/src/fastertransformer/triton_backend/gptj/GptJTritonModel.h @@ -27,61 +27,72 @@ namespace ft = fastertransformer; template struct GptJTritonModel: public AbstractTransformerModel { - GptJTritonModel(size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - size_t rotary_embedding_dim, - int start_id, - int end_id, - size_t tensor_para_size, - size_t pipeline_para_size, - int enable_custom_all_reduce, - std::string model_name, + GptJTritonModel(size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + int start_id, + int end_id, + int prompt_learning_start_id, + ft::PromptLearningType prompt_learning_type, + std::map> prompt_learning_table_pair, + size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_name, + std::string model_dir); + + GptJTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, std::string model_dir); ~GptJTritonModel() = default; virtual std::unique_ptr - createModelInstance(int deviceId, - int rank, - cudaStream_t stream, - std::pair, std::vector> nccl_comms, + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, std::shared_ptr custom_all_reduce_comm = nullptr) override; - virtual void createCustomComms(std::vector>* custom_all_reduce_comms, - int world_size) override; - - virtual std::pair, std::vector> - createNcclComms(std::vector nccl_ids, - const int node_id, - bool multi_instances = false, - int instance_id = 0) override; + virtual void createSharedWeights(int deviceId, int rank) override; - virtual std::vector createNcclIds(const uint32_t world_size, bool multi_instances = false) override; + virtual void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) override; virtual std::string toString() override; - virtual int getTensorParaSize() override; - virtual int getPipelineParaSize() override; + virtual int getTensorParaSize() override; + virtual int getPipelineParaSize() override; private: - const size_t max_seq_len_; - const size_t head_num_; - const size_t size_per_head_; - const size_t inter_size_; - const size_t num_layer_; - const size_t vocab_size_; - const size_t rotary_embedding_dim_; - const int start_id_; - const int end_id_; - const size_t tensor_para_size_; - const size_t pipeline_para_size_; + size_t max_seq_len_ = 0; // optional as FT automatically sets it + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + size_t rotary_embedding_dim_; + int start_id_; + int end_id_; + size_t tensor_para_size_; + size_t pipeline_para_size_; bool is_fp16_; - int enable_custom_all_reduce_ = 0; + int enable_custom_all_reduce_ = 0; + + // shared weights for each device + std::vector>> shared_weights_; std::string model_name_; std::string model_dir_; -}; \ No newline at end of file + + // number of tasks (for prefix-prompt, p/prompt-tuning) + size_t num_tasks_ = 0; + int prompt_learning_start_id_ = 0; + ft::PromptLearningType prompt_learning_type_ = ft::PromptLearningType::no_prompt; + std::map> prompt_learning_table_pair_ = {}; +}; diff --git a/src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.cc b/src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.cc index 61ec89967..51856c9de 100644 --- a/src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.cc +++ b/src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.cc @@ -26,15 +26,24 @@ namespace ft = fastertransformer; template -GptJTritonModelInstance::GptJTritonModelInstance(std::unique_ptr> gpt, - std::unique_ptr> gpt_weight, +void triton_stream_callback(std::unordered_map* output_tensors, void* ctx) +{ + GptJTritonModelInstance* model = reinterpret_cast*>(ctx); + auto result = GptJTritonModelInstance::convert_outputs(*output_tensors); + + model->stream_cb_(result, model->stream_ctx_); +} + +template +GptJTritonModelInstance::GptJTritonModelInstance(std::unique_ptr> gpt, + std::shared_ptr> gpt_weight, std::unique_ptr> allocator, - std::unique_ptr cublas_algo_map, - std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, std::unique_ptr cublas_wrapper, - std::unique_ptr cuda_device_prop_ptr): + std::unique_ptr cuda_device_prop_ptr): gpt_(std::move(gpt)), - gpt_weight_(std::move(gpt_weight)), + gpt_weight_(gpt_weight), allocator_(std::move(allocator)), cublas_algo_map_(std::move(cublas_algo_map)), cublas_wrapper_mutex_(std::move(cublas_wrapper_mutex)), @@ -49,72 +58,55 @@ std::unordered_map GptJTritonModelInstance::convert_ { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_); - move_tensor_H2D(input_tensors->at("input_lengths"), d_input_lengths_); + move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_, &allocator_); + move_tensor_H2D(input_tensors->at("input_lengths"), d_input_lengths_, &allocator_); - const int input_data_len = input_tensors->at("input_ids").shape[1]; - size_t size = 1; - for (auto t : input_tensors->at("request_output_len").shape) { - size = size * t; - } - - h_total_output_lengths_ = reinterpret_cast(malloc(size * sizeof(int))); - for (int i = 0; i < size; ++i) { + const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; + const size_t input_data_len = input_tensors->at("input_ids").shape[1]; + h_total_output_lengths_ = reinterpret_cast(malloc(request_batch_size * sizeof(uint32_t))); + for (int i = 0; i < request_batch_size; ++i) { h_total_output_lengths_[i] = - reinterpret_cast(input_tensors->at("request_output_len").data)[i] + input_data_len; + reinterpret_cast(input_tensors->at("request_output_len").data)[i] + input_data_len; } std::unordered_map ft_input_tensors = std::unordered_map{ {"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, {"input_lengths", as_GPU_tensor(input_tensors->at("input_lengths"), d_input_lengths_)}, - {"max_output_seq_len", - ft::Tensor{ - ft::MEMORY_CPU, ft::TYPE_INT32, input_tensors->at("request_output_len").shape, h_total_output_lengths_}}}; + {"output_seq_len", + ft::Tensor{ft::MEMORY_CPU, + ft::TYPE_UINT32, + {input_tensors->at("request_output_len").shape[0]}, + h_total_output_lengths_}}}; if (input_tensors->find("bad_words_list") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_); + move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_, &allocator_); ft_input_tensors.insert( {"bad_words_list", as_GPU_tensor(input_tensors->at("bad_words_list"), d_input_bad_words_)}); } if (input_tensors->find("stop_words_list") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_); + move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_, &allocator_); ft_input_tensors.insert( {"stop_words_list", as_GPU_tensor(input_tensors->at("stop_words_list"), d_input_stop_words_)}); } - if (input_tensors->count("prefix_soft_prompt_embedding") && input_tensors->count("prefix_soft_prompt_lengths")) { - triton::Tensor soft_prompt_lengths_tensor = input_tensors->at("prefix_soft_prompt_lengths"); - size_t length_size = std::accumulate(soft_prompt_lengths_tensor.shape.begin(), - soft_prompt_lengths_tensor.shape.end(), - 1, - std::multiplies()); - ft::deviceMalloc(&d_prefix_soft_prompt_lengths_, length_size, false); - ft::cudaH2Dcpy( - d_prefix_soft_prompt_lengths_, reinterpret_cast(soft_prompt_lengths_tensor.data), length_size); + if (input_tensors->count("request_prompt_embedding") && input_tensors->count("request_prompt_lengths") + && input_tensors->count("request_prompt_type")) { + + move_tensor_H2D(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_, &allocator_); ft_input_tensors.insert( - {"prefix_soft_prompt_lengths", - ft::Tensor{ - ft::MEMORY_GPU, ft::TYPE_INT32, soft_prompt_lengths_tensor.shape, d_prefix_soft_prompt_lengths_}}); - - triton::Tensor soft_prompt_embedding_tensor = input_tensors->at("prefix_soft_prompt_embedding"); - size_t emb_size = std::accumulate(soft_prompt_embedding_tensor.shape.begin(), - soft_prompt_embedding_tensor.shape.end(), - 1, - std::multiplies()); - ft::deviceMalloc(&d_prefix_soft_prompt_embedding_, emb_size, false); - ft::cudaH2Dcpy(d_prefix_soft_prompt_embedding_, - reinterpret_cast(soft_prompt_embedding_tensor.data), - emb_size); + {"request_prompt_lengths", + as_GPU_tensor(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_)}); + + move_tensor_H2D(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_, &allocator_); ft_input_tensors.insert( - {"prefix_soft_prompt_embedding", - ft::Tensor{ - ft::MEMORY_GPU, ft::TYPE_FP32, soft_prompt_embedding_tensor.shape, d_prefix_soft_prompt_embedding_}}); + {"request_prompt_embedding", + as_GPU_tensor(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_)}); } for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { if (t->first.find("input_ids") == std::string::npos && t->first.find("input_lengths") == std::string::npos - && t->first.find("max_output_seq_len") == std::string::npos + && t->first.find("output_seq_len") == std::string::npos && t->first.find("prefix_soft_prompt_embedding") == std::string::npos && t->first.find("prefix_soft_prompt_lengths") == std::string::npos) { if (ft_input_tensors.count(t->first) == 0) { @@ -159,15 +151,14 @@ GptJTritonModelInstance::forward(std::shared_ptrat("input_lengths").shape.size() == 1, "input_tensors->at(\"input_lengths\").shape.size() == 1"); - const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; - const size_t max_request_output_len = (size_t)*std::max_element( + const uint32_t request_batch_size = input_tensors->at("input_ids").shape[0]; + const uint32_t max_request_output_len = (size_t)*std::max_element( (int*)input_tensors->at("request_output_len").data, (int*)input_tensors->at("request_output_len").data + input_tensors->at("request_output_len").shape[0]); - const size_t total_output_len = max_request_output_len + input_tensors->at("input_ids").shape[1]; - const size_t beam_width = + const uint32_t total_output_len = max_request_output_len + input_tensors->at("input_ids").shape[1]; + const uint32_t beam_width = input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; - freeBuffer(); // free buffer of previous iteration allocateBuffer(request_batch_size, beam_width, total_output_len, max_request_output_len); std::unordered_map ft_input_tensors = convert_inputs(input_tensors); @@ -196,36 +187,20 @@ GptJTritonModelInstance::forward(std::shared_ptr{request_batch_size, beam_width}, d_cum_log_probs_}}); } - gpt_->forward(&output_tensors, &ft_input_tensors, gpt_weight_.get()); - if (d_input_ids_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_ids_)); - d_input_ids_ = nullptr; + if (stream_cb_ != nullptr) { + gpt_->registerCallback(triton_stream_callback, this); } - if (d_input_lengths_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_lengths_)); - d_input_lengths_ = nullptr; + gpt_->forward(&output_tensors, &ft_input_tensors, gpt_weight_.get()); + if (stream_cb_ != nullptr) { + gpt_->unRegisterCallback(); } + if (h_total_output_lengths_ != nullptr) { free(h_total_output_lengths_); h_total_output_lengths_ = nullptr; } - if (d_input_bad_words_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_bad_words_)); - d_input_bad_words_ = nullptr; - } - if (d_input_stop_words_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_stop_words_)); - d_input_stop_words_ = nullptr; - } - if (d_prefix_soft_prompt_embedding_ != nullptr) { - ft::check_cuda_error(cudaFree(d_prefix_soft_prompt_embedding_)); - d_prefix_soft_prompt_embedding_ = nullptr; - } - if (d_prefix_soft_prompt_lengths_ != nullptr) { - ft::check_cuda_error(cudaFree(d_prefix_soft_prompt_lengths_)); - d_prefix_soft_prompt_lengths_ = nullptr; - } + return convert_outputs(output_tensors); } @@ -241,20 +216,27 @@ void GptJTritonModelInstance::allocateBuffer(const size_t request_batch_size, const size_t total_output_len, const size_t max_request_output_len) { - ft::deviceMalloc(&d_output_ids_, request_batch_size * beam_width * total_output_len); - ft::deviceMalloc(&d_sequence_lengths_, request_batch_size * beam_width); - ft::deviceMalloc(&d_output_log_probs_, max_request_output_len * request_batch_size * beam_width); - ft::deviceMalloc(&d_cum_log_probs_, request_batch_size * beam_width); + d_output_ids_ = (int*)(allocator_->reMalloc( + d_output_ids_, sizeof(int) * request_batch_size * beam_width * total_output_len, false)); + d_sequence_lengths_ = + (int*)(allocator_->reMalloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width, false)); + d_output_log_probs_ = (float*)(allocator_->reMalloc( + d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * max_request_output_len, false)); + d_cum_log_probs_ = + (float*)(allocator_->reMalloc(d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width, false)); } template void GptJTritonModelInstance::freeBuffer() { - ft::deviceFree(d_output_ids_); - ft::deviceFree(d_sequence_lengths_); - ft::deviceFree(d_output_log_probs_); - ft::deviceFree(d_cum_log_probs_); + allocator_->free((void**)(&d_output_ids_)); + allocator_->free((void**)(&d_sequence_lengths_)); + allocator_->free((void**)(&d_output_log_probs_)); + allocator_->free((void**)(&d_cum_log_probs_)); } template struct GptJTritonModelInstance; template struct GptJTritonModelInstance; +#ifdef ENABLE_BF16 +template struct GptJTritonModelInstance<__nv_bfloat16>; +#endif \ No newline at end of file diff --git a/src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h b/src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h index da3a038bb..e5fd36a54 100644 --- a/src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h +++ b/src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h @@ -26,13 +26,13 @@ namespace ft = fastertransformer; template struct GptJTritonModelInstance: AbstractTransformerModelInstance { - GptJTritonModelInstance(std::unique_ptr> gpt, - std::unique_ptr> gpt_weight, + GptJTritonModelInstance(std::unique_ptr> gpt, + std::shared_ptr> gpt_weight, std::unique_ptr> allocator, - std::unique_ptr cublas_algo_map, - std::unique_ptr cublas_wrapper_mutex, - std::unique_ptr cublas_wrapper, - std::unique_ptr cuda_device_prop_ptr); + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr); ~GptJTritonModelInstance(); std::shared_ptr> @@ -41,19 +41,20 @@ struct GptJTritonModelInstance: AbstractTransformerModelInstance { std::shared_ptr> forward(std::shared_ptr> input_tensors) override; + static std::shared_ptr> + convert_outputs(const std::unordered_map& output_tensors); + private: - const std::unique_ptr> gpt_; - const std::unique_ptr> gpt_weight_; + const std::unique_ptr> gpt_; + const std::shared_ptr> gpt_weight_; const std::unique_ptr> allocator_; - const std::unique_ptr cublas_algo_map_; - const std::unique_ptr cublas_wrapper_mutex_; - const std::unique_ptr cublas_wrapper_; - const std::unique_ptr cuda_device_prop_ptr_; + const std::unique_ptr cublas_algo_map_; + const std::unique_ptr cublas_wrapper_mutex_; + const std::unique_ptr cublas_wrapper_; + const std::unique_ptr cuda_device_prop_ptr_; std::unordered_map convert_inputs(std::shared_ptr> input_tensors); - std::shared_ptr> - convert_outputs(const std::unordered_map& output_tensors); void allocateBuffer(const size_t request_batch_size, const size_t beam_width, @@ -61,17 +62,17 @@ struct GptJTritonModelInstance: AbstractTransformerModelInstance { const size_t max_request_output_len); void freeBuffer(); - int* d_input_ids_ = nullptr; - int* d_input_lengths_ = nullptr; - int* d_input_bad_words_ = nullptr; - int* d_input_stop_words_ = nullptr; - int* d_prefix_soft_prompt_lengths_ = nullptr; - float* d_prefix_soft_prompt_embedding_ = nullptr; + int* d_input_ids_ = nullptr; + int* d_input_lengths_ = nullptr; + int* d_input_bad_words_ = nullptr; + int* d_input_stop_words_ = nullptr; + int* d_request_prompt_lengths_ = nullptr; + T* d_request_prompt_embedding_ = nullptr; - int* d_output_ids_ = nullptr; - int* d_sequence_lengths_ = nullptr; + int* d_output_ids_ = nullptr; + int* d_sequence_lengths_ = nullptr; float* d_output_log_probs_ = nullptr; - float* d_cum_log_probs_ = nullptr; + float* d_cum_log_probs_ = nullptr; - int* h_total_output_lengths_ = nullptr; + uint32_t* h_total_output_lengths_ = nullptr; }; diff --git a/src/fastertransformer/triton_backend/gptneox/CMakeLists.txt b/src/fastertransformer/triton_backend/gptneox/CMakeLists.txt new file mode 100644 index 000000000..64c3874a5 --- /dev/null +++ b/src/fastertransformer/triton_backend/gptneox/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +set(parallel_gpt_triton_backend_files + GptNeoXTritonModel.cc + GptNeoXTritonModelInstance.cc +) + +add_library(GptNeoXTritonBackend SHARED ${parallel_gpt_triton_backend_files}) +target_link_libraries(GptNeoXTritonBackend PRIVATE TransformerTritonBackend GptNeoX) +target_compile_features(GptNeoXTritonBackend PRIVATE cxx_std_14) diff --git a/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.cc b/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.cc new file mode 100644 index 000000000..94b50cf65 --- /dev/null +++ b/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.cc @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h" +#include "3rdparty/INIReader.h" +#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/utils/allocator.h" + +namespace ft = fastertransformer; + +std::shared_ptr AbstractTransformerModel::createGptNeoXModel(std::string inifile) +{ + INIReader reader = INIReader(inifile); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << inifile << "'\n"; + return nullptr; + } + + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + std::string model_dir = reader.Get("ft_instance_hyperparameter", "model_dir"); + + if (data_type == "half") { + return std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + model_dir); + } + else { + return std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + model_dir); + } +} + +template +GptNeoXTritonModel::GptNeoXTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_dir): + tensor_para_size_(tensor_para_size), + pipeline_para_size_(pipeline_para_size), + shared_weights_(std::vector>>(ft::getDeviceCount())), + enable_custom_all_reduce_(enable_custom_all_reduce) +{ + model_dir_ = model_dir; + const std::string inifile{model_dir + "/config.ini"}; + INIReader reader = INIReader(inifile); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << inifile << "'\n"; + ft::FT_CHECK(false); + } + + model_name_ = reader.Get("gptneox", "model_name"); + head_num_ = reader.GetInteger("gptneox", "head_num"); + size_per_head_ = reader.GetInteger("gptneox", "size_per_head"); + inter_size_ = reader.GetInteger("gptneox", "inter_size"); + num_layer_ = reader.GetInteger("gptneox", "num_layer"); + vocab_size_ = reader.GetInteger("gptneox", "vocab_size"); + rotary_embedding_dim_ = reader.GetInteger("gptneox", "rotary_embedding"); + start_id_ = reader.GetInteger("gptneox", "start_id"); + end_id_ = reader.GetInteger("gptneox", "end_id"); + use_gptj_residual_ = (bool)reader.GetInteger("gptneox", "use_gptj_residual", 1); + + num_tasks_ = reader.GetInteger("gptneox", "num_tasks", 0); + + prompt_learning_start_id_ = reader.GetInteger("gptneox", "prompt_learning_start_id", end_id_ + 1); + prompt_learning_type_ = + static_cast(reader.GetInteger("gptneox", "prompt_learning_type", 0)); + + for (int task_name_id = 0; task_name_id < num_tasks_; task_name_id++) { + std::string config_task_name = "task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prompt_learning_table_pair_.insert({task_name, {task_name_id, prompt_length}}); + } +} + +template +std::unique_ptr GptNeoXTritonModel::createModelInstance( + int device_id, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); + + std::unique_ptr> allocator( + new ft::Allocator(device_id)); + + allocator->setStream(stream); + + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); + + std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); + std::unique_ptr cublas_wrapper_mutex(new std::mutex()); + std::unique_ptr cublas_wrapper(new ft::cublasMMWrapper( + cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get())); + + std::unique_ptr cuda_device_prop_ptr(new cudaDeviceProp); + ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); + + if (std::is_same::value) { + cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } + else if (std::is_same::value) { + cublas_wrapper->setFP32GemmConfig(); + } + + ft::NcclParam tensor_para = nccl_params.first[comms_rank]; + ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; + + auto gpt = std::make_unique>( + ft::GptNeoX(head_num_, + size_per_head_, + inter_size_, + num_layer_, + vocab_size_, + rotary_embedding_dim_, + start_id_, + end_id_, + prompt_learning_start_id_, // p/prompt tuning virtual token start id + prompt_learning_type_, + use_gptj_residual_, + 0.0f, // beam_search_diversity_rate_, + 0, // top_k_, + 0.0f, // top_p_, + 0, // random seed, note that all gpus should use same seed + 0.0f, // temperature_, + 0.0f, // len_penalty_, + 0.0f, // repetition_penalty_, + tensor_para, + pipeline_para, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + cuda_device_prop_ptr.get(), + custom_all_reduce_comm, + enable_custom_all_reduce_)); + + return std::unique_ptr>( + new GptNeoXTritonModelInstance(std::move(gpt), + shared_weights_[device_id], + std::move(allocator), + std::move(cublas_algo_map), + std::move(cublas_wrapper_mutex), + std::move(cublas_wrapper), + std::move(cuda_device_prop_ptr))); +} + +template +void GptNeoXTritonModel::createSharedWeights(int device_id, int rank) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + shared_weights_[device_id] = std::make_shared>(head_num_ * size_per_head_, + inter_size_, + vocab_size_, + num_layer_, + 0, // max_seq_len, deprecated + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank, + use_gptj_residual_, + prompt_learning_type_, + prompt_learning_table_pair_); + shared_weights_[device_id]->loadModel(model_dir_); + return; +} + +template +std::string GptNeoXTritonModel::toString() +{ + std::stringstream ss; + ss << "Model: " + << "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ + << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ << "\nstart_id: " << start_id_ + << "\nend_id: " << end_id_ << "\nuse_gptj_residual: " << use_gptj_residual_ + << "\nprompt_learning_type_: " << static_cast(prompt_learning_type_) + << "\nprompt_learning_start_id_: " << prompt_learning_start_id_ << "\ntensor_para_size: " << tensor_para_size_ + << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ + << "\nmodel_name: " << model_name_ << "\nmodel_dir: " << model_dir_ << std::endl; + return ss.str(); +} + +template +void GptNeoXTritonModel::createCustomComms( + std::vector>* custom_all_reduce_comms, int world_size) +{ + using commDataType = typename ft::CustomARCommTypeConverter::Type; + ft::initCustomAllReduceComm(custom_all_reduce_comms, enable_custom_all_reduce_, world_size); +} + +template +int GptNeoXTritonModel::getTensorParaSize() +{ + return tensor_para_size_; +} + +template +int GptNeoXTritonModel::getPipelineParaSize() +{ + return pipeline_para_size_; +} + +template struct GptNeoXTritonModel; +template struct GptNeoXTritonModel; diff --git a/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h b/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h new file mode 100644 index 000000000..3a587d143 --- /dev/null +++ b/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/models/gptneox/GptNeoX.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include + +namespace ft = fastertransformer; + +template +struct GptNeoXTritonModel: public AbstractTransformerModel { + GptNeoXTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_dir); + + ~GptNeoXTritonModel() = default; + + virtual std::unique_ptr + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) override; + + virtual void createSharedWeights(int deviceId, int rank) override; + + virtual void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) override; + + virtual std::string toString() override; + virtual int getTensorParaSize() override; + virtual int getPipelineParaSize() override; + +private: + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + size_t rotary_embedding_dim_; + int start_id_; + int end_id_; + size_t tensor_para_size_; + size_t pipeline_para_size_; + + // shared weights for each device + std::vector>> shared_weights_; + + // residual type + bool use_gptj_residual_ = true; + + // number of tasks (for prefix-prompt, p/prompt-tuning) + size_t num_tasks_ = 0; + int prompt_learning_start_id_ = 0; + ft::PromptLearningType prompt_learning_type_ = ft::PromptLearningType::no_prompt; + std::map> prompt_learning_table_pair_ = {}; + + bool is_fp16_; + int enable_custom_all_reduce_ = 0; + + std::string model_name_; + std::string model_dir_; +}; diff --git a/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.cc b/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.cc new file mode 100644 index 000000000..8b9b1f1f2 --- /dev/null +++ b/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.cc @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/triton_backend/triton_utils.hpp" +#include "src/fastertransformer/utils/Tensor.h" +#include +#include +#include +#include + +namespace ft = fastertransformer; + +template +void triton_stream_callback(std::unordered_map* output_tensors, void* ctx) +{ + GptNeoXTritonModelInstance* model = reinterpret_cast*>(ctx); + auto result = GptNeoXTritonModelInstance::convert_outputs(*output_tensors); + + model->stream_cb_(result, model->stream_ctx_); +} + +template +GptNeoXTritonModelInstance::GptNeoXTritonModelInstance( + std::unique_ptr> gpt, + std::shared_ptr> gpt_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr): + gpt_(std::move(gpt)), + gpt_weight_(gpt_weight), + allocator_(std::move(allocator)), + cublas_algo_map_(std::move(cublas_algo_map)), + cublas_wrapper_mutex_(std::move(cublas_wrapper_mutex)), + cublas_wrapper_(std::move(cublas_wrapper)), + cuda_device_prop_ptr_(std::move(cuda_device_prop_ptr)) +{ +} + +template +std::unordered_map GptNeoXTritonModelInstance::convert_inputs( + std::shared_ptr> input_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_, &allocator_); + move_tensor_H2D(input_tensors->at("input_lengths"), d_input_lengths_, &allocator_); + + const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; + const size_t input_data_len = input_tensors->at("input_ids").shape[1]; + h_total_output_lengths_ = reinterpret_cast(malloc(request_batch_size * sizeof(uint32_t))); + for (int i = 0; i < request_batch_size; ++i) { + h_total_output_lengths_[i] = + reinterpret_cast(input_tensors->at("request_output_len").data)[i] + input_data_len; + } + + std::unordered_map ft_input_tensors = std::unordered_map{ + {"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, + {"input_lengths", as_GPU_tensor(input_tensors->at("input_lengths"), d_input_lengths_)}, + {"output_seq_len", + ft::Tensor{ft::MEMORY_CPU, + ft::TYPE_UINT32, + {input_tensors->at("request_output_len").shape[0]}, + h_total_output_lengths_}}}; + + if (input_tensors->find("bad_words_list") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_, &allocator_); + ft_input_tensors.insert( + {"bad_words_list", as_GPU_tensor(input_tensors->at("bad_words_list"), d_input_bad_words_)}); + } + + if (input_tensors->find("stop_words_list") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_, &allocator_); + ft_input_tensors.insert( + {"stop_words_list", as_GPU_tensor(input_tensors->at("stop_words_list"), d_input_stop_words_)}); + } + + if (input_tensors->count("request_prompt_embedding") && input_tensors->count("request_prompt_lengths") + && input_tensors->count("request_prompt_type")) { + + move_tensor_H2D(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_, &allocator_); + ft_input_tensors.insert( + {"request_prompt_lengths", + as_GPU_tensor(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_)}); + + move_tensor_H2D(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_, &allocator_); + ft_input_tensors.insert( + {"request_prompt_embedding", + as_GPU_tensor(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_)}); + } + + for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { + if (t->first.find("input_ids") == std::string::npos && t->first.find("input_lengths") == std::string::npos + && t->first.find("output_seq_len") == std::string::npos + && t->first.find("prefix_soft_prompt_embedding") == std::string::npos + && t->first.find("prefix_soft_prompt_lengths") == std::string::npos) { + if (ft_input_tensors.count(t->first) == 0) { + ft_input_tensors.insert({t->first, t->second.convertTritonTensorToFt()}); + } + } + } + + return ft_input_tensors; +} + +template +std::shared_ptr> +GptNeoXTritonModelInstance::convert_outputs(const std::unordered_map& output_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + std::unordered_map* outputs_mapping = + new std::unordered_map(); + + for (auto it = output_tensors.begin(); it != output_tensors.end(); it++) { + outputs_mapping->insert({it->first, triton::Tensor::convertFtTensorToTriton(it->second)}); + } + + return std::shared_ptr>(outputs_mapping); +} + +template +std::shared_ptr> +GptNeoXTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + ft::FT_CHECK(false); + return nullptr; +} + +template +std::shared_ptr> +GptNeoXTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + ft::FT_CHECK_WITH_INFO(input_tensors->at("input_ids").shape.size() == 2, + "input_tensors->at(\"input_ids\").shape.size() == 2"); + ft::FT_CHECK_WITH_INFO(input_tensors->at("input_lengths").shape.size() == 1, + "input_tensors->at(\"input_lengths\").shape.size() == 1"); + + const uint32_t request_batch_size = input_tensors->at("input_ids").shape[0]; + const uint32_t max_request_output_len = (size_t)*std::max_element( + (int*)input_tensors->at("request_output_len").data, + (int*)input_tensors->at("request_output_len").data + input_tensors->at("request_output_len").shape[0]); + const uint32_t total_output_len = max_request_output_len + input_tensors->at("input_ids").shape[1]; + const uint32_t beam_width = + input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; + + allocateBuffer(request_batch_size, beam_width, total_output_len, max_request_output_len); + + std::unordered_map ft_input_tensors = convert_inputs(input_tensors); + + std::unordered_map output_tensors = std::unordered_map{ + {"output_ids", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_UINT32, + std::vector{request_batch_size, beam_width, total_output_len}, + d_output_ids_}}, + {"sequence_length", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_INT32, + std::vector{request_batch_size, beam_width}, + d_sequence_lengths_}}}; + + if (input_tensors->count("is_return_log_probs") && *((bool*)input_tensors->at("is_return_log_probs").data)) { + output_tensors.insert({"output_log_probs", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_FP32, + std::vector{request_batch_size, beam_width, max_request_output_len}, + d_output_log_probs_}}); + output_tensors.insert({"cum_log_probs", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_FP32, + std::vector{request_batch_size, beam_width}, + d_cum_log_probs_}}); + } + if (stream_cb_ != nullptr) { + gpt_->registerCallback(triton_stream_callback, this); + } + gpt_->forward(&output_tensors, &ft_input_tensors, gpt_weight_.get()); + if (stream_cb_ != nullptr) { + gpt_->unRegisterCallback(); + } + + if (h_total_output_lengths_ != nullptr) { + free(h_total_output_lengths_); + h_total_output_lengths_ = nullptr; + } + + return convert_outputs(output_tensors); +} + +template +GptNeoXTritonModelInstance::~GptNeoXTritonModelInstance() +{ + freeBuffer(); +} + +template +void GptNeoXTritonModelInstance::allocateBuffer(const size_t request_batch_size, + const size_t beam_width, + const size_t total_output_len, + const size_t max_request_output_len) +{ + d_output_ids_ = (int*)(allocator_->reMalloc( + d_output_ids_, sizeof(int) * request_batch_size * beam_width * total_output_len, false)); + d_sequence_lengths_ = + (int*)(allocator_->reMalloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width, false)); + d_output_log_probs_ = (float*)(allocator_->reMalloc( + d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * max_request_output_len, false)); + d_cum_log_probs_ = + (float*)(allocator_->reMalloc(d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width, false)); +} + +template +void GptNeoXTritonModelInstance::freeBuffer() +{ + allocator_->free((void**)(&d_output_ids_)); + allocator_->free((void**)(&d_sequence_lengths_)); + allocator_->free((void**)(&d_output_log_probs_)); + allocator_->free((void**)(&d_cum_log_probs_)); +} + +template struct GptNeoXTritonModelInstance; +template struct GptNeoXTritonModelInstance; diff --git a/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h b/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h new file mode 100644 index 000000000..adeae040a --- /dev/null +++ b/src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/models/gptneox/GptNeoX.h" +#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include + +namespace ft = fastertransformer; + +template +struct GptNeoXTritonModelInstance: AbstractTransformerModelInstance { + + GptNeoXTritonModelInstance(std::unique_ptr> gpt, + std::shared_ptr> gpt_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr); + ~GptNeoXTritonModelInstance(); + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + static std::shared_ptr> + convert_outputs(const std::unordered_map& output_tensors); + +private: + const std::unique_ptr> gpt_; + const std::shared_ptr> gpt_weight_; + const std::unique_ptr> allocator_; + const std::unique_ptr cublas_algo_map_; + const std::unique_ptr cublas_wrapper_mutex_; + const std::unique_ptr cublas_wrapper_; + const std::unique_ptr cuda_device_prop_ptr_; + + std::unordered_map + convert_inputs(std::shared_ptr> input_tensors); + + void allocateBuffer(const size_t request_batch_size, + const size_t beam_width, + const size_t total_output_len, + const size_t max_request_output_len); + void freeBuffer(); + + int* d_input_ids_ = nullptr; + int* d_input_lengths_ = nullptr; + int* d_input_bad_words_ = nullptr; + int* d_input_stop_words_ = nullptr; + int* d_request_prompt_lengths_ = nullptr; + T* d_request_prompt_embedding_ = nullptr; + + int* d_output_ids_ = nullptr; + int* d_sequence_lengths_ = nullptr; + float* d_output_log_probs_ = nullptr; + float* d_cum_log_probs_ = nullptr; + + uint32_t* h_total_output_lengths_ = nullptr; +}; diff --git a/src/fastertransformer/triton_backend/multi_gpu_gpt/CMakeLists.txt b/src/fastertransformer/triton_backend/multi_gpu_gpt/CMakeLists.txt index 871e35a62..8a3d08459 100644 --- a/src/fastertransformer/triton_backend/multi_gpu_gpt/CMakeLists.txt +++ b/src/fastertransformer/triton_backend/multi_gpu_gpt/CMakeLists.txt @@ -20,5 +20,5 @@ set(parallel_gpt_triton_backend_files ) add_library(ParallelGptTritonBackend SHARED ${parallel_gpt_triton_backend_files}) -target_link_libraries(ParallelGptTritonBackend PRIVATE ParallelGpt) +target_link_libraries(ParallelGptTritonBackend PRIVATE TransformerTritonBackend ParallelGpt) target_compile_features(ParallelGptTritonBackend PRIVATE cxx_std_14) \ No newline at end of file diff --git a/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc b/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc index cf8693ef9..dee27556c 100644 --- a/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc +++ b/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc @@ -33,7 +33,44 @@ std::shared_ptr AbstractTransformerModel::createGptMod } const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); - const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + + // gpt variant parameters + ft::gptVariantParams gpt_variant_params{}; + std::string model_variant = reader.Get(model_name, "model_variant", "gpt"); + + if (model_variant == "opt-pre") { + gpt_variant_params.layernorm_eps = 1e-5f; + gpt_variant_params.layernorm_type = ft::LayerNormType::pre_layernorm; + gpt_variant_params.activation_type = ft::ActivationType::Relu; + gpt_variant_params.has_post_decoder_layernorm = false; + } + else if (model_variant == "opt-post") { + gpt_variant_params.layernorm_eps = 1e-5f; + gpt_variant_params.layernorm_type = ft::LayerNormType::post_layernorm; + gpt_variant_params.activation_type = ft::ActivationType::Relu; + gpt_variant_params.has_post_decoder_layernorm = false; + } + + gpt_variant_params.has_adapters = reader.GetBoolean(model_name, "has_adapters", false); + + // Prompt Learning Configurations + int end_id = reader.GetInteger(model_name, "end_id"); + int prompt_learning_start_id = reader.GetInteger(model_name, "prompt_learning_start_id", end_id + 1); + ft::PromptLearningType prompt_learning_type = + static_cast(reader.GetInteger(model_name, "prompt_learning_type", 0)); + + std::map> prompt_learning_table_pair; + + // NOTE: get prompt from configuration files + const int num_tasks = reader.GetInteger(model_name, "num_tasks", 0); + for (int task_name_id = 0; task_name_id < num_tasks; task_name_id++) { + std::string config_task_name = model_name + "_task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prompt_learning_table_pair.insert({task_name, {task_name_id, prompt_length}}); + } + if (data_type == "fp16") { return std::make_shared>( reader.GetInteger("ft_instance_hyperparameter", "max_seq_len"), @@ -43,7 +80,11 @@ std::shared_ptr AbstractTransformerModel::createGptMod reader.GetInteger(model_name, "decoder_layers"), reader.GetInteger(model_name, "vocab_size"), reader.GetInteger(model_name, "start_id"), - reader.GetInteger(model_name, "end_id"), + end_id, + prompt_learning_start_id, + prompt_learning_type, + prompt_learning_table_pair, + gpt_variant_params, reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), reader.Get("ft_instance_hyperparameter", "model_name"), @@ -61,7 +102,11 @@ std::shared_ptr AbstractTransformerModel::createGptMod reader.GetInteger(model_name, "decoder_layers"), reader.GetInteger(model_name, "vocab_size"), reader.GetInteger(model_name, "start_id"), - reader.GetInteger(model_name, "end_id"), + end_id, + prompt_learning_start_id, + prompt_learning_type, + prompt_learning_table_pair, + gpt_variant_params, reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), reader.Get("ft_instance_hyperparameter", "model_name"), @@ -79,7 +124,11 @@ std::shared_ptr AbstractTransformerModel::createGptMod reader.GetInteger(model_name, "decoder_layers"), reader.GetInteger(model_name, "vocab_size"), reader.GetInteger(model_name, "start_id"), - reader.GetInteger(model_name, "end_id"), + end_id, + prompt_learning_start_id, + prompt_learning_type, + prompt_learning_table_pair, + gpt_variant_params, reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), reader.Get("ft_instance_hyperparameter", "model_name"), @@ -88,26 +137,120 @@ std::shared_ptr AbstractTransformerModel::createGptMod reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0)); } else { - FT_LOG_ERROR("Unspported data type " + data_type); + FT_LOG_ERROR("Unsupported data type " + data_type); exit(-1); } } template -ParallelGptTritonModel::ParallelGptTritonModel(size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - int start_id, - int end_id, - size_t tensor_para_size, - size_t pipeline_para_size, - std::string model_name, +ParallelGptTritonModel::ParallelGptTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, std::string model_dir, - int int8_mode, - int enable_custom_all_reduce): + int int8_mode): + tensor_para_size_(tensor_para_size), + pipeline_para_size_(pipeline_para_size), + shared_weights_(std::vector>>(ft::getDeviceCount())), + enable_custom_all_reduce_(enable_custom_all_reduce), + model_dir_(model_dir), + int8_mode_(int8_mode) +{ + INIReader reader = INIReader(model_dir + "/config.ini"); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << model_dir << "/config.ini" + << "'\n"; + ft::FT_CHECK(false); + } + + /* GPT Configuration File Example + [gpt] + model_name=gpt + max_pos_seq_len=2048 ;for position embedding tables + head_num=12 + size_per_head=64 + inter_size=3072 + num_layer=12 + layernorm_eps=1e-6 # optional for the default gpt + layernorm_type=pre_layernorm # optional for the default gpt + activation_type=Gelu # optional for the default gpt + has_post_decoder_layernorm=1 # optional for the default gpt + vocab_size=50257 + start_id=50256 + end_id=50256 + prompt_learning_start_id=50257 + prompt_learning_type=3 + num_tasks=3 + + [task_0] + task_name=sentiment + prompt_length=10 + + [task_1] + task_name=intent_and_slot + prompt_length=10 + + [task_2] + task_name=squad + prompt_length=16 + */ + + model_name_ = reader.Get("gpt", "model_name"); + max_seq_len_ = reader.GetInteger("gpt", "max_pos_seq_len"); + head_num_ = reader.GetInteger("gpt", "head_num"); + size_per_head_ = reader.GetInteger("gpt", "size_per_head"); + inter_size_ = reader.GetInteger("gpt", "inter_size"); + num_layer_ = reader.GetInteger("gpt", "num_layer"); + vocab_size_ = reader.GetInteger("gpt", "vocab_size"); + /* Meta Opt Examples + layernorm_eps=1e-5 + layernorm_type=pre_layernorm + activation_type=Relu + has_post_decoder_layernorm=0 + */ + gpt_variant_params_.layernorm_eps = reader.GetFloat("gpt", "layernorm_eps", 1e-6f); + gpt_variant_params_.layernorm_type = ft::getLayerNormType(reader.Get("gpt", "layernorm_type", "pre_layernorm")); + gpt_variant_params_.activation_type = ft::getActivationType(reader.Get("gpt", "activation_type", "Gelu")); + gpt_variant_params_.has_post_decoder_layernorm = reader.GetBoolean("gpt", "has_post_decoder_layernorm", "1"); + /* Megatron GPT Adapter Examples + has_adapters=True + adapter_inter_size=1024 + */ + gpt_variant_params_.has_adapters = reader.GetBoolean("gpt", "has_adapters", false); + gpt_variant_params_.adapter_inter_size = reader.GetInteger("gpt", "adapter_inter_size", inter_size_); + start_id_ = reader.GetInteger("gpt", "start_id"); + end_id_ = reader.GetInteger("gpt", "end_id"); + + num_tasks_ = reader.GetInteger("gpt", "num_tasks", 0); + prompt_learning_start_id_ = reader.GetInteger("gpt", "prompt_learning_start_id", end_id_ + 1); + prompt_learning_type_ = static_cast(reader.GetInteger("gpt", "prompt_learning_type", 0)); + + for (int task_name_id = 0; task_name_id < num_tasks_; task_name_id++) { + std::string config_task_name = "task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prompt_learning_table_pair_.insert({task_name, {task_name_id, prompt_length}}); + } +} + +template +ParallelGptTritonModel::ParallelGptTritonModel(size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + int start_id, + int end_id, + int prompt_learning_start_id, + ft::PromptLearningType prompt_learning_type, + std::map> prompt_learning_table_pair, + ft::gptVariantParams gpt_variant_params, + size_t tensor_para_size, + size_t pipeline_para_size, + std::string model_name, + std::string model_dir, + int int8_mode, + int enable_custom_all_reduce): max_seq_len_(max_seq_len), head_num_(head_num), size_per_head_(size_per_head), @@ -116,8 +259,13 @@ ParallelGptTritonModel::ParallelGptTritonModel(size_t max_seq_len, vocab_size_(vocab_size), start_id_(start_id), end_id_(end_id), + prompt_learning_start_id_(prompt_learning_start_id), + prompt_learning_type_(prompt_learning_type), + prompt_learning_table_pair_(prompt_learning_table_pair), + gpt_variant_params_(gpt_variant_params), tensor_para_size_(tensor_para_size), pipeline_para_size_(pipeline_para_size), + shared_weights_(std::vector>>(ft::getDeviceCount())), model_name_(model_name), model_dir_(model_dir), int8_mode_(int8_mode), @@ -126,31 +274,30 @@ ParallelGptTritonModel::ParallelGptTritonModel(size_t max_seq_len, } template -std::unique_ptr -ParallelGptTritonModel::createModelInstance(int device_id, - int rank, - cudaStream_t stream, - std::pair, std::vector> nccl_comms, - std::shared_ptr custom_all_reduce_comm) +std::unique_ptr ParallelGptTritonModel::createModelInstance( + int device_id, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) { ft::check_cuda_error(cudaSetDevice(device_id)); - const int tensor_para_rank = rank % tensor_para_size_; - const int pipeline_para_rank = rank / tensor_para_size_; + const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); std::unique_ptr> allocator( new ft::Allocator(device_id)); allocator->setStream(stream); - cublasHandle_t cublas_handle; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cublasCreate(&cublas_handle); cublasLtCreate(&cublaslt_handle); cublasSetStream(cublas_handle, stream); - std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); - std::unique_ptr cublas_wrapper_mutex(new std::mutex()); + std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); + std::unique_ptr cublas_wrapper_mutex(new std::mutex()); std::unique_ptr cublas_wrapper(new ft::cublasMMWrapper( cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get())); @@ -169,55 +316,46 @@ ParallelGptTritonModel::createModelInstance(int device_id, cublas_wrapper->setFP32GemmConfig(); } - ft::NcclParam tensor_para(tensor_para_rank, tensor_para_size_, nccl_comms.first[device_id]); - ft::NcclParam pipeline_para(pipeline_para_rank, pipeline_para_size_, nccl_comms.second[device_id]); - - auto gpt = std::make_unique>( - ft::ParallelGpt(0, // max_batch_size, FT will adjust the buffer automatically. - 0, // max_seq_len, FT will adjust the buffer automatically. - 0, // max_input_len, FT will adjust the buffer automatically. - 0, - head_num_, - size_per_head_, - inter_size_, - num_layer_, - vocab_size_, - start_id_, - end_id_, - 0.0f, // beam_search_diversity_rate_, - 1, // top_k_, - 0.0f, // top_p_, - 0, // random seed, note that all gpus should use same seed - 1.0f, // temperature_, - 1.0f, // len_penalty_, - 1.0f, // repetition_penalty_, - tensor_para, - pipeline_para, - stream, - cublas_wrapper.get(), - allocator.get(), - false, - cuda_device_prop_ptr.get(), - false, - int8_mode_, - custom_all_reduce_comm, - enable_custom_all_reduce_)); - - auto weight = std::unique_ptr>(new ft::ParallelGptWeight(head_num_ * size_per_head_, - inter_size_, - vocab_size_, - num_layer_, - max_seq_len_, - tensor_para_size_, - tensor_para_rank, - pipeline_para_size_, - pipeline_para_rank, - int8_mode_)); - - weight->loadModel(model_dir_); + ft::NcclParam tensor_para = nccl_params.first[comms_rank]; + ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; + + auto gpt = + std::make_unique>(0, // max_batch_size, FT will adjust the buffer automatically. + 0, // max_seq_len, FT will adjust the buffer automatically. + 0, // max_input_len, FT will adjust the buffer automatically. + 0, + head_num_, + size_per_head_, + inter_size_, + num_layer_, + vocab_size_, + start_id_, + end_id_, + prompt_learning_start_id_, // p/prompt tuning virtual token start id + prompt_learning_type_, + gpt_variant_params_, + 0.0f, // beam_search_diversity_rate_, + 1, // top_k_, + 0.0f, // top_p_, + 0, // random seed, note that all gpus should use same seed + 1.0f, // temperature_, + 0.0f, // len_penalty_, + 1.0f, // repetition_penalty_, + tensor_para, + pipeline_para, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + cuda_device_prop_ptr.get(), + false, + int8_mode_, + custom_all_reduce_comm, + enable_custom_all_reduce_); + return std::unique_ptr>( new ParallelGptTritonModelInstance(std::move(gpt), - std::move(weight), + shared_weights_[device_id], std::move(allocator), std::move(cublas_algo_map), std::move(cublas_wrapper_mutex), @@ -225,6 +363,29 @@ ParallelGptTritonModel::createModelInstance(int device_id, std::move(cuda_device_prop_ptr))); } +template +void ParallelGptTritonModel::createSharedWeights(int device_id, int rank) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + shared_weights_[device_id] = std::make_shared>(head_num_ * size_per_head_, + inter_size_, + vocab_size_, + num_layer_, + max_seq_len_, + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank, + int8_mode_, + prompt_learning_type_, + prompt_learning_table_pair_, + gpt_variant_params_); + shared_weights_[device_id]->loadModel(model_dir_); + return; +} + template std::string ParallelGptTritonModel::toString() { @@ -232,79 +393,16 @@ std::string ParallelGptTritonModel::toString() ss << "Model: " << "\nmax_seq_len: " << max_seq_len_ << "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ - << "\nstart_id: " << start_id_ << "\nend_id: " << end_id_ << "\ntensor_para_size: " << tensor_para_size_ - << "\npipeline_para_size: " << pipeline_para_size_ << "\nint8_mode: " << int8_mode_ - << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_ - << "\nmodel_dir: " << model_dir_ << std::endl; + << "\nlayernorm_eps" << gpt_variant_params_.layernorm_eps << "\nlayernorm_type" + << static_cast(gpt_variant_params_.layernorm_type) << "\nactivation_type" + << static_cast(gpt_variant_params_.activation_type) << "\nhas_post_decoder_layernorm" + << gpt_variant_params_.has_post_decoder_layernorm << "\nstart_id: " << start_id_ << "\nend_id: " << end_id_ + << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ + << "\nint8_mode: " << int8_mode_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ + << "\nmodel_name: " << model_name_ << "\nmodel_dir: " << model_dir_ << std::endl; return ss.str(); } -template -std::vector ParallelGptTritonModel::createNcclIds(const uint32_t world_size, bool multi_instances) -{ - std::vector nccl_ids(tensor_para_size_ + pipeline_para_size_); - if (multi_instances) { - if (tensor_para_size_ * pipeline_para_size_ != 1) { - printf( - "[ERROR] Multiple Instances currently only support tensor_para_size_ and pipeline_para_size_ both 1\n"); - ft::FT_CHECK(tensor_para_size_ == 1 && pipeline_para_size_ == 1); - } - nccl_ids.resize(2); - } - else { - if (world_size != tensor_para_size_ * pipeline_para_size_) { - ft::FT_CHECK_WITH_INFO(world_size == tensor_para_size_ * pipeline_para_size_, - "world_size == tensor_para_size_ * pipeline_para_size_ (" - + std::to_string(world_size) + " != " + std::to_string(tensor_para_size_) + "*" - + std::to_string(pipeline_para_size_) + ")"); - } - } - - for (uint32_t i = 0; i < nccl_ids.size(); i++) { - NCCLCHECK(ncclGetUniqueId(&nccl_ids[i])); - } - return nccl_ids; -} - -template -std::pair, std::vector> ParallelGptTritonModel::createNcclComms( - std::vector nccl_ids, const int node_id, bool multi_instances, int instance_id) -{ - const int gpu_count = ft::getDeviceCount(); - std::vector tensor_para_comms(gpu_count); - std::vector pipeline_para_comms(gpu_count); - if (multi_instances) { - ncclUniqueId tensor_para_nccl_uid = nccl_ids[0]; - ncclUniqueId pipeline_para_nccl_uid = nccl_ids[1]; - size_t tensor_para_rank = 0; - size_t pipeline_para_rank = 0; - - ft::check_cuda_error(cudaSetDevice(instance_id)); - NCCLCHECK(ncclCommInitRank( - &tensor_para_comms[instance_id], tensor_para_size_, tensor_para_nccl_uid, tensor_para_rank)); - NCCLCHECK(ncclCommInitRank( - &pipeline_para_comms[instance_id], pipeline_para_size_, pipeline_para_nccl_uid, pipeline_para_rank)); - } - else { - NCCLCHECK(ncclGroupStart()); - for (int gid = 0; gid < gpu_count; gid++) { - int rank = node_id * gpu_count + gid; - size_t tensor_para_rank = rank % tensor_para_size_; - size_t pipeline_para_rank = rank / tensor_para_size_; - ncclUniqueId tensor_para_nccl_uid = nccl_ids[pipeline_para_rank]; - ncclUniqueId pipeline_para_nccl_uid = nccl_ids[pipeline_para_size_ + tensor_para_rank]; - - ft::check_cuda_error(cudaSetDevice(gid)); - NCCLCHECK( - ncclCommInitRank(&tensor_para_comms[gid], tensor_para_size_, tensor_para_nccl_uid, tensor_para_rank)); - NCCLCHECK(ncclCommInitRank( - &pipeline_para_comms[gid], pipeline_para_size_, pipeline_para_nccl_uid, pipeline_para_rank)); - } - NCCLCHECK(ncclGroupEnd()); - } - return std::pair, std::vector>(tensor_para_comms, pipeline_para_comms); -} - template void ParallelGptTritonModel::createCustomComms( std::vector>* custom_all_reduce_comms, int world_size) @@ -329,4 +427,4 @@ template struct ParallelGptTritonModel; template struct ParallelGptTritonModel; #ifdef ENABLE_BF16 template struct ParallelGptTritonModel<__nv_bfloat16>; -#endif \ No newline at end of file +#endif diff --git a/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h b/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h index 7192bded9..e7148121b 100644 --- a/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h +++ b/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h @@ -28,59 +28,73 @@ namespace ft = fastertransformer; template struct ParallelGptTritonModel: public AbstractTransformerModel { - ParallelGptTritonModel(size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - size_t num_layer, - size_t vocab_size, - int start_id, - int end_id, - size_t tensor_para_size, - size_t pipeline_para_size, - std::string model_name, + ParallelGptTritonModel(size_t max_seq_len, + size_t head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + int start_id, + int end_id, + int prompt_learning_start_id, + ft::PromptLearningType prompt_learning_type, + std::map> prompt_learning_table_pair, + ft::gptVariantParams gpt_variant_params, + size_t tensor_para_size, + size_t pipeline_para_size, + std::string model_name, + std::string model_dir, + int int8_mode, + int enable_custom_all_reduce); + + ParallelGptTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, std::string model_dir, - int int8_mode, - int enable_custom_all_reduce); + int int8_mode); virtual std::unique_ptr - createModelInstance(int deviceId, - int rank, - cudaStream_t stream, - std::pair, std::vector> nccl_comms, + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, std::shared_ptr custom_all_reduce_comm = nullptr) override; - virtual std::pair, std::vector> - createNcclComms(std::vector nccl_ids, - const int node_id, - bool multi_instances = false, - int instance_id = 0) override; + virtual void createSharedWeights(int deviceId, int rank) override; virtual void createCustomComms(std::vector>* custom_all_reduce_comms, - int world_size) override; - - virtual std::vector createNcclIds(const uint32_t world_size, bool multi_instances = false) override; + int world_size) override; virtual std::string toString() override; - virtual int getTensorParaSize() override; - virtual int getPipelineParaSize() override; + virtual int getTensorParaSize() override; + virtual int getPipelineParaSize() override; private: - const size_t max_seq_len_; - const size_t head_num_; - const size_t size_per_head_; - const size_t inter_size_; - const size_t num_layer_; - const size_t vocab_size_; - const int start_id_; - const int end_id_; - const size_t tensor_para_size_; - const size_t pipeline_para_size_; + size_t max_seq_len_; // needed for position embedding table + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + int start_id_; + int end_id_; + size_t tensor_para_size_; + size_t pipeline_para_size_; - bool is_fp16_; + // shared weights for each device + std::vector>> shared_weights_; + + // model variants parameters + ft::gptVariantParams gpt_variant_params_ = {}; std::string model_name_; std::string model_dir_; - int int8_mode_ = 0; - int enable_custom_all_reduce_ = 0; + int int8_mode_ = 0; + int enable_custom_all_reduce_ = 0; + + // number of tasks (for prefix-prompt, p/prompt-tuning) + size_t num_tasks_ = 0; + int prompt_learning_start_id_ = 0; + ft::PromptLearningType prompt_learning_type_ = ft::PromptLearningType::no_prompt; + std::map> prompt_learning_table_pair_ = {}; }; diff --git a/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.cc b/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.cc index 831079442..174b0fe84 100644 --- a/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.cc +++ b/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.cc @@ -25,17 +25,26 @@ namespace ft = fastertransformer; +template +void triton_stream_callback(std::unordered_map* output_tensors, void* ctx) +{ + ParallelGptTritonModelInstance* model = reinterpret_cast*>(ctx); + auto result = ParallelGptTritonModelInstance::convert_outputs(*output_tensors); + + model->stream_cb_(result, model->stream_ctx_); +} + template ParallelGptTritonModelInstance::ParallelGptTritonModelInstance( - std::unique_ptr> gpt, - std::unique_ptr> gpt_weight, + std::unique_ptr> gpt, + std::shared_ptr> gpt_weight, std::unique_ptr> allocator, - std::unique_ptr cublas_algo_map, - std::unique_ptr cublas_wrapper_mutex, - std::unique_ptr cublas_wrapper, - std::unique_ptr cuda_device_prop_ptr): + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr): gpt_(std::move(gpt)), - gpt_weight_(std::move(gpt_weight)), + gpt_weight_(gpt_weight), allocator_(std::move(allocator)), cublas_algo_map_(std::move(cublas_algo_map)), cublas_wrapper_mutex_(std::move(cublas_wrapper_mutex)), @@ -56,44 +65,57 @@ template std::unordered_map ParallelGptTritonModelInstance::convert_inputs( std::shared_ptr> input_tensors) { - move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_); - move_tensor_H2D(input_tensors->at("input_lengths"), d_input_lengths_); + const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; + move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_, &allocator_); + move_tensor_H2D(input_tensors->at("input_lengths"), d_input_lengths_, &allocator_); + + const int input_data_len = input_tensors->at("input_ids").shape[1]; + h_total_output_lengths_ = reinterpret_cast(malloc(request_batch_size * sizeof(uint32_t))); + for (int i = 0; i < request_batch_size; ++i) { + h_total_output_lengths_[i] = + reinterpret_cast(input_tensors->at("request_output_len").data)[i] + input_data_len; + } std::unordered_map ft_input_tensors{ {"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, {"input_lengths", as_GPU_tensor(input_tensors->at("input_lengths"), d_input_lengths_)}, - {"max_output_seq_len", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_INT32, {1}, &h_total_output_len_}}}; + {"output_seq_len", + ft::Tensor{ft::MEMORY_CPU, + ft::TYPE_UINT32, + {input_tensors->at("request_output_len").shape[0]}, + h_total_output_lengths_}}}; - if (input_tensors->count("prefix_soft_prompt_embedding") && input_tensors->count("prefix_soft_prompt_lengths")) { + if (input_tensors->count("request_prompt_embedding") && input_tensors->count("request_prompt_lengths") + && input_tensors->count("request_prompt_type")) { - move_tensor_H2D(input_tensors->at("prefix_soft_prompt_lengths"), d_prefix_soft_prompt_lengths_); + move_tensor_H2D(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_, &allocator_); ft_input_tensors.insert( - {"prefix_soft_prompt_lengths", - as_GPU_tensor(input_tensors->at("prefix_soft_prompt_lengths"), d_prefix_soft_prompt_lengths_)}); + {"request_prompt_lengths", + as_GPU_tensor(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_)}); - move_tensor_H2D(input_tensors->at("prefix_soft_prompt_embedding"), d_prefix_soft_prompt_embedding_); + move_tensor_H2D(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_, &allocator_); ft_input_tensors.insert( - {"prefix_soft_prompt_embedding", - as_GPU_tensor(input_tensors->at("prefix_soft_prompt_embedding"), d_prefix_soft_prompt_embedding_)}); + {"request_prompt_embedding", + as_GPU_tensor(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_)}); } if (input_tensors->find("bad_words_list") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_); + move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_, &allocator_); ft_input_tensors.insert( {"bad_words_list", as_GPU_tensor(input_tensors->at("bad_words_list"), d_input_bad_words_)}); } if (input_tensors->find("stop_words_list") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_); + move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_, &allocator_); ft_input_tensors.insert( {"stop_words_list", as_GPU_tensor(input_tensors->at("stop_words_list"), d_input_stop_words_)}); } for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { if (t->first.find("input_ids") == std::string::npos && t->first.find("input_lengths") == std::string::npos - && t->first.find("max_output_seq_len") == std::string::npos - && t->first.find("prefix_soft_prompt_embedding") == std::string::npos - && t->first.find("prefix_soft_prompt_lengths") == std::string::npos) { + && t->first.find("output_seq_len") == std::string::npos + && t->first.find("request_prompt_embedding") == std::string::npos + && t->first.find("request_prompt_lengths") == std::string::npos) { ft_input_tensors.insert({t->first, t->second.convertTritonTensorToFt()}); } } @@ -123,28 +145,45 @@ std::shared_ptr> ParallelGptTrit "input_tensors->at(\"input_ids\").shape.size() == 2"); ft::FT_CHECK_WITH_INFO(input_tensors->at("input_lengths").shape.size() == 1, "input_tensors->at(\"input_lengths\").shape.size() == 1"); - const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; - const size_t request_output_len = (size_t)(*(uint*)input_tensors->at("request_output_len").data); - const size_t beam_width = - input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; - h_total_output_len_ = request_output_len + input_tensors->at("input_ids").shape[1]; - - freeBuffer(); // free buffer of previous iteration - allocateBuffer(request_batch_size, beam_width, h_total_output_len_, request_output_len); + const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; + size_t max_request_output_len = (size_t)*std::max_element( + (int*)input_tensors->at("request_output_len").data, + (int*)input_tensors->at("request_output_len").data + input_tensors->at("request_output_len").shape[0]); + size_t beam_width = input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; + + size_t total_length = max_request_output_len + input_tensors->at("input_ids").shape[1]; + + if (beam_width != 1 && beam_width != 2 && beam_width != 3 && beam_width != 4 && beam_width != 8 && beam_width != 16 + && beam_width != 32) { + FT_LOG_WARNING("beam_width = %ld is invalid. Set it to 1 to use sampling by default.", beam_width); + beam_width = 1; + } std::unordered_map ft_input_tensors = convert_inputs(input_tensors); + // If input_tensors don't contain "START" flag, then it is non-interactive generation, allocate buffer directly. + // If input_tensors contains "START" flag, then only allocate buffer when "START == 1". + if (ft_input_tensors.count("START") == 0 + || (ft_input_tensors.count("START") && ft_input_tensors.at("START").getVal() == 1)) { + if (ft_input_tensors.count("session_len")) { + total_length = ft_input_tensors.at("session_len").getVal(); + max_request_output_len = ft_input_tensors.at("session_len").getVal(); + } + size_t max_prefix_soft_prompt_length = 0; + if (input_tensors->count("request_prompt_lengths")) { + max_prefix_soft_prompt_length = + (size_t)*std::max_element((int*)input_tensors->at("request_prompt_lengths").data, + (int*)input_tensors->at("request_prompt_lengths").data + request_batch_size); + } + total_length += max_prefix_soft_prompt_length; + allocateBuffer(request_batch_size, beam_width, total_length, max_request_output_len); + } std::unordered_map output_tensors = std::unordered_map{ {"output_ids", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_UINT32, - std::vector{request_batch_size, beam_width, (size_t)h_total_output_len_}, + std::vector{request_batch_size, beam_width, total_length}, d_output_ids_}}, - {"parent_ids", - ft::Tensor{ft::MEMORY_GPU, - ft::TYPE_INT32, - std::vector{(size_t)h_total_output_len_, request_batch_size, beam_width}, - d_parent_ids_}}, {"sequence_length", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_INT32, @@ -155,7 +194,7 @@ std::shared_ptr> ParallelGptTrit output_tensors.insert({"output_log_probs", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_FP32, - std::vector{request_batch_size, beam_width, request_output_len}, + std::vector{request_batch_size, beam_width, max_request_output_len}, d_output_log_probs_}}); output_tensors.insert({"cum_log_probs", ft::Tensor{ft::MEMORY_GPU, @@ -163,31 +202,15 @@ std::shared_ptr> ParallelGptTrit std::vector{request_batch_size, beam_width}, d_cum_log_probs_}}); } - gpt_->forward(&output_tensors, &ft_input_tensors, gpt_weight_.get()); - if (d_input_ids_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_ids_)); - d_input_ids_ = nullptr; - } - if (d_input_lengths_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_lengths_)); - d_input_lengths_ = nullptr; + if (stream_cb_ != nullptr) { + gpt_->registerCallback(triton_stream_callback, this); } - if (d_prefix_soft_prompt_embedding_ != nullptr) { - ft::check_cuda_error(cudaFree(d_prefix_soft_prompt_embedding_)); - d_prefix_soft_prompt_embedding_ = nullptr; - } - if (d_prefix_soft_prompt_lengths_ != nullptr) { - ft::check_cuda_error(cudaFree(d_prefix_soft_prompt_lengths_)); - d_prefix_soft_prompt_lengths_ = nullptr; - } - if (d_input_bad_words_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_bad_words_)); - d_input_bad_words_ = nullptr; - } - if (d_input_stop_words_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_stop_words_)); - d_input_stop_words_ = nullptr; + + gpt_->forward(&output_tensors, &ft_input_tensors, gpt_weight_.get()); + + if (stream_cb_ != nullptr) { + gpt_->unRegisterCallback(); } return convert_outputs(output_tensors); @@ -205,25 +228,27 @@ void ParallelGptTritonModelInstance::allocateBuffer(const size_t request_batc const size_t total_output_len, const size_t request_output_len) { - ft::deviceMalloc(&d_output_ids_, request_batch_size * beam_width * total_output_len); - ft::deviceMalloc(&d_parent_ids_, request_batch_size * beam_width * total_output_len); - ft::deviceMalloc(&d_sequence_lengths_, request_batch_size * beam_width); - ft::deviceMalloc(&d_output_log_probs_, request_output_len * request_batch_size * beam_width); - ft::deviceMalloc(&d_cum_log_probs_, request_batch_size * beam_width); + d_output_ids_ = (int*)(allocator_->reMalloc( + d_output_ids_, sizeof(int) * request_batch_size * beam_width * total_output_len, false)); + d_sequence_lengths_ = + (int*)(allocator_->reMalloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width, false)); + d_output_log_probs_ = (float*)(allocator_->reMalloc( + d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * request_output_len, false)); + d_cum_log_probs_ = + (float*)(allocator_->reMalloc(d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width, false)); } template void ParallelGptTritonModelInstance::freeBuffer() { - ft::deviceFree(d_output_ids_); - ft::deviceFree(d_parent_ids_); - ft::deviceFree(d_sequence_lengths_); - ft::deviceFree(d_output_log_probs_); - ft::deviceFree(d_cum_log_probs_); + allocator_->free((void**)(&d_output_ids_)); + allocator_->free((void**)(&d_sequence_lengths_)); + allocator_->free((void**)(&d_output_log_probs_)); + allocator_->free((void**)(&d_cum_log_probs_)); } template struct ParallelGptTritonModelInstance; template struct ParallelGptTritonModelInstance; #ifdef ENABLE_BF16 template struct ParallelGptTritonModelInstance<__nv_bfloat16>; -#endif \ No newline at end of file +#endif diff --git a/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h b/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h index da7f4a559..3751a9910 100644 --- a/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h +++ b/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h @@ -26,13 +26,13 @@ namespace ft = fastertransformer; template struct ParallelGptTritonModelInstance: AbstractTransformerModelInstance { - ParallelGptTritonModelInstance(std::unique_ptr> gpt, - std::unique_ptr> gpt_weight, + ParallelGptTritonModelInstance(std::unique_ptr> gpt, + std::shared_ptr> gpt_weight, std::unique_ptr> allocator, - std::unique_ptr cublas_algo_map, - std::unique_ptr cublas_wrapper_mutex, - std::unique_ptr cublas_wrapper, - std::unique_ptr cuda_device_prop_ptr); + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr); ~ParallelGptTritonModelInstance(); std::shared_ptr> @@ -41,19 +41,20 @@ struct ParallelGptTritonModelInstance: AbstractTransformerModelInstance { std::shared_ptr> forward(std::shared_ptr> input_tensors) override; + static std::shared_ptr> + convert_outputs(const std::unordered_map& output_tensors); + private: - const std::unique_ptr> gpt_; - const std::unique_ptr> gpt_weight_; + const std::unique_ptr> gpt_; + const std::shared_ptr> gpt_weight_; const std::unique_ptr> allocator_; - const std::unique_ptr cublas_algo_map_; - const std::unique_ptr cublas_wrapper_mutex_; - const std::unique_ptr cublas_wrapper_; - const std::unique_ptr cuda_device_prop_ptr_; + const std::unique_ptr cublas_algo_map_; + const std::unique_ptr cublas_wrapper_mutex_; + const std::unique_ptr cublas_wrapper_; + const std::unique_ptr cuda_device_prop_ptr_; std::unordered_map convert_inputs(std::shared_ptr> input_tensors); - std::shared_ptr> - convert_outputs(const std::unordered_map& output_tensors); void allocateBuffer(const size_t request_batch_size, const size_t beam_width, @@ -61,18 +62,17 @@ struct ParallelGptTritonModelInstance: AbstractTransformerModelInstance { const size_t request_output_len); void freeBuffer(); - int* d_input_ids_ = nullptr; - int* d_input_lengths_ = nullptr; - int* d_prefix_soft_prompt_lengths_ = nullptr; - int* d_input_bad_words_ = nullptr; - int* d_input_stop_words_ = nullptr; - float* d_prefix_soft_prompt_embedding_ = nullptr; + int* d_input_ids_ = nullptr; + int* d_input_lengths_ = nullptr; + int* d_request_prompt_lengths_ = nullptr; + int* d_input_bad_words_ = nullptr; + int* d_input_stop_words_ = nullptr; + T* d_request_prompt_embedding_ = nullptr; - int* d_output_ids_ = nullptr; - int* d_parent_ids_ = nullptr; - int* d_sequence_lengths_ = nullptr; + int* d_output_ids_ = nullptr; + int* d_sequence_lengths_ = nullptr; float* d_output_log_probs_ = nullptr; - float* d_cum_log_probs_ = nullptr; + float* d_cum_log_probs_ = nullptr; - int h_total_output_len_; + uint32_t* h_total_output_lengths_ = nullptr; }; diff --git a/src/fastertransformer/triton_backend/t5/CMakeLists.txt b/src/fastertransformer/triton_backend/t5/CMakeLists.txt index f6638868a..2d221539b 100644 --- a/src/fastertransformer/triton_backend/t5/CMakeLists.txt +++ b/src/fastertransformer/triton_backend/t5/CMakeLists.txt @@ -20,5 +20,5 @@ set(t5_triton_backend_files ) add_library(T5TritonBackend SHARED ${t5_triton_backend_files}) -target_link_libraries(T5TritonBackend PRIVATE T5Encoder T5Decoding) +target_link_libraries(T5TritonBackend PRIVATE TransformerTritonBackend T5Encoder T5Decoding) target_compile_features(T5TritonBackend PRIVATE cxx_std_14) diff --git a/src/fastertransformer/triton_backend/t5/T5TritonModel.cc b/src/fastertransformer/triton_backend/t5/T5TritonModel.cc index 29b2ce6b3..85138fa6f 100644 --- a/src/fastertransformer/triton_backend/t5/T5TritonModel.cc +++ b/src/fastertransformer/triton_backend/t5/T5TritonModel.cc @@ -30,56 +30,69 @@ std::shared_ptr AbstractTransformerModel::createT5Mode return nullptr; } - const int is_half = reader.GetInteger("ft_instance_hyperparameter", "is_half"); - if (is_half) { + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + if (data_type == "fp16") { return std::make_shared>(reader, model_dir); } - else { +#ifdef ENABLE_BF16 + else if (data_type == "bf16") { + return std::make_shared>(reader, model_dir); + } +#endif + else if (data_type == "fp32") { return std::make_shared>(reader, model_dir); } + else { + FT_LOG_ERROR("Unsupported data type " + data_type); + exit(-1); + } } template T5TritonModel::T5TritonModel(INIReader reader, std::string model_dir): model_dir_(model_dir) { // encoder - encoder_head_num_ = reader.GetInteger("encoder", "num_heads"); + encoder_head_num_ = reader.GetInteger("encoder", "num_heads"); encoder_size_per_head_ = reader.GetInteger("encoder", "d_kv"); - encoder_d_model_ = reader.GetInteger("encoder", "d_model"); - encoder_inter_size_ = reader.GetInteger("encoder", "d_ff"); - encoder_num_layer_ = reader.GetInteger("encoder", "num_layers"); - encoder_vocab_size_ = reader.GetInteger("encoder", "vocab_size"); + encoder_d_model_ = reader.GetInteger("encoder", "d_model"); + encoder_inter_size_ = reader.GetInteger("encoder", "d_ff"); + encoder_num_layer_ = reader.GetInteger("encoder", "num_layers"); + encoder_vocab_size_ = reader.GetInteger("encoder", "vocab_size"); encoder_num_bucket_or_max_pos_seq_len_ = reader.GetInteger("encoder", "relative_attention_num_buckets_or_max_pos_seq_len"); // decoding - decoding_head_num_ = reader.GetInteger("decoder", "num_heads"); + decoding_head_num_ = reader.GetInteger("decoder", "num_heads"); decoding_size_per_head_ = reader.GetInteger("decoder", "d_kv"); - decoding_d_model_ = reader.GetInteger("decoder", "d_model"); - decoding_inter_size_ = reader.GetInteger("decoder", "d_ff"); - decoding_num_layer_ = reader.GetInteger("decoder", "num_layers"); - decoding_vocab_size_ = reader.GetInteger("decoder", "vocab_size"); + decoding_d_model_ = reader.GetInteger("decoder", "d_model"); + decoding_inter_size_ = reader.GetInteger("decoder", "d_ff"); + decoding_num_layer_ = reader.GetInteger("decoder", "num_layers"); + decoding_vocab_size_ = reader.GetInteger("decoder", "vocab_size"); decoding_num_bucket_or_max_pos_seq_len_ = reader.GetInteger("decoder", "relative_attention_num_buckets_or_max_pos_seq_len"); - start_id_ = reader.GetInteger("decoder", "decoder_start_token_id"); - end_id_ = reader.GetInteger("decoder", "eos_token_id"); - tensor_para_size_ = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); - pipeline_para_size_ = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); + start_id_ = reader.GetInteger("decoder", "decoder_start_token_id"); + end_id_ = reader.GetInteger("decoder", "eos_token_id"); + tensor_para_size_ = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + pipeline_para_size_ = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); enable_custom_all_reduce_ = reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0); - t5_with_bias_ = (bool)reader.GetInteger("structure", "t5_with_bias", 0); - position_embedding_type_ = ft::PositionEmbeddingType(reader.GetInteger("structure", "position_embedding_type", 0)); - q_scaling_ = t5_with_bias_ ? 1.0f : (1.0f / (sqrt(encoder_size_per_head_) * 1.0f)); + t5_with_bias_ = reader.GetBoolean("structure", "t5_with_bias", false); + use_gated_activation_ = reader.GetBoolean("structure", "use_gated_activation", false); + position_embedding_type_ = + ft::PositionEmbeddingType(reader.Get("structure", "position_embedding_type", "relative") == "relative" ? 0 : 1); + q_scaling_ = t5_with_bias_ ? 1.0f : (1.0f / (sqrt(encoder_size_per_head_) * 1.0f)); max_distance_ = 128; // use default value of huggingface here } template -T5TritonModel::T5TritonModel(size_t tensor_para_size, - size_t pipeline_para_size, - int enable_custom_all_reduce, +T5TritonModel::T5TritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, std::string model_dir, - int int8_mode): + int int8_mode): tensor_para_size_(tensor_para_size), pipeline_para_size_(pipeline_para_size), + encoder_shared_weights_(std::vector>>(ft::getDeviceCount())), + decoding_shared_weights_(std::vector>>(ft::getDeviceCount())), enable_custom_all_reduce_(enable_custom_all_reduce), model_dir_(model_dir), int8_mode_(int8_mode) @@ -95,29 +108,33 @@ T5TritonModel::T5TritonModel(size_t tensor_para_size, model_name_ = reader.Get("encoder", "_name_or_path"); // encoder - encoder_head_num_ = reader.GetInteger("encoder", "num_heads"); + encoder_head_num_ = reader.GetInteger("encoder", "num_heads"); encoder_size_per_head_ = reader.GetInteger("encoder", "d_kv"); - encoder_d_model_ = reader.GetInteger("encoder", "d_model"); - encoder_inter_size_ = reader.GetInteger("encoder", "d_ff"); - encoder_num_layer_ = reader.GetInteger("encoder", "num_layers"); - encoder_vocab_size_ = reader.GetInteger("encoder", "vocab_size"); + encoder_d_model_ = reader.GetInteger("encoder", "d_model"); + encoder_inter_size_ = reader.GetInteger("encoder", "d_ff"); + encoder_num_layer_ = reader.GetInteger("encoder", "num_layers"); + encoder_vocab_size_ = reader.GetInteger("encoder", "vocab_size"); encoder_num_bucket_or_max_pos_seq_len_ = reader.GetInteger("encoder", "relative_attention_num_buckets_or_max_pos_seq_len"); // decoding - decoding_head_num_ = reader.GetInteger("decoder", "num_heads"); + decoding_head_num_ = reader.GetInteger("decoder", "num_heads"); decoding_size_per_head_ = reader.GetInteger("decoder", "d_kv"); - decoding_d_model_ = reader.GetInteger("decoder", "d_model"); - decoding_inter_size_ = reader.GetInteger("decoder", "d_ff"); - decoding_num_layer_ = reader.GetInteger("decoder", "num_layers"); - decoding_vocab_size_ = reader.GetInteger("decoder", "vocab_size"); + decoding_d_model_ = reader.GetInteger("decoder", "d_model"); + decoding_inter_size_ = reader.GetInteger("decoder", "d_ff"); + decoding_num_layer_ = reader.GetInteger("decoder", "num_layers"); + decoding_vocab_size_ = reader.GetInteger("decoder", "vocab_size"); decoding_num_bucket_or_max_pos_seq_len_ = reader.GetInteger("decoder", "relative_attention_num_buckets_or_max_pos_seq_len"); - start_id_ = reader.GetInteger("decoder", "decoder_start_token_id"); - end_id_ = reader.GetInteger("decoder", "eos_token_id"); - - t5_with_bias_ = (bool)reader.GetInteger("structure", "t5_with_bias", 0); - position_embedding_type_ = ft::PositionEmbeddingType(reader.GetInteger("structure", "position_embedding_type", 0)); + start_id_ = reader.GetInteger("decoder", "decoder_start_token_id"); + end_id_ = reader.GetInteger("decoder", "eos_token_id"); + tie_word_embeddings_ = reader.GetBoolean("decoder", "tie_word_embeddings", true); + + t5_with_bias_ = reader.GetBoolean("structure", "t5_with_bias", false); + use_gated_activation_ = reader.GetBoolean("structure", "use_gated_activation", false); + activation_type_ = ft::getActivationType(reader.Get("encoder", "feed_forward_proj")); + position_embedding_type_ = + ft::PositionEmbeddingType(reader.Get("structure", "position_embedding_type", "relative") == "relative" ? 0 : 1); q_scaling_ = t5_with_bias_ ? 1.0f : (1.0f / (sqrt(encoder_size_per_head_) * 1.0f)); max_distance_ = 128; // use default value of huggingface here @@ -125,30 +142,29 @@ T5TritonModel::T5TritonModel(size_t tensor_para_size, template std::unique_ptr -T5TritonModel::createModelInstance(int device_id, - int rank, - cudaStream_t stream, - std::pair, std::vector> nccl_comms, +T5TritonModel::createModelInstance(int device_id, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, std::shared_ptr custom_all_reduce_comm) { ft::check_cuda_error(cudaSetDevice(device_id)); - const int tensor_para_rank = rank % tensor_para_size_; - const int pipeline_para_rank = rank / tensor_para_size_; + const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); std::unique_ptr> allocator( new ft::Allocator(device_id)); allocator->setStream(stream); - cublasHandle_t cublas_handle; + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; cublasCreate(&cublas_handle); cublasLtCreate(&cublaslt_handle); cublasSetStream(cublas_handle, stream); - std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); - std::unique_ptr cublas_wrapper_mutex(new std::mutex()); + std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); + std::unique_ptr cublas_wrapper_mutex(new std::mutex()); std::unique_ptr cublas_wrapper(new ft::cublasMMWrapper( cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get())); @@ -158,6 +174,11 @@ T5TritonModel::createModelInstance(int device_id, if (std::is_same::value) { cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } @@ -165,116 +186,73 @@ T5TritonModel::createModelInstance(int device_id, const int sm_ = ft::getSMVersion(); // TODO(bhsueh) not support fused mha - // ft::AttentionType attention_type = - // ft::getAttentionType(encoder_size_per_head_, sm_, true, max_encoder_seq_len_, false); - ft::AttentionType attention_type = ft::AttentionType::UNFUSED_MHA; - - ft::NcclParam tensor_para_; - ft::NcclParam pipeline_para_; - - tensor_para_.world_size_ = tensor_para_size_; - tensor_para_.rank_ = tensor_para_rank; - tensor_para_.nccl_comm_ = nccl_comms.first[device_id]; - pipeline_para_.world_size_ = pipeline_para_size_; - pipeline_para_.rank_ = pipeline_para_rank; - pipeline_para_.nccl_comm_ = nccl_comms.second[device_id]; - - auto encoder = std::make_unique>( - ft::T5Encoder(0, - 0, - encoder_head_num_, - encoder_size_per_head_, - encoder_inter_size_, - encoder_d_model_, - encoder_num_layer_, - encoder_num_bucket_or_max_pos_seq_len_, - max_distance_, - sm_, - q_scaling_, - stream, - cublas_wrapper.get(), - allocator.get(), - false, - attention_type, - false, - t5_with_bias_ ? ft::ActivationType::Gelu : ft::ActivationType::Relu, - ft::LayerNormType::pre_layernorm, - tensor_para_, - pipeline_para_, - custom_all_reduce_comm, - enable_custom_all_reduce_)); - - auto decoding = std::make_unique>( - ft::T5Decoding(0, - 0, - 0, - 0, - decoding_head_num_, - decoding_size_per_head_, - decoding_inter_size_, - decoding_d_model_, - decoding_num_layer_, - decoding_vocab_size_, - decoding_num_bucket_or_max_pos_seq_len_, - max_distance_, - q_scaling_, - start_id_, - end_id_, - 0.0f, // beam_search_diversity_rate_, - 1, // top_k_, - 0.0f, // top_p_, - 1.0f, // temperature_, - 1.0f, // len_penalty_, - 1.0f, // repetition_penalty_, - stream, - cublas_wrapper.get(), - allocator.get(), - false, - cuda_device_prop_ptr.get(), - tensor_para_, - pipeline_para_, - t5_with_bias_ ? ft::ActivationType::Gelu : ft::ActivationType::Relu, - custom_all_reduce_comm, - enable_custom_all_reduce_)); - - auto encoder_weight = - std::unique_ptr>(new ft::T5EncoderWeight(encoder_head_num_, - encoder_size_per_head_, - encoder_d_model_, - encoder_inter_size_, - encoder_vocab_size_, - encoder_num_layer_, - encoder_num_bucket_or_max_pos_seq_len_, - tensor_para_.world_size_, - tensor_para_.rank_, - pipeline_para_.world_size_, - pipeline_para_.rank_, - t5_with_bias_, - position_embedding_type_)); - - auto decoding_weight = - std::unique_ptr>(new ft::T5DecodingWeight(decoding_head_num_, - decoding_size_per_head_, - decoding_d_model_, - decoding_inter_size_, - decoding_vocab_size_, - decoding_num_layer_, - encoder_d_model_, - decoding_num_bucket_or_max_pos_seq_len_, - tensor_para_.world_size_, - tensor_para_.rank_, - pipeline_para_.world_size_, - pipeline_para_.rank_, - t5_with_bias_, - position_embedding_type_)); - - encoder_weight->loadModel(model_dir_); - decoding_weight->loadModel(model_dir_); + ft::AttentionType attention_type = + ft::getAttentionType(encoder_size_per_head_, sm_, true, encoder_num_bucket_or_max_pos_seq_len_, false); + + ft::NcclParam tensor_para_ = nccl_params.first[comms_rank]; + ft::NcclParam pipeline_para_ = nccl_params.first[comms_rank]; + + auto encoder = std::make_unique>(ft::T5Encoder(0, + 0, + encoder_head_num_, + encoder_size_per_head_, + encoder_inter_size_, + encoder_d_model_, + encoder_num_layer_, + encoder_num_bucket_or_max_pos_seq_len_, + max_distance_, + sm_, + q_scaling_, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + attention_type, + false, + activation_type_, + ft::LayerNormType::pre_layernorm, + tensor_para_, + pipeline_para_, + custom_all_reduce_comm, + enable_custom_all_reduce_)); + + auto decoding = std::make_unique>(ft::T5Decoding(0, + 0, + 0, + 0, + decoding_head_num_, + decoding_size_per_head_, + decoding_inter_size_, + decoding_d_model_, + decoding_num_layer_, + decoding_vocab_size_, + decoding_num_bucket_or_max_pos_seq_len_, + max_distance_, + q_scaling_, + start_id_, + end_id_, + 0.0f, // beam_search_diversity_rate_, + 1, // top_k_, + 0.0f, // top_p_, + 1.0f, // temperature_, + 0.0f, // len_penalty_, + 1.0f, // repetition_penalty_, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + cuda_device_prop_ptr.get(), + tensor_para_, + pipeline_para_, + activation_type_, + tie_word_embeddings_, + custom_all_reduce_comm, + enable_custom_all_reduce_)); return std::unique_ptr>(new T5TritonModelInstance(std::move(encoder), std::move(decoding), - std::move(encoder_weight), - std::move(decoding_weight), + encoder_shared_weights_[device_id], + decoding_shared_weights_[device_id], std::move(allocator), std::move(cublas_algo_map), std::move(cublas_wrapper_mutex), @@ -282,11 +260,56 @@ T5TritonModel::createModelInstance(int device_id, std::move(cuda_device_prop_ptr))); } +template +void T5TritonModel::createSharedWeights(int device_id, int rank) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + + encoder_shared_weights_[device_id] = + std::make_shared>(encoder_head_num_, + encoder_size_per_head_, + encoder_d_model_, + encoder_inter_size_, + encoder_vocab_size_, + encoder_num_layer_, + encoder_num_bucket_or_max_pos_seq_len_, + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank, + t5_with_bias_, + use_gated_activation_, + position_embedding_type_); + + decoding_shared_weights_[device_id] = + std::make_shared>(decoding_head_num_, + decoding_size_per_head_, + decoding_d_model_, + decoding_inter_size_, + decoding_vocab_size_, + decoding_num_layer_, + encoder_d_model_, + decoding_num_bucket_or_max_pos_seq_len_, + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank, + t5_with_bias_, + use_gated_activation_, + position_embedding_type_); + + encoder_shared_weights_[device_id]->loadModel(model_dir_); + decoding_shared_weights_[device_id]->loadModel(model_dir_); + return; +} + template std::string T5TritonModel::toString() { std::stringstream ss; - std::string position_embedding_type_string = + std::string position_embedding_type_string = position_embedding_type_ == ft::PositionEmbeddingType::relative ? "relative" : "absolute"; ss << "\nModel: " @@ -299,7 +322,7 @@ std::string T5TritonModel::toString() << "\n decoding_d_model_: " << decoding_d_model_ << "\n decoding_inter_size_: " << decoding_inter_size_ << "\n decoding_num_layer_: " << decoding_num_layer_ << "\n decoding_vocab_size_: " << decoding_vocab_size_ << "\n decoding_num_bucket_or_max_pos_seq_len_: " << decoding_num_bucket_or_max_pos_seq_len_ - << "\n t5_with_bias_: " << t5_with_bias_ + << "\n t5_with_bias_: " << t5_with_bias_ << "\n use_gated_activation_: " << use_gated_activation_ << "\n position_embedding_type_: " << position_embedding_type_string << "\n start_id_: " << start_id_ << "\n end_id_: " << end_id_ << "\n model_name_: " << model_name_ << "\n model_dir_: " << model_dir_ << std::endl; @@ -307,77 +330,9 @@ std::string T5TritonModel::toString() return ss.str(); } -template -std::vector T5TritonModel::createNcclIds(const uint32_t world_size, bool multi_instances) -{ - std::vector nccl_ids(tensor_para_size_ + pipeline_para_size_); - if (multi_instances) { - if (tensor_para_size_ * pipeline_para_size_ != 1) { - printf( - "[ERROR] Multiple Instances currently only support tensor_para_size_ and pipeline_para_size_ both 1\n"); - ft::FT_CHECK(tensor_para_size_ == 1 && pipeline_para_size_ == 1); - } - nccl_ids.resize(2); - } - else { - if (world_size != tensor_para_size_ * pipeline_para_size_) { - printf( - "[ERROR] world_size (%d) should equal to tensor_para_size_ * pipeline_para_size_ (%ld * %ld here) \n", - world_size, - tensor_para_size_, - pipeline_para_size_); - ft::FT_CHECK(world_size == tensor_para_size_ * pipeline_para_size_); - } - } - - for (uint32_t i = 0; i < nccl_ids.size(); i++) { - NCCLCHECK(ncclGetUniqueId(&nccl_ids[i])); - } - return nccl_ids; -} - -template -std::pair, std::vector> T5TritonModel::createNcclComms( - std::vector nccl_ids, const int node_id, bool multi_instances, int instance_id) -{ - const int gpu_count = ft::getDeviceCount(); - std::vector tensor_para_comms(gpu_count); - std::vector pipeline_para_comms(gpu_count); - if (multi_instances) { - ncclUniqueId tensor_para_nccl_uid = nccl_ids[0]; - ncclUniqueId pipeline_para_nccl_uid = nccl_ids[1]; - size_t tensor_para_rank = 0; - size_t pipeline_para_rank = 0; - - ft::check_cuda_error(cudaSetDevice(instance_id)); - NCCLCHECK(ncclCommInitRank( - &tensor_para_comms[instance_id], tensor_para_size_, tensor_para_nccl_uid, tensor_para_rank)); - NCCLCHECK(ncclCommInitRank( - &pipeline_para_comms[instance_id], pipeline_para_size_, pipeline_para_nccl_uid, pipeline_para_rank)); - } - else { - NCCLCHECK(ncclGroupStart()); - for (int gid = 0; gid < gpu_count; gid++) { - int rank = node_id * gpu_count + gid; - size_t tensor_para_rank = rank % tensor_para_size_; - size_t pipeline_para_rank = rank / tensor_para_size_; - ncclUniqueId tensor_para_nccl_uid = nccl_ids[pipeline_para_rank]; - ncclUniqueId pipeline_para_nccl_uid = nccl_ids[pipeline_para_size_ + tensor_para_rank]; - - ft::check_cuda_error(cudaSetDevice(gid)); - NCCLCHECK( - ncclCommInitRank(&tensor_para_comms[gid], tensor_para_size_, tensor_para_nccl_uid, tensor_para_rank)); - NCCLCHECK(ncclCommInitRank( - &pipeline_para_comms[gid], pipeline_para_size_, pipeline_para_nccl_uid, pipeline_para_rank)); - } - NCCLCHECK(ncclGroupEnd()); - } - return std::pair, std::vector>(tensor_para_comms, pipeline_para_comms); -} - template void T5TritonModel::createCustomComms(std::vector>* custom_all_reduce_comms, - int world_size) + int world_size) { using commDataType = typename ft::CustomARCommTypeConverter::Type; ft::initCustomAllReduceComm(custom_all_reduce_comms, enable_custom_all_reduce_, world_size); @@ -397,3 +352,6 @@ int T5TritonModel::getPipelineParaSize() template struct T5TritonModel; template struct T5TritonModel; +#ifdef ENABLE_BF16 +template struct T5TritonModel<__nv_bfloat16>; +#endif diff --git a/src/fastertransformer/triton_backend/t5/T5TritonModel.h b/src/fastertransformer/triton_backend/t5/T5TritonModel.h index d5803d8ba..09f444e76 100644 --- a/src/fastertransformer/triton_backend/t5/T5TritonModel.h +++ b/src/fastertransformer/triton_backend/t5/T5TritonModel.h @@ -31,34 +31,29 @@ template struct T5TritonModel: public AbstractTransformerModel { T5TritonModel(INIReader reader, std::string model_dir); - T5TritonModel(size_t tensor_para_size, - size_t pipeline_para_size, - int enable_custom_all_reduce, + T5TritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, std::string model_dir, - int int8_mode); + int int8_mode); ~T5TritonModel() = default; virtual std::unique_ptr - createModelInstance(int deviceId, - int rank, - cudaStream_t stream, - std::pair, std::vector> nccl_comms, + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, std::shared_ptr custom_all_reduce_comm = nullptr); - virtual void createCustomComms(std::vector>* custom_all_reduce_comms, - int world_size) override; + virtual void createSharedWeights(int deviceId, int rank) override; - virtual std::pair, std::vector> - createNcclComms(std::vector nccl_ids, - const int node_id, - bool multi_instances = false, - int instance_id = 0) override; + virtual void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) override; - virtual std::vector createNcclIds(const uint32_t world_size, bool multi_instances = false) override; virtual std::string toString() override; - virtual int getTensorParaSize() override; - virtual int getPipelineParaSize() override; + virtual int getTensorParaSize() override; + virtual int getPipelineParaSize() override; private: // encoder @@ -82,21 +77,29 @@ struct T5TritonModel: public AbstractTransformerModel { float q_scaling_; size_t max_distance_; - int start_id_; - int end_id_; + int start_id_; + int end_id_; + + bool tie_word_embeddings_; size_t tensor_para_size_; size_t pipeline_para_size_; + // shared weights for each device + std::vector>> encoder_shared_weights_; + std::vector>> decoding_shared_weights_; + // t5 structure difference - bool t5_with_bias_; + bool t5_with_bias_; + bool use_gated_activation_; ft::PositionEmbeddingType position_embedding_type_; + ft::ActivationType activation_type_; bool is_fp16_; - int int8_mode_; + int int8_mode_; int enable_custom_all_reduce_ = 0; std::string model_name_; std::string model_dir_; -}; \ No newline at end of file +}; diff --git a/src/fastertransformer/triton_backend/t5/T5TritonModelInstance.cc b/src/fastertransformer/triton_backend/t5/T5TritonModelInstance.cc index 0024126cb..87f1b3180 100644 --- a/src/fastertransformer/triton_backend/t5/T5TritonModelInstance.cc +++ b/src/fastertransformer/triton_backend/t5/T5TritonModelInstance.cc @@ -23,19 +23,19 @@ namespace ft = fastertransformer; template -T5TritonModelInstance::T5TritonModelInstance(std::unique_ptr> t5_encoder, - std::unique_ptr> t5_decoding, - std::unique_ptr> t5_encoder_weight, - std::unique_ptr> t5_decoding_weight, +T5TritonModelInstance::T5TritonModelInstance(std::unique_ptr> t5_encoder, + std::unique_ptr> t5_decoding, + std::shared_ptr> t5_encoder_weight, + std::shared_ptr> t5_decoding_weight, std::unique_ptr> allocator, - std::unique_ptr cublas_algo_map, - std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, std::unique_ptr cublas_wrapper, - std::unique_ptr cuda_device_prop_ptr): + std::unique_ptr cuda_device_prop_ptr): t5_encoder_(std::move(t5_encoder)), t5_decoding_(std::move(t5_decoding)), - t5_encoder_weight_(std::move(t5_encoder_weight)), - t5_decoding_weight_(std::move(t5_decoding_weight)), + t5_encoder_weight_(t5_encoder_weight), + t5_decoding_weight_(t5_decoding_weight), allocator_(std::move(allocator)), cublas_algo_map_(std::move(cublas_algo_map)), cublas_wrapper_mutex_(std::move(cublas_wrapper_mutex)), @@ -48,8 +48,8 @@ template std::unordered_map T5TritonModelInstance::convert_inputs(std::shared_ptr> input_tensors) { - move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_); - move_tensor_H2D(input_tensors->at("sequence_length"), d_input_lengths_); + move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_, &allocator_); + move_tensor_H2D(input_tensors->at("sequence_length"), d_input_lengths_, &allocator_); std::unordered_map ft_input_tensors{ {"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, @@ -76,12 +76,11 @@ std::shared_ptr> T5TritonModelInstance::forward(std::shared_ptr> input_tensors) { const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; - const size_t mem_max_seq_len = input_tensors->at("input_ids").shape[1]; - const size_t max_output_len = *((uint*)input_tensors->at("max_output_len").data); + const size_t mem_max_seq_len = input_tensors->at("input_ids").shape[1]; + const size_t max_output_len = *((uint*)input_tensors->at("max_output_len").data); const size_t beam_width = input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; - freeBuffer(); // free buffer of previous iteration allocateBuffer(request_batch_size, beam_width, max_output_len, mem_max_seq_len); std::unordered_map encoder_input_tensors = convert_inputs(input_tensors); @@ -105,13 +104,13 @@ T5TritonModelInstance::forward(std::shared_ptrfind("bad_words_list") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_); + move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_, &allocator_); decoding_input_tensors.insert( {"bad_words_list", as_GPU_tensor(input_tensors->at("bad_words_list"), d_input_bad_words_)}); } if (input_tensors->find("stop_words_list") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_); + move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_, &allocator_); decoding_input_tensors.insert( {"stop_words_list", as_GPU_tensor(input_tensors->at("stop_words_list"), d_input_stop_words_)}); } @@ -144,23 +143,6 @@ T5TritonModelInstance::forward(std::shared_ptrforward(&encoder_output_tensors, &encoder_input_tensors, t5_encoder_weight_.get()); t5_decoding_->forward(&decoding_output_tensors, &decoding_input_tensors, t5_decoding_weight_.get()); - if (d_input_ids_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_ids_)); - d_input_ids_ = nullptr; - } - if (d_input_lengths_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_lengths_)); - d_input_lengths_ = nullptr; - } - if (d_input_bad_words_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_bad_words_)); - d_input_bad_words_ = nullptr; - } - if (d_input_stop_words_ != nullptr) { - ft::check_cuda_error(cudaFree(d_input_stop_words_)); - d_input_stop_words_ = nullptr; - } - return convert_outputs(decoding_output_tensors); } @@ -176,22 +158,30 @@ void T5TritonModelInstance::allocateBuffer(const size_t request_batch_size, const size_t max_output_len, const size_t mem_max_seq_len) { - ft::deviceMalloc(&d_encoder_outputs_, request_batch_size * mem_max_seq_len * t5_encoder_->getDModel()); - ft::deviceMalloc(&d_output_ids_, request_batch_size * beam_width * max_output_len); - ft::deviceMalloc(&d_sequence_lengths_, request_batch_size * beam_width); - ft::deviceMalloc(&d_output_log_probs_, request_batch_size * beam_width * max_output_len); - ft::deviceMalloc(&d_cum_log_probs_, request_batch_size * beam_width * max_output_len); + d_output_ids_ = (int*)(allocator_->reMalloc( + d_output_ids_, sizeof(int) * request_batch_size * beam_width * max_output_len, false)); + d_encoder_outputs_ = (T*)(allocator_->reMalloc( + d_encoder_outputs_, sizeof(T) * request_batch_size * mem_max_seq_len * t5_encoder_->getDModel(), false)); + d_sequence_lengths_ = + (int*)(allocator_->reMalloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width, false)); + d_output_log_probs_ = (float*)(allocator_->reMalloc( + d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * max_output_len, false)); + d_cum_log_probs_ = (float*)(allocator_->reMalloc( + d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width * max_output_len, false)); } template void T5TritonModelInstance::freeBuffer() { - ft::deviceFree(d_encoder_outputs_); - ft::deviceFree(d_output_ids_); - ft::deviceFree(d_sequence_lengths_); - ft::deviceFree(d_output_log_probs_); - ft::deviceFree(d_cum_log_probs_); + allocator_->free((void**)(&d_encoder_outputs_)); + allocator_->free((void**)(&d_output_ids_)); + allocator_->free((void**)(&d_sequence_lengths_)); + allocator_->free((void**)(&d_output_log_probs_)); + allocator_->free((void**)(&d_cum_log_probs_)); } template struct T5TritonModelInstance; template struct T5TritonModelInstance; +#ifdef ENABLE_BF16 +template struct T5TritonModelInstance<__nv_bfloat16>; +#endif \ No newline at end of file diff --git a/src/fastertransformer/triton_backend/t5/T5TritonModelInstance.h b/src/fastertransformer/triton_backend/t5/T5TritonModelInstance.h index a8a9e9c7b..2f30a7418 100644 --- a/src/fastertransformer/triton_backend/t5/T5TritonModelInstance.h +++ b/src/fastertransformer/triton_backend/t5/T5TritonModelInstance.h @@ -27,15 +27,15 @@ namespace ft = fastertransformer; template struct T5TritonModelInstance: AbstractTransformerModelInstance { - T5TritonModelInstance(std::unique_ptr> t5_encoder, - std::unique_ptr> t5_decoding, - std::unique_ptr> t5_encoder_weight, - std::unique_ptr> t5_decoding_weight, + T5TritonModelInstance(std::unique_ptr> t5_encoder, + std::unique_ptr> t5_decoding, + std::shared_ptr> t5_encoder_weight, + std::shared_ptr> t5_decoding_weight, std::unique_ptr> allocator, - std::unique_ptr cublas_algo_map, - std::unique_ptr cublas_wrapper_mutex, - std::unique_ptr cublas_wrapper, - std::unique_ptr cuda_device_prop_ptr); + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr); ~T5TritonModelInstance(); std::shared_ptr> @@ -49,15 +49,15 @@ struct T5TritonModelInstance: AbstractTransformerModelInstance { forward(std::shared_ptr> input_tensors) override; private: - const std::unique_ptr> t5_encoder_; - const std::unique_ptr> t5_encoder_weight_; - const std::unique_ptr> t5_decoding_; - const std::unique_ptr> t5_decoding_weight_; + const std::unique_ptr> t5_encoder_; + const std::shared_ptr> t5_encoder_weight_; + const std::unique_ptr> t5_decoding_; + const std::shared_ptr> t5_decoding_weight_; const std::unique_ptr> allocator_; - const std::unique_ptr cublas_algo_map_; - const std::unique_ptr cublas_wrapper_mutex_; - const std::unique_ptr cublas_wrapper_; - const std::unique_ptr cuda_device_prop_ptr_; + const std::unique_ptr cublas_algo_map_; + const std::unique_ptr cublas_wrapper_mutex_; + const std::unique_ptr cublas_wrapper_; + const std::unique_ptr cuda_device_prop_ptr_; std::unordered_map convert_inputs(std::shared_ptr> input_tensors); @@ -71,16 +71,16 @@ struct T5TritonModelInstance: AbstractTransformerModelInstance { const size_t mem_max_seq_len); void freeBuffer(); - int* d_input_ids_ = nullptr; - int* d_input_lengths_ = nullptr; - int* d_input_bad_words_ = nullptr; + int* d_input_ids_ = nullptr; + int* d_input_lengths_ = nullptr; + int* d_input_bad_words_ = nullptr; int* d_input_stop_words_ = nullptr; - T* d_encoder_outputs_ = nullptr; - int* d_output_ids_ = nullptr; - int* d_sequence_lengths_ = nullptr; + T* d_encoder_outputs_ = nullptr; + int* d_output_ids_ = nullptr; + int* d_sequence_lengths_ = nullptr; float* d_output_log_probs_ = nullptr; - float* d_cum_log_probs_ = nullptr; + float* d_cum_log_probs_ = nullptr; int h_total_output_len_; }; diff --git a/src/fastertransformer/triton_backend/transformer_triton_backend.cpp b/src/fastertransformer/triton_backend/transformer_triton_backend.cpp new file mode 100644 index 000000000..eb8d7e913 --- /dev/null +++ b/src/fastertransformer/triton_backend/transformer_triton_backend.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" + +std::pair, std::vector> +AbstractTransformerModel::createNcclParams(const int node_id, const int device_id_start, const bool multi_node) +{ + const int gpu_count = ft::getDeviceCount(); + const int tensor_para_size = getTensorParaSize(); + const int pipeline_para_size = getPipelineParaSize(); + const int local_comm_size = multi_node ? gpu_count : tensor_para_size * pipeline_para_size; + ft::FT_CHECK(tensor_para_size > 0 && pipeline_para_size > 0); + ft::FT_CHECK(device_id_start + (int)local_comm_size <= gpu_count); + + std::vector nccl_ids; + if (tensor_para_size > 1 || pipeline_para_size > 1) { + nccl_ids.resize(tensor_para_size + pipeline_para_size); + if (node_id == 0) { + for (uint32_t i = 0; i < nccl_ids.size(); i++) { + ft::ftNcclGetUniqueId(nccl_ids[i]); + } + } + for (size_t i = 0; i < nccl_ids.size(); i++) { + ft::mpi::bcast(&nccl_ids[i], sizeof(nccl_ids[i]), ft::mpi::MPI_TYPE_BYTE, 0, ft::mpi::COMM_WORLD); + } + } + + std::vector tensor_para_params(local_comm_size); + std::vector pipeline_para_params(local_comm_size); + // Don't init comm when size == 1 + if (tensor_para_size > 1) { + ft::ftNcclGroupStart(); + for (int gid = device_id_start; gid < device_id_start + local_comm_size; gid++) { + int rank = node_id * gpu_count + gid - device_id_start; + int tensor_para_rank = rank % tensor_para_size; + int pipeline_para_rank = rank / tensor_para_size; + + ft::NcclUid tensor_para_nccl_uid = nccl_ids[pipeline_para_rank]; + ft::check_cuda_error(cudaSetDevice(gid)); + ft::ftNcclCommInitRank( + tensor_para_params[gid - device_id_start], tensor_para_rank, tensor_para_size, tensor_para_nccl_uid); + } + ft::ftNcclGroupEnd(); + } + if (pipeline_para_size > 1) { + ft::ftNcclGroupStart(); + for (int gid = device_id_start; gid < device_id_start + local_comm_size; gid++) { + int rank = node_id * gpu_count + gid - device_id_start; + int tensor_para_rank = rank % tensor_para_size; + int pipeline_para_rank = rank / tensor_para_size; + + ft::NcclUid pipeline_para_nccl_uid = nccl_ids[pipeline_para_size + tensor_para_rank]; + ft::check_cuda_error(cudaSetDevice(gid)); + ft::ftNcclCommInitRank(pipeline_para_params[gid - device_id_start], + pipeline_para_rank, + pipeline_para_size, + pipeline_para_nccl_uid); + } + ft::ftNcclGroupEnd(); + } + return std::pair, std::vector>(tensor_para_params, pipeline_para_params); +} diff --git a/src/fastertransformer/triton_backend/transformer_triton_backend.hpp b/src/fastertransformer/triton_backend/transformer_triton_backend.hpp index bb0d2a3bd..0fe8b8f99 100644 --- a/src/fastertransformer/triton_backend/transformer_triton_backend.hpp +++ b/src/fastertransformer/triton_backend/transformer_triton_backend.hpp @@ -23,6 +23,7 @@ #include "src/fastertransformer/utils/Tensor.h" #include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/mpi_utils.h" #include "src/fastertransformer/utils/nccl_utils.h" namespace ft = fastertransformer; @@ -32,59 +33,59 @@ namespace triton { #include "triton/core/tritonbackend.h" #include "triton/core/tritonserver.h" -typedef TRITONSERVER_DataType DataType; +typedef TRITONSERVER_DataType DataType; typedef TRITONSERVER_MemoryType MemoryType; constexpr TRITONSERVER_DataType TYPE_INVALID = TRITONSERVER_TYPE_INVALID; -constexpr TRITONSERVER_DataType TYPE_BOOL = TRITONSERVER_TYPE_BOOL; -constexpr TRITONSERVER_DataType TYPE_UINT8 = TRITONSERVER_TYPE_UINT8; -constexpr TRITONSERVER_DataType TYPE_UINT16 = TRITONSERVER_TYPE_UINT16; -constexpr TRITONSERVER_DataType TYPE_UINT32 = TRITONSERVER_TYPE_UINT32; -constexpr TRITONSERVER_DataType TYPE_UINT64 = TRITONSERVER_TYPE_UINT64; -constexpr TRITONSERVER_DataType TYPE_INT8 = TRITONSERVER_TYPE_INT8; -constexpr TRITONSERVER_DataType TYPE_INT16 = TRITONSERVER_TYPE_INT16; -constexpr TRITONSERVER_DataType TYPE_INT32 = TRITONSERVER_TYPE_INT32; -constexpr TRITONSERVER_DataType TYPE_INT64 = TRITONSERVER_TYPE_INT64; -constexpr TRITONSERVER_DataType TYPE_FP16 = TRITONSERVER_TYPE_FP16; -constexpr TRITONSERVER_DataType TYPE_FP32 = TRITONSERVER_TYPE_FP32; -constexpr TRITONSERVER_DataType TYPE_FP64 = TRITONSERVER_TYPE_FP64; -constexpr TRITONSERVER_DataType TYPE_BYTES = TRITONSERVER_TYPE_BYTES; +constexpr TRITONSERVER_DataType TYPE_BOOL = TRITONSERVER_TYPE_BOOL; +constexpr TRITONSERVER_DataType TYPE_UINT8 = TRITONSERVER_TYPE_UINT8; +constexpr TRITONSERVER_DataType TYPE_UINT16 = TRITONSERVER_TYPE_UINT16; +constexpr TRITONSERVER_DataType TYPE_UINT32 = TRITONSERVER_TYPE_UINT32; +constexpr TRITONSERVER_DataType TYPE_UINT64 = TRITONSERVER_TYPE_UINT64; +constexpr TRITONSERVER_DataType TYPE_INT8 = TRITONSERVER_TYPE_INT8; +constexpr TRITONSERVER_DataType TYPE_INT16 = TRITONSERVER_TYPE_INT16; +constexpr TRITONSERVER_DataType TYPE_INT32 = TRITONSERVER_TYPE_INT32; +constexpr TRITONSERVER_DataType TYPE_INT64 = TRITONSERVER_TYPE_INT64; +constexpr TRITONSERVER_DataType TYPE_FP16 = TRITONSERVER_TYPE_FP16; +constexpr TRITONSERVER_DataType TYPE_FP32 = TRITONSERVER_TYPE_FP32; +constexpr TRITONSERVER_DataType TYPE_FP64 = TRITONSERVER_TYPE_FP64; +constexpr TRITONSERVER_DataType TYPE_BYTES = TRITONSERVER_TYPE_BYTES; // constexpr TRITONSERVER_DataType TYPE_BF16 = TRITONSERVER_TYPE_BF16; // BF16 is not supported in Triton -constexpr TRITONSERVER_MemoryType MEMORY_CPU = TRITONSERVER_MEMORY_CPU; +constexpr TRITONSERVER_MemoryType MEMORY_CPU = TRITONSERVER_MEMORY_CPU; constexpr TRITONSERVER_MemoryType MEMORY_CPU_PINNED = TRITONSERVER_MEMORY_CPU_PINNED; -constexpr TRITONSERVER_MemoryType MEMORY_GPU = TRITONSERVER_MEMORY_GPU; +constexpr TRITONSERVER_MemoryType MEMORY_GPU = TRITONSERVER_MEMORY_GPU; #else -typedef ft::DataType DataType; +typedef ft::DataType DataType; typedef ft::MemoryType MemoryType; constexpr DataType TYPE_INVALID = ft::TYPE_INVALID; -constexpr DataType TYPE_BOOL = ft::TYPE_BOOL; -constexpr DataType TYPE_UINT8 = ft::TYPE_UINT8; -constexpr DataType TYPE_UINT16 = ft::TYPE_UINT16; -constexpr DataType TYPE_UINT32 = ft::TYPE_UINT32; -constexpr DataType TYPE_UINT64 = ft::TYPE_UINT64; -constexpr DataType TYPE_INT8 = ft::TYPE_INT8; -constexpr DataType TYPE_INT16 = ft::TYPE_INT16; -constexpr DataType TYPE_INT32 = ft::TYPE_INT32; -constexpr DataType TYPE_INT64 = ft::TYPE_INT64; -constexpr DataType TYPE_FP16 = ft::TYPE_FP16; -constexpr DataType TYPE_FP32 = ft::TYPE_FP32; -constexpr DataType TYPE_FP64 = ft::TYPE_FP64; -constexpr DataType TYPE_BYTES = ft::TYPE_BYTES; +constexpr DataType TYPE_BOOL = ft::TYPE_BOOL; +constexpr DataType TYPE_UINT8 = ft::TYPE_UINT8; +constexpr DataType TYPE_UINT16 = ft::TYPE_UINT16; +constexpr DataType TYPE_UINT32 = ft::TYPE_UINT32; +constexpr DataType TYPE_UINT64 = ft::TYPE_UINT64; +constexpr DataType TYPE_INT8 = ft::TYPE_INT8; +constexpr DataType TYPE_INT16 = ft::TYPE_INT16; +constexpr DataType TYPE_INT32 = ft::TYPE_INT32; +constexpr DataType TYPE_INT64 = ft::TYPE_INT64; +constexpr DataType TYPE_FP16 = ft::TYPE_FP16; +constexpr DataType TYPE_FP32 = ft::TYPE_FP32; +constexpr DataType TYPE_FP64 = ft::TYPE_FP64; +constexpr DataType TYPE_BYTES = ft::TYPE_BYTES; // constexpr DataType TYPE_BF16 = ft::TYPE_BF16; -constexpr MemoryType MEMORY_CPU = ft::MEMORY_CPU; +constexpr MemoryType MEMORY_CPU = ft::MEMORY_CPU; constexpr MemoryType MEMORY_CPU_PINNED = ft::MEMORY_CPU_PINNED; -constexpr MemoryType MEMORY_GPU = ft::MEMORY_GPU; +constexpr MemoryType MEMORY_GPU = ft::MEMORY_GPU; #endif struct Tensor { - const MemoryType where; - const DataType type; + const MemoryType where; + const DataType type; const std::vector shape; - const void* data; + const void* data; Tensor(const MemoryType _where, const DataType _type, const std::vector _shape, const void* _data): where(_where), type(_type), shape(_shape), data(_data) @@ -146,7 +147,7 @@ struct Tensor { ft::Tensor convertTritonTensorToFt() { - ft::DataType ft_data_type = convertTritonTypeToFt(type); + ft::DataType ft_data_type = convertTritonTypeToFt(type); ft::MemoryType ft_memory_type; switch (where) { case MEMORY_CPU: @@ -230,6 +231,8 @@ struct Tensor { } // namespace triton +using triton_stream_cb_t = void(std::shared_ptr>, void*); + struct AbstractTransformerModel; struct AbstractTransformerModelInstance; @@ -239,28 +242,45 @@ struct AbstractTransformerModelInstance { virtual std::shared_ptr> forward(std::shared_ptr> input_tensors) = 0; + + void registerCallback(triton_stream_cb_t* cb, void* ctx) + { + stream_cb_ = cb; + stream_ctx_ = ctx; + } + + void unRegisterCallback() + { + stream_cb_ = nullptr; + stream_ctx_ = nullptr; + } + + triton_stream_cb_t* stream_cb_ = nullptr; + void* stream_ctx_ = nullptr; }; struct AbstractTransformerModel { static std::shared_ptr createGptModel(std::string inifile); static std::shared_ptr createGptJModel(std::string inifile); + static std::shared_ptr createGptNeoXModel(std::string inifile); static std::shared_ptr createT5Model(std::string model_dir); - virtual std::vector createNcclIds(const uint32_t world_size, bool multi_instances = false) = 0; - virtual std::pair, std::vector> createNcclComms( - std::vector nccl_ids, const int node_id, bool multi_instances = false, int instance_id = 0) = 0; + std::pair, std::vector> + createNcclParams(const int node_id, const int device_id_start = 0, const bool multi_node = false); virtual void createCustomComms(std::vector>* custom_all_reduce_comms, - int world_size) = 0; + int world_size) = 0; virtual std::unique_ptr - createModelInstance(int deviceId, - int rank, - cudaStream_t stream, - std::pair, std::vector> nccl_comms, + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, std::shared_ptr custom_all_reduce_comm = nullptr) = 0; - virtual std::string toString() = 0; - virtual int getTensorParaSize() = 0; - virtual int getPipelineParaSize() = 0; + virtual void createSharedWeights(int deviceId, int rank) = 0; + + virtual std::string toString() = 0; + virtual int getTensorParaSize() = 0; + virtual int getPipelineParaSize() = 0; }; diff --git a/src/fastertransformer/triton_backend/triton_utils.hpp b/src/fastertransformer/triton_backend/triton_utils.hpp index 52d0c3a92..def98a156 100644 --- a/src/fastertransformer/triton_backend/triton_utils.hpp +++ b/src/fastertransformer/triton_backend/triton_utils.hpp @@ -22,7 +22,9 @@ namespace ft = fastertransformer; template -void move_tensor_H2D(const triton::Tensor &tensor, T* &d_ptr) +void move_tensor_H2D(const triton::Tensor& tensor, + T*& d_ptr, + const std::unique_ptr>* allocator) { if (tensor.where == triton::MEMORY_GPU) { return; @@ -32,15 +34,18 @@ void move_tensor_H2D(const triton::Tensor &tensor, T* &d_ptr) for (auto t : tensor.shape) { tensor_size *= t; } - ft::deviceMalloc(&d_ptr, tensor_size, false); - ft::cudaH2Dcpy(d_ptr, (T*) tensor.data, tensor_size); + + cudaStream_t stream = (*allocator)->returnStream(); + + d_ptr = (T*)((*allocator)->reMalloc(d_ptr, sizeof(T) * tensor_size, false)); + ft::check_cuda_error(cudaMemcpyAsync(d_ptr, (T*)tensor.data, sizeof(T) * tensor_size, cudaMemcpyDefault, stream)); } template -ft::Tensor as_GPU_tensor(const triton::Tensor &tensor, T* d_ptr) +ft::Tensor as_GPU_tensor(const triton::Tensor& tensor, T* d_ptr) { - return ft::Tensor {ft::MEMORY_GPU, - triton::Tensor::convertTritonTypeToFt(tensor.type), - tensor.shape, - tensor.where == triton::MEMORY_CPU ? d_ptr : tensor.data}; + return ft::Tensor{ft::MEMORY_GPU, + triton::Tensor::convertTritonTypeToFt(tensor.type), + tensor.shape, + tensor.where == triton::MEMORY_CPU ? d_ptr : tensor.data}; } diff --git a/src/fastertransformer/utils/CMakeLists.txt b/src/fastertransformer/utils/CMakeLists.txt index 3d0f28a6d..9e843063c 100644 --- a/src/fastertransformer/utils/CMakeLists.txt +++ b/src/fastertransformer/utils/CMakeLists.txt @@ -44,10 +44,19 @@ set_property(TARGET memory_utils PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET memory_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(memory_utils PUBLIC -lnvToolsExt) +add_library(mpi_utils STATIC mpi_utils.cc) +set_property(TARGET mpi_utils PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET mpi_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +if (BUILD_MULTI_GPU) + target_link_libraries(mpi_utils PUBLIC -lmpi) +endif() + add_library(nccl_utils STATIC nccl_utils.cc) set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(nccl_utils PUBLIC -lnccl) +if (BUILD_MULTI_GPU) + target_link_libraries(nccl_utils PUBLIC -lnccl mpi_utils) +endif() add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) @@ -71,3 +80,7 @@ else() -lcublas -lcublasLt -lcudart -lcurand cublasAlgoMap memory_utils) endif() + +add_library(tensor STATIC Tensor.cc) +set_property(TARGET tensor PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET tensor PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/fastertransformer/utils/ScaleList.h b/src/fastertransformer/utils/ScaleList.h index 44a855640..49edf0aaa 100644 --- a/src/fastertransformer/utils/ScaleList.h +++ b/src/fastertransformer/utils/ScaleList.h @@ -42,10 +42,10 @@ struct ScaleList { // Part 5 -- 21: reverse const float* d_scale_list_ = nullptr; const float* h_scale_list_ = nullptr; - size_t size_ = ACTIVATION_AMAX_NUM + 9 * 768 + INT8O_GEMM_NUM + TRT_AMAX_NUM; - size_t p2_offset_ = ACTIVATION_AMAX_NUM; - size_t p3_offset_ = ACTIVATION_AMAX_NUM + 9 * 768; - size_t p4_offset_ = ACTIVATION_AMAX_NUM + 9 * 768 + INT8O_GEMM_NUM; + size_t size_ = ACTIVATION_AMAX_NUM + 9 * 768 + INT8O_GEMM_NUM + TRT_AMAX_NUM; + size_t p2_offset_ = ACTIVATION_AMAX_NUM; + size_t p3_offset_ = ACTIVATION_AMAX_NUM + 9 * 768; + size_t p4_offset_ = ACTIVATION_AMAX_NUM + 9 * 768 + INT8O_GEMM_NUM; }; } // namespace fastertransformer diff --git a/src/fastertransformer/utils/Tensor.cc b/src/fastertransformer/utils/Tensor.cc new file mode 100644 index 000000000..22d543365 --- /dev/null +++ b/src/fastertransformer/utils/Tensor.cc @@ -0,0 +1,422 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/string_utils.h" + +#include "stdlib.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fastertransformer { + +Tensor::Tensor(): + // a none tensor. + where(MEMORY_CPU), + type(TYPE_INVALID), + shape({}), + data(nullptr), + offsets({}) // only a record to record offset +{ +} + +Tensor::Tensor(const MemoryType _where, const DataType _type, const std::vector _shape, const void* _data): + where(_where), type(_type), shape(_shape), data(_data) +{ +} + +Tensor::Tensor(const MemoryType _where, + const DataType _type, + const std::vector _shape, + const void* _data, + const std::vector _offset): + where(_where), type(_type), shape(_shape), data(_data), offsets(_offset) +{ +} + +void Tensor::parseNpyIntro(FILE*& f_ptr, uint32_t& header_len, uint32_t& start_data) +{ + const char magic[] = "\x93" + "NUMPY"; + char magic_test[sizeof(magic)] = "\0"; + + size_t n_elems = fread((void*)magic_test, sizeof(char), sizeof(magic) - 1, f_ptr); + if (n_elems != sizeof(magic) - 1 || std::string(magic) != std::string(magic_test)) { + throw std::runtime_error("Could read magic token in NPY file"); + } + + uint8_t npy_major = 0; + uint8_t npy_minor = 0; + n_elems = fread((void*)&npy_major, sizeof(uint8_t), 1, f_ptr); + n_elems += fread((void*)&npy_minor, sizeof(uint8_t), 1, f_ptr); + + if (npy_major == 1) { + uint16_t header_len_u16 = 0; + n_elems = fread((void*)&header_len_u16, sizeof(uint16_t), 1, f_ptr); + header_len = header_len_u16; + } + else if (npy_major == 2) { + uint32_t header_len_u32 = 0; + n_elems = fread((void*)&header_len_u32, sizeof(uint32_t), 1, f_ptr); + header_len = header_len_u32; + } + else { + throw std::runtime_error("Unsupported npy version: " + std::to_string(npy_major)); + } + + start_data = 8 + 2 * npy_major + header_len; +} + +int Tensor::parseNpyHeader(FILE*& f_ptr, uint32_t header_len, DataType& type, std::vector& shape) +{ + char* header_c = (char*)malloc(header_len * sizeof(char)); + size_t n_elems = fread((void*)header_c, sizeof(char), header_len, f_ptr); + if (n_elems != header_len) { + free(header_c); + return -1; + } + std::string header(header_c, header_len); + free(header_c); + + size_t start, end; + start = header.find("'descr'") + 7; + start = header.find("'", start); + end = header.find("'", start + 1); + type = typeFromNumpyDesc(header.substr(start + 1, end - start - 1)); + + start = header.find("'fortran_order'") + 15; + start = header.find(":", start); + end = header.find(",", start + 1); + if (header.substr(start + 1, end - start - 1).find("False") == std::string::npos) { + throw std::runtime_error("Unsupported value for fortran_order while reading npy file"); + } + + start = header.find("'shape'") + 7; + start = header.find("(", start); + end = header.find(")", start + 1); + + std::istringstream shape_stream(header.substr(start + 1, end - start - 1)); + std::string token; + + shape.clear(); + while (std::getline(shape_stream, token, ',')) { + if (token.find_first_not_of(' ') == std::string::npos) { + break; + } + shape.push_back(std::stoul(token)); + } + + return 0; +} + +Tensor Tensor::loadNpy(const std::string& npy_file, const MemoryType where) +{ + DataType type; + std::vector shape; + + FILE* f_ptr = fopen(npy_file.c_str(), "rb"); + if (f_ptr == nullptr) { + throw std::runtime_error("Could not open file " + npy_file); + } + uint32_t header_len, start_data; + parseNpyIntro(f_ptr, header_len, start_data); + parseNpyHeader(f_ptr, header_len, type, shape); + + const size_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + void* data_cpu = malloc(size * Tensor::getTypeSize(type)); + void* data = data_cpu; + + size_t n_elems = fread(data_cpu, Tensor::getTypeSize(type), size, f_ptr); + FT_CHECK_WITH_INFO(n_elems == size, "reading tensor failed"); + if (where == MEMORY_GPU) { + cudaMalloc(&data, size * Tensor::getTypeSize(type)); + cudaMemcpy(data, data_cpu, size * Tensor::getTypeSize(type), cudaMemcpyHostToDevice); + free(data_cpu); + } + + fclose(f_ptr); + return Tensor(where, type, shape, data); +} + +size_t Tensor::size() const +{ + if (data == nullptr || shape.size() == 0) { + return 0; + } + return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); +} + +size_t Tensor::sizeBytes() const +{ + return size() * Tensor::getTypeSize(type); +} + +std::string Tensor::whereToString() const +{ + static const std::unordered_map mem_to_string{ + {MEMORY_CPU, "CPU"}, {MEMORY_CPU_PINNED, "CPU_PINNED"}, {MEMORY_GPU, "GPU"}}; + return mem_to_string.at(where); +} + +std::string Tensor::toString() const +{ + std::string memtype_str = whereToString(); + + static const std::unordered_map type_to_string{ + {TYPE_BOOL, "BOOL"}, + {TYPE_UINT8, "UINT8"}, + {TYPE_UINT16, "UINT16"}, + {TYPE_UINT32, "UINT32"}, + {TYPE_UINT64, "UINT64"}, + {TYPE_INT8, "INT8"}, + {TYPE_INT16, "INT16"}, + {TYPE_INT32, "INT32"}, + {TYPE_INT64, "INT64"}, + {TYPE_BF16, "BF16"}, + {TYPE_FP16, "FP16"}, + {TYPE_FP32, "FP32"}, + {TYPE_FP64, "FP64"}, + {TYPE_BYTES, "BYTES"}, + {TYPE_INVALID, "INVALID"}, + }; + return fmtstr("Tensor[where=%s, type=%s, shape=%s]", + memtype_str.c_str(), + type_to_string.at(type).c_str(), + vec2str(shape).c_str()); +} + +DataType Tensor::typeFromNumpyDesc(std::string type) +{ + static const std::unordered_map type_map{{"?", TYPE_BOOL}, + {"u1", TYPE_UINT8}, + {"u2", TYPE_UINT16}, + {"u4", TYPE_UINT32}, + {"u8", TYPE_UINT64}, + {"i1", TYPE_INT8}, + {"i2", TYPE_INT16}, + {"i4", TYPE_INT32}, + {"i8", TYPE_INT64}, + {"f2", TYPE_FP16}, + {"f4", TYPE_FP32}, + {"f8", TYPE_FP64}}; + return type_map.at(type); +} + +size_t Tensor::getTypeSize(DataType type) +{ + static const std::unordered_map type_map{{TYPE_BOOL, sizeof(bool)}, + {TYPE_UINT8, sizeof(uint8_t)}, + {TYPE_UINT16, sizeof(uint16_t)}, + {TYPE_UINT32, sizeof(uint32_t)}, + {TYPE_UINT64, sizeof(uint64_t)}, + {TYPE_INT8, sizeof(int8_t)}, + {TYPE_INT16, sizeof(int16_t)}, + {TYPE_INT32, sizeof(int32_t)}, + {TYPE_INT64, sizeof(int64_t)}, + {TYPE_FP16, sizeof(half)}, + {TYPE_FP32, sizeof(float)}, + {TYPE_FP64, sizeof(double)}}; + return type_map.at(type); +} + +std::string Tensor::getNumpyTypeDesc(DataType type) const +{ + static const std::unordered_map type_map{{TYPE_BOOL, "?"}, + {TYPE_UINT8, "u1"}, + {TYPE_UINT16, "u2"}, + {TYPE_UINT32, "u4"}, + {TYPE_UINT64, "u8"}, + {TYPE_INT8, "i1"}, + {TYPE_INT16, "i2"}, + {TYPE_INT32, "i4"}, + {TYPE_INT64, "i8"}, + {TYPE_FP16, "f2"}, + {TYPE_FP32, "f4"}, + {TYPE_FP64, "f8"}}; + return type_map.at(type); +} + +void Tensor::saveNpy(const std::string& filename) const +{ + // Save tensor to NPY 1.0 format (see https://numpy.org/neps/nep-0001-npy-format.html) + void* cpu_data = (void*)data; + bool is_data_temp = false; + size_t tensor_size = size(); + if (where == MemoryType::MEMORY_GPU) { + cpu_data = malloc(tensor_size * Tensor::getTypeSize(type)); + is_data_temp = true; + cudaDeviceSynchronize(); + cudaMemcpy(cpu_data, data, tensor_size * Tensor::getTypeSize(type), cudaMemcpyDeviceToHost); + } + + const char magic[] = "\x93" + "NUMPY"; + const uint8_t npy_major = 1; + const uint8_t npy_minor = 0; + + std::stringstream header_stream; + header_stream << "{'descr': '" << getNumpyTypeDesc(type) << "', 'fortran_order': False, 'shape': ("; + for (size_t i = 0; i < shape.size(); ++i) { + header_stream << shape[i]; + if (i + 1 < shape.size() || shape.size() == 1) { + header_stream << ", "; + } + } + header_stream << ")}"; + int base_length = 6 + 4 + header_stream.str().size(); + int pad_length = 16 * ((base_length + 1 + 15) / 16); // Take ceiling of base_length + 1 (for '\n' ending) + for (int i = 0; i < pad_length - base_length; ++i) { + header_stream << ((i == pad_length - base_length - 1) ? "\n" : "\x20"); + } + std::string header = header_stream.str(); + const uint16_t header_len = header.size(); + + FILE* f_ptr = fopen(filename.c_str(), "wb"); + FT_CHECK_WITH_INFO(f_ptr != nullptr, fmtstr("Unable to open %s for writing.\n", filename.c_str())); + + fwrite(magic, sizeof(char), sizeof(magic) - 1, f_ptr); + fwrite(&npy_major, sizeof(uint8_t), 1, f_ptr); + fwrite(&npy_minor, sizeof(uint8_t), 1, f_ptr); + fwrite(&header_len, sizeof(uint16_t), 1, f_ptr); + fwrite(header.c_str(), sizeof(char), header_len, f_ptr); + fwrite(cpu_data, Tensor::getTypeSize(type), tensor_size, f_ptr); + + fclose(f_ptr); + + if (is_data_temp) { + free(cpu_data); + } +} + +Tensor Tensor::slice(std::vector shape, size_t offset) const +{ + if (this->data != nullptr) { + size_t n_elts = this->size(); + size_t n_sliced_elts = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + FT_CHECK_WITH_INFO( + n_sliced_elts + offset <= n_elts, + fmtstr("The number (%ld) of elements of sliced tensor exceeds that (%ld) of the original tensor", + n_sliced_elts + offset, + n_elts)); + } + return Tensor(this->where, this->type, shape, this->getPtrWithOffset(offset)); +} + +TensorMap::TensorMap(const std::unordered_map& tensor_map) +{ + for (auto& kv : tensor_map) { + insert(kv.first, kv.second); + } +} + +TensorMap::TensorMap(const std::vector& tensor_map) +{ + for (size_t i = 0; i < tensor_map.size(); i++) { + insert(std::to_string(i), tensor_map[i]); + } +} + +TensorMap::~TensorMap() +{ + tensor_map_.clear(); +} + +std::vector TensorMap::keys() const +{ + std::vector key_names; + for (auto& kv : tensor_map_) { + key_names.push_back(kv.first); + } + return key_names; +} + +std::string TensorMap::toString() +{ + std::stringstream ss; + ss << "{"; + std::vector key_names = keys(); + for (size_t i = 0; i < tensor_map_.size(); ++i) { + ss << key_names[i] << ": " << at(key_names[i]).toString(); + if (i < tensor_map_.size() - 1) { + ss << ", "; + } + } + ss << "}"; + return ss.str(); +} + +TensorMap TensorMap::fromNpyFolder(const std::string& base_folder) +{ + DIR* dir_p = opendir(base_folder.c_str()); + FT_CHECK_WITH_INFO(dir_p != nullptr, fmtstr("Could not open folder %s. ", base_folder.c_str())); + struct dirent* dp; + + TensorMap ret_tensor; + while ((dp = readdir(dir_p)) != nullptr) { + std::string filename(dp->d_name); + size_t len = filename.length(); + if (len < 4 || filename.compare(len - 4, 4, ".npy")) { + continue; + } + + size_t pos = filename.find('-'); + FT_CHECK_WITH_INFO(pos != std::string::npos, fmtstr("Invalid filename: %s\n", filename.c_str())); + + MemoryType where; + if (filename.compare(0, pos, "GPU") == 0) { + where = MEMORY_GPU; + } + else if (filename.compare(0, pos, "CPU") == 0) { + where = MEMORY_CPU; + } + else if (filename.compare(0, pos, "CPU_PINNED") == 0) { + where = MEMORY_CPU_PINNED; + } + else { + FT_CHECK_WITH_INFO(false, fmtstr("Invalid filename: %s\n", filename.c_str())); + } + std::string key = filename.substr(pos + 1, len - pos - 5); + + ret_tensor.tensor_map_.insert({key, Tensor::loadNpy(base_folder + "/" + filename, where)}); + } + + closedir(dir_p); + + return ret_tensor; +} + +void TensorMap::saveNpy(const std::string& base_folder) +{ + mode_t mode_0755 = S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH; + int ret = mkdir(base_folder.c_str(), mode_0755); + FT_CHECK_WITH_INFO(ret == 0 || errno == EEXIST, fmtstr("Could not create folder %s.\n", base_folder.c_str())); + + for (const auto& item : tensor_map_) { + item.second.saveNpy(base_folder + "/" + item.second.whereToString() + "-" + item.first + ".npy"); + } +} + +} // namespace fastertransformer diff --git a/src/fastertransformer/utils/Tensor.h b/src/fastertransformer/utils/Tensor.h index 83ae91883..6cacb1da5 100644 --- a/src/fastertransformer/utils/Tensor.h +++ b/src/fastertransformer/utils/Tensor.h @@ -23,8 +23,13 @@ #include "stdlib.h" #include #include +#include +#include #include #include +#include +#include +#include #include namespace fastertransformer { @@ -45,6 +50,7 @@ typedef enum datatype_enum { TYPE_FP64, TYPE_BYTES, TYPE_BF16, + TYPE_STR, } DataType; typedef enum memorytype_enum { @@ -54,93 +60,32 @@ typedef enum memorytype_enum { } MemoryType; struct Tensor { - const MemoryType where; - const DataType type; + const MemoryType where; + const DataType type; const std::vector shape; - const void* data; // TODO(bhseuh) modify from const void* to void* const + const void* data; // TODO(bhseuh) modify from const void* to void* const + const std::vector offsets = std::vector{}; - Tensor(const MemoryType _where, const DataType _type, const std::vector _shape, const void* _data): - where(_where), type(_type), shape(_shape), data(_data) - { - } + Tensor(); + Tensor(const MemoryType _where, const DataType _type, const std::vector _shape, const void* _data); + Tensor(const MemoryType _where, + const DataType _type, + const std::vector _shape, + const void* _data, + const std::vector _offset); - size_t size() const - { - size_t n_elements = 1; - for (size_t s : shape) { - n_elements *= s; - } - return n_elements; - } + size_t size() const; + size_t sizeBytes() const; - std::string toString() const - { - std::string memtype_str; - switch (where) { - case MEMORY_CPU: - memtype_str = "CPU"; - break; - case MEMORY_CPU_PINNED: - memtype_str = "CPU_PINNED"; - break; - case MEMORY_GPU: - memtype_str = "GPU"; - break; - } + std::string whereToString() const; + std::string toString() const; + std::string getNumpyTypeDesc(DataType type) const; - std::string dtype_str = ""; - switch (type) { - case TYPE_BOOL: - dtype_str = "BOOL"; - break; - case TYPE_UINT8: - dtype_str = "UINT8"; - break; - case TYPE_UINT16: - dtype_str = "UINT16"; - break; - case TYPE_UINT32: - dtype_str = "UINT32"; - break; - case TYPE_UINT64: - dtype_str = "UINT64"; - break; - case TYPE_INT8: - dtype_str = "INT8"; - break; - case TYPE_INT16: - dtype_str = "INT16"; - break; - case TYPE_INT32: - dtype_str = "INT32"; - break; - case TYPE_INT64: - dtype_str = "INT64"; - break; - case TYPE_BF16: - dtype_str = "BF16"; - break; - case TYPE_FP16: - dtype_str = "FP16"; - break; - case TYPE_FP32: - dtype_str = "FP32"; - break; - case TYPE_FP64: - dtype_str = "FP64"; - break; - case TYPE_BYTES: - dtype_str = "BYTES"; - break; - case TYPE_INVALID: - dtype_str = "INVALID"; - break; - default: - break; - } - return fmtstr( - "Tensor[where=%s, type=%s, shape=%s]", memtype_str.c_str(), dtype_str.c_str(), vec2str(shape).c_str()); - } + void saveNpy(const std::string& filename) const; + static Tensor loadNpy(const std::string& npy_file, const MemoryType where); + + static DataType typeFromNumpyDesc(std::string type); + static size_t getTypeSize(DataType type); template inline T getVal(size_t index) const @@ -170,7 +115,7 @@ struct Tensor { } else { FT_CHECK_WITH_INFO(offset < size(), "offset is larger than buffer size"); - return (void*)((char*)data + offset * getDataTypeByteNum(type)); + return (void*)((char*)data + offset * Tensor::getTypeSize(type)); } } @@ -186,129 +131,75 @@ struct Tensor { } } - std::string getNumpyTypeDesc(DataType type) const + template + T max() const { - switch (type) { - case TYPE_BOOL: - return "?"; - case TYPE_UINT8: - return "u1"; - case TYPE_UINT16: - return "u2"; - case TYPE_UINT32: - return "u4"; - case TYPE_UINT64: - return "u8"; - case TYPE_INT8: - return "i1"; - case TYPE_INT16: - return "i2"; - case TYPE_INT32: - return "i4"; - case TYPE_INT64: - return "i8"; - case TYPE_FP16: - return "f2"; - case TYPE_FP32: - return "f4"; - case TYPE_FP64: - return "f8"; - case TYPE_INVALID: - default:; + FT_CHECK_WITH_INFO(shape.size() > 0 && data != nullptr, "Should be a non-empty tensor."); + FT_CHECK_WITH_INFO(where == MEMORY_CPU || where == MEMORY_CPU_PINNED, + "max() supports MEMORY_CPU or MEMORY_CPU_PINNED tensor."); + size_t max_idx = 0; + T max_val = getVal(max_idx); + for (size_t i = 1; i < size(); ++i) { + T val = getVal(i); + if (val > max_val) { + max_idx = i; + max_val = val; + } } - return ""; + return max_val; } - int getDataTypeByteNum(DataType type) const + template + T min() const { - switch (type) { - case TYPE_BOOL: - return 1; - case TYPE_UINT8: - return 1; - case TYPE_UINT16: - return 2; - case TYPE_UINT32: - return 4; - case TYPE_UINT64: - return 8; - case TYPE_INT8: - return 1; - case TYPE_INT16: - return 2; - case TYPE_INT32: - return 4; - case TYPE_INT64: - return 8; - case TYPE_FP16: - return 2; - case TYPE_BF16: - return 2; - case TYPE_FP32: - return 4; - case TYPE_FP64: - return 8; - case TYPE_INVALID: - FT_CHECK(false); - default: - FT_CHECK(false); + FT_CHECK_WITH_INFO(shape.size() > 0 && data != nullptr, "Should be a non-empty tensor."); + FT_CHECK_WITH_INFO(where == MEMORY_CPU || where == MEMORY_CPU_PINNED, + "min() supports MEMORY_CPU or MEMORY_CPU_PINNED tensor."); + size_t min_idx = 0; + T min_val = getVal(min_idx); + for (size_t i = 1; i < size(); ++i) { + T val = getVal(i); + if (val < min_val) { + min_idx = i; + min_val = val; + } } + return min_val; } template - void save(const std::string& filename) const + T any(T val) const { - // Save tensor to NPY 1.0 format (see https://numpy.org/neps/nep-0001-npy-format.html) - void* cpu_data = (void*)data; - bool is_data_temp = false; - size_t tensor_size = size(); - if (where == MemoryType::MEMORY_GPU) { - cpu_data = malloc(tensor_size * sizeof(T)); - is_data_temp = true; - cudaDeviceSynchronize(); - cudaMemcpy(cpu_data, data, tensor_size * sizeof(T), cudaMemcpyDeviceToHost); - } - - const char magic[] = "\x93" - "NUMPY"; - const uint8_t npy_major = 1; - const uint8_t npy_minor = 0; - - std::stringstream header_stream; - header_stream << "{'descr': '" << getNumpyTypeDesc(type) << "', 'fortran_order': False, 'shape': ("; - for (size_t i = 0; i < shape.size(); ++i) { - header_stream << shape[i]; - if (i + 1 < shape.size() || shape.size() == 1) { - header_stream << ", "; + FT_CHECK_WITH_INFO(shape.size() > 0 && data != nullptr, "Should be a non-empty tensor."); + FT_CHECK_WITH_INFO(where == MEMORY_CPU || where == MEMORY_CPU_PINNED, + "any() supports MEMORY_CPU or MEMORY_CPU_PINNED tensor."); + for (size_t i = 0; i < size(); ++i) { + if (getVal(i) == val) { + return true; } } - header_stream << ")}"; - int base_length = 6 + 4 + header_stream.str().size(); - int pad_length = 16 * ((base_length + 1 + 15) / 16); // Take ceiling of base_length + 1 (for '\n' ending) - for (int i = 0; i < pad_length - base_length; ++i) { - header_stream << ((i == pad_length - base_length - 1) ? "\n" : "\x20"); - } - std::string header = header_stream.str(); - const uint16_t header_len = header.size(); + return false; + } - FILE* f_ptr = fopen(filename.c_str(), "wb"); - if (f_ptr == nullptr) { - printf("Unable to open %s for writing.\n", filename.c_str()); - exit(-1); + template + T all(T val) const + { + FT_CHECK_WITH_INFO(shape.size() > 0 && data != nullptr, "Should be a non-empty tensor."); + FT_CHECK_WITH_INFO(where == MEMORY_CPU || where == MEMORY_CPU_PINNED, + "all() supports MEMORY_CPU or MEMORY_CPU_PINNED tensor."); + for (size_t i = 0; i < size(); ++i) { + if (getVal(i) != val) { + return false; + } } - fwrite(magic, sizeof(char), sizeof(magic) - 1, f_ptr); - fwrite(&npy_major, sizeof(uint8_t), 1, f_ptr); - fwrite(&npy_minor, sizeof(uint8_t), 1, f_ptr); - fwrite(&header_len, sizeof(uint16_t), 1, f_ptr); - fwrite(header.c_str(), sizeof(char), header_len, f_ptr); - fwrite(cpu_data, sizeof(T), tensor_size, f_ptr); + return true; + } - fclose(f_ptr); + Tensor slice(std::vector shape, size_t offset = 0) const; - if (is_data_temp) { - free(cpu_data); - } - } +private: + static void parseNpyIntro(FILE*& f_ptr, uint32_t& header_len, uint32_t& start_data); + static int parseNpyHeader(FILE*& f_ptr, uint32_t header_len, DataType& type, std::vector& shape); }; template @@ -331,9 +222,139 @@ DataType getTensorType() else if (std::is_same::value) { return TYPE_INT8; } + else if (std::is_same::value) { + return TYPE_UINT32; + } + else if (std::is_same::value) { + return TYPE_UINT64; + } + else if (std::is_same::value) { + return TYPE_BOOL; + } else { return TYPE_INVALID; } } +class TensorMap { +private: + std::unordered_map tensor_map_; + + inline bool isValid(const Tensor& tensor) + { + return tensor.size() > 0 && tensor.data != nullptr; + } + +public: + TensorMap() = default; + TensorMap(const std::unordered_map& tensor_map); + TensorMap(const std::vector& tensor_map); + ~TensorMap(); + + inline size_t size() const + { + return tensor_map_.size(); + } + + inline bool isExist(const std::string& key) const + { + return tensor_map_.find(key) != tensor_map_.end(); + } + + std::vector keys() const; + + inline void insert(const std::string& key, const Tensor& value) + { + FT_CHECK_WITH_INFO(!isExist(key), fmtstr("Duplicated key %s", key.c_str())); + FT_CHECK_WITH_INFO(isValid(value), "A none tensor or nullptr is not allowed"); + tensor_map_.insert({key, value}); + } + + // prevent converting int or size_t to string automatically + Tensor at(int tmp) = delete; + Tensor at(size_t tmp) = delete; + + inline Tensor& at(const std::string& key) + { + FT_CHECK_WITH_INFO(isExist(key), + fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", + key.c_str(), + vec2str(keys()).c_str())); + return tensor_map_.at(key); + } + + inline Tensor& at(const std::string& key, Tensor& default_tensor) + { + if (isExist(key)) { + return tensor_map_.at(key); + } + return default_tensor; + } + + inline Tensor& at(const std::string& key, Tensor&& default_tensor) + { + if (isExist(key)) { + return tensor_map_.at(key); + } + return default_tensor; + } + + template + inline T getVal(const std::string& key) const + { + FT_CHECK_WITH_INFO(isExist(key), + fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", + key.c_str(), + vec2str(keys()).c_str())); + return tensor_map_.at(key).getVal(); + } + + template + inline T getVal(const std::string& key, T default_value) const + { + if (isExist(key)) { + return tensor_map_.at(key).getVal(); + } + return default_value; + } + + template + inline T getValWithOffset(const std::string& key, size_t index) const + { + FT_CHECK_WITH_INFO(isExist(key), + fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", + key.c_str(), + vec2str(keys()).c_str())); + return tensor_map_.at(key).getVal(index); + } + + template + inline T getValWithOffset(const std::string& key, size_t index, T default_value) const + { + if (isExist(key)) { + return tensor_map_.at(key).getVal(index); + } + return default_value; + } + + inline std::unordered_map getMap() const + { + return tensor_map_; + } + + inline std::unordered_map::iterator begin() + { + return tensor_map_.begin(); + } + + inline std::unordered_map::iterator end() + { + return tensor_map_.end(); + } + + std::string toString(); + static TensorMap fromNpyFolder(const std::string& base_folder); + void saveNpy(const std::string& base_folder); +}; + } // namespace fastertransformer diff --git a/src/fastertransformer/utils/allocator.h b/src/fastertransformer/utils/allocator.h index f66f4dda2..755065add 100644 --- a/src/fastertransformer/utils/allocator.h +++ b/src/fastertransformer/utils/allocator.h @@ -43,6 +43,10 @@ #include "src/fastertransformer/utils/logger.h" +#if defined(CUDART_VERSION) && CUDART_VERSION < 11020 +#define CUDA_MEMORY_POOL_DISABLED +#endif + namespace fastertransformer { enum class AllocatorType { @@ -51,24 +55,40 @@ enum class AllocatorType { TH }; +enum class ReallocType { + INCREASE, + REUSE, + DECREASE, +}; + class IAllocator { public: - virtual void* malloc(size_t size, const bool is_set_zero = true) = 0; - virtual void free(void* ptr) const = 0; - virtual void setStream(cudaStream_t stream) = 0; + virtual void* malloc(size_t size, const bool is_set_zero = true) = 0; + virtual void free(void** ptr) const = 0; + virtual void setStream(cudaStream_t stream) = 0; + virtual cudaStream_t returnStream() = 0; template void* reMalloc(T* ptr, size_t size, const bool is_set_zero = true) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - void* void_ptr = (void*)ptr; + size = ((size + 31) / 32) * 32; // make the buffer align with 32 bytes + void* void_ptr = (void*)ptr; std::string ptr_address = getAddress(void_ptr); if (isExist(ptr_address)) { - if (isReMalloc(ptr_address, size)) { + ReallocType realloc_type = isReMalloc(ptr_address, size); + if (realloc_type == ReallocType::INCREASE) { FT_LOG_DEBUG("ReMalloc the buffer %p since it is too small.", void_ptr); - free(void_ptr); + free((void**)(&void_ptr)); + return malloc(size, is_set_zero); + } +#if !defined(CUDA_MEMORY_POOL_DISABLED) + else if (realloc_type == ReallocType::DECREASE) { + FT_LOG_DEBUG("ReMalloc the buffer %p to release unused memory to memory pools.", void_ptr); + free((void**)(&void_ptr)); return malloc(size, is_set_zero); } +#endif else { FT_LOG_DEBUG("Reuse original buffer %p and do nothing for reMalloc.", void_ptr); return void_ptr; @@ -81,8 +101,8 @@ class IAllocator { } protected: - virtual bool isExist(std::string address) const = 0; - virtual bool isReMalloc(std::string address, size_t size) const = 0; + virtual bool isExist(std::string address) const = 0; + virtual ReallocType isReMalloc(std::string address, size_t size) const = 0; std::string getAddress(void* ptr) const { @@ -99,22 +119,25 @@ class Allocator; template<> class Allocator: public IAllocator { private: - const int device_id_; - cudaStream_t stream_ = 0; // initialize as default stream + const int device_id_; + cudaStream_t stream_ = 0; // initialize as default stream std::unordered_map>* pointer_mapping_; bool isExist(std::string address) const { return pointer_mapping_->count(address) > 0; } - bool isReMalloc(std::string address, size_t size) const + ReallocType isReMalloc(std::string address, size_t size) const { FT_CHECK(isExist(address)); if (pointer_mapping_->at(address).second < size) { - return true; + return ReallocType::INCREASE; + } + else if (pointer_mapping_->at(address).second == size) { + return ReallocType::REUSE; } else { - return false; + return ReallocType::DECREASE; } } @@ -123,33 +146,35 @@ class Allocator: public IAllocator { { FT_LOG_DEBUG(__PRETTY_FUNCTION__); pointer_mapping_ = new std::unordered_map>(); -#if defined(CUDART_VERSION) && CUDART_VERSION < 11020 +#if defined(CUDA_MEMORY_POOL_DISABLED) FT_LOG_WARNING( "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free." "Note this may lead to hang with NCCL kernels launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP"); #else int device_count = 1; - cudaGetDeviceCount(&device_count); + check_cuda_error(cudaGetDeviceCount(&device_count)); cudaMemPool_t mempool; - cudaDeviceGetMemPool(&mempool, device_id); - cudaMemAccessDesc desc = {}; - int peer_access_available = 0; + check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); + cudaMemAccessDesc desc = {}; + int peer_access_available = 0; for (int i = 0; i < device_count; i++) { if (i == device_id) { continue; } - cudaDeviceCanAccessPeer(&peer_access_available, device_id, i); + check_cuda_error(cudaDeviceCanAccessPeer(&peer_access_available, device_id, i)); if (!peer_access_available) { - FT_LOG_WARNING( - "Device " + std::to_string(device_id) + " peer access Device " + std::to_string(i) - + " is not avaiable. This may lead to peer access errors when doing tensor/pipeline parallel!"); + FT_LOG_WARNING("Device " + std::to_string(device_id) + " peer access Device " + std::to_string(i) + + " is not available."); continue; } desc.location.type = cudaMemLocationTypeDevice; - desc.location.id = i; - desc.flags = cudaMemAccessFlagsProtReadWrite; - cudaMemPoolSetAccess(mempool, &desc, 1); + desc.location.id = i; + desc.flags = cudaMemAccessFlagsProtReadWrite; + check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); } + // set memory pool threshold to avoid shrinking the pool + uint64_t setVal = UINT64_MAX; + check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); #endif } @@ -157,7 +182,7 @@ class Allocator: public IAllocator { { FT_LOG_DEBUG(__PRETTY_FUNCTION__); while (!pointer_mapping_->empty()) { - free(pointer_mapping_->begin()->second.first); + free((void**)(&pointer_mapping_->begin()->second.first)); } delete pointer_mapping_; } @@ -167,20 +192,25 @@ class Allocator: public IAllocator { stream_ = stream; } + cudaStream_t returnStream() + { + return stream_; + }; + void* malloc(size_t size, const bool is_set_zero = true) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (size == 0) { return nullptr; } - void* ptr = nullptr; - int o_device = 0; + void* ptr = nullptr; + int o_device = 0; check_cuda_error(getSetDevice(device_id_, &o_device)); -#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020 - check_cuda_error(cudaMallocAsync(&ptr, (size_t)(ceil(size / 32.)) * 32, stream_)); -#else +#if defined(CUDA_MEMORY_POOL_DISABLED) check_cuda_error(cudaMalloc(&ptr, (size_t)(ceil(size / 32.)) * 32)); +#else + check_cuda_error(cudaMallocAsync(&ptr, (size_t)(ceil(size / 32.)) * 32, stream_)); #endif check_cuda_error(getSetDevice(o_device)); FT_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size); @@ -190,20 +220,19 @@ class Allocator: public IAllocator { return ptr; } - void free(void* ptr) const + void free(void** ptr) const { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - std::string address = getAddress(ptr); - if (ptr != nullptr) { + std::string address = getAddress(*ptr); + if (*ptr != nullptr) { int o_device = 0; - if (pointer_mapping_->count(address)) { FT_LOG_DEBUG("Free buffer %s", address.c_str()); check_cuda_error(getSetDevice(device_id_, &o_device)); -#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020 - check_cuda_error(cudaFreeAsync(ptr, stream_)); +#if defined(CUDA_MEMORY_POOL_DISABLED) + check_cuda_error(cudaFree(*ptr)); #else - check_cuda_error(cudaFree(ptr)); + check_cuda_error(cudaFreeAsync(*ptr, stream_)); #endif check_cuda_error(getSetDevice(o_device)); pointer_mapping_->erase(address); @@ -212,6 +241,7 @@ class Allocator: public IAllocator { FT_LOG_WARNING("pointer_mapping_ does not have information of ptr at %s.", address.c_str()); } } + *ptr = nullptr; return; } }; @@ -220,15 +250,15 @@ class Allocator: public IAllocator { using namespace tensorflow; template<> class Allocator: public IAllocator { - OpKernelContext* context_; + OpKernelContext* context_; std::unordered_map* pointer_mapping_; - cudaStream_t stream_; + cudaStream_t stream_; bool isExist(std::string address) const { return pointer_mapping_->count(address) > 0; } - bool isReMalloc(std::string address, size_t size) const + ReallocType isReMalloc(std::string address, size_t size) const { FT_CHECK(isExist(address)); size_t current_buffer_size = 1; @@ -237,10 +267,13 @@ class Allocator: public IAllocator { } FT_LOG_DEBUG("current_buffer_size: %d, new buffer: %d", current_buffer_size, size); if (current_buffer_size < size) { - return true; + return ReallocType::INCREASE; + } + else if (current_buffer_size == size) { + return ReallocType::REUSE; } else { - return false; + return ReallocType::DECREASE; } } @@ -255,19 +288,24 @@ class Allocator: public IAllocator { stream_ = stream; } + cudaStream_t returnStream() + { + return stream_; + }; + void* malloc(size_t size, const bool is_set_zero = true) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); tensorflow::Tensor buf; - long long int buf_size = ((long long int)ceil(size / 32.) * 32); - tensorflow::Status status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf); + long long int buf_size = ((long long int)ceil(size / 32.) * 32); + tensorflow::Status status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf); if (status != tensorflow::Status::OK()) { throw std::runtime_error("TF error: context->allocate_temp failed"); } - auto flat = buf.flat(); - void* ptr = (void*)flat.data(); + auto flat = buf.flat(); + void* ptr = (void*)flat.data(); if (is_set_zero == true) { cudaMemsetAsync(ptr, 0, buf_size, stream_); } @@ -276,18 +314,20 @@ class Allocator: public IAllocator { return ptr; } - void free(void* ptr) const + void free(void** ptr) const { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - std::string address = getAddress(ptr); + std::string address = getAddress(*ptr); pointer_mapping_->erase(address); + *ptr = nullptr; return; } virtual ~Allocator() { while (!pointer_mapping_->empty()) { - free((void*)pointer_mapping_->begin()->second.flat().data()); + void* ptr = pointer_mapping_->begin()->second.flat().data(); + free((void**)(&ptr)); } pointer_mapping_->clear(); delete pointer_mapping_; @@ -304,7 +344,7 @@ class Allocator: public IAllocator { { return pointer_mapping_->count(address) > 0; } - bool isReMalloc(std::string address, size_t size) const + ReallocType isReMalloc(std::string address, size_t size) const { FT_CHECK(isExist(address)); size_t current_buffer_size = 1; @@ -314,10 +354,13 @@ class Allocator: public IAllocator { FT_LOG_DEBUG( "current_buffer_size: %d, original buffer: %p, new buffer: %d", current_buffer_size, address, size); if (current_buffer_size < size) { - return true; + return ReallocType::INCREASE; + } + else if (current_buffer_size == size) { + return ReallocType::REUSE; } else { - return false; + return ReallocType::DECREASE; } } @@ -332,6 +375,12 @@ class Allocator: public IAllocator { // nothing to do here; } + cudaStream_t returnStream() + { + // nothing to do here; + return 0; + }; + void* malloc(size_t size, const bool is_set_zero = true) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -341,17 +390,18 @@ class Allocator: public IAllocator { // torch::zeros({buf_size}, torch::dtype(torch::kUInt8).device(torch::kCUDA)) : // torch::empty({buf_size}, torch::dtype(torch::kUInt8).device(torch::kCUDA)); torch::Tensor buf = torch::empty({buf_size}, torch::dtype(torch::kUInt8).device(torch::kCUDA)); - void* ptr = buf.data_ptr(); + void* ptr = buf.data_ptr(); FT_LOG_DEBUG("malloc buffer %p with size %ld", ptr, buf_size); pointer_mapping_->insert({getAddress(ptr), buf}); return ptr; } - void free(void* ptr) const + void free(void** ptr) const { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - std::string address = getAddress(ptr); + std::string address = getAddress(*ptr); pointer_mapping_->erase(address); + *ptr = nullptr; return; } @@ -359,11 +409,12 @@ class Allocator: public IAllocator { { FT_LOG_DEBUG(__PRETTY_FUNCTION__); while (!pointer_mapping_->empty()) { - free(pointer_mapping_->begin()->second.data_ptr()); + void* ptr = pointer_mapping_->begin()->second.data_ptr(); + free((void**)(&ptr)); } pointer_mapping_->clear(); delete pointer_mapping_; } }; #endif -} // namespace fastertransformer +} // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/utils/conv2d.h b/src/fastertransformer/utils/conv2d.h index d0f996913..75955f4b5 100644 --- a/src/fastertransformer/utils/conv2d.h +++ b/src/fastertransformer/utils/conv2d.h @@ -26,34 +26,39 @@ namespace fastertransformer { template -void conv2d(T* output, - const T* input, - const T* kernel, - const int batch, - const int h, - const int w, - const int in_channels, - const int out_channels, - const int kernel_size, - const int stride, +void conv2d(T* output, + const T* input, + const T* kernel, + const int batch, + const int h, + const int w, + const int in_channels, + const int out_channels, + const int kernel_size, + const int stride, cudnnHandle_t& cudnn_handle) { cudnnDataType_t dataType; cudnnDataType_t computeType = CUDNN_DATA_FLOAT; - float alpha = 1.0f; - float beta = 0.0f; + float alpha = 1.0f; + float beta = 0.0f; if (std::is_same::value) { dataType = CUDNN_DATA_HALF; } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + dataType = CUDNN_DATA_BFLOAT16; + } +#endif else { dataType = CUDNN_DATA_FLOAT; } - cudnnTensorDescriptor_t input_descriptor_; - cudnnTensorDescriptor_t output_descriptor_; - cudnnFilterDescriptor_t kernel_descriptor_; + cudnnTensorDescriptor_t input_descriptor_; + cudnnTensorDescriptor_t output_descriptor_; + cudnnFilterDescriptor_t kernel_descriptor_; cudnnConvolutionDescriptor_t convolution_descriptor_; - cudnnConvolutionFwdAlgo_t convolution_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; + cudnnConvolutionFwdAlgo_t convolution_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; // cudnnConvolutionFwdAlgo_t convolution_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; // cudnnConvolutionFwdAlgo_t convolution_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_GEMM; // cudnnConvolutionFwdAlgo_t convolution_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_DIRECT; diff --git a/src/fastertransformer/utils/convert_data_type.h b/src/fastertransformer/utils/convert_data_type.h index af58254e2..7bbdb0c04 100644 --- a/src/fastertransformer/utils/convert_data_type.h +++ b/src/fastertransformer/utils/convert_data_type.h @@ -21,7 +21,7 @@ // be consistent with FasterTransformer int8_t float_to_int8_rn_host(float x) { - int8_t res; + int8_t res; int32_t tmp; if (x >= 0) { tmp = int(x + 0.5); diff --git a/src/fastertransformer/utils/cublasAlgoMap.cc b/src/fastertransformer/utils/cublasAlgoMap.cc index 37dd1f9b9..13d14bc55 100644 --- a/src/fastertransformer/utils/cublasAlgoMap.cc +++ b/src/fastertransformer/utils/cublasAlgoMap.cc @@ -47,11 +47,11 @@ void cublasAlgoMap::loadGemmConfig() return; } - int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val; - int batch_size, seq_len, head_num, size_per_head, dataType; - int swizzle, reductionScheme, workspaceSize, stages; + int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val; + int batch_size, seq_len, head_num, size_per_head, dataType; + int swizzle, reductionScheme, workspaceSize, stages; float exec_time; - char tmp[1024]; + char tmp[1024]; if (!fgets(tmp, 1024, fd)) { printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); exit(-1); @@ -87,15 +87,15 @@ void cublasAlgoMap::loadGemmConfig() std::string markStr(mark); // workspaceSize should be zero if (algo_map_.find(markStr) == algo_map_.end()) { - algo_map_[markStr].algoId = algoId; - algo_map_[markStr].customOption = customOption; - algo_map_[markStr].tile = tile; - algo_map_[markStr].splitK_val = splitK_val; - algo_map_[markStr].swizzle = swizzle; + algo_map_[markStr].algoId = algoId; + algo_map_[markStr].customOption = customOption; + algo_map_[markStr].tile = tile; + algo_map_[markStr].splitK_val = splitK_val; + algo_map_[markStr].swizzle = swizzle; algo_map_[markStr].reductionScheme = reductionScheme; - algo_map_[markStr].workspaceSize = workspaceSize; - algo_map_[markStr].stages = stages; - algo_map_[markStr].exec_time = exec_time; + algo_map_[markStr].workspaceSize = workspaceSize; + algo_map_[markStr].stages = stages; + algo_map_[markStr].exec_time = exec_time; } } fclose(fd); @@ -121,14 +121,14 @@ cublasAlgoMap::getAlgo(const int batch_count, const int m, const int n, const in cublasLtMatmulAlgo_info tmp_algo; tmp_algo.algoId = static_cast(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP); - tmp_algo.customOption = -1; - tmp_algo.tile = -1; - tmp_algo.splitK_val = -1; - tmp_algo.swizzle = -1; + tmp_algo.customOption = -1; + tmp_algo.tile = -1; + tmp_algo.splitK_val = -1; + tmp_algo.swizzle = -1; tmp_algo.reductionScheme = -1; - tmp_algo.workspaceSize = -1; - tmp_algo.stages = -1; - tmp_algo.exec_time = -1.0f; + tmp_algo.workspaceSize = -1; + tmp_algo.stages = -1; + tmp_algo.exec_time = -1.0f; return tmp_algo; } } @@ -144,10 +144,10 @@ void cublasAlgoMap::loadSpGemmConfig() return; } sp_algo_map_.clear(); - int batch_size, seq_len, head_num, size_per_head, data_type; - int batchCount, m, n, k, algoId; + int batch_size, seq_len, head_num, size_per_head, data_type; + int batchCount, m, n, k, algoId; float exec_time; - char tmp[1024]; + char tmp[1024]; if (!fgets(tmp, 1024, fd)) { printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); exit(-1); diff --git a/src/fastertransformer/utils/cublasAlgoMap.h b/src/fastertransformer/utils/cublasAlgoMap.h index 668ae28fa..8fff493aa 100644 --- a/src/fastertransformer/utils/cublasAlgoMap.h +++ b/src/fastertransformer/utils/cublasAlgoMap.h @@ -35,28 +35,28 @@ typedef struct { int algoId, customOption, tile, splitK_val; int swizzle, reductionScheme, workspaceSize; // only used in cublasLt >= 11.0 - int stages; + int stages; float exec_time; } cublasLtMatmulAlgo_info; /* Structure to store information about different run trials */ typedef struct { - cublasLtMatmulAlgo_t algo; - cublasStatus_t status; - float time; - size_t workspaceSize; // actual memory workspace needed - cublasMath_t mathMode; + cublasLtMatmulAlgo_t algo; + cublasStatus_t status; + float time; + size_t workspaceSize; // actual memory workspace needed + cublasMath_t mathMode; cublasLtReductionScheme_t reductionScheme; - int customOption; - float wavesCount; + int customOption; + float wavesCount; } customMatmulPerf_t; class cublasAlgoMap { private: std::map algo_map_; - std::string config_filename_; - std::string sp_config_filename_; - std::map sp_algo_map_; + std::string config_filename_; + std::string sp_config_filename_; + std::map sp_algo_map_; public: explicit cublasAlgoMap(const std::string filename, const std::string sp_config_filename = ""); @@ -64,7 +64,7 @@ class cublasAlgoMap { ~cublasAlgoMap(); void loadGemmConfig(); void loadSpGemmConfig(); - int getSpAlgo(const int batch_count, const int m, const int n, const int k); + int getSpAlgo(const int batch_count, const int m, const int n, const int k); bool isUseSparse(const int batch_count, const int m, const int n, const int k); bool isExist(const int batch_count, const int m, const int n, const int k, const CublasDataType data_type); diff --git a/src/fastertransformer/utils/cublasINT8MMWrapper.cc b/src/fastertransformer/utils/cublasINT8MMWrapper.cc index 580299efa..4bcf8d48b 100644 --- a/src/fastertransformer/utils/cublasINT8MMWrapper.cc +++ b/src/fastertransformer/utils/cublasINT8MMWrapper.cc @@ -22,33 +22,33 @@ namespace fastertransformer { cublasINT8MMWrapper::cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle, - cudaStream_t stream, - cublasAlgoMap* cublas_algo_map, - std::mutex* mu, - bool use_ORDER_COL32_2R_4R4): + cudaStream_t stream, + cublasAlgoMap* cublas_algo_map, + std::mutex* mu, + bool use_ORDER_COL32_2R_4R4): cublasMMWrapper(nullptr, cublaslt_handle, stream, cublas_algo_map, mu, nullptr), use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) { } -cublasINT8MMWrapper::cublasINT8MMWrapper(cublasHandle_t cublas_handle, +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, - cudaStream_t stream, - cublasAlgoMap* cublas_algo_map, - std::mutex* mu, - bool use_ORDER_COL32_2R_4R4): + cudaStream_t stream, + cublasAlgoMap* cublas_algo_map, + std::mutex* mu, + bool use_ORDER_COL32_2R_4R4): cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, mu, nullptr), use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) { } #ifdef SPARSITY_ENABLED -cublasINT8MMWrapper::cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle, +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle, cusparseLtHandle_t cusparselt_handle, - cudaStream_t stream, - cublasAlgoMap* cublas_algo_map, - std::mutex* mu, - bool use_ORDER_COL32_2R_4R4): + cudaStream_t stream, + cublasAlgoMap* cublas_algo_map, + std::mutex* mu, + bool use_ORDER_COL32_2R_4R4): cublasMMWrapper(nullptr, cublaslt_handle, cusparselt_handle, stream, cublas_algo_map, mu, nullptr), use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) { @@ -81,14 +81,14 @@ cublasINT8MMWrapper::cublasINT8MMWrapper(const cublasINT8MMWrapper& wrapper): // ATransform should be m*n, CUBLASLT_ORDER_COL32 // kernel should be n*k, CUBLASLT_ORDER_COL4_4R2_8C or CUBLASLT_ORDER_COL32_2R_4R4 // res is m*n, CUBLASLT_ORDER_COL32 -void cublasINT8MMWrapper::Gemm(int* res, - int batchCount, - int m, - int n, - int k, - int64_t stridea, - int64_t strideb, - int64_t stridec, +void cublasINT8MMWrapper::Gemm(int* res, + int batchCount, + int m, + int n, + int k, + int64_t stridea, + int64_t strideb, + int64_t stridec, const int8_t* ATransform, const int8_t* kernel) { @@ -99,11 +99,11 @@ void cublasINT8MMWrapper::Gemm(int* res, #else cudaDataType_t computeType = CUDA_R_32I; #endif - cublasLtMatmulDesc_t matmulDesc; + cublasLtMatmulDesc_t matmulDesc; cublasLtMatrixLayout_t AtransformDesc = NULL; cublasLtMatrixLayout_t BtransformDesc = NULL; cublasLtMatrixLayout_t CtransformDesc = NULL; - cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_matrixB; #if (CUDART_VERSION >= 11000) @@ -114,7 +114,7 @@ void cublasINT8MMWrapper::Gemm(int* res, order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; } #else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; #endif int ldaTransform = 32 * m; @@ -157,11 +157,11 @@ void cublasINT8MMWrapper::Gemm(int* res, } int alphaI = 1; - int betaI = 0; + int betaI = 0; // get algo cublasLtMatmulAlgo_t algo; - int findAlgo = 0; + int findAlgo = 0; if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { // printf("find algo %s\n", markStr.c_str()); findAlgo = 1; @@ -201,10 +201,10 @@ void cublasINT8MMWrapper::Gemm(int* res, else { algoId = 6; } - int swizzle = 0; - int customOption = 0; - int tile = 20; - int splitK_val = 0; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; int reductionScheme = 0; cublasLtMatmulAlgoInit( cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); @@ -256,15 +256,15 @@ void cublasINT8MMWrapper::Gemm(int* res, // ATransform should be m*k CUBLASLT_ORDER_COL32 // kernel should be n*k CUBLASLT_ORDER_COL4_4R2_8C // res is m*n CUBLASLT_ORDER_COL32 -void cublasINT8MMWrapper::Gemm(int8_t* res, - int batchCount, - int m, - int n, - int k, - int64_t stridea, - int64_t strideb, - int64_t stridec, - const float alpha, +void cublasINT8MMWrapper::Gemm(int8_t* res, + int batchCount, + int m, + int n, + int k, + int64_t stridea, + int64_t strideb, + int64_t stridec, + const float alpha, const int8_t* ATransform, const int8_t* kernel) { @@ -278,11 +278,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, #else cudaDataType_t computeType = CUDA_R_32I; #endif - cublasLtMatmulDesc_t matmulDesc; + cublasLtMatmulDesc_t matmulDesc; cublasLtMatrixLayout_t AtransformDesc = NULL; cublasLtMatrixLayout_t BtransformDesc = NULL; cublasLtMatrixLayout_t CtransformDesc = NULL; - cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_matrixB; #if (CUDART_VERSION >= 11000) @@ -293,7 +293,7 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; } #else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; #endif int ldaTransform = 32 * m; @@ -342,7 +342,7 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, // get algo cublasLtMatmulAlgo_t algo; - int findAlgo = 0; + int findAlgo = 0; if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { findAlgo = 1; @@ -381,10 +381,10 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, else { algoId = 6; } - int swizzle = 0; - int customOption = 0; - int tile = 20; - int splitK_val = 0; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; int reductionScheme = 0; cublasLtMatmulAlgoInit( cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); @@ -437,10 +437,10 @@ template int cublasINT8MMWrapper::getFusedINT8QKVType(const int k, const int n, const AttentionWeight* attention_weights) { - int fusedINT8QKV_type = 0; - const int8_t* Q_weight = (const int8_t*)(attention_weights->query_weight.kernel); - const int8_t* K_weight = (const int8_t*)(attention_weights->key_weight.kernel); - const int8_t* V_weight = (const int8_t*)(attention_weights->value_weight.kernel); + int fusedINT8QKV_type = 0; + const int8_t* Q_weight = (const int8_t*)(attention_weights->query_weight.kernel); + const int8_t* K_weight = (const int8_t*)(attention_weights->key_weight.kernel); + const int8_t* V_weight = (const int8_t*)(attention_weights->value_weight.kernel); // for QKV weight are DataType_ & continue if ((attention_weights->query_weight.kernel + n * k == attention_weights->key_weight.kernel) && (attention_weights->key_weight.kernel + n * k == attention_weights->value_weight.kernel)) { @@ -470,29 +470,29 @@ cublasINT8MMWrapper::getFusedINT8QKVType(const int k, const int n, const Attenti void cublasINT8MMWrapper::SpGemm( const int m, const int n, const int k, const float alpha, const void* A, const void* B, void* C) { - cudaDataType_t Atype = CUDA_R_8I; - cudaDataType_t Btype = CUDA_R_8I; - cudaDataType_t Ctype = CUDA_R_8I; - cusparseComputeType compute_type = CUSPARSE_COMPUTE_32I; - cusparseOrder_t col_order = CUSPARSE_ORDER_COL; - cusparseOrder_t row_order = CUSPARSE_ORDER_ROW; - cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; - cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; - cusparseLtMatmulDescriptor_t matmul; + cudaDataType_t Atype = CUDA_R_8I; + cudaDataType_t Btype = CUDA_R_8I; + cudaDataType_t Ctype = CUDA_R_8I; + cusparseComputeType compute_type = CUSPARSE_COMPUTE_32I; + cusparseOrder_t col_order = CUSPARSE_ORDER_COL; + cusparseOrder_t row_order = CUSPARSE_ORDER_ROW; + cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulAlgSelection_t alg_sel; - cusparseLtMatmulPlan_t plan; - - auto num_A_rows = m; - auto num_A_cols = k; - auto num_B_rows = k; - auto num_B_cols = n; - auto num_C_rows = m; - auto num_C_cols = n; - unsigned alignment = 16; - auto lda = num_A_cols; - auto ldb = num_B_rows; - auto ldc = num_C_rows; - float _beta(0.0f); + cusparseLtMatmulPlan_t plan; + + auto num_A_rows = m; + auto num_A_cols = k; + auto num_B_rows = k; + auto num_B_cols = n; + auto num_C_rows = m; + auto num_C_cols = n; + unsigned alignment = 16; + auto lda = num_A_cols; + auto ldb = num_B_rows; + auto ldc = num_C_rows; + float _beta(0.0f); char mark[256]; sprintf(mark, "%d_%d_%d_%d", 1, m, n, k); @@ -546,9 +546,9 @@ void cublasINT8MMWrapper::SpGemm( CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&cusparselt_handle_, &alg_sel, &workspace_size)) CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&cusparselt_handle_, &plan, &matmul, &alg_sel, workspace_size)) - void* d_workspace = nullptr; - int num_streams = 1; - cudaStream_t streams[1] = {stream_}; + void* d_workspace = nullptr; + int num_streams = 1; + cudaStream_t streams[1] = {stream_}; CHECK_CUSPARSE( cusparseLtMatmul(&cusparselt_handle_, &plan, &alpha, A, B, &_beta, C, C, d_workspace, streams, num_streams)) CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan)) diff --git a/src/fastertransformer/utils/cublasINT8MMWrapper.h b/src/fastertransformer/utils/cublasINT8MMWrapper.h index 92169c4d0..66bc940e3 100644 --- a/src/fastertransformer/utils/cublasINT8MMWrapper.h +++ b/src/fastertransformer/utils/cublasINT8MMWrapper.h @@ -34,50 +34,50 @@ class cublasINT8MMWrapper: public cublasMMWrapper { public: cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle_, - cudaStream_t stream, - cublasAlgoMap* map, - std::mutex* mu, - bool use_ORDER_COL32_2R_4R4); + cudaStream_t stream, + cublasAlgoMap* map, + std::mutex* mu, + bool use_ORDER_COL32_2R_4R4); - cublasINT8MMWrapper(cublasHandle_t cublas_handle, + cublasINT8MMWrapper(cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, - cudaStream_t stream, - cublasAlgoMap* map, - std::mutex* mu, - bool use_ORDER_COL32_2R_4R4); + cudaStream_t stream, + cublasAlgoMap* map, + std::mutex* mu, + bool use_ORDER_COL32_2R_4R4); #ifdef SPARSITY_ENABLED - cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle_, + cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle_, cusparseLtHandle_t cusparselt_handle, - cudaStream_t stream, - cublasAlgoMap* map, - std::mutex* mu, - bool use_ORDER_COL32_2R_4R4); + cudaStream_t stream, + cublasAlgoMap* map, + std::mutex* mu, + bool use_ORDER_COL32_2R_4R4); #endif ~cublasINT8MMWrapper(); cublasINT8MMWrapper(const cublasINT8MMWrapper& wrapper); - void Gemm(int* res, - int batchCount, - int m, - int n, - int k, - int64_t stridea, - int64_t strideb, - int64_t stridec, + void Gemm(int* res, + int batchCount, + int m, + int n, + int k, + int64_t stridea, + int64_t strideb, + int64_t stridec, const int8_t* ATransform, const int8_t* kernel); - void Gemm(int8_t* res, - int batchCount, - int m, - int n, - int k, - int64_t stridea, - int64_t strideb, - int64_t stridec, - const float alpha, + void Gemm(int8_t* res, + int batchCount, + int m, + int n, + int k, + int64_t stridea, + int64_t strideb, + int64_t stridec, + const float alpha, const int8_t* ATransform, const int8_t* kernel); diff --git a/src/fastertransformer/utils/cublasMMWrapper.cc b/src/fastertransformer/utils/cublasMMWrapper.cc index e291151a2..ea9c10a37 100644 --- a/src/fastertransformer/utils/cublasMMWrapper.cc +++ b/src/fastertransformer/utils/cublasMMWrapper.cc @@ -21,12 +21,12 @@ #endif namespace fastertransformer { -cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, +cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, - cudaStream_t stream, - cublasAlgoMap* cublas_algo_map, - std::mutex* mu, - IAllocator* allocator): + cudaStream_t stream, + cublasAlgoMap* cublas_algo_map, + std::mutex* mu, + IAllocator* allocator): cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), stream_(stream), @@ -34,19 +34,20 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, mu_(mu), allocator_(allocator) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (allocator_ != nullptr) { - cublas_workspace_ = allocator_->malloc(CUBLAS_WORKSPACE_SIZE); + cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE); } } #ifdef SPARSITY_ENABLED -cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, - cublasLtHandle_t cublaslt_handle, +cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, cusparseLtHandle_t cusparselt_handle, - cudaStream_t stream, - cublasAlgoMap* cublas_algo_map, - std::mutex* mu, - IAllocator* allocator): + cudaStream_t stream, + cublasAlgoMap* cublas_algo_map, + std::mutex* mu, + IAllocator* allocator): cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), cusparselt_handle_(cusparselt_handle), @@ -55,18 +56,19 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, mu_(mu), allocator_(allocator) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); if (allocator_ != nullptr) { - cublas_workspace_ = allocator_->malloc(CUBLAS_WORKSPACE_SIZE); + cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE); } } #endif cublasMMWrapper::~cublasMMWrapper() { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); mu_ = nullptr; if (allocator_ != nullptr) { - allocator_->free(cublas_workspace_); - allocator_ = nullptr; + allocator_->free((void**)(&cublas_workspace_)); } } @@ -85,22 +87,22 @@ cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper): void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* alpha, - const void* A, - cudaDataType_t Atype, - int lda, - const void* B, - cudaDataType_t Btype, - int ldb, - const void* beta, - void* C, - cudaDataType_t Ctype, - int ldc, - cudaDataType_t computeType, - cublasGemmAlgo_t algo) + const int m, + const int n, + const int k, + const void* alpha, + const void* A, + cudaDataType_t Atype, + int lda, + const void* B, + cudaDataType_t Btype, + int ldb, + const void* beta, + void* C, + cudaDataType_t Ctype, + int ldc, + cudaDataType_t computeType, + cublasGemmAlgo_t algo) { mu_->lock(); check_cuda_error(cublasGemmEx(cublas_handle_, @@ -128,45 +130,45 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const int lda, - const void* B, - const int ldb, - void* C, - const int ldc) + const int m, + const int n, + const int k, + const void* A, + const int lda, + const void* B, + const int ldb, + void* C, + const int ldc) { Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); } void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const int lda, - const void* B, - const int ldb, - void* C, - const int ldc, - float f_alpha, - float f_beta) + const int m, + const int n, + const int k, + const void* A, + const int lda, + const void* B, + const int ldb, + void* C, + const int ldc, + float f_alpha, + float f_beta) { half h_alpha = (half)(f_alpha); - half h_beta = (half)(f_beta); + half h_beta = (half)(f_beta); mu_->lock(); // TODO: default cublas libs - int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; - bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false; - int batch_count = 1; + int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false; + int batch_count = 1; // fp32 use cublas as default // fp16 use cublasLt as default const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(Atype_)); @@ -181,9 +183,9 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, } if (using_cublasLt) { - cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - cudaDataType_t scaleType; + cudaDataType_t scaleType; #if (CUDART_VERSION >= 11000) cublasComputeType_t computeType; #else @@ -222,8 +224,8 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)); cublasLtMatmulAlgo_t algo; - void* workSpace = cublas_workspace_; - int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + void* workSpace = cublas_workspace_; + int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; if (findAlgo) { if (info.workspaceSize > workspaceSize) { findAlgo = 0; @@ -299,26 +301,26 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, void cublasMMWrapper::setFP32GemmConfig() { - Atype_ = CUDA_R_32F; - Btype_ = CUDA_R_32F; - Ctype_ = CUDA_R_32F; + Atype_ = CUDA_R_32F; + Btype_ = CUDA_R_32F; + Ctype_ = CUDA_R_32F; computeType_ = CUDA_R_32F; } void cublasMMWrapper::setFP16GemmConfig() { - Atype_ = CUDA_R_16F; - Btype_ = CUDA_R_16F; - Ctype_ = CUDA_R_16F; + Atype_ = CUDA_R_16F; + Btype_ = CUDA_R_16F; + Ctype_ = CUDA_R_16F; computeType_ = CUDA_R_32F; } #ifdef ENABLE_BF16 void cublasMMWrapper::setBF16GemmConfig() { - Atype_ = CUDA_R_16BF; - Btype_ = CUDA_R_16BF; - Ctype_ = CUDA_R_16BF; + Atype_ = CUDA_R_16BF; + Btype_ = CUDA_R_16BF; + Ctype_ = CUDA_R_16BF; computeType_ = CUDA_R_32F; } #endif @@ -328,9 +330,9 @@ void cublasMMWrapper::setGemmConfig(cudaDataType_t aType, cudaDataType_t cType, cudaDataType_t computeType) { - Atype_ = aType; - Btype_ = bType; - Ctype_ = cType; + Atype_ = aType; + Btype_ = bType; + Ctype_ = cType; computeType_ = computeType; } @@ -355,58 +357,58 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type) // only works for cublas 11.x void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const int lda, - const void* B, - const int ldb, - const void* bias, - void* C, - const int ldc) + const int m, + const int n, + const int k, + const void* A, + const int lda, + const void* B, + const int ldb, + const void* bias, + void* C, + const int ldc) { - cudaDataType_t Atype, Btype, Ctype; + cudaDataType_t Atype, Btype, Ctype; cublasComputeType_t computeType; - cudaDataType_t scaleType; - float alpha_float = 1.0f; - float beta_float = 0.0f; - half alpha_half = half(1.0f); - half beta_half = half(0.0f); - void *alpha, *beta; + cudaDataType_t scaleType; + float alpha_float = 1.0f; + float beta_float = 0.0f; + half alpha_half = half(1.0f); + half beta_half = half(0.0f); + void * alpha, *beta; // int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; if (Atype_ == CUDA_R_32F) { computeType = CUBLAS_COMPUTE_32F_FAST_TF32; - Atype = CUDA_R_32F; - Btype = CUDA_R_32F; - Ctype = CUDA_R_32F; - scaleType = CUDA_R_32F; - alpha = &alpha_float; - beta = &beta_float; + Atype = CUDA_R_32F; + Btype = CUDA_R_32F; + Ctype = CUDA_R_32F; + scaleType = CUDA_R_32F; + alpha = &alpha_float; + beta = &beta_float; } else if (Atype_ == CUDA_R_16BF) { computeType = CUBLAS_COMPUTE_32F_FAST_TF32; - Atype = CUDA_R_16BF; - Btype = CUDA_R_16BF; - Ctype = CUDA_R_16BF; - scaleType = CUDA_R_32F; - alpha = &alpha_float; - beta = &beta_float; + Atype = CUDA_R_16BF; + Btype = CUDA_R_16BF; + Ctype = CUDA_R_16BF; + scaleType = CUDA_R_32F; + alpha = &alpha_float; + beta = &beta_float; } else { computeType = CUBLAS_COMPUTE_16F; - Atype = CUDA_R_16F; - Btype = CUDA_R_16F; - Ctype = CUDA_R_16F; - scaleType = CUDA_R_16F; - alpha = &alpha_half; - beta = &beta_half; + Atype = CUDA_R_16F; + Btype = CUDA_R_16F; + Ctype = CUDA_R_16F; + scaleType = CUDA_R_16F; + alpha = &alpha_half; + beta = &beta_half; } - cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; + cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda); cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb); cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc); @@ -431,27 +433,27 @@ void cublasMMWrapper::setStream(cudaStream_t stream) void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const int lda, - const int64_t strideA, - const void* B, - const int ldb, - const int64_t strideB, - void* C, - const int ldc, - const int64_t strideC, - const int batch_count, - const float f_alpha, - const float f_beta) + const int m, + const int n, + const int k, + const void* A, + const int lda, + const int64_t strideA, + const void* B, + const int ldb, + const int64_t strideB, + void* C, + const int ldc, + const int64_t strideC, + const int batch_count, + const float f_alpha, + const float f_beta) { half h_alpha = (half)f_alpha; - half h_beta = (half)f_beta; + half h_beta = (half)f_beta; mu_->lock(); - int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); @@ -486,31 +488,31 @@ void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const float f_alpha, - const void* A, - cudaDataType_t AType, - const int lda, - const int64_t strideA, - const void* B, - cudaDataType_t BType, - const int ldb, - const int64_t strideB, - const float f_beta, - void* C, - cudaDataType_t CType, - const int ldc, - const int64_t strideC, - const int batch_count, - cudaDataType_t computeType) + const int m, + const int n, + const int k, + const float f_alpha, + const void* A, + cudaDataType_t AType, + const int lda, + const int64_t strideA, + const void* B, + cudaDataType_t BType, + const int ldb, + const int64_t strideB, + const float f_beta, + void* C, + cudaDataType_t CType, + const int ldc, + const int64_t strideC, + const int batch_count, + cudaDataType_t computeType) { half h_alpha = (half)f_alpha; - half h_beta = (half)f_beta; + half h_beta = (half)f_beta; mu_->lock(); - int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0; + int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0; const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); @@ -543,29 +545,29 @@ void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, mu_->unlock(); } -void cublasMMWrapper::batchedGemm(cublasOperation_t transa, - cublasOperation_t transb, - const int m, - const int n, - const int k, +void cublasMMWrapper::batchedGemm(cublasOperation_t transa, + cublasOperation_t transb, + const int m, + const int n, + const int k, const void* const* A, - const int lda, + const int lda, const void* const* B, - const int ldb, - void* const* C, - const int ldc, - const int batch_count) + const int ldb, + void* const* C, + const int ldc, + const int batch_count) { float f_alpha = static_cast(1.0f); - float f_beta = static_cast(0.0f); + float f_beta = static_cast(0.0f); half h_alpha = (half)1.0f; - half h_beta = (half)0.0f; + half h_beta = (half)0.0f; mu_->lock(); - int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_)); check_cuda_error(cublasGemmBatchedEx(cublas_handle_, @@ -608,46 +610,46 @@ bool cublasMMWrapper::isFuseBatchGemm(const int batch_count, const int m, const #ifdef SPARSITY_ENABLED void cublasMMWrapper::SpGemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const void* B, - void* C) + const int m, + const int n, + const int k, + const void* A, + const void* B, + void* C) { if (Atype_ != CUDA_R_16F || Btype_ != CUDA_R_16F || Ctype_ != CUDA_R_16F) { throw std::runtime_error("\n[FT][ERROR] sparse GEMM only supports FP16 data type now."); } static bool not_printed_fp32_accumulation_warning = true; if (computeType_ != CUDA_R_16F && not_printed_fp32_accumulation_warning) { - printf("[FT][WRANING] cublasMMWrapper sets to FP32 compute type, " + printf("[FT][WARNING] cublasMMWrapper sets to FP32 compute type, " "but sparse gemm will use FP16 compute type since cusparselt " "supports FP16 accumulation only.\n"); not_printed_fp32_accumulation_warning = false; } - cusparseOrder_t order = CUSPARSE_ORDER_COL; + cusparseOrder_t order = CUSPARSE_ORDER_COL; cusparseOperation_t opA = (transa == CUBLAS_OP_N) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; cusparseOperation_t opB = (transb == CUBLAS_OP_N) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; cusparseComputeType compute_type = CUSPARSE_COMPUTE_16F; - cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulAlgSelection_t alg_sel; - cusparseLtMatmulPlan_t plan; - - bool is_rowmajor = (order == CUSPARSE_ORDER_ROW); - bool isA_transposed = (opA != CUSPARSE_OPERATION_NON_TRANSPOSE); - bool isB_transposed = (opB != CUSPARSE_OPERATION_NON_TRANSPOSE); - auto num_A_rows = (isA_transposed) ? k : m; - auto num_A_cols = (isA_transposed) ? m : k; - auto num_B_rows = (isB_transposed) ? n : k; - auto num_B_cols = (isB_transposed) ? k : n; - auto num_C_rows = m; - auto num_C_cols = n; - unsigned alignment = 16; - auto lda = (is_rowmajor) ? num_A_cols : num_A_rows; - auto ldb = (is_rowmajor) ? num_B_cols : num_B_rows; - auto ldc = (is_rowmajor) ? num_C_cols : num_C_rows; - float _alpha(1.0f); - float _beta(0.0f); + cusparseLtMatmulPlan_t plan; + + bool is_rowmajor = (order == CUSPARSE_ORDER_ROW); + bool isA_transposed = (opA != CUSPARSE_OPERATION_NON_TRANSPOSE); + bool isB_transposed = (opB != CUSPARSE_OPERATION_NON_TRANSPOSE); + auto num_A_rows = (isA_transposed) ? k : m; + auto num_A_cols = (isA_transposed) ? m : k; + auto num_B_rows = (isB_transposed) ? n : k; + auto num_B_cols = (isB_transposed) ? k : n; + auto num_C_rows = m; + auto num_C_cols = n; + unsigned alignment = 16; + auto lda = (is_rowmajor) ? num_A_cols : num_A_rows; + auto ldb = (is_rowmajor) ? num_B_cols : num_B_rows; + auto ldc = (is_rowmajor) ? num_C_cols : num_C_rows; + float _alpha(1.0f); + float _beta(0.0f); char mark[256]; sprintf(mark, "%d_%d_%d_%d", 1, m, n, k); @@ -701,9 +703,9 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa, CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&cusparselt_handle_, &alg_sel, &workspace_size)) CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&cusparselt_handle_, &plan, &matmul, &alg_sel, workspace_size)) - void* d_workspace = nullptr; - int num_streams = 1; - cudaStream_t streams[1] = {stream_}; + void* d_workspace = nullptr; + int num_streams = 1; + cudaStream_t streams[1] = {stream_}; CHECK_CUSPARSE( cusparseLtMatmul(&cusparselt_handle_, &plan, &_alpha, A, B, &_beta, C, C, d_workspace, streams, num_streams)) CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan)) @@ -714,12 +716,12 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa, size_t cublasMMWrapper::getSparseMatrixSize(int m, int k) { // Get a compressed matrix size of shape (m, k) used in cusparselt. - auto Atype_ = CUDA_R_16F; - cusparseOrder_t order = CUSPARSE_ORDER_COL; - unsigned alignment = 16; - int num_A_rows = m; - int num_A_cols = k; - int lda = num_A_rows; + auto Atype_ = CUDA_R_16F; + cusparseOrder_t order = CUSPARSE_ORDER_COL; + unsigned alignment = 16; + int num_A_rows = m; + int num_A_cols = k; + int lda = num_A_rows; cusparseLtMatDescriptor_t matA; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_, @@ -738,10 +740,10 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k) void cublasMMWrapper::compressMatrix(const void* input, void* output, const int m, const int k) { - cusparseOrder_t order = CUSPARSE_ORDER_COL; - cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOrder_t order = CUSPARSE_ORDER_COL; + cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; cusparseLtMatDescriptor_t matA; - unsigned alignment = 16; + unsigned alignment = 16; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( &cusparselt_handle_, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &matA, true, opA, input, output, stream_)) diff --git a/src/fastertransformer/utils/cublasMMWrapper.h b/src/fastertransformer/utils/cublasMMWrapper.h index 6f410ab3f..537102bae 100644 --- a/src/fastertransformer/utils/cublasMMWrapper.h +++ b/src/fastertransformer/utils/cublasMMWrapper.h @@ -29,10 +29,10 @@ namespace fastertransformer { class cublasMMWrapper { private: - cublasHandle_t cublas_handle_; + cublasHandle_t cublas_handle_; cublasLtHandle_t cublaslt_handle_; #ifdef SPARSITY_ENABLED - cusparseLtHandle_t cusparselt_handle_; + cusparseLtHandle_t cusparselt_handle_; std::map sp_mat_A_desc_map_; std::map sp_mat_B_desc_map_; std::map sp_mat_C_desc_map_; @@ -43,31 +43,31 @@ class cublasMMWrapper { cudaDataType_t Ctype_; cudaDataType_t computeType_; - cudaStream_t stream_; + cudaStream_t stream_; cublasAlgoMap* cublas_algo_map_; - std::mutex* mu_; + std::mutex* mu_; - IAllocator* allocator_ = nullptr; - void* cublas_workspace_ = nullptr; + IAllocator* allocator_ = nullptr; + void* cublas_workspace_ = nullptr; friend class cublasINT8MMWrapper; public: - cublasMMWrapper(cublasHandle_t cublas_handle_, + cublasMMWrapper(cublasHandle_t cublas_handle_, cublasLtHandle_t cublaslt_handle_, - cudaStream_t stream, - cublasAlgoMap* map, - std::mutex* mu, - IAllocator* allocator); + cudaStream_t stream, + cublasAlgoMap* map, + std::mutex* mu, + IAllocator* allocator); #ifdef SPARSITY_ENABLED - cublasMMWrapper(cublasHandle_t cublas_handle_, - cublasLtHandle_t cublaslt_handle_, + cublasMMWrapper(cublasHandle_t cublas_handle_, + cublasLtHandle_t cublaslt_handle_, cusparseLtHandle_t cusparselt_handle, - cudaStream_t stream, - cublasAlgoMap* map, - std::mutex* mu, - IAllocator* allocator); + cudaStream_t stream, + cublasAlgoMap* map, + std::mutex* mu, + IAllocator* allocator); #endif ~cublasMMWrapper(); @@ -76,48 +76,48 @@ class cublasMMWrapper { void Gemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* alpha, - const void* A, - cudaDataType_t Atype, - int lda, - const void* B, - cudaDataType_t Btype, - int ldb, - const void* beta, - void* C, - cudaDataType_t Ctype, - int ldc, - cudaDataType_t computeType, - cublasGemmAlgo_t algo); + const int m, + const int n, + const int k, + const void* alpha, + const void* A, + cudaDataType_t Atype, + int lda, + const void* B, + cudaDataType_t Btype, + int ldb, + const void* beta, + void* C, + cudaDataType_t Ctype, + int ldc, + cudaDataType_t computeType, + cublasGemmAlgo_t algo); void Gemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const int lda, - const void* B, - const int ldb, - void* C, - const int ldc); + const int m, + const int n, + const int k, + const void* A, + const int lda, + const void* B, + const int ldb, + void* C, + const int ldc); void Gemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const int lda, - const void* B, - const int ldb, - void* C, - const int ldc, - float f_alpha, - float f_beta); + const int m, + const int n, + const int k, + const void* A, + const int lda, + const void* B, + const int ldb, + void* C, + const int ldc, + float f_alpha, + float f_beta); void setFP32GemmConfig(); void setFP16GemmConfig(); @@ -133,85 +133,85 @@ class cublasMMWrapper { #if (CUDART_VERSION >= 11000) void Gemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const int lda, - const void* B, - const int ldb, - const void* bias, - void* C, - const int ldc); + const int m, + const int n, + const int k, + const void* A, + const int lda, + const void* B, + const int ldb, + const void* bias, + void* C, + const int ldc); #endif void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const int lda, - const int64_t strideA, - const void* B, - const int ldb, - const int64_t strideB, - void* C, - const int ldc, - const int64_t strideC, - const int batchCount, - const float f_alpha = 1.0f, - const float f_beta = 0.0f); + const int m, + const int n, + const int k, + const void* A, + const int lda, + const int64_t strideA, + const void* B, + const int ldb, + const int64_t strideB, + void* C, + const int ldc, + const int64_t strideC, + const int batchCount, + const float f_alpha = 1.0f, + const float f_beta = 0.0f); void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const float f_alpha, - const void* A, - cudaDataType_t AType, - const int lda, - const int64_t strideA, - const void* B, - cudaDataType_t BType, - const int ldb, - const int64_t strideB, - const float f_beta, - void* C, - cudaDataType_t CType, - const int ldc, - const int64_t strideC, - const int batch_count, - cudaDataType_t computeType); - - void batchedGemm(cublasOperation_t transa, - cublasOperation_t transb, - const int m, - const int n, - const int k, + const int m, + const int n, + const int k, + const float f_alpha, + const void* A, + cudaDataType_t AType, + const int lda, + const int64_t strideA, + const void* B, + cudaDataType_t BType, + const int ldb, + const int64_t strideB, + const float f_beta, + void* C, + cudaDataType_t CType, + const int ldc, + const int64_t strideC, + const int batch_count, + cudaDataType_t computeType); + + void batchedGemm(cublasOperation_t transa, + cublasOperation_t transb, + const int m, + const int n, + const int k, const void* const* A, - const int lda, + const int lda, const void* const* B, - const int ldb, - void* const* C, - const int ldc, - const int batch_count); + const int ldb, + void* const* C, + const int ldc, + const int batch_count); bool isFuseBatchGemm(const int batch_count, const int m, const int k, const int n); #ifdef SPARSITY_ENABLED void SpGemm(cublasOperation_t transa, cublasOperation_t transb, - const int m, - const int n, - const int k, - const void* A, - const void* B, - void* C); + const int m, + const int n, + const int k, + const void* A, + const void* B, + void* C); size_t getSparseMatrixSize(int m, int k); - void compressMatrix(const void* input, void* output, const int m, const int k); + void compressMatrix(const void* input, void* output, const int m, const int k); bool isUseSparse(const int batch_count, const int m, const int n, const int k); #endif diff --git a/src/fastertransformer/utils/cuda_utils.h b/src/fastertransformer/utils/cuda_utils.h index 5d73c87c5..49f1525ef 100644 --- a/src/fastertransformer/utils/cuda_utils.h +++ b/src/fastertransformer/utils/cuda_utils.h @@ -45,10 +45,10 @@ typedef struct half4 { /* **************************** type definition ***************************** */ enum CublasDataType { - FLOAT_DATATYPE = 0, - HALF_DATATYPE = 1, + FLOAT_DATATYPE = 0, + HALF_DATATYPE = 1, BFLOAT16_DATATYPE = 2, - INT8_DATATYPE = 3 + INT8_DATATYPE = 3 }; enum FtCudaDataType { @@ -147,17 +147,11 @@ void print_to_file(const T* result, const int size, const char* file) cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); printf("[INFO] file: %s \n", file); - FILE* fd = fopen(file, "w"); - T* tmp = reinterpret_cast(malloc(sizeof(T) * size)); + FILE* fd = fopen(file, "w"); + T* tmp = reinterpret_cast(malloc(sizeof(T) * size)); check_cuda_error(cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost)); for (int i = 0; i < size; ++i) { - float val; - if (sizeof(T) == 2) { - val = (T)__half2float(tmp[i]); - } - else { - val = (T)tmp[i]; - } + float val = (float)(tmp[i]); fprintf(fd, "%f\n", val); } free(tmp); @@ -167,10 +161,10 @@ void print_to_file(const T* result, const int size, const char* file) } template -void print_to_file(const T* result, - const int size, - const char* file, - cudaStream_t stream, +void print_to_file(const T* result, + const int size, + const char* file, + cudaStream_t stream, std::ios::openmode open_mode = std::ios::out) { cudaDeviceSynchronize(); @@ -212,12 +206,12 @@ void print_abs_mean(const T* buf, uint size, cudaStream_t stream, std::string na cudaMemcpyAsync(h_tmp, buf, sizeof(T) * size, cudaMemcpyDeviceToHost, stream); cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); - double sum = 0.0f; + double sum = 0.0f; uint64_t zero_count = 0; - float max_val = -1e10; - bool find_inf = false; + float max_val = -1e10; + bool find_inf = false; for (uint i = 0; i < size; i++) { - if (std::isinf(h_tmp[i])) { + if (std::isinf((float)(h_tmp[i]))) { find_inf = true; continue; } @@ -292,6 +286,117 @@ static inline void printMatrix(T* ptr, int m, int k, int stride, bool is_device_ } } +static inline void printMatrix(unsigned long long* ptr, int m, int k, int stride, bool is_device_ptr) +{ + typedef unsigned long long T; + T* tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } + else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } + else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%4llu ", tmp[ii * stride + jj]); + } + else { + printf("%4d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +static inline void printMatrix(int* ptr, int m, int k, int stride, bool is_device_ptr) +{ + typedef int T; + T* tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } + else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } + else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%4d ", tmp[ii * stride + jj]); + } + else { + printf("%4d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +static inline void printMatrix(size_t* ptr, int m, int k, int stride, bool is_device_ptr) +{ + typedef size_t T; + T* tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } + else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } + else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%4ld ", tmp[ii * stride + jj]); + } + else { + printf("%4d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + template void check_max_val(const T* result, const int size) { @@ -352,8 +457,8 @@ inline void myAssert(bool result, const char* const file, int const line, std::s /*************Time Handling**************/ class CudaTimer { private: - cudaEvent_t event_start_; - cudaEvent_t event_stop_; + cudaEvent_t event_start_; + cudaEvent_t event_stop_; cudaStream_t stream_; public: @@ -387,13 +492,14 @@ static double diffTime(timeval start, timeval end) /* ***************************** common utils ****************************** */ -inline void print_mem_usage() +inline void print_mem_usage(std::string time="after allocation") { size_t free_bytes, total_bytes; check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); - float free = static_cast(free_bytes) / 1024.0 / 1024.0 / 1024.0; + float free = static_cast(free_bytes) / 1024.0 / 1024.0 / 1024.0; float total = static_cast(total_bytes) / 1024.0 / 1024.0 / 1024.0; - printf("after allocation, free %.2f GB total %.2f GB\n", free, total); + float used = total - free; + printf("%-20s: free: %5.2f GB, total: %5.2f GB, used: %5.2f GB\n", time.c_str(), free, total, used); } inline int getSMVersion() @@ -421,8 +527,8 @@ inline int div_up(int a, int n) inline cudaError_t getSetDevice(int i_device, int* o_device = NULL) { - int current_dev_id = 0; - cudaError_t err = cudaSuccess; + int current_dev_id = 0; + cudaError_t err = cudaSuccess; if (o_device != NULL) { err = cudaGetDevice(¤t_dev_id); @@ -524,7 +630,7 @@ struct getTypeFromCudaDataType { inline FtCudaDataType getModelFileType(std::string ini_file, std::string section_name) { FtCudaDataType model_file_type; - INIReader reader = INIReader(ini_file); + INIReader reader = INIReader(ini_file); if (reader.ParseError() < 0) { FT_LOG_WARNING("Can't load %s. Use FP32 as default", ini_file.c_str()); model_file_type = FtCudaDataType::FP32; diff --git a/src/fastertransformer/utils/custom_ar_comm.cc b/src/fastertransformer/utils/custom_ar_comm.cc index ded1e58c8..b4cbb2b1e 100644 --- a/src/fastertransformer/utils/custom_ar_comm.cc +++ b/src/fastertransformer/utils/custom_ar_comm.cc @@ -23,9 +23,9 @@ CustomAllReduceComm::CustomAllReduceComm(size_t rank_size, size_t rank): rank { param_.barrier_flag = 0; // NOTE: assume All Reduce happens within the node (DGX A100) - param_.rank = rank_; + param_.rank = rank_; param_.local_rank = rank_; - param_.node_id = 0; + param_.node_id = 0; } template @@ -45,7 +45,7 @@ CustomAllReduceComm::~CustomAllReduceComm() template void CustomAllReduceComm::customAllReduce(size_t elts, cudaStream_t stream) { - param_.elts_total = elts; + param_.elts_total = elts; param_.barrier_flag = FLAG(param_.barrier_flag + 1); invokeOneOrTwoShotAllReduceKernel(param_, stream); @@ -69,8 +69,8 @@ void CustomAllReduceComm::allocateAndExchangePeerAccessPointer( cudaMalloc(&(param_.peer_barrier_ptrs[i]), rank_size_ * (MAX_ALL_REDUCE_BLOCKS + 1) * sizeof(uint32_t))); check_cuda_error( cudaMemset(param_.peer_barrier_ptrs[i], 0, rank_size_ * (MAX_ALL_REDUCE_BLOCKS + 1) * sizeof(uint32_t))); - T* current_peer_comm_buffer_ptr = param_.peer_comm_buffer_ptrs[i]; - uint32_t* current_peer_barrier_ptr = param_.peer_barrier_ptrs[i]; + T* current_peer_comm_buffer_ptr = param_.peer_comm_buffer_ptrs[i]; + uint32_t* current_peer_barrier_ptr = param_.peer_barrier_ptrs[i]; // Assume current comm allocates device memory on all ranks (rank_ == 0) for (size_t j = 1; j < rank_size_; j++) { static_cast*>(custom_all_reduce_comms->at(j).get()) @@ -112,9 +112,9 @@ bool CustomAllReduceComm::swapInternalBuffer(std::vector* tensor_buff // If meet, then swap the local comm buffer ptr with output tensor data pointer (avoid additional // memory movement) if (rank_size_ > 1 && elts * sizeof(T) <= CUSTOM_AR_SIZE_THRESHOLD) { - tmp_tensor_data_ = (T*)(tensor_buffer->at(0).data); - output_tensor_ = tensor_buffer; - tensor_buffer->at(0).data = param_.peer_comm_buffer_ptrs[rank_]; + tmp_tensor_data_ = (T*)(tensor_buffer->at(0).data); + output_tensor_ = tensor_buffer; + tensor_buffer->at(0).data = param_.peer_comm_buffer_ptrs[rank_]; param_.local_output_buffer_ptr = tmp_tensor_data_; return true; } @@ -123,8 +123,8 @@ bool CustomAllReduceComm::swapInternalBuffer(std::vector* tensor_buff template void initCustomAllReduceComm(std::vector>* custom_all_reduce_comms, - int enable_custom_all_reduce, - size_t rank_size) + int enable_custom_all_reduce, + size_t rank_size) { if (custom_all_reduce_comms == 0) { // don't use custom all reduce kernels, fall back to NCCL @@ -135,7 +135,17 @@ void initCustomAllReduceComm(std::vector>* c } if (rank_size != RANKS_PER_NODE) { - FT_LOG_WARNING("Custom All Reduce only supports 8 Ranks currently. Using NCCL as Comm."); +#ifdef BUILD_MULTI_GPU + if (rank_size > 1) { + FT_LOG_WARNING("Custom All Reduce only supports 8 Ranks currently. Using NCCL as Comm."); + } +#else + FT_CHECK_WITH_INFO(rank_size == 1, + fmtstr("Custom All Reduce only supports 8 Ranks currently, got rank_size %ld. FT needs " + "the NCCL library to communicate among devices but has built without NCCL. " + "Please use the flag -DBUILD_MULTI_GPU=ON when compiling.", + rank_size)); +#endif for (size_t i = 0; i < rank_size; i++) { custom_all_reduce_comms->push_back(nullptr); } @@ -163,17 +173,17 @@ template class CustomAllReduceComm<__nv_bfloat16>; template class CustomAllReduceComm; template void initCustomAllReduceComm(std::vector>* custom_all_reduce_comms, - int enable_custom_all_reduce, - size_t rank_size); + int enable_custom_all_reduce, + size_t rank_size); #ifdef ENABLE_BF16 template void initCustomAllReduceComm<__nv_bfloat16>(std::vector>* custom_all_reduce_comms, - int enable_custom_all_reduce, - size_t rank_size); + int enable_custom_all_reduce, + size_t rank_size); #endif template void initCustomAllReduceComm(std::vector>* custom_all_reduce_comms, - int enable_custom_all_reduce, - size_t rank_size); + int enable_custom_all_reduce, + size_t rank_size); -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/utils/custom_ar_comm.h b/src/fastertransformer/utils/custom_ar_comm.h index 798211e09..a88d8fc16 100644 --- a/src/fastertransformer/utils/custom_ar_comm.h +++ b/src/fastertransformer/utils/custom_ar_comm.h @@ -30,10 +30,10 @@ namespace fastertransformer { class AbstractCustomComm { public: - AbstractCustomComm() = default; - virtual ~AbstractCustomComm() = default; - virtual void customAllReduce(size_t elts, cudaStream_t stream) = 0; - virtual void enableP2P(int ngpus) = 0; + AbstractCustomComm() = default; + virtual ~AbstractCustomComm() = default; + virtual void customAllReduce(size_t elts, cudaStream_t stream) = 0; + virtual void enableP2P(int ngpus) = 0; virtual bool swapInternalBuffer(std::vector* tensor_buffer, size_t elts) = 0; virtual void allocateAndExchangePeerAccessPointer(std::vector>* custom_all_reduce_comms) = 0; @@ -55,17 +55,17 @@ class CustomAllReduceComm: public AbstractCustomComm { void enableP2P(int ngpus) override; private: - AllReduceParams param_; + AllReduceParams param_; std::vector* output_tensor_; - T* tmp_tensor_data_; - size_t rank_size_; - size_t rank_; + T* tmp_tensor_data_; + size_t rank_size_; + size_t rank_; }; template void initCustomAllReduceComm(std::vector>* custom_all_reduce_comms, - int enable_custom_all_reduce, - size_t rank_size); + int enable_custom_all_reduce, + size_t rank_size); template struct CustomARCommTypeConverter { diff --git a/src/fastertransformer/utils/gemm.cc b/src/fastertransformer/utils/gemm.cc index 035d5b3a9..e545ca230 100644 --- a/src/fastertransformer/utils/gemm.cc +++ b/src/fastertransformer/utils/gemm.cc @@ -23,14 +23,14 @@ namespace fastertransformer { Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file) { allocator_ = allocator; - stream_ = stream; - mutex_ = new std::mutex(); // mutex per process + stream_ = stream; + mutex_ = new std::mutex(); // mutex per process check_cuda_error(cublasCreate(&cublas_handle_)); check_cuda_error(cublasLtCreate(&cublaslt_handle_)); check_cuda_error(cublasSetStream(cublas_handle_, stream)); if (allocator_ != nullptr) { - workspace_ = allocator_->malloc(WORKSPACE_SIZE); + workspace_ = allocator_->reMalloc(workspace_, WORKSPACE_SIZE); } loadGemmConfig(config_file); } @@ -38,7 +38,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file) Gemm::~Gemm() { if (allocator_ != nullptr) { - allocator_->free(workspace_); + allocator_->free((void**)(&workspace_)); allocator_ = nullptr; } cublasLtDestroy(cublaslt_handle_); @@ -49,9 +49,9 @@ Gemm::~Gemm() std::string Gemm::toString() { - const char* a_type_str = a_type_ == TYPE_FP16 ? "FP16" : "FP32"; - const char* b_type_str = b_type_ == TYPE_FP16 ? "FP16" : "FP32"; - const char* c_type_str = c_type_ == TYPE_FP16 ? "FP16" : "FP32"; + const char* a_type_str = a_type_ == TYPE_FP16 ? "FP16" : "FP32"; + const char* b_type_str = b_type_ == TYPE_FP16 ? "FP16" : "FP32"; + const char* c_type_str = c_type_ == TYPE_FP16 ? "FP16" : "FP32"; const char* compute_type_str = compute_type_ == TYPE_FP16 ? "FP16" : "FP32"; return fmtstr( "Gemm[a_type=%s, b_type=%s, c_type=%s, compute_type=%s]", a_type_str, b_type_str, c_type_str, compute_type_str); @@ -60,11 +60,11 @@ std::string Gemm::toString() void Gemm::setAllocator(IAllocator* allocator) { if (allocator_ != nullptr && workspace_ != nullptr) { - allocator_->free(workspace_); + allocator_->free((void**)(&workspace_)); } allocator_ = allocator; if (allocator_ != nullptr) { - workspace_ = allocator_->malloc(WORKSPACE_SIZE); + workspace_ = allocator_->reMalloc(workspace_, WORKSPACE_SIZE); } } @@ -113,16 +113,16 @@ void Gemm::loadGemmConfig(std::string config_file) cublas_algo_map_ = new cublasAlgoMap(config_file); } -void Gemm::gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, +void Gemm::gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha, - const float beta) + void* output, + const float alpha, + const float beta) { gemm(transa, transb, @@ -142,16 +142,16 @@ void Gemm::gemm(const GemmOp transa, beta); } -void Gemm::gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, +void Gemm::gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha, - const float beta) + void* output, + const float alpha, + const float beta) { gemm(transa, transb, @@ -176,11 +176,11 @@ void Gemm::gemm(const GemmOp transa, const size_t m, const size_t n, const size_t k, - const void* A, - const void* B, - void* C, - const float alpha, - const float beta) + const void* A, + const void* B, + void* C, + const float alpha, + const float beta) { size_t lda = (transa == GEMM_OP_N) ? k : m; size_t ldb = (transb == GEMM_OP_N) ? n : k; @@ -193,34 +193,34 @@ void Gemm::gemm(const GemmOp transa, const size_t m, const size_t n, const size_t k, - const void* A, + const void* A, const size_t lda, - const void* B, + const void* B, const size_t ldb, - void* C, + void* C, const size_t ldc, - const float alpha, - const float beta) + const float alpha, + const float beta) { gemm(transa, transb, m, n, k, A, a_type_, lda, B, b_type_, ldb, C, c_type_, ldc, alpha, beta); } -void Gemm::gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, +void Gemm::gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, const DataType Atype, - const size_t lda, - const void* B, + const size_t lda, + const void* B, const DataType Btype, - const size_t ldb, - void* C, + const size_t ldb, + void* C, const DataType Ctype, - const size_t ldc, - const float alpha, - const float beta) + const size_t ldc, + const float alpha, + const float beta) { FT_LOG_TRACE("Gemm::gemm [m=%ld, n=%ld, k=%ld, lda=%ld, ldb=%ld, ldc=%ld]", m, n, k, lda, ldb, ldc); @@ -248,11 +248,11 @@ void Gemm::gemm(const GemmOp transa, mutex_->lock(); // Use cublas as default in FP32 and cublasLt as default in FP16 bool is_fp16_compute_type = compute_type_ == TYPE_FP16; - bool using_cublasLt = Atype == TYPE_FP16; - int batch_count = 1; + bool using_cublasLt = Atype == TYPE_FP16; + int batch_count = 1; - half h_alpha = (half)alpha; - half h_beta = (half)beta; + half h_alpha = (half)alpha; + half h_beta = (half)beta; const void* alpha_ptr = is_fp16_compute_type ? reinterpret_cast(&h_alpha) : reinterpret_cast(&alpha); const void* beta_ptr = @@ -273,10 +273,10 @@ void Gemm::gemm(const GemmOp transa, const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n; const size_t b_cols = (b_op == getCublasOperation(GEMM_OP_N)) ? _n : k; - cublasLtMatmulDesc_t matmul_desc = NULL; + cublasLtMatmulDesc_t matmul_desc = NULL; cublasLtMatrixLayout_t a_desc = NULL, b_desc = NULL, c_desc = NULL; - cudaDataType_t scale_type = getCublasDataType(compute_type_); - auto compute_type = getCublasComputeType(compute_type_); + cudaDataType_t scale_type = getCublasDataType(compute_type_); + auto compute_type = getCublasComputeType(compute_type_); // -------------------------------------- // Create descriptors for the original matrices @@ -293,8 +293,8 @@ void Gemm::gemm(const GemmOp transa, cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &b_op, sizeof(cublasOperation_t)); cublasLtMatmulAlgo_t algo; - void* workspace = workspace_; - int workspace_size = workspace_ == nullptr ? 0 : CUBLAS_WORKSPACE_SIZE; + void* workspace = workspace_; + int workspace_size = workspace_ == nullptr ? 0 : CUBLAS_WORKSPACE_SIZE; if (findAlgo) { if (info.workspaceSize > workspace_size) { findAlgo = 0; @@ -344,7 +344,7 @@ void Gemm::gemm(const GemmOp transa, } else { cudaDataType_t compute_type = getCublasDataType(compute_type_); - int cublas_algo = info.algoId; + int cublas_algo = info.algoId; check_cuda_error(cublasGemmEx(cublas_handle_, a_op, b_op, @@ -369,17 +369,17 @@ void Gemm::gemm(const GemmOp transa, mutex_->unlock(); } -void Gemm::batchedGemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, +void Gemm::batchedGemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, const void* const* A, const void* const* B, - void* const* C, - const size_t batch_size, - const float alpha, - const float beta) + void* const* C, + const size_t batch_size, + const float alpha, + const float beta) { size_t lda = (transa == GEMM_OP_N) ? k : m; size_t ldb = (transb == GEMM_OP_N) ? n : k; @@ -387,41 +387,41 @@ void Gemm::batchedGemm(const GemmOp transa, batchedGemm(transa, transb, m, n, k, A, a_type_, lda, B, b_type_, ldb, C, c_type_, ldc, batch_size, alpha, beta); } -void Gemm::batchedGemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, +void Gemm::batchedGemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, const void* const* A, - const size_t lda, + const size_t lda, const void* const* B, - const size_t ldb, - void* const* C, - const size_t ldc, - const size_t batch_size, - const float alpha, - const float beta) + const size_t ldb, + void* const* C, + const size_t ldc, + const size_t batch_size, + const float alpha, + const float beta) { batchedGemm(transa, transb, m, n, k, A, a_type_, lda, B, b_type_, ldb, C, c_type_, ldc, batch_size, alpha, beta); } -void Gemm::batchedGemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, +void Gemm::batchedGemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, const void* const* A, - const DataType Atype, - const size_t lda, + const DataType Atype, + const size_t lda, const void* const* B, - const DataType Btype, - const size_t ldb, - void* const* C, - const DataType Ctype, - const size_t ldc, - const size_t batch_size, - const float alpha, - const float beta) + const DataType Btype, + const size_t ldb, + void* const* C, + const DataType Ctype, + const size_t ldc, + const size_t batch_size, + const float alpha, + const float beta) { FT_LOG_TRACE( "Gemm::batchedGemm [b=%ld m=%ld, n=%ld, k=%ld, lda=%ld, ldb=%ld, ldc=%ld]", batch_size, m, n, k, lda, ldb, ldc); @@ -438,16 +438,16 @@ void Gemm::batchedGemm(const GemmOp transa, cudaDataType_t c_type = getCublasDataType(Ctype); // swap m and n, lda and ldb - const size_t _m = n; - const size_t _n = m; + const size_t _m = n; + const size_t _n = m; const size_t _lda = ldb; const size_t _ldb = lda; half h_alpha = (half)alpha; - half h_beta = (half)beta; + half h_beta = (half)beta; mutex_->lock(); - bool is_fp16_compute_type = compute_type_ == TYPE_FP16; + bool is_fp16_compute_type = compute_type_ == TYPE_FP16; const void* alpha_ptr = is_fp16_compute_type ? reinterpret_cast(&h_alpha) : reinterpret_cast(&alpha); const void* beta_ptr = @@ -478,21 +478,21 @@ void Gemm::batchedGemm(const GemmOp transa, mutex_->unlock(); } -void Gemm::stridedBatchedGemm(GemmOp transa, - GemmOp transb, +void Gemm::stridedBatchedGemm(GemmOp transa, + GemmOp transb, const size_t m, const size_t n, const size_t k, - const void* A, - const void* B, - void* C, + const void* A, + const void* B, + void* C, const size_t batch_size, - const float alpha, - const float beta) + const float alpha, + const float beta) { - size_t lda = (transa == GEMM_OP_N) ? k : m; - size_t ldb = (transb == GEMM_OP_N) ? n : k; - size_t ldc = n; + size_t lda = (transa == GEMM_OP_N) ? k : m; + size_t ldb = (transb == GEMM_OP_N) ? n : k; + size_t ldc = n; int64_t stridea = m * k; int64_t strideb = k * n; int64_t stridec = m * n; @@ -520,20 +520,20 @@ void Gemm::stridedBatchedGemm(GemmOp transa, beta); } -void Gemm::stridedBatchedGemm(GemmOp transa, - GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, +void Gemm::stridedBatchedGemm(GemmOp transa, + GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, const int64_t strideA, - const void* B, + const void* B, const int64_t strideB, - void* C, + void* C, const int64_t strideC, - const size_t batch_size, - const float alpha, - const float beta) + const size_t batch_size, + const float alpha, + const float beta) { size_t lda = (transa == GEMM_OP_N) ? k : m; size_t ldb = (transb == GEMM_OP_N) ? n : k; @@ -561,23 +561,23 @@ void Gemm::stridedBatchedGemm(GemmOp transa, beta); } -void Gemm::stridedBatchedGemm(GemmOp transa, - GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, - const size_t lda, +void Gemm::stridedBatchedGemm(GemmOp transa, + GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, + const size_t lda, const int64_t strideA, - const void* B, - const size_t ldb, + const void* B, + const size_t ldb, const int64_t strideB, - void* C, - const size_t ldc, + void* C, + const size_t ldc, const int64_t strideC, - const size_t batch_size, - const float alpha, - const float beta) + const size_t batch_size, + const float alpha, + const float beta) { stridedBatchedGemm(transa, transb, @@ -602,27 +602,27 @@ void Gemm::stridedBatchedGemm(GemmOp transa, beta); } -void Gemm::stridedBatchedGemm(GemmOp transa, - GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, - DataType Atype, - const size_t lda, +void Gemm::stridedBatchedGemm(GemmOp transa, + GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, + DataType Atype, + const size_t lda, const int64_t strideA, - const void* B, - DataType Btype, - const size_t ldb, + const void* B, + DataType Btype, + const size_t ldb, const int64_t strideB, - void* C, - DataType Ctype, - const size_t ldc, + void* C, + DataType Ctype, + const size_t ldc, const int64_t strideC, - const size_t batch_size, - DataType compute_type, - const float alpha, - const float beta) + const size_t batch_size, + DataType compute_type, + const float alpha, + const float beta) { FT_LOG_TRACE("Gemm::stridedBatchedGemm [b=%ld, m=%ld, n=%ld, k=%ld, lda=%ld, ldb=%ld, ldc=%ld]", batch_size, @@ -645,18 +645,18 @@ void Gemm::stridedBatchedGemm(GemmOp transa, cudaDataType_t c_type = getCublasDataType(Ctype); // swap m and n, lda and ldb, stride A and B - const size_t _m = n; - const size_t _n = m; - const size_t _lda = ldb; - const size_t _ldb = lda; + const size_t _m = n; + const size_t _n = m; + const size_t _lda = ldb; + const size_t _ldb = lda; const int64_t _stridea = strideB; const int64_t _strideb = strideA; half h_alpha = (half)alpha; - half h_beta = (half)beta; + half h_beta = (half)beta; mutex_->lock(); - bool is_fp16_compute_type = compute_type_ == TYPE_FP16; + bool is_fp16_compute_type = compute_type_ == TYPE_FP16; const void* alpha_ptr = is_fp16_compute_type ? reinterpret_cast(&h_alpha) : reinterpret_cast(&alpha); const void* beta_ptr = @@ -721,9 +721,9 @@ SpGemm::SpGemm(IAllocator* allocator, cudaStream_t stream, std::string config_fi // allowing us to inherit Gemm's constructor. // cublas_algo_map_.loadSpGemmConfig(spconfig_file); // enable this line later. - a_type_ = TYPE_FP16; - b_type_ = TYPE_FP16; - c_type_ = TYPE_FP16; + a_type_ = TYPE_FP16; + b_type_ = TYPE_FP16; + c_type_ = TYPE_FP16; compute_type_ = TYPE_FP16; } @@ -744,9 +744,9 @@ SpGemm::~SpGemm() std::string SpGemm::toString() { - const char* a_type_str = a_type_ == TYPE_FP16 ? "FP16" : "FP32"; - const char* b_type_str = b_type_ == TYPE_FP16 ? "FP16" : "FP32"; - const char* c_type_str = c_type_ == TYPE_FP16 ? "FP16" : "FP32"; + const char* a_type_str = a_type_ == TYPE_FP16 ? "FP16" : "FP32"; + const char* b_type_str = b_type_ == TYPE_FP16 ? "FP16" : "FP32"; + const char* c_type_str = c_type_ == TYPE_FP16 ? "FP16" : "FP32"; const char* compute_type_str = compute_type_ == TYPE_FP16 ? "FP16" : "FP32"; return fmtstr("SpGemm[a_type=%s, b_type=%s, c_type=%s, compute_type=%s]", a_type_str, @@ -777,16 +777,16 @@ bool SpGemm::useBaseGemm(size_t batch_size, size_t m, size_t n, size_t k) // Temporal gemm helper mtehod to use template T. template -void SpGemm::weightGemmHelper(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, +void SpGemm::weightGemmHelper(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha, - const float beta) + void* output, + const float alpha, + const float beta) { size_t lda = (transa == GEMM_OP_N) ? k : m; size_t ldb = (transb == GEMM_OP_N) ? n : k; @@ -829,49 +829,49 @@ void SpGemm::weightGemmHelper(const GemmOp transa, } } -void SpGemm::gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, +void SpGemm::gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha, - const float beta) + void* output, + const float alpha, + const float beta) { weightGemmHelper(transa, transb, m, n, k, input, weight, output, alpha, beta); } -void SpGemm::gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, +void SpGemm::gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha, - const float beta) + void* output, + const float alpha, + const float beta) { weightGemmHelper(transa, transb, m, n, k, input, weight, output, alpha, beta); } -void SpGemm::gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, +void SpGemm::gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, const DataType Atype, - const size_t lda, - const void* B, + const size_t lda, + const void* B, const DataType Btype, - const size_t ldb, - void* C, + const size_t ldb, + void* C, const DataType Ctype, - const size_t ldc, - const float alpha, - const float beta) + const size_t ldc, + const float alpha, + const float beta) { FT_LOG_TRACE("SpGemm::gemm [m=%ld, n=%ld, k=%ld, lda=%ld, ldb=%ld, ldc=%ld]", m, n, k, lda, ldb, ldc); checkDataTypeValidity(Atype); @@ -888,7 +888,7 @@ void SpGemm::gemm(const GemmOp transa, // Switch A/B due to column major layout in computation. // Typical usecase of Gemm family is to compute Y = X * W where X is an // input tensor and W is a kernel weight. Compression takes a lot time - // so only the kerenl weight (which is fixed in inference time) can be + // so only the kernel weight (which is fixed in inference time) can be // sparse. Using B as sparse seems not stable, unfortunately. // (e.g. caching matrix descriptions is not correctly working.) // Thus, SpGemm considers a column major layout in computation to make @@ -905,8 +905,8 @@ void SpGemm::gemm(const GemmOp transa, cudaDataType_t b_type = getCublasDataType(Atype); cudaDataType_t c_type = getCublasDataType(Ctype); - const size_t _m = n; - const size_t _n = m; + const size_t _m = n; + const size_t _n = m; const size_t _lda = ldb; const size_t _ldb = lda; @@ -917,12 +917,12 @@ void SpGemm::gemm(const GemmOp transa, const size_t c_rows = _m; const size_t c_cols = _n; - const unsigned alignment = 16; + const unsigned alignment = 16; cusparseComputeType compute_type = getCusparseComputeType(compute_type_); - cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulAlgSelection_t alg_sel; - cusparseLtMatmulPlan_t plan; + cusparseLtMatmulPlan_t plan; char mark[256]; sprintf(mark, "%d_%ld_%ld_%ld_%s_%s", 1, m, n, k, getGemmOpString(transb).c_str(), getGemmOpString(transa).c_str()); @@ -977,9 +977,9 @@ void SpGemm::gemm(const GemmOp transa, CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&cusparselt_handle_, &alg_sel, &workspace_size)); CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&cusparselt_handle_, &plan, &matmul, &alg_sel, workspace_size)); - void* d_workspace = nullptr; // Can we use the workspace of the class? - int num_streams = 1; - cudaStream_t streams[1] = {stream_}; + void* d_workspace = nullptr; // Can we use the workspace of the class? + int num_streams = 1; + cudaStream_t streams[1] = {stream_}; CHECK_CUSPARSE(cusparseLtMatmul( &cusparselt_handle_, &plan, &alpha, a_data, b_data, &beta, C, C, d_workspace, streams, num_streams)) CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan)) @@ -1108,15 +1108,15 @@ void pruneMatrixB(void* data, const cudaStream_t& stream, const size_t k, const FT_LOG_TRACE("Prune matrix B [k=%ld, n=%ld, op=%s]", k, n, getGemmOpString(trans).c_str()); // Due to A/B switching, the matrix B will be used as a matrix A. - const cusparseOrder_t order = CUSPARSE_ORDER_COL; - const size_t rows = (trans == GEMM_OP_N) ? n : k; - const size_t cols = (trans == GEMM_OP_N) ? k : n; - const size_t ld = rows; - const unsigned alignment = 16; + const cusparseOrder_t order = CUSPARSE_ORDER_COL; + const size_t rows = (trans == GEMM_OP_N) ? n : k; + const size_t cols = (trans == GEMM_OP_N) ? k : n; + const size_t ld = rows; + const unsigned alignment = 16; const cusparseLtPruneAlg_t prune_alg = CUSPARSELT_PRUNE_SPMMA_STRIP; - const cusparseOperation_t op = getCusparseOperation(trans); - const cudaDataType_t dtype = CUDA_R_16F; // fixed under cusparselt == 0.2.0. + const cusparseOperation_t op = getCusparseOperation(trans); + const cudaDataType_t dtype = CUDA_R_16F; // fixed under cusparselt == 0.2.0. // 0: B is sparse, 1: A is sparse // B matrix will be used as A matrix at the SpGemm::gemm. @@ -1133,28 +1133,28 @@ void pruneMatrixB(void* data, const cudaStream_t& stream, const size_t k, const CHECK_CUSPARSE(cusparseLtDestroy(&handle)); } -size_t compressMatrixB(void** output, - IAllocator& allocator, +size_t compressMatrixB(void** output, + IAllocator& allocator, const cudaStream_t& stream, - const void* input, - const size_t k, - const size_t n, - const GemmOp trans) + const void* input, + const size_t k, + const size_t n, + const GemmOp trans) { FT_LOG_TRACE("compressMatrix [k=%ld, n=%ld, dtype=FP16]", k, n); // swap A/B due to column/row major layout mismatch. cusparseOrder_t order = CUSPARSE_ORDER_COL; - const size_t rows = (trans == GEMM_OP_N) ? n : k; - const size_t cols = (trans == GEMM_OP_N) ? k : n; - const size_t ld = rows; + const size_t rows = (trans == GEMM_OP_N) ? n : k; + const size_t cols = (trans == GEMM_OP_N) ? k : n; + const size_t ld = rows; - cudaDataType_t dtype = CUDA_R_16F; // fixed under cusparselt == 0.2.0. - cusparseLtSparsity_t sparsity = CUSPARSELT_SPARSITY_50_PERCENT; - cusparseOperation_t op = getCusparseOperation(trans); + cudaDataType_t dtype = CUDA_R_16F; // fixed under cusparselt == 0.2.0. + cusparseLtSparsity_t sparsity = CUSPARSELT_SPARSITY_50_PERCENT; + cusparseOperation_t op = getCusparseOperation(trans); cusparseLtMatDescriptor_t mat_desc; - const unsigned alignment = 16; - const int is_sparse_a = 1; // 0: B is sparse, 1: A is sparse + const unsigned alignment = 16; + const int is_sparse_a = 1; // 0: B is sparse, 1: A is sparse cusparseLtHandle_t handle; CHECK_CUSPARSE(cusparseLtInit(&handle)); diff --git a/src/fastertransformer/utils/gemm.h b/src/fastertransformer/utils/gemm.h index ef2f49929..15de78cd9 100644 --- a/src/fastertransformer/utils/gemm.h +++ b/src/fastertransformer/utils/gemm.h @@ -165,51 +165,51 @@ class Gemm { // We temperally add an interface here for two cases float/half, // but to finialze this function, we need an interface of a weight class // which is not a template class. - virtual void gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, + virtual void gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha = 1.0f, - const float beta = 0.0f); - virtual void gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, + void* output, + const float alpha = 1.0f, + const float beta = 0.0f); + virtual void gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha = 1.0f, - const float beta = 0.0f); + void* output, + const float alpha = 1.0f, + const float beta = 0.0f); virtual void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, - const void* A, - const void* B, - void* C, - const float alpha = 1.0f, - const float beta = 0.0f); + const void* A, + const void* B, + void* C, + const float alpha = 1.0f, + const float beta = 0.0f); virtual void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, - const void* A, + const void* A, const size_t lda, - const void* B, + const void* B, const size_t ldb, - void* C, + void* C, const size_t ldc, - const float alpha = 1.0f, - const float beta = 0.0f); + const float alpha = 1.0f, + const float beta = 0.0f); /** * @brief Compute the matrix multiplication `C = \alpha * op(A) * op(B) + \beta * C`. * @@ -233,49 +233,49 @@ class Gemm { * @throw GemmNotSupportedException if a type is not TYPE_FP16 or TYPE_FP32. * @throw std::runtime_error if any exception inside CUDA. */ - virtual void gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, + virtual void gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, const DataType Atype, - const size_t lda, - const void* B, + const size_t lda, + const void* B, const DataType Btype, - const size_t ldb, - void* C, + const size_t ldb, + void* C, const DataType Ctype, - const size_t ldc, - const float alpha = 1.0f, - const float beta = 0.0f); - - virtual void batchedGemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, + const size_t ldc, + const float alpha = 1.0f, + const float beta = 0.0f); + + virtual void batchedGemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, const void* const* A, const void* const* B, - void* const* C, - const size_t batch_size, - const float alpha = 1.0f, - const float beta = 0.0f); - - virtual void batchedGemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, + void* const* C, + const size_t batch_size, + const float alpha = 1.0f, + const float beta = 0.0f); + + virtual void batchedGemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, const void* const* A, - const size_t lda, + const size_t lda, const void* const* B, - const size_t ldb, - void* const* C, - const size_t ldc, - const size_t batch_size, - const float alpha = 1.0f, - const float beta = 0.0f); + const size_t ldb, + void* const* C, + const size_t ldc, + const size_t batch_size, + const float alpha = 1.0f, + const float beta = 0.0f); /** * @brief Compute the matrix multiplication of batch of matrices As and Bs @@ -303,68 +303,68 @@ class Gemm { * @throw GemmNotSupportedException if a type is not TYPE_FP16 or TYPE_FP32. * @throw std::runtime_error if any exception inside CUDA. */ - virtual void batchedGemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, + virtual void batchedGemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, const void* const* A, - const DataType Atype, - const size_t lda, + const DataType Atype, + const size_t lda, const void* const* B, - const DataType Btype, - const size_t ldb, - void* const* C, - const DataType Ctype, - const size_t ldc, - const size_t batch_size, - const float alpha = 1.0f, - const float beta = 0.0f); - - virtual void stridedBatchedGemm(GemmOp transa, - GemmOp transb, + const DataType Btype, + const size_t ldb, + void* const* C, + const DataType Ctype, + const size_t ldc, + const size_t batch_size, + const float alpha = 1.0f, + const float beta = 0.0f); + + virtual void stridedBatchedGemm(GemmOp transa, + GemmOp transb, const size_t m, const size_t n, const size_t k, - const void* A, - const void* B, - void* C, + const void* A, + const void* B, + void* C, const size_t batch_size, - const float alpha = 1.0f, - const float beta = 0.0f); - - virtual void stridedBatchedGemm(GemmOp transa, - GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, + const float alpha = 1.0f, + const float beta = 0.0f); + + virtual void stridedBatchedGemm(GemmOp transa, + GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, const int64_t strideA, - const void* B, + const void* B, const int64_t strideB, - void* C, + void* C, const int64_t strideC, - const size_t batch_size, - const float alpha = 1.0f, - const float beta = 0.0f); - - virtual void stridedBatchedGemm(GemmOp transa, - GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, - const size_t lda, + const size_t batch_size, + const float alpha = 1.0f, + const float beta = 0.0f); + + virtual void stridedBatchedGemm(GemmOp transa, + GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, + const size_t lda, const int64_t strideA, - const void* B, - const size_t ldb, + const void* B, + const size_t ldb, const int64_t strideB, - void* C, - const size_t ldc, + void* C, + const size_t ldc, const int64_t strideC, - const size_t batch_size, - const float alpha = 1.0f, - const float beta = 0.0f); + const size_t batch_size, + const float alpha = 1.0f, + const float beta = 0.0f); /** * @brief Compute the strided matrix multiplication of batch of matrices As and Bs * @@ -395,42 +395,42 @@ class Gemm { * @throw GemmNotSupportedException if a type is not TYPE_FP16 or TYPE_FP32. * @throw std::runtime_error if any exception inside CUDA. */ - virtual void stridedBatchedGemm(GemmOp transa, - GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, - DataType Atype, - const size_t lda, + virtual void stridedBatchedGemm(GemmOp transa, + GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, + DataType Atype, + const size_t lda, const int64_t strideA, - const void* B, - DataType Btype, - const size_t ldb, + const void* B, + DataType Btype, + const size_t ldb, const int64_t strideB, - void* C, - DataType Ctype, - const size_t ldc, + void* C, + DataType Ctype, + const size_t ldc, const int64_t strideC, - const size_t batch_size, - DataType compute_type, - const float alpha = 1.0f, - const float beta = 0.0f); + const size_t batch_size, + DataType compute_type, + const float alpha = 1.0f, + const float beta = 0.0f); protected: - IAllocator* allocator_ = nullptr; - cudaStream_t stream_; - std::mutex* mutex_ = nullptr; + IAllocator* allocator_ = nullptr; + cudaStream_t stream_; + std::mutex* mutex_ = nullptr; cublasAlgoMap* cublas_algo_map_ = nullptr; - cublasHandle_t cublas_handle_; + cublasHandle_t cublas_handle_; cublasLtHandle_t cublaslt_handle_; - void* workspace_ = nullptr; + void* workspace_ = nullptr; // use FP32 as default - DataType a_type_ = TYPE_FP32; - DataType b_type_ = TYPE_FP32; - DataType c_type_ = TYPE_FP32; + DataType a_type_ = TYPE_FP32; + DataType b_type_ = TYPE_FP32; + DataType c_type_ = TYPE_FP32; DataType compute_type_ = TYPE_FP32; // Check if data and inputs are valid in the Gemm class. @@ -457,11 +457,11 @@ class Gemm { class SpGemm: public Gemm { protected: - cusparseLtHandle_t cusparselt_handle_; + cusparseLtHandle_t cusparselt_handle_; std::map a_desc_map_; std::map b_desc_map_; std::map c_desc_map_; - bool useBaseGemm(size_t batch_size, size_t m, size_t n, size_t k); + bool useBaseGemm(size_t batch_size, size_t m, size_t n, size_t k); public: using Gemm::setComputeType; @@ -479,68 +479,68 @@ class SpGemm: public Gemm { * @param config_file A file path of a GEMM configuration. */ // TODO: Let's unify algo map loading part. - SpGemm(IAllocator* allocator, + SpGemm(IAllocator* allocator, cudaStream_t stream, - std::string config_file = GEMM_CONFIG, - std::string spconfig_file = SPGEMM_CONFIG); + std::string config_file = GEMM_CONFIG, + std::string spconfig_file = SPGEMM_CONFIG); ~SpGemm(); std::string toString() override; - void loadGemmConfig(std::string config_file, std::string spconfig_file); + void loadGemmConfig(std::string config_file, std::string spconfig_file); // Template method cannot be overrided. - void gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, + void gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha = 1.0f, - const float beta = 0.0f) override; - void gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, + void* output, + const float alpha = 1.0f, + const float beta = 0.0f) override; + void gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha = 1.0f, - const float beta = 0.0f) override; - - void gemm(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* A, + void* output, + const float alpha = 1.0f, + const float beta = 0.0f) override; + + void gemm(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* A, const DataType Atype, - const size_t lda, - const void* B, + const size_t lda, + const void* B, const DataType Btype, - const size_t ldb, - void* C, + const size_t ldb, + void* C, const DataType Ctype, - const size_t ldc, - const float alpha = 1.0f, - const float beta = 0.0f) override; + const size_t ldc, + const float alpha = 1.0f, + const float beta = 0.0f) override; private: void checkDataTypeValidity(const DataType& type) override; // Temporal gemm helper mtehod to use template T. template - void weightGemmHelper(const GemmOp transa, - const GemmOp transb, - const size_t m, - const size_t n, - const size_t k, - const void* input, + void weightGemmHelper(const GemmOp transa, + const GemmOp transb, + const size_t m, + const size_t n, + const size_t k, + const void* input, const DenseWeight& weight, - void* output, - const float alpha, - const float beta); + void* output, + const float alpha, + const float beta); }; // class Int8SpGemm : public Int8Gemm, public SpGemm { @@ -627,7 +627,7 @@ cublasComputeType_t getCublasComputeType(DataType dtype); cudaDataType_t getCublasComputeType(DataType dtype); #endif cublasOperation_t getCublasOperation(GemmOp op); -std::string getGemmOpString(const GemmOp& op); +std::string getGemmOpString(const GemmOp& op); #ifdef SPARSITY_ENABLED cusparseOperation_t getCusparseOperation(GemmOp op); @@ -665,13 +665,13 @@ void pruneMatrixB( * or if fail to compute a correct buffer size to store the compressed matrix. * @throw std::runtime_error if any exception inside CUDA. */ -size_t compressMatrixB(void** output, - IAllocator& allocator, +size_t compressMatrixB(void** output, + IAllocator& allocator, const cudaStream_t& stream, - const void* input, - const size_t k, - const size_t n, - const GemmOp trans = GEMM_OP_N); + const void* input, + const size_t k, + const size_t n, + const GemmOp trans = GEMM_OP_N); #endif diff --git a/src/fastertransformer/utils/gemm_test/decoding_gemm_func.cc b/src/fastertransformer/utils/gemm_test/decoding_gemm_func.cc index 5962f2ba3..4d4f8b86a 100644 --- a/src/fastertransformer/utils/gemm_test/decoding_gemm_func.cc +++ b/src/fastertransformer/utils/gemm_test/decoding_gemm_func.cc @@ -19,20 +19,20 @@ namespace fastertransformer { template -void generate_decoding_gemm_config(int batch_size, - int beam_width, - int max_mem_seq_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int mem_hidden_units, +void generate_decoding_gemm_config(int batch_size, + int beam_width, + int max_mem_seq_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int mem_hidden_units, void* buffer_in, - bool isAppend) + bool isAppend) { void* cublas_workspace; void* buffer; - int workSpaceSize; + int workSpaceSize; #ifdef ENABLE_BF16 if (std::is_same::value || std::is_same::value) { @@ -42,13 +42,13 @@ void generate_decoding_gemm_config(int batch_size, // cublas_workspace_ should be the start pointer of cudaMalloc() // to ensure 16B alignemnet cublas_workspace = buffer_in; - buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); - workSpaceSize = CUBLAS_WORKSPACE_SIZE; + buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); + workSpaceSize = CUBLAS_WORKSPACE_SIZE; } else { cublas_workspace = nullptr; - buffer = buffer_in; - workSpaceSize = 0; + buffer = buffer_in; + workSpaceSize = 0; } struct cudaDeviceProp prop; @@ -57,14 +57,14 @@ void generate_decoding_gemm_config(int batch_size, // check config FILE* fd; - int line_count = 0; + int line_count = 0; if (!isAppend) { fd = fopen(GEMM_CONFIG, "w+"); } else { fd = fopen(GEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -83,12 +83,12 @@ void generate_decoding_gemm_config(int batch_size, } const int hidden_units = head_num * size_per_head; - const int gemm_num = 6; - int M[gemm_num]; - int N[gemm_num]; - int K[gemm_num]; - int batchCount[gemm_num] = {1, 1, 1, 1, 1, 1}; - char mess[gemm_num][256]; + const int gemm_num = 6; + int M[gemm_num]; + int N[gemm_num]; + int K[gemm_num]; + int batchCount[gemm_num] = {1, 1, 1, 1, 1, 1}; + char mess[gemm_num][256]; // gemm 0 M[0] = batch_size * beam_width; @@ -135,44 +135,44 @@ void generate_decoding_gemm_config(int batch_size, cudaDataType_t BType; cudaDataType_t CType; cudaDataType_t computeType; - int startAlgo, endAlgo; - const int ites = 100; + int startAlgo, endAlgo; + const int ites = 100; struct timeval start, end; CublasDataType data_type; if (std::is_same::value) { - data_type = FLOAT_DATATYPE; - AType = CUDA_R_32F; - BType = CUDA_R_32F; - CType = CUDA_R_32F; + data_type = FLOAT_DATATYPE; + AType = CUDA_R_32F; + BType = CUDA_R_32F; + CType = CUDA_R_32F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT; - endAlgo = (int)CUBLAS_GEMM_ALGO23; + startAlgo = (int)CUBLAS_GEMM_DEFAULT; + endAlgo = (int)CUBLAS_GEMM_ALGO23; } else if (std::is_same::value) { - data_type = HALF_DATATYPE; - AType = CUDA_R_16F; - BType = CUDA_R_16F; - CType = CUDA_R_16F; + data_type = HALF_DATATYPE; + AType = CUDA_R_16F; + BType = CUDA_R_16F; + CType = CUDA_R_16F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #ifdef ENABLE_BF16 else if (std::is_same::value) { - data_type = BFLOAT16_DATATYPE; - AType = CUDA_R_16BF; - BType = CUDA_R_16BF; - CType = CUDA_R_16BF; + data_type = BFLOAT16_DATATYPE; + AType = CUDA_R_16BF; + BType = CUDA_R_16BF; + CType = CUDA_R_16BF; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #endif using scaleT = typename ScaleTypeConverter::Type; scaleT alpha = (scaleT)1.0f; - scaleT beta = (scaleT)0.0f; + scaleT beta = (scaleT)0.0f; printf("***Encoder Gemm Testing Begin***\n"); printf("***Cublas Gemm Testing Begin***\n"); @@ -190,8 +190,8 @@ void generate_decoding_gemm_config(int batch_size, T* d_C = d_B + k * n * batchCount[i]; float exec_time = 99999.0f; - int fast_algo = 0; - int seq_len = i == 2 ? max_mem_seq_len : 1; + int fast_algo = 0; + int seq_len = i == 2 ? max_mem_seq_len : 1; for (int algo = startAlgo; algo <= endAlgo; algo++) { cublasStatus_t status; cudaDeviceSynchronize(); @@ -236,7 +236,7 @@ void generate_decoding_gemm_config(int batch_size, if (data_type != FLOAT_DATATYPE) { printf("***cublasLt Gemm Testing Beign***\n"); // Let try a fixed number of combinations - int ALGO_COMBINATIONS = 5000; + int ALGO_COMBINATIONS = 5000; customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; LtHgemmCustomFind(ltHandle, @@ -309,61 +309,61 @@ void generate_decoding_gemm_config(int batch_size, return; } -template void generate_decoding_gemm_config(int batch_size, - int beam_width, - int seq_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int mem_hidden_units, +template void generate_decoding_gemm_config(int batch_size, + int beam_width, + int seq_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int mem_hidden_units, void* buffer_in, - bool isAppend); + bool isAppend); -template void generate_decoding_gemm_config(int batch_size, - int beam_width, - int seq_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int mem_hidden_units, +template void generate_decoding_gemm_config(int batch_size, + int beam_width, + int seq_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int mem_hidden_units, void* buffer_in, - bool isAppend); + bool isAppend); #ifdef ENABLE_BF16 -template void generate_decoding_gemm_config<__nv_bfloat16>(int batch_size, - int beam_width, - int seq_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int mem_hidden_units, +template void generate_decoding_gemm_config<__nv_bfloat16>(int batch_size, + int beam_width, + int seq_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int mem_hidden_units, void* buffer_in, - bool isAppend); + bool isAppend); #endif -size_t calDecodingGemmTestBufSizeInByte(int batch_size, - int beam_width, - int max_mem_seq_len, - int head_num, - int size_per_head, - int inter_size, - int memory_hidden_units, - int vocab_size, +size_t calDecodingGemmTestBufSizeInByte(int batch_size, + int beam_width, + int max_mem_seq_len, + int head_num, + int size_per_head, + int inter_size, + int memory_hidden_units, + int vocab_size, CublasDataType data_type) { - size_t buf_size_in_byte = 0; - const size_t tensor_para_size = 1; - const size_t hidden_units = head_num * size_per_head; - const size_t local_head_num = head_num / tensor_para_size; + size_t buf_size_in_byte = 0; + const size_t tensor_para_size = 1; + const size_t hidden_units = head_num * size_per_head; + const size_t local_head_num = head_num / tensor_para_size; const size_t local_hidden_units = local_head_num * size_per_head; // TODO need to add bfloat16 here int wordSize = (data_type == FLOAT_DATATYPE ? sizeof(float) : sizeof(half)); - size_t m = batch_size * beam_width; + size_t m = batch_size * beam_width; std::vector buff_size; // for qkv gemm buff_size.push_back(m * hidden_units + hidden_units * 3 * local_hidden_units + m * 3 * local_hidden_units); diff --git a/src/fastertransformer/utils/gemm_test/decoding_gemm_func.h b/src/fastertransformer/utils/gemm_test/decoding_gemm_func.h index 178b3d9ce..0b073c9c0 100644 --- a/src/fastertransformer/utils/gemm_test/decoding_gemm_func.h +++ b/src/fastertransformer/utils/gemm_test/decoding_gemm_func.h @@ -34,25 +34,25 @@ namespace fastertransformer { template -void generate_decoding_gemm_config(int batch_size, - int beam_width, - int seq_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int mem_hidden_units, +void generate_decoding_gemm_config(int batch_size, + int beam_width, + int seq_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int mem_hidden_units, void* buffer_in, - bool isAppend); + bool isAppend); -size_t calDecodingGemmTestBufSizeInByte(int batch_size, - int beam_width, - int max_mem_seq_len, - int head_num, - int size_per_head, - int inter_size, - int memory_hidden_units, - int vocab_size, +size_t calDecodingGemmTestBufSizeInByte(int batch_size, + int beam_width, + int max_mem_seq_len, + int head_num, + int size_per_head, + int inter_size, + int memory_hidden_units, + int vocab_size, CublasDataType data_type); } // namespace fastertransformer diff --git a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc index 03c6947e0..3c66d8cf5 100644 --- a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc +++ b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc @@ -20,11 +20,11 @@ namespace fastertransformer { template void generate_encoder_gemm_config( - int batch_size, int seq_len, int head_num, int size_per_head, void* buffer_in, bool isAppend) + int batch_size, int seq_len, int head_num, int size_per_head, void* buffer_in, bool isAppend, int tensor_para_size) { void* cublas_workspace; void* buffer; - int workSpaceSize; + int workSpaceSize; #ifdef ENABLE_BF16 if (std::is_same::value || std::is_same::value) { @@ -34,13 +34,13 @@ void generate_encoder_gemm_config( // cublas_workspace_ should be the start pointer of cudaMalloc() // to ensure 16B alignemnet cublas_workspace = buffer_in; - buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); - workSpaceSize = CUBLAS_WORKSPACE_SIZE; + buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); + workSpaceSize = CUBLAS_WORKSPACE_SIZE; } else { cublas_workspace = nullptr; - buffer = buffer_in; - workSpaceSize = 0; + buffer = buffer_in; + workSpaceSize = 0; } struct cudaDeviceProp prop; @@ -49,14 +49,14 @@ void generate_encoder_gemm_config( // check config FILE* fd; - int line_count = 0; + int line_count = 0; if (!isAppend) { fd = fopen(GEMM_CONFIG, "w+"); } else { fd = fopen(GEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -74,49 +74,54 @@ void generate_encoder_gemm_config( } } - const int gemm_num = 6; - int M[gemm_num]; - int N[gemm_num]; - int K[gemm_num]; - int batchCount[gemm_num] = {1, 1, 1, 1, 1, 1}; - char mess[gemm_num][256]; - float exec_times[gemm_num]; + const int gemm_num = 7; + int M[gemm_num]; + int N[gemm_num]; + int K[gemm_num]; + int batchCount[gemm_num] = {1, 1, 1, 1, 1, 1, 1}; + char mess[gemm_num][256]; + float exec_times[gemm_num]; // gemm1 M[0] = batch_size * seq_len; K[0] = head_num * size_per_head; - N[0] = K[0]; - strcpy(mess[0], "from_tensor * weightQ/K/V, attr * output_kernel"); + N[0] = (head_num / tensor_para_size) * size_per_head; + strcpy(mess[0], "from_tensor * weightQ/K/V"); // gemm2 M[1] = M[0]; - K[1] = K[0]; - N[1] = 4 * N[0]; + K[1] = head_num * size_per_head; + N[1] = 4 * head_num * size_per_head / tensor_para_size; strcpy(mess[1], "attr_output * inter_kernel"); // gemm3 M[2] = M[0]; - K[2] = 4 * K[0]; - N[2] = N[0]; + K[2] = 4 * head_num * size_per_head / tensor_para_size; + N[2] = head_num * size_per_head; strcpy(mess[2], "inter_matmul * output_kernel"); - M[3] = seq_len; - N[3] = seq_len; - K[3] = size_per_head; - batchCount[3] = batch_size * head_num; + M[3] = seq_len; + N[3] = seq_len; + K[3] = size_per_head; + batchCount[3] = batch_size * (head_num / tensor_para_size); strcpy(mess[3], "attention batched Gemm1"); - M[4] = seq_len; - N[4] = size_per_head; - K[4] = seq_len; - batchCount[4] = batch_size * head_num; + M[4] = seq_len; + N[4] = size_per_head; + K[4] = seq_len; + batchCount[4] = batch_size * (head_num / tensor_para_size); strcpy(mess[4], "attention batched Gemm2"); - M[5] = batch_size * seq_len; - N[5] = head_num * size_per_head; - K[5] = N[5]; + M[5] = batch_size * seq_len; + N[5] = (head_num / tensor_para_size) * size_per_head; + K[5] = head_num * size_per_head; batchCount[5] = 3; strcpy(mess[5], "from_tensor * weight_QKV in BatchGemm"); + + M[6] = batch_size * seq_len; + K[6] = (head_num / tensor_para_size) * size_per_head; + N[6] = head_num * size_per_head; + strcpy(mess[6], "attr * output_kernel"); cublasHandle_t cublas_handle; check_cuda_error(cublasCreate(&cublas_handle)); @@ -127,44 +132,44 @@ void generate_encoder_gemm_config( cudaDataType_t BType; cudaDataType_t CType; cudaDataType_t computeType; - int startAlgo, endAlgo; - const int ites = 100; + int startAlgo, endAlgo; + const int ites = 100; struct timeval start, end; CublasDataType data_type; if (std::is_same::value) { - data_type = FLOAT_DATATYPE; - AType = CUDA_R_32F; - BType = CUDA_R_32F; - CType = CUDA_R_32F; + data_type = FLOAT_DATATYPE; + AType = CUDA_R_32F; + BType = CUDA_R_32F; + CType = CUDA_R_32F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT; - endAlgo = (int)CUBLAS_GEMM_ALGO23; + startAlgo = (int)CUBLAS_GEMM_DEFAULT; + endAlgo = (int)CUBLAS_GEMM_ALGO23; } else if (std::is_same::value) { - data_type = HALF_DATATYPE; - AType = CUDA_R_16F; - BType = CUDA_R_16F; - CType = CUDA_R_16F; + data_type = HALF_DATATYPE; + AType = CUDA_R_16F; + BType = CUDA_R_16F; + CType = CUDA_R_16F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #ifdef ENABLE_BF16 else if (std::is_same::value) { - data_type = BFLOAT16_DATATYPE; - AType = CUDA_R_16BF; - BType = CUDA_R_16BF; - CType = CUDA_R_16BF; + data_type = BFLOAT16_DATATYPE; + AType = CUDA_R_16BF; + BType = CUDA_R_16BF; + CType = CUDA_R_16BF; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #endif using scaleT = typename ScaleTypeConverter::Type; scaleT alpha = (scaleT)1.0f; - scaleT beta = (scaleT)0.0f; + scaleT beta = (scaleT)0.0f; printf("***Encoder Gemm Testing Begin***\n"); printf("***Cublas Gemm Testing Begin***\n"); @@ -185,14 +190,14 @@ void generate_encoder_gemm_config( // array of pointer for batchedGemm T* harray[12]; - harray[0] = (T*)buffer; - harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); - harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); - harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); - harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); - harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); - harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); - harray[9] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); + harray[0] = (T*)buffer; + harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); + harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); + harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); + harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); + harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); + harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); + harray[9] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); harray[10] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + 2 * sizeof(T) * m * n); T** darray = 0; @@ -203,7 +208,7 @@ void generate_encoder_gemm_config( T** dCarray = darray + 8; float exec_time = 99999.0f; - int fast_algo = 0; + int fast_algo = 0; for (int algo = startAlgo; algo <= endAlgo; algo++) { cublasStatus_t status; cudaDeviceSynchronize(); @@ -322,7 +327,7 @@ void generate_encoder_gemm_config( if (i < 3 && data_type != FLOAT_DATATYPE) { printf("***cublasLt Gemm Testing Beign***\n"); // Let try a fixed number of combinations - int ALGO_COMBINATIONS = 5000; + int ALGO_COMBINATIONS = 5000; customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; LtHgemmCustomFind(ltHandle, batch_size, @@ -401,7 +406,7 @@ void generate_encoder_gemm_config( else { fd = fopen(SPGEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -425,16 +430,16 @@ void generate_encoder_gemm_config( } cusparseLtHandle_t handle; CHECK_CUSPARSE(cusparseLtInit(&handle)); - cusparseOrder_t order = CUSPARSE_ORDER_COL; - cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; - cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOrder_t order = CUSPARSE_ORDER_COL; + cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; cusparseComputeType compute_type = CUSPARSE_COMPUTE_16F; - unsigned alignment = 16; - cudaStream_t stream = 0; - float alpha2 = 1.0f; - float beta2 = 0.0f; + unsigned alignment = 16; + cudaStream_t stream = 0; + float alpha2 = 1.0f; + float beta2 = 0.0f; for (int i = 0; i < spgemm_num; ++i) { - // to be compatable with spgemm wrapper, we let A be the weight matrix + // to be compatible with spgemm wrapper, we let A be the weight matrix // so m and n are swapped // A: mxk B: kxn C:mxn int m = N[i], n = M[i], k = K[i]; @@ -457,13 +462,13 @@ void generate_encoder_gemm_config( } float exec_time = 99999.0f; - int fast_algo = 0; + int fast_algo = 0; for (int alg = 0; alg < 4; ++alg) { cudaDeviceSynchronize(); cusparseLtMatDescriptor_t matA, matB, matC; - void* d_workspace = nullptr; - int num_streams = 1; - cudaStream_t streams[1] = {stream}; + void* d_workspace = nullptr; + int num_streams = 1; + cudaStream_t streams[1] = {stream}; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( &handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order)) @@ -473,9 +478,9 @@ void generate_encoder_gemm_config( // initializing MatDesc takes a lot of time // and these descs can be stored to other place // whereas storing MatMulPlan to other place will cause errors - cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulAlgSelection_t alg_sel; - cusparseLtMatmulPlan_t plan; + cusparseLtMatmulPlan_t plan; CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( &handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) CHECK_CUSPARSE( @@ -535,12 +540,12 @@ void generate_encoder_gemm_config( } template void generate_encoder_gemm_config( - int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); + int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend, int tensor_para_size); template void generate_encoder_gemm_config( - int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); + int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend, int tensor_para_size); #ifdef ENABLE_BF16 template void generate_encoder_gemm_config<__nv_bfloat16>( - int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); + int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend, int tensor_para_size); #endif } // namespace fastertransformer diff --git a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h index fd067b90b..28cb681ab 100644 --- a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h +++ b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h @@ -34,7 +34,12 @@ namespace fastertransformer { template -void generate_encoder_gemm_config( - int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend = true); +void generate_encoder_gemm_config(int batch_size, + int seq_len, + int head_num, + int size_per_head, + void* buffer, + bool isAppend = true, + int tensor_para_size = 1); } // namespace fastertransformer diff --git a/src/fastertransformer/utils/gemm_test/encoder_igemm_func.cc b/src/fastertransformer/utils/gemm_test/encoder_igemm_func.cc index 36b3d904c..670a3156e 100644 --- a/src/fastertransformer/utils/gemm_test/encoder_igemm_func.cc +++ b/src/fastertransformer/utils/gemm_test/encoder_igemm_func.cc @@ -83,7 +83,7 @@ int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE #if (CUDART_VERSION >= 11000) cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); #else - stages = 0; + stages = 0; #endif printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d " @@ -149,7 +149,7 @@ int printBatchPerfStructure( #if (CUDART_VERSION >= 11000) cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); #else - stages = 0; + stages = 0; #endif printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d " @@ -202,28 +202,28 @@ static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMa return ((perf_a.status == CUBLAS_STATUS_SUCCESS) && (perf_a.time < perf_b.time)); } -static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU) - cublasLtMatmulDesc_t operationDesc, - const void* alpha, /* host or device pointer */ - const void* A, - cublasLtMatrixLayout_t Adesc, - const void* B, - cublasLtMatrixLayout_t Bdesc, - const void* beta, /* host or device pointer */ - const void* C, - cublasLtMatrixLayout_t Cdesc, - void* D, - cublasLtMatrixLayout_t Ddesc, +static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU) + cublasLtMatmulDesc_t operationDesc, + const void* alpha, /* host or device pointer */ + const void* A, + cublasLtMatrixLayout_t Adesc, + const void* B, + cublasLtMatrixLayout_t Bdesc, + const void* beta, /* host or device pointer */ + const void* C, + cublasLtMatrixLayout_t Cdesc, + void* D, + cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t& algo, - int kernelRepeats, - void* workSpace, - size_t workSpaceSizeInBytes, - customMatmulPerf_t& perfResults, - cudaStream_t stream) + int kernelRepeats, + void* workSpace, + size_t workSpaceSizeInBytes, + customMatmulPerf_t& perfResults, + cudaStream_t stream) { cublasLtMatmulHeuristicResult_t heurResult; /* Looping over the Algo */ - int repeats = kernelRepeats; + int repeats = kernelRepeats; cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, &algo, &heurResult); if (algoStatus == CUBLAS_STATUS_SUCCESS) { @@ -258,10 +258,10 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the float time = diffTime(start, end); // For the moment only add successful findings if (algoStatus == CUBLAS_STATUS_SUCCESS) { - perfResults.algo = algo; - perfResults.time = time / repeats; + perfResults.algo = algo; + perfResults.time = time / repeats; perfResults.workspaceSize = heurResult.workspaceSize; - perfResults.wavesCount = heurResult.wavesCount; + perfResults.wavesCount = heurResult.wavesCount; } } else { @@ -279,32 +279,32 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the // API template int LtIgemmCustomFind(cublasLtHandle_t ltHandle, - int m, - int n, - int k, - const scaleT* alpha, /* host pointer */ - const int8_t* A, - const int8_t* B, - const scaleT* beta, /* host pointer */ - T* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout) + int m, + int n, + int k, + const scaleT* alpha, /* host pointer */ + const int8_t* A, + const int8_t* B, + const scaleT* beta, /* host pointer */ + T* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout) { cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - cudaStream_t stream = 0; + cudaStream_t stream = 0; // SplitK value that we are going to try when SplitK is supported for a given algo const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; // Let try a fixed number of combinations #define ALGO_COMBINATIONS 50000 - int AlgoCombinations = ALGO_COMBINATIONS; - int AlgoCount = 0; - int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back + int AlgoCombinations = ALGO_COMBINATIONS; + int AlgoCount = 0; + int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; - int nbAlgoIds = 0; + int nbAlgoIds = 0; #define ALGO_IDS 100 int algoIdA[ALGO_IDS]; @@ -313,11 +313,11 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, Btype = CUDA_R_8I; if (std::is_same::value && std::is_same::value) { - Ctype = CUDA_R_32I; + Ctype = CUDA_R_32I; scaleType = CUDA_R_32I; } else if (std::is_same::value && std::is_same::value) { - Ctype = CUDA_R_8I; + Ctype = CUDA_R_8I; scaleType = CUDA_R_32F; } else { @@ -352,7 +352,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; } #else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; #endif int ldaTransform = 32 * m; @@ -369,7 +369,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, #if (CUDART_VERSION >= 11000) status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); #else - status = cublasLtMatmulDescCreate(&operationDesc, scaleType); + status = cublasLtMatmulDescCreate(&operationDesc, scaleType); #endif if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; @@ -413,7 +413,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, // Loop over the Algo IDs for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) { cublasLtMatmulAlgo_t algo; - size_t sizeWritten = 0; + size_t sizeWritten = 0; /* Initialize algo structure with given Algp ID */ status = cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo); @@ -422,19 +422,19 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, } // Query the tiles enums supported by that algo cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten); - int nbTiles = int(sizeWritten / sizeof(int)); - int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; + int nbTiles = int(sizeWritten / sizeof(int)); + int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; if (nbTiles == 0) { tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; - nbTiles = 1; + nbTiles = 1; } #if (CUDART_VERSION >= 11000) cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten); - int nbStages = int(sizeWritten / sizeof(int)); + int nbStages = int(sizeWritten / sizeof(int)); std::vector stagesA(nbStages == 0 ? 1 : nbStages); if (nbStages == 0) { stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; - nbStages = 1; + nbStages = 1; } else { cublasLtMatmulAlgoCapGetAttribute( @@ -471,14 +471,14 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, if (splitkSupport) { splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); } - // Loop over the splitK value over a fixed sequence splitKSequenceA in addtion to the case where - // splitK is not enabled + // Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case + // where splitK is not enabled for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { /* Setup attribute of the algo to run */ cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx])); int splitK_val = 0; - int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; + int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val)); cublasLtMatmulAlgoConfigSetAttribute( @@ -501,7 +501,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(redScheme)); - status = customMatmulRun(ltHandle, + status = customMatmulRun(ltHandle, operationDesc, alpha, /* host or device pointer */ A, @@ -529,7 +529,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, else { // Non-splitK case /* if user preference is ok with workspace */ if (AlgoCount < AlgoCombinations) { - status = customMatmulRun(ltHandle, + status = customMatmulRun(ltHandle, operationDesc, alpha, /* host or device pointer */ A, @@ -588,60 +588,60 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, } template int LtIgemmCustomFind(cublasLtHandle_t ltHandle, - int m, - int n, - int k, - const int* alpha, /* host pointer */ - const int8_t* A, - const int8_t* B, - const int* beta, /* host pointer */ - int32_t* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout); + int m, + int n, + int k, + const int* alpha, /* host pointer */ + const int8_t* A, + const int8_t* B, + const int* beta, /* host pointer */ + int32_t* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout); template int LtIgemmCustomFind(cublasLtHandle_t ltHandle, - int m, - int n, - int k, - const float* alpha, /* host pointer */ - const int8_t* A, - const int8_t* B, - const float* beta, /* host pointer */ - int8_t* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout); + int m, + int n, + int k, + const float* alpha, /* host pointer */ + const int8_t* A, + const int8_t* B, + const float* beta, /* host pointer */ + int8_t* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout); template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, - int batchCount, - int m, - int n, - int k, - const scaleT* alpha, /* host pointer */ - const int8_t* A, - const int8_t* B, - const scaleT* beta, /* host pointer */ - T* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout) + int batchCount, + int m, + int n, + int k, + const scaleT* alpha, /* host pointer */ + const int8_t* A, + const int8_t* B, + const scaleT* beta, /* host pointer */ + T* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout) { cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - cudaStream_t stream = 0; + cudaStream_t stream = 0; // SplitK value that we are going to try when SplitK is supported for a given algo const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; // Let try a fixed number of combinations #define ALGO_COMBINATIONS 50000 - int AlgoCombinations = ALGO_COMBINATIONS; - int AlgoCount = 0; - int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back + int AlgoCombinations = ALGO_COMBINATIONS; + int AlgoCount = 0; + int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; - int nbAlgoIds = 0; + int nbAlgoIds = 0; #define ALGO_IDS 100 int algoIdA[ALGO_IDS]; @@ -650,11 +650,11 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, Btype = CUDA_R_8I; if (std::is_same::value && std::is_same::value) { - Ctype = CUDA_R_32I; + Ctype = CUDA_R_32I; scaleType = CUDA_R_32I; } else if (std::is_same::value && std::is_same::value) { - Ctype = CUDA_R_8I; + Ctype = CUDA_R_8I; scaleType = CUDA_R_32F; } else { @@ -689,7 +689,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; } #else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; #endif int ldaTransform = 32 * m; @@ -711,7 +711,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, #if (CUDART_VERSION >= 11000) status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); #else - status = cublasLtMatmulDescCreate(&operationDesc, scaleType); + status = cublasLtMatmulDescCreate(&operationDesc, scaleType); #endif if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; @@ -763,7 +763,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, // Loop over the Algo IDs for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) { cublasLtMatmulAlgo_t algo; - size_t sizeWritten = 0; + size_t sizeWritten = 0; /* Initialize algo structure with given Algp ID */ status = cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo); @@ -772,19 +772,19 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, } // Query the tiles enums supported by that algo cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten); - int nbTiles = int(sizeWritten / sizeof(int)); - int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; + int nbTiles = int(sizeWritten / sizeof(int)); + int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; if (nbTiles == 0) { tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; - nbTiles = 1; + nbTiles = 1; } #if (CUDART_VERSION >= 11000) cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten); - int nbStages = int(sizeWritten / sizeof(int)); + int nbStages = int(sizeWritten / sizeof(int)); std::vector stagesA(nbStages == 0 ? 1 : nbStages); if (nbStages == 0) { stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; - nbStages = 1; + nbStages = 1; } else { cublasLtMatmulAlgoCapGetAttribute( @@ -821,14 +821,14 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, if (splitkSupport) { splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); } - // Loop over the splitK value over a fixed sequence splitKSequenceA in addtion to the case where - // splitK is not enabled + // Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case + // where splitK is not enabled for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { /* Setup attribute of the algo to run */ cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx])); int splitK_val = 0; - int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; + int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val)); cublasLtMatmulAlgoConfigSetAttribute( @@ -851,7 +851,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(redScheme)); - status = customMatmulRun(ltHandle, + status = customMatmulRun(ltHandle, operationDesc, alpha, /* host or device pointer */ A, @@ -879,7 +879,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, else { // Non-splitK case /* if user preference is ok with workspace */ if (AlgoCount < AlgoCombinations) { - status = customMatmulRun(ltHandle, + status = customMatmulRun(ltHandle, operationDesc, alpha, /* host or device pointer */ A, @@ -938,32 +938,32 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, } template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, - int batchCount, - int m, - int n, - int k, - const int* alpha, /* host pointer */ - const int8_t* A, - const int8_t* B, - const int* beta, /* host pointer */ - int32_t* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout); + int batchCount, + int m, + int n, + int k, + const int* alpha, /* host pointer */ + const int8_t* A, + const int8_t* B, + const int* beta, /* host pointer */ + int32_t* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout); template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, - int batchCount, - int m, - int n, - int k, - const float* alpha, /* host pointer */ - const int8_t* A, - const int8_t* B, - const float* beta, /* host pointer */ - int8_t* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout); + int batchCount, + int m, + int n, + int k, + const float* alpha, /* host pointer */ + const int8_t* A, + const int8_t* B, + const float* beta, /* host pointer */ + int8_t* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout); // initialize matrix in column-major void matInit(int rows, int cols, int8_t* p, int ld) @@ -983,10 +983,10 @@ int batch_igemm_config(int batchCount, int m, int n, int k, FILE* fout, void* bu { printf("batchCount %d m %d n %d k %d\n", batchCount, m, n, k); int alpha = 1; - int beta = 0; + int beta = 0; - int8_t* d_A = (int8_t*)buffer; // m * k, stored in column-major - int8_t* d_B = d_A + batchCount * m * k; // k * n, stored in column-major + int8_t* d_A = (int8_t*)buffer; // m * k, stored in column-major + int8_t* d_B = d_A + batchCount * m * k; // k * n, stored in column-major int32_t* d_C = (int32_t*)(d_B + batchCount * k * n); // m * n, stored in column-major cublasLtHandle_t ltHandle; @@ -1014,10 +1014,10 @@ int igemm_config(int m, int n, int k, FILE* fout, void* buffer) { printf("batchCount %d m %d n %d k %d\n", 1, m, n, k); int alpha = 1; - int beta = 0; + int beta = 0; - int8_t* d_A = (int8_t*)buffer; // m * k, stored in column-major - int8_t* d_B = d_A + m * k; // k * n, stored in column-major + int8_t* d_A = (int8_t*)buffer; // m * k, stored in column-major + int8_t* d_B = d_A + m * k; // k * n, stored in column-major int32_t* d_C = (int32_t*)(d_B + k * n); // m * n, stored in column-major cublasLtHandle_t ltHandle; @@ -1064,7 +1064,7 @@ int generate_encoder_igemm_config( else { fout = fopen(IGEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fout) != NULL) { config.push_back(std::string(line)); } @@ -1078,22 +1078,22 @@ int generate_encoder_igemm_config( } } - batch_size_ = batch_size; - seq_len_ = seq_len; - head_num_ = head_num; + batch_size_ = batch_size; + seq_len_ = seq_len; + head_num_ = head_num; size_per_head_ = size_per_head; - int m = batch_size * seq_len; - int n = head_num * size_per_head; - int k = n; + int m = batch_size * seq_len; + int n = head_num * size_per_head; + int k = n; int batchCount; printf("***Encoder IGemm Testing Begin***\n"); printf("\n-----------------------------\n"); batchCount = 3; - m = batch_size * seq_len; - k = head_num * size_per_head; - n = k; + m = batch_size * seq_len; + k = head_num * size_per_head; + n = k; if (n % 32 != 0 || k % 32 != 0) { printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); } @@ -1102,9 +1102,9 @@ int generate_encoder_igemm_config( } printf("\n-----------------------------\n"); - m = seq_len; - n = seq_len; - k = size_per_head; + m = seq_len; + n = seq_len; + k = size_per_head; batchCount = batch_size * head_num; if (n % 32 != 0 || k % 32 != 0) { printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); @@ -1114,9 +1114,9 @@ int generate_encoder_igemm_config( } printf("\n-----------------------------\n"); - m = seq_len; - n = size_per_head; - k = seq_len; + m = seq_len; + n = size_per_head; + k = seq_len; batchCount = batch_size * head_num; if (n % 32 != 0 || k % 32 != 0) { printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); @@ -1166,10 +1166,10 @@ int generate_encoder_igemm_config( } if (do_sparse_test) { printf("***cusparseLt Gemm Testing Begin***\n"); - const int spgemm_num = 3; - FILE* fd; - int line_count = 0; - const int ites = 100; + const int spgemm_num = 3; + FILE* fd; + int line_count = 0; + const int ites = 100; struct timeval start, end; if (!isAppend) { fd = fopen(SPIGEMM_CONFIG, "w+"); @@ -1177,7 +1177,7 @@ int generate_encoder_igemm_config( else { fd = fopen(SPIGEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -1218,17 +1218,17 @@ int generate_encoder_igemm_config( cusparseLtHandle_t handle; CHECK_CUSPARSE(cusparseLtInit(&handle)); - cusparseOrder_t col_order = CUSPARSE_ORDER_COL; - cusparseOrder_t row_order = CUSPARSE_ORDER_ROW; - cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; - cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOrder_t col_order = CUSPARSE_ORDER_COL; + cusparseOrder_t row_order = CUSPARSE_ORDER_ROW; + cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; cusparseComputeType compute_type = CUSPARSE_COMPUTE_32I; - unsigned alignment = 16; - cudaStream_t stream = 0; - float alpha2 = 1.0f; - float beta2 = 0.0f; + unsigned alignment = 16; + cudaStream_t stream = 0; + float alpha2 = 1.0f; + float beta2 = 0.0f; for (int i = 0; i < spgemm_num; ++i) { - // to be compatable with spgemm wrapper, we let A be the weight matrix + // to be compatible with spgemm wrapper, we let A be the weight matrix // so m and n are swapped // A: mxk B: kxn C:mxn int m = N[i], n = M[i], k = K[i]; @@ -1256,13 +1256,13 @@ int generate_encoder_igemm_config( } float exec_time = 99999.0f; - int fast_algo = 0; + int fast_algo = 0; for (int alg = 0; alg < 4; ++alg) { cudaDeviceSynchronize(); cusparseLtMatDescriptor_t matA, matB, matC; - void* d_workspace = nullptr; - int num_streams = 1; - cudaStream_t streams[1] = {stream}; + void* d_workspace = nullptr; + int num_streams = 1; + cudaStream_t streams[1] = {stream}; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( &handle, &matA, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_8I, col_order)) @@ -1272,9 +1272,9 @@ int generate_encoder_igemm_config( // initializing MatDesc takes a lot of time // and these descs can be stored to other place // whereas storing MatMulPlan to other place will cause errors - cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulAlgSelection_t alg_sel; - cusparseLtMatmulPlan_t plan; + cusparseLtMatmulPlan_t plan; CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( &handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) CHECK_CUSPARSE( diff --git a/src/fastertransformer/utils/gemm_test/encoder_igemm_func.h b/src/fastertransformer/utils/gemm_test/encoder_igemm_func.h index ce23ca582..1a2926077 100644 --- a/src/fastertransformer/utils/gemm_test/encoder_igemm_func.h +++ b/src/fastertransformer/utils/gemm_test/encoder_igemm_func.h @@ -48,32 +48,32 @@ int printBatchPerfStructure( template int LtIgemmCustomFind(cublasLtHandle_t ltHandle, - int m, - int n, - int k, - const scaleT* alpha, /* host pointer */ - const int8_t* A, - const int8_t* B, - const scaleT* beta, /* host pointer */ - T* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout); + int m, + int n, + int k, + const scaleT* alpha, /* host pointer */ + const int8_t* A, + const int8_t* B, + const scaleT* beta, /* host pointer */ + T* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout); template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, - int batchCount, - int m, - int n, - int k, - const scaleT* alpha, /* host pointer */ - const int8_t* A, - const int8_t* B, - const scaleT* beta, /* host pointer */ - T* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout); + int batchCount, + int m, + int n, + int k, + const scaleT* alpha, /* host pointer */ + const int8_t* A, + const int8_t* B, + const scaleT* beta, /* host pointer */ + T* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout); void matInit(int rows, int cols, int8_t* p, int ld); diff --git a/src/fastertransformer/utils/gemm_test/gemm_func.cc b/src/fastertransformer/utils/gemm_test/gemm_func.cc index edbfc40ad..99d95ebdb 100644 --- a/src/fastertransformer/utils/gemm_test/gemm_func.cc +++ b/src/fastertransformer/utils/gemm_test/gemm_func.cc @@ -23,17 +23,17 @@ namespace fastertransformer { // Utility function to print customMatmulPerf_t structure -int printPerfStructure(int batch_size, - int seq_len, - int head_num, - int size_per_head, - int m, - int n, - int k, +int printPerfStructure(int batch_size, + int seq_len, + int head_num, + int size_per_head, + int m, + int n, + int k, const customMatmulPerf_t& perf, - FILE* fout, - CublasDataType data_type, - int hasPrint) + FILE* fout, + CublasDataType data_type, + int hasPrint) { int algoId, tile, swizzle, customOption, numSplitsK, reductionScheme, stages; @@ -103,30 +103,30 @@ static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMa return ((perf_a.status == CUBLAS_STATUS_SUCCESS) && (perf_a.time < perf_b.time)); } -static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU) - cublasLtMatmulDesc_t operationDesc, - const void* alpha, /* host or device pointer */ - const void* A, - cublasLtMatrixLayout_t Adesc, - const void* B, - cublasLtMatrixLayout_t Bdesc, - const void* beta, /* host or device pointer */ - const void* C, - cublasLtMatrixLayout_t Cdesc, - void* D, - cublasLtMatrixLayout_t Ddesc, +static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU) + cublasLtMatmulDesc_t operationDesc, + const void* alpha, /* host or device pointer */ + const void* A, + cublasLtMatrixLayout_t Adesc, + const void* B, + cublasLtMatrixLayout_t Bdesc, + const void* beta, /* host or device pointer */ + const void* C, + cublasLtMatrixLayout_t Cdesc, + void* D, + cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t& algo, - int kernelRepeats, - void* workSpace, - size_t workSpaceSizeInBytes, - customMatmulPerf_t& perfResults, - cudaStream_t stream, - cudaEvent_t& startEvent, - cudaEvent_t& stopEvent) + int kernelRepeats, + void* workSpace, + size_t workSpaceSizeInBytes, + customMatmulPerf_t& perfResults, + cudaStream_t stream, + cudaEvent_t& startEvent, + cudaEvent_t& stopEvent) { cublasLtMatmulHeuristicResult_t heurResult; /* Looping over the Algo */ - int repeats = kernelRepeats; + int repeats = kernelRepeats; cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, &algo, &heurResult); @@ -165,10 +165,10 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the } // For the moment only add successful findings if (algoStatus == CUBLAS_STATUS_SUCCESS) { - perfResults.algo = algo; - perfResults.time = time / repeats; + perfResults.algo = algo; + perfResults.time = time / repeats; perfResults.workspaceSize = heurResult.workspaceSize; - perfResults.wavesCount = heurResult.wavesCount; + perfResults.wavesCount = heurResult.wavesCount; } } else { @@ -181,31 +181,31 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the } template -int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - int batch_size, - int seq_len, - int head_num, - int size_per_head, - int m, - int n, - int k, - const scaleT* alpha, /* host pointer */ - const T* A, - const T* B, - const scaleT* beta, /* host pointer */ - T* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout, +int LtHgemmCustomFind(cublasLtHandle_t ltHandle, + int batch_size, + int seq_len, + int head_num, + int size_per_head, + int m, + int n, + int k, + const scaleT* alpha, /* host pointer */ + const T* A, + const T* B, + const scaleT* beta, /* host pointer */ + T* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout, customMatmulPerf_t perfResults[], - int AlgoCombinations) + int AlgoCombinations) { cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - cudaEvent_t startEvent; - cudaEvent_t stopEvent; + cudaEvent_t startEvent; + cudaEvent_t stopEvent; CublasDataType data_type; - cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; cudaStream_t stream = 0; @@ -213,13 +213,13 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, // given algo const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; // Let try a fixed number of combinations - int AlgoCount = 0; - int AlgoCountRestrict = 0; // workspace == 0 - int maxNumTraversal = 50; // max number of traversal + int AlgoCount = 0; + int AlgoCountRestrict = 0; // workspace == 0 + int maxNumTraversal = 50; // max number of traversal cublasLtMatmulAlgo_t algos[AlgoCombinations]; // 0 <= workspace <= 32MB cublasLtMatmulAlgo_t algosRestrict[AlgoCombinations]; // workspace == 0 - int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back - int nbAlgoIds = 0; // Number of algorithms actually returned by + int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back + int nbAlgoIds = 0; // Number of algorithms actually returned by // cublasLtMatmulAlgoGetIds function. #define ALGO_IDS 100 // Number of algorithms requested. int algoIdA[ALGO_IDS]; // Array containing the algorithm IDs returned by @@ -310,7 +310,7 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, // Loop over the Algo IDs for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) { cublasLtMatmulAlgo_t algo; - size_t sizeWritten = 0; + size_t sizeWritten = 0; /* Initialize algo structure with given Algp ID */ status = cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo); @@ -319,19 +319,19 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, } // Query the tiles enums supported by that algo cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten); - int nbTiles = int(sizeWritten / sizeof(int)); - int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; + int nbTiles = int(sizeWritten / sizeof(int)); + int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; if (nbTiles == 0) { tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; - nbTiles = 1; + nbTiles = 1; } #if (CUDART_VERSION >= 11000) cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten); - int nbStages = int(sizeWritten / sizeof(int)); + int nbStages = int(sizeWritten / sizeof(int)); std::vector stagesA(nbStages == 0 ? 1 : nbStages); if (nbStages == 0) { stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; - nbStages = 1; + nbStages = 1; } else { cublasLtMatmulAlgoCapGetAttribute( @@ -371,14 +371,14 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); } // Loop over the splitK value over a fixed sequence - // splitKSequenceA in addtion to the case where splitK + // splitKSequenceA in addition to the case where splitK // is not enabled for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { /* Setup attribute of the algo to run */ cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx])); int splitK_val = 0; - int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; + int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val)); cublasLtMatmulAlgoConfigSetAttribute( @@ -403,7 +403,7 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, sizeof(redScheme)); cublasLtMatmulHeuristicResult_t heurResult; - cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( + cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult); if (heurResult.workspaceSize > workSpaceSize) { // printf("not enough workspace! @@ -426,7 +426,7 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, /* if user preference is ok with workspace */ if (AlgoCount < AlgoCombinations) { cublasLtMatmulHeuristicResult_t heurResult; - cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( + cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult); if (heurResult.workspaceSize > workSpaceSize) { // printf("not enough workspace! %ld\n", @@ -459,7 +459,7 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, if (AlgoCount < maxNumTraversal) { // 0 <= workspacesize <= 32MB for (int i = 0; i < AlgoCount; i++) { - status = customMatmulRun(ltHandle, + status = customMatmulRun(ltHandle, operationDesc, alpha, /* host or device pointer */ A, @@ -508,7 +508,7 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, printf("return %d and run heuristic algo\n", nbAlgoIds); for (int i = 0; i < nbAlgoIds; i++) { if (heuristicResultsArray[i].state == CUBLAS_STATUS_SUCCESS) { - status = customMatmulRun(ltHandle, + status = customMatmulRun(ltHandle, operationDesc, alpha, /* host or device pointer */ A, @@ -538,7 +538,7 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, // workspacesize==0 printf("workspacesize==0, run %d algos\n", AlgoCountRestrict); for (int i = 0; i < AlgoCountRestrict && i < (maxNumTraversal - nbAlgoIds); i++) { - status = customMatmulRun(ltHandle, + status = customMatmulRun(ltHandle, operationDesc, alpha, /* host or device pointer */ A, @@ -597,91 +597,91 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } -template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - int batch_size, - int seq_len, - int head_num, - int size_per_head, - int m, - int n, - int k, - const float* alpha, /* host pointer */ - const float* A, - const float* B, - const float* beta, /* host pointer */ - float* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout, +template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, + int batch_size, + int seq_len, + int head_num, + int size_per_head, + int m, + int n, + int k, + const float* alpha, /* host pointer */ + const float* A, + const float* B, + const float* beta, /* host pointer */ + float* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout, customMatmulPerf_t perfResults[], - int AlgoCombinations); - -template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - int batch_size, - int seq_len, - int head_num, - int size_per_head, - int m, - int n, - int k, - const half* alpha, /* host pointer */ - const half* A, - const half* B, - const half* beta, /* host pointer */ - half* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout, + int AlgoCombinations); + +template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, + int batch_size, + int seq_len, + int head_num, + int size_per_head, + int m, + int n, + int k, + const half* alpha, /* host pointer */ + const half* A, + const half* B, + const half* beta, /* host pointer */ + half* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout, customMatmulPerf_t perfResults[], - int AlgoCombinations); + int AlgoCombinations); #ifdef ENABLE_BF16 -template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - int batch_size, - int seq_len, - int head_num, - int size_per_head, - int m, - int n, - int k, - const float* alpha, /* host pointer */ +template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, + int batch_size, + int seq_len, + int head_num, + int size_per_head, + int m, + int n, + int k, + const float* alpha, /* host pointer */ const __nv_bfloat16* A, const __nv_bfloat16* B, - const float* beta, /* host pointer */ - __nv_bfloat16* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout, - customMatmulPerf_t perfResults[], - int AlgoCombinations); + const float* beta, /* host pointer */ + __nv_bfloat16* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout, + customMatmulPerf_t perfResults[], + int AlgoCombinations); #endif -template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - int batch_size, - int seq_len, - int head_num, - int size_per_head, - int m, - int n, - int k, - const float* alpha, /* host pointer */ - const half* A, - const half* B, - const float* beta, /* host pointer */ - half* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout, +template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, + int batch_size, + int seq_len, + int head_num, + int size_per_head, + int m, + int n, + int k, + const float* alpha, /* host pointer */ + const half* A, + const half* B, + const float* beta, /* host pointer */ + half* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout, customMatmulPerf_t perfResults[], - int AlgoCombinations); - -size_t calGemmTestBufSizeInByte(int batch_size, - int seq_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int int8_mode, + int AlgoCombinations); + +size_t calGemmTestBufSizeInByte(int batch_size, + int seq_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int int8_mode, CublasDataType data_type) { size_t buf_size_in_byte; @@ -697,24 +697,26 @@ size_t calGemmTestBufSizeInByte(int batch_size, size_t size3 = batch_size * head_num * (seq_len * seq_len * sizeof(int8_t) + seq_len * size_per_head * sizeof(int8_t) + seq_len * size_per_head * sizeof(int)); - size_t size4 = m * k * sizeof(int8_t) + k * inter_size * sizeof(int8_t) + m * inter_size * sizeof(int); - size_t size5 = m * k * sizeof(int8_t) + k * vocab_size * sizeof(int8_t) + m * vocab_size * sizeof(int); + size_t size4 = m * k * sizeof(int8_t) + k * inter_size * sizeof(int8_t) + m * inter_size * sizeof(int); + size_t size5 = m * k * sizeof(int8_t) + k * vocab_size * sizeof(int8_t) + m * vocab_size * sizeof(int); buf_size_in_byte = size1 > size2 ? size1 : size2; buf_size_in_byte = buf_size_in_byte > size3 ? buf_size_in_byte : size3; buf_size_in_byte = buf_size_in_byte > size4 ? buf_size_in_byte : size4; buf_size_in_byte = buf_size_in_byte > size5 ? buf_size_in_byte : size5; } else { - int m = batch_size * seq_len; - int n = head_num * size_per_head; - int k = n; + size_t m = batch_size * seq_len; + size_t n = head_num * size_per_head; + size_t k = n; // TODO need to add bfloat16 here - int wordSize = (data_type == FLOAT_DATATYPE ? sizeof(float) : sizeof(half)); - size_t size1 = 3 * (m * k + k * n + m * n) * wordSize; - size_t size2 = - batch_size * head_num * (seq_len * seq_len + seq_len * size_per_head + seq_len * size_per_head) * wordSize; - size_t size3 = (m * k + k * inter_size + m * inter_size) * wordSize; - size_t size4 = (m * k + k * vocab_size + m * vocab_size) * wordSize; + int wordSize = (data_type == FLOAT_DATATYPE ? sizeof(float) : sizeof(half)); + size_t size1 = 3 * (m * k + k * n + m * n) * wordSize; + size_t size2 = (size_t)batch_size * (size_t)head_num + * ((size_t)seq_len * (size_t)seq_len + (size_t)seq_len * (size_t)size_per_head + + (size_t)seq_len * (size_t)size_per_head) + * (size_t)wordSize; + size_t size3 = (m * k + k * inter_size + m * inter_size) * wordSize; + size_t size4 = (m * k + k * vocab_size + m * vocab_size) * wordSize; buf_size_in_byte = size1 > size2 ? size1 : size2; buf_size_in_byte = buf_size_in_byte > size3 ? buf_size_in_byte : size3; buf_size_in_byte = buf_size_in_byte > size4 ? buf_size_in_byte : size4; @@ -727,39 +729,39 @@ size_t calGemmTestBufSizeInByte(int batch_size, size_t calGemmTestBufSizeInByteXlnet( int batch_size, int seq_len, int head_num, int size_per_head, int inter_size, int hidden_units, int is_fp16) { - int M[10] = {0}; - int N[10] = {0}; - int K[10] = {0}; + int M[10] = {0}; + int N[10] = {0}; + int K[10] = {0}; int batchCount[10] = {0}; // gemm1 - M[0] = hidden_units; - N[0] = seq_len * batch_size; - K[0] = hidden_units; + M[0] = hidden_units; + N[0] = seq_len * batch_size; + K[0] = hidden_units; batchCount[0] = 3; // gemm2 - M[1] = hidden_units; - N[1] = seq_len * 2; - K[1] = hidden_units; + M[1] = hidden_units; + N[1] = seq_len * 2; + K[1] = hidden_units; batchCount[1] = 1; // gemm3 - M[2] = seq_len; - N[2] = seq_len; - K[2] = size_per_head; + M[2] = seq_len; + N[2] = seq_len; + K[2] = size_per_head; batchCount[2] = batch_size * head_num; // gemm4 - M[3] = seq_len * 2; - N[3] = seq_len; - K[3] = size_per_head; + M[3] = seq_len * 2; + N[3] = seq_len; + K[3] = size_per_head; batchCount[3] = batch_size * head_num; // gemm5 - M[4] = 2; - N[4] = seq_len; - K[4] = size_per_head; + M[4] = 2; + N[4] = seq_len; + K[4] = size_per_head; batchCount[4] = batch_size * head_num; // gemm6 @@ -767,33 +769,33 @@ size_t calGemmTestBufSizeInByteXlnet( N[5] = seq_len; K[5] = 2; // gemm7 - M[6] = size_per_head; - N[6] = seq_len; - K[6] = seq_len; + M[6] = size_per_head; + N[6] = seq_len; + K[6] = seq_len; batchCount[6] = batch_size * head_num; // gemm8 - M[7] = hidden_units; - N[7] = seq_len; - K[7] = hidden_units; + M[7] = hidden_units; + N[7] = seq_len; + K[7] = hidden_units; batchCount[7] = batch_size; // gemm9 - M[8] = inter_size; - N[8] = seq_len; - K[8] = hidden_units; + M[8] = inter_size; + N[8] = seq_len; + K[8] = hidden_units; batchCount[8] = batch_size; // gemm10 - M[9] = hidden_units; - N[9] = seq_len; - K[9] = inter_size; + M[9] = hidden_units; + N[9] = seq_len; + K[9] = inter_size; batchCount[9] = batch_size; size_t max_size = 0; for (int i = 0; i < 10; ++i) { - int m = M[i], n = N[i], k = K[i]; + int m = M[i], n = N[i], k = K[i]; size_t size = (M[i] * N[i] + M[i] * K[i] + N[i] * K[i]) * batchCount[i]; if (size > max_size) { max_size = size; diff --git a/src/fastertransformer/utils/gemm_test/gemm_func.h b/src/fastertransformer/utils/gemm_test/gemm_func.h index 577aa6fac..79d67c7a9 100644 --- a/src/fastertransformer/utils/gemm_test/gemm_func.h +++ b/src/fastertransformer/utils/gemm_test/gemm_func.h @@ -44,47 +44,47 @@ struct ScaleTypeConverter { }; template -int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - int batch_size, - int seq_len, - int head_num, - int size_per_head, - int m, - int n, - int k, - const scaleT* alpha, /* host pointer */ - const T* A, - const T* B, - const scaleT* beta, /* host pointer */ - T* C, - void* workSpace, - size_t workSpaceSize, - FILE* fout, +int LtHgemmCustomFind(cublasLtHandle_t ltHandle, + int batch_size, + int seq_len, + int head_num, + int size_per_head, + int m, + int n, + int k, + const scaleT* alpha, /* host pointer */ + const T* A, + const T* B, + const scaleT* beta, /* host pointer */ + T* C, + void* workSpace, + size_t workSpaceSize, + FILE* fout, customMatmulPerf_t perfResults[], - int AlgoCombinations); + int AlgoCombinations); -size_t calGemmTestBufSizeInByte(int batch_size, - int seq_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int int8_mode, +size_t calGemmTestBufSizeInByte(int batch_size, + int seq_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int int8_mode, CublasDataType data_type); size_t calGemmTestBufSizeInByteXlnet( int batch_size, int seq_len, int head_num, int size_per_head, int inter_size, int hidden_units, int is_fp16); -int printPerfStructure(int batch_size, - int seq_len, - int head_num, - int size_per_head, - int m, - int n, - int k, +int printPerfStructure(int batch_size, + int seq_len, + int head_num, + int size_per_head, + int m, + int n, + int k, const customMatmulPerf_t& perf, - FILE* fout, - CublasDataType data_type, - int hasPrint); + FILE* fout, + CublasDataType data_type, + int hasPrint); } // namespace fastertransformer diff --git a/src/fastertransformer/utils/gemm_test/gpt_gemm_func.cc b/src/fastertransformer/utils/gemm_test/gpt_gemm_func.cc index 0c3575057..11eb86282 100644 --- a/src/fastertransformer/utils/gemm_test/gpt_gemm_func.cc +++ b/src/fastertransformer/utils/gemm_test/gpt_gemm_func.cc @@ -24,21 +24,21 @@ bool isSparseGemmAvailable(size_t m, size_t n, size_t k) } template -void generate_gpt_gemm_config(int batch_size, - int beam_width, - int max_input_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int tensor_para_size, +void generate_gpt_gemm_config(int batch_size, + int beam_width, + int max_input_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend) + bool isAppend) { FT_CHECK(head_num % tensor_para_size == 0); void* cublas_workspace; void* buffer; - int workSpaceSize; + int workSpaceSize; #ifdef ENABLE_BF16 if (std::is_same::value || std::is_same::value) { #else @@ -47,13 +47,13 @@ void generate_gpt_gemm_config(int batch_size, // cublas_workspace_ should be the start pointer of cudaMalloc() // to ensure 16B alignemnet cublas_workspace = buffer_in; - buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); - workSpaceSize = CUBLAS_WORKSPACE_SIZE; + buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); + workSpaceSize = CUBLAS_WORKSPACE_SIZE; } else { cublas_workspace = nullptr; - buffer = buffer_in; - workSpaceSize = 0; + buffer = buffer_in; + workSpaceSize = 0; } struct cudaDeviceProp prop; @@ -62,14 +62,14 @@ void generate_gpt_gemm_config(int batch_size, // check config FILE* fd; - int line_count = 0; + int line_count = 0; if (!isAppend) { fd = fopen(GEMM_CONFIG, "w+"); } else { fd = fopen(GEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -87,91 +87,91 @@ void generate_gpt_gemm_config(int batch_size, } } - const int hidden_units = head_num * size_per_head; - const int local_head_num = head_num / tensor_para_size; + const int hidden_units = head_num * size_per_head; + const int local_head_num = head_num / tensor_para_size; const int local_hidden_units = local_head_num * size_per_head; - const int gemm_num = 11; - int M[gemm_num]; - int N[gemm_num]; - int K[gemm_num]; - int batchCount[gemm_num]; - char mess[gemm_num][256]; - float exec_times[gemm_num]; + const int gemm_num = 11; + int M[gemm_num]; + int N[gemm_num]; + int K[gemm_num]; + int batchCount[gemm_num]; + char mess[gemm_num][256]; + float exec_times[gemm_num]; // gemm 0 - M[0] = batch_size * beam_width * max_input_len; - K[0] = hidden_units; - N[0] = 3 * local_hidden_units; + M[0] = batch_size * beam_width * max_input_len; + K[0] = hidden_units; + N[0] = 3 * local_hidden_units; batchCount[0] = 1; strcpy(mess[0], "context from_tensor * weightQKV"); // gemm 1 - M[1] = max_input_len; - K[1] = size_per_head; - N[1] = max_input_len; + M[1] = max_input_len; + K[1] = size_per_head; + N[1] = max_input_len; batchCount[1] = batch_size * beam_width * local_head_num; strcpy(mess[1], "context batch gemm Q*K^T"); // gemm 2 - M[2] = max_input_len; - K[2] = max_input_len; - N[2] = size_per_head; + M[2] = max_input_len; + K[2] = max_input_len; + N[2] = size_per_head; batchCount[2] = batch_size * beam_width * local_head_num; strcpy(mess[2], "context batch gemm QK*V^T"); // gemm 3 - M[3] = batch_size * beam_width * max_input_len; - K[3] = local_hidden_units; - N[3] = hidden_units; + M[3] = batch_size * beam_width * max_input_len; + K[3] = local_hidden_units; + N[3] = hidden_units; batchCount[3] = 1; strcpy(mess[3], "context attr * output_kernel"); // gemm 4 - M[4] = batch_size * beam_width * max_input_len; - K[4] = hidden_units; - N[4] = inter_size / tensor_para_size; + M[4] = batch_size * beam_width * max_input_len; + K[4] = hidden_units; + N[4] = inter_size / tensor_para_size; batchCount[4] = 1; strcpy(mess[4], "context ffn gemm 1"); // gemm 5 - M[5] = batch_size * beam_width * max_input_len; - K[5] = inter_size / tensor_para_size; - N[5] = hidden_units; + M[5] = batch_size * beam_width * max_input_len; + K[5] = inter_size / tensor_para_size; + N[5] = hidden_units; batchCount[5] = 1; strcpy(mess[5], "context ffn gemm 2"); // gemm 6 - M[6] = batch_size * beam_width; - K[6] = hidden_units; - N[6] = 3 * local_hidden_units; + M[6] = batch_size * beam_width; + K[6] = hidden_units; + N[6] = 3 * local_hidden_units; batchCount[6] = 1; strcpy(mess[6], "from_tensor * weightQKV"); // gemm 7 - M[7] = batch_size * beam_width; - K[7] = local_hidden_units; - N[7] = hidden_units; + M[7] = batch_size * beam_width; + K[7] = local_hidden_units; + N[7] = hidden_units; batchCount[7] = 1; strcpy(mess[7], "attr * output_kernel"); // gemm 8 - M[8] = batch_size * beam_width; - K[8] = hidden_units; - N[8] = inter_size / tensor_para_size; + M[8] = batch_size * beam_width; + K[8] = hidden_units; + N[8] = inter_size / tensor_para_size; batchCount[8] = 1; strcpy(mess[8], "ffn gemm 1"); // gemm 9 - M[9] = batch_size * beam_width; - K[9] = inter_size / tensor_para_size; - N[9] = hidden_units; + M[9] = batch_size * beam_width; + K[9] = inter_size / tensor_para_size; + N[9] = hidden_units; batchCount[9] = 1; strcpy(mess[9], "ffn gemm 2"); // gemm 10 - M[10] = batch_size * beam_width; - K[10] = hidden_units; - N[10] = ceil(vocab_size / 8.) * 8 / tensor_para_size; + M[10] = batch_size * beam_width; + K[10] = hidden_units; + N[10] = ceil(vocab_size / 8.) * 8 / tensor_para_size; batchCount[10] = 1; strcpy(mess[10], "logits gemm"); @@ -184,42 +184,42 @@ void generate_gpt_gemm_config(int batch_size, cudaDataType_t BType; cudaDataType_t CType; cudaDataType_t computeType; - int startAlgo, endAlgo; - const int ites = 100; + int startAlgo, endAlgo; + const int ites = 100; struct timeval start, end; CublasDataType data_type; if (std::is_same::value) { - data_type = FLOAT_DATATYPE; - AType = CUDA_R_32F; - BType = CUDA_R_32F; - CType = CUDA_R_32F; + data_type = FLOAT_DATATYPE; + AType = CUDA_R_32F; + BType = CUDA_R_32F; + CType = CUDA_R_32F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT; - endAlgo = (int)CUBLAS_GEMM_ALGO23; + startAlgo = (int)CUBLAS_GEMM_DEFAULT; + endAlgo = (int)CUBLAS_GEMM_ALGO23; } else if (std::is_same::value) { - data_type = HALF_DATATYPE; - AType = CUDA_R_16F; - BType = CUDA_R_16F; - CType = CUDA_R_16F; + data_type = HALF_DATATYPE; + AType = CUDA_R_16F; + BType = CUDA_R_16F; + CType = CUDA_R_16F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #ifdef ENABLE_BF16 else if (std::is_same::value) { - data_type = BFLOAT16_DATATYPE; - AType = CUDA_R_16BF; - BType = CUDA_R_16BF; - CType = CUDA_R_16BF; + data_type = BFLOAT16_DATATYPE; + AType = CUDA_R_16BF; + BType = CUDA_R_16BF; + CType = CUDA_R_16BF; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #endif float alpha = (float)1.0f; - float beta = (float)0.0f; + float beta = (float)0.0f; printf("***Encoder Gemm Testing Begin***\n"); printf("***Cublas Gemm Testing Begin***\n"); @@ -239,7 +239,7 @@ void generate_gpt_gemm_config(int batch_size, T* d_C = d_B + k * n * batchCount[i]; float exec_time = 99999.0f; - int fast_algo = 0; + int fast_algo = 0; for (int algo = startAlgo; algo <= endAlgo; algo++) { cublasStatus_t status; cudaDeviceSynchronize(); @@ -360,7 +360,7 @@ void generate_gpt_gemm_config(int batch_size, if (data_type != FLOAT_DATATYPE && i != 1 && i != 2 && i != 10) { printf("***cublasLt Gemm Testing Beign***\n"); // Let try a fixed number of combinations - int ALGO_COMBINATIONS = 5000; + int ALGO_COMBINATIONS = 5000; customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; // for gpt, computeType & scaleType should be FP32 @@ -449,7 +449,7 @@ void generate_gpt_gemm_config(int batch_size, else { fd = fopen(SPGEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -475,15 +475,15 @@ void generate_gpt_gemm_config(int batch_size, cusparseLtHandle_t handle; CHECK_CUSPARSE(cusparseLtInit(&handle)); - cusparseOrder_t order = CUSPARSE_ORDER_COL; - cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; - cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOrder_t order = CUSPARSE_ORDER_COL; + cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; // let's make this optional cusparseComputeType compute_type = CUSPARSE_COMPUTE_16F; - unsigned alignment = 16; - cudaStream_t stream = 0; - float alpha2 = 1.0f; - float beta2 = 0.0f; + unsigned alignment = 16; + cudaStream_t stream = 0; + float alpha2 = 1.0f; + float beta2 = 0.0f; for (int i = 0; i < gemm_num; ++i) { // skip qk or attn or logit gemms. if (i == 1 || i == 2 || i == 10) { @@ -493,7 +493,7 @@ void generate_gpt_gemm_config(int batch_size, // seq_len is always 1 except context gemms. int seq_len = i <= 5 ? max_input_len : 1; - // to be compatable with spgemm wrapper, we let A be the weight matrix + // to be compatible with spgemm wrapper, we let A be the weight matrix // so m and n are swapped // A: mxk B: kxn C:mxn int m = N[i], n = M[i], k = K[i]; @@ -521,14 +521,14 @@ void generate_gpt_gemm_config(int batch_size, } float exec_time = 99999.0f; - int fast_algo = 0; + int fast_algo = 0; if (isSparseGemmAvailable(m, n, k)) { for (int alg = 0; alg < 4; ++alg) { cudaDeviceSynchronize(); cusparseLtMatDescriptor_t matA, matB, matC; - void* d_workspace = nullptr; - int num_streams = 1; - cudaStream_t streams[1] = {stream}; + void* d_workspace = nullptr; + int num_streams = 1; + cudaStream_t streams[1] = {stream}; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( &handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order)) @@ -539,9 +539,9 @@ void generate_gpt_gemm_config(int batch_size, // initializing MatDesc takes a lot of time // and these descs can be stored to other place // whereas storing MatMulPlan to other place will cause errors - cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulAlgSelection_t alg_sel; - cusparseLtMatmulPlan_t plan; + cusparseLtMatmulPlan_t plan; CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( &handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) CHECK_CUSPARSE( @@ -603,60 +603,60 @@ void generate_gpt_gemm_config(int batch_size, return; } -template void generate_gpt_gemm_config(int batch_size, - int beam_width, - int max_input_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int tensor_para_size, +template void generate_gpt_gemm_config(int batch_size, + int beam_width, + int max_input_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend); - -template void generate_gpt_gemm_config(int batch_size, - int beam_width, - int max_input_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int tensor_para_size, + bool isAppend); + +template void generate_gpt_gemm_config(int batch_size, + int beam_width, + int max_input_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend); + bool isAppend); #ifdef ENABLE_BF16 -template void generate_gpt_gemm_config<__nv_bfloat16>(int batch_size, - int beam_width, - int max_input_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int tensor_para_size, +template void generate_gpt_gemm_config<__nv_bfloat16>(int batch_size, + int beam_width, + int max_input_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend); + bool isAppend); #endif -size_t calGptGemmTestBufSizeInByte(int batch_size, - int beam_width, - int max_input_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int tensor_para_size, +size_t calGptGemmTestBufSizeInByte(int batch_size, + int beam_width, + int max_input_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int tensor_para_size, CublasDataType data_type) { - size_t buf_size_in_byte = 0; - const size_t hidden_units = head_num * size_per_head; - const size_t local_head_num = head_num / tensor_para_size; + size_t buf_size_in_byte = 0; + const size_t hidden_units = head_num * size_per_head; + const size_t local_head_num = head_num / tensor_para_size; const size_t local_hidden_units = local_head_num * size_per_head; // TODO add bfloat16 int wordSize = (data_type == FLOAT_DATATYPE ? sizeof(float) : sizeof(half)); - size_t m = batch_size * beam_width * max_input_len; + size_t m = batch_size * beam_width * max_input_len; std::vector buff_size; // for context qkv gemm buff_size.push_back(m * hidden_units + hidden_units * 3 * local_hidden_units + m * 3 * local_hidden_units); diff --git a/src/fastertransformer/utils/gemm_test/gpt_gemm_func.h b/src/fastertransformer/utils/gemm_test/gpt_gemm_func.h index 08996e1c8..ecfa4f3c3 100644 --- a/src/fastertransformer/utils/gemm_test/gpt_gemm_func.h +++ b/src/fastertransformer/utils/gemm_test/gpt_gemm_func.h @@ -34,25 +34,25 @@ namespace fastertransformer { template -void generate_gpt_gemm_config(int batch_size, - int beam_width, - int seq_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int tensor_para_size, +void generate_gpt_gemm_config(int batch_size, + int beam_width, + int seq_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend); + bool isAppend); -size_t calGptGemmTestBufSizeInByte(int batch_size, - int beam_width, - int max_input_len, - int head_num, - int size_per_head, - int inter_size, - int vocab_size, - int tensor_para_size, +size_t calGptGemmTestBufSizeInByte(int batch_size, + int beam_width, + int max_input_len, + int head_num, + int size_per_head, + int inter_size, + int vocab_size, + int tensor_para_size, CublasDataType data_type); } // namespace fastertransformer diff --git a/src/fastertransformer/utils/gemm_test/swin_gemm_func.cc b/src/fastertransformer/utils/gemm_test/swin_gemm_func.cc index 227902448..486955177 100644 --- a/src/fastertransformer/utils/gemm_test/swin_gemm_func.cc +++ b/src/fastertransformer/utils/gemm_test/swin_gemm_func.cc @@ -24,7 +24,7 @@ void generate_swin_gemm_config( { void* cublas_workspace; void* buffer; - int workSpaceSize; + int workSpaceSize; #ifdef ENABLE_BF16 if (std::is_same::value || std::is_same::value) { #else @@ -33,13 +33,13 @@ void generate_swin_gemm_config( // cublas_workspace_ should be the start pointer of cudaMalloc() // to ensure 16B alignemnet cublas_workspace = buffer_in; - buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); - workSpaceSize = CUBLAS_WORKSPACE_SIZE; + buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); + workSpaceSize = CUBLAS_WORKSPACE_SIZE; } else { cublas_workspace = nullptr; - buffer = buffer_in; - workSpaceSize = 0; + buffer = buffer_in; + workSpaceSize = 0; } struct cudaDeviceProp prop; @@ -48,7 +48,7 @@ void generate_swin_gemm_config( // check config FILE* fd; - int line_count = 0; + int line_count = 0; if (!isAppend) { fd = fopen(GEMM_CONFIG, "w+"); fprintf( @@ -58,7 +58,7 @@ void generate_swin_gemm_config( else { fd = fopen(GEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -76,14 +76,14 @@ void generate_swin_gemm_config( } } - const int gemm_num = 7; + const int gemm_num = 7; const int NUM_OF_BASIC_LAYERS = 4; - int M[gemm_num]; - int N[gemm_num]; - int K[gemm_num]; - int batchCount[gemm_num] = {1, 1, 1, 1, 1, 1, 1}; - char mess[gemm_num][256]; - float exec_times[gemm_num]; + int M[gemm_num]; + int N[gemm_num]; + int K[gemm_num]; + int batchCount[gemm_num] = {1, 1, 1, 1, 1, 1, 1}; + char mess[gemm_num][256]; + float exec_times[gemm_num]; printf("***Encoder Gemm Testing Begin***\n"); printf("***Cublas Gemm Testing Begin***\n"); @@ -117,15 +117,15 @@ void generate_swin_gemm_config( N[4] = 2 * K[0]; strcpy(mess[4], "patchMerge gemm"); - M[5] = seq_len; - N[5] = seq_len; - K[5] = size_per_head; + M[5] = seq_len; + N[5] = seq_len; + K[5] = size_per_head; batchCount[5] = batch_size * head_num; strcpy(mess[5], "attention batched Gemm1"); - M[6] = seq_len; - N[6] = size_per_head; - K[6] = seq_len; + M[6] = seq_len; + N[6] = size_per_head; + K[6] = seq_len; batchCount[6] = batch_size * head_num; strcpy(mess[6], "attention batched Gemm2"); @@ -138,44 +138,44 @@ void generate_swin_gemm_config( cudaDataType_t BType; cudaDataType_t CType; cudaDataType_t computeType; - int startAlgo, endAlgo; - const int ites = 100; + int startAlgo, endAlgo; + const int ites = 100; struct timeval start, end; CublasDataType data_type; if (std::is_same::value) { - data_type = FLOAT_DATATYPE; - AType = CUDA_R_32F; - BType = CUDA_R_32F; - CType = CUDA_R_32F; + data_type = FLOAT_DATATYPE; + AType = CUDA_R_32F; + BType = CUDA_R_32F; + CType = CUDA_R_32F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT; - endAlgo = (int)CUBLAS_GEMM_ALGO23; + startAlgo = (int)CUBLAS_GEMM_DEFAULT; + endAlgo = (int)CUBLAS_GEMM_ALGO23; } else if (std::is_same::value) { - data_type = HALF_DATATYPE; - AType = CUDA_R_16F; - BType = CUDA_R_16F; - CType = CUDA_R_16F; + data_type = HALF_DATATYPE; + AType = CUDA_R_16F; + BType = CUDA_R_16F; + CType = CUDA_R_16F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #ifdef ENABLE_BF16 else if (std::is_same::value) { - data_type = BFLOAT16_DATATYPE; - AType = CUDA_R_16BF; - BType = CUDA_R_16BF; - CType = CUDA_R_16BF; + data_type = BFLOAT16_DATATYPE; + AType = CUDA_R_16BF; + BType = CUDA_R_16BF; + CType = CUDA_R_16BF; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #endif using scaleT = typename ScaleTypeConverter::Type; scaleT alpha = (scaleT)1.0f; - scaleT beta = (scaleT)0.0f; + scaleT beta = (scaleT)0.0f; for (int i = 0; i < gemm_num; ++i) { // if(i != 0 && i != 5) continue; @@ -189,14 +189,14 @@ void generate_swin_gemm_config( // array of pointer for batchedGemm T* harray[12]; - harray[0] = (T*)buffer; - harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); - harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); - harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); - harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); - harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); - harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); - harray[9] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); + harray[0] = (T*)buffer; + harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); + harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); + harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); + harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); + harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); + harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); + harray[9] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); harray[10] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + 2 * sizeof(T) * m * n); T** darray = 0; @@ -207,7 +207,7 @@ void generate_swin_gemm_config( T** dCarray = darray + 8; float exec_time = 99999.0f; - int fast_algo = 0; + int fast_algo = 0; for (int algo = startAlgo; algo <= endAlgo; algo++) { cublasStatus_t status; cudaDeviceSynchronize(); @@ -304,7 +304,7 @@ void generate_swin_gemm_config( if (i < 5 && data_type != FLOAT_DATATYPE) { printf("***cublasLt Gemm Testing Beign***\n"); // Let try a fixed number of combinations - int ALGO_COMBINATIONS = 5000; + int ALGO_COMBINATIONS = 5000; customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; LtHgemmCustomFind(ltHandle, @@ -368,7 +368,7 @@ void generate_swin_gemm_config( if (basic_layer != NUM_OF_BASIC_LAYERS - 1) { batch_size = batch_size / 4; - head_num = head_num * 2; + head_num = head_num * 2; } } printf("***cublas Gemm Testing End***\n\n"); diff --git a/src/fastertransformer/utils/gemm_test/swin_igemm_func.cc b/src/fastertransformer/utils/gemm_test/swin_igemm_func.cc index 182c2b9a2..9f962dd53 100644 --- a/src/fastertransformer/utils/gemm_test/swin_igemm_func.cc +++ b/src/fastertransformer/utils/gemm_test/swin_igemm_func.cc @@ -60,28 +60,28 @@ static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMa return ((perf_a.status == CUBLAS_STATUS_SUCCESS) && (perf_a.time < perf_b.time)); } -static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU) - cublasLtMatmulDesc_t operationDesc, - const void* alpha, /* host or device pointer */ - const void* A, - cublasLtMatrixLayout_t Adesc, - const void* B, - cublasLtMatrixLayout_t Bdesc, - const void* beta, /* host or device pointer */ - const void* C, - cublasLtMatrixLayout_t Cdesc, - void* D, - cublasLtMatrixLayout_t Ddesc, +static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU) + cublasLtMatmulDesc_t operationDesc, + const void* alpha, /* host or device pointer */ + const void* A, + cublasLtMatrixLayout_t Adesc, + const void* B, + cublasLtMatrixLayout_t Bdesc, + const void* beta, /* host or device pointer */ + const void* C, + cublasLtMatrixLayout_t Cdesc, + void* D, + cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t& algo, - int kernelRepeats, - void* workSpace, - size_t workSpaceSizeInBytes, - customMatmulPerf_t& perfResults, - cudaStream_t stream) + int kernelRepeats, + void* workSpace, + size_t workSpaceSizeInBytes, + customMatmulPerf_t& perfResults, + cudaStream_t stream) { cublasLtMatmulHeuristicResult_t heurResult; /* Looping over the Algo */ - int repeats = kernelRepeats; + int repeats = kernelRepeats; cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, &algo, &heurResult); if (algoStatus == CUBLAS_STATUS_SUCCESS) { @@ -116,10 +116,10 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the float time = diffTime(start, end); // For the moment only add successful findings if (algoStatus == CUBLAS_STATUS_SUCCESS) { - perfResults.algo = algo; - perfResults.time = time / repeats; + perfResults.algo = algo; + perfResults.time = time / repeats; perfResults.workspaceSize = heurResult.workspaceSize; - perfResults.wavesCount = heurResult.wavesCount; + perfResults.wavesCount = heurResult.wavesCount; } } else { @@ -137,7 +137,7 @@ int igemm_config_INT8IO(int m, int n, int k, FILE* fout, void* buffer) { printf("batchCount %d m %d n %d k %d\n", 1, m, n, k); float alpha = 1.0f; - float beta = 0.0f; + float beta = 0.0f; int8_t* d_A = (int8_t*)buffer; // m * k, stored in column-major int8_t* d_B = d_A + m * k; // k * n, stored in column-major @@ -187,7 +187,7 @@ int generate_swin_igemm_config( else { fout = fopen(IGEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fout) != NULL) { config.push_back(std::string(line)); } @@ -201,10 +201,10 @@ int generate_swin_igemm_config( } } - int m = batch_size * seq_len; - int n = head_num * size_per_head; - int k = n; - int batchCount; + int m = batch_size * seq_len; + int n = head_num * size_per_head; + int k = n; + int batchCount; const int NUM_OF_BASIC_LAYERS = 4; printf("***Swin IGemm Testing Begin***\n"); @@ -212,9 +212,9 @@ int generate_swin_igemm_config( for (int basic_layer = 0; basic_layer < NUM_OF_BASIC_LAYERS; basic_layer++) { printf("\n-----------------------------\n"); batchCount = 1; - m = batch_size * seq_len; - k = head_num * size_per_head; - n = 3 * head_num * size_per_head; + m = batch_size * seq_len; + k = head_num * size_per_head; + n = 3 * head_num * size_per_head; if (n % 32 != 0 || k % 32 != 0) { printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); } @@ -258,10 +258,10 @@ int generate_swin_igemm_config( if (basic_layer != NUM_OF_BASIC_LAYERS - 1) { printf("\n-----------------------------\n"); batch_size = batch_size / 4; - head_num = head_num * 2; - m = batch_size * seq_len; - n = head_num * size_per_head; - k = 2 * head_num * size_per_head; + head_num = head_num * 2; + m = batch_size * seq_len; + n = head_num * size_per_head; + k = 2 * head_num * size_per_head; if (n % 32 != 0 || k % 32 != 0) { printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); } diff --git a/src/fastertransformer/utils/gemm_test/t5_gemm_func.cc b/src/fastertransformer/utils/gemm_test/t5_gemm_func.cc index ea8856dc7..e0c189eb0 100644 --- a/src/fastertransformer/utils/gemm_test/t5_gemm_func.cc +++ b/src/fastertransformer/utils/gemm_test/t5_gemm_func.cc @@ -24,29 +24,29 @@ bool isSparseGemmAvailable(size_t m, size_t n, size_t k) } template -void generate_t5_gemm_config(int batch_size, - int beam_width, - int max_mem_seq_len, - int encoder_d_model, - int encoder_head_num, - int encoder_size_per_head, - int encoder_inter_size, - int decoder_d_model, - int decoder_head_num, - int decoder_size_per_head, - int decoder_inter_size, - int decoder_vocab_size, - int tensor_para_size, +void generate_t5_gemm_config(int batch_size, + int beam_width, + int max_mem_seq_len, + int encoder_d_model, + int encoder_head_num, + int encoder_size_per_head, + int encoder_inter_size, + int decoder_d_model, + int decoder_head_num, + int decoder_size_per_head, + int decoder_inter_size, + int decoder_vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend, - bool is_fp16_compute_type) + bool isAppend, + bool is_fp16_compute_type) { FT_CHECK(encoder_head_num % tensor_para_size == 0); FT_CHECK(decoder_head_num % tensor_para_size == 0); void* cublas_workspace; void* buffer; - int workSpaceSize; + int workSpaceSize; #ifdef ENABLE_BF16 if (std::is_same::value || std::is_same::value) { #else @@ -55,13 +55,13 @@ void generate_t5_gemm_config(int batch_size, // cublas_workspace_ should be the start pointer of cudaMalloc() // to ensure 16B alignemnet cublas_workspace = buffer_in; - buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); - workSpaceSize = CUBLAS_WORKSPACE_SIZE; + buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); + workSpaceSize = CUBLAS_WORKSPACE_SIZE; } else { cublas_workspace = nullptr; - buffer = buffer_in; - workSpaceSize = 0; + buffer = buffer_in; + workSpaceSize = 0; } struct cudaDeviceProp prop; @@ -70,14 +70,14 @@ void generate_t5_gemm_config(int batch_size, // check config FILE* fd; - int line_count = 0; + int line_count = 0; if (!isAppend) { fd = fopen(GEMM_CONFIG, "w+"); } else { fd = fopen(GEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -96,80 +96,80 @@ void generate_t5_gemm_config(int batch_size, } const int gemm_num = 12; - int M[gemm_num]; - int N[gemm_num]; - int K[gemm_num]; - int batchCount[gemm_num]; - char mess[gemm_num][256]; - float exec_times[gemm_num]; + int M[gemm_num]; + int N[gemm_num]; + int K[gemm_num]; + int batchCount[gemm_num]; + char mess[gemm_num][256]; + float exec_times[gemm_num]; // gemm 0 - M[0] = batch_size * max_mem_seq_len; - K[0] = encoder_d_model; - N[0] = encoder_head_num / tensor_para_size * encoder_size_per_head; + M[0] = batch_size * max_mem_seq_len; + K[0] = encoder_d_model; + N[0] = encoder_head_num / tensor_para_size * encoder_size_per_head; batchCount[0] = 3; strcpy(mess[0], "encoder from_tensor * batched gemm weightQKV"); // gemm 1 - M[1] = max_mem_seq_len; - K[1] = encoder_size_per_head; - N[1] = max_mem_seq_len; + M[1] = max_mem_seq_len; + K[1] = encoder_size_per_head; + N[1] = max_mem_seq_len; batchCount[1] = batch_size * encoder_head_num / tensor_para_size; strcpy(mess[1], "encoder batch strided gemm Q*K^T"); // gemm 2 - M[2] = max_mem_seq_len; - K[2] = max_mem_seq_len; - N[2] = encoder_size_per_head; + M[2] = max_mem_seq_len; + K[2] = max_mem_seq_len; + N[2] = encoder_size_per_head; batchCount[2] = batch_size * encoder_head_num / tensor_para_size; strcpy(mess[2], "encoder batch strided gemm QK*V^T"); // gemm 3 - M[3] = batch_size * max_mem_seq_len; - K[3] = encoder_head_num / tensor_para_size * encoder_size_per_head; - N[3] = encoder_d_model; + M[3] = batch_size * max_mem_seq_len; + K[3] = encoder_head_num / tensor_para_size * encoder_size_per_head; + N[3] = encoder_d_model; batchCount[3] = 1; strcpy(mess[3], "encoder attr * output_kernel"); // gemm 4 - M[4] = batch_size * max_mem_seq_len; - K[4] = encoder_d_model; - N[4] = encoder_inter_size / tensor_para_size; + M[4] = batch_size * max_mem_seq_len; + K[4] = encoder_d_model; + N[4] = encoder_inter_size / tensor_para_size; batchCount[4] = 1; strcpy(mess[4], "encoder ffn gemm 1"); // gemm 5 - M[5] = batch_size * max_mem_seq_len; - K[5] = encoder_inter_size / tensor_para_size; - N[5] = encoder_d_model; + M[5] = batch_size * max_mem_seq_len; + K[5] = encoder_inter_size / tensor_para_size; + N[5] = encoder_d_model; batchCount[5] = 1; strcpy(mess[5], "encoder ffn gemm 2"); // gemm 6 - M[6] = batch_size * beam_width; - K[6] = decoder_d_model; - N[6] = 3 * decoder_head_num / tensor_para_size * decoder_size_per_head; + M[6] = batch_size * beam_width; + K[6] = decoder_d_model; + N[6] = 3 * decoder_head_num / tensor_para_size * decoder_size_per_head; batchCount[6] = 1; strcpy(mess[6], "from_tensor * weightQKV"); // gemm 7 - M[7] = batch_size * beam_width; - K[7] = decoder_head_num / tensor_para_size * decoder_size_per_head; - N[7] = decoder_d_model; + M[7] = batch_size * beam_width; + K[7] = decoder_head_num / tensor_para_size * decoder_size_per_head; + N[7] = decoder_d_model; batchCount[7] = 1; strcpy(mess[7], "attr * output_kernel"); // gemm 8 - M[8] = batch_size * beam_width; - K[8] = decoder_d_model; - N[8] = decoder_inter_size / tensor_para_size; + M[8] = batch_size * beam_width; + K[8] = decoder_d_model; + N[8] = decoder_inter_size / tensor_para_size; batchCount[8] = 1; strcpy(mess[8], "ffn gemm 1"); // gemm 9 - M[9] = batch_size * beam_width; - K[9] = decoder_inter_size / tensor_para_size; - N[9] = decoder_d_model; + M[9] = batch_size * beam_width; + K[9] = decoder_inter_size / tensor_para_size; + N[9] = decoder_d_model; batchCount[9] = 1; strcpy(mess[9], "ffn gemm 2"); @@ -178,16 +178,16 @@ void generate_t5_gemm_config(int batch_size, if (!std::is_same::value) { decoder_vocab_size_padded = ((size_t)ceil(decoder_vocab_size_padded / 8.) * 8); } - M[10] = batch_size * beam_width; - K[10] = decoder_d_model; - N[10] = decoder_vocab_size_padded / tensor_para_size; + M[10] = batch_size * beam_width; + K[10] = decoder_d_model; + N[10] = decoder_vocab_size_padded / tensor_para_size; batchCount[10] = 1; strcpy(mess[10], "logits gemm"); // gemm 11 - M[11] = batch_size * max_mem_seq_len; - K[11] = encoder_d_model; - N[11] = encoder_head_num / tensor_para_size * encoder_size_per_head; + M[11] = batch_size * max_mem_seq_len; + K[11] = encoder_d_model; + N[11] = encoder_head_num / tensor_para_size * encoder_size_per_head; batchCount[11] = 1; strcpy(mess[11], "encoder from_tensor * splited qkv weight"); @@ -200,48 +200,48 @@ void generate_t5_gemm_config(int batch_size, cudaDataType_t BType; cudaDataType_t CType; cudaDataType_t computeType; - int startAlgo, endAlgo; - const int ites = 100; + int startAlgo, endAlgo; + const int ites = 100; struct timeval start, end; CublasDataType data_type; if (std::is_same::value) { - data_type = FLOAT_DATATYPE; - AType = CUDA_R_32F; - BType = CUDA_R_32F; - CType = CUDA_R_32F; + data_type = FLOAT_DATATYPE; + AType = CUDA_R_32F; + BType = CUDA_R_32F; + CType = CUDA_R_32F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT; - endAlgo = (int)CUBLAS_GEMM_ALGO23; + startAlgo = (int)CUBLAS_GEMM_DEFAULT; + endAlgo = (int)CUBLAS_GEMM_ALGO23; } else if (std::is_same::value) { - data_type = HALF_DATATYPE; - AType = CUDA_R_16F; - BType = CUDA_R_16F; - CType = CUDA_R_16F; + data_type = HALF_DATATYPE; + AType = CUDA_R_16F; + BType = CUDA_R_16F; + CType = CUDA_R_16F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #ifdef ENABLE_BF16 else if (std::is_same::value) { - data_type = BFLOAT16_DATATYPE; - AType = CUDA_R_16BF; - BType = CUDA_R_16BF; - CType = CUDA_R_16BF; + data_type = BFLOAT16_DATATYPE; + AType = CUDA_R_16BF; + BType = CUDA_R_16BF; + CType = CUDA_R_16BF; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #endif float f_alpha = (float)1.0f; - float f_beta = (float)0.0f; + float f_beta = (float)0.0f; half h_alpha = (half)(1.0f); - half h_beta = (half)(0.0f); + half h_beta = (half)(0.0f); void* alpha = computeType == CUDA_R_16F ? (void*)(&h_alpha) : (void*)(&f_alpha); - void* beta = computeType == CUDA_R_16F ? (void*)(&h_beta) : (void*)(&f_beta); + void* beta = computeType == CUDA_R_16F ? (void*)(&h_beta) : (void*)(&f_beta); printf("***Encoder Gemm Testing Begin***\n"); printf("***Cublas Gemm Testing Begin***\n"); @@ -251,8 +251,8 @@ void generate_t5_gemm_config(int batch_size, "customOption, tile, numSplitsK, swizzle, reductionScheme, workspaceSize, stages, exec_time\n"); } for (int i = 0; i < gemm_num; ++i) { - int seq_len = (i <= 5 || i == 11) ? max_mem_seq_len : 1; - int head_num = ((i <= 5 || i == 11) ? encoder_head_num : decoder_head_num) / tensor_para_size; + int seq_len = (i <= 5 || i == 11) ? max_mem_seq_len : 1; + int head_num = ((i <= 5 || i == 11) ? encoder_head_num : decoder_head_num) / tensor_para_size; int size_per_head = (i <= 5 || i == 11) ? encoder_size_per_head : decoder_size_per_head; int m = M[i], n = N[i], k = K[i]; @@ -264,14 +264,14 @@ void generate_t5_gemm_config(int batch_size, // array of pointer for batchedGemm T* harray[12]; - harray[0] = (T*)buffer; - harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); - harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); - harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); - harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); - harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); - harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); - harray[9] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); + harray[0] = (T*)buffer; + harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); + harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); + harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); + harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); + harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); + harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); + harray[9] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); harray[10] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + 2 * sizeof(T) * m * n); T** darray = 0; @@ -282,7 +282,7 @@ void generate_t5_gemm_config(int batch_size, T** dCarray = darray + 8; float exec_time = 99999.0f; - int fast_algo = 0; + int fast_algo = 0; for (int algo = startAlgo; algo <= endAlgo; algo++) { cublasStatus_t status; cudaDeviceSynchronize(); @@ -431,14 +431,14 @@ void generate_t5_gemm_config(int batch_size, if (data_type != FLOAT_DATATYPE && i != 1 && i != 2 && i != 0 && i != 10) { printf("***cublasLt Gemm Testing Beign***\n"); // Let try a fixed number of combinations - int ALGO_COMBINATIONS = 5000; + int ALGO_COMBINATIONS = 5000; customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; // for t5, computeType & scaleType should be FP32 if (is_fp16_compute_type) { - using scaleT = typename ScaleTypeConverter::Type; + using scaleT = typename ScaleTypeConverter::Type; scaleT alpha_scale = (scaleT)1.0f; - scaleT beta_scale = (scaleT)0.0f; + scaleT beta_scale = (scaleT)0.0f; LtHgemmCustomFind(ltHandle, m, @@ -547,7 +547,7 @@ void generate_t5_gemm_config(int batch_size, else { fd = fopen(SPGEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -573,15 +573,15 @@ void generate_t5_gemm_config(int batch_size, cusparseLtHandle_t handle; CHECK_CUSPARSE(cusparseLtInit(&handle)); - cusparseOrder_t order = CUSPARSE_ORDER_COL; - cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; - cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOrder_t order = CUSPARSE_ORDER_COL; + cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; // let's make this optional cusparseComputeType compute_type = CUSPARSE_COMPUTE_16F; - unsigned alignment = 16; - cudaStream_t stream = 0; - float alpha2 = 1.0f; - float beta2 = 0.0f; + unsigned alignment = 16; + cudaStream_t stream = 0; + float alpha2 = 1.0f; + float beta2 = 0.0f; for (int i = 0; i < gemm_num; ++i) { // skip qk or attn or logit gemms. if (i == 1 || i == 2 || i == 10) { @@ -589,11 +589,11 @@ void generate_t5_gemm_config(int batch_size, } // seq_len is always 1 except context gemms. - int seq_len = i <= 5 ? max_mem_seq_len : 1; - int head_num = (i <= 5 ? encoder_head_num : decoder_head_num) / tensor_para_size; + int seq_len = i <= 5 ? max_mem_seq_len : 1; + int head_num = (i <= 5 ? encoder_head_num : decoder_head_num) / tensor_para_size; int size_per_head = i <= 5 ? encoder_size_per_head : decoder_size_per_head; - // to be compatable with spgemm wrapper, we let A be the weight matrix + // to be compatible with spgemm wrapper, we let A be the weight matrix // so m and n are swapped // A: mxk B: kxn C:mxn int m = N[i], n = M[i], k = K[i]; @@ -616,14 +616,14 @@ void generate_t5_gemm_config(int batch_size, } float exec_time = 99999.0f; - int fast_algo = 0; + int fast_algo = 0; if (isSparseGemmAvailable(m, n, k)) { for (int alg = 0; alg < 4; ++alg) { cudaDeviceSynchronize(); cusparseLtMatDescriptor_t matA, matB, matC; - void* d_workspace = nullptr; - int num_streams = 1; - cudaStream_t streams[1] = {stream}; + void* d_workspace = nullptr; + int num_streams = 1; + cudaStream_t streams[1] = {stream}; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( &handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order)) @@ -634,9 +634,9 @@ void generate_t5_gemm_config(int batch_size, // initializing MatDesc takes a lot of time // and these descs can be stored to other place // whereas storing MatMulPlan to other place will cause errors - cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulAlgSelection_t alg_sel; - cusparseLtMatmulPlan_t plan; + cusparseLtMatmulPlan_t plan; CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( &handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) CHECK_CUSPARSE( @@ -698,82 +698,82 @@ void generate_t5_gemm_config(int batch_size, return; } -template void generate_t5_gemm_config(int batch_size, - int beam_width, - int max_mem_seq_len, - int encoder_d_model, - int encoder_head_num, - int encoder_size_per_head, - int encoder_inter_size, - int decoder_d_model, - int decoder_head_num, - int decoder_size_per_head, - int decoder_inter_size, - int decoder_vocab_size, - int tensor_para_size, +template void generate_t5_gemm_config(int batch_size, + int beam_width, + int max_mem_seq_len, + int encoder_d_model, + int encoder_head_num, + int encoder_size_per_head, + int encoder_inter_size, + int decoder_d_model, + int decoder_head_num, + int decoder_size_per_head, + int decoder_inter_size, + int decoder_vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend, - bool is_fp16_compute_type); - -template void generate_t5_gemm_config(int batch_size, - int beam_width, - int max_mem_seq_len, - int encoder_d_model, - int encoder_head_num, - int encoder_size_per_head, - int encoder_inter_size, - int decoder_d_model, - int decoder_head_num, - int decoder_size_per_head, - int decoder_inter_size, - int decoder_vocab_size, - int tensor_para_size, + bool isAppend, + bool is_fp16_compute_type); + +template void generate_t5_gemm_config(int batch_size, + int beam_width, + int max_mem_seq_len, + int encoder_d_model, + int encoder_head_num, + int encoder_size_per_head, + int encoder_inter_size, + int decoder_d_model, + int decoder_head_num, + int decoder_size_per_head, + int decoder_inter_size, + int decoder_vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend, - bool is_fp16_compute_type); + bool isAppend, + bool is_fp16_compute_type); #ifdef ENABLE_BF16 -template void generate_t5_gemm_config<__nv_bfloat16>(int batch_size, - int beam_width, - int max_mem_seq_len, - int encoder_d_model, - int encoder_head_num, - int encoder_size_per_head, - int encoder_inter_size, - int decoder_d_model, - int decoder_head_num, - int decoder_size_per_head, - int decoder_inter_size, - int decoder_vocab_size, - int tensor_para_size, +template void generate_t5_gemm_config<__nv_bfloat16>(int batch_size, + int beam_width, + int max_mem_seq_len, + int encoder_d_model, + int encoder_head_num, + int encoder_size_per_head, + int encoder_inter_size, + int decoder_d_model, + int decoder_head_num, + int decoder_size_per_head, + int decoder_inter_size, + int decoder_vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend, - bool is_fp16_compute_type); + bool isAppend, + bool is_fp16_compute_type); #endif -size_t calT5GemmTestBufSizeInByte(int batch_size, - int beam_width, - int max_mem_seq_len, - int encoder_d_model, - int encoder_head_num, - int encoder_size_per_head, - int encoder_inter_size, - int decoder_d_model, - int decoder_head_num, - int decoder_size_per_head, - int decoder_inter_size, - int decoder_vocab_size, - int tensor_para_size, +size_t calT5GemmTestBufSizeInByte(int batch_size, + int beam_width, + int max_mem_seq_len, + int encoder_d_model, + int encoder_head_num, + int encoder_size_per_head, + int encoder_inter_size, + int decoder_d_model, + int decoder_head_num, + int decoder_size_per_head, + int decoder_inter_size, + int decoder_vocab_size, + int tensor_para_size, CublasDataType data_type) { - const size_t local_encoder_head_num = encoder_head_num / tensor_para_size; + const size_t local_encoder_head_num = encoder_head_num / tensor_para_size; const size_t local_encoder_hidden_units = local_encoder_head_num * encoder_size_per_head; - const size_t local_encoder_inter_size = encoder_inter_size / tensor_para_size; - const size_t local_decoder_head_num = decoder_head_num / tensor_para_size; + const size_t local_encoder_inter_size = encoder_inter_size / tensor_para_size; + const size_t local_decoder_head_num = decoder_head_num / tensor_para_size; const size_t local_decoder_hidden_units = local_decoder_head_num * decoder_size_per_head; - const size_t local_decoder_inter_size = decoder_inter_size / tensor_para_size; + const size_t local_decoder_inter_size = decoder_inter_size / tensor_para_size; - size_t m = batch_size * max_mem_seq_len; + size_t m = batch_size * max_mem_seq_len; std::vector buff_size; // encoder qkv gemm @@ -805,7 +805,7 @@ size_t calT5GemmTestBufSizeInByte(int batch_size, + m * decoder_vocab_size_padded / tensor_para_size); size_t buf_size_in_byte = 0; - int wordSize = (data_type == FLOAT_DATATYPE ? sizeof(float) : sizeof(half)); + int wordSize = (data_type == FLOAT_DATATYPE ? sizeof(float) : sizeof(half)); for (auto t : buff_size) { buf_size_in_byte = buf_size_in_byte > t ? buf_size_in_byte : t; } diff --git a/src/fastertransformer/utils/gemm_test/t5_gemm_func.h b/src/fastertransformer/utils/gemm_test/t5_gemm_func.h index 56b8ba18d..67ec12a8c 100644 --- a/src/fastertransformer/utils/gemm_test/t5_gemm_func.h +++ b/src/fastertransformer/utils/gemm_test/t5_gemm_func.h @@ -34,36 +34,36 @@ namespace fastertransformer { template -void generate_t5_gemm_config(int batch_size, - int beam_width, - int max_mem_seq_len, - int encoder_d_model, - int encoder_head_num, - int encoder_size_per_head, - int encoder_inter_size, - int decoder_d_model, - int decoder_head_num, - int decoder_size_per_head, - int decoder_inter_size, - int decoder_vocab_size, - int tensor_para_size, +void generate_t5_gemm_config(int batch_size, + int beam_width, + int max_mem_seq_len, + int encoder_d_model, + int encoder_head_num, + int encoder_size_per_head, + int encoder_inter_size, + int decoder_d_model, + int decoder_head_num, + int decoder_size_per_head, + int decoder_inter_size, + int decoder_vocab_size, + int tensor_para_size, void* buffer_in, - bool isAppend, - bool is_fp16_compute_type); + bool isAppend, + bool is_fp16_compute_type); -size_t calT5GemmTestBufSizeInByte(int batch_size, - int beam_width, - int max_mem_seq_len, - int encoder_d_model, - int encoder_head_num, - int encoder_size_per_head, - int encoder_inter_size, - int decoder_d_model, - int decoder_head_num, - int decoder_size_per_head, - int decoder_inter_size, - int decoder_vocab_size, - int tensor_para_size, +size_t calT5GemmTestBufSizeInByte(int batch_size, + int beam_width, + int max_mem_seq_len, + int encoder_d_model, + int encoder_head_num, + int encoder_size_per_head, + int encoder_inter_size, + int decoder_d_model, + int decoder_head_num, + int decoder_size_per_head, + int decoder_inter_size, + int decoder_vocab_size, + int tensor_para_size, CublasDataType data_type); } // namespace fastertransformer diff --git a/src/fastertransformer/utils/gemm_test/xlnet_gemm_func.cc b/src/fastertransformer/utils/gemm_test/xlnet_gemm_func.cc index 977b108cc..799a95a50 100644 --- a/src/fastertransformer/utils/gemm_test/xlnet_gemm_func.cc +++ b/src/fastertransformer/utils/gemm_test/xlnet_gemm_func.cc @@ -19,18 +19,18 @@ namespace fastertransformer { template -void generate_xlnet_gemm_config(int batch_size, - int seq_len, - int head_num, - int size_per_head, - int hidden_units_, - int inter_size_, +void generate_xlnet_gemm_config(int batch_size, + int seq_len, + int head_num, + int size_per_head, + int hidden_units_, + int inter_size_, void* buffer_in, - bool isAppend) + bool isAppend) { void* cublas_workspace; void* buffer; - int workSpaceSize; + int workSpaceSize; #ifdef ENABLE_BF16 if (std::is_same::value || std::is_same::value) { @@ -40,13 +40,13 @@ void generate_xlnet_gemm_config(int batch_size, // cublas_workspace_ should be the start pointer of cudaMalloc() // to ensure 16B alignemnet cublas_workspace = buffer_in; - buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); - workSpaceSize = CUBLAS_WORKSPACE_SIZE; + buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); + workSpaceSize = CUBLAS_WORKSPACE_SIZE; } else { cublas_workspace = nullptr; - buffer = buffer_in; - workSpaceSize = 0; + buffer = buffer_in; + workSpaceSize = 0; } struct cudaDeviceProp prop; @@ -55,14 +55,14 @@ void generate_xlnet_gemm_config(int batch_size, // check config FILE* fd; - int line_count = 0; + int line_count = 0; if (!isAppend) { fd = fopen(GEMM_CONFIG, "w+"); } else { fd = fopen(GEMM_CONFIG, "a+"); std::vector config; - char line[1024]; + char line[1024]; while (fgets(line, 1024, fd) != NULL) { config.push_back(std::string(line)); } @@ -80,136 +80,136 @@ void generate_xlnet_gemm_config(int batch_size, } } - const int gemm_num = 10; - int M[gemm_num]; - int N[gemm_num]; - int K[gemm_num]; - int lda[gemm_num]; - int strideA[gemm_num]; - int ldb[gemm_num]; - int strideB[gemm_num]; - int ldc[gemm_num]; - int strideC[gemm_num]; - cublasOperation_t transa[gemm_num] = {CUBLAS_OP_N, - CUBLAS_OP_N, - CUBLAS_OP_T, - CUBLAS_OP_T, - CUBLAS_OP_T, - CUBLAS_OP_T, - CUBLAS_OP_N, - CUBLAS_OP_T, - CUBLAS_OP_N, - CUBLAS_OP_N}; - cublasOperation_t transb[gemm_num] = {CUBLAS_OP_N}; - int batchCount[gemm_num] = {1}; - char mess[gemm_num][256]; + const int gemm_num = 10; + int M[gemm_num]; + int N[gemm_num]; + int K[gemm_num]; + int lda[gemm_num]; + int strideA[gemm_num]; + int ldb[gemm_num]; + int strideB[gemm_num]; + int ldc[gemm_num]; + int strideC[gemm_num]; + cublasOperation_t transa[gemm_num] = {CUBLAS_OP_N, + CUBLAS_OP_N, + CUBLAS_OP_T, + CUBLAS_OP_T, + CUBLAS_OP_T, + CUBLAS_OP_T, + CUBLAS_OP_N, + CUBLAS_OP_T, + CUBLAS_OP_N, + CUBLAS_OP_N}; + cublasOperation_t transb[gemm_num] = {CUBLAS_OP_N}; + int batchCount[gemm_num] = {1}; + char mess[gemm_num][256]; // gemm1 - M[0] = hidden_units_; - N[0] = seq_len * batch_size; - K[0] = hidden_units_; - lda[0] = hidden_units_; - strideA[0] = hidden_units_ * hidden_units_; - ldb[0] = hidden_units_; - strideB[0] = 0; - ldc[0] = hidden_units_; - strideC[0] = seq_len * batch_size * hidden_units_; + M[0] = hidden_units_; + N[0] = seq_len * batch_size; + K[0] = hidden_units_; + lda[0] = hidden_units_; + strideA[0] = hidden_units_ * hidden_units_; + ldb[0] = hidden_units_; + strideB[0] = 0; + ldc[0] = hidden_units_; + strideC[0] = seq_len * batch_size * hidden_units_; batchCount[0] = 3; strcpy(mess[0], "from_tensor * weightQ/K/V"); // gemm2 - M[1] = hidden_units_; - N[1] = seq_len * 2; - K[1] = hidden_units_; + M[1] = hidden_units_; + N[1] = seq_len * 2; + K[1] = hidden_units_; batchCount[1] = 1; strcpy(mess[1], " k_head_r_"); // gemm3 - M[2] = seq_len; - N[2] = seq_len; - K[2] = size_per_head; - lda[2] = size_per_head; - strideA[2] = seq_len * size_per_head; - ldb[2] = size_per_head; - strideB[2] = seq_len * size_per_head; - ldc[2] = seq_len; - strideC[2] = seq_len * seq_len; + M[2] = seq_len; + N[2] = seq_len; + K[2] = size_per_head; + lda[2] = size_per_head; + strideA[2] = seq_len * size_per_head; + ldb[2] = size_per_head; + strideB[2] = seq_len * size_per_head; + ldc[2] = seq_len; + strideC[2] = seq_len * seq_len; batchCount[2] = batch_size * head_num; strcpy(mess[2], "ac"); // gemm4 - M[3] = seq_len * 2; - N[3] = seq_len; - K[3] = size_per_head; - lda[3] = size_per_head; + M[3] = seq_len * 2; + N[3] = seq_len; + K[3] = size_per_head; + lda[3] = size_per_head; strideA[3] = seq_len * 2 * size_per_head; - ldb[3] = size_per_head; + ldb[3] = size_per_head; strideB[3] = seq_len * size_per_head; - ldc[3] = seq_len * 2; + ldc[3] = seq_len * 2; strideC[3] = seq_len * seq_len * 2; batchCount[3] = batch_size * head_num; strcpy(mess[3], "bd"); // gemm5 - M[4] = 2; - N[4] = seq_len; - K[4] = size_per_head; - lda[4] = size_per_head; - strideA[4] = 2 * size_per_head; - ldb[4] = size_per_head; - strideB[4] = seq_len * size_per_head; - ldc[4] = 2; - strideC[4] = seq_len * 2; + M[4] = 2; + N[4] = seq_len; + K[4] = size_per_head; + lda[4] = size_per_head; + strideA[4] = 2 * size_per_head; + ldb[4] = size_per_head; + strideB[4] = seq_len * size_per_head; + ldc[4] = 2; + strideC[4] = seq_len * 2; batchCount[4] = batch_size * head_num; strcpy(mess[4], "ef"); // gemm6 - M[5] = head_num; - N[5] = seq_len; - K[5] = 2; - lda[5] = 2; + M[5] = head_num; + N[5] = seq_len; + K[5] = 2; + lda[5] = 2; strideA[5] = 2 * head_num; - ldb[5] = 2; + ldb[5] = 2; strideB[5] = seq_len * 2; - ldc[5] = head_num; + ldc[5] = head_num; strideC[5] = seq_len * head_num; batchCount[5] = batch_size * seq_len; strcpy(mess[5], "seg_mat"); // gemm7 - M[6] = size_per_head; - N[6] = seq_len; - K[6] = seq_len; - lda[6] = size_per_head; + M[6] = size_per_head; + N[6] = seq_len; + K[6] = seq_len; + lda[6] = size_per_head; strideA[6] = seq_len * size_per_head; - ldb[6] = seq_len; + ldb[6] = seq_len; strideB[6] = seq_len * seq_len; - ldc[6] = size_per_head; + ldc[6] = size_per_head; strideC[6] = seq_len * size_per_head; batchCount[6] = batch_size * head_num; strcpy(mess[6], "attn_vec"); // gemm8 - M[7] = hidden_units_; - N[7] = seq_len * batch_size; - K[7] = hidden_units_; - lda[7] = hidden_units_; + M[7] = hidden_units_; + N[7] = seq_len * batch_size; + K[7] = hidden_units_; + lda[7] = hidden_units_; batchCount[7] = 1; strcpy(mess[7], "attn_out"); // gemm9 - M[8] = inter_size_; - N[8] = seq_len * batch_size; - K[8] = hidden_units_; + M[8] = inter_size_; + N[8] = seq_len * batch_size; + K[8] = hidden_units_; batchCount[8] = 1; strcpy(mess[8], "output_fc1_"); // gemm10 - M[9] = hidden_units_; - N[9] = seq_len * batch_size; - K[9] = inter_size_; + M[9] = hidden_units_; + N[9] = seq_len * batch_size; + K[9] = inter_size_; batchCount[9] = 1; strcpy(mess[9], "output_fc2_"); @@ -223,45 +223,45 @@ void generate_xlnet_gemm_config(int batch_size, cudaDataType_t BType; cudaDataType_t CType; cudaDataType_t computeType; - int startAlgo, endAlgo; - const int ites = 100; + int startAlgo, endAlgo; + const int ites = 100; struct timeval start, end; CublasDataType data_type; if (std::is_same::value) { - data_type = FLOAT_DATATYPE; - AType = CUDA_R_32F; - BType = CUDA_R_32F; - CType = CUDA_R_32F; + data_type = FLOAT_DATATYPE; + AType = CUDA_R_32F; + BType = CUDA_R_32F; + CType = CUDA_R_32F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT; - endAlgo = (int)CUBLAS_GEMM_ALGO23; + startAlgo = (int)CUBLAS_GEMM_DEFAULT; + endAlgo = (int)CUBLAS_GEMM_ALGO23; } else if (std::is_same::value) { - data_type = HALF_DATATYPE; - AType = CUDA_R_16F; - BType = CUDA_R_16F; - CType = CUDA_R_16F; + data_type = HALF_DATATYPE; + AType = CUDA_R_16F; + BType = CUDA_R_16F; + CType = CUDA_R_16F; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #ifdef ENABLE_BF16 else if (std::is_same::value) { - data_type = BFLOAT16_DATATYPE; - AType = CUDA_R_16BF; - BType = CUDA_R_16BF; - CType = CUDA_R_16BF; + data_type = BFLOAT16_DATATYPE; + AType = CUDA_R_16BF; + BType = CUDA_R_16BF; + CType = CUDA_R_16BF; computeType = CUDA_R_32F; - startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; } #endif using scaleT = typename ScaleTypeConverter::Type; scaleT alpha = (scaleT)1.0f; - scaleT beta = (scaleT)0.0f; + scaleT beta = (scaleT)0.0f; printf("***Xlnet Gemm Testing Begin***\n"); printf("***Cublas Gemm Testing Begin***\n"); @@ -281,7 +281,7 @@ void generate_xlnet_gemm_config(int batch_size, T* d_C = d_B + k * n * batchCount[i]; float exec_time = 99999.0f; - int fast_algo = 0; + int fast_algo = 0; for (int algo = startAlgo; algo <= endAlgo; algo++) { cublasStatus_t status; cudaDeviceSynchronize(); @@ -353,7 +353,7 @@ void generate_xlnet_gemm_config(int batch_size, if ((i == 1 || i == 7 || i == 8 || i == 9) && data_type != FLOAT_DATATYPE) { printf("***cublasLt Gemm Testing Beign***\n"); // Let try a fixed number of combinations - int ALGO_COMBINATIONS = 5000; + int ALGO_COMBINATIONS = 5000; customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; LtHgemmCustomFind(ltHandle, @@ -419,31 +419,31 @@ void generate_xlnet_gemm_config(int batch_size, return; } -template void generate_xlnet_gemm_config(int batch_size, - int seq_len, - int head_num, - int size_per_head, - int hidden_units_, - int inter_size_, +template void generate_xlnet_gemm_config(int batch_size, + int seq_len, + int head_num, + int size_per_head, + int hidden_units_, + int inter_size_, void* buffer_in, - bool isAppend); -template void generate_xlnet_gemm_config(int batch_size, - int seq_len, - int head_num, - int size_per_head, - int hidden_units_, - int inter_size_, + bool isAppend); +template void generate_xlnet_gemm_config(int batch_size, + int seq_len, + int head_num, + int size_per_head, + int hidden_units_, + int inter_size_, void* buffer_in, - bool isAppend); + bool isAppend); #ifdef ENABLE_BF16 -template void generate_xlnet_gemm_config<__nv_bfloat16>(int batch_size, - int seq_len, - int head_num, - int size_per_head, - int hidden_units_, - int inter_size_, +template void generate_xlnet_gemm_config<__nv_bfloat16>(int batch_size, + int seq_len, + int head_num, + int size_per_head, + int hidden_units_, + int inter_size_, void* buffer_in, - bool isAppend); + bool isAppend); #endif } // namespace fastertransformer diff --git a/src/fastertransformer/utils/gemm_test/xlnet_gemm_func.h b/src/fastertransformer/utils/gemm_test/xlnet_gemm_func.h index 6e290102f..1e9e8c33a 100644 --- a/src/fastertransformer/utils/gemm_test/xlnet_gemm_func.h +++ b/src/fastertransformer/utils/gemm_test/xlnet_gemm_func.h @@ -34,13 +34,13 @@ namespace fastertransformer { template -void generate_xlnet_gemm_config(int batch_size, - int seq_len, - int head_num, - int size_per_head, - int hidden_units_, - int inter_size_, +void generate_xlnet_gemm_config(int batch_size, + int seq_len, + int head_num, + int size_per_head, + int hidden_units_, + int inter_size_, void* buffer_in, - bool isAppend = true); + bool isAppend = true); } // namespace fastertransformer diff --git a/src/fastertransformer/utils/logger.h b/src/fastertransformer/utils/logger.h index bcdf8fa6c..91974cc15 100644 --- a/src/fastertransformer/utils/logger.h +++ b/src/fastertransformer/utils/logger.h @@ -12,11 +12,11 @@ class Logger { public: enum Level { - TRACE = 0, - DEBUG = 10, - INFO = 20, + TRACE = 0, + DEBUG = 10, + INFO = 20, WARNING = 30, - ERROR = 40 + ERROR = 40 }; static Logger& getLogger() @@ -24,15 +24,15 @@ class Logger { static Logger instance; return instance; } - Logger(Logger const&) = delete; + Logger(Logger const&) = delete; void operator=(Logger const&) = delete; template void log(const Level level, const std::string format, const Args&... args) { if (level_ <= level) { - std::string fmt = getPrefix(level) + format + "\n"; - FILE* out = level_ < WARNING ? stdout : stderr; + std::string fmt = getPrefix(level) + format + "\n"; + FILE* out = level_ < WARNING ? stdout : stderr; std::string logstr = fmtstr(fmt, args...); fprintf(out, "%s", logstr.c_str()); } @@ -42,8 +42,8 @@ class Logger { void log(const Level level, const int rank, const std::string format, const Args&... args) { if (level_ <= level) { - std::string fmt = getPrefix(level, rank) + format + "\n"; - FILE* out = level_ < WARNING ? stdout : stderr; + std::string fmt = getPrefix(level, rank) + format + "\n"; + FILE* out = level_ < WARNING ? stdout : stderr; std::string logstr = fmtstr(fmt, args...); fprintf(out, "%s", logstr.c_str()); } @@ -56,7 +56,7 @@ class Logger { } private: - const std::string PREFIX = "[FT]"; + const std::string PREFIX = "[FT]"; std::map level_name_ = { {TRACE, "TRACE"}, {DEBUG, "DEBUG"}, {INFO, "INFO"}, {WARNING, "WARNING"}, {ERROR, "ERROR"}}; diff --git a/src/fastertransformer/utils/memory_utils.cu b/src/fastertransformer/utils/memory_utils.cu index f47b2dd51..ffe46f3b6 100644 --- a/src/fastertransformer/utils/memory_utils.cu +++ b/src/fastertransformer/utils/memory_utils.cu @@ -21,24 +21,25 @@ namespace fastertransformer { template -void deviceMalloc(T** ptr, int size, bool is_random_initialize) +void deviceMalloc(T** ptr, size_t size, bool is_random_initialize) { + FT_CHECK_WITH_INFO(size >= 0, "Ask deviceMalloc size " + std::to_string(size) + "< 0 is invalid."); check_cuda_error(cudaMalloc((void**)(ptr), sizeof(T) * size)); if (is_random_initialize) { cudaRandomUniform(*ptr, size); } } -template void deviceMalloc(float** ptr, int size, bool is_random_initialize); -template void deviceMalloc(half** ptr, int size, bool is_random_initialize); +template void deviceMalloc(float** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(half** ptr, size_t size, bool is_random_initialize); #ifdef ENABLE_BF16 -template void deviceMalloc(__nv_bfloat16** ptr, int size, bool is_random_initialize); +template void deviceMalloc(__nv_bfloat16** ptr, size_t size, bool is_random_initialize); #endif -template void deviceMalloc(uint16_t** ptr, int size, bool is_random_initialize); -template void deviceMalloc(int** ptr, int size, bool is_random_initialize); -template void deviceMalloc(bool** ptr, int size, bool is_random_initialize); -template void deviceMalloc(char** ptr, int size, bool is_random_initialize); -template void deviceMalloc(int8_t** ptr, int size, bool is_random_initialize); +template void deviceMalloc(uint16_t** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(int** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(bool** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(char** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(int8_t** ptr, size_t size, bool is_random_initialize); template void deviceMemSetZero(T* ptr, int size) @@ -73,20 +74,21 @@ template void deviceFree(char*& ptr); template void deviceFree(int8_t*& ptr); template -void deviceFill(T* devptr, int size, T value) +void deviceFill(T* devptr, int size, T value, cudaStream_t stream) { T* arr = new T[size]; std::fill(arr, arr + size, value); - check_cuda_error(cudaMemcpy(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice)); + check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream)); delete[] arr; } -template void deviceFill(float* devptr, int size, float value); -template void deviceFill(half* devptr, int size, half value); +template void deviceFill(float* devptr, int size, float value, cudaStream_t stream); +template void deviceFill(half* devptr, int size, half value, cudaStream_t stream); #ifdef ENABLE_BF16 -template void deviceFill(__nv_bfloat16* devptr, int size, __nv_bfloat16 value); +template void deviceFill(__nv_bfloat16* devptr, int size, __nv_bfloat16 value, cudaStream_t stream); #endif -template void deviceFill(int* devptr, int size, int value); +template void deviceFill(int* devptr, int size, int value, cudaStream_t stream); +template void deviceFill(bool* devptr, int size, bool value, cudaStream_t stream); template void cudaD2Hcpy(T* tgt, const T* src, const int size) @@ -101,6 +103,8 @@ template void cudaD2Hcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, int size) #endif template void cudaD2Hcpy(int* tgt, const int* src, int size); template void cudaD2Hcpy(bool* tgt, const bool* src, int size); +template void cudaD2Hcpy(unsigned long long* tgt, const unsigned long long* src, int size); +template void cudaD2Hcpy(unsigned int* tgt, const unsigned int* src, int size); template void cudaH2Dcpy(T* tgt, const T* src, const int size) @@ -115,6 +119,8 @@ template void cudaH2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, int size) #endif template void cudaH2Dcpy(int* tgt, const int* src, int size); template void cudaH2Dcpy(bool* tgt, const bool* src, int size); +template void cudaH2Dcpy(unsigned long long* tgt, const unsigned long long* src, int size); +template void cudaH2Dcpy(unsigned int* tgt, const unsigned int* src, int size); template void cudaD2Dcpy(T* tgt, const T* src, const int size) @@ -130,11 +136,45 @@ template void cudaD2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, int size) template void cudaD2Dcpy(int* tgt, const int* src, int size); template void cudaD2Dcpy(bool* tgt, const bool* src, int size); template void cudaD2Dcpy(int8_t* tgt, const int8_t* src, int size); +template void cudaD2Dcpy(unsigned long long* tgt, const unsigned long long* src, int size); + +template +void cudaAutoCpy(T* tgt, const T* src, const int size, cudaStream_t stream) +{ + if (stream != NULL) { + check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDefault, stream)); + } + else { + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDefault)); + } +} + +template void cudaAutoCpy(float* tgt, const float* src, int size, cudaStream_t stream); +template void cudaAutoCpy(half* tgt, const half* src, int size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaAutoCpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, int size, cudaStream_t stream); +#endif +template void cudaAutoCpy(int* tgt, const int* src, int size, cudaStream_t stream); +template void cudaAutoCpy(bool* tgt, const bool* src, int size, cudaStream_t stream); +template void cudaAutoCpy(int8_t* tgt, const int8_t* src, int size, cudaStream_t stream); +template void cudaAutoCpy(uint* tgt, const uint* src, int size, cudaStream_t stream); +template void cudaAutoCpy(unsigned long long* tgt, const unsigned long long* src, int size, cudaStream_t stream); + +template void cudaAutoCpy(float const** tgt, float const* const* src, int size, cudaStream_t stream); +template void cudaAutoCpy(half const** tgt, half const* const* src, int size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaAutoCpy(__nv_bfloat16 const** tgt, __nv_bfloat16 const* const* src, int size, cudaStream_t stream); +#endif +template void cudaAutoCpy(int const** tgt, int const* const* src, int size, cudaStream_t stream); +template void cudaAutoCpy(bool const** tgt, bool const* const* src, int size, cudaStream_t stream); +template void cudaAutoCpy(int8_t const** tgt, int8_t const* const* src, int size, cudaStream_t stream); +template void +cudaAutoCpy(unsigned long long const** tgt, unsigned long long const* const* src, int size, cudaStream_t stream); template __global__ void cuda_random_uniform_kernel(T* buffer, const int size) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int idx = blockIdx.x * blockDim.x + threadIdx.x; curandState_t local_state; curand_init((unsigned long long int)1337, idx, 0, &local_state); for (int index = idx; index < size; index += blockDim.x * gridDim.x) { @@ -185,39 +225,39 @@ template void cudaRandomUniform(bool* buffer, const int size); template void cudaRandomUniform(char* buffer, const int size); template -inline T_OUT convert_to_type(T_IN val) +__host__ __device__ inline T_OUT convert_to_type(T_IN val) { return (T_OUT)val; } #ifdef ENABLE_BF16 template<> -inline __nv_bfloat16 convert_to_type(float val) +__host__ __device__ inline __nv_bfloat16 convert_to_type(float val) { return __float2bfloat16(val); } template<> -inline __nv_bfloat16 convert_to_type(half val) +__host__ __device__ inline __nv_bfloat16 convert_to_type(half val) { return __float2bfloat16(__half2float(val)); } template<> -inline float convert_to_type<__nv_bfloat16, float>(__nv_bfloat16 val) +__host__ __device__ inline float convert_to_type<__nv_bfloat16, float>(__nv_bfloat16 val) { return __bfloat162float(val); } template<> -inline half convert_to_type<__nv_bfloat16, half>(__nv_bfloat16 val) +__host__ __device__ inline half convert_to_type<__nv_bfloat16, half>(__nv_bfloat16 val) { return __float2half(__bfloat162float(val)); } #endif // ENABLE_BF16 template -int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) +int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) { if (shape.size() > 2) { printf("[ERROR] shape should have less than two dims \n"); @@ -228,8 +268,12 @@ int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) dim1 = shape[1]; } size_t size = dim0 * dim1; + if (size == 0) { + FT_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); + return 0; + } std::vector host_array(size); - std::ifstream in(filename, std::ios::in | std::ios::binary); + std::ifstream in(filename, std::ios::in | std::ios::binary); if (!in.is_open()) { FT_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); return 0; @@ -255,35 +299,36 @@ int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) cudaH2Dcpy(ptr, (T*)host_array.data(), size); } else { - std::vector host_array_2(size); - for (size_t i = 0; i < size; i++) { - host_array_2[i] = convert_to_type(host_array[i]); - } - cudaH2Dcpy(ptr, (T*)host_array_2.data(), size); + T_IN* ptr_2 = nullptr; + deviceMalloc(&ptr_2, size, false); + cudaH2Dcpy(ptr_2, host_array.data(), size); + invokeCudaD2DcpyConvert(ptr, ptr_2, size); + deviceFree(ptr_2); } in.close(); return 0; } -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); #ifdef ENABLE_BF16 template int -loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, std::vector shape, std::string filename); +loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, std::vector shape, std::string filename); #endif -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); #ifdef ENABLE_BF16 template int -loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int -loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr, std::vector shape, std::string filename); +loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr, + std::vector shape, + std::string filename); #endif // ENABLE_BF16 template -int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type) +int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type) { switch (model_file_type) { case FtCudaDataType::FP32: @@ -305,35 +350,105 @@ int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, FtCu } template int -loadWeightFromBin(float* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); -template int loadWeightFromBin(half* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +loadWeightFromBin(float* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int +loadWeightFromBin(half* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); #ifdef ENABLE_BF16 template int -loadWeightFromBin(__nv_bfloat16* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +loadWeightFromBin(__nv_bfloat16* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); #endif -__global__ void cudaD2DcpyHalf2Float(float* dst, half* src, const int size) +template +__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const int size) { for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) { - dst[tid] = __half2float(src[tid]); + dst[tid] = convert_to_type(src[tid]); } } +template +void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const int size, cudaStream_t stream) +{ + cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size); +} + +template void invokeCudaD2DcpyConvert(float* tgt, const float* src, const int size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(half* tgt, const float* src, const int size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, const half* src, const int size, cudaStream_t stream); + +#ifdef ENABLE_BF16 +template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, const float* src, const int size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, const __nv_bfloat16* src, const int size, cudaStream_t stream); +#endif // ENABLE_BF16 + void invokeCudaD2DcpyHalf2Float(float* dst, half* src, const int size, cudaStream_t stream) { - cudaD2DcpyHalf2Float<<<256, 256, 0, stream>>>(dst, src, size); + invokeCudaD2DcpyConvert(dst, src, size, stream); } -__global__ void cudaD2DcpyFloat2Half(half* dst, float* src, const int size) +void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const int size, cudaStream_t stream) { - for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) { - dst[tid] = __float2half(src[tid]); + invokeCudaD2DcpyConvert(dst, src, size, stream); +} + +template +void saveToBinary(const T* ptr, const int size, std::string filename) +{ + + std::vector h_ptr(size); + cudaD2Hcpy(h_ptr.data(), ptr, size); + std::vector float_ptr(size); + for (int i = 0; i < size; i++) { + float_ptr[i] = (float)h_ptr[i]; } + + std::ofstream out(filename, std::ios::out | std::ios::binary); + FT_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); + + out.write((char*)float_ptr.data(), size * sizeof(float)); } -void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const int size, cudaStream_t stream) +template void saveToBinary(const float* ptr, const int size, std::string filename); +template void saveToBinary(const half* ptr, const int size, std::string filename); +#ifdef ENABLE_BF16 +template void saveToBinary(const __nv_bfloat16* ptr, const int size, std::string filename); +#endif // ENABLE_BF16 + +template<> +void saveToBinary(const int* ptr, const int size, std::string filename) +{ + std::vector h_ptr(size); + cudaD2Hcpy(h_ptr.data(), ptr, size); + std::ofstream out(filename, std::ios::out | std::ios::binary); + FT_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); + out.write((char*)h_ptr.data(), size * sizeof(int)); +} + +template +__global__ void fakeCast(T_IN* input_ptr, const size_t size) +{ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + T_fake_type tmp_val = (T_fake_type)((float)input_ptr[i]); + tmp_val = tmp_val * (T_fake_type)(1.0f); + input_ptr[i] = (T_IN)((float)tmp_val); + } +} + +template +void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream) { - cudaD2DcpyFloat2Half<<<256, 256, 0, stream>>>(dst, src, size); + dim3 block(256); + dim3 grid((size + 255) / 256); + fakeCast<<>>(input_ptr, size); } -} // namespace fastertransformer \ No newline at end of file +#ifdef ENABLE_BF16 +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void +invokeFakeCast<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); +#endif +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/utils/memory_utils.h b/src/fastertransformer/utils/memory_utils.h index 27b3aaa26..abd92581e 100644 --- a/src/fastertransformer/utils/memory_utils.h +++ b/src/fastertransformer/utils/memory_utils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ namespace fastertransformer { template -void deviceMalloc(T** ptr, int size, bool is_random_initialize = true); +void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true); template void deviceMemSetZero(T* ptr, int size); @@ -30,7 +30,7 @@ template void deviceFree(T*& ptr); template -void deviceFill(T* devptr, int size, T value); +void deviceFill(T* devptr, int size, T value, cudaStream_t stream = 0); template void cudaD2Hcpy(T* tgt, const T* src, const int size); @@ -41,16 +41,38 @@ void cudaH2Dcpy(T* tgt, const T* src, const int size); template void cudaD2Dcpy(T* tgt, const T* src, const int size); +template +void cudaAutoCpy(T* tgt, const T* src, const int size, cudaStream_t stream = NULL); + template void cudaRandomUniform(T* buffer, const int size); template -int loadWeightFromBin(T* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type = FtCudaDataType::FP32); +int loadWeightFromBin(T* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type = FtCudaDataType::FP32); void invokeCudaD2DcpyHalf2Float(float* dst, half* src, const int size, cudaStream_t stream); void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const int size, cudaStream_t stream); +template +void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const int size, cudaStream_t stream = 0); + +inline bool checkIfFileExist(const std::string& file_path) +{ + std::ifstream in(file_path, std::ios::in | std::ios::binary); + if (in.is_open()) { + in.close(); + return true; + } + return false; +} + +template +void saveToBinary(const T* ptr, const int size, std::string filename); + +template +void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream); + } // namespace fastertransformer diff --git a/src/fastertransformer/utils/mpi_utils.cc b/src/fastertransformer/utils/mpi_utils.cc new file mode 100644 index 000000000..2a70e5a86 --- /dev/null +++ b/src/fastertransformer/utils/mpi_utils.cc @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/utils/mpi_utils.h" + +namespace fastertransformer { +namespace mpi { + +#ifdef BUILD_MULTI_GPU +MPI_Datatype getMpiDtype(MpiType dtype) +{ + static const std::unordered_map dtype_map{ + {MPI_TYPE_BYTE, MPI_BYTE}, + {MPI_TYPE_CHAR, MPI_CHAR}, + {MPI_TYPE_INT, MPI_INT}, + {MPI_TYPE_INT64_T, MPI_INT64_T}, + {MPI_TYPE_UINT32_T, MPI_UINT32_T}, + {MPI_TYPE_UNSIGNED_LONG_LONG, MPI_UNSIGNED_LONG_LONG}, + }; + return dtype_map.at(dtype); +} +#endif + +void initialize(int* argc, char*** argv) +{ +#ifdef BUILD_MULTI_GPU + MPICHECK(MPI_Init(argc, argv)); +#endif +} + +void finalize() +{ +#ifdef BUILD_MULTI_GPU + MPICHECK(MPI_Finalize()); +#endif +} + +bool isInitialized() +{ + int mpi_initialized = 0; +#ifdef BUILD_MULTI_GPU + MPICHECK(MPI_Initialized(&mpi_initialized)); +#endif + return static_cast(mpi_initialized); +} + +void initThread(int* argc, char*** argv, MpiThreadSupport required, int* provided) +{ +#ifdef BUILD_MULTI_GPU + switch (required) { + case THREAD_SINGLE: + MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_SINGLE, provided)); + break; + case THREAD_FUNNELED: + MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_FUNNELED, provided)); + break; + case THREAD_SERIALIZED: + MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_SERIALIZED, provided)); + break; + case THREAD_MULTIPLE: + MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_MULTIPLE, provided)); + break; + default: + break; + } +#endif +} + +int getCommWorldRank() +{ + int rank = 0; +#ifdef BUILD_MULTI_GPU + MPI_Comm_rank(MPI_COMM_WORLD, &rank); +#endif + return rank; +} + +int getCommWorldSize() +{ + int world_size = 1; +#ifdef BUILD_MULTI_GPU + MPI_Comm_size(MPI_COMM_WORLD, &world_size); +#endif + return world_size; +} + +void barrier(MpiComm comm) +{ +#ifdef BUILD_MULTI_GPU + MPICHECK(MPI_Barrier(comm.group)); +#endif +} + +void barrier() +{ +#ifdef BUILD_MULTI_GPU + MPICHECK(MPI_Barrier(MPI_COMM_WORLD)); +#endif +} + +void bcast(void* buffer, size_t size, MpiType dtype, int root, MpiComm comm) +{ +#ifdef BUILD_MULTI_GPU + MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, comm.group)); +#endif +} + +} // namespace mpi +} // namespace fastertransformer diff --git a/src/fastertransformer/utils/mpi_utils.h b/src/fastertransformer/utils/mpi_utils.h index 755e52fb1..65234538b 100644 --- a/src/fastertransformer/utils/mpi_utils.h +++ b/src/fastertransformer/utils/mpi_utils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,17 @@ #pragma once -#include "mpi.h" +#include "src/fastertransformer/utils/logger.h" + +#ifdef BUILD_MULTI_GPU +#include +#endif #include +#include namespace fastertransformer { +#ifdef BUILD_MULTI_GPU #define MPICHECK(cmd) \ do { \ int e = cmd; \ @@ -29,5 +35,60 @@ namespace fastertransformer { exit(EXIT_FAILURE); \ } \ } while (0) +#else +#define MPICHECK(cmd) +#endif + +// A wrapper module of the MPI library. +namespace mpi { + +// A wrapper of MPI data type. MPI_TYPE_{data_type} +enum MpiType { + MPI_TYPE_BYTE, + MPI_TYPE_CHAR, + MPI_TYPE_INT, + MPI_TYPE_INT64_T, + MPI_TYPE_UINT32_T, + MPI_TYPE_UNSIGNED_LONG_LONG, +}; + +// A wrapper of the level of MPI thread support +enum MpiThreadSupport { + THREAD_SINGLE, + THREAD_FUNNELED, + THREAD_SERIALIZED, + THREAD_MULTIPLE +}; + +struct MpiComm { +#ifdef BUILD_MULTI_GPU + MPI_Comm group; + MpiComm(){}; + MpiComm(MPI_Comm g): group(g){}; +#endif +}; + +#ifdef BUILD_MULTI_GPU +#define COMM_WORLD MpiComm(MPI_COMM_WORLD) +#else +#define COMM_WORLD MpiComm() +#endif + +#ifdef BUILD_MULTI_GPU +MPI_Datatype getMpiDtype(MpiType dtype); +#endif + +void initialize(int* argc, char*** argv); +void initThread(int* argc, char*** argv, MpiThreadSupport required, int* provided); +void finalize(); +bool isInitialized(); +void barrier(MpiComm comm); +void barrier(); + +int getCommWorldRank(); +int getCommWorldSize(); + +void bcast(void* buffer, size_t size, MpiType dtype, int root, MpiComm comm); -} // namespace fastertransformer \ No newline at end of file +} // namespace mpi +} // namespace fastertransformer diff --git a/src/fastertransformer/utils/nccl_utils.cc b/src/fastertransformer/utils/nccl_utils.cc index 2bddda1dc..00fcb6a73 100644 --- a/src/fastertransformer/utils/nccl_utils.cc +++ b/src/fastertransformer/utils/nccl_utils.cc @@ -15,10 +15,42 @@ */ #include "src/fastertransformer/utils/nccl_utils.h" -#include "src/fastertransformer/utils/cuda_utils.h" namespace fastertransformer { +#ifdef BUILD_MULTI_GPU +template +ncclDataType_t getNcclDataType() +{ + ncclDataType_t nccl_data_type; + if (std::is_same::value) { + nccl_data_type = ncclFloat; + } + else if (std::is_same::value) { + nccl_data_type = ncclHalf; + } +#if defined(ENABLE_BF16) && defined(ENABLE_BF16_NCCL) + else if (std::is_same::value) { + nccl_data_type = ncclBfloat16; + } +#endif + else if (std::is_same::value) { + nccl_data_type = ncclInt; + } + else if (std::is_same::value) { + nccl_data_type = ncclChar; + } + else if (std::is_same::value) { + nccl_data_type = ncclInt8; + } + else { + printf("[ERROR] NCCL only support float, half, bfloat16, int, char, and bool. \n"); + exit(-1); + } + return nccl_data_type; +} +#endif + template void ftNcclAllReduceSum(const T* send_buf, T* recv_buf, const int data_size, NcclParam nccl_param, cudaStream_t stream) { @@ -65,6 +97,8 @@ template void ftNcclSend(const int* send_buf, const int data_size, const int peer, NcclParam nccl_param, cudaStream_t stream); template void ftNcclSend(const bool* send_buf, const int data_size, const int peer, NcclParam nccl_param, cudaStream_t stream); +template void +ftNcclSend(const char* send_buf, const int data_size, const int peer, NcclParam nccl_param, cudaStream_t stream); template void ftNcclRecv(T* recv_buf, const int data_size, const int peer, NcclParam nccl_param, cudaStream_t stream) @@ -86,6 +120,8 @@ ftNcclRecv(__nv_bfloat16* recv_buf, const int data_size, const int peer, NcclPar template void ftNcclRecv(int* recv_buf, const int data_size, const int peer, NcclParam nccl_param, cudaStream_t stream); template void ftNcclRecv(bool* recv_buf, const int data_size, const int peer, NcclParam nccl_param, cudaStream_t stream); +template void +ftNcclRecv(char* recv_buf, const int data_size, const int peer, NcclParam nccl_param, cudaStream_t stream); template void ftNcclBroadCast(T* buff, const int data_size, const int root, NcclParam nccl_param, cudaStream_t stream) @@ -117,62 +153,245 @@ template void ftNcclAllReduceSum( #ifdef ENABLE_BF16 template void ftNcclAllReduceSum(const __nv_bfloat16* send_buf, - __nv_bfloat16* recv_buf, - const int data_size, - NcclParam nccl_param, - cudaStream_t stream); + __nv_bfloat16* recv_buf, + const int data_size, + NcclParam nccl_param, + cudaStream_t stream); #endif template void ftNcclAllGather(const float* send_buf, - float* recv_buf, - const int data_size, - const int rank, - NcclParam nccl_param, + float* recv_buf, + const int data_size, + const int rank, + NcclParam nccl_param, cudaStream_t stream); -template void ftNcclAllGather(const half* send_buf, - half* recv_buf, - const int data_size, - const int rank, - NcclParam nccl_param, +template void ftNcclAllGather(const half* send_buf, + half* recv_buf, + const int data_size, + const int rank, + NcclParam nccl_param, cudaStream_t stream); #ifdef ENABLE_BF16 template void ftNcclAllGather(const __nv_bfloat16* send_buf, - __nv_bfloat16* recv_buf, - const int data_size, - const int rank, - NcclParam nccl_param, - cudaStream_t stream); + __nv_bfloat16* recv_buf, + const int data_size, + const int rank, + NcclParam nccl_param, + cudaStream_t stream); #endif -template -ncclDataType_t getNcclDataType() +void ftNcclGroupStart() { #ifdef BUILD_MULTI_GPU - ncclDataType_t nccl_data_type; - if (std::is_same::value) { - nccl_data_type = ncclFloat; + NCCLCHECK(ncclGroupStart()); +#endif +} + +void ftNcclGroupEnd() +{ +#ifdef BUILD_MULTI_GPU + NCCLCHECK(ncclGroupEnd()); +#endif +} + +void ftNcclStreamSynchronize(NcclParam tensor_para, NcclParam pipeline_para, cudaStream_t stream) +{ +#ifdef BUILD_MULTI_GPU + cudaError_t cudaErr; + ncclResult_t tensor_ncclErr = ncclSuccess, tensor_ncclAsyncErr = ncclSuccess, pipeline_ncclErr = ncclSuccess, + pipeline_ncclAsyncErr = ncclSuccess; + ncclComm_t tensor_comm = tensor_para.nccl_comm_; + ncclComm_t pipeline_comm = pipeline_para.nccl_comm_; + if (tensor_para.world_size_ == 1 && pipeline_para.world_size_ == 1) { + check_cuda_error(cudaStreamSynchronize(stream)); + return; } - else if (std::is_same::value) { - nccl_data_type = ncclHalf; + while (1) { + cudaErr = cudaStreamQuery(stream); + if (cudaErr == cudaSuccess) { + return; + } + + if (cudaErr != cudaErrorNotReady) { + std::string error_msg = "CUDA Error : cudaStreamQuery returned " + std::to_string(cudaErr); + throw std::runtime_error(error_msg); + } + if (tensor_para.world_size_ > 1) { + tensor_ncclErr = ncclCommGetAsyncError(tensor_comm, &tensor_ncclAsyncErr); + } + if (pipeline_para.world_size_ > 1) { + pipeline_ncclErr = ncclCommGetAsyncError(pipeline_comm, &pipeline_ncclAsyncErr); + } + + if (tensor_ncclErr != ncclSuccess || pipeline_ncclErr != ncclSuccess) { + std::string error_msg = "NCCL Error : ncclCommGetAsyncError returned " + std::to_string(tensor_ncclErr) + + " (tensor_para) " + std::to_string(pipeline_ncclErr) + " (pipeline_para)"; + throw std::runtime_error(error_msg); + } + + if (tensor_ncclAsyncErr != ncclSuccess) { + // An asynchronous error happened. Stop the operation and destroy + // the communicator + tensor_ncclErr = ncclCommAbort(tensor_comm); + if (tensor_ncclErr != ncclSuccess) { + std::string error_msg = "NCCL Error : ncclCommDestroy returned " + std::to_string(tensor_ncclErr); + throw std::runtime_error(error_msg); + } + } + + if (pipeline_ncclAsyncErr != ncclSuccess) { + // An asynchronous error happened. Stop the operation and destroy + // the communicator + pipeline_ncclErr = ncclCommAbort(pipeline_comm); + if (pipeline_ncclErr != ncclSuccess) { + std::string error_msg = "NCCL Error : ncclCommDestroy returned " + std::to_string(pipeline_ncclErr); + throw std::runtime_error(error_msg); + } + } } -#if defined(ENABLE_BF16) && defined(ENABLE_BF16_NCCL) - else if (std::is_same::value) { - nccl_data_type = ncclBfloat16; +#endif +} + +void ftNcclGetUniqueId(NcclUid& uid) +{ +#ifdef BUILD_MULTI_GPU + NCCLCHECK(ncclGetUniqueId(&uid.nccl_uid_)); +#endif +} + +void ftNcclCommInitRank(NcclParam& param, const int rank, const int world_size, const NcclUid uid) +{ +#ifdef BUILD_MULTI_GPU + // Initialize a nccl communicator. + if (param.nccl_comm_ != nullptr) { + FT_LOG_WARNING("NcclParam is already initialized."); + return; } + param.rank_ = rank; + param.world_size_ = world_size; + param.nccl_uid_ = uid.nccl_uid_; + NCCLCHECK(ncclCommInitRank(¶m.nccl_comm_, param.world_size_, param.nccl_uid_, param.rank_)); #endif - else if (std::is_same::value) { - nccl_data_type = ncclInt; +} + +void ftNcclParamDestroy(NcclParam& param) +{ +#ifdef BUILD_MULTI_GPU + if (param.nccl_comm_ != nullptr) { + ncclCommDestroy(param.nccl_comm_); } - else if (std::is_same::value) { - nccl_data_type = ncclInt8; +#endif +} + +void ftNcclInitialize(NcclParam& tensor_para, + NcclParam& pipeline_para, + const int tensor_para_size, + const int pipeline_para_size) +{ + // Initialize nccl communication grid of tensor and pipeline parallel groups. +#ifndef BUILD_MULTI_GPU + FT_CHECK_WITH_INFO(tensor_para_size == 1, + fmtstr("tensor_para_size=%d although BUILD_MULTI_GPU is disabled. " + "Please use the cmake flag -DBUILD_MULTI_GPU=ON if you want " + "to use tensor/pipeline parallelism.", + tensor_para_size)); + FT_CHECK_WITH_INFO(pipeline_para_size == 1, + fmtstr("pipeline_para_size=%d although BUILD_MULTI_GPU is disabled. " + "Please use the cmake flag -DBUILD_MULTI_GPU=ON if you want " + "to use tensor/pipeline parallelism.", + pipeline_para_size)); + tensor_para.rank_ = 0; + tensor_para.world_size_ = tensor_para_size; + pipeline_para.rank_ = 0; + pipeline_para.world_size_ = pipeline_para_size; +#else + // Initialize a nccl communicator. + if (tensor_para.nccl_comm_ != nullptr && pipeline_para.nccl_comm_ != nullptr) { + FT_LOG_WARNING("NcclParam is already initialized. Skip NCCL initialization."); + return; } - else { - printf("[ERROR] NCCL only support float, half, bfloat16, int and bool. \n"); - exit(-1); + FT_CHECK(tensor_para.nccl_comm_ == nullptr); + FT_CHECK(pipeline_para.nccl_comm_ == nullptr); + FT_CHECK(tensor_para_size > 0); + FT_CHECK(pipeline_para_size > 0); + + if (tensor_para_size == 1 && pipeline_para_size == 1) { + FT_LOG_WARNING("Skip NCCL initialization since requested tensor/pipeline parallel sizes are equals to 1."); + tensor_para.rank_ = 0; + tensor_para.world_size_ = tensor_para_size; + pipeline_para.rank_ = 0; + pipeline_para.world_size_ = pipeline_para_size; + return; } - return nccl_data_type; + + int mpi_initialized; + MPICHECK(MPI_Initialized(&mpi_initialized)); + FT_CHECK_WITH_INFO(mpi_initialized, "Fail to nccl initialization because MPI is not initialized."); + + int rank, world_size; + MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); + MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size)); + + FT_CHECK_WITH_INFO(tensor_para_size * pipeline_para_size <= world_size, + fmtstr("tensor_para_size (%d) * pipeline_para_size (%d) should equal to the world size (%d).", + tensor_para_size, + pipeline_para_size, + world_size)); + + // Convert WORLD communicator into 2D grid (k * n) communicator. + // row = a tensor parallel group, col = a pipeline parallel group. + MPI_Comm grid_comm, tp_comm, pp_comm; + + int dims[2] = {pipeline_para_size, tensor_para_size}; + int periods[2] = {0, 0}; + MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm); + + // Split 2D communicator into rows and cols. + int tp_remain_dims[2] = {false, true}; + int pp_remain_dims[2] = {true, false}; + MPI_Cart_sub(grid_comm, tp_remain_dims, &tp_comm); + MPI_Cart_sub(grid_comm, pp_remain_dims, &pp_comm); + + int tp_rank, pp_rank; + MPI_Comm_rank(tp_comm, &tp_rank); + MPI_Comm_rank(pp_comm, &pp_rank); + + ncclUniqueId tp_uid; + ncclUniqueId pp_uid; + // The root of each group creates a nccl uid. + if (tp_rank == 0) { + FT_LOG_DEBUG("rank %d pp rank %d creates nccl uid.", rank, tp_rank); + NCCLCHECK(ncclGetUniqueId(&tp_uid)); + } + if (pp_rank == 0) { + FT_LOG_DEBUG("rank %d pp rank %d creates nccl uid.", rank, pp_rank); + NCCLCHECK(ncclGetUniqueId(&pp_uid)); + } + // Broadcast nccl uid to share the same nccl uid across gpus in the same group. + FT_LOG_DEBUG("Broadcast nccl uid to the others in the same parallel groups."); + MPI_Bcast(&tp_uid, sizeof(tp_uid), MPI_BYTE, 0, tp_comm); + MPI_Bcast(&pp_uid, sizeof(pp_uid), MPI_BYTE, 0, pp_comm); + + FT_LOG_DEBUG("Initialize NCCL communicators."); + ncclComm_t tp_nccl_comm, pp_nccl_comm; + NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank)); + NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank)); + + tensor_para.world_size_ = tensor_para_size; + tensor_para.rank_ = tp_rank; + tensor_para.nccl_uid_ = tp_uid; + tensor_para.nccl_comm_ = tp_nccl_comm; + pipeline_para.world_size_ = pipeline_para_size; + pipeline_para.rank_ = pp_rank; + pipeline_para.nccl_uid_ = pp_uid; + pipeline_para.nccl_comm_ = pp_nccl_comm; + FT_LOG_INFO("NCCL initialized rank=%d world_size=%d tensor_para=%s pipeline_para=%s", + rank, + world_size, + tensor_para.toString().c_str(), + pipeline_para.toString().c_str()); #endif } @@ -185,7 +404,7 @@ size_t getLocalBatchSize(const size_t batch_size, const size_t seq_len, const si if (local_batch_size % pipeline_para_size == 0) { local_batch_size /= pipeline_para_size; } - while (local_batch_size * seq_len > 8192 && local_batch_size % 2 == 0) { + while (local_batch_size * seq_len > 1024 && local_batch_size % 2 == 0) { local_batch_size /= 2; } return local_batch_size; diff --git a/src/fastertransformer/utils/nccl_utils.h b/src/fastertransformer/utils/nccl_utils.h index 0fa845fcc..31299ac8f 100644 --- a/src/fastertransformer/utils/nccl_utils.h +++ b/src/fastertransformer/utils/nccl_utils.h @@ -16,9 +16,17 @@ #pragma once -#include "cuda_runtime.h" -#include "nccl.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/logger.h" +#include "src/fastertransformer/utils/mpi_utils.h" + +#include +#ifdef BUILD_MULTI_GPU +#include +#include +#endif #include +#include #if defined(NCCL_VERSION_CODE) && (NCCL_VERSION_CODE >= 21003) #define ENABLE_BF16_NCCL @@ -26,6 +34,7 @@ namespace fastertransformer { +#ifdef BUILD_MULTI_GPU #define NCCLCHECK(cmd) \ do { \ ncclResult_t r = cmd; \ @@ -34,26 +43,47 @@ namespace fastertransformer { exit(EXIT_FAILURE); \ } \ } while (0) +#else +#define NCCLCHECK(cmd) +#endif + +struct NcclUid { +#ifndef BUILD_MULTI_GPU + NcclUid(){}; + NcclUid(NcclUid const& uid){}; +#else + ncclUniqueId nccl_uid_; + NcclUid(){}; + NcclUid(NcclUid const& uid): nccl_uid_(uid.nccl_uid_){}; +#endif +}; struct NcclParam { int rank_{0}; int world_size_{1}; - ncclComm_t nccl_comm_ = nullptr; +#ifdef BUILD_MULTI_GPU + ncclUniqueId nccl_uid_; + ncclComm_t nccl_comm_ = nullptr; +#endif +#ifdef BUILD_MULTI_GPU NcclParam(): rank_(0), world_size_(1), nccl_comm_(nullptr){}; - NcclParam(int rank, int world_size, ncclComm_t comm): rank_(rank), world_size_(world_size), nccl_comm_(comm){}; + NcclParam(int rank, int world_size): rank_(rank), world_size_(world_size){}; NcclParam(NcclParam const& param): - rank_(param.rank_), world_size_(param.world_size_), nccl_comm_(param.nccl_comm_){}; - - // int layers_per_group{0}; - // bool is_valid(int i) - // { - // if (i >= layers_per_group * rank && i < layers_per_group * (rank + 1)) - // return true; - // else - // return false; - // } - // int local_batch_size{-1}; + rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_){}; + std::string toString() + { + return fmtstr("NcclParam[rank=%d, world_size=%d, nccl_comm=%p]", rank_, world_size_, nccl_comm_); + } +#else + NcclParam(): rank_(0), world_size_(1){}; + NcclParam(int rank, int world_size): rank_(rank), world_size_(world_size){}; + NcclParam(NcclParam const& param): rank_(param.rank_), world_size_(param.world_size_){}; + std::string toString() + { + return fmtstr("NcclParam[rank=%d, world_size=%d]", rank_, world_size_); + } +#endif }; // New APIs @@ -73,9 +103,20 @@ void ftNcclRecv(T* recv_buf, const int data_size, const int peer, NcclParam nccl template void ftNcclSend(const T* send_buf, const int data_size, const int peer, NcclParam nccl_param, cudaStream_t stream); -template -ncclDataType_t getNcclDataType(); +// nccl stream synchronize, abort nccl comms and throw errors when nccl async errors detected +void ftNcclStreamSynchronize(NcclParam tensor_para, NcclParam pipeline_para_, cudaStream_t stream); + +void ftNcclGroupStart(); +void ftNcclGroupEnd(); +void ftNcclGetUniqueId(NcclUid& uid); +void ftNcclCommInitRank(NcclParam& param, const int rank, const int world_size, const NcclUid uid); +void ftNcclParamDestroy(NcclParam& param); + +void ftNcclInitialize(NcclParam& tensor_para, + NcclParam& pipeline_para, + const int tensor_para_size, + const int pipeline_para_size); size_t getLocalBatchSize(const size_t batch_size, const size_t seq_len, const size_t pipeline_para_size); -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/utils/nvtx_utils.h b/src/fastertransformer/utils/nvtx_utils.h index 6cda6022e..b98a4e31d 100644 --- a/src/fastertransformer/utils/nvtx_utils.h +++ b/src/fastertransformer/utils/nvtx_utils.h @@ -23,10 +23,10 @@ extern bool NVTX_ON; namespace nvtx { static std::string scope; -std::string getScope(); -void addScope(std::string name); -void setScope(std::string name); -void resetScope(); +std::string getScope(); +void addScope(std::string name); +void setScope(std::string name); +void resetScope(); } // namespace nvtx #ifdef USE_NVTX diff --git a/src/fastertransformer/utils/prompt_learning.h b/src/fastertransformer/utils/prompt_learning.h new file mode 100644 index 000000000..faced8d7a --- /dev/null +++ b/src/fastertransformer/utils/prompt_learning.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace fastertransformer { + +enum class PromptLearningType { + no_prompt, + soft_prompt, + prefix_prompt, + p_prompt_tuning +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/utils/string_utils.h b/src/fastertransformer/utils/string_utils.h index 28e6897ec..2beeaac24 100644 --- a/src/fastertransformer/utils/string_utils.h +++ b/src/fastertransformer/utils/string_utils.h @@ -40,7 +40,7 @@ inline std::string fmtstr(const std::string& format, Args... args) throw std::runtime_error("Error during formatting."); } auto size = static_cast(size_s); - auto buf = std::make_unique(size); + auto buf = std::make_unique(size); std::snprintf(buf.get(), size, format.c_str(), args...); #if defined(_MSC_VER) #pragma warning(pop) @@ -50,7 +50,8 @@ inline std::string fmtstr(const std::string& format, Args... args) return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside } -inline std::string vec2str(std::vector vec) +template +inline std::string vec2str(std::vector vec) { std::stringstream ss; ss << "("; diff --git a/src/fastertransformer/utils/word_list.cc b/src/fastertransformer/utils/word_list.cc index cfbeb1496..93666aa15 100644 --- a/src/fastertransformer/utils/word_list.cc +++ b/src/fastertransformer/utils/word_list.cc @@ -26,12 +26,12 @@ int read_word_list(const std::string& filename, std::vector& file_data) std::ifstream word_list_file(filename, std::ios::in); std::string line_buf; - int line_count = 0; - size_t id_counts[2] = {0, 0}; + int line_count = 0; + size_t id_counts[2] = {0, 0}; while (std::getline(word_list_file, line_buf)) { std::stringstream line_stream(line_buf); - std::string vals; + std::string vals; while (std::getline(line_stream, vals, ',')) { file_data.push_back(std::stoi(vals)); id_counts[line_count]++; diff --git a/tests/bert/tf_bert_unit_test.py b/tests/bert/tf_bert_unit_test.py index 922d3ab3c..bfc2197e4 100644 --- a/tests/bert/tf_bert_unit_test.py +++ b/tests/bert/tf_bert_unit_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,9 +23,9 @@ dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.append(dir_path + "/../..") from examples.tensorflow.bert.bert_example import bert_example +from examples.tensorflow.encoder.encoder_example import encoder_example class TestEncoder(unittest.TestCase): - common_args_dict = {'batch_size' : 4, 'num_layer' : 12, 'max_seq_len': 32, @@ -40,35 +40,74 @@ class TestEncoder(unittest.TestCase): 'thread_num': 1, 'int8_mode': 0 } - threshold = {'fp32': 3e-5, 'fp16': 4e-2 } + threshold = {'fp32': 3e-5, 'fp16': 4e-2, 'bf16': 5e-2 } + test_level = 1 def test_batch_fp32(self): + if self.test_level >= 3: + print(f"[INFO] test level {self.test_level}, run unit test test_batch_fp32 (level {3})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_batch_fp32 (level {3})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp32' for batch in [1, 8, 64, 128]: args_dict['batch_size'] = batch tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_batch_fp16(self): + if self.test_level >= 2: + print(f"[INFO] test level {self.test_level}, run unit test test_batch_fp16 (level {2})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_batch_fp16 (level {2})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp16' for batch in [1, 8, 64, 128]: args_dict['batch_size'] = batch tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + + def test_batch_bf16(self): + if self.test_level >= 2: + print(f"[INFO] test level {self.test_level}, run unit test test_batch_bf16 (level {2})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_batch_bf16 (level {2})") + return + args_dict = copy.deepcopy(self.common_args_dict) + args_dict['data_type'] = 'bf16' + + for batch in [1, 8, 64, 128]: + args_dict['batch_size'] = batch + tf.reset_default_graph() + os.system("./bin/bert_gemm {} {} {} {} 2 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], + args_dict['head_number'], args_dict['size_per_head'])) + max_diff = bert_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_size_fp32(self): + if self.test_level >= 3: + print(f"[INFO] test level {self.test_level}, run unit test test_size_fp32 (level {3})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_size_fp32 (level {3})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp32' args_dict['head_number'] = 8 @@ -77,13 +116,20 @@ def test_size_fp32(self): args_dict['size_per_head'] = size args_dict['inter_size'] = args_dict['head_number'] * size * 4 tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_size_fp16(self): + if self.test_level >= 2: + print(f"[INFO] test level {self.test_level}, run unit test test_size_fp16 (level {2})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_size_fp16 (level {2})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp16' args_dict['head_number'] = 12 @@ -92,13 +138,41 @@ def test_size_fp16(self): args_dict['size_per_head'] = size args_dict['inter_size'] = args_dict['head_number'] * size * 4 tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + def test_size_bf16(self): + if self.test_level >= 2: + print(f"[INFO] test level {self.test_level}, run unit test test_size_bf16 (level {2})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_size_bf16 (level {2})") + return + args_dict = copy.deepcopy(self.common_args_dict) + args_dict['data_type'] = 'bf16' + args_dict['head_number'] = 12 + + for size in [32, 40, 64, 120, 128]: + args_dict['size_per_head'] = size + args_dict['inter_size'] = args_dict['head_number'] * size * 4 + tf.reset_default_graph() + os.system("./bin/bert_gemm {} {} {} {} 2 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], + args_dict['head_number'], args_dict['size_per_head'])) + max_diff = bert_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + def test_head_fp32(self): + if self.test_level >= 3: + print(f"[INFO] test level {self.test_level}, run unit test test_head_fp32 (level {3})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_head_fp32 (level {3})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp32' args_dict['size_per_head'] = 64 @@ -107,13 +181,20 @@ def test_head_fp32(self): args_dict['head_number'] = h args_dict['inter_size'] = h * args_dict['size_per_head'] * 4 tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_head_fp16(self): + if self.test_level >= 2: + print(f"[INFO] test level {self.test_level}, run unit test test_head_fp16 (level {2})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_head_fp16 (level {2})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp16' args_dict['size_per_head'] = 64 @@ -122,13 +203,41 @@ def test_head_fp16(self): args_dict['head_number'] = h args_dict['inter_size'] = h * args_dict['size_per_head'] * 4 tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + + def test_head_bf16(self): + if self.test_level >= 2: + print(f"[INFO] test level {self.test_level}, run unit test test_head_bf16 (level {2})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_head_bf16 (level {2})") + return + args_dict = copy.deepcopy(self.common_args_dict) + args_dict['data_type'] = 'bf16' + args_dict['size_per_head'] = 64 + + for h in [8, 12, 17, 24, 29, 32]: + args_dict['head_number'] = h + args_dict['inter_size'] = h * args_dict['size_per_head'] * 4 + tf.reset_default_graph() + os.system("./bin/bert_gemm {} {} {} {} 2 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], + args_dict['head_number'], args_dict['size_per_head'])) + max_diff = bert_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_hidden_fp32(self): + if self.test_level >= 3: + print(f"[INFO] test level {self.test_level}, run unit test test_hidden_fp32 (level {3})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_hidden_fp32 (level {3})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp32' @@ -137,13 +246,20 @@ def test_hidden_fp32(self): args_dict['size_per_head'] = p[1] args_dict['inter_size'] = p[0] * p[1] * 4 tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_hidden_fp16(self): + if self.test_level >= 1: + print(f"[INFO] test level {self.test_level}, run unit test test_hidden_fp16 (level {1})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_hidden_fp16 (level {1})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp16' @@ -152,39 +268,100 @@ def test_hidden_fp16(self): args_dict['size_per_head'] = p[1] args_dict['inter_size'] = p[0] * p[1] * 4 tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + + def test_hidden_bf16(self): + if self.test_level >= 1: + print(f"[INFO] test level {self.test_level}, run unit test test_hidden_bf16 (level {1})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_hidden_bf16 (level {1})") + return + args_dict = copy.deepcopy(self.common_args_dict) + args_dict['data_type'] = 'bf16' + + for p in [tuple([12, 64]), tuple([16, 64]), tuple([4, 32]), tuple([8, 96]), tuple([12, 120])]: + args_dict['head_number'] = p[0] + args_dict['size_per_head'] = p[1] + args_dict['inter_size'] = p[0] * p[1] * 4 + tf.reset_default_graph() + os.system("./bin/bert_gemm {} {} {} {} 2 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], + args_dict['head_number'], args_dict['size_per_head'])) + max_diff = bert_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_seqlen_fp32(self): + if self.test_level >= 3: + print(f"[INFO] test level {self.test_level}, run unit test test_seqlen_fp32 (level {3})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_seqlen_fp32 (level {3})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp32' for seqlen in [32, 130, 511, 1024, 1536]: args_dict['max_seq_len'] = seqlen tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_seqlen_fp16(self): + if self.test_level >= 1: + print(f"[INFO] test level {self.test_level}, run unit test test_seqlen_fp16 (level {1})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_seqlen_fp16 (level {1})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp16' for seqlen in [32, 130, 511, 1024, 1536]: args_dict['max_seq_len'] = seqlen tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + + def test_seqlen_bf16(self): + if self.test_level >= 1: + print(f"[INFO] test level {self.test_level}, run unit test test_seqlen_bf16 (level {1})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_seqlen_bf16 (level {1})") + return + args_dict = copy.deepcopy(self.common_args_dict) + args_dict['data_type'] = 'bf16' + + for seqlen in [32, 130, 511, 1024, 1536]: + args_dict['max_seq_len'] = seqlen + tf.reset_default_graph() + os.system("./bin/bert_gemm {} {} {} {} 2 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], + args_dict['head_number'], args_dict['size_per_head'])) + max_diff = bert_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_large_model_fp32(self): + if self.test_level >= 3: + print(f"[INFO] test level {self.test_level}, run unit test test_large_model_fp32 (level {3})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_large_model_fp32 (level {3})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp32' args_dict['num_layer'] = 4 @@ -194,28 +371,65 @@ def test_large_model_fp32(self): args_dict['size_per_head'] = p[1] args_dict['inter_size'] = p[0] * p[1] * 4 tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_large_model_fp16(self): + if self.test_level >= 2: + print(f"[INFO] test level {self.test_level}, run unit test test_large_model_fp16 (level {2})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_large_model_fp16 (level {2})") + return args_dict = copy.deepcopy(self.common_args_dict) args_dict['data_type'] = 'fp16' args_dict['num_layer'] = 4 - threshold = 0.08 # Use larger threshold for larger model, need to check it makse sense or not + threshold = 0.08 # Use larger threshold for larger model, need to check it makes sense or not for p in [tuple([32, 64]), tuple([64, 64]), tuple([32, 128])]: args_dict['head_number'] = p[0] args_dict['size_per_head'] = p[1] args_dict['inter_size'] = p[0] * p[1] * 4 tf.reset_default_graph() - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['max_seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], args_dict['head_number'], args_dict['size_per_head'], args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) self.assertTrue(max_diff < threshold) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < threshold) + + def test_large_model_bf16(self): + if self.test_level >= 2: + print(f"[INFO] test level {self.test_level}, run unit test test_large_model_bf16 (level {2})") + else: + print(f"[INFO] test level {self.test_level}, skip unit test test_large_model_bf16 (level {2})") + return + args_dict = copy.deepcopy(self.common_args_dict) + args_dict['data_type'] = 'bf16' + args_dict['num_layer'] = 4 + threshold = 0.08 # Use larger threshold for larger model, need to check it makes sense or not + + for p in [tuple([32, 64]), tuple([64, 64]), tuple([32, 128])]: + args_dict['head_number'] = p[0] + args_dict['size_per_head'] = p[1] + args_dict['inter_size'] = p[0] * p[1] * 4 + tf.reset_default_graph() + os.system("./bin/bert_gemm {} {} {} {} 2 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['max_seq_len'], + args_dict['head_number'], args_dict['size_per_head'])) + max_diff = bert_example(args_dict) + self.assertTrue(max_diff < threshold) + max_diff = encoder_example(args_dict) + self.assertTrue(max_diff < threshold) if __name__ == "__main__": + test_level = 1 + if len(sys.argv) > 1: + test_level = sys.argv.pop() + TestEncoder.test_level = int(test_level) unittest.main() + \ No newline at end of file diff --git a/tests/bert/th_bert_unit_test.py b/tests/bert/th_bert_unit_test.py index 571a9b95b..6e65ebe34 100644 --- a/tests/bert/th_bert_unit_test.py +++ b/tests/bert/th_bert_unit_test.py @@ -21,6 +21,7 @@ dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.append(dir_path + "/../..") from examples.pytorch.bert.bert_example import bert_example +from examples.pytorch.encoder.encoder_example import encoder_example class TestEncoder(unittest.TestCase): @@ -33,15 +34,18 @@ class TestEncoder(unittest.TestCase): 'allow_gemm_test': False, 'sparse': False, 'time': False, - 'fp16': False, + 'data_type': 'fp32', 'remove_padding': False, 'avg_seq_len': -1, 'thread_num': 1, 'ths_path': 'lib/libth_bert.so', 'weight_path': None, - 'int8_mode': 0 + 'int8_mode': 0, + 'tensor_para_size': 1, + 'pipeline_para_size': 1, + 'error_threshold': None, } - threshold = {False: 3e-5, True: 4e-2 } + threshold = {'fp32': 4e-5, 'fp16': 4e-2, 'bf16': 5e-2 } def test_batch_fp32(self): args_dict = copy.deepcopy(self.common_args_dict) @@ -49,24 +53,34 @@ def test_batch_fp32(self): for batch in [1, 8, 64, 128]: args_dict['batch_size'] = batch - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['seq_len'], args_dict['head_num'], args_dict['head_size'], - args_dict['fp16'] == True)) + args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) - self.assertTrue(max_diff < self.threshold[args_dict['fp16']]) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + args_dict['ths_path'] = 'lib/libth_encoder.so' + max_diff = encoder_example(args_dict) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_batch_fp16(self): args_dict = copy.deepcopy(self.common_args_dict) - args_dict['fp16'] = True + args_dict['data_type'] = 'fp16' for batch in [1, 8, 64, 128]: args_dict['batch_size'] = batch - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['seq_len'], args_dict['head_num'], args_dict['head_size'], - args_dict['fp16'] == True)) + args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) - self.assertTrue(max_diff < self.threshold[args_dict['fp16']]) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + args_dict['ths_path'] = 'lib/libth_encoder.so' + max_diff = encoder_example(args_dict) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_hidden_fp32(self): args_dict = copy.deepcopy(self.common_args_dict) @@ -76,26 +90,36 @@ def test_hidden_fp32(self): args_dict['head_size'] = p[1] args_dict['inter_size'] = p[0] * p[1] * 4 - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['seq_len'], args_dict['head_num'], args_dict['head_size'], - args_dict['fp16'] == True)) + args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) - self.assertTrue(max_diff < self.threshold[args_dict['fp16']]) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + args_dict['ths_path'] = 'lib/libth_encoder.so' + max_diff = encoder_example(args_dict) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_hidden_fp16(self): args_dict = copy.deepcopy(self.common_args_dict) - args_dict['fp16'] = True + args_dict['data_type'] = 'fp16' for p in [tuple([12, 64]), tuple([16, 64]), tuple([4, 32]), tuple([8, 96])]: args_dict['head_num'] = p[0] args_dict['head_size'] = p[1] args_dict['inter_size'] = p[0] * p[1] * 4 - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['seq_len'], args_dict['head_num'], args_dict['head_size'], - args_dict['fp16'] == True)) + args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) - self.assertTrue(max_diff < self.threshold[args_dict['fp16']]) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + args_dict['ths_path'] = 'lib/libth_encoder.so' + max_diff = encoder_example(args_dict) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) def test_seqlen_fp32(self): args_dict = copy.deepcopy(self.common_args_dict) @@ -105,26 +129,38 @@ def test_seqlen_fp32(self): if seqlen == 1536: args_dict['layer_num'] = 6 - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['seq_len'], + threshold_tmp = {'fp32': 1e-4, 'fp16': 4e-2, 'bf16': 5e-2} # The error of encoder on this test is larger + + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['seq_len'], args_dict['head_num'], args_dict['head_size'], - args_dict['fp16'] == True)) + args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) - self.assertTrue(max_diff < self.threshold[args_dict['fp16']]) + sys.stdout.flush() + self.assertTrue(max_diff < threshold_tmp[args_dict['data_type']]) + args_dict['ths_path'] = 'lib/libth_encoder.so' + max_diff = encoder_example(args_dict) + sys.stdout.flush() + self.assertTrue(max_diff < threshold_tmp[args_dict['data_type']]) def test_seqlen_fp16(self): args_dict = copy.deepcopy(self.common_args_dict) - args_dict['fp16'] = True + args_dict['data_type'] = 'fp16' for seqlen in [32, 130, 511, 1024, 1536]: args_dict['seq_len'] = seqlen if seqlen == 1536: args_dict['layer_num'] = 6 - os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat .tmp.gemm.log".format(args_dict['batch_size'], args_dict['seq_len'], + os.system("./bin/bert_gemm {} {} {} {} {} 0 > .tmp.gemm.log && cat gemm_config.in".format(args_dict['batch_size'], args_dict['seq_len'], args_dict['head_num'], args_dict['head_size'], - args_dict['fp16'] == True)) + args_dict['data_type'] == 'fp16')) max_diff = bert_example(args_dict) - self.assertTrue(max_diff < self.threshold[args_dict['fp16']]) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) + args_dict['ths_path'] = 'lib/libth_encoder.so' + max_diff = encoder_example(args_dict) + sys.stdout.flush() + self.assertTrue(max_diff < self.threshold[args_dict['data_type']]) if __name__ == "__main__": unittest.main() diff --git a/tests/data/gpt_context_decoder_inputs/GPU-batch_to_compact_idx.npy b/tests/data/gpt_context_decoder_inputs/GPU-batch_to_compact_idx.npy new file mode 100644 index 000000000..9a179d70c Binary files /dev/null and b/tests/data/gpt_context_decoder_inputs/GPU-batch_to_compact_idx.npy differ diff --git a/tests/data/gpt_context_decoder_inputs/GPU-compact_idx.npy b/tests/data/gpt_context_decoder_inputs/GPU-compact_idx.npy new file mode 100644 index 000000000..9a179d70c Binary files /dev/null and b/tests/data/gpt_context_decoder_inputs/GPU-compact_idx.npy differ diff --git a/tests/data/gpt_context_decoder_inputs/GPU-context_decoder_input.npy b/tests/data/gpt_context_decoder_inputs/GPU-context_decoder_input.npy new file mode 100644 index 000000000..a14f041ae Binary files /dev/null and b/tests/data/gpt_context_decoder_inputs/GPU-context_decoder_input.npy differ diff --git a/tests/data/gpt_context_decoder_inputs/GPU-input_attention_mask.npy b/tests/data/gpt_context_decoder_inputs/GPU-input_attention_mask.npy new file mode 100644 index 000000000..8834ceb39 Binary files /dev/null and b/tests/data/gpt_context_decoder_inputs/GPU-input_attention_mask.npy differ diff --git a/tests/data/gpt_context_decoder_inputs/GPU-tiled_input_lengths.npy b/tests/data/gpt_context_decoder_inputs/GPU-tiled_input_lengths.npy new file mode 100644 index 000000000..62d097bf7 Binary files /dev/null and b/tests/data/gpt_context_decoder_inputs/GPU-tiled_input_lengths.npy differ diff --git a/tests/decoding/tf_decoding_unit_test.py b/tests/decoding/tf_decoding_unit_test.py index 654200b46..3cb9a4ecb 100644 --- a/tests/decoding/tf_decoding_unit_test.py +++ b/tests/decoding/tf_decoding_unit_test.py @@ -47,14 +47,15 @@ class TestDecoding(unittest.TestCase): "max_iteration": 10, } - def check_result(self, beam_width, datatype, test_time, topk=4, topp=0.0, batch_size=-1): - result = Value('i', -1) - p = Process(target=self.run_translate, args=(beam_width, datatype, test_time, topk, topp, batch_size, result)) + def check_result(self, beam_width, datatype, test_time, topk=4, topp=0.0, batch_size=-1, + decoder_bleu_score_threshold=None, decoding_bleu_score_threshold=None): + p = Process(target=self.run_translate, args=(beam_width, datatype, test_time, topk, topp, + batch_size, decoder_bleu_score_threshold, decoding_bleu_score_threshold)) p.start() p.join() - # self.assertTrue(result.value == 1) - def run_translate(self, beam_width, datatype, test_time, topk=4, topp=0.0, batch_size=-1, result=None): + def run_translate(self, beam_width, datatype, test_time, topk=4, topp=0.0, batch_size=-1, + decoder_bleu_score_threshold=None, decoding_bleu_score_threshold=None): args_dict = copy.deepcopy(self.common_args_dict) args_dict['beam_width'] = beam_width args_dict['data_type'] = datatype @@ -68,63 +69,62 @@ def run_translate(self, beam_width, datatype, test_time, topk=4, topp=0.0, batch tf.reset_default_graph() translation_result_list = translate(args_dict) - tf_bleu_score = translation_result_list[0].bleu_score.score + # translation_result_list[0] is warmup, skip it op_decoder_bleu_score = translation_result_list[1].bleu_score.score op_decoding_bleu_score = translation_result_list[2].bleu_score.score - + if decoder_bleu_score_threshold != None: + self.assertTrue(op_decoder_bleu_score >= decoder_bleu_score_threshold) + if decoding_bleu_score_threshold != None: + self.assertTrue(op_decoding_bleu_score >= decoding_bleu_score_threshold) sys.stdout.flush() - if op_decoder_bleu_score >= tf_bleu_score - 1.0 and op_decoding_bleu_score >= tf_bleu_score - 1.0: - result.value = 1 - else: - result.value = 0 - + def test_decoding_beamsearch_fp32(self): - os.system("./bin/decoding_gemm 32 4 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(4, 'fp32', '012', batch_size=32) + os.system("./bin/decoding_gemm 32 4 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(4, 'fp32', '12', batch_size=32, decoder_bleu_score_threshold=37.0, decoding_bleu_score_threshold=37.0) def test_decoding_beamsearch_fp16(self): - os.system("./bin/decoding_gemm 32 4 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(4, 'fp16', '012', batch_size=32) + os.system("./bin/decoding_gemm 32 4 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(4, 'fp16', '12', batch_size=32, decoder_bleu_score_threshold=37.0, decoding_bleu_score_threshold=37.0) def test_decoding_beamsearch_fp32_2(self): - os.system("./bin/decoding_gemm 16 32 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(32, 'fp32', '012', batch_size=16) + os.system("./bin/decoding_gemm 16 32 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(32, 'fp32', '12', batch_size=16, decoder_bleu_score_threshold=35.0, decoding_bleu_score_threshold=35.0) def test_decoding_beamsearch_fp16_2(self): - os.system("./bin/decoding_gemm 16 32 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(32, 'fp16', '012', batch_size=16) + os.system("./bin/decoding_gemm 16 32 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(32, 'fp16', '12', batch_size=16, decoder_bleu_score_threshold=35.0, decoding_bleu_score_threshold=35.0) def test_decoding_topk_sampling_fp32(self): - os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(1, 'fp32', '345', 4, 0.0) + os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(1, 'fp32', '45', 4, 0.0, decoder_bleu_score_threshold=25.0, decoding_bleu_score_threshold=25.0) def test_decoding_topk_sampling_fp16(self): - os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(1, 'fp16', '345', 4, 0.0) + os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(1, 'fp16', '45', 4, 0.0, decoder_bleu_score_threshold=25.0, decoding_bleu_score_threshold=25.0) def test_decoding_topk_sampling_fp32_2(self): - os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(1, 'fp32', '345', 64, 0.0) + os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(1, 'fp32', '45', 64, 0.0, decoder_bleu_score_threshold=19.0, decoding_bleu_score_threshold=17.0) def test_decoding_topk_sampling_fp16_2(self): - os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(1, 'fp16', '345', 64, 0.0) + os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(1, 'fp16', '45', 64, 0.0, decoder_bleu_score_threshold=19.0, decoding_bleu_score_threshold=17.0) def test_decoding_topp_sampling_fp32(self): - os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(1, 'fp32', '345', 0, 0.5) + os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(1, 'fp32', '45', 0, 0.5, decoder_bleu_score_threshold=30.0, decoding_bleu_score_threshold=29.0) def test_decoding_topp_sampling_fp16(self): - os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(1, 'fp16', '345', 0, 0.5) + os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(1, 'fp16', '45', 0, 0.5, decoder_bleu_score_threshold=30.0, decoding_bleu_score_threshold=29.0) def test_decoding_topp_sampling_fp32_2(self): - os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(1, 'fp32', '345', 0, 0.9) + os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 0 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(1, 'fp32', '45', 0, 0.9, decoder_bleu_score_threshold=16.0, decoding_bleu_score_threshold=14.5) def test_decoding_topp_sampling_fp16_2(self): - os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat .tmp.gemm.log") - self.check_result(1, 'fp16', '345', 0, 0.9) + os.system("./bin/decoding_gemm 128 1 8 64 2048 32001 128 512 1 > .tmp.gemm.log && cat gemm_config.in") + self.check_result(1, 'fp16', '45', 0, 0.9, decoder_bleu_score_threshold=16.0, decoding_bleu_score_threshold=14.5) if __name__ == "__main__": unittest.main() diff --git a/tests/decoding/tf_fused_self_multihead_attention_unit_test.py b/tests/decoding/tf_fused_self_multihead_attention_unit_test.py index b18145a6e..3b6ffd247 100644 --- a/tests/decoding/tf_fused_self_multihead_attention_unit_test.py +++ b/tests/decoding/tf_fused_self_multihead_attention_unit_test.py @@ -70,7 +70,7 @@ def run_attn(self, batch_size, seq_len, head_num, size_per_head, data_type): if data_type == tf.float16: threshold = 4e-3 # Inputs: qkv_buf and k/v cache - # Do: update k/v cahce, and compute attention (Q*K, QK*V) + # Do: update k/v cache, and compute attention (Q*K, QK*V) # Output: attention result, new k/v cache # Notes: Only used for decoder, so seqlen of q is always 1. @@ -171,4 +171,4 @@ def run_attn(self, batch_size, seq_len, head_num, size_per_head, data_type): assert(v_cache_max_diff < threshold) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/longformer/py_longformer_unit_test.py b/tests/longformer/py_longformer_unit_test.py index a4fe0bfb9..34f76cdd1 100644 --- a/tests/longformer/py_longformer_unit_test.py +++ b/tests/longformer/py_longformer_unit_test.py @@ -61,14 +61,14 @@ def __init__(self, methodName: str) -> None: self.model_dir = "examples/pytorch/longformer/longformer-large-4096-finetuned-triviaqa" self.ft_longformer_lib = os.path.join('build', 'lib', 'libth_longformer.so') - def run_all_qa(self, seq_len, batch_size, ft_longformer, fp16): + def run_all_qa(self, seq_len, batch_size, ft_longformer, data_type): for idx in range(len(self.passage_texts)): passage_text = self.passage_texts[idx] question = self.questions[idx] answer = self.answers[idx] input_ids_b, local_attn_mask_b, global_attn_mask_b, input_ids, actual_seq_len = prepare_input( - question, passage_text, seq_len, batch_size, self.model_dir, fp16) + question, passage_text, seq_len, batch_size, self.model_dir, data_type) with torch.no_grad(): outputs = ft_longformer(input_ids_b, @@ -88,9 +88,9 @@ def test_fp32_with_qa_answer(self): ft_longformer = build_ft_longformer(self.model_dir, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, self.ft_longformer_lib, fp16=False) + attn_scaler, self.ft_longformer_lib, data_type='fp32') - self.run_all_qa(seq_len, batch_size, ft_longformer, False) + self.run_all_qa(seq_len, batch_size, ft_longformer, 'fp32') def test_fp32_with_qa_answer_2(self): seq_len = 2048 @@ -103,9 +103,9 @@ def test_fp32_with_qa_answer_2(self): ft_longformer = build_ft_longformer(self.model_dir, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, self.ft_longformer_lib, fp16=False) + attn_scaler, self.ft_longformer_lib, data_type='fp32') - self.run_all_qa(seq_len, batch_size, ft_longformer, False) + self.run_all_qa(seq_len, batch_size, ft_longformer, 'fp32') def test_fp32_with_qa_answer_3(self): seq_len = 4096 @@ -118,9 +118,9 @@ def test_fp32_with_qa_answer_3(self): ft_longformer = build_ft_longformer(self.model_dir, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, self.ft_longformer_lib, fp16=False) + attn_scaler, self.ft_longformer_lib, data_type='fp32') - self.run_all_qa(seq_len, batch_size, ft_longformer, False) + self.run_all_qa(seq_len, batch_size, ft_longformer, 'fp32') def test_fp16_with_qa_answer(self): seq_len = 1024 @@ -133,9 +133,9 @@ def test_fp16_with_qa_answer(self): ft_longformer = build_ft_longformer(self.model_dir, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, self.ft_longformer_lib, fp16=True) + attn_scaler, self.ft_longformer_lib, data_type='fp16') - self.run_all_qa(seq_len, batch_size, ft_longformer, True) + self.run_all_qa(seq_len, batch_size, ft_longformer, 'fp16') def test_fp16_with_qa_answer_2(self): seq_len = 1536 @@ -148,9 +148,9 @@ def test_fp16_with_qa_answer_2(self): ft_longformer = build_ft_longformer(self.model_dir, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, self.ft_longformer_lib, fp16=True) + attn_scaler, self.ft_longformer_lib, data_type='fp16') - self.run_all_qa(seq_len, batch_size, ft_longformer, True) + self.run_all_qa(seq_len, batch_size, ft_longformer, 'fp16') def test_fp16_with_qa_answer_3(self): seq_len = 4096 @@ -163,9 +163,54 @@ def test_fp16_with_qa_answer_3(self): ft_longformer = build_ft_longformer(self.model_dir, layer_num, head_num, size_per_head, intermediate_size, local_attn_window_size, max_global_token_num, batch_size, seq_len, - attn_scaler, self.ft_longformer_lib, fp16=True) + attn_scaler, self.ft_longformer_lib, data_type='fp16') - self.run_all_qa(seq_len, batch_size, ft_longformer, True) + self.run_all_qa(seq_len, batch_size, ft_longformer, 'fp16') + + def test_bf16_with_qa_answer(self): + seq_len = 1024 + batch_size = 1 + max_global_token_num = 128 + + (layer_num, _, head_num, size_per_head, + intermediate_size, local_attn_window_size, attn_scaler) = parse_from_config(self.model_dir) + + ft_longformer = build_ft_longformer(self.model_dir, layer_num, head_num, size_per_head, + intermediate_size, local_attn_window_size, + max_global_token_num, batch_size, seq_len, + attn_scaler, self.ft_longformer_lib, data_type='bf16') + + self.run_all_qa(seq_len, batch_size, ft_longformer, 'bf16') + + def test_bf16_with_qa_answer_2(self): + seq_len = 1536 + batch_size = 4 + max_global_token_num = 64 + + (layer_num, _, head_num, size_per_head, + intermediate_size, local_attn_window_size, attn_scaler) = parse_from_config(self.model_dir) + + ft_longformer = build_ft_longformer(self.model_dir, layer_num, head_num, size_per_head, + intermediate_size, local_attn_window_size, + max_global_token_num, batch_size, seq_len, + attn_scaler, self.ft_longformer_lib, data_type='bf16') + + self.run_all_qa(seq_len, batch_size, ft_longformer, 'bf16') + + def test_bf16_with_qa_answer_3(self): + seq_len = 4096 + batch_size = 8 + max_global_token_num = 256 + + (layer_num, _, head_num, size_per_head, + intermediate_size, local_attn_window_size, attn_scaler) = parse_from_config(self.model_dir) + + ft_longformer = build_ft_longformer(self.model_dir, layer_num, head_num, size_per_head, + intermediate_size, local_attn_window_size, + max_global_token_num, batch_size, seq_len, + attn_scaler, self.ft_longformer_lib, data_type='bf16') + + self.run_all_qa(seq_len, batch_size, ft_longformer, 'bf16') if __name__ == "__main__": diff --git a/tests/unittests/CMakeLists.txt b/tests/unittests/CMakeLists.txt index 31d42fb7d..a6bc0fb5a 100644 --- a/tests/unittests/CMakeLists.txt +++ b/tests/unittests/CMakeLists.txt @@ -13,15 +13,40 @@ # limitations under the License. add_executable(test_gemm test_gemm.cu) -target_link_libraries(test_gemm PUBLIC -lcublas -lcudart -lcurand gemm cublasMMWrapper) +target_link_libraries(test_gemm PUBLIC -lcublas -lcudart -lcurand gemm cublasMMWrapper tensor) + +add_executable(test_gpt_kernels test_gpt_kernels.cu) +target_link_libraries(test_gpt_kernels PUBLIC + gpt_kernels memory_utils) add_executable(test_sampling test_sampling.cu) target_link_libraries(test_sampling PUBLIC -lcublas -lcublasLt -lcudart cublasMMWrapper memory_utils - DynamicDecodeLayer TopKSamplingLayer TopPSamplingLayer TopKTopPSamplingLayer) + DynamicDecodeLayer TopKSamplingLayer TopPSamplingLayer TopKTopPSamplingLayer tensor) add_executable(test_logprob_kernels test_logprob_kernels.cu) target_link_libraries(test_logprob_kernels PUBLIC -lcublas -lcublasLt -lcudart logprob_kernels memory_utils) + +add_executable(test_penalty_kernels test_penalty_kernels.cu) +target_link_libraries(test_penalty_kernels PUBLIC + -lcublas -lcublasLt -lcudart + sampling_penalty_kernels beam_search_penalty_kernels memory_utils) + +add_executable(test_sampling_kernels test_sampling_kernels.cu) +target_link_libraries(test_sampling_kernels PUBLIC -lcudart + sampling_topk_kernels sampling_topp_kernels memory_utils tensor) +add_executable(test_tensor test_tensor.cu) +target_link_libraries(test_tensor PUBLIC tensor) + +add_executable(test_activation test_activation.cu) +target_link_libraries(test_activation PUBLIC + -lcublas -lcublasLt -lcudart + activation_kernels memory_utils) + +add_executable(test_context_decoder_layer test_context_decoder_layer.cu) +target_link_libraries(test_context_decoder_layer PUBLIC + ParallelGpt -lcublas -lcublasLt -lcudart + memory_utils tensor) diff --git a/tests/unittests/test_activation.cu b/tests/unittests/test_activation.cu new file mode 100644 index 000000000..aa8870bf9 --- /dev/null +++ b/tests/unittests/test_activation.cu @@ -0,0 +1,120 @@ +#include // snprintf +#include // std::string +#include // std::vector + +#include "src/fastertransformer/kernels/activation_kernels.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/memory_utils.h" + +#include "unittest_utils.h" + +using namespace fastertransformer; + +#define PRINT_LIMIT 16 +#define EPSILON (1e-20) +#define EPSILON_FP16 (1e-10) + +struct TestCase { + std::string name; + size_t m; + size_t n; + size_t ite; + + std::string toString() + { + char buf[100]; + snprintf(buf, sizeof(buf), "TestCase[name=%s, m=%ld, n=%ld]", name.c_str(), m, n); + return buf; + } + + void print() + { + FT_LOG_INFO(toString()); + } +}; + +template +void testActivationKernel(TestCase tc) +{ + const int m = tc.m; + const int n = tc.n; + cudaStream_t stream; + cudaStreamCreate(&stream); + + T *output_baseline, *output_opt1, *bias; + deviceMalloc(&output_baseline, m * n); + deviceMalloc(&output_opt1, m * n); + deviceMalloc(&bias, n); + cudaD2Dcpy(output_opt1, output_baseline, m * n); + invokeAddBiasGelu(output_baseline, bias, m, n, stream); + invokeAddBiasGeluV2(output_opt1, bias, m, n, stream); + bool passed = checkResult(tc.name, output_baseline, output_opt1, m * n, true, true); + FT_CHECK(passed); + + const int ite = tc.ite; + CudaTimer cuda_timer_baseline(stream); + // warmup + for (int i = 0; i < ite; i++) { + invokeAddBiasGelu(output_baseline, bias, m, n, stream); + } + cuda_timer_baseline.start(); + for (int i = 0; i < ite; i++) { + invokeAddBiasGelu(output_baseline, bias, m, n, stream); + } + float total_time_baseline = cuda_timer_baseline.stop(); + + CudaTimer cuda_timer_opt(stream); + // warmup + for (int i = 0; i < ite; i++) { + invokeAddBiasGeluV2(output_baseline, bias, m, n, stream); + } + cuda_timer_opt.start(); + for (int i = 0; i < ite; i++) { + invokeAddBiasGeluV2(output_baseline, bias, m, n, stream); + } + float total_time_opt = cuda_timer_opt.stop(); + FT_LOG_INFO("%s baseline_time: %f us, opt_time: %f us, speedup: %f (ite: %d)", + tc.toString().c_str(), + total_time_baseline / ite * 1000.f, + total_time_opt / ite * 1000.f, + total_time_baseline / total_time_opt, + ite); + + deviceFree(output_baseline); + deviceFree(output_opt1); + deviceFree(bias); +} + +int main() +{ + printf("[INFO] Device: %s \n", getDeviceName().c_str()); + std::vector test_cases{ + // TC: name / m / n + TestCase{"addBiasGelu", 32, 1024, 1000}, + TestCase{"addBiasGelu", 128, 1024, 1000}, + TestCase{"addBiasGelu", 2048, 1024, 1000}, + TestCase{"addBiasGelu", 32, 3072, 1000}, + TestCase{"addBiasGelu", 128, 3072, 1000}, + TestCase{"addBiasGelu", 2048, 3072, 1000}, + TestCase{"addBiasGelu", 32, 4096, 1000}, + TestCase{"addBiasGelu", 128, 4096, 1000}, + TestCase{"addBiasGelu", 2048, 4096, 1000}, + TestCase{"addBiasGelu", 32, 8192, 1000}, + TestCase{"addBiasGelu", 128, 8192, 1000}, + TestCase{"addBiasGelu", 2048, 8192, 1000}, + TestCase{"addBiasGelu", 32, 49152, 1000}, + TestCase{"addBiasGelu", 128, 49152, 1000}, + TestCase{"addBiasGelu", 2048, 49152, 1000}, + TestCase{"addBiasGelu", 32, 81920, 1000}, + TestCase{"addBiasGelu", 128, 81920, 1000}, + TestCase{"addBiasGelu", 2048, 81920, 1000}, + }; + + for (auto& tc : test_cases) { + // testActivationKernel(tc); + testActivationKernel(tc); + } + FT_LOG_INFO("testActivationKernel done"); + + return 0; +} diff --git a/tests/unittests/test_context_decoder_layer.cu b/tests/unittests/test_context_decoder_layer.cu new file mode 100644 index 000000000..0a0f08d71 --- /dev/null +++ b/tests/unittests/test_context_decoder_layer.cu @@ -0,0 +1,263 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "src/fastertransformer/layers/DenseWeight.h" +#include "src/fastertransformer/utils/allocator.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/gemm.h" +#include "src/fastertransformer/utils/logger.h" +#include "src/fastertransformer/utils/memory_utils.h" +#include "src/fastertransformer/utils/Tensor.h" + +#include "src/fastertransformer/kernels/gpt_kernels.h" +#include "src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h" + +using namespace fastertransformer; + +static const char* usage = + "Usage: %s \n" + "Example: $test_context_decoder_layer ../models/megatron_models/c-model/345m/ ../tests/data\n"; + + +template +bool test_context_sharing(const std::string& weight_dir, const std::string& data_dir); +void allocate_tensors(std::vector &tensors); +void free_tensors(std::vector &tensors); +template bool all_close(Tensor &tensor_x, Tensor &tensor_y); +Tensor tensor_to_cpu(Tensor &tensor); + + +int main(int argc, const char* argv[]) +{ + if (argc != 3) { + printf(usage, argv[0]); + return EXIT_FAILURE; + } + + bool result = true; + result &= test_context_sharing( + argv[1], argv[2] + std::string("/gpt_context_decoder_inputs")); + + return result ? EXIT_SUCCESS: EXIT_FAILURE; +} + +template +bool test_context_sharing(const std::string& weight_dir, const std::string& data_dir) +{ + const size_t head_num = 16; + const size_t size_per_head = 64; + const size_t hidden_units = head_num * size_per_head; + const size_t inter_size = 4 * hidden_units; + const size_t decoder_layers = 2, num_layer = 2; // Reduce the number of layers for faster loading / processing + const size_t max_seq_len = 1024; + const size_t vocab_size = 50304; + /* start_id = 50256 */ + /* end_id = 50256 */ + /* weight_data_type = fp32 */ + /* tensor_para_size = 1 */ + const DataType data_type = getTensorType(); + + NcclParam tensor_para; + NcclParam pipeline_para; + + cudaStream_t stream; + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + check_cuda_error(cudaStreamCreate(&stream)); + check_cuda_error(cublasCreate(&cublas_handle)); + check_cuda_error(cublasLtCreate(&cublaslt_handle)); + check_cuda_error(cublasSetStream(cublas_handle, stream)); + + cublasAlgoMap cublas_algo_map(GEMM_CONFIG); + Allocator * allocator = new Allocator(getDevice()); + allocator->setStream(stream); + + std::mutex* cublas_wrapper_mutex = new std::mutex(); + cublasMMWrapper *cublas_wrapper = new cublasMMWrapper(cublas_handle, + cublaslt_handle, + stream, + &cublas_algo_map, + cublas_wrapper_mutex, + allocator); + if (std::is_same::value) { + cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } + else if (std::is_same::value) { + cublas_wrapper->setFP32GemmConfig(); + } + + ParallelGptWeight gpt_weights( + hidden_units, inter_size, vocab_size, decoder_layers, max_seq_len, + 1, // tensor_para_size + 0, // tensor_para_rank + 1, // layer_para_size + 0, // layer_para_rank + 0 // int8 + ); + gpt_weights.loadModel((weight_dir + std::string("/1-gpu")).c_str()); + + ParallelGptContextDecoder gpt_context_decoder( + 0, + 0, + head_num, + size_per_head, + inter_size, + num_layer, + 1e-5f, // layernorm_eps + gptVariantParams {}, + tensor_para, + pipeline_para, + stream, + cublas_wrapper, + allocator, + false, // is_free_buffer_after_forward + true, // is_context_qk_buf_float + false, // sparse + nullptr, // custom_all_reduce_comm + false, // enable_custom_all_reduce + false // remove_padding + ); + + /*************************** REFERENCE PART *********************************/ + + auto decoder_inputs_import = TensorMap::fromNpyFolder(data_dir); + + const size_t seq_num = decoder_inputs_import.at("context_decoder_input").shape[0]; + const size_t seq_len = decoder_inputs_import.at("context_decoder_input").shape[1]; + + std::vector decoder_inputs {decoder_inputs_import.at("context_decoder_input"), + decoder_inputs_import.at("input_attention_mask"), + decoder_inputs_import.at("tiled_input_lengths")}; + + const std::vector self_k_cache_shape = {num_layer / 1, + seq_num, + head_num, + size_per_head / (16 / sizeof(T)), + max_seq_len, + 16 / sizeof(T)}; + const std::vector self_v_cache_shape = {num_layer / 1, + seq_num, + head_num, + max_seq_len, + size_per_head}; + std::vector decoder_outputs { + Tensor{MEMORY_GPU, + data_type, + {seq_num, (size_t)seq_len, hidden_units}, + nullptr}, + Tensor{MEMORY_GPU, data_type, self_k_cache_shape, nullptr}, + Tensor{MEMORY_GPU, data_type, self_v_cache_shape, nullptr}, + Tensor{MEMORY_GPU, data_type, {seq_num, hidden_units}, nullptr}}; + + allocate_tensors(decoder_outputs); + cudaMemset((void *) decoder_outputs[1].data, 0, decoder_outputs[1].sizeBytes()); + cudaMemset((void *) decoder_outputs[2].data, 0, decoder_outputs[2].sizeBytes()); + + gpt_context_decoder.forward( + &decoder_outputs, + &decoder_inputs, + &gpt_weights.decoder_layer_weights + ); + + /********************************* TEST PART *********************************/ + + decoder_inputs.push_back(decoder_inputs_import.at("compact_idx")); + decoder_inputs.push_back(decoder_inputs_import.at("batch_to_compact_idx")); + + std::vector decoder_outputs_test { + {MEMORY_GPU, + data_type, + {seq_num, (size_t)seq_len, hidden_units}, + nullptr}, + {MEMORY_GPU, data_type, self_k_cache_shape, nullptr}, + {MEMORY_GPU, data_type, self_v_cache_shape, nullptr}, + {MEMORY_GPU, data_type, {seq_num, hidden_units}, nullptr}}; + + allocate_tensors(decoder_outputs_test); + cudaMemset((void *) decoder_outputs_test[1].data, 0, decoder_outputs_test[1].sizeBytes()); + cudaMemset((void *) decoder_outputs_test[2].data, 0, decoder_outputs_test[2].sizeBytes()); + + gpt_context_decoder.forward( + &decoder_outputs_test, + &decoder_inputs, + &gpt_weights.decoder_layer_weights + ); + + all_close(decoder_outputs[0], decoder_outputs_test[0]); printf("."); + all_close(decoder_outputs[3], decoder_outputs_test[3]); printf("."); + all_close(decoder_outputs[1], decoder_outputs_test[1]); printf("."); + all_close(decoder_outputs[2], decoder_outputs_test[2]); printf("."); + puts(""); + + free_tensors(decoder_outputs); + free_tensors(decoder_outputs_test); + free_tensors(decoder_inputs); + + return true; +} + +Tensor tensor_to_cpu(Tensor &tensor) +{ + FT_CHECK(tensor.where == MEMORY_GPU); + void *host_ptr = malloc(tensor.sizeBytes()); + cudaMemcpy(host_ptr, tensor.data, tensor.sizeBytes(), cudaMemcpyDeviceToHost); + + return Tensor {MEMORY_CPU, tensor.type, tensor.shape, host_ptr}; +} + +void allocate_tensors(std::vector &tensors) +{ + for (auto &tensor : tensors) { + auto size = std::accumulate(tensor.shape.begin(), tensor.shape.end(), 1, std::multiplies()); + auto size_bytes = size * Tensor::getTypeSize(tensor.type); + if (tensor.where == MEMORY_GPU) { + cudaMalloc(&tensor.data, size_bytes); + } + else { + tensor.data = malloc(size_bytes); + } + } +} + +void free_tensors(std::vector &tensors) +{ + for (auto &tensor : tensors) { + if (tensor.where == MEMORY_GPU) { + cudaFree((void *) tensor.data); + } + else { + free((void *) tensor.data); + } + tensor.data = nullptr; + } +} + +template +bool all_close(Tensor &tensor_x, Tensor &tensor_y) +{ + Tensor tensor_x_h = tensor_to_cpu(tensor_x); + Tensor tensor_y_h = tensor_to_cpu(tensor_y); + + FT_CHECK(tensor_x.size() == tensor_y.size()); + size_t n_elems = tensor_x.size(); + + const float r_tol = 1e-5; + const float a_tol = 1e-8; + for (size_t idx = 0; idx < n_elems; idx++) { + const float x_value = tensor_x_h.getPtr()[idx]; + const float y_value = tensor_y_h.getPtr()[idx]; + + FT_CHECK(fabsf(x_value - y_value) <= (a_tol + r_tol * fabsf(y_value))); + } + + free((void *) tensor_x_h.data); + free((void *) tensor_y_h.data); + + return true; +} diff --git a/tests/unittests/test_gemm.cu b/tests/unittests/test_gemm.cu index 13719f7ff..d45c447b8 100644 --- a/tests/unittests/test_gemm.cu +++ b/tests/unittests/test_gemm.cu @@ -86,7 +86,7 @@ public: ~TensorWrapper() { delete tensor; - allocator->free(data); + allocator->free((void**)(&data)); } void setInvalidValues() @@ -755,7 +755,7 @@ void testSpGemmCorrectnessMatmul(size_t m, size_t n, size_t k) { c_tensor.data); EXPECT_ALMOST_EQUAL(tc_name + " api4", T, computeType, c_tensor, expected); - allocator.free(b_compressed); + allocator.free((void**)(&b_compressed)); } check_cuda_error(cudaStreamDestroy(stream)); } diff --git a/tests/unittests/test_gpt_kernels.cu b/tests/unittests/test_gpt_kernels.cu new file mode 100644 index 000000000..cef959078 --- /dev/null +++ b/tests/unittests/test_gpt_kernels.cu @@ -0,0 +1,338 @@ +#include +#include + +#include "src/fastertransformer/kernels/gpt_kernels.h" +#include "src/fastertransformer/utils/memory_utils.h" + +#include "unittest_utils.h" + +int test_find_context_dups(); +int test_compact(); +int test_uncompact(); + +int main(int argc, char* argv[]) +{ + bool all_passed = true; + bool passed; + + passed = test_find_context_dups() == EXIT_SUCCESS; + all_passed |= passed; + printf("%s", passed ? "." : "X"); + if (!passed) { + puts("\ntest_find_context_dups: FAILED"); + } + + passed = test_compact() == EXIT_SUCCESS; + all_passed |= passed; + printf("%s", passed ? "." : "X"); + if (!passed) { + puts("\ntest_compact: FAILED"); + } + + passed = test_uncompact() == EXIT_SUCCESS; + all_passed |= passed; + printf("%s", passed ? "." : "X"); + if (!passed) { + puts("\ntest_uncompact: FAILED"); + } + + puts(""); + return all_passed ? EXIT_SUCCESS : EXIT_FAILURE; +} + +int test_find_context_dups() +{ + const size_t vec_size = 1234; + const size_t batch_size = 8; + // Reference to the first unique vector + const std::vector shared_contexts_ref {0, 0, 2, 3, 4, 4, 3, 3}; + + // Which compact index belong to what vector + const std::vector batch_idx_to_compact_idx {0, 0, 1, 2, 3, 3, 2, 2}; + std::vector batch_idx_to_compact_idx_test(batch_size); + + // Reverse map of batch_idx_to_compact_idx + const std::vector compact_idx_to_batch_idx {0, 2, 3, 4, -1, -1, -1, -1}; + std::vector compact_idx_to_batch_idx_test(batch_size, -1); + + std::vector input_ids; + std::vector default_vector(vec_size, 0); + + for (size_t i = 0; i < batch_size; ++i) { + default_vector[vec_size - 1] = shared_contexts_ref[i]; + input_ids.insert(input_ids.end(), default_vector.begin(), default_vector.end()); + } + + std::vector shared_contexts_test(batch_size); + + int* d_input_ids; + int* d_shared_contexts_test; + int* d_batch_idx_to_compact_idx; + int* d_compact_to_batch; + int* d_compact_size; + cudaMalloc(&d_input_ids, batch_size * vec_size * sizeof(int)); + cudaMalloc(&d_shared_contexts_test, batch_size * sizeof(int)); + cudaMalloc(&d_batch_idx_to_compact_idx, batch_size * sizeof(int)); + cudaMalloc(&d_compact_to_batch, batch_size * sizeof(int)); + cudaMalloc(&d_compact_size, sizeof(int)); + + cudaH2Dcpy(d_input_ids, input_ids.data(), batch_size * vec_size); + cudaH2Dcpy(d_compact_to_batch, compact_idx_to_batch_idx_test.data(), batch_size); + + invokeFindContextDups(d_shared_contexts_test, + d_batch_idx_to_compact_idx, + d_compact_to_batch, + d_compact_size, + d_input_ids, + batch_size, + vec_size); + + int compact_size; + cudaD2Hcpy(shared_contexts_test.data(), d_shared_contexts_test, batch_size); + cudaD2Hcpy(batch_idx_to_compact_idx_test.data(), d_batch_idx_to_compact_idx, batch_size); + cudaD2Hcpy(compact_idx_to_batch_idx_test.data(), d_compact_to_batch, batch_size); + cudaD2Hcpy(&compact_size, d_compact_size, 1); + + cudaFree(d_input_ids); + cudaFree(d_shared_contexts_test); + + EXPECT_TRUE(shared_contexts_test == shared_contexts_ref); + EXPECT_TRUE(batch_idx_to_compact_idx == batch_idx_to_compact_idx_test); + EXPECT_TRUE(compact_idx_to_batch_idx_test == compact_idx_to_batch_idx); + EXPECT_TRUE(compact_size == 4); + + return EXIT_SUCCESS; +} + +int test_compact() +{ + size_t batch_size = 128; + size_t compact_size = 5; + size_t seq_len = 40; + size_t hidden_dimension = 8; + auto generator_f = std::bind(std::uniform_real_distribution(-1.0, 1.0), std::mt19937()); + auto generator_i = std::bind(std::uniform_int_distribution(0, 128), std::mt19937()); + + // decoder_input [batch_size, seq_len, hidden_dimension] -> + // compact_decoder_input [compact_size, seq_len, hidden_dimension] + std::vector decoder_input(batch_size * seq_len * hidden_dimension); + std::vector compact_decoder_input(compact_size * seq_len * hidden_dimension); + std::generate(decoder_input.begin(), decoder_input.end(), generator_f); + float *d_decoder_input, *d_compact_decoder_input; + cudaMalloc(&d_decoder_input, decoder_input.size() * sizeof(float)); + cudaMalloc(&d_compact_decoder_input, compact_decoder_input.size() * sizeof(float)); + cudaH2Dcpy(d_decoder_input, decoder_input.data(), decoder_input.size()); + + // attention_mask [batch_size, seq_len, seq_len] -> + // compact_attention_mask [compact_size, seq_len, seq_len] + std::vector attention_mask(batch_size * seq_len * seq_len); + std::vector compact_attention_mask(compact_size * seq_len * seq_len); + std::generate(attention_mask.begin(), attention_mask.end(), generator_f); + float *d_attention_mask, *d_compact_attention_mask; + cudaMalloc(&d_attention_mask, attention_mask.size() * sizeof(float)); + cudaMalloc(&d_compact_attention_mask, compact_attention_mask.size() * sizeof(float)); + cudaH2Dcpy(d_attention_mask, attention_mask.data(), attention_mask.size()); + + // input_lengths [batch_size] -> compact_input_lengths [compact_size] + std::vector input_lengths(batch_size); + std::vector compact_input_lengths(compact_size); + std::generate(input_lengths.begin(), input_lengths.end(), generator_i); + int *d_input_lengths, *d_compact_input_lengths; + cudaMalloc(&d_input_lengths, input_lengths.size() * sizeof(int)); + cudaMalloc(&d_compact_input_lengths, compact_input_lengths.size() * sizeof(int)); + cudaH2Dcpy(d_input_lengths, input_lengths.data(), input_lengths.size()); + + // compact_idx [compact_size] + /* std::vector compact_idx {0, 3}; */ + std::vector compact_idx {0, 29, 42, 44, 100}; + int *d_compact_idx; + cudaMalloc(&d_compact_idx, compact_idx.size() * sizeof(int)); + cudaH2Dcpy(d_compact_idx, compact_idx.data(), compact_idx.size()); + + invokeCompactInputs(d_compact_decoder_input, + d_compact_attention_mask, + d_compact_input_lengths, + d_decoder_input, + d_attention_mask, + d_input_lengths, + d_compact_idx, + compact_size, + seq_len, + hidden_dimension); + + cudaD2Hcpy(compact_decoder_input.data(), d_compact_decoder_input, compact_decoder_input.size()); + cudaD2Hcpy(compact_attention_mask.data(), d_compact_attention_mask, compact_attention_mask.size()); + cudaD2Hcpy(compact_input_lengths.data(), d_compact_input_lengths, compact_input_lengths.size()); + + for (size_t i = 0; i < compact_size; i++) { + for (size_t t = 0; t < seq_len; t++) { + for (size_t h = 0; h < hidden_dimension; h++) { + EXPECT_TRUE(compact_decoder_input[(i * seq_len + t) * hidden_dimension + h] == + decoder_input[(compact_idx[i] * seq_len + t) * hidden_dimension + h]); + } + } + } + + for (size_t i = 0; i < compact_size; i++) { + for (size_t t1 = 0; t1 < seq_len; t1++) { + for (size_t t2 = 0; t2 < seq_len; t2++) { + EXPECT_TRUE(compact_attention_mask[(i * seq_len + t1) * seq_len + t2] == + attention_mask[(compact_idx[i] * seq_len + t1) * seq_len + t2]); + } + } + } + + for (size_t i = 0; i < compact_size; i++) { + EXPECT_TRUE(compact_input_lengths[i] == input_lengths[compact_idx[i]]); + } + + cudaFree(d_decoder_input); + cudaFree(d_compact_decoder_input); + cudaFree(d_attention_mask); + cudaFree(d_compact_attention_mask); + cudaFree(d_input_lengths); + cudaFree(d_compact_input_lengths); + cudaFree(d_compact_idx); + + return EXIT_SUCCESS; +} + +int test_uncompact() +{ + // compact_decoder_outputs [compact_size, seq_len, hidden_dimension] -> + // decoder_outputs [batch_size, seq_len, hidden_dimension] + size_t batch_size = 128; + size_t compact_size = 6; + size_t local_batch_size = compact_size / 2; + size_t seq_len = 40; + size_t max_seq_len = 60; + size_t hidden_dimension = 8; + size_t num_layer = 2; + size_t num_head = 2; + size_t size_per_head = 4; + auto generator_f = std::bind(std::uniform_real_distribution(-1.0, 1.0), std::mt19937()); + auto generator_i = std::bind(std::uniform_int_distribution(0, compact_size - 1), std::mt19937()); + + std::vector compact_decoder_outputs(compact_size * seq_len * hidden_dimension); + std::vector decoder_outputs(batch_size * seq_len * hidden_dimension); + std::vector k_cache_compact(num_layer * compact_size * num_head * size_per_head * seq_len); + std::vector v_cache_compact(num_layer * compact_size * num_head * seq_len * size_per_head); + std::vector k_cache_out(num_layer * batch_size * num_head * size_per_head * max_seq_len); + std::vector v_cache_out(num_layer * batch_size * num_head * max_seq_len * size_per_head); + + std::generate(compact_decoder_outputs.begin(), compact_decoder_outputs.end(), generator_f); + std::generate(k_cache_compact.begin(), k_cache_compact.end(), generator_f); + std::generate(v_cache_compact.begin(), v_cache_compact.end(), generator_f); + + std::vector batch_to_compact_idx(batch_size); + std::generate(batch_to_compact_idx.begin(), batch_to_compact_idx.end(), generator_i); + + float *d_compact_decoder_outputs, *d_decoder_outputs, *d_k_cache, *d_v_cache; + float *d_k_cache_compact, *d_v_cache_compact; + + cudaMalloc(&d_compact_decoder_outputs, compact_decoder_outputs.size() * sizeof(float)); + cudaH2Dcpy(d_compact_decoder_outputs, compact_decoder_outputs.data(), compact_decoder_outputs.size()); + + cudaMalloc(&d_k_cache_compact, k_cache_compact.size() * sizeof(float)); + cudaMalloc(&d_v_cache_compact, v_cache_compact.size() * sizeof(float)); + cudaH2Dcpy(d_k_cache_compact, k_cache_compact.data(), k_cache_compact.size()); + cudaH2Dcpy(d_v_cache_compact, v_cache_compact.data(), v_cache_compact.size()); + + cudaMalloc(&d_k_cache, k_cache_out.size() * sizeof(float)); + cudaMalloc(&d_v_cache, v_cache_out.size() * sizeof(float)); + cudaMemset(d_k_cache, 0, k_cache_out.size() * sizeof(float)); + cudaMemset(d_v_cache, 0, v_cache_out.size() * sizeof(float)); + + cudaMalloc(&d_decoder_outputs, decoder_outputs.size() * sizeof(float)); + + int *d_batch_to_compact_idx; + cudaMalloc(&d_batch_to_compact_idx, batch_to_compact_idx.size() * sizeof(int)); + cudaH2Dcpy(d_batch_to_compact_idx, batch_to_compact_idx.data(), batch_to_compact_idx.size()); + + const size_t cache_stride_dst = max_seq_len * hidden_dimension; + const size_t cache_stride_src = seq_len * hidden_dimension; + for (size_t ite = 0; ite < (batch_size / local_batch_size); ite++) { + for (size_t l = 0; l < num_layer; l++) { + + const float *k_cache_offset = d_k_cache_compact + (l * compact_size + ite * local_batch_size) * cache_stride_src; + const float *v_cache_offset = d_v_cache_compact + (l * compact_size + ite * local_batch_size) * cache_stride_src; + + invokeUnCompactCaches(d_k_cache + l * batch_size * cache_stride_dst, + d_v_cache + l * batch_size * cache_stride_dst, + k_cache_offset, + v_cache_offset, + d_batch_to_compact_idx, + batch_size, + num_head, + max_seq_len, + seq_len, + size_per_head, + local_batch_size, + ite); + } + } + + invokeUnCompactOutputs(d_decoder_outputs, + d_compact_decoder_outputs, + d_batch_to_compact_idx, + batch_size, + cache_stride_src); + + cudaD2Hcpy(decoder_outputs.data(), d_decoder_outputs, decoder_outputs.size()); + cudaD2Hcpy(k_cache_out.data(), d_k_cache, k_cache_out.size()); + cudaD2Hcpy(v_cache_out.data(), d_v_cache, v_cache_out.size()); + + for (size_t i = 0; i < batch_size; i++) { + for (size_t t = 0; t < seq_len; t++) { + for (size_t h = 0; h < hidden_dimension; h++) { + EXPECT_TRUE(decoder_outputs[(i * seq_len + t) * hidden_dimension] == + compact_decoder_outputs[(batch_to_compact_idx[i] * seq_len + t) * hidden_dimension]); + } + } + } + + size_t x_size = (16 / sizeof(float)); + for (size_t l = 0; l < num_layer; l++) { + for (size_t i = 0; i < batch_size; i++) { + for (size_t h = 0; h < num_head; h++) { + for (size_t dh = 0; dh < size_per_head / x_size; dh++) { + for (size_t t = 0; t < seq_len; t++) { + for (size_t x = 0; x < x_size; x++) { + auto src = batch_to_compact_idx[i]; + EXPECT_TRUE( + k_cache_out[((((l * batch_size + i ) * num_head + h) * (size_per_head / x_size) + dh) * + max_seq_len + t) * x_size + x] == + k_cache_compact[((((l * compact_size + src) * num_head + h) * (size_per_head / x_size) + dh) * + seq_len + t) * x_size + x]); + } + } + } + } + } + } + + for (size_t l = 0; l < num_layer; l++) { + for (size_t i = 0; i < batch_size; i++) { + for (size_t h = 0; h < num_head; h++) { + for (size_t t = 0; t < seq_len; t++) { + for (size_t dh = 0; dh < size_per_head; dh++) { + auto src = batch_to_compact_idx[i]; + EXPECT_TRUE( + v_cache_out[(((l * batch_size + i ) * num_head + h) * max_seq_len + t) * size_per_head + dh] == + v_cache_compact[(((l * compact_size + src) * num_head + h) * seq_len + t) * size_per_head + dh]); + } + } + } + } + } + + cudaFree(d_compact_decoder_outputs); + cudaFree(d_k_cache_compact); + cudaFree(d_v_cache_compact); + cudaFree(d_k_cache); + cudaFree(d_v_cache); + cudaFree(d_decoder_outputs); + cudaFree(d_batch_to_compact_idx); + + return EXIT_SUCCESS; +} diff --git a/tests/unittests/test_logprob_kernels.cu b/tests/unittests/test_logprob_kernels.cu index 83d6ce672..a357aad71 100644 --- a/tests/unittests/test_logprob_kernels.cu +++ b/tests/unittests/test_logprob_kernels.cu @@ -12,32 +12,9 @@ #include "src/fastertransformer/utils/logger.h" #include "src/fastertransformer/utils/memory_utils.h" -using namespace fastertransformer; - -#define PRINT_LIMIT 16 -#define EPSILON (1e-20) - -// Can be replaced by the function provided by a test framework - -class TestFailureError : public std::exception { -private: - std::string msg_; -public: - explicit TestFailureError() = default; - explicit TestFailureError(std::string name, std::string msg = "") { - msg_ = fmtstr("TEST FAIL [%s] %s", name.c_str(), msg.c_str()); - } - const char* what () const throw () { - return msg_.c_str(); - } -}; +#include "tests/unittests/unittest_utils.h" -#define EXPECT_TRUE(cond) \ - do { if(!(cond)) { \ - FT_LOG_ERROR("TEST FAIL [%s] at %s:%d", \ - __func__, __FILE__, __LINE__); \ - throw TestFailureError(__func__); \ - } } while(false) +using namespace fastertransformer; #define EXPECT_ALMOST_EQUAL(name, dtype, ctype, out, ref) \ do { \ @@ -83,9 +60,15 @@ void computeCumLogProbs(float* cum_log_probs, const size_t vocab_size_padded) { for (size_t step = 0; step < max_input_length; ++step) { - size_t step_offset = step * batch_size * vocab_size_padded; for (size_t i = 0; i < batch_size; ++i) { - if ((int)step < input_lengths[i]) { + if ((int)step == 0) { + if (log_probs != nullptr) { + log_probs[i] = 0.0f; + } + cum_log_probs[i] = 0.0f; + } + else if ((int)step < input_lengths[i]) { + size_t step_offset = (step - 1) * batch_size * vocab_size_padded; const T* vec = logits + step_offset + i * vocab_size_padded; float max_logits = -FLT_MAX; for (size_t v = 0; v < vocab_size; ++v) { @@ -123,8 +106,14 @@ void computeCumLogProbsBatchFirst(float* cum_log_probs, for (size_t i = 0; i < batch_size; ++i) { size_t batch_offset = i * max_input_length * vocab_size_padded; for (size_t step = 0; step < max_input_length; ++step) { - if ((int)step < input_lengths[i]) { - const T* vec = logits + batch_offset + step * vocab_size_padded; + if ((int)step == 0) { + if (log_probs != nullptr) { + log_probs[i * max_input_length] = 0.0f; + } + cum_log_probs[i] = 0.0f; + } + else if ((int)step < input_lengths[i]) { + const T* vec = logits + batch_offset + (step - 1) * vocab_size_padded; float max_logits = -FLT_MAX; for (size_t v = 0; v < vocab_size; ++v) { float val = static_cast(vec[v]); @@ -147,129 +136,6 @@ void computeCumLogProbsBatchFirst(float* cum_log_probs, } } -bool almostEqual(float a, float b, float atol = 1e-5, float rtol = 1e-8) -{ - // Params: a = value to compare and b = reference - // This function follows implementation of numpy.isclose(), which checks - // abs(a - b) <= (atol + rtol * abs(b)). - // Note that the inequality above is asymmetric where b is considered as - // a reference value. To account into both absolute/relative errors, it - // uses absolute tolerance and relative tolerance at the same time. The - // default values of atol and rtol borrowed from numpy.isclose(). For the - // case of nan value, the result will be true. - if (isnan(a) && isnan(b)) { - return true; - } - return fabs(a - b) <= (atol + rtol * fabs(b)); -} - -template -bool checkResult(std::string name, T* out, T*ref, size_t size, float atol, float rtol) { - size_t failures = 0; - float relative_gap = 0.0f; - - T* h_out = reinterpret_cast(malloc(sizeof(T) * size)); - check_cuda_error(cudaMemcpy(h_out, out, sizeof(T) * size, cudaMemcpyDeviceToHost)); - - for (size_t i = 0; i < size; ++i) { - // The values for the output and the reference. - float a = (float)h_out[i]; - float b = (float)ref[i]; - - bool ok = almostEqual(a, b, atol, rtol); - // Print the error. - if (!ok && failures < 4) { - FT_LOG_ERROR(">> invalid result for i=%lu:", i); - FT_LOG_ERROR(">> found......: %10.6f", a); - FT_LOG_ERROR(">> expected...: %10.6f", b); - FT_LOG_ERROR(">> error......: %.6f", fabsf(a - b)); - FT_LOG_ERROR(">> tol........: %.6f", atol + rtol * fabs(b)); - } - // Update the number of failures. - failures += ok ? 0 : 1; - // Update the relative gap. - relative_gap += fabsf(a - b) / (fabsf(b) + EPSILON); - } - - relative_gap /= size; - - // Allow not matched up to 1% elements. - size_t tol_failures = (size_t)(0.01 * size); - FT_LOG_INFO("check.......%-30s : %s (failures: %.2f%% atol: %.2e rtol: %.2e rel_gap: %.2e%%)", - name.c_str(), failures <= tol_failures ? "OK" : "FAILED", - 100. * failures / size, atol, rtol, 100. * relative_gap); - return failures <= tol_failures; -} - -template -bool checkResult(std::string name, T* out, T* ref, size_t size) { - bool is_fp32 = sizeof(T) == 4; - float atol = is_fp32 ? 1e-6f : 1e-3f; - float rtol = is_fp32 ? 1e-4f : 1e-1f; - bool is_ok = checkResult(name, out, ref, size, atol, rtol); - return is_ok; -} - -template -void initRandom(T* ptr, size_t size, float minval, float maxval) { - for (size_t i = 0; i < size; ++i) { - float val = static_cast(rand()) / static_cast(RAND_MAX); - val *= (maxval - minval); - ptr[i] = static_cast(minval + val); - } -} - -void initRandomInt(int* ptr, size_t size, int minval, int maxval) { - assert(minval < maxval); - int mod = maxval - minval; - for (size_t i = 0; i < size; ++i) { - ptr[i] = minval + rand() % mod; - } -} - -template -static inline void printMatrixScientificFormat(T* ptr, int m, int k, int stride, bool is_device_ptr) -{ - T* tmp; - if (is_device_ptr) { - // k < stride ; stride = col-dimension. - tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); - check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); - cudaDeviceSynchronize(); - } - else { - tmp = ptr; - } - - for (int ii = -1; ii < m; ++ii) { - if (ii >= 0) { - printf("%02d ", ii); - } - else { - printf(" "); - } - - for (int jj = 0; jj < k; jj += 1) { - if (ii >= 0) { - printf("%11.4e ", (float)tmp[ii * stride + jj]); - } - else { - printf("%11d ", jj); - } - } - printf("\n"); - } - if (is_device_ptr) { - free(tmp); - } -} - -template -static inline void printMatrixWithLimit(T* ptr, int m, int k, int stride, bool is_device_ptr) { - printMatrixScientificFormat(ptr, std::min(PRINT_LIMIT, m), std::min(PRINT_LIMIT, k), stride, is_device_ptr); -} - - /////////////////////////////////// Unittests ////////////////////////////////////////// template @@ -277,7 +143,7 @@ void testCumLogProbCorrectness(TestCase tc) { size_t max_input_length = tc.max_input_length; size_t batchxbeam = tc.batch_size * tc.beam_width; size_t vocab_size = tc.vocab_size; - // Make mulitple of 8 as GPT does. + // Make multiple of 8 as GPT does. size_t vocab_size_padded = static_cast(ceil(vocab_size / 8.f) * 8); cudaStream_t stream; @@ -285,12 +151,12 @@ void testCumLogProbCorrectness(TestCase tc) { Allocator allocator(getDevice()); // input values - T* h_logits = reinterpret_cast(malloc(sizeof(T) * max_input_length * batchxbeam * vocab_size)); - int* h_input_ids = reinterpret_cast(malloc(sizeof(int) * max_input_length * batchxbeam)); - int* h_input_lengths = reinterpret_cast(malloc(sizeof(int) * batchxbeam)); + T* h_logits = new T[max_input_length * batchxbeam * vocab_size]; + int* h_input_ids = new int[max_input_length * batchxbeam]; + int* h_input_lengths = new int[batchxbeam]; // outupt buffers - float* expected_cum_log_probs = reinterpret_cast(malloc(sizeof(float) * batchxbeam)); + float* expected_cum_log_probs = new float[batchxbeam]; // initialize host buffers initRandom(h_logits, max_input_length * batchxbeam * vocab_size, -10.0f / vocab_size, -1.0f); @@ -333,19 +199,21 @@ void testCumLogProbCorrectness(TestCase tc) { batchxbeam, vocab_size, vocab_size_padded); - checkResult(tc.toString().c_str(), d_cum_log_probs, expected_cum_log_probs, batchxbeam); + std::string tag = tc.toString() + (std::is_same::value ? " (fp32)" : " (fp16)"); + bool passed = checkResult(tag.c_str(), d_cum_log_probs, expected_cum_log_probs, batchxbeam); + EXPECT_TRUE(passed); FT_LOG_DEBUG("free host buffers"); - free(expected_cum_log_probs); - free(h_input_lengths); - free(h_input_ids); - free(h_logits); + delete[] expected_cum_log_probs; + delete[] h_input_lengths; + delete[] h_input_ids; + delete[] h_logits; FT_LOG_DEBUG("free device buffers"); - allocator.free(d_cum_log_probs); - allocator.free(d_input_lengths); - allocator.free(d_input_ids); - allocator.free(d_logits); + allocator.free((void**)(&d_cum_log_probs)); + allocator.free((void**)(&d_input_lengths)); + allocator.free((void**)(&d_input_ids)); + allocator.free((void**)(&d_logits)); check_cuda_error(cudaStreamDestroy(stream)); } @@ -354,7 +222,7 @@ void testBatchFirstCumLogProbCorrectness(TestCase tc) { size_t max_input_length = tc.max_input_length; size_t batchxbeam = tc.batch_size * tc.beam_width; size_t vocab_size = tc.vocab_size; - // Make mulitple of 8 as GPT does. + // Make multiple of 8 as GPT does. size_t vocab_size_padded = static_cast(ceil(vocab_size / 8.f) * 8); cudaStream_t stream; @@ -362,12 +230,12 @@ void testBatchFirstCumLogProbCorrectness(TestCase tc) { Allocator allocator(getDevice()); // input values - T* h_logits = reinterpret_cast(malloc(sizeof(T) * max_input_length * batchxbeam * vocab_size_padded)); - int* h_input_ids = reinterpret_cast(malloc(sizeof(int) * max_input_length * batchxbeam)); - int* h_input_lengths = reinterpret_cast(malloc(sizeof(int) * batchxbeam)); + T* h_logits = new T[max_input_length * batchxbeam * vocab_size_padded]; + int* h_input_ids = new int[max_input_length * batchxbeam]; + int* h_input_lengths = new int[batchxbeam]; // outupt buffers - float* expected_cum_log_probs = reinterpret_cast(malloc(sizeof(float) * batchxbeam)); + float* expected_cum_log_probs = new float[batchxbeam]; // initialize host buffers initRandom(h_logits, max_input_length * batchxbeam * vocab_size_padded, -10.0f / vocab_size, -1.0f); @@ -411,19 +279,21 @@ void testBatchFirstCumLogProbCorrectness(TestCase tc) { batchxbeam, vocab_size, vocab_size_padded); - checkResult(tc.toString().c_str(), d_cum_log_probs, expected_cum_log_probs, batchxbeam); + std::string tag = tc.toString() + (std::is_same::value ? " (fp32)" : " (fp16)"); + bool passed = checkResult(tag.c_str(), d_cum_log_probs, expected_cum_log_probs, batchxbeam); + EXPECT_TRUE(passed); FT_LOG_DEBUG("free host buffers"); - free(expected_cum_log_probs); - free(h_input_lengths); - free(h_input_ids); - free(h_logits); + delete[] expected_cum_log_probs; + delete[] h_input_lengths; + delete[] h_input_ids; + delete[] h_logits; FT_LOG_DEBUG("free device buffers"); - allocator.free(d_cum_log_probs); - allocator.free(d_input_lengths); - allocator.free(d_input_ids); - allocator.free(d_logits); + allocator.free((void**)(&d_cum_log_probs)); + allocator.free((void**)(&d_input_lengths)); + allocator.free((void**)(&d_input_ids)); + allocator.free((void**)(&d_logits)); check_cuda_error(cudaStreamDestroy(stream)); } diff --git a/tests/unittests/test_penalty_kernels.cu b/tests/unittests/test_penalty_kernels.cu new file mode 100644 index 000000000..efad10d34 --- /dev/null +++ b/tests/unittests/test_penalty_kernels.cu @@ -0,0 +1,872 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // std::min, std::max +#include // snprintf +#include // expf, log +#include +#include // rand +#include // std::string +#include // std::vector + +#include +#include +#include + +#include "src/fastertransformer/kernels/beam_search_penalty_kernels.h" +#include "src/fastertransformer/kernels/sampling_penalty_kernels.h" +#include "src/fastertransformer/utils/cuda_utils.h" + +#include "tests/unittests/unittest_utils.h" + +using namespace fastertransformer; + +struct TemperatureTestCase { + size_t batch_size; + size_t vocab_size; + float temperature; + + std::string toString() { + char buf[200]; + snprintf(buf, sizeof(buf), + "TemperatureTestCase[batch=%ld, vocab=%ld, temperature=%4.2f]", + batch_size, vocab_size, temperature); + return buf; + } + + void print() { + FT_LOG_INFO(toString()); + } +}; + +struct RepetitionTestCase { + size_t batch_size; + size_t vocab_size; + size_t max_input_length; + float repetition_penalty; + + std::string toString() { + char buf[200]; + snprintf(buf, sizeof(buf), + "RepetitionTestCase[batch=%ld, vocab=%ld, max_input_length=%ld, repetition_penalty=%4.2f]", + batch_size, vocab_size, max_input_length, repetition_penalty); + return buf; + } + + void print() { + FT_LOG_INFO(toString()); + } +}; + +size_t pad_vocab_size(size_t vocab_size, size_t pad = 8) { + return (vocab_size + pad - 1) / pad * pad; +} + +void checkTemperatureValidity(float temperature) { + if (temperature <= 0.0f) { + throw std::domain_error( + fmtstr("temperature should be positive but got %.2f.", temperature)); + } +} + +template +void applyTemperature(T* logits, + const T* bias, + const float temperature, + const size_t batch_size, + const size_t vocab_size, + const size_t vocab_size_padded) +{ + checkTemperatureValidity(temperature); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < vocab_size; ++j) { + size_t index = i * vocab_size_padded + j; + float logit = static_cast(logits[index]); + if (bias != nullptr) { + logit += static_cast(bias[j]); + } + logits[index] = static_cast(logit / temperature); + } + } +} + +template +void batchApplyTemperature(T* logits, + const T* bias, + const float* temperatures, + const size_t batch_size, + const size_t vocab_size, + const size_t vocab_size_padded) +{ + for (size_t i = 0; i < batch_size; ++i) { + float temperature = temperatures[i]; + checkTemperatureValidity(temperature); + for (size_t j = 0; j < vocab_size; ++j) { + size_t index = i * vocab_size_padded + j; + float logit = static_cast(logits[index]); + if (bias != nullptr) { + logit += static_cast(bias[j]); + } + logits[index] = static_cast(logit / temperature); + } + } +} + +template +void applyRepetitonPenalty(T* logits, + const int* output_ids, + const int* input_lengths, + const float repetition_penalty, + const size_t step, + const size_t max_input_length, + const size_t batch_size, + const size_t vocab_size, + const size_t vocab_size_padded) +{ + bool* penalized = new bool[vocab_size]; + for (size_t i = 0; i < batch_size; ++i) { + std::fill_n(penalized, vocab_size, false); + size_t length = std::min(step, input_lengths[i]); + size_t offset = i * vocab_size_padded; + for (size_t t = 0; t < step; ++t) { + if (t >= (size_t)input_lengths[i] && t < max_input_length) { + continue; + } + int token_id = output_ids[i + t * batch_size]; + if (!penalized[token_id]) { + float logit = static_cast(logits[offset + token_id]); + logits[offset + token_id] = static_cast(logit < 0.0f ? + logit * repetition_penalty : logit / repetition_penalty); + penalized[token_id] = true; + } + } + } + delete[] penalized; +} + +template +void batchApplyRepetitonPenalty(T* logits, + const int* output_ids, + const int* input_lengths, + const float* repetition_penalties, + const size_t step, + const size_t max_input_length, + const size_t batch_size, + const size_t vocab_size, + const size_t vocab_size_padded) +{ + bool* penalized = new bool[vocab_size]; + for (size_t i = 0; i < batch_size; ++i) { + float repetition_penalty = repetition_penalties[i]; + std::fill_n(penalized, vocab_size, false); + size_t offset = i * vocab_size_padded; + for (size_t t = 0; t < step; ++t) { + if (t >= (size_t)input_lengths[i] && t < max_input_length) { + continue; + } + int token_id = output_ids[i + t * batch_size]; + if (!penalized[token_id]) { + float logit = static_cast(logits[offset + token_id]); + logits[offset + token_id] = static_cast(logit < 0.0f ? + logit * repetition_penalty : logit / repetition_penalty); + penalized[token_id] = true; + } + } + } + delete[] penalized; +} + +template +void initLogitsAndBias(T* logits, + T* bias, + const size_t batch_size, + const size_t vocab_size, + const size_t vocab_size_padded) +{ + initRandom(logits, batch_size * vocab_size_padded, -5.0f, 5.0f); + if (bias != nullptr) { + initRandom(bias, vocab_size, -5.0f, 5.0f); + } + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < vocab_size_padded; ++j) { + if (j >= vocab_size) { + logits[i * vocab_size_padded + j] = static_cast(isHalf() ? -65504.f : -FLT_MAX); + if (bias != nullptr && i == 0) { + bias[j] = (T)0.0f; + } + } + } + } +} + + +/////////////////////////////////// Tests ////////////////////////////////////////// + +template +void testApplyTemperaturePenaltyKernel(TemperatureTestCase tc) { + // Set up test + const size_t batch_size = tc.batch_size; + const size_t vocab_size = tc.vocab_size; + const size_t vocab_size_padded = pad_vocab_size(vocab_size); + + const float temperature = tc.temperature; + T* h_logits = new T[batch_size * vocab_size_padded]; + T* h_bias = new T[vocab_size_padded]; + initLogitsAndBias(h_logits, h_bias, batch_size, vocab_size, vocab_size_padded); + + T* d_logits; + T* d_bias; + check_cuda_error(cudaMalloc(&d_logits, sizeof(T) * batch_size * vocab_size_padded)); + check_cuda_error(cudaMalloc(&d_bias, sizeof(T) * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_logits, h_logits, sizeof(T) * batch_size * vocab_size_padded, cudaMemcpyHostToDevice)); + check_cuda_error(cudaMemcpy(d_bias, h_bias, sizeof(T) * vocab_size_padded, cudaMemcpyHostToDevice)); + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + + // Do test + invokeApplyTemperaturePenalty(d_logits, + d_bias, + temperature, + batch_size, + vocab_size, + vocab_size_padded, + stream); + + applyTemperature(h_logits, h_bias, temperature, batch_size, vocab_size, vocab_size_padded); + std::string tag = "Correctness " + tc.toString() + (isHalf() ? " (FP16)" : " (FP32)"); + bool passed = checkResult(tag, d_logits, h_logits, batch_size * vocab_size_padded); + + // Tear down test + check_cuda_error(cudaStreamDestroy(stream)); + check_cuda_error(cudaFree(d_logits)); + check_cuda_error(cudaFree(d_bias)); + delete[] h_logits; + delete[] h_bias; + + EXPECT_TRUE(passed); +} + +template +void testBatchApplyTemperaturePenaltyKernel(TemperatureTestCase tc) { + // Set up test + const size_t batch_size = tc.batch_size; + const size_t vocab_size = tc.vocab_size; + const size_t vocab_size_padded = pad_vocab_size(vocab_size); + + float* h_temperatures = new float[batch_size]; + for (size_t i = 0; i < batch_size; ++i) { + h_temperatures[i] = i % 2 == 0 ? tc.temperature : 0.1f * tc.temperature; + } + T* h_logits = new T[batch_size * vocab_size_padded]; + T* h_bias = new T[vocab_size_padded]; + initLogitsAndBias(h_logits, h_bias, batch_size, vocab_size, vocab_size_padded); + + float* d_temperatures; + T* d_logits; + T* d_bias; + check_cuda_error(cudaMalloc(&d_temperatures, sizeof(float) * batch_size)); + check_cuda_error(cudaMalloc(&d_logits, sizeof(T) * batch_size * vocab_size_padded)); + check_cuda_error(cudaMalloc(&d_bias, sizeof(T) * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_temperatures, h_temperatures, sizeof(float) * batch_size, cudaMemcpyHostToDevice)); + check_cuda_error(cudaMemcpy(d_logits, h_logits, sizeof(T) * batch_size * vocab_size_padded, cudaMemcpyHostToDevice)); + check_cuda_error(cudaMemcpy(d_bias, h_bias, sizeof(T) * vocab_size_padded, cudaMemcpyHostToDevice)); + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + + // Do test + invokeBatchApplyTemperaturePenalty(d_logits, + d_bias, + d_temperatures, + batch_size, + vocab_size, + vocab_size_padded, + stream); + + batchApplyTemperature(h_logits, h_bias, h_temperatures, batch_size, vocab_size, vocab_size_padded); + std::string tag = "Correctness Batch " + tc.toString() + (isHalf() ? " (FP16)" : " (FP32)"); + bool passed = checkResult(tag, d_logits, h_logits, batch_size * vocab_size_padded); + + // Tear down test + check_cuda_error(cudaStreamDestroy(stream)); + check_cuda_error(cudaFree(d_logits)); + check_cuda_error(cudaFree(d_bias)); + check_cuda_error(cudaFree(d_temperatures)); + delete[] h_logits; + delete[] h_bias; + delete[] h_temperatures; + + EXPECT_TRUE(passed); +} + +template +void testConsistencyTemperaturePenaltyKernel(TemperatureTestCase tc) { + // Set up test + const size_t batch_size = tc.batch_size; + const size_t vocab_size = tc.vocab_size; + const size_t vocab_size_padded = pad_vocab_size(vocab_size); + + float temperature = tc.temperature; + float* h_temperatures = new float[batch_size]; + for (size_t i = 0; i < batch_size; ++i) { + h_temperatures[i] = temperature; + } + T* h_logits = new T[batch_size * vocab_size_padded]; + T* h_bias = new T[vocab_size_padded]; + initLogitsAndBias(h_logits, h_bias, batch_size, vocab_size, vocab_size_padded); + + float* d_temperatures; + check_cuda_error(cudaMalloc(&d_temperatures, sizeof(float) * batch_size)); + check_cuda_error(cudaMemcpy(d_temperatures, h_temperatures, sizeof(float) * batch_size, cudaMemcpyHostToDevice)); + + T* d_logits_single; + T* d_bias_single; + check_cuda_error(cudaMalloc(&d_logits_single, sizeof(T) * batch_size * vocab_size_padded)); + check_cuda_error(cudaMalloc(&d_bias_single, sizeof(T) * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_logits_single, h_logits, sizeof(T) * batch_size * vocab_size_padded, cudaMemcpyHostToDevice)); + check_cuda_error(cudaMemcpy(d_bias_single, h_bias, sizeof(T) * vocab_size_padded, cudaMemcpyHostToDevice)); + + T* d_logits_batch; + T* d_bias_batch; + check_cuda_error(cudaMalloc(&d_logits_batch, sizeof(T) * batch_size * vocab_size_padded)); + check_cuda_error(cudaMalloc(&d_bias_batch, sizeof(T) * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_logits_batch, h_logits, sizeof(T) * batch_size * vocab_size_padded, cudaMemcpyHostToDevice)); + check_cuda_error(cudaMemcpy(d_bias_batch, h_bias, sizeof(T) * vocab_size_padded, cudaMemcpyHostToDevice)); + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + + // Do test + invokeApplyTemperaturePenalty(d_logits_single, + d_bias_single, + temperature, + batch_size, + vocab_size, + vocab_size_padded, + stream); + invokeBatchApplyTemperaturePenalty(d_logits_batch, + d_bias_batch, + d_temperatures, + batch_size, + vocab_size, + vocab_size_padded, + stream); + std::string tag = "Consistency " + tc.toString() + (isHalf() ? " (FP16)" : " (FP32)"); + bool passed = checkResult(tag, d_logits_single, d_logits_batch, batch_size * vocab_size_padded, true, true); + + // Tear down test + check_cuda_error(cudaStreamDestroy(stream)); + check_cuda_error(cudaFree(d_logits_single)); + check_cuda_error(cudaFree(d_bias_single)); + check_cuda_error(cudaFree(d_logits_batch)); + check_cuda_error(cudaFree(d_bias_batch)); + check_cuda_error(cudaFree(d_temperatures)); + delete[] h_logits; + delete[] h_bias; + delete[] h_temperatures; + + EXPECT_TRUE(passed); +} + +template +void testApplyRepetitonPenaltyKernel(RepetitionTestCase tc) { + // Set up test + const size_t batch_size = tc.batch_size; + const size_t vocab_size = tc.vocab_size; + const size_t vocab_size_padded = pad_vocab_size(vocab_size); + const size_t max_input_length = tc.max_input_length; + const size_t sequence_length = 2 * max_input_length; // input + output + const size_t step = sequence_length * 0.5; + const float repetition_penalty = tc.repetition_penalty; + T* h_logits = new T[batch_size * vocab_size_padded]; + int* h_output_ids = new int[sequence_length * batch_size]; + int* h_input_lengths = new int[batch_size]; + initLogitsAndBias(h_logits, (T*)nullptr, batch_size, vocab_size, vocab_size_padded); + initRandomInt(h_output_ids, sequence_length * batch_size, 0, vocab_size); + initRandomInt(h_input_lengths, batch_size, 1, max_input_length); + + T* d_logits; + check_cuda_error(cudaMalloc(&d_logits, sizeof(T) * batch_size * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_logits, h_logits, sizeof(T) * batch_size * vocab_size_padded, cudaMemcpyHostToDevice)); + int* d_output_ids; + check_cuda_error(cudaMalloc(&d_output_ids, sizeof(int) * sequence_length * batch_size)); + check_cuda_error(cudaMemcpy(d_output_ids, h_output_ids, sizeof(int) * sequence_length * batch_size, cudaMemcpyHostToDevice)); + int* d_input_lengths; + check_cuda_error(cudaMalloc(&d_input_lengths, sizeof(int) * batch_size)); + check_cuda_error(cudaMemcpy(d_input_lengths, h_input_lengths, sizeof(int) * batch_size, cudaMemcpyHostToDevice)); + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + + // Do test + invokeApplyRepetitionPenalty(d_logits, + repetition_penalty, + nullptr, + d_output_ids, + batch_size, + batch_size, + vocab_size, + vocab_size_padded, + d_input_lengths, + max_input_length, + step, + stream); + + applyRepetitonPenalty(h_logits, + h_output_ids, + h_input_lengths, + repetition_penalty, + step, + max_input_length, + batch_size, + vocab_size, + vocab_size_padded); + + std::string tag = "Correctness " + tc.toString() + (isHalf() ? " (FP16)" : " (FP32)"); + bool passed = checkResult(tag, d_logits, h_logits, batch_size * vocab_size_padded); + + // Tear down test + check_cuda_error(cudaStreamDestroy(stream)); + check_cuda_error(cudaFree(d_logits)); + check_cuda_error(cudaFree(d_output_ids)); + check_cuda_error(cudaFree(d_input_lengths)); + delete[] h_logits; + delete[] h_output_ids; + delete[] h_input_lengths; + + EXPECT_TRUE(passed); +} + +template +void testBatchApplyRepetitonPenaltyKernel(RepetitionTestCase tc) { + // Set up test + const size_t batch_size = tc.batch_size; + const size_t vocab_size = tc.vocab_size; + const size_t vocab_size_padded = pad_vocab_size(vocab_size); + const size_t max_input_length = tc.max_input_length; + const size_t sequence_length = 2 * tc.max_input_length; + const size_t step = sequence_length * 0.8; + const float repetition_penalty = tc.repetition_penalty; + float* h_repetition_penalties = new float[batch_size]; + for (size_t i = 0; i < batch_size; ++i) { + h_repetition_penalties[i] = i % 2 == 0 ? repetition_penalty : 0.1f * repetition_penalty; + } + + T* h_logits = new T[batch_size * vocab_size_padded]; + int* h_output_ids = new int[sequence_length * batch_size]; + int* h_input_lengths = new int[batch_size]; + initLogitsAndBias(h_logits, (T*)nullptr, batch_size, vocab_size, vocab_size_padded); + initRandomInt(h_output_ids, sequence_length * batch_size, 0, vocab_size); + initRandomInt(h_input_lengths, batch_size, 1, max_input_length); + + T* d_logits; + check_cuda_error(cudaMalloc(&d_logits, sizeof(T) * batch_size * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_logits, h_logits, sizeof(T) * batch_size * vocab_size_padded, cudaMemcpyHostToDevice)); + int* d_output_ids; + check_cuda_error(cudaMalloc(&d_output_ids, sizeof(int) * sequence_length * batch_size)); + check_cuda_error(cudaMemcpy(d_output_ids, h_output_ids, sizeof(int) * sequence_length * batch_size, cudaMemcpyHostToDevice)); + int* d_input_lengths; + check_cuda_error(cudaMalloc(&d_input_lengths, sizeof(int) * batch_size)); + check_cuda_error(cudaMemcpy(d_input_lengths, h_input_lengths, sizeof(int) * batch_size, cudaMemcpyHostToDevice)); + float* d_repetition_penalties; + check_cuda_error(cudaMalloc(&d_repetition_penalties, sizeof(float) * batch_size)); + check_cuda_error(cudaMemcpy(d_repetition_penalties, h_repetition_penalties, sizeof(float) * batch_size, cudaMemcpyHostToDevice)); + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + + // Do test + invokeBatchApplyRepetitionPenalty(d_logits, + d_repetition_penalties, + d_output_ids, + batch_size, + batch_size, + vocab_size_padded, + d_input_lengths, + max_input_length, + step, + stream); + + batchApplyRepetitonPenalty(h_logits, + h_output_ids, + h_input_lengths, + h_repetition_penalties, + step, + max_input_length, + batch_size, + vocab_size, + vocab_size_padded); + + std::string tag = "Correctness Batch " + tc.toString() + (isHalf() ? " (FP16)" : " (FP32)"); + bool passed = checkResult(tag, d_logits, h_logits, batch_size * vocab_size_padded); + + // Tear down test + check_cuda_error(cudaStreamDestroy(stream)); + check_cuda_error(cudaFree(d_logits)); + check_cuda_error(cudaFree(d_output_ids)); + check_cuda_error(cudaFree(d_input_lengths)); + check_cuda_error(cudaFree(d_repetition_penalties)); + delete[] h_repetition_penalties; + delete[] h_logits; + delete[] h_output_ids; + delete[] h_input_lengths; + + EXPECT_TRUE(passed); +} + +template +void testBatchApplyRepetitonPenaltyKernelWithLocalBatch(RepetitionTestCase tc) { + // Set up test + const size_t batch_size = tc.batch_size; + if (batch_size % 2 != 0) { + FT_LOG_WARNING("Skip testApplyRepetitonPenaltyKernelWithLocalBatch (batch_size % 2 != 0)."); + return; + } + const size_t local_batch_size = batch_size / 2; + const size_t vocab_size = tc.vocab_size; + const size_t vocab_size_padded = pad_vocab_size(vocab_size); + const size_t max_input_length = tc.max_input_length; + const size_t sequence_length = 2 * tc.max_input_length; // input + output + const size_t step = sequence_length * 0.8; + const float repetition_penalty = tc.repetition_penalty; + float* h_repetition_penalties = new float[batch_size]; + for (size_t i = 0; i < batch_size; ++i) { + h_repetition_penalties[i] = i % 2 == 0 ? repetition_penalty : 0.1f * repetition_penalty; + } + + T* h_logits = new T[batch_size * vocab_size_padded]; + int* h_output_ids = new int[sequence_length * batch_size]; + int* h_input_lengths = new int[batch_size]; + initLogitsAndBias(h_logits, (T*)nullptr, batch_size, vocab_size, vocab_size_padded); + initRandomInt(h_output_ids, sequence_length * batch_size, 0, vocab_size); + initRandomInt(h_input_lengths, batch_size, 1, max_input_length); + + T* d_logits; + check_cuda_error(cudaMalloc(&d_logits, sizeof(T) * batch_size * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_logits, h_logits, sizeof(T) * batch_size * vocab_size_padded, cudaMemcpyHostToDevice)); + int* d_output_ids; + check_cuda_error(cudaMalloc(&d_output_ids, sizeof(int) * sequence_length * batch_size)); + check_cuda_error(cudaMemcpy(d_output_ids, h_output_ids, sizeof(int) * sequence_length * batch_size, cudaMemcpyHostToDevice)); + int* d_input_lengths; + check_cuda_error(cudaMalloc(&d_input_lengths, sizeof(int) * batch_size)); + check_cuda_error(cudaMemcpy(d_input_lengths, h_input_lengths, sizeof(int) * batch_size, cudaMemcpyHostToDevice)); + float* d_repetition_penalties; + check_cuda_error(cudaMalloc(&d_repetition_penalties, sizeof(float) * batch_size)); + check_cuda_error(cudaMemcpy(d_repetition_penalties, h_repetition_penalties, sizeof(float) * batch_size, cudaMemcpyHostToDevice)); + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + + // Do test + int ite = 1; + invokeBatchApplyRepetitionPenalty(d_logits + ite * local_batch_size * vocab_size_padded, + d_repetition_penalties + ite * local_batch_size, + d_output_ids + ite * local_batch_size, + batch_size, + local_batch_size, + vocab_size_padded, + d_input_lengths + ite * local_batch_size, + max_input_length, + step, + stream); + batchApplyRepetitonPenalty(h_logits, + h_output_ids, + h_input_lengths, + h_repetition_penalties, + step, + max_input_length, + batch_size, + vocab_size, + vocab_size_padded); + + std::string tag = "Correctness (local batch) " + tc.toString() + (isHalf() ? " (FP16)" : " (FP32)"); + bool passed = checkResult(tag, + d_logits + ite * local_batch_size * vocab_size_padded, + h_logits + ite * local_batch_size * vocab_size_padded, + local_batch_size * vocab_size_padded); + + // Tear down test + check_cuda_error(cudaStreamDestroy(stream)); + check_cuda_error(cudaFree(d_logits)); + check_cuda_error(cudaFree(d_output_ids)); + check_cuda_error(cudaFree(d_input_lengths)); + check_cuda_error(cudaFree(d_repetition_penalties)); + delete[] h_repetition_penalties; + delete[] h_logits; + delete[] h_output_ids; + delete[] h_input_lengths; + + EXPECT_TRUE(passed); +} + +template +void testConsistencyRepetitionPenaltyKernel(RepetitionTestCase tc) { + // Set up test + const size_t batch_size = tc.batch_size; + const size_t vocab_size = tc.vocab_size; + const size_t vocab_size_padded = pad_vocab_size(vocab_size); + const size_t max_input_length = tc.max_input_length; + const size_t sequence_length = 2 * max_input_length; + const size_t step = max_input_length * 0.8; + const float repetition_penalty = tc.repetition_penalty; + float* h_repetition_penalties = new float[batch_size]; + for (size_t i = 0; i < batch_size; ++i) { + h_repetition_penalties[i] = repetition_penalty; + } + + T* h_logits = new T[batch_size * vocab_size_padded]; + int* h_output_ids = new int[sequence_length * batch_size]; + int* h_input_lengths = new int[batch_size]; + initLogitsAndBias(h_logits, (T*)nullptr, batch_size, vocab_size, vocab_size_padded); + initRandomInt(h_output_ids, sequence_length * batch_size, 0, vocab_size); + initRandomInt(h_input_lengths, batch_size, 1, max_input_length); + + T* d_logits_single; + check_cuda_error(cudaMalloc(&d_logits_single, sizeof(T) * batch_size * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_logits_single, h_logits, sizeof(T) * batch_size * vocab_size_padded, cudaMemcpyHostToDevice)); + T* d_logits_batch; + check_cuda_error(cudaMalloc(&d_logits_batch, sizeof(T) * batch_size * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_logits_batch, h_logits, sizeof(T) * batch_size * vocab_size_padded, cudaMemcpyHostToDevice)); + + int* d_output_ids; + check_cuda_error(cudaMalloc(&d_output_ids, sizeof(int) * sequence_length * batch_size)); + check_cuda_error(cudaMemcpy(d_output_ids, h_output_ids, sizeof(int) * sequence_length * batch_size, cudaMemcpyHostToDevice)); + int* d_input_lengths; + check_cuda_error(cudaMalloc(&d_input_lengths, sizeof(int) * batch_size)); + check_cuda_error(cudaMemcpy(d_input_lengths, h_input_lengths, sizeof(int) * batch_size, cudaMemcpyHostToDevice)); + float* d_repetition_penalties; + check_cuda_error(cudaMalloc(&d_repetition_penalties, sizeof(float) * batch_size)); + check_cuda_error(cudaMemcpy(d_repetition_penalties, h_repetition_penalties, sizeof(float) * batch_size, cudaMemcpyHostToDevice)); + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + + // Do test + invokeApplyRepetitionPenalty(d_logits_single, + repetition_penalty, + nullptr, + d_output_ids, + batch_size, + batch_size, + vocab_size, + vocab_size_padded, + d_input_lengths, + max_input_length, + step, + stream); + + invokeBatchApplyRepetitionPenalty(d_logits_batch, + d_repetition_penalties, + d_output_ids, + batch_size, + batch_size, + vocab_size_padded, + d_input_lengths, + max_input_length, + step, + stream); + + std::string tag = "Consistency " + tc.toString() + (isHalf() ? " (FP16)" : " (FP32)"); + bool passed = checkResult(tag, d_logits_single, d_logits_batch, batch_size * vocab_size_padded, true, true); + + // Tear down test + check_cuda_error(cudaStreamDestroy(stream)); + check_cuda_error(cudaFree(d_logits_single)); + check_cuda_error(cudaFree(d_logits_batch)); + check_cuda_error(cudaFree(d_output_ids)); + check_cuda_error(cudaFree(d_input_lengths)); + check_cuda_error(cudaFree(d_repetition_penalties)); + delete[] h_logits; + delete[] h_output_ids; + delete[] h_repetition_penalties; + delete[] h_input_lengths; + EXPECT_TRUE(passed); +} + +template +void testBeamPenaltyKernelCorrectness() { + // Set up test + const size_t batch_size = 2; + const size_t beam_width = 3; + const size_t batchxbeam = batch_size * beam_width; + const size_t vocab_size = 4; + const size_t vocab_size_padded = 8; + const size_t max_input_length = 2; + const size_t local_batch_size = batch_size; + const int ite = 0; + const int step = 4; + assert(step > max_input_length); + int* h_end_ids = new int[batch_size]{0, 2}; + int* h_input_lengths = new int[batchxbeam]{2, 2, 2, 2, 2, 2}; + const T MASK_VAL = static_cast(isHalf() ? -65504.f : -FLT_MAX); + T* h_logits = new T[batchxbeam * vocab_size_padded]{ + 4.0f, -2.0f, 5.0f, 9.0f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + 4.0f, -2.0f, 5.0f, 9.0f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + 4.0f, -2.0f, 5.0f, 9.0f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + -2.0f, 1.0f, -3.0f, -2.0f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + -2.0f, 1.0f, -3.0f, -2.0f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + -2.0f, 1.0f, -3.0f, -2.0f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL + }; + T* h_bias = new T[vocab_size_padded]{ + 0.0f, 0.0f, 1.0f, -1.0f, 0.0f, 0.0f, 0.0f, 0.0f + }; + int* h_previous_ids = new int[(step - 1) * batchxbeam]{ + 3, 3, 2, 3, 0, 2, // step 0 [b1 b1 b1 b2 b2 b2] + 1, 2, 1, 1, 1, 2, // step 1 [b1 b1 b1 b2 b2 b2] + 2, 0, 1, 1, 2, 1, // step 2 + }; + int* h_current_ids = new int[batchxbeam]{0, 3, 1, 0, 0, 2}; // step 3. + int* h_parent_ids = new int[(step - 1) * batchxbeam]{ + 0, 1, 2, 0, 1, 1, // step 0 [b1 b1 b1 b2 b2 b2] + 0, 2, 2, 2, 1, 1, // step 1 [b1 b1 b1 b2 b2 b2] + 2, 0, 1, 2, 2, 1, // step 2 + }; + // final output sequence [batch, beam] + // [0, 0]: 2 1 1 0 + // [0, 1]: 3 1 2 3 + // [0, 2]: 2 1 0 1 + // [1, 0]: 0 1 1 0 + // [1, 1]: 0 1 2 0 + // [1, 2]: 0 1 2 2 + + float temperature = 2.0f; + float repetition_penalty = 2.0f; + + T* h_expected = new T[batchxbeam * vocab_size_padded]{ + 1.0f, -2.0f, 1.5f, 4.0f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + 2.0f, -2.0f, 1.5f, 2.0f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + 1.0f, -2.0f, 1.5f, 4.0f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + -2.0f, 0.25f, -1.0f, -1.5f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + -2.0f, 0.25f, -1.0f, -1.5f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL, + -2.0f, 0.25f, -2.0f, -1.5f, MASK_VAL, MASK_VAL, MASK_VAL, MASK_VAL + }; + + T *d_logits, *d_bias; + check_cuda_error(cudaMalloc(&d_logits, sizeof(T) * batchxbeam * vocab_size_padded)); + check_cuda_error(cudaMemcpy( + d_logits, h_logits, sizeof(T) * batchxbeam * vocab_size_padded, cudaMemcpyHostToDevice)); + check_cuda_error(cudaMalloc(&d_bias, sizeof(T) * vocab_size_padded)); + check_cuda_error(cudaMemcpy(d_bias, h_bias, sizeof(T) * vocab_size_padded, cudaMemcpyHostToDevice)); + int *d_previous_ids, *d_current_ids, *d_parent_ids; + check_cuda_error(cudaMalloc(&d_previous_ids, sizeof(int) * (step - 1) * batchxbeam)); + check_cuda_error(cudaMemcpy( + d_previous_ids, h_previous_ids, sizeof(int) * (step - 1) * batchxbeam, cudaMemcpyHostToDevice)); + check_cuda_error(cudaMalloc(&d_current_ids, sizeof(int) * batchxbeam)); + check_cuda_error(cudaMemcpy( + d_current_ids, h_current_ids, sizeof(int) * batchxbeam, cudaMemcpyHostToDevice)); + check_cuda_error(cudaMalloc(&d_parent_ids, sizeof(int) * (step - 1) * batchxbeam)); + check_cuda_error(cudaMemcpy( + d_parent_ids, h_parent_ids, sizeof(int) * (step - 1) * batchxbeam, cudaMemcpyHostToDevice)); + int *d_end_ids; + check_cuda_error(cudaMalloc(&d_end_ids, sizeof(int) * batch_size)); + check_cuda_error(cudaMemcpy(d_end_ids, h_end_ids, sizeof(int) * batch_size, cudaMemcpyHostToDevice)); + int* d_input_lengths; + check_cuda_error(cudaMalloc(&d_input_lengths, sizeof(int) * batchxbeam)); + check_cuda_error(cudaMemcpy( + d_input_lengths, h_input_lengths, sizeof(int) * batchxbeam, cudaMemcpyHostToDevice)); + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + + // Do test + invokeAddBiasApplyPenalties(step, + d_logits + ite * vocab_size_padded, + d_current_ids, + d_previous_ids, + d_parent_ids, // + ite * local_batch_size * beam_width, + d_input_lengths + ite * local_batch_size * beam_width, + d_bias, + ite, + max_input_length, + local_batch_size, + batch_size, + beam_width, + vocab_size, + vocab_size_padded, + d_end_ids, + temperature, + repetition_penalty, + stream); + std::string tag = std::string("Beamsearch Penalty Kernel Correctness") + + (isHalf() ? " (FP16)" : " (FP32)"); + bool passed = checkResult(tag, d_logits, h_expected, batchxbeam * vocab_size_padded); + + // Tear down test + check_cuda_error(cudaStreamDestroy(stream)); + check_cuda_error(cudaFree(d_logits)); + check_cuda_error(cudaFree(d_bias)); + check_cuda_error(cudaFree(d_current_ids)); + check_cuda_error(cudaFree(d_previous_ids)); + check_cuda_error(cudaFree(d_parent_ids)); + check_cuda_error(cudaFree(d_input_lengths)); + check_cuda_error(cudaFree(d_end_ids)); + delete[] h_logits; + delete[] h_bias; + delete[] h_current_ids; + delete[] h_previous_ids; + delete[] h_parent_ids; + delete[] h_input_lengths; + delete[] h_end_ids; + EXPECT_TRUE(passed); +} + +int main() { + std::vector temperature_test_cases { + // TC: name / batch / vocab / temperature / repetition + {6, 4, 0.53f}, + {6, 4, 1.0f}, + {6, 4, 2.01f}, + {6, 50001, 2.01f}, + {128, 51200, 2.01f} + }; + + for (auto &tc : temperature_test_cases) { + testApplyTemperaturePenaltyKernel(tc); + testApplyTemperaturePenaltyKernel(tc); + testBatchApplyTemperaturePenaltyKernel(tc); + testBatchApplyTemperaturePenaltyKernel(tc); + testConsistencyTemperaturePenaltyKernel(tc); + testConsistencyTemperaturePenaltyKernel(tc); + } + FT_LOG_INFO("test TemperaturePenaltyKernel done"); + + std::vector repetition_test_cases { + {6, 4, 10, 0.53f}, + {6, 4, 10, 1.0f}, + {6, 4, 10, 2.01f}, + {6, 50001, 10, 2.01f}, + {128, 51200, 1024, 2.01f}, + {128, 51200, 2048, 2.01f} + }; + for (auto& tc : repetition_test_cases) { + testApplyRepetitonPenaltyKernel(tc); + testApplyRepetitonPenaltyKernel(tc); + testBatchApplyRepetitonPenaltyKernel(tc); + testBatchApplyRepetitonPenaltyKernel(tc); + testBatchApplyRepetitonPenaltyKernelWithLocalBatch(tc); + testBatchApplyRepetitonPenaltyKernelWithLocalBatch(tc); + testConsistencyRepetitionPenaltyKernel(tc); + testConsistencyRepetitionPenaltyKernel(tc); + } + FT_LOG_INFO("test RepetitionPenaltyKernel done"); + + testBeamPenaltyKernelCorrectness(); + testBeamPenaltyKernelCorrectness(); + FT_LOG_INFO("test BeamPenaltyKernelCorrectness done"); + + return 0; +} diff --git a/tests/unittests/test_sampling.cu b/tests/unittests/test_sampling.cu index aa8359b54..94513d4a9 100644 --- a/tests/unittests/test_sampling.cu +++ b/tests/unittests/test_sampling.cu @@ -9,6 +9,7 @@ #include #include +#include "src/fastertransformer/kernels/sampling_topk_kernels.h" #include "src/fastertransformer/layers/DynamicDecodeLayer.h" #include "src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.h" #include "src/fastertransformer/utils/cublasMMWrapper.h" @@ -16,13 +17,9 @@ #include "src/fastertransformer/utils/memory_utils.h" #include "src/fastertransformer/utils/Tensor.h" -// namespace ft = fastertransformer; -using namespace fastertransformer; - +#include "tests/unittests/unittest_utils.h" -#define PRINT_LIMIT 16 -#define EPSILON (1e-20) -#define EPSILON_FP16 (1e-10) +using namespace fastertransformer; struct TestCase { std::string name; @@ -46,80 +43,6 @@ struct TestCase { } }; -bool almostEqual(float a, float b, float atol = 1e-5, float rtol = 1e-8) -{ - // Params: a = value to compare and b = reference - // This function follows implementation of numpy.isclose(), which checks - // abs(a - b) <= (atol + rtol * abs(b)). - // Note that the inequality above is asymmetric where b is considered as - // a reference value. To account into both absolute/relative errors, it - // uses absolute tolerance and relative tolerance at the same time. The - // default values of atol and rtol borrowed from numpy.isclose(). For the - // case of nan value, the result will be true. - if (isnan(a) && isnan(b)) { - return true; - } - return fabs(a - b) <= (atol + rtol * fabs(b)); -} - -template -bool checkResult(std::string name, T* out, T*ref, size_t size, float atol, float rtol) { - size_t failures = 0; - float relative_gap = 0.0f; - - T* h_out = reinterpret_cast(malloc(sizeof(T) * size)); - cudaMemcpy(h_out, out, sizeof(T) * size, cudaMemcpyDeviceToHost); - - for (size_t i = 0; i < size; ++i) { - // The values for the output and the reference. - float a = (float)h_out[i]; - float b = (float)ref[i]; - - bool ok = almostEqual(a, b, atol, rtol); - // Print the error. - if (!ok && failures < 4) { - FT_LOG_ERROR(">> invalid result for i=%lu:", i); - FT_LOG_ERROR(">> found......: %10.6f", a); - FT_LOG_ERROR(">> expected...: %10.6f", b); - FT_LOG_ERROR(">> error......: %.6f", fabsf(a - b)); - FT_LOG_ERROR(">> tol........: %.6f", atol + rtol * fabs(b)); - } - // Update the number of failures. - failures += ok ? 0 : 1; - // Update the relative gap. - relative_gap += fabsf(a - b) / (fabsf(b) + EPSILON); - } - - relative_gap /= size; - - // Allow not matched up to 1% elements. - size_t tol_failures = (size_t)(0.01 * size); - FT_LOG_INFO("check.......%-30s : %s (failures: %.2f%% atol: %.2e rtol: %.2e rel_gap: %.2e%%)", - name.c_str(), failures <= tol_failures ? "OK" : "FAILED", - 100. * failures / size, atol, rtol, 100. * relative_gap); - return failures <= tol_failures; -} - -template -bool checkResult(std::string name, T* out, T* ref, size_t size) { - bool is_fp32 = sizeof(T) == 4; - // float atol = is_fp32 ? 1e-6f : 1e-3f; - // float rtol = is_fp32 ? 1e-4f : 1e-1f; - float atol = is_fp32 ? 1e-4f : 1e-3f; - float rtol = is_fp32 ? 1e-2f : 1e-1f; - bool is_ok = checkResult(name, out, ref, size, atol, rtol); - return is_ok; -} - -template -void initRandom(T* ptr, size_t size, float minval, float maxval) { - for (size_t i = 0; i < size; ++i) { - float val = static_cast(rand()) / static_cast(RAND_MAX); - val *= (maxval - minval); - ptr[i] = static_cast(minval + val); - } -} - template void computeProb(T* probs, T* logits, int batch_size, int vocab_size) { // Compute the log probability from logits. @@ -154,60 +77,20 @@ void computeLogProb(T* logprobs, T* logits, int batch_size, int vocab_size) { } } -template -static inline void printMatrixHightPrecision(T* ptr, int m, int k, int stride, bool is_device_ptr) -{ - T* tmp; - if (is_device_ptr) { - // k < stride ; stride = col-dimension. - tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); - check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); - cudaDeviceSynchronize(); - } - else { - tmp = ptr; - } - - for (int ii = -1; ii < m; ++ii) { - if (ii >= 0) { - printf("%02d ", ii); - } - else { - printf(" "); - } - - for (int jj = 0; jj < k; jj += 1) { - if (ii >= 0) { - printf("%9.6f ", (float)tmp[ii * stride + jj]); - } - else { - printf("%9d ", jj); - } - } - printf("\n"); - } - if (is_device_ptr) { - free(tmp); - } -} - -template -static inline void printMatrixWithLimit(T* ptr, int m, int k, int stride, bool is_device_ptr) { - printMatrixHightPrecision(ptr, std::min(PRINT_LIMIT, m), std::min(PRINT_LIMIT, k), stride, is_device_ptr); -} +/////////////////////////////////// Tests ////////////////////////////////////////// template -void testDynamicDecoingLayer(TestCase tc) { +void testCumLogProbComputation(TestCase tc) { bool is_fp32 = std::is_same::value; size_t beam_width = tc.beam_width; - size_t top_k = tc.top_k; + uint top_k = tc.top_k; float top_p = tc.top_p; unsigned long long seed = 0; // use default values having no effect. float temperature = 1.0f; - float len_penalty = 1.0f; + float len_penalty = 0.0f; float repetition_penalty = 1.0f; size_t batch_size = tc.batch_size; @@ -234,33 +117,32 @@ void testDynamicDecoingLayer(TestCase tc) { std::mutex* cublas_wrapper_mutex = new std::mutex(); cublasMMWrapper *cublas_wrapper = new cublasMMWrapper(cublas_handle, - cublaslt_handle, - stream, - &cublas_algo_map, - cublas_wrapper_mutex, - allocator); - - DynamicDecodeLayer dynamic_decode_layer(vocab_size, - vocab_size, - end_id, - stream, - cublas_wrapper, - allocator, - false, // is_free_buffer_after_forward - &prop); // cuda_device_prop + cublaslt_handle, + stream, + &cublas_algo_map, + cublas_wrapper_mutex, + allocator); + + DynamicDecodeLayer *dynamic_decode_layer = new DynamicDecodeLayer(vocab_size, + vocab_size, + end_id, + stream, + cublas_wrapper, + allocator, + false, // is_free_buffer_after_forward + &prop); // cuda_device_prop const DataType data_type = getTensorType(); size_t logits_size = batch_size * beam_width * vocab_size; T* logits_buf = reinterpret_cast(allocator->malloc(sizeof(T) * logits_size, true)); // Logit values in the host of shape ((batch_size x beam) x vocab_size) where beam = 1. - T* h_logits = reinterpret_cast(malloc(sizeof(T) * batch_size * beam_width * vocab_size)); - T* h_probs = reinterpret_cast(malloc(sizeof(T) * batch_size * beam_width * vocab_size)); - T* h_log_probs = reinterpret_cast(malloc(sizeof(T) * batch_size * beam_width * vocab_size)); - float* h_cum_log_probs = reinterpret_cast(malloc(sizeof(float) * batch_size * beam_width)); - float* h_output_log_probs = reinterpret_cast( - malloc(sizeof(float) * max_output_len * batch_size * beam_width)); - float* expected_cum_log_probs = reinterpret_cast(malloc(sizeof(float) * batch_size * beam_width)); + T* h_logits = new T[batch_size * beam_width * vocab_size]; + T* h_probs = new T[batch_size * beam_width * vocab_size]; + T* h_log_probs = new T[batch_size * beam_width * vocab_size]; + float* h_cum_log_probs = new float[batch_size * beam_width]; + float* h_output_log_probs = new float[max_output_len * batch_size * beam_width]; + float* expected_cum_log_probs = new float[batch_size * beam_width]; initRandom(h_logits, batch_size * beam_width * vocab_size, -10.0f / vocab_size, -1.0f); computeProb(h_probs, h_logits, batch_size * beam_width, vocab_size); computeLogProb(h_log_probs, h_logits, batch_size * beam_width, vocab_size); @@ -279,21 +161,30 @@ void testDynamicDecoingLayer(TestCase tc) { float* cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batch_size * beam_width)); float* output_log_probs = reinterpret_cast( allocator->malloc(sizeof(float) * max_output_len * batch_size * beam_width)); - bool has_diff_runtime_args = false; - bool is_initialize_random_table = true; int* output_ids = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batch_size * beam_width)); - int* h_output_ids = reinterpret_cast(malloc(sizeof(int) * batch_size * beam_width)); + int* h_output_ids = new int[batch_size * beam_width]; + + int* end_ids = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + deviceFill(end_ids, batch_size, end_id); // Init by zero. cudaMemset(cum_log_probs, 0, sizeof(float) * batch_size * beam_width); cudaMemset(output_log_probs, 0, sizeof(float) * max_output_len * batch_size * beam_width); cudaMemset(output_ids, 0, sizeof(int) * max_seq_len * batch_size * beam_width); + std::unordered_map input_tensors{ + {"random_seed", {MEMORY_CPU, TYPE_INT32, {1}, &seed}}, + {"runtime_top_k", {MEMORY_CPU, TYPE_UINT32, {1}, &top_k}}, + {"runtime_top_p", {MEMORY_CPU, TYPE_FP32, {1}, &top_p}}, + {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature}}, + {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &len_penalty}}, + {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty}} + }; + dynamic_decode_layer->setup(batch_size, beam_width, &input_tensors); + for (size_t step = max_input_len; step < max_output_len; ++step) { uint ite = 0; - seed += step; - // Reset by the test value since the sampling layer internally update the logit buffer (making it log-prob). cudaH2Dcpy(logits_buf, h_logits, logits_size); std::unordered_map dynamic_decode_input_tensors{ @@ -304,12 +195,10 @@ void testDynamicDecoingLayer(TestCase tc) { {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, tiled_input_lengths_buf}}, {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}}, - {"has_diff_runtime_args", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &has_diff_runtime_args}}, {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &batch_size}}, - {"is_initialize_random_table", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_initialize_random_table}}, - {"end_id", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &end_id}}, - {"random_seed", {MEMORY_CPU, TYPE_INT32, {1}, &seed}}, - {"runtime_top_k", {MEMORY_CPU, TYPE_INT32, {1}, &top_k}}, + {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids}}, + {"random_seed", {MEMORY_CPU, TYPE_UINT64, {1}, &seed}}, + {"runtime_top_k", {MEMORY_CPU, TYPE_UINT32, {1}, &top_k}}, {"runtime_top_p", {MEMORY_CPU, TYPE_FP32, {1}, &top_p}}, {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature}}, {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &len_penalty}}, @@ -329,7 +218,7 @@ void testDynamicDecoingLayer(TestCase tc) { {"tgt_cache_indirection", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width, max_output_len}, nullptr}}}; - dynamic_decode_layer.forward(&dynamic_decode_output_tensors, + dynamic_decode_layer->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); FT_LOG_DEBUG("Step %2d generated ids", step); @@ -369,21 +258,903 @@ void testDynamicDecoingLayer(TestCase tc) { #endif } std::string tag = tc.toString() + (is_fp32 ? " (fp32)" : " (fp16)"); - checkResult(tag, cum_log_probs, expected_cum_log_probs, batch_size * beam_width); + bool passed = checkResult(tag, cum_log_probs, expected_cum_log_probs, batch_size * beam_width); + EXPECT_TRUE(passed); + + delete[] expected_cum_log_probs; + delete[] h_output_log_probs; + delete[] h_cum_log_probs; + delete[] h_logits; + delete[] h_log_probs; + delete[] h_probs; + delete[] h_output_ids; + + delete dynamic_decode_layer; + delete cublas_wrapper; + delete allocator; + check_cuda_error(cudaStreamDestroy(stream)); + check_cuda_error(cublasDestroy(cublas_handle)); + check_cuda_error(cublasLtDestroy(cublaslt_handle)); +} + +void printTensors(std::unordered_map* map, size_t limit = 8) { + FT_LOG_INFO("Tensors:"); + for (auto& kv : *map) { + Tensor t = kv.second; + FT_LOG_INFO(" - %-18s : %s", kv.first.c_str(), t.toString().c_str()); + } +} + +template +class SamplingDecodeTest { +private: + unsigned long long seed = 0; + const static unsigned long long max_seed = 30; + const size_t batch_size = 6; + const size_t beam_width = 1; + const size_t batchxbeam = batch_size * beam_width; + const size_t vocab_size = 8; + const size_t max_input_len = 0; // has no effect. + const size_t max_output_len = 3; + const size_t max_seq_len = max_input_len + max_output_len; + const int end_id = vocab_size - 1; + const DataType data_type = getTensorType(); + + // vocab size 8 & length 3 + T* test_input_logits; + + Allocator *allocator; + std::mutex *cublas_wrapper_mutex; + cublasMMWrapper *cublas_wrapper; + DynamicDecodeLayer *dynamic_decode_layer; + + int* h_output_ids; + T* h_logits; + T* h_probs; + T* h_log_probs; + float* h_cum_log_probs; + float* h_output_log_probs; + + T* d_logits; + int* d_input_lengths; + float* d_cum_log_probs; + float* d_output_log_probs; + int* d_output_ids; + int* d_end_ids; + + void setup(unsigned long long seed = 0) { + this->seed = seed; + struct cudaDeviceProp prop; + check_cuda_error(cudaGetDeviceProperties(&prop, 0)); + cudaStream_t stream; + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + check_cuda_error(cudaStreamCreate(&stream)); + check_cuda_error(cublasCreate(&cublas_handle)); + check_cuda_error(cublasLtCreate(&cublaslt_handle)); + check_cuda_error(cublasSetStream(cublas_handle, stream)); + cublasAlgoMap cublas_algo_map(GEMM_CONFIG); + allocator = new Allocator(getDevice()); + allocator->setStream(stream); + cublas_wrapper_mutex = new std::mutex(); + cublas_wrapper = new cublasMMWrapper(cublas_handle, + cublaslt_handle, + stream, + &cublas_algo_map, + cublas_wrapper_mutex, + allocator); + dynamic_decode_layer = new DynamicDecodeLayer(vocab_size, + vocab_size, + end_id, + stream, + cublas_wrapper, + allocator, + false, // is_free_buffer_after_forward + &prop); // cuda_device_prop + + h_output_ids = new int[batchxbeam]; + h_logits = new T[batchxbeam * vocab_size]; + h_probs = new T[batchxbeam * vocab_size]; + h_log_probs = new T[batchxbeam * vocab_size]; + h_cum_log_probs = new float[batchxbeam]; + h_output_log_probs = new float[max_output_len * batchxbeam]; + + // prob = (0.4, 0.3, 0.2, 0.1, ...) + test_input_logits = new T[24]{ + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 0 + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, // step 1 + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX // step 2 + }; + + d_logits = reinterpret_cast(allocator->malloc(sizeof(T) * batchxbeam * vocab_size, true)); + d_input_lengths = reinterpret_cast(allocator->malloc(sizeof(int) * batchxbeam)); + d_cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batchxbeam)); + d_output_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * max_output_len * batchxbeam)); + d_output_ids = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batchxbeam)); + d_end_ids = reinterpret_cast(allocator->malloc(sizeof(int) * batchxbeam)); + + // Init by zero. + cudaMemset(d_cum_log_probs, 0, sizeof(float) * batchxbeam); + cudaMemset(d_output_log_probs, 0, sizeof(float) * max_output_len * batchxbeam); + cudaMemset(d_output_ids, 0, sizeof(int) * max_seq_len * batchxbeam); + deviceFill(d_end_ids, batchxbeam, end_id, stream); + } + + void teardown() { + delete[] test_input_logits; + delete[] h_output_ids; + delete[] h_logits; + delete[] h_probs; + delete[] h_log_probs; + delete[] h_cum_log_probs; + delete[] h_output_log_probs; + delete dynamic_decode_layer; + delete cublas_wrapper; + delete cublas_wrapper_mutex; + delete allocator; + } + + std::unordered_map* createInputTensors(int* topk, + size_t topk_size, + float* topp, + size_t topp_size, + float* temperature, + float* repetition_penalty) + { + // construct common input tensors + std::unordered_map* input_tensors = new std::unordered_map(); + if (topk != nullptr) { + input_tensors->insert({"runtime_top_k", {MEMORY_CPU, TYPE_INT32, {topk_size}, topk}}); + } + if (topp != nullptr) { + input_tensors->insert({"runtime_top_p", {MEMORY_CPU, TYPE_FP32, {topp_size}, topp}}); + } + if (temperature != nullptr) { + input_tensors->insert({"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, temperature}}); + } + if (repetition_penalty != nullptr) { + input_tensors->insert({"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, repetition_penalty}}); + } + input_tensors->insert({"logits", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size, beam_width, vocab_size}, d_logits}}); + input_tensors->insert({"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size}, nullptr}}); + input_tensors->insert({"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_len}}); + input_tensors->insert({"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, d_input_lengths}}); + input_tensors->insert({"end_id", Tensor{MEMORY_CPU, TYPE_INT32, {batchxbeam}, &d_end_ids}}); + input_tensors->insert({"random_seed", Tensor{MEMORY_CPU, TYPE_UINT64, {1}, &seed}}); + return input_tensors; + } + + std::unordered_map* createOutputTensors() { + // construct common output tensors + std::unordered_map* output_tensors = new std::unordered_map(); + output_tensors->insert( + {"output_ids", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, d_output_ids}}); + output_tensors->insert({"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, nullptr}}); + output_tensors->insert( + {"cum_log_probs", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size * beam_width}, d_cum_log_probs}}); + output_tensors->insert( + {"output_log_probs", + Tensor{MEMORY_GPU, TYPE_FP32, {max_seq_len, batch_size, beam_width}, d_output_log_probs}}); + output_tensors->insert( + {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, nullptr}}); + return output_tensors; + } + + void batchH2Dcpy(T* dst, T* src, size_t m, size_t n) { + for (size_t i = 0; i < m; ++i) { + cudaH2Dcpy(dst + i * n, src, n); + } + } + + bool checkResult(std::string name, int* d_output_ids, std::vector>& expected_ids) { + assert(expected_ids.size() == max_seq_len * batchxbeam); + int* h_output_ids = new int[max_seq_len * batchxbeam]; + cudaD2Hcpy(h_output_ids, d_output_ids, max_seq_len * batchxbeam); + int failures = 0; + for (size_t i = 0; i < max_seq_len * batchxbeam; ++i) { + size_t s = i / batchxbeam; + size_t b = i % batchxbeam; + std::set expts = expected_ids.at(i); + if (expts.count(h_output_ids[i]) == 0) { + if (failures < 10) { + std::stringstream ss; + ss << " - Fail " << name + << " (step=" << s << ", batch=" << b << ") " + << "actual=" << h_output_ids[i] << ", expected"; + for (auto& expt : expts) { + ss << " " << expt; + } + FT_LOG_DEBUG("%s", ss.str().c_str()); + } + ++failures; + } + } + delete[] h_output_ids; + FT_LOG_DEBUG("check...%6s : %s (failures: %d / %d)", + failures == 0 ? "....OK" : "FAILED", name.c_str(), + failures, max_seq_len * batchxbeam); + return failures == 0; + } + + bool testSampling(std::string name, + std::vector> expected_output_ids, + int* top_ks, + size_t top_k_size, + float* top_ps, + size_t top_p_size, + float* temperature, + float* repetition_penalty) + { + FT_LOG_INFO("Test %s", name.c_str()); + std::string tag = fmtstr( + "Test %s T=%s", name.c_str(), std::is_same::value ? "fp32" : "fp16"); + bool passed = true; + for (unsigned long long seed = 0; seed < max_seed; ++seed) { + this->setup(seed); + size_t step = max_input_len; + uint ite = 0; + std::unordered_map* input_tensors = createInputTensors( + top_ks, top_k_size, top_ps, top_p_size, temperature, repetition_penalty); + input_tensors->insert({"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}); + input_tensors->insert({"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}}); + input_tensors->insert({"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &batch_size}}); + std::unordered_map* output_tensors = createOutputTensors(); + + dynamic_decode_layer->setup(batch_size, beam_width, input_tensors); + for (step = max_input_len; step < max_output_len; ++step) { + // Reset by the test value since the sampling layer internally update the logit buffer. + batchH2Dcpy(input_tensors->at("logits").getPtr(), + test_input_logits + step * vocab_size, + batchxbeam, + vocab_size); + dynamic_decode_layer->forward(output_tensors, input_tensors); + } + bool is_ok = checkResult(tag + fmtstr(" seed=%lld", seed), d_output_ids, expected_output_ids); + passed &= is_ok; +#ifndef NDEBUG + if (!is_ok) { + FT_LOG_ERROR("actual output ids"); + printMatrix(d_output_ids, max_seq_len, batch_size, batch_size, true); + } +#endif + delete output_tensors; + delete input_tensors; + this->teardown(); + } + FT_LOG_INFO("check...%6s : %s", passed ? "....OK" : "FAILED", tag.c_str()); + return passed; + } + + bool testSamplingWithLocalBatch(std::string name, + std::vector> expected_output_ids, + int* top_ks, + size_t top_k_size, + float* top_ps, + size_t top_p_size, + float* temperature, + float* repetition_penalty) + { + FT_LOG_INFO("Test %s", name.c_str()); + std::string tag = fmtstr( + "Test %s T=%s", name.c_str(), std::is_same::value ? "fp32" : "fp16"); + bool passed = true; + size_t local_batch_size = 2; + uint ite = 1; + for (unsigned long long seed = 0; seed < max_seed; ++seed) { + this->setup(seed); + size_t step = max_input_len; + std::unordered_map* input_tensors = createInputTensors( + top_ks, top_k_size, top_ps, top_p_size, temperature, repetition_penalty); + input_tensors->insert({"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}); + input_tensors->insert({"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}}); + input_tensors->insert({"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &local_batch_size}}); + std::unordered_map* output_tensors = createOutputTensors(); + + dynamic_decode_layer->setup(batch_size, beam_width, input_tensors); + for (step = max_input_len; step < max_output_len; ++step) { + // Reset by the test value since the sampling layer internally update the logit buffer. + batchH2Dcpy(input_tensors->at("logits").getPtr(), + test_input_logits + step * vocab_size, + batchxbeam, + vocab_size); + dynamic_decode_layer->forward(output_tensors, input_tensors); + } + bool is_ok = checkResult(tag + fmtstr(" seed=%lld", seed), d_output_ids, expected_output_ids); + passed &= is_ok; +#ifndef NDEBUG + if (!is_ok) { + FT_LOG_ERROR("actual output ids"); + printMatrix(d_output_ids, max_seq_len, batch_size, batch_size, true); + } +#endif + delete output_tensors; + delete input_tensors; + this->teardown(); + } + FT_LOG_INFO("check...%6s : %s", passed ? "....OK" : "FAILED", tag.c_str()); + return passed; + } + +public: + + void testTopK() { + int top_k = 2; + std::vector> expected_output_ids { + // batch + // 0 1 2 3 4 5 + {0, 1}, {0, 1}, {0, 1}, {0, 1}, {0, 1}, {0, 1}, // step 0 + {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, // step 1 + {2, 3}, {2, 3}, {2, 3}, {2, 3}, {2, 3}, {2, 3} // step 2 + }; + bool passed = this->testSampling( + "TopK", expected_output_ids, &top_k, 1, nullptr, 0, nullptr, nullptr); + EXPECT_TRUE(true); + } + + void testBatchTopK() { + int* top_ks = new int[batch_size]{2, 1, 1, 2, 1, 1}; + std::vector> expected_output_ids { + // batch + // 0 1 2 3 4 5 + {0, 1}, {0}, {0}, {0, 1}, {0}, {0}, // step 0 + {4, 5}, {4}, {4}, {4, 5}, {4}, {4}, // step 1 + {2, 3}, {2}, {2}, {2, 3}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + "BatchTopK", expected_output_ids, top_ks, batch_size, nullptr, 0, nullptr, nullptr); + delete[] top_ks; + EXPECT_TRUE(passed); + } + + void testTopP() { + float top_p = 0.3; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0}, {0}, {0}, {0}, // step 0 + {4}, {4}, {4}, {4}, {4}, {4}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + "TopP", expected_output_ids, nullptr, 0, &top_p, 1, nullptr, nullptr); + EXPECT_TRUE(true); + } + + void testBatchTopP() { + float* top_ps = new float[batch_size]{0.3f, 0.5f, 0.5f, 0.3f, 0.5f, 0.5f}; + std::vector> expected_output_ids { + {0}, {0, 1}, {0, 1}, {0}, {0, 1}, {0, 1}, // step 0 + {4}, {4, 5}, {4, 5}, {4}, {4, 5}, {4, 5}, // step 1 + {2}, {2, 3}, {2, 3}, {2}, {2, 3}, {2, 3} // step 2 + }; + bool passed = this->testSampling( + "BatchTopP", expected_output_ids, nullptr, 0, top_ps, batch_size, nullptr, nullptr); + delete[] top_ps; + EXPECT_TRUE(passed); + } + + void testTopKTopP() { + int top_k = 2; + float top_p = 0.3; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0}, {0}, {0}, {0}, // step 0 + {4}, {4}, {4}, {4}, {4}, {4}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + "TopP", expected_output_ids, &top_k, 1, &top_p, 1, nullptr, nullptr); + EXPECT_TRUE(true); + } + + void testBatchTopKTopP() { + std::string name = "BatchTopKTopP"; + int* top_ks = new int[batch_size]{2, 2, 1, 2, 2, 1}; + float top_p = 0.3; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0}, {0}, {0}, {0}, // step 0 + {4}, {4}, {4}, {4}, {4}, {4}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, top_ks, batch_size, &top_p, 1, nullptr, nullptr); + delete[] top_ks; + EXPECT_TRUE(passed); + } + + void testTopKBatchTopP() { + std::string name = "TopKBatchTopP"; + int top_k = 2; + float* top_ps = new float[batch_size]{0.5, 0.3, 0.5, 0.5, 0.3, 0.5}; + std::vector> expected_output_ids { + // batch + {0, 1}, {0}, {0, 1}, {0, 1}, {0}, {0, 1}, // step 0 + {4, 5}, {4}, {4, 5}, {4, 5}, {4}, {4, 5}, // step 1 + {2, 3}, {2}, {2, 3}, {2, 3}, {2}, {2, 3} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, &top_k, 1, top_ps, batch_size, nullptr, nullptr); + delete[] top_ps; + EXPECT_TRUE(passed); + } + + void testBatchTopKBatchTopP() { + std::string name = "BatchTopKBatchTopP"; + int* top_ks = new int[batch_size]{2, 2, 0, 2, 2, 0}; + float* top_ps = new float[batch_size]{0.0, 0.3, 0.5, 0.0, 0.3, 0.5}; + std::vector> expected_output_ids { + // batch + {0, 1}, {0}, {0, 1}, {0, 1}, {0}, {0, 1}, // step 0 + {4, 5}, {4}, {4, 5}, {4, 5}, {4}, {4, 5}, // step 1 + {2, 3}, {2}, {2, 3}, {2, 3}, {2}, {2, 3} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, top_ks, batch_size, top_ps, batch_size, nullptr, nullptr); + delete[] top_ks; + delete[] top_ps; + EXPECT_TRUE(passed); + } + + void testInvalidArgsZeroTopK() { + std::string name = "InvalidArgsZeroTopK"; + int top_k = 0; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0}, {0}, {0}, {0}, // step 0 + {4}, {4}, {4}, {4}, {4}, {4}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, &top_k, 1, nullptr, 0, nullptr, nullptr); + EXPECT_TRUE(passed); + } + + void testInvalidArgsZeroTopP() { + std::string name = "InvalidArgsZeroTopP"; + float top_p = 0; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0}, {0}, {0}, {0}, // step 0 + {4}, {4}, {4}, {4}, {4}, {4}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, nullptr, 0, &top_p, 1, nullptr, nullptr); + EXPECT_TRUE(passed); + } + + void testInvalidArgsZeroTopKTopP() { + std::string name = "InvalidArgsZeroTopKTopP"; + int top_k = 0; + float top_p = 0; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0}, {0}, {0}, {0}, // step 0 + {4}, {4}, {4}, {4}, {4}, {4}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, &top_k, 1, &top_p, 1, nullptr, nullptr); + EXPECT_TRUE(passed); + } + + void testInvalidArgsZeroBatchTopKTopP() { + std::string name = "InvalidArgsZeroBatchTopKTopP"; + int* top_ks = new int[batch_size]{0, 0, 0, 0, 0, 0}; + float top_p = 0; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0}, {0}, {0}, {0}, // step 0 + {4}, {4}, {4}, {4}, {4}, {4}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, top_ks, batch_size, &top_p, 1, nullptr, nullptr); + delete[] top_ks; + EXPECT_TRUE(passed); + } + + void testInvalidArgsZeroTopKBatchTopP() { + std::string name = "InvalidArgsZeroTopKBatchTopP"; + int top_k = 0; + float* top_ps = new float[batch_size]{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0}, {0}, {0}, {0}, // step 0 + {4}, {4}, {4}, {4}, {4}, {4}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, &top_k, 1, top_ps, batch_size, nullptr, nullptr); + delete[] top_ps; + EXPECT_TRUE(passed); + } + + void testInvalidArgsBatchTopKContainZero() { + std::string name = "InvalidArgsBatchTopKContainZero"; + int* top_ks = new int[batch_size]{2, 1, 0, 0, 2, 1}; + std::vector> expected_output_ids { + // batch + {0, 1}, {0}, {0}, {0}, {0, 1}, {0}, // step 0 + {4, 5}, {4}, {4}, {4}, {4, 5}, {4}, // step 1 + {2, 3}, {2}, {2}, {2}, {2, 3}, {2} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, top_ks, batch_size, nullptr, 0, nullptr, nullptr); + delete[] top_ks; + EXPECT_TRUE(passed); + } + + void testInvalidArgsBatchTopPContainZero() { + std::string name = "InvalidArgsBatchTopPContainZero"; + float* top_ps = new float[batch_size]{0.5f, 0.5f, 0.0f, 0.5f, 0.0f, 0.3f}; + std::vector> expected_output_ids { + // batch + {0, 1}, {0, 1}, {0}, {0, 1}, {0}, {0}, // step 0 + {4, 5}, {4, 5}, {4}, {4, 5}, {4}, {4}, // step 1 + {2, 3}, {2, 3}, {2}, {2, 3}, {2}, {2} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, nullptr, 0, top_ps, batch_size, nullptr, nullptr); + delete[] top_ps; + EXPECT_TRUE(passed); + } + + void testInvalidArgsBatchTopKTopPContainZero() { + std::string name = "InvalidArgsBatchTopKTopPContainZero"; + int* top_ks = new int[batch_size]{2, 2, 1, 0, 2, 0}; + float top_p = 0.0; + std::vector> expected_output_ids { + // batch + {0, 1}, {0, 1}, {0}, {0}, {0, 1}, {0}, // step 0 + {4, 5}, {4, 5}, {4}, {4}, {4, 5}, {4}, // step 1 + {2, 3}, {2, 3}, {2}, {2}, {2, 3}, {2} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, top_ks, batch_size, &top_p, 1, nullptr, nullptr); + delete[] top_ks; + EXPECT_TRUE(passed); + } + + void testInvalidArgsTopKBatchTopPContainZero() { + std::string name = "InvalidArgsTopKBatchTopPContainZero"; + int top_k = 0; + float* top_ps = new float[batch_size]{0.0, 0.3, 0.5, 0.0, 0.3, 0.5}; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0, 1}, {0}, {0}, {0, 1}, // step 0 + {4}, {4}, {4, 5}, {4}, {4}, {4, 5}, // step 1 + {2}, {2}, {2, 3}, {2}, {2}, {2, 3} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, &top_k, 1, top_ps, batch_size, nullptr, nullptr); + delete[] top_ps; + EXPECT_TRUE(passed); + } + + void testInvalidArgsBatchTopKBatchTopPContainZero() { + std::string name = "InvalidArgsBatchTopKBatchTopPContainZero"; + int* top_ks = new int[batch_size]{0, 2, 1, 2, 2, 0}; + float* top_ps = new float[batch_size]{0.0, 0.3, 0.9, 0.0, 0.3, 0.5}; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0}, {0, 1}, {0}, {0, 1}, // step 0 + {4}, {4}, {4}, {4, 5}, {4}, {4, 5}, // step 1 + {2}, {2}, {2}, {2, 3}, {2}, {2, 3} // step 2 + }; + bool passed = this->testSampling( + name, expected_output_ids, top_ks, batch_size, top_ps, batch_size, nullptr, nullptr); + delete[] top_ks; + delete[] top_ps; + EXPECT_TRUE(passed); + } + + void testLocalBatchBatchTopP() { + std::string name = "LocalBatch_BatchTopP"; + float* top_ps = new float[batch_size]{0.3f, 0.5f, 0.5f, 0.3f, 0.5f, 0.5f}; + std::vector> expected_output_ids { + {0}, {0}, {0, 1}, {0}, {0}, {0}, // step 0 + {0}, {0}, {4, 5}, {4}, {0}, {0}, // step 1 + {0}, {0}, {2, 3}, {2}, {0}, {0} // step 2 + }; + bool passed = this->testSamplingWithLocalBatch( + name, expected_output_ids, nullptr, 0, top_ps, batch_size, nullptr, nullptr); + delete[] top_ps; + EXPECT_TRUE(passed); + } + + + void testLocalBatchBatchTopKBatchTopP() { + std::string name = "LocalBatch_BatchTopKBatchTopP"; + int* top_ks = new int[batch_size]{2, 2, 0, 2, 2, 0}; + float* top_ps = new float[batch_size]{0.0, 0.3, 0.5, 0.0, 0.3, 0.5}; + std::vector> expected_output_ids { + // batch + {0}, {0}, {0, 1}, {0, 1}, {0}, {0}, // step 0 + {0}, {0}, {4, 5}, {4, 5}, {0}, {0}, // step 1 + {0}, {0}, {2, 3}, {2, 3}, {0}, {0} // step 2 + }; + bool passed = this->testSamplingWithLocalBatch( + name, expected_output_ids, top_ks, batch_size, top_ps, batch_size, nullptr, nullptr); + delete[] top_ks; + delete[] top_ps; + EXPECT_TRUE(passed); + } + + void testAll() { + this->testTopK(); + this->testTopP(); + this->testTopKTopP(); + this->testBatchTopK(); + this->testBatchTopP(); + this->testBatchTopKTopP(); + this->testTopKBatchTopP(); + this->testBatchTopKBatchTopP(); + this->testInvalidArgsZeroTopK(); + this->testInvalidArgsZeroTopP(); + this->testInvalidArgsZeroBatchTopKTopP(); + this->testInvalidArgsZeroTopKBatchTopP(); + this->testInvalidArgsZeroTopKTopP(); + this->testInvalidArgsBatchTopKContainZero(); + this->testInvalidArgsBatchTopPContainZero(); + this->testInvalidArgsBatchTopKTopPContainZero(); + this->testInvalidArgsTopKBatchTopPContainZero(); + this->testInvalidArgsBatchTopKBatchTopPContainZero(); + this->testLocalBatchBatchTopP(); + this->testLocalBatchBatchTopKBatchTopP(); + } +}; + +__global__ +void generateRandomNumber(unsigned int *vals, curandState_t *states, const int batch_size) { + int idx = threadIdx.x; + if (idx < batch_size) { + vals[idx] = curand(states + idx); + } +} + +template +static inline bool isEqualInPeriod(T* vals, size_t size, size_t period_size) { + // The same seed produces the same random number. + for (size_t i = 0; i + period_size - 1 < size; i += period_size) { + for (size_t j = 1; j < period_size; ++j) { + if (vals[i] != vals[i + j]) { + FT_LOG_INFO(" **** *** ** * [%d] %d <> [%d] %d", i, vals[i], i + j, vals[i + j]); + return false; + } + } + } + return true; +} + +template +static inline bool isEqualInPeriod(T* vals, size_t size, size_t period_size, size_t except) { + // The same seed produces the same random number. + for (size_t i = 0; i + period_size - 1 < size; i += period_size) { + for (size_t j = 1; j < period_size; ++j) { + if (j != except && vals[i] != vals[i + j]) { + FT_LOG_INFO(" **** *** ** * [%d] %d <> [%d] %d", i, vals[i], i + j, vals[i + j]); + return false; + } + } + } + return true; +} + +void testCuandBatchInitialize(const size_t batch_size) { + cudaStream_t stream; + cudaStreamCreate(&stream); + + curandState_t* curand_states; + check_cuda_error(cudaMalloc(&curand_states, sizeof(curandState_t) * batch_size)); + unsigned long long* h_random_seeds = new unsigned long long[batch_size]; + const size_t period_size = 3; + for (size_t i = 0; i < batch_size; ++i) { + h_random_seeds[i] = i / period_size; + } + unsigned long long* d_random_seeds; + check_cuda_error(cudaMalloc(&d_random_seeds, sizeof(unsigned long long) * batch_size)); + check_cuda_error(cudaMemcpy(d_random_seeds, h_random_seeds, + sizeof(unsigned long long) * batch_size, cudaMemcpyHostToDevice)); + + // Initialize curand states. + invokeCurandBatchInitialize(curand_states, batch_size, d_random_seeds, stream); + sync_check_cuda_error(); + + // Generate random numbers using initialized curand states. + unsigned int* d_rand_vals; + unsigned int* h_rand_vals = new unsigned int[batch_size]; + check_cuda_error(cudaMalloc(&d_rand_vals, sizeof(unsigned int) * batch_size)); + generateRandomNumber<<<1, batch_size, 0, stream>>>(d_rand_vals, curand_states, batch_size); + check_cuda_error(cudaMemcpyAsync( + h_rand_vals, d_rand_vals, sizeof(unsigned int) * batch_size, cudaMemcpyDeviceToHost, stream)); + check_cuda_error(cudaStreamSynchronize(stream)); + + // The same seed produces the same random number. + bool passed = isEqualInPeriod(h_rand_vals, batch_size, period_size); + FT_LOG_INFO("CuandBatchInitTest check....... : %s", passed ? "OK" : "FAILED"); + EXPECT_TRUE(passed); + + delete h_rand_vals; + delete h_random_seeds; + + check_cuda_error(cudaFree(d_rand_vals)); + check_cuda_error(cudaFree(d_random_seeds)); + check_cuda_error(cudaFree(curand_states)); + check_cuda_error(cudaStreamDestroy(stream)); +} + +template +void testSamplingLayerCurandInit(TestCase tc) { + FT_LOG_DEBUG("testSamplingLayerCurandInit %s", tc.toString().c_str()); + const DataType data_type = getTensorType(); + + const size_t beam_width = 1; + const uint top_k = tc.top_k; + const float top_p = tc.top_p; + // use default values having no effect. + const float temperature = 1.0f; + const float len_penalty = 0.0f; + const float repetition_penalty = 1.0f; + const int end_id = 3; + + const size_t batch_size = tc.batch_size; + const size_t batchxbeam = batch_size * beam_width; + const size_t local_batch_size = USE_LOCAL_BATCH ? 2 : batch_size; + assert(batch_size % local_batch_size == 0); + const size_t vocab_size = tc.vocab_size; + const size_t max_input_len = 0; // has no effect. + const size_t max_output_len = tc.output_len; + const size_t max_seq_len = max_input_len + max_output_len; + + struct cudaDeviceProp prop; + check_cuda_error(cudaGetDeviceProperties(&prop, 0)); + + cudaStream_t stream; + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + check_cuda_error(cudaStreamCreate(&stream)); + check_cuda_error(cublasCreate(&cublas_handle)); + check_cuda_error(cublasLtCreate(&cublaslt_handle)); + check_cuda_error(cublasSetStream(cublas_handle, stream)); + cublasAlgoMap cublas_algo_map(GEMM_CONFIG); + std::mutex* cublas_wrapper_mutex = new std::mutex(); + Allocator * allocator = new Allocator(getDevice()); + allocator->setStream(stream); + cublasMMWrapper *cublas_wrapper = new cublasMMWrapper(cublas_handle, + cublaslt_handle, + stream, + &cublas_algo_map, + cublas_wrapper_mutex, + allocator); + DynamicDecodeLayer *dynamic_decode_layer = new DynamicDecodeLayer(vocab_size, + vocab_size, + end_id, + stream, + cublas_wrapper, + allocator, + false, // is_free_buffer_after_forward + &prop); // cuda_device_prop + + T* h_logits = reinterpret_cast(malloc(sizeof(T) * batchxbeam * vocab_size)); + int* h_output_ids = reinterpret_cast(malloc(sizeof(int) * batchxbeam)); + + T* d_logits_buf = reinterpret_cast(allocator->malloc(sizeof(T) * batchxbeam * vocab_size)); + int* d_input_lengths = reinterpret_cast(allocator->malloc(sizeof(int) * batchxbeam)); + int* d_output_ids = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batchxbeam)); + + // Init by zero. + cudaMemset(d_input_lengths, 0, sizeof(int) * batchxbeam); + cudaMemset(d_output_ids, 0, sizeof(int) * max_seq_len * batchxbeam); + + // Prepare decoding arguments + const size_t random_seed_size = SINGLE_RANDOM_SEED ? 1 : batch_size; + const size_t period_size = 3; + unsigned long long* random_seed = new unsigned long long[random_seed_size]; + for (size_t i = 0; i < random_seed_size; ++i) { + random_seed[i] = i / period_size; + } + const bool has_diff_runtime_args = HAS_DIFF_ARGS; + const size_t runtime_args_size = has_diff_runtime_args ? batch_size : 1; + uint* runtime_top_k = new uint[runtime_args_size]; + float* runtime_top_p = new float[runtime_args_size]; + const size_t except_idx = 1; + for (size_t i = 0; i < runtime_args_size; ++i) { + runtime_top_k[i] = (top_k > 1) && (i % period_size == except_idx) ? 1 : top_k; + runtime_top_p[i] = (i % period_size == except_idx) ? top_p * 0.1f : top_p; + } + int* d_end_id_buf = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + deviceFill(d_end_id_buf, batch_size, end_id); + +#ifndef NDEBUG + FT_LOG_DEBUG("Random Seeds"); + printMatrixWithLimit(random_seed, 1, random_seed_size, random_seed_size, false); +#endif + + bool passed = true; + + std::unordered_map runtime_args; + runtime_args.insert({"has_diff_runtime_args", Tensor(MEMORY_CPU, TYPE_BOOL, {1}, &has_diff_runtime_args)}); + runtime_args.insert({"random_seed", Tensor(MEMORY_CPU, TYPE_UINT64, {random_seed_size}, random_seed)}); + runtime_args.insert({"runtime_top_k", Tensor(MEMORY_CPU, TYPE_INT32, {runtime_args_size}, runtime_top_k)}); + runtime_args.insert({"runtime_top_p", Tensor(MEMORY_CPU, TYPE_FP32, {runtime_args_size}, runtime_top_p)}); + runtime_args.insert({"temperature", Tensor(MEMORY_CPU, TYPE_FP32, {1}, &temperature)}); + runtime_args.insert({"len_penalty", Tensor(MEMORY_CPU, TYPE_FP32, {1}, &len_penalty)}); + runtime_args.insert({"repetition_penalty", Tensor(MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty)}); + dynamic_decode_layer->setup(batch_size, beam_width, &runtime_args); + + for (size_t step = max_input_len; step < max_output_len; ++step) { + const size_t iteration_num = batch_size / local_batch_size; + + initRandom(h_logits, beam_width * vocab_size, -10.0f / vocab_size, -1.0f); + tile(h_logits, batch_size, beam_width * vocab_size); + cudaH2Dcpy(d_logits_buf, h_logits, batchxbeam * vocab_size); + +#ifndef NDEBUG + FT_LOG_DEBUG("logit values"); + printMatrixWithLimit(h_logits, batchxbeam, vocab_size, vocab_size, false); +#endif + for (uint ite = 0; ite < iteration_num; ++ite) { + std::unordered_map dynamic_decode_input_tensors{ + {"logits", Tensor{MEMORY_GPU, data_type, {batch_size, beam_width, vocab_size}, d_logits_buf}}, + {"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size}, nullptr}}, + {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}, + {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_len}}, + {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, d_input_lengths}}, + {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}}, + {"has_diff_runtime_args", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &has_diff_runtime_args}}, + {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &local_batch_size}}, + {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, d_end_id_buf}}, + {"random_seed", {MEMORY_CPU, TYPE_UINT64, {random_seed_size}, random_seed}}, + {"runtime_top_k", {MEMORY_CPU, TYPE_UINT32, {runtime_args_size}, runtime_top_k}}, + {"runtime_top_p", {MEMORY_CPU, TYPE_FP32, {runtime_args_size}, runtime_top_p}}, + {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature}}, + {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &len_penalty}}, + {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty}} + }; + + // common outputs + std::unordered_map dynamic_decode_output_tensors{ + {"output_ids", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, d_output_ids}}, + {"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, nullptr}}, + {"parent_ids", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, nullptr}}, + {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, nullptr}}, + // necessary for beam search. + {"tgt_cache_indirection", + Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width, max_output_len}, nullptr}}}; + + dynamic_decode_layer->forward(&dynamic_decode_output_tensors, + &dynamic_decode_input_tensors); + sync_check_cuda_error(); +#ifndef NDEBUG + FT_LOG_DEBUG("Step %2d generated ids", step); + printMatrix(d_output_ids, max_seq_len, batchxbeam, batchxbeam, true); + FT_LOG_DEBUG(""); +#endif + // check results. + cudaD2Hcpy(h_output_ids, + (int*)dynamic_decode_output_tensors.at("output_ids").getPtrWithOffset(step * batchxbeam), + batchxbeam); + } + bool is_ok = isEqualInPeriod(h_output_ids, batchxbeam, period_size, except_idx); + passed &= is_ok; + } + std::string tag = fmtstr("%s (seed=%-6s has_diff_args=%-5s local_batch=%-5s T=%s)", + tc.toString().c_str(), + SINGLE_RANDOM_SEED ? "single" : "multi", + HAS_DIFF_ARGS ? "true" : "false", + USE_LOCAL_BATCH ? "true" : "false", + (std::is_same::value ? " fp32" : " fp16")); + FT_LOG_INFO("check...%s SamplingLayerCurandInitTest %-30s", passed ? "....OK" : "FAILED", tag.c_str()); + EXPECT_TRUE(passed); - free(expected_cum_log_probs); - free(h_output_log_probs); - free(h_cum_log_probs); free(h_logits); - free(h_log_probs); - free(h_probs); free(h_output_ids); - allocator->free(tiled_input_lengths_buf); - allocator->free(cum_log_probs); - allocator->free(output_ids); - allocator->free(logits_buf); - allocator->free(output_log_probs); + delete dynamic_decode_layer; + delete runtime_top_k; + delete runtime_top_p; + delete random_seed; delete cublas_wrapper; delete allocator; check_cuda_error(cudaStreamDestroy(stream)); @@ -399,6 +1170,7 @@ int main() { TestCase{"topk", 6, 51200, 1, 31, 0.0f, 16}, TestCase{"topk", 32, 51200, 1, 63, 0.0f, 16}, TestCase{"topk", 32, 51200, 1, 64, 0.0f, 16}, + TestCase{"topp", 6, 4, 1, 0, 0.2f, 4}, TestCase{"topp", 6, 4, 1, 0, 0.8f, 4}, TestCase{"topp", 6, 4, 1, 0, 1.0f, 4}, TestCase{"topp", 6, 51200, 1, 0, 0.8f, 16}, @@ -412,9 +1184,30 @@ int main() { }; for (auto &tc : test_cases) { - testDynamicDecoingLayer(tc); - testDynamicDecoingLayer(tc); // T5 model uses DynamicDecodingLayer. + testCumLogProbComputation(tc); + testCumLogProbComputation(tc); } - FT_LOG_INFO("Test Done"); + FT_LOG_INFO("testCumLogProbComputation done"); + + SamplingDecodeTest sampling_decode_test; + sampling_decode_test.testAll(); + + testCuandBatchInitialize(127); + FT_LOG_INFO("testCuandBatchInitialize done"); + + #define LAUNCH_VARIANTS(T, tc, local_batch) \ + testSamplingLayerCurandInit(tc); \ + testSamplingLayerCurandInit(tc); \ + testSamplingLayerCurandInit(tc); \ + testSamplingLayerCurandInit(tc); + for (auto &tc : test_cases) { + LAUNCH_VARIANTS(float, tc, false); // without local batch + LAUNCH_VARIANTS(half, tc, false); + LAUNCH_VARIANTS(float, tc, true); // with local batch + LAUNCH_VARIANTS(half, tc, true); + } + #undef LAUNCH_VARIANTS + FT_LOG_INFO("testSamplingLayerCurandInit done"); + return 0; } diff --git a/tests/unittests/test_sampling_kernels.cu b/tests/unittests/test_sampling_kernels.cu new file mode 100644 index 000000000..068ec6559 --- /dev/null +++ b/tests/unittests/test_sampling_kernels.cu @@ -0,0 +1,908 @@ +#include // std::min, std::max +#include // snprintf +#include // expf, log +#include // rand +#include // std::string +#include // std::vector + +#include +#include +#include + +#include "src/fastertransformer/kernels/sampling_topk_kernels.h" +#include "src/fastertransformer/kernels/sampling_topp_kernels.h" +#include "src/fastertransformer/layers/DynamicDecodeLayer.h" +#include "src/fastertransformer/layers/sampling_layers/TopKSamplingLayer.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/memory_utils.h" +#include "src/fastertransformer/utils/Tensor.h" + +#include "tests/unittests/unittest_utils.h" + +using namespace fastertransformer; + +struct TestCase { + std::string name; + size_t batch_size; + size_t vocab_size; + size_t beam_width; + size_t top_k; + float top_p; + size_t output_len; + + std::string toString() { + char buf[100]; + snprintf(buf, sizeof(buf), + "TestCase[name=%s, batch=%ld, vocab=%ld, beam=%ld, k=%ld, p=%3.1f, output_len=%ld]", + name.c_str(), batch_size, vocab_size, beam_width, top_k, top_p, output_len); + return buf; + } + + void print() { + FT_LOG_INFO(toString()); + } +}; + +template +void computeProb(T* probs, T* logits, int batch_size, int vocab_size) { + // Compute the log probability from logits. + // logits = batch_size x vocab_size vector. + // logprobs = log(softmax(logits)) (softmax along with vocab dimension) + for (int bidx = 0; bidx < batch_size; ++bidx) { + float sum = 0.0f; + for (int i = 0; i < vocab_size; ++i) { + sum += expf((float)logits[bidx * vocab_size + i]); + } + for (int i = 0; i < vocab_size; ++i) { + int idx = bidx * vocab_size + i; + probs[idx] = static_cast(expf((float)logits[idx]) / (sum + EPSILON)); + } + } +} + +template +void computeLogProb(T* logprobs, T* logits, int batch_size, int vocab_size) { + // Compute the log probability from logits. + // logits = batch_size x vocab_size vector. + // logprobs = log(softmax(logits)) (softmax along with vocab dimension) + for (int bidx = 0; bidx < batch_size; ++bidx) { + float sum = 0.0f; + for (int i = 0; i < vocab_size; ++i) { + sum += expf(logits[bidx * vocab_size + i]); + } + for (int i = 0; i < vocab_size; ++i) { + int idx = bidx * vocab_size + i; + logprobs[idx] = static_cast(logf(expf(logits[idx]) / (sum + EPSILON) + EPSILON)); + } + } +} + +std::string toTestTag(std::string name, TestCase tc, bool is_fp32) { + return name + " " + tc.toString() + (is_fp32 ? " (fp32)" : " (fp16)"); +} + +/////////////////////////////////// Tests ////////////////////////////////////////// + +template +void testTopKSamplingKernel(TestCase tc) { + + bool is_fp32 = std::is_same::value; + + size_t top_k = tc.top_k; + unsigned long long seed = 0; + + size_t batch_size = tc.batch_size; + size_t vocab_size = tc.vocab_size; + + int end_id = 3; + size_t max_output_len = tc.output_len; + size_t max_seq_len = max_output_len; + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + Allocator* allocator = new Allocator(getDevice()); + allocator->setStream(stream); + + // Logit values in the host of shape (batch_size x vocab_size). + T* h_logits = new T[batch_size * vocab_size]; + T* h_probs = new T[batch_size * vocab_size]; + T* h_log_probs = new T[batch_size * vocab_size]; + float* h_cum_log_probs = new float[batch_size]; + float* h_output_log_probs = new float[batch_size]; + float* expected_cum_log_probs = new float[batch_size]; + int* h_output_ids = new int[batch_size]; + int* h_seq_lengths = new int[batch_size]; + bool* h_finished = new bool[batch_size]; + + initRandom(h_logits, batch_size * vocab_size, -10.0f, -1.0f); + memset(expected_cum_log_probs, 0, sizeof(float) * batch_size); + + curandState_t* curand_states = reinterpret_cast( + allocator->malloc(sizeof(curandState_t) * batch_size, false)); + invokeCurandInitialize(curand_states, batch_size, seed, stream); + + size_t workspace_size = 0; + // retrieve the workspace size of the top-k sampling kernel. + invokeTopKSampling(nullptr, + workspace_size, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + top_k, + 1.0f, + vocab_size, + nullptr, + stream, + batch_size, + nullptr); + void* workspace = allocator->malloc(workspace_size, false); + int* sequence_lengths = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + bool* finished = reinterpret_cast(allocator->malloc(sizeof(bool) * batch_size)); + int* end_ids = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size, false)); + + T* probs = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size * vocab_size, true)); + float* cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batch_size)); + float* output_log_probs = reinterpret_cast( + allocator->malloc(sizeof(float) * max_output_len * batch_size)); + int* output_ids = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batch_size)); + + // Init by zero. + deviceFill(sequence_lengths, batch_size, 0); + deviceFill(finished, batch_size, false); + deviceFill(end_ids, batch_size, end_id); + + deviceFill(cum_log_probs, batch_size, 0.0f); + deviceFill(output_log_probs, max_output_len * batch_size, 0.0f); + deviceFill(output_ids, max_seq_len * batch_size, 0); + + void* h_worksapce = malloc(workspace_size); + + for (size_t step = 0; step < max_output_len; ++step) { + initRandom(h_logits, batch_size * vocab_size, -10.0f, -1.0f); + computeProb(h_probs, h_logits, batch_size, vocab_size); + cudaH2Dcpy(probs, h_probs, batch_size * vocab_size); + invokeTopKSampling(workspace, + workspace_size, + // Note that the kernel needs vocab probs instead of + // log-prob if cum_log_probs or output_log_probs are + // provided. It's because the sampling layer already + // preprocesses log_prob_buf when those are provided. + probs, + output_ids + step * batch_size, + sequence_lengths, + finished, + cum_log_probs, + output_log_probs + step * batch_size, + curand_states, + top_k, + 1.0f, + vocab_size, + end_ids, + stream, + batch_size, + nullptr); + + // Compute reference. + cudaD2Hcpy(h_output_ids, output_ids + step * batch_size, batch_size); + cudaD2Hcpy(h_output_log_probs, output_log_probs + step * batch_size, batch_size); + cudaD2Hcpy(h_cum_log_probs, cum_log_probs, batch_size); + cudaD2Hcpy(h_seq_lengths, sequence_lengths, batch_size); + cudaD2Hcpy(h_finished, finished, batch_size); + computeLogProb(h_log_probs, h_logits, batch_size, vocab_size); + for (size_t i = 0; i < batch_size; ++i) { + int idx = i * vocab_size + h_output_ids[i]; + bool expected_finished = h_output_ids[i] == end_id; + float expected_log_prob = (int)step < h_seq_lengths[i] ? (float)h_log_probs[idx] : 0.0f; + expected_cum_log_probs[i] += expected_log_prob; + EXPECT_TRUE(h_finished[i] == expected_finished); + } + } + std::string tag = toTestTag("TestTopKSamplingKernel", tc, is_fp32); + bool passed = checkResult(tag, cum_log_probs, expected_cum_log_probs, batch_size); + EXPECT_TRUE(passed); + + delete[] expected_cum_log_probs; + delete[] h_seq_lengths; + delete[] h_output_log_probs; + delete[] h_cum_log_probs; + delete[] h_logits; + delete[] h_log_probs; + delete[] h_probs; + delete[] h_output_ids; + delete allocator; + check_cuda_error(cudaStreamDestroy(stream)); +} + +template +void testBatchTopKSamplingKernel(TestCase tc, bool has_diff_runtime_args) { + + bool is_fp32 = std::is_same::value; + + unsigned long long seed = 0; + + size_t batch_size = tc.batch_size; + size_t vocab_size = tc.vocab_size; + + int top_k = (int)tc.top_k; + int* h_top_ks = new int[batch_size]; + // Initialize runtime top k values. + for (size_t i = 0; i < batch_size; ++i) { + h_top_ks[i] = has_diff_runtime_args ? std::max(1, top_k - int(i % 3)) : top_k; + } + int max_top_k = *std::max_element(h_top_ks, h_top_ks + batch_size); + int end_id = 3; + size_t max_output_len = tc.output_len; + size_t max_seq_len = max_output_len; + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + Allocator* allocator = new Allocator(getDevice()); + allocator->setStream(stream); + + // Logit values in the host of shape (batch_size x vocab_size). + T* h_logits = new T[batch_size * vocab_size]; + T* h_probs = new T[batch_size * vocab_size]; + T* h_log_probs = new T[batch_size * vocab_size]; + float* h_cum_log_probs = new float[batch_size]; + float* h_output_log_probs = new float[batch_size]; + float* expected_cum_log_probs = new float[batch_size]; + int* h_output_ids = new int[batch_size]; + int* h_seq_lengths = new int[batch_size]; + bool* h_finished = new bool[batch_size]; + + initRandom(h_logits, batch_size * vocab_size, -10.0f, -1.0f); + memset(expected_cum_log_probs, 0, sizeof(float) * batch_size); + + curandState_t* curand_states = reinterpret_cast( + allocator->malloc(sizeof(curandState_t) * batch_size, false)); + invokeCurandInitialize(curand_states, batch_size, seed, stream); + + size_t workspace_size = 0; + // retrieve the workspace size of the top-k sampling kernel. + invokeBatchTopKSampling(nullptr, // workspace + workspace_size, + nullptr, // log_probs + nullptr, // ids + nullptr, // sequence_lengths + nullptr, // finished + nullptr, // cum_log_probs + nullptr, // output_log_probs + nullptr, // curandstates + max_top_k, + nullptr, // top_ks + 1.0f, + nullptr, + vocab_size, + nullptr, // end_ids + stream, + batch_size, + nullptr); + + void* workspace = allocator->malloc(workspace_size, false); + int* top_ks = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + int* end_ids = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + int* sequence_lengths = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + bool* finished = reinterpret_cast(allocator->malloc(sizeof(bool) * batch_size)); + T* probs = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size * vocab_size, true)); + float* cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batch_size)); + float* output_log_probs = reinterpret_cast( + allocator->malloc(sizeof(float) * max_output_len * batch_size)); + int* output_ids = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batch_size)); + + // Initialize. + cudaH2Dcpy(top_ks, h_top_ks, batch_size); + deviceFill(end_ids, batch_size, end_id); + deviceFill(sequence_lengths, batch_size, 0); + deviceFill(finished, batch_size, false); + deviceFill(cum_log_probs, batch_size, 0.0f); + deviceFill(output_log_probs, max_output_len * batch_size, 0.0f); + deviceFill(output_ids, max_seq_len * batch_size, 0); + + for (size_t step = 0; step < max_output_len; ++step) { + initRandom(h_logits, batch_size * vocab_size, -10.0f, -1.0f); + computeProb(h_probs, h_logits, batch_size, vocab_size); + cudaH2Dcpy(probs, h_probs, batch_size * vocab_size); + + invokeBatchTopKSampling(workspace, + workspace_size, + // Note that the kernel needs vocab probs instead of + // log-prob if cum_log_probs or output_log_probs are + // provided. It's because the sampling layer already + // preprocesses log_prob_buf when those are provided. + probs, + output_ids + step * batch_size, + sequence_lengths, + finished, + cum_log_probs, + output_log_probs + step * batch_size, + curand_states, + max_top_k, + top_ks, + 1.0f, + nullptr, + vocab_size, + end_ids, + stream, + batch_size, + nullptr); + + // Compute reference. + cudaD2Hcpy(h_output_ids, output_ids + step * batch_size, batch_size); + cudaD2Hcpy(h_output_log_probs, output_log_probs + step * batch_size, batch_size); + cudaD2Hcpy(h_cum_log_probs, cum_log_probs, batch_size); + cudaD2Hcpy(h_seq_lengths, sequence_lengths, batch_size); + cudaD2Hcpy(h_finished, finished, batch_size); + computeLogProb(h_log_probs, h_logits, batch_size, vocab_size); + for (size_t i = 0; i < batch_size; ++i) { + int idx = i * vocab_size + h_output_ids[i]; + bool expected_finished = h_output_ids[i] == end_id; + float expected_log_prob = (int)step < h_seq_lengths[i] ? (float)h_log_probs[idx] : 0.0f; + expected_cum_log_probs[i] += expected_log_prob; + EXPECT_TRUE(h_finished[i] == expected_finished); + } + } + std::string tag = toTestTag("TestBatchTopKSamplingKernel", tc, is_fp32) + + (has_diff_runtime_args ? " (diff_args)" : ""); + bool passed = checkResult(tag, cum_log_probs, expected_cum_log_probs, batch_size); + EXPECT_TRUE(passed); + + delete[] expected_cum_log_probs; + delete[] h_seq_lengths; + delete[] h_output_log_probs; + delete[] h_cum_log_probs; + delete[] h_logits; + delete[] h_log_probs; + delete[] h_probs; + delete[] h_output_ids; + delete[] h_top_ks; + delete allocator; + check_cuda_error(cudaStreamDestroy(stream)); +} + +template +void testBatchTopKSamplingWithSkipDecode(TestCase tc) { + + bool is_fp32 = std::is_same::value; + + unsigned long long seed = 0; + + size_t batch_size = tc.batch_size; + size_t vocab_size = tc.vocab_size; + + int top_k = (int)tc.top_k; + int* h_top_ks = new int[batch_size]; + // Initialize runtime top k values. + for (size_t i = 0; i < batch_size; ++i) { + h_top_ks[i] = i % 3 == 0 ? top_k : 1; + } + int max_top_k = *std::max_element(h_top_ks, h_top_ks + batch_size); + int end_id = 0; + size_t max_output_len = tc.output_len; + size_t max_seq_len = max_output_len; + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + Allocator* allocator = new Allocator(getDevice()); + allocator->setStream(stream); + + // Logit values in the host of shape (batch_size x vocab_size). + T* h_logits = new T[batch_size * vocab_size]; + T* h_probs = new T[batch_size * vocab_size]; + T* h_log_probs = new T[batch_size * vocab_size]; + float* h_cum_log_probs = new float[batch_size]; + float* h_output_log_probs = new float[batch_size]; + float* expected_cum_log_probs = new float[batch_size]; + int* h_output_ids = new int[batch_size]; + int* h_seq_lengths = new int[batch_size]; + bool* h_finished = new bool[batch_size]; + bool* h_skip_decode = new bool[batch_size]; + for (size_t i = 0; i < batch_size; ++i) { + h_skip_decode[i] = i % 2 == 0; + } + + initRandom(h_logits, batch_size * vocab_size, -3.0f, 3.0f); + memset(expected_cum_log_probs, 0, sizeof(float) * batch_size); + + curandState_t* curand_states = reinterpret_cast( + allocator->malloc(sizeof(curandState_t) * batch_size, false)); + invokeCurandInitialize(curand_states, batch_size, seed, stream); + + size_t workspace_size = 0; + // retrieve the workspace size of the top-k sampling kernel. + invokeBatchTopKSampling(nullptr, // workspace + workspace_size, + nullptr, // log_probs + nullptr, // ids + nullptr, // sequence_lengths + nullptr, // finished + nullptr, // cum_log_probs + nullptr, // output_log_probs + nullptr, // curandstates + max_top_k, + nullptr, // top_ks + 1.0f, + nullptr, + vocab_size, + nullptr, // end_ids + stream, + batch_size, + nullptr); + + void* workspace = allocator->malloc(workspace_size, false); + int* top_ks = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + int* end_ids = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + int* sequence_lengths = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + bool* finished = reinterpret_cast(allocator->malloc(sizeof(bool) * batch_size)); + T* probs = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size * vocab_size, true)); + float* cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batch_size)); + float* output_log_probs = reinterpret_cast( + allocator->malloc(sizeof(float) * max_output_len * batch_size)); + int* output_ids = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batch_size)); + bool* skip_decode = reinterpret_cast(allocator->malloc(sizeof(bool) * batch_size)); + cudaH2Dcpy(skip_decode, h_skip_decode, batch_size); + + // Initialize. + cudaH2Dcpy(top_ks, h_top_ks, batch_size); + deviceFill(end_ids, batch_size, end_id); + deviceFill(sequence_lengths, batch_size, 0); + deviceFill(finished, batch_size, false); + deviceFill(cum_log_probs, batch_size, 0.0f); + deviceFill(output_log_probs, max_output_len * batch_size, 0.0f); + deviceFill(output_ids, max_seq_len * batch_size, 0); + + for (size_t step = 0; step < max_output_len; ++step) { + initRandom(h_logits, batch_size * vocab_size, -10.0f, -1.0f); + computeProb(h_probs, h_logits, batch_size, vocab_size); + cudaH2Dcpy(probs, h_probs, batch_size * vocab_size); + + invokeBatchTopKSampling(workspace, + workspace_size, + // Note that the kernel needs vocab probs instead of + // log-prob if cum_log_probs or output_log_probs are + // provided. It's because the sampling layer already + // preprocesses log_prob_buf when those are provided. + probs, + output_ids + step * batch_size, + sequence_lengths, + finished, + cum_log_probs, + output_log_probs + step * batch_size, + curand_states, + max_top_k, + top_ks, + 1.0f, + nullptr, + vocab_size, + end_ids, + stream, + batch_size, + skip_decode); + + // Compute reference. + cudaD2Hcpy(h_output_ids, output_ids + step * batch_size, batch_size); + cudaD2Hcpy(h_output_log_probs, output_log_probs + step * batch_size, batch_size); + cudaD2Hcpy(h_cum_log_probs, cum_log_probs, batch_size); + cudaD2Hcpy(h_seq_lengths, sequence_lengths, batch_size); + cudaD2Hcpy(h_finished, finished, batch_size); + computeLogProb(h_log_probs, h_logits, batch_size, vocab_size); + for (size_t i = 0; i < batch_size; ++i) { + if (!h_skip_decode[i]) { + int idx = i * vocab_size + h_output_ids[i]; + bool expected_finished = h_output_ids[i] == end_id; + float expected_log_prob = (int)step < h_seq_lengths[i] ? (float)h_log_probs[idx] : 0.0f; + expected_cum_log_probs[i] += expected_log_prob; + EXPECT_TRUE(h_finished[i] == expected_finished); + } + } + } + std::string tag = toTestTag("TestBatchTopKSamplingWithSkip", tc, is_fp32); + bool passed = checkResult(tag, cum_log_probs, expected_cum_log_probs, batch_size); + EXPECT_TRUE(passed); + + delete[] expected_cum_log_probs; + delete[] h_seq_lengths; + delete[] h_output_log_probs; + delete[] h_cum_log_probs; + delete[] h_logits; + delete[] h_log_probs; + delete[] h_probs; + delete[] h_output_ids; + delete[] h_top_ks; + delete allocator; + check_cuda_error(cudaStreamDestroy(stream)); +} + +template +inline T clip(T val, T minval, T maxval) { + if (val < minval) return minval; + if (val > maxval) return maxval; + return val; +} + +template +void testBatchTopPSamplingKernel(TestCase tc, bool has_diff_runtime_args) { + unsigned long long seed = 0; + + size_t batch_size = tc.batch_size; + size_t vocab_size = tc.vocab_size; + + float top_p = tc.top_p; + float* h_top_ps = new float[batch_size]; + // Initialize runtime top k values. + for (size_t i = 0; i < batch_size; ++i) { + h_top_ps[i] = top_p; + if (has_diff_runtime_args) { + h_top_ps[i] = clip(h_top_ps[i] + ((i % 2 == 0) ? -0.1 : 0.1), 0.1f, 0.9f); + } + } + int max_top_p = *std::max_element(h_top_ps, h_top_ps + batch_size); + int end_id = 3; + size_t max_output_len = tc.output_len; + size_t max_seq_len = max_output_len; + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + Allocator* allocator = new Allocator(getDevice()); + allocator->setStream(stream); + + // Logit values in the host of shape (batch_size x vocab_size). + T* h_logits = new T[batch_size * vocab_size]; + T* h_probs = new T[batch_size * vocab_size]; + T* h_log_probs = new T[batch_size * vocab_size]; + float* h_cum_log_probs = new float[batch_size]; + float* h_output_log_probs = new float[batch_size]; + float* expected_cum_log_probs = new float[batch_size]; + int* h_output_ids = new int[batch_size]; + int* h_seq_lengths = new int[batch_size]; + bool* h_finished = new bool[batch_size]; + + initRandom(h_logits, batch_size * vocab_size, -10.0f, -1.0f); + memset(expected_cum_log_probs, 0, sizeof(float) * batch_size); + + int device; + cudaGetDevice(&device); + struct cudaDeviceProp device_prop; + cudaGetDeviceProperties(&device_prop, device); + + curandState_t* curand_states = reinterpret_cast( + allocator->malloc(sizeof(curandState_t) * batch_size, false)); + invokeCurandInitialize(curand_states, batch_size, seed, stream); + + float* top_ps = reinterpret_cast(allocator->malloc(sizeof(float) * batch_size)); + int* end_ids = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + int* sequence_lengths = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + bool* finished = reinterpret_cast(allocator->malloc(sizeof(bool) * batch_size)); + T* probs = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size * vocab_size, true)); + float* cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batch_size)); + float* output_log_probs = reinterpret_cast( + allocator->malloc(sizeof(float) * max_output_len * batch_size)); + int* output_ids = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batch_size)); + + int* begin_offsets = reinterpret_cast(allocator->malloc(sizeof(int) * (batch_size + 1))); + int* end_offsets = reinterpret_cast(allocator->malloc(sizeof(int) * (batch_size + 1))); + int* topp_id_vals_buf = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size * vocab_size)); + + size_t workspace_size = 0; + size_t cub_temp_storage_size = 0; + // retrieve the workspace size of the top-k sampling kernel. + invokeBatchTopPSampling(nullptr, // workspace + workspace_size, + cub_temp_storage_size, + nullptr, // output_ids + nullptr, // sequence_length + nullptr, // finished_buffer + nullptr, // cum_log_probs + nullptr, // output_log_probs + (T*)nullptr, // log_probs + topp_id_vals_buf, + end_offsets, + begin_offsets, + curand_states, + batch_size, + vocab_size, + nullptr, + max_top_p, + top_ps, + stream, + &device_prop, + nullptr); + void* workspace = allocator->malloc(workspace_size, false); + + // Initialize. + cudaH2Dcpy(top_ps, h_top_ps, batch_size); + deviceFill(end_ids, batch_size, end_id); + deviceFill(sequence_lengths, batch_size, 0); + deviceFill(finished, batch_size, false); + deviceFill(cum_log_probs, batch_size, 0.0f); + deviceFill(output_log_probs, max_output_len * batch_size, 0.0f); + deviceFill(output_ids, max_seq_len * batch_size, 0); + + for (size_t step = 0; step < max_output_len; ++step) { + initRandom(h_logits, batch_size * vocab_size, -10.0f, -1.0f); + computeProb(h_probs, h_logits, batch_size, vocab_size); + cudaH2Dcpy(probs, h_probs, batch_size * vocab_size); + + invokeTopPInitialize(topp_id_vals_buf, + end_offsets, + begin_offsets, + batch_size, + vocab_size, + stream); + + invokeBatchTopPSampling(workspace, + workspace_size, + cub_temp_storage_size, + output_ids + step * batch_size, + sequence_lengths, + finished, + cum_log_probs, + output_log_probs + step * batch_size, + // Note that the kernel needs vocab probs instead of + // log-prob if cum_log_probs or output_log_probs are + // provided. It's because the sampling layer already + // preprocesses log_prob_buf when those are provided. + probs, + topp_id_vals_buf, + end_offsets, + begin_offsets, + curand_states, + batch_size, + vocab_size, + end_ids, + max_top_p, + top_ps, + stream, + &device_prop, + nullptr); + + // Compute reference. + cudaD2Hcpy(h_output_ids, output_ids + step * batch_size, batch_size); + cudaD2Hcpy(h_output_log_probs, output_log_probs + step * batch_size, batch_size); + cudaD2Hcpy(h_cum_log_probs, cum_log_probs, batch_size); + cudaD2Hcpy(h_seq_lengths, sequence_lengths, batch_size); + cudaD2Hcpy(h_finished, finished, batch_size); + computeLogProb(h_log_probs, h_logits, batch_size, vocab_size); + for (size_t i = 0; i < batch_size; ++i) { + int idx = i * vocab_size + h_output_ids[i]; + bool expected_finished = h_output_ids[i] == end_id; + float expected_log_prob = (int)step < h_seq_lengths[i] ? (float)h_log_probs[idx] : 0.0f; + expected_cum_log_probs[i] += expected_log_prob; + EXPECT_TRUE(h_finished[i] == expected_finished); + } + } + std::string tag = toTestTag("TestBatchTopPSamplingKernel", tc, std::is_same::value); + bool passed = checkResult(tag, cum_log_probs, expected_cum_log_probs, batch_size); + EXPECT_TRUE(passed); + + delete[] expected_cum_log_probs; + delete[] h_seq_lengths; + delete[] h_output_log_probs; + delete[] h_cum_log_probs; + delete[] h_logits; + delete[] h_log_probs; + delete[] h_probs; + delete[] h_output_ids; + delete[] h_top_ps; + delete allocator; + check_cuda_error(cudaStreamDestroy(stream)); +} + +template +void testBatchTopPSamplingWithSkipDecode(TestCase tc) { + unsigned long long seed = 0; + + size_t batch_size = tc.batch_size; + size_t vocab_size = tc.vocab_size; + + float top_p = tc.top_p; + float* h_top_ps = new float[batch_size]; + // Initialize runtime top k values. + for (size_t i = 0; i < batch_size; ++i) { + h_top_ps[i] = i % 2 == 0 ? top_p : 0.3f * top_p; + } + int max_top_p = *std::max_element(h_top_ps, h_top_ps + batch_size); + int end_id = 3; + size_t max_output_len = tc.output_len; + size_t max_seq_len = max_output_len; + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + Allocator* allocator = new Allocator(getDevice()); + allocator->setStream(stream); + + // Logit values in the host of shape (batch_size x vocab_size). + T* h_logits = new T[batch_size * vocab_size]; + T* h_probs = new T[batch_size * vocab_size]; + T* h_log_probs = new T[batch_size * vocab_size]; + float* h_cum_log_probs = new float[batch_size]; + float* h_output_log_probs = new float[batch_size]; + float* expected_cum_log_probs = new float[batch_size]; + int* h_output_ids = new int[batch_size]; + int* h_seq_lengths = new int[batch_size]; + bool* h_finished = new bool[batch_size]; + bool* h_skip_decode = new bool[batch_size]; + for (size_t i = 0; i < batch_size; ++i) { + h_skip_decode[i] = i % 2 == 0; + } + + initRandom(h_logits, batch_size * vocab_size, -3.0f, -3.0f); + memset(expected_cum_log_probs, 0, sizeof(float) * batch_size); + + int device; + cudaGetDevice(&device); + struct cudaDeviceProp device_prop; + cudaGetDeviceProperties(&device_prop, device); + + curandState_t* curand_states = reinterpret_cast( + allocator->malloc(sizeof(curandState_t) * batch_size, false)); + invokeCurandInitialize(curand_states, batch_size, seed, stream); + + float* top_ps = reinterpret_cast(allocator->malloc(sizeof(float) * batch_size)); + int* end_ids = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + int* sequence_lengths = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size)); + bool* finished = reinterpret_cast(allocator->malloc(sizeof(bool) * batch_size)); + T* probs = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size * vocab_size, true)); + float* cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batch_size)); + float* output_log_probs = reinterpret_cast( + allocator->malloc(sizeof(float) * max_output_len * batch_size)); + int* output_ids = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batch_size)); + + int* begin_offsets = reinterpret_cast(allocator->malloc(sizeof(int) * (batch_size + 1))); + int* end_offsets = reinterpret_cast(allocator->malloc(sizeof(int) * (batch_size + 1))); + int* topp_id_vals_buf = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size * vocab_size)); + + bool* skip_decode = reinterpret_cast(allocator->malloc(sizeof(bool) * batch_size)); + cudaH2Dcpy(skip_decode, h_skip_decode, batch_size); + + size_t workspace_size = 0; + size_t cub_temp_storage_size = 0; + // retrieve the workspace size of the top-k sampling kernel. + invokeBatchTopPSampling(nullptr, // workspace + workspace_size, + cub_temp_storage_size, + nullptr, // output_ids + nullptr, // sequence_length + nullptr, // finished_buffer + nullptr, // cum_log_probs + nullptr, // output_log_probs + (T*)nullptr, // log_probs + topp_id_vals_buf, + end_offsets, + begin_offsets, + curand_states, + batch_size, + vocab_size, + nullptr, + max_top_p, + top_ps, + stream, + &device_prop, + nullptr); + void* workspace = allocator->malloc(workspace_size, false); + + // Initialize. + cudaH2Dcpy(top_ps, h_top_ps, batch_size); + deviceFill(end_ids, batch_size, end_id); + deviceFill(sequence_lengths, batch_size, 0); + deviceFill(finished, batch_size, false); + deviceFill(cum_log_probs, batch_size, 0.0f); + deviceFill(output_log_probs, max_output_len * batch_size, 0.0f); + deviceFill(output_ids, max_seq_len * batch_size, 0); + + for (size_t step = 0; step < max_output_len; ++step) { + initRandom(h_logits, batch_size * vocab_size, -3.0f, 3.0f); + computeProb(h_probs, h_logits, batch_size, vocab_size); + cudaH2Dcpy(probs, h_probs, batch_size * vocab_size); + + invokeTopPInitialize(topp_id_vals_buf, + end_offsets, + begin_offsets, + batch_size, + vocab_size, + stream); + + invokeBatchTopPSampling(workspace, + workspace_size, + cub_temp_storage_size, + output_ids + step * batch_size, + sequence_lengths, + finished, + cum_log_probs, + output_log_probs + step * batch_size, + // Note that the kernel needs vocab probs instead of + // log-prob if cum_log_probs or output_log_probs are + // provided. It's because the sampling layer already + // preprocesses log_prob_buf when those are provided. + probs, + topp_id_vals_buf, + end_offsets, + begin_offsets, + curand_states, + batch_size, + vocab_size, + end_ids, + max_top_p, + top_ps, + stream, + &device_prop, + skip_decode); + + // Compute reference. + cudaD2Hcpy(h_output_ids, output_ids + step * batch_size, batch_size); + cudaD2Hcpy(h_output_log_probs, output_log_probs + step * batch_size, batch_size); + cudaD2Hcpy(h_cum_log_probs, cum_log_probs, batch_size); + cudaD2Hcpy(h_seq_lengths, sequence_lengths, batch_size); + cudaD2Hcpy(h_finished, finished, batch_size); + computeLogProb(h_log_probs, h_logits, batch_size, vocab_size); + for (size_t i = 0; i < batch_size; ++i) { + if (!h_skip_decode[i]) { + int idx = i * vocab_size + h_output_ids[i]; + bool expected_finished = h_output_ids[i] == end_id; + float expected_log_prob = (int)step < h_seq_lengths[i] ? (float)h_log_probs[idx] : 0.0f; + expected_cum_log_probs[i] += expected_log_prob; + EXPECT_TRUE(h_finished[i] == expected_finished); + } + } + } + std::string tag = toTestTag("TestBatchTopPSamplingWithSkipDecode", tc, std::is_same::value); + bool passed = checkResult(tag, cum_log_probs, expected_cum_log_probs, batch_size); + + delete[] expected_cum_log_probs; + delete[] h_seq_lengths; + delete[] h_output_log_probs; + delete[] h_cum_log_probs; + delete[] h_logits; + delete[] h_log_probs; + delete[] h_probs; + delete[] h_output_ids; + delete[] h_top_ps; + delete allocator; + check_cuda_error(cudaStreamDestroy(stream)); + EXPECT_TRUE(passed); +} + +int main() { + std::vector topk_test_cases { + // TC: name / batch / vocab / beam / k / p / outlen + TestCase{"topk", 6, 4, 1, 1, 0.0f, 1}, + TestCase{"topk", 6, 4, 1, 4, 0.0f, 1}, + TestCase{"topk", 128, 51200, 1, 1, 0.0f, 8}, + TestCase{"topk", 128, 51200, 1, 63, 0.0f, 8} + }; + for (auto &tc : topk_test_cases) { + testTopKSamplingKernel(tc); + testTopKSamplingKernel(tc); + testBatchTopKSamplingKernel(tc, false); + testBatchTopKSamplingKernel(tc, false); + testBatchTopKSamplingKernel(tc, true); + testBatchTopKSamplingKernel(tc, true); + testBatchTopKSamplingWithSkipDecode(tc); + testBatchTopKSamplingWithSkipDecode(tc); + } + + std::vector topp_test_cases { + // TC: name / batch / vocab / beam / k / p / outlen + TestCase{"topp", 6, 4, 1, 0, 0.2f, 1}, + TestCase{"topp", 6, 4, 1, 0, 0.9f, 1}, + TestCase{"topp", 6, 4, 1, 0, 1.0f, 1}, + TestCase{"topp", 128, 51200, 1, 0, 0.8f, 16}, + TestCase{"topp", 128, 51200, 1, 0, 1.0f, 16} + }; + + for (auto &tc : topp_test_cases) { + testBatchTopPSamplingKernel(tc, false); + testBatchTopPSamplingKernel(tc, false); + testBatchTopPSamplingKernel(tc, true); + testBatchTopPSamplingKernel(tc, true); + testBatchTopPSamplingWithSkipDecode(tc); + testBatchTopPSamplingWithSkipDecode(tc); + } + + FT_LOG_INFO("testTopKSamplingKernel done"); + return 0; +} diff --git a/tests/unittests/test_tensor.cu b/tests/unittests/test_tensor.cu new file mode 100644 index 000000000..d386486d6 --- /dev/null +++ b/tests/unittests/test_tensor.cu @@ -0,0 +1,262 @@ +#include +#include +#include + +#include "src/fastertransformer/utils/Tensor.h" + +using namespace fastertransformer; + +class TestFailureError : public std::exception { +private: + std::string msg_; +public: + explicit TestFailureError() = default; + explicit TestFailureError(std::string name, std::string msg = "") { + msg_ = fmtstr("TEST FAIL [%s] %s", name.c_str(), msg.c_str()); + } + const char* what () const throw () { + return msg_.c_str(); + } +}; + +#define EXPECT_TRUE(cond) \ + do { if(!(cond)) { \ + FT_LOG_ERROR("TEST FAIL [%s] at %s:%d", \ + __func__, __FILE__, __LINE__); \ + throw TestFailureError(__func__); \ + } } while(false) + +#define EXPECT_FALSE(cond) \ + do { if(cond) { \ + FT_LOG_ERROR("TEST FAIL [%s] at %s:%d", \ + __func__, __FILE__, __LINE__); \ + throw TestFailureError(__func__); \ + } } while(false) + +#define EXPECT_EQUAL_TENSORS(t1, t2) \ + do { \ + EXPECT_TRUE(t1.where == t2.where); \ + EXPECT_TRUE(t1.type == t2.type); \ + EXPECT_TRUE(t1.shape == t2.shape); \ + EXPECT_TRUE(t1.data == t2.data); \ + } while(false) + +void testTensorMapHasKey() { + bool* v1 = new bool(true); + float* v2 = new float[6]{1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f}; + Tensor t1 = Tensor{MEMORY_CPU, TYPE_BOOL, {1}, v1}; + Tensor t2 = Tensor{MEMORY_CPU, TYPE_FP32, {3, 2}, v2}; + + TensorMap map({{"t1", t1}, {"t2", t2}}); + EXPECT_TRUE(map.isExist("t1")); + EXPECT_TRUE(map.isExist("t2")); + EXPECT_FALSE(map.isExist("t3")); + + delete v1; + delete[] v2; +} + +void testTensorMapInsert() { + int* v1 = new int[4]{1, 10, 20, 30}; + float* v2 = new float[2]{1.0f, 2.0f}; + Tensor t1 = Tensor(MEMORY_CPU, TYPE_INT32, {4}, v1); + Tensor t2 = Tensor(MEMORY_CPU, TYPE_INT32, {2}, v2); + + TensorMap map({{"t1", t1}}); + EXPECT_TRUE(map.size() == 1); + EXPECT_TRUE(map.isExist("t1")); + EXPECT_EQUAL_TENSORS(map.at("t1"), t1); + EXPECT_FALSE(map.isExist("t2")); + + // forbid a none tensor. + try { + map.insert("none", {}); + map.insert("empty", Tensor(MEMORY_CPU, TYPE_INT32, {}, nullptr)); + EXPECT_TRUE(false); + } catch (std::runtime_error& e) { + EXPECT_TRUE(true); + } + EXPECT_TRUE(map.size() == 1); + + // forbid a duplicated key. + try { + map.insert("t1", t2); + EXPECT_TRUE(false); + } catch (std::runtime_error& e) { + EXPECT_TRUE(true); + } + EXPECT_TRUE(map.size() == 1); + + map.insert("t2", t2); + EXPECT_TRUE(map.size() == 2); + EXPECT_EQUAL_TENSORS(map.at("t2"), t2); + + delete[] v1; + delete[] v2; +} + +void testTensorMapGetVal() { + int* v1 = new int[4]{1, 10, 20, 30}; + Tensor t1 = Tensor(MEMORY_CPU, TYPE_INT32, {4}, v1); + + TensorMap map({{"t1", t1}}); + EXPECT_TRUE(map.size() == 1); + + try { + int val = map.getVal("t3"); + EXPECT_TRUE(false); + } catch(std::runtime_error& e) { + EXPECT_TRUE(true); + } + EXPECT_TRUE(map.getVal("t1") == 1); + EXPECT_TRUE(map.getVal("t1", 3) == 1); + EXPECT_TRUE(map.getVal("t2", 3) == 3); + + v1[0] += 1; // update value. + EXPECT_TRUE(map.getVal("t1") == 2); + EXPECT_TRUE(map.getVal("t1", 3) == 2); + + size_t index = 2; + EXPECT_TRUE(map.getValWithOffset("t1", index) == 20); + EXPECT_TRUE(map.getValWithOffset("t1", index, 3) == 20); + EXPECT_TRUE(map.getValWithOffset("t2", index, 3) == 3); + delete[] v1; +} + +void testTensorMapGetTensor() { + bool* t1_val = new bool(true); + float* t2_val = new float[6]{1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f}; + Tensor t1 = Tensor{MEMORY_CPU, TYPE_BOOL, {1}, t1_val}; + Tensor t2 = Tensor{MEMORY_CPU, TYPE_FP32, {3, 2}, t2_val}; + + int* default_val = new int[4]{0, 1, 2, 3}; + Tensor default_tensor = Tensor{MEMORY_CPU, TYPE_INT32, {4}, default_val}; + + TensorMap map({{"t1", t1}, {"t2", t2}}); + + try { + Tensor t = map.at("t3"); + EXPECT_TRUE(false); + } catch(std::runtime_error& e) { + EXPECT_TRUE(true); + } + EXPECT_EQUAL_TENSORS(map.at("t1", default_tensor), t1); + EXPECT_EQUAL_TENSORS(map.at("t2", default_tensor), t2); + EXPECT_EQUAL_TENSORS(map.at("t3", default_tensor), default_tensor); + EXPECT_EQUAL_TENSORS(map.at("t3", {}), Tensor()); + + delete[] default_val; + delete[] t2_val; + delete[] t1_val; +} + +void testEmptyTensorMinMaxRaiseError() { + Tensor t1; + try { + int minval = t1.min(); + int maxval = t1.max(); + EXPECT_TRUE(false); + } catch (std::runtime_error& e) { + EXPECT_TRUE(true); + } + + Tensor t2 = Tensor{MEMORY_CPU, TYPE_INT32, {1}, nullptr}; + try { + int minval = t2.min(); + int maxval = t2.max(); + EXPECT_TRUE(false); + } catch (std::runtime_error& e) { + EXPECT_TRUE(true); + } +} + +template +void testTensorMinMax() { + constexpr int SIZE = 4; + constexpr T MAX_VAL = T(4); + constexpr T MIN_VAL = T(1); + + T* v1 = new T[SIZE]{T(1), T(2), T(3), T(4)}; + T* v2 = new T[SIZE]{T(4), T(3), T(2), T(1)}; + T* v3 = new T[SIZE]{T(1), T(2), T(4), T(3)}; + Tensor t1 = Tensor{MEMORY_CPU, getTensorType(), {SIZE}, v1}; + Tensor t2 = Tensor{MEMORY_CPU, getTensorType(), {SIZE}, v2}; + Tensor t3 = Tensor{MEMORY_CPU, getTensorType(), {SIZE}, v3}; + + EXPECT_TRUE(t1.max() == MAX_VAL); + EXPECT_TRUE(t2.max() == MAX_VAL); + EXPECT_TRUE(t3.max() == MAX_VAL); + EXPECT_TRUE(t1.min() == MIN_VAL); + EXPECT_TRUE(t2.min() == MIN_VAL); + EXPECT_TRUE(t3.min() == MIN_VAL); + + delete[] v1; + delete[] v2; + delete[] v3; +} + +template +void testTensorAny() { + constexpr int SIZE = 4; + T* v = new T[SIZE]{T(1), T(2), T(3), T(4)}; + Tensor t = Tensor{MEMORY_CPU, getTensorType(), {SIZE}, v}; + EXPECT_TRUE(t.any(T(1))); + EXPECT_FALSE(t.any(T(5))); + delete[] v; +} + +template +void testTensorAll() { + constexpr int SIZE = 4; + T* v1 = new T[SIZE]{T(1), T(1), T(1), T(1)}; + T* v2 = new T[SIZE]{T(1), T(1), T(1), T(2)}; + Tensor t1 = Tensor{MEMORY_CPU, getTensorType(), {SIZE}, v1}; + Tensor t2 = Tensor{MEMORY_CPU, getTensorType(), {SIZE}, v2}; + EXPECT_TRUE(t1.all(T(1))); + EXPECT_FALSE(t2.all(T(2))); + delete[] v1; + delete[] v2; +} + + +template +void testTensorSlice() { + constexpr int SIZE = 12; + T* v = new T[SIZE]; + for (int i = 0; i < SIZE; ++i) { + v[i] = i; + } + DataType dtype = getTensorType(); + Tensor t1 = Tensor(MEMORY_CPU, dtype, {3, 4}, v); + Tensor t2 = t1.slice({2, 4}, 4); + EXPECT_EQUAL_TENSORS(t2, Tensor(MEMORY_CPU, dtype, {2, 4}, &v[4])); + try { + Tensor overflowed_tensor = t1.slice({2, 4}, 5); + EXPECT_TRUE(false); + } catch (std::runtime_error& e) { + EXPECT_TRUE(true); + } + delete[] v; +} + +int main() { + testTensorMapHasKey(); + testTensorMapInsert(); + testTensorMapGetVal(); + testTensorMapGetTensor(); + testEmptyTensorMinMaxRaiseError(); + testTensorMinMax(); + testTensorMinMax(); + testTensorMinMax(); + testTensorAny(); + testTensorAny(); + testTensorAny(); + testTensorAll(); + testTensorAll(); + testTensorAll(); + testTensorSlice(); + testTensorSlice(); + testTensorSlice(); + FT_LOG_INFO("Test Done"); + return 0; +} diff --git a/tests/unittests/unittest_utils.h b/tests/unittests/unittest_utils.h new file mode 100644 index 000000000..f6a4c32f9 --- /dev/null +++ b/tests/unittests/unittest_utils.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // min, max +#include // assert +#include // FLT_MAX +#include // snprintf +#include // expf, log +#include // numeric_limits +#include // rand +#include // string +#include // vector + +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/memory_utils.h" +#include "src/fastertransformer/utils/string_utils.h" + +#define PRINT_LIMIT 16 +#define EPSILON (1e-20) +#define EPSILON_FP16 (1e-10) + +using namespace fastertransformer; + +class TestFailureError : public std::exception { +private: + std::string msg_; +public: + explicit TestFailureError() = default; + explicit TestFailureError(std::string name, std::string msg = "") { + msg_ = fmtstr("TEST FAIL [%s] %s", name.c_str(), msg.c_str()); + } + const char* what () const throw () { + return msg_.c_str(); + } +}; + +#define EXPECT_TRUE(cond) \ + do { if(!(cond)) { \ + FT_LOG_ERROR("TEST FAIL [%s]: %s at %s:%d", \ + __func__, #cond, __FILE__, __LINE__); \ + throw TestFailureError(__func__); \ + } } while(false) + +#define EXPECT_FALSE(cond) \ + do { if(cond) { \ + FT_LOG_ERROR("TEST FAIL [%s]: %s at %s:%d", \ + __func__, #cond, __FILE__, __LINE__); \ + throw TestFailureError(__func__); \ + } } while(false) + +bool almostEqual(float a, float b, float atol = 1e-5, float rtol = 1e-8) +{ + // Params: a = value to compare and b = reference + // This function follows implementation of numpy.isclose(), which checks + // abs(a - b) <= (atol + rtol * abs(b)). + // Note that the inequality above is asymmetric where b is considered as + // a reference value. To account into both absolute/relative errors, it + // uses absolute tolerance and relative tolerance at the same time. The + // default values of atol and rtol borrowed from numpy.isclose(). For the + // case of nan value, the result will be true. + if (isnan(a) && isnan(b)) { + return true; + } + return fabs(a - b) <= (atol + rtol * fabs(b)); +} + +template +bool checkResult(std::string name, T* out, T*ref, size_t size, float atol, float rtol) { + size_t failures = 0; + float relative_gap = 0.0f;; + + for (size_t i = 0; i < size; ++i) { + // The values for the output and the reference. + float a = (float)out[i]; + float b = (float)ref[i]; + + bool ok = almostEqual(a, b, atol, rtol); + // Print the error. + if (!ok && failures < 4) { + FT_LOG_ERROR(">> invalid result for i=%lu:", i); + FT_LOG_ERROR(">> found......: %10.6f", a); + FT_LOG_ERROR(">> expected...: %10.6f", b); + FT_LOG_ERROR(">> error......: %.6f", fabsf(a - b)); + FT_LOG_ERROR(">> tol........: %.6f", atol + rtol * fabs(b)); + } + // Update the number of failures. + failures += ok ? 0 : 1; + // Update the relative gap. + relative_gap += fabsf(a - b) / (fabsf(b) + EPSILON); + } + + relative_gap /= size; + + // Allow not matched up to 1% elements. + size_t tol_failures = (size_t)(0.01 * size); + FT_LOG_INFO("check...%6s : %-50s (failures: %.2f%% atol: %.2e rtol: %.2e rel_gap: %.2e%%)", + failures <= tol_failures ? "....OK" : "FAILED", name.c_str(), + 100. * failures / size, atol, rtol, 100. * relative_gap); + return failures <= tol_failures; +} + +template +bool checkResult(std::string name, T* out, T* ref, size_t size, + bool device_out = true, bool device_ref = false) +{ + bool is_fp32 = sizeof(T) == 4; + float atol = is_fp32 ? 1e-4f : 1e-3f; + float rtol = is_fp32 ? 1e-2f : 1e-1f; + + T* h_out = nullptr; + if (device_out) { + h_out = new T[size]; + cudaMemcpy(h_out, out, sizeof(T) * size, cudaMemcpyDeviceToHost); + out = h_out; + } + T* h_ref = nullptr; + if (device_ref) { + h_ref = new T[size]; + cudaMemcpy(h_ref, ref, sizeof(T) * size, cudaMemcpyDeviceToHost); + ref = h_ref; + } + bool is_ok = checkResult(name, out, ref, size, atol, rtol); + if (h_out != nullptr){ + delete[] h_out; + } + if (h_ref != nullptr) { + delete[] h_ref; + } + return is_ok; +} + +template +void initRandom(T* ptr, size_t size, float minval, float maxval) { + for (size_t i = 0; i < size; ++i) { + float val = static_cast(rand()) / static_cast(RAND_MAX); + val *= (maxval - minval); + ptr[i] = static_cast(minval + val); + } +} + +void initRandomInt(int* ptr, size_t size, int minval, int maxval) { + assert(minval < maxval); + int mod = maxval - minval; + for (size_t i = 0; i < size; ++i) { + ptr[i] = minval + rand() % mod; + } +} + +template +void tile(T* x, int m, int n) { + for (int i = 1; i < m; ++i) { + for (int j = 0; j < n; ++j) { + x[i * n + j] = x[j]; + } + } +} + +template +void tile(T* dst, T* src, int m, int n) { + for (int i = 1; i < m; ++i) { + for (int j = 0; j < n; ++j) { + dst[i * n + j] = src[j]; + } + } +} + +#define HALF_FLT_MAX 65504.0f + +template +bool isHalf() { + return std::is_same::value; +} + +template +static inline void printMatrixWithLimit(T* ptr, int m, int k, int stride, bool is_device_ptr) { + printMatrix(ptr, std::min(PRINT_LIMIT, m), std::min(PRINT_LIMIT, k), stride, is_device_ptr); +}