Skip to content

Commit

Permalink
Reduce max memory usage in init stage (#246)
Browse files Browse the repository at this point in the history
use vector::reserve to solve memory spike caused by the C++ vector's expansion feature when init weights.
  • Loading branch information
prnake authored Jul 24, 2022
1 parent 0313c99 commit 43ae78a
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/fastertransformer/models/bert/BertWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct BertWeight {
deviceMalloc(&weights_ptr[1], hidden_units_);

setWeightPtr();
bert_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
bert_layer_weights.push_back(BertLayerWeight<T>(hidden_units_, inter_size_));
}
Expand All @@ -54,6 +55,7 @@ struct BertWeight {
hidden_units_(other.hidden_units_), inter_size_(other.inter_size_), num_layer_(other.num_layer_)
{
bert_layer_weights.clear();
bert_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
bert_layer_weights.push_back(other.bert_layer_weights[i]);
}
Expand All @@ -71,6 +73,7 @@ struct BertWeight {
inter_size_ = other.inter_size_;
num_layer_ = other.num_layer_;
bert_layer_weights.clear();
bert_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
bert_layer_weights.push_back(other.bert_layer_weights[i]);
}
Expand Down
2 changes: 2 additions & 0 deletions src/fastertransformer/models/gpt/GptWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ GptWeight<T>::GptWeight(
num_layer_(num_layer),
max_seq_len_(max_seq_len)
{
decoder_layer_weights.reserve(num_layer_);
for (int l = 0; l < num_layer_; l++) {
decoder_layer_weights.push_back(GptDecoderLayerWeight<T>(hidden_units_, inter_size_));
}
Expand Down Expand Up @@ -70,6 +71,7 @@ GptWeight<T>::GptWeight(const GptWeight& other):
setWeightPtr();

decoder_layer_weights.clear();
decoder_layer_weights.reserve(num_layer_);
for (int l = 0; l < num_layer_; l++) {
decoder_layer_weights.push_back(other.decoder_layer_weights[l]);
}
Expand Down
2 changes: 2 additions & 0 deletions src/fastertransformer/models/gptj/GptJWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ GptJWeight<T>::GptJWeight(const int hidden_units,
layer_para_size_(layer_para_size),
layer_para_rank_(layer_para_rank)
{
decoder_layer_weights.reserve(num_layer_);
for (int l = 0; l < num_layer_; l++) {
if (isValidLayerParallelId(l)) {
decoder_layer_weights.push_back(
Expand Down Expand Up @@ -92,6 +93,7 @@ GptJWeight<T>::GptJWeight(const GptJWeight& other):
setWeightPtr();

decoder_layer_weights.clear();
decoder_layer_weights.reserve(num_layer_);
for (int l = 0; l < num_layer_; l++) {
decoder_layer_weights.push_back(other.decoder_layer_weights[l]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ ParallelGptWeight<T>::ParallelGptWeight(const int hidden_units,
int8_mode_(int8_mode)
{
decoder_layer_weights.clear();
decoder_layer_weights.reserve(num_layer_);
for (int l = 0; l < num_layer_; l++) {
if (isValidLayerParallelId(l)) {
decoder_layer_weights.push_back(new ParallelGptDecoderLayerWeight<T>(
Expand Down Expand Up @@ -99,6 +100,7 @@ ParallelGptWeight<T>::ParallelGptWeight(const ParallelGptWeight& other):
setWeightPtr();

decoder_layer_weights.clear();
decoder_layer_weights.reserve(num_layer_);
for (int l = 0; l < num_layer_; l++) {
decoder_layer_weights.push_back(other.decoder_layer_weights[l]);
}
Expand Down
3 changes: 3 additions & 0 deletions src/fastertransformer/models/t5/T5EncoderWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ T5EncoderWeight<T>::T5EncoderWeight(const size_t head_num,
mallocWeights();
setWeightPtr();
t5_encoder_layer_weights.clear();
t5_encoder_layer_weights.reserve(num_layer_);
for (int l = 0; l < num_layer_; l++) {
if (isValidLayerParallelId(l)) {
t5_encoder_layer_weights.push_back(new T5EncoderLayerWeight<T>(
Expand Down Expand Up @@ -130,6 +131,7 @@ T5EncoderWeight<T>::T5EncoderWeight(const T5EncoderWeight& other):
setWeightPtr();

t5_encoder_layer_weights.clear();
t5_encoder_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
t5_encoder_layer_weights.push_back(new T5EncoderLayerWeight<T>(*other.t5_encoder_layer_weights[i]));
}
Expand Down Expand Up @@ -163,6 +165,7 @@ T5EncoderWeight<T>& T5EncoderWeight<T>::operator=(const T5EncoderWeight& other)
setWeightPtr();

t5_encoder_layer_weights.clear();
t5_encoder_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
t5_encoder_layer_weights.push_back(new T5EncoderLayerWeight<T>(*other.t5_encoder_layer_weights[i]));
}
Expand Down
3 changes: 3 additions & 0 deletions src/fastertransformer/models/vit/ViTWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ struct ViTWeight {
setWeightPtr();
}

vit_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
vit_layer_weights.push_back(ViTLayerWeight<T>(embed_dim_, inter_size_, i, hold_buffer));
}
Expand Down Expand Up @@ -115,6 +116,7 @@ struct ViTWeight {
}

vit_layer_weights.clear();
vit_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
vit_layer_weights.push_back(other.vit_layer_weights[i]);
}
Expand Down Expand Up @@ -143,6 +145,7 @@ struct ViTWeight {
}

vit_layer_weights.clear();
vit_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
vit_layer_weights.push_back(other.vit_layer_weights[i]);
}
Expand Down
3 changes: 3 additions & 0 deletions src/fastertransformer/models/vit_int8/ViTINT8Weight.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct ViTINT8Weight {
deviceMalloc(&weights_ptr[5], embed_dim_); // pre_encoder_conv_weights.bias

setWeightPtr();
vit_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
vit_layer_weights.push_back(ViTLayerINT8Weight<T>(embed_dim_, inter_size_));
}
Expand Down Expand Up @@ -94,6 +95,7 @@ struct ViTINT8Weight {
cls_num_(other.cls_num_)
{
vit_layer_weights.clear();
vit_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
vit_layer_weights.push_back(other.vit_layer_weights[i]);
}
Expand Down Expand Up @@ -132,6 +134,7 @@ struct ViTINT8Weight {
chn_num_ = other.chn_num_;
cls_num_ = other.cls_num_;
vit_layer_weights.clear();
vit_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; i++) {
vit_layer_weights.push_back(other.vit_layer_weights[i]);
}
Expand Down

0 comments on commit 43ae78a

Please sign in to comment.