Skip to content

Commit 00f4163

Browse files
authored
Merge pull request #650 from vlovich/fix-chat-template
Cleanup chat template API
2 parents 5c8e81b + 72c1255 commit 00f4163

File tree

4 files changed

+128
-42
lines changed

4 files changed

+128
-42
lines changed

examples/simple/src/main.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ use anyhow::{anyhow, bail, Context, Result};
1010
use clap::Parser;
1111
use hf_hub::api::sync::ApiBuilder;
1212
use llama_cpp_2::context::params::LlamaContextParams;
13-
use llama_cpp_2::{ggml_time_us, send_logs_to_tracing, LogOptions};
1413
use llama_cpp_2::llama_backend::LlamaBackend;
1514
use llama_cpp_2::llama_batch::LlamaBatch;
1615
use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
1716
use llama_cpp_2::model::params::LlamaModelParams;
1817
use llama_cpp_2::model::LlamaModel;
1918
use llama_cpp_2::model::{AddBos, Special};
2019
use llama_cpp_2::sampling::LlamaSampler;
20+
use llama_cpp_2::{ggml_time_us, send_logs_to_tracing, LogOptions};
2121

2222
use std::ffi::CString;
2323
use std::io::Write;
@@ -67,11 +67,7 @@ struct Args {
6767
help = "size of the prompt context (default: loaded from themodel)"
6868
)]
6969
ctx_size: Option<NonZeroU32>,
70-
#[arg(
71-
short = 'v',
72-
long,
73-
help = "enable verbose llama.cpp logs",
74-
)]
70+
#[arg(short = 'v', long, help = "enable verbose llama.cpp logs")]
7571
verbose: bool,
7672
}
7773

