From 43ae78abfaa13a920ac1930c23615fe28c0e9819 Mon Sep 17 00:00:00 2001 From: prnake Date: Mon, 25 Jul 2022 07:31:49 +0800 Subject: [PATCH] Reduce max memory usage in init stage (#246) use vector::reserve to solve memory spike caused by the C++ vector's expansion feature when init weights. --- src/fastertransformer/models/bert/BertWeight.h | 3 +++ src/fastertransformer/models/gpt/GptWeight.cc | 2 ++ src/fastertransformer/models/gptj/GptJWeight.cc | 2 ++ .../models/multi_gpu_gpt/ParallelGptWeight.cc | 2 ++ src/fastertransformer/models/t5/T5EncoderWeight.cc | 3 +++ src/fastertransformer/models/vit/ViTWeight.h | 3 +++ src/fastertransformer/models/vit_int8/ViTINT8Weight.h | 3 +++ 7 files changed, 18 insertions(+) diff --git a/src/fastertransformer/models/bert/BertWeight.h b/src/fastertransformer/models/bert/BertWeight.h index 67ead51d3..ec9341361 100644 --- a/src/fastertransformer/models/bert/BertWeight.h +++ b/src/fastertransformer/models/bert/BertWeight.h @@ -31,6 +31,7 @@ struct BertWeight { deviceMalloc(&weights_ptr[1], hidden_units_); setWeightPtr(); + bert_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { bert_layer_weights.push_back(BertLayerWeight(hidden_units_, inter_size_)); } @@ -54,6 +55,7 @@ struct BertWeight { hidden_units_(other.hidden_units_), inter_size_(other.inter_size_), num_layer_(other.num_layer_) { bert_layer_weights.clear(); + bert_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { bert_layer_weights.push_back(other.bert_layer_weights[i]); } @@ -71,6 +73,7 @@ struct BertWeight { inter_size_ = other.inter_size_; num_layer_ = other.num_layer_; bert_layer_weights.clear(); + bert_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { bert_layer_weights.push_back(other.bert_layer_weights[i]); } diff --git a/src/fastertransformer/models/gpt/GptWeight.cc b/src/fastertransformer/models/gpt/GptWeight.cc index c4a7b6d63..31ccc06be 100644 --- a/src/fastertransformer/models/gpt/GptWeight.cc +++ b/src/fastertransformer/models/gpt/GptWeight.cc @@ -27,6 +27,7 @@ GptWeight::GptWeight( num_layer_(num_layer), max_seq_len_(max_seq_len) { + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(GptDecoderLayerWeight(hidden_units_, inter_size_)); } @@ -70,6 +71,7 @@ GptWeight::GptWeight(const GptWeight& other): setWeightPtr(); decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(other.decoder_layer_weights[l]); } diff --git a/src/fastertransformer/models/gptj/GptJWeight.cc b/src/fastertransformer/models/gptj/GptJWeight.cc index 241bbf5a4..ce8441db2 100644 --- a/src/fastertransformer/models/gptj/GptJWeight.cc +++ b/src/fastertransformer/models/gptj/GptJWeight.cc @@ -38,6 +38,7 @@ GptJWeight::GptJWeight(const int hidden_units, layer_para_size_(layer_para_size), layer_para_rank_(layer_para_rank) { + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { decoder_layer_weights.push_back( @@ -92,6 +93,7 @@ GptJWeight::GptJWeight(const GptJWeight& other): setWeightPtr(); decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(other.decoder_layer_weights[l]); } diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc index 489b07a66..ee3a82b57 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc @@ -41,6 +41,7 @@ ParallelGptWeight::ParallelGptWeight(const int hidden_units, int8_mode_(int8_mode) { decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { decoder_layer_weights.push_back(new ParallelGptDecoderLayerWeight( @@ -99,6 +100,7 @@ ParallelGptWeight::ParallelGptWeight(const ParallelGptWeight& other): setWeightPtr(); decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { decoder_layer_weights.push_back(other.decoder_layer_weights[l]); } diff --git a/src/fastertransformer/models/t5/T5EncoderWeight.cc b/src/fastertransformer/models/t5/T5EncoderWeight.cc index 8057bd173..7995e0d7e 100644 --- a/src/fastertransformer/models/t5/T5EncoderWeight.cc +++ b/src/fastertransformer/models/t5/T5EncoderWeight.cc @@ -53,6 +53,7 @@ T5EncoderWeight::T5EncoderWeight(const size_t head_num, mallocWeights(); setWeightPtr(); t5_encoder_layer_weights.clear(); + t5_encoder_layer_weights.reserve(num_layer_); for (int l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { t5_encoder_layer_weights.push_back(new T5EncoderLayerWeight( @@ -130,6 +131,7 @@ T5EncoderWeight::T5EncoderWeight(const T5EncoderWeight& other): setWeightPtr(); t5_encoder_layer_weights.clear(); + t5_encoder_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { t5_encoder_layer_weights.push_back(new T5EncoderLayerWeight(*other.t5_encoder_layer_weights[i])); } @@ -163,6 +165,7 @@ T5EncoderWeight& T5EncoderWeight::operator=(const T5EncoderWeight& other) setWeightPtr(); t5_encoder_layer_weights.clear(); + t5_encoder_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { t5_encoder_layer_weights.push_back(new T5EncoderLayerWeight(*other.t5_encoder_layer_weights[i])); } diff --git a/src/fastertransformer/models/vit/ViTWeight.h b/src/fastertransformer/models/vit/ViTWeight.h index 356a6f6d9..6d3b7ea66 100644 --- a/src/fastertransformer/models/vit/ViTWeight.h +++ b/src/fastertransformer/models/vit/ViTWeight.h @@ -68,6 +68,7 @@ struct ViTWeight { setWeightPtr(); } + vit_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { vit_layer_weights.push_back(ViTLayerWeight(embed_dim_, inter_size_, i, hold_buffer)); } @@ -115,6 +116,7 @@ struct ViTWeight { } vit_layer_weights.clear(); + vit_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { vit_layer_weights.push_back(other.vit_layer_weights[i]); } @@ -143,6 +145,7 @@ struct ViTWeight { } vit_layer_weights.clear(); + vit_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { vit_layer_weights.push_back(other.vit_layer_weights[i]); } diff --git a/src/fastertransformer/models/vit_int8/ViTINT8Weight.h b/src/fastertransformer/models/vit_int8/ViTINT8Weight.h index 2a5b54def..e5451ad4f 100644 --- a/src/fastertransformer/models/vit_int8/ViTINT8Weight.h +++ b/src/fastertransformer/models/vit_int8/ViTINT8Weight.h @@ -60,6 +60,7 @@ struct ViTINT8Weight { deviceMalloc(&weights_ptr[5], embed_dim_); // pre_encoder_conv_weights.bias setWeightPtr(); + vit_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { vit_layer_weights.push_back(ViTLayerINT8Weight(embed_dim_, inter_size_)); } @@ -94,6 +95,7 @@ struct ViTINT8Weight { cls_num_(other.cls_num_) { vit_layer_weights.clear(); + vit_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { vit_layer_weights.push_back(other.vit_layer_weights[i]); } @@ -132,6 +134,7 @@ struct ViTINT8Weight { chn_num_ = other.chn_num_; cls_num_ = other.cls_num_; vit_layer_weights.clear(); + vit_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; i++) { vit_layer_weights.push_back(other.vit_layer_weights[i]); }