Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TriePagedAttentionCache - 2 #628

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,68 @@
Base class for kv caches.
"""

from typing import List
from typing import List, Iterable, Protocol
from .page_pool import PageInfo
import math
from abc import ABC, abstractmethod
from .page_pool import PagePool

# logging
import logging

logger = logging.getLogger(__name__)

# exception for when cache allocation failed
class CacheAllocationFailure(Exception):
pass


class PageAllocation(ABC):
"""Abstract base class for page allocations in the cache."""

@property
@abstractmethod
def pages(self) -> List[PageInfo]:
"""Returns the list of pages that were allocated."""
pass

@abstractmethod
def publish_pages_for_tokens(self, tokens, publish_incomplete_pages=False) -> None:
"""
Makes pages available to other requests. For details, reference the derived class in trie_attention_cache.py.
"""
pass

@abstractmethod
def release_pages(self) -> None:
"""Releases the allocation's reference to pages."""
pass


class BasePageAttentionCacheAllocation(PageAllocation):
"""Represents a page allocation in the cache."""

def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"):
self._pages = tuple(pages)
self._cache = cache
self._is_released = False

@property
def pages(self) -> List[PageInfo]:
return list(self._pages)

def publish_pages_for_tokens(self, tokens, publish_incomplete_pages=False) -> None:
pass

def release_pages(self) -> None:
if self._is_released:
logger.warning("Releasing already-released allocation")
return
self._cache.page_pool.free_pages(self._pages)
self._is_released = True

def __rerp__(self) -> str:
return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})"


class BasePagedAttentionCache:
Expand All @@ -33,13 +92,13 @@ class BasePagedAttentionCache:
- Reference counting prevents eviction of in-use pages
"""

def __init__(self, page_pool, tokens_per_page):
def __init__(self, page_pool: PagePool, tokens_per_page: int):
self.page_pool = page_pool
self.tokens_per_page = tokens_per_page

def acquire_pages_for_tokens(
self, tokens: List[int], extra_token_slots: int = 1
) -> tuple[list[PageInfo], int]:
) -> PageAllocation:
"""
Given a list of tokens, return a list of pages and a start position to continue generation from.

Expand All @@ -57,24 +116,7 @@ def acquire_pages_for_tokens(
pages_needed = math.ceil(token_count / self.tokens_per_page)
pages = self.page_pool.acquire_free_pages(pages_needed)

n_cached_tokens = 0

return pages, n_cached_tokens

def publish_pages(self, tokens, pages) -> None:
"""
Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests.
if pages is None:
raise CacheAllocationFailure()

Associates the tokens with the pages, and mark them as done writing.

It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)].
"""

pass # the base implementation doesn't cache unfinished requests.

def release_pages(self, tokens, pages):
"""
Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction.
"""
# in the base implementation, the pages can be owned by 1 request max, so they can be instantly release
self.page_pool.release_pages(pages)
return BasePageAttentionCacheAllocation(pages, cache=self)
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig
for i in range(self.config.alloc_page_count)
]

self.attn_page_free = list(self.attn_page_entries)
self.available_pages = list(self.attn_page_entries)

# Initialize a page table on each device.
page_table_shape = [
Expand All @@ -108,14 +108,14 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig

def acquire_free_pages(self, count: int) -> list[PageInfo] | None:
with self._lock:
available = len(self.attn_page_free)
available = len(self.available_pages)
if count > available:
return None
return [self.attn_page_free.pop() for _ in range(count)]
return [self.available_pages.pop() for _ in range(count)]

def release_pages(self, pages: list[PageInfo]):
def free_pages(self, pages: list[PageInfo]):
with self._lock:
self.attn_page_free.extend(pages)
self.available_pages.extend(pages)

def copy_page(self, src_page: PageInfo) -> PageInfo:
"""
Expand Down Expand Up @@ -148,7 +148,7 @@ def copy_page(self, src_page: PageInfo) -> PageInfo:

def __repr__(self):
# No need to lock for repr (list is internally synchronized).
free_pages = len(self.attn_page_free)
free_pages = len(self.available_pages)
total_pages = len(self.attn_page_entries)
return (
f"PagePool({total_pages - free_pages}/{total_pages} pages in use: "
Expand Down
Loading
Loading