Skip to content

Commit

Permalink
add missing use_dynamic_ntk
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 3, 2024
1 parent 0513e12 commit df0b1ff
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/unified_attention_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ inline void UnifiedAttentionLayer<T>::forward(TensorMap* outputs, const TensorMa
bool* is_finished = inputs->getPtr<bool>("finished");
float* rope_theta = inputs->getPtr<float>("rope_theta");

float* cos_sin = inputs->at("cos_sin", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr<float>();
float* cos_sin = inputs->getPtr<float>("cos_sin");

void** block_ptrs = outputs->getPtr<void*>("block_ptrs");
int* cu_block_count = inputs->getPtr<int>("cu_block_counts");
Expand Down
7 changes: 6 additions & 1 deletion src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,12 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
attn_param_.softmax_scale = attention_reader["softmax_scale"].as<float>(0);
attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as<int>(0);
// rotary embedding parameters
attn_param_.rope.type = GetRoPEType(attention_reader["rope_scaling_type"].as<std::string>(""));
if (attention_reader["use_dynamic_ntk"].as<int>(0) == 1) {
attn_param_.rope.type = RopeType::kDynamic;
}
else {
attn_param_.rope.type = GetRoPEType(attention_reader["rope_scaling_type"].as<std::string>(""));
}
attn_param_.rope.dim = attention_reader["rotary_embedding"].as<int>();
attn_param_.rope.base = attention_reader["rope_theta"].as<float>(10000.0f);
attn_param_.rope.max_position_embeddings = attention_reader["max_position_embeddings"].as<int>(0);
Expand Down

0 comments on commit df0b1ff

Please sign in to comment.