Skip to content

Commit

Permalink
Merge pull request #25 from ai2es/torch
Browse files Browse the repository at this point in the history
First commit of updated torch code
  • Loading branch information
djgagne authored Aug 26, 2024
2 parents ea65f40 + 2966bd0 commit 23216cd
Show file tree
Hide file tree
Showing 53 changed files with 8,774 additions and 443 deletions.
15 changes: 8 additions & 7 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,26 @@ jobs:
- uses: actions/checkout@v2
- uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment.yml
activate-environment: test
environment-file: environment_torch.yml
- shell: bash -l {0}
run: |
pip install --upgrade keras
conda info
conda list
conda config --show-sources
conda config --show
printenv | sort
- name: Lint with flake8
- name: Lint with ruff
shell: bash -l {0}
run: |
micromamba install flake8
micromamba install ruff
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
ruff check --select=E9,F63,F7,F82
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=100 --max-line-length=127 --statistics
ruff check --output-format concise --exit-zero
# Checking documentation errors
ruff check --select=D --exit-zero --statistics
- name: Test with pytest
shell: bash -l {0}
run: |
export KERAS_BACKEND="torch"
pytest
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,26 @@ conda activate guess

## Using miles-guess

The package contains three scripts for training three regression models, and one for training categorical models.
The regression examples are trained on our surface layer ("SL") dataset for predicting latent heat and other quantities,
and the categorical example is trained on a precipitation dataset ("p-type").

The law of total variance for each model prediction target may be computed as
The law of total variance for each model prediction target may be computed as

$$LoTV = E[\sigma^2] + Var[\mu]$$

which is the sum of aleatoric and epistemic contributions, respectively.
which is the sum of aleatoric and epistemic contributions, respectively. The MILES-GUESS package contains options for using either Keras or PyTorch for computing quantites according to the LoTV as well as utilizing Dempster-Shafer theory uncertainty in the classifier case.

For detailed information about training with Keras, refer to [the Keras training details README](docs/source/keras.md). There three scripts for training three regression models, and one for training categorical models. The regression examples are trained on our surface layer ("SL") dataset for predicting latent heat and other quantities,
and the categorical example is trained on a precipitation dataset ("p-type").

For pyTorch, please visit the [the pyTorch training details README](docs/source/torch.md) where details on training scripts for both evidential standard classification tasks are detailed. Torch examples use the same datasets as the Keras models. The torch training code will also scale on GPUs, and is compatitible with DDP and FSDP.

<!--
### 1a. Train/evaluate a deterministic multi-layer perceptrion (MLP) on the SL dataset:
```bash
python3 applications/train_mlp_SL.py -c config/model_mlp_SL.yml
```
### 1b. Train/evaluate a parametric "Gaussian" MLP on the SL dataset:
```bash
python applications/train_gaussian_SL.py -c config/model_gaussian_SL.yml
```
Expand Down Expand Up @@ -271,4 +274,4 @@ Depending on the problem, a data field is customized and also present in the con
## ECHO hyperparameter optimization
Configuration files are also supplied for use with the Earth Computing Hyperparameter Optimization (ECHO) package. See the echo package https://github.com/NCAR/echo-opt/tree/main/echo for more details on the configuration fields.
Configuration files are also supplied for use with the Earth Computing Hyperparameter Optimization (ECHO) package. See the echo package https://github.com/NCAR/echo-opt/tree/main/echo for more details on the configuration fields. -->
4 changes: 2 additions & 2 deletions applications/evaluate_ptype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support

from ptype.reliability import (
from mlguess.reliability import (
compute_calibration,
reliability_diagram,
reliability_diagrams,
)
from ptype.plotting import (
from mlguess.plotting import (
plot_confusion_matrix,
coverage_figures,
)
Expand Down
Loading

0 comments on commit 23216cd

Please sign in to comment.