Skip to content

Commit

Permalink
Merge branch 'benchmark_v1_workflow' of github.com:pawel-czyz/bmi int…
Browse files Browse the repository at this point in the history
…o benchmark_v1_workflow
  • Loading branch information
pawel-czyz committed Mar 5, 2024
2 parents 0d46b8d + 6701a9d commit 931f299
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.11", "3.12"]
poetry-version: ["1.3.2"]

steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
cache: "pip"
- name: Install Poetry
uses: snok/install-poetry@v1
Expand Down
7 changes: 2 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@ readme = "README.md"
packages = [{include = "bmi", from = "src"}]

[tool.poetry.dependencies]
# <3.11 because of PyType. Update when it's resolved
# <3.12 because of SciPy. Update when it's resolved
python = ">=3.9,<3.11"
python = ">=3.9,<3.13"
equinox = "^0.10.2"
jax = "^0.4.8"
jaxlib = "^0.4.7"
numpy = "^1.24.2"
scikit-learn = "^1.2.2"
optax = "^0.1.4"
# Pandas <2.0.0 to ensure compatibility. Update when we validate that it works
pandas = "<2.0.0"
pandas = "^1.5.3"
pydantic = "^1.10.7"
pyyaml = "^6.0"
scipy = "^1.10.1"
Expand Down
14 changes: 10 additions & 4 deletions workflows/benchmark/_common_benchmark_rules.smk
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ TASKS_DICT = {
}


rule benchmark_table:
rule benchmark_tables:
input: 'results.csv'
output: 'benchmark.html'
output: 'benchmark.html', 'benchmark_converged_only.html', 'benchmark_convergance.html'
run:
results = utils.read_results(str(input))
table = utils.create_benchmark_table(results)
table.to_html(str(output))
table = utils.create_benchmark_table(results, converged_only=False)
table.to_html('benchmark.html')

table = utils.create_benchmark_table(results, converged_only=True)
table.to_html('benchmark_converged_only.html')

table = utils.create_convergance_table(results)
table.to_html('benchmark_convergance.html')


# Gather all results into one CSV file
Expand Down
81 changes: 76 additions & 5 deletions workflows/benchmark/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,35 @@ def read_results(
return results


def create_benchmark_table(results, n_samples=None):
def add_col_log_relative_error(results):
with np.errstate(all="ignore"):
results["log_relative_error"] = np.log2(results["mi_estimate"] / results["mi_true"])

return results


def add_col_converged(results):
mi_estimate_maxs = results.groupby(["estimator_id", "task_id", "n_samples"])[
"mi_estimate"
].transform("max")
results["converged"] = results["mi_estimate"] > 0.1 * mi_estimate_maxs
return results


def create_benchmark_table(results, n_samples=None, converged_only=False):
if n_samples is None:
n_samples = results["n_samples"].max()

data = results[results["n_samples"] == n_samples]
data = results[["estimator_id", "task_id", "mi_estimate", "mi_true"]].copy()
data = results[["estimator_id", "task_id", "n_samples", "mi_estimate", "mi_true"]].copy()

# relative_error
with np.errstate(all="ignore"):
data["log_relative_error"] = np.log2(data["mi_estimate"] / data["mi_true"])
data = add_col_log_relative_error(data)

# TODO(frdrc): filter out convergence failures
# convergence
if converged_only:
data = add_col_converged(data)
data = data[data["converged"]]

# mean over seeds
data = (
Expand Down Expand Up @@ -86,10 +103,64 @@ def make_pretty(styler):
gmap=table_err,
axis=None,
)
if converged_only:
styler.set_caption(
"Estimates smaller than 10% of the maximal estimate are excluded."
"This can help neural estimators which sometimes fail to converge."
)
return styler

table_pretty = table_mi.style.pipe(make_pretty)
table_pretty.index.name = None
table_pretty.columns.name = None

return table_pretty


def create_convergance_table(results, n_samples=None):
if n_samples is None:
n_samples = results["n_samples"].max()

data = results[results["n_samples"] == n_samples]
data = results[["estimator_id", "task_id", "n_samples", "mi_estimate"]].copy()

data = add_col_converged(data)
data["not_converged"] = ~data["converged"]

# sum over seeds
data = (
data.groupby(["estimator_id", "task_id"])[["converged", "not_converged"]]
.sum()
.reset_index()
)

# compute fraction
with np.errstate(all="ignore"):
data["converged_fraction"] = data["converged"] / (
data["converged"] + data["not_converged"]
)

# create table of results
table = data.pivot(index="task_id", columns="estimator_id", values="converged_fraction")

def make_pretty(styler):
styler.format(lambda x: f"{x:.1%}" if not pd.isna(x) else "?")
styler.set_table_styles(
[{"selector": "td", "props": "text-align: center; min-width: 5em;"}]
)
styler.background_gradient(
vmin=0.0,
vmax=1.0,
cmap="gray",
)
styler.set_caption(
"Estimates higher than 10% of the maximal estimate are considered"
"converged. This table shows the percentage of converged estimates."
)
return styler

table_pretty = table.style.pipe(make_pretty)
table_pretty.index.name = None
table_pretty.columns.name = None

return table_pretty

0 comments on commit 931f299

Please sign in to comment.