Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) snowflake endpoint POC #338

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
push:
branches:
- 'main'
- 'snowflake-endpoint'
tags:
- 'v*'

Expand Down
50 changes: 42 additions & 8 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ use infer::Infer;
use loader::AdapterLoader;
use queue::Entry;
use serde::{Deserialize, Serialize};
use serde_json::json;
use utoipa::ToSchema;
use validation::Validation;
use std::collections::HashMap;

/// Hub type
#[derive(Clone, Debug, Deserialize)]
Expand Down Expand Up @@ -301,6 +301,19 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct SnowflakeGenerateRequest {
#[schema(example = "data: [[row_index, value]]")]
pub data: Vec<(i32, String, GenerateParameters)>,
// 'data: [[0, "'{inputs: "blah"}'"], [1, "next row data" ]]
}

#[derive(Serialize, ToSchema, Clone, Debug, Deserialize)]
pub(crate) struct SnowflakeGenerateResponse {
#[schema(example = "data: [[row_index, value]]")]
pub data: Vec<(i32, GenerateResponse)>,
// 'data: [[0, "'{inputs: "blah"}'"], [1, "next row data" ]]
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatGenerateRequest {
#[schema(example = "My name is Olivier and I")]
Expand All @@ -321,7 +334,28 @@ impl From<CompatGenerateRequest> for GenerateRequest {
}
}

#[derive(Debug, Serialize, ToSchema)]
impl From<SnowflakeGenerateRequest> for GenerateRequest {
fn from(req: SnowflakeGenerateRequest) -> Self {
let input_row = req.data.into_iter().nth(0).unwrap();
Self {
inputs: input_row.1,
parameters: input_row.2
}
}
}

impl From<GenerateResponse> for SnowflakeGenerateResponse {
fn from(req: GenerateResponse) -> Self {
let mut vec: Vec<(i32, GenerateResponse)> = Vec::new();
let data_tuple: (i32, GenerateResponse) = (0, req);
vec.push(data_tuple);
Self {
data: vec,
}
}
}

#[derive(Debug, Serialize, ToSchema, Clone, Deserialize)]
pub struct PrefillToken {
#[schema(example = 0)]
id: u32,
Expand All @@ -331,7 +365,7 @@ pub struct PrefillToken {
logprob: f32,
}

#[derive(Debug, Serialize, ToSchema)]
#[derive(Debug, Serialize, ToSchema, Clone, Deserialize)]
pub struct AlternativeToken {
#[schema(example = 0)]
id: u32,
Expand All @@ -341,7 +375,7 @@ pub struct AlternativeToken {
logprob: f32,
}

#[derive(Debug, Serialize, ToSchema)]
#[derive(Debug, Serialize, ToSchema, Clone, Deserialize)]
pub struct Token {
#[schema(example = 0)]
id: u32,
Expand All @@ -356,7 +390,7 @@ pub struct Token {
alternative_tokens: Option<Vec<AlternativeToken>>,
}

#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, Debug, Clone, Deserialize)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
Expand All @@ -368,7 +402,7 @@ pub(crate) enum FinishReason {
StopSequence,
}

#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, Debug, Clone, Deserialize)]
pub(crate) struct BestOfSequence {
#[schema(example = "test")]
pub generated_text: String,
Expand All @@ -382,7 +416,7 @@ pub(crate) struct BestOfSequence {
pub tokens: Vec<Token>,
}

#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, Clone, Debug, Deserialize)]
pub(crate) struct Details {
#[schema(example = "length")]
pub finish_reason: FinishReason,
Expand All @@ -398,7 +432,7 @@ pub(crate) struct Details {
pub best_of_sequences: Option<Vec<BestOfSequence>>,
}

#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, Clone, Debug, Deserialize)]
pub(crate) struct GenerateResponse {
#[schema(example = "test")]
pub generated_text: String,
Expand Down
40 changes: 36 additions & 4 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
CompatGenerateRequest, CompletionRequest, CompletionResponse, CompletionStreamResponse,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, CompatGenerateRequest, CompletionRequest, CompletionResponse, CompletionStreamResponse, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, SnowflakeGenerateRequest, SnowflakeGenerateResponse, StreamDetails, StreamResponse, Token, Validation
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
Expand Down Expand Up @@ -85,6 +82,40 @@ async fn compat_generate(
Ok((headers, Json(vec![generation.0])).into_response())
}
}
/// Snowflake compatible generation endpoint
#[utoipa::path(
post,
tag = "LoRAX",
path = "/snowflake/generate",
request_body = SnowflakeGenerateRequest,
responses(
(status = 200, description = "Generated Text", body = SnowflakeGenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn snowflake_generate(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
req_headers: HeaderMap,
req: Json<SnowflakeGenerateRequest>,
) -> Result<(HeaderMap, Json<SnowflakeGenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let req = req.0;
let gen_req = GenerateRequest::from(req);
let (headers, generation) = generate(infer, req_headers, Json(gen_req.into())).await?;
let details = generation.details.clone();
let generated_text = generation.generated_text.clone();
let response = SnowflakeGenerateResponse::from(GenerateResponse{generated_text, details});
Ok((headers, Json(response)))
}


/// OpenAI compatible completions endpoint
#[utoipa::path(
Expand Down Expand Up @@ -945,6 +976,7 @@ pub async fn run(
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/v1/completions", post(completions_v1))
.route("/snowflake/generate", post(snowflake_generate))
.route("/v1/chat/completions", post(chat_completions_v1))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
Expand Down
Loading