diff --git a/skcriteria/cmp/ranks_cmp.py b/skcriteria/cmp/ranks_cmp.py index 53a86fa..0a9cc4e 100644 --- a/skcriteria/cmp/ranks_cmp.py +++ b/skcriteria/cmp/ranks_cmp.py @@ -88,11 +88,13 @@ class RanksComparator(Sequence, DiffEqualityMixin): def __init__(self, ranks): ranks = list(ranks) - self._validate_ranks(ranks) self._ranks = ranks + self._validate_ranks() # INTERNALS =============================================================== - def _validate_ranks(self, ranks): + def _validate_ranks(self): + ranks = self._ranks + if len(ranks) <= 1: raise ValueError("Please provide more than one ranking") @@ -228,6 +230,8 @@ def to_dataframe(self, *, untied=False): return df + # STATISTICALS ============================================================ + def corr(self, *, untied=False, **kwargs): """Compute pairwise correlation of rankings, excluding NA/null values. @@ -371,6 +375,41 @@ def distance(self, *, untied=False, metric="hamming", **kwargs): ) return dis_df + def extra_get(self, key, default=None): + """Retrieve a specific key from each rank, returning a \ + dictionary of results. + + This method iterates through all ranks and attempts to get the value + associated with the specified key. If the key is not found in a rank, + the default value is used. + + Parameters + ---------- + key : hashable + The key to look up in each rank. + default : any, optional + The value to return if the key is not found in a rank. + Defaults to None. + + Returns + ------- + dict + A dictionary where each key is the name of a rank, and each value + is the result of calling `get(key, default)` on that rank. + + Notes + ----- + The returned dictionary will have an entry for every rank, even if the + key was not found and the default value was used. + + """ + return { + rank_name: rank.extra_.get(key, default) + for rank_name, rank in self._ranks + } + + eget = extra_get # shortcut + # ACCESSORS (YES, WE USE CACHED PROPERTIES IS THE EASIEST WAY) ============ @methodtools.lru_cache(maxsize=None) diff --git a/skcriteria/utils/bunch.py b/skcriteria/utils/bunch.py index f9aedd9..9ff444b 100644 --- a/skcriteria/utils/bunch.py +++ b/skcriteria/utils/bunch.py @@ -108,3 +108,11 @@ def __repr__(self): def __dir__(self): """x.__dir__() <==> dir(x).""" return super().__dir__() + list(self._data) + + def __setstate__(self, state): + """Needed for multiprocessing environment.""" + self.__dict__.update(state) + + def get(self, key, default=None): + """Get item from bunch.""" + return self._data.get(key, default) diff --git a/tests/cmp/test_ranks_cmp.py b/tests/cmp/test_ranks_cmp.py index 5766a74..2d90668 100644 --- a/tests/cmp/test_ranks_cmp.py +++ b/tests/cmp/test_ranks_cmp.py @@ -271,6 +271,24 @@ def test_RanksComparator_hash(): assert id(rcmp) == hash(rcmp) +def test_RanksComparator_extra_get(): + rank0 = agg.RankResult( + "test", ["a", "b"], [1, 1], {"alpha": 1, "bravo": 2} + ) + rank1 = agg.RankResult( + "test", ["a", "b"], [1, 1], {"alpha": 1, "delta": 3} + ) + rcmp = ranks_cmp.mkrank_cmp(rank0, rank1) + + assert rcmp.extra_get("alpha") == {"test_1": 1, "test_2": 1} + assert rcmp.extra_get("bravo") == {"test_1": 2, "test_2": None} + assert rcmp.extra_get("delta", "foo") == {"test_1": "foo", "test_2": 3} + assert rcmp.extra_get("charly", "foo") == { + "test_1": "foo", + "test_2": "foo", + } + + def test_RanksComparator_plot(): rank0 = agg.RankResult("test", ["a", "b"], [1, 1], {}) rank1 = agg.RankResult("test", ["a", "b"], [1, 1], {}) diff --git a/tests/utils/test_bunch.py b/tests/utils/test_bunch.py index a5fa8cf..563452f 100644 --- a/tests/utils/test_bunch.py +++ b/tests/utils/test_bunch.py @@ -19,6 +19,7 @@ # ============================================================================= import copy +import pickle import pytest @@ -97,3 +98,19 @@ def test_Bunch_assign_fails(): foo_bunch = bunch.Bunch("foo", {}) with pytest.raises(AttributeError, match="Bunch 'foo' is read-only"): foo_bunch.some_key = 1 + + +def test_Bunch_setstate(): + md = bunch.Bunch("foo", {"alfa": 1}) + md_c = pickle.loads(pickle.dumps(md)) + + assert md is not md_c + assert md._name == md_c._name # string are inmutable never deep copy + assert md._data == md_c._data and md._data is not md_c._data + + +def test_Bunch_get(): + md = bunch.Bunch("foo", {"alfa": 1}) + assert md.get("alfa") == 1 + assert md.get("bravo") is None + assert md.get("bravo", 2) == 2