Skip to content

Commit

Permalink
paper, rename loaders, readme
Browse files Browse the repository at this point in the history
  • Loading branch information
homerjed committed Oct 7, 2024
1 parent 4f2a557 commit 9bdbc69
Show file tree
Hide file tree
Showing 12 changed files with 327 additions and 137 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/joss.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Draft JOSS paper PDF
on:
push:
paths:
- paper/**
- .github/workflows/draft-pdf.yml

jobs:
paper:
runs-on: ubuntu-latest
name: Paper Draft
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Build draft PDF
uses: openjournals/openjournals-draft-action@master
with:
journal: joss
# This should be the path to the paper within your repo.
paper-path: paper.md
- name: Upload
uses: actions/upload-artifact@v4
with:
name: paper
# This is the output path where Pandoc will write the compiled
# PDF. Note, this should be the same directory as the input
# paper.md
path: paper.pdf

- name: Commit PDF to repository
uses: EndBug/add-and-commit@v9
with:
message: '(auto) Paper PDF Draft'
# This should be the path to the paper within your repo.
add: 'paper.pdf' # 'paper/*.pdf' to commit all PDFs in the paper directory
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,22 @@ The Stein score of the marginal probability distributions over $t$ is approximat
For each SDE there exists a deterministic ODE with marginal likelihoods $p_t(\boldsymbol{x})$ that match the SDE for all time $t$

$$
\text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t = F(\boldsymbol{x}(t), t).
\text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t = f'(\boldsymbol{x}(t), t)\text{d}t.
$$

The continuous normalizing flow formalism allows the ODE to be expressed as

$$
\frac{\partial}{\partial t} \log p(\boldsymbol{x}(t)) = -\text{Tr}\bigg [ \frac{\partial}{\partial \boldsymbol{x}(t)} F(\boldsymbol{x}(t), t) \bigg ],
\frac{\partial}{\partial t} \log p(\boldsymbol{x}(t)) = \nabla_{\boldsymbol{x}} \cdot f'(\boldsymbol{x}(t), t),
$$

but note that maximum-likelihood training is prohibitively expensive for SDE based diffusion models.
which gives the log-likelihood of a datapoint $\boldsymbol{x}$ as

$$
\log p(\boldsymbol{x}(0)) = \log p(\boldsymbol{x}(T)) + \int_{t=0}^{t=T}\text{d}t \; \nabla_{\boldsymbol{x}}\cdot f'(\boldsymbol{x}, t).
$$

Note that maximum-likelihood training is prohibitively expensive for SDE based diffusion models.

### Usage

Expand Down Expand Up @@ -112,7 +118,6 @@ model = sbgm.train.train(
sde,
dataset,
config,
reload_opt_state=False,
sharding=sharding,
save_dir=root_dir
)
Expand Down
2 changes: 1 addition & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .moons import moons
from .grfs import grfs
from .quijote import quijote
from .utils import Scaler, ScalerDataset, _InMemoryDataLoader, _TorchDataLoader
from .utils import Scaler, ScalerDataset, InMemoryDataLoader, TorchDataLoader


def get_dataset(
Expand Down
6 changes: 3 additions & 3 deletions data/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jaxtyping import Key
from torchvision import transforms, datasets

from .utils import Scaler, ScalerDataset, _TorchDataLoader
from .utils import Scaler, ScalerDataset, TorchDataLoader


def cifar10(path: str, key: Key) -> ScalerDataset:
Expand Down Expand Up @@ -44,10 +44,10 @@ def cifar10(path: str, key: Key) -> ScalerDataset:
transform=valid_transform
)

train_dataloader = _TorchDataLoader(
train_dataloader = TorchDataLoader(
train_dataset, data_shape, context_shape=None, parameter_dim=parameter_dim, key=key_train
)
valid_dataloader = _TorchDataLoader(
valid_dataloader = TorchDataLoader(
valid_dataset, data_shape, context_shape=None, parameter_dim=parameter_dim, key=key_valid
)

Expand Down
6 changes: 3 additions & 3 deletions data/flowers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jaxtyping import Key
from torchvision import transforms, datasets

from .utils import Scaler, ScalerDataset, _TorchDataLoader
from .utils import Scaler, ScalerDataset, TorchDataLoader


def flowers(key: Key, n_pix: int) -> ScalerDataset:
Expand Down Expand Up @@ -46,8 +46,8 @@ def flowers(key: Key, n_pix: int) -> ScalerDataset:
transform=valid_transform
)

train_dataloader = _TorchDataLoader(train_dataset, key=key_train)
valid_dataloader = _TorchDataLoader(valid_dataset, key=key_valid)
train_dataloader = TorchDataLoader(train_dataset, key=key_train)
valid_dataloader = TorchDataLoader(valid_dataset, key=key_valid)

def label_fn(key, n):
Q = None
Expand Down
6 changes: 3 additions & 3 deletions data/grfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision import transforms
import powerbox

from .utils import Scaler, ScalerDataset, _TorchDataLoader
from .utils import Scaler, ScalerDataset, TorchDataLoader

data_dir = "/project/ls-gruen/users/jed.homer/data/fields/"

Expand Down Expand Up @@ -135,8 +135,8 @@ def grfs(key, n_pix, split=0.5):
valid_dataset = MapDataset(
(X[n_train:], Q[n_train:], A[n_train:]), transform=valid_transform
)
train_dataloader = _TorchDataLoader(train_dataset, key=key_train)
valid_dataloader = _TorchDataLoader(valid_dataset, key=key_valid)
train_dataloader = TorchDataLoader(train_dataset, key=key_train)
valid_dataloader = TorchDataLoader(valid_dataset, key=key_valid)

# Don't have many maps
# train_dataloader = _InMemoryDataLoader(
Expand Down
6 changes: 3 additions & 3 deletions data/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor
from torchvision import datasets

from .utils import Scaler, ScalerDataset, _InMemoryDataLoader
from .utils import Scaler, ScalerDataset, InMemoryDataLoader


def tensor_to_array(tensor: Tensor) -> Array:
Expand Down Expand Up @@ -46,10 +46,10 @@ def mnist(path:str, key: Key) -> ScalerDataset:

# scaler = Normer(train_data.mean(), train_data.std())

train_dataloader = _InMemoryDataLoader(
train_dataloader = InMemoryDataLoader(
train_data, Q=None, A=train_targets, key=key_train
)
valid_dataloader = _InMemoryDataLoader(
valid_dataloader = InMemoryDataLoader(
valid_data, Q=None, A=valid_targets, key=key_valid
)

Expand Down
6 changes: 3 additions & 3 deletions data/moons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.random as jr
from sklearn.datasets import make_moons

from .utils import ScalerDataset, _InMemoryDataLoader
from .utils import ScalerDataset, InMemoryDataLoader


def key_to_seed(key):
Expand Down Expand Up @@ -32,12 +32,12 @@ def moons(key):
train_data = (Xt - mean) / std
valid_data = (Xv - mean) / std

train_dataloader = _InMemoryDataLoader(
train_dataloader = InMemoryDataLoader(
X=jnp.asarray(train_data),
A=jnp.asarray(Yt)[:, jnp.newaxis],
key=key_train
)
valid_dataloader = _InMemoryDataLoader(
valid_dataloader = InMemoryDataLoader(
X=jnp.asarray(valid_data),
A=jnp.asarray(Yv)[:, jnp.newaxis],
key=key_valid
Expand Down
18 changes: 9 additions & 9 deletions data/quijote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch
from torchvision import transforms

from .utils import Scaler, ScalerDataset, _TorchDataLoader, _InMemoryDataLoader
from .utils import Scaler, ScalerDataset, TorchDataLoader, InMemoryDataLoader

data_dir = "/project/ls-gruen/users/jed.homer/data/fields/"
DATA_DIR = "/project/ls-gruen/users/jed.homer/data/fields/"


class MapDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -39,16 +39,16 @@ def __len__(self):


def get_quijote_data(n_pix: int) -> Tuple[Array, Array]:
X = np.load(os.path.join(data_dir, "quijote_fields.npy"))[:, np.newaxis, ...]
A = np.load(os.path.join(data_dir, "quijote_parameters.npy"))
X = np.load(os.path.join(DATA_DIR, "quijote_fields.npy"))[:, np.newaxis, ...]
A = np.load(os.path.join(DATA_DIR, "quijote_parameters.npy"))

dx = int(256 / n_pix)
X = X.reshape((-1, 1, n_pix, dx, n_pix, dx)).mean(axis=(3, 5))
return X, A


def get_quijote_labels() -> Array:
Q = np.load(os.path.join(data_dir, "quijote_parameters.npy"))
Q = np.load(os.path.join(DATA_DIR, "quijote_parameters.npy"))
return Q


Expand Down Expand Up @@ -91,14 +91,14 @@ def quijote(key, n_pix, split=0.5):
valid_dataset = MapDataset(
(X[n_train:], A[n_train:]), transform=valid_transform
)
# train_dataloader = _TorchDataLoader(
# train_dataloader = TorchDataLoader(
# train_dataset,
# data_shape=data_shape,
# context_shape=None,
# parameter_dim=parameter_dim,
# key=key_train
# )
# valid_dataloader = _TorchDataLoader(
# valid_dataloader = TorchDataLoader(
# valid_dataset,
# data_shape=data_shape,
# context_shape=None,
Expand All @@ -107,10 +107,10 @@ def quijote(key, n_pix, split=0.5):
# )

# Don't have many maps
train_dataloader = _InMemoryDataLoader(
train_dataloader = InMemoryDataLoader(
X=X[:n_train], A=A[:n_train], key=key_train
)
valid_dataloader = _InMemoryDataLoader(
valid_dataloader = InMemoryDataLoader(
X=X[n_train:], A=A[n_train:], key=key_valid
)

Expand Down
10 changes: 5 additions & 5 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Tuple, Union, NamedTuple, Callable
from typing import Tuple, Union, Callable
from dataclasses import dataclass
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -27,7 +27,7 @@ def loop(self, batch_size):
pass


class _InMemoryDataLoader(_AbstractDataLoader):
class InMemoryDataLoader(_AbstractDataLoader):
def __init__(self, X, Q=None, A=None, *, key):
self.X = X
self.Q = Q
Expand Down Expand Up @@ -58,7 +58,7 @@ def loop(self, batch_size):
end = start + batch_size


class _TorchDataLoader(_AbstractDataLoader):
class TorchDataLoader(_AbstractDataLoader):
def __init__(
self,
dataset,
Expand Down Expand Up @@ -132,8 +132,8 @@ def __init__(self, x_mean=0., x_std=1.):
@dataclass
class ScalerDataset:
name: str
train_dataloader: Union[_TorchDataLoader | _InMemoryDataLoader]
valid_dataloader: Union[_TorchDataLoader | _InMemoryDataLoader]
train_dataloader: Union[TorchDataLoader | InMemoryDataLoader]
valid_dataloader: Union[TorchDataLoader | InMemoryDataLoader]
data_shape: Tuple[int]
context_shape: Tuple[int]
parameter_dim: int
Expand Down
Loading

0 comments on commit 9bdbc69

Please sign in to comment.