Skip to content

Commit

Permalink
✨ adds model endpoint with model as path parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Jan 10, 2024
1 parent 669b80f commit 55e5d2c
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/api/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::model::{
CompatGenerateRequest, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
Info, StreamDetails, StreamResponse, Token,
};
use crate::api::model::ErrorResponse;
use crate::{api::model::ErrorResponse, llm::models::Models};
use utoipa::OpenApi;

/// Represents the API documentation for the text generation inference service.
Expand All @@ -17,6 +17,7 @@ use utoipa::OpenApi;
super::routes::generate::generate_handler,
super::routes::generate_text::generate_text_handler,
super::routes::generate_stream::generate_stream_handler,
super::routes::model::generate_model_handler,
super::routes::health::get_health_handler,
super::routes::info::get_info_handler
),
Expand All @@ -32,7 +33,8 @@ use utoipa::OpenApi;
StreamDetails,
Token,
FinishReason,
Info
Info,
Models
)
),
// Metadata and description of the API tags.
Expand Down
2 changes: 2 additions & 0 deletions src/api/routes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ pub mod generate_stream; // Module for handling streaming text generation reques
pub mod generate_text; // Module for handling text generation requests.
pub mod health; // Module for the health check endpoint.
pub mod info; // Module for the service information endpoint.
pub mod model; // Module to define model by path.

// Public exports of route handlers for ease of access.
pub use generate::generate_handler;
pub use generate_stream::generate_stream_handler;
pub use generate_text::generate_text_handler;
pub use health::get_health_handler;
pub use info::get_info_handler;
pub use model::generate_model_handler;
84 changes: 84 additions & 0 deletions src/api/routes/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};

use crate::{
api::model::{CompatGenerateRequest, ErrorResponse, GenerateRequest},
config::Config,
llm::models::Models,
};

use super::{generate_stream::generate_stream_handler, generate_text_handler};

/// Handler for generating text tokens.
///
/// This endpoint accepts a `CompatGenerateRequest` and returns a stream of generated text
/// or a single text response based on the `stream` field in the request. If `stream` is true,
/// it returns a stream of `StreamResponse`. If `stream` is false, it returns `GenerateResponse`.
///
/// # Arguments
/// * `config` - State containing the application configuration.
/// * `payload` - JSON payload containing the input text and optional parameters.
///
/// # Responses
/// * `200 OK` - Successful generation of text.
/// * `501 Not Implemented` - Returned if streaming is not implemented.
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/model/{model}",
params(
("model" = Models, Path, description = "Model to use for generation"),
),
request_body = CompatGenerateRequest,
responses(
(status = 200, description = "Generated Text",
content(
("application/json" = GenerateResponse),
("text/event-stream" = StreamResponse),
)
),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json!({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json!({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json!({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json!({"error": "Incomplete generation"})),
)
)]

pub async fn generate_model_handler(
Path(model): Path<Models>,
config: State<Config>,
Json(payload): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut config = config.clone();
config.model = model;

if payload.stream {
Ok(generate_stream_handler(
config,
Json(GenerateRequest {
inputs: payload.inputs,
parameters: payload.parameters,
}),
)
.await
.into_response())
} else {
Ok(generate_text_handler(
config,
Json(GenerateRequest {
inputs: payload.inputs,
parameters: payload.parameters,
}),
)
.await
.into_response())
}
}
12 changes: 11 additions & 1 deletion src/llm/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// source: https://github.com/huggingface/candle/blob/main/candle-examples/examples/quantized/main.rs
use serde::Deserialize;
use std::str::FromStr;
use utoipa::ToSchema;

#[derive(Default, Deserialize, Clone, Debug, Copy, PartialEq, Eq)]
#[derive(Default, Deserialize, Clone, Debug, Copy, PartialEq, Eq, ToSchema)]
pub enum Models {
#[serde(rename = "7b")]
L7b,
Expand Down Expand Up @@ -193,3 +194,12 @@ impl Models {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
fn test_from_str() {
let model = Models::from_str("7b-open-chat-3.5").unwrap();
assert_eq!(model, Models::OpenChat35);
}
}
6 changes: 5 additions & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use utoipa_swagger_ui::SwaggerUi;
use crate::{
api::{
openapi::ApiDoc,
routes::{generate_handler, generate_stream_handler, generate_text_handler},
routes::{
generate_handler, generate_model_handler, generate_stream_handler,
generate_text_handler,
},
routes::{get_health_handler, get_info_handler},
},
config::Config,
Expand All @@ -35,6 +38,7 @@ pub fn server(config: Config) -> Router {
.route("/health", get(get_health_handler))
.route("/info", get(get_info_handler))
.route("/generate_stream", post(generate_stream_handler))
.route("/model/:model", post(generate_model_handler))
.with_state(config);

let swagger_ui = SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi());
Expand Down
25 changes: 25 additions & 0 deletions tests/server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,31 @@ async fn test_generate_text_handler() {
assert_eq!(response.status_code(), 200);
}

//#[ignore = "ignore until mocked"]
#[tokio::test]
async fn test_generate_text_model_handler() {
let config = Config::default();
let app = server(config);

let server = TestServer::new(app).unwrap();
let response = server
.post("/model/phi-v2")
.json(&serde_json::json!({
"inputs": "write hello world in rust",
"parameters": {
"temperature": 0.9,
"top_p": 0.9,
"repetition_penalty": 1.1,
"top_n_tokens": 64,
"max_new_tokens": 50,
"stop": ["</s>"]
}
}))
.await;

assert_eq!(response.status_code(), 200);
}

#[tokio::test]
async fn test_get_health_handler() {
let config = Config::default();
Expand Down

0 comments on commit 55e5d2c

Please sign in to comment.