Skip to content

Commit

Permalink
Moved np.trapz -> scipy.integrate.trapezoid. Updated version requirem… (
Browse files Browse the repository at this point in the history
#96)

* Moved np.trapz -> scipy.integrate.trapezoid. Updated version requirements. Some linting.

* Bumped python version array in github action.

* Github action version numbers now strings.
  • Loading branch information
jfcrenshaw authored Feb 26, 2024
1 parent 9d37be4 commit f306c4b
Show file tree
Hide file tree
Showing 14 changed files with 2,463 additions and 1,048 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
fail-fast: true
matrix:
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
python-version: ["3.9", "3.10", "3.11"]
runs-on: ${{ matrix.os }}

steps:
Expand All @@ -37,7 +37,7 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'
cache: "poetry"

# ----------------------------------------------
# install root project, if required
Expand Down
3,404 changes: 2,411 additions & 993 deletions poetry.lock

Large diffs are not rendered by default.

60 changes: 30 additions & 30 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pzflow"
version = "3.1.2"
version = "3.1.3"
description = "Probabilistic modeling of tabular data with normalizing flows."
authors = ["John Franklin Crenshaw <[email protected]>"]
license = "MIT"
Expand All @@ -9,38 +9,38 @@ readme = "README.md"


[tool.poetry.dependencies]
python = "^3.8.1"
jax = "^0.4.2"
jaxlib = "^0.4.2"
optax = "^0.1.4"
python = "^3.9"
jax = ">=0.4.16"
jaxlib = ">=0.4.16"
optax = ">=0.1.4"
pandas = ">=1.1"
dill = "^0.3.6"
tqdm = "^4.64.1"
dill = ">=0.3.6"
tqdm = ">=4.64.1"

[tool.poetry.dev-dependencies]
pre-commit = "^3.0.3"
black = "^23.1.0"
mypy = "^0.931"
isort = "^5.12.0"
flake8 = "^6.0.0"
flake8-bugbear = "^23.1.20"
flake8-builtins = "^2.1.0"
flake8-comprehensions = "^3.10.1"
flake8-docstrings = "^1.7.0"
flake8-isort = "^6.0.0"
flake8-markdown = "^0.4.0"
flake8-print = "^5.0.0"
flake8-pytest-style = "^1.6.0"
flake8-simplify = "^0.19.3"
flake8-tidy-imports = "^4.8.0"
pandas-vet = "^0.2.3"
pep8-naming = "^0.13.3"
pytest = "^7.2.1"
pytest-cov = "^4.0.0"
pytest-xdist = "^3.1.0"
jupyter = "^1.0.0"
matplotlib = "^3.6.3"
toml = "^0.10.2"
pre-commit = ">=3.0.3"
black = ">=23.1.0"
mypy = ">=0.931"
isort = ">=5.12.0"
flake8 = ">=6.0.0"
flake8-bugbear = ">=23.1.20"
flake8-builtins = ">=2.1.0"
flake8-comprehensions = ">=3.10.1"
flake8-docstrings = ">=1.7.0"
flake8-isort = ">=6.0.0"
flake8-markdown = ">=0.4.0"
flake8-print = ">=5.0.0"
flake8-pytest-style = ">=1.6.0"
flake8-simplify = ">=0.19.3"
flake8-tidy-imports = ">=4.8.0"
pandas-vet = ">=0.2.3"
pep8-naming = ">=0.13.3"
pytest = ">=7.2.1"
pytest-cov = ">=4.0.0"
pytest-xdist = ">=3.1.0"
jupyter = ">=1.0.0"
matplotlib = ">=3.6.3"
toml = ">=0.10.2"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
2 changes: 1 addition & 1 deletion pzflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from pzflow.flow import Flow
from pzflow.flowEnsemble import FlowEnsemble

__version__ = "3.1.2"
__version__ = "3.1.3"
4 changes: 0 additions & 4 deletions pzflow/bijectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def Chain(

@InitFunction
def init_fun(rng, input_dim, **kwargs):

all_params, forward_funs, inverse_funs = [], [], []
for init_f in init_funs:
rng, layer_rng = random.split(rng)
Expand Down Expand Up @@ -237,7 +236,6 @@ def ColorTransform(

@InitFunction
def init_fun(rng, input_dim, **kwargs):

# array of all the indices
all_idx = jnp.arange(input_dim)
# indices for columns to stick at the front
Expand Down Expand Up @@ -462,7 +460,6 @@ def NeuralSplineCoupling(

@InitFunction
def init_fun(rng, input_dim, **kwargs):

if transformed_dim is None:
upper_dim = input_dim // 2 # variables that determine NN params
lower_dim = (
Expand Down Expand Up @@ -797,7 +794,6 @@ def Shuffle() -> Tuple[InitFunction, Bijector_Info]:

@InitFunction
def init_fun(rng, input_dim, **kwargs):

perm = random.permutation(rng, jnp.arange(input_dim))
inv_perm = jnp.argsort(perm)

Expand Down
2 changes: 1 addition & 1 deletion pzflow/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def sample(
)
return 2 * self.B * (samples - 0.5)


class CentBeta13(LatentDist):
"""A centered Beta distribution with alpha, beta = 13.
Expand Down Expand Up @@ -228,7 +229,6 @@ def sample(
return 2 * self.B * (samples - 0.5)



class Normal(LatentDist):
"""A multivariate Gaussian distribution with mean zero and unit variance.
Expand Down
4 changes: 3 additions & 1 deletion pzflow/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def get_example_flow() -> Flow:
For more info: `print(example_flow().info)`.
"""
this_dir, _ = os.path.split(__file__)
flow_path = os.path.join(this_dir, f"{EXAMPLE_FILE_DIR}/example-flow.pzflow.pkl")
flow_path = os.path.join(
this_dir, f"{EXAMPLE_FILE_DIR}/example-flow.pzflow.pkl"
)
flow = Flow(file=flow_path)
return flow
3 changes: 2 additions & 1 deletion pzflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import optax
import pandas as pd
from jax import grad, jit, random
from jax.scipy.integrate import trapezoid
from tqdm import tqdm

from pzflow import distributions
Expand Down Expand Up @@ -731,7 +732,7 @@ def check_flags(data):

if normalize:
# normalize so they integrate to one
pdfs = pdfs / jnp.trapz(y=pdfs, x=grid).reshape(-1, 1)
pdfs = pdfs / trapezoid(y=pdfs, x=grid).reshape(-1, 1)
if nan_to_zero:
# set NaN's equal to zero probability
pdfs = jnp.nan_to_num(pdfs, nan=0.0)
Expand Down
6 changes: 4 additions & 2 deletions pzflow/flowEnsemble.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Define FlowEnsemble object that holds an ensemble of normalizing flows."""

from typing import Any, Callable, Sequence, Tuple

import dill as pickle
import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import random
from jax.scipy.integrate import trapezoid

from pzflow import Flow, distributions
from pzflow.bijectors import Bijector_Info, InitFunction
Expand Down Expand Up @@ -336,7 +338,7 @@ def posterior(
# return the ensemble of posteriors
if normalize:
ensemble = ensemble.reshape(-1, grid.size)
ensemble = ensemble / jnp.trapz(y=ensemble, x=grid).reshape(
ensemble = ensemble / trapezoid(y=ensemble, x=grid).reshape(
-1, 1
)
ensemble = ensemble.reshape(inputs.shape[0], -1, grid.size)
Expand All @@ -345,7 +347,7 @@ def posterior(
# return mean over ensemble
pdfs = ensemble.mean(axis=1)
if normalize:
pdfs = pdfs / jnp.trapz(y=pdfs, x=grid).reshape(-1, 1)
pdfs = pdfs / trapezoid(y=pdfs, x=grid).reshape(-1, 1)
return pdfs

def sample(
Expand Down
4 changes: 1 addition & 3 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
(Joint, (Normal(1), Uniform(1, 4)), ((), ())),
(Joint, (Normal(1), Tdist(1)), ((), jnp.log(30.0))),
(Joint, (Joint(Normal(1), Uniform(1)).info[1]), ((), ())),
(CentBeta13, (2, 4), ())
(CentBeta13, (2, 4), ()),
],
)
class TestDistributions:
def test_returns_correct_shapes(self, distribution, inputs, params):

dist = distribution(*inputs)

nsamples = 8
Expand All @@ -30,7 +29,6 @@ def test_returns_correct_shapes(self, distribution, inputs, params):
assert log_prob.shape == (nsamples,)

def test_control_sample_randomness(self, distribution, inputs, params):

dist = distribution(*inputs)

nsamples = 8
Expand Down
9 changes: 2 additions & 7 deletions tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import pytest
from jax import random
from jax.scipy.integrate import trapezoid

from pzflow import Flow, FlowEnsemble
from pzflow.bijectors import Reverse, RollingSplineCoupling
Expand All @@ -17,7 +18,6 @@


def test_log_prob():

lpEns = flowEns.log_prob(x, returnEnsemble=True)
assert lpEns.shape == (3, 2)

Expand All @@ -36,7 +36,6 @@ def test_log_prob():


def test_posterior():

grid = jnp.linspace(-1, 1, 5)

pEns = flowEns.posterior(x, "x", grid, returnEnsemble=True)
Expand All @@ -53,12 +52,11 @@ def test_posterior():
p0 = flow0.posterior(x, "x", grid, normalize=False)
p1 = flow1.posterior(x, "x", grid, normalize=False)
manualMean = (p0 + p1) / 2
manualMean = manualMean / jnp.trapz(y=manualMean, x=grid).reshape(-1, 1)
manualMean = manualMean / trapezoid(y=manualMean, x=grid).reshape(-1, 1)
assert jnp.allclose(pEnsMean, manualMean)


def test_sample():

# first test everything with returnEnsemble=False
sEns = flowEns.sample(10, seed=0).values
assert sEns.shape == (10, 2)
Expand All @@ -81,7 +79,6 @@ def test_sample():


def test_conditional_sample():

cEns = FlowEnsemble(
("x", "y"),
RollingSplineCoupling(nlayers=2, n_conditions=2),
Expand Down Expand Up @@ -122,7 +119,6 @@ def test_conditional_sample():


def test_train():

data = random.normal(random.PRNGKey(0), shape=(100, 2))
data = pd.DataFrame(np.array(data), columns=("x", "y"))

Expand All @@ -138,7 +134,6 @@ def test_train():


def test_load_ensemble(tmp_path):

flowEns = FlowEnsemble(("x", "y"), RollingSplineCoupling(nlayers=2), N=2)

preSave = flowEns.sample(10, seed=0)
Expand Down
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_get_city_data():
assert isinstance(data, pd.DataFrame)
assert data.shape == (47_966, 5)


def test_get_checkerboard_data():
data = examples.get_checkerboard_data()
assert isinstance(data, pd.DataFrame)
Expand Down
1 change: 0 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


def test_build_bijector_from_info():

x = jnp.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]])

init_fun, info1 = Chain(
Expand Down
7 changes: 5 additions & 2 deletions tests/test_version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import toml
from pathlib import Path

import toml

import pzflow


def test_versions_are_in_sync():
"""Checks if the pyproject.toml and pzflow.__init__.py __version__ are in sync."""

Expand All @@ -10,5 +13,5 @@ def test_versions_are_in_sync():
pyproject_version = pyproject["tool"]["poetry"]["version"]

package_init_version = pzflow.__version__

assert package_init_version == pyproject_version

0 comments on commit f306c4b

Please sign in to comment.