llama-cpp-2/src/lib.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ pub enum LLamaCppError {
6969
/// There was an error while getting the chat template from a model.
7070
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
7171
pub enum ChatTemplateError {
72-
/// the buffer was too small.
73-
#[error("The buffer was too small. However, a buffer size of {0} would be just large enough.")]
74-
BuffSizeError(usize),
7572
/// gguf has no chat template
7673
#[error("the model has no meta val - returned code {0}")]
7774
MissingTemplate(i32),
@@ -80,6 +77,12 @@ pub enum ChatTemplateError {
8077
Utf8Error(#[from] std::str::Utf8Error),
8178
}
8279

80+
enum InternalChatTemplateError {
81+
Permanent(ChatTemplateError),
82+
/// the buffer was too small.
83+
RetryWithLargerBuffer(usize),
84+
}
85+
8386
/// Failed to Load context
8487
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
8588
pub enum LlamaContextLoadError {

llama-cpp-2/src/log.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ impl State {
171171
} else {
172172
let level = self
173173
.previous_level
174-
.load(std::sync::atomic::Ordering::Acquire) as llama_cpp_sys_2::ggml_log_level;
174+
.load(std::sync::atomic::Ordering::Acquire)
175+
as llama_cpp_sys_2::ggml_log_level;
175176
tracing::warn!(
176177
inferred_level = level,
177178
text = text,

llama-cpp-2/src/model.rs

Lines changed: 118 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
//! A safe wrapper around `llama_model`.
2-
use std::ffi::{c_char, CString};
2+
use std::ffi::{c_char, CStr, CString};
33
use std::num::NonZeroU16;
44
use std::os::raw::c_int;
55
use std::path::Path;
66
use std::ptr::NonNull;
7+
use std::str::{FromStr, Utf8Error};
78

89
use crate::context::params::LlamaContextParams;
910
use crate::context::LlamaContext;
@@ -12,8 +13,9 @@ use crate::model::params::LlamaModelParams;
1213
use crate::token::LlamaToken;
1314
use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
1415
use crate::{
15-
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
16-
LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
16+
ApplyChatTemplateError, ChatTemplateError, InternalChatTemplateError, LlamaContextLoadError,
17+
LlamaLoraAdapterInitError, LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError,
18+
TokenToStringError,
1719
};
1820

1921
pub mod params;
@@ -34,6 +36,42 @@ pub struct LlamaLoraAdapter {
3436
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
3537
}
3638

39+
/// A performance-friendly wrapper around [LlamaModel::get_chat_template] which is then
40+
/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM
41+
/// prompt. Internally the template is stored as a CString to avoid round-trip conversions
42+
/// within the FFI.
43+
#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
44+
pub struct LlamaChatTemplate(CString);
45+
46+
impl LlamaChatTemplate {
47+
/// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61)
48+
/// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret.
49+
pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
50+
Ok(Self(CString::from_str(template)?))
51+
}
52+
53+
/// Accesses the template as a c string reference.
54+
pub fn as_c_str(&self) -> &CStr {
55+
&self.0
56+
}
57+
58+
/// Attempts to convert the CString into a Rust str reference.
59+
pub fn to_str(&self) -> Result<&str, Utf8Error> {
60+
self.0.to_str()
61+
}
62+
63+
/// Convenience method to create an owned String.
64+
pub fn to_string(&self) -> Result<String, Utf8Error> {
65+
self.to_str().map(str::to_string)
66+
}
67+
}
68+
69+
impl std::fmt::Debug for LlamaChatTemplate {
70+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71+
self.0.fmt(f)
72+
}
73+
}
74+
3775
/// A Safe wrapper around `llama_chat_message`
3876
#[derive(Debug, Eq, PartialEq, Clone)]
3977
pub struct LlamaChatMessage {
@@ -408,41 +446,84 @@ impl LlamaModel {
408446
unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
409447
}
410448

411-
/// Get chat template from model.
412-
///
413-
/// # Errors
414-
///
415-
/// * If the model has no chat template
416-
/// * If the chat template is not a valid [`CString`].
417-
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
418-
pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
449+
fn get_chat_template_impl(
450+
&self,
451+
capacity: usize,
452+
) -> Result<LlamaChatTemplate, InternalChatTemplateError> {
419453
// longest known template is about 1200 bytes from llama.cpp
420-
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
421-
let chat_ptr = chat_temp.into_raw();
422-
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
454+
// TODO: Once MaybeUninit support is better, this can be converted to use that instead of dummy initializing such a large array.
455+
let mut chat_temp = vec![b'*' as u8; capacity];
456+
let chat_name =
457+
CStr::from_bytes_with_nul(b"tokenizer.chat_template\0").expect("should have null byte");
423458

424459
let ret = unsafe {
425460
llama_cpp_sys_2::llama_model_meta_val_str(
426461
self.model.as_ptr(),
427462
chat_name.as_ptr(),
428-
chat_ptr,
429-
buf_size,
463+
chat_temp.as_mut_ptr() as *mut c_char,
464+
chat_temp.len(),
430465
)
431466
};
432467

433468
if ret < 0 {
434-
return Err(ChatTemplateError::MissingTemplate(ret));
469+
return Err(InternalChatTemplateError::Permanent(
470+
ChatTemplateError::MissingTemplate(ret),
471+
));
435472
}
436473

437-
let template_c = unsafe { CString::from_raw(chat_ptr) };
438-
let template = template_c.to_str()?;
474+
let returned_len = ret as usize;
439475

440-
let ret: usize = ret.try_into().unwrap();
441-
if template.len() < ret {
442-
return Err(ChatTemplateError::BuffSizeError(ret + 1));
476+
if ret as usize >= capacity {
477+
// >= is important because if the returned length is equal to capacity, it means we're missing a trailing null
478+
// since the returned length doesn't count the trailing null.
479+
return Err(InternalChatTemplateError::RetryWithLargerBuffer(
480+
returned_len,
481+
));
443482
}
444483

445-
Ok(template.to_owned())
484+
assert_eq!(
485+
chat_temp.get(returned_len),
486+
Some(&0),
487+
"should end with null byte"
488+
);
489+
490+
chat_temp.resize(returned_len + 1, 0);
491+
492+
Ok(LlamaChatTemplate(unsafe {
493+
CString::from_vec_with_nul_unchecked(chat_temp)
494+
}))
495+
}
496+
497+
/// Get chat template from model. If this fails, you may either want to fail to chat or pick the
498+
/// specific shortcode that llama.cpp supports templates it has baked-in directly into its codebase
499+
/// as fallbacks when the model doesn't contain. NOTE: If you don't specify a chat template, then
500+
/// it uses chatml by default which is unlikely to actually be the correct template for your model
501+
/// and you'll get weird results back.
502+
///
503+
/// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template
504+
/// substitution applied to convert a list of messages into a prompt the LLM can use to complete
505+
/// the chat.
506+
///
507+
/// # Errors
508+
///
509+
/// * If the model has no chat template
510+
/// * If the chat template is not a valid [`CString`].
511+
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
512+
pub fn get_chat_template(&self) -> Result<LlamaChatTemplate, ChatTemplateError> {
513+
// Typical chat templates are quite small. Let's start with a small allocation likely to succeed.
514+
// Ideally the performance of this would be negligible but uninitialized arrays in Rust are currently
515+
// still not well supported so we end up initializing the chat template buffer twice. One idea might
516+
// be to use a very small value here that will likely fail (like 0 or 1) and then use that to initialize.
517+
// Not sure which approach is the most optimal but in practice this should work well.
518+
match self.get_chat_template_impl(200) {
519+
Ok(t) => Ok(t),
520+
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
521+
Err(InternalChatTemplateError::RetryWithLargerBuffer(actual_len)) => match self.get_chat_template_impl(actual_len + 1) {
522+
Ok(t) => Ok(t),
523+
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
524+
Err(InternalChatTemplateError::RetryWithLargerBuffer(unexpected_len)) => panic!("Was told that the template length was {actual_len} but now it's {unexpected_len}"),
525+
}
526+
}
446527
}
447528

448529
/// Loads a model from a file.
@@ -526,15 +607,25 @@ impl LlamaModel {
526607
/// Apply the models chat template to some messages.
527608
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
528609
///
529-
/// `tmpl` of None means to use the default template provided by llama.cpp for the model
610+
/// Unlike the llama.cpp apply_chat_template which just randomly uses the ChatML template when given
611+
/// a null pointer for the template, this requires an explicit template to be specified. If you want to
612+
/// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
613+
/// string.
614+
///
615+
/// Use [Self::get_chat_template] to retrieve the template baked into the model (this is the preferred
616+
/// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
617+
///
618+
/// You probably want to set `add_ass` to true so that the generated template string ends with a the
619+
/// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
620+
/// one into the output and the output may also have unexpected output aside from that.
530621
///
531622
/// # Errors
532623
/// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
533624
#[tracing::instrument(skip_all)]
534625
pub fn apply_chat_template(
535626
&self,
536-
tmpl: Option<String>,
537-
chat: Vec<LlamaChatMessage>,
627+
tmpl: &LlamaChatTemplate,
628+
chat: &[LlamaChatMessage],
538629
add_ass: bool,
539630
) -> Result<String, ApplyChatTemplateError> {
540631
// Buffer is twice the length of messages per their recommendation
@@ -552,12 +643,7 @@ impl LlamaModel {
552643
})
553644
.collect();
554645

555-
// Set the tmpl pointer
556-
let tmpl = tmpl.map(CString::new);
557-
let tmpl_ptr = match &tmpl {
558-
Some(str) => str.as_ref().map_err(Clone::clone)?.as_ptr(),
559-
None => std::ptr::null(),
560-
};
646+
let tmpl_ptr = tmpl.0.as_ptr();
561647

562648
let res = unsafe {
563649
llama_cpp_sys_2::llama_chat_apply_template(

0 commit comments

Comments
 (0)