diff --git a/boiling_learning/preprocessing/experiment_video.py b/boiling_learning/preprocessing/experiment_video.py index 6236cf63..d2309ddb 100644 --- a/boiling_learning/preprocessing/experiment_video.py +++ b/boiling_learning/preprocessing/experiment_video.py @@ -10,9 +10,10 @@ from boiling_learning.datasets.sliceable import SliceableDataset from boiling_learning.io import json from boiling_learning.preprocessing.video import Video, VideoFrame, convert_video +from boiling_learning.utils.collections import merge_dicts from boiling_learning.utils.dataclasses import dataclass, field from boiling_learning.utils.descriptions import describe -from boiling_learning.utils.utils import PathLike, merge_dicts, resolve +from boiling_learning.utils.utils import PathLike, resolve class ExperimentVideo: diff --git a/boiling_learning/utils/collections.py b/boiling_learning/utils/collections.py index 4a1fb333..c71f9aa7 100644 --- a/boiling_learning/utils/collections.py +++ b/boiling_learning/utils/collections.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import ChainMap from contextlib import suppress from itertools import chain from typing import ( @@ -12,6 +13,7 @@ Iterable, Iterator, KeysView, + Mapping, MutableSet, TypeVar, Union, @@ -23,6 +25,16 @@ _Value = TypeVar('_Value') +def merge_dicts( + *mappings: Mapping[_Key, _Value], + latter_precedence: bool = True, +) -> Dict[_Key, _Value]: + if latter_precedence: + return merge_dicts(*reversed(mappings), latter_precedence=False) + + return dict(ChainMap(*mappings)) + + class KeyedDefaultDict(DefaultDict[_Key, _Value]): ''' Source: https://stackoverflow.com/a/2912455/5811400 diff --git a/boiling_learning/utils/utils.py b/boiling_learning/utils/utils.py index f3f99e88..2e99e1ea 100644 --- a/boiling_learning/utils/utils.py +++ b/boiling_learning/utils/utils.py @@ -4,11 +4,10 @@ import os import random import string -from collections import ChainMap from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory -from typing import Iterable, Iterator, Mapping, Optional, Sequence, TypeVar, Union +from typing import Any, Iterable, Iterator, Optional, Sequence, TypeVar, Union import funcy @@ -24,25 +23,18 @@ def reorder(seq: Sequence[_T], indices: Iterable[int]) -> Iterable[_T]: return map(seq.__getitem__, indices) -def argmin(iterable: Iterable) -> int: +def argmin(iterable: Iterable[Any]) -> int: return min(enumerate(iterable), key=operator.itemgetter(1))[0] -def argmax(iterable: Iterable) -> int: +def argmax(iterable: Iterable[Any]) -> int: return max(enumerate(iterable), key=operator.itemgetter(1))[0] -def argsorted(iterable: Iterable) -> Iterable[int]: +def argsorted(iterable: Iterable[Any]) -> Iterable[int]: return funcy.pluck(0, sorted(enumerate(iterable), key=operator.itemgetter(1))) -def merge_dicts(*dict_args: Mapping, latter_precedence: bool = True) -> dict: - if latter_precedence: - dict_args = reversed(dict_args) - - return dict(ChainMap(*dict_args)) - - def one_factor_at_a_time( iterables: Iterable[Iterable], default_indices: Iterable[int] = tuple(),