Skip to content

Commit

Permalink
Openai embedding endpoint (#702)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Dec 4, 2024
1 parent 5e0b003 commit 07b77db
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 34 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/integration-tests/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ runs:
cd integration-tests
pytest test_embeddings.py -vv --capture=tee-sys --log-cli-level=INFO
- name: Run Classification tests
shell: bash
run: |
cd integration-tests
pytest test_classifications.py -vv --capture=tee-sys --log-cli-level=INFO
- name: Run LLM tests
shell: bash
run: |
Expand Down
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions integration-tests/test_classifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import requests
from utils.docker_runner import run_lorax_container


def test_distilbert_ner():
config = {
"name": "distilbert-ner",
"model_id": "dslim/distilbert-NER",
"docker_args": {
"max_input_length": 512,
"max_batch_prefill_tokens": 512,
"max_batch_total_tokens": 512,
"max_total_tokens": 512,
},
}
with run_lorax_container(config):
response = requests.post(
"http://localhost:8080/classify",
json={"inputs": "Johnny supports the Golden State Warriors. He lives in London."},
)
response.raise_for_status()
print("RESPONSE FROM CLASSIFICATION: ", response.json())
assert len(response.json()["predictions"]) > 0
18 changes: 18 additions & 0 deletions integration-tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,21 @@ def test_stella_1_5b():
response.raise_for_status()
print("RESPONSE FROM EMBEDDING: ", response.json())
assert len(response.json()["embeddings"]) > 0


def test_uae_large_v1_1_5b():
config = {
"name": "UAE-Large-V1-1.5b",
"model_id": "WhereIsAI/UAE-Large-V1",
"docker_args": {
"max_input_length": 512,
"max_batch_prefill_tokens": 512,
"max_batch_total_tokens": 512,
"max_total_tokens": 512,
},
}
with run_lorax_container(config):
response = requests.post("http://localhost:8080/embed", json={"inputs": "Hello, world!"})
response.raise_for_status()
print("RESPONSE FROM EMBEDDING: ", response.json())
assert len(response.json()["embeddings"]) > 0
7 changes: 0 additions & 7 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1675,13 +1675,6 @@ fn main() -> Result<(), LauncherError> {
}
};

// Validate args
if max_input_tokens >= max_total_tokens {
return Err(LauncherError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(),
));
}

if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
Expand Down
2 changes: 1 addition & 1 deletion router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ path = "src/main.rs"

[dependencies]
async-stream = "0.3.3"
axum = { version = "0.7", features = ["json"] }
axum = { version = "0.7", features = ["json", "macros"] }
axum-tracing-opentelemetry = "0.16"
clap = { version = "4.1.4", features = ["derive", "env"] }
futures = "0.3.26"
Expand Down
148 changes: 139 additions & 9 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use crate::queue::AdapterEvent;
use crate::scheduler::AdapterScheduler;
use crate::validation::{Validation, ValidationError};
use crate::{
AdapterParameters, AlternativeToken, BatchClassifyRequest, ChatTemplateVersions,
ClassifyRequest, EmbedRequest, EmbedResponse, Entity, Entry, HubTokenizerConfig, Message,
MessageChunk, MessageContent, TextMessage, Token, TokenizerConfigToken, Tool,
AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchEmbedRequest,
ChatTemplateVersions, ClassifyRequest, EmbedRequest, EmbedResponse, Entity, Entry,
HubTokenizerConfig, Message, MessageChunk, MessageContent, TextMessage, Token,
TokenizerConfigToken, Tool,
};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
Expand Down Expand Up @@ -596,6 +597,7 @@ impl Infer {
embedding,
start: _,
queued: _,
id: _,
} => {
return_embeddings = Some(embedding.values);
}
Expand Down Expand Up @@ -857,6 +859,137 @@ impl Infer {
}
}

#[instrument(skip(self))]
pub(crate) async fn embed_batch(
&self,
request: BatchEmbedRequest,
) -> Result<Vec<EmbedResponse>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = self
.clone()
.limit_concurrent_requests
.try_acquire_owned()
.map_err(|err| {
metrics::increment_counter!("lorax_request_failure", "err" => "overloaded");
tracing::error!("{err}");
err
})?;

let (adapter_source, adapter_parameters) = extract_adapter_params(
request.parameters.adapter_id.clone(),
request.parameters.adapter_source.clone(),
request.parameters.adapter_parameters.clone(),
);

let adapter_idx;
{
// TODO(travis): can optimize concurrency here using RWLock
let mut adapter_to_index = self.adapter_to_index.lock().await;
let adapter_key = adapter_parameters.clone();
if adapter_to_index.contains_key(&adapter_key) {
adapter_idx = *adapter_to_index.get(&adapter_key).unwrap();
} else {
adapter_idx = adapter_to_index.len() as u32;
adapter_to_index.insert(adapter_key, adapter_idx);
}
}

