Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 committed Dec 18, 2023
1 parent af59e54 commit 6226aaf
Show file tree
Hide file tree
Showing 52 changed files with 1,796 additions and 914 deletions.
1 change: 0 additions & 1 deletion clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,6 @@ async def generate_stream(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:

if resp.status != 200:
raise parse_error(resp.status, await resp.json())

Expand Down
4 changes: 1 addition & 3 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,4 @@ def unsupported_url(base_url, unsupported_model):

@pytest.fixture(scope="session")
def hf_headers():
return build_hf_headers(
library_name="lorax-tests", library_version=__version__
)
return build_hf_headers(library_name="lorax-tests", library_version=__version__)
29 changes: 15 additions & 14 deletions integration-tests/scripts/dynamic_adapter_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def query_lorax(args):
if adapter_id is not None:
# request_params["adapter_source"] = "local"
request_params["adapter_id"] = adapter_id
print("request_params", request_params)

print("request_params", request_params)
url = "http://localhost:8080/generate"
headers = {
"Content-Type": "application/json",
Expand All @@ -67,7 +67,7 @@ def query_lorax(args):
},
).encode("utf-8")
request = Request(url, headers=headers, data=data)

try:
with urlopen(request) as response:
response_body = json.loads(response.read().decode("utf-8"))
Expand All @@ -78,19 +78,22 @@ def query_lorax(args):
print(f"exception in request: {adapter_id}")
return adapter_id, 0, None

print("adapter_id: {}\nCompleted {} in {:3f} seconds ({:3f} tokens / s)\n----".format(
adapter_id,
ntokens,
duration_s,
(ntokens / duration_s),
))
print(
"adapter_id: {}\nCompleted {} in {:3f} seconds ({:3f} tokens / s)\n----".format(
adapter_id,
ntokens,
duration_s,
(ntokens / duration_s),
)
)
return adapter_id, ntokens, duration_s, response_body["generated_text"]


def get_local_path(model_id):
model_id = model_id.replace("/", "--")
return f"/data/models--{model_id}/snapshots/834b33af35ff5965ea3e4bc18b51ad5d65da7466"

return (
f"/data/models--{model_id}/snapshots/834b33af35ff5965ea3e4bc18b51ad5d65da7466"
)


def main():
Expand Down Expand Up @@ -145,7 +148,6 @@ def main():
# # # "hessertaboada/ludwig-webinar",
# # # "AmlanSamanta/ludwig-webinar",


# # # None,

# # # # download error: bad adapter name
Expand Down Expand Up @@ -197,6 +199,5 @@ def main():
# print({k: np.mean(v) for k, v in d.items()})


if __name__ == '__main__':
if __name__ == "__main__":
main()

38 changes: 22 additions & 16 deletions server/awq_kernels/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

common_setup_kwargs = {
}
common_setup_kwargs = {}


def get_generator_flag():
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
if os.path.exists(
os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")
):
generator_flag = ["-DOLD_GENERATOR_PATH"]

return generator_flag


Expand All @@ -23,7 +24,9 @@ def get_compute_capabilities():
cc = major * 10 + minor

if cc < 75:
raise RuntimeError("GPUs with compute capability less than 7.5 are not supported.")
raise RuntimeError(
"GPUs with compute capability less than 7.5 are not supported."
)

# figure out compute capability
compute_capabilities = {75, 80, 86, 89, 90}
Expand All @@ -34,14 +37,15 @@ def get_compute_capabilities():

return capability_flags


generator_flags = get_generator_flag()
arch_flags = get_compute_capabilities()


extra_compile_args={
extra_compile_args = {
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
"nvcc": [
"-O3",
"-O3",
"-std=c++17",
"-DENABLE_BF16",
"-U__CUDA_NO_HALF_OPERATORS__",
Expand All @@ -53,7 +57,9 @@ def get_compute_capabilities():
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
] + arch_flags + generator_flags
]
+ arch_flags
+ generator_flags,
}

extensions = [
Expand All @@ -64,8 +70,9 @@ def get_compute_capabilities():
"awq_cuda/quantization/gemm_cuda_gen.cu",
"awq_cuda/layernorm/layernorm.cu",
"awq_cuda/position_embedding/pos_encoding_kernels.cu",
"awq_cuda/quantization/gemv_cuda.cu"
], extra_compile_args=extra_compile_args
"awq_cuda/quantization/gemv_cuda.cu",
],
extra_compile_args=extra_compile_args,
)
]

Expand All @@ -75,18 +82,17 @@ def get_compute_capabilities():
[
"awq_cuda/pybind_ft.cpp",
"awq_cuda/attention/ft_attention.cpp",
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
], extra_compile_args=extra_compile_args
"awq_cuda/attention/decoder_masked_multihead_attention.cu",
],
extra_compile_args=extra_compile_args,
)
)

