diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml
index 9101fc2bea..c94093bc97 100644
--- a/.github/workflows/base.yml
+++ b/.github/workflows/base.yml
@@ -28,7 +28,19 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
- pytorch: 2.4.0
+ pytorch: 2.4.1
+ torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
+ - cuda: "124"
+ cuda_version: 12.4.1
+ cudnn_version: ""
+ python_version: "3.11"
+ pytorch: 2.4.1
+ torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
+ - cuda: "124"
+ cuda_version: 12.4.1
+ cudnn_version: ""
+ python_version: "3.11"
+ pytorch: 2.5.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 671be4b652..919cfd6545 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -6,7 +6,7 @@ on:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- - "*.md"
+ - "*.[q]md"
- "examples/**/*.y[a]?ml"
workflow_dispatch:
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 5a972f5f08..47a4c7f114 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -27,7 +27,12 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
- pytorch: 2.4.0
+ pytorch: 2.4.1
+ axolotl_extras:
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -84,7 +89,12 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
- pytorch: 2.4.0
+ pytorch: 2.4.1
+ axolotl_extras:
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml
index c854af9abe..d9f0ce7e6c 100644
--- a/.github/workflows/multi-gpu-e2e.yml
+++ b/.github/workflows/multi-gpu-e2e.yml
@@ -1,6 +1,9 @@
name: docker-multigpu-tests-biweekly
on:
+ pull_request:
+ paths:
+ - 'tests/e2e/multigpu/*.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -18,6 +21,20 @@ jobs:
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.4.1
+ axolotl_extras:
+ num_gpus: 2
+ nightly_build: "true"
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.5.0
+ axolotl_extras:
+ num_gpus: 2
+ nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -39,6 +56,7 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
+ echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu
diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml
index 1d95a0983f..55123a9026 100644
--- a/.github/workflows/nightlies.yml
+++ b/.github/workflows/nightlies.yml
@@ -26,7 +26,12 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
- pytorch: 2.4.0
+ pytorch: 2.4.1
+ axolotl_extras:
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -83,7 +88,12 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
- pytorch: 2.4.0
+ pytorch: 2.4.1
+ axolotl_extras:
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml
index 885239d185..04dbc6385c 100644
--- a/.github/workflows/pypi.yml
+++ b/.github/workflows/pypi.yml
@@ -27,7 +27,7 @@ jobs:
run: |
pip3 install wheel packaging
pip3 install -e .
- pip3 install -r requirements-tests.txt
+ pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name
id: tag
diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml
new file mode 100644
index 0000000000..90b1e23cd2
--- /dev/null
+++ b/.github/workflows/tests-nightly.yml
@@ -0,0 +1,128 @@
+name: Tests Nightly against upstream main
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '0 0 * * *' # Runs at 00:00 UTC every day
+
+jobs:
+ pre-commit:
+ name: pre-commit
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ cache: 'pip' # caching pip dependencies
+ - uses: pre-commit/action@v3.0.0
+ env:
+ SKIP: no-commit-to-branch
+
+ pytest:
+ name: PyTest
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python_version: ["3.10", "3.11"]
+ pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
+ timeout-minutes: 20
+
+ steps:
+ - name: Check out repository code
+ uses: actions/checkout@v3
+
+ - name: Setup Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python_version }}
+ cache: 'pip' # caching pip dependencies
+
+ - name: Install PyTorch
+ run: |
+ pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
+
+ - name: Update requirements.txt
+ run: |
+ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
+ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
+ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
+ sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
+
+ - name: Install dependencies
+ run: |
+ pip3 install --upgrade pip
+ pip3 install --upgrade packaging
+ pip3 install -U -e .
+ pip3 install -r requirements-dev.txt -r requirements-tests.txt
+
+ - name: Run tests
+ run: |
+ pytest --ignore=tests/e2e/ tests/
+
+ - name: cleanup pip cache
+ run: |
+ find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
+
+ docker-e2e-tests:
+ if: github.repository_owner == 'axolotl-ai-cloud'
+ # this job needs to be run on self-hosted GPU runners...
+ runs-on: [self-hosted, modal]
+ timeout-minutes: 60
+ needs: [pre-commit, pytest]
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - cuda: 121
+ cuda_version: 12.1.1
+ python_version: "3.10"
+ pytorch: 2.3.1
+ num_gpus: 1
+ axolotl_extras: mamba-ssm
+ nightly_build: "true"
+ - cuda: 121
+ cuda_version: 12.1.1
+ python_version: "3.11"
+ pytorch: 2.3.1
+ num_gpus: 1
+ axolotl_extras: mamba-ssm
+ nightly_build: "true"
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.4.1
+ num_gpus: 1
+ axolotl_extras:
+ nightly_build: "true"
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.5.0
+ num_gpus: 1
+ axolotl_extras:
+ nightly_build: "true"
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Install Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.10"
+ - name: Install Modal
+ run: |
+ python -m pip install --upgrade pip
+ pip install modal==0.63.64 jinja2
+ - name: Update env vars
+ run: |
+ echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
+ echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
+ echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
+ echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
+ echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
+ echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
+ echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
+ - name: Run tests job on Modal
+ run: |
+ modal run cicd.tests
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 74b4bcfbdb..ba50adfd35 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -36,6 +36,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
+ pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
timeout-minutes: 20
steps:
@@ -48,12 +49,20 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- - name: Install dependencies
+ - name: upgrade pip
run: |
pip3 install --upgrade pip
- pip3 install --upgrade packaging
+ pip3 install --upgrade packaging setuptools wheel
+
+ - name: Install PyTorch
+ run: |
+ pip3 install torch==${{ matrix.pytorch_version }}
+
+ - name: Install dependencies
+ run: |
+ pip3 show torch
pip3 install -U -e .
- pip3 install -r requirements-tests.txt
+ pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
run: |
@@ -67,7 +76,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
- timeout-minutes: 60
+ timeout-minutes: 90
needs: [pre-commit, pytest]
strategy:
@@ -89,7 +98,13 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
- pytorch: 2.4.0
+ pytorch: 2.4.1
+ num_gpus: 1
+ axolotl_extras:
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.5.0
num_gpus: 1
axolotl_extras:
steps:
diff --git a/.isort.cfg b/.isort.cfg
index 79067a7c91..e487797321 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -1,3 +1,3 @@
[settings]
profile=black
-known_third_party=wandb
+known_third_party=wandb,comet_ml
diff --git a/.mypy.ini b/.mypy.ini
index ede9fef887..c6d837d3f2 100644
--- a/.mypy.ini
+++ b/.mypy.ini
@@ -11,6 +11,9 @@ ignore_errors = True
[mypy-axolotl.models.mixtral.*]
ignore_errors = True
+[mypy-axolotl.integrations.liger.models.*]
+ignore_errors = True
+
[mypy-axolotl.models.phi.*]
ignore_errors = True
diff --git a/README.md b/README.md
index a626635dc8..c12aa3bba0 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,9 @@
# Axolotl
+![tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg)
+![tests-nightly](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg)
+![multigpu-semi-weekly tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg)
+
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
Features:
@@ -7,10 +11,10 @@ Features:
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
-- Integrated with xformer, flash attention, rope scaling, and multipacking
+- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
-- Log results and optionally checkpoints to wandb or mlflow
+- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!
@@ -22,39 +26,50 @@ Features:
## Table of Contents
-- [Introduction](#axolotl)
-- [Supported Features](#axolotl-supports)
-- [Quickstart](#quickstart-)
-- [Environment](#environment)
- - [Docker](#docker)
- - [Conda/Pip venv](#condapip-venv)
- - [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
- - [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- - [Windows](#windows)
- - [Mac](#mac)
- - [Google Colab](#google-colab)
- - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- - [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
-- [Dataset](#dataset)
-- [Config](#config)
- - [Train](#train)
- - [Inference](#inference-playground)
- - [Merge LORA to Base](#merge-lora-to-base)
- - [Special Tokens](#special-tokens)
- - [All Config Options](#all-config-options)
-- Advanced Topics
- - [Multipack](./docs/multipack.qmd)
- - [RLHF & DPO](./docs/rlhf.qmd)
- - [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)
- - [Unsloth](./docs/unsloth.qmd)
-- [Common Errors](#common-errors-)
- - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
-- [Debugging Axolotl](#debugging-axolotl)
-- [Need Help?](#need-help-)
-- [Badge](#badge-)
-- [Community Showcase](#community-showcase)
-- [Contributing](#contributing-)
-- [Sponsors](#sponsors-)
+- [Axolotl](#axolotl)
+ - [Table of Contents](#table-of-contents)
+ - [Axolotl supports](#axolotl-supports)
+ - [Quickstart ⚡](#quickstart-)
+ - [Usage](#usage)
+ - [Advanced Setup](#advanced-setup)
+ - [Environment](#environment)
+ - [Docker](#docker)
+ - [Conda/Pip venv](#condapip-venv)
+ - [Cloud GPU](#cloud-gpu)
+ - [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
+ - [LambdaLabs](#lambdalabs)
+ - [GCP](#gcp)
+ - [Windows](#windows)
+ - [Mac](#mac)
+ - [Google Colab](#google-colab)
+ - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
+ - [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
+ - [Dataset](#dataset)
+ - [Config](#config)
+ - [All Config Options](#all-config-options)
+ - [Train](#train)
+ - [Preprocess dataset](#preprocess-dataset)
+ - [Multi-GPU](#multi-gpu)
+ - [DeepSpeed](#deepspeed)
+ - [FSDP](#fsdp)
+ - [FSDP + QLoRA](#fsdp--qlora)
+ - [Weights \& Biases Logging](#weights--biases-logging)
+ - [Special Tokens](#special-tokens)
+ - [Liger Kernel](#liger-kernel)
+ - [Inference Playground](#inference-playground)
+ - [Merge LORA to base](#merge-lora-to-base)
+ - [Common Errors 🧰](#common-errors-)
+ - [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
+ - [Debugging Axolotl](#debugging-axolotl)
+ - [Need help? 🙋](#need-help-)
+ - [Badge ❤🏷️](#badge-️)
+ - [Community Showcase](#community-showcase)
+ - [Contributing 🤝](#contributing-)
+ - [Sponsors 🤝❤](#sponsors-)
+ - [💎 Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly)
+ - [🥇 Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo)
+ - [🥈 Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo)
+ - [🥉 Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo)
|
@@ -96,6 +111,7 @@ Features:
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
+| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
@@ -105,7 +121,7 @@ Features:
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
-**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
+**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
@@ -367,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
- # fastchat conversation
+ # fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt
@@ -499,6 +515,22 @@ wandb_name:
wandb_log_model:
```
+##### Comet Logging
+
+Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
+
+- wandb options
+```yaml
+use_comet:
+comet_api_key:
+comet_workspace:
+comet_project_name:
+comet_experiment_key:
+comet_mode:
+comet_online:
+comet_experiment_config:
+```
+
##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
@@ -515,6 +547,25 @@ tokens: # these are delimiters
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
+##### Liger Kernel
+
+Liger Kernel: Efficient Triton Kernels for LLM Training
+
+https://github.com/linkedin/Liger-Kernel
+
+Liger (LinkedIn GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training.
+It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The Liger Kernel
+composes well and is compatible with both FSDP and Deepspeed.
+
+```yaml
+plugins:
+ - axolotl.integrations.liger.LigerPlugin
+liger_rope: true
+liger_rms_norm: true
+liger_swiglu: true
+liger_fused_linear_cross_entropy: true
+```
+
### Inference Playground
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
diff --git a/_quarto.yml b/_quarto.yml
index 6b2eed971b..acb4872589 100644
--- a/_quarto.yml
+++ b/_quarto.yml
@@ -37,6 +37,7 @@ website:
- docs/mac.qmd
- docs/multi-node.qmd
- docs/unsloth.qmd
+ - docs/amd_hpc.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"
diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja
index 3a79883667..8ce6550056 100644
--- a/cicd/Dockerfile.jinja
+++ b/cicd/Dockerfile.jinja
@@ -8,6 +8,7 @@ ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
+ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
@@ -22,7 +23,13 @@ RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
-RUN pip install causal_conv1d
+RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
+ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
+ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
+ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
+ sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
+ fi
+
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
@@ -30,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi
# So we can test the Docker image
-RUN pip install -r requirements-tests.txt
+RUN pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
diff --git a/cicd/cicd.sh b/cicd/cicd.sh
index eceda9b375..483d62a7ad 100755
--- a/cicd/cicd.sh
+++ b/cicd/cicd.sh
@@ -1,6 +1,6 @@
#!/bin/bash
set -e
-pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
-pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
-pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ /workspace/axolotl/tests/e2e/
+pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
+pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
+pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
diff --git a/cicd/multigpu.py b/cicd/multigpu.py
index be10fbc73a..da726b4731 100644
--- a/cicd/multigpu.py
+++ b/cicd/multigpu.py
@@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
- timeout=45 * 60,
+ timeout=60 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
)
diff --git a/cicd/tests.py b/cicd/tests.py
index c214676378..9ebce9815f 100644
--- a/cicd/tests.py
+++ b/cicd/tests.py
@@ -28,6 +28,7 @@
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
+ "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
}
dockerfile_contents = df_template.render(**df_args)
@@ -64,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
- timeout=45 * 60,
+ timeout=60 * 60,
cpu=8.0,
memory=131072,
)
diff --git a/deepspeed_configs/zero3_bf16.json b/deepspeed_configs/zero3_bf16.json
index 16e64d76b4..49fb757552 100644
--- a/deepspeed_configs/zero3_bf16.json
+++ b/deepspeed_configs/zero3_bf16.json
@@ -14,15 +14,6 @@
"bf16": {
"enabled": true
},
- "fp16": {
- "enabled": "auto",
- "auto_cast": false,
- "loss_scale": 0,
- "initial_scale_power": 32,
- "loss_scale_window": 1000,
- "hysteresis": 2,
- "min_loss_scale": 1
- },
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_all.json b/deepspeed_configs/zero3_bf16_cpuoffload_all.json
index 09ca6785b2..3ccc66db48 100644
--- a/deepspeed_configs/zero3_bf16_cpuoffload_all.json
+++ b/deepspeed_configs/zero3_bf16_cpuoffload_all.json
@@ -24,15 +24,6 @@
"bf16": {
"enabled": true
},
- "fp16": {
- "enabled": "auto",
- "auto_cast": false,
- "loss_scale": 0,
- "initial_scale_power": 32,
- "loss_scale_window": 1000,
- "hysteresis": 2,
- "min_loss_scale": 1
- },
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_params.json b/deepspeed_configs/zero3_bf16_cpuoffload_params.json
index 41d4a21323..fe21d35f88 100644
--- a/deepspeed_configs/zero3_bf16_cpuoffload_params.json
+++ b/deepspeed_configs/zero3_bf16_cpuoffload_params.json
@@ -20,15 +20,6 @@
"bf16": {
"enabled": true
},
- "fp16": {
- "enabled": "auto",
- "auto_cast": false,
- "loss_scale": 0,
- "initial_scale_power": 32,
- "loss_scale_window": 1000,
- "hysteresis": 2,
- "min_loss_scale": 1
- },
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
diff --git a/devtools/dev_sharegpt.yml b/devtools/dev_chat_template.yml
similarity index 92%
rename from devtools/dev_sharegpt.yml
rename to devtools/dev_chat_template.yml
index 9c65b49dcd..9697da4b33 100644
--- a/devtools/dev_sharegpt.yml
+++ b/devtools/dev_chat_template.yml
@@ -7,8 +7,8 @@ load_in_8bit: true
load_in_4bit: false
datasets:
- - path: philschmid/guanaco-sharegpt-style
- type: sharegpt
+ - path: fozziethebeat/alpaca_messages_2k_test
+ type: chat_template
shards: 10
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 2b106f1ed8..4872b3907c 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -20,7 +20,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
-RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
diff --git a/docs/amd_hpc.qmd b/docs/amd_hpc.qmd
new file mode 100644
index 0000000000..d1c274e15a
--- /dev/null
+++ b/docs/amd_hpc.qmd
@@ -0,0 +1,108 @@
+---
+title: Training with AMD GPUs on HPC Systems
+description: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs
+---
+
+This guide provides step-by-step instructions for installing and configuring Axolotl on a High-Performance Computing (HPC) environment equipped with AMD GPUs.
+
+## Setup
+
+### 1. Install Python
+
+We recommend using Miniforge, a minimal conda-based Python distribution:
+
+```bash
+curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
+bash Miniforge3-$(uname)-$(uname -m).sh
+```
+
+### 2. Configure Python Environment
+Add Python to your PATH and ensure it's available at login:
+
+```bash
+echo 'export PATH=~/miniforge3/bin:$PATH' >> ~/.bashrc
+echo 'if [ -f ~/.bashrc ]; then . ~/.bashrc; fi' >> ~/.bash_profile
+```
+
+### 3. Load AMD GPU Software
+
+Load the ROCm module:
+
+```bash
+module load rocm/5.7.1
+```
+
+Note: The specific module name and version may vary depending on your HPC system. Consult your system documentation for the correct module name.
+
+### 4. Install PyTorch
+
+Install PyTorch with ROCm support:
+
+```bash
+pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7 --force-reinstall
+```
+
+### 5. Install Flash Attention
+
+Clone and install the Flash Attention repository:
+
+```bash
+git clone --recursive https://github.com/ROCmSoftwarePlatform/flash-attention.git
+export GPU_ARCHS="gfx90a"
+cd flash-attention
+export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
+patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
+pip install .
+```
+
+### 6. Install Axolotl
+
+Clone and install Axolotl:
+
+```bash
+git clone https://github.com/axolotl-ai-cloud/axolotl
+cd axolotl
+pip install packaging ninja
+pip install -e .
+```
+
+### 7. Apply xformers Workaround
+
+xformers appears to be incompatible with ROCm. Apply the following workarounds:
+ - Edit $HOME/packages/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py modifying the code to always return `False` for SwiGLU availability from xformers.
+ - Edit $HOME/miniforge3/lib/python3.10/site-packages/xformers/ops/swiglu_op.py replacing the "SwiGLU" function with a pass statement.
+
+### 8. Prepare Job Submission Script
+
+Create a script for job submission using your HPC's particular software (e.g. Slurm, PBS). Include necessary environment setup and the command to run Axolotl training. If the compute node(s) do(es) not have internet access, it is recommended to include
+
+```bash
+export TRANSFORMERS_OFFLINE=1
+export HF_DATASETS_OFFLINE=1
+```
+
+### 9. Download Base Model
+
+Download a base model using the Hugging Face CLI:
+
+```bash
+huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
+```
+
+### 10. Create Axolotl Configuration
+
+Create an Axolotl configuration file (YAML format) tailored to your specific training requirements and dataset. Use FSDP for multi-node training.
+
+Note: Deepspeed did not work at the time of testing. However, if anyone managed to get it working, please let us know.
+
+### 11. Preprocess Data
+
+Run preprocessing on the login node:
+
+```bash
+CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess /path/to/your/config.yaml
+```
+
+### 12. Train
+
+You are now ready to submit your previously prepared job script. 🚂
diff --git a/docs/config.qmd b/docs/config.qmd
index e859999787..a7bf9080bf 100644
--- a/docs/config.qmd
+++ b/docs/config.qmd
@@ -83,13 +83,14 @@ lora_on_cpu: true
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
- # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
+ # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format: (chat/instruct) | .load_
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
+ revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -123,6 +124,48 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:
+ # Using chat template
+ - path: ...
+ # Set type to `chat_template` to use this strategy
+ type: chat_template
+ # Specify the name of the chat template to use
+ # The name of the chat template to use for training, following values are supported:
+ # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
+ # - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
+ # - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
+ # - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
+ chat_template: tokenizer_default
+ # Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
+ chat_template_jinja:
+ # The key in the data example that contains the messages. Default is "messages".
+ field_messages: messages
+ # The key in the message turn that contains the role. Default is "role".
+ message_field_role: role
+ # The key in the message turn that contains the content. Default is "content".
+ message_field_content: content
+ # Optional[Dict[str, List]]. Roles mapping for the messages.
+ roles:
+ user: ["human", "user"]
+ assistant: ["gpt", "assistant", "ai"]
+ system: ["system"]
+
+ ## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
+
+ # Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
+ roles_to_train: ["gpt", "assistant"]
+ # Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
+ # - all: train on all EOS tokens
+ # - turn: train on the EOS token at the end of each trainable turn
+ # - last: train on the last EOS token in the conversation
+ train_on_eos: last
+ # The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
+ message_field_training: training
+ # The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
+ # The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
+ # See example at `docs/dataset-formats/conversation.qmd`
+ message_field_training_detail: train_detail
+
+
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
@@ -141,9 +184,16 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto'
rl:
-# Saves the desired chat template to the tokenizer_config.json for easier inferencing
-# Currently supports chatml and inst (mistral/mixtral)
-chat_template: chatml
+# The name of the chat template to use for training, following values are supported:
+# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
+# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
+# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
+# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
+# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
+# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
+chat_template: tokenizer_default
+# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
+chat_template_jinja: null
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
@@ -265,8 +315,21 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
+mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
+# Comet configuration if you're using it
+# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
+# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
+use_comet: # Enable or disable Comet integration.
+comet_api_key: # API key for Comet. Recommended to set via `comet login`.
+comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
+comet_project_name: # Project name in Comet. Defaults to Uncategorized.
+comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
+comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
+comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
+comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
+
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -301,7 +364,7 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
-eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
+eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd
index 28d13c987c..c7273c5be5 100644
--- a/docs/dataset-formats/conversation.qmd
+++ b/docs/dataset-formats/conversation.qmd
@@ -6,6 +6,8 @@ order: 3
## sharegpt
+UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
+
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```{.json filename="data.jsonl"}
@@ -69,3 +71,138 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```
+
+
+## chat_template
+
+Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
+
+```{.json filename="data.jsonl"}
+{"conversations": [{"role": "...", "content": "..."}]}
+```
+
+See `config.qmd` for full configs and supported templates.
+
+### Migrating from sharegpt
+
+Most configs can be adapted as follows:
+
+```yaml
+# old
+chat_template: chatml
+datasets:
+ - path: ...
+ type: sharegpt
+ conversation: chatml
+
+# new (if using tokenizer's chat_template)
+datasets:
+ - path: ...
+ type: chat_template
+
+ field_messages: conversations
+ message_field_role: from
+ message_field_content: value
+
+# new (if setting a new chat_template like chatml, gemma, etc)
+chat_template: chatml
+datasets:
+ - path: ...
+ type: chat_template
+
+ field_messages: conversations
+ message_field_role: from
+ message_field_content: value
+```
+
+We recommend checking the below examples for other usecases.
+
+### Examples
+
+1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
+
+```yaml
+datasets:
+ - path: ...
+ type: chat_template
+```
+
+2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
+
+```yaml
+chat_template: gemma # this overwrites the tokenizer's chat_template
+datasets:
+ - path: ...
+ type: chat_template
+ roles_to_train: ["assistant"]
+```
+
+3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
+
+```yaml
+chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
+datasets:
+ - path: ...
+ type: chat_template
+ roles_to_train: ["assistant"]
+```
+
+4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
+
+```yaml
+# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
+chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
+
+datasets:
+ - path: ...
+ type: chat_template
+ roles_to_train: ["assistant"]
+```
+
+5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
+
+For a data sample that looks like:
+
+```{.json filename="data.jsonl"}
+{
+ "conversations": [
+ {"from": "system", "value": "You are an AI assistant.", "train": false},
+ {"from": "human", "value": "Hello", "train": false},
+ {"from": "assistant", "value": "Hello", "train": true},
+ {"from": "human", "value": "How are you?", "train": true},
+ {
+ "from": "assistant",
+ "value": "I'm doing very well, thank you!",
+ "train_detail": [
+ {"begin_offset": 0, "end_offset": 8, "train": false},
+ {"begin_offset": 9, "end_offset": 18, "train": true},
+ {"begin_offset": 19, "end_offset": 30, "train": false},
+ ],
+ },
+ {
+ "from": "human",
+ "value": "I'm doing very well, thank you!",
+ "train": true,
+ },
+ {"from": "assistant", "value": "Hi there!", "train": true}
+ ]
+}
+```
+
+The configuration would look like:
+
+```yaml
+datasets:
+ - path: ...
+ type: chat_template
+ chat_template: tokenizer_default
+ field_messages: conversations
+ message_field_role: from
+ message_field_content: value
+ roles_to_train: []
+ train_on_eos: turn
+ message_field_training: train
+ message_field_training_detail: train_detail
+```
+
+Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.
diff --git a/docs/dataset-formats/tokenized.qmd b/docs/dataset-formats/tokenized.qmd
index b2ea003c02..61028cae7f 100644
--- a/docs/dataset-formats/tokenized.qmd
+++ b/docs/dataset-formats/tokenized.qmd
@@ -7,7 +7,7 @@ order: 5
- Pass an empty `type:` in your axolotl config.
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
- To indicate that a token should be ignored during training, set its corresponding label to `-100`.
-- Do not add BOS/EOS. Axolotl will add them for you based on the default tokenizer for the model you're using.
+- You must add BOS and EOS, and make sure that you are training on EOS by not setting its label to -100.
- For pretraining, do not truncate/pad documents to the context window length.
- For instruction training, documents must be truncated/padded as desired.
diff --git a/docs/debugging.qmd b/docs/debugging.qmd
index 1d0779b073..029549d85b 100644
--- a/docs/debugging.qmd
+++ b/docs/debugging.qmd
@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
### Background
-The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
+The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
```yaml
datasets:
- - path: # example on HF Hub: philschmid/guanaco-sharegpt-style
- type: sharegpt
+ - path: # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
+ type: chat_template
```
>[!Important]
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
-For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
+For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```jsonc
// .vscode/launch.json
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"version": "0.2.0",
"configurations": [
{
- "name": "Debug axolotl prompt - sharegpt",
+ "name": "Debug axolotl prompt - chat_template",
"type": "python",
"module": "accelerate.commands.launch",
"request": "launch",
"args": [
- "-m", "axolotl.cli.train", "dev_sharegpt.yml",
+ "-m", "axolotl.cli.train", "dev_chat_template.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
-[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
+[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
diff --git a/docs/input_output.qmd b/docs/input_output.qmd
index 7715dd250d..6559578d18 100644
--- a/docs/input_output.qmd
+++ b/docs/input_output.qmd
@@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
hi there!. goodbye farewell
```
-We can check that the right tokens are ingored by comparing the labels
+We can check that the right tokens are ignored by comparing the labels
to each token:
```python
diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd
new file mode 100644
index 0000000000..2381566adb
--- /dev/null
+++ b/docs/multimodal.qmd
@@ -0,0 +1,28 @@
+# MultiModal / Vision Language Models (BETA)
+
+### Supported Models
+
+- Mllama, i.e. llama with vision models
+
+### Usage
+
+Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
+you'll need to use the following in YAML in combination with the rest of the required hyperparams.
+
+```yaml
+base_model: alpindale/Llama-3.2-11B-Vision-Instruct
+processor_type: AutoProcessor
+skip_prepare_dataset: true
+
+chat_template: llama3_2_vision
+datasets:
+ - path: HuggingFaceH4/llava-instruct-mix-vsft
+ type: chat_template
+ split: train[:1%]
+ field_messages: messages
+remove_unused_columns: false
+sample_packing: false
+
+# only finetune the Language model, leave the vision model and vision tower frozen
+lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
+```
diff --git a/docs/unsloth.qmd b/docs/unsloth.qmd
index 390609fd33..90cb49bafa 100644
--- a/docs/unsloth.qmd
+++ b/docs/unsloth.qmd
@@ -34,7 +34,7 @@ unsloth_lora_o: true
```
These options are composable and can be used with multi-gpu finetuning
-```
+```yaml
unsloth_cross_entropy_loss: true
unsloth_rms_norm: true
unsloth_rope: true
diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml
new file mode 100644
index 0000000000..b55646df7f
--- /dev/null
+++ b/examples/deepseek-v2/fft-fsdp-16b.yaml
@@ -0,0 +1,67 @@
+base_model: deepseek-ai/DeepSeek-V2-Lite
+trust_remote_code: true
+
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: tatsu-lab/alpaca
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.0
+output_dir: ./outputs/out
+
+sequence_len: 2048
+sample_packing: true
+pad_to_sequence_len: true
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 8
+micro_batch_size: 1
+num_epochs: 1
+optimizer: adamw_torch
+lr_scheduler: cosine
+learning_rate: 2e-5
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+early_stopping_patience:
+resume_from_checkpoint:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 100
+evals_per_epoch: 2
+eval_table_size:
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+special_tokens:
+fsdp:
+ - full_shard
+ - auto_wrap
+fsdp_config:
+ fsdp_limit_all_gathers: true
+ fsdp_sync_module_states: true
+ fsdp_offload_params: true
+ fsdp_use_orig_params: false
+ fsdp_cpu_ram_efficient_loading: true
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
+ fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_sharding_strategy: FULL_SHARD
diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
new file mode 100644
index 0000000000..0320e02138
--- /dev/null
+++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
@@ -0,0 +1,86 @@
+base_model: axolotl-quants/DeepSeek-V2.5-bnb-nf4-bf16
+trust_remote_code: true
+
+load_in_8bit: false
+load_in_4bit: true
+strict: false
+
+
+plugins:
+ - axolotl.integrations.liger.LigerPlugin
+liger_rms_norm: true
+liger_swiglu: true
+liger_fused_linear_cross_entropy: true
+
+chat_template: deepseek_v2
+datasets:
+ - path: mlabonne/FineTome-100k
+ type: chat_template
+ split: train[:20%]
+ field_messages: conversations
+ message_field_role: from
+ message_field_content: value
+
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.0
+output_dir: ./outputs/out
+
+sequence_len: 4096
+sample_packing: true
+pad_to_sequence_len: true
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+adapter: qlora
+lora_r: 256
+lora_alpha: 256
+lora_target_linear: true
+peft_use_rslora: true
+
+gradient_accumulation_steps: 1
+micro_batch_size: 8
+num_epochs: 1
+optimizer: adamw_torch
+lr_scheduler: cosine
+learning_rate: 2e-5
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+early_stopping_patience:
+resume_from_checkpoint:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 100
+evals_per_epoch: 2
+eval_table_size:
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+special_tokens:
+fsdp:
+ - full_shard
+ - auto_wrap
+fsdp_config:
+ fsdp_limit_all_gathers: true
+ fsdp_sync_module_states: true
+ fsdp_offload_params: true
+ fsdp_use_orig_params: false
+ fsdp_cpu_ram_efficient_loading: true
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
+ fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_sharding_strategy: FULL_SHARD
diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml
index b6dd653750..00e6d84e0d 100644
--- a/examples/gemma2/qlora.yml
+++ b/examples/gemma2/qlora.yml
@@ -11,8 +11,11 @@ chat_template: gemma
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
- chat_template: gemma
drop_system_message: true
+ field_messages: conversations
+ message_field_role: from
+ message_field_content: value
+
val_set_size: 0.0
output_dir: ./outputs/out
diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml
new file mode 100644
index 0000000000..c1f993c3ae
--- /dev/null
+++ b/examples/gemma2/reward-model.yaml
@@ -0,0 +1,63 @@
+base_model: google/gemma-2-2b
+model_type: AutoModelForSequenceClassification
+tokenizer_type: AutoTokenizer
+
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+
+reward_model: true
+chat_template: gemma
+datasets:
+ - path: argilla/distilabel-intel-orca-dpo-pairs
+ type: bradley_terry.chat_template
+val_set_size: 0.0
+output_dir: ./outputs/out
+remove_unused_columns: false
+
+sequence_len: 2048
+sample_packing: false
+eval_sample_packing: false
+pad_to_sequence_len: true
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 4
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16:
+tf32: true
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_ratio: 0.1
+evals_per_epoch:
+eval_table_size:
+eval_max_new_tokens: 128
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
diff --git a/examples/jamba/README.md b/examples/jamba/README.md
index 54f5d1da9c..4c9dc85a06 100644
--- a/examples/jamba/README.md
+++ b/examples/jamba/README.md
@@ -6,5 +6,5 @@
- ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)
- ✅ qlora single-gpu, ~51GiB VRAM
- ✅ multipack
-- ❓ FSDP
+- ✅ FSDP
- ❓ 8-bit LoRA
diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml
new file mode 100644
index 0000000000..84cf906422
--- /dev/null
+++ b/examples/jamba/qlora_fsdp_large.yaml
@@ -0,0 +1,65 @@
+base_model: ai21labs/AI21-Jamba-1.5-Large
+tokenizer_type: AutoTokenizer
+
+load_in_4bit: true
+strict: false
+use_tensorboard: true
+chat_template: jamba
+datasets:
+ - path: cgato/SlimOrcaDedupCleaned
+ type: chat_template
+ drop_system_message: true
+ field_messages: conversations
+ message_field_role: from
+ message_field_content: value
+
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.0
+output_dir: jamba-large-fsdp-qlora-ft
+save_safetensors: true
+adapter: qlora
+sequence_len: 2048
+sample_packing: true
+pad_to_sequence_len: true
+
+lora_r: 16
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj]
+lora_target_linear: false
+
+gradient_accumulation_steps: 4
+micro_batch_size: 1
+num_epochs: 2
+optimizer: adamw_torch
+lr_scheduler: cosine
+learning_rate: 0.00001
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+tf32: true
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: true
+logging_steps: 1
+flash_attention: true
+
+warmup_steps: 10
+evals_per_epoch: 1
+saves_per_epoch: 1
+weight_decay: 0.0
+fsdp:
+ - full_shard
+ - auto_wrap
+fsdp_config:
+ fsdp_limit_all_gathers: true
+ fsdp_sync_module_states: true
+ fsdp_offload_params: false
+ fsdp_use_orig_params: false
+ fsdp_cpu_ram_efficient_loading: true
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
+ fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_sharding_strategy: FULL_SHARD
diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml
new file mode 100644
index 0000000000..b2e4946418
--- /dev/null
+++ b/examples/llama-3-vision/lora-11b.yaml
@@ -0,0 +1,63 @@
+base_model: alpindale/Llama-3.2-11B-Vision-Instruct
+processor_type: AutoProcessor
+strict: false
+
+# these 3 lines are needed for now to handle vision chat templates w images
+skip_prepare_dataset: true
+remove_unused_columns: false
+sample_packing: false
+
+chat_template: llama3_2_vision
+datasets:
+ - path: HuggingFaceH4/llava-instruct-mix-vsft
+ type: chat_template
+ split: train[:1%]
+ field_messages: messages
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.0
+output_dir: ./outputs/out
+
+adapter: lora
+lora_model_dir:
+
+sequence_len: 8192
+pad_to_sequence_len: false
+
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 1
+num_epochs: 1
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16:
+tf32: true
+
+gradient_checkpointing: true
+local_rank:
+logging_steps: 1
+flash_attention: true
+eager_attention:
+
+warmup_ratio: 0.1
+evals_per_epoch: 1
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml
new file mode 100644
index 0000000000..99ba63fcc6
--- /dev/null
+++ b/examples/llama-3/fft-8b-liger-fsdp.yaml
@@ -0,0 +1,80 @@
+base_model: NousResearch/Meta-Llama-3.1-8B
+
+plugins:
+ - axolotl.integrations.liger.LigerPlugin
+liger_rope: true
+liger_rms_norm: true
+liger_swiglu: true
+liger_fused_linear_cross_entropy: true
+
+strict: false
+
+chat_template: llama3
+datasets:
+ - path: mlabonne/FineTome-100k
+ type: chat_template
+ split: train[:20%]
+ field_messages: conversations
+ message_field_role: from
+ message_field_content: value
+
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.02
+output_dir: ./outputs/out
+
+sequence_len: 4096
+sample_packing: true
+pad_to_sequence_len: true
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 1
+optimizer: adamw_torch
+lr_scheduler: cosine
+learning_rate: 2e-5
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+early_stopping_patience:
+resume_from_checkpoint:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 100
+evals_per_epoch: 2
+eval_table_size:
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+ - full_shard
+ - auto_wrap
+fsdp_config:
+ fsdp_limit_all_gathers: true
+ fsdp_sync_module_states: true
+ fsdp_offload_params: true
+ fsdp_use_orig_params: false
+ fsdp_cpu_ram_efficient_loading: true
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
+ fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_sharding_strategy: FULL_SHARD
+ fsdp_backward_prefetch: BACKWARD_PRE
+special_tokens:
+ pad_token: <|finetune_right_pad_id|>
+ eos_token: <|eot_id|>
diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml
index 908ef6e035..335902aac7 100644
--- a/examples/llama-3/fft-8b.yaml
+++ b/examples/llama-3/fft-8b.yaml
@@ -1,6 +1,4 @@
-base_model: NousResearch/Meta-Llama-3-8B
-model_type: LlamaForCausalLM
-tokenizer_type: AutoTokenizer
+base_model: NousResearch/Meta-Llama-3.1-8B
load_in_8bit: false
load_in_4bit: false
diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml
index 14febb810a..dc88350358 100644
--- a/examples/llama-3/instruct-dpo-lora-8b.yml
+++ b/examples/llama-3/instruct-dpo-lora-8b.yml
@@ -11,7 +11,6 @@ rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
- chat_template: llama3
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml
index 4acad59999..ae9a8088c3 100644
--- a/examples/llama-3/instruct-lora-8b.yml
+++ b/examples/llama-3/instruct-lora-8b.yml
@@ -10,7 +10,6 @@ chat_template: llama3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
- chat_template: llama3
field_messages: messages
message_field_role: role
message_field_content: content
diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml
new file mode 100644
index 0000000000..fdfe4aa7c8
--- /dev/null
+++ b/examples/llama-3/qlora-1b.yml
@@ -0,0 +1,77 @@
+base_model: meta-llama/Llama-3.2-1B
+
+load_in_8bit: false
+load_in_4bit: true
+strict: false
+
+datasets:
+ - path: teknium/GPT4-LLM-Cleaned
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.1
+output_dir: ./outputs/qlora-out
+
+adapter: qlora
+lora_model_dir:
+
+sequence_len: 2048
+sample_packing: true
+eval_sample_packing: true
+pad_to_sequence_len: true
+
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+lora_target_modules:
+ - gate_proj
+ - down_proj
+ - up_proj
+ - q_proj
+ - v_proj
+ - k_proj
+ - o_proj
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 1
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+loss_watchdog_threshold: 5.0
+loss_watchdog_patience: 3
+
+warmup_steps: 10
+evals_per_epoch: 4
+eval_table_size:
+eval_max_new_tokens: 128
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: "<|end_of_text|>"
diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml
new file mode 100644
index 0000000000..246701148c
--- /dev/null
+++ b/examples/phi/lora-3.5.yaml
@@ -0,0 +1,75 @@
+base_model: microsoft/Phi-3.5-mini-instruct
+model_type: AutoModelForCausalLM
+tokenizer_type: AutoTokenizer
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+chat_template: phi_3
+datasets:
+ - path: fozziethebeat/alpaca_messages_2k_test
+ type: chat_template
+ field_messages: messages
+ message_field_role: role
+ message_field_content: content
+ roles:
+ user:
+ - user
+ assistant:
+ - assistant
+
+dataset_prepared_path:
+val_set_size: 0.05
+output_dir: ./outputs/lora-out
+
+sequence_len: 4096
+sample_packing: false
+pad_to_sequence_len: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 4
+num_epochs: 2
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bfloat16: true
+bf16: true
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+s2_attention:
+
+warmup_steps: 10
+evals_per_epoch: 4
+eval_table_size:
+eval_max_new_tokens: 128
+saves_per_epoch: 4
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml
index 44f9c7e495..d61c72a378 100644
--- a/examples/qwen2/qlora-fsdp.yaml
+++ b/examples/qwen2/qlora-fsdp.yaml
@@ -72,4 +72,5 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_sharding_strategy: FULL_SHARD
special_tokens:
diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml
index e501dcb8e5..010a1608a3 100644
--- a/examples/tiny-llama/pretrain.yml
+++ b/examples/tiny-llama/pretrain.yml
@@ -9,9 +9,9 @@ strict: false
max_steps: 200
pretraining_dataset:
- path: c4
- name: en
- type: pretrain
+ - path: allenai/c4
+ name: en
+ type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/model-out
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 4b5df167b6..dcc729d1b2 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -2,3 +2,4 @@ pre-commit
black
mypy
types-requests
+tbparse
diff --git a/requirements.txt b/requirements.txt
index dc74b916f8..6bb1aa6848 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,12 +1,12 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
-peft==0.12.0
-transformers==4.44.0
-tokenizers>=0.19.1
-bitsandbytes==0.43.3
-accelerate==0.33.0
-datasets==2.20.0
-deepspeed==0.14.4
+peft==0.13.2
+transformers==4.46.0
+tokenizers>=0.20.1
+bitsandbytes==0.44.1
+accelerate==1.0.1
+datasets==3.0.1
+deepspeed==0.15.3
pydantic==2.6.3
addict
fire
@@ -16,12 +16,12 @@ flash-attn==2.6.3
sentencepiece
wandb
einops
-xformers==0.0.27
+xformers>=0.0.23.post1
optimum==1.16.2
hf_transfer
colorama
numba
-numpy>=1.24.4
+numpy>=1.24.4,<=2.0.1
# qlora things
evaluate==0.4.1
scipy
@@ -33,6 +33,8 @@ gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
+triton>=2.3.0
+liger-kernel==0.3.0
mamba-ssm==1.2.0.post1
@@ -41,6 +43,14 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
-trl==0.9.6
+trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
zstandard==0.22.0
fastcore
+
+# lm eval harness
+lm_eval==0.4.4
+langdetect==1.0.9
+immutabledict==4.2.0
+antlr4-python3-runtime==4.13.2
+
+torchao==0.5.0
diff --git a/requirements_env.txt b/requirements_env.txt
new file mode 100644
index 0000000000..f8acbf73c2
--- /dev/null
+++ b/requirements_env.txt
@@ -0,0 +1,315 @@
+accelerate==0.34.1
+addict==2.4.0
+aiofiles==23.2.1
+aiohttp==3.9.0
+aiosignal==1.3.1
+aiostream==0.5.2
+alembic==1.13.1
+annotated-types==0.6.0
+annoy==1.17.3
+ansible==6.7.0
+ansible-core==2.13.13
+ansible-vault==2.1.0
+anyio==3.7.1
+appdirs==1.4.4
+art==6.0
+asgiref==3.7.2
+async-timeout==4.0.2
+attrdict==2.0.1
+attrs==22.2.0
+awscli==1.32.75
+-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
+backoff==2.2.1
+base58==2.1.1
+beartype==0.17.2
+bitnet==0.2.1
+bitsandbytes==0.42.0
+bittensor==6.7.0
+black==23.7.0
+blinker==1.7.0
+boto3==1.34.75
+botocore==1.34.75
+cachetools==5.3.3
+cachy==0.1.1
+certifi==2023.7.22
+cffi==1.16.0
+cfgv==3.3.1
+chai-guanaco==1.2.4
+charset-normalizer==3.2.0
+cleo==0.6.8
+click==8.1.7
+cloudpickle==2.0.0
+cohere==4.11.2
+colorama==0.4.4
+coloredlogs==15.0.1
+CoLT5-attention==0.10.20
+contextlib2==21.6.0
+contourpy==1.2.0
+cryptography==41.0.3
+cycler==0.12.1
+cytoolz==0.12.3
+databricks-cli==0.18.0
+dataclasses-json==0.5.7
+datasets==2.11.0
+ddt==1.6.0
+decorator==5.1.1
+deepspeed==0.15.0
+# Editable Git install with no remote (dialogpt==0.1)
+-e /Users/wing/Projects/ml/dialogpt/src
+dill==0.3.6
+distlib==0.3.6
+docker==7.0.0
+docker-pycreds==0.4.0
+docstring-parser==0.15
+docutils==0.16
+ecdsa==0.18.0
+einops==0.7.0
+einops-exts==0.0.4
+einx==0.1.3
+entrypoints==0.4
+eth-hash==0.6.0
+eth-keys==0.5.0
+eth-typing==4.0.0
+eth-utils==2.3.1
+evaluate==0.4.0
+exceptiongroup==1.1.1
+fastapi==0.109.2
+fastcore==1.5.29
+ffmpy==0.4.0
+filelock==3.12.2
+-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
+fire==0.5.0
+first==2.0.2
+flake8==7.0.0
+Flask==3.0.1
+fonttools==4.47.2
+frozendict==2.4.1
+frozenlist==1.3.3
+fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
+fsspec==2023.6.0
+fuzzywuzzy==0.18.0
+gitdb==4.0.10
+GitPython==3.1.31
+google-pasta==0.2.0
+gradio==4.42.0
+gradio_client==1.3.0
+greenlet==2.0.2
+grpclib==0.4.7
+gunicorn==21.2.0
+h11==0.14.0
+h2==4.1.0
+hpack==4.0.0
+httpcore==0.17.3
+httpx==0.24.1
+huggingface-hub==0.23.4
+humanfriendly==10.0
+hyperframe==6.0.1
+identify==2.5.24
+idna==3.4
+immutables==0.20
+importlib-metadata==6.7.0
+importlib-resources==6.1.1
+inflection==0.5.1
+iniconfig==2.0.0
+itsdangerous==2.1.2
+Jinja2==3.1.2
+jmespath==1.0.1
+joblib==1.3.2
+jsonlines==3.1.0
+jsonschema==2.6.0
+kiwisolver==1.4.5
+langchain==0.0.144
+Levenshtein==0.24.0
+libcst==1.1.0
+liger-kernel==0.0.0
+lion-pytorch==0.1.2
+llama-cpp-python==0.1.36
+llvmlite==0.40.1
+local-attention==1.9.0
+loguru==0.7.0
+Mako==1.3.2
+Markdown==3.5.2
+markdown-it-py==3.0.0
+markdown2==2.4.10
+MarkupSafe==2.1.2
+marshmallow==3.19.0
+marshmallow-enum==1.5.1
+matplotlib==3.8.2
+mccabe==0.7.0
+mdurl==0.1.2
+MEGABYTE-pytorch==0.0.7
+-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
+mlflow==2.10.0
+modal==0.62.77
+more-itertools==10.2.0
+mpmath==1.2.1
+msgpack==1.0.7
+msgpack-numpy-opentensor==0.5.0
+multidict==6.0.4
+multiprocess==0.70.14
+munch==2.5.0
+mypy==1.3.0
+mypy-extensions==1.0.0
+nest-asyncio==1.6.0
+netaddr==0.10.1
+networkx==3.0rc1
+nh3==0.2.14
+nodeenv==1.8.0
+nomic==2.0.2
+numba==0.57.1
+numexpr==2.8.4
+numpy==1.24.4
+oauthlib==3.2.2
+openai==0.27.4
+openapi==1.1.0
+openapi-schema-pydantic==1.2.4
+optimum==1.8.6
+orjson==3.10.7
+packaging==23.1
+pandas==2.0.0
+parameterized==0.9.0
+password-strength==0.0.3.post2
+pastel==0.1.1
+pathos==0.3.0
+pathspec==0.11.1
+pathtools==0.1.2
+peft==0.11.1
+pendulum==3.0.0
+Pillow==9.5.0
+pip-tools==1.11.0
+platformdirs==3.2.0
+pluggy==1.4.0
+poetry==0.7.1
+pox==0.3.2
+ppft==1.7.6.6
+pre-commit==3.3.2
+prettytable==3.10.0
+prompt-toolkit==3.0.39
+protobuf==3.20.2
+protobuf3-to-dict==0.1.5
+psutil==5.9.5
+psycopg==3.1.18
+PuLP==2.8.0
+py==1.11.0
+py-bip39-bindings==0.1.11
+py-cpuinfo==9.0.0
+py-ed25519-zebra-bindings==1.0.1
+py-sr25519-bindings==0.2.0
+pyarrow==11.0.0
+pyasn1==0.6.0
+pycodestyle==2.11.1
+pycparser==2.21
+pycryptodome==3.20.0
+pydantic==2.5.3
+pydantic_core==2.14.6
+pydub==0.25.1
+pyfiglet==0.8.post1
+pyflakes==3.2.0
+Pygments==2.15.1
+PyJWT==2.8.0
+pylev==1.4.0
+PyNaCl==1.5.0
+pynvml==11.5.0
+pyparsing==2.4.7
+pyrsistent==0.14.11
+pytest==8.0.2
+pytest-asyncio==0.23.4
+python-dateutil==2.8.2
+python-dotenv==1.0.1
+python-Levenshtein==0.24.0
+python-multipart==0.0.9
+pytz==2023.3
+PyYAML==6.0.1
+querystring-parser==1.2.4
+rapidfuzz==3.6.1
+regex==2023.6.3
+requests==2.31.0
+requests-toolbelt==0.8.0
+resolvelib==0.8.1
+responses==0.18.0
+retry==0.9.2
+rich==13.7.0
+rsa==4.7.2
+ruff==0.6.3
+s3transfer==0.10.1
+safetensors==0.4.5
+sagemaker==2.148.0
+scalecodec==1.2.7
+schedulefree==1.2.1
+schema==0.7.5
+scikit-learn==1.4.0
+scipy==1.9.3
+seaborn==0.13.2
+semantic-version==2.10.0
+sentencepiece==0.2.0
+sentry-sdk==1.19.1
+setproctitle==1.3.2
+shellingham==1.5.4
+shortuuid==1.0.11
+shtab==1.6.5
+sigtools==4.0.1
+six==1.16.0
+skypilot==0.4.1
+smdebug-rulesconfig==1.0.1
+smmap==5.0.0
+sniffio==1.3.0
+SQLAlchemy==1.4.47
+sqlparse==0.4.4
+starlette==0.36.3
+substrate-interface==1.5.2
+svgwrite==1.4.3
+sympy==1.11.1
+synchronicity==0.6.7
+tabulate==0.9.0
+tblib==1.7.0
+tenacity==8.2.2
+tensor-parallel==2.0.0
+termcolor==2.2.0
+text2art==0.2.0
+threadpoolctl==3.2.0
+tiktoken==0.6.0
+time-machine==2.14.1
+timm==0.9.16
+tokenizers==0.19.1
+tokenmonster==1.1.12
+toml==0.9.6
+tomli==2.0.1
+tomlkit==0.12.0
+toolz==0.12.1
+torch==2.2.0
+torchdata==0.6.1
+torchdiffeq==0.2.3
+TorchFix==0.4.0
+torchtext==0.15.2
+torchvision==0.17.0
+tqdm==4.66.2
+transformers==4.44.2
+trl==0.9.6
+typer==0.12.5
+types-certifi==2021.10.8.3
+types-requests==2.31.0.20240125
+types-setuptools==69.0.0.20240125
+types-toml==0.10.8.7
+typing==3.7.4.3
+typing-inspect==0.8.0
+typing_extensions==4.9.0
+tyro==0.5.18
+tzdata==2023.3
+unique-names-generator==1.0.2
+urllib3==2.2.2
+uvicorn==0.22.0
+vector_quantize_pytorch==1.14.1
+virtualenv==20.23.0
+voyager==2.0.2
+wandb==0.16.2
+watchfiles==0.21.0
+wavedrom==2.0.3.post3
+wcwidth==0.2.6
+websocket-client==1.7.0
+websockets==12.0
+Werkzeug==3.0.1
+wonderwords==2.2.0
+xxhash==3.2.0
+yarl==1.8.2
+zetascale==2.2.7
+zipp==3.15.0
diff --git a/scripts/chat_datasets.py b/scripts/chat_datasets.py
new file mode 100644
index 0000000000..5eb5bde1e2
--- /dev/null
+++ b/scripts/chat_datasets.py
@@ -0,0 +1,60 @@
+"""
+helper script to parse chat datasets into a usable yaml
+"""
+import click
+import yaml
+from datasets import load_dataset
+
+
+@click.command()
+@click.argument("dataset", type=str)
+@click.option("--split", type=str, default="train")
+def parse_dataset(dataset=None, split="train"):
+ ds_cfg = {}
+ ds_cfg["path"] = dataset
+ ds_cfg["split"] = split
+ ds_cfg["type"] = "chat_template"
+ ds_cfg["chat_template"] = "<<>>"
+
+ dataset = load_dataset(dataset, split=split)
+ features = dataset.features
+ feature_keys = features.keys()
+ field_messages = None
+ for key in ["conversation", "conversations", "messages"]:
+ if key in feature_keys:
+ field_messages = key
+ break
+ if not field_messages:
+ raise ValueError(
+ f'No conversation field found in dataset: {", ".join(feature_keys)}'
+ )
+ ds_cfg["field_messages"] = field_messages
+
+ message_fields = features["conversations"][0].keys()
+ message_field_role = None
+ for key in ["from", "role"]:
+ if key in message_fields:
+ message_field_role = key
+ break
+ if not message_field_role:
+ raise ValueError(
+ f'No role field found in messages: {", ".join(message_fields)}'
+ )
+ ds_cfg["message_field_role"] = message_field_role
+
+ message_field_content = None
+ for key in ["content", "text", "value"]:
+ if key in message_fields:
+ message_field_content = key
+ break
+ if not message_field_content:
+ raise ValueError(
+ f'No content field found in messages: {", ".join(message_fields)}'
+ )
+ ds_cfg["message_field_content"] = message_field_content
+
+ print(yaml.dump({"datasets": [ds_cfg]}))
+
+
+if __name__ == "__main__":
+ parse_dataset()
diff --git a/setup.py b/setup.py
index 1d164e0a18..17347f0632 100644
--- a/setup.py
+++ b/setup.py
@@ -30,6 +30,9 @@ def parse_requirements():
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
+ torchao_version = [req for req in _install_requires if "torchao" in req][0]
+ autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
+
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
@@ -49,20 +52,35 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
- if (major, minor) >= (2, 3):
+ if (major, minor) >= (2, 5):
+ _install_requires.pop(_install_requires.index(xformers_version))
+ _install_requires.pop(_install_requires.index(autoawq_version))
+ elif (major, minor) >= (2, 4):
+ if patch == 0:
+ _install_requires.pop(_install_requires.index(xformers_version))
+ _install_requires.append("xformers>=0.0.27")
+ else:
+ _install_requires.pop(_install_requires.index(xformers_version))
+ _install_requires.append("xformers==0.0.28.post1")
+ elif (major, minor) >= (2, 3):
+ _install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
+ else:
+ _install_requires.pop(_install_requires.index(xformers_version))
+ _install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
+ _install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
+ _install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
-
return _install_requires, _dependency_links
@@ -80,7 +98,7 @@ def parse_requirements():
dependency_links=dependency_links,
extras_require={
"flash-attn": [
- "flash-attn==2.6.2",
+ "flash-attn==2.6.3",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
@@ -91,6 +109,7 @@ def parse_requirements():
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
+ "causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py
index a05ee84e97..52765a9b58 100644
--- a/src/axolotl/cli/__init__.py
+++ b/src/axolotl/cli/__init__.py
@@ -27,8 +27,11 @@
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
+from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
+from axolotl.utils.chat_templates import get_chat_template
+from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
@@ -38,7 +41,7 @@
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
-from axolotl.utils.models import load_tokenizer
+from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -52,8 +55,22 @@
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
-
-def print_axolotl_text_art(suffix=None):
+AXOLOTL_LOGO = """
+ #@@ #@@ @@# @@#
+ @@ @@ @@ @@ =@@# @@ #@ =@@#.
+ @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
+ #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
+ @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
+ @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
+ @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
+ =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
+ @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
+ =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
+ @@@@ @@@@@@@@@@@@@@@@
+"""
+
+
+def print_legacy_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
@@ -66,6 +83,13 @@ def print_axolotl_text_art(suffix=None):
print_dep_versions()
+def print_axolotl_text_art(
+ **kwargs, # pylint: disable=unused-argument
+):
+ if is_main_process():
+ print(AXOLOTL_LOGO)
+
+
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
@@ -233,7 +257,8 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
- default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""}
+ # default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""}
+ default_tokens: Dict[str, str] = {}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
@@ -241,10 +266,13 @@ def do_inference_gradio(
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
+ chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
+ elif cfg.chat_template:
+ chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -258,7 +286,24 @@ def generate(instruction):
)
else:
prompt = instruction.strip()
- batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
+
+ if chat_template_str:
+ batch = tokenizer.apply_chat_template(
+ [
+ {
+ "role": "user",
+ "content": prompt,
+ }
+ ],
+ return_tensors="pt",
+ add_special_tokens=True,
+ add_generation_prompt=True,
+ chat_template=chat_template_str,
+ tokenize=True,
+ return_dict=True,
+ )
+ else:
+ batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
@@ -281,6 +326,7 @@ def generate(instruction):
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
+ "attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
@@ -365,6 +411,11 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg.axolotl_config_path = config
+ if cfg.get("plugins"):
+ plugin_manager = PluginManager.get_instance()
+ for plugin_name in cfg["plugins"]:
+ plugin_manager.register(plugin_name)
+
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
@@ -392,6 +443,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
setup_mlflow_env_vars(cfg)
+ setup_comet_env_vars(cfg)
+
return cfg
@@ -401,12 +454,20 @@ def load_datasets(
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
+ processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
- cfg, tokenizer
+ cfg,
+ tokenizer,
+ processor=processor,
)
- if cli_args.debug or cfg.debug:
+ if (
+ cli_args.debug
+ or cfg.debug
+ or cli_args.debug_text_only
+ or int(cli_args.debug_num_examples) > 0
+ ):
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py
new file mode 100644
index 0000000000..25408fd57e
--- /dev/null
+++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py
@@ -0,0 +1,204 @@
+"""
+This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
+"""
+import json
+import logging
+import os
+import shutil
+from pathlib import Path
+from typing import Dict, Union
+
+import fire
+import torch
+import torch.distributed.checkpoint as dist_cp
+import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
+import transformers
+from accelerate.utils import (
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ is_torch_version,
+)
+from dotenv import load_dotenv
+from huggingface_hub import split_torch_state_dict_into_shards
+from safetensors.torch import save_file as safe_save_file
+from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
+
+from axolotl.cli import load_cfg, print_axolotl_text_art
+from axolotl.common.cli import TrainerCliArgs
+
+LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")
+
+
+class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
+ """
+ A custom planner to cast tensors to bfloat16 on the fly during loading.
+ """
+
+ def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
+ tensor.copy_(tensor.to(torch.bfloat16))
+
+
+def _distributed_checkpoint_to_merged_weights(
+ checkpoint_dir: Union[str, Path],
+ save_path: str,
+ safe_serialization: bool = False,
+ max_shard_size: str = "5GB",
+):
+ """
+ Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
+
+ Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
+ """
+
+ state_dict: Dict = {}
+ save_path_ = Path(save_path)
+ save_path_.mkdir(exist_ok=True)
+ dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
+ state_dict,
+ storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
+ planner=BFloat16CastPlanner(), # pylint: disable=protected-access
+ no_dist=True,
+ )
+
+ # To handle if state is a dict like {model: {...}}
+ if len(state_dict.keys()) == 1:
+ state_dict = state_dict[list(state_dict)[0]]
+
+ # Ensure all tensors are in bfloat16
+ for key, value in state_dict.items():
+ if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
+ state_dict[key] = value.to(torch.bfloat16)
+
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
+ ".safetensors", "{suffix}.safetensors"
+ )
+ state_dict_split = split_torch_state_dict_into_shards(
+ state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
+ )
+ # Save index if sharded
+ index = None
+ if state_dict_split.is_sharded:
+ index = {
+ "metadata": state_dict_split.metadata,
+ "weight_map": state_dict_split.tensor_to_filename,
+ }
+
+ # Save the model
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
+
+ for shard_file, tensors in filename_to_tensors:
+ shard = {tensor: state_dict[tensor] for tensor in tensors}
+
+ if safe_serialization:
+ safe_save_file(
+ shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
+ )
+ else:
+ torch.save(shard, os.path.join(save_path_, shard_file))
+
+ if index is not None:
+ save_index_file = (
+ SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
+ )
+ save_index_file = os.path.join(save_path_, save_index_file)
+ # Save the index as well
+ with open(save_index_file, "w", encoding="utf-8") as fout:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ fout.write(content)
+
+ return save_path_
+
+
+def merge_fsdp_weights(
+ checkpoint_dir: str,
+ output_path: str,
+ safe_serialization: bool = False,
+ remove_checkpoint_dir: bool = False,
+):
+ """
+ Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
+ `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
+ `safe_serialization` else `pytorch_model.bin`.
+
+ Note: this is a CPU-bound process.
+
+ Args:
+ checkpoint_dir (`str`):
+ The directory containing the FSDP checkpoints (can be either the model or optimizer).
+ output_path (`str`):
+ The path to save the merged checkpoint.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the merged weights with safetensors (recommended).
+ remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
+ Whether to remove the checkpoint directory after merging.
+ """
+ checkpoint_dir_ = Path(checkpoint_dir)
+ from accelerate.state import PartialState
+
+ if not is_torch_version(">=", "2.3.0"):
+ raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
+
+ # Verify that the checkpoint directory exists
+ if not checkpoint_dir_.exists():
+ model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists()
+ optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists()
+ err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file."
+ if model_path_exists and optimizer_path_exists:
+ err += (
+ " However, potential model and optimizer checkpoint directories exist."
+ )
+ err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0"
+ err += "instead."
+ elif model_path_exists:
+ err += " However, a potential model checkpoint directory exists."
+ err += (
+ f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead."
+ )
+ elif optimizer_path_exists:
+ err += " However, a potential optimizer checkpoint directory exists."
+ err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead."
+ raise ValueError(err)
+
+ # To setup `save` to work
+ state = PartialState()
+ if state.is_main_process:
+ LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
+ save_path = _distributed_checkpoint_to_merged_weights(
+ checkpoint_dir_, output_path, safe_serialization
+ )
+ LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
+ if remove_checkpoint_dir:
+ LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
+ shutil.rmtree(checkpoint_dir_)
+ state.wait_for_everyone()
+
+
+def do_cli(config: Path = Path("examples/"), **kwargs):
+ # pylint: disable=duplicate-code
+ print_axolotl_text_art()
+ parser = transformers.HfArgumentParser((TrainerCliArgs))
+ parsed_cli_args, _ = parser.parse_args_into_dataclasses(
+ return_remaining_strings=True
+ )
+ parsed_cli_args.merge_lora = True
+
+ parsed_cfg = load_cfg(
+ config,
+ **kwargs,
+ )
+
+ fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
+ merge_fsdp_weights(
+ checkpoint_dir=str(fsdp_dir),
+ output_path=str(Path(parsed_cfg.output_dir) / "merged"),
+ safe_serialization=True,
+ )
+
+
+if __name__ == "__main__":
+ load_dotenv()
+ fire.Fire(do_cli)
diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py
index e12462c000..aab29e2670 100644
--- a/src/axolotl/cli/preprocess.py
+++ b/src/axolotl/cli/preprocess.py
@@ -27,6 +27,7 @@
register_chatml_template,
register_llama3_template,
)
+from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -70,10 +71,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
- if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
- load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
- else:
- load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
+ with disable_datasets_caching():
+ if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
+ load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
+ else:
+ load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.download:
model_name = parsed_cfg.base_model
diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py
index 050f18a054..16d66a82f0 100644
--- a/src/axolotl/cli/train.py
+++ b/src/axolotl/cli/train.py
@@ -3,13 +3,11 @@
"""
import logging
from pathlib import Path
-from typing import Tuple, Union
+from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
-from transformers.modeling_utils import PreTrainedModel
-from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -20,6 +18,7 @@
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
+from axolotl.integrations.base import PluginManager
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
@@ -39,7 +38,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return do_train(parsed_cfg, parsed_cli_args)
-def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
+def do_train(cfg, cli_args) -> None:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
@@ -64,7 +63,13 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
- return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ plugin_manager = PluginManager.get_instance()
+
+ del model
+ del tokenizer
+
+ plugin_manager.post_train_unload(cfg)
if __name__ == "__main__":
diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py
index c96f8f81ff..6a3a22e637 100644
--- a/src/axolotl/common/cli.py
+++ b/src/axolotl/common/cli.py
@@ -23,7 +23,7 @@ class TrainerCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
- debug_num_examples: int = field(default=5)
+ debug_num_examples: int = field(default=0)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
diff --git a/src/axolotl/core/chat/__init__.py b/src/axolotl/core/chat/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/axolotl/core/chat/format/__init__.py b/src/axolotl/core/chat/format/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/axolotl/core/chat/format/chatml.py b/src/axolotl/core/chat/format/chatml.py
new file mode 100644
index 0000000000..315d101a86
--- /dev/null
+++ b/src/axolotl/core/chat/format/chatml.py
@@ -0,0 +1,34 @@
+"""
+ChatML transformation functions for MessageContents
+"""
+from typing import Optional
+
+from ..messages import MessageContents, Messages
+from .shared import wrap_tools
+
+
+def format_message(
+ message: Messages,
+ message_index: Optional[int] = None, # pylint: disable=unused-argument
+) -> Messages:
+ if message.is_chat_formatted:
+ return message
+
+ # prepend the role prefix within a MessageContents to message.content
+ message.content.insert(
+ 0,
+ MessageContents(
+ type="text",
+ value=f"<|im_start|>{message.role}\n",
+ weight=0,
+ ),
+ )
+ message.content.append(
+ MessageContents(type="text", value="<|im_end|>", weight=message.weight)
+ )
+ message.content.append(MessageContents(type="text", value="\n", weight=0))
+
+ message = wrap_tools(message)
+
+ message.is_chat_formatted = True
+ return message
diff --git a/src/axolotl/core/chat/format/llama3x.py b/src/axolotl/core/chat/format/llama3x.py
new file mode 100644
index 0000000000..17fa7aa8d4
--- /dev/null
+++ b/src/axolotl/core/chat/format/llama3x.py
@@ -0,0 +1,45 @@
+"""
+Llama 3.x chat formatting functions for MessageContents
+"""
+from typing import Optional
+
+from ..messages import MessageContents, Messages
+from .shared import wrap_tools
+
+
+def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
+ if message.is_chat_formatted:
+ return message
+
+ message_role = message.role
+ if message.role == "tool":
+ message_role = "ipython"
+
+ # prepend the role prefix within a MessageContents to message.content
+ message.content.insert(
+ 0,
+ MessageContents(
+ type="text",
+ value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n",
+ weight=0,
+ ),
+ )
+
+ message.content.append(
+ MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
+ )
+
+ message = wrap_tools(message)
+
+ if message_index == 0:
+ message.content.insert(
+ 0,
+ MessageContents(
+ type="text",
+ value="<|begin_of_text|>",
+ weight=0,
+ ),
+ )
+
+ message.is_chat_formatted = True
+ return message
diff --git a/src/axolotl/core/chat/format/shared.py b/src/axolotl/core/chat/format/shared.py
new file mode 100644
index 0000000000..9efa2353db
--- /dev/null
+++ b/src/axolotl/core/chat/format/shared.py
@@ -0,0 +1,47 @@
+"""
+shared functions for format transforms
+"""
+from axolotl.core.chat.messages import MessageContents, Messages
+
+
+def wrap_tools(message: Messages):
+ # loop over message.content by index to find tool calls, we need to wrap each with tags,
+ # so be wary of indexing issues when changing the list while iterating.
+ # iterate over the range in reverse order to avoid index shifting
+ for i in range(len(message.content) - 1, -1, -1):
+ if message.content[i].type == "tool_call":
+ # append a MessageContents text tag after
+ message.content.insert(
+ i + 1,
+ MessageContents(
+ type="text", value="\n", weight=message.weight
+ ),
+ )
+ # make sure the actual tool call content ends with a newline
+ message.content[i].has_newline = True
+ # prepend a MessageContents text tag before
+ message.content.insert(
+ i,
+ MessageContents(
+ type="text", value="\n", weight=message.weight
+ ),
+ )
+ elif message.content[i].type == "tool_response":
+ # append a MessageContents text tag after
+ message.content.insert(
+ i + 1,
+ MessageContents(
+ type="text", value="\n", weight=message.weight
+ ),
+ )
+ # make sure the actual tool response content ends with a newline
+ message.content[i].has_newline = True
+ # prepend a MessageContents text tag before
+ message.content.insert(
+ i,
+ MessageContents(
+ type="text", value="\n", weight=message.weight
+ ),
+ )
+
+ return message
diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py
new file mode 100644
index 0000000000..c879bf477b
--- /dev/null
+++ b/src/axolotl/core/chat/messages.py
@@ -0,0 +1,230 @@
+"""
+internal message representations of chat messages
+"""
+import json
+from enum import Enum
+from typing import Any, Callable, List, Optional, Union
+
+from pydantic import BaseModel
+from transformers import PreTrainedTokenizer
+
+
+class MessageRoles(str, Enum):
+ """
+ Message roles for the system, user, assistant, and tools
+ """
+
+ system = "system" # pylint: disable=invalid-name
+ user = "user" # pylint: disable=invalid-name
+ assistant = "assistant" # pylint: disable=invalid-name
+ tool = "tool" # pylint: disable=invalid-name
+ ipython = ( # pylint: disable=invalid-name
+ # for responses from builtin tools
+ "ipython"
+ )
+
+
+class MessageContentTypes(str, Enum):
+ """
+ Message content types for text, image, audio, tool calls, and tool responses
+ """
+
+ special_token = "special_token" # pylint: disable=invalid-name # nosec B105
+ text = "text" # pylint: disable=invalid-name
+ image = "image" # pylint: disable=invalid-name
+ audio = "audio" # pylint: disable=invalid-name
+ tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
+ tool_response = "tool_response" # pylint: disable=invalid-name
+
+
+class SpecialToken(str, Enum):
+ """
+ Special tokens for beginning of string and end of string
+ """
+
+ bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
+ eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
+
+
+class ToolCallFunction(BaseModel):
+ """
+ Tool call function with name and arguments
+ """
+
+ name: str
+ arguments: dict[str, str]
+
+
+class Tool(BaseModel):
+ """
+ Tool with description, function, and parameters
+ """
+
+ description: str
+ function: ToolCallFunction
+ parameters: dict[str, str] # .properties
+
+
+class ToolCallContents(BaseModel):
+ """
+ Tool call contents with name, arguments, and optional id
+ """
+
+ name: str
+ arguments: dict[str, Union[str, int]]
+ id: Optional[str] = None # pylint: disable=invalid-name
+
+ def __str__(self) -> str:
+ data = {"name": self.name, "arguments": self.arguments}
+ if self.id is not None:
+ data["id"] = self.id
+ return json.dumps(data)
+
+
+class ToolResponseContents(BaseModel):
+ """
+ Tool response contents with name, content, and optional id
+ """
+
+ name: str
+ content: Union[str, dict[str, Union[str, int, float]]]
+ id: Optional[str] = None # pylint: disable=invalid-name
+
+ def __str__(self) -> str:
+ data = {"name": self.name, "content": self.content}
+ if self.id is not None:
+ data["id"] = self.id
+ return json.dumps(data)
+
+
+class MessageContents(BaseModel):
+ """
+ Message contents with type, value, metadata, weight, newline, and end of contents
+ """
+
+ type: Union[str, MessageContentTypes]
+ value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
+ meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
+ weight: Optional[Union[int, float]] = None
+ has_newline: bool = False
+ eoc: bool = False # end of contents
+
+ def __str__(self) -> str:
+ str_val = str(self.value)
+ if self.has_newline and not str_val.endswith("\n"):
+ str_val += "\n"
+ return str_val
+
+
+class Messages(BaseModel):
+ """
+ Messages with role, content, metadata, weight, and chat formatting
+ """
+
+ role: Union[MessageRoles, str] # allows for arbitrary roles
+ content: List["MessageContents"]
+ meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
+ weight: Optional[Union[int, float]] = None
+ is_chat_formatted: bool = False
+
+ def __str__(self) -> str:
+ return "".join(str(c) for c in self.content)
+
+ def tokenized(
+ self, tokenizer: PreTrainedTokenizer, ignore_index=-100
+ ) -> dict[str, List[int]]:
+ # iterate over the contents, tokenizing the concatenated string values up to the current MessageContents
+ # returns a dictionary mapping w input_ids, attention_mask, and labels
+ input_ids: List[int] = []
+ labels: List[int] = []
+ pending_input_ids: List[int] = []
+ pending_weight = self.weight
+ running_content = ""
+ for _, msg_content in enumerate(self.content):
+ # TODO also handle non-text content types
+ if msg_content.type in [
+ MessageContentTypes.text.value,
+ MessageContentTypes.tool_call.value,
+ MessageContentTypes.tool_response.value,
+ ]:
+ running_content += str(msg_content)
+ tok_results = tokenizer(running_content, add_special_tokens=False)
+ tok_input_ids = tok_results["input_ids"]
+ if pending_input_ids:
+ new_pending_inputs = tok_input_ids[
+ len(input_ids) : len(input_ids) + len(pending_input_ids)
+ ]
+ if new_pending_inputs != pending_input_ids:
+ # logging.warning("tokenization mismatch from concatenation.")
+ pending_input_ids = new_pending_inputs
+ input_ids.extend(pending_input_ids)
+ if pending_weight:
+ labels.extend(pending_input_ids)
+ else:
+ labels.extend([ignore_index] * len(pending_input_ids))
+ pending_input_ids = tok_results["input_ids"][len(input_ids) :]
+ pending_weight = self.weight and msg_content.weight not in [0, 0.0]
+ input_ids.extend(pending_input_ids)
+ if pending_weight:
+ labels.extend(pending_input_ids)
+ else:
+ labels.extend([ignore_index] * len(pending_input_ids))
+ attention_mask = [1] * len(input_ids)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ }
+
+
+class Chats(BaseModel):
+ """
+ top level data structure for chat conversations
+ """
+
+ conversation: List[Messages]
+
+ def __str__(self) -> str:
+ return "".join(str(c) for c in self.conversation)
+
+ def tokenized(
+ self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100
+ ) -> dict[str, List[int]]:
+ input_ids = []
+ attention_mask = []
+ labels = []
+ for msg in self.conversation:
+ msg_results = msg.tokenized(tokenizer, ignore_index)
+ input_ids.extend(msg_results["input_ids"])
+ attention_mask.extend(msg_results["attention_mask"])
+ labels.extend(msg_results["labels"])
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ }
+
+
+class ChatFormattedChats(Chats):
+ """
+ Chat formatted chats with formatter and optional train on inputs
+ """
+
+ formatter: Callable # [[Union[dict, Chats]], Chats]
+ train_on_inputs: bool = False
+
+ def model_post_init(self, __context):
+ for i, msg in enumerate(self.conversation):
+ self.conversation[i] = self.formatter(msg, message_index=i)
+ if self.train_on_inputs:
+ self.conversation[i].weight = 1
+
+
+class PreferenceChats(BaseModel):
+ """
+ representation for preference data for chat
+ """
+
+ prompt: List[Messages]
+ chosen: Messages
+ rejected: Messages
diff --git a/src/axolotl/core/datasets/__init__.py b/src/axolotl/core/datasets/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/axolotl/core/datasets/chat.py b/src/axolotl/core/datasets/chat.py
new file mode 100644
index 0000000000..e74c247d2c
--- /dev/null
+++ b/src/axolotl/core/datasets/chat.py
@@ -0,0 +1,55 @@
+"""
+chat dataset module
+"""
+import os
+from typing import Callable, Optional, Union
+
+from datasets import Dataset
+from transformers import PreTrainedTokenizer
+
+from axolotl.core.chat.messages import ChatFormattedChats
+
+
+class TokenizedChatDataset(Dataset):
+ """
+ Tokenized chat dataset
+ """
+
+ def __init__(
+ self,
+ data: Dataset,
+ model_transform: Union[PreTrainedTokenizer, Callable],
+ *args,
+ message_transform: Optional[Callable] = None,
+ formatter=None,
+ process_count: Optional[int] = None,
+ keep_in_memory: Optional[bool] = False,
+ **kwargs,
+ ):
+ def map_fn(ex):
+ if message_transform is not None:
+ ex = message_transform(ex)
+ if formatter is not None:
+ ex = ChatFormattedChats(
+ formatter=formatter,
+ **ex,
+ )
+ else:
+ ex = ChatFormattedChats(
+ **ex,
+ )
+ return ex.tokenized(model_transform)
+
+ process_or_cpu_count: int = (
+ process_count or os.cpu_count() # type: ignore[assignment]
+ )
+ num_proc = min(64, process_or_cpu_count)
+ features = data.features.keys()
+ tokenized_data = data.map(
+ map_fn,
+ num_proc=num_proc,
+ keep_in_memory=keep_in_memory,
+ remove_columns=features,
+ desc="Tokenizing Chats",
+ )
+ super().__init__(tokenized_data.data, *args, **kwargs)
diff --git a/src/axolotl/core/datasets/transforms/__init__.py b/src/axolotl/core/datasets/transforms/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/axolotl/core/datasets/transforms/chat_builder.py b/src/axolotl/core/datasets/transforms/chat_builder.py
new file mode 100644
index 0000000000..98d5f171a7
--- /dev/null
+++ b/src/axolotl/core/datasets/transforms/chat_builder.py
@@ -0,0 +1,150 @@
+"""
+This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
+"""
+from typing import Any, Mapping, Union
+
+
+def chat_message_transform_builder( # pylint: disable=dangerous-default-value
+ train_on_inputs=False,
+ conversations_field: str = "conversations",
+ message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
+ message_field_content: Union[str, list[str]] = [
+ "value",
+ "text",
+ "content",
+ ], # commonly "content"
+ message_field_training: Union[str, list[str]] = [
+ "train",
+ "weight",
+ ], # commonly "weight"
+):
+ """Builds a transform that takes a row from the dataset and converts it to a Chat
+
+ Args:
+ train_on_inputs (bool, optional):
+ If True, the transform will train on the inputs. If False, the transform will train on the targets.
+ Defaults to False.
+ conversations_field (str, optional):
+ The field name of the conversations. Defaults to "conversations".
+ message_field_role (str | list[str], optional):
+ The field name of the role. Defaults to "role".
+ message_field_content (str | list[str], optional):
+ The field name of the message content. Defaults to "content".
+ message_field_training (str | list[str], optional):
+ The field name of the train/weight. Defaults to "weight".
+
+ Returns:
+ Callable:
+ A function that takes a list of conversations and returns a list of messages.
+ """
+
+ message_field_role = (
+ [message_field_role]
+ if isinstance(message_field_role, str)
+ else message_field_role
+ )
+ message_field_content = (
+ [message_field_content]
+ if isinstance(message_field_content, str)
+ else message_field_content
+ )
+ message_weight_fields = (
+ [message_field_training]
+ if isinstance(message_field_training, str)
+ else message_field_training
+ )
+
+ role_value_mappings = {
+ "system": "system",
+ "user": "user",
+ "human": "user",
+ "assistant": "assistant",
+ "gpt": "assistant",
+ "tool": "tool",
+ "ipython": "ipython",
+ }
+ if train_on_inputs:
+ role_default_weights_mappings = {
+ "system": 1,
+ "user": 1,
+ "assistant": 1,
+ "tool": 1,
+ "ipython": 1,
+ }
+ else:
+ role_default_weights_mappings = {
+ "system": 0,
+ "user": 0,
+ "assistant": 1,
+ "tool": 0,
+ "ipython": 0,
+ }
+
+ def transform_builder(sample: Mapping[str, Any]):
+ if conversations_field not in sample:
+ raise ValueError(f"Field '{conversations_field}' not found in sample.")
+ # if none of the role fields are in the message, raise an error
+ if not any(
+ role in sample[conversations_field][0] for role in message_field_role
+ ):
+ raise ValueError("No role field found in message.")
+ role_field = next(
+ role
+ for role in message_field_role
+ if role in sample[conversations_field][0]
+ )
+ if not any(
+ field in sample[conversations_field][0] for field in message_field_content
+ ):
+ raise ValueError("No message_content field found in message.")
+ message_content_field = next(
+ field
+ for field in message_field_content
+ if field in sample[conversations_field][0]
+ )
+ if not any(
+ field in sample[conversations_field][0] for field in message_field_training
+ ):
+ message_weight_field = None
+ else:
+ message_weight_field = next(
+ field
+ for field in message_weight_fields
+ if field in sample[conversations_field][0]
+ )
+
+ messages = []
+ for message in sample[conversations_field]:
+ role = role_value_mappings[message[role_field]]
+ weight = (
+ int(message[message_weight_field])
+ if message_weight_field
+ else role_default_weights_mappings[role]
+ )
+
+ # TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
+ if isinstance(message[message_content_field], str):
+ messages.append(
+ {
+ "role": role,
+ "content": [
+ {
+ "type": "text",
+ "value": message[message_content_field],
+ }
+ ],
+ "weight": weight,
+ }
+ )
+ else:
+ messages.append(
+ {
+ "role": role,
+ "content": message[message_content_field],
+ "weight": weight,
+ }
+ )
+
+ return {"conversation": messages}
+
+ return transform_builder
diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py
index 4e8b369052..d125f838d3 100755
--- a/src/axolotl/core/trainer_builder.py
+++ b/src/axolotl/core/trainer_builder.py
@@ -4,8 +4,10 @@
"""
import abc
+import gc
import importlib
import importlib.util
+import inspect
import logging
import math
import os
@@ -15,16 +17,17 @@
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
-from typing import Dict, List, Literal, Optional, Type, Union
+from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
+from peft.optimizers import create_loraplus_optimizer
+from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
- PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
@@ -40,13 +43,14 @@
KTOTrainer,
ORPOConfig,
ORPOTrainer,
+ RewardConfig,
+ RewardTrainer,
)
-from trl.trainer.utils import pad_to_length
+from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
-from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
-from axolotl.utils import is_mlflow_available
+from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -59,12 +63,14 @@
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
+from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
+from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
@@ -248,6 +254,10 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
+ chat_template: Optional[str] = field(
+ default=None,
+ metadata={"help": "Chat template converting chat messages to text"},
+ )
@dataclass
@@ -293,6 +303,13 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
)
+@dataclass
+class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
+ """
+ Reward config for Reward training
+ """
+
+
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
@@ -390,12 +407,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def __init__(
self,
*_args,
- num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs,
):
- self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
@@ -454,14 +469,14 @@ def create_optimizer(self):
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
- self.args, "loraplus_lr_embedding", None
+ self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
- optimizer_kwargs,
- loraplus_lr_ratio,
- loraplus_lr_embedding,
+ loraplus_lr_ratio=loraplus_lr_ratio,
+ loraplus_lr_embedding=loraplus_lr_embedding,
+ **optimizer_kwargs,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
@@ -504,9 +519,10 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
- batch_max_len = (
- self.args.per_device_train_batch_size * self.args.max_seq_length
+ train_batch_size = (
+ self.state.train_batch_size or self.args.per_device_train_batch_size
)
+ batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset),
@@ -650,7 +666,9 @@ def get_bench_dataloader(
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
- def compute_loss(self, model, inputs, return_outputs=False):
+ def compute_loss(
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
+ ):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
@@ -658,8 +676,18 @@ def compute_loss(self, model, inputs, return_outputs=False):
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
- return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
- return super().compute_loss(model, inputs, return_outputs=return_outputs)
+ return self.orpo_compute_loss(
+ model,
+ inputs,
+ return_outputs=return_outputs,
+ num_items_in_batch=num_items_in_batch,
+ )
+ return super().compute_loss(
+ model,
+ inputs,
+ return_outputs=return_outputs,
+ num_items_in_batch=num_items_in_batch,
+ )
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
@@ -755,7 +783,13 @@ def orpo_compute_logps(
).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
- def orpo_compute_loss(self, model, inputs, return_outputs=False):
+ def orpo_compute_loss(
+ self,
+ model,
+ inputs,
+ return_outputs=False,
+ num_items_in_batch=None, # pylint: disable=unused-argument
+ ):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
label_pad_token=-100,
@@ -882,6 +916,7 @@ def compute_loss(
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
+ num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
@@ -966,9 +1001,9 @@ def create_optimizer(self):
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
- optimizer_kwargs,
- loraplus_lr_ratio,
- loraplus_lr_embedding,
+ loraplus_lr_ratio=loraplus_lr_ratio,
+ loraplus_lr_embedding=loraplus_lr_embedding,
+ **optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
@@ -989,14 +1024,36 @@ def push_to_hub(self, *args, **kwargs) -> str:
return super().push_to_hub(*args, **kwargs)
def tokenize_row(
- self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
+ self,
+ features,
+ processing_class,
+ max_prompt_length,
+ max_completion_length,
+ add_special_tokens,
) -> Dict:
- res = super().tokenize_row(feature, model=model)
- if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
+ res = super().tokenize_row(
+ features,
+ processing_class,
+ max_prompt_length,
+ max_completion_length,
+ add_special_tokens,
+ )
+ if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res
+ def training_step(
+ self,
+ model: nn.Module,
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ num_items_in_batch=None,
+ ) -> torch.Tensor:
+ loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
+ gc.collect()
+ torch.cuda.empty_cache()
+ return loss
+
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1022,6 +1079,14 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
+class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
+ """
+ Extend the base RewardTrainer for axolotl helpers
+ """
+
+ tag_names = ["axolotl", "reward"]
+
+
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
@@ -1032,10 +1097,11 @@ class TrainerBuilderBase(abc.ABC):
_model_ref = None
_peft_config = None
- def __init__(self, cfg, model, tokenizer):
+ def __init__(self, cfg, model, tokenizer, processor=None):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
+ self.processor = processor
# in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
@@ -1086,12 +1152,23 @@ def get_callbacks(self) -> List[TrainerCallback]:
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
+ from transformers.integrations.integration_utils import MLflowCallback
+
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
+ callbacks.extend(
+ [
+ SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
+ MLflowCallback,
+ ]
+ )
+ if self.cfg.use_comet and is_comet_available():
+ from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
+
callbacks.append(
- SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
+ SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
return callbacks
@@ -1161,6 +1238,11 @@ def get_post_trainer_create_callbacks(self, trainer):
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
+ if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
+ LogPredictionCallback = log_prediction_callback_factory(
+ trainer, self.tokenizer, "comet_ml"
+ )
+ callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
@@ -1185,6 +1267,8 @@ def _get_trainer_cls(self):
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
+ if self.cfg.reward_model:
+ return AxolotlRewardTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -1369,6 +1453,10 @@ def build(self, total_num_steps):
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
+ if self.cfg.auto_find_batch_size is not None:
+ training_arguments_kwargs[
+ "auto_find_batch_size"
+ ] = self.cfg.auto_find_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
@@ -1402,15 +1490,22 @@ def build(self, total_num_steps):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
+ if self.cfg.wandb_name:
+ training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
+ if self.cfg.use_comet:
+ report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to
- training_arguments_kwargs["run_name"] = (
- self.cfg.wandb_name if self.cfg.use_wandb else None
- )
+ if self.cfg.use_wandb:
+ training_arguments_kwargs["run_name"] = self.cfg.wandb_name
+ elif self.cfg.use_mlflow:
+ training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
+ else:
+ training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
@@ -1451,9 +1546,9 @@ def build(self, total_num_steps):
)
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
- training_arguments_kwargs[
- "multipack_real_batches"
- ] = not self.cfg.flash_attention
+ training_arguments_kwargs["multipack_real_batches"] = (
+ not self.cfg.flash_attention or self.cfg.multipack_real_batches
+ )
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
@@ -1498,6 +1593,10 @@ def build(self, total_num_steps):
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
+ if self.cfg.chat_template:
+ training_arguments_kwargs["chat_template"] = get_chat_template(
+ self.cfg.chat_template
+ )
if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
@@ -1509,6 +1608,9 @@ def build(self, total_num_steps):
trainer_kwargs = {}
+ if self.cfg.reward_model:
+ trainer_kwargs["max_length"] = self.cfg.sequence_len
+
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
@@ -1552,13 +1654,22 @@ def build(self, total_num_steps):
"accelerator_config"
] = self.cfg.accelerator_config
- training_args = (
- AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
- **training_arguments_kwargs,
- )
+ training_args_cls = (
+ AxolotlTrainingArguments
+ if not self.cfg.reward_model
+ else AxolotlRewardConfig
+ )
+ training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
+ **training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)
+ # unset run_name so wandb sets up experiment names
+ if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
+ training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
+ None
+ )
+
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
@@ -1571,27 +1682,37 @@ def build(self, total_num_steps):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
+ if self.cfg.reward_model:
+ data_collator_kwargs["max_length"] = self.cfg.sequence_len
+
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
+ if eval_data_collator := self.build_collator(
+ training_args, is_eval=True, **data_collator_kwargs
+ ):
+ if not self.cfg.reward_model:
+ trainer_kwargs["eval_data_collator"] = eval_data_collator
+ if not self.cfg.reward_model:
+ trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
+ self.tokenizer,
+ return_tensors="pt",
+ **data_collator_kwargs,
+ )
+ sig = inspect.signature(trainer_cls)
+ if "processing_class" in sig.parameters.keys():
+ trainer_kwargs["processing_class"] = self.tokenizer
+ else:
+ trainer_kwargs["tokenizer"] = self.tokenizer
+
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
- tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
- eval_data_collator=self.build_collator(
- training_args, is_eval=True, **data_collator_kwargs
- ),
- bench_data_collator=transformers.DataCollatorForSeq2Seq(
- self.tokenizer,
- return_tensors="pt",
- **data_collator_kwargs,
- ),
callbacks=self.get_callbacks(),
- num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
@@ -1625,9 +1746,14 @@ def build_collator(
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
+ RewardDataCollatorWithPadding,
]
]
- if use_batch_sampler_collator:
+ if self.cfg.reward_model:
+ collator = RewardDataCollatorWithPadding
+ if "max_length" in kwargs:
+ kwargs.pop("max_length")
+ elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
@@ -1638,7 +1764,12 @@ def build_collator(
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
- collator = DataCollatorForSeq2Seq
+ if self.cfg.processor_type and self.processor:
+ collator = MultiModalChatDataCollator
+ kwargs["processor"] = self.processor
+ kwargs["chat_template"] = training_args.chat_template
+ else:
+ collator = DataCollatorForSeq2Seq
return collator(
self.tokenizer,
@@ -1824,7 +1955,7 @@ def build(self, total_num_steps):
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
- dpo_trainer_kwargs["generate_during_eval"] = True
+ dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
@@ -1836,16 +1967,24 @@ def build(self, total_num_steps):
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
+
+ sig = inspect.signature(trainer_cls)
+ if "processing_class" in sig.parameters.keys():
+ dpo_trainer_kwargs["processing_class"] = self.tokenizer
+ else:
+ dpo_trainer_kwargs["tokenizer"] = self.tokenizer
+
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
- tokenizer=self.tokenizer,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
+ if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
+ ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
diff --git a/src/axolotl/integrations/LICENSE.md b/src/axolotl/integrations/LICENSE.md
new file mode 100644
index 0000000000..435d36d75b
--- /dev/null
+++ b/src/axolotl/integrations/LICENSE.md
@@ -0,0 +1,58 @@
+### AXOLOTL COMMUNITY LICENSE AGREEMENT
+
+This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
+any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
+and conditions set forth in this Agreement.
+
+1. Definitions
+ 1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
+ 1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
+ which may be licensed separately by their respective authors and/or licensors.
+ 1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
+ https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
+ permits Plugin Integrations to integrate with the Axolotl service.
+2. Grant of License
+ 2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
+ publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
+ - Licensee must comply with all the terms and conditions of this Agreement.
+ - Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
+ portions of the Software.
+ 2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
+3. Restrictions
+ 3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
+ free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
+ third parties to fine-tune artificial intelligence models.
+ 3.2 Licensee shall not:
+ - Use the Software for any illegal or unauthorized purpose.
+ - Reverse engineer, decompile, or disassemble the Software.
+ - Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
+ - Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
+ Software or interfere with any third-party use of the Software.
+ 3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
+4. Intellectual Property Rights
+ 4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
+ acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
+ Licensee.
+5. Disclaimer of Warranty
+ 5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
+ TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
+ THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
+ CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ DEALINGS IN THE SOFTWARE.
+6. Termination
+ 6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
+ conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
+ copies in its possession.
+7. Governing Law
+ 7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
+ without regards to conflicts of laws provisions thereof.
+8. Entire Agreement
+ 8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
+ hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
+ the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
+ Licensee’s continued use of the Software after any such updates shall constitute acceptance of updated terms
+ on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
+ material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
+ bound by the terms and conditions of this Agreement.
+
+This Agreement was last updated on August 23, 2024.
diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py
new file mode 100644
index 0000000000..e2bd79bc4d
--- /dev/null
+++ b/src/axolotl/integrations/base.py
@@ -0,0 +1,420 @@
+# Copyright 2024 Axolotl AI. All rights reserved.
+#
+# This software may be used and distributed according to
+# the terms of the Axolotl Community License Agreement (the "License");
+# you may not use this file except in compliance with the License.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+
+"""
+Base class for all plugins.
+
+A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.
+Plugins can be used to integrate third-party models, modify the training process, or add new features.
+
+To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
+"""
+import importlib
+import logging
+from typing import List
+
+
+class BasePlugin:
+ """
+ Base class for all plugins. Defines the interface for plugin methods.
+
+ Attributes:
+ None
+
+ Methods:
+ register(cfg): Registers the plugin with the given configuration.
+ pre_model_load(cfg): Performs actions before the model is loaded.
+ post_model_load(cfg, model): Performs actions after the model is loaded.
+ pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
+ post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
+ create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
+ create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
+ add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
+ add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
+ """
+
+ def __init__(self):
+ """
+ Initializes the BasePlugin.
+ """
+
+ def register(self, cfg):
+ """
+ Registers the plugin with the given configuration.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+
+ Returns:
+ None
+ """
+
+ def get_input_args(self):
+ """
+ Returns a pydantic model for the plugin's input arguments.
+ """
+
+ def pre_model_load(self, cfg):
+ """
+ Performs actions before the model is loaded.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+
+ Returns:
+ None
+ """
+
+ def post_model_load(self, cfg, model):
+ """
+ Performs actions after the model is loaded.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+ model (object): The loaded model.
+
+ Returns:
+ None
+ """
+
+ def pre_lora_load(self, cfg, model):
+ """
+ Performs actions before LoRA weights are loaded.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+ model (object): The loaded model.
+
+ Returns:
+ None
+ """
+
+ def post_lora_load(self, cfg, model):
+ """
+ Performs actions after LoRA weights are loaded.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+ model (object): The loaded model.
+
+ Returns:
+ None
+ """
+
+ def create_optimizer(self, cfg, trainer):
+ """
+ Creates and returns an optimizer for training.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+ trainer (object): The trainer object for training.
+
+ Returns:
+ object: The created optimizer.
+ """
+
+ def create_lr_scheduler(self, cfg, trainer, optimizer):
+ """
+ Creates and returns a learning rate scheduler.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+ trainer (object): The trainer object for training.
+ optimizer (object): The optimizer for training.
+
+ Returns:
+ object: The created learning rate scheduler.
+ """
+
+ def add_callbacks_pre_trainer(self, cfg, model):
+ """
+ Adds callbacks to the trainer before training.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+ model (object): The loaded model.
+
+ Returns:
+ List[callable]: A list of callback functions to be added to the TrainingArgs
+ """
+
+ def add_callbacks_post_trainer(self, cfg, trainer):
+ """
+ Adds callbacks to the trainer after training.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+ trainer (object): The trainer object for training.
+
+ Returns:
+ List[callable]: A list of callback functions to be added to the TrainingArgs
+ """
+
+ def post_train(self, cfg, model):
+ """
+ Performs actions after training is complete.
+
+ Parameters:
+ cfg (dict): The axolotl configuration
+ model (object): The loaded model.
+
+ Returns:
+ None
+ """
+
+ def post_train_unload(self, cfg):
+ """
+ Performs actions after training is complete and the model is unloaded.
+
+ Parameters:
+ cfg (dict): The configuration for the plugin.
+
+ Returns:
+ None
+ """
+
+
+def load_plugin(plugin_name: str) -> BasePlugin:
+ """
+ Loads a plugin based on the given plugin name.
+
+ The plugin name should be in the format "module_name.class_name".
+ This function splits the plugin name into module and class, imports the module,
+ retrieves the class from the module, and creates an instance of the class.
+
+ Parameters:
+ plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name".
+
+ Returns:
+ BasePlugin: An instance of the loaded plugin.
+
+ Raises:
+ ImportError: If the plugin module cannot be imported.
+ """
+ # split the plugin name into module and class
+ module_name, class_name = plugin_name.rsplit(".", 1)
+
+ # import the module
+ module = importlib.import_module(module_name)
+ # instantiate the class
+ plugin_class = getattr(module, class_name)
+ # create an instance of the class
+ plugin = plugin_class()
+
+ return plugin
+
+
+class PluginManager:
+ """
+ The PluginManager class is responsible for loading and managing plugins.
+ It should be a singleton so it can be accessed from anywhere in the codebase.
+
+ Attributes:
+ plugins (List[BasePlugin]): A list of loaded plugins.
+
+ Methods:
+ get_instance(): Static method to get the singleton instance of PluginManager.
+ register(plugin_name: str): Registers a new plugin by its name.
+ pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
+ """
+
+ plugins: List[BasePlugin] = []
+
+ _instance = None
+
+ def __new__(cls):
+ """
+ Creates a new instance of PluginManager if it doesn't exist yet.
+ """
+ if cls._instance is None:
+ cls._instance = super(PluginManager, cls).__new__(cls)
+ cls._instance.plugins: List[BasePlugin] = []
+ return cls._instance
+
+ @staticmethod
+ def get_instance() -> "PluginManager":
+ """
+ Returns the singleton instance of PluginManager.
+ If the instance doesn't exist, it creates a new one.
+ """
+ if PluginManager._instance is None:
+ PluginManager()
+ return PluginManager._instance # type: ignore
+
+ def register(self, plugin_name: str):
+ """
+ Registers a new plugin by its name.
+
+ Parameters:
+ plugin_name (str): The name of the plugin to be registered.
+
+ Returns:
+ None
+
+ Raises:
+ ImportError: If the plugin module cannot be imported.
+ """
+ try:
+ plugin = load_plugin(plugin_name)
+ self.plugins.append(plugin)
+ except ImportError:
+ logging.error(f"Failed to load plugin: {plugin_name}")
+
+ def get_input_args(self):
+ """
+ Returns a list of Pydantic classes for all registered plugins' input arguments.'
+
+ Returns:
+ list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
+ """
+ input_args = []
+ for plugin in self.plugins:
+ input_args_from_plugin = plugin.get_input_args()
+ if input_args_from_plugin is not None:
+ input_args.append(input_args_from_plugin)
+ return input_args
+
+ def pre_model_load(self, cfg):
+ """
+ Calls the pre_model_load method of all registered plugins.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+
+ Returns:
+ None
+ """
+ for plugin in self.plugins:
+ plugin.pre_model_load(cfg)
+
+ def post_model_load(self, cfg, model):
+ """
+ Calls the post_model_load method of all registered plugins.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+ model (object): The loaded model.
+
+ Returns:
+ None
+ """
+ for plugin in self.plugins:
+ plugin.post_model_load(cfg, model)
+
+ def pre_lora_load(self, cfg, model):
+ """
+ Calls the pre_lora_load method of all registered plugins.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+ model (object): The loaded model.
+
+ Returns:
+ None
+ """
+ for plugin in self.plugins:
+ plugin.pre_lora_load(cfg, model)
+
+ def post_lora_load(self, cfg, model):
+ """
+ Calls the post_lora_load method of all registered plugins.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+ model (object): The loaded model.
+
+ Returns:
+ None
+ """
+ for plugin in self.plugins:
+ plugin.post_lora_load(cfg, model)
+
+ def create_optimizer(self, cfg, trainer):
+ """
+ Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+ trainer (object): The trainer object for training.
+
+ Returns:
+ object: The created optimizer, or None if none was found.
+ """
+ for plugin in self.plugins:
+ optimizer = plugin.create_optimizer(cfg, trainer)
+ if optimizer is not None:
+ return optimizer
+ return None
+
+ def create_lr_scheduler(self, cfg, trainer, optimizer):
+ """
+ Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+ trainer (object): The trainer object for training.
+ optimizer (object): The optimizer for training.
+
+ Returns:
+ object: The created learning rate scheduler, or None if none was found.
+ """
+ for plugin in self.plugins:
+ scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
+ if scheduler is not None:
+ return scheduler
+ return None
+
+ def add_callbacks_pre_trainer(self, cfg, model):
+ """
+ Calls the add_callbacks_pre_trainer method of all registered plugins.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+ model (object): The loaded model.
+
+ Returns:
+ List[callable]: A list of callback functions to be added to the TrainingArgs.
+ """
+ callbacks = []
+ for plugin in self.plugins:
+ callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
+ return callbacks
+
+ def add_callbacks_post_trainer(self, cfg, trainer):
+ """
+ Calls the add_callbacks_post_trainer method of all registered plugins.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+ trainer (object): The trainer object for training.
+
+ Returns:
+ List[callable]: A list of callback functions to be added to the TrainingArgs.
+ """
+ callbacks = []
+ for plugin in self.plugins:
+ callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
+ return callbacks
+
+ def post_train_unload(self, cfg):
+ """
+ Calls the post_train_unload method of all registered plugins.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+ model (object): The loaded model.
+
+ Returns:
+ None
+ """
+ for plugin in self.plugins:
+ plugin.post_train_unload(cfg)
diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py
new file mode 100644
index 0000000000..b4ffd6758f
--- /dev/null
+++ b/src/axolotl/integrations/config.py
@@ -0,0 +1,65 @@
+# Copyright 2024 Axolotl AI. All rights reserved.
+#
+# This software may be used and distributed according to
+# the terms of the Axolotl Community License Agreement (the "License");
+# you may not use this file except in compliance with the License.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+
+"""
+module to handle merging the plugins' input arguments with the base configurations.
+
+this was moved here to prevent circular imports
+"""
+
+from typing import Any, Dict, List
+
+from axolotl.utils.config.models.input.v0_4_1 import (
+ AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
+)
+from axolotl.utils.config.models.input.v0_4_1 import (
+ AxolotlInputConfig as AxolotlInputConfigBase,
+)
+
+
+def merge_input_args():
+ """
+ Merges input arguments from registered plugins with the base configurations.
+
+ This function retrieves the input arguments from registered plugins using the PluginManager.
+ It then dynamically creates new classes, AxolotlConfigWCapabilities and AxolotlInputConfig,
+ that inherit from the base configurations and include the input arguments from the plugins.
+
+ Returns:
+ tuple: A tuple containing the newly created classes, AxolotlConfigWCapabilities and AxolotlInputConfig.
+ """
+ from axolotl.integrations.base import PluginManager
+
+ plugin_manager = PluginManager.get_instance()
+ input_args: List[str] = plugin_manager.get_input_args()
+ plugin_classes = []
+ dynamic_input = ""
+ for plugin_args in input_args:
+ plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
+ dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
+ plugin_classes.append(plugin_cls)
+ if dynamic_input:
+ dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
+ dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
+
+ namespace: Dict[Any, Any] = {}
+ exec( # pylint: disable=exec-used # nosec B102
+ dynamic_input, globals(), namespace
+ )
+ AxolotlInputConfig = namespace[ # pylint: disable=invalid-name
+ "AxolotlInputConfig"
+ ]
+ AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name
+ "AxolotlConfigWCapabilities"
+ ]
+ return AxolotlConfigWCapabilities, AxolotlInputConfig
+ return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
diff --git a/src/axolotl/integrations/liger/LICENSE b/src/axolotl/integrations/liger/LICENSE
new file mode 100644
index 0000000000..d645695673
--- /dev/null
+++ b/src/axolotl/integrations/liger/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py
new file mode 100644
index 0000000000..2047f3815d
--- /dev/null
+++ b/src/axolotl/integrations/liger/__init__.py
@@ -0,0 +1,189 @@
+# Copyright 2024 Axolotl AI. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Module for the Plugin for LIGER integraton with Axolotl.
+
+Liger Kernel is the collection of Triton-native kernels for LLM Training.
+It is designed to be performant, correct, and light-weight.
+"""
+import logging
+import sys
+from functools import partial
+
+from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
+from liger_kernel.transformers.geglu import LigerGEGLUMLP
+from liger_kernel.transformers.rms_norm import LigerRMSNorm
+from liger_kernel.transformers.rope import liger_rotary_pos_emb
+from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
+
+from axolotl.integrations.base import BasePlugin
+
+from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
+
+
+class LigerPlugin(BasePlugin):
+ """
+ Plugin for LIGER integraton with Axolotl.
+ """
+
+ def get_input_args(self):
+ return "axolotl.integrations.liger.LigerArgs"
+
+ def pre_model_load(self, cfg):
+ if cfg.model_config_type == "llama":
+ from liger_kernel.transformers.model.llama import (
+ lce_forward as llama_lce_forward,
+ )
+ from transformers.models.llama import modeling_llama
+
+ if cfg.liger_rope:
+ modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
+ if cfg.liger_rms_norm:
+ modeling_llama.LlamaRMSNorm = LigerRMSNorm
+ if cfg.liger_swiglu:
+ modeling_llama.LlamaMLP = LigerSwiGLUMLP
+ if cfg.liger_cross_entropy:
+ modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
+ elif cfg.liger_fused_linear_cross_entropy:
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
+
+ elif cfg.model_config_type == "mistral":
+ from liger_kernel.transformers.model.mistral import (
+ lce_forward as mistral_lce_forward,
+ )
+ from transformers.models.mistral import modeling_mistral
+
+ if cfg.liger_rope:
+ modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
+ if cfg.liger_rms_norm:
+ modeling_mistral.MistralRMSNorm = LigerRMSNorm
+ if cfg.liger_swiglu:
+ modeling_mistral.MistralMLP = LigerSwiGLUMLP
+ if cfg.liger_cross_entropy:
+ modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
+ if cfg.liger_fused_linear_cross_entropy:
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
+
+ elif cfg.model_config_type == "gemma":
+ from liger_kernel.transformers.model.gemma import (
+ lce_forward as gemma_lce_forward,
+ )
+ from transformers.models.gemma import modeling_gemma
+
+ if cfg.liger_rope:
+ modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
+ if cfg.liger_rms_norm:
+ modeling_gemma.GemmaRMSNorm = partial(
+ LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
+ )
+ if cfg.liger_swiglu:
+ modeling_gemma.GemmaMLP = LigerGEGLUMLP
+ if cfg.liger_cross_entropy:
+ modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
+ if cfg.liger_fused_linear_cross_entropy:
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
+
+ elif cfg.model_config_type == "jamba":
+ from transformers.models.jamba import modeling_jamba
+
+ from .models.jamba import lce_forward as jamba_lce_forward
+
+ if cfg.liger_rope:
+ modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
+ if cfg.liger_rms_norm:
+ modeling_jamba.JambaRMSNorm = LigerRMSNorm
+ if cfg.liger_swiglu:
+ modeling_jamba.JambaMLP = LigerSwiGLUMLP
+ if cfg.liger_cross_entropy:
+ modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
+ if cfg.liger_fused_linear_cross_entropy:
+ modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
+
+ elif cfg.model_config_type == "qwen2":
+ from liger_kernel.transformers.model.qwen2 import (
+ lce_forward as qwen2_lce_forward,
+ )
+ from transformers.models.qwen2 import modeling_qwen2
+
+ if cfg.liger_rope:
+ modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
+ if cfg.liger_rms_norm:
+ modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
+ if cfg.liger_swiglu:
+ modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
+ if cfg.liger_cross_entropy:
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
+ if cfg.liger_fused_linear_cross_entropy:
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
+
+ elif cfg.model_config_type == "deepseek_v2":
+ from accelerate import init_empty_weights
+ from transformers import AutoModelForCausalLM
+
+ with init_empty_weights():
+ model = AutoModelForCausalLM.from_pretrained(
+ cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
+ )
+ modeling_mod = sys.modules[model.__class__.__module__]
+
+ from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
+
+ if cfg.liger_rope:
+ # The DeepseekV2 version of RoPE is different than upstream LLaMA.
+ # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
+ logging.warning("Fused liger_rope is not supported for DeepseekV2.")
+ if cfg.liger_rms_norm:
+ modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
+ if cfg.liger_swiglu:
+ modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
+ if cfg.liger_cross_entropy:
+ modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
+ if cfg.liger_fused_linear_cross_entropy:
+ modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
+
+ elif cfg.model_config_type == "gemma2":
+ from transformers.models.gemma2 import modeling_gemma2
+
+ if cfg.liger_rope:
+ modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
+ if cfg.liger_rms_norm:
+ modeling_gemma2.Gemma2RMSNorm = partial(
+ LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
+ )
+ if cfg.liger_swiglu:
+ modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
+ if cfg.liger_cross_entropy:
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
+ if cfg.liger_fused_linear_cross_entropy:
+ logging.warning(
+ "Fused linear cross entropy is not supported for Gemma 2."
+ )
+
+ elif cfg.model_config_type == "phi3":
+ from liger_kernel.transformers.model.phi3 import (
+ lce_forward as phi3_lce_forward,
+ )
+ from transformers.models.phi3 import modeling_phi3
+
+ if cfg.liger_rope:
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
+ if cfg.liger_rms_norm:
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm
+ if cfg.liger_swiglu:
+ modeling_phi3.Phi3MLP = LigerSwiGLUMLP
+ if cfg.liger_cross_entropy:
+ modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
+ if cfg.liger_fused_linear_cross_entropy:
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py
new file mode 100644
index 0000000000..decdb37750
--- /dev/null
+++ b/src/axolotl/integrations/liger/args.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Axolotl AI. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Module for handling LIGER input arguments.
+"""
+from typing import Optional
+
+from pydantic import BaseModel
+
+
+class LigerArgs(BaseModel):
+ """
+ Input args for LIGER.
+ """
+
+ liger_rope: Optional[bool] = None
+ liger_rms_norm: Optional[bool] = None
+ liger_swiglu: Optional[bool] = None
+ liger_cross_entropy: Optional[bool] = None
+ liger_fused_linear_cross_entropy: Optional[bool] = None
diff --git a/src/axolotl/integrations/liger/models/deepseekv2.py b/src/axolotl/integrations/liger/models/deepseekv2.py
new file mode 100644
index 0000000000..79fb274360
--- /dev/null
+++ b/src/axolotl/integrations/liger/models/deepseekv2.py
@@ -0,0 +1,127 @@
+"""
+DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
+"""
+# pylint: disable=duplicate-code
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+from liger_kernel.transformers.fused_linear_cross_entropy import (
+ LigerFusedLinearCrossEntropyLoss,
+)
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+
+# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
+# @replace_return_docstrings(
+# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
+# )
+def lce_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
+
+ >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ loss = None
+ logits = None
+
+ if self.training:
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ # flatten tokens
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
+ shift_labels = shift_labels.view(-1)
+
+ lce = LigerFusedLinearCrossEntropyLoss()
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/axolotl/integrations/liger/models/jamba.py b/src/axolotl/integrations/liger/models/jamba.py
new file mode 100644
index 0000000000..40cec63a4f
--- /dev/null
+++ b/src/axolotl/integrations/liger/models/jamba.py
@@ -0,0 +1,173 @@
+"""
+Jamba model with LigerFusedLinearCrossEntropyLoss
+"""
+# pylint: disable=duplicate-code
+
+from typing import Optional, Tuple, Union
+
+import torch
+from liger_kernel.transformers.fused_linear_cross_entropy import (
+ LigerFusedLinearCrossEntropyLoss,
+)
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import MoeCausalLMOutputWithPast
+from transformers.models.jamba.modeling_jamba import (
+ _CONFIG_FOR_DOC,
+ JAMBA_INPUTS_DOCSTRING,
+ HybridMambaAttentionDynamicCache,
+ load_balancing_loss_func,
+)
+from transformers.utils import (
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+
+
+@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
+@replace_return_docstrings(
+ output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
+)
+def lce_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: Optional[Union[int, None]] = None,
+) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ num_logits_to_keep (`int` or `None`, *optional*):
+ Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
+ `input_ids`. Only last token logits are needed for generation, and calculating them only for that token
+ can save memory, which becomes pretty significant for long sequences.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, JambaForCausalLM
+
+ >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_router_logits = (
+ output_router_logits
+ if output_router_logits is not None
+ else self.config.output_router_logits
+ )
+
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_router_logits=output_router_logits,
+ cache_position=cache_position,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ loss = None
+ logits = None
+
+ if self.training:
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ # flatten tokens
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
+ shift_labels = shift_labels.view(-1)
+
+ lce = LigerFusedLinearCrossEntropyLoss()
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
+ else:
+ if num_logits_to_keep is None:
+ logits = self.lm_head(hidden_states)
+ else:
+ logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
+ logits = logits.float()
+
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits if return_dict else outputs[-1],
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss.to(
+ loss.device
+ ) # make sure to reside in the same device
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ if output_router_logits:
+ output = (aux_loss,) + output
+ return (loss,) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
diff --git a/src/axolotl/integrations/lm_eval/README.md b/src/axolotl/integrations/lm_eval/README.md
new file mode 100644
index 0000000000..3724c49ccf
--- /dev/null
+++ b/src/axolotl/integrations/lm_eval/README.md
@@ -0,0 +1,13 @@
+# LM Eval Harness
+
+### Usage
+
+```yaml
+plugins:
+ - axolotl.integrations.lm_eval.LMEvalPlugin
+
+lm_eval_tasks:
+ - gsm8k
+ - hellaswag
+ - arc_easy
+```
diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py
new file mode 100644
index 0000000000..f1daa20000
--- /dev/null
+++ b/src/axolotl/integrations/lm_eval/__init__.py
@@ -0,0 +1,42 @@
+"""
+Module for the Plugin for LM Eval Harness
+"""
+import subprocess # nosec
+from datetime import datetime
+
+from axolotl.integrations.base import BasePlugin
+
+from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
+
+
+class LMEvalPlugin(BasePlugin):
+ """
+ Plugin for LM Evaluation Harness integraton with Axolotl.
+ """
+
+ def get_input_args(self):
+ return "axolotl.integrations.lm_eval.LMEvalArgs"
+
+ def post_train_unload(self, cfg):
+ tasks = ",".join(cfg.lm_eval_tasks)
+ fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
+ dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
+ output_path = cfg.output_dir
+ output_path += "" if cfg.output_dir.endswith("/") else "/"
+ output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
+ subprocess.run( # nosec
+ [
+ "lm_eval",
+ "--model",
+ "hf",
+ "--model_args",
+ f"pretrained={cfg.output_dir}{fa2}{dtype}",
+ "--tasks",
+ tasks,
+ "--batch_size",
+ str(cfg.lm_eval_batch_size),
+ "--output_path",
+ output_path,
+ ],
+ check=True,
+ )
diff --git a/src/axolotl/integrations/lm_eval/args.py b/src/axolotl/integrations/lm_eval/args.py
new file mode 100644
index 0000000000..f58e6a6e38
--- /dev/null
+++ b/src/axolotl/integrations/lm_eval/args.py
@@ -0,0 +1,15 @@
+"""
+Module for handling lm eval harness input arguments.
+"""
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+
+class LMEvalArgs(BaseModel):
+ """
+ Input args for lm eval harness
+ """
+
+ lm_eval_tasks: List[str] = []
+ lm_eval_batch_size: Optional[int] = 8
diff --git a/src/axolotl/integrations/spectrum/LICENSE b/src/axolotl/integrations/spectrum/LICENSE
new file mode 100644
index 0000000000..d645695673
--- /dev/null
+++ b/src/axolotl/integrations/spectrum/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/src/axolotl/integrations/spectrum/README.md b/src/axolotl/integrations/spectrum/README.md
new file mode 100644
index 0000000000..192918060e
--- /dev/null
+++ b/src/axolotl/integrations/spectrum/README.md
@@ -0,0 +1,21 @@
+## Spectrum: Targeted Training on Signal to Noise Ratio
+
+by Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar
+
+This plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR).
+
+### Overview
+
+Spectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models.
+By identifying the top n% of layers with the highest SNR, you can optimize training efficiency.
+
+### Usage
+
+```yaml
+plugins:
+ - axolotl.integrations.spectrum.SpectrumPlugin
+
+spectrum_top_fraction: 0.5
+# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror
+spectrum_model_name: meta-llama/Meta-Llama-3.1-8B
+```
diff --git a/src/axolotl/integrations/spectrum/__init__.py b/src/axolotl/integrations/spectrum/__init__.py
new file mode 100644
index 0000000000..6059e7951c
--- /dev/null
+++ b/src/axolotl/integrations/spectrum/__init__.py
@@ -0,0 +1,102 @@
+# Copyright 2024 Axolotl AI. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
+"""
+
+import json
+import logging
+
+import requests
+
+from axolotl.integrations.base import BasePlugin
+
+from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
+
+
+def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
+ unfrozen_parameters = {}
+ for layer_name, info in snr_data.items():
+ layer_type = info["type"]
+ if layer_type not in unfrozen_parameters:
+ unfrozen_parameters[layer_type] = []
+ unfrozen_parameters[layer_type].append((layer_name, info["snr"]))
+ top_layers_by_type = {}
+ for layer_type, layers in unfrozen_parameters.items():
+ layers_sorted = sorted(layers, key=lambda x: x[1], reverse=True)
+ num_top_layers = int(len(layers) * top_fraction)
+ top_layers_by_type[layer_type] = [
+ layer[0] for layer in layers_sorted[:num_top_layers]
+ ]
+ unfrozen_parameters = [
+ "^lm_head.weight$",
+ "^model.embed_tokens.weight$",
+ ]
+ for layer_type, layer_names in top_layers_by_type.items():
+ for layer_name in layer_names:
+ unfrozen_parameters.append(layer_name)
+ return unfrozen_parameters
+
+
+class SpectrumPlugin(BasePlugin):
+ """
+ Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
+ """
+
+ base_url = "https://raw.githubusercontent.com/cognitivecomputations/spectrum/main/model_snr_results/"
+ base_path = "./model_snr_results/"
+ snr_file_template = "snr_results_{model_name_slug}.json"
+
+ def get_input_args(self):
+ return "axolotl.integrations.spectrum.SpectrumArgs"
+
+ def pre_model_load(self, cfg):
+ if cfg.get("spectrum_model_name"):
+ model_name = cfg["spectrum_model_name"]
+ else:
+ model_name = cfg["base_model"]
+ top_fraction = cfg.get("spectrum_top_fraction", 50)
+ model_slug = model_name.replace("/", "-").replace("_", "-")
+ snr_url = self.base_url + self.snr_file_template.format(
+ model_name_slug=model_slug
+ )
+ snr_path = self.base_path + self.snr_file_template.format(
+ model_name_slug=model_slug
+ )
+ # first check if the files exist locally and read the json
+ snr_data = None
+ try:
+ with open(snr_path, "r", encoding="utf-8") as fin:
+ snr_data = json.load(fin)
+ except FileNotFoundError:
+ pass
+ except Exception as exc: # pylint: disable=broad-exception-caught
+ logging.warning(f"Failed to read SNR data from {snr_path}: {exc}")
+
+ if not snr_data:
+ try:
+ snr_data = requests.get(snr_url, timeout=60).json()
+ except requests.exceptions.RequestException as exc:
+ logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
+ return
+ # also catch json parsing errors
+ except json.JSONDecodeError as exc:
+ logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
+ return
+
+ unfrozen_parameters = _generate_unfrozen_params_yaml(
+ snr_data, top_fraction=top_fraction
+ )
+ cfg["unfrozen_parameters"] = unfrozen_parameters
diff --git a/src/axolotl/integrations/spectrum/args.py b/src/axolotl/integrations/spectrum/args.py
new file mode 100644
index 0000000000..03426d8413
--- /dev/null
+++ b/src/axolotl/integrations/spectrum/args.py
@@ -0,0 +1,29 @@
+# Copyright 2024 Axolotl AI. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Module for handling Spectrum input arguments.
+"""
+from typing import Optional
+
+from pydantic import BaseModel
+
+
+class SpectrumArgs(BaseModel):
+ """
+ Input args for Spectrum.
+ """
+
+ spectrum_top_fraction: Optional[float] = 0.5
+ spectrum_model_name: Optional[str] = None
diff --git a/src/axolotl/loraplus.py b/src/axolotl/loraplus.py
deleted file mode 100644
index b4abec55ad..0000000000
--- a/src/axolotl/loraplus.py
+++ /dev/null
@@ -1,133 +0,0 @@
-"""Module for LoRA+"""
-
-# MIT License
-#
-# Copyright (c) 2024 nikhil-ghosh-berkeley
-# https://github.com/nikhil-ghosh-berkeley/loraplus
-
-import logging
-from functools import reduce
-
-from peft.tuners import lora
-from torch import nn
-from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
-from transformers.trainer_pt_utils import get_parameter_names
-
-LOG = logging.getLogger("axolotl.loraplus")
-
-
-def get_module(name, opt_model):
- """
- Retrieve a module from a model using its parameter name.
- Args:
- name (str): Full name of the parameter, typically including module path.
- opt_model (torch.nn.Module): The model from which to retrieve the module.
-
- Returns:
- Module corresponding to the given name.
- """
- parent_idx = 2 if "lora" in name else 1
- module_names = name.split(sep=".")[:-parent_idx]
- module = reduce(getattr, module_names, opt_model)
- return module
-
-
-def create_loraplus_optimizer(
- opt_model,
- optimizer_cls,
- optimizer_kwargs,
- loraplus_lr_ratio,
- loraplus_lr_embedding=None,
-):
- """
- Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
-
- Args:
- opt_model (torch.nn.Module): The model for which the optimizer is being created.
- optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
- optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
- loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
- loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
-
- Returns:
- An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
- """
-
- assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
-
- if loraplus_lr_embedding is None:
- loraplus_lr_embedding = 1e-6
-
- decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
- param_groups = {
- "groupA": {},
- "groupB": {},
- "groupB_no_decay": {},
- "embedding": {},
- }
-
- for name, param in opt_model.named_parameters():
- if not param.requires_grad:
- continue
-
- module = get_module(name, opt_model)
- if isinstance(module, lora.Embedding):
- param_groups["embedding"][name] = param
- elif "lora_B" in name or param.ndim == 1:
- if name in decay_parameters:
- param_groups["groupB"][name] = param
- else:
- param_groups["groupB_no_decay"][name] = param
- else:
- param_groups["groupA"][name] = param
-
- assigned_param_groups = ""
- for group, group_params in param_groups.items():
- assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
- LOG.info(assigned_param_groups)
-
- lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
- weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
-
- optimizer_grouped_parameters = [
- {
- "params": list(param_groups["groupA"].values()),
- "weight_decay": weight_decay,
- "lr": lr,
- },
- {
- "params": list(param_groups["embedding"].values()),
- "weight_decay": weight_decay,
- "lr": loraplus_lr_embedding,
- },
- {
- "params": list(param_groups["groupB"].values()),
- "weight_decay": weight_decay,
- "lr": lr * loraplus_lr_ratio,
- },
- {
- "params": list(param_groups["groupB_no_decay"].values()),
- "weight_decay": 0.0,
- "lr": lr * loraplus_lr_ratio,
- },
- ]
-
- optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
- if optimizer_cls.__name__ == "Adam8bit":
- import bitsandbytes
-
- manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
-
- skipped = 0
- for module in opt_model.modules():
- if isinstance(module, nn.Embedding):
- skipped += sum(
- {p.data_ptr(): p.numel() for p in module.parameters()}.values()
- )
- LOG.info(f"skipped {module}: {skipped/2**20}M params")
- manager.register_module_override(module, "weight", {"optim_bits": 32})
- LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
- LOG.info(f"skipped: {skipped/2**20}M params")
-
- return optimizer
diff --git a/src/axolotl/monkeypatch/attention/mllama.py b/src/axolotl/monkeypatch/attention/mllama.py
new file mode 100644
index 0000000000..0b18b716d5
--- /dev/null
+++ b/src/axolotl/monkeypatch/attention/mllama.py
@@ -0,0 +1,229 @@
+"""
+Monkeypatch for Vision Llama for FA2 support
+"""
+# pylint: disable=duplicate-code
+
+from typing import Optional, Tuple
+
+import torch
+from flash_attn.flash_attn_interface import flash_attn_func
+from transformers.cache_utils import Cache
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.models.mllama.configuration_mllama import MllamaTextConfig
+from transformers.models.mllama.modeling_mllama import (
+ MllamaTextCrossAttention,
+ MllamaTextSelfAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+from transformers.utils import is_flash_attn_greater_or_equal_2_10
+
+
+class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention):
+ """
+ Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and
+ implements the forward pass using Flash Attention for improved performance.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Check if flash attention version is greater or equal to 2.1
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ attention_mask: Optional[ # pylint: disable=unused-argument
+ torch.Tensor
+ ] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False, # pylint: disable=unused-argument
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ query_states = self.q_norm(query_states)
+
+ if cross_attention_states is not None:
+ key_states = self.k_proj(cross_attention_states)
+ value_states = self.v_proj(cross_attention_states)
+ key_states = key_states.view(
+ bsz, -1, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, -1, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ key_states = self.k_norm(key_states)
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(
+ key_states,
+ value_states,
+ self.layer_idx,
+ {"cache_position": cache_position},
+ )
+ elif cache_position[0] != 0:
+ key_states, value_states = (
+ past_key_value.key_cache[self.layer_idx],
+ past_key_value.value_cache[self.layer_idx],
+ )
+ else:
+ raise ValueError(
+ "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
+ )
+
+ # Transpose to get the expected layout for flash attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # Apply Flash Attention
+ dropout_rate = self.dropout if self.training else 0.0
+ output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout_p=dropout_rate,
+ softmax_scale=None,
+ causal=False,
+ return_attn_probs=output_attentions,
+ )
+
+ attn_output = output.contiguous().view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention):
+ """
+ Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and
+ implements the forward pass using Flash Attention for improved performance.
+ """
+
+ def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs):
+ super().__init__(config, layer_idx, *args, **kwargs)
+
+ # Check if flash attention version is greater or equal to 2.1
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False, # pylint: disable=unused-argument
+ past_key_value=None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs, # pylint: disable=unused-argument
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x num_heads x head_dim
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Transpose to get the expected layout for flash attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.dropout if self.training else 0.0
+
+ # Handle potential silent casting to float32
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = (
+ self.config._pre_quantization_dtype # pylint: disable=protected-access
+ )
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=True,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def patch_mllama():
+ from transformers.models.mllama.modeling_mllama import (
+ MLLAMA_TEXT_ATTENTION_CLASSES,
+ MLLAMA_TEXT_CROSS_ATTENTION_CLASSES,
+ MLLAMA_VISION_ATTENTION_CLASSES,
+ MllamaPreTrainedModel,
+ )
+
+ MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access
+ True
+ )
+ MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
+ MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
+ "flash_attention_2"
+ ] = MllamaTextCrossFlashAttention2
+ # fallback to SDPA
+ MLLAMA_VISION_ATTENTION_CLASSES[
+ "flash_attention_2"
+ ] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
index 4c3571ea4f..c804d0c6b9 100644
--- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
@@ -22,7 +22,6 @@
apply_rotary_pos_emb,
repeat_kv,
)
-from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
@@ -44,7 +43,19 @@
LOG = logging.getLogger("axolotl")
+def is_xformers_available() -> bool:
+ try:
+ import xformers # pylint: disable=unused-import # noqa: F401
+
+ return True
+ except ImportError:
+ return False
+
+
def is_xformers_swiglu_available() -> bool:
+ if not is_xformers_available():
+ return False
+
from xformers.ops.common import get_xformers_operator
try:
@@ -57,6 +68,11 @@ def is_xformers_swiglu_available() -> bool:
def replace_llama_mlp_with_swiglu(model):
+ if is_xformers_swiglu_available():
+ from axolotl.monkeypatch.xformers_ import FusedMLP
+ else:
+ raise RuntimeError("xformers SwiGLU not available for this environment")
+
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
@@ -181,49 +197,6 @@ def _post_training(self, model, name):
set_module_name(model, name, new_attn)
-class FusedMLP(torch.nn.Module):
- """
- Fused MLP layer for incrementally improved training efficiency
- """
-
- def __init__(
- self,
- config,
- gate_proj: torch.nn.Linear,
- up_proj: torch.nn.Linear,
- down_proj: torch.nn.Linear,
- ):
- super().__init__()
- self.config = config
- self.swiglu = SwiGLU(
- in_features=config.hidden_size,
- hidden_features=config.intermediate_size,
- bias=False,
- _pack_weights=True,
- )
- # overwrite initialized weights with pretrained weights
- self.swiglu.w12.weight.data = torch.cat(
- (gate_proj.weight.data, up_proj.weight.data), dim=0
- )
- self.swiglu.w3.weight.data = down_proj.weight.data
-
- def _post_training(self, model, name):
- w1, w2 = torch.split( # pylint: disable=invalid-name
- self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
- )
-
- # Assign the split weights back to the original layers
- new_mlp = LlamaMLP(self.config)
- new_mlp.gate_proj.weight.data = w1
- new_mlp.up_proj.weight.data = w2
- new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
-
- set_module_name(model, name, new_mlp)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
- return self.swiglu(x)
-
-
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py
index 540c5577a0..cfd525367e 100644
--- a/src/axolotl/monkeypatch/llama_patch_multipack.py
+++ b/src/axolotl/monkeypatch/llama_patch_multipack.py
@@ -9,18 +9,18 @@
def hijack_llama_prepare_4d_mask():
- import transformers.modeling_attn_mask_utils
- import transformers.models.llama.modeling_llama
+ from transformers import modeling_attn_mask_utils
+ from transformers.models.llama import modeling_llama
- transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
+ modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
- transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
+ modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
- transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
+ modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
- transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
+ modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py
index 9043520108..85101cd3c4 100644
--- a/src/axolotl/monkeypatch/multipack.py
+++ b/src/axolotl/monkeypatch/multipack.py
@@ -10,6 +10,7 @@
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
+ "mllama_text_model",
"llama",
"mistral",
"mixtral",
@@ -17,6 +18,7 @@
"qwen2_moe",
"falcon",
"phi",
+ "phi3",
"gemma",
"gemma2",
"gemmoe",
diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py
index e4352cbe3d..9d246cb17f 100644
--- a/src/axolotl/monkeypatch/relora.py
+++ b/src/axolotl/monkeypatch/relora.py
@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
- reset_params: list[str], # where str is the key to a torch.nn.Parameter
- optimizer_state_keys: list[str],
+ reset_params: List[str], # where str is the key to a torch.nn.Parameter
+ optimizer_state_keys: List[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
index 0269f90157..67e9337e36 100644
--- a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
@@ -16,6 +16,7 @@
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+# pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """
import importlib
import math
diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py
index 5b1f0061de..c8272ac735 100644
--- a/src/axolotl/monkeypatch/unsloth_.py
+++ b/src/axolotl/monkeypatch/unsloth_.py
@@ -16,28 +16,6 @@
LOG = get_logger("axolotl.monkeypatch.unsloth")
-ORIGINAL_CEL_CODE = """ if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-"""
-
-PATCHED_CEL_CODE = """ if labels is not None:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- loss = fast_cross_entropy_loss(
- logits = shift_logits,
- labels = shift_labels,
- )
-"""
-
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@@ -82,12 +60,6 @@ def get_forward_code() -> str:
return forward
-def check_cel_is_patchable() -> bool:
- forward = get_forward_code()
- forward, _ = detab_code(forward)
- return ORIGINAL_CEL_CODE in forward
-
-
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
@@ -100,48 +72,31 @@ def check_self_attn_is_patchable() -> bool:
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
- if model_type == "llama":
- forward = get_forward_code()
- LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
- forward, _ = detab_code(forward)
- assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
+ from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
+
+ def UnslothForCausalLMLoss( # pylint: disable=invalid-name
+ logits,
+ labels,
+ vocab_size: int, # pylint: disable=unused-argument
+ num_items_in_batch: int = None,
+ ignore_index: int = -100, # pylint: disable=unused-argument
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
- forward = forward.replace(
- "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
- )
- forward = forward.replace(
- "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
- "",
- )
- forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
- forward = forward.replace(
- "def forward(",
- "def fast_cross_entropy_loss_forward(",
- 1,
+ loss = fast_cross_entropy_loss(
+ logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
+ return loss
- # load imports necessary
- import transformers.models.llama.modeling_llama
-
- items_to_import = []
- for item in dir(transformers.models.llama.modeling_llama):
- if item in forward:
- items_to_import.append(item)
-
- exec( # pylint: disable=exec-used # nosec B102
- "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
- globals(),
- )
+ if model_type == "llama":
+ from transformers.loss import loss_utils
- exec( # pylint: disable=exec-used # nosec B102
- "from transformers.models.llama.modeling_llama import ("
- + ", ".join(x for x in items_to_import)
- + ")",
- globals(),
- )
- exec(forward, globals()) # pylint: disable=exec-used # nosec B102
- LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
- LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
+ loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
else:
raise ValueError("Unsupported model type")
diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py
index e43c58650a..f29f21be77 100644
--- a/src/axolotl/monkeypatch/utils.py
+++ b/src/axolotl/monkeypatch/utils.py
@@ -17,11 +17,9 @@ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item())
batch_size, _ = attention_mask.shape
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
-
for i in range(1, max_num + 1):
mask = attention_mask == i
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
-
result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1)
return result[nonzero_indices]
diff --git a/src/axolotl/monkeypatch/xformers_/__init__.py b/src/axolotl/monkeypatch/xformers_/__init__.py
new file mode 100644
index 0000000000..bddc036b24
--- /dev/null
+++ b/src/axolotl/monkeypatch/xformers_/__init__.py
@@ -0,0 +1,51 @@
+"""
+Fused MLP layer for incrementally improved training efficiency
+"""
+import torch
+from transformers.models.llama.modeling_llama import LlamaMLP
+from xformers.ops import SwiGLU
+
+from axolotl.monkeypatch.utils import set_module_name
+
+
+class FusedMLP(torch.nn.Module):
+ """
+ Fused MLP layer for incrementally improved training efficiency
+ """
+
+ def __init__(
+ self,
+ config,
+ gate_proj: torch.nn.Linear,
+ up_proj: torch.nn.Linear,
+ down_proj: torch.nn.Linear,
+ ):
+ super().__init__()
+ self.config = config
+ self.swiglu = SwiGLU(
+ in_features=config.hidden_size,
+ hidden_features=config.intermediate_size,
+ bias=False,
+ _pack_weights=True,
+ )
+ # overwrite initialized weights with pretrained weights
+ self.swiglu.w12.weight.data = torch.cat(
+ (gate_proj.weight.data, up_proj.weight.data), dim=0
+ )
+ self.swiglu.w3.weight.data = down_proj.weight.data
+
+ def _post_training(self, model, name):
+ w1, w2 = torch.split( # pylint: disable=invalid-name
+ self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
+ )
+
+ # Assign the split weights back to the original layers
+ new_mlp = LlamaMLP(self.config)
+ new_mlp.gate_proj.weight.data = w1
+ new_mlp.up_proj.weight.data = w2
+ new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
+
+ set_module_name(model, name, new_mlp)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
+ return self.swiglu(x)
diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py
index f5699a0871..74da20c5e1 100644
--- a/src/axolotl/prompt_strategies/__init__.py
+++ b/src/axolotl/prompt_strategies/__init__.py
@@ -9,8 +9,12 @@
LOG = logging.getLogger("axolotl.prompt_strategies")
-def load(strategy, tokenizer, cfg, ds_cfg):
+def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
try:
+ if strategy == "messages":
+ from .messages import load as messages_load
+
+ return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
@@ -24,9 +28,12 @@ def load(strategy, tokenizer, cfg, ds_cfg):
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
+ if "processor" in sig.parameters:
+ load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
- return None
+ raise exc
+ return None
diff --git a/src/axolotl/prompt_strategies/bradley_terry/README.md b/src/axolotl/prompt_strategies/bradley_terry/README.md
new file mode 100644
index 0000000000..39cd16137c
--- /dev/null
+++ b/src/axolotl/prompt_strategies/bradley_terry/README.md
@@ -0,0 +1,10 @@
+### example yaml
+
+```yaml
+chat_template: gemma
+datasets:
+ - path: argilla/distilabel-intel-orca-dpo-pairs
+ type: bradley_terry.chat_template
+val_set_size: 0.0
+output_dir: ./outputs/out
+```
diff --git a/src/axolotl/prompt_strategies/bradley_terry/__init__.py b/src/axolotl/prompt_strategies/bradley_terry/__init__.py
new file mode 100644
index 0000000000..4457c50be5
--- /dev/null
+++ b/src/axolotl/prompt_strategies/bradley_terry/__init__.py
@@ -0,0 +1,35 @@
+"""Module to load prompt strategies."""
+
+import importlib
+import inspect
+import logging
+
+from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
+
+LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
+
+
+def load(strategy, tokenizer, cfg, ds_cfg):
+ # pylint: disable=duplicate-code
+ try:
+ load_fn = "load"
+ if strategy.split(".")[-1].startswith("load_"):
+ load_fn = strategy.split(".")[-1]
+ strategy = ".".join(strategy.split(".")[:-1])
+ mod = importlib.import_module(
+ f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
+ )
+ func = getattr(mod, load_fn)
+ load_kwargs = {}
+ if strategy == "user_defined":
+ load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
+ else:
+ sig = inspect.signature(func)
+ if "ds_cfg" in sig.parameters:
+ load_kwargs["ds_cfg"] = ds_cfg
+ return func(tokenizer, cfg, **load_kwargs)
+ except ModuleNotFoundError:
+ return None
+ except Exception as exc: # pylint: disable=broad-exception-caught
+ LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
+ return None
diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py
new file mode 100644
index 0000000000..fa85cdcb26
--- /dev/null
+++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py
@@ -0,0 +1,102 @@
+"""
+Bradley-Terry model with chat template prompt strategy.
+"""
+
+import logging
+from typing import Any, Dict, Optional
+
+from axolotl.prompt_strategies.chat_template import (
+ ChatTemplatePrompter,
+ ChatTemplateStrategy,
+)
+from axolotl.utils.chat_templates import get_chat_template_from_config
+
+# Configure the logger
+LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
+LOG.setLevel(logging.INFO)
+
+
+class BTChatTemplateStrategy(ChatTemplateStrategy):
+ """
+ Bradley-Terry reward model pairwise chat template prompt strategy.
+ """
+
+ def tokenize_prompt(self, prompt):
+ """
+
+ :param prompt: the actual row of data from the underlying dataset
+ :return:
+ """
+
+ self.messages = "chosen_messages"
+ # pylint: disable=duplicate-code
+ prompt[self.messages] = []
+ if prompt["system"]:
+ prompt[self.messages].append(
+ {"role": "system", "content": prompt["system"]}
+ )
+ prompt[self.messages].append({"role": "user", "content": prompt["input"]})
+ prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
+ chosen_tokenized = super().tokenize_prompt(prompt)
+
+ self.messages = "rejected_messages"
+ # pylint: disable=duplicate-code
+ prompt[self.messages] = []
+ if prompt["system"]:
+ prompt[self.messages].append(
+ {"role": "system", "content": prompt["system"]}
+ )
+ prompt[self.messages].append({"role": "user", "content": prompt["input"]})
+ prompt[self.messages].append(
+ {"role": "assistant", "content": prompt["rejected"]}
+ )
+ rejected_tokenized = super().tokenize_prompt(prompt)
+
+ return {
+ "input_ids_chosen": chosen_tokenized["input_ids"],
+ "attention_mask_chosen": chosen_tokenized["attention_mask"],
+ "labels_chosen": 1.0,
+ "input_ids_rejected": rejected_tokenized["input_ids"],
+ "attention_mask_rejected": rejected_tokenized["attention_mask"],
+ "labels_rejected": 0.0,
+ }
+
+
+def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
+ ds_cfg = ds_cfg or {}
+ chat_template_string = get_chat_template_from_config(
+ cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
+ )
+
+ prompter_params = {
+ "tokenizer": tokenizer,
+ "chat_template": chat_template_string,
+ "message_field_role": ds_cfg.get("message_field_role", "role"),
+ "message_field_content": ds_cfg.get("message_field_content", "content"),
+ "message_field_training": ds_cfg.get("message_field_training", None),
+ "message_field_training_detail": ds_cfg.get(
+ "message_field_training_detail", None
+ ),
+ "roles": ds_cfg.get("roles"),
+ "drop_system_message": ds_cfg.get("drop_system_message", False),
+ # we need to add one for detecting sequences with exceeding the `sequence_len` limit.
+ "max_length": cfg.sequence_len + 1
+ if not cfg.reward_model
+ else cfg.sequence_len,
+ }
+
+ strategy_params = {
+ "train_on_inputs": cfg.train_on_inputs,
+ "sequence_len": cfg.sequence_len,
+ "roles_to_train": ds_cfg.get("roles_to_train", []),
+ "train_on_eos": ds_cfg.get("train_on_eos", None),
+ }
+
+ strategy = BTChatTemplateStrategy(
+ ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
+ )
+
+ if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
+ strategy.messages = ds_cfg["field_messages"]
+
+ return strategy
diff --git a/src/axolotl/prompt_strategies/bradley_terry/llama3.py b/src/axolotl/prompt_strategies/bradley_terry/llama3.py
new file mode 100644
index 0000000000..1d586fd5f4
--- /dev/null
+++ b/src/axolotl/prompt_strategies/bradley_terry/llama3.py
@@ -0,0 +1,27 @@
+"""
+chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
+"""
+
+
+def icr(
+ cfg,
+ **kwargs,
+): # pylint: disable=possibly-unused-variable,unused-argument
+ """
+ chatml transforms for datasets with system, input, chosen, rejected
+ ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
+ """
+
+ def transform_fn(sample):
+ if "system" in sample and sample["system"]:
+ prompt = (
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ else:
+ prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
+ sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
+ sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
+ return sample
+
+ return transform_fn
diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py
index d0fad4483a..0946a4b8c7 100644
--- a/src/axolotl/prompt_strategies/chat_template.py
+++ b/src/axolotl/prompt_strategies/chat_template.py
@@ -5,9 +5,11 @@
import logging
from typing import Any, Dict, List, Optional
+from transformers import ProcessorMixin
+
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
-from axolotl.utils.chat_templates import chat_templates
+from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -20,12 +22,13 @@ class ChatTemplatePrompter(Prompter):
def __init__(
self,
tokenizer,
+ processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
- message_field_training: str = "train",
- message_field_training_detail: str = "train_detail",
+ message_field_training: Optional[str] = None,
+ message_field_training_detail: Optional[str] = None,
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
@@ -44,11 +47,12 @@ def __init__(
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer
+ self.processor: ProcessorMixin = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
- def build_prompt(self, conversation, add_generation_prompt=False):
+ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
turns = [
{
"role": self.roles[t[self.message_field_role]],
@@ -61,6 +65,28 @@ def build_prompt(self, conversation, add_generation_prompt=False):
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
+ if self.processor:
+ text = self.processor.apply_chat_template(
+ turns,
+ chat_template=self.chat_template,
+ tokenize=False,
+ add_generation_prompt=add_generation_prompt,
+ )
+ batch = self.processor(
+ text=text,
+ images=images,
+ return_tensors="pt",
+ truncation=True,
+ max_length=self.max_length,
+ )
+ # workaround since processor works in batches instead of single examples
+ for k, val in batch.items():
+ if k in ["pixel_values"]:
+ batch[k] = val.tolist()
+ else:
+ batch[k] = val.squeeze().tolist()
+ return batch
+
return self.tokenizer.apply_chat_template(
turns,
truncation=True,
@@ -186,11 +212,12 @@ def __init__(
train_on_inputs,
sequence_len,
roles_to_train=None,
- train_on_eos="last",
+ train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos
+ self.images = "images"
@property
def messages(self):
@@ -201,6 +228,40 @@ def messages(self, messages):
self._messages = messages
def tokenize_prompt(self, prompt):
+ # Old simple legacy behavior that works reliably.
+ if (
+ not self.roles_to_train
+ and not self.train_on_eos
+ and not self.prompter.message_field_training
+ and not self.prompter.message_field_training_detail
+ ):
+ turns = self.get_conversation_thread(prompt)
+ images = self.get_images(prompt)
+ prompt_ids = self.prompter.build_prompt(
+ turns[:-1],
+ add_generation_prompt=True,
+ images=images,
+ )
+ tokenized_res = self.prompter.build_prompt(turns, images=images)
+ tokenized_prompt = {}
+ if isinstance(tokenized_res, list):
+ input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
+ tokenized_prompt["input_ids"] = input_ids
+ tokenized_prompt["attention_mask"] = [1] * len(input_ids)
+ else:
+ input_ids = tokenized_res["input_ids"]
+ tokenized_prompt = tokenized_res
+
+ if not self.train_on_inputs:
+ user_prompt_len = len(prompt_ids)
+ labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
+ else:
+ labels = input_ids
+
+ tokenized_prompt["labels"] = labels
+
+ return tokenized_prompt
+
turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)
@@ -219,9 +280,11 @@ def tokenize_prompt(self, prompt):
should_train = (
train_turn
if train_turn is not None
- else bool(train_detail is not None)
- if train_detail is not None
- else self.train_on_inputs or role in self.roles_to_train
+ else (
+ bool(train_detail is not None)
+ if train_detail is not None
+ else self.train_on_inputs or role in self.roles_to_train
+ )
)
LOG.debug(f"Should train: {should_train}")
@@ -335,29 +398,40 @@ def find_turn(self, conversation_ids, turn, turn_content):
def get_conversation_thread(self, prompt):
return prompt[self.messages]
+ def get_images(self, prompt):
+ return prompt.get(self.images, None)
+
-def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
+def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
+ # pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
+ chat_template_string = get_chat_template_from_config(
+ cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
+ )
+ LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
- "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
- "message_field_role": ds_cfg.get("message_field_role", "from"),
- "message_field_content": ds_cfg.get("message_field_content", "value"),
- "message_field_training": ds_cfg.get("message_field_training", "training"),
+ "chat_template": chat_template_string,
+ "message_field_role": ds_cfg.get("message_field_role", "role"),
+ "message_field_content": ds_cfg.get("message_field_content", "content"),
+ "message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
- "message_field_training_detail", "train_detail"
+ "message_field_training_detail",
+ None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
- "max_length": cfg.sequence_len,
+ # we need to add one for detecting sequences with exceeding the `sequence_len` limit.
+ "max_length": cfg.sequence_len + 1,
+ "processor": processor,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
- "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
- "train_on_eos": ds_cfg.get("train_on_eos", "last"),
+ "roles_to_train": ds_cfg.get("roles_to_train", []),
+ "train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = ChatTemplateStrategy(
diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py
index e0e5eb1294..489b864851 100644
--- a/src/axolotl/prompt_strategies/dpo/chat_template.py
+++ b/src/axolotl/prompt_strategies/dpo/chat_template.py
@@ -2,15 +2,16 @@
DPO prompt strategies for using tokenizer chat templates.
"""
-from axolotl.utils.chat_templates import chat_templates
+from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
- chat_template_str = chat_templates(cfg.chat_template)
-
+ chat_template_choice, chat_template_jinja = extract_chat_template_args(
+ cfg=cfg, ds_cfg=ds_cfg
+ )
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -30,6 +31,12 @@ def default(
role_map[source] = target
def transform_fn(sample, tokenizer=None):
+ chat_template_string = get_chat_template(
+ user_choice=chat_template_choice,
+ jinja_template=chat_template_jinja,
+ tokenizer=tokenizer,
+ )
+
messages = sample[field_messages]
messages = [
{
@@ -46,28 +53,29 @@ def transform_fn(sample, tokenizer=None):
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
+ dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
- chat_template=chat_template_str,
+ chat_template=chat_template_string,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
- [chosen],
+ [dummy_user_message, chosen],
add_generation_prompt=False,
- chat_template=chat_template_str,
+ chat_template=chat_template_string,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
- [rejected],
+ [dummy_user_message, rejected],
add_generation_prompt=False,
- chat_template=chat_template_str,
+ chat_template=chat_template_string,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])
diff --git a/src/axolotl/prompt_strategies/messages/__init__.py b/src/axolotl/prompt_strategies/messages/__init__.py
new file mode 100644
index 0000000000..d014d93a6b
--- /dev/null
+++ b/src/axolotl/prompt_strategies/messages/__init__.py
@@ -0,0 +1,34 @@
+"""Module to load message prompt strategies."""
+
+import importlib
+import inspect
+import logging
+
+LOG = logging.getLogger("axolotl.prompt_strategies.messages")
+
+
+def load(tokenizer, cfg, ds_cfg, processor=None):
+ try:
+ strategy = ds_cfg.get("input_transform", "chat")
+ # pylint: disable=duplicate-code
+ load_fn = "load"
+ if strategy.split(".")[-1].startswith("load_"):
+ load_fn = strategy.split(".")[-1]
+ strategy = ".".join(strategy.split(".")[:-1])
+ mod = importlib.import_module(
+ f".{strategy}", "axolotl.prompt_strategies.messages"
+ )
+ func = getattr(mod, load_fn)
+ load_kwargs = {}
+ sig = inspect.signature(func)
+ if "ds_cfg" in sig.parameters:
+ load_kwargs["ds_cfg"] = ds_cfg
+ if "processor" in sig.parameters:
+ load_kwargs["processor"] = processor
+ return func(tokenizer, cfg, **load_kwargs)
+ except ModuleNotFoundError:
+ return None
+ except Exception as exc: # pylint: disable=broad-exception-caught
+ LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
+ raise exc
+ return None
diff --git a/src/axolotl/prompt_strategies/messages/chat.py b/src/axolotl/prompt_strategies/messages/chat.py
new file mode 100644
index 0000000000..35d7649026
--- /dev/null
+++ b/src/axolotl/prompt_strategies/messages/chat.py
@@ -0,0 +1,84 @@
+"""
+Chat dataset wrapping strategy for new internal messages representations
+"""
+from typing import Any, Callable, Dict, Optional
+
+from axolotl.core.datasets.chat import TokenizedChatDataset
+from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
+from axolotl.prompt_tokenizers import DatasetWrappingStrategy
+
+
+class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
+ """
+ Chat dataset wrapping strategy for new internal messages representations
+ """
+
+ def __init__(
+ self,
+ processor,
+ message_transform=None,
+ formatter=None,
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ """
+ :param processor: tokenizer or image processor
+ :param kwargs:
+ """
+ self.processor = processor
+ self.dataset = None
+ self.message_transform = message_transform
+ self.formatter = formatter
+
+ def wrap_dataset(
+ self,
+ dataset,
+ process_count: Optional[int] = None,
+ keep_in_memory: Optional[bool] = False,
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ self.dataset = TokenizedChatDataset(
+ dataset,
+ message_transform=self.message_transform,
+ model_transform=self.processor,
+ formatter=self.formatter,
+ process_count=process_count,
+ keep_in_memory=keep_in_memory,
+ )
+ return self.dataset
+
+
+def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
+ ds_cfg = ds_cfg or {}
+
+ field_messages = ds_cfg.get("field_messages")
+ message_field_role = ds_cfg.get("message_field_role")
+ message_field_content = ds_cfg.get("message_field_content")
+ message_field_training = ds_cfg.get("message_field_training")
+
+ builder_kwargs = {}
+ if field_messages:
+ builder_kwargs["conversations_field"] = field_messages
+ if message_field_role:
+ builder_kwargs["message_field_role"] = message_field_role
+ if message_field_content:
+ builder_kwargs["message_field_content"] = message_field_content
+ if message_field_training:
+ builder_kwargs["message_field_training"] = message_field_training
+
+ chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
+ format_message = (
+ lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
+ )
+ if chat_template == "chatml":
+ from axolotl.core.chat.format.chatml import format_message # noqa F811
+ if chat_template.startswith("llama3"):
+ from axolotl.core.chat.format.llama3x import format_message # noqa F811
+ message_transform: Callable = chat_message_transform_builder(
+ train_on_inputs=ds_cfg.get("train_on_inputs", False),
+ **builder_kwargs,
+ )
+ strategy = ChatMessageDatasetWrappingStrategy(
+ tokenizer, message_transform=message_transform, formatter=format_message
+ )
+
+ return strategy
diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py
index bba6948568..e53a547483 100644
--- a/src/axolotl/prompt_strategies/orpo/chat_template.py
+++ b/src/axolotl/prompt_strategies/orpo/chat_template.py
@@ -5,7 +5,7 @@
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
-from axolotl.utils.chat_templates import chat_templates
+from axolotl.utils.chat_templates import get_chat_template_from_config
class Message(BaseModel):
@@ -28,18 +28,13 @@ def load(
"""
chatml transforms for datasets with system, input, chosen, rejected
"""
-
- chat_template = chat_templates("chatml")
- if ds_cfg and "chat_template" in ds_cfg:
- chat_template = ds_cfg["chat_template"]
- try:
- chat_template = chat_templates(chat_template)
- except ValueError:
- pass
- tokenizer.chat_template = chat_template
+ chat_template_string = get_chat_template_from_config(
+ cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
+ )
+ tokenizer.chat_template = chat_template_string
return ORPOTokenizingStrategy(
- ORPOPrompter(chat_template, tokenizer),
+ ORPOPrompter(chat_template_string, tokenizer),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -248,28 +243,30 @@ def build_prompt(
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
- chat_template_str = chat_templates(cfg.chat_template)
-
def transform_fn(sample, tokenizer=None):
res = {}
+ chat_template_string = get_chat_template_from_config(
+ cfg=cfg, tokenizer=tokenizer
+ )
+
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
- chat_template=chat_template_str,
+ chat_template=chat_template_string,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
- chat_template=chat_template_str,
+ chat_template=chat_template_string,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
- chat_template=chat_template_str,
+ chat_template=chat_template_string,
tokenize=False,
)[prompt_str_len:]
diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py
index 321f19554b..069d243f52 100644
--- a/src/axolotl/prompt_strategies/sharegpt.py
+++ b/src/axolotl/prompt_strategies/sharegpt.py
@@ -61,6 +61,9 @@ def build_loader(
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
+ LOG.warning(
+ "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
+ )
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py
index 11dd084a85..51d497a23c 100644
--- a/src/axolotl/prompt_tokenizers.py
+++ b/src/axolotl/prompt_tokenizers.py
@@ -30,6 +30,12 @@ class InvalidDataException(Exception):
"""
+class DatasetWrappingStrategy(abc.ABC):
+ """
+ Abstract class for wrapping datasets for Chat Messages
+ """
+
+
class PromptTokenizingStrategy(abc.ABC):
"""
Abstract class for tokenizing strategies
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 0ffa3e55fd..18b73e725e 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -65,8 +65,10 @@ def match_prompt_style(self):
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
elif self.prompt_style == PromptStyle.PHI.value:
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
- self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
- self.system_format = "<|system|>{system}\n"
+ self.turn_no_input_format = (
+ "<|user|>\n{instruction}<|end|>\n<|assistant|>\n"
+ )
+ self.system_format = "<|system|>\n{system}<|end|>\n"
def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
@@ -350,9 +352,12 @@ def _build_result(self, source):
"Please help us by creating an Issue to add support for this conversation type."
)
- role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
- ROLE=from_role
- )
+ if self._conversation.name in ["llama3"]:
+ role = from_role
+ else:
+ role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
+ ROLE=from_role
+ )
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
if (
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index b8890d4f7a..5fde4d3848 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -10,8 +10,8 @@
import torch
import transformers.modelcard
-from accelerate import Accelerator
from accelerate.logging import get_logger
+from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
@@ -23,7 +23,7 @@
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
-from axolotl.utils.models import load_model, load_tokenizer
+from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
try:
@@ -68,6 +68,9 @@ def train(
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
+ processor = None
+ if cfg.is_multimodal:
+ processor = load_processor(cfg, tokenizer)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
@@ -93,10 +96,11 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
- # we wait unitl the last possible moment to setup Accelerator
- Accelerator()
- model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
- model.generation_config.do_sample = True
+ model, peft_config = load_model(
+ cfg, tokenizer, processor=processor, inference=cli_args.inference
+ )
+ if model.generation_config is not None:
+ model.generation_config.do_sample = True
model_ref = None
if cfg.rl and cfg.rl != "orpo":
@@ -121,6 +125,7 @@ def train(
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
+ processor,
total_num_steps,
)
@@ -194,9 +199,12 @@ def terminate_handler(_, __, model_weakref):
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
+ state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
- trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
- LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
+ if cfg.fsdp_final_state_dict_type:
+ state_dict_type = cfg.fsdp_final_state_dict_type
+ trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
+ LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
@@ -208,7 +216,18 @@ def terminate_handler(_, __, model_weakref):
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
- trainer.save_model(cfg.output_dir)
+ if (
+ state_dict_type == "SHARDED_STATE_DICT"
+ and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
+ ):
+ save_fsdp_model(
+ trainer.accelerator.state.fsdp_plugin,
+ trainer.accelerator,
+ trainer.model,
+ cfg.output_dir,
+ )
+ elif state_dict_type == "FULL_STATE_DICT":
+ trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
@@ -241,8 +260,10 @@ def terminate_handler(_, __, model_weakref):
if not cfg.hub_model_id:
try:
- trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
- except AttributeError:
+ trainer.create_model_card(
+ model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
+ )
+ except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py
index 99dec79f1b..91545009ad 100644
--- a/src/axolotl/utils/__init__.py
+++ b/src/axolotl/utils/__init__.py
@@ -1,8 +1,12 @@
"""
Basic utils for Axolotl
"""
-import importlib
+import importlib.util
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
+
+
+def is_comet_available():
+ return importlib.util.find_spec("comet_ml") is not None
diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py
index 73715b06ab..0bc781fcb4 100644
--- a/src/axolotl/utils/callbacks/__init__.py
+++ b/src/axolotl/utils/callbacks/__init__.py
@@ -29,7 +29,7 @@
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
-from axolotl.utils import is_mlflow_available
+from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -462,7 +462,7 @@ def evaluate_preds(sources, predictions, references):
references=[[r] for r in references],
predictions=predictions,
)
- scores[metric_name] = score
+ scores["eval_" + metric_name] = score
return scores
def predict_with_generate():
@@ -747,6 +747,15 @@ def log_table_from_dataloader(name: str, table_dataloader):
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
+ elif logger == "comet_ml" and is_comet_available():
+ import comet_ml
+
+ experiment = comet_ml.get_running_experiment()
+ if experiment:
+ experiment.log_table(
+ f"{name} - Predictions vs Ground Truth.csv",
+ pd.DataFrame(table_data),
+ )
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
diff --git a/src/axolotl/utils/callbacks/comet_.py b/src/axolotl/utils/callbacks/comet_.py
new file mode 100644
index 0000000000..b29f997a86
--- /dev/null
+++ b/src/axolotl/utils/callbacks/comet_.py
@@ -0,0 +1,43 @@
+"""Comet module for trainer callbacks"""
+
+import logging
+from typing import TYPE_CHECKING
+
+import comet_ml
+from transformers import TrainerCallback, TrainerControl, TrainerState
+
+from axolotl.utils.distributed import is_main_process
+
+if TYPE_CHECKING:
+ from axolotl.core.trainer_builder import AxolotlTrainingArguments
+
+LOG = logging.getLogger("axolotl.callbacks")
+
+
+class SaveAxolotlConfigtoCometCallback(TrainerCallback):
+ """Callback to save axolotl config to comet"""
+
+ def __init__(self, axolotl_config_path):
+ self.axolotl_config_path = axolotl_config_path
+
+ def on_train_begin(
+ self,
+ args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
+ state: TrainerState, # pylint: disable=unused-argument
+ control: TrainerControl,
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ if is_main_process():
+ try:
+ comet_experiment = comet_ml.start(source="axolotl")
+ comet_experiment.log_other("Created from", "axolotl")
+ comet_experiment.log_asset(
+ self.axolotl_config_path,
+ file_name="axolotl-config",
+ )
+ LOG.info(
+ "The Axolotl config has been saved to the Comet Experiment under assets."
+ )
+ except (FileNotFoundError, ConnectionError) as err:
+ LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
+ return control
diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py
index ca4334d75a..dfb3fef21a 100644
--- a/src/axolotl/utils/chat_templates.py
+++ b/src/axolotl/utils/chat_templates.py
@@ -2,14 +2,48 @@
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
+import logging
+from typing import TYPE_CHECKING, Any, Dict, Optional
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizerBase
-def chat_templates(user_choice: str):
+LOG = logging.getLogger("axolotl.utils.chat_templates")
+
+_JINJA_TEMPALTE_CHOICE = "jinja"
+_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
+_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
+
+_CHAT_TEMPLATES = {
+ "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
+ "mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
+ "mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
+ "mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral...
+ "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+ "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}",
+ "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
+ "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
+ "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n',
+ "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
+ "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
+ "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}",
+ "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n',
+ "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
+}
+
+
+def get_chat_template(
+ user_choice: str,
+ jinja_template: Optional[str] = None,
+ tokenizer: Optional["PreTrainedTokenizerBase"] = None,
+):
"""
- Finds the correct chat_template for the tokenizer_config.
+ Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
Args:
user_choice (str): The user's choice of template.
+ jinja_template (Optional[str], optional): The jinja template string. Defaults to None.
+ tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None.
Returns:
str: The chosen template string.
@@ -17,19 +51,81 @@ def chat_templates(user_choice: str):
Raises:
ValueError: If the user_choice is not found in the templates.
"""
+ if user_choice == _JINJA_TEMPALTE_CHOICE:
+ if not jinja_template:
+ raise ValueError(
+ f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}"
+ )
+ return jinja_template
+
+ if user_choice == _DEFAULT_TEMPLATE_CHOICE:
+ if not tokenizer:
+ raise ValueError(
+ f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
+ )
+ if not tokenizer.chat_template:
+ raise ValueError(
+ f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
+ f"Please add a chat_template in tokenizer config"
+ )
+ return tokenizer.chat_template
- templates = {
- "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
- "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
- "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
- "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}",
- "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
- "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
- "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
- "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}",
- }
+ if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
+ if not tokenizer:
+ raise ValueError(
+ f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
+ )
+ if tokenizer.chat_template:
+ return tokenizer.chat_template
- if user_choice in templates:
- return templates[user_choice]
+ user_choice = user_choice[
+ len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
+ ]
+ LOG.warning(
+ f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
+ )
+
+ if user_choice in _CHAT_TEMPLATES:
+ return _CHAT_TEMPLATES[user_choice]
raise ValueError(f"Template '{user_choice}' not found.")
+
+
+def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None):
+ if ds_cfg and ds_cfg.get("chat_template"):
+ chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
+ chat_template_jinja = ds_cfg.get("chat_template_jinja")
+ else:
+ chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
+ chat_template_jinja = cfg.get("chat_template_jinja")
+ return chat_template_choice, chat_template_jinja
+
+
+def get_chat_template_from_config(
+ cfg,
+ ds_cfg: Optional[Dict[str, Any]] = None,
+ tokenizer: Optional["PreTrainedTokenizerBase"] = None,
+) -> str:
+ chat_template_choice, chat_template_jinja = extract_chat_template_args(
+ cfg=cfg, ds_cfg=ds_cfg
+ )
+ return get_chat_template(
+ user_choice=chat_template_choice,
+ jinja_template=chat_template_jinja,
+ tokenizer=tokenizer,
+ )
+
+
+def register_chat_template(template_name: str, chat_template: str):
+ """
+ Registers chat templates.
+
+ Args:
+ template_name (str): The name of the template.
+ chat_template (str): The template string.
+ """
+
+ if template_name in _CHAT_TEMPLATES:
+ raise ValueError(f"Template '{template_name}' already exists.")
+
+ _CHAT_TEMPLATES[template_name] = chat_template
diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py
new file mode 100644
index 0000000000..93502b67d7
--- /dev/null
+++ b/src/axolotl/utils/collators/__init__.py
@@ -0,0 +1,10 @@
+"""
+shared axolotl collators for multipack, mamba, multimodal
+"""
+from .batching import ( # noqa: F401
+ BatchSamplerDataCollatorForSeq2Seq,
+ DataCollatorForSeq2Seq,
+ PretrainingBatchSamplerDataCollatorForSeq2Seq,
+ V2BatchSamplerDataCollatorForSeq2Seq,
+)
+from .mamba import MambaDataCollator # noqa: F401
diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators/batching.py
similarity index 90%
rename from src/axolotl/utils/collators.py
rename to src/axolotl/utils/collators/batching.py
index 26c7fa9f3c..7cf771421c 100644
--- a/src/axolotl/utils/collators.py
+++ b/src/axolotl/utils/collators/batching.py
@@ -1,17 +1,14 @@
"""
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
+
from dataclasses import dataclass
-from typing import Any, Dict, Optional, Sequence, Union
+from typing import Any, Optional, Union
import numpy as np
-import torch
-import transformers
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
-IGNORE_INDEX = -100
-
@dataclass
class DataCollatorForSeq2Seq:
@@ -183,34 +180,6 @@ def __call__(self, features, return_tensors=None):
return super().__call__(out_features, return_tensors=return_tensors)
-@dataclass
-class MambaDataCollator:
- """
- Collator for State Space Models (Mamba)
- """
-
- tokenizer: transformers.PreTrainedTokenizer
-
- def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
- input_ids, labels = tuple(
- [torch.LongTensor(instance[key]) for instance in instances]
- for key in ("input_ids", "labels")
- )
- input_ids = torch.nn.utils.rnn.pad_sequence(
- input_ids,
- batch_first=True,
- padding_value=self.tokenizer.pad_token_id,
- )
- labels = torch.nn.utils.rnn.pad_sequence(
- labels, batch_first=True, padding_value=IGNORE_INDEX
- )
-
- return {
- "input_ids": input_ids,
- "labels": labels,
- }
-
-
@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
diff --git a/src/axolotl/utils/collators/core.py b/src/axolotl/utils/collators/core.py
new file mode 100644
index 0000000000..0eae0c3bda
--- /dev/null
+++ b/src/axolotl/utils/collators/core.py
@@ -0,0 +1,4 @@
+"""
+basic shared collator constants
+"""
+IGNORE_INDEX = -100
diff --git a/src/axolotl/utils/collators/mamba.py b/src/axolotl/utils/collators/mamba.py
new file mode 100644
index 0000000000..0c4a22fcc0
--- /dev/null
+++ b/src/axolotl/utils/collators/mamba.py
@@ -0,0 +1,38 @@
+"""
+collators for Mamba
+"""
+from dataclasses import dataclass
+from typing import Dict, Sequence
+
+import torch
+import transformers
+
+from axolotl.utils.collators.core import IGNORE_INDEX
+
+
+@dataclass
+class MambaDataCollator:
+ """
+ Collator for State Space Models (Mamba)
+ """
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple(
+ [torch.LongTensor(instance[key]) for instance in instances]
+ for key in ("input_ids", "labels")
+ )
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(
+ labels, batch_first=True, padding_value=IGNORE_INDEX
+ )
+
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ }
diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py
new file mode 100644
index 0000000000..b9b67f8750
--- /dev/null
+++ b/src/axolotl/utils/collators/mm_chat.py
@@ -0,0 +1,83 @@
+"""
+Collators for multi-modal chat messages and packing
+"""
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+from PIL import Image
+from transformers import PreTrainedTokenizerBase, ProcessorMixin
+from transformers.data.data_collator import DataCollatorMixin
+from transformers.utils import PaddingStrategy
+
+
+@dataclass
+class MultiModalChatDataCollator(DataCollatorMixin):
+ """
+ Collator for multi-modal chat messages
+ """
+
+ tokenizer: PreTrainedTokenizerBase
+ processor: ProcessorMixin
+ return_tensors: str = "pt"
+ chat_template: Optional[str] = None
+ packing: bool = False
+ max_images: int = -1
+ padding: Union[bool, str, PaddingStrategy] = True
+ pad_to_multiple_of: Optional[int] = None
+
+ def __post_init__(self):
+ if self.packing:
+ raise ValueError("Packing is currently not supported.")
+
+ def torch_call(
+ self, examples: List[Union[List[int], Any, Dict[str, Any]]]
+ ) -> Dict[str, Any]:
+ # Handle dict or lists with proper padding and conversion to tensor.
+
+ return self.__class__.process_rows(
+ examples, self.processor, self.chat_template, self.max_images
+ )
+
+ @staticmethod
+ def process_rows(examples, processor, chat_template, max_images, length_only=False):
+ # HINT: use `_torch_collate_batch` to stack and pad tensors
+ # see also DataCollatorWithFlattening and DefaultDataCollator
+
+ # *** This is COPIED from the trl example sft_vlm.py code ***
+ # use this as a starting point
+
+ # Get the texts and images, and apply the chat template
+ texts = [
+ processor.apply_chat_template(
+ example["messages"], chat_template=chat_template, tokenize=False
+ )
+ for example in examples
+ ]
+ images = [
+ Image.open(example["images"])
+ if isinstance(example["images"], str)
+ else example["images"]
+ for example in examples
+ ]
+
+ if max_images > 0:
+ images = [img_batch[:max_images] for img_batch in images]
+
+ # Tokenize the texts and process the images
+ batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
+
+ # The labels are the input_ids, and we mask the padding tokens in the loss computation
+ labels = batch["input_ids"].clone()
+ labels[labels == processor.tokenizer.pad_token_id] = -100 #
+ # Ignore the image token index in the loss computation (model specific)
+ image_token_id = processor.tokenizer.convert_tokens_to_ids(
+ processor.image_token
+ )
+ labels[labels == image_token_id] = -100
+ batch["labels"] = labels
+
+ if length_only:
+ return {
+ "length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
+ }
+ return batch
diff --git a/src/axolotl/utils/comet_.py b/src/axolotl/utils/comet_.py
new file mode 100644
index 0000000000..b4ecc80ad9
--- /dev/null
+++ b/src/axolotl/utils/comet_.py
@@ -0,0 +1,93 @@
+"""Module for wandb utilities"""
+
+import logging
+import os
+
+from axolotl.utils.dict import DictDefault
+
+LOG = logging.getLogger("axolotl.utils.comet_")
+
+COMET_ENV_MAPPING_OVERRIDE = {
+ "comet_mode": "COMET_START_MODE",
+ "comet_online": "COMET_START_ONLINE",
+}
+COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
+ "auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
+ "auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
+ "auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
+ "auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
+ "auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
+ "auto_log_co2": "COMET_AUTO_LOG_CO2",
+ "auto_metric_logging": "COMET_AUTO_LOG_METRICS",
+ "auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
+ "auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
+ "auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
+ "comet_disabled": "COMET_AUTO_LOG_DISABLE",
+ "display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
+ "distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
+ "log_code": "COMET_AUTO_LOG_CODE",
+ "log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
+ "log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
+ "log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
+ "log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
+ "log_env_host": "COMET_AUTO_LOG_ENV_HOST",
+ "log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
+ "log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
+ "log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
+ "log_graph": "COMET_AUTO_LOG_GRAPH",
+ "name": "COMET_START_EXPERIMENT_NAME",
+ "offline_directory": "COMET_OFFLINE_DIRECTORY",
+ "parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
+ "tags": "COMET_START_EXPERIMENT_TAGS",
+}
+
+
+def python_value_to_environ_value(python_value):
+ if isinstance(python_value, bool):
+ if python_value is True:
+ return "true"
+
+ return "false"
+
+ if isinstance(python_value, int):
+ return str(python_value)
+
+ if isinstance(python_value, list): # Comet only have one list of string parameter
+ return ",".join(map(str, python_value))
+
+ return python_value
+
+
+def setup_comet_env_vars(cfg: DictDefault):
+ # TODO, we need to convert Axolotl configuration to environment variables
+ # as Transformers integration are call first and would create an
+ # Experiment first
+
+ for key in cfg.keys():
+ if key.startswith("comet_") and key != "comet_experiment_config":
+ value = cfg.get(key, "")
+
+ if value is not None and value != "":
+ env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
+ final_value = python_value_to_environ_value(value)
+ os.environ[env_variable_name] = final_value
+
+ if cfg.comet_experiment_config:
+ for key, value in cfg.comet_experiment_config.items():
+ if value is not None and value != "":
+ config_env_variable_name = (
+ COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
+ )
+
+ if config_env_variable_name is None:
+ LOG.warning(
+ f"Unknown Comet Experiment Config name {key}, ignoring it"
+ )
+ continue
+
+ final_value = python_value_to_environ_value(value)
+ os.environ[config_env_variable_name] = final_value
+
+ # Enable comet if project name is present
+ if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
+ cfg.use_comet = True
diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py
index ed165e89ca..afc8c4fc41 100644
--- a/src/axolotl/utils/config/__init__.py
+++ b/src/axolotl/utils/config/__init__.py
@@ -8,11 +8,14 @@
import torch
from transformers.utils import is_torch_bf16_gpu_available
+from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
+from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
from axolotl.utils.config.models.input.v0_4_1 import (
- SUPPORTED_METRICS,
- AxolotlConfigWCapabilities,
- AxolotlInputConfig,
+ AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
+)
+from axolotl.utils.config.models.input.v0_4_1 import (
+ AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
@@ -118,15 +121,36 @@ def normalize_config(cfg):
cfg.base_model_config = cfg.base_model
model_config = load_model_config(cfg)
- cfg.model_config_type = model_config.model_type
cfg.tokenizer_config = (
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
)
+ cfg.is_multimodal = (
+ hasattr(model_config, "model_type")
+ and model_config.model_type in ["llava", "mllama"]
+ or any(
+ multimodal_name in cfg.base_model.lower()
+ for multimodal_name in [
+ "pixtral",
+ ]
+ )
+ or cfg.is_multimodal
+ )
+ if cfg.is_multimodal:
+ cfg.processor_config = (
+ cfg.processor_config or cfg.base_model_config or cfg.base_model
+ )
+ model_config = model_config.text_config
+
+ cfg.model_config_type = model_config.model_type
+
# figure out if the model is llama
cfg.is_llama_derived_model = (
- (hasattr(model_config, "model_type") and model_config.model_type == "llama")
+ (
+ hasattr(model_config, "model_type")
+ and model_config.model_type == ["llama", "mllama_text_model"]
+ )
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
@@ -204,9 +228,19 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
)
cfg.datasets[idx].chat_template = cfg.chat_template
+ cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
+ AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
+ AxolotlInputConfig = AxolotlInputConfigBase
+
+ if cfg.plugins:
+ (
+ AxolotlConfigWCapabilities, # pylint: disable=invalid-name
+ AxolotlInputConfig, # pylint: disable=invalid-name
+ ) = merge_input_args()
+
if capabilities:
return DictDefault(
dict(
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index 5e690bb88e..96e5330005 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -8,9 +8,16 @@
import os
from enum import Enum
from importlib.metadata import version
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union
+from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
-from pydantic import BaseModel, Field, conlist, field_validator, model_validator
+from pydantic import (
+ BaseModel,
+ Field,
+ StringConstraints,
+ conlist,
+ field_validator,
+ model_validator,
+)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
@@ -21,6 +28,37 @@
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
+class RLType(str, Enum):
+ """RL trainer type configuration subset"""
+
+ dpo = "dpo" # pylint: disable=invalid-name
+ ipo = "ipo" # pylint: disable=invalid-name
+ orpo = "orpo" # pylint: disable=invalid-name
+ kto = "kto" # pylint: disable=invalid-name
+ simpo = "simpo" # pylint: disable=invalid-name
+
+
+class ChatTemplate(str, Enum):
+ """Chat templates configuration subset"""
+
+ alpaca = "alpaca" # pylint: disable=invalid-name
+ chatml = "chatml" # pylint: disable=invalid-name
+ mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
+ mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
+ mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
+ gemma = "gemma" # pylint: disable=invalid-name
+ cohere = "cohere" # pylint: disable=invalid-name
+ llama3 = "llama3" # pylint: disable=invalid-name
+ llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
+ phi_3 = "phi_3" # pylint: disable=invalid-name
+ phi_35 = "phi_35" # pylint: disable=invalid-name
+ deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
+ jamba = "jamba" # pylint: disable=invalid-name
+ jinja = "jinja" # pylint: disable=invalid-name
+ qwen_25 = "qwen_25" # pylint: disable=invalid-name
+ tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
+
+
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -102,14 +140,22 @@ class SFTDataset(BaseModel):
path: Optional[str] = None
split: Optional[str] = None
type: Optional[Union[str, UserDefinedPrompterType]] = None
+ input_transform: Optional[str] = None
shards: Optional[int] = None
conversation: Optional[str] = None
- chat_template: Optional[str] = None
+ # Do not make this too strict or it will break the validator to choose different dataset class
+ chat_template: Optional[
+ Union[
+ ChatTemplate,
+ str,
+ ]
+ ] = None
+ chat_template_jinja: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None
+ input_format: Optional[str] = None
name: Optional[str] = None
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
-
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
@@ -120,11 +166,31 @@ class SFTDataset(BaseModel):
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
-
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
-
trust_remote_code: Optional[bool] = False
+ revision: Optional[str] = None
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_chat_template_config(cls, data):
+ # Set chat_template to tokenizer_default if not set
+ if data.get("type") == "chat_template" and not data.get("chat_template"):
+ data["chat_template"] = ChatTemplate.tokenizer_default
+
+ # if chat_template is set to jinja, chat_template_jinja is required
+ if data.get("chat_template") == ChatTemplate.jinja and not data.get(
+ "chat_template_jinja"
+ ):
+ raise ValueError(
+ "chat_template_jinja is required when chat_template is set to jinja"
+ )
+
+ # If chat_template_jinja is set, set chat_template to jinja
+ if data.get("chat_template_jinja") and not data.get("chat_template"):
+ data["chat_template"] = ChatTemplate.jinja
+
+ return data
class UserDefinedDPOType(BaseModel):
@@ -146,6 +212,7 @@ class DPODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
+ revision: Optional[str] = None
class UserDefinedKTOType(BaseModel):
@@ -167,29 +234,7 @@ class KTODataset(BaseModel):
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
-
-
-class RLType(str, Enum):
- """RL trainer type configuration subset"""
-
- dpo = "dpo" # pylint: disable=invalid-name
- ipo = "ipo" # pylint: disable=invalid-name
- orpo = "orpo" # pylint: disable=invalid-name
- kto = "kto" # pylint: disable=invalid-name
- simpo = "simpo" # pylint: disable=invalid-name
-
-
-class ChatTemplate(str, Enum):
- """Chat templates configuration subset"""
-
- alpaca = "alpaca" # pylint: disable=invalid-name
- chatml = "chatml" # pylint: disable=invalid-name
- inst = "inst" # pylint: disable=invalid-name
- gemma = "gemma" # pylint: disable=invalid-name
- cohere = "cohere" # pylint: disable=invalid-name
- llama3 = "llama3" # pylint: disable=invalid-name
- phi_3 = "phi_3" # pylint: disable=invalid-name
- deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
+ revision: Optional[str] = None
class LoftQConfig(BaseModel):
@@ -226,11 +271,12 @@ class LoraConfig(BaseModel):
lora_r: Optional[int] = None
lora_alpha: Optional[int] = None
lora_fan_in_fan_out: Optional[bool] = None
- lora_target_modules: Optional[List[str]] = None
+ lora_target_modules: Optional[Union[str, List[str]]] = None
lora_target_linear: Optional[bool] = None
lora_modules_to_save: Optional[List[str]] = None
lora_dropout: Optional[float] = 0.0
peft_layers_to_transform: Optional[List[int]] = None
+ peft_layers_pattern: Optional[List[str]] = None
peft: Optional[PeftConfig] = None
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
@@ -296,6 +342,13 @@ def validate_qlora(self):
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
+ @field_validator("loraplus_lr_embedding")
+ @classmethod
+ def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
+ if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
+ loraplus_lr_embedding = float(loraplus_lr_embedding)
+ return loraplus_lr_embedding
+
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
@@ -319,6 +372,9 @@ class ModelInputConfig(BaseModel):
tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"}
)
+ processor_type: Optional[str] = Field(
+ default=None, metadata={"help": "transformers processor class"}
+ )
trust_remote_code: Optional[bool] = None
model_kwargs: Optional[Dict[str, Any]] = None
@@ -354,6 +410,8 @@ class HyperparametersConfig(BaseModel):
},
)
+ auto_find_batch_size: Optional[bool] = None
+
train_on_inputs: Optional[bool] = False
group_by_length: Optional[bool] = None
@@ -428,6 +486,7 @@ class MLFlowConfig(BaseModel):
use_mlflow: Optional[bool] = None
mlflow_tracking_uri: Optional[str] = None
mlflow_experiment_name: Optional[str] = None
+ mlflow_run_name: Optional[str] = None
hf_mlflow_log_artifacts: Optional[bool] = None
@@ -473,6 +532,19 @@ def check_wandb_run(cls, data):
return data
+class CometConfig(BaseModel):
+ """Comet configuration subset"""
+
+ use_comet: Optional[bool] = None
+ comet_api_key: Optional[str] = None
+ comet_workspace: Optional[str] = None
+ comet_project_name: Optional[str] = None
+ comet_experiment_key: Optional[str] = None
+ comet_mode: Optional[str] = None
+ comet_online: Optional[bool] = None
+ comet_experiment_config: Optional[Dict[str, Any]] = None
+
+
class GradioConfig(BaseModel):
"""Gradio configuration subset"""
@@ -493,6 +565,7 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
+ CometConfig,
LISAConfig,
GradioConfig,
RemappedParameters,
@@ -510,8 +583,10 @@ class Config:
resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
+ mean_resizing_embeddings: Optional[bool] = False
rl: Optional[RLType] = None
+ reward_model: Optional[bool] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
@@ -519,6 +594,7 @@ class Config:
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
+ skip_prepare_dataset: Optional[bool] = False
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
@@ -591,6 +667,7 @@ class Config:
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
+ multipack_real_batches: Optional[bool] = None
# for PoSE context length extension
use_pose: Optional[bool] = None
@@ -628,6 +705,9 @@ class Config:
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
+ fsdp_final_state_dict_type: Optional[
+ Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
+ ] = None
val_set_size: Optional[float] = Field(default=0.0)
@@ -673,7 +753,13 @@ class Config:
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None
- chat_template: Optional[ChatTemplate] = None
+ chat_template: Optional[
+ Union[
+ ChatTemplate,
+ Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
+ ]
+ ] = None
+ chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
@@ -782,6 +868,23 @@ def check_sample_packing_w_xformers(cls, data):
return data
+ @model_validator(mode="before")
+ @classmethod
+ def check_chat_template_config(cls, data):
+ # if chat_template is set to jinja, chat_template_jinja is required
+ if data.get("chat_template") == ChatTemplate.jinja and not data.get(
+ "chat_template_jinja"
+ ):
+ raise ValueError(
+ "chat_template_jinja is required when chat_template is set to jinja"
+ )
+
+ # If chat_template_jinja is set, set chat_template to jinja
+ if data.get("chat_template_jinja") and not data.get("chat_template"):
+ data["chat_template"] = ChatTemplate.jinja
+
+ return data
+
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):
@@ -812,6 +915,17 @@ def hint_sample_packing_padding(cls, data):
)
return data
+ @model_validator(mode="before")
+ @classmethod
+ def hint_reward_model_pad(cls, data):
+ if data.get("reward_model") and not data.get("pad_to_sequence_len"):
+ LOG.warning(
+ "`pad_to_sequence_len: true` is recommended when using reward_model"
+ )
+ if data.get("pad_to_sequence_len") is None:
+ data["pad_to_sequence_len"] = True
+ return data
+
@model_validator(mode="before")
@classmethod
def check_gas_bsz(cls, data):
@@ -945,6 +1059,26 @@ def check_evals(cls, data):
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
+ if data.get("do_bench_eval") and not (
+ data.get("evals_per_epoch") or data.get("eval_steps")
+ ):
+ raise ValueError(
+ "do_bench_eval requires evals_per_epoch or eval_steps to be set."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_test_datasets_bench(cls, data):
+ if (
+ data.get("do_bench_eval")
+ and not data.get("test_datasets")
+ and not data.get("val_set_size")
+ ):
+ LOG.warning(
+ "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
+ )
+ data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
return data
@model_validator(mode="before")
@@ -982,6 +1116,18 @@ def check_eval_packing(cls, data):
return data
+ @model_validator(mode="before")
+ @classmethod
+ def check_mm_prepare(cls, data):
+ if data.get("skip_prepare_dataset"):
+ if data.get("remove_unused_columns") is None:
+ LOG.info(
+ "setting `remove_unused_columns: false` for skip_prepare_dataset"
+ )
+ data["remove_unused_columns"] = False
+
+ return data
+
@model_validator(mode="before")
@classmethod
def check_warmup(cls, data):
@@ -1009,12 +1155,20 @@ def validate_neftune_noise_alpha(cls, neftune_noise_alpha):
return neftune_noise_alpha
@model_validator(mode="after")
- def check(self):
+ def check_rl_beta(self):
if self.dpo_beta and not self.rl_beta:
self.rl_beta = self.dpo_beta
del self.dpo_beta
return self
+ @model_validator(mode="after")
+ def check_simpo_warmup(self):
+ if self.rl == "simpo" and self.warmup_ratio:
+ raise ValueError(
+ "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
+ )
+ return self
+
@model_validator(mode="before")
@classmethod
def check_frozen(cls, data):
@@ -1029,6 +1183,15 @@ def check_frozen(cls, data):
return data
+ @model_validator(mode="before")
+ @classmethod
+ def check_peft_layers_pattern(cls, data):
+ if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
+ raise ValueError(
+ "peft_layers_pattern requires peft_layers_to_transform to be set"
+ )
+ return data
+
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (
@@ -1148,6 +1311,20 @@ def check_fsdp_offload_w_8bit_optimizer(cls, data):
)
return data
+ @model_validator(mode="before")
+ @classmethod
+ def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
+ if (
+ data.get("fsdp")
+ and data.get("save_safetensors")
+ and data.get("fsdp_config")
+ and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
+ ):
+ raise ValueError(
+ "FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
+ )
+ return data
+
@model_validator(mode="before")
@classmethod
def check_causal_lm_evals(cls, data):
diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py
index e056c7f509..16f38218cd 100644
--- a/src/axolotl/utils/data/pretraining.py
+++ b/src/axolotl/utils/data/pretraining.py
@@ -18,10 +18,10 @@
def encode_pretraining(
- tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
+ tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
) -> Dict[str, List]:
res = tokenizer(
- examples,
+ examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py
index d0324e1ebd..35bd5fcbb7 100644
--- a/src/axolotl/utils/data/rl.py
+++ b/src/axolotl/utils/data/rl.py
@@ -90,6 +90,7 @@ def load_split(dataset_cfgs, _cfg):
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
+ revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)
diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py
index 1b6df1cded..ce01b44098 100644
--- a/src/axolotl/utils/data/sft.py
+++ b/src/axolotl/utils/data/sft.py
@@ -19,10 +19,12 @@
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies import load
+from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
+ DatasetWrappingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
@@ -51,20 +53,31 @@
LOG = logging.getLogger("axolotl")
-def prepare_dataset(cfg, tokenizer):
+def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_local_main_process()):
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
+ tokenizer,
+ cfg,
+ DEFAULT_DATASET_PREPARED_PATH,
+ split="train",
+ processor=processor,
)
_, eval_dataset, _ = load_prepare_datasets(
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
+ tokenizer,
+ cfg,
+ DEFAULT_DATASET_PREPARED_PATH,
+ split="test",
+ processor=processor,
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
+ tokenizer,
+ cfg,
+ DEFAULT_DATASET_PREPARED_PATH,
+ processor=processor,
)
else:
path = cfg.pretraining_dataset
@@ -123,6 +136,7 @@ def load_tokenized_prepared_datasets(
cfg,
default_dataset_prepared_path,
split="train",
+ processor=None,
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config
@@ -180,6 +194,7 @@ def load_tokenized_prepared_datasets(
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
+ and not cfg.skip_prepare_dataset
):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
@@ -229,6 +244,7 @@ def for_d_in_datasets(dataset_configs):
name=config_dataset.name,
streaming=True,
token=use_auth_token,
+ revision=config_dataset.revision,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -333,6 +349,7 @@ def for_d_in_datasets(dataset_configs):
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
+ revision=config_dataset.revision,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
@@ -367,6 +384,7 @@ def for_d_in_datasets(dataset_configs):
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
+ revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
@@ -376,6 +394,7 @@ def for_d_in_datasets(dataset_configs):
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
+ revision=config_dataset.revision,
)
)
else:
@@ -420,15 +439,19 @@ def for_d_in_datasets(dataset_configs):
config_dataset=config_dataset,
tokenizer=tokenizer,
cfg=cfg,
- dataset=ds,
d_base_type=d_base_type,
+ dataset=ds,
d_prompt_style=d_prompt_style,
+ processor=processor,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
- LOG.info("merging datasets")
- dataset = concatenate_datasets(datasets)
+ if len(datasets) == 1:
+ dataset = datasets[0]
+ else:
+ LOG.info("merging datasets")
+ dataset = concatenate_datasets(datasets)
if len(datasets) > 1:
if cfg.shuffle_merged_datasets:
@@ -437,9 +460,10 @@ def for_d_in_datasets(dataset_configs):
else:
LOG.debug("NOT shuffling merged datasets")
- dataset, _ = process_datasets_for_packing(cfg, dataset, None)
+ if cfg.sample_packing and not cfg.skip_prepare_dataset:
+ dataset, _ = process_datasets_for_packing(cfg, dataset, None)
- if cfg.local_rank == 0:
+ if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
@@ -478,9 +502,14 @@ def load_prepare_datasets(
cfg,
default_dataset_prepared_path,
split="train",
+ processor=None,
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
- tokenizer, cfg, default_dataset_prepared_path, split=split
+ tokenizer,
+ cfg,
+ default_dataset_prepared_path,
+ split=split,
+ processor=processor,
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
@@ -546,6 +575,7 @@ def get_dataset_wrapper(
d_base_type,
dataset,
d_prompt_style=None,
+ processor=None, # pylint: disable=unused-argument
):
dataset_wrapper = None
dataset_prompter = None
@@ -578,13 +608,31 @@ def get_dataset_wrapper(
dataset,
**ds_kwargs,
)
- elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
+ elif cfg.skip_prepare_dataset:
+ dataset_wrapper = dataset
+ elif ds_strategy := config_dataset.type.startswith(
+ "bradley_terry"
+ ) and bradley_terry_load(
+ config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
+ ):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
+ elif ds_strategy := load(
+ config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
+ ):
+ if isinstance(ds_strategy, DatasetWrappingStrategy):
+ dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
+ else:
+ dataset_prompter = UnsupportedPrompter()
+ dataset_wrapper = TokenizedPromptDataset(
+ ds_strategy,
+ dataset,
+ **ds_kwargs,
+ )
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(
diff --git a/src/axolotl/utils/mlflow_.py b/src/axolotl/utils/mlflow_.py
index ce77390342..8710b07d06 100644
--- a/src/axolotl/utils/mlflow_.py
+++ b/src/axolotl/utils/mlflow_.py
@@ -16,3 +16,7 @@ def setup_mlflow_env_vars(cfg: DictDefault):
# Enable mlflow if experiment name is present
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
cfg.use_mlflow = True
+
+ # Enable logging hf artifacts in mlflow if value is truthy
+ if cfg.hf_mlflow_log_artifacts is True:
+ os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 5ac66260a7..f3386cccfa 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -28,14 +28,22 @@
AddedToken,
AutoConfig,
AutoModelForCausalLM,
+ AutoModelForVision2Seq,
+ AutoProcessor,
AutoTokenizer,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
+ LlavaForConditionalGeneration,
+ MllamaForConditionalGeneration,
PreTrainedModel,
PreTrainedTokenizerBase,
+ ProcessorMixin,
+)
+from transformers.integrations.deepspeed import (
+ HfTrainerDeepSpeedConfig,
+ is_deepspeed_zero3_enabled,
)
-from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.models.mamba import fix_mamba_attn_for_loss
@@ -45,7 +53,7 @@
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
-from axolotl.utils.chat_templates import chat_templates
+from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
@@ -80,6 +88,9 @@ def get_module_class_from_name(module, name):
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
+ if cfg.is_multimodal:
+ model_config = model_config.text_config
+
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
@@ -285,7 +296,10 @@ def load_tokenizer(cfg):
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
- chat_template_string = chat_templates(cfg.chat_template)
+ chat_template_string = get_chat_template_from_config(
+ cfg=cfg,
+ tokenizer=tokenizer,
+ )
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
@@ -299,636 +313,879 @@ def load_tokenizer(cfg):
return tokenizer
-def load_model(
- cfg: DictDefault,
- tokenizer: PreTrainedTokenizerBase,
- inference: bool = False,
- reference_model: bool = False,
-) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
+def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
+ processor_kwargs: Dict[str, Any] = {} # do we actually need this?
+
+ processor_cls = AutoProcessor
+ if cfg.processor_type:
+ processor_cls = getattr(transformers, cfg.processor_type)
+
+ processor = processor_cls.from_pretrained(
+ cfg.processor_config,
+ trust_remote_code=cfg.trust_remote_code or False,
+ tokenizer=tokenizer,
+ **processor_kwargs,
+ )
+
+ return processor
+
+
+class ModelLoader:
"""
- Load a model for a given configuration and tokenizer.
+ ModelLoader: managing all the config and monkey patches while loading model
"""
- base_model = cfg.base_model
- model_type = cfg.type_of_model
- model_config = load_model_config(cfg)
- # TODO refactor as a kwarg
- load_in_8bit = cfg.load_in_8bit
+ def __init__(
+ self,
+ cfg: DictDefault,
+ tokenizer: PreTrainedTokenizerBase,
+ *,
+ processor: ProcessorMixin = None, # pylint: disable=unused-argument
+ inference: bool = False,
+ reference_model: bool = False,
+ **kwargs, # pylint: disable=unused-argument
+ ) -> None:
+ self.cfg = cfg
+ self.tokenizer = tokenizer
+ self.inference: bool = inference
+ self.reference_model: bool = reference_model
+
+ # init model kwargs
+ self.model_kwargs: Dict[str, Any] = {}
+ if cfg.model_kwargs:
+ for key, val in cfg.model_kwargs.items():
+ self.model_kwargs[key] = val
+
+ # init model
+ self.model: PreTrainedModel
+ self.base_model = cfg.base_model
+ self.model_type = cfg.type_of_model
+
+ # init model config
+ self.model_config = load_model_config(cfg)
+ if cfg.is_multimodal:
+ self.text_model_config = self.model_config.text_config
+ else:
+ self.text_model_config = self.model_config
- if cfg.gradient_checkpointing == "unsloth":
- transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
+ self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
- if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
- if cfg.flash_attention:
- from axolotl.monkeypatch.btlm_attn_hijack_flash import (
- replace_btlm_attn_with_flash_attn,
- )
+ def apply_patches(self) -> None:
+ # load any patches from plugins
+ from axolotl.integrations.base import PluginManager
- replace_btlm_attn_with_flash_attn(cfg.base_model)
+ plugin_manager = PluginManager.get_instance()
+ plugin_manager.pre_model_load(self.cfg)
- if (
- hasattr(model_config, "model_type")
- and model_config.model_type == "stablelm_epoch"
- ):
- if cfg.flash_attention and cfg.sample_packing:
- from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
- replace_stablelm_attn_with_flash_attn,
+ if self.cfg.gradient_checkpointing == "unsloth":
+ transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
+
+ if self.cfg.flash_attention:
+ self.patch_attention()
+
+ if self.cfg.sample_packing and self.cfg.s2_attention:
+ raise ValueError(
+ "Received `sample_packing=true` and `s2_attention=true`; however, \
+ shifted-sparse attention does not currently support sample packing."
)
- replace_stablelm_attn_with_flash_attn(cfg.base_model)
+ if (
+ self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
+ and self.cfg.flash_attention
+ and self.cfg.sample_packing
+ ):
+ patch_for_multipack(
+ self.cfg.model_config_type,
+ model_name=self.cfg.base_model,
+ is_remote_code=self.cfg.trust_remote_code,
+ )
- if cfg.sample_packing and cfg.s2_attention:
- raise ValueError(
- "Received `sample_packing=true` and `s2_attention=true`; however, \
- shifted-sparse attention does not currently support sample packing."
- )
+ if self.cfg.is_llama_derived_model:
+ self.patch_loss()
+ if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
+ from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
- if (
- cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
- and cfg.flash_attention
- and cfg.sample_packing
- ):
- patch_for_multipack(
- cfg.model_config_type,
- model_name=cfg.base_model,
- is_remote_code=cfg.trust_remote_code,
- )
+ patch_self_attn_lora()
+ elif self.cfg.is_llama_derived_model:
+ self.patch_llama_derived_model()
- if cfg.is_llama_derived_model:
- from axolotl.monkeypatch.llama_attn_hijack_flash import (
- patch_llama_cross_entropy,
- patch_llama_rms_norm,
+ if (
+ self.cfg.model_config_type == "mistral"
+ and self.cfg.flash_attn_cross_entropy_loss
+ ):
+ from axolotl.monkeypatch.mistral_attn_hijack_flash import (
+ patch_mistral_cross_entropy,
)
- if cfg.flash_attn_cross_entropy:
- patch_llama_cross_entropy()
- if cfg.flash_attn_rms_norm:
- patch_llama_rms_norm()
- elif cfg.unsloth_rms_norm:
- from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
-
- patch_unsloth_layernorm()
- if cfg.unsloth_cross_entropy_loss:
- from axolotl.monkeypatch.unsloth_ import (
- integrate_cross_entropy_loss_patch,
+ patch_mistral_cross_entropy()
+
+ def patch_attention(self) -> None:
+ if hasattr(self.model_config, "model_type"):
+ if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
+ from axolotl.monkeypatch.attention.mllama import patch_mllama
+
+ patch_mllama()
+
+ if self.model_config.model_type == "btlm":
+ from axolotl.monkeypatch.btlm_attn_hijack_flash import (
+ replace_btlm_attn_with_flash_attn,
)
- integrate_cross_entropy_loss_patch(model_type="llama")
- if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
- from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
+ replace_btlm_attn_with_flash_attn(self.cfg.base_model)
- patch_self_attn_lora()
- elif cfg.is_llama_derived_model:
- # Modify all llama derived models in one block
+ if (
+ self.model_config.model_type == "stablelm_epoch"
+ and self.cfg.sample_packing
+ ):
+ from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
+ replace_stablelm_attn_with_flash_attn,
+ )
- if cfg.flash_attention:
+ replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
+
+ def patch_loss(self) -> None:
+ """
+ Patch loss functions
+ """
+ from axolotl.monkeypatch.llama_attn_hijack_flash import (
+ patch_llama_cross_entropy,
+ patch_llama_rms_norm,
+ )
+
+ if self.cfg.flash_attn_cross_entropy:
+ patch_llama_cross_entropy()
+ if self.cfg.flash_attn_rms_norm:
+ patch_llama_rms_norm()
+ elif self.cfg.unsloth_rms_norm:
+ from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
+
+ patch_unsloth_layernorm()
+ if self.cfg.unsloth_cross_entropy_loss:
+ from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
+
+ integrate_cross_entropy_loss_patch(model_type="llama")
+ if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
+ from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
+
+ patch_self_attn_lora()
+
+ def patch_llama_derived_model(self) -> None:
+ """
+ Modify all llama derived models in one block
+ """
+
+ if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
- if cfg.sample_packing:
- if cfg.device not in ["mps", "cpu"] and not inference:
+ if self.cfg.sample_packing:
+ if self.cfg.device not in ["mps", "cpu"] and not self.inference:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
- cross_entropy=cfg.flash_attn_cross_entropy,
- rms_norm=cfg.flash_attn_rms_norm,
+ cross_entropy=self.cfg.flash_attn_cross_entropy,
+ rms_norm=self.cfg.flash_attn_rms_norm,
)
- elif cfg.s2_attention:
+ elif self.cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
packed=False,
- cross_entropy=cfg.flash_attn_cross_entropy,
- rms_norm=cfg.flash_attn_rms_norm,
+ cross_entropy=self.cfg.flash_attn_cross_entropy,
+ rms_norm=self.cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
- elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
+ elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
packed=False,
- cross_entropy=cfg.flash_attn_cross_entropy,
- rms_norm=cfg.flash_attn_rms_norm,
+ cross_entropy=self.cfg.flash_attn_cross_entropy,
+ rms_norm=self.cfg.flash_attn_rms_norm,
)
- elif cfg.xformers_attention:
+ elif self.cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("patching with xformers attention")
hijack_llama_attention()
- elif cfg.sample_packing:
+ elif self.cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
hijack_llama_prepare_4d_mask()
- elif cfg.s2_attention:
+ elif self.cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)
- if cfg.unsloth_cross_entropy_loss:
+ if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
- if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
+ if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
- # Modify mistral derived models
- if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
- from axolotl.monkeypatch.mistral_attn_hijack_flash import (
- patch_mistral_cross_entropy,
- )
-
- patch_mistral_cross_entropy()
-
- model_kwargs: Dict[str, Any] = {}
-
- if cfg.model_kwargs:
- for key, val in cfg.model_kwargs.items():
- model_kwargs[key] = val
-
- max_memory = cfg.max_memory
- device_map = cfg.device_map
-
- if cfg.gpu_memory_limit:
- gpu_memory_limit = (
- str(cfg.gpu_memory_limit) + "GiB"
- if isinstance(cfg.gpu_memory_limit, int)
- else cfg.gpu_memory_limit
- )
-
- max_memory = {}
- for i in range(torch.cuda.device_count()):
- max_memory[i] = gpu_memory_limit
- max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
+ def set_auto_model_loader(self) -> None:
+ """set self.AutoModelLoader
+ - default value: AutoModelForCausalLM (set at __init__)
+ - when using a multi modality model, self.AutoModelLoader should
+ be set according to model type of the model
+ """
+ if self.cfg.is_multimodal:
+ if self.model_config.model_type == "llava":
+ self.AutoModelLoader = ( # pylint: disable=invalid-name
+ LlavaForConditionalGeneration
+ )
+ elif self.model_config.model_type == "mllama":
+ self.AutoModelLoader = ( # pylint: disable=invalid-name
+ MllamaForConditionalGeneration
+ )
+ else:
+ self.AutoModelLoader = (
+ AutoModelForVision2Seq # pylint: disable=invalid-name
+ )
- if max_memory is not None:
- # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
- from accelerate import infer_auto_device_map
+ def set_device_map_config(self) -> None:
+ device_map = self.cfg.device_map
+ max_memory = self.cfg.max_memory
- with init_empty_weights():
- model_canvas = AutoModelForCausalLM.from_config(
- model_config, trust_remote_code=cfg.trust_remote_code or False
+ if self.cfg.gpu_memory_limit:
+ gpu_memory_limit = (
+ str(self.cfg.gpu_memory_limit) + "GiB"
+ if isinstance(self.cfg.gpu_memory_limit, int)
+ else self.cfg.gpu_memory_limit
)
- model_canvas.tie_weights()
- device_map = infer_auto_device_map(
- model_canvas,
- max_memory=max_memory,
- dtype=cfg.torch_dtype,
- )
- # We can discard max_memory now as we have a device map set up for us
- max_memory = None
-
- model_kwargs["device_map"] = device_map
- model_kwargs["torch_dtype"] = cfg.torch_dtype
- if torch.backends.mps.is_available():
- model_kwargs["device_map"] = "mps:0"
+ max_memory = {}
+ for i in range(torch.cuda.device_count()):
+ max_memory[i] = gpu_memory_limit
+ max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
- # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
- # if cfg.rl:
- # if torch.cuda.device_count() > 1:
- # if reference_model:
- # model_kwargs["device_map"] = "cuda:" + str(
- # torch.cuda.current_device() + 1
- # )
- # else:
- # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
+ if max_memory is not None:
+ # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
+ from accelerate import infer_auto_device_map
- if is_deepspeed_zero3_enabled():
- del model_kwargs["device_map"]
+ with init_empty_weights():
+ model_canvas = self.AutoModelLoader.from_config(
+ self.model_config,
+ trust_remote_code=self.cfg.trust_remote_code or False,
+ )
+ model_canvas.tie_weights()
+ device_map = infer_auto_device_map(
+ model_canvas,
+ max_memory=max_memory,
+ dtype=self.cfg.torch_dtype,
+ )
+ # We can discard max_memory now as we have a device map set up for us
+ max_memory = None
+
+ self.model_kwargs["device_map"] = device_map
+ self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
+
+ if torch.backends.mps.is_available():
+ self.model_kwargs["device_map"] = "mps:0"
+
+ # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
+ # if cfg.rl:
+ # if torch.cuda.device_count() > 1:
+ # if reference_model:
+ # model_kwargs["device_map"] = "cuda:" + str(
+ # torch.cuda.current_device() + 1
+ # )
+ # else:
+ # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
+
+ if is_deepspeed_zero3_enabled():
+ del self.model_kwargs["device_map"]
+
+ def set_quantization_config(self) -> None:
+ self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
+ self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
+
+ if self.cfg.gptq:
+ if not hasattr(self.model_config, "quantization_config"):
+ LOG.warning(
+ "model config does not contain quantization_config information"
+ )
+ else:
+ if self.cfg.gptq_disable_exllama is not None:
+ self.model_config.quantization_config[
+ "disable_exllama"
+ ] = self.cfg.gptq_disable_exllama
+ self.model_kwargs["quantization_config"] = GPTQConfig(
+ **self.model_config.quantization_config
+ )
+ if (
+ self.cfg.adapter in ["qlora", "lora"]
+ and hasattr(self.model_config, "quantization_config")
+ and self.model_config.quantization_config["quant_method"]
+ in ["gptq", "awq", "bitsandbytes"]
+ ):
+ if self.model_config.quantization_config["quant_method"] == "gptq":
+ self.model_kwargs["quantization_config"] = GPTQConfig(
+ **self.model_config.quantization_config
+ )
+ elif self.model_config.quantization_config["quant_method"] == "awq":
+ self.model_kwargs["quantization_config"] = AwqConfig(
+ **self.model_config.quantization_config
+ )
+ elif (
+ self.model_config.quantization_config["quant_method"] == "bitsandbytes"
+ ):
+ self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
+ **self.model_config.quantization_config
+ )
+ elif self.cfg.adapter == "qlora" and (
+ "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
+ ):
+ bnb_config = {
+ "load_in_4bit": True,
+ "llm_int8_threshold": 6.0,
+ "llm_int8_has_fp16_weight": False,
+ "bnb_4bit_compute_dtype": self.cfg.torch_dtype,
+ "bnb_4bit_use_double_quant": True,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_quant_storage": torch.bfloat16,
+ }
+ if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
+ self.cfg.deepspeed or self.cfg.fsdp
+ ):
+ # for some reason, this causes the loss to be off by an order of magnitude
+ # but deepspeed needs this still in bfloat16
+ bnb_config["bnb_4bit_quant_storage"] = torch.float32
- if cfg.revision_of_model:
- model_kwargs["revision"] = cfg.revision_of_model
+ if self.cfg.bnb_config_kwargs:
+ bnb_config.update(self.cfg.bnb_config_kwargs)
- if cfg.gptq:
- if not hasattr(model_config, "quantization_config"):
- LOG.warning("model config does not contain quantization_config information")
- else:
- if cfg.gptq_disable_exllama is not None:
- model_config.quantization_config[
- "disable_exllama"
- ] = cfg.gptq_disable_exllama
- model_kwargs["quantization_config"] = GPTQConfig(
- **model_config.quantization_config
- )
- if (
- cfg.adapter in ["qlora", "lora"]
- and hasattr(model_config, "quantization_config")
- and model_config.quantization_config["quant_method"]
- in ["gptq", "awq", "bitsandbytes"]
- ):
- if model_config.quantization_config["quant_method"] == "gptq":
- model_kwargs["quantization_config"] = GPTQConfig(
- **model_config.quantization_config
+ self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
+ **bnb_config,
)
- elif model_config.quantization_config["quant_method"] == "awq":
- model_kwargs["quantization_config"] = AwqConfig(
- **model_config.quantization_config
- )
- elif model_config.quantization_config["quant_method"] == "bitsandbytes":
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
- **model_config.quantization_config
+ elif self.cfg.adapter == "lora" and (
+ "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
+ ):
+ bnb_config = {
+ "load_in_8bit": True,
+ }
+ # Exclude mamba blocks from int8 quantization for jamba
+ if self.cfg.model_config_type == "jamba":
+ bnb_config["llm_int8_skip_modules"] = ["mamba"]
+ self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
+ **bnb_config,
)
- elif cfg.adapter == "qlora" and cfg.load_in_4bit:
- bnb_config = {
- "load_in_4bit": True,
- "llm_int8_threshold": 6.0,
- "llm_int8_has_fp16_weight": False,
- "bnb_4bit_compute_dtype": cfg.torch_dtype,
- "bnb_4bit_use_double_quant": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_quant_storage": torch.bfloat16,
- }
- if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
- # for some reason, this causes the loss to be off by an order of magnitude
- # but deepspeed needs this still in bfloat16
- bnb_config["bnb_4bit_quant_storage"] = torch.float32
-
- if cfg.bnb_config_kwargs:
- bnb_config.update(cfg.bnb_config_kwargs)
-
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
- **bnb_config,
- )
- elif cfg.adapter == "lora" and cfg.load_in_8bit:
- bnb_config = {
- "load_in_8bit": True,
- }
- # Exclude mamba blocks from int8 quantization for jamba
- if cfg.model_config_type == "jamba":
- bnb_config["llm_int8_skip_modules"] = ["mamba"]
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
- **bnb_config,
- )
- if cfg.load_in_8bit and cfg.adapter is not None:
- model_kwargs["load_in_8bit"] = True
- if cfg.load_in_4bit and cfg.adapter is not None:
- model_kwargs["load_in_4bit"] = True
-
- # no longer needed per https://github.com/huggingface/transformers/pull/26610
- if "quantization_config" in model_kwargs or cfg.gptq:
- if "load_in_8bit" in model_kwargs:
- del model_kwargs["load_in_8bit"]
- if "load_in_4bit" in model_kwargs:
- del model_kwargs["load_in_4bit"]
-
- # sample packing uses custom FA2 patch
- if cfg.flash_attention:
- if not cfg.sample_packing:
- if cfg.s2_attention:
+ # no longer needed per https://github.com/huggingface/transformers/pull/26610
+ if "quantization_config" in self.model_kwargs or self.cfg.gptq:
+ if "load_in_8bit" in self.model_kwargs:
+ del self.model_kwargs["load_in_8bit"]
+ if "load_in_4bit" in self.model_kwargs:
+ del self.model_kwargs["load_in_4bit"]
+
+ def set_attention_config(self) -> None:
+ """
+ sample packing uses custom FA2 patch
+ """
+ if self.cfg.flash_attention:
+ if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
- # most other models support flash attention, we can define exceptions as they come up
- model_kwargs["attn_implementation"] = "flash_attention_2"
- model_config._attn_implementation = ( # pylint: disable=protected-access
+ self.model_kwargs["attn_implementation"] = "flash_attention_2"
+ self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
- else:
- if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
- model_kwargs["attn_implementation"] = "flash_attention_2"
- model_config._attn_implementation = ( # pylint: disable=protected-access
- "flash_attention_2"
+ elif self.cfg.sdp_attention:
+ self.model_kwargs["attn_implementation"] = "sdpa"
+ self.model_config._attn_implementation = ( # pylint: disable=protected-access
+ "sdpa"
+ )
+ elif self.cfg.eager_attention:
+ self.model_kwargs["attn_implementation"] = "eager"
+ self.model_config._attn_implementation = ( # pylint: disable=protected-access
+ "eager"
+ )
+
+ if self.cfg.low_cpu_mem_usage:
+ self.model_kwargs["low_cpu_mem_usage"] = True
+
+ def build_model(self, qlora_fsdp) -> bool:
+ def _configure_zero3_memory_efficient_loading():
+ """
+ Set the deepspeed config to load the model into RAM first before moving to VRAM.
+
+ We need to return hf_ds_cfg as it needs to exist before model loading.
+ """
+ hf_ds_cfg = None
+
+ if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
+ hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed)
+ hf_ds_cfg.fill_match(
+ "train_micro_batch_size_per_gpu", self.cfg.micro_batch_size
)
- else:
- model_kwargs["attn_implementation"] = "eager"
- model_config._attn_implementation = ( # pylint: disable=protected-access
- "eager"
+ hf_ds_cfg.fill_match(
+ "gradient_accumulation_steps", self.cfg.gradient_accumulation_steps
+ )
+ hf_ds_cfg.fill_match(
+ "train_batch_size",
+ int(os.getenv("WORLD_SIZE", "1"))
+ * self.cfg.micro_batch_size
+ * self.cfg.gradient_accumulation_steps,
)
- elif cfg.sdp_attention:
- model_kwargs["attn_implementation"] = "sdpa"
- model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
- elif cfg.eager_attention:
- model_kwargs["attn_implementation"] = "eager"
- model_config._attn_implementation = "eager" # pylint: disable=protected-access
+ if "device_map" in self.model_kwargs:
+ del self.model_kwargs["device_map"]
- if cfg.low_cpu_mem_usage:
- model_kwargs["low_cpu_mem_usage"] = True
+ transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
+ transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
+ lambda: True
+ )
- qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
+ return hf_ds_cfg
- try:
skip_move_to_device = False
if ( # pylint: disable=condition-evals-to-constant)
- (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
+ (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
and not qlora_fsdp
and False
):
- model = load_sharded_model(
- base_model,
- model_config,
- cfg,
- torch_dtype=cfg.torch_dtype,
+ self.model = load_sharded_model(
+ self.base_model,
+ self.model_config,
+ self.cfg,
+ torch_dtype=self.cfg.torch_dtype,
)
skip_move_to_device = True
elif (
qlora_fsdp
- and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
- and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
+ and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
+ and (
+ self.cfg.model_config_type == "dbrx"
+ or self.cfg.qlora_sharded_model_loading
+ )
):
- quant_storage = cfg.torch_dtype
+ quant_storage = self.cfg.torch_dtype
quantization_config = hasattr(
- model_config, "quantization_config"
- ) and getattr(model_config, "quantization_config")
+ self.model_config, "quantization_config"
+ ) and getattr(self.model_config, "quantization_config")
quantization_config = (
- quantization_config or model_kwargs["quantization_config"]
+ quantization_config or self.model_kwargs["quantization_config"]
)
- model = load_sharded_model_quant(
- base_model,
- model_config,
- cfg,
+ if self.cfg.is_multimodal:
+ self.model_config.text_config = self.text_model_config
+ self.model = load_sharded_model_quant(
+ self.base_model,
+ self.model_config,
+ self.cfg,
quant_storage=quant_storage,
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
- model_config.model_type == "llama"
- and not cfg.trust_remote_code
- and not cfg.gptq
+ self.model_config.model_type == "llama"
+ and not self.cfg.trust_remote_code
+ and not self.cfg.gptq
):
- if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
+ if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
- if "device_map" in model_kwargs:
- del model_kwargs["device_map"]
+ if "device_map" in self.model_kwargs:
+ del self.model_kwargs["device_map"]
+
+ _ = _configure_zero3_memory_efficient_loading()
- model = AutoModelForCausalLM.from_pretrained(
- base_model,
- config=model_config,
- **model_kwargs,
+ if self.cfg.is_multimodal:
+ self.model_config.text_config = self.text_model_config
+ self.model = self.AutoModelLoader.from_pretrained(
+ self.base_model,
+ config=self.model_config,
+ **self.model_kwargs,
)
- if cfg.flash_attention and not inference:
+ # TODO (MengqingCao) split these patches seperately
+ if self.cfg.flash_attention and not self.inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
- if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
+ if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("patching with SwiGLU")
- replace_llama_mlp_with_swiglu(model)
+ replace_llama_mlp_with_swiglu(self.model)
- if cfg.flash_attn_fuse_qkv:
+ if self.cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
- replace_llama_qkv_with_fused(model)
- elif model_type == "MambaLMHeadModel":
+ replace_llama_qkv_with_fused(self.model)
+ elif self.model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
- model_kwargs["dtype"] = model_kwargs["torch_dtype"]
- model_kwargs["device"] = torch.cuda.current_device()
- del model_kwargs["torch_dtype"]
- del model_kwargs["device_map"]
+ self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
+ self.model_kwargs["device"] = torch.cuda.current_device()
+ del self.model_kwargs["torch_dtype"]
+ del self.model_kwargs["device_map"]
- model = MambaLMHeadModel.from_pretrained(
- base_model,
- **model_kwargs,
+ self.model = MambaLMHeadModel.from_pretrained(
+ self.base_model,
+ **self.model_kwargs,
)
elif (
- model_type
- and model_type != "AutoModelForCausalLM"
- and not cfg.trust_remote_code
+ self.model_type
+ and self.model_type != "AutoModelForCausalLM"
+ and not self.cfg.trust_remote_code
):
- if cfg.gptq:
- model = AutoModelForCausalLM.from_pretrained(
- base_model,
- config=model_config,
- trust_remote_code=cfg.trust_remote_code or False,
- **model_kwargs,
+ if self.cfg.is_multimodal:
+ self.model_config.text_config = self.text_model_config
+ if self.cfg.gptq:
+ self.model = self.AutoModelLoader.from_pretrained(
+ self.base_model,
+ config=self.model_config,
+ trust_remote_code=self.cfg.trust_remote_code or False,
+ **self.model_kwargs,
)
else:
- model = getattr(transformers, model_type).from_pretrained(
- base_model,
- config=model_config,
- trust_remote_code=cfg.trust_remote_code or False,
- **model_kwargs,
+ self.model = getattr(transformers, self.model_type).from_pretrained(
+ self.base_model,
+ config=self.model_config,
+ trust_remote_code=self.cfg.trust_remote_code or False,
+ **self.model_kwargs,
)
else:
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
- hasattr(model_config, "max_seq_len")
- and model_config.max_seq_len
- and cfg.sequence_len > model_config.max_seq_len
+ hasattr(self.text_model_config, "max_seq_len")
+ and self.text_model_config.max_seq_len
+ and self.cfg.sequence_len > self.text_model_config.max_seq_len
):
- model_config.max_seq_len = cfg.sequence_len
- LOG.warning(f"increasing context length to {cfg.sequence_len}")
+ self.text_model_config.max_seq_len = self.cfg.sequence_len
+ LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
elif (
- hasattr(model_config, "max_sequence_length")
- and model_config.max_sequence_length
- and cfg.sequence_len > model_config.max_sequence_length
+ hasattr(self.text_model_config, "max_sequence_length")
+ and self.text_model_config.max_sequence_length
+ and self.cfg.sequence_len > self.text_model_config.max_sequence_length
):
- model_config.max_sequence_length = cfg.sequence_len
- LOG.warning(f"increasing context length to {cfg.sequence_len}")
- if cfg.gptq:
- model = AutoModelForCausalLM.from_pretrained(
- base_model,
- config=model_config,
- trust_remote_code=cfg.trust_remote_code or False,
- **model_kwargs,
+ self.text_model_config.max_sequence_length = self.cfg.sequence_len
+ LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
+ if self.cfg.gptq:
+ if self.cfg.is_multimodal:
+ self.model_config.text_config = self.text_model_config
+ self.model = self.AutoModelLoader.from_pretrained(
+ self.base_model,
+ config=self.model_config,
+ trust_remote_code=self.cfg.trust_remote_code or False,
+ **self.model_kwargs,
)
else:
- if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
+ if (
+ self.cfg.fsdp
+ and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
+ ):
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
- if "device_map" in model_kwargs:
- del model_kwargs["device_map"]
-
- model = AutoModelForCausalLM.from_pretrained(
- base_model,
- config=model_config,
- trust_remote_code=cfg.trust_remote_code or False,
- **model_kwargs,
+ if "device_map" in self.model_kwargs:
+ del self.model_kwargs["device_map"]
+
+ _ = _configure_zero3_memory_efficient_loading()
+
+ if self.cfg.is_multimodal:
+ self.model_config.text_config = self.text_model_config
+ self.model = self.AutoModelLoader.from_pretrained(
+ self.base_model,
+ config=self.model_config,
+ trust_remote_code=self.cfg.trust_remote_code or False,
+ **self.model_kwargs,
)
- except Exception as err: # pylint: disable=broad-exception-caught
- LOG.exception(err)
- raise err
-
- if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
- model = model.merge_and_unload()
-
- embeddings_len = (
- math.ceil(len(tokenizer) / 32) * 32
- if cfg.resize_token_embeddings_to_32x
- else len(tokenizer)
- )
- if (
- hasattr(model, "get_input_embeddings")
- and model.get_input_embeddings().num_embeddings < embeddings_len
- ):
- model.resize_token_embeddings(embeddings_len)
- else:
- model.tie_weights()
+ if is_deepspeed_zero3_enabled():
+ skip_move_to_device = True
- if (
- hasattr(model, "config")
- and hasattr(model.config, "max_position_embeddings")
- and model.config.max_position_embeddings
- and cfg.sequence_len > model.config.max_position_embeddings
- ):
- LOG.warning(
- f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
- )
- model.config.max_position_embeddings = cfg.sequence_len
+ return skip_move_to_device
- if (
- hasattr(model, "config")
- and hasattr(model.config, "bos_token_id")
- and model.config.bos_token_id
- and model.config.bos_token_id != tokenizer.bos_token_id
- ):
- model.config.bos_token_id = tokenizer.bos_token_id
+ def ajust_model_config(self) -> None:
+ if (
+ hasattr(self.model, "config")
+ and hasattr(self.model.config, "max_position_embeddings")
+ and self.model.config.max_position_embeddings
+ and self.cfg.sequence_len > self.model.config.max_position_embeddings
+ ):
+ LOG.warning(
+ f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}"
+ )
+ self.model.config.max_position_embeddings = self.cfg.sequence_len
- if (
- hasattr(model, "config")
- and hasattr(model.config, "eos_token_id")
- and model.config.eos_token_id
- and model.config.eos_token_id != tokenizer.eos_token_id
- ):
- model.config.eos_token_id = tokenizer.eos_token_id
-
- if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
- log_gpu_memory_usage(LOG, "after model load", model.device)
-
- # make sure these are fp32 per Ramesh et al. (2021)
- embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
- if not cfg.fsdp:
- # FSDP doesn't like mixed Float and BFloat16
- for name, module in model.named_modules():
- if "norm" in name or name.endswith(".gate"):
- module.to(torch.float32)
- if model_config.model_type == "btlm":
- # don't upcast lm_head for btlm
- continue
- if any(m in name for m in embedding_modules):
- if hasattr(module, "weight"):
- module.to(torch.float32)
+ if (
+ hasattr(self.model, "config")
+ and hasattr(self.model.config, "bos_token_id")
+ and self.model.config.bos_token_id
+ and self.model.config.bos_token_id != self.tokenizer.bos_token_id
+ ):
+ self.model.config.bos_token_id = self.tokenizer.bos_token_id
- needs_fa2_dtype = cfg.adapter or cfg.fsdp
- skip_prepare_model_for_kbit_training = False
+ if (
+ hasattr(self.model, "config")
+ and hasattr(self.model.config, "eos_token_id")
+ and self.model.config.eos_token_id
+ and self.model.config.eos_token_id != self.tokenizer.eos_token_id
+ ):
+ self.model.config.eos_token_id = self.tokenizer.eos_token_id
- if is_deepspeed_zero3_enabled():
+ def set_z3_leaf_modules(self) -> None:
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
- if cfg.model_config_type in MOE_ARCH_BLOCK:
- moe_blocks = MOE_ARCH_BLOCK[cfg.model_config_type]
+ if self.cfg.model_config_type in MOE_ARCH_BLOCK:
+ moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
set_z3_leaf_modules(
- model,
+ self.model,
[
- get_module_class_from_name(model, module_name)
+ get_module_class_from_name(self.model, module_name)
for module_name in moe_blocks
],
)
- if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
- # Qwen doesn't play nicely with LoRA if this is enabled
- skip_prepare_model_for_kbit_training = True
+ def prepare_model(self, qlora_fsdp) -> None:
+ skip_prepare_model_for_kbit_training = False
+ if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
+ # Qwen doesn't play nicely with LoRA if this is enabled
+ skip_prepare_model_for_kbit_training = True
- loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
- if cfg.adapter == "lora" and loftq_bits:
- skip_prepare_model_for_kbit_training = True
+ loftq_bits = (
+ self.cfg.peft
+ and self.cfg.peft.loftq_config
+ and self.cfg.peft.loftq_config.loftq_bits
+ )
+ if self.cfg.adapter == "lora" and loftq_bits:
+ skip_prepare_model_for_kbit_training = True
- if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
- # make sure everything is in the same dtype
- skip_prepare_model_for_kbit_training = True
+ if qlora_fsdp or (
+ self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
+ ):
+ # make sure everything is in the same dtype
+ skip_prepare_model_for_kbit_training = True
- if is_deepspeed_zero3_enabled():
- skip_prepare_model_for_kbit_training = True
+ if is_deepspeed_zero3_enabled():
+ skip_prepare_model_for_kbit_training = True
+
+ is_load_in_8bit = (
+ "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
+ )
+ is_load_in_4bit = (
+ "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
+ )
- if cfg.adapter in ["lora", "qlora"]:
- if cfg.gradient_checkpointing:
- model.gradient_checkpointing_enable(
- gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
- )
if (
- cfg.load_in_8bit or cfg.load_in_4bit
- ) and not skip_prepare_model_for_kbit_training:
+ not skip_prepare_model_for_kbit_training
+ and self.cfg.adapter in ["lora", "qlora"]
+ and (is_load_in_8bit or is_load_in_4bit)
+ ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
- model = prepare_model_for_kbit_training(
- model, use_gradient_checkpointing=cfg.gradient_checkpointing
+ self.model = prepare_model_for_kbit_training(
+ self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing
)
- needs_fa2_dtype = True
- # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
- # convert them back to fp16/bf16 for flash-attn compatibility.
- if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
- LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
- for name, module in model.named_modules():
+ def convert_embedding_modules_dtype(
+ self, embedding_modules, dist_dtype, before_kbit_train_or_finetune
+ ) -> None:
+ for name, module in self.model.named_modules():
if "norm" in name:
- module.to(cfg.torch_dtype)
+ module.to(dist_dtype)
+ if before_kbit_train_or_finetune:
+ if name.endswith(".gate"):
+ module.to(dist_dtype)
+ if self.model_config.model_type == "btlm":
+ # don't upcast lm_head for btlm
+ continue
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
- module.to(cfg.torch_dtype)
-
- lora_config = None
- if not reference_model or cfg.lora_model_dir:
- # if we're not loading the reference model, then we're loading the model for training
- # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
- if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto"] and not cfg.merge_lora:
- _, lora_config = load_lora(model, cfg, inference=False, config_only=True)
+ module.to(dist_dtype)
+
+ def apply_lora_patch(self) -> None:
+ if self.cfg.unsloth_lora_mlp:
+ from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
+
+ integrate_lora_mlp_patch(self.model)
+ if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
+ from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
+
+ integrate_lora_patch(self.model, self.cfg)
+ if self.cfg.unsloth_rope:
+ from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
+
+ integrate_rope_embeddings()
+
+ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
+ self.apply_patches()
+ self.set_auto_model_loader()
+ self.set_device_map_config()
+ if self.cfg.revision_of_model:
+ self.model_kwargs["revision"] = self.cfg.revision_of_model
+ self.set_quantization_config()
+ self.set_attention_config()
+
+ qlora_fsdp = self.cfg.fsdp and self.cfg.adapter == "qlora"
+ skip_move_to_device = False
+
+ try:
+ skip_move_to_device = self.build_model(qlora_fsdp)
+ except Exception as err: # pylint: disable=broad-exception-caught
+ LOG.exception(err)
+ raise err
+
+ if isinstance(self.model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
+ self.model = self.model.merge_and_unload()
+
+ embeddings_len = (
+ math.ceil(len(self.tokenizer) / 32) * 32
+ if self.cfg.resize_token_embeddings_to_32x
+ else len(self.tokenizer)
+ )
+ if (
+ hasattr(self.model, "get_input_embeddings")
+ and self.model.get_input_embeddings().num_embeddings < embeddings_len
+ ):
+ resize_kwargs = {}
+ if self.cfg.mean_resizing_embeddings is not None:
+ resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
+ self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
else:
- model, lora_config = load_adapter(model, cfg, cfg.adapter)
+ self.model.tie_weights()
+
+ self.ajust_model_config()
+
+ # log device memory usage
+ if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
+ log_gpu_memory_usage(LOG, "after model load", self.model.device)
+
+ # make sure these are fp32 per Ramesh et al. (2021)
+ embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
+ if not self.cfg.fsdp:
+ # FSDP doesn't like mixed Float and BFloat16
+ self.convert_embedding_modules_dtype(
+ embedding_modules,
+ dist_dtype=torch.float32,
+ before_kbit_train_or_finetune=True,
+ )
- if is_deepspeed_zero3_enabled():
- skip_move_to_device = True
+ if is_deepspeed_zero3_enabled():
+ self.set_z3_leaf_modules()
- if (
- cfg.ddp
- and not load_in_8bit
- and not (cfg.rl and cfg.load_in_4bit)
- and not skip_move_to_device
- ):
- # TODO revaldate this conditional
- model.to(f"cuda:{cfg.local_rank}")
+ needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp
+ if self.cfg.adapter in ["lora", "qlora"]:
+ needs_fa2_dtype = True
+ if self.cfg.gradient_checkpointing:
+ self.model.gradient_checkpointing_enable(
+ gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
+ )
- if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
- setattr(model, "is_parallelizable", True)
- setattr(model, "model_parallel", True)
+ self.prepare_model(qlora_fsdp)
- requires_grad = []
- for name, param in model.named_parameters(recurse=True):
- if param.requires_grad:
- requires_grad.append(f"{name}: {param.requires_grad}")
- if len(requires_grad) == 0:
- LOG.warning("there are no parameters that require gradient updates")
- if hasattr(model, "config"):
- model.config.use_cache = False
+ # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
+ # convert them back to fp16/bf16 for flash-attn compatibility.
+ if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp:
+ LOG.info(
+ "converting modules to %s for flash attention", self.cfg.torch_dtype
+ )
+ self.convert_embedding_modules_dtype(
+ embedding_modules,
+ dist_dtype=self.cfg.torch_dtype,
+ before_kbit_train_or_finetune=False,
+ )
- if cfg.flash_optimum:
- from optimum.bettertransformer import BetterTransformer
+ # ---------------------------------------------------------
+ # load lora or adapter
+ # ---------------------------------------------------------
+ lora_config = None
+ if not self.reference_model or self.cfg.lora_model_dir:
+ # if we're not loading the reference model, then we're loading the model for training
+ # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
+ if (
+ self.cfg.adapter
+ and self.cfg.rl in ["dpo", "ipo", "kto"]
+ and not self.cfg.merge_lora
+ ):
+ _, lora_config = load_lora(
+ self.model, self.cfg, inference=False, config_only=True
+ )
+ else:
+ self.model, lora_config = load_adapter(
+ self.model, self.cfg, self.cfg.adapter
+ )
- model = BetterTransformer.transform(model)
+ # ---------------------------------------------------------
+ # put model to accelerator
+ # ---------------------------------------------------------
+ is_load_in_8bit = (
+ "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
+ )
+ is_load_in_4bit = (
+ "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
+ )
+ if (
+ self.cfg.ddp
+ and not is_load_in_8bit
+ and not (self.cfg.rl and is_load_in_4bit)
+ and not skip_move_to_device
+ ):
+ # TODO revaldate this conditional
+ self.model.to(f"cuda:{self.cfg.local_rank}")
- if cfg.adapter is not None:
- log_gpu_memory_usage(LOG, "after adapters", model.device)
+ if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
+ setattr(self.model, "is_parallelizable", True)
+ setattr(self.model, "model_parallel", True)
- if cfg.unsloth_lora_mlp:
- from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
+ # ---------------------------------------------------------
+ # parameters that require gradient updates
+ # ---------------------------------------------------------
+ requires_grad = []
+ for name, param in self.model.named_parameters(recurse=True):
+ if param.requires_grad:
+ requires_grad.append(f"{name}: {param.requires_grad}")
+ if len(requires_grad) == 0:
+ LOG.warning("there are no parameters that require gradient updates")
+ if hasattr(self.model, "config"):
+ self.model.config.use_cache = False
- integrate_lora_mlp_patch(model)
- if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
- from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
+ if self.cfg.flash_optimum:
+ from optimum.bettertransformer import BetterTransformer
- integrate_lora_patch(model, cfg)
+ self.model = BetterTransformer.transform(self.model)
- if cfg.unsloth_rope:
- from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
+ if self.cfg.adapter is not None:
+ log_gpu_memory_usage(LOG, "after adapters", self.model.device)
- integrate_rope_embeddings()
+ self.apply_lora_patch()
- for _ in range(3):
- gc.collect()
- torch.cuda.empty_cache()
+ for _ in range(3):
+ gc.collect()
+ torch.cuda.empty_cache()
- # TODO resume_from_checkpoint handling
- return model, lora_config
+ # TODO resume_from_checkpoint handling
+ return self.model, lora_config
+
+
+def load_model(
+ cfg: DictDefault,
+ tokenizer: PreTrainedTokenizerBase,
+ *,
+ processor: ProcessorMixin = None, # pylint: disable=unused-argument
+ inference: bool = False,
+ reference_model: bool = False,
+ **kwargs, # pylint: disable=unused-argument
+) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
+ """
+ Load a model for a given configuration and tokenizer.
+ """
+ loader = ModelLoader(
+ cfg,
+ tokenizer,
+ processor=processor,
+ inference=inference,
+ reference_model=reference_model,
+ **kwargs,
+ )
+ return loader.load_model()
def load_adapter(model, cfg, adapter, inference=False):
@@ -1020,12 +1277,17 @@ def load_lora(model, cfg, inference=False, config_only=False):
from peft import LoraConfig, get_peft_model
- lora_target_modules = list(cfg.lora_target_modules or [])
+ lora_target_modules = cfg.lora_target_modules or []
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
- lora_target_modules = list(set(lora_target_modules + linear_names))
+ lora_target_modules_as_list = (
+ lora_target_modules
+ if isinstance(lora_target_modules, list)
+ else [lora_target_modules]
+ )
+ lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
@@ -1044,6 +1306,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
layers_to_transform=cfg.peft_layers_to_transform,
+ layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
@@ -1100,9 +1363,20 @@ def load_lora(model, cfg, inference=False, config_only=False):
def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
+ weight_mismatch = False
+ bias_mismatch = False
try:
- if module.weight.dtype != dtype:
- print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
- module.to(dtype)
+ weight_mismatch = module.weight.dtype != dtype
except AttributeError:
pass
+ try:
+ bias_mismatch = module.bias.dtype != dtype
+ except AttributeError:
+ pass
+
+ if weight_mismatch:
+ print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
+ if bias_mismatch:
+ print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
+ if weight_mismatch or bias_mismatch:
+ module.to(dtype)
diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py
index 957ca57464..db14a6819e 100644
--- a/src/axolotl/utils/samplers/multipack.py
+++ b/src/axolotl/utils/samplers/multipack.py
@@ -11,6 +11,8 @@
import numpy as np
from torch.utils.data import BatchSampler, Sampler
+from axolotl.utils.distributed import reduce_and_broadcast
+
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
@@ -131,6 +133,8 @@ def __init__(
self.eff_total_used = 0
self.eff_total_slots = 0
+ self.len_across_ranks = None
+
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -174,16 +178,45 @@ def num_batches(self):
def efficiency(self):
return self.eff_total_used / self.eff_total_slots
+ def gather_efficiency(self):
+ def calc_sample_packing_eff_est(estimates: List[float]):
+ LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
+ return math.floor(0.997 * max(estimates))
+
+ sample_packing_actual_eff_all = reduce_and_broadcast(
+ lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
+ calc_sample_packing_eff_est,
+ )
+ sample_packing_eff_est = (
+ math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
+ )
+ return sample_packing_eff_est
+
+ def gather_len_batches(self, num):
+ def calc_min_len(estimates: list[(int, float)]):
+ LOG.info(f"gather_len_batches: {repr(estimates)}")
+ return math.floor(0.998 * min(estimates))
+
+ min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
+ return min_len_batches
+
def __len__(self):
- self.num_batches()
- return self._len_est()
+ if not self.len_across_ranks:
+ len_batches = self.num_batches()
+ self.len_across_ranks = self.gather_len_batches(len_batches)
+ return self.len_across_ranks
def _len_est(self):
+ efficiency = (
+ self.packing_efficiency_estimate
+ if self.packing_efficiency_estimate
+ else self.gather_efficiency()
+ )
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
- f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
+ f"packing_efficiency_estimate: {efficiency} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
@@ -195,7 +228,7 @@ def _len_est(self):
* math.floor(
0.99
* lengths_sum_per_device
- / self.packing_efficiency_estimate
+ / efficiency
// (self.batch_max_len * self.batch_size)
)
- 1
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 26796f2e53..7ebf384aff 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -11,7 +11,7 @@
import torch
import torch.cuda
from accelerate.logging import get_logger
-from datasets import set_caching_enabled
+from datasets import disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
@contextmanager
def disable_datasets_caching():
try:
- set_caching_enabled(False)
+ disable_caching()
yield
finally:
- set_caching_enabled(True)
+ enable_caching()
def add_position_ids(sample):
@@ -217,6 +217,24 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Dropping Long Sequences",
)
+ # drop samples with where the number of elements with labels not equal to -100 is zero
+ def drop_no_trainable_tokens(sample):
+ return np.sum(np.array(sample["labels"]) != -100) > 0
+
+ train_dataset = train_dataset.filter(
+ drop_no_trainable_tokens,
+ num_proc=cfg.dataset_processes,
+ load_from_cache_file=not cfg.is_preprocess,
+ desc="Drop Samples with Zero Trainable Tokens",
+ )
+ if eval_dataset:
+ eval_dataset = eval_dataset.filter(
+ drop_no_trainable_tokens,
+ num_proc=cfg.dataset_processes,
+ load_from_cache_file=not cfg.is_preprocess,
+ desc="Drop Samples with Zero Trainable Tokens",
+ )
+
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
@@ -288,7 +306,11 @@ def process_pretraining_datasets_for_packing(
def calculate_total_num_steps(cfg, train_dataset, update=True):
- if not cfg.total_num_tokens:
+ if (
+ not cfg.total_num_tokens
+ and not cfg.skip_prepare_dataset
+ and not cfg.reward_model
+ ):
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
@@ -301,7 +323,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
skip_estimates = cfg.model_config_type == "mamba"
- if not skip_estimates and not cfg.total_supervised_tokens:
+ if (
+ not skip_estimates
+ and not cfg.total_supervised_tokens
+ and not cfg.skip_prepare_dataset
+ and not cfg.reward_model
+ ):
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
@@ -339,7 +366,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
main_process_only=True,
)
else:
- if cfg.flash_attention:
+ if cfg.flash_attention and not cfg.multipack_real_batches:
sampler_batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else:
@@ -399,12 +426,16 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None):
+ from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
+ # If we don't assign this, it doesn't actually get set in the accelerate weakref
+ _ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
def setup_fsdp_envs(cfg):
@@ -456,13 +487,15 @@ def prepare_opinionated_env(cfg):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
-def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
+def setup_trainer(
+ cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
+):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
- trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
+ trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
else:
- trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
+ trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset
diff --git a/tests/core/chat/__init__.py b/tests/core/chat/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/core/chat/format/__init__.py b/tests/core/chat/format/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/core/chat/test_messages.py b/tests/core/chat/test_messages.py
new file mode 100644
index 0000000000..b3be56c590
--- /dev/null
+++ b/tests/core/chat/test_messages.py
@@ -0,0 +1,197 @@
+"""
+Tests for the chat messages module
+"""
+import unittest
+
+import pytest
+from transformers import AddedToken, AutoTokenizer
+
+from axolotl.core.chat.format.chatml import format_message
+from axolotl.core.chat.messages import ChatFormattedChats, Chats
+
+
+@pytest.fixture(scope="session", name="llama_tokenizer")
+def llama_tokenizer_fixture():
+ return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
+
+
+@pytest.fixture(scope="session", name="chatml_tokenizer")
+def llama_tokenizer_w_chatml(llama_tokenizer):
+ llama_tokenizer.add_special_tokens(
+ {
+ "eos_token": AddedToken(
+ "<|im_end|>", rstrip=False, lstrip=False, normalized=False
+ )
+ }
+ )
+ llama_tokenizer.add_tokens(
+ [
+ AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
+ ]
+ )
+
+ return llama_tokenizer
+
+
+@pytest.fixture(scope="session", name="chat_msgs")
+def chat_msgs_fixture():
+ return {
+ "conversation": [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "value": "You are a helpful assistant."},
+ ],
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "value": "What is today's stock price of Apple?"},
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "tool_call",
+ "value": {
+ "name": "get_date",
+ "arguments": {},
+ },
+ },
+ {
+ "type": "tool_call",
+ "value": {
+ "name": "get_stock_price",
+ "arguments": {"symbol": "AAPL"},
+ },
+ },
+ ],
+ "weight": 1,
+ },
+ {
+ "role": "tool",
+ "content": [
+ {
+ "type": "tool_response",
+ "value": {
+ "name": "get_date",
+ "content": {"date": "2024-09-09"},
+ },
+ },
+ {
+ "type": "tool_response",
+ "value": {
+ "name": "get_stock_price",
+ "content": {"symbol": "AAPL", "price": 123.45},
+ },
+ },
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "text",
+ "value": "The stock price of Apple is $123.45.\n",
+ "weight": 0,
+ },
+ {
+ "type": "text",
+ "value": "The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.",
+ },
+ {
+ "type": "text",
+ "value": "The stock price of Apple on September 9, 2024 is $123.45.",
+ },
+ ],
+ "weight": 1,
+ },
+ ]
+ }
+
+
+class TestMessagesCase:
+ """
+ Test cases for the chat messages module
+ """
+
+ def test_tool_call_stringify(self, chat_msgs):
+ chat_msgs_as_obj = Chats(**chat_msgs)
+ assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str(
+ chat_msgs_as_obj.conversation[2].content[1].value
+ )
+
+ def test_chatml_formatted_wrapper(self, chat_msgs):
+ chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
+ target_chatml = """<|im_start|>system
+You are a helpful assistant.<|im_end|>
+<|im_start|>user
+What is today's stock price of Apple?<|im_end|>
+<|im_start|>assistant
+
+{"name": "get_date", "arguments": {}}
+
+
+{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}
+
+<|im_end|>
+<|im_start|>tool
+
+{"name": "get_date", "content": {"date": "2024-09-09"}}
+
+
+{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}}
+
+<|im_end|>
+<|im_start|>assistant
+The stock price of Apple is $123.45.
+The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n"""
+ assert target_chatml == str(chat_msg_formatted)
+
+ def test_chatml_formatting_tool_call(self, chat_msgs):
+ chat_msgs_as_obj = Chats(**chat_msgs)
+ target_chatml_turn2 = """<|im_start|>assistant\n\n{"name": "get_date", "arguments": {}}\n\n\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n\n<|im_end|>\n"""
+ assert target_chatml_turn2 == str(
+ format_message(chat_msgs_as_obj.conversation[2])
+ )
+
+ def test_train_labels(self, chatml_tokenizer, chat_msgs):
+ chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
+ tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer)
+ # fmt: off
+ target_labels = [
+ -100, -100, -100, # role
+ 27, 14506, 13735, 397, 5018, 609, 794,
+ 330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524,
+ 14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794,
+ 330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314,
+ 794, 330, 84016, 43, 96742, 524, 14506, 13735, 397,
+ 128256, # <|im_end|>
+ -100 # trailing newline
+ ]
+ # fmt: on
+ assert tokenized["labels"] == target_labels
+
+ def test_train_labels_2(self, chatml_tokenizer, chat_msgs):
+ # also test if indivudal contents are set not to train
+ chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
+ tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer)
+ # fmt: off
+ target_labels = [
+ -100, -100, -100, # role
+ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response
+ 27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430,
+ 315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457,
+ 5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315,
+ 8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400,
+ 4513, 13, 1774, 13,
+ 128256, # <|im_end|>
+ -100, # trailing newline
+ ]
+ # fmt: on
+ assert tokenized["labels"] == target_labels
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/e2e/integrations/__init__.py b/tests/e2e/integrations/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/e2e/integrations/liger.py b/tests/e2e/integrations/liger.py
new file mode 100644
index 0000000000..4497cebe32
--- /dev/null
+++ b/tests/e2e/integrations/liger.py
@@ -0,0 +1,110 @@
+"""
+Simple end-to-end test for Liger integration
+"""
+
+import unittest
+from pathlib import Path
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+from ..utils import with_temp_dir
+
+
+class LigerIntegrationTestCase(unittest.TestCase):
+ """
+ e2e tests for liger integration with Axolotl
+ """
+
+ @with_temp_dir
+ def test_llama_wo_flce(self, temp_dir):
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "tokenizer_type": "LlamaTokenizer",
+ "plugins": [
+ "axolotl.integrations.liger.LigerPlugin",
+ ],
+ "liger_rope": True,
+ "liger_rms_norm": True,
+ "liger_swiglu": True,
+ "liger_cross_entropy": True,
+ "liger_fused_linear_cross_entropy": False,
+ "sequence_len": 1024,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 8,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "save_safetensors": True,
+ "bf16": "auto",
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "model.safetensors").exists()
+
+ @with_temp_dir
+ def test_llama_w_flce(self, temp_dir):
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "tokenizer_type": "LlamaTokenizer",
+ "plugins": [
+ "axolotl.integrations.liger.LigerPlugin",
+ ],
+ "liger_rope": True,
+ "liger_rms_norm": True,
+ "liger_swiglu": True,
+ "liger_cross_entropy": False,
+ "liger_fused_linear_cross_entropy": True,
+ "sequence_len": 1024,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 8,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "save_safetensors": True,
+ "bf16": "auto",
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "model.safetensors").exists()
diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py
new file mode 100644
index 0000000000..65d26bb824
--- /dev/null
+++ b/tests/e2e/multigpu/test_eval.py
@@ -0,0 +1,155 @@
+"""
+E2E tests for multigpu eval
+"""
+import logging
+import os
+import unittest
+from pathlib import Path
+
+import yaml
+from accelerate.test_utils import execute_subprocess_async
+
+from axolotl.utils.dict import DictDefault
+
+from ..utils import with_temp_dir
+
+LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
+os.environ["WANDB_DISABLED"] = "true"
+
+AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
+
+
+class TestMultiGPUEval(unittest.TestCase):
+ """
+ Test case for MultiGPU Eval Sample Packing
+ """
+
+ @with_temp_dir
+ def test_eval_sample_packing(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "load_in_8bit": False,
+ "load_in_4bit": True,
+ "strict": False,
+ "sequence_len": 2048,
+ "adapter": "qlora",
+ "sample_packing": True,
+ "eval_sample_packing": True,
+ "pad_to_sequence_len": True,
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "lora_modules_to_save": ["embed_tokens", "lm_head"],
+ "val_set_size": 0.1,
+ "special_tokens": {"pad_token": "<|end_of_text|>"},
+ "datasets": [
+ {
+ "path": "teknium/GPT4-LLM-Cleaned",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 5,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 4,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_8bit",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "loss_watchdog_threshold": 5.0,
+ "loss_watchdog_patience": 3,
+ "bf16": "auto",
+ "warmup_steps": 1,
+ "evals_per_epoch": 2,
+ "eval_max_new_tokens": 128,
+ "saves_per_epoch": 1,
+ "logging_steps": 1,
+ "weight_decay": 0.0,
+ }
+ )
+
+ # write cfg to yaml file
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
+ with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+ fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
+
+ execute_subprocess_async(
+ [
+ "accelerate",
+ "launch",
+ "--num-processes",
+ "2",
+ "-m",
+ "axolotl.cli.train",
+ str(Path(temp_dir) / "config.yaml"),
+ ]
+ )
+
+ @with_temp_dir
+ def test_eval(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "load_in_8bit": False,
+ "load_in_4bit": True,
+ "strict": False,
+ "sequence_len": 2048,
+ "adapter": "qlora",
+ "sample_packing": True,
+ "eval_sample_packing": False,
+ "pad_to_sequence_len": True,
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "lora_modules_to_save": ["embed_tokens", "lm_head"],
+ "val_set_size": 0.1,
+ "special_tokens": {"pad_token": "<|end_of_text|>"},
+ "datasets": [
+ {
+ "path": "teknium/GPT4-LLM-Cleaned",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 5,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 4,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_8bit",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "loss_watchdog_threshold": 5.0,
+ "loss_watchdog_patience": 3,
+ "bf16": "auto",
+ "warmup_steps": 1,
+ "evals_per_epoch": 2,
+ "eval_max_new_tokens": 128,
+ "saves_per_epoch": 1,
+ "logging_steps": 1,
+ "weight_decay": 0.0,
+ }
+ )
+
+ # write cfg to yaml file
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
+ with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+ fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
+
+ execute_subprocess_async(
+ [
+ "accelerate",
+ "launch",
+ "--num-processes",
+ "2",
+ "-m",
+ "axolotl.cli.train",
+ str(Path(temp_dir) / "config.yaml"),
+ ]
+ )
diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py
index 344c57fb85..957a6a9e36 100644
--- a/tests/e2e/multigpu/test_llama.py
+++ b/tests/e2e/multigpu/test_llama.py
@@ -10,6 +10,7 @@
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
+from huggingface_hub import snapshot_download
from axolotl.utils.dict import DictDefault
@@ -18,6 +19,14 @@
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
+AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
+
+
+@pytest.fixture(scope="session", autouse=True)
+def download_model():
+ # download the model
+ snapshot_download("TinyLlama/TinyLlama_v1.1")
+
class TestMultiGPULlama(unittest.TestCase):
"""
@@ -339,3 +348,115 @@ def test_fsdp_qlora_prequant_packed(self, temp_dir):
str(Path(temp_dir) / "config.yaml"),
]
)
+
+ @with_temp_dir
+ def test_ds_zero3_packed(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "TinyLlama/TinyLlama_v1.1",
+ "tokenizer_type": "LlamaTokenizer",
+ "sample_packing": True,
+ "eval_sample_packing": False,
+ "pad_to_sequence_len": True,
+ "sequence_len": 2048,
+ "val_set_size": 0.05,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "tatsu-lab/alpaca",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 100,
+ "micro_batch_size": 4,
+ "gradient_accumulation_steps": 4,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
+ }
+ )
+
+ # write cfg to yaml file
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
+ with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+ fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
+
+ execute_subprocess_async(
+ [
+ "accelerate",
+ "launch",
+ "--num-processes",
+ "2",
+ "-m",
+ "axolotl.cli.train",
+ str(Path(temp_dir) / "config.yaml"),
+ ]
+ )
+
+ @with_temp_dir
+ def test_ds_zero3_qlora_packed(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "TinyLlama/TinyLlama_v1.1",
+ "tokenizer_type": "LlamaTokenizer",
+ "load_in_4bit": True,
+ "adapter": "qlora",
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "sample_packing": True,
+ "eval_sample_packing": False,
+ "pad_to_sequence_len": True,
+ "sequence_len": 2048,
+ "val_set_size": 0.05,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "tatsu-lab/alpaca",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 100,
+ "micro_batch_size": 4,
+ "gradient_accumulation_steps": 4,
+ "output_dir": temp_dir,
+ "learning_rate": 0.0001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
+ }
+ )
+
+ # write cfg to yaml file
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
+ with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+ fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
+
+ execute_subprocess_async(
+ [
+ "accelerate",
+ "launch",
+ "--num-processes",
+ "2",
+ "-m",
+ "axolotl.cli.train",
+ str(Path(temp_dir) / "config.yaml"),
+ ]
+ )
diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py
new file mode 100644
index 0000000000..2513be69e5
--- /dev/null
+++ b/tests/e2e/multigpu/test_qwen2.py
@@ -0,0 +1,98 @@
+"""
+E2E tests for multigpu qwen2
+"""
+
+import logging
+import os
+import unittest
+from pathlib import Path
+
+import yaml
+from accelerate.test_utils import execute_subprocess_async
+
+from axolotl.utils.dict import DictDefault
+
+from ..utils import with_temp_dir
+
+LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class TestMultiGPUQwen2(unittest.TestCase):
+ """
+ Test case for Llama models using LoRA
+ """
+
+ @with_temp_dir
+ def test_qlora_fsdp_dpo(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "Qwen/Qwen2-1.5B",
+ "load_in_4bit": True,
+ "rl": "dpo",
+ "chat_template": "chatml",
+ "sequence_len": 2048,
+ "adapter": "qlora",
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.05,
+ "datasets": [
+ {
+ "path": "Intel/orca_dpo_pairs",
+ "split": "train",
+ "type": "chatml.intel",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 100,
+ "warmup_steps": 20,
+ "micro_batch_size": 4,
+ "gradient_accumulation_steps": 2,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "bf16": "auto",
+ "tf32": True,
+ "gradient_checkpointing": True,
+ "gradient_checkpointing_kwargs": {
+ "use_reentrant": False,
+ },
+ "fsdp": [
+ "full_shard",
+ "auto_wrap",
+ ],
+ "fsdp_config": {
+ "fsdp_limit_all_gathers": True,
+ "fsdp_offload_params": False,
+ "fsdp_sync_module_states": True,
+ "fsdp_use_orig_params": False,
+ "fsdp_cpu_ram_efficient_loading": False,
+ "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
+ "fsdp_state_dict_type": "FULL_STATE_DICT",
+ "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
+ "fsdp_sharding_strategy": "FULL_SHARD",
+ },
+ }
+ )
+
+ # write cfg to yaml file
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
+ with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+ fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
+
+ execute_subprocess_async(
+ [
+ "accelerate",
+ "launch",
+ "--num-processes",
+ "2",
+ "-m",
+ "axolotl.cli.train",
+ str(Path(temp_dir) / "config.yaml"),
+ ]
+ )
diff --git a/tests/e2e/patched/test_unsloth_integration.py b/tests/e2e/patched/test_unsloth_integration.py
index 39c7abb1c1..8882742861 100644
--- a/tests/e2e/patched/test_unsloth_integration.py
+++ b/tests/e2e/patched/test_unsloth_integration.py
@@ -1,22 +1,12 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
-from axolotl.monkeypatch.unsloth_ import (
- check_cel_is_patchable,
- check_self_attn_is_patchable,
-)
+from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
- def test_is_cel_patchable(self):
- # ensures the current version of transformers has loss code that matches our patching code
- self.assertTrue(
- check_cel_is_patchable(),
- "HF transformers loss code has changed and isn't patchable",
- )
-
def test_is_self_attn_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
diff --git a/tests/e2e/test_load_model.py b/tests/e2e/test_load_model.py
new file mode 100644
index 0000000000..31a9b1a878
--- /dev/null
+++ b/tests/e2e/test_load_model.py
@@ -0,0 +1,95 @@
+"""Module for testing ModelLoader."""
+
+import shutil
+import tempfile
+
+import pytest
+import torch
+
+from axolotl.utils.dict import DictDefault
+from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
+
+
+@pytest.fixture(name="temp_dir")
+def fixture_temp_dir():
+ temp_dir = tempfile.mkdtemp()
+ yield temp_dir
+ shutil.rmtree(temp_dir)
+
+
+class TestLoadModelUtils:
+ """
+ Testing module testing ModelLoader.
+ """
+
+ def setup_method(self):
+ # load config
+ self.cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "tokenizer_type": "LlamaTokenizer",
+ "tokenizer_config": "JackFram/llama-68m",
+ "sequence_len": 1024,
+ "load_in_8bit": False,
+ "adapter": "lora",
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 8,
+ "gradient_accumulation_steps": 1,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ }
+ )
+ self.model_loader = ( # pylint: disable=attribute-defined-outside-init
+ ModelLoader(
+ cfg=self.cfg,
+ tokenizer="",
+ )
+ )
+
+ @pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
+ @pytest.mark.parametrize(
+ "dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
+ )
+ @pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
+ def test_convert_embedding_modules_dtype(
+ self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
+ ):
+ self.cfg.output_dir = temp_dir
+ self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
+ self.model_loader.model, _ = load_model(
+ self.cfg,
+ self.model_loader.tokenizer,
+ inference=False,
+ reference_model=True,
+ )
+ self.model_loader.convert_embedding_modules_dtype(
+ embedding_modules, dist_dtype, before_kbit_train_or_finetune
+ )
+ for name, module in self.model_loader.model.named_modules():
+ if (
+ "norm" in name
+ or (before_kbit_train_or_finetune and name.endswith(".gate"))
+ or (
+ any(m in name for m in embedding_modules)
+ and hasattr(module, "weight")
+ )
+ ):
+ for _, param in module.named_parameters():
+ assert param.dtype == dist_dtype
diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py
new file mode 100644
index 0000000000..73f9e60bac
--- /dev/null
+++ b/tests/e2e/test_packing_loss.py
@@ -0,0 +1,74 @@
+"""
+E2E tests for packed training
+"""
+
+import logging
+import os
+import unittest
+
+from tbparse import SummaryReader
+from transformers.utils import is_torch_bf16_gpu_available
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+from .utils import most_recent_subdir, with_temp_dir
+
+LOG = logging.getLogger("axolotl.tests.e2e")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class TestPackedLlama(unittest.TestCase):
+ """
+ Test case for Packed training of llama models
+ """
+
+ @with_temp_dir
+ def test_loss_packed(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "HuggingFaceTB/SmolLM-135M",
+ "sequence_len": 1024,
+ "sample_packing": True,
+ "flash_attention": True,
+ "val_set_size": 0.0,
+ "special_tokens": {
+ "pad_token": "<|endoftext|>",
+ },
+ "datasets": [
+ {
+ "path": "vicgalle/alpaca-gpt4",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 4,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 5,
+ "use_tensorboard": True,
+ }
+ )
+ if is_torch_bf16_gpu_available():
+ cfg.bf16 = True
+ else:
+ cfg.fp16 = True
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+
+ tb_log_path = most_recent_subdir(temp_dir + "/runs")
+ event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
+ reader = SummaryReader(event_file)
+ df = reader.scalars # pylint: disable=invalid-name
+ df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
+ assert df.value.values[-1] < 2.0, "Loss is too high"
diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_llama.py
new file mode 100644
index 0000000000..27ac3e25f1
--- /dev/null
+++ b/tests/e2e/test_reward_model_llama.py
@@ -0,0 +1,74 @@
+"""
+E2E tests for reward model lora llama
+"""
+
+import logging
+import os
+import unittest
+from pathlib import Path
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+from .utils import with_temp_dir
+
+LOG = logging.getLogger("axolotl.tests.e2e")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class TestRewardModelLoraLlama(unittest.TestCase):
+ """
+ Test case for Llama reward models using LoRA
+ """
+
+ @with_temp_dir
+ def test_rm_fft(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "model_type": "AutoModelForSequenceClassification",
+ "tokenizer_type": "LlamaTokenizer",
+ "chat_template": "alpaca",
+ "reward_model": True,
+ "sequence_len": 1024,
+ "pad_to_sequence_len": True,
+ "adapter": "lora",
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.0,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "argilla/distilabel-intel-orca-dpo-pairs",
+ "type": "bradley_terry.chat_template",
+ },
+ ],
+ "remove_unused_columns": False,
+ "max_steps": 10,
+ "num_epochs": 1,
+ "micro_batch_size": 4,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_bnb_8bit",
+ "lr_scheduler": "cosine",
+ "gradient_checkpointing": True,
+ "warmup_ratio": 0.1,
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py
new file mode 100644
index 0000000000..43423f7255
--- /dev/null
+++ b/tests/prompt_strategies/conftest.py
@@ -0,0 +1,71 @@
+"""
+shared fixtures for prompt strategies tests
+"""
+
+import pytest
+from datasets import Dataset
+from transformers import AutoTokenizer
+
+
+@pytest.fixture(name="assistant_dataset")
+def fixture_assistant_dataset():
+ return Dataset.from_list(
+ [
+ {
+ "messages": [
+ {"role": "user", "content": "hello"},
+ {"role": "assistant", "content": "hello"},
+ {"role": "user", "content": "goodbye"},
+ {"role": "assistant", "content": "goodbye"},
+ ]
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="sharegpt_dataset")
+def fixture_sharegpt_dataset():
+ # pylint: disable=duplicate-code
+ return Dataset.from_list(
+ [
+ {
+ "conversations": [
+ {"from": "human", "value": "hello"},
+ {"from": "gpt", "value": "hello"},
+ {"from": "human", "value": "goodbye"},
+ {"from": "gpt", "value": "goodbye"},
+ ]
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="basic_dataset")
+def fixture_basic_dataset():
+ # pylint: disable=duplicate-code
+ return Dataset.from_list(
+ [
+ {
+ "conversations": [
+ {"from": "system", "value": "You are an AI assistant."},
+ {"from": "human", "value": "Hello"},
+ {"from": "assistant", "value": "Hi there!"},
+ {"from": "human", "value": "How are you?"},
+ {"from": "assistant", "value": "I'm doing well, thank you!"},
+ ]
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="llama3_tokenizer")
+def fixture_llama3_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
+
+ return tokenizer
+
+
+@pytest.fixture(name="phi35_tokenizer")
+def fixture_phi35_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
+ return tokenizer
diff --git a/tests/prompt_strategies/messages/__init__.py b/tests/prompt_strategies/messages/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/prompt_strategies/messages/test_chat.py b/tests/prompt_strategies/messages/test_chat.py
new file mode 100644
index 0000000000..96c4b6cbbf
--- /dev/null
+++ b/tests/prompt_strategies/messages/test_chat.py
@@ -0,0 +1,62 @@
+"""
+tests for chat_template prompt strategy
+"""
+# pylint: disable=duplicate-code
+import logging
+import unittest
+
+from axolotl.prompt_strategies.messages.chat import load
+from axolotl.utils.dict import DictDefault
+
+logging.basicConfig(level=logging.DEBUG)
+LOG = logging.getLogger("axolotl")
+
+
+class TestMessagesChatLlama3:
+ """
+ Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy.
+ """
+
+ def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
+ LOG.info("Loading llama-3 tokenizer with assistant dataset")
+ strategy = load(
+ llama3_tokenizer,
+ DictDefault(
+ {
+ "train_on_inputs": False,
+ "sequence_len": 512,
+ }
+ ),
+ DictDefault(
+ {
+ "chat_template": "llama3",
+ "message_field_role": "role",
+ "message_field_content": "content",
+ "field_messages": "messages",
+ }
+ ),
+ )
+ res = strategy.wrap_dataset(assistant_dataset)
+ input_ids = res[0]["input_ids"]
+ # fmt: off
+ expected_input_ids = [
+ 128000, # bos
+ 128006, 882, 128007, # user header
+ 271, 15339, 128009, # user prompt eot
+ 128006, 78191, 128007, # assistant header
+ 271, 15339, 128009, # assistant response eot
+ 128006, 882, 128007,
+ 271, 19045, 29474, 128009,
+ 128006, 78191, 128007,
+ 271, 19045, 29474, 128009,
+ ]
+ # fmt: on
+ LOG.debug(f"Expected input_ids: {expected_input_ids}")
+ LOG.debug(f"Actual input_ids: {input_ids}")
+ assert (
+ input_ids == expected_input_ids
+ ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/prompt_strategies/test_chat_template_utils.py b/tests/prompt_strategies/test_chat_template_utils.py
new file mode 100644
index 0000000000..b63c9aa179
--- /dev/null
+++ b/tests/prompt_strategies/test_chat_template_utils.py
@@ -0,0 +1,125 @@
+"""
+Tests for utils in axolotl.utils.chat_templates
+"""
+import unittest
+
+import pytest
+from transformers import AutoTokenizer
+
+from axolotl.utils.chat_templates import (
+ _CHAT_TEMPLATES,
+ extract_chat_template_args,
+ get_chat_template,
+)
+
+
+@pytest.fixture(name="llama3_tokenizer")
+def fixture_llama3_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
+
+ return tokenizer
+
+
+class TestGetChatTemplateUtils:
+ """
+ Tests the get_chat_template function.
+ """
+
+ def test_known_chat_template(self):
+ chat_template_str = get_chat_template("llama3")
+ assert chat_template_str == _CHAT_TEMPLATES["llama3"]
+
+ def test_invalid_chat_template(self):
+ with pytest.raises(ValueError) as exc:
+ get_chat_template("invalid_template")
+ assert str(exc) == "Template 'invalid_template' not found."
+
+ def test_tokenizer_default_no_tokenizer(self):
+ with pytest.raises(ValueError):
+ get_chat_template("tokenizer_default", tokenizer=None)
+
+ def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
+ with pytest.raises(ValueError):
+ get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
+
+ def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
+ llama3_tokenizer.chat_template = "test_template"
+ chat_template_str = get_chat_template(
+ "tokenizer_default", tokenizer=llama3_tokenizer
+ )
+ assert chat_template_str == "test_template"
+
+ def test_tokenizer_default_fallback_no_tokenizer(self):
+ with pytest.raises(ValueError):
+ get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
+
+ def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
+ self, llama3_tokenizer
+ ):
+ chat_template_str = get_chat_template(
+ "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
+ )
+ assert chat_template_str == get_chat_template("chatml")
+
+ def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
+ self, llama3_tokenizer
+ ):
+ llama3_tokenizer.chat_template = "test_template"
+ chat_template_str = get_chat_template(
+ "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
+ )
+ assert chat_template_str == "test_template"
+
+ def test_jinja_template_mode(self):
+ jinja_template = "example_jinja_template"
+ chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
+ assert chat_template_str == jinja_template
+
+ def test_jinja_template_mode_no_jinja_template(self):
+ with pytest.raises(ValueError):
+ get_chat_template("jinja", jinja_template=None)
+
+ def test_extract_chat_template_args(self):
+ # No ds_cfg
+ chat_template_choice, chat_template_jinja = extract_chat_template_args(
+ cfg={"chat_template": "chatml"},
+ )
+ assert chat_template_choice == "chatml"
+ assert chat_template_jinja is None
+
+ # ds_cfg provided
+ chat_template_choice, chat_template_jinja = extract_chat_template_args(
+ cfg={
+ "chat_template": "jinja",
+ "chat_template_jinja": "global_jinja_template",
+ },
+ ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
+ )
+ assert chat_template_choice == "llama3"
+ assert chat_template_jinja is None
+
+ # ds_cfg provided with jinja template
+ chat_template_choice, chat_template_jinja = extract_chat_template_args(
+ cfg={"chat_template": "chatml", "chat_template_jinja": None},
+ ds_cfg={
+ "chat_template": "jinja",
+ "chat_template_jinja": "ds_jinja_template",
+ },
+ )
+ assert chat_template_choice == "jinja"
+ assert chat_template_jinja == "ds_jinja_template"
+
+ # ds_cfg provided with no chat_template
+ chat_template_choice, chat_template_jinja = extract_chat_template_args(
+ cfg={
+ "chat_template": "jinja",
+ "chat_template_jinja": "global_jinja_template",
+ },
+ ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
+ )
+ assert chat_template_choice == "jinja"
+ assert chat_template_jinja == "global_jinja_template"
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py
index e2fc0f6a52..4ec12b82cb 100644
--- a/tests/prompt_strategies/test_chat_templates.py
+++ b/tests/prompt_strategies/test_chat_templates.py
@@ -5,674 +5,19 @@
import logging
import unittest
-import pytest
-from datasets import Dataset
-from transformers import AutoTokenizer
-
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
load,
)
from axolotl.prompters import IGNORE_TOKEN_ID
-from axolotl.utils.chat_templates import chat_templates
+from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
-@pytest.fixture(name="assistant_dataset")
-def fixture_assistant_dataset():
- return Dataset.from_list(
- [
- {
- "messages": [
- {"role": "user", "content": "hello"},
- {"role": "assistant", "content": "hello"},
- {"role": "user", "content": "goodbye"},
- {"role": "assistant", "content": "goodbye"},
- ]
- }
- ]
- )
-
-
-@pytest.fixture(name="sharegpt_dataset")
-def fixture_sharegpt_dataset():
- # pylint: disable=duplicate-code
- return Dataset.from_list(
- [
- {
- "conversations": [
- {"from": "human", "value": "hello"},
- {"from": "gpt", "value": "hello"},
- {"from": "human", "value": "goodbye"},
- {"from": "gpt", "value": "goodbye"},
- ]
- }
- ]
- )
-
-
-@pytest.fixture(name="basic_dataset")
-def fixture_basic_dataset():
- # pylint: disable=duplicate-code
- return Dataset.from_list(
- [
- {
- "conversations": [
- {"from": "system", "value": "You are an AI assistant."},
- {"from": "human", "value": "Hello"},
- {"from": "assistant", "value": "Hi there!"},
- {"from": "human", "value": "How are you?"},
- {"from": "assistant", "value": "I'm doing well, thank you!"},
- ]
- }
- ]
- )
-
-
-@pytest.fixture(name="llama3_tokenizer")
-def fixture_llama3_tokenizer():
- tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
-
- return tokenizer
-
-
-class TestChatTemplateConfigurations:
- """
- Test class for various configurations of ChatTemplateStrategy.
- """
-
- @staticmethod
- def find_sublist(full_list, sub_list):
- token_count = len(sub_list)
- for index in range(len(full_list) - token_count + 1):
- if full_list[index : index + token_count] == sub_list:
- return index
- return -1
-
- def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_inputs=True")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=True,
- sequence_len=512,
- roles_to_train=["assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Verify that assistant responses are labeled
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert start_idx != -1, f"Could not find '{response}' in input_ids"
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- # Check the behavior of human inputs
- human_inputs = ["Hello", "How are you?"]
- for input_text in human_inputs:
- input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, input_ids)
- labeled = all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(input_ids)]
- )
- LOG.debug(
- f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
- )
-
- LOG.debug("Full labels: %s", labels)
- LOG.debug("Full input_ids: %s", input_ids)
-
- def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_inputs=False")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Verify that only assistant responses are labeled
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert start_idx != -1, f"Could not find '{response}' in input_ids"
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- # Verify that human inputs are not labeled
- human_inputs = ["Hello", "How are you?"]
- for input_text in human_inputs:
- input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, input_ids)
- LOG.debug(
- f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
- )
- assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
- assert all(
- label == IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(input_ids)]
- ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
-
- def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing roles_to_train with assistant only")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Verify that only assistant responses are labeled
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing roles_to_train with all roles")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=True,
- sequence_len=512,
- roles_to_train=["human", "assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Verify that all responses are labeled (except for special tokens)
- all_responses = [
- "Hello",
- "Hi there!",
- "How are you?",
- "I'm doing well, thank you!",
- ]
- for response in all_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with empty roles_to_train")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=[],
- train_on_eos="none", # Add this line
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
-
- # Verify that no labels are set when roles_to_train is empty
- LOG.debug("Full labels: %s", labels)
- assert all(
- label == IGNORE_TOKEN_ID for label in labels
- ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
-
- def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_eos='all'")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- train_on_eos="all",
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- eos_token_id = llama3_tokenizer.eos_token_id
- eos_indices = [
- i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
- ]
-
- assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
- for eos_idx in eos_indices:
- assert (
- labels[eos_idx] != IGNORE_TOKEN_ID
- ), f"Expected EOS token at index {eos_idx} to be labeled"
-
- def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_eos='turn'")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- train_on_eos="turn",
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- eos_token_id = llama3_tokenizer.eos_token_id
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
-
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- assert start_idx != -1, f"Could not find '{response}' in input_ids"
-
- eos_idx = start_idx + len(response_ids)
- while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
- eos_idx += 1
-
- assert eos_idx < len(
- input_ids
- ), f"Could not find EOS token after '{response}'"
- assert (
- labels[eos_idx] != IGNORE_TOKEN_ID
- ), f"Expected EOS token after assistant response '{response}' to be labeled"
-
- # Check that EOS tokens after human inputs are not labeled
- human_inputs = ["Hello", "How are you?"]
- for input_text in human_inputs:
- input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, input_ids)
- assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
-
- eos_idx = start_idx + len(input_ids)
- while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
- eos_idx += 1
-
- assert (
- labels[eos_idx] == IGNORE_TOKEN_ID
- ), f"Expected EOS token after human input '{input_text}' to not be labeled"
-
- def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_eos='last'")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- train_on_eos="last",
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- eos_token_id = llama3_tokenizer.eos_token_id
- eos_indices = [
- i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
- ]
-
- assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
- last_eos_idx = eos_indices[-1]
-
- # Check that only the last EOS token is labeled
- for idx in eos_indices[:-1]:
- assert (
- labels[idx] == IGNORE_TOKEN_ID
- ), f"Expected EOS token at index {idx} to not be labeled"
- assert (
- labels[last_eos_idx] != IGNORE_TOKEN_ID
- ), f"Expected last EOS token at index {last_eos_idx} to be labeled"
-
- def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_eos='none'")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- train_on_eos="none",
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- eos_token_id = llama3_tokenizer.eos_token_id
- eos_indices = [
- i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
- ]
-
- assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
- for eos_idx in eos_indices:
- assert (
- labels[eos_idx] == IGNORE_TOKEN_ID
- ), f"Expected EOS token at index {eos_idx} to not be labeled"
-
- def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with drop_system_message=True")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(
- llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
- ),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- input_ids = res["input_ids"]
-
- # Check if system message is not present in input_ids
- system_message = "You are an AI assistant."
- system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
- assert (
- self.find_sublist(input_ids, system_ids) == -1
- ), "Expected system message to be dropped"
-
- def test_custom_roles(self, llama3_tokenizer):
- LOG.info("Testing with custom roles mapping")
- custom_roles = {
- "user": ["human", "user"],
- "assistant": ["ai", "assistant"],
- "system": ["context"],
- }
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(
- llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
- ),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["ai"],
- )
-
- # Create a new dataset with modified role names
- modified_conversations = [
- {"from": "context", "value": "You are an AI assistant."},
- {"from": "human", "value": "Hello"},
- {"from": "ai", "value": "Hi there!"},
- {"from": "human", "value": "How are you?"},
- {"from": "ai", "value": "I'm doing well, thank you!"},
- ]
-
- modified_dataset = Dataset.from_dict(
- {"conversations": [modified_conversations]}
- )
-
- res = strategy.tokenize_prompt(modified_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Check if AI responses are labeled correctly
- ai_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in ai_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- assert start_idx != -1, f"Could not find response '{response}' in input_ids"
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for AI response '{response}' to be set"
-
- # Check if human messages are not labeled
- human_messages = ["Hello", "How are you?"]
- for message in human_messages:
- message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, message_ids)
- assert start_idx != -1, f"Could not find message '{message}' in input_ids"
- assert all(
- label == IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(message_ids)]
- ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
-
- def test_message_field_training(self, llama3_tokenizer):
- LOG.info("Testing with message_field_training")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(
- llama3_tokenizer,
- chat_templates("llama3"),
- message_field_training="train",
- message_field_training_detail="train_detail",
- ),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=[],
- )
-
- # Create a new dataset with the train and train_detail fields
- modified_conversation = [
- {"from": "system", "value": "You are an AI assistant.", "train": False},
- {"from": "human", "value": "Hello", "train": False},
- {"from": "assistant", "value": "Hello", "train": True},
- {"from": "human", "value": "How are you?", "train": True},
- {
- "from": "assistant",
- "value": "I'm doing very well, thank you!",
- "train_detail": [
- {"begin_offset": 0, "end_offset": 8, "train": False},
- {"begin_offset": 9, "end_offset": 18, "train": True},
- {"begin_offset": 19, "end_offset": 30, "train": False},
- ],
- },
- {
- "from": "human",
- "value": "I'm doing very well, thank you!",
- "train": False,
- },
- {"from": "assistant", "value": "Hi there!", "train": True},
- ]
-
- modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
-
- res = strategy.tokenize_prompt(modified_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Function to find all occurrences of a sublist
- def find_all_sublists(full_list, sub_list):
- indices = []
- for index in range(len(full_list) - len(sub_list) + 1):
- if full_list[index : index + len(sub_list)] == sub_list:
- indices.append(index)
- return indices
-
- # Keep track of which occurrences we've processed
- processed_occurrences = {}
- # Check if messages are labeled correctly based on train or train_detail
- for i, turn in enumerate(modified_conversation):
- turn_tokens = llama3_tokenizer.encode(
- turn["value"], add_special_tokens=False
- )
- occurrences = find_all_sublists(input_ids, turn_tokens)
- turn_key = turn["value"]
- if turn_key not in processed_occurrences:
- processed_occurrences[turn_key] = 0
- current_occurrence = processed_occurrences[turn_key]
-
- if current_occurrence >= len(occurrences):
- assert (
- False
- ), f"Not enough occurrences found for message: {turn['value']}"
-
- start_idx = occurrences[current_occurrence]
- processed_occurrences[turn_key] += 1
- end_idx = start_idx + len(turn_tokens)
-
- LOG.debug(
- f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
- )
-
- if "train_detail" in turn:
- # Get token offsets
- tokenized_output = llama3_tokenizer(
- turn["value"], return_offsets_mapping=True, add_special_tokens=False
- )
- token_offsets = tokenized_output["offset_mapping"]
-
- # Adjust token offsets as done in the implementation
- for i in range(len(token_offsets) - 1):
- token_offsets[i] = (
- token_offsets[i][0],
- token_offsets[i + 1][0] - 1,
- )
- token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
-
- # Adjust train_details
- adjusted_train_details = strategy.prompter.adjust_train_details(
- turn["train_detail"], token_offsets
- )
-
- LOG.debug(f"Original train_details: {turn['train_detail']}")
- LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
-
- # Handle train_detail
- token_offsets = strategy.prompter.get_offsets_for_train_detail(
- text=turn["value"],
- train_details=adjusted_train_details,
- mask_untrainable=False,
- )
- token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
- text=turn["value"],
- train_details=adjusted_train_details,
- mask_untrainable=True,
- )
- LOG.debug(f"Token offsets: {token_offsets_masked}")
-
- expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
- for i, offset in enumerate(token_offsets_masked):
- if offset != IGNORE_TOKEN_ID:
- expected_labels[i] = turn_tokens[i]
- actual_labels = labels[
- start_idx : start_idx + len(token_offsets_masked)
- ]
- assert (
- actual_labels == expected_labels
- ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
-
- for detail in adjusted_train_details:
- # Find the token indices that correspond to the character offsets
- detail_start = start_idx + next(
- i
- for i, offset in enumerate(token_offsets)
- if offset >= detail["begin_offset"]
- )
- detail_end = start_idx + next(
- (
- i
- for i, offset in enumerate(token_offsets)
- if offset > detail["end_offset"]
- ),
- len(token_offsets),
- )
-
- detail_text = turn["value"][
- detail["begin_offset"] : detail["end_offset"] + 1
- ]
- detail_labels = labels[detail_start:detail_end]
- detail_input_ids = input_ids[detail_start:detail_end]
-
- LOG.debug(
- f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
- )
- LOG.debug(f"Detail input_ids: {detail_input_ids}")
- LOG.debug(f"Detail labels: {detail_labels}")
- LOG.debug(
- f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
- )
- LOG.debug(
- f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
- )
-
- if detail["train"]:
- assert all(
- label != IGNORE_TOKEN_ID for label in detail_labels
- ), (
- f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
- f"Labels({detail_start}:{detail_end}): {detail_labels}, "
- f"InputIDs: {detail_input_ids}, "
- f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
- )
- else:
- assert all(
- label == IGNORE_TOKEN_ID for label in detail_labels
- ), (
- f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
- f"Labels({detail_start}:{detail_end}): {detail_labels}, "
- f"InputIDs: {detail_input_ids}, "
- f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
- )
- else:
- should_train = turn.get("train", False)
- turn_labels = labels[start_idx:end_idx]
-
- LOG.debug(f"Should train: {should_train}")
- LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
- LOG.debug(f"Turn labels: {turn_labels}")
- LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
- LOG.debug(
- f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
- )
-
- if should_train:
- assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
- f"Expected all labels for '{turn['value']}' to be set\n"
- f"Labels({start_idx}:{end_idx}): {turn_labels}, "
- f"InputIDs: {input_ids[start_idx:end_idx]}, "
- f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
- )
- else:
- assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
- f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
- f"Labels({start_idx}:{end_idx}): {turn_labels}, "
- f"InputIDs: {input_ids[start_idx:end_idx]}, "
- f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
- )
-
- LOG.debug(
- f"Processed turn: {turn['from']}, content: '{turn['value']}', "
- f"start_idx: {start_idx}, end_idx: {end_idx}, "
- f"labels: {labels[start_idx:end_idx]}"
- )
-
- LOG.debug(f"Final labels: {labels}")
- LOG.debug(f"Final input_ids: {input_ids}")
-
-
class TestAssistantChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
@@ -728,7 +73,7 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset):
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
- chat_templates("llama3"),
+ chat_template=get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
roles={
@@ -740,7 +85,6 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset):
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
- roles_to_train=["assistant"],
)
strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0])
@@ -764,12 +108,70 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset):
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
+ def test_phi35(self, phi35_tokenizer, assistant_dataset):
+ LOG.info("Testing phi-3.5 with assistant dataset")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ phi35_tokenizer,
+ chat_template=get_chat_template("phi_35"),
+ message_field_role="role",
+ message_field_content="content",
+ roles={
+ "user": ["user"],
+ "assistant": ["assistant"],
+ "system": ["system"],
+ },
+ ),
+ tokenizer=phi35_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ )
+ strategy.messages = "messages"
+ res = strategy.tokenize_prompt(assistant_dataset[0])
+ input_ids = res["input_ids"]
+ labels = res["labels"]
+ # fmt: off
+ expected_input_ids = [
+ 32010, # user
+ 22172, 32007, # user eot
+ 32001, # assistant
+ 22172, 32007, # assistant eot
+ 32010, # user
+ 1781, 26966, 32007, # user eot
+ 32001, # assistant
+ 1781, 26966, 32007, # assistant eot
+ 32000, # eos
+ ]
+ expected_labels = [
+ -100, # user
+ -100, -100, # user eot
+ -100, # assistant
+ -100, -100, # assistant eot,
+ -100, # user
+ -100, -100, -100, # user eot
+ -100, # assistant
+ 1781, 26966, 32007, # assistant eot
+ 32000, # eos
+ ]
+ # fmt: on
+ LOG.debug(f"Expected input_ids: {expected_input_ids}")
+ LOG.debug(f"Actual input_ids: {input_ids}")
+ assert (
+ input_ids == expected_input_ids
+ ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
+
+ LOG.debug(f"Expected labels : {expected_labels}")
+ LOG.debug(f"Actual labels : {labels}")
+ assert (
+ labels == expected_labels
+ ), f"Input IDs mismatch: {labels} != {expected_labels}"
+
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset including training data")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
- chat_templates("llama3"),
+ chat_template=get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
@@ -825,8 +227,11 @@ class TestSharegptChatTemplateLlama3:
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
+ # pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
@@ -875,8 +280,11 @@ def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
+ # pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
@@ -925,8 +333,11 @@ def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
+ # pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py
new file mode 100644
index 0000000000..be8e3ccdf9
--- /dev/null
+++ b/tests/prompt_strategies/test_chat_templates_advanced.py
@@ -0,0 +1,637 @@
+"""
+tests for chat_template prompt strategy
+"""
+
+import logging
+import unittest
+
+from datasets import Dataset
+
+from axolotl.prompt_strategies.chat_template import (
+ ChatTemplatePrompter,
+ ChatTemplateStrategy,
+)
+from axolotl.prompters import IGNORE_TOKEN_ID
+from axolotl.utils.chat_templates import get_chat_template
+
+logging.basicConfig(level=logging.DEBUG)
+LOG = logging.getLogger("axolotl")
+
+
+class TestChatTemplateConfigurations:
+ """
+ Test class for various configurations of ChatTemplateStrategy.
+ """
+
+ @staticmethod
+ def find_sublist(full_list, sub_list):
+ token_count = len(sub_list)
+ for index in range(len(full_list) - token_count + 1):
+ if full_list[index : index + token_count] == sub_list:
+ return index
+ return -1
+
+ def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_inputs=True")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=True,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Verify that assistant responses are labeled
+ assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
+ for response in assistant_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ LOG.debug(
+ f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ )
+ assert start_idx != -1, f"Could not find '{response}' in input_ids"
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
+
+ # Check the behavior of human inputs
+ human_inputs = ["Hello", "How are you?"]
+ for input_text in human_inputs:
+ input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, input_ids)
+ labeled = all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(input_ids)]
+ )
+ LOG.debug(
+ f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
+ )
+
+ LOG.debug("Full labels: %s", labels)
+ LOG.debug("Full input_ids: %s", input_ids)
+
+ def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_inputs=False")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Verify that only assistant responses are labeled
+ assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
+ for response in assistant_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ LOG.debug(
+ f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ )
+ assert start_idx != -1, f"Could not find '{response}' in input_ids"
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
+
+ # Verify that human inputs are not labeled
+ human_inputs = ["Hello", "How are you?"]
+ for input_text in human_inputs:
+ input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, input_ids)
+ LOG.debug(
+ f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
+ )
+ assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
+ assert all(
+ label == IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(input_ids)]
+ ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
+
+ def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing roles_to_train with assistant only")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Verify that only assistant responses are labeled
+ assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
+ for response in assistant_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ LOG.debug(
+ f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ )
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
+
+ def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing roles_to_train with all roles")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=True,
+ sequence_len=512,
+ roles_to_train=["human", "assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Verify that all responses are labeled (except for special tokens)
+ all_responses = [
+ "Hello",
+ "Hi there!",
+ "How are you?",
+ "I'm doing well, thank you!",
+ ]
+ for response in all_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ LOG.debug(
+ f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ )
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
+
+ def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with empty roles_to_train")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=[],
+ train_on_eos="none", # Add this line
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+
+ # Verify that no labels are set when roles_to_train is empty
+ LOG.debug("Full labels: %s", labels)
+ assert all(
+ label == IGNORE_TOKEN_ID for label in labels
+ ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
+
+ def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_eos='all'")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ train_on_eos="all",
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ eos_token_id = llama3_tokenizer.eos_token_id
+ eos_indices = [
+ i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
+ ]
+
+ assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
+ for eos_idx in eos_indices:
+ assert (
+ labels[eos_idx] != IGNORE_TOKEN_ID
+ ), f"Expected EOS token at index {eos_idx} to be labeled"
+
+ def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_eos='turn'")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ train_on_eos="turn",
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ eos_token_id = llama3_tokenizer.eos_token_id
+ assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
+
+ for response in assistant_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ assert start_idx != -1, f"Could not find '{response}' in input_ids"
+
+ eos_idx = start_idx + len(response_ids)
+ while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
+ eos_idx += 1
+
+ assert eos_idx < len(
+ input_ids
+ ), f"Could not find EOS token after '{response}'"
+ assert (
+ labels[eos_idx] != IGNORE_TOKEN_ID
+ ), f"Expected EOS token after assistant response '{response}' to be labeled"
+
+ # Check that EOS tokens after human inputs are not labeled
+ human_inputs = ["Hello", "How are you?"]
+ for input_text in human_inputs:
+ input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, input_ids)
+ assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
+
+ eos_idx = start_idx + len(input_ids)
+ while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
+ eos_idx += 1
+
+ assert (
+ labels[eos_idx] == IGNORE_TOKEN_ID
+ ), f"Expected EOS token after human input '{input_text}' to not be labeled"
+
+ def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_eos='last'")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ train_on_eos="last",
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ eos_token_id = llama3_tokenizer.eos_token_id
+ eos_indices = [
+ i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
+ ]
+
+ assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
+ last_eos_idx = eos_indices[-1]
+
+ # Check that only the last EOS token is labeled
+ for idx in eos_indices[:-1]:
+ assert (
+ labels[idx] == IGNORE_TOKEN_ID
+ ), f"Expected EOS token at index {idx} to not be labeled"
+ assert (
+ labels[last_eos_idx] != IGNORE_TOKEN_ID
+ ), f"Expected last EOS token at index {last_eos_idx} to be labeled"
+
+ def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_eos='none'")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_template=get_chat_template("llama3")
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ train_on_eos="none",
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ eos_token_id = llama3_tokenizer.eos_token_id
+ eos_indices = [
+ i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
+ ]
+
+ assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
+ for eos_idx in eos_indices:
+ assert (
+ labels[eos_idx] == IGNORE_TOKEN_ID
+ ), f"Expected EOS token at index {eos_idx} to not be labeled"
+
+ def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with drop_system_message=True")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer,
+ chat_template=get_chat_template("llama3"),
+ drop_system_message=True,
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ input_ids = res["input_ids"]
+
+ # Check if system message is not present in input_ids
+ system_message = "You are an AI assistant."
+ system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
+ assert (
+ self.find_sublist(input_ids, system_ids) == -1
+ ), "Expected system message to be dropped"
+
+ def test_custom_roles(self, llama3_tokenizer):
+ LOG.info("Testing with custom roles mapping")
+ custom_roles = {
+ "user": ["human", "user"],
+ "assistant": ["ai", "assistant"],
+ "system": ["context"],
+ }
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer,
+ chat_template=get_chat_template("llama3"),
+ roles=custom_roles,
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["ai"],
+ )
+
+ # Create a new dataset with modified role names
+ modified_conversations = [
+ {"from": "context", "value": "You are an AI assistant."},
+ {"from": "human", "value": "Hello"},
+ {"from": "ai", "value": "Hi there!"},
+ {"from": "human", "value": "How are you?"},
+ {"from": "ai", "value": "I'm doing well, thank you!"},
+ ]
+
+ modified_dataset = Dataset.from_dict(
+ {"conversations": [modified_conversations]}
+ )
+
+ res = strategy.tokenize_prompt(modified_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Check if AI responses are labeled correctly
+ ai_responses = ["Hi there!", "I'm doing well, thank you!"]
+ for response in ai_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ assert start_idx != -1, f"Could not find response '{response}' in input_ids"
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for AI response '{response}' to be set"
+
+ # Check if human messages are not labeled
+ human_messages = ["Hello", "How are you?"]
+ for message in human_messages:
+ message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, message_ids)
+ assert start_idx != -1, f"Could not find message '{message}' in input_ids"
+ assert all(
+ label == IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(message_ids)]
+ ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
+
+ def test_message_field_training(self, llama3_tokenizer):
+ LOG.info("Testing with message_field_training")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer,
+ chat_template=get_chat_template("llama3"),
+ message_field_training="train",
+ message_field_training_detail="train_detail",
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=[],
+ )
+
+ # Create a new dataset with the train and train_detail fields
+ modified_conversation = [
+ {"from": "system", "value": "You are an AI assistant.", "train": False},
+ {"from": "human", "value": "Hello", "train": False},
+ {"from": "assistant", "value": "Hello", "train": True},
+ {"from": "human", "value": "How are you?", "train": True},
+ {
+ "from": "assistant",
+ "value": "I'm doing very well, thank you!",
+ "train_detail": [
+ {"begin_offset": 0, "end_offset": 8, "train": False},
+ {"begin_offset": 9, "end_offset": 18, "train": True},
+ {"begin_offset": 19, "end_offset": 30, "train": False},
+ ],
+ },
+ {
+ "from": "human",
+ "value": "I'm doing very well, thank you!",
+ "train": False,
+ },
+ {"from": "assistant", "value": "Hi there!", "train": True},
+ ]
+
+ modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
+
+ res = strategy.tokenize_prompt(modified_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Function to find all occurrences of a sublist
+ def find_all_sublists(full_list, sub_list):
+ indices = []
+ for index in range(len(full_list) - len(sub_list) + 1):
+ if full_list[index : index + len(sub_list)] == sub_list:
+ indices.append(index)
+ return indices
+
+ # Keep track of which occurrences we've processed
+ processed_occurrences = {}
+ # Check if messages are labeled correctly based on train or train_detail
+ for i, turn in enumerate(modified_conversation):
+ turn_tokens = llama3_tokenizer.encode(
+ turn["value"], add_special_tokens=False
+ )
+ occurrences = find_all_sublists(input_ids, turn_tokens)
+ turn_key = turn["value"]
+ if turn_key not in processed_occurrences:
+ processed_occurrences[turn_key] = 0
+ current_occurrence = processed_occurrences[turn_key]
+
+ if current_occurrence >= len(occurrences):
+ assert (
+ False
+ ), f"Not enough occurrences found for message: {turn['value']}"
+
+ start_idx = occurrences[current_occurrence]
+ processed_occurrences[turn_key] += 1
+ end_idx = start_idx + len(turn_tokens)
+
+ LOG.debug(
+ f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
+ )
+
+ if "train_detail" in turn:
+ # Get token offsets
+ tokenized_output = llama3_tokenizer(
+ turn["value"], return_offsets_mapping=True, add_special_tokens=False
+ )
+ token_offsets = tokenized_output["offset_mapping"]
+
+ # Adjust token offsets as done in the implementation
+ for i in range(len(token_offsets) - 1):
+ token_offsets[i] = (
+ token_offsets[i][0],
+ token_offsets[i + 1][0] - 1,
+ )
+ token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
+
+ # Adjust train_details
+ adjusted_train_details = strategy.prompter.adjust_train_details(
+ turn["train_detail"], token_offsets
+ )
+
+ LOG.debug(f"Original train_details: {turn['train_detail']}")
+ LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
+
+ # Handle train_detail
+ token_offsets = strategy.prompter.get_offsets_for_train_detail(
+ text=turn["value"],
+ train_details=adjusted_train_details,
+ mask_untrainable=False,
+ )
+ token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
+ text=turn["value"],
+ train_details=adjusted_train_details,
+ mask_untrainable=True,
+ )
+ LOG.debug(f"Token offsets: {token_offsets_masked}")
+
+ expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
+ for i, offset in enumerate(token_offsets_masked):
+ if offset != IGNORE_TOKEN_ID:
+ expected_labels[i] = turn_tokens[i]
+ actual_labels = labels[
+ start_idx : start_idx + len(token_offsets_masked)
+ ]
+ assert (
+ actual_labels == expected_labels
+ ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
+
+ for detail in adjusted_train_details:
+ # Find the token indices that correspond to the character offsets
+ detail_start = start_idx + next(
+ i
+ for i, offset in enumerate(token_offsets)
+ if offset >= detail["begin_offset"]
+ )
+ detail_end = start_idx + next(
+ (
+ i
+ for i, offset in enumerate(token_offsets)
+ if offset > detail["end_offset"]
+ ),
+ len(token_offsets),
+ )
+
+ detail_text = turn["value"][
+ detail["begin_offset"] : detail["end_offset"] + 1
+ ]
+ detail_labels = labels[detail_start:detail_end]
+ detail_input_ids = input_ids[detail_start:detail_end]
+
+ LOG.debug(
+ f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
+ )
+ LOG.debug(f"Detail input_ids: {detail_input_ids}")
+ LOG.debug(f"Detail labels: {detail_labels}")
+ LOG.debug(
+ f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
+ )
+ LOG.debug(
+ f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
+ )
+
+ if detail["train"]:
+ assert all(
+ label != IGNORE_TOKEN_ID for label in detail_labels
+ ), (
+ f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
+ f"Labels({detail_start}:{detail_end}): {detail_labels}, "
+ f"InputIDs: {detail_input_ids}, "
+ f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
+ )
+ else:
+ assert all(
+ label == IGNORE_TOKEN_ID for label in detail_labels
+ ), (
+ f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
+ f"Labels({detail_start}:{detail_end}): {detail_labels}, "
+ f"InputIDs: {detail_input_ids}, "
+ f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
+ )
+ else:
+ should_train = turn.get("train", False)
+ turn_labels = labels[start_idx:end_idx]
+
+ LOG.debug(f"Should train: {should_train}")
+ LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
+ LOG.debug(f"Turn labels: {turn_labels}")
+ LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
+ LOG.debug(
+ f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
+ )
+
+ if should_train:
+ assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
+ f"Expected all labels for '{turn['value']}' to be set\n"
+ f"Labels({start_idx}:{end_idx}): {turn_labels}, "
+ f"InputIDs: {input_ids[start_idx:end_idx]}, "
+ f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
+ )
+ else:
+ assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
+ f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
+ f"Labels({start_idx}:{end_idx}): {turn_labels}, "
+ f"InputIDs: {input_ids[start_idx:end_idx]}, "
+ f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
+ )
+
+ LOG.debug(
+ f"Processed turn: {turn['from']}, content: '{turn['value']}', "
+ f"start_idx: {start_idx}, end_idx: {end_idx}, "
+ f"labels: {labels[start_idx:end_idx]}"
+ )
+
+ LOG.debug(f"Final labels: {labels}")
+ LOG.debug(f"Final input_ids: {input_ids}")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py
index cca48b1cf3..740edc22f2 100644
--- a/tests/prompt_strategies/test_dpo_chat_templates.py
+++ b/tests/prompt_strategies/test_dpo_chat_templates.py
@@ -86,6 +86,20 @@ def fixture_llama3_tokenizer():
return tokenizer
+@pytest.fixture(name="phi3_tokenizer")
+def fixture_phi3_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
+
+ return tokenizer
+
+
+@pytest.fixture(name="gemma_tokenizer")
+def fixture_gemma_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
+
+ return tokenizer
+
+
class TestAssistantDPOChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
@@ -99,7 +113,7 @@ def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
"chat_template": "llama3",
"datasets": [
{
- "chat_template": "llama3",
+ "type": "chat_template",
}
],
}
@@ -124,7 +138,7 @@ def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
"chat_template": "llama3",
"datasets": [
{
- "chat_template": "llama3",
+ "type": "chat_template",
"field_messages": "conversation",
"field_chosen": "better",
"field_rejected": "worse",
@@ -152,5 +166,65 @@ def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
assert result["rejected"] == "party on<|eot_id|>"
+class TestAssistantDPOChatTemplatePhi3:
+ """
+ Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
+ """
+
+ def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
+ # pylint: disable=duplicate-code
+ transform_fn = default(
+ DictDefault(
+ {
+ "chat_template": "tokenizer_default",
+ "datasets": [
+ {
+ "type": "chat_template",
+ }
+ ],
+ }
+ )
+ )
+ result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
+ assert result["prompt"] == (
+ "<|user|>\nhello<|end|>\n"
+ + "<|assistant|>\nhello<|end|>\n"
+ + "<|user|>\ngoodbye<|end|>\n"
+ + "<|assistant|>\n"
+ )
+ assert result["chosen"] == "goodbye<|end|>"
+ assert result["rejected"] == "party on<|end|>"
+
+
+class TestAssistantDPOChatTemplateGemma:
+ """
+ Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
+ """
+
+ def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
+ # pylint: disable=duplicate-code
+ transform_fn = default(
+ DictDefault(
+ {
+ "chat_template": "tokenizer_default",
+ "datasets": [
+ {
+ "type": "chat_template",
+ }
+ ],
+ }
+ )
+ )
+ result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
+ assert result["prompt"] == (
+ "user\nhello\n"
+ + "model\nhello\n"
+ + "user\ngoodbye\n"
+ + "model\n"
+ )
+ assert result["chosen"] == "goodbye"
+ assert result["rejected"] == "party on"
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_data.py b/tests/test_data.py
index 16af089a06..9d7f5a0412 100644
--- a/tests/test_data.py
+++ b/tests/test_data.py
@@ -35,7 +35,7 @@ def test_encode_pretraining(self):
"hello, hello",
]
}
- result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
+ result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
self.assertEqual(len(result["input_ids"]), 3)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index a274b7b894..f8b463a03e 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -12,6 +12,7 @@
from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets
+from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault
@@ -267,6 +268,143 @@ def test_load_from_single_json(self):
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
+ def test_load_hub_with_dpo(self):
+ """Verify that processing dpo data from the hub works"""
+
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 1024,
+ "rl": "dpo",
+ "chat_template": "llama3",
+ "datasets": [
+ {
+ "path": "fozziethebeat/alpaca_messages_2k_dpo_test",
+ "type": "chat_template.default",
+ "chat_template": "llama3",
+ "field_messages": "conversation",
+ "field_chosen": "chosen",
+ "field_rejected": "rejected",
+ "message_field_role": "role",
+ "message_field_content": "content",
+ "roles": {
+ "system": ["system"],
+ "user": ["user"],
+ "assistant": ["assistant"],
+ },
+ }
+ ],
+ }
+ )
+
+ train_dataset, _ = load_prepare_dpo_datasets(cfg)
+
+ assert len(train_dataset) == 1800
+ assert "conversation" in train_dataset.features
+
+ def test_load_hub_with_revision(self):
+ """Verify that processing data from the hub works with a specific revision"""
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ prepared_path = Path(tmp_dir) / "prepared"
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 1024,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ "revision": "d05c1cb",
+ },
+ ],
+ }
+ )
+
+ dataset, _ = load_tokenized_prepared_datasets(
+ self.tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 2000
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+
+ def test_load_hub_with_revision_with_dpo(self):
+ """Verify that processing dpo data from the hub works with a specific revision"""
+
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 1024,
+ "rl": "dpo",
+ "chat_template": "llama3",
+ "datasets": [
+ {
+ "path": "fozziethebeat/alpaca_messages_2k_dpo_test",
+ "type": "chat_template.default",
+ "chat_template": "llama3",
+ "revision": "ea82cff",
+ "field_messages": "conversation",
+ "field_chosen": "chosen",
+ "field_rejected": "rejected",
+ "message_field_role": "role",
+ "message_field_content": "content",
+ "roles": {
+ "system": ["system"],
+ "user": ["user"],
+ "assistant": ["assistant"],
+ },
+ }
+ ],
+ }
+ )
+
+ train_dataset, _ = load_prepare_dpo_datasets(cfg)
+
+ assert len(train_dataset) == 1800
+ assert "conversation" in train_dataset.features
+
+ def test_load_local_hub_with_revision(self):
+ """Verify that a local copy of a hub dataset can be loaded with a specific revision"""
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
+ tmp_ds_path.mkdir(parents=True, exist_ok=True)
+ snapshot_download(
+ repo_id="mhenrichsen/alpaca_2k_test",
+ repo_type="dataset",
+ local_dir=tmp_ds_path,
+ revision="d05c1cb",
+ )
+
+ prepared_path = Path(tmp_dir) / "prepared"
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 1024,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "ds_type": "parquet",
+ "type": "alpaca",
+ "data_files": [
+ "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
+ ],
+ "revision": "d05c1cb",
+ },
+ ],
+ }
+ )
+
+ dataset, _ = load_tokenized_prepared_datasets(
+ self.tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 2000
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+ shutil.rmtree(tmp_ds_path)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_prompters.py b/tests/test_prompters.py
index 6c5b8f27c2..3d61398e04 100644
--- a/tests/test_prompters.py
+++ b/tests/test_prompters.py
@@ -42,6 +42,19 @@ def test_prompt_style_w_instruct(self):
assert "USER:" not in res
assert "ASSISTANT:" not in res
+ def test_prompt_style_w_phi(self):
+ prompter = AlpacaPrompter(prompt_style=PromptStyle.PHI.value)
+ res = next(prompter.build_prompt("tell me a joke about the following"))
+ assert (
+ """<|system|>
+Below is an instruction that describes a task. Write a response that appropriately completes the request.<|end|>
+<|user|>
+tell me a joke about the following<|end|>
+<|assistant|>
+"""
+ == res
+ )
+
def test_prompt_style_w_chat(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
diff --git a/tests/test_validation.py b/tests/test_validation.py
index 35d0e265e7..fb63977f5c 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -9,9 +9,11 @@
import pytest
from pydantic import ValidationError
+from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault
+from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -1329,3 +1331,160 @@ def test_wandb_set_disabled(self, minimal_cfg):
os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_DISABLED", None)
+
+
+@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
+class TestValidationComet(BaseValidation):
+ """
+ Validation test for comet
+ """
+
+ def test_comet_sets_env(self, minimal_cfg):
+ from axolotl.utils.comet_ import setup_comet_env_vars
+
+ comet_config = {
+ "comet_api_key": "foo",
+ "comet_workspace": "some_workspace",
+ "comet_project_name": "some_project",
+ "comet_experiment_key": "some_experiment_key",
+ "comet_mode": "get_or_create",
+ "comet_online": False,
+ "comet_experiment_config": {
+ "auto_histogram_activation_logging": False,
+ "auto_histogram_epoch_rate": 2,
+ "auto_histogram_gradient_logging": True,
+ "auto_histogram_tensorboard_logging": False,
+ "auto_histogram_weight_logging": True,
+ "auto_log_co2": False,
+ "auto_metric_logging": True,
+ "auto_metric_step_rate": 15,
+ "auto_output_logging": False,
+ "auto_param_logging": True,
+ "comet_disabled": False,
+ "display_summary_level": 2,
+ "distributed_node_identifier": "some_distributed_node_identifier",
+ "log_code": True,
+ "log_env_cpu": False,
+ "log_env_details": True,
+ "log_env_disk": False,
+ "log_env_gpu": True,
+ "log_env_host": False,
+ "log_env_network": True,
+ "log_git_metadata": False,
+ "log_git_patch": True,
+ "log_graph": False,
+ "name": "some_name",
+ "offline_directory": "some_offline_directory",
+ "parse_args": True,
+ "tags": ["tag1", "tag2"],
+ },
+ }
+
+ cfg = DictDefault(comet_config) | minimal_cfg
+
+ new_cfg = validate_config(cfg)
+
+ setup_comet_env_vars(new_cfg)
+
+ comet_env = {
+ key: value for key, value in os.environ.items() if key.startswith("COMET_")
+ }
+
+ assert (
+ len(comet_env)
+ == len(comet_config) + len(comet_config["comet_experiment_config"]) - 1
+ )
+
+ assert comet_env == {
+ "COMET_API_KEY": "foo",
+ "COMET_AUTO_LOG_CLI_ARGUMENTS": "true",
+ "COMET_AUTO_LOG_CO2": "false",
+ "COMET_AUTO_LOG_CODE": "true",
+ "COMET_AUTO_LOG_DISABLE": "false",
+ "COMET_AUTO_LOG_ENV_CPU": "false",
+ "COMET_AUTO_LOG_ENV_DETAILS": "true",
+ "COMET_AUTO_LOG_ENV_DISK": "false",
+ "COMET_AUTO_LOG_ENV_GPU": "true",
+ "COMET_AUTO_LOG_ENV_HOST": "false",
+ "COMET_AUTO_LOG_ENV_NETWORK": "true",
+ "COMET_AUTO_LOG_GIT_METADATA": "false",
+ "COMET_AUTO_LOG_GIT_PATCH": "true",
+ "COMET_AUTO_LOG_GRAPH": "false",
+ "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false",
+ "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2",
+ "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true",
+ "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false",
+ "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true",
+ "COMET_AUTO_LOG_METRIC_STEP_RATE": "15",
+ "COMET_AUTO_LOG_METRICS": "true",
+ "COMET_AUTO_LOG_OUTPUT_LOGGER": "false",
+ "COMET_AUTO_LOG_PARAMETERS": "true",
+ "COMET_DISPLAY_SUMMARY_LEVEL": "2",
+ "COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier",
+ "COMET_EXPERIMENT_KEY": "some_experiment_key",
+ "COMET_OFFLINE_DIRECTORY": "some_offline_directory",
+ "COMET_PROJECT_NAME": "some_project",
+ "COMET_START_EXPERIMENT_NAME": "some_name",
+ "COMET_START_EXPERIMENT_TAGS": "tag1,tag2",
+ "COMET_START_MODE": "get_or_create",
+ "COMET_START_ONLINE": "false",
+ "COMET_WORKSPACE": "some_workspace",
+ }
+
+ for key in comet_env.keys():
+ os.environ.pop(key, None)
+
+
+class TestValidationMLflow(BaseValidation):
+ """
+ Validation test for MLflow
+ """
+
+ def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "hf_mlflow_log_artifacts": True,
+ }
+ )
+ | minimal_cfg
+ )
+
+ new_cfg = validate_config(cfg)
+
+ assert new_cfg.hf_mlflow_log_artifacts is True
+
+ # Check it's not already present in env
+ assert "HF_MLFLOW_LOG_ARTIFACTS" not in os.environ
+
+ setup_mlflow_env_vars(new_cfg)
+
+ assert os.environ.get("HF_MLFLOW_LOG_ARTIFACTS") == "true"
+
+ os.environ.pop("HF_MLFLOW_LOG_ARTIFACTS", None)
+
+ def test_mlflow_not_used_by_default(self, minimal_cfg):
+ cfg = DictDefault({}) | minimal_cfg
+
+ new_cfg = validate_config(cfg)
+
+ setup_mlflow_env_vars(new_cfg)
+
+ assert cfg.use_mlflow is not True
+
+ cfg = (
+ DictDefault(
+ {
+ "mlflow_experiment_name": "foo",
+ }
+ )
+ | minimal_cfg
+ )
+
+ new_cfg = validate_config(cfg)
+
+ setup_mlflow_env_vars(new_cfg)
+
+ assert new_cfg.use_mlflow is True
+
+ os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)
diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py
new file mode 100644
index 0000000000..389424217b
--- /dev/null
+++ b/tests/test_validation_dataset.py
@@ -0,0 +1,238 @@
+"""Module for testing the validation module for the dataset config"""
+
+import warnings
+from typing import Optional
+
+import pytest
+
+from axolotl.utils.config import validate_config
+from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
+from axolotl.utils.dict import DictDefault
+
+warnings.filterwarnings("error")
+
+
+@pytest.fixture(name="minimal_cfg")
+def fixture_cfg():
+ return DictDefault(
+ {
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ }
+ )
+
+
+# pylint: disable=too-many-public-methods (duplicate-code)
+class BaseValidation:
+ """
+ Base validation module to setup the log capture
+ """
+
+ _caplog: Optional[pytest.LogCaptureFixture] = None
+
+ @pytest.fixture(autouse=True)
+ def inject_fixtures(self, caplog):
+ self._caplog = caplog
+
+
+class TestValidationCheckDatasetConfig(BaseValidation):
+ """
+ Test the validation for the dataset config to ensure no correct parameters are dropped
+ """
+
+ def test_dataset_config_no_drop_param(self, minimal_cfg):
+ cfg = DictDefault(
+ minimal_cfg
+ | {
+ "datasets": [
+ {
+ "path": "LDJnr/Puffin",
+ "type": "sharegpt",
+ "conversation": "chatml",
+ "shards": 10,
+ }
+ ]
+ }
+ )
+
+ checked_cfg = validate_config(cfg)
+
+ def _check_config():
+ assert checked_cfg.datasets[0].path == cfg.datasets[0].path
+ assert checked_cfg.datasets[0].type == cfg.datasets[0].type
+ assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
+ assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
+
+ _check_config()
+
+ checked_cfg = validate_config(
+ cfg,
+ capabilities={
+ "bf16": "false",
+ "n_gpu": 1,
+ "compute_capability": "8.0",
+ },
+ )
+
+ _check_config()
+
+ def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):
+ cfg = DictDefault(
+ minimal_cfg
+ | {
+ "datasets": [
+ {
+ "path": "LDJnr/Puffin",
+ "type": "chat_template",
+ "field_messages": "conversations",
+ "shards": 10,
+ "message_field_role": "from",
+ "message_field_content": "value",
+ }
+ ],
+ }
+ )
+
+ checked_cfg = validate_config(cfg)
+
+ def _check_config():
+ assert checked_cfg.datasets[0].path == cfg.datasets[0].path
+ assert checked_cfg.datasets[0].type == cfg.datasets[0].type
+ assert checked_cfg.chat_template is None
+ assert (
+ checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
+ )
+ assert (
+ checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
+ )
+ assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
+ assert (
+ checked_cfg.datasets[0].message_field_role
+ == cfg.datasets[0].message_field_role
+ )
+ assert (
+ checked_cfg.datasets[0].message_field_content
+ == cfg.datasets[0].message_field_content
+ )
+
+ _check_config()
+
+ checked_cfg = validate_config(
+ cfg,
+ capabilities={
+ "bf16": "false",
+ "n_gpu": 1,
+ "compute_capability": "8.0",
+ },
+ )
+
+ _check_config()
+
+ def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):
+ cfg = DictDefault(
+ minimal_cfg
+ | {
+ "chat_template": "chatml",
+ "datasets": [
+ {
+ "path": "LDJnr/Puffin",
+ "type": "chat_template",
+ "field_messages": "conversations",
+ "shards": 10,
+ "message_field_role": "from",
+ "message_field_content": "value",
+ }
+ ],
+ }
+ )
+
+ checked_cfg = validate_config(cfg)
+
+ def _check_config():
+ assert checked_cfg.datasets[0].path == cfg.datasets[0].path
+ assert checked_cfg.datasets[0].type == cfg.datasets[0].type
+ assert checked_cfg.chat_template == ChatTemplate.chatml
+ assert (
+ checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
+ )
+ assert (
+ checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
+ )
+ assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
+ assert (
+ checked_cfg.datasets[0].message_field_role
+ == cfg.datasets[0].message_field_role
+ )
+ assert (
+ checked_cfg.datasets[0].message_field_content
+ == cfg.datasets[0].message_field_content
+ )
+
+ _check_config()
+
+ checked_cfg = validate_config(
+ cfg,
+ capabilities={
+ "bf16": "false",
+ "n_gpu": 1,
+ "compute_capability": "8.0",
+ },
+ )
+
+ _check_config()
+
+ def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):
+ cfg = DictDefault(
+ minimal_cfg
+ | {
+ "chat_template": "chatml",
+ "datasets": [
+ {
+ "path": "LDJnr/Puffin",
+ "type": "chat_template",
+ "chat_template": "gemma",
+ "field_messages": "conversations",
+ "shards": 10,
+ "message_field_role": "from",
+ "message_field_content": "value",
+ }
+ ],
+ }
+ )
+
+ checked_cfg = validate_config(cfg)
+
+ def _check_config():
+ assert checked_cfg.datasets[0].path == cfg.datasets[0].path
+ assert checked_cfg.datasets[0].type == cfg.datasets[0].type
+ assert checked_cfg.chat_template == cfg.chat_template
+ assert (
+ checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template
+ )
+ assert (
+ checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
+ )
+ assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
+ assert (
+ checked_cfg.datasets[0].message_field_role
+ == cfg.datasets[0].message_field_role
+ )
+ assert (
+ checked_cfg.datasets[0].message_field_content
+ == cfg.datasets[0].message_field_content
+ )
+
+ _check_config()
+
+ checked_cfg = validate_config(
+ cfg,
+ capabilities={
+ "bf16": "false",
+ "n_gpu": 1,
+ "compute_capability": "8.0",
+ },
+ )
+
+ _check_config()
diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py
index e06bb6c250..31698f05fb 100644
--- a/tests/utils/test_models.py
+++ b/tests/utils/test_models.py
@@ -1,18 +1,64 @@
"""Module for testing models utils file."""
-
-import unittest
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
import pytest
+from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
+from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+from transformers.utils.import_utils import is_torch_mps_available
from axolotl.utils.dict import DictDefault
-from axolotl.utils.models import load_model
+from axolotl.utils.models import ModelLoader, load_model
-class ModelsUtilsTest(unittest.TestCase):
+class TestModelsUtils:
"""Testing module for models utils."""
+ def setup_method(self) -> None:
+ # load config
+ self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
+ {
+ "base_model": "JackFram/llama-68m",
+ "model_type": "LlamaForCausalLM",
+ "tokenizer_type": "LlamaTokenizer",
+ "load_in_8bit": True,
+ "load_in_4bit": False,
+ "adapter": "lora",
+ "flash_attention": False,
+ "sample_packing": True,
+ "device_map": "auto",
+ }
+ )
+ self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
+ spec=PreTrainedTokenizerBase
+ )
+ self.inference = False # pylint: disable=attribute-defined-outside-init
+ self.reference_model = True # pylint: disable=attribute-defined-outside-init
+
+ # init ModelLoader
+ self.model_loader = ( # pylint: disable=attribute-defined-outside-init
+ ModelLoader(
+ cfg=self.cfg,
+ tokenizer=self.tokenizer,
+ inference=self.inference,
+ reference_model=self.reference_model,
+ )
+ )
+
+ def test_set_device_map_config(self):
+ # check device_map
+ device_map = self.cfg.device_map
+ if is_torch_mps_available():
+ device_map = "mps"
+ self.model_loader.set_device_map_config()
+ if is_deepspeed_zero3_enabled():
+ assert "device_map" not in self.model_loader.model_kwargs
+ else:
+ assert device_map in self.model_loader.model_kwargs["device_map"]
+
+ # check torch_dtype
+ assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
+
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
cfg = DictDefault(
{
@@ -35,3 +81,38 @@ def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
)
+
+ @pytest.mark.parametrize("adapter", ["lora", "qlora", None])
+ @pytest.mark.parametrize("load_in_8bit", [True, False])
+ @pytest.mark.parametrize("load_in_4bit", [True, False])
+ @pytest.mark.parametrize("gptq", [True, False])
+ def test_set_quantization_config(
+ self,
+ adapter,
+ load_in_8bit,
+ load_in_4bit,
+ gptq,
+ ):
+ # init cfg as args
+ self.cfg.load_in_8bit = load_in_8bit
+ self.cfg.load_in_4bit = load_in_4bit
+ self.cfg.gptq = gptq
+ self.cfg.adapter = adapter
+
+ self.model_loader.set_quantization_config()
+ if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
+ assert not (
+ hasattr(self.model_loader.model_kwargs, "load_in_8bit")
+ and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
+ )
+ elif load_in_8bit and self.cfg.adapter is not None:
+ assert self.model_loader.model_kwargs["load_in_8bit"]
+ elif load_in_4bit and self.cfg.adapter is not None:
+ assert self.model_loader.model_kwargs["load_in_4bit"]
+
+ if (self.cfg.adapter == "qlora" and load_in_4bit) or (
+ self.cfg.adapter == "lora" and load_in_8bit
+ ):
+ assert self.model_loader.model_kwargs.get(
+ "quantization_config", BitsAndBytesConfig
+ )
|