Skip to content

Commit

Permalink
core: cleanup itertool style helpers
Browse files Browse the repository at this point in the history
- deprecate group_by_key, should use itertool.bucket instead
- move make_dict and ensure_unique to my.core.utils.itertools
  • Loading branch information
karlicoss committed Aug 13, 2024
1 parent 918e8ee commit 029fa3a
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 93 deletions.
94 changes: 16 additions & 78 deletions my/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,84 +65,6 @@ def import_dir(path: PathIsh, extra: str='') -> types.ModuleType:
return import_from(p.parent, p.name + extra)


T = TypeVar('T')
K = TypeVar('K')
V = TypeVar('V')

# TODO more_itertools.bucket?
def group_by_key(l: Iterable[T], key: Callable[[T], K]) -> Dict[K, List[T]]:
res: Dict[K, List[T]] = {}
for i in l:
kk = key(i)
lst = res.get(kk, [])
lst.append(i)
res[kk] = lst
return res


def _identity(v: T) -> V: # type: ignore[type-var]
return cast(V, v)


# ugh. nothing in more_itertools?
def ensure_unique(
it: Iterable[T],
*,
key: Callable[[T], K],
value: Callable[[T], V]=_identity,
key2value: Optional[Dict[K, V]]=None
) -> Iterable[T]:
if key2value is None:
key2value = {}
for i in it:
k = key(i)
v = value(i)
pv = key2value.get(k, None)
if pv is not None:
raise RuntimeError(f"Duplicate key: {k}. Previous value: {pv}, new value: {v}")
key2value[k] = v
yield i


def test_ensure_unique() -> None:
import pytest
assert list(ensure_unique([1, 2, 3], key=lambda i: i)) == [1, 2, 3]

dups = [1, 2, 1, 4]
# this works because it's lazy
it = ensure_unique(dups, key=lambda i: i)

# but forcing throws
with pytest.raises(RuntimeError, match='Duplicate key'):
list(it)

# hacky way to force distinct objects?
list(ensure_unique(dups, key=lambda i: object()))


def make_dict(
it: Iterable[T],
*,
key: Callable[[T], K],
value: Callable[[T], V]=_identity
) -> Dict[K, V]:
res: Dict[K, V] = {}
uniques = ensure_unique(it, key=key, value=value, key2value=res)
for _ in uniques:
pass # force the iterator
return res


def test_make_dict() -> None:
it = range(5)
d = make_dict(it, key=lambda i: i, value=lambda i: i % 2)
assert d == {0: 0, 1: 1, 2: 0, 3: 1, 4: 0}

# check type inference
d2: Dict[str, int ] = make_dict(it, key=lambda i: str(i))
d3: Dict[str, bool] = make_dict(it, key=lambda i: str(i), value=lambda i: i % 2 == 0)


# https://stackoverflow.com/a/12377059/706389
def listify(fn=None, wrapper=list):
"""
Expand Down Expand Up @@ -696,6 +618,22 @@ def cproperty(*args, **kwargs):

return functools.cached_property(*args, **kwargs)

@deprecated('use more_itertools.bucket instead')
def group_by_key(l, key):
res = {}
for i in l:
kk = key(i)
lst = res.get(kk, [])
lst.append(i)
res[kk] = lst
return res

@deprecated('use my.core.utils.make_dict instead')
def make_dict(*args, **kwargs):
from .utils import itertools as UI

return UI.make_dict(*args, **kwargs)

# todo wrap these in deprecated decorator as well?
from .cachew import mcachew # noqa: F401

Expand Down
77 changes: 77 additions & 0 deletions my/core/utils/itertools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Various helpers/transforms of iterators
Ideally this should be as small as possible and we should rely on stdlib itertools or more_itertools
"""

from typing import Callable, Dict, Iterable, TypeVar, cast


T = TypeVar('T')
K = TypeVar('K')
V = TypeVar('V')


def _identity(v: T) -> V: # type: ignore[type-var]
return cast(V, v)


# ugh. nothing in more_itertools?
# perhaps duplicates_everseen? but it doesn't yield non-unique elements?
def ensure_unique(it: Iterable[T], *, key: Callable[[T], K]) -> Iterable[T]:
key2item: Dict[K, T] = {}
for i in it:
k = key(i)
pi = key2item.get(k, None)
if pi is not None:
raise RuntimeError(f"Duplicate key: {k}. Previous value: {pi}, new value: {i}")
key2item[k] = i
yield i


def test_ensure_unique() -> None:
import pytest

assert list(ensure_unique([1, 2, 3], key=lambda i: i)) == [1, 2, 3]

dups = [1, 2, 1, 4]
# this works because it's lazy
it = ensure_unique(dups, key=lambda i: i)

# but forcing throws
with pytest.raises(RuntimeError, match='Duplicate key'):
list(it)

# hacky way to force distinct objects?
list(ensure_unique(dups, key=lambda i: object()))


