Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to Triton Punica kernels #658

Merged
merged 77 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
832d905
Collect timings
tgaddair Oct 23, 2024
697bf4d
Profiler
tgaddair Oct 23, 2024
d155163
Allow max batch prefill tokens < max input length
tgaddair Oct 23, 2024
ca3280c
Fix fallback
tgaddair Oct 23, 2024
830ce3d
Vectorize test
tgaddair Oct 23, 2024
7f250fe
Triton punica kernels
tgaddair Oct 24, 2024
e4fb765
Use triton punica
tgaddair Oct 24, 2024
634c8e2
Fix format
tgaddair Oct 24, 2024
7870729
Plumb weights
tgaddair Oct 24, 2024
0e057f0
Fixed issues
tgaddair Oct 24, 2024
c8ad4cb
Fixed cuda graphs
tgaddair Oct 24, 2024
a82eb64
Remove debug
tgaddair Oct 24, 2024
f68d2c0
Remove debug
tgaddair Oct 24, 2024
2ffc1db
Move init to warmup
tgaddair Oct 24, 2024
ea6c86d
Fix preloaded and speculators
tgaddair Oct 24, 2024
0497a76
Docker test
tgaddair Oct 24, 2024
9e2a29d
Profiling docs
tgaddair Oct 24, 2024
94e3742
Revert timings
tgaddair Oct 25, 2024
0abeccc
Fixed merge
tgaddair Oct 25, 2024
6f5a976
Added LORAX_SPECULATION_MAX_BATCH_SIZE
tgaddair Oct 26, 2024
f89ee87
Try separate trees per adapter
tgaddair Oct 27, 2024
23a77d2
Allow refcount==0
tgaddair Oct 27, 2024
22ed54d
Message
tgaddair Oct 28, 2024
327bb91
Docker test
tgaddair Oct 28, 2024
fbb2b3f
Cleanup
tgaddair Oct 28, 2024
f0693e9
Padding
tgaddair Oct 28, 2024
e62e0f8
Fixed turbo lora + compile
tgaddair Oct 28, 2024
66d8676
Fix
tgaddair Oct 28, 2024
55e5c41
Fix adapter root node id
tgaddair Oct 30, 2024
a6f3a17
More tests
tgaddair Oct 30, 2024
352c92a
Docker test
tgaddair Oct 30, 2024
1ea8d6e
Bump flashinfer
tgaddair Oct 30, 2024
c0640f2
Added logprobs fix
tgaddair Oct 31, 2024
54c36c9
Fix slots
tgaddair Oct 31, 2024
88cd932
No debugging
tgaddair Oct 31, 2024
3505b52
Docker test
tgaddair Oct 31, 2024
cf3d2d9
Fixed slot filtering
tgaddair Oct 31, 2024
d1ff7b4
Triton kernels
tgaddair Oct 31, 2024
57c33d7
Fix ragged
tgaddair Oct 31, 2024
ece47f7
More fixes
tgaddair Oct 31, 2024
779bff3
Merge
tgaddair Oct 31, 2024
cb99320
Revert docker
tgaddair Oct 31, 2024
466ea37
Renamed sgmv -> punica
tgaddair Oct 31, 2024
2f80c6a
Refactor PunicaWrapper
tgaddair Oct 31, 2024
47bfd0c
More configuration
tgaddair Oct 31, 2024
2343d78
More logs
tgaddair Oct 31, 2024
f915abe
Fixes
tgaddair Oct 31, 2024
ad460c0
Guard init
tgaddair Nov 1, 2024
43c129b
Guard model has lm_head
tgaddair Nov 1, 2024
1c70ec6
Determine trace set from preloaded adapter set
tgaddair Nov 1, 2024
3ebcbea
Plumb skip_lm_head
tgaddair Nov 1, 2024
922c5d6
Cleanup comments
tgaddair Nov 1, 2024
b2de54f
Fixed orient for rank
tgaddair Nov 1, 2024
35c7de2
Format
tgaddair Nov 1, 2024
295829f
Fixed tests
tgaddair Nov 1, 2024
ef86071
Fixed CausalLM and embedding model
tgaddair Nov 1, 2024
0d78a0a
Replace flume
tgaddair Nov 1, 2024
8cb79b2
Remove unused dep
tgaddair Nov 1, 2024
045a45a
Update axum
tgaddair Nov 1, 2024
20cf752
Client debug mode, fixed /
tgaddair Nov 1, 2024
2868acc
Docker test
tgaddair Nov 1, 2024
2131dc1
Fixed unused imports
tgaddair Nov 1, 2024
b727a94
Revert docker
tgaddair Nov 1, 2024
cc17d47
Add back tracing
tgaddair Nov 1, 2024
68991ba
Debug
tgaddair Nov 1, 2024
5380426
Docker test
tgaddair Nov 1, 2024
89abd51
Debug registration
tgaddair Nov 1, 2024
3c7b69b
Update tag
tgaddair Nov 1, 2024
d52f530
Don't skip filter
tgaddair Nov 4, 2024
45c6c53
Docker test
tgaddair Nov 4, 2024
3ad4d66
Remove register
tgaddair Nov 4, 2024
b45c219
Revert docker
tgaddair Nov 4, 2024
a4a2d5f
Fixed tests
tgaddair Nov 4, 2024
4a264bc
ruff
tgaddair Nov 4, 2024
e1067a0
Fix tests
tgaddair Nov 4, 2024
848b4c7
Clear cache
tgaddair Nov 4, 2024
107be9a
Check for key in lora weights
tgaddair Nov 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
660 changes: 471 additions & 189 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
RUN pip install einops --no-cache-dir

