Skip to content

Commit

Permalink
Add T5 LM v1.1 encoder
Browse files Browse the repository at this point in the history
The encoder shares much of the underlying stack as the decoder. Here
only the encoder is presented as a class.
I have not gone out of my way to strip all decoder related stuff from
the stack. Things like check-pointing and dropout are stripped.

The author attribution is added to the license of the T5 model file as
this seems like a derivative work. They are both Apache 2.0.

There are a few tests of the various components and 2 tests for the
entire encoder for the small and xxl variants. They relay on huggingface
and the models are downloaded no the fly into the cache.
The tests expect the corresponding GGUF files to be already preset and
available on the file system.
  • Loading branch information
sogartar committed Nov 15, 2024
1 parent 8d9a923 commit 512356f
Show file tree
Hide file tree
Showing 12 changed files with 1,621 additions and 63 deletions.
62 changes: 62 additions & 0 deletions .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

# TODO: refactor out common steps into actions/scripts.

jobs:
test_perplexity_vmfb:
timeout-minutes: 1000
Expand Down Expand Up @@ -123,3 +125,63 @@ jobs:
- name: Run perplexity test in eager mode
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json

test_torch:
timeout-minutes: 1000
name: "Torch/Eager mode"
strategy:
matrix:
version: [3.11]
runs-on: [llama-mi300x-3]
fail-fast: false
runs-on: ${{matrix.runs-on}}
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
HF_HOME: "/data/huggingface"
SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
steps:
- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Cache Pip Packages
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}

- name: Install sharktank deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
# Install latest iree-tubrine.
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
# Try with the latest IREE nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler \
iree-base-runtime
- name: Run long running tests
run: |
pytest --longrun \
--google-t5-v1-1-small-fp32-model-path=/data/t5/small/google__t5-v1_1-small_fp32.gguf \
--google-t5-v1-1-xxl-fp32-model-path=/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf \
sharktank/tests/models/t5/t5_test.py
26 changes: 26 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ def pytest_addoption(parser):
help="Llama3.1 405b fp8 model path",
)

parser.addoption(
"--google-t5-v1-1-small-fp32-model-path",
type=Path,
action="store",
default=None,
help="Google T5 v1.1 small fp32 model path",
)

parser.addoption(
"--google-t5-v1-1-xxl-fp32-model-path",
type=Path,
action="store",
default=None,
help="Google T5 v1.1 XXL fp32 model path",
)

parser.addoption(
"--baseline-perplexity-scores",
type=Path,
Expand Down Expand Up @@ -256,6 +272,16 @@ def get_model_artifacts(request: FixtureRequest):
model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model"
)
model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-small-fp32-model-path",
"google__t5_v1_1_small_fp32_model",
)
model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-xxl-fp32-model-path",
"google__t5_v1_1_xxl_fp32_model",
)
return model_path


Expand Down
28 changes: 20 additions & 8 deletions sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional
from typing import Optional, Callable

import torch
import torch.nn.functional as F
from .. import ops
from ..types import AnyTensor

from .base import Theta, ThetaLayer
from .linear import LinearLayer
Expand All @@ -22,18 +23,29 @@ class FFN(ThetaLayer):
def __init__(
self,
theta: Theta,
is_gated: bool = None,
activation_fn: Optional[Callable[[AnyTensor], AnyTensor]] = None,
):
super().__init__(theta)

self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.is_gated = is_gated is None or is_gated
self.activation_fn = activation_fn or F.silu
if self.is_gated:
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))

def forward(
self,
h: torch.Tensor,
):
ffn_gate = ops.elementwise(F.silu, self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
h: AnyTensor,
) -> AnyTensor:
if self.is_gated:
ffn_gate = ops.elementwise(self.activation_fn, self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
else:
h = self.ffn_up(h)
h = ops.elementwise(self.activation_fn, h)
h = self.ffn_down(h)
return h
Loading

0 comments on commit 512356f

Please sign in to comment.