Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0 add autoquant #402

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
[![DOI](https://zenodo.org/badge/703686617.svg)](https://zenodo.org/doi/10.5281/zenodo.11406462)
![Docker pulls](https://img.shields.io/docker/pulls/michaelf34/infinity)


Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip. Infinity is developed under [MIT License](https://github.com/michaelfeil/infinity/blob/main/LICENSE).

## Why Infinity
Expand Down
3 changes: 3 additions & 0 deletions docs/docs/cli_v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ $ infinity_emb v2 --help
│ `INFINITY_LENGTHS_VIA_TOKENIZ… │
│ [default: │
│ lengths-via-tokenize] │
│ --dtype [float32|float16|int8|fp8|aut dtype for the model weights. │
│ oquant|auto] [env var: `INFINITY_DTYPE`] │

│ --dtype [float32|float16|bfloat16|int dtype for the model weights. │
│ 8|fp8|auto] [env var: `INFINITY_DTYPE`] │
│ [default: auto] │
Expand Down
4 changes: 2 additions & 2 deletions libs/infinity_emb/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ test tests:
poetry run pytest

openapi:
./../../docs/assets/create_openapi_with_server_hook.sh
poetry run ./../../docs/assets/create_openapi_with_server_hook.sh

######################
# LINTING AND FORMATTING
Expand Down Expand Up @@ -60,7 +60,7 @@ benchmark_embed: tests/data/benchmark/benchmark_embed.json

# Generate CLI v2 documentation
cli_v2_docs:
./../../docs/assets/create_cli_v2_docs.sh
poetry run ./../../docs/assets/create_cli_v2_docs.sh

######################
# HELP
Expand Down
1 change: 1 addition & 0 deletions libs/infinity_emb/infinity_emb/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Dtype(EnumType):
bfloat16: str = "bfloat16"
int8: str = "int8"
fp8: str = "fp8"
autoquant: str = "autoquant"
auto: str = "auto"

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
if TYPE_CHECKING:
from torch import Tensor


if CHECK_SENTENCE_TRANSFORMERS.is_available:
from sentence_transformers import SentenceTransformer, util # type: ignore
else:
Expand Down Expand Up @@ -88,7 +87,7 @@ def __init__(self, *, engine_args=EngineArgs):
]:
fm.auto_model.to(torch.bfloat16)

if engine_args.dtype in (Dtype.int8, Dtype.fp8):
if engine_args.dtype in (Dtype.int8, Dtype.fp8, Dtype.autoquant):
fm.auto_model = quant_interface(
fm.auto_model, engine_args.dtype, device=Device[self.device.type]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import requests # type: ignore
import torch.ao.quantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: This import is unused in the current file. Consider removing it if not needed.


from infinity_emb._optional_imports import CHECK_SENTENCE_TRANSFORMERS, CHECK_TORCH
from infinity_emb.env import MANAGER
Expand Down Expand Up @@ -34,14 +35,22 @@ def quant_interface(model: Any, dtype: Dtype = Dtype.int8, device: Device = Devi
Defaults to Device.cpu.
"""
device_orig = model.device
if device == Device.cpu and dtype in [Dtype.int8, Dtype.auto]:
if dtype == Dtype.autoquant:
import torchao # type: ignore

model = torchao.autoquant(model)
logger.info("using dtype=autoquant")
elif device == Device.cpu and dtype in [Dtype.int8, Dtype.auto]:
logger.info("using torch.quantization.quantize_dynamic()")
# TODO: verify if cpu requires quantization with torch.quantization.quantize_dynamic()
model = torch.quantization.quantize_dynamic(
model.to("cpu"), # the original model
{torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=torch.qint8,
)
model = torch.ao.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
Comment on lines 46 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Two quantization methods are applied sequentially. This might lead to unexpected behavior or reduced model performance. Consider using only one method or clarify why both are necessary.

elif device == Device.cuda and dtype in [Dtype.int8, Dtype.auto]:
logger.info(f"using quantize() for {dtype.value}")
quant_handler, state_dict = quantize(model, mode=dtype.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,10 @@ def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
if mod.bias is not None:
raise ValueError(
"int4 quantization requires all layers to have bias=False. This model is not compatible."
)
out_features = mod.out_features
in_features = mod.in_features
assert out_features % 8 == 0, "require out_features % 8 == 0"
Expand Down Expand Up @@ -710,7 +713,10 @@ def quantize(
quantized_state_dict = quant_handler.create_quantized_state_dict()

new_base_name = base_name.replace(".pth", f"{label}int8.pth")
elif mode == "autoquant":
import torchao

model = torchao.autoquant(torch.compile(model))
elif mode == "int4":
logger.info(
"Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
Expand Down
38 changes: 29 additions & 9 deletions libs/infinity_emb/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions libs/infinity_emb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ diskcache = {version = "*", optional=true}
onnxruntime-gpu = {version = "*", optional=true}
tensorrt = {version = "^8.6.1", optional=true}
soundfile = {version="^0.12.1", optional=true}
torchao = {version="^0.5.0", optional=true}

[tool.poetry.scripts]
infinity_emb = "infinity_emb.infinity_server:cli"
Expand Down Expand Up @@ -82,9 +83,9 @@ types-chardet = "^5.0.4.6"
mypy-protobuf = "^3.0.0"

[tool.poetry.extras]
ct2=["ctranslate2","sentence-transformers","torch","transformers"]
ct2=["ctranslate2","sentence-transformers","torch","torchao","transformers"]
optimum=["optimum"]
torch=["sentence-transformers","torch"]
torch=["sentence-transformers","torch","torchao"]
einops=["einops"]
logging=["rich"]
cache=["diskcache"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import torch
from transformers import AutoTokenizer, BertModel # type: ignore

from infinity_emb.args import EngineArgs
from infinity_emb.primitives import Device, Dtype
from infinity_emb.transformer.embedder.sentence_transformer import (
SentenceTransformerPatched,
)
from infinity_emb.transformer.quantization.interface import quant_interface

devices = [Device.cpu]
Expand Down Expand Up @@ -49,3 +53,45 @@ def test_quantize_bert(device: Device, dtype: Dtype):
out_quant = model.forward(**tokens_encoded)["last_hidden_state"].mean(dim=1)

assert torch.cosine_similarity(out_default, out_quant) > 0.95


def test_autoquant_quantization():
model_st = SentenceTransformerPatched(
engine_args=EngineArgs(
model_name_or_path="michaelfeil/bge-small-en-v1.5",
dtype="autoquant",
engine="torch",
bettertransformer=False,
)
)
model_default = SentenceTransformerPatched(
engine_args=EngineArgs(
model_name_or_path="michaelfeil/bge-small-en-v1.5",
dtype="float32",
engine="torch",
bettertransformer=False,
)
)
sentence = "This is a test sentence."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: This line is unused and can be removed.

for sentence in [
"This is a test sentence.",
"This is another sentence, that should be embedded. " * 10,
"1",
]:
embedding_st = model_st.encode_post(
model_st.encode_core(model_st.encode_pre([sentence]))
)
embedding_default = model_default.encode_post(
model_default.encode_core(model_default.encode_pre([sentence]))
)
assert embedding_st.shape == embedding_default.shape

# cosine similarity
sim = torch.nn.functional.cosine_similarity(
torch.tensor(embedding_st), torch.tensor(embedding_default)
)
assert sim > 0.95


if __name__ == "__main__":
test_autoquant_quantization()
Comment on lines +96 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Running a single test function in main might not be ideal. Consider using a test runner or removing this block if not necessary.

Loading