Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Openai embedding endpoint #702

Merged
merged 11 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/integration-tests/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ runs:
cd integration-tests
pytest test_embeddings.py -vv --capture=tee-sys --log-cli-level=INFO

- name: Run Classification tests
shell: bash
run: |
cd integration-tests
pytest test_classifications.py -vv --capture=tee-sys --log-cli-level=INFO

- name: Run LLM tests
shell: bash
run: |
Expand Down
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions integration-tests/test_classifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import requests
from utils.docker_runner import run_lorax_container


def test_distilbert_ner():
config = {
"name": "distilbert-ner",
"model_id": "dslim/distilbert-NER",
"docker_args": {
"max_input_length": 512,
"max_batch_prefill_tokens": 512,
"max_batch_total_tokens": 512,
"max_total_tokens": 512,
},
}
with run_lorax_container(config):
response = requests.post(
"http://localhost:8080/classify",
json={"inputs": "Johnny supports the Golden State Warriors. He lives in London."},
)
response.raise_for_status()
print("RESPONSE FROM CLASSIFICATION: ", response.json())
assert len(response.json()["predictions"]) > 0
18 changes: 18 additions & 0 deletions integration-tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,21 @@ def test_stella_1_5b():
response.raise_for_status()
print("RESPONSE FROM EMBEDDING: ", response.json())
assert len(response.json()["embeddings"]) > 0


def test_uae_large_v1_1_5b():
config = {
"name": "UAE-Large-V1-1.5b",
"model_id": "WhereIsAI/UAE-Large-V1",
"docker_args": {
"max_input_length": 512,
"max_batch_prefill_tokens": 512,
"max_batch_total_tokens": 512,
"max_total_tokens": 512,
},
}
with run_lorax_container(config):
response = requests.post("http://localhost:8080/embed", json={"inputs": "Hello, world!"})
response.raise_for_status()
print("RESPONSE FROM EMBEDDING: ", response.json())
assert len(response.json()["embeddings"]) > 0
7 changes: 0 additions & 7 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1675,13 +1675,6 @@ fn main() -> Result<(), LauncherError> {
}
};

// Validate args
if max_input_tokens >= max_total_tokens {
return Err(LauncherError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(),
));
}

if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
Expand Down
2 changes: 1 addition & 1 deletion router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ path = "src/main.rs"

[dependencies]
async-stream = "0.3.3"
axum = { version = "0.7", features = ["json"] }
axum = { version = "0.7", features = ["json", "macros"] }
axum-tracing-opentelemetry = "0.16"
clap = { version = "4.1.4", features = ["derive", "env"] }
futures = "0.3.26"
Expand Down
148 changes: 139 additions & 9 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use crate::queue::AdapterEvent;
use crate::scheduler::AdapterScheduler;
use crate::validation::{Validation, ValidationError};
use crate::{
AdapterParameters, AlternativeToken, BatchClassifyRequest, ChatTemplateVersions,
ClassifyRequest, EmbedRequest, EmbedResponse, Entity, Entry, HubTokenizerConfig, Message,
MessageChunk, MessageContent, TextMessage, Token, TokenizerConfigToken, Tool,
AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchEmbedRequest,
ChatTemplateVersions, ClassifyRequest, EmbedRequest, EmbedResponse, Entity, Entry,
HubTokenizerConfig, Message, MessageChunk, MessageContent, TextMessage, Token,
TokenizerConfigToken, Tool,
};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
Expand Down Expand Up @@ -596,6 +597,7 @@ impl Infer {
embedding,
start: _,
queued: _,
id: _,
} => {
return_embeddings = Some(embedding.values);
}
Expand Down Expand Up @@ -857,6 +859,137 @@ impl Infer {
}
}

#[instrument(skip(self))]
pub(crate) async fn embed_batch(
&self,
request: BatchEmbedRequest,
) -> Result<Vec<EmbedResponse>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = self
.clone()
.limit_concurrent_requests
.try_acquire_owned()
.map_err(|err| {
metrics::increment_counter!("lorax_request_failure", "err" => "overloaded");
tracing::error!("{err}");
err
})?;

let (adapter_source, adapter_parameters) = extract_adapter_params(
request.parameters.adapter_id.clone(),
request.parameters.adapter_source.clone(),
request.parameters.adapter_parameters.clone(),
);

let adapter_idx;
{
// TODO(travis): can optimize concurrency here using RWLock
let mut adapter_to_index = self.adapter_to_index.lock().await;
let adapter_key = adapter_parameters.clone();
if adapter_to_index.contains_key(&adapter_key) {
adapter_idx = *adapter_to_index.get(&adapter_key).unwrap();
} else {
adapter_idx = adapter_to_index.len() as u32;
adapter_to_index.insert(adapter_key, adapter_idx);
}
}

