Skip to content

Commit

Permalink
Refactor CMAP into RankedBasedMetrics and add generalized mean rank m…
Browse files Browse the repository at this point in the history
…etric

PiperOrigin-RevId: 520093902
  • Loading branch information
vdumoulin authored and copybara-github committed Mar 28, 2023
1 parent 8af934a commit 7412704
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 41 deletions.
45 changes: 45 additions & 0 deletions chirp/models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,51 @@ def average_precision(
return mask * raw_av_prec


def generalized_mean_rank(
scores: jnp.ndarray,
labels: jnp.ndarray,
label_mask: jnp.ndarray | None = None,
sort_descending: bool = True,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Computes the generalized mean rank and its variance over the last axis.
The generalized mean rank can be expressed as
(sum_i #FP ranked above TP_i) / (#FP * #TP).
We treat all labels as either true positives (if the label is 1) or false
positives (if the label is zero).
Args:
scores: A score for each label which can be ranked.
labels: A multi-hot encoding of the ground truth positives. Must match the
shape of scores.
label_mask: A mask indicating which labels to involve in the calculation.
sort_descending: An indicator if the search result ordering is in descending
order (e.g. for evaluating over similarity metrics where higher scores are
preferred). If false, computes the generalize mean rank on descendingly
sorted inputs.
Returns:
The generalized mean rank and its variance.
"""
# TODO(vdumoulin): add support for `label_mask`.
if label_mask is not None:
raise NotImplementedError

idx = jnp.argsort(scores, axis=-1)
if sort_descending:
idx = jnp.flip(idx, axis=-1)
labels = jnp.take_along_axis(labels, idx, axis=-1)

num_fp = (labels == 0).sum(axis=-1)
num_fp_above = jnp.cumsum(labels == 0, axis=-1)

gmr = num_fp_above.mean(axis=-1, where=(labels > 0)) / num_fp
gmr_var = num_fp_above.var(axis=-1, where=(labels > 0)) / num_fp
return gmr, gmr_var


def least_squares_solve_mix(matrix, rhs, diag_loading=1e-3):
# Assumes a real-valued matrix, with zero mean.
adj_matrix = jnp.conjugate(jnp.swapaxes(matrix, -1, -2))
Expand Down
36 changes: 27 additions & 9 deletions chirp/models/cmap.py → chirp/models/rank_based_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Metric for Class Mean Average Precision (CMAP)."""
"""Rank-based metrics, including cmAP and generalized mean rank."""
from chirp.models import metrics
from clu import metrics as clu_metrics
import flax.struct
Expand All @@ -22,16 +22,19 @@


@flax.struct.dataclass
class CMAP(
class RankBasedMetrics(
# TODO(bartvm): Create class factory for calculating over different outputs
clu_metrics.CollectingMetric.from_outputs(("label", "label_logits"))
):
"""(Class-wise) mean average precision.
"""(Class-wise) rank-based metrics.
This metric calculates the average precision score of each class, and also
returns the average of those values. This is sometimes referred to as the
macro-averaged average precision, or the class-wise mean average precision
(CmAP).
It also calculates the generalized mean rank for each class and returns the
geometric average of those values.
"""

def compute(self, sample_threshold: int = 0):
Expand All @@ -40,28 +43,43 @@ def compute(self, sample_threshold: int = 0):
with jax.default_device(jax.devices("cpu")[0]):
mask = jnp.sum(values["label"], axis=0) > sample_threshold
if jnp.sum(mask) == 0:
return {"macro": 0.0}
return {"macro_cmap": 0.0, "macro_gmr": 0.0}
# Same as sklearn's average_precision_score(label, logits, average=None)
# but that implementation doesn't scale to 10k+ classes
class_aps = metrics.average_precision(
values["label_logits"].T, values["label"].T
)
class_aps = jnp.where(mask, class_aps, jnp.nan)

class_gmr, class_gmr_var = metrics.generalized_mean_rank(
values["label_logits"].T, values["label"].T
)
class_gmr = jnp.where(mask, class_gmr, jnp.nan)
class_gmr_var = jnp.where(mask, class_gmr_var, jnp.nan)

return {
"macro": jnp.mean(class_aps, where=mask),
"individual": class_aps,
"macro_cmap": jnp.mean(class_aps, where=mask),
"individual_cmap": class_aps,
# If the GMR is 0.0 for at least one class, then the geometric average
# goes to zero. Instead, we take the geometric average of 1 - GMR and
# then take 1 - geometric_average.
"macro_gmr": 1.0 - jnp.exp(
jnp.mean(jnp.log(1.0 - class_gmr), where=mask)
),
"individual_gmr": class_gmr,
"individual_gmr_var": class_gmr_var,
}


def add_cmap_to_metrics_collection(name: str, metrics_collection):
"""Adds a CMAP instance to an existing CLU metrics collection."""
def add_rank_based_metrics_to_metrics_collection(name: str, metrics_collection):
"""Adds a RankBasedMetric instance to an existing CLU metrics collection."""
new_collection = flax.struct.dataclass(
type(
"_ValidCollection",
(metrics_collection,),
{
"__annotations__": {
f"{name}_cmap": CMAP,
f"{name}_rank_based": RankBasedMetrics,
**metrics_collection.__annotations__,
}
},
Expand Down
6 changes: 3 additions & 3 deletions chirp/projects/multicluster/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dataclasses

from chirp.models import cmap
from chirp.models import rank_based_metrics
from chirp.projects.multicluster import data_lib
import tensorflow as tf

Expand Down Expand Up @@ -83,8 +83,8 @@ def train_embedding_model(
# Manually compute per-class mAP and CmAP scores.
test_logits = model.predict(test_ds, verbose=0)
test_labels = merged.data['label_hot'][test_locs]
maps = cmap.CMAP.from_model_output(
maps = rank_based_metrics.RankBasedMetrics.from_model_output(
label_logits=test_logits, label=test_labels
).compute()
cmap_value = maps.pop('macro')
cmap_value = maps.pop('macro_cmap')
return ClassifierMetrics(acc, auc_roc, recall, cmap_value, maps)
52 changes: 40 additions & 12 deletions chirp/tests/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import functools
import os
from chirp.models import cmap
from chirp.models import cwt
from chirp.models import metrics
from chirp.models import rank_based_metrics
from clu import metrics as clu_metrics
import flax
import jax
Expand Down Expand Up @@ -150,33 +150,61 @@ def test_cmap(self):
[0, 0, 0],
[1, 0, 0],
])
full_cmap_value = cmap.CMAP.from_model_output(
full_cmap_value = rank_based_metrics.RankBasedMetrics.from_model_output(
label=labels, label_logits=scores
).compute()["macro"]
).compute()["macro_cmap"]
# Check against the manually verified outcome.
self.assertAlmostEqual(full_cmap_value, 0.49687502)

batched_cmap_metric = cmap.CMAP.empty()
batched_cmap_metric = rank_based_metrics.RankBasedMetrics.empty()
batched_cmap_metric = batched_cmap_metric.merge(
cmap.CMAP.from_model_output(label_logits=scores[:5], label=labels[:5])
rank_based_metrics.RankBasedMetrics.from_model_output(
label_logits=scores[:5], label=labels[:5]
)
)
batched_cmap_metric = batched_cmap_metric.merge(
cmap.CMAP.from_model_output(label_logits=scores[5:], label=labels[5:])
rank_based_metrics.RankBasedMetrics.from_model_output(
label_logits=scores[5:], label=labels[5:]
)
)
batched_cmap_value = batched_cmap_metric.compute()["macro"]
batched_cmap_value = batched_cmap_metric.compute()["macro_cmap"]
self.assertEqual(batched_cmap_value, full_cmap_value)

# Check that when setting a threshold to 3, the cmap is only computed
# taking into account column 1 (the only one with >3 samples).
self.assertEqual(
cmap.CMAP.from_model_output(label_logits=scores, label=labels).compute(
sample_threshold=3
)["macro"],
cmap.CMAP.from_model_output(
rank_based_metrics.RankBasedMetrics.from_model_output(
label_logits=scores, label=labels
).compute(sample_threshold=3)["macro_cmap"],
rank_based_metrics.RankBasedMetrics.from_model_output(
label_logits=scores[:, 1:2], label=labels[:, 1:2]
).compute()["macro"],
).compute()["macro_cmap"],
)

def test_gmr(self):
# The following example was worked out manually and verified.
scores = jnp.array([[0.9, 0.2, 0.3, 0.6, 0.5, 0.7, 0.1, 0.4, 0.8]]).T
labels = jnp.array([[0, 0, 0, 0, 0, 1, 0, 1, 1]]).T
full_gmr_value = rank_based_metrics.RankBasedMetrics.from_model_output(
label=labels, label_logits=scores
).compute()["macro_gmr"]
# Check against the manually verified outcome.
self.assertAlmostEqual(full_gmr_value, 5.0 / 18.0)

batched_gmr_metric = rank_based_metrics.RankBasedMetrics.empty()
batched_gmr_metric = batched_gmr_metric.merge(
rank_based_metrics.RankBasedMetrics.from_model_output(
label_logits=scores[:5], label=labels[:5]
)
)
batched_gmr_metric = batched_gmr_metric.merge(
rank_based_metrics.RankBasedMetrics.from_model_output(
label_logits=scores[5:], label=labels[5:]
)
)
batched_gmr_value = batched_gmr_metric.compute()["macro_gmr"]
self.assertEqual(batched_gmr_value, full_gmr_value)


if __name__ == "__main__":
absltest.main()
10 changes: 6 additions & 4 deletions chirp/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from absl import logging
from chirp import export_utils
from chirp.data import pipeline
from chirp.models import cmap
from chirp.models import metrics
from chirp.models import output
from chirp.models import rank_based_metrics
from chirp.models import taxonomy_model
from chirp.taxonomy import class_utils
from chirp.train import utils
Expand Down Expand Up @@ -277,12 +277,14 @@ def evaluate(
if taxonomy_loss_weight != 0.0:
taxonomy_keys += utils.TAXONOMY_KEYS

# The metrics are the same as for training, but with CmAP added
# The metrics are the same as for training, but with rank-based metrics added.
base_metrics_collection = make_metrics_collection(
name, taxonomy_keys, model_bundle.model.num_classes
)
valid_metrics_collection = cmap.add_cmap_to_metrics_collection(
name, base_metrics_collection
valid_metrics_collection = (
rank_based_metrics.add_rank_based_metrics_to_metrics_collection(
name, base_metrics_collection
)
)

@functools.partial(jax.pmap, axis_name="batch")
Expand Down
6 changes: 4 additions & 2 deletions chirp/train/hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from absl import logging
from chirp.data import pipeline
from chirp.models import cmap
from chirp.models import rank_based_metrics
from chirp.models import frontend as frontend_models
from chirp.models import hubert
from chirp.models import layers
Expand Down Expand Up @@ -912,7 +912,9 @@ def evaluate(
(base_metrics_collection,),
{
"__annotations__": {
f"{name}_cmap": cmap.CMAP,
f"{name}_rank_based_metrics": (
rank_based_metrics.RankBasedMetrics
),
**base_metrics_collection.__annotations__,
}
},
Expand Down
16 changes: 5 additions & 11 deletions chirp/train/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from absl import logging
from chirp import export_utils
from chirp.data import pipeline
from chirp.models import cmap
from chirp.models import metrics
from chirp.models import output
from chirp.models import rank_based_metrics
from chirp.models import separation_model
from chirp.taxonomy import class_utils
from chirp.train import utils
Expand Down Expand Up @@ -82,14 +82,6 @@ def p_log_sisnr_loss(
)


@flax.struct.dataclass
class ValidationMetrics(clu_metrics.Collection):
valid_loss: clu_metrics.Average.from_fun(p_log_snr_loss)
valid_mixit_log_mse: clu_metrics.Average.from_fun(p_log_mse_loss)
valid_mixit_neg_snr: clu_metrics.Average.from_fun(p_log_snr_loss)
valid_cmap: cmap.CMAP


def keyed_cross_entropy(
key: str,
outputs: separation_model.SeparatorOutput,
Expand Down Expand Up @@ -288,8 +280,10 @@ def evaluate(
):
"""Run evaluation."""
base_metrics_collection = make_metrics_collection('valid__')
valid_metrics_collection = cmap.add_cmap_to_metrics_collection(
'valid', base_metrics_collection
valid_metrics_collection = (
rank_based_metrics.add_rank_based_metrics_to_metrics_collection(
'valid', base_metrics_collection
)
)

@functools.partial(jax.pmap, axis_name='batch')
Expand Down

0 comments on commit 7412704

Please sign in to comment.