Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jmuk committed Mar 13, 2024
1 parent 575af56 commit 5e6deaa
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
16 changes: 8 additions & 8 deletions crates/llm-chain-gemma-sys/src/bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,32 @@ gcpp::ModelTraining gcpp_LoaderArgs_ModelTraining(const gcpp::LoaderArgs* args)
return args->ModelTraining();
}

void gcpp_LoaderArgs_SetTokenizer(gcpp::LoaderArgs* args, char* path) {
args->tokenizer.path = std::string(path);
void gcpp_LoaderArgs_SetTokenizer(gcpp::LoaderArgs* args, char* path, size_t n) {
args->tokenizer.path = std::string(path, n);
}

const char* gcpp_LoaderArgs_Tokenizer(gcpp::LoaderArgs* args) {
return args->tokenizer.path.c_str();
}

void gcpp_LoaderArgs_SetModel(gcpp::LoaderArgs* args, char* path) {
args->model.path = std::string(path);
void gcpp_LoaderArgs_SetModel(gcpp::LoaderArgs* args, char* path, size_t n) {
args->model.path = std::string(path, n);
}

const char* gcpp_LoaderArgs_Model(gcpp::LoaderArgs* args) {
return args->model.path.c_str();
}

void gcpp_LoaderArgs_SetCache(gcpp::LoaderArgs* args, char* path) {
args->cache.path = std::string(path);
void gcpp_LoaderArgs_SetCache(gcpp::LoaderArgs* args, char* path, size_t n) {
args->cache.path = std::string(path, n);
}

const char* gcpp_LoaderArgs_Cache(gcpp::LoaderArgs* args) {
return args->cache.path.c_str();
}

void gcpp_LoaderArgs_SetModelTypeValue(gcpp::LoaderArgs* args, char* v) {
args->model_type = std::string(v);
void gcpp_LoaderArgs_SetModelTypeValue(gcpp::LoaderArgs* args, char* v, size_t n) {
args->model_type = std::string(v, n);
}

const char* gcpp_LoaderArgs_ModelTypeValue(gcpp::LoaderArgs* args) {
Expand Down
14 changes: 7 additions & 7 deletions crates/llm-chain-gemma-sys/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ extern "C" {
pub fn gcpp_LoaderArgs_Validate(largs: *mut gcpp_LoaderArgs) -> *const ffi::c_char;
pub fn gcpp_LoaderArgs_ModelType(largs: *const gcpp_LoaderArgs) -> gcpp_Model;
pub fn gcpp_LoaderArgs_ModelTraining(largs: *const gcpp_LoaderArgs) -> gcpp_ModelTraining;
pub fn gcpp_LoaderArgs_SetTokenizer(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char);
pub fn gcpp_LoaderArgs_SetTokenizer(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint);
pub fn gcpp_LoaderArgs_Tokenizer(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char;
pub fn gcpp_LoaderArgs_SetModel(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char);
pub fn gcpp_LoaderArgs_SetModel(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint);
pub fn gcpp_LoaderArgs_Model(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char;
pub fn gcpp_LoaderArgs_SetCache(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char);
pub fn gcpp_LoaderArgs_SetCache(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint);
pub fn gcpp_LoaderArgs_Cache(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char;
pub fn gcpp_LoaderArgs_SetModelTypeValue(largs: *mut gcpp_LoaderArgs, s: *const ffi::c_char);
pub fn gcpp_LoaderArgs_SetModelTypeValue(largs: *mut gcpp_LoaderArgs, s: *const ffi::c_char, n: ffi::c_uint);
pub fn gcpp_LoaderArgs_ModelTypeValue(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char;
}

Expand Down Expand Up @@ -207,9 +207,9 @@ mod test {
let model = "2b-pt";
unsafe {
let largs = gcpp_LoaderArgs_LoaderArgs(0, std::ptr::null_mut());
gcpp_LoaderArgs_SetTokenizer(largs, ffi::CString::new(tokenizer_path).unwrap().as_ptr());
gcpp_LoaderArgs_SetCache(largs, ffi::CString::new(compressed_weights).unwrap().as_ptr());
gcpp_LoaderArgs_SetModelTypeValue(largs, ffi::CString::new(model).unwrap().as_ptr());
gcpp_LoaderArgs_SetTokenizer(largs, tokenizer_path.as_ptr() as *const i8, tokenizer_path.len() as ffi::c_uint);
gcpp_LoaderArgs_SetCache(largs, compressed_weights.as_ptr() as *const i8, compressed_weights.len() as ffi::c_uint);
gcpp_LoaderArgs_SetModelTypeValue(largs, model.as_ptr() as *const i8, model.len() as ffi::c_uint);
let err = gcpp_LoaderArgs_Validate(largs);
if err != std::ptr::null_mut() {
println!("{}", ffi::CStr::from_ptr(err).to_str().unwrap());
Expand Down
5 changes: 3 additions & 2 deletions crates/llm-chain-gemma/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ impl GemmaContext {
gcpp_LoaderArgs_SetModelTypeValue(
largs,
mt.clone().into_bytes().as_ptr() as *const i8,
mt.len() as ffi::c_uint,
);
}
if let Some(Opt::Model(m)) = options.get(OptDiscriminants::Model) {
// Typically the downloaded model data is compressed and set as cache.
// TODO: consider the case of non-compressed one?
let path = m.to_path();
gcpp_LoaderArgs_SetCache(largs, path.as_ptr() as *const i8);
gcpp_LoaderArgs_SetCache(largs, path.as_ptr() as *const i8, path.len() as ffi::c_uint);
// TODO: consider adding the option for tokenizer file.
let parent = Path::new(&path).parent();
if parent.is_none() {
Expand All @@ -53,7 +54,7 @@ impl GemmaContext {
)));
}
if let Some(tokenizer_path) = parent.unwrap().join("tokenizer.spm").to_str() {
gcpp_LoaderArgs_SetTokenizer(largs, tokenizer_path.as_ptr() as *const i8);
gcpp_LoaderArgs_SetTokenizer(largs, tokenizer_path.as_ptr() as *const i8, tokenizer_path.len() as ffi::c_uint);
} else {
return Err(ExecutorCreationError::InvalidValue(String::from(
"conversion from path to str for tokenizer",
Expand Down

0 comments on commit 5e6deaa

Please sign in to comment.