additional_setup_kwargs = {
"ext_modules": extensions,
"cmdclass": {'build_ext': BuildExtension}
"cmdclass": {"build_ext": BuildExtension},
}

common_setup_kwargs.update(additional_setup_kwargs)

setup(
**common_setup_kwargs
)
setup(**common_setup_kwargs)
15 changes: 13 additions & 2 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,19 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, source, adapter_source
model_id,
adapter_id,
revision,
sharded,
quantize,
dtype,
trust_remote_code,
uds_path,
source,
adapter_source,
)


def _download_weights(
model_id: str,
revision: Optional[str] = None,
Expand All @@ -95,6 +105,7 @@ def _download_weights(
# Import here after the logger is added to log potential import exceptions
from lorax_server import utils
from lorax_server.utils import sources

model_source = sources.get_model_source(source, model_id, revision, extension)

# Test if files were already download
Expand Down Expand Up @@ -168,7 +179,7 @@ def _download_weights(
discard_names = getattr(class_, "_tied_weights_keys", [])
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))

except Exception as e:
except Exception:
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names)
Expand Down
40 changes: 24 additions & 16 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import torch

from loguru import logger
Expand All @@ -19,6 +18,7 @@
from lorax_server.models.t5 import T5Sharded
from lorax_server.models.gpt_neox import GPTNeoxSharded
from lorax_server.utils.sources import get_s3_model_local_dir
from lorax_server.utils.globals import set_speculation_num

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
Expand Down Expand Up @@ -69,7 +69,7 @@
__all__.append(FlashGPT2)
__all__.append(FlashQwen)
__all__.append(FlashPhi)

MISTRAL = True
try:
from lorax_server.models.flash_mistral import FlashMistral
Expand Down Expand Up @@ -103,10 +103,11 @@ def get_model(
adapter_source: str,
) -> Model:
config_dict = None
medusa_id = None
if source == "s3":
# change the model id to be the local path to the folder so
# we can load the config_dict locally
logger.info(f"Using the local files since we are coming from s3")
logger.info("Using the local files since we are coming from s3")
model_path = get_s3_model_local_dir(model_id)
logger.info(f"model_path: {model_path}")
config_dict, _ = PretrainedConfig.get_config_dict(
Expand All @@ -118,11 +119,18 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
else:
else:
raise ValueError(f"Unknown source {source}")

model_type = config_dict["model_type"]

if "medusa_num_heads" in config_dict:
medusa_id = model_id
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
set_speculation_num(speculate_medusa)

if dtype is None:
dtype = torch.float16
elif dtype == "float16":
Expand Down Expand Up @@ -234,6 +242,7 @@ def get_model(
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
medusa_id=medusa_id,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
Expand Down Expand Up @@ -271,7 +280,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded(
Expand Down Expand Up @@ -300,9 +309,10 @@ def get_model(
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
medusa_id=medusa_id,
)
raise NotImplementedError("Mistral model requires flash attention v2")

if model_type == "mixtral":
if MIXTRAL:
return FlashMixtral(
Expand All @@ -314,8 +324,10 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")

raise NotImplementedError(
"Mixtral models requires flash attention v2, stk and megablocks"
)

if model_type == "qwen":
if FLASH_ATTENTION:
return FlashQwen(
Expand All @@ -328,7 +340,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Qwen model requires flash attention v2")

if model_type in ["phi-msft", "phi"]:
if FLASH_ATTENTION:
return FlashPhi(
Expand Down Expand Up @@ -367,13 +379,9 @@ def get_model(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `lorax-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError(
"4bit quantization is not supported for AutoModel"
)
raise ValueError("4bit quantization is not supported for AutoModel")
if quantize == "awq":
raise ValueError(
"awq quantization is not supported for AutoModel"
)
raise ValueError("awq quantization is not supported for AutoModel")

if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ def get_cache_manager() -> CacheManager:
if CACHE_MANAGER is None:
raise RuntimeError("cache manager was not initialized")

return CACHE_MANAGER
return CACHE_MANAGER
5 changes: 3 additions & 2 deletions server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import inspect

from dataclasses import dataclass
from opentelemetry import trace
Expand Down Expand Up @@ -99,7 +98,9 @@ def from_pb(
)
adapter_indices_list.append(r.adapter_index)

adapter_indices = torch.tensor(adapter_indices_list, dtype=torch.int64, device=device)
adapter_indices = torch.tensor(
adapter_indices_list, dtype=torch.int64, device=device
)

tokenized_inputs = tokenizer(
inputs,
Expand Down
Loading

0 comments on commit 6226aaf

Please sign in to comment.