# Install flashinfer
RUN pip install --no-cache-dir flashinfer==0.1.5+cu124torch2.4 -i https://flashinfer.ai/whl/cu124/torch2.4
RUN pip install --no-cache-dir flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4

# Install server
COPY proto proto
Expand Down
18 changes: 17 additions & 1 deletion clients/python/lorax/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import requests
from requests.adapters import HTTPAdapter, Retry

Expand All @@ -20,7 +21,22 @@
from lorax.errors import parse_error
import os

LORAX_DEBUG_MODE = os.getenv("LORAD_DEBUG_MODE", None) is not None
LORAX_DEBUG_MODE = os.getenv("LORAX_DEBUG_MODE", None) is not None
if LORAX_DEBUG_MODE:
# https://stackoverflow.com/a/16630836/1869739
# These two lines enable debugging at httplib level (requests->urllib3->http.client)
# You will see the REQUEST, including HEADERS and DATA, and RESPONSE with HEADERS but without DATA.
# The only thing missing will be the response.body which is not logged.
import http.client as http_client
http_client.HTTPConnection.debuglevel = 1

# You must initialize logging, otherwise you'll not see debug output.
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
requests_log = logging.getLogger("requests.packages.urllib3")
requests_log.setLevel(logging.DEBUG)
requests_log.propagate = True


class Client:
"""Client to make calls to a LoRAX instance
Expand Down
13 changes: 6 additions & 7 deletions docs/guides/contributing/development_env.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ We'll be working out of three different terminals during development, each servi
Install development dependencies:

```shell
DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y
DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
rm -f $PROTOC_ZIP && \
hash -r
```

Expand All @@ -71,8 +71,7 @@ tmux new -s server
From within the `tmux` session, move into the LoRAX `server` directory within the repo (assumed to be in `/data/lorax`) and install dependencies:

```shell
cd /data/lorax/server
pip install -e .
cd /data/lorax/server && pip install -e .
make gen-server
```

Expand All @@ -95,9 +94,9 @@ tmux new -s router
Now move into the `router` directory within the repo and install dependencies:

```shell
cd /data/lorax/router
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
export PATH=$PATH:$HOME/.cargo/bin
cd /data/lorax/router && \
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \
export PATH=$PATH:$HOME/.cargo/bin && \
touch ../proto/generate.proto
```

Expand Down
19 changes: 19 additions & 0 deletions docs/guides/contributing/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,22 @@ make export-requirements
```

Never modify `requirements.txt` directly, as it may introduce dependency conflicts.

## Profiling

LoRAX supports the [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) to measure performance of LoRAX.

You can enable profiling when launching LoRAX by setting the `LORAX_PROFILER_DIR` environment variable to the directory
you wish to output the Tensorboard traces to.

Once initialized, LoRAX will begin recording traces for every request to the server. Because traces can get very large,
we record only the first 10 prefill requests (plus any decode requests between them), then stop recording and write
out the results. A summary will be printed to stdout when this occurs.

Once you have your traces written to the profiler directory, you can visualize them in Tensorboard using the
[PyTorch Profiler Tensorboard Plugin](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html).

```bash
pip install torch_tb_profiler
tensorboard --logdir=$LORAX_PROFILER_DIR
```
1 change: 1 addition & 0 deletions launcher/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ clap = { version = "4.1.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.5", features = ["termination"] }
nix = "0.26.2"
openssl = "0.10.66"
hf-hub = { version = "0.3.0", features = ["tokio"] }
h2 = "0.3.26"
rustix = "0.37.25"
serde = { version = "1.0.152", features = ["derive"] }
Expand Down
Loading
Loading