diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 8a1de364b..8768e7fd0 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -256,7 +256,7 @@ void LlamaV2::contextDecode(T* deocder_output, }; std::unordered_map decoder_output_tensors{ - {"decoder_output", {MEMORY_GPU, dtype, {bsz, max_input_len, hidden_units_}, context_decoder_output_buf}}, + {"decoder_output", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_output_buf}}, {"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_cache_ptr}}, {"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_cache_ptr}}, {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, deocder_output}}};