Skip to content

JetStream changes for Jax based implementation of unified_lora_params for decoding batch of multiple different lora adapters. #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

aman2930
Copy link
Collaborator

@aman2930 aman2930 commented Mar 7, 2025

JetStream changes for Jax based implementation of unified_lora_params for decoding batch of multiple different lora adapters.

…tAdapters, LoadAdapter and UnloadAdapter. 2) Driver which is holding list of all loaded base-parameters is now storing the list of lora updated paramters for loaded lora. Implemented methods for loading, unloading and listing LoRA adapters into the Driver object. Original base model params are intact and saved into the params dictionary with key . 3) Created a proxy-client to make MultiAdapterManager service requests to JetStream server.
…pters. Its functionality includes loading, unloading of adapters between CPU RAM and HBM. It also follows LRU policy to evict the adapter if a new load_adapter request comes up. Currently it is only storing the adapter as separate tensors (lora_a and lora_b). Calculation of lora_b x lora_a is being done in prefill() and generate() during decode request. Adapter_tensorstore can be configured with a max_limit on HBM and RAM.

2) Functionality to load from a catalog file at the start of the server is added. If no file is given, it will just load the base params. Loading from the catalog file is done on CPU RAM. After that based on incoming requests, those params are moved/evicted to/from HBM.
3) Some proto updates to get only single path for each adapter, and that path is expected to have an adapter_config.json and Orbax format weights in 0/items folder.
…n API (https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/docs/proposals/003-model-server-protocol/README.md#inference-api-protocol),  & .

2) Added a flag to explicitly run the JetStream server with these APIs when . Else only expose older Decode() & HealthCheck() APIs of the JetStream Server.
3) Fixed a bug in the adapter_tensorstore while converting jnp_array and np_array.
4) Added a  which made requests to the new APIs (v1/load_lora_adapter, v1/unload_lora_adapter, v1/models, v1/completions)
1) kv_cache_utilization: This refers to percentage of memory in the allocated kv-cache on TPU HBM, that is actually used during decode. It is based on the percentage of slots used.
2) num_requests_waiting: Total number of requests which are waiting to be decoded.
3) lora_requests_info: List of LoRA adapters that are loaded into the TPU HBM for serving the requests.
2) Fixing model_ckpt_conversion.sh after refactoring and merging from main.
… for decoding batch of multiple different lora adapters.
@aman2930 aman2930 requested a review from vipannalla as a code owner March 7, 2025 16:53
@aman2930 aman2930 requested a review from yixinshi March 7, 2025 16:58
Base automatically changed from amangu-lora to main April 14, 2025 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant