Skip to content

Commit

Permalink
allow distance fn selection; add euclidean, abs_cosine
Browse files Browse the repository at this point in the history
  • Loading branch information
alxndrkalinin committed Sep 17, 2024
1 parent 2673e55 commit fac5985
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 14 deletions.
29 changes: 25 additions & 4 deletions src/copairs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def par_func(i):
return batched_fn


@batch_processing
def pairwise_corr(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray:
"""
Compute pearson correlation between two matrices in a paired row-wise
Expand All @@ -62,18 +61,40 @@ def pairwise_corr(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray:
return corrs


@batch_processing
def pairwise_cosine(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray:
x_norm = x_sample / np.linalg.norm(x_sample, axis=1)[:, np.newaxis]
y_norm = y_sample / np.linalg.norm(y_sample, axis=1)[:, np.newaxis]
c_sim = np.sum(x_norm * y_norm, axis=1)
return c_sim


@batch_processing
def pairwise_abs_cosine(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray:
return np.abs(pairwise_cosine(x_sample, y_sample))


def pairwise_euclidean(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray:
e_dist = np.sqrt(np.sum((x_sample - y_sample) ** 2, axis=1))
return 1 - e_dist
return 1 / (1 + e_dist)


def get_distance_fn(distance):
distance_metrics = {
"abs_cosine": pairwise_abs_cosine,
"cosine": pairwise_cosine,
"correlation": pairwise_corr,
"euclidean": pairwise_euclidean,
}

if isinstance(distance, str):
if distance not in distance_metrics:
raise ValueError(f"Unsupported distance metric: {distance}. Supported metrics are: {list(distance_metrics.keys())}")
distance_fn = distance_metrics[distance]
elif callable(distance):
distance_fn = distance
else:
raise ValueError("Distance must be either a string or a callable object.")

return batch_processing(distance_fn)


def random_binary_matrix(n, m, k, rng):
Expand Down
7 changes: 4 additions & 3 deletions src/copairs/map/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ def build_rank_lists(pos_pairs, neg_pairs, pos_sims, neg_sims):


def average_precision(
meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby, batch_size=20000
meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby, batch_size=20000, distance="cosine"
) -> pd.DataFrame:
columns = flatten_str_list(pos_sameby, pos_diffby, neg_sameby, neg_diffby)
meta, columns = evaluate_and_filter(meta, columns)
validate_pipeline_input(meta, feats, columns)
distance_fn = compute.get_distance_fn(distance)

# Critical!, otherwise the indexing wont work
meta = meta.reset_index(drop=True).copy()
Expand Down Expand Up @@ -62,10 +63,10 @@ def average_precision(
)

logger.info("Computing positive similarities...")
pos_sims = compute.pairwise_cosine(feats, pos_pairs, batch_size)
pos_sims = distance_fn(feats, pos_pairs, batch_size)

logger.info("Computing negative similarities...")
neg_sims = compute.pairwise_cosine(feats, neg_pairs, batch_size)
neg_sims = distance_fn(feats, neg_pairs, batch_size)

logger.info("Building rank lists...")
paired_ix, rel_k_list, counts = build_rank_lists(
Expand Down
6 changes: 4 additions & 2 deletions src/copairs/map/multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ def average_precision(
neg_diffby,
multilabel_col,
batch_size=20000,
distance="cosine",
) -> pd.DataFrame:
columns = flatten_str_list(pos_sameby, pos_diffby, neg_sameby, neg_diffby)
meta, columns = evaluate_and_filter(meta, columns)
validate_pipeline_input(meta, feats, columns)
distance_fn = compute.get_distance_fn(distance)
# Critical!, otherwise the indexing wont work
meta = meta.reset_index(drop=True).copy()

Expand Down Expand Up @@ -114,10 +116,10 @@ def average_precision(
neg_pairs = np.unique(neg_pairs, axis=0)

logger.info("Computing positive similarities...")
pos_sims = compute.pairwise_cosine(feats, pos_pairs, batch_size)
pos_sims = distance_fn(feats, pos_pairs, batch_size)

logger.info("Computing negative similarities...")
neg_sims = compute.pairwise_cosine(feats, neg_pairs, batch_size)
neg_sims = distance_fn(feats, neg_pairs, batch_size)

logger.info("Computing AP per label...")
negs_for = create_neg_query_solver(neg_pairs, neg_sims)
Expand Down
8 changes: 5 additions & 3 deletions src/copairs/replicating.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import numpy as np
import pandas as pd

from copairs.compute import pairwise_corr
from copairs.compute import get_distance_fn
from .matching import Matcher


def corr_from_null_pairs(X: np.ndarray, null_pairs, n_replicates):
"""Correlation from a given list of unnamed pairs."""
null_pairs = np.asarray(null_pairs, int)
corrs = pairwise_corr(X, null_pairs, batch_size=20000)
corr_fn = get_distance_fn("correlation")
corrs = corr_fn(X, null_pairs, batch_size=20000)
corrs = corrs.reshape(-1, n_replicates)
null_dist = np.nanmedian(corrs, axis=1)
return pd.Series(null_dist)
Expand Down Expand Up @@ -56,7 +57,8 @@ def corr_from_pairs(X: np.ndarray, pairs: dict, sameby: List[str]):
list-like of correlation values and median of number of replicates
"""
pair_ix = np.vstack(list(pairs.values()))
corrs = pairwise_corr(X, pair_ix, batch_size=20000)
corr_fn = get_distance_fn("correlation")
corrs = corr_fn(X, pair_ix, batch_size=20000)
counts = [len(v) for v in pairs.values()]

if len(sameby) == 1:
Expand Down
47 changes: 45 additions & 2 deletions tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def cosine_naive(feats, pairs):
return cosine


def euclidean_naive(feats, pairs):
euclidean_sim = np.empty((len(pairs),))
for pos, (i, j) in enumerate(pairs):
dist = np.linalg.norm(feats[i] - feats[j])
euclidean_sim[pos] = 1 / (1 + dist)
return euclidean_sim


def abs_cosine_naive(feats, pairs):
return np.abs(cosine_naive(feats, pairs))


def test_corrcoef():
n_samples = 10
n_pairs = 20
Expand All @@ -33,7 +45,8 @@ def test_corrcoef():
pairs = rng.integers(0, n_samples - 1, [n_pairs, 2])

corr_gt = corrcoef_naive(feats, pairs)
corr = compute.pairwise_corr(feats, pairs, batch_size)
corr_fn = compute.get_distance_fn("correlation")
corr = corr_fn(feats, pairs, batch_size)
assert np.allclose(corr_gt, corr)


Expand All @@ -46,5 +59,35 @@ def test_cosine():
pairs = rng.integers(0, n_samples - 1, [n_pairs, 2])

cosine_gt = cosine_naive(feats, pairs)
cosine = compute.pairwise_cosine(feats, pairs, batch_size)
cosine_fn = compute.get_distance_fn("cosine")
cosine = cosine_fn(feats, pairs, batch_size)
assert np.allclose(cosine_gt, cosine)


def test_euclidean():
n_samples = 10
n_pairs = 20
n_feats = 5
batch_size = 4
feats = rng.uniform(0, 1, [n_samples, n_feats])
pairs = rng.integers(0, n_samples - 1, [n_pairs, 2])

euclidean_gt = euclidean_naive(feats, pairs)
euclidean_fn = compute.get_distance_fn("euclidean")
euclidean = euclidean_fn(feats, pairs, batch_size)
assert np.allclose(euclidean_gt, euclidean)


def test_abs_cosine():
n_samples = 10
n_pairs = 20
n_feats = 5
batch_size = 4
feats = rng.uniform(0, 1, [n_samples, n_feats])
pairs = rng.integers(0, n_samples - 1, [n_pairs, 2])

abs_cosine_gt = abs_cosine_naive(feats, pairs)
abs_cosine_fn = compute.get_distance_fn("abs_cosine")
abs_cosine = abs_cosine_fn(feats, pairs, batch_size)
assert np.allclose(abs_cosine_gt, abs_cosine)

0 comments on commit fac5985

Please sign in to comment.