diff --git a/crates/llm-chain-gemma-sys/src/bindings.cc b/crates/llm-chain-gemma-sys/src/bindings.cc index 3fecf508..329353b4 100644 --- a/crates/llm-chain-gemma-sys/src/bindings.cc +++ b/crates/llm-chain-gemma-sys/src/bindings.cc @@ -22,32 +22,32 @@ gcpp::ModelTraining gcpp_LoaderArgs_ModelTraining(const gcpp::LoaderArgs* args) return args->ModelTraining(); } -void gcpp_LoaderArgs_SetTokenizer(gcpp::LoaderArgs* args, char* path) { - args->tokenizer.path = std::string(path); +void gcpp_LoaderArgs_SetTokenizer(gcpp::LoaderArgs* args, char* path, size_t n) { + args->tokenizer.path = std::string(path, n); } const char* gcpp_LoaderArgs_Tokenizer(gcpp::LoaderArgs* args) { return args->tokenizer.path.c_str(); } -void gcpp_LoaderArgs_SetModel(gcpp::LoaderArgs* args, char* path) { - args->model.path = std::string(path); +void gcpp_LoaderArgs_SetModel(gcpp::LoaderArgs* args, char* path, size_t n) { + args->model.path = std::string(path, n); } const char* gcpp_LoaderArgs_Model(gcpp::LoaderArgs* args) { return args->model.path.c_str(); } -void gcpp_LoaderArgs_SetCache(gcpp::LoaderArgs* args, char* path) { - args->cache.path = std::string(path); +void gcpp_LoaderArgs_SetCache(gcpp::LoaderArgs* args, char* path, size_t n) { + args->cache.path = std::string(path, n); } const char* gcpp_LoaderArgs_Cache(gcpp::LoaderArgs* args) { return args->cache.path.c_str(); } -void gcpp_LoaderArgs_SetModelTypeValue(gcpp::LoaderArgs* args, char* v) { - args->model_type = std::string(v); +void gcpp_LoaderArgs_SetModelTypeValue(gcpp::LoaderArgs* args, char* v, size_t n) { + args->model_type = std::string(v, n); } const char* gcpp_LoaderArgs_ModelTypeValue(gcpp::LoaderArgs* args) { diff --git a/crates/llm-chain-gemma-sys/src/bindings.rs b/crates/llm-chain-gemma-sys/src/bindings.rs index ebb4428f..e0922ecc 100644 --- a/crates/llm-chain-gemma-sys/src/bindings.rs +++ b/crates/llm-chain-gemma-sys/src/bindings.rs @@ -23,13 +23,13 @@ extern "C" { pub fn gcpp_LoaderArgs_Validate(largs: *mut gcpp_LoaderArgs) -> *const ffi::c_char; pub fn gcpp_LoaderArgs_ModelType(largs: *const gcpp_LoaderArgs) -> gcpp_Model; pub fn gcpp_LoaderArgs_ModelTraining(largs: *const gcpp_LoaderArgs) -> gcpp_ModelTraining; - pub fn gcpp_LoaderArgs_SetTokenizer(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char); + pub fn gcpp_LoaderArgs_SetTokenizer(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint); pub fn gcpp_LoaderArgs_Tokenizer(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; - pub fn gcpp_LoaderArgs_SetModel(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char); + pub fn gcpp_LoaderArgs_SetModel(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint); pub fn gcpp_LoaderArgs_Model(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; - pub fn gcpp_LoaderArgs_SetCache(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char); + pub fn gcpp_LoaderArgs_SetCache(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint); pub fn gcpp_LoaderArgs_Cache(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; - pub fn gcpp_LoaderArgs_SetModelTypeValue(largs: *mut gcpp_LoaderArgs, s: *const ffi::c_char); + pub fn gcpp_LoaderArgs_SetModelTypeValue(largs: *mut gcpp_LoaderArgs, s: *const ffi::c_char, n: ffi::c_uint); pub fn gcpp_LoaderArgs_ModelTypeValue(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; } @@ -207,9 +207,9 @@ mod test { let model = "2b-pt"; unsafe { let largs = gcpp_LoaderArgs_LoaderArgs(0, std::ptr::null_mut()); - gcpp_LoaderArgs_SetTokenizer(largs, ffi::CString::new(tokenizer_path).unwrap().as_ptr()); - gcpp_LoaderArgs_SetCache(largs, ffi::CString::new(compressed_weights).unwrap().as_ptr()); - gcpp_LoaderArgs_SetModelTypeValue(largs, ffi::CString::new(model).unwrap().as_ptr()); + gcpp_LoaderArgs_SetTokenizer(largs, tokenizer_path.as_ptr() as *const i8, tokenizer_path.len() as ffi::c_uint); + gcpp_LoaderArgs_SetCache(largs, compressed_weights.as_ptr() as *const i8, compressed_weights.len() as ffi::c_uint); + gcpp_LoaderArgs_SetModelTypeValue(largs, model.as_ptr() as *const i8, model.len() as ffi::c_uint); let err = gcpp_LoaderArgs_Validate(largs); if err != std::ptr::null_mut() { println!("{}", ffi::CStr::from_ptr(err).to_str().unwrap()); diff --git a/crates/llm-chain-gemma/src/context.rs b/crates/llm-chain-gemma/src/context.rs index 0944c397..b9105576 100644 --- a/crates/llm-chain-gemma/src/context.rs +++ b/crates/llm-chain-gemma/src/context.rs @@ -38,13 +38,14 @@ impl GemmaContext { gcpp_LoaderArgs_SetModelTypeValue( largs, mt.clone().into_bytes().as_ptr() as *const i8, + mt.len() as ffi::c_uint, ); } if let Some(Opt::Model(m)) = options.get(OptDiscriminants::Model) { // Typically the downloaded model data is compressed and set as cache. // TODO: consider the case of non-compressed one? let path = m.to_path(); - gcpp_LoaderArgs_SetCache(largs, path.as_ptr() as *const i8); + gcpp_LoaderArgs_SetCache(largs, path.as_ptr() as *const i8, path.len() as ffi::c_uint); // TODO: consider adding the option for tokenizer file. let parent = Path::new(&path).parent(); if parent.is_none() { @@ -53,7 +54,7 @@ impl GemmaContext { ))); } if let Some(tokenizer_path) = parent.unwrap().join("tokenizer.spm").to_str() { - gcpp_LoaderArgs_SetTokenizer(largs, tokenizer_path.as_ptr() as *const i8); + gcpp_LoaderArgs_SetTokenizer(largs, tokenizer_path.as_ptr() as *const i8, tokenizer_path.len() as ffi::c_uint); } else { return Err(ExecutorCreationError::InvalidValue(String::from( "conversion from path to str for tokenizer",