Skip to content

Commit

Permalink
Add predibase as a source for adapters (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Dec 14, 2023
1 parent 3ad7aae commit 143f303
Show file tree
Hide file tree
Showing 15 changed files with 103 additions and 49 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
target
router/tokenizer.json
*__pycache__*
run.sh
run.sh
data/
36 changes: 1 addition & 35 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ COMMIT_SHA=$(git rev-parse --short HEAD)
TAG="${COMMIT_SHA}${DIRTY}"

# Name of the Docker image
IMAGE_NAME="kubellm"
IMAGE_NAME="lorax"

# ECR Repository URL (replace with your actual ECR repository URL)
ECR_REPO="474375891613.dkr.ecr.us-west-2.amazonaws.com"
Expand All @@ -28,37 +28,3 @@ echo "Building ${IMAGE_NAME}:${TAG}"
docker build -t ${IMAGE_NAME}:${TAG} .
docker tag ${IMAGE_NAME}:${TAG} ${IMAGE_NAME}:latest

# Tag the Docker image for ECR repository
docker tag ${IMAGE_NAME}:${TAG} ${ECR_REPO}/${IMAGE_NAME}:${TAG}

# Log in to the ECR registry (assumes AWS CLI and permissions are set up)
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin ${ECR_REPO}

# Push to ECR
docker push ${ECR_REPO}/${IMAGE_NAME}:${TAG}


latest_flag=false

# Parse command line arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--latest)
latest_flag=true
shift
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done

# Check if the --latest flag has been passed
if $latest_flag; then
# Tag and push as 'latest'
docker tag ${IMAGE_NAME}:${TAG} ${ECR_REPO}/${IMAGE_NAME}:latest
docker push ${ECR_REPO}/${IMAGE_NAME}:latest
else
echo "The --latest flag has not been passed. Skipping push to ECR as latest."
fi
4 changes: 4 additions & 0 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,10 @@
"adapter_source": {
"type": "string",
"nullable": true
},
"api_token": {
"type": "string",
"nullable": true
}
}
},
Expand Down
3 changes: 2 additions & 1 deletion launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ struct Args {
source: String,

/// The source of the model to load.
/// Can be `hub` or `s3`.
/// Can be `hub` or `s3` or `pbase`
/// `hub` will load the model from the huggingface hub.
/// `s3` will load the model from the predibase S3 bucket.
/// `pbase` will load an s3 model but resolve the metadata from a predibase server
#[clap(default_value = "hub", long, env)]
adapter_source: String,

Expand Down
8 changes: 7 additions & 1 deletion proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,17 @@ enum AdapterSource {
S3 = 1;
/// Adapters loaded via local filesystem path
LOCAL = 2;
/// Adapters loaded via predibase
PBASE = 3;
}

message DownloadAdapterRequest {
/// Adapter ID
string adapter_id = 1;
/// Adapter source
AdapterSource adapter_source = 2;
/// Token for external API (predibase / HuggingFace)
optional string api_token = 3;
}

message DownloadAdapterResponse {
Expand All @@ -247,6 +251,8 @@ message LoadAdapterRequest {
AdapterSource adapter_source = 2;
/// Adapter index
uint32 adapter_index = 3;
/// Token for external API (predibase / HuggingFace)
optional string api_token = 4;
}

message LoadAdapterResponse {
Expand Down Expand Up @@ -274,4 +280,4 @@ message OffloadAdapterResponse {
AdapterSource adapter_source = 2;
/// Adapter index
uint32 adapter_index = 3;
}
}
4 changes: 4 additions & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,15 @@ impl Client {
&mut self,
adapter_id: String,
adapter_source: String,
api_token: Option<String>,
) -> Result<String> {
if let Some(adapter_source_enum) =
AdapterSource::from_str_name(adapter_source.to_uppercase().as_str())
{
let request = tonic::Request::new(DownloadAdapterRequest {
adapter_id,
adapter_source: adapter_source_enum.into(),
api_token: api_token,
})
.inject_context();
let response = self.stub.download_adapter(request).await?.into_inner();
Expand All @@ -210,6 +212,7 @@ impl Client {
adapter_id: String,
adapter_source: String,
adapter_index: u32,
api_token: Option<String>,
) -> Result<String> {
if let Some(adapter_source_enum) =
AdapterSource::from_str_name(adapter_source.to_uppercase().as_str())
Expand All @@ -218,6 +221,7 @@ impl Client {
adapter_id,
adapter_source: adapter_source_enum.into(),
adapter_index,
api_token: api_token,
})
.inject_context();
let response = self.stub.load_adapter(request).await?.into_inner();
Expand Down
5 changes: 4 additions & 1 deletion router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,11 @@ impl ShardedClient {
&mut self,
adapter_id: String,
adapter_source: String,
api_token: Option<String>,
) -> Result<String> {
// Only download the adapter with one client, since they share a single disk
self.clients[0]
.download_adapter(adapter_id, adapter_source)
.download_adapter(adapter_id, adapter_source, api_token)
.await
}

Expand All @@ -163,6 +164,7 @@ impl ShardedClient {
adapter_id: String,
adapter_source: String,
adapter_index: u32,
api_token: Option<String>,
) -> Result<String> {
// Load the adapter in all clients since there is sharding done between them
let futures: Vec<_> = self
Expand All @@ -173,6 +175,7 @@ impl ShardedClient {
adapter_id.clone(),
adapter_source.clone(),
adapter_index,
api_token.clone(),
))
})
.collect();
Expand Down
15 changes: 13 additions & 2 deletions router/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,18 @@ pub(crate) struct Adapter {
source: String,
/// index of the adapter
index: u32,
/// Optional - External api token
api_token: Option<String>,
}

impl Adapter {
pub(crate) fn new(id: String, source: String, index: u32) -> Self {
Self { id, source, index }
pub(crate) fn new(id: String, source: String, index: u32, api_token: Option<String>) -> Self {
Self {
id,
source,
index,
api_token,
}
}

pub(crate) fn id(&self) -> &str {
Expand All @@ -32,6 +39,10 @@ impl Adapter {
&self.source
}

pub(crate) fn api_token(&self) -> &std::option::Option<std::string::String> {
&self.api_token
}

pub(crate) fn index(&self) -> u32 {
self.index
}
Expand Down
8 changes: 7 additions & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,13 @@ impl Infer {
}
}

let adapter = Adapter::new(adapter_id.unwrap(), adapter_source.unwrap(), adapter_idx);
let api_token = request.parameters.api_token.clone();
let adapter = Adapter::new(
adapter_id.unwrap(),
adapter_source.unwrap(),
adapter_idx,
api_token,
);

// Validate request
let valid_request = self
Expand Down
4 changes: 4 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ pub(crate) struct GenerateParameters {
#[schema(nullable = true, default = "null", example = "hub")]
pub adapter_source: Option<String>,
#[serde(default)]
#[schema(nullable = true, default = "null", example = "<token from predibase>")]
pub api_token: Option<String>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>,
#[serde(default)]
Expand Down Expand Up @@ -159,6 +162,7 @@ fn default_parameters() -> GenerateParameters {
GenerateParameters {
adapter_id: None,
adapter_source: None,
api_token: None,
best_of: None,
temperature: None,
repetition_penalty: None,
Expand Down
7 changes: 6 additions & 1 deletion router/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver<Adapte
}

match client
.download_adapter(adapter.id().to_string(), adapter.source().to_string())
.download_adapter(
adapter.id().to_string(),
adapter.source().to_string(),
adapter.api_token().clone(),
)
.await
{
Ok(_) => {
Expand Down Expand Up @@ -185,6 +189,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver<Adapte
adapter.id().to_string(),
adapter.source().to_string(),
adapter.index(),
adapter.api_token().clone(),
)
.await
{
Expand Down
8 changes: 4 additions & 4 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ mod tests {
..default_parameters()
},
},
Adapter::new("".to_string(), "hf".to_string(), 0),
Adapter::new("".to_string(), "hf".to_string(), 0, None),
)
.await
{
Expand Down Expand Up @@ -508,7 +508,7 @@ mod tests {
..default_parameters()
},
},
Adapter::new("".to_string(), "hf".to_string(), 0),
Adapter::new("".to_string(), "hf".to_string(), 0, None),
)
.await
{
Expand All @@ -526,7 +526,7 @@ mod tests {
..default_parameters()
},
},
Adapter::new("".to_string(), "hf".to_string(), 0),
Adapter::new("".to_string(), "hf".to_string(), 0, None),
)
.await
{
Expand All @@ -544,7 +544,7 @@ mod tests {
..default_parameters()
},
},
Adapter::new("".to_string(), "hf".to_string(), 0),
Adapter::new("".to_string(), "hf".to_string(), 0, None),
)
.await
.unwrap();
Expand Down
10 changes: 9 additions & 1 deletion server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lorax_server.models import Model, get_model
from lorax_server.pb import generate_pb2_grpc, generate_pb2
from lorax_server.tracing import UDSOpenTelemetryAioServerInterceptor
from lorax_server.utils import HUB, LOCAL, S3, get_config_path, get_local_dir
from lorax_server.utils import HUB, LOCAL, S3, PBASE, get_config_path, get_local_dir, map_pbase_model_id_to_s3
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID


Expand Down Expand Up @@ -127,6 +127,9 @@ async def DownloadAdapter(self, request, context):
)

adapter_source = _adapter_source_enum_to_string(request.adapter_source)
if adapter_source == PBASE:
adapter_id = map_pbase_model_id_to_s3(adapter_id, request.api_token)
adapter_source = S3
try:
# fail fast if ID is not an adapter (i.e. it is a full model)
# TODO(geoffrey): do this for S3– can't do it this way because the
Expand Down Expand Up @@ -159,6 +162,9 @@ async def LoadAdapter(self, request, context):
adapter_id = request.adapter_id
adapter_source = _adapter_source_enum_to_string(request.adapter_source)
adapter_index = request.adapter_index
if adapter_source == PBASE:
adapter_id = map_pbase_model_id_to_s3(adapter_id, request.api_token)
adapter_source = S3
self.model.load_adapter(adapter_id, adapter_source, adapter_index)

return generate_pb2.LoadAdapterResponse(
Expand Down Expand Up @@ -281,5 +287,7 @@ def _adapter_source_enum_to_string(adapter_source: int) -> str:
return S3
elif adapter_source == generate_pb2.AdapterSource.LOCAL:
return LOCAL
elif adapter_source == generate_pb2.AdapterSource.PBASE:
return PBASE
else:
raise ValueError(f"Unknown adapter source {adapter_source}")
6 changes: 5 additions & 1 deletion server/lorax_server/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
get_config_path,
get_local_dir,
download_weights,
map_pbase_model_id_to_s3,
weight_hub_files,
weight_files,
EntryNotFoundError,
HUB,
PBASE,
LOCAL,
LocalEntryNotFoundError,
RevisionNotFoundError,
Expand All @@ -41,17 +43,19 @@
"get_local_dir",
"get_start_stop_idxs_for_rank",
"initialize_torch_distributed",
"map_pbase_model_id_to_s3",
"download_weights",
"weight_hub_files",
"EntryNotFoundError",
"HeterogeneousNextTokenChooser",
"HUB",
"LOCAL",
"PBASE",
"S3",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",
"NextTokenChooser",
"S3",
"Sampling",
"StoppingCriteria",
"StopSequenceCriteria",
Expand Down
Loading

0 comments on commit 143f303

Please sign in to comment.