Skip to content

Commit

Permalink
[Misc] Clean up and consolidate LRUCache (vllm-project#11339)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Dec 19, 2024
1 parent e24113a commit cdf22af
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 67 deletions.
9 changes: 4 additions & 5 deletions vllm/adapter_commons/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
from typing import Any, Callable, Dict, Optional, TypeVar

from torch import nn

Expand All @@ -24,14 +24,13 @@ def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
T = TypeVar('T')


class AdapterLRUCache(LRUCache[T]):
class AdapterLRUCache(LRUCache[int, T]):

def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn

def _on_remove(self, key: Hashable, value: Optional[T]):
def _on_remove(self, key: int, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/tokenizer_group/tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[AnyTokenizer](
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)

@classmethod
Expand Down
59 changes: 26 additions & 33 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
overload)
Optional, Tuple, Type, TypeVar, Union, overload)
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -154,10 +153,12 @@
}

P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")

_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")


class _Sentinel:
...
Expand Down Expand Up @@ -190,50 +191,48 @@ def reset(self) -> None:
self.counter = 0


class LRUCache(Generic[T]):
class LRUCache(Generic[_K, _V]):

def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: Set[Hashable] = set()
def __init__(self, capacity: int) -> None:
self.cache = OrderedDict[_K, _V]()
self.pinned_items = set[_K]()
self.capacity = capacity

def __contains__(self, key: Hashable) -> bool:
def __contains__(self, key: _K) -> bool:
return key in self.cache

def __len__(self) -> int:
return len(self.cache)

def __getitem__(self, key: Hashable) -> T:
def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value

def __setitem__(self, key: Hashable, value: T) -> None:
def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value)

def __delitem__(self, key: Hashable) -> None:
def __delitem__(self, key: _K) -> None:
self.pop(key)

def touch(self, key: Hashable) -> None:
def touch(self, key: _K) -> None:
self.cache.move_to_end(key)

def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
value: Optional[_V]
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
value = default
return value

def put(self, key: Hashable, value: T) -> None:
def put(self, key: _K, value: _V) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()

def pin(self, key: Hashable) -> None:
def pin(self, key: _K) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
Expand All @@ -242,13 +241,13 @@ def pin(self, key: Hashable) -> None:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)

def _unpin(self, key: Hashable) -> None:
def _unpin(self, key: _K) -> None:
self.pinned_items.remove(key)

def _on_remove(self, key: Hashable, value: Optional[T]):
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass

def remove_oldest(self, remove_pinned=False):
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache:
return

Expand All @@ -262,25 +261,23 @@ def remove_oldest(self, remove_pinned=False):
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key)
self.pop(lru_key) # type: ignore

def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()

def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value)
value = self.cache.pop(key, default)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value

def clear(self):
def clear(self) -> None:
while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True)
self.cache.clear()
Expand Down Expand Up @@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist]


_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")


def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike :class:`itertools.groupby`, groups are not broken by
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache
from vllm.utils import LRUCache

logger = init_logger(__name__)

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(

# Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)

# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
Expand Down Expand Up @@ -120,7 +120,7 @@ class MMInputMapperServer:

def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)

def process_inputs(
self,
Expand Down
25 changes: 0 additions & 25 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections import OrderedDict
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
Expand Down Expand Up @@ -102,27 +101,3 @@ def make_zmq_socket(

finally:
ctx.destroy(linger=0)


K = TypeVar('K')
V = TypeVar('V')


class LRUDictCache(Generic[K, V]):

def __init__(self, size: int):
self.cache: OrderedDict[K, V] = OrderedDict()
self.size = size

def get(self, key: K, default=None) -> V:
if key not in self.cache:
return default

self.cache.move_to_end(key)
return self.cache[key]

def put(self, key: K, value: V):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.size:
self.cache.popitem(last=False)

0 comments on commit cdf22af

Please sign in to comment.