diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml index fb18d865..eb8f8929 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml @@ -22,7 +22,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml index 3e4859e1..d103811c 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml @@ -16,7 +16,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml index 1b14f127..0ba23d5e 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml @@ -17,7 +17,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi.yml b/.github/workflows/test-pytorch-xla-tpu-tgi.yml index 78492caf..ff57d648 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi.yml @@ -20,7 +20,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/.github/workflows/test-pytorch-xla-tpu.yml b/.github/workflows/test-pytorch-xla-tpu.yml index efa2e354..11e66c8d 100644 --- a/.github/workflows/test-pytorch-xla-tpu.yml +++ b/.github/workflows/test-pytorch-xla-tpu.yml @@ -20,7 +20,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/.github/workflows/tpu-tgi-release.yml b/.github/workflows/tpu-tgi-release.yml index 2d81a471..400a5238 100644 --- a/.github/workflows/tpu-tgi-release.yml +++ b/.github/workflows/tpu-tgi-release.yml @@ -74,7 +74,7 @@ jobs: labels: ${{ steps.meta.outputs.labels }} build-args: | VERSION=${{ steps.version.outputs.version }} - TGI_VERSION=0ff6ff60ada291840beed63d8bf458d6f9606f7f + TGI_VERSION="v2.4.1" - name: Generate artifact attestation for TGI @@ -95,7 +95,7 @@ jobs: labels: ${{ steps.meta-ie.outputs.labels }} build-args: | VERSION=${{ steps.version.outputs.version }} - TGI_VERSION=0ff6ff60ada291840beed63d8bf458d6f9606f7f + TGI_VERSION="v2.4.1" target: inference-endpoint diff --git a/Makefile b/Makefile index 7448a46e..7091ac00 100644 --- a/Makefile +++ b/Makefile @@ -19,8 +19,7 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL)) .PHONY: build_dist style style_check clean -# Ths is essentially v2.3.0 plus a fix to support v2 proto interface -TGI_VERSION ?= 0ff6ff60ada291840beed63d8bf458d6f9606f7f +TGI_VERSION ?= 690702b1ce9a27ce5bdf2a9dd3a80277ecea12cd rwildcard=$(wildcard $1) $(foreach d,$1,$(call rwildcard,$(addsuffix /$(notdir $d),$(wildcard $(dir $d)*)))) @@ -47,6 +46,7 @@ tpu-tgi: docker build --rm -f text-generation-inference/docker/Dockerfile \ --build-arg VERSION=$(VERSION) \ --build-arg TGI_VERSION=$(TGI_VERSION) \ + --ulimit nofile=100000:100000 \ -t huggingface/optimum-tpu:$(VERSION)-tgi . docker tag huggingface/optimum-tpu:$(VERSION)-tgi huggingface/optimum-tpu:latest @@ -55,6 +55,7 @@ tpu-tgi-ie: --target inference-endpoint \ --build-arg VERSION=$(VERSION) \ --build-arg TGI_VERSION=$(TGI_VERSION) \ + --ulimit nofile=100000:100000 \ -t huggingface/optimum-tpu:$(VERSION)-tgi . docker tag huggingface/optimum-tpu:$(VERSION)-tgi huggingface/optimum-tpu:latest-ie diff --git a/optimum/tpu/__init__.py b/optimum/tpu/__init__.py index 4f14dfc8..848946e5 100644 --- a/optimum/tpu/__init__.py +++ b/optimum/tpu/__init__.py @@ -14,5 +14,4 @@ from .jetstream_pt_support import jetstream_pt_available # isort:skip from .fsdp_v2 import get_fsdp_config, use_fsdp_v2 -from .modeling import AutoModelForCausalLM from .version import VERSION, __version__ diff --git a/optimum/tpu/cli.py b/optimum/tpu/cli.py index 303eed8c..069ff965 100644 --- a/optimum/tpu/cli.py +++ b/optimum/tpu/cli.py @@ -9,8 +9,8 @@ import typer -TORCH_VER = "2.4.0" -JETSTREAM_PT_VER = "02927c9f563082421abe8eedceabe8aedd7ec2f9" +TORCH_VER = "2.5.1" +JETSTREAM_PT_VER = "jetstream-v0.2.4" DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps") app = typer.Typer() diff --git a/optimum/tpu/fsdp_v2.py b/optimum/tpu/fsdp_v2.py index 9f4a5ad1..d303e44a 100644 --- a/optimum/tpu/fsdp_v2.py +++ b/optimum/tpu/fsdp_v2.py @@ -17,8 +17,6 @@ """ from typing import Any, Dict, List, Union -from transformers.utils import logging - PreTrainedModel = Any # NOTE: instead of the above, modeling_utils.PreTrainedModel should be used, but since the usage is only for type @@ -92,15 +90,6 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict: from .modeling_gemma import GemmaForCausalLM if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM): - logger = logging.get_logger(__name__) - from torch_xla import __version__ as xla_version - - if xla_version == "2.3.0": - logger.warning_once( - "Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any " - "issues consider using the nightly version, and report the issue on the optimum-tpu " - "GitHub repository: https://github.com/huggingface/optimum-tpu/issues/new." - ) cls_to_wrap = "GemmaDecoderLayer" matched_model = True elif model_type == "llama": diff --git a/pyproject.toml b/pyproject.toml index b9d4c9d4..449c909b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,9 +42,9 @@ keywords = [ ] dependencies = [ - "transformers == 4.41.1", - "torch == 2.4.0", - "torch-xla[tpu] == 2.4.0", + "transformers == 4.46.3", + "torch == 2.5.1", + "torch-xla[tpu] == 2.5.1", 'typer == 0.6.1', "loguru == 0.6.0", "sentencepiece == 0.2.0", @@ -63,7 +63,7 @@ quality = ["black", "ruff", "isort"] # Pallas is pulled because it will install a compatible version of jax[tpu]. jetstream-pt = [ "jetstream-pt", - "torch-xla[pallas] == 2.4.0" + "torch-xla[pallas] == 2.5.1" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 76a52096..64a4a636 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ # This is not a complete list of dependencies, but it allows to install torch without CUDA support --index-url https://download.pytorch.org/whl/cpu -torch==2.4.0 +torch==2.5.1 diff --git a/text-generation-inference/docker/Dockerfile b/text-generation-inference/docker/Dockerfile index 218561dc..319ae9e8 100644 --- a/text-generation-inference/docker/Dockerfile +++ b/text-generation-inference/docker/Dockerfile @@ -8,12 +8,12 @@ RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1 # Build cargo components (adapted from TGI original Dockerfile) # Note that the build image is aligned on the same Linux version as the base image (Debian bookworm/ Ubuntu 22.04) -FROM lukemathwalker/cargo-chef:latest-rust-1.79-bookworm AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80.1-bookworm AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse -FROM chef as planner +FROM chef AS planner COPY --from=tgi /tgi/Cargo.toml Cargo.toml COPY --from=tgi /tgi/Cargo.lock Cargo.lock COPY --from=tgi /tgi/rust-toolchain.toml rust-toolchain.toml @@ -101,12 +101,12 @@ RUN apt-get update -y \ RUN pip install --upgrade pip # Install HuggingFace packages -ARG TRANSFORMERS_VERSION='4.41.1' -ARG ACCELERATE_VERSION='0.27.2' -ARG SAFETENSORS_VERSION='0.4.2' +ARG TRANSFORMERS_VERSION='4.46.3' +ARG ACCELERATE_VERSION='1.1.1' +ARG SAFETENSORS_VERSION='0.4.5' # TGI base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 \ VERSION=${VERSION} @@ -134,7 +134,7 @@ RUN pip install dist/text_generation_server*.tar.gz # TPU compatible image for Inference Endpoints -FROM tpu_base as inference-endpoint +FROM tpu_base AS inference-endpoint COPY text-generation-inference/docker/entrypoint.sh entrypoint.sh RUN chmod +x entrypoint.sh @@ -145,4 +145,5 @@ ENTRYPOINT ["./entrypoint.sh"] FROM tpu_base ENTRYPOINT ["text-generation-launcher"] -CMD ["--json-output"] +# This is commented out in the original TGI Dockerfile +# CMD ["--json-output"] diff --git a/text-generation-inference/server/Makefile b/text-generation-inference/server/Makefile index d513e9b5..56e481b0 100644 --- a/text-generation-inference/server/Makefile +++ b/text-generation-inference/server/Makefile @@ -2,7 +2,7 @@ pkg_name := text_generation_server BUILDDIR ?= $(CURDIR)/build VERSION ?= 0.0.1 -TGI_VERSION ?= 0ff6ff60ada291840beed63d8bf458d6f9606f7f +TGI_VERSION ?= "v2.4.1" mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) mkfile_dir := $(dir $(mkfile_path)) pkg_dir := $(BUILDDIR)/$(pkg_name) diff --git a/text-generation-inference/server/pyproject.toml b/text-generation-inference/server/pyproject.toml index a10727b8..ab00ebc9 100644 --- a/text-generation-inference/server/pyproject.toml +++ b/text-generation-inference/server/pyproject.toml @@ -14,8 +14,8 @@ dependencies = [ 'grpcio-reflection == 1.62.1', 'grpc-interceptor == 0.15.2', 'typer == 0.6.1', - 'safetensors == 0.4.2', - 'transformers == 4.41.1', + 'safetensors == 0.4.5', + 'transformers == 4.46.3', 'loguru == 0.6.0', "sentencepiece == 0.2.0", "numpy<2.0", diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index ab7c174b..cf7e1f3b 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -16,23 +16,23 @@ from transformers.generation import GenerationConfig import optimum.tpu.xla_logger as logger -from optimum.tpu import AutoModelForCausalLM from optimum.tpu.generation import TokenSelector +from optimum.tpu.modeling import AutoModelForCausalLM from optimum.tpu.static_cache_xla import StaticCacheXla from optimum.tpu.xla_mp_comm import AgentMailbox, RootMailbox from .generator_base import Generator from .pb.generate_pb2 import ( - Batch, - CachedBatch, - FinishReason, - GeneratedText, - Generation, - InfoResponse, - NextTokenChooserParameters, - Request, - StoppingCriteriaParameters, - Tokens, + Batch, + CachedBatch, + FinishReason, + GeneratedText, + Generation, + InfoResponse, + NextTokenChooserParameters, + Request, + StoppingCriteriaParameters, + Tokens, ) @@ -314,6 +314,9 @@ def __init__( tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids + # The token selector will use the model's generation mixin internal variables to select the next token, and it + # expects special tokens to be initialized in the model. + model._prepare_special_tokens(generation_config=model.generation_config, device=model.device) # Slots are empty to begin with, they will be populated as new batches arrive self.slots = [] self.batch_id = 0 diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 97061421..45f0a549 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -262,6 +262,10 @@ def __init__( tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids + # The token selector will use the model's generation mixin internal variables to select the next token, and it + # expects special tokens to be initialized in the model. + model = self.engine.pt_model + model._prepare_special_tokens(generation_config=model.generation_config, device='cpu') # Slots number is static, it cannot grow over the size of the batch self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)] self.batch_id = 0 diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py index e61788d8..f0c0476b 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py @@ -4,6 +4,15 @@ from transformers import GemmaConfig, GenerationConfig, GenerationMixin +class GemmaConfigHf(GemmaConfig, gemma_config.GemmaConfig): + """This class is used to support both the HF GemmaConfig and the Jetstream Pytorch GemmaConfig at the same time. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer = None + + class GemmaModelHf(GemmaModel, GenerationMixin): """Transformer module that uses HF GemmaConfig instead of Jetstream Pytorch GemmaConfig + device. @@ -16,24 +25,8 @@ def __init__( device, env, ): - self.config = config self.generation_config = GenerationConfig.from_model_config(config) - - args = gemma_config.GemmaConfig( - vocab_size=config.vocab_size, - max_position_embeddings=config.max_position_embeddings, - num_hidden_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - head_dim=config.head_dim, - rms_norm_eps=config.rms_norm_eps, - dtype="bfloat16", - quant=False, # No quantization support for now - tokenizer=None, - ) - + args = GemmaConfigHf(**config.to_dict()) args.device = device super().__init__(args, env) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py index 0e476a9a..25a93595 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py @@ -1,12 +1,42 @@ - from jetstream_pt.third_party.mixtral import config as mixtral_config from jetstream_pt.third_party.mixtral.model import Transformer from transformers import GenerationConfig, GenerationMixin, MixtralConfig +class MixtralConfigHf(MixtralConfig, mixtral_config.ModelArgs): + """This class is used to support both the HF MixtralConfig and the Jetstream Pytorch ModelArgs at the same time.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__post_init__() + + @property + def block_size(self): + return self.max_position_embeddings + + @property + def n_layer(self): + return self.num_hidden_layers + + @property + def n_head(self): + return self.num_attention_heads + + @property + def dim(self): + return self.hidden_size + + @property + def n_local_heads(self): + return self.num_local_experts or self.num_attention_heads + + @property + def num_activated_experts(self): + return self.num_experts_per_tok + + class MixtralModelHf(Transformer, GenerationMixin): - """Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device. - """ + """Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device.""" def __init__( self, @@ -14,20 +44,9 @@ def __init__( device, env, ): - self.config = config self.generation_config = GenerationConfig.from_model_config(config) - - args = mixtral_config.ModelArgs( - block_size=config.max_position_embeddings, - vocab_size=config.vocab_size, - n_layer=config.num_hidden_layers, - n_head=config.num_attention_heads, - dim=config.hidden_size, - intermediate_size=config.intermediate_size, - n_local_heads=config.num_local_experts or config.num_attention_heads, - num_activated_experts=config.num_experts_per_tok, - device=device, - ) + args = MixtralConfigHf(**config.to_dict()) + args.device = device super().__init__(args, env) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py index ce0820c4..472b601b 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +import torch_xla2 from jetstream.engine import sampling_utils from transformers.generation import ( GenerationConfig, @@ -173,7 +174,12 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: Return: `jnp.ndarray`: A `jnp.ndarray` containing the selected tokens. """ - scores = self.logits_processor(input_ids, logits) + # Logits processors is written in pytorch, so parameters are cast to float32 and converted to pytorch and back + # to jax with j2t/t2j (that is a bit expensive, it does copies), otherwise some operations are not supported. + logits_t = torch_xla2.tensor.j2t(logits.astype(jnp.float32)) + scores = self.logits_processor(input_ids, logits_t) + scores = torch_xla2.tensor.t2j(scores).to_device(logits.device) + if self.mode == GenerationMode.SAMPLE: # split the key to avoid reusing the same key for multiple samples subkey, self.key = jax.random.split(self.key)