Skip to content

Commit

Permalink
重构了 Pool
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Jan 6, 2024
1 parent 1f636dc commit 6193ea3
Showing 1 changed file with 57 additions and 41 deletions.
98 changes: 57 additions & 41 deletions cftool/data_structures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import io
import gc
import time

from typing import Any
from typing import Set
Expand All @@ -14,7 +13,9 @@
from typing import Callable
from typing import Iterator
from typing import Optional
from typing import Protocol
from typing import NamedTuple
from datetime import datetime
from pydantic import Field
from pydantic import BaseModel

Expand All @@ -41,9 +42,10 @@
TImage = None


TItemData = TypeVar("TItemData")
TTypes = TypeVar("TTypes")
TItem = TypeVar("TItem")
TItemData = TypeVar("TItemData")
TPoolItem = TypeVar("TPoolItem", bound="IPoolItem")
PItemInit = Callable[[], TPoolItem]


class Item(Generic[TItemData]):
Expand Down Expand Up @@ -238,41 +240,61 @@ def values(self) -> Iterator[Type[TTypes]]:
return self._types.values() # type: ignore


class ILoadableItem(Generic[TItem]):
_item: Optional[TItem]
class IPoolItem:
def unload(self) -> None:
pass


class PoolItemManager(Generic[TPoolItem]):
_item: Optional[TPoolItem]

def __init__(
self,
init_fn: Callable[[], TItem],
init_fn: PItemInit,
*,
init: bool = False,
force_keep: bool = False,
):
self.init_fn = init_fn
self.load_time = time.time()
self.load_time = datetime.now()
self.force_keep = force_keep
self._item = init_fn() if init or force_keep else None

def load(self, **kwargs: Any) -> TItem:
self.load_time = time.time()
@property
def item(self) -> TPoolItem:
if self._item is None:
raise ValueError("item is not loaded")
return self._item

@property
def loaded(self) -> bool:
return self._item is not None

def load(self, **kwargs: Any) -> TPoolItem:
self.load_time = datetime.now()
if self._item is None:
self._item = self.init_fn()
return self._item

def unload(self) -> None:
unload_fn = getattr(self._item, "unload", None)
if unload_fn is not None:
unload_fn()
del self._item
self._item = None
gc.collect()


class ILoadablePool(Generic[TItem]):
pool: Dict[str, ILoadableItem]
activated: Dict[str, ILoadableItem]
class Pool(Generic[TPoolItem]):
t_manager = PoolItemManager

pool: Dict[str, PoolItemManager[TPoolItem]]

# set `limit` to negative values to indicate 'no limit'
def __init__(self, limit: int = -1):
def __init__(self, limit: int = -1, *, allow_duplicate: bool = False):
self.pool = {}
self.activated = {}
self.limit = limit
self.allow_duplicate = allow_duplicate
if limit == 0:
raise ValueError(
"limit should either be negative "
Expand All @@ -283,39 +305,33 @@ def __contains__(self, key: str) -> bool:
return key in self.pool

@property
def num_activated(self) -> int:
return len([v for v in self.activated.values() if not v.force_keep])
def activated(self) -> Dict[str, PoolItemManager[TPoolItem]]:
return {k: m for k, m in self.pool.items() if m.loaded and not m.force_keep}

def register(self, key: str, init_fn: Callable[[bool], ILoadableItem]) -> None:
def register(self, key: str, init_fn: PItemInit, **kwargs: Any) -> None:
if key in self.pool:
if self.allow_duplicate:
return
raise ValueError(f"key '{key}' already exists")
init = self.limit < 0 or self.num_activated < self.limit
loadable_item = init_fn(init)
self.pool[key] = loadable_item
if init or loadable_item.force_keep:
self.activated[key] = loadable_item

def get(self, key: str, **kwargs: Any) -> TItem:
loadable_item = self.pool.get(key)
if loadable_item is None:
init = self.limit < 0 or len(self.activated) < self.limit
manager = self.t_manager(init_fn, init=init, **kwargs)
self.pool[key] = manager

def get(self, key: str, **kwargs: Any) -> TPoolItem:
target = self.pool.get(key)
if target is None:
raise ValueError(f"key '{key}' does not exist")
item = loadable_item.load(**kwargs)
if key in self.activated:
return item
load_times = {
key: item.load_time
for key, item in self.activated.items()
if not item.force_keep
}
print("> activated", self.activated)
print("> load_times", load_times)
if target.loaded:
return target.item
item = target.load(**kwargs)
load_times = {k: m.load_time for k, m in self.activated.items()}
earliest_key = list(sort_dict_by_value(load_times).keys())[0]
self.activated.pop(earliest_key).unload()
self.activated[key] = loadable_item
time_format = "-".join(TIME_FORMAT.split("-")[:-1])
earliest = self.pool[earliest_key]
earliest.unload()
get_time_str = lambda m: datetime.strftime(m.load_time, TIME_FORMAT)
print_info(
f"'{earliest_key}' is unloaded to make room for '{key}' "
f"(last updated: {time.strftime(time_format, time.localtime(loadable_item.load_time))})"
f"'{earliest_key}' (last updated: {get_time_str(earliest)}) is unloaded "
f"to make room for '{key}' (last updated: {get_time_str(target)})"
)
return item

Expand Down

0 comments on commit 6193ea3

Please sign in to comment.