diff --git a/src/api/openapi.rs b/src/api/openapi.rs index ecbbf30..cf36d97 100644 --- a/src/api/openapi.rs +++ b/src/api/openapi.rs @@ -14,6 +14,7 @@ use utoipa::OpenApi; #[openapi( // List of API endpoints to be included in the documentation. paths( + super::routes::generate::generate_handler, super::routes::generate_text::generate_text_handler, super::routes::generate_stream::generate_stream_handler, super::routes::health::get_health_handler, @@ -50,6 +51,7 @@ mod tests { fn api_doc_contains_all_endpoints() { let api_doc = ApiDoc::openapi(); let paths = api_doc.paths.paths; + assert!(paths.contains_key("/")); assert!(paths.contains_key("/generate")); assert!(paths.contains_key("/generate_stream")); assert!(paths.contains_key("/health")); diff --git a/src/api/routes/generate.rs b/src/api/routes/generate.rs index dfe2d0e..4eb7cc8 100644 --- a/src/api/routes/generate.rs +++ b/src/api/routes/generate.rs @@ -1,56 +1,77 @@ -use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; +use axum::{ + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; use crate::{ api::model::{CompatGenerateRequest, ErrorResponse, GenerateRequest}, config::Config, }; -use super::generate_stream::generate_stream_handler; +use super::{generate_stream::generate_stream_handler, generate_text_handler}; /// Handler for generating text tokens. /// -/// This endpoint accepts a `CompatGenerateRequest` and returns a stream of generated text. -/// It requires the `stream` field in the request to be true. If `stream` is false, -/// the handler will return a `StatusCode::NOT_IMPLEMENTED` error. +/// This endpoint accepts a `CompatGenerateRequest` and returns a stream of generated text +/// or a single text response based on the `stream` field in the request. If `stream` is true, +/// it returns a stream of `StreamResponse`. If `stream` is false, it returns `GenerateResponse`. /// /// # Arguments /// * `config` - State containing the application configuration. /// * `payload` - JSON payload containing the input text and optional parameters. /// /// # Responses -/// * `200 OK` - Successful generation of text, returns a stream of `StreamResponse`. -/// * `501 Not Implemented` - Returned if `stream` field in request is false. +/// * `200 OK` - Successful generation of text. +/// * `501 Not Implemented` - Returned if streaming is not implemented. #[utoipa::path( post, + tag = "Text Generation Inference", path = "/", request_body = CompatGenerateRequest, responses( - (status = 200, description = "Generated Text", body = StreamResponse), - (status = 501, description = "Streaming not enabled", body = ErrorResponse), - ), - tag = "Text Generation Inference" + (status = 200, description = "Generated Text", + content( + ("application/json" = GenerateResponse), + ("text/event-stream" = StreamResponse), + ) + ), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json!({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json!({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json!({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json!({"error": "Incomplete generation"})), + ) )] pub async fn generate_handler( config: State, Json(payload): Json, -) -> impl IntoResponse { - if !payload.stream { - return Err(( - StatusCode::NOT_IMPLEMENTED, - Json(ErrorResponse { - error: "Use /generate endpoint if not streaming".to_string(), - error_type: None, +) -> Result)> { + if payload.stream { + Ok(generate_stream_handler( + config, + Json(GenerateRequest { + inputs: payload.inputs, + parameters: payload.parameters, + }), + ) + .await + .into_response()) + } else { + Ok(generate_text_handler( + config, + Json(GenerateRequest { + inputs: payload.inputs, + parameters: payload.parameters, }), - )); + ) + .await + .into_response()) } - Ok(generate_stream_handler( - config, - Json(GenerateRequest { - inputs: payload.inputs, - parameters: payload.parameters, - }), - ) - .await) } #[cfg(test)] @@ -97,6 +118,7 @@ mod tests { /// Test the generate_handler function for streaming disabled. #[tokio::test] + #[ignore = "Will download model from HuggingFace"] async fn test_generate_handler_stream_disabled() { let app = Router::new() .route("/", post(generate_handler)) @@ -120,6 +142,6 @@ mod tests { .await .unwrap(); - assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); + assert_eq!(response.status(), StatusCode::OK); } } diff --git a/src/api/routes/generate_stream.rs b/src/api/routes/generate_stream.rs index 30a8c6d..ae47fac 100644 --- a/src/api/routes/generate_stream.rs +++ b/src/api/routes/generate_stream.rs @@ -10,7 +10,24 @@ use futures::stream::StreamExt; use log::debug; use std::vec; -/// Generate tokens +/// Asynchronous handler for generating text through a streaming API. +/// +/// This function handles POST requests to the `/generate_stream` endpoint. It takes a JSON payload +/// representing a `GenerateRequest` and uses the configuration and parameters specified to +/// generate text using a streaming approach. The response is a stream of Server-Sent Events (SSE), +/// allowing clients to receive generated text in real-time as it is produced. +/// +/// # Parameters +/// - `config`: Application state holding the global configuration. +/// - `Json(payload)`: JSON payload containing the input text and generation parameters. +/// +/// # Responses +/// - `200 OK`: Stream of generated text as `StreamResponse` events. +/// - Error responses: Descriptive error messages if any issues occur. +/// +/// # Usage +/// This endpoint is suitable for scenarios where real-time text generation is required, +/// such as interactive chatbots or live content creation tools. #[utoipa::path( post, path = "/generate_stream", diff --git a/src/api/routes/generate_text.rs b/src/api/routes/generate_text.rs index d5dce53..6c95b68 100644 --- a/src/api/routes/generate_text.rs +++ b/src/api/routes/generate_text.rs @@ -5,7 +5,27 @@ use crate::{ }; use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; -/// Generate tokens +/// Asynchronous handler for generating text. +/// +/// This function handles POST requests to the `/generate` endpoint. It takes a JSON payload +/// representing a `GenerateRequest` and uses the configuration and parameters specified to +/// generate text. The generated text is returned in a `GenerateResponse` if successful. +/// +/// # Parameters +/// - `config`: Application state holding the global configuration. +/// - `Json(payload)`: JSON payload containing the input text and generation parameters. +/// +/// # Responses +/// - `200 OK`: Successful text generation with `GenerateResponse`. +/// - `422 Unprocessable Entity`: Input validation error with `ErrorResponse`. +/// - `424 Failed Dependency`: Generation error with `ErrorResponse`. +/// - `429 Too Many Requests`: Model is overloaded with `ErrorResponse`. +/// - `500 Internal Server Error`: Incomplete generation with `ErrorResponse`. +/// +/// # Usage +/// This endpoint is suitable for generating text based on given prompts and parameters. +/// It can be used in scenarios where batch text generation is required, such as content +/// creation, language modeling, or any application needing on-demand text generation. #[utoipa::path( post, path = "/generate", diff --git a/src/api/routes/mod.rs b/src/api/routes/mod.rs index 28a5617..8fb8ba1 100644 --- a/src/api/routes/mod.rs +++ b/src/api/routes/mod.rs @@ -1,8 +1,21 @@ -pub mod generate; -pub mod generate_stream; -pub mod generate_text; -pub mod health; -pub mod info; +/// Module containing all route handlers. +/// +/// This module organizes the different API endpoints and their associated handlers. +/// Each route corresponds to a specific functionality of the text generation inference API. +/// +/// # Modules +/// * `generate` - Handles requests for token generation with streaming capability. +/// * `generate_stream` - Handles streaming requests for text generation. +/// * `generate_text` - Handles requests for generating text without streaming. +/// * `health` - Provides a health check endpoint. +/// * `info` - Provides information about the text generation inference service. +pub mod generate; // Module for handling token generation with streaming. +pub mod generate_stream; // Module for handling streaming text generation requests. +pub mod generate_text; // Module for handling text generation requests. +pub mod health; // Module for the health check endpoint. +pub mod info; // Module for the service information endpoint. + +// Public exports of route handlers for ease of access. pub use generate::generate_handler; pub use generate_stream::generate_stream_handler; pub use generate_text::generate_text_handler;