Support direct quantization for FP8 matmul #13902
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | |
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | |
name: Build | |
on: | |
push: | |
branches: | |
- main | |
- 'test_*' | |
pull_request: | |
branches: | |
- main | |
jobs: | |
cancel-previous: | |
name: Cancel Previous Runs | |
runs-on: ubuntu-latest | |
steps: | |
- name: Cancel previous | |
uses: styfle/[email protected] | |
if: ${{github.ref != 'refs/head/main'}} | |
with: | |
access_token: ${{ github.token }} | |
pre-commit: | |
name: Test pre-commit hooks | |
runs-on: ubuntu-latest | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Set up Python | |
uses: actions/setup-python@v4 | |
with: | |
python-version: '3.10' | |
- uses: pre-commit/[email protected] | |
commit-count: | |
name: Check commit count | |
runs-on: ubuntu-latest | |
steps: | |
- uses: actions/checkout@v3 | |
# We allow at most 5 commits in a branch to ensure our CI doesn't break. | |
- name: Check commit count in PR | |
if: always() | |
shell: bash | |
run: | | |
set -x | |
# $GITHUB_REF is in format `refs/heads/<branch_name>`. We fetch it under | |
# the name `commit-count` so we can refer to it below. | |
# Do an unshallow fetch so we retrieve all commits (this is necessary | |
# because ations/checkout@v2 fetches a shallow copy). | |
git fetch origin --unshallow $GITHUB_REF:commit-count | |
git fetch origin main | |
diff=$(git rev-list --count origin/main...commit-count) | |
# $GITHUB_REF adds an additional commit to the commit tree, so $diff is | |
# one too high when executing this as a Github Action. | |
if (( $diff > 6)); then | |
echo "ERROR! More than 5 commits in PR -- please squash your commits." | |
url=https://flax.readthedocs.io/en/latest/contributing.html#too-many-commits-in-a-pull-request | |
echo "See $url for help on how to resolve this." | |
exit 1 | |
fi | |
test-import: | |
name: Test import standalone | |
runs-on: ubuntu-latest | |
strategy: | |
matrix: | |
python-version: ['3.10', '3.11'] | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Set up Python ${{ matrix.python-version }} | |
uses: actions/setup-python@v4 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- name: Install standalone dependencies only | |
run: | | |
pip install -e .[all] | |
- name: Test importing Flax | |
run: | | |
python -c "import flax" | |
tests: | |
name: Run Tests | |
needs: [cancel-previous, pre-commit, commit-count, test-import] | |
runs-on: ubuntu-20.04-16core | |
strategy: | |
matrix: | |
python-version: ['3.10', '3.11'] | |
test-type: [doctest, pytest, pytype, mypy] | |
jax-version: [newest] | |
exclude: | |
- test-type: pytype | |
python-version: '3.10' | |
- test-type: mypy | |
python-version: '3.11' | |
include: | |
- python-version: '3.10' | |
test-type: pytest | |
jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Set up Python ${{ matrix.python-version }} | |
id: setup_python | |
uses: actions/setup-python@v4 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- name: Get week and year | |
id: date_key | |
run: echo "DATE=$(date +%j)" >> $GITHUB_OUTPUT | |
# TODO(cgarciae): caching is breaking the install step, disabling for now. | |
# - name: Cached virtual environment | |
# id: venv_cache | |
# uses: actions/cache@v3 | |
# with: | |
# path: venv | |
# key: pip-${{ steps.setup_python.outputs.python-version }}-${{ steps.date_key.outputs.DATE }}-${{ hashFiles('**/requirements.txt', 'pyproject.toml') }} | |
- name: Install Dependencies for cache | |
if: steps.venv_cache.outputs.cache-hit != 'true' | |
run: | | |
if [ -d "venv" ]; then rm -rf venv; fi | |
python3 -m venv venv | |
venv/bin/python3 -m pip install .[all,testing] | |
venv/bin/python3 -m pip install tensorflow_datasets[dev] | |
venv/bin/python3 -m pip install -r docs/requirements.txt | |
- name: Install Flax | |
run: | | |
venv/bin/python3 -m pip install -e .[all,testing] | |
- name: Install JAX | |
run: | | |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then | |
venv/bin/python3 -m pip install -U jax jaxlib | |
else | |
venv/bin/python3 -m pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" | |
fi | |
- name: Cached mypy cache | |
id: mypy_cache | |
uses: actions/cache@v3 | |
if: matrix.test-type == 'mypy' | |
with: | |
path: .mypy_cache | |
key: mypy-${{ steps.setup_python.outputs.python-version }}-${{ steps.date_key.outputs.DATE }} | |
- name: Test with ${{ matrix.test-type }} | |
run: | | |
if [[ "${{ matrix.test-type }}" == "doctest" ]]; then | |
tests/run_all_tests.sh --no-pytest --no-pytype --no-mypy --use-venv | |
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then | |
tests/run_all_tests.sh --no-doctest --no-pytype --no-mypy --with-cov --use-venv | |
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then | |
tests/run_all_tests.sh --no-doctest --no-pytest --no-mypy --use-venv | |
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then | |
tests/run_all_tests.sh --no-doctest --no-pytest --no-pytype --use-venv | |
else | |
echo "Unknown test type: ${{ matrix.test-type }}" | |
exit 1 | |
fi | |
- name: Upload coverage to Codecov | |
if: matrix.test-type == 'pytest' | |
uses: codecov/codecov-action@v4 | |
env: | |
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} | |
with: | |
file: ./coverage.xml | |
# The below step just reports the success or failure of tests as a "commit status". | |
# This is needed for copybara integration. | |
- name: Report success or failure as github status | |
if: always() | |
shell: bash | |
run: | | |
status="${{ job.status }}" | |
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') | |
curl -sS --request POST \ | |
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ | |
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ | |
--header 'content-type: application/json' \ | |
--data '{ | |
"state": "'$lowercase_status'", | |
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", | |
"description": "'$status'", | |
"context": "github-actions/Build" | |
}' |