diff --git a/my/core/common.py b/my/core/common.py index 460a6587..0be4daea 100644 --- a/my/core/common.py +++ b/my/core/common.py @@ -65,84 +65,6 @@ def import_dir(path: PathIsh, extra: str='') -> types.ModuleType: return import_from(p.parent, p.name + extra) -T = TypeVar('T') -K = TypeVar('K') -V = TypeVar('V') - -# TODO more_itertools.bucket? -def group_by_key(l: Iterable[T], key: Callable[[T], K]) -> Dict[K, List[T]]: - res: Dict[K, List[T]] = {} - for i in l: - kk = key(i) - lst = res.get(kk, []) - lst.append(i) - res[kk] = lst - return res - - -def _identity(v: T) -> V: # type: ignore[type-var] - return cast(V, v) - - -# ugh. nothing in more_itertools? -def ensure_unique( - it: Iterable[T], - *, - key: Callable[[T], K], - value: Callable[[T], V]=_identity, - key2value: Optional[Dict[K, V]]=None -) -> Iterable[T]: - if key2value is None: - key2value = {} - for i in it: - k = key(i) - v = value(i) - pv = key2value.get(k, None) - if pv is not None: - raise RuntimeError(f"Duplicate key: {k}. Previous value: {pv}, new value: {v}") - key2value[k] = v - yield i - - -def test_ensure_unique() -> None: - import pytest - assert list(ensure_unique([1, 2, 3], key=lambda i: i)) == [1, 2, 3] - - dups = [1, 2, 1, 4] - # this works because it's lazy - it = ensure_unique(dups, key=lambda i: i) - - # but forcing throws - with pytest.raises(RuntimeError, match='Duplicate key'): - list(it) - - # hacky way to force distinct objects? - list(ensure_unique(dups, key=lambda i: object())) - - -def make_dict( - it: Iterable[T], - *, - key: Callable[[T], K], - value: Callable[[T], V]=_identity -) -> Dict[K, V]: - res: Dict[K, V] = {} - uniques = ensure_unique(it, key=key, value=value, key2value=res) - for _ in uniques: - pass # force the iterator - return res - - -def test_make_dict() -> None: - it = range(5) - d = make_dict(it, key=lambda i: i, value=lambda i: i % 2) - assert d == {0: 0, 1: 1, 2: 0, 3: 1, 4: 0} - - # check type inference - d2: Dict[str, int ] = make_dict(it, key=lambda i: str(i)) - d3: Dict[str, bool] = make_dict(it, key=lambda i: str(i), value=lambda i: i % 2 == 0) - - # https://stackoverflow.com/a/12377059/706389 def listify(fn=None, wrapper=list): """ @@ -696,6 +618,22 @@ def cproperty(*args, **kwargs): return functools.cached_property(*args, **kwargs) + @deprecated('use more_itertools.bucket instead') + def group_by_key(l, key): + res = {} + for i in l: + kk = key(i) + lst = res.get(kk, []) + lst.append(i) + res[kk] = lst + return res + + @deprecated('use my.core.utils.make_dict instead') + def make_dict(*args, **kwargs): + from .utils import itertools as UI + + return UI.make_dict(*args, **kwargs) + # todo wrap these in deprecated decorator as well? from .cachew import mcachew # noqa: F401 diff --git a/my/core/utils/itertools.py b/my/core/utils/itertools.py new file mode 100644 index 00000000..78b91de2 --- /dev/null +++ b/my/core/utils/itertools.py @@ -0,0 +1,77 @@ +""" +Various helpers/transforms of iterators + +Ideally this should be as small as possible and we should rely on stdlib itertools or more_itertools +""" + +from typing import Callable, Dict, Iterable, TypeVar, cast + + +T = TypeVar('T') +K = TypeVar('K') +V = TypeVar('V') + + +def _identity(v: T) -> V: # type: ignore[type-var] + return cast(V, v) + + +# ugh. nothing in more_itertools? +# perhaps duplicates_everseen? but it doesn't yield non-unique elements? +def ensure_unique(it: Iterable[T], *, key: Callable[[T], K]) -> Iterable[T]: + key2item: Dict[K, T] = {} + for i in it: + k = key(i) + pi = key2item.get(k, None) + if pi is not None: + raise RuntimeError(f"Duplicate key: {k}. Previous value: {pi}, new value: {i}") + key2item[k] = i + yield i + + +def test_ensure_unique() -> None: + import pytest + + assert list(ensure_unique([1, 2, 3], key=lambda i: i)) == [1, 2, 3] + + dups = [1, 2, 1, 4] + # this works because it's lazy + it = ensure_unique(dups, key=lambda i: i) + + # but forcing throws + with pytest.raises(RuntimeError, match='Duplicate key'): + list(it) + + # hacky way to force distinct objects? + list(ensure_unique(dups, key=lambda i: object())) + + +def make_dict( + it: Iterable[T], + *, + key: Callable[[T], K], + # TODO make value optional instead? but then will need a typing override for it? + value: Callable[[T], V] = _identity, +) -> Dict[K, V]: + with_keys = ((key(i), i) for i in it) + uniques = ensure_unique(with_keys, key=lambda p: p[0]) + res: Dict[K, V] = {} + for k, i in uniques: + res[k] = i if value is None else value(i) + return res + + +def test_make_dict() -> None: + import pytest + + it = range(5) + d = make_dict(it, key=lambda i: i, value=lambda i: i % 2) + assert d == {0: 0, 1: 1, 2: 0, 3: 1, 4: 0} + + it = range(5) + with pytest.raises(RuntimeError, match='Duplicate key'): + d = make_dict(it, key=lambda i: i % 2, value=lambda i: i) + + # check type inference + d2: Dict[str, int] = make_dict(it, key=lambda i: str(i)) + d3: Dict[str, bool] = make_dict(it, key=lambda i: str(i), value=lambda i: i % 2 == 0) diff --git a/my/github/ghexport.py b/my/github/ghexport.py index 9dc8fd5f..80106a51 100644 --- a/my/github/ghexport.py +++ b/my/github/ghexport.py @@ -65,11 +65,10 @@ def _dal() -> dal.DAL: @mcachew(depends_on=inputs) def events() -> Results: - from my.core.common import ensure_unique - key = lambda e: object() if isinstance(e, Exception) else e.eid + # key = lambda e: object() if isinstance(e, Exception) else e.eid # crap. sometimes API events can be repeated with exactly the same payload and different id # yield from ensure_unique(_events(), key=key) - yield from _events() + return _events() def _events() -> Results: diff --git a/my/jawbone/__init__.py b/my/jawbone/__init__.py index 7f4d6bdb..0659bc66 100644 --- a/my/jawbone/__init__.py +++ b/my/jawbone/__init__.py @@ -108,16 +108,22 @@ def load_sleeps() -> List[SleepEntry]: from ..core.error import Res, set_error_datetime, extract_error_datetime def pre_dataframe() -> Iterable[Res[SleepEntry]]: + from more_itertools import bucket + sleeps = load_sleeps() # todo emit error if graph doesn't exist?? sleeps = [s for s in sleeps if s.graph.exists()] # TODO careful.. - from ..core.common import group_by_key - for dd, group in group_by_key(sleeps, key=lambda s: s.date_).items(): + + bucketed = bucket(sleeps, key=lambda s: s.date_) + + for dd in bucketed: + group = list(bucketed[dd]) if len(group) == 1: yield group[0] else: err = RuntimeError(f'Multiple sleeps per night, not supported yet: {group}') - set_error_datetime(err, dt=dd) # type: ignore[arg-type] + dt = datetime.combine(dd, time.min) + set_error_datetime(err, dt=dt) logger.exception(err) yield err diff --git a/my/pdfs.py b/my/pdfs.py index 3305ecaf..b3ef85d2 100644 --- a/my/pdfs.py +++ b/my/pdfs.py @@ -17,10 +17,10 @@ from my.core import LazyLogger, get_files, Paths, PathIsh from my.core.cachew import mcachew from my.core.cfg import Attrs, make_config -from my.core.common import group_by_key from my.core.error import Res, split_errors +from more_itertools import bucket import pdfannots @@ -169,7 +169,9 @@ def annotated_pdfs(*, filelist: Optional[Sequence[PathIsh]]=None) -> Iterator[Re ait = annotations() vit, eit = split_errors(ait, ET=Exception) - for k, g in group_by_key(vit, key=lambda a: a.path).items(): + bucketed = bucket(vit, key=lambda a: a.path) + for k in bucketed: + g = list(bucketed[k]) yield Pdf(path=Path(k), annotations=g) yield from eit diff --git a/my/rtm.py b/my/rtm.py index 8d41e7ad..56f4d076 100644 --- a/my/rtm.py +++ b/my/rtm.py @@ -11,13 +11,15 @@ import re from typing import Dict, List, Iterator -from .core.common import LazyLogger, get_files, group_by_key, make_dict +from my.core.common import LazyLogger, get_files +from my.core.utils.itertools import make_dict from my.config import rtm as config -import icalendar # type: ignore -from icalendar.cal import Todo # type: ignore +from more_itertools import bucket +import icalendar # type: ignore +from icalendar.cal import Todo # type: ignore logger = LazyLogger(__name__) @@ -96,7 +98,8 @@ def get_todos_by_uid(self) -> Dict[str, MyTodo]: def get_todos_by_title(self) -> Dict[str, List[MyTodo]]: todos = self.all_todos() - return group_by_key(todos, lambda todo: todo.title) + bucketed = bucket(todos, lambda todo: todo.title) + return {k: list(bucketed[k]) for k in bucketed} def dal(): diff --git a/my/tests/reddit.py b/my/tests/reddit.py index 4af95ae9..fb8d6d2f 100644 --- a/my/tests/reddit.py +++ b/my/tests/reddit.py @@ -1,9 +1,10 @@ from my.core.cfg import tmp_config -from my.core.common import make_dict +from my.core.utils.itertools import ensure_unique # todo ugh, it's discovered as a test??? from .common import testdata +from more_itertools import consume import pytest # deliberately use mixed style imports on the top level and inside the methods to test tmp_config stuff @@ -36,8 +37,8 @@ def test_saves() -> None: saves = list(saved()) assert len(saves) > 0 - # just check that they are unique (makedict will throw) - make_dict(saves, key=lambda s: s.sid) + # will throw if not unique + consume(ensure_unique(saves, key=lambda s: s.sid)) def test_preserves_extra_attr() -> None: