Skip to content

Commit

Permalink
add rust error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Apr 19, 2024
1 parent 39b2066 commit e6be53d
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 129 deletions.
10 changes: 10 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ struct Args {
#[clap(long, env)]
sharded: Option<bool>,

/// Whether this model is mean for embeddings or text generation.
/// By default models are for text generation.
/// Setting it to `true` will enable the embedding endpoints and disable the generation ones.
#[clap(long, env)]
embedding_model: Option<bool>,

/// The number of shards to use if you don't want to use all GPUs on a given machine.
/// You can use `CUDA_VISIBLE_DEVICES=0,1 lorax-launcher... --num_shard 2`
/// and `CUDA_VISIBLE_DEVICES=2,3 lorax-launcher... --num_shard 2` to
Expand Down Expand Up @@ -1097,6 +1103,10 @@ fn spawn_webserver(
router_args.push(origin.to_string());
}

if args.embedding_model.unwrap_or(false) {
router_args.push("--embedding-model".to_string());
}

// Ngrok
if args.ngrok {
router_args.push("--ngrok".to_string());
Expand Down
2 changes: 2 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ pub struct Info {
pub docker_label: Option<&'static str>,
#[schema(nullable = true, example = "http://localhost:8899")]
pub request_logger_url: Option<String>,
#[schema(example = false)]
pub embedding_model: bool,
}

#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
Expand Down
4 changes: 4 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ struct Args {
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
embedding_model: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
Expand Down Expand Up @@ -109,6 +111,7 @@ async fn main() -> Result<(), RouterError> {
revision,
validation_workers,
json_output,
embedding_model,
otlp_endpoint,
cors_allow_origin,
cors_allow_method,
Expand Down Expand Up @@ -372,6 +375,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken,
ngrok_edge,
adapter_source,
embedding_model,
)
.await?;
Ok(())
Expand Down
Loading

0 comments on commit e6be53d

Please sign in to comment.