let api_token = request.parameters.api_token.clone();
let adapter = Adapter::new(
adapter_parameters,
adapter_source.unwrap(),
adapter_idx,
api_token,
);

// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();

// Call validate_input on every input in the request and await the results
let futures: Vec<_> = request
.inputs
.iter()
.map(|input| {
self.validation
.validate_input(input.clone(), true, None, Some(1))
})
.collect();

let all_tokenized_inputs = try_join_all(futures).await?;

for ((id, r_inputs), (tokenized_inputs, input_length)) in
request.inputs.iter().enumerate().zip(all_tokenized_inputs)
{
let inputs = r_inputs.to_string().clone();
let valid_request = ValidEmbedRequest {
inputs,
tokenized_inputs,
input_length: input_length as u32,
adapter: adapter.clone(),
};

// Process the request by sending it to the queue associated with `adapter`
self.adapter_scheduler.process(
adapter.clone(),
Entry {
request: Arc::new(valid_request),
response_tx: response_tx.clone(),
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
id: Some(id as u64),
},
);
}

drop(response_tx); // Close the sending end

// Return values

let mut all_embeddings = HashMap::new();
let mut stream = UnboundedReceiverStream::new(response_rx);
while let Some(response) = stream.next().await {
match response? {
InferStreamResponse::Embed {
embedding,
start: _,
queued: _,
id,
} => {
all_embeddings.insert(
id.unwrap(),
EmbedResponse {
embeddings: embedding.values,
},
);
}
_ => {
tracing::error!(
"Received unexpected message type in classify_batch. This is a bug."
);
}
}
}
if all_embeddings.is_empty() {
let err = InferError::EmbeddingFailure;
metrics::increment_counter!("lorax_request_failure", "err" => "embedding_failure");
tracing::error!("{err}");
Err(err)
} else {
let mut sorted_responses: Vec<_> = all_embeddings.into_iter().collect();
sorted_responses.sort_by_key(|&(id, _)| id);

let sorted_responses: Vec<EmbedResponse> = sorted_responses
.into_iter()
.map(|(_, response)| response)
.collect();

Ok(sorted_responses)
}
}

/// Add best_of new requests to the queue and return a InferResponse of the sequence with
/// the highest log probability per token
#[instrument(skip(self))]
Expand Down Expand Up @@ -1478,6 +1611,7 @@ fn send_embeddings(
embedding: embedding.clone(),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
id: entry.id,
}))?;

// TODO(travis): redundant as we always return true, just make it return nothing
Expand Down Expand Up @@ -1534,22 +1668,18 @@ pub(crate) enum InferStreamResponse {
// Intermediate messages
Token(Token),
// Embeddings
// TODO: add tracing for embedding
Embed {
embedding: Embedding,
// For now allow this field even though it is unused.
// TODO:(magdy) enable tracing for these requests
#[allow(dead_code)]
start: Instant,
#[allow(dead_code)]
queued: Instant,
id: Option<u64>, // to support batching
},
Classify {
predictions: ClassifyPredictionList,
// For now allow this field even though it is unused.
// TODO:(magdy) enable tracing for these requests
#[allow(dead_code)]
start: Instant,
#[allow(dead_code)]
queued: Instant,
id: Option<u64>, // to support batching
},
Expand Down
50 changes: 50 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,49 @@ struct EmbedResponse {
embeddings: Vec<f32>,
}

#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum StringOrVec {
String(String),
Vec(Vec<String>),
}

impl std::fmt::Display for StringOrVec {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
StringOrVec::String(s) => write!(f, "{}", s),
StringOrVec::Vec(v) => write!(f, "{}", v.join(", ")),
}
}
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct CompatEmbedRequest {
input: StringOrVec,
#[allow(dead_code)]
model: String,
#[allow(dead_code)]
encoding_format: Option<String>,
#[allow(dead_code)]
dimensions: Option<i32>,
#[allow(dead_code)]
user: Option<String>,
#[serde(default = "default_embed_parameters")]
parameters: EmbedParameters,
}

#[derive(Serialize, ToSchema)]
struct CompatEmbedResponse {
embeddings: Vec<CompatEmbedding>,
}

#[derive(Serialize, ToSchema)]
struct CompatEmbedding {
index: i32,
embedding: Vec<f32>,
object: String,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct ClassifyRequest {
inputs: String,
Expand All @@ -1175,6 +1218,13 @@ struct BatchClassifyRequest {
inputs: Vec<String>,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct BatchEmbedRequest {
inputs: Vec<String>,
#[serde(default = "default_embed_parameters")]
parameters: EmbedParameters,
}

#[derive(Debug, Serialize, Deserialize)]
struct Entity {
entity_group: String,
Expand Down
7 changes: 0 additions & 7 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,6 @@ async fn main() -> Result<(), RouterError> {

init_logging(otlp_endpoint, json_output);

// Validate args
if max_input_length >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(),
));
}

if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
Expand Down
Loading
Loading