diff --git a/torcharrow/dispatcher.py b/torcharrow/dispatcher.py index 942da1a82..dcaafe4df 100644 --- a/torcharrow/dispatcher.py +++ b/torcharrow/dispatcher.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, ClassVar, Dict, Tuple +from typing import Callable, ClassVar, Dict, Generic, Tuple, TypeVar # --------------------------------------------------------------------------- # column factory (class methods only!) @@ -13,21 +13,24 @@ Typecode = str # one of dtype.typecode -class Dispatcher: +T = TypeVar("T") - # singelton, append only, registering is idempotent - _calls: ClassVar[Dict[Tuple[Typecode, Device], Callable]] = {} - @classmethod - def register(cls, key: Tuple[Typecode, Device], call: Callable): - # key is tuple: (device,typecode) - if key in Dispatcher._calls: - if call == Dispatcher._calls[key]: +class _Dispatcher(Generic[T]): + def __init__(self): + # append only, registering is idempotent + self._calls: Dict[T, Callable] = {} + + def register(self, key: T, call: Callable): + if key in self._calls: + if call == self._calls[key]: return else: raise ValueError("keys for calls can only be registered once") - Dispatcher._calls[key] = call + self._calls[key] = call + + def lookup(self, key: T) -> Callable: + return self._calls[key] + - @classmethod - def lookup(cls, key: Tuple[Typecode, Device]) -> Callable: - return Dispatcher._calls[key] +Dispatcher: _Dispatcher[Tuple[Typecode, Device]] = _Dispatcher()