def make_dict(
it: Iterable[T],
*,
key: Callable[[T], K],
# TODO make value optional instead? but then will need a typing override for it?
value: Callable[[T], V] = _identity,
) -> Dict[K, V]:
with_keys = ((key(i), i) for i in it)
uniques = ensure_unique(with_keys, key=lambda p: p[0])
res: Dict[K, V] = {}
for k, i in uniques:
res[k] = i if value is None else value(i)
return res


def test_make_dict() -> None:
import pytest

it = range(5)
d = make_dict(it, key=lambda i: i, value=lambda i: i % 2)
assert d == {0: 0, 1: 1, 2: 0, 3: 1, 4: 0}

it = range(5)
with pytest.raises(RuntimeError, match='Duplicate key'):
d = make_dict(it, key=lambda i: i % 2, value=lambda i: i)

# check type inference
d2: Dict[str, int] = make_dict(it, key=lambda i: str(i))
d3: Dict[str, bool] = make_dict(it, key=lambda i: str(i), value=lambda i: i % 2 == 0)
5 changes: 2 additions & 3 deletions my/github/ghexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ def _dal() -> dal.DAL:

@mcachew(depends_on=inputs)
def events() -> Results:
from my.core.common import ensure_unique
key = lambda e: object() if isinstance(e, Exception) else e.eid
# key = lambda e: object() if isinstance(e, Exception) else e.eid
# crap. sometimes API events can be repeated with exactly the same payload and different id
# yield from ensure_unique(_events(), key=key)
yield from _events()
return _events()


def _events() -> Results:
Expand Down
12 changes: 9 additions & 3 deletions my/jawbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,22 @@ def load_sleeps() -> List[SleepEntry]:
from ..core.error import Res, set_error_datetime, extract_error_datetime

def pre_dataframe() -> Iterable[Res[SleepEntry]]:
from more_itertools import bucket

sleeps = load_sleeps()
# todo emit error if graph doesn't exist??
sleeps = [s for s in sleeps if s.graph.exists()] # TODO careful..
from ..core.common import group_by_key
for dd, group in group_by_key(sleeps, key=lambda s: s.date_).items():

bucketed = bucket(sleeps, key=lambda s: s.date_)

for dd in bucketed:
group = list(bucketed[dd])
if len(group) == 1:
yield group[0]
else:
err = RuntimeError(f'Multiple sleeps per night, not supported yet: {group}')
set_error_datetime(err, dt=dd) # type: ignore[arg-type]
dt = datetime.combine(dd, time.min)
set_error_datetime(err, dt=dt)
logger.exception(err)
yield err

Expand Down
6 changes: 4 additions & 2 deletions my/pdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from my.core import LazyLogger, get_files, Paths, PathIsh
from my.core.cachew import mcachew
from my.core.cfg import Attrs, make_config
from my.core.common import group_by_key
from my.core.error import Res, split_errors


from more_itertools import bucket
import pdfannots


Expand Down Expand Up @@ -169,7 +169,9 @@ def annotated_pdfs(*, filelist: Optional[Sequence[PathIsh]]=None) -> Iterator[Re
ait = annotations()
vit, eit = split_errors(ait, ET=Exception)

for k, g in group_by_key(vit, key=lambda a: a.path).items():
bucketed = bucket(vit, key=lambda a: a.path)
for k in bucketed:
g = list(bucketed[k])
yield Pdf(path=Path(k), annotations=g)
yield from eit

Expand Down
11 changes: 7 additions & 4 deletions my/rtm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
import re
from typing import Dict, List, Iterator

from .core.common import LazyLogger, get_files, group_by_key, make_dict
from my.core.common import LazyLogger, get_files
from my.core.utils.itertools import make_dict

from my.config import rtm as config


import icalendar # type: ignore
from icalendar.cal import Todo # type: ignore
from more_itertools import bucket
import icalendar # type: ignore
from icalendar.cal import Todo # type: ignore


logger = LazyLogger(__name__)
Expand Down Expand Up @@ -96,7 +98,8 @@ def get_todos_by_uid(self) -> Dict[str, MyTodo]:

def get_todos_by_title(self) -> Dict[str, List[MyTodo]]:
todos = self.all_todos()
return group_by_key(todos, lambda todo: todo.title)
bucketed = bucket(todos, lambda todo: todo.title)
return {k: list(bucketed[k]) for k in bucketed}


def dal():
Expand Down
7 changes: 4 additions & 3 deletions my/tests/reddit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from my.core.cfg import tmp_config
from my.core.common import make_dict
from my.core.utils.itertools import ensure_unique

# todo ugh, it's discovered as a test???
from .common import testdata

from more_itertools import consume
import pytest

# deliberately use mixed style imports on the top level and inside the methods to test tmp_config stuff
Expand Down Expand Up @@ -36,8 +37,8 @@ def test_saves() -> None:
saves = list(saved())
assert len(saves) > 0

# just check that they are unique (makedict will throw)
make_dict(saves, key=lambda s: s.sid)
# will throw if not unique
consume(ensure_unique(saves, key=lambda s: s.sid))


def test_preserves_extra_attr() -> None:
Expand Down

0 comments on commit 029fa3a

Please sign in to comment.