diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 21fc0d84f..56632abf4 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,6 +5,7 @@ on: push: branches: - 'main' + - 'snowflake-endpoint' tags: - 'v*' diff --git a/router/src/lib.rs b/router/src/lib.rs index c3ca10f87..4757deaec 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -14,9 +14,9 @@ use infer::Infer; use loader::AdapterLoader; use queue::Entry; use serde::{Deserialize, Serialize}; -use serde_json::json; use utoipa::ToSchema; use validation::Validation; +use std::collections::HashMap; /// Hub type #[derive(Clone, Debug, Deserialize)] @@ -301,6 +301,19 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } +#[derive(Clone, Debug, Deserialize, ToSchema)] +pub(crate) struct SnowflakeGenerateRequest { + #[schema(example = "data: [[row_index, value]]")] + pub data: Vec<(i32, String, GenerateParameters)>, + // 'data: [[0, "'{inputs: "blah"}'"], [1, "next row data" ]] +} + +#[derive(Serialize, ToSchema, Clone, Debug, Deserialize)] +pub(crate) struct SnowflakeGenerateResponse { + #[schema(example = "data: [[row_index, value]]")] + pub data: Vec<(i32, GenerateResponse)>, + // 'data: [[0, "'{inputs: "blah"}'"], [1, "next row data" ]] +} #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct CompatGenerateRequest { #[schema(example = "My name is Olivier and I")] @@ -321,7 +334,28 @@ impl From for GenerateRequest { } } -#[derive(Debug, Serialize, ToSchema)] +impl From for GenerateRequest { + fn from(req: SnowflakeGenerateRequest) -> Self { + let input_row = req.data.into_iter().nth(0).unwrap(); + Self { + inputs: input_row.1, + parameters: input_row.2 + } + } +} + +impl From for SnowflakeGenerateResponse { + fn from(req: GenerateResponse) -> Self { + let mut vec: Vec<(i32, GenerateResponse)> = Vec::new(); + let data_tuple: (i32, GenerateResponse) = (0, req); + vec.push(data_tuple); + Self { + data: vec, + } + } +} + +#[derive(Debug, Serialize, ToSchema, Clone, Deserialize)] pub struct PrefillToken { #[schema(example = 0)] id: u32, @@ -331,7 +365,7 @@ pub struct PrefillToken { logprob: f32, } -#[derive(Debug, Serialize, ToSchema)] +#[derive(Debug, Serialize, ToSchema, Clone, Deserialize)] pub struct AlternativeToken { #[schema(example = 0)] id: u32, @@ -341,7 +375,7 @@ pub struct AlternativeToken { logprob: f32, } -#[derive(Debug, Serialize, ToSchema)] +#[derive(Debug, Serialize, ToSchema, Clone, Deserialize)] pub struct Token { #[schema(example = 0)] id: u32, @@ -356,7 +390,7 @@ pub struct Token { alternative_tokens: Option>, } -#[derive(Serialize, ToSchema)] +#[derive(Serialize, ToSchema, Debug, Clone, Deserialize)] #[serde(rename_all(serialize = "snake_case"))] pub(crate) enum FinishReason { #[schema(rename = "length")] @@ -368,7 +402,7 @@ pub(crate) enum FinishReason { StopSequence, } -#[derive(Serialize, ToSchema)] +#[derive(Serialize, ToSchema, Debug, Clone, Deserialize)] pub(crate) struct BestOfSequence { #[schema(example = "test")] pub generated_text: String, @@ -382,7 +416,7 @@ pub(crate) struct BestOfSequence { pub tokens: Vec, } -#[derive(Serialize, ToSchema)] +#[derive(Serialize, ToSchema, Clone, Debug, Deserialize)] pub(crate) struct Details { #[schema(example = "length")] pub finish_reason: FinishReason, @@ -398,7 +432,7 @@ pub(crate) struct Details { pub best_of_sequences: Option>, } -#[derive(Serialize, ToSchema)] +#[derive(Serialize, ToSchema, Clone, Debug, Deserialize)] pub(crate) struct GenerateResponse { #[schema(example = "test")] pub generated_text: String, diff --git a/router/src/server.rs b/router/src/server.rs index c55ed1552..e8b780237 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,10 +4,7 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ - BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, - CompatGenerateRequest, CompletionRequest, CompletionResponse, CompletionStreamResponse, - Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, - HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, + BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, CompatGenerateRequest, CompletionRequest, CompletionResponse, CompletionStreamResponse, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, SnowflakeGenerateRequest, SnowflakeGenerateResponse, StreamDetails, StreamResponse, Token, Validation }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -85,6 +82,40 @@ async fn compat_generate( Ok((headers, Json(vec![generation.0])).into_response()) } } +/// Snowflake compatible generation endpoint +#[utoipa::path( +post, +tag = "LoRAX", +path = "/snowflake/generate", +request_body = SnowflakeGenerateRequest, +responses( +(status = 200, description = "Generated Text", body = SnowflakeGenerateResponse), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) +)] +#[instrument(skip(infer, req))] +async fn snowflake_generate( + default_return_full_text: Extension, + infer: Extension, + req_headers: HeaderMap, + req: Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + let req = req.0; + let gen_req = GenerateRequest::from(req); + let (headers, generation) = generate(infer, req_headers, Json(gen_req.into())).await?; + let details = generation.details.clone(); + let generated_text = generation.generated_text.clone(); + let response = SnowflakeGenerateResponse::from(GenerateResponse{generated_text, details}); + Ok((headers, Json(response))) +} + /// OpenAI compatible completions endpoint #[utoipa::path( @@ -945,6 +976,7 @@ pub async fn run( .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/completions", post(completions_v1)) + .route("/snowflake/generate", post(snowflake_generate)) .route("/v1/chat/completions", post(chat_completions_v1)) // AWS Sagemaker route .route("/invocations", post(compat_generate))