Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expose headers #230

Merged
merged 4 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,13 @@ struct Args {
cors_allow_origin: Vec<String>,

#[clap(long, env)]
cors_allow_headers: Vec<String>,
cors_allow_header: Vec<String>,

#[clap(long, env)]
cors_allow_methods: Vec<String>,
cors_expose_header: Vec<String>,

#[clap(long, env)]
cors_allow_method: Vec<String>,

#[clap(long, env)]
cors_allow_credentials: Option<bool>,
Expand Down Expand Up @@ -1030,14 +1033,20 @@ fn spawn_webserver(
}

// CORS methods
for origin in args.cors_allow_methods.into_iter() {
router_args.push("--cors-allow-methods".to_string());
for origin in args.cors_allow_method.into_iter() {
router_args.push("--cors-allow-method".to_string());
router_args.push(origin);
}

// CORS Allow headers
for origin in args.cors_allow_header.into_iter() {
router_args.push("--cors-allow-header".to_string());
router_args.push(origin);
}

// CORS headers
for origin in args.cors_allow_headers.into_iter() {
router_args.push("--cors-allow-headers".to_string());
// CORS expose headers
for origin in args.cors_expose_header.into_iter() {
router_args.push("--cors-expose-header".to_string());
router_args.push(origin);
}

Expand Down
43 changes: 28 additions & 15 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::path::Path;
use std::time::Duration;
use thiserror::Error;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tower_http::cors::{AllowCredentials, AllowHeaders, AllowMethods, AllowOrigin};
use tower_http::cors::{AllowCredentials, AllowHeaders, AllowMethods, AllowOrigin, ExposeHeaders};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Expand Down Expand Up @@ -64,11 +64,13 @@ struct Args {
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
cors_allow_methods: Option<Vec<String>>,
cors_allow_method: Option<Vec<String>>,
#[clap(long, env)]
cors_allow_credentials: Option<bool>,
cors_allow_header: Option<Vec<String>>,
#[clap(long, env)]
cors_expose_header: Option<Vec<String>>,
#[clap(long, env)]
cors_allow_headers: Option<Vec<String>>,
cors_allow_credentials: Option<bool>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
Expand Down Expand Up @@ -104,9 +106,10 @@ fn main() -> Result<(), RouterError> {
json_output,
otlp_endpoint,
cors_allow_origin,
cors_allow_methods,
cors_allow_method,
cors_allow_header,
cors_expose_header,
cors_allow_credentials,
cors_allow_headers,
ngrok,
ngrok_authtoken,
ngrok_edge,
Expand Down Expand Up @@ -152,30 +155,39 @@ fn main() -> Result<(), RouterError> {
// CORS allowed methods
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowMethods.
let cors_allow_methods: Option<AllowMethods> = cors_allow_methods.map(|cors_allow_methods| {
let cors_allow_method: Option<AllowMethods> = cors_allow_method.map(|cors_allow_methods| {
AllowMethods::list(
cors_allow_methods
.iter()
.map(|method| method.parse::<axum::http::Method>().unwrap()),
)
});

// CORS allow credentials
// Parse bool into AllowCredentials
let cors_allow_credentials: Option<AllowCredentials> = cors_allow_credentials
.map(|cors_allow_credentials| AllowCredentials::from(cors_allow_credentials));

// CORS allowed headers
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowHeaders.
let cors_allow_headers: Option<AllowHeaders> = cors_allow_headers.map(|cors_allow_headers| {
let cors_allow_header: Option<AllowHeaders> = cors_allow_header.map(|cors_allow_headers| {
AllowHeaders::list(
cors_allow_headers
.iter()
.map(|header| header.parse::<HeaderName>().unwrap()),
)
});

// CORS expose headers
let cors_expose_header: Option<ExposeHeaders> = cors_expose_header.map(|cors_expose_headers| {
ExposeHeaders::list(
cors_expose_headers
.iter()
.map(|header| header.parse::<HeaderName>().unwrap()),
)
});

// CORS allow credentials
// Parse bool into AllowCredentials
let cors_allow_credentials: Option<AllowCredentials> = cors_allow_credentials
.map(|cors_allow_credentials| AllowCredentials::from(cors_allow_credentials));

// Parse Huggingface hub token
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

Expand Down Expand Up @@ -320,9 +332,10 @@ fn main() -> Result<(), RouterError> {
validation_workers,
addr,
cors_allow_origin,
cors_allow_methods,
cors_allow_method,
cors_allow_credentials,
cors_allow_headers,
cors_allow_header,
cors_expose_header,
ngrok,
ngrok_authtoken,
ngrok_edge,
Expand Down
14 changes: 11 additions & 3 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ use std::sync::Arc;
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::time::Instant;
use tower_http::cors::{AllowCredentials, AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
use tower_http::cors::{
AllowCredentials, AllowHeaders, AllowMethods, AllowOrigin, CorsLayer, ExposeHeaders,
};
use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
Expand Down Expand Up @@ -706,10 +708,11 @@ pub async fn run(
tokenizer: Option<Tokenizer>,
validation_workers: usize,
addr: SocketAddr,
cors_allow_origin: Option<AllowOrigin>,
cors_allow_origin: Option<AllowOrigin>, // exact match
cors_allow_methods: Option<AllowMethods>,
cors_allow_credentials: Option<AllowCredentials>,
cors_allow_headers: Option<AllowHeaders>,
cors_expose_headers: Option<ExposeHeaders>,
ngrok: bool,
ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>,
Expand Down Expand Up @@ -839,19 +842,24 @@ pub async fn run(
let cors_allow_headers =
cors_allow_headers.unwrap_or(AllowHeaders::list(vec![http::header::CONTENT_TYPE]));

let cors_expose_headers = cors_expose_headers.unwrap_or(ExposeHeaders::default());
let cors_allow_credentials = cors_allow_credentials.unwrap_or(AllowCredentials::default());

// log cors stuff
tracing::info!(
"CORS: origin: {cors_allow_origin:?}, methods: {cors_allow_methods:?}, headers: {cors_allow_headers:?}, credentials: {cors_allow_credentials:?}",
"CORS: origin: {cors_allow_origin:?}, methods: {cors_allow_methods:?}, headers: {cors_allow_headers:?}, expose-headers: {cors_expose_headers:?} credentials: {cors_allow_credentials:?}",
);

let cors_layer = CorsLayer::new()
.allow_methods(cors_allow_methods)
.allow_headers(cors_allow_headers)
.allow_credentials(cors_allow_credentials)
.expose_headers(cors_expose_headers)
.allow_origin(cors_allow_origin);

// log all the cors layer
tracing::info!("CORS: {cors_layer:?}");

// Endpoint info
let info = Info {
model_id: model_info.model_id,
Expand Down
Loading