Skip to content

Commit

Permalink
Cpu tgi (huggingface#1936)
Browse files Browse the repository at this point in the history
* add CPU tgi support

Signed-off-by: Wang, Yi A <[email protected]>

* ipex distributed ops support

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Funtowicz Morgan <[email protected]>
  • Loading branch information
2 people authored and yuanwu2017 committed Sep 24, 2024
1 parent a9faabc commit 0d879fe
Show file tree
Hide file tree
Showing 17 changed files with 221 additions and 77 deletions.
87 changes: 83 additions & 4 deletions Dockerfile_intel
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
ARG PLATFORM=xpu

FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src

Expand Down Expand Up @@ -37,7 +39,8 @@ RUN cargo build --profile release-opt


# Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base

FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu

USER root
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
Expand All @@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \

WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed

# Install server
COPY proto proto
Expand Down Expand Up @@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher

# Final image
FROM base

# Text Generation Inference base image for Intel-cpu
FROM ubuntu:22.04 as cpu

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
ca-certificates \
make \
g++ \
git \
wget \
cmake

ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80

ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.10.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH

# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh

RUN conda install -c conda-forge gperftools mkl

RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl

WORKDIR /usr/src

RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a

RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131

RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install

RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .

ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
ENV KMP_BLOCKTIME=1
ENV KMP_TPAUSE=0
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist

# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_intel.txt && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir

# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher

FROM ${PLATFORM} as final
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
4 changes: 2 additions & 2 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
import os

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
Expand All @@ -7,7 +7,7 @@
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "xpu":
elif IPEX_AVAIL:
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
5 changes: 2 additions & 3 deletions server/text_generation_server/layers/attention/xpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE

SUPPORTS_WINDOWING = False

Expand Down Expand Up @@ -56,8 +57,6 @@ def paged_attention(
input_lengths: torch.Tensor,
max_s: int,
):
query = query.contiguous()
block_size = value_cache.shape[3]
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
Expand All @@ -67,7 +66,7 @@ def paged_attention(
softmax_scale,
block_tables,
input_lengths,
block_size,
BLOCK_SIZE,
max_s,
None,
)
24 changes: 12 additions & 12 deletions server/text_generation_server/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from accelerate import init_empty_weights
from text_generation_server.utils.import_utils import (
SYSTEM,
IPEX_AVAIL,
)


Expand Down Expand Up @@ -82,18 +83,20 @@ def forward(self, hidden_states, residual=None):

return super().forward(hidden_states), residual

elif SYSTEM == "xpu":
elif IPEX_AVAIL:
import intel_extension_for_pytorch as ipex

class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
res_out = hidden_states
out = ipex.llm.functional.add_layer_norm(
residual, hidden_states, self.weight, self.bias, self.eps, True
residual,
hidden_states,
self.weight,
self.bias,
self.eps,
residual is not None,
)
if residual is not None:
res_out = residual
return out, res_out
return out, residual if residual is not None else hidden_states


class FastRMSNorm(nn.Module):
Expand All @@ -109,19 +112,16 @@ def load(cls, prefix, weights, eps=1e-6):
return cls(weight, eps)

def forward(self, hidden_states, residual=None):
if SYSTEM == "xpu":
residual_out = hidden_states
if IPEX_AVAIL:
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
self.weight,
None,
self.variance_epsilon,
True,
residual is not None,
)
if residual is not None:
residual_out = residual
return out, residual_out
return out, residual if residual is not None else hidden_states
elif hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
Expand Down
6 changes: 3 additions & 3 deletions server/text_generation_server/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import torch
from torch import nn

from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL

if SYSTEM == "cuda":
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops
elif SYSTEM == "xpu":
elif IPEX_AVAIL:
import intel_extension_for_pytorch as ipex


Expand Down Expand Up @@ -69,7 +69,7 @@ def forward(

# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif SYSTEM == "xpu":
elif IPEX_AVAIL:
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)
Expand Down
31 changes: 24 additions & 7 deletions server/text_generation_server/layers/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import IPEX_AVAIL

if IPEX_AVAIL:
import intel_extension_for_pytorch as ipex


class LayerConcat(torch.nn.Module):
Expand Down Expand Up @@ -96,10 +100,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
local_out = gather_input.T

torch.mm(input, self.linear.weight.T, out=local_out)

torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if IPEX_AVAIL:
ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)

if input.shape[0] == 1:
return world_out
Expand All @@ -109,7 +117,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
if IPEX_AVAIL:
ipex.distributed.all_gather(world_output, output, group=self.process_group)
else:
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output

Expand Down Expand Up @@ -206,7 +217,10 @@ def load(cls, config, prefix: str, weights, bias: bool):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1 and reduce:
torch.distributed.all_reduce(out, group=self.process_group)
if IPEX_AVAIL:
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out


Expand Down Expand Up @@ -243,5 +257,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
if IPEX_AVAIL:
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.import_utils import IPEX_AVAIL

if SYSTEM != "xpu":
if not IPEX_AVAIL:
from vllm.model_executor.layers.fused_moe import fused_moe

from text_generation_server.layers.attention import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import numpy as np

from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.import_utils import IPEX_AVAIL

if SYSTEM != "xpu":
if not IPEX_AVAIL:
from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
Expand Down
49 changes: 33 additions & 16 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK
Expand Down Expand Up @@ -773,21 +773,38 @@ def init_kv_cache(
else:
x = BLOCK_SIZE // element_size

self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
if IPEX_AVAIL and SYSTEM == "cpu":
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
else:
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]

def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
Expand Down
Loading

0 comments on commit 0d879fe

Please sign in to comment.