Skip to content

Commit

Permalink
feat: Add option to input model name that is passed to OpenAI API end…
Browse files Browse the repository at this point in the history
…point rather than using the tokenizer name.
  • Loading branch information
Hugoch committed Oct 17, 2024
1 parent 038c30c commit 6be66d4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub struct RunConfiguration {
pub dataset_file: String,
pub hf_token: Option<String>,
pub extra_metadata: Option<HashMap<String, String>>,
pub model_name: String
}

pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyhow::Result<()> {
Expand All @@ -67,7 +68,7 @@ pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyho
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
run_config.url.clone(),
run_config.tokenizer_name.clone(),
run_config.model_name.clone(),
tokenizer,
run_config.duration,
)?;
Expand Down
7 changes: 7 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ struct Args {
/// The name of the tokenizer to use
#[clap(short, long, env)]
tokenizer_name: String,

/// The name of the model to use. If not provided, the same name as the tokenizer will be used.
#[clap(long, env)]
model_name: Option<String>,

/// The maximum number of virtual users to use
#[clap(default_value = "128", short, long, env)]
max_vus: u64,
Expand Down Expand Up @@ -166,6 +171,7 @@ async fn main() {
Some(token) => Some(token),
None => cache.token(),
};
let model_name = args.model_name.clone().unwrap_or(args.tokenizer_name.clone());
let run_config = RunConfiguration {
url: args.url.clone(),
tokenizer_name: args.tokenizer_name.clone(),
Expand All @@ -182,6 +188,7 @@ async fn main() {
dataset_file: args.dataset_file.clone(),
hf_token,
extra_metadata: args.extra_meta.clone(),
model_name,
};
let main_thread = tokio::spawn(async move {
match run(run_config, stop_sender_clone).await {
Expand Down

0 comments on commit 6be66d4

Please sign in to comment.