Skip to content

Commit

Permalink
Fix #50
Browse files Browse the repository at this point in the history
  • Loading branch information
juanbc committed Aug 30, 2024
1 parent 4bad2bf commit 5d67b62
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 2 deletions.
43 changes: 41 additions & 2 deletions skcriteria/cmp/ranks_cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ class RanksComparator(Sequence, DiffEqualityMixin):

def __init__(self, ranks):
ranks = list(ranks)
self._validate_ranks(ranks)
self._ranks = ranks
self._validate_ranks()

# INTERNALS ===============================================================
def _validate_ranks(self, ranks):
def _validate_ranks(self):
ranks = self._ranks

if len(ranks) <= 1:
raise ValueError("Please provide more than one ranking")

Expand Down Expand Up @@ -228,6 +230,8 @@ def to_dataframe(self, *, untied=False):

return df

# STATISTICALS ============================================================

def corr(self, *, untied=False, **kwargs):
"""Compute pairwise correlation of rankings, excluding NA/null values.
Expand Down Expand Up @@ -371,6 +375,41 @@ def distance(self, *, untied=False, metric="hamming", **kwargs):
)
return dis_df

def extra_get(self, key, default=None):
"""Retrieve a specific key from each rank, returning a \
dictionary of results.
This method iterates through all ranks and attempts to get the value
associated with the specified key. If the key is not found in a rank,
the default value is used.
Parameters
----------
key : hashable
The key to look up in each rank.
default : any, optional
The value to return if the key is not found in a rank.
Defaults to None.
Returns
-------
dict
A dictionary where each key is the name of a rank, and each value
is the result of calling `get(key, default)` on that rank.
Notes
-----
The returned dictionary will have an entry for every rank, even if the
key was not found and the default value was used.
"""
return {
rank_name: rank.extra_.get(key, default)
for rank_name, rank in self._ranks
}

eget = extra_get # shortcut

# ACCESSORS (YES, WE USE CACHED PROPERTIES IS THE EASIEST WAY) ============

@methodtools.lru_cache(maxsize=None)
Expand Down
8 changes: 8 additions & 0 deletions skcriteria/utils/bunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,11 @@ def __repr__(self):
def __dir__(self):
"""x.__dir__() <==> dir(x)."""
return super().__dir__() + list(self._data)

def __setstate__(self, state):
"""Needed for multiprocessing environment."""
self.__dict__.update(state)

def get(self, key, default=None):
"""Get item from bunch."""
return self._data.get(key, default)
18 changes: 18 additions & 0 deletions tests/cmp/test_ranks_cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,24 @@ def test_RanksComparator_hash():
assert id(rcmp) == hash(rcmp)


def test_RanksComparator_extra_get():
rank0 = agg.RankResult(
"test", ["a", "b"], [1, 1], {"alpha": 1, "bravo": 2}
)
rank1 = agg.RankResult(
"test", ["a", "b"], [1, 1], {"alpha": 1, "delta": 3}
)
rcmp = ranks_cmp.mkrank_cmp(rank0, rank1)

assert rcmp.extra_get("alpha") == {"test_1": 1, "test_2": 1}
assert rcmp.extra_get("bravo") == {"test_1": 2, "test_2": None}
assert rcmp.extra_get("delta", "foo") == {"test_1": "foo", "test_2": 3}
assert rcmp.extra_get("charly", "foo") == {
"test_1": "foo",
"test_2": "foo",
}


def test_RanksComparator_plot():
rank0 = agg.RankResult("test", ["a", "b"], [1, 1], {})
rank1 = agg.RankResult("test", ["a", "b"], [1, 1], {})
Expand Down
17 changes: 17 additions & 0 deletions tests/utils/test_bunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# =============================================================================

import copy
import pickle

import pytest

Expand Down Expand Up @@ -97,3 +98,19 @@ def test_Bunch_assign_fails():
foo_bunch = bunch.Bunch("foo", {})
with pytest.raises(AttributeError, match="Bunch 'foo' is read-only"):
foo_bunch.some_key = 1


def test_Bunch_setstate():
md = bunch.Bunch("foo", {"alfa": 1})
md_c = pickle.loads(pickle.dumps(md))

assert md is not md_c
assert md._name == md_c._name # string are inmutable never deep copy
assert md._data == md_c._data and md._data is not md_c._data


def test_Bunch_get():
md = bunch.Bunch("foo", {"alfa": 1})
assert md.get("alfa") == 1
assert md.get("bravo") is None
assert md.get("bravo", 2) == 2

0 comments on commit 5d67b62

Please sign in to comment.