diff --git a/src/copairs/compute.py b/src/copairs/compute.py index 8bc7dce..7e977f4 100644 --- a/src/copairs/compute.py +++ b/src/copairs/compute.py @@ -41,7 +41,6 @@ def par_func(i): return batched_fn -@batch_processing def pairwise_corr(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray: """ Compute pearson correlation between two matrices in a paired row-wise @@ -62,7 +61,6 @@ def pairwise_corr(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray: return corrs -@batch_processing def pairwise_cosine(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray: x_norm = x_sample / np.linalg.norm(x_sample, axis=1)[:, np.newaxis] y_norm = y_sample / np.linalg.norm(y_sample, axis=1)[:, np.newaxis] @@ -70,10 +68,33 @@ def pairwise_cosine(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray: return c_sim -@batch_processing +def pairwise_abs_cosine(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray: + return np.abs(pairwise_cosine(x_sample, y_sample)) + + def pairwise_euclidean(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray: e_dist = np.sqrt(np.sum((x_sample - y_sample) ** 2, axis=1)) - return 1 - e_dist + return 1 / (1 + e_dist) + + +def get_distance_fn(distance): + distance_metrics = { + "abs_cosine": pairwise_abs_cosine, + "cosine": pairwise_cosine, + "correlation": pairwise_corr, + "euclidean": pairwise_euclidean, + } + + if isinstance(distance, str): + if distance not in distance_metrics: + raise ValueError(f"Unsupported distance metric: {distance}. Supported metrics are: {list(distance_metrics.keys())}") + distance_fn = distance_metrics[distance] + elif callable(distance): + distance_fn = distance + else: + raise ValueError("Distance must be either a string or a callable object.") + + return batch_processing(distance_fn) def random_binary_matrix(n, m, k, rng): diff --git a/src/copairs/map/average_precision.py b/src/copairs/map/average_precision.py index fc886e8..7084134 100644 --- a/src/copairs/map/average_precision.py +++ b/src/copairs/map/average_precision.py @@ -28,11 +28,12 @@ def build_rank_lists(pos_pairs, neg_pairs, pos_sims, neg_sims): def average_precision( - meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby, batch_size=20000 + meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby, batch_size=20000, distance="cosine" ) -> pd.DataFrame: columns = flatten_str_list(pos_sameby, pos_diffby, neg_sameby, neg_diffby) meta, columns = evaluate_and_filter(meta, columns) validate_pipeline_input(meta, feats, columns) + distance_fn = compute.get_distance_fn(distance) # Critical!, otherwise the indexing wont work meta = meta.reset_index(drop=True).copy() @@ -62,10 +63,10 @@ def average_precision( ) logger.info("Computing positive similarities...") - pos_sims = compute.pairwise_cosine(feats, pos_pairs, batch_size) + pos_sims = distance_fn(feats, pos_pairs, batch_size) logger.info("Computing negative similarities...") - neg_sims = compute.pairwise_cosine(feats, neg_pairs, batch_size) + neg_sims = distance_fn(feats, neg_pairs, batch_size) logger.info("Building rank lists...") paired_ix, rel_k_list, counts = build_rank_lists( diff --git a/src/copairs/map/multilabel.py b/src/copairs/map/multilabel.py index 25e5b9a..da8efde 100644 --- a/src/copairs/map/multilabel.py +++ b/src/copairs/map/multilabel.py @@ -74,10 +74,12 @@ def average_precision( neg_diffby, multilabel_col, batch_size=20000, + distance="cosine", ) -> pd.DataFrame: columns = flatten_str_list(pos_sameby, pos_diffby, neg_sameby, neg_diffby) meta, columns = evaluate_and_filter(meta, columns) validate_pipeline_input(meta, feats, columns) + distance_fn = compute.get_distance_fn(distance) # Critical!, otherwise the indexing wont work meta = meta.reset_index(drop=True).copy() @@ -114,10 +116,10 @@ def average_precision( neg_pairs = np.unique(neg_pairs, axis=0) logger.info("Computing positive similarities...") - pos_sims = compute.pairwise_cosine(feats, pos_pairs, batch_size) + pos_sims = distance_fn(feats, pos_pairs, batch_size) logger.info("Computing negative similarities...") - neg_sims = compute.pairwise_cosine(feats, neg_pairs, batch_size) + neg_sims = distance_fn(feats, neg_pairs, batch_size) logger.info("Computing AP per label...") negs_for = create_neg_query_solver(neg_pairs, neg_sims) diff --git a/src/copairs/replicating.py b/src/copairs/replicating.py index c1311e0..1957b40 100644 --- a/src/copairs/replicating.py +++ b/src/copairs/replicating.py @@ -4,14 +4,15 @@ import numpy as np import pandas as pd -from copairs.compute import pairwise_corr +from copairs.compute import get_distance_fn from .matching import Matcher def corr_from_null_pairs(X: np.ndarray, null_pairs, n_replicates): """Correlation from a given list of unnamed pairs.""" null_pairs = np.asarray(null_pairs, int) - corrs = pairwise_corr(X, null_pairs, batch_size=20000) + corr_fn = get_distance_fn("correlation") + corrs = corr_fn(X, null_pairs, batch_size=20000) corrs = corrs.reshape(-1, n_replicates) null_dist = np.nanmedian(corrs, axis=1) return pd.Series(null_dist) @@ -56,7 +57,8 @@ def corr_from_pairs(X: np.ndarray, pairs: dict, sameby: List[str]): list-like of correlation values and median of number of replicates """ pair_ix = np.vstack(list(pairs.values())) - corrs = pairwise_corr(X, pair_ix, batch_size=20000) + corr_fn = get_distance_fn("correlation") + corrs = corr_fn(X, pair_ix, batch_size=20000) counts = [len(v) for v in pairs.values()] if len(sameby) == 1: diff --git a/tests/test_compute.py b/tests/test_compute.py index 696d965..ca18cf1 100644 --- a/tests/test_compute.py +++ b/tests/test_compute.py @@ -24,6 +24,18 @@ def cosine_naive(feats, pairs): return cosine +def euclidean_naive(feats, pairs): + euclidean_sim = np.empty((len(pairs),)) + for pos, (i, j) in enumerate(pairs): + dist = np.linalg.norm(feats[i] - feats[j]) + euclidean_sim[pos] = 1 / (1 + dist) + return euclidean_sim + + +def abs_cosine_naive(feats, pairs): + return np.abs(cosine_naive(feats, pairs)) + + def test_corrcoef(): n_samples = 10 n_pairs = 20 @@ -33,7 +45,8 @@ def test_corrcoef(): pairs = rng.integers(0, n_samples - 1, [n_pairs, 2]) corr_gt = corrcoef_naive(feats, pairs) - corr = compute.pairwise_corr(feats, pairs, batch_size) + corr_fn = compute.get_distance_fn("correlation") + corr = corr_fn(feats, pairs, batch_size) assert np.allclose(corr_gt, corr) @@ -46,5 +59,35 @@ def test_cosine(): pairs = rng.integers(0, n_samples - 1, [n_pairs, 2]) cosine_gt = cosine_naive(feats, pairs) - cosine = compute.pairwise_cosine(feats, pairs, batch_size) + cosine_fn = compute.get_distance_fn("cosine") + cosine = cosine_fn(feats, pairs, batch_size) assert np.allclose(cosine_gt, cosine) + + +def test_euclidean(): + n_samples = 10 + n_pairs = 20 + n_feats = 5 + batch_size = 4 + feats = rng.uniform(0, 1, [n_samples, n_feats]) + pairs = rng.integers(0, n_samples - 1, [n_pairs, 2]) + + euclidean_gt = euclidean_naive(feats, pairs) + euclidean_fn = compute.get_distance_fn("euclidean") + euclidean = euclidean_fn(feats, pairs, batch_size) + assert np.allclose(euclidean_gt, euclidean) + + +def test_abs_cosine(): + n_samples = 10 + n_pairs = 20 + n_feats = 5 + batch_size = 4 + feats = rng.uniform(0, 1, [n_samples, n_feats]) + pairs = rng.integers(0, n_samples - 1, [n_pairs, 2]) + + abs_cosine_gt = abs_cosine_naive(feats, pairs) + abs_cosine_fn = compute.get_distance_fn("abs_cosine") + abs_cosine = abs_cosine_fn(feats, pairs, batch_size) + assert np.allclose(abs_cosine_gt, abs_cosine) +