Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace CSVs with chi2.ppf for GLRT test #456

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 9 additions & 54 deletions src/dolphin/shp/_glrt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations

import csv
from functools import lru_cache
from math import log
from pathlib import Path
from typing import Optional

import numba
import numpy as np
from numpy.typing import ArrayLike
from scipy import stats

from dolphin._types import Strides
from dolphin.utils import _get_slices, compute_out_shape
Expand Down Expand Up @@ -78,7 +76,8 @@ def estimate_neighbors(
half_row, half_col = halfwin_rowcol
rows, cols = mean.shape

threshold = get_cutoff(alpha=alpha, N=nslc)
# 1 Degree of freedom, regardless of N
threshold = stats.chi2.ppf(1 - alpha, df=1)

strides_rowcol = (strides["y"], strides["x"])
out_rows, out_cols = compute_out_shape((rows, cols), Strides(*strides_rowcol))
Expand All @@ -88,6 +87,7 @@ def estimate_neighbors(
return _loop_over_pixels(
mean,
var,
nslc,
halfwin_rowcol,
strides_rowcol,
threshold,
Expand All @@ -96,47 +96,18 @@ def estimate_neighbors(
)


def get_cutoff(alpha: float, N: int) -> float:
r"""Compute the upper cutoff for the GLRT test statistic.

Statistic is

\[
2\log(\sigma_{pooled}) - \log(\sigma_{p}) -\log(\sigma_{q})
\]

Parameters
----------
alpha: float
Significance level (0 < alpha < 1).
N: int
Number of samples.

Returns
-------
float
Cutoff value for the GLRT test statistic.

"""
n_alpha_to_cutoff = _read_cutoff_csv()
try:
return n_alpha_to_cutoff[(N, alpha)]
except KeyError as e:
msg = f"Not implemented for {N = }, {alpha = }"
raise NotImplementedError(msg) from e


@numba.njit(nogil=True)
def _compute_glrt_test_stat(scale_1, scale_2):
def _compute_glrt_test_stat(scale_sq_1, scale_sq_2, N):
"""Compute the GLRT test statistic."""
scale_pooled = (scale_1 + scale_2) / 2
return 2 * log(scale_pooled) - log(scale_1) - log(scale_2)
scale_pooled = (scale_sq_1 + scale_sq_2) / 2
return N * (2 * log(scale_pooled) - log(scale_sq_1) - log(scale_sq_2))


@numba.njit(nogil=True, parallel=True)
def _loop_over_pixels(
mean: ArrayLike,
var: ArrayLike,
N: int,
halfwin_rowcol: tuple[int, int],
strides_rowcol: tuple[int, int],
threshold: float,
Expand Down Expand Up @@ -180,27 +151,11 @@ def _loop_over_pixels(
continue
scale_2 = scale_squared[in_r2, in_c2]

T = _compute_glrt_test_stat(scale_1, scale_2)
T = _compute_glrt_test_stat(scale_1, scale_2, N)

is_shp[out_r, out_c, r_off, c_off] = threshold > T
if prune_disconnected:
# For this pixel, prune the groups not connected to the center
remove_unconnected(is_shp[out_r, out_c], inplace=True)

return is_shp


@lru_cache
def _read_cutoff_csv():
filename = Path(__file__).parent / "glrt_cutoffs.csv"

result = {}
with open(filename) as file:
reader = csv.DictReader(file)
for row in reader:
n = int(row["N"])
alpha = float(row["alpha"])
cutoff = float(row["cutoff"])
result[(n, alpha)] = cutoff

return result
Loading