Skip to content

Commit

Permalink
import
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 committed May 22, 2024
1 parent 629986e commit 78519ea
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from lorax_server.adapters.types import LORA, MEDUSA
from lorax_server.layers.linear import get_linear
from lorax_server.layers.tensor_parallel import SuperLayer
from lorax_server.layers.tensor_parallel import SuperLayer, TensorParallelColumnLinear, TensorParallelHead # noqa: F401
from lorax_server.utils.sgmv import (
add_lora_a_bgmv,
add_lora_b_bgmv,
Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/utils/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)

fp8_supported = torch.cuda.get_device_capability()[0] >= 9 or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9)


def reshape_and_cache(
key: torch.Tensor,
Expand All @@ -28,7 +30,7 @@ def reshape_and_cache(
)
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
key, value, key_cache, value_cache, slots, "fp8" if fp8_supported else "auto", 1.0
)


Expand Down

0 comments on commit 78519ea

Please sign in to comment.