Skip to content

Commit

Permalink
Embedder Service v0 with FlashBert (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored May 25, 2024
1 parent feb69c4 commit e37549e
Show file tree
Hide file tree
Showing 16 changed files with 644 additions and 138 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ repos:
rev: 7.0.0
hooks:
- id: flake8
name: flake8
args: ['--max-line-length=120']
57 changes: 56 additions & 1 deletion clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
Response,
Request,
Parameters,
MergedAdapters, ResponseFormat,
MergedAdapters,
ResponseFormat,
EmbedResponse
)
from lorax.errors import parse_error

Expand Down Expand Up @@ -58,6 +60,7 @@ def __init__(
HTTP requests session object to reuse
"""
self.base_url = base_url
self.embed_endpoint = f"{base_url}/embed"
self.headers = headers
self.cookies = cookies
self.timeout = timeout
Expand Down Expand Up @@ -345,6 +348,34 @@ def generate_stream(
raise parse_error(resp.status_code, json_payload)
yield response


def embed(self, inputs: str) -> EmbedResponse:
"""
Given inputs, embed the text using the model
Args:
inputs (`str`):
Input text
Returns:
Embeddings: computed embeddings
"""
request = Request(inputs=inputs)

resp = requests.post(
self.embed_endpoint,
json=request.dict(by_alias=True),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)

payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())

return EmbedResponse(**payload)


class AsyncClient:
"""Asynchronous Client to make calls to a LoRAX instance
Expand Down Expand Up @@ -387,6 +418,7 @@ def __init__(
Timeout in seconds
"""
self.base_url = base_url
self.embed_endpoint = f"{base_url}/embed"
self.headers = headers
self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60)
Expand Down Expand Up @@ -661,3 +693,26 @@ async def generate_stream(
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload)
yield response


async def embed(self, inputs: str) -> EmbedResponse:
"""
Given inputs, embed the text using the model
Args:
inputs (`str`):
Input text
Returns:
Embeddings: computed embeddings
"""
request = Request(inputs=inputs)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.embed_endpoint, json=request.dict(by_alias=True)) as resp:
payload = await resp.json()

if resp.status != 200:
raise parse_error(resp.status, payload)
return EmbedResponse(**payload)
8 changes: 6 additions & 2 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ class StreamResponse(BaseModel):
class DeployedModel(BaseModel):
model_id: str
sha: str

# Suppress pydantic warning over `model_id` field.
# Suppress pydantic warning over `model_id` field
model_config = ConfigDict(protected_namespaces=())


class EmbedResponse(BaseModel):
# Embeddings
embeddings: Optional[List[float]]
10 changes: 10 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ struct Args {
#[clap(long, env)]
sharded: Option<bool>,

/// Whether this model is mean for embeddings or text generation.
/// By default models are for text generation.
/// Setting it to `true` will enable the embedding endpoints and disable the generation ones.
#[clap(long, env)]
embedding_model: Option<bool>,

/// The number of shards to use if you don't want to use all GPUs on a given machine.
/// You can use `CUDA_VISIBLE_DEVICES=0,1 lorax-launcher... --num_shard 2`
/// and `CUDA_VISIBLE_DEVICES=2,3 lorax-launcher... --num_shard 2` to
Expand Down Expand Up @@ -1119,6 +1125,10 @@ fn spawn_webserver(
router_args.push(origin.to_string());
}

if args.embedding_model.unwrap_or(false) {
router_args.push("--embedding-model".to_string());
}

// Ngrok
if args.ngrok {
router_args.push("--ngrok".to_string());
Expand Down
17 changes: 16 additions & 1 deletion proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ service LoraxService {
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Embed
rpc Embed (EmbedRequest) returns (EmbedResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
Expand Down Expand Up @@ -230,6 +232,19 @@ message DecodeResponse {
optional CachedBatch batch = 2;
}

message EmbedRequest {
string inputs = 1;
}

message Embedding {
repeated float values = 1;
}

message EmbedResponse {
Embedding embeddings = 1;
string errorMsg = 2;
}

message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
Expand Down Expand Up @@ -309,7 +324,7 @@ message DownloadAdapterResponse {

/// Fraction of the adapter memory limit consumed by the adapter.
/// If no limit is set, will return 0.
/// When the total across all loaded adapters exceeds
/// When the total across all loaded adapters exceeds
/// the adapter_memory_fraction limit, no more adapters
/// will be loaded to GPU and LoRAX will begin swapping.
float memory_fraction = 2;
Expand Down
8 changes: 8 additions & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ impl Client {
Ok(response)
}

/// Embed
#[instrument(skip(self))]
pub async fn embed(&mut self, inputs: String) -> Result<EmbedResponse> {
let request = tonic::Request::new(EmbedRequest { inputs }).inject_context();
let response = self.stub.embed(request).await?.into_inner();
Ok(response)
}

/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
Expand Down
12 changes: 12 additions & 0 deletions router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::pb::generate::v1::EmbedResponse;
/// Multi shard Client
use crate::{
AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation,
Expand Down Expand Up @@ -153,6 +154,17 @@ impl ShardedClient {
merge_generations(results?)
}

/// Get the model info
#[instrument(skip(self))]
pub async fn embed(&mut self, inputs: String) -> Result<Vec<EmbedResponse>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.embed(inputs.clone())))
.collect();
join_all(futures).await.into_iter().collect()
}

pub async fn download_adapter(
&mut self,
adapter_parameters: AdapterParameters,
Expand Down
12 changes: 12 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ pub struct Info {
pub docker_label: Option<&'static str>,
#[schema(nullable = true, example = "http://localhost:8899")]
pub request_logger_url: Option<String>,
#[schema(example = false)]
pub embedding_model: bool,
}

#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
Expand Down Expand Up @@ -633,6 +635,16 @@ pub(crate) enum CompletionFinishReason {
ToolCalls,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct EmbedRequest {
inputs: String,
}

#[derive(Serialize, ToSchema)]
struct EmbedResponse {
embeddings: Vec<f32>,
}

impl From<CompletionRequest> for CompatGenerateRequest {
fn from(req: CompletionRequest) -> Self {
CompatGenerateRequest {
Expand Down
4 changes: 4 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ struct Args {
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
embedding_model: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
Expand Down Expand Up @@ -109,6 +111,7 @@ async fn main() -> Result<(), RouterError> {
revision,
validation_workers,
json_output,
embedding_model,
otlp_endpoint,
cors_allow_origin,
cors_allow_method,
Expand Down Expand Up @@ -372,6 +375,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken,
ngrok_edge,
adapter_source,
embedding_model,
)
.await?;
Ok(())
Expand Down
Loading

0 comments on commit e37549e

Please sign in to comment.