Skip to content

Fix blockwise sharding #632

Fix blockwise sharding

Fix blockwise sharding #632

Workflow file for this run

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: Unit Tests
on:
pull_request:
push:
branches: [ "main" ]
workflow_dispatch:
schedule:
# Run the job every 12 hours
- cron: '0 */12 * * *'
jobs:
py:
name: "Python type/lint/format checks"
strategy:
matrix:
os: [ubuntu-20.04]
python-version: ['3.10']
runs-on: ${{ matrix.os }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
pip install pytype
pip install pylint
pip install pyink
source install_everything.sh
# - name: Typecheck the code with pytype
# run: |
# pytype --jobs auto --disable import-error --disable module-attr jetstream_pt/
- name: Analysing the code with pylint
run: |
pylint --indent-string=' ' jetstream_pt/ benchmarks/
- name: Format check with pyink
run: |
pyink --pyink-indentation 2 --line-length 80 --check --verbose --extend-exclude=deps .
cpu:
name: "jetstream_pt unit tests"
strategy:
matrix:
os: [ubuntu-20.04]
python-version: ['3.10']
runs-on: ${{ matrix.os }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
source install_everything.sh
- name: Run all unit tests for jetstream_pt (/tests)
run: |
JAX_PLATFORMS=cpu coverage run -m unittest -v
- name: Create test coverage report
run: |
coverage report -m
interactive:
name: "jetstream_pt run interactive"
strategy:
matrix:
os: [ubuntu-20.04]
python-version: ['3.10']
runs-on: ${{ matrix.os }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
source install_everything.sh
- name: Run interactive (bf16)
run: |
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0
- name: Run interactive (int8)
run: |
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_type="int8_per_channel" --quantize_kv_cache=1