let api_token = request.parameters.api_token.clone();
let adapter = Adapter::new(
adapter_parameters,
adapter_source.unwrap(),
adapter_idx,
api_token,
);

// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();

// Call validate_input on every input in the request and await the results
let futures: Vec<_> = request
.inputs
.iter()
.map(|input| {
self.validation
.validate_input(input.clone(), true, None, Some(1))
})
.collect();

let all_tokenized_inputs = try_join_all(futures).await?;

for ((id, r_inputs), (tokenized_inputs, input_length)) in
request.inputs.iter().enumerate().zip(all_tokenized_inputs)
{
let inputs = r_inputs.to_string().clone();
let valid_request = ValidEmbedRequest {
inputs,
tokenized_inputs,
input_length: input_length as u32,
adapter: adapter.clone(),
};

// Process the request by sending it to the queue associated with `adapter`
self.adapter_scheduler.process(
adapter.clone(),
Entry {
request: Arc::new(valid_request),
response_tx: response_tx.clone(),
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
id: Some(id as u64),
},
);
}

drop(response_tx); // Close the sending end

// Return values

let mut all_embeddings = HashMap::new();
let mut stream = UnboundedReceiverStream::new(response_rx);
while let Some(response) = stream.next().await {
match response? {
InferStreamResponse::Embed {
embedding,
start: _,
queued: _,
id,
} => {
all_embeddings.insert(
id.unwrap(),
EmbedResponse {
embeddings: embedding.values,
},
);
}
_ => {
tracing::error!(
"Received unexpected message type in classify_batch. This is a bug."
);
}
}
}
if all_embeddings.is_empty() {
let err = InferError::EmbeddingFailure;
metrics::increment_counter!("lorax_request_failure", "err" => "embedding_failure");
tracing::error!("{err}");
Err(err)
} else {
let mut sorted_responses: Vec<_> = all_embeddings.into_iter().collect();
sorted_responses.sort_by_key(|&(id, _)| id);

let sorted_responses: Vec<EmbedResponse> = sorted_responses
.into_iter()
.map(|(_, response)| response)
.collect();

Ok(sorted_responses)
}
}

/// Add best_of new requests to the queue and return a InferResponse of the sequence with
/// the highest log probability per token
#[instrument(skip(self))]
Expand Down Expand Up @@ -1478,6 +1611,7 @@ fn send_embeddings(
embedding: embedding.clone(),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
id: entry.id,
}))?;

// TODO(travis): redundant as we always return true, just make it return nothing
Expand Down Expand Up @@ -1534,22 +1668,18 @@ pub(crate) enum InferStreamResponse {
// Intermediate messages
Token(Token),
// Embeddings
// TODO: add tracing for embedding
Embed {
embedding: Embedding,
// For now allow this field even though it is unused.
// TODO:(magdy) enable tracing for these requests
#[allow(dead_code)]
start: Instant,
#[allow(dead_code)]
queued: Instant,
id: Option<u64>, // to support batching
},
Classify {
predictions: ClassifyPredictionList,
// For now allow this field even though it is unused.
// TODO:(magdy) enable tracing for these requests
#[allow(dead_code)]
start: Instant,
#[allow(dead_code)]
queued: Instant,
id: Option<u64>, // to support batching
},
Expand Down
50 changes: 50 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,49 @@ struct EmbedResponse {
embeddings: Vec<f32>,
}

#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum StringOrVec {
String(String),
Vec(Vec<String>),
}

impl std::fmt::Display for StringOrVec {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
StringOrVec::String(s) => write!(f, "{}", s),
StringOrVec::Vec(v) => write!(f, "{}", v.join(", ")),
}
}
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct CompatEmbedRequest {
input: StringOrVec,
#[allow(dead_code)]
model: String,
#[allow(dead_code)]
encoding_format: Option<String>,
#[allow(dead_code)]
dimensions: Option<i32>,
#[allow(dead_code)]
user: Option<String>,
#[serde(default = "default_embed_parameters")]
parameters: EmbedParameters,
}

#[derive(Serialize, ToSchema)]
struct CompatEmbedResponse {
embeddings: Vec<CompatEmbedding>,
}

#[derive(Serialize, ToSchema)]
struct CompatEmbedding {
index: i32,
embedding: Vec<f32>,
object: String,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct ClassifyRequest {
inputs: String,
Expand All @@ -1175,6 +1218,13 @@ struct BatchClassifyRequest {
inputs: Vec<String>,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct BatchEmbedRequest {
inputs: Vec<String>,
#[serde(default = "default_embed_parameters")]
parameters: EmbedParameters,
}

#[derive(Debug, Serialize, Deserialize)]
struct Entity {
entity_group: String,
Expand Down
7 changes: 0 additions & 7 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,6 @@ async fn main() -> Result<(), RouterError> {

init_logging(otlp_endpoint, json_output);

// Validate args
if max_input_length >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(),
));
}

if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
Expand Down
Loading

0 comments on commit 07b77db

Please sign in to comment.