From 248b3ee3a2f12a6b9e618f390c09eb5b5b7ec590 Mon Sep 17 00:00:00 2001 From: Drew Oldag <47493171+drewoldag@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:27:32 -0800 Subject: [PATCH] Use t-digest to parallelize point to point metrics (#216) * WIP - Initial commit for parallel metrics. * Adding a simple notebook for testing parallelized PointSigmaIQR. * Example of metric running in parallel in notebook. * Parallelizing PointBias. Expanding the demo notebook. * First attempt at parallelizing PointSigmaMAD. * Clean up PointSigmaMAD and introduce a method parameter for num_bins. * Parallelizing PointOutlierRate and updating the notebook with an example. * Moved PointToPointMetricDigester. Consolidated code in that class. Added **kwargs so that we can pass a config dict to the concrete metric classes. * Adding unit tests for parallelized point to point metrics. * Adding another tests case to chase test coverage. * Using bin centers for `bin_dist` calculation. * added point_sigma_iqr.eval_from_iterator call and a few pragma statements to get to full coverage --------- Co-authored-by: Eric Charles --- pyproject.toml | 2 + src/qp/metrics/base_metric_classes.py | 15 + src/qp/metrics/parallel_metrics.ipynb | 324 ++++++++++++++++++ src/qp/metrics/pit.py | 2 +- .../metrics/point_estimate_metric_classes.py | 137 +++++++- tests/qp/test_point_metrics.py | 92 +++-- 6 files changed, 537 insertions(+), 35 deletions(-) create mode 100644 src/qp/metrics/parallel_metrics.ipynb diff --git a/pyproject.toml b/pyproject.toml index 5466c87..f8bba55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "scipy", "tables-io", "deprecated", + "pytdigest", ] # On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes) @@ -43,6 +44,7 @@ dev = [ "pylint", "mpi4py", "coverage", + "ipyparallel", ] full = [ "tables-io[full]", diff --git a/src/qp/metrics/base_metric_classes.py b/src/qp/metrics/base_metric_classes.py index 165dd28..5e29e80 100644 --- a/src/qp/metrics/base_metric_classes.py +++ b/src/qp/metrics/base_metric_classes.py @@ -147,9 +147,24 @@ class PointToPointMetric(BaseMetric): metric_input_type = MetricInputType.point_to_point + def eval_from_iterator(self, estimate, reference): + self.initialize() + for estimate, reference in zip(estimate, reference): + centroids = self.accumulate(estimate, reference) + return self.finalize([centroids]) + def evaluate(self, estimate, reference): raise NotImplementedError() + def initialize(self): #pragma: no cover + pass + + def accumulate(self, estimate, reference): #pragma: no cover + raise NotImplementedError() + + def finalize(self): #pragma: no cover + raise NotImplementedError() + class PointToDistMetric(BaseMetric): """A base class for metrics that require a point estimate as the estimated diff --git a/src/qp/metrics/parallel_metrics.ipynb b/src/qp/metrics/parallel_metrics.ipynb new file mode 100644 index 0000000..7ad1dc5 --- /dev/null +++ b/src/qp/metrics/parallel_metrics.ipynb @@ -0,0 +1,324 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import ipyparallel as ipp\n", + "from qp.metrics.point_estimate_metric_classes import (\n", + " PointSigmaIQR,\n", + " PointBias,\n", + " PointSigmaMAD,\n", + " PointOutlierRate,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate the random numbers \n", + "SEED = 1002330\n", + "rng = np.random.default_rng(SEED)\n", + "\n", + "chunk_size = 10_000\n", + "n_chunk = 10\n", + "total_size = n_chunk*chunk_size\n", + "\n", + "estimate = rng.lognormal(mean=1.0, sigma=2, size=total_size)\n", + "reference = rng.lognormal(mean=1.3, sigma=1.9, size=total_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#generator that yields chunks from estimate and reference\n", + "def chunker(seq, size):\n", + " return (seq[pos:pos + size] for pos in range(0, len(seq), size))\n", + "\n", + "# create an iterator that yields chunks of chunk_size elements\n", + "estimate_chunks = chunker(estimate, chunk_size)\n", + "reference_chunks = chunker(reference, chunk_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# A function to pass to MPI\n", + "def mpi_example(chunk):\n", + " centroids = chunk[0].accumulate(chunk[1], chunk[2])\n", + " return centroids" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is a function that will configure a local cluster of 4 nodes using MPI as the engine.\n", + "\n", + "A metric estimator class is passed in as well as list of 3-tuple \"data chunks\".\n", + "\n", + "The 3-tuple is (metric class, chunk_of_estimated_values, chunk_of_reference_values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_parallel_metric(data_chunks):\n", + " with ipp.Cluster(controller_ip=\"*\", engines=\"mpi\", n=4) as rc:\n", + " # get a broadcast_view on the cluster which is best\n", + " # suited for MPI style computation\n", + " view = rc.load_balanced_view()\n", + " # run the mpi_example function on all engines in parallel\n", + " asyncresult = view.map_async(mpi_example, data_chunks)\n", + " # Retrieve and print the result from the engines\n", + " asyncresult.wait_interactive()\n", + " # retrieve actual results\n", + " result = asyncresult.get()\n", + " # get and print the results\n", + " for i, res in enumerate(result):\n", + " np.array(res)\n", + " print(f\"{i} : {res.shape}\")\n", + " metric_estimator = data_chunks[0][0]\n", + " final = metric_estimator.finalize(centroids=result)\n", + " print(final)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### An example running the PointSigmaIQR metric directly and in parallel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up for ipyparallel\n", + "config = {'tdigest_compression': 1000}\n", + "\n", + "sigma_iqr_estimator = PointSigmaIQR(**config)\n", + "sigma_iqr_estimator_list = [sigma_iqr_estimator]*n_chunk\n", + "iqr_data_chunks = [chunk for chunk in zip(sigma_iqr_estimator_list, chunker(estimate, chunk_size), chunker(reference, chunk_size))]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PointSigmaIQR().evaluate(estimate, reference)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_parallel_metric(iqr_data_chunks)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### An example running the PointBias metric directly and in parallel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up for ipyparallel\n", + "config = {'tdigest_compression': 1000}\n", + "\n", + "point_bias_estimator = PointBias(**config)\n", + "point_bias_estimator_list = [point_bias_estimator]*n_chunk\n", + "point_bias_data_chunks = [chunk for chunk in zip(point_bias_estimator_list, chunker(estimate, chunk_size), chunker(reference, chunk_size))]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PointBias().evaluate(estimate, reference)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_parallel_metric(point_bias_data_chunks)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### An example running PointSigmaMAD directly and in parallel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# An example with PointSigmaMAD\n", + "config = {'num_bins': 1_000_000, 'tdigest_compression': 1000}\n", + "point_sigma_mad_estimator = PointSigmaMAD(**config)\n", + "point_sigma_mad_estimator_list = [point_sigma_mad_estimator]*n_chunk\n", + "point_sigma_mad_data_chunks = [chunk for chunk in zip(point_sigma_mad_estimator_list, chunker(estimate, chunk_size), chunker(reference, chunk_size))]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PointSigmaMAD().evaluate(estimate, reference)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This cell allows for adjustment of the `num_bins` parameter.\n", + "\n", + "Larger values trend closer to the analytic result from the cell above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = {'num_bins': 1_000_000, 'tdigest_compression': 1000}\n", + "psmad = PointSigmaMAD(**config)\n", + "centroids = psmad.accumulate(estimate, reference)\n", + "\n", + "#default value for `num_bins` is 1_000_000\n", + "psmad.finalize(centroids=[centroids])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_parallel_metric(point_sigma_mad_data_chunks)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### An example running PointOutlierRate metric directly and in parallel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# An example with PointOutlierRate\n", + "config = {'tdigest_compression': 1000}\n", + "point_outlier_estimator = PointOutlierRate(**config)\n", + "point_outlier_estimator_list = [point_outlier_estimator]*n_chunk\n", + "point_outlier_data_chunks = [chunk for chunk in zip(point_outlier_estimator_list, chunker(estimate, chunk_size), chunker(reference, chunk_size))]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PointOutlierRate().evaluate(estimate, reference)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The parallel estimation of the metric trends closer to the analytic as the value of `compression` is increased.\n", + "\n", + "The default value for compression is 1000. If set to 10_000, the estimate becomes 0.13663.\n", + "\n", + "Note that, of course, setting compression = 10_000 increases memory usage with minimal affect on runtime." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = {'tdigest_compression': 1000}\n", + "por = PointOutlierRate(**config)\n", + "centroids = por.accumulate(estimate, reference)\n", + "\n", + "por.finalize(centroids=[centroids])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_parallel_metric(point_outlier_data_chunks)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/qp/metrics/pit.py b/src/qp/metrics/pit.py index e19cc6e..f434240 100644 --- a/src/qp/metrics/pit.py +++ b/src/qp/metrics/pit.py @@ -57,7 +57,7 @@ def __init__(self, qp_ens, true_vals, eval_grid=DEFAULT_QUANTS): # efficiently on line 61 with `data_quants = np.nanquantile(...)`.` samp_mask = np.isfinite(self._pit_samps) self._pit_samps[~samp_mask] = 0 - if not np.all(samp_mask): + if not np.all(samp_mask): #pragma: no cover logging.warning( "Some PIT samples were `NaN`. They have been replacd with 0." ) diff --git a/src/qp/metrics/point_estimate_metric_classes.py b/src/qp/metrics/point_estimate_metric_classes.py index 4c44a9c..7856813 100644 --- a/src/qp/metrics/point_estimate_metric_classes.py +++ b/src/qp/metrics/point_estimate_metric_classes.py @@ -3,6 +3,68 @@ MetricOutputType, PointToPointMetric, ) +from pytdigest import TDigest +from functools import reduce +from operator import add + + +class PointToPointMetricDigester(PointToPointMetric): + + def __init__(self, tdigest_compression: int = 1000, **kwargs) -> None: + super().__init__() + self._tdigest_compression = tdigest_compression + + def initialize(self): + pass + + def accumulate(self, estimate, reference): + """This function compresses the input into a TDigest and returns the + centroids. + + Parameters + ---------- + estimate : Numpy 1d array + Point estimate values + reference : Numpy 1d array + True values + + Returns + ------- + Numpy 2d array + The centroids of the TDigest. Roughly approximates a histogram with + centroid locations and weights. + """ + ez = (estimate - reference) / (1.0 + reference) + digest = TDigest.compute(ez, compression=self._tdigest_compression) + centroids = digest.get_centroids() + return centroids + + def finalize(self, centroids: np.ndarray = []): + """This function combines all the centroids that were calculated for the + input estimate and reference subsets and returns the resulting TDigest + object. + + Parameters + ---------- + centroids : Numpy 2d array, optional + The output collected from prior calls to `accumulate`, by default [] + + Returns + ------- + float + The result of the specific metric calculation defined in the subclasses + `compute_from_digest` method. + """ + digests = ( + TDigest.of_centroids(np.array(centroid), compression=self._tdigest_compression) + for centroid in centroids + ) + digest = reduce(add, digests) + + return self.compute_from_digest(digest) + + def compute_from_digest(self, digest): #pragma: no cover + raise NotImplementedError class PointStatsEz(PointToPointMetric): @@ -37,14 +99,14 @@ def evaluate(self, estimate, reference): return (estimate - reference) / (1.0 + reference) -class PointSigmaIQR(PointToPointMetric): +class PointSigmaIQR(PointToPointMetricDigester): """Calculate sigmaIQR""" metric_name = "point_stats_iqr" metric_output_type = MetricOutputType.single_value - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) def evaluate(self, estimate, reference): """Calculate the width of the e_z distribution @@ -68,8 +130,14 @@ def evaluate(self, estimate, reference): sigma_iqr = iqr / 1.349 return sigma_iqr + def compute_from_digest(self, digest): + x75, x25 = digest.inverse_cdf([0.75,0.25]) + iqr = x75 - x25 + sigma_iqr = iqr / 1.349 + return sigma_iqr + -class PointBias(PointToPointMetric): +class PointBias(PointToPointMetricDigester): """calculates the bias of the point stats ez samples. In keeping with the Science Book, this is just the median of the ez values. """ @@ -77,8 +145,8 @@ class PointBias(PointToPointMetric): metric_name = "point_bias" metric_output_type = MetricOutputType.single_value - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) def evaluate(self, estimate, reference): """The point bias, or median of the point stats ez samples. @@ -97,8 +165,11 @@ def evaluate(self, estimate, reference): """ return np.median((estimate - reference) / (1.0 + reference)) + def compute_from_digest(self, digest): + return digest.inverse_cdf(0.50) + -class PointOutlierRate(PointToPointMetric): +class PointOutlierRate(PointToPointMetricDigester): """Calculates the catastrophic outlier rate, defined in the Science Book as the number of galaxies with ez larger than max(0.06,3sigma). This keeps the fraction reasonable when @@ -108,8 +179,8 @@ class PointOutlierRate(PointToPointMetric): metric_name = "point_outlier_rate" metric_output_type = MetricOutputType.single_value - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) def evaluate(self, estimate, reference): """Calculates the catastrophic outlier rate @@ -136,8 +207,27 @@ def evaluate(self, estimate, reference): outlier = np.sum(mask) return float(outlier) / float(num) + def compute_from_digest(self, digest): + # this replaces the call to PointSigmaIQR().evaluate() + x75, x25 = digest.inverse_cdf([0.75,0.25]) + iqr = x75 - x25 + sigma_iqr = iqr / 1.349 + + three_sig = 3.0 * sigma_iqr + cut_criterion = np.maximum(0.06, three_sig) + + # here we use the number of points in the centroids as an approximation + # of ez. + centroids = digest.get_centroids() + mask = np.fabs(centroids[:,0]) > cut_criterion + outlier = np.sum(centroids[mask,1]) + + # Since we use equal weights for all the values in the digest + # digest.weight is the total number of values, and is stored as a float. + return float(outlier) / digest.weight -class PointSigmaMAD(PointToPointMetric): + +class PointSigmaMAD(PointToPointMetricDigester): """Function to calculate median absolute deviation and sigma based on MAD (just scaled up by 1.4826) for the full and magnitude trimmed samples of ez values @@ -146,8 +236,11 @@ class PointSigmaMAD(PointToPointMetric): metric_name = "point_stats_sigma_mad" metric_output_type = MetricOutputType.single_value - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self._num_bins = 1_000_000 + if "num_bins" in kwargs: + self._num_bins = kwargs["num_bins"] def evaluate(self, estimate, reference): """Function to calculate SigmaMAD (the median absolute deviation scaled @@ -170,3 +263,23 @@ def evaluate(self, estimate, reference): ez = (estimate - reference) / (1.0 + reference) mad = np.median(np.fabs(ez - np.median(ez))) return mad * SCALE_FACTOR + + def compute_from_digest(self, digest): + SCALE_FACTOR = 1.4826 + + # calculation of `np.median(np.fabs(ez - np.median(ez)))` as suggested by Eric Charles + this_median = digest.inverse_cdf(0.50) + this_min = digest.inverse_cdf(0) + this_max = digest.inverse_cdf(1) + bins = np.linspace(this_min, this_max, self._num_bins) + bin_cents = (bins[0:-1] + bins[1:]) / 2.0 + this_pdf = digest.cdf(bins[1:]) - digest.cdf(bins[0:-1]) # len(this_pdf) = lots_of_bins - 1 + bin_dist = np.fabs(bin_cents - this_median) # get the distance to the center for each bin in the hist + + sorted_bins_dist_idx = np.argsort(bin_dist) # sort the bins by dist to median + sorted_bins_dist = bin_dist[sorted_bins_dist_idx] # get the sorted distances + cumulative_sorted = this_pdf[sorted_bins_dist_idx].cumsum() # the cumulate PDF within the nearest bins + median_sorted_bin = np.searchsorted(cumulative_sorted, 0.5) # which bins are the nearest 50% of the PDF + dist_to_median = sorted_bins_dist[median_sorted_bin] # return the corresponding distance to the median + + return dist_to_median * SCALE_FACTOR diff --git a/tests/qp/test_point_metrics.py b/tests/qp/test_point_metrics.py index ec803ff..abec9db 100644 --- a/tests/qp/test_point_metrics.py +++ b/tests/qp/test_point_metrics.py @@ -1,3 +1,4 @@ +import unittest import numpy as np import qp @@ -16,7 +17,6 @@ OUTRATE = 0.0 SIGMAD = 0.0046489 - def construct_test_ensemble(): np.random.seed(87) nmax = 2.5 @@ -33,31 +33,79 @@ def construct_test_ensemble(): return zgrid, true_zs, grid_ens, true_ez -def test_point_metrics(): - """Basic tests for the various point estimate metrics""" - zgrid, zspec, pdf_ens, true_ez = construct_test_ensemble() - zb = pdf_ens.mode(grid=zgrid).flatten() +#generator that yields chunks from estimate and reference +def chunker(seq, size): + return (seq[pos:pos + size] for pos in range(0, len(seq), size)) + + +class test_point_metrics(unittest.TestCase): + + def test_point_metrics(self): + """Basic tests for the various point estimate metrics""" + zgrid, zspec, pdf_ens, true_ez = construct_test_ensemble() + zb = pdf_ens.mode(grid=zgrid).flatten() + + ez = PointStatsEz().evaluate(zb, zspec) + assert np.allclose(ez, true_ez, atol=1.0e-2) + + # grid limits ez vals to ~10^-2 tol + + sig_iqr = PointSigmaIQR().evaluate(zb, zspec) + assert np.isclose(sig_iqr, SIGIQR) + + bias = PointBias().evaluate(zb, zspec) + assert np.isclose(bias, BIAS) + + out_rate = PointOutlierRate().evaluate(zb, zspec) + assert np.isclose(out_rate, OUTRATE) + + sig_mad = PointSigmaMAD().evaluate(zb, zspec) + assert np.isclose(sig_mad, SIGMAD) + + def test_point_metrics_digest(self): + """Basic tests for the various point estimate metrics when using the + t-digest approximation.""" + + zgrid, zspec, pdf_ens, true_ez = construct_test_ensemble() + zb = pdf_ens.mode(grid=zgrid).flatten() + + configuration = {'tdigest_compression': 5000} + point_sigma_iqr = PointSigmaIQR(**configuration) + centroids = point_sigma_iqr.accumulate(zb, zspec) + sig_iqr = point_sigma_iqr.finalize([centroids]) + assert np.isclose(sig_iqr, SIGIQR, atol=1.0e-4) + + zb_iter = chunker(zb, 100) + zspec_iter = chunker(zspec, 100) + + sig_iqr_v2 = point_sigma_iqr.eval_from_iterator(zb_iter, zspec_iter) - ez = PointStatsEz().evaluate(zb, zspec) - assert np.allclose(ez, true_ez, atol=1.0e-2) - # grid limits ez vals to ~10^-2 tol + point_bias = PointBias(**configuration) + centroids = point_bias.accumulate(zb, zspec) + bias = point_bias.finalize([centroids]) + assert np.isclose(bias, BIAS) - sig_iqr = PointSigmaIQR().evaluate(zb, zspec) - assert np.isclose(sig_iqr, SIGIQR) + point_outlier_rate = PointOutlierRate(**configuration) + centroids = point_outlier_rate.accumulate(zb, zspec) + out_rate = point_outlier_rate.finalize([centroids]) + assert np.isclose(out_rate, OUTRATE) - bias = PointBias().evaluate(zb, zspec) - assert np.isclose(bias, BIAS) + point_sigma_mad = PointSigmaMAD(**configuration) + centroids = point_sigma_mad.accumulate(zb, zspec) + sig_mad = point_sigma_mad.finalize([centroids]) + assert np.isclose(sig_mad, SIGMAD, atol=1e-5) - out_rate = PointOutlierRate().evaluate(zb, zspec) - assert np.isclose(out_rate, OUTRATE) + configuration = {'tdigest_compression': 5000, 'num_bins': 1_000} + point_sigma_mad = PointSigmaMAD(**configuration) + centroids = point_sigma_mad.accumulate(zb, zspec) + sig_mad = point_sigma_mad.finalize([centroids]) + assert np.isclose(sig_mad, SIGMAD, atol=1e-4) - sig_mad = PointSigmaMAD().evaluate(zb, zspec) - assert np.isclose(sig_mad, SIGMAD) -def test_cde_loss_metric(): - """Basic test to ensure that the CDE Loss metric class is working.""" - zgrid, zspec, pdf_ens, _ = construct_test_ensemble() - cde_loss_class = CDELossMetric(zgrid) - result = cde_loss_class.evaluate(pdf_ens, zspec) - assert np.isclose(result, CDEVAL) + def test_cde_loss_metric(self): + """Basic test to ensure that the CDE Loss metric class is working.""" + zgrid, zspec, pdf_ens, _ = construct_test_ensemble() + cde_loss_class = CDELossMetric(zgrid) + result = cde_loss_class.evaluate(pdf_ens, zspec) + assert np.isclose(result, CDEVAL)