Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Droid] DBRX Truss Implementation #283

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/truss_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install git+https://github.com/basetenlabs/truss.git requests tenacity --upgrade

- name: Run tests
env:
BASETEN_API_KEY: ${{ secrets.BASETEN_API_KEY }}
Expand Down
43 changes: 43 additions & 0 deletions dbrx_truss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# DBRX Truss

This truss makes the [DBRX](https://huggingface.co/databricks/dbrx-instruct) model available on the Baseten platform for efficient inference. DBRX is an open-source large language model trained by Databricks. It is a 132B parameter model capable of instruction following and general language tasks.

## Setup

This truss requires Python 3.11 and the dependencies listed in `requirements.txt`. It is configured to run on A10G GPUs for optimal performance.

## Usage

Once deployed on Baseten, the truss exposes an endpoint for making prediction requests to the model.

### Request Format

Requests should be made with a JSON payload in the following format:

```json
{
"prompt": "What is machine learning?"
}
```

### Parameters

The following inference parameters can be configured in `config.yaml`:

- `max_new_tokens`: Max number of tokens to generate in the response (default: 100)
- `temperature`: Controls randomness of output (default: 0.7)
- `top_p`: Nucleus sampling probability threshold (default: 0.95)
- `top_k`: Number of highest probability vocabulary tokens to keep (default: 50)
- `repetition_penalty`: Penalty for repeated tokens (default: 1.01)

## Original Model

DBRX was developed and open-sourced by Databricks. For more information, see:

- [DBRX Model Card](https://github.com/databricks/dbrx/blob/master/MODEL_CARD_dbrx_instruct.md)
- [Databricks Blog Post](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm)
- [HuggingFace Model Page](https://huggingface.co/databricks/dbrx-instruct)

## About Baseten

This truss was created by [Baseten](https://www.baseten.co/) to enable easy deployment and serving of the open-source DBRX model at scale. Baseten is a platform for building powerful AI apps.
1 change: 1 addition & 0 deletions dbrx_truss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Empty file
13 changes: 13 additions & 0 deletions dbrx_truss/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
python_version: py311
requirements_file: requirements.txt

resources:
accelerator: A10G
use_gpu: true

model_metadata:
example_model_input: |
{
"prompt": "What is machine learning?"
}
repo_id: databricks/dbrx-instruct
1 change: 1 addition & 0 deletions dbrx_truss/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Empty file
60 changes: 60 additions & 0 deletions dbrx_truss/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Dict

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


class Model:
def __init__(self, data_dir: str, config: Dict, **kwargs):
self.data_dir = data_dir
self.config = config
self.cuda_available = torch.cuda.is_available()

def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(
"databricks/dbrx-instruct", trust_remote_code=True, token=True
)

if self.cuda_available:
self.model = AutoModelForCausalLM.from_pretrained(
"databricks/dbrx-instruct",
trust_remote_code=True,
token=True,
torch_dtype=(
torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
),
device_map="auto",
attn_implementation=(
"flash_attention_2" if "flash_attn" in locals() else "eager"
),
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
"databricks/dbrx-instruct", trust_remote_code=True, token=True
)

def predict(self, request: Dict) -> Dict:
self.load() # Reload model for each request

prompt = request["prompt"]
messages = [{"role": "user", "content": prompt}]

tokenized_input = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
tokenized_input = tokenized_input.to(self.model.device)

generated = self.model.generate(
input_ids=tokenized_input,
max_new_tokens=self.config.get("max_new_tokens", 100),
temperature=self.config.get("temperature", 0.7),
top_p=self.config.get("top_p", 0.95),
top_k=self.config.get("top_k", 50),
repetition_penalty=self.config.get("repetition_penalty", 1.01),
pad_token_id=self.tokenizer.pad_token_id,
)

decoded_output = self.tokenizer.batch_decode(generated)[0]
response_text = decoded_output.split("<|im_start|> assistant\n")[-1]

return {"result": response_text}
4 changes: 4 additions & 0 deletions dbrx_truss/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch>=2.1.0
transformers>=4.39.0
accelerate==0.28.0
tiktoken==0.4.0
Loading