Skip to content

Commit

Permalink
Add test sharding and enable more tests
Browse files Browse the repository at this point in the history
`sim_custom_sources.py` is disabled until it runs under pytest.

PiperOrigin-RevId: 618284210
  • Loading branch information
Torax team committed Apr 10, 2024
1 parent aaa9d4e commit 74edd37
Showing 1 changed file with 89 additions and 14 deletions.
103 changes: 89 additions & 14 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,121 @@ name: Unittests
# Allow to trigger the workflow manually (e.g. when deps changes)
on: [push, workflow_dispatch]

# Concurrency config borrowed from tensorflow_datasets.
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.ref != 'refs/heads/master' || github.run_number }}
# Cancel only PR intermediate builds
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

# Set the correct env variables for all the unit tests
env:
TORAX_ERRORS_ENABLED: 1
PYTEST_NUM_SHARDS: 4 # Controls tests sharding enabled by `pytest-shard`

jobs:
pytest-job:
shards-job:
name: Generate shards
runs-on: ubuntu-latest
timeout-minutes: 80

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
steps:
- name: Create variables
id: create-vars
run: |
echo "num-shards=$(jq -n -c '[${{ env.PYTEST_NUM_SHARDS }}]')" >> $GITHUB_OUTPUT
echo "shard-ids=$(jq -n -c '[range(1;${{ env.PYTEST_NUM_SHARDS }}+1)]')" >> $GITHUB_OUTPUT
outputs:
num-shards: ${{ steps.create-vars.outputs.num-shards }}
shard-ids: ${{ steps.create-vars.outputs.shard-ids }}

pytest-job-shards:
needs: shards-job

name: '[${{ matrix.os-version }}][Python ${{ matrix.python-version }}][${{ matrix.shard-id }}/${{ matrix.num-shards }}] Core TORAX tests'
runs-on: ${{ matrix.os-version }}
timeout-minutes: 30
strategy:
# Do not cancel in-progress jobs if any matrix job fails.
fail-fast: false
matrix:
# Can't reference env variables in matrix
num-shards: ${{ fromJson(needs.shards-job.outputs.num-shards) }}
shard-id: ${{ fromJson(needs.shards-job.outputs.shard-ids) }}
python-version: ['3.10']
os-version: [ubuntu-latest]

steps:
- uses: actions/checkout@v4

# Install deps
- uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: ${{ matrix.python-version }}

- run: git clone https://gitlab.com/qualikiz-group/qlknn-hyper.git
- run: echo "TORAX_QLKNN_MODEL_PATH=$PWD/qlknn-hyper" >> "$GITHUB_ENV"

- run: pip --version
- run: pip install -e .[dev]
# TODO(b/323504363): [dev] should install these
- run: pip install pytest pytest-xdist
- run: pip install pytest pytest-xdist pytest-shard
- run: pip freeze

# Run tests (in parallel)
# TODO(b/323504363): tests should be discovered automatically
- name: Run core tests
run: pytest -vv -n auto torax/tests/{boundary_conditions,config,config_slice,geometry,interpolated_param,jax_utils,math_utils,opt,sim_time_dependence}.py
run: |
pytest \
torax/fvm/tests/fvm.py \
torax/sources/tests/bootstrap_current_source.py \
torax/sources/tests/current_density_sources.py \
torax/sources/tests/electron_density_sources.py \
torax/sources/tests/external_current_source.py \
torax/sources/tests/fusion_heat_source.py \
torax/sources/tests/generic_ion_el_heat_source.py \
torax/sources/tests/ion_el_heat_sources.py \
torax/sources/tests/qei_source.py \
torax/sources/tests/source_config.py \
torax/sources/tests/source_models.py \
torax/sources/tests/source.py \
torax/spectators/tests/plotting.py \
torax/spectators/tests/spectator.py \
torax/tests/boundary_conditions.py \
torax/tests/config_slice.py \
torax/tests/config.py \
torax/tests/geometry.py \
torax/tests/interpolated_param.py \
torax/tests/jax_utils.py \
torax/tests/math_utils.py \
torax/tests/opt.py \
torax/tests/physics.py \
torax/tests/sim_time_dependence.py \
torax/tests/state.py \
torax/transport_model/tests/qlknn_wrapper.py \
torax/transport_model/tests/transport_model.py \
-vv -n auto \
--shard-id=$((${{ matrix.shard-id }} - 1)) --num-shards=${{ env.PYTEST_NUM_SHARDS }}
# # TODO(b/323504363): fix UnparsedFlagAccessError and run under pytest
# - name: "Run torax/tests/sim_custom_sources.py"
# run: python torax/tests/sim_custom_sources.py

# TODO(b/323504363): these tests should also run under pytest
- name: "Run sim_custom_sources.py"
run: python torax/tests/sim_custom_sources.py
# TODO(b/323504363): these tests should be parallelized with multiple workers
# - name: "Run sim_no_compile.py"
# # TODO(b/323504363): fix UnparsedFlagAccessError and run under pytest
# - name: "Run torax/sources/tests/formulas.py"
# run: python torax/sources/tests/formulas.py

# # TODO(b/323504363): fix UnparsedFlagAccessError, pass env variable, and run under pytest
# - name: "Run torax/tests/sim_no_compile.py"
# run: TORAX_COMPILATION_ENABLED=False python torax/tests/sim_no_compile.py
# # - name: "Run sim.py"

# # TODO(b/323504363): fix UnparsedFlagAccessError, parallelize, and run under pytest
# - name: "Run torax/tests/sim.py"
# run: python torax/tests/sim.py

pytest-job:
# Dummy job to enable smooth transition of Copybara configs.
# TODO(b/323504363): remove once Copybara configs are updated.
runs-on: ubuntu-latest
timeout-minutes: 1
steps:
- run: echo OK

0 comments on commit 74edd37

Please sign in to comment.