Skip to content

Commit

Permalink
Lorax NER (#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Jul 9, 2024
1 parent 24cb494 commit a3ad209
Show file tree
Hide file tree
Showing 19 changed files with 723 additions and 50 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
uses: lerentis/[email protected]
with:
soci-release: 'v0.4.0'

- name: Set up Docker Buildx
uses: docker/[email protected]

Expand All @@ -51,7 +51,7 @@ jobs:
with:
config-inline: |
version = 2
# persistent data location
root = "/runner/build/containerd"
Expand All @@ -66,7 +66,7 @@ jobs:
type=semver,pattern={{major}}.{{minor}}
type=sha,prefix=,suffix=,format=short
type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }}
- name: Create a hash from tags
env:
tags: ${{ steps.meta.outputs.tags }}
Expand Down Expand Up @@ -124,7 +124,7 @@ jobs:
echo "Pushing $tag to GHCR"
sudo ctr i push --user "${{ github.repository_owner }}:${{ secrets.GHCR_PAT }}" $tag
done
- name: Create and push soci index
env:
tags: ${{ steps.meta.outputs.tags }}
Expand All @@ -151,4 +151,3 @@ jobs:
# Delete the SHA image(s) from containerd store
sudo ctr i rm $(sudo ctr i ls -q)
59 changes: 57 additions & 2 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
Parameters,
MergedAdapters,
ResponseFormat,
EmbedResponse
EmbedResponse,
ClassifyResponse
)
from lorax.errors import parse_error

Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
"""
self.base_url = base_url
self.embed_endpoint = f"{base_url}/embed"
self.classify_endpoint = f"{base_url}/classify"
self.headers = headers
self.cookies = cookies
self.timeout = timeout
Expand Down Expand Up @@ -441,6 +443,35 @@ def embed(self, inputs: str) -> EmbedResponse:
return EmbedResponse(**payload)


def classify(self, inputs: str) -> ClassifyResponse:
"""
Given inputs, run token classification on the text using the model
Args:
inputs (`str`):
Input text
Returns:
Entities: Entities found in the input text
"""
request = Request(inputs=inputs)

resp = requests.post(
self.classify_endpoint,
json=request.dict(by_alias=True),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)

payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())

print(payload)
return ClassifyResponse(**payload)


class AsyncClient:
"""Asynchronous Client to make calls to a LoRAX instance
Expand Down Expand Up @@ -483,6 +514,7 @@ def __init__(
"""
self.base_url = base_url
self.embed_endpoint = f"{base_url}/embed"
self.classify_endpoint = f"{base_url}/classify"
self.headers = headers
self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60)
Expand Down Expand Up @@ -779,4 +811,27 @@ async def embed(self, inputs: str) -> EmbedResponse:

if resp.status != 200:
raise parse_error(resp.status, payload)
return EmbedResponse(**payload)
return EmbedResponse(**payload)


async def classify(self, inputs: str) -> ClassifyResponse:
"""
Given inputs, run token classification on the text using the model
Args:
inputs (`str`):
Input text
Returns:
Entities: Entities found in the input text
"""
request = Request(inputs=inputs)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.classify_endpoint, json=request.dict(by_alias=True)) as resp:
payload = await resp.json()

if resp.status != 200:
raise parse_error(resp.status, payload)
return ClassifyResponse(**payload)
4 changes: 4 additions & 0 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,7 @@ class DeployedModel(BaseModel):
class EmbedResponse(BaseModel):
# Embeddings
embeddings: Optional[List[float]]

class ClassifyResponse(BaseModel):
# Classifications
entities: Optional[List[dict]]
30 changes: 30 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ service LoraxService {
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Embed
rpc Embed (EmbedRequest) returns (EmbedResponse);
/// Classify
rpc Classify (ClassifyRequest) returns (ClassifyResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
Expand Down Expand Up @@ -253,6 +255,34 @@ message EmbedResponse {
string errorMsg = 2;
}

message Entity {
string entity = 1;
float score = 2;
uint32 index = 3;
string word = 4;
uint32 start = 5;
uint32 end = 6;
}

message EntityList {
/// Request ID
uint64 request_id = 1;
/// Entities
repeated Entity entities = 2;
}

message ClassifyRequest {
/// Batch
Batch batch = 1;
}

message ClassifyResponse {
/// Classifications
repeated EntityList entity_lists = 1;
/// Error message on failure
string errorMsg = 2;
}

message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
Expand Down
9 changes: 9 additions & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use tonic::transport::{Channel, Uri};
use tonic::Response;
use tracing::instrument;

/// LoRAX gRPC client
Expand Down Expand Up @@ -196,6 +197,14 @@ impl Client {
Ok(response.embeddings)
}

/// Classify
#[instrument(skip(self))]
pub async fn classify(&mut self, batch: Batch) -> Result<Vec<EntityList>> {
let request = tonic::Request::new(ClassifyRequest { batch: Some(batch) }).inject_context();
let response = self.stub.classify(request).await?.into_inner();
Ok(response.entity_lists)
}

/// Downloads the weights for an adapter.
pub async fn download_adapter(
&mut self,
Expand Down
2 changes: 1 addition & 1 deletion router/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
AdapterParameters, AlternativeTokens, Batch, CachedBatch, DownloadAdapterResponse, Embedding,
FinishReason, GeneratedText, Generation, MajoritySignMethod, MergeStrategy,
Entity, EntityList, FinishReason, GeneratedText, Generation, MajoritySignMethod, MergeStrategy,
NextTokenChooserParameters, NextTokens, PrefillTokens, Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
Expand Down
14 changes: 13 additions & 1 deletion router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::pb::generate::v1::{EmbedResponse, Embedding};
use crate::pb::generate::v1::{EmbedResponse, Embedding, EntityList};
/// Multi shard Client
use crate::{
AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation,
Expand Down Expand Up @@ -166,6 +166,18 @@ impl ShardedClient {
Ok(results?.into_iter().flatten().collect())
}

/// Classify the given batch
#[instrument(skip(self))]
pub async fn classify(&mut self, batch: Batch) -> Result<Vec<EntityList>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.classify(batch.clone())))
.collect();
let results: Result<Vec<Vec<EntityList>>> = join_all(futures).await.into_iter().collect();
Ok(results?.into_iter().flatten().collect())
}

pub async fn download_adapter(
&mut self,
adapter_parameters: AdapterParameters,
Expand Down
Loading

0 comments on commit a3ad209

Please sign in to comment.