Skip to content

Commit

Permalink
Format using ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
johnarevalo committed Oct 22, 2024
1 parent 0eb53b3 commit c3fd0c6
Show file tree
Hide file tree
Showing 16 changed files with 59 additions and 37 deletions.
3 changes: 3 additions & 0 deletions src/copairs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""
Package to create pairwise lists based on sameby and diffby criteria
"""

from .matching import Matcher, MatcherMultilabel

__all__ = ["Matcher", "MatcherMultilabel"]
6 changes: 4 additions & 2 deletions src/copairs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,15 @@ def get_distance_fn(distance):

if isinstance(distance, str):
if distance not in distance_metrics:
raise ValueError(f"Unsupported distance metric: {distance}. Supported metrics are: {list(distance_metrics.keys())}")
raise ValueError(
f"Unsupported distance metric: {distance}. Supported metrics are: {list(distance_metrics.keys())}"
)
distance_fn = distance_metrics[distance]
elif callable(distance):
distance_fn = distance
else:
raise ValueError("Distance must be either a string or a callable object.")

return batch_processing(distance_fn)


Expand Down
4 changes: 3 additions & 1 deletion src/copairs/map/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .map import mean_average_precision
from . import multilabel
from .average_precision import average_precision
from .map import mean_average_precision

__all__ = ["mean_average_precision", "multilabel", "average_precision"]
9 changes: 8 additions & 1 deletion src/copairs/map/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ def build_rank_lists(pos_pairs, neg_pairs, pos_sims, neg_sims):


def average_precision(
meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby, batch_size=20000, distance="cosine"
meta,
feats,
pos_sameby,
pos_diffby,
neg_sameby,
neg_diffby,
batch_size=20000,
distance="cosine",
) -> pd.DataFrame:
columns = flatten_str_list(pos_sameby, pos_diffby, neg_sameby, neg_diffby)
meta, columns = evaluate_and_filter(meta, columns)
Expand Down
12 changes: 7 additions & 5 deletions src/copairs/map/filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Tuple, List
import itertools
import re
from typing import List, Tuple

import pandas as pd
import numpy as np
import pandas as pd


def validate_pipeline_input(meta, feats, columns):
Expand Down Expand Up @@ -45,12 +45,12 @@ def extract_filters(columns, df_columns) -> Tuple[List[str], List[str]]:
if col in df_columns:
parsed_cols.append(col)
continue
column_names = re.findall(r'(\w+)\s*[=<>!]+', col)
column_names = re.findall(r"(\w+)\s*[=<>!]+", col)

valid_column_names = [col for col in column_names if col in df_columns]
if not valid_column_names:
raise ValueError(f"Invalid query or column name: {col}")

queries_to_eval.append(col)
parsed_cols.extend(valid_column_names)

Expand All @@ -71,6 +71,8 @@ def apply_filters(df, query_list):
if df_filtered.empty:
raise ValueError(f"No data matched the query: {combined_query}")
except Exception as e:
raise ValueError(f"Invalid combined query expression: {combined_query}. Error: {e}")
raise ValueError(
f"Invalid combined query expression: {combined_query}. Error: {e}"
)

return df_filtered
4 changes: 1 addition & 3 deletions src/copairs/map/multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def average_precision(
meta = meta.reset_index(drop=True).copy()

logger.info("Indexing metadata...")
matcher = MatcherMultilabel(
meta, columns, multilabel_col=multilabel_col, seed=0
)
matcher = MatcherMultilabel(meta, columns, multilabel_col=multilabel_col, seed=0)

logger.info("Finding positive pairs...")
pos_pairs = matcher.get_all_pairs(sameby=pos_sameby, diffby=pos_diffby)
Expand Down
10 changes: 7 additions & 3 deletions src/copairs/matching.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
Sample pairs with given column restrictions
"""
from collections import namedtuple

import itertools
import logging
from math import comb
import re
from collections import namedtuple
from math import comb
from typing import Dict, Sequence, Set, Union

import numpy as np
Expand Down Expand Up @@ -442,5 +443,8 @@ def _only_diffby_multi(self):
pairs = itertools.chain.from_iterable(pairs.values())
pairs = set(map(frozenset, pairs))
all_pairs = itertools.combinations(range(self.size), 2)
filter_fn = lambda x: set(x) not in pairs

def filter_fn(x):
return set(x) not in pairs

return {None: list(filter(filter_fn, all_pairs))}
1 change: 1 addition & 0 deletions src/copairs/plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional

from plotly import graph_objects as go
from plotly.subplots import make_subplots

Expand Down
2 changes: 2 additions & 0 deletions src/copairs/replicating.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Class for getting Percent replicating metric"""

from typing import List, Literal

import numpy as np
import pandas as pd

from copairs.compute import get_distance_fn

from .matching import Matcher


Expand Down
2 changes: 1 addition & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from itertools import product
from typing import Dict

import pandas as pd
import numpy as np
import pandas as pd

from copairs.matching import ColumnList

Expand Down
2 changes: 0 additions & 2 deletions tests/test_compute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import numpy as np

from copairs import compute
Expand Down Expand Up @@ -90,4 +89,3 @@ def test_abs_cosine():
abs_cosine_fn = compute.get_distance_fn("abs_cosine")
abs_cosine = abs_cosine_fn(feats, pairs, batch_size)
assert np.allclose(abs_cosine_gt, abs_cosine)

26 changes: 16 additions & 10 deletions tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.metrics import average_precision_score
import numpy as np

from copairs import compute
from copairs.map import average_precision
Expand Down Expand Up @@ -140,8 +140,8 @@ def test_raise_no_pairs():
average_precision(meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby)
with pytest.raises(UnpairedException, match="Unable to find negative pairs."):
average_precision(meta, feats, pos_diffby, [], pos_sameby, [])


def test_raise_nan_error():
length = 10
vocab_size = {"p": 5, "w": 3, "l": 4}
Expand All @@ -154,14 +154,20 @@ def test_raise_nan_error():
meta = simulate_random_dframe(length, vocab_size, pos_sameby, pos_diffby, rng)
length = len(meta)
feats = rng.uniform(size=(length, n_feats))

# add null values
feats_nan = feats.copy()
feats_nan[2,2] = None
feats_nan[2, 2] = None
meta_nan = meta.copy()
meta_nan.loc[1,"p"] = None
meta_nan.loc[1, "p"] = None

with pytest.raises(ValueError, match="features should not have null values."):
average_precision(meta, feats_nan, pos_sameby, pos_diffby, neg_sameby, neg_diffby)
with pytest.raises(ValueError, match="metadata columns should not have null values."):
average_precision(meta_nan, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby)
average_precision(
meta, feats_nan, pos_sameby, pos_diffby, neg_sameby, neg_diffby
)
with pytest.raises(
ValueError, match="metadata columns should not have null values."
):
average_precision(
meta_nan, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby
)
10 changes: 3 additions & 7 deletions tests/test_map_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import pandas as pd
import numpy as np
import pytest

from copairs.map.filter import evaluate_and_filter
from tests.helpers import simulate_random_dframe
Expand All @@ -23,8 +22,8 @@ def mock_dataframe():
def test_correct(mock_dataframe):
df, parsed_cols = evaluate_and_filter(mock_dataframe, ["p == 'p1'", "w > 'w2'"])
assert not df.empty
assert 'p' in parsed_cols and 'w' in parsed_cols
assert all(df['w'].str.extract(r'(\d+)')[0].astype(int) > 2)
assert "p" in parsed_cols and "w" in parsed_cols
assert all(df["w"].str.extract(r"(\d+)")[0].astype(int) > 2)


def test_invalid_query(mock_dataframe):
Expand All @@ -44,6 +43,3 @@ def test_empty_result_from_valid_query(mock_dataframe):
with pytest.raises(ValueError) as excinfo:
evaluate_and_filter(mock_dataframe, ['p == "p4"'])
assert "No data matched the query" in str(excinfo.value)



1 change: 1 addition & 0 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test functions for Matcher"""

from string import ascii_letters

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions tests/test_matching_multilabel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd

from copairs.matching import MatcherMultilabel
from tests.helpers import simulate_random_plates

Expand Down
3 changes: 1 addition & 2 deletions tests/test_replicating.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from copairs import Matcher
from copairs.replicating import (
corr_between_replicates,
correlation_test,
corr_from_pairs,
correlation_test,
)

from tests.helpers import create_dframe

SEED = 0
Expand Down

0 comments on commit c3fd0c6

Please sign in to comment.