From b8864a764e7b5bb8b825d3c5887252c7a69bc666 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 25 Nov 2024 16:44:03 -0800 Subject: [PATCH 01/18] initial PageAllocation handle implementation --- .../kvcache/base_attention_cache.py | 111 ++++++++++++++---- .../shortfin_apps/llm/components/messages.py | 39 ++---- .../shortfin_apps/llm/components/service.py | 59 +++++----- 3 files changed, 134 insertions(+), 75 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 0007000bc..40d1b1197 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -8,9 +8,91 @@ Base class for kv caches. """ -from typing import List +from typing import List, Iterable, Protocol from .page_pool import PageInfo import math +from abc import ABC, abstractmethod +from .page_pool import PagePool + +# logging +import logging + +logger = logging.getLogger(__name__) + +# exception for when cache allocation failed +class CacheAllocationFailure(Exception): + pass + + +class PageAllocation(ABC): + """ + Abstract base class for page allocations in the cache. + Subclasses only need to implement the core allocation methods. + """ + + @abstractmethod + def get_page_list(self) -> List[PageInfo]: + """Returns the list of pages that were allocated.""" + pass + + @abstractmethod + def publish_pages(self, up_to_page_index) -> None: + """ + Makes self.get_page_list()[0:up_to_page_index] available to other requests after writing is complete. + Associates tokens with pages and marks them as ready for reading. + """ + pass + + @abstractmethod + def release_pages(self) -> None: + """ + Releases the allocation's reference to pages. + Pages become eligible for eviction when their reference count reaches zero. + """ + pass + + +class BasePageAttentionCacheAllocation(PageAllocation): + """ + Represents a page allocation in the cache, implementing the PageAllocation protocol. + """ + + def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): + # this should only be called by the associated attention cache & + self._pages = tuple(pages) + self._cache = cache + self._is_released = False + + def get_page_list(self) -> List[PageInfo]: + return list(self._pages) # return a list, as expected by service.py + + def publish_pages(self, up_to_page_index) -> None: + """ + Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. + + Associates the tokens with the pages, and mark them as done writing. + + It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. + + This should be called when the request has finished writing to the pages. + """ + pass # the base implementation doesn't cache unfinished requests. + + def release_pages(self) -> None: + """ + Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. + + This should be called when the request has finished reading from the pages. + """ + # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release + if self._is_released: + logger.warning("Releasing already-released allocation") + return + self._cache.page_pool.release_pages(self._pages) + self._is_released = True + + def __repr__(self): + return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})" class BasePagedAttentionCache: @@ -33,13 +115,13 @@ class BasePagedAttentionCache: - Reference counting prevents eviction of in-use pages """ - def __init__(self, page_pool, tokens_per_page): + def __init__(self, page_pool: PagePool, tokens_per_page: int): self.page_pool = page_pool self.tokens_per_page = tokens_per_page def acquire_pages_for_tokens( self, tokens: List[int], extra_token_slots: int = 1 - ) -> tuple[list[PageInfo], int]: + ) -> PageAllocation: """ Given a list of tokens, return a list of pages and a start position to continue generation from. @@ -57,24 +139,9 @@ def acquire_pages_for_tokens( pages_needed = math.ceil(token_count / self.tokens_per_page) pages = self.page_pool.acquire_free_pages(pages_needed) - n_cached_tokens = 0 - - return pages, n_cached_tokens - - def publish_pages(self, tokens, pages) -> None: - """ - Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. - - Associates the tokens with the pages, and mark them as done writing. - - It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. - """ + if pages is None: + raise CacheAllocationFailure() - pass # the base implementation doesn't cache unfinished requests. + n_cached_tokens = 0 - def release_pages(self, tokens, pages): - """ - Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. - """ - # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release - self.page_pool.release_pages(pages) + return BasePageAttentionCacheAllocation(pages, cache=self) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index c3e6fe34b..d049b3229 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -9,7 +9,7 @@ import shortfin as sf import shortfin.array as sfnp -from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.base_attention_cache import BasePagedAttentionCache, PageAllocation from .kvcache.page_pool import PageInfo @@ -43,7 +43,7 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]): # Cache pages that have been locked for this request. self._cache: BasePagedAttentionCache | None = None - self.locked_pages: list[PageInfo] | None = None + self.allocation: PageAllocation | None = None def reset(self, phase: InferencePhase): """Resets all per request state in preparation for an subsequent execution.""" @@ -52,35 +52,22 @@ def reset(self, phase: InferencePhase): self.return_all_logits = False self.return_host_array = True self.result_logits = None + self.allocation.release_pages() + self.allocation = None def cache_page_indices(self, max_len: int) -> list[int]: - if not self.locked_pages: + if not self.allocation: return [] - indices = [p.index for p in self.locked_pages] - if len(indices) > max_len: - return indices[0:max_len] - return indices + indices = [p.index for p in self.allocation.get_page_list()] + return indices[:max_len] + + def publish_allocated_pages(self, up_to_page_index: int): + assert self.allocation + self.allocation.publish_pages(up_to_page_index) def free_cache_pages(self): - cache = self._cache - if cache: - pages = self.locked_pages - self._cache = None - self.locked_pages = None - cache.release_pages(self.input_token_ids, pages) - - def lock_initial_cache_pages( - self, cache: BasePagedAttentionCache, pages: list[PageInfo] - ): - assert not self._cache - self._cache = cache - self.locked_pages = pages - - def lock_new_cache_pages( - self, cache: BasePagedAttentionCache, pages: list[PageInfo] - ): - assert self._cache is cache - self.locked_pages.extend(pages) + if self.allocation: + self.allocation.release_pages() class StrobeMessage(sf.Message): diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 8d3cc1424..2f942aec7 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -11,8 +11,12 @@ import shortfin as sf import shortfin.array as sfnp -from .kvcache.base_attention_cache import BasePagedAttentionCache -from .kvcache.page_pool import PagePoolConfig, PagePool +from .kvcache.base_attention_cache import ( + BasePagedAttentionCache, + CacheAllocationFailure, + PageAllocation, +) +from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage @@ -229,16 +233,17 @@ def board_prefills(self, cache: BasePagedAttentionCache): len(prefill_request.input_token_ids) / self.page_seq_stride ) # allocate kv cache pages - pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( - prefill_request.input_token_ids, - extra_token_slots=0, # prefill needs no extra kvcache slots to write to - ) - if pages is None: + try: + allocation = cache.acquire_pages_for_tokens( + prefill_request.input_token_ids, + extra_token_slots=0, # prefill needs no extra kvcache slots to write to + ) + except CacheAllocationFailure: logger.debug("Cannot fulfill request for %d pages", needed_pages) continue - else: - logger.debug("Allocated %d cache pages to request", len(pages)) - prefill_request.lock_initial_cache_pages(cache, pages) + logger.debug(f"Successfully acquired allocation: {allocation}") + prefill_request.free_cache_pages() + prefill_request.allocation = allocation # Can flight this request. exec_process.exec_requests.append(prefill_request) @@ -266,26 +271,20 @@ def board_decodes(self, cache: BasePagedAttentionCache): if len(exec_process.exec_requests) >= self.ideal_batch_size: break incoming_token_count = len(decode_request.input_token_ids) - needed_pages = math.ceil( - (decode_request.start_position + incoming_token_count) - / self.page_seq_stride - ) - if needed_pages > len(decode_request.locked_pages): - # allocate kv cache pages - pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + + try: + allocation = cache.acquire_pages_for_tokens( decode_request.input_token_ids, extra_token_slots=1, # need 1 extra slot to write result. ) - if pages is None: - logger.debug( - "Cannot fulfill decode request for %d pages", needed_pages - ) - continue - else: - logger.debug( - "Allocated %d cache pages to decode request", len(pages) - ) - decode_request.lock_new_cache_pages(cache, pages) + except CacheAllocationFailure: + logger.debug( + "Cannot fulfill request for %d tokens", + len(decode_request.input_token_ids), + ) + + decode_request.free_cache_pages() + decode_request.allocation = allocation # Can flight this request. exec_process.exec_requests.append(decode_request) @@ -438,6 +437,12 @@ async def run(self): # Invoke. Logits are of shape [bs, bsl, d]. (logits,) = await fn(*args, fiber=self.fiber) + # publish cache pages + for r in self.exec_requests: + total_tokens = r.start_position + len(r.input_token_ids) + number_of_complete_pages = total_tokens // seq_stride + r.publish_allocated_pages(number_of_complete_pages) + # Return results. for i in range(req_count): req = self.exec_requests[i] From 947a7b5a6450f33c23f201d120228591140b9bd6 Mon Sep 17 00:00:00 2001 From: Cedar Date: Tue, 26 Nov 2024 11:28:39 -0800 Subject: [PATCH 02/18] docstring udpates --- .../llm/components/kvcache/base_attention_cache.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 40d1b1197..e8520b44a 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -68,13 +68,9 @@ def get_page_list(self) -> List[PageInfo]: def publish_pages(self, up_to_page_index) -> None: """ - Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. + Release self.get_pages_list()[0:up_to_page_index] for reading by other requests. - Associates the tokens with the pages, and mark them as done writing. - - It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. - - This should be called when the request has finished writing to the pages. + This should be called when writing completes, after each kernel invocation. """ pass # the base implementation doesn't cache unfinished requests. @@ -82,7 +78,11 @@ def release_pages(self) -> None: """ Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. - This should be called when the request has finished reading from the pages. + This should be called when the request has finished reading from the pages, and they are no longer needed. + + This does not immediately release the pages, but decrements the reference count. + + Pages should become available for eviction when their reference count reaches zero & the pool runs out of free pages. """ # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release if self._is_released: From 19aa8fd17b291ffaf3a802899c330af3f74063b4 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 09:52:52 -0800 Subject: [PATCH 03/18] make pages a property --- .../kvcache/base_attention_cache.py | 49 +++++-------------- .../shortfin_apps/llm/components/messages.py | 2 +- 2 files changed, 12 insertions(+), 39 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index e8520b44a..73134903c 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -25,73 +25,48 @@ class CacheAllocationFailure(Exception): class PageAllocation(ABC): - """ - Abstract base class for page allocations in the cache. - Subclasses only need to implement the core allocation methods. - """ + """Abstract base class for page allocations in the cache.""" + @property @abstractmethod - def get_page_list(self) -> List[PageInfo]: + def pages(self) -> List[PageInfo]: """Returns the list of pages that were allocated.""" pass @abstractmethod def publish_pages(self, up_to_page_index) -> None: - """ - Makes self.get_page_list()[0:up_to_page_index] available to other requests after writing is complete. - Associates tokens with pages and marks them as ready for reading. - """ + """Makes pages[0:up_to_page_index] available to other requests.""" pass @abstractmethod def release_pages(self) -> None: - """ - Releases the allocation's reference to pages. - Pages become eligible for eviction when their reference count reaches zero. - """ + """Releases the allocation's reference to pages.""" pass class BasePageAttentionCacheAllocation(PageAllocation): - """ - Represents a page allocation in the cache, implementing the PageAllocation protocol. - """ + """Represents a page allocation in the cache.""" def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): - # this should only be called by the associated attention cache & self._pages = tuple(pages) self._cache = cache self._is_released = False - def get_page_list(self) -> List[PageInfo]: - return list(self._pages) # return a list, as expected by service.py + @property + def pages(self) -> List[PageInfo]: + return list(self._pages) def publish_pages(self, up_to_page_index) -> None: - """ - Release self.get_pages_list()[0:up_to_page_index] for reading by other requests. - - This should be called when writing completes, after each kernel invocation. - """ - pass # the base implementation doesn't cache unfinished requests. + pass def release_pages(self) -> None: - """ - Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. - - This should be called when the request has finished reading from the pages, and they are no longer needed. - - This does not immediately release the pages, but decrements the reference count. - - Pages should become available for eviction when their reference count reaches zero & the pool runs out of free pages. - """ - # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release if self._is_released: logger.warning("Releasing already-released allocation") return self._cache.page_pool.release_pages(self._pages) self._is_released = True - def __repr__(self): + def __rerp__(self) -> str: return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})" @@ -142,6 +117,4 @@ def acquire_pages_for_tokens( if pages is None: raise CacheAllocationFailure() - n_cached_tokens = 0 - return BasePageAttentionCacheAllocation(pages, cache=self) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index d049b3229..148feea99 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -58,7 +58,7 @@ def reset(self, phase: InferencePhase): def cache_page_indices(self, max_len: int) -> list[int]: if not self.allocation: return [] - indices = [p.index for p in self.allocation.get_page_list()] + indices = [p.index for p in self.allocation.pages] return indices[:max_len] def publish_allocated_pages(self, up_to_page_index: int): From af4d84b21e5a9a89cf1b8b0da7e1567d707afa3e Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 10:34:19 -0800 Subject: [PATCH 04/18] another small fix for efficiency. Thanks stephen! --- shortfin/python/shortfin_apps/llm/components/messages.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 148feea99..c03900782 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -58,8 +58,8 @@ def reset(self, phase: InferencePhase): def cache_page_indices(self, max_len: int) -> list[int]: if not self.allocation: return [] - indices = [p.index for p in self.allocation.pages] - return indices[:max_len] + indices = [p.index for p in self.allocation.pages[:max_len]] + return indices def publish_allocated_pages(self, up_to_page_index: int): assert self.allocation From f184c3231c14d6e830709a677ae9ade0fbd360ec Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 11:25:08 -0800 Subject: [PATCH 05/18] add unit tests --- .../shortfin_apps/llm/components/__init__.py | 0 .../llm/components/kvcache/__init__.py | 0 .../kvcache/base_attention_cache_test.py | 89 +++++++++++++++++++ 3 files changed, 89 insertions(+) create mode 100644 shortfin/python/shortfin_apps/llm/components/__init__.py create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/__init__.py create mode 100644 shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py diff --git a/shortfin/python/shortfin_apps/llm/components/__init__.py b/shortfin/python/shortfin_apps/llm/components/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/__init__.py b/shortfin/python/shortfin_apps/llm/components/kvcache/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py new file mode 100644 index 000000000..84ca36d59 --- /dev/null +++ b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py @@ -0,0 +1,89 @@ +import pytest +import threading +import queue +import random +import time +from unittest.mock import Mock +from dataclasses import dataclass +from typing import List, Optional + +from shortfin_apps.llm.components.kvcache.base_attention_cache import ( + BasePagedAttentionCache, + BasePageAttentionCacheAllocation, + CacheAllocationFailure, +) +from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PageInfo + + +class MockPagePool(PagePool): + def __init__(self, total_pages: int = 100): + self._queue = queue.Queue() + + for i in range(total_pages): + page = PageInfo(index=i, pool=self, token_offset=0, token_count=0) + self._queue.put(page) + + def acquire_free_pages(self, count: int) -> List[PageInfo]: + try: + return [self._queue.get_nowait() for _ in range(count)] + except queue.Empty: + return None + + def release_pages(self, pages): + for page in pages: + self._queue.put(page) + + +@pytest.fixture +def page_pool(): + return MockPagePool(total_pages=10) + + +@pytest.fixture +def cache(page_pool): + return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=16) + + +@pytest.fixture +def page_pool(): + return MockPagePool(total_pages=10) + + +@pytest.fixture +def cache(page_pool): + """Create cache with 16 tokens per page""" + return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=16) + + +def test_allocation_sizes(cache): + test_cases = [ + ([], 0), # Empty token list + (list(range(8)), 1), # Partial page + (list(range(16)), 1), # Exact page + (list(range(17)), 2), # Just over one page + (list(range(32)), 2), # Multiple exact pages + (list(range(33)), 3), # Multiple pages with remainder + ] + + for tokens, expected_pages in test_cases: + allocation = cache.acquire_pages_for_tokens(tokens) + pages = allocation.pages + assert len(pages) == expected_pages + allocation.release_pages() + + +def test_concurrent_access(cache): + def worker(results: List): + allocation = cache.acquire_pages_for_tokens(list(range(16))) + results.append(len(allocation.pages)) + allocation.release_pages() + + results = [] + threads = [threading.Thread(target=worker, args=(results,)) for _ in range(5)] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert all(r == 1 for r in results) From 13af91371d2d482cb8581ee412add2220fc173c4 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 11:28:17 -0800 Subject: [PATCH 06/18] avoid magic numbers --- .../kvcache/base_attention_cache_test.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py index 84ca36d59..b3e2f5d77 100644 --- a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py @@ -14,9 +14,12 @@ ) from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PageInfo +TEST_PAGE_SIZE = 16 +TEST_POOL_CAPACITY = 10 + class MockPagePool(PagePool): - def __init__(self, total_pages: int = 100): + def __init__(self, total_pages: int): self._queue = queue.Queue() for i in range(total_pages): @@ -36,33 +39,33 @@ def release_pages(self, pages): @pytest.fixture def page_pool(): - return MockPagePool(total_pages=10) + return MockPagePool(total_pages=TEST_POOL_CAPACITY) @pytest.fixture def cache(page_pool): - return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=16) + return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE) @pytest.fixture def page_pool(): - return MockPagePool(total_pages=10) + return MockPagePool(total_pages=TEST_POOL_CAPACITY) @pytest.fixture def cache(page_pool): - """Create cache with 16 tokens per page""" - return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=16) + """Create cache with TEST_PAGE_SIZE tokens per page""" + return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE) def test_allocation_sizes(cache): test_cases = [ ([], 0), # Empty token list - (list(range(8)), 1), # Partial page - (list(range(16)), 1), # Exact page - (list(range(17)), 2), # Just over one page - (list(range(32)), 2), # Multiple exact pages - (list(range(33)), 3), # Multiple pages with remainder + (list(range(TEST_PAGE_SIZE // 2)), 1), # Partial page + (list(range(TEST_PAGE_SIZE)), 1), # Exact page + (list(range(TEST_PAGE_SIZE + 1)), 2), # Just over one page + (list(range(TEST_PAGE_SIZE * 2)), 2), # Multiple exact pages + (list(range(TEST_PAGE_SIZE * 2 + 1)), 3), # Multiple pages with remainder ] for tokens, expected_pages in test_cases: @@ -74,7 +77,7 @@ def test_allocation_sizes(cache): def test_concurrent_access(cache): def worker(results: List): - allocation = cache.acquire_pages_for_tokens(list(range(16))) + allocation = cache.acquire_pages_for_tokens(list(range(TEST_PAGE_SIZE))) results.append(len(allocation.pages)) allocation.release_pages() From c3c52d4c0f7343d80ad52ae4e435dcfcf7dfbc54 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 11:36:38 -0800 Subject: [PATCH 07/18] very extensive concurrency testing --- .../kvcache/base_attention_cache_test.py | 137 +++++++++++++----- 1 file changed, 98 insertions(+), 39 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py index b3e2f5d77..b679dbf6f 100644 --- a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py @@ -3,9 +3,10 @@ import queue import random import time +from collections import defaultdict from unittest.mock import Mock from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Set from shortfin_apps.llm.components.kvcache.base_attention_cache import ( BasePagedAttentionCache, @@ -21,7 +22,6 @@ class MockPagePool(PagePool): def __init__(self, total_pages: int): self._queue = queue.Queue() - for i in range(total_pages): page = PageInfo(index=i, pool=self, token_offset=0, token_count=0) self._queue.put(page) @@ -47,46 +47,105 @@ def cache(page_pool): return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE) -@pytest.fixture -def page_pool(): - return MockPagePool(total_pages=TEST_POOL_CAPACITY) - - -@pytest.fixture -def cache(page_pool): - """Create cache with TEST_PAGE_SIZE tokens per page""" - return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE) - - -def test_allocation_sizes(cache): - test_cases = [ - ([], 0), # Empty token list - (list(range(TEST_PAGE_SIZE // 2)), 1), # Partial page - (list(range(TEST_PAGE_SIZE)), 1), # Exact page - (list(range(TEST_PAGE_SIZE + 1)), 2), # Just over one page - (list(range(TEST_PAGE_SIZE * 2)), 2), # Multiple exact pages - (list(range(TEST_PAGE_SIZE * 2 + 1)), 3), # Multiple pages with remainder - ] - - for tokens, expected_pages in test_cases: - allocation = cache.acquire_pages_for_tokens(tokens) - pages = allocation.pages - assert len(pages) == expected_pages - allocation.release_pages() - - -def test_concurrent_access(cache): - def worker(results: List): - allocation = cache.acquire_pages_for_tokens(list(range(TEST_PAGE_SIZE))) - results.append(len(allocation.pages)) - allocation.release_pages() - - results = [] - threads = [threading.Thread(target=worker, args=(results,)) for _ in range(5)] +@pytest.mark.parametrize( + "tokens,expected_pages,test_name", + [ + ([], 0, "empty_token_list"), + (list(range(TEST_PAGE_SIZE // 2)), 1, "partial_page"), + (list(range(TEST_PAGE_SIZE)), 1, "exact_page"), + (list(range(TEST_PAGE_SIZE + 1)), 2, "just_over_one_page"), + (list(range(TEST_PAGE_SIZE * 2)), 2, "multiple_exact_pages"), + (list(range(TEST_PAGE_SIZE * 2 + 1)), 3, "multiple_pages_with_remainder"), + (list(range(TEST_PAGE_SIZE * 3)), 3, "three_exact_pages"), + (list(range(1)), 1, "single_token"), + (list(range(TEST_PAGE_SIZE - 1)), 1, "almost_full_page"), + ], +) +def test_allocation_sizes(cache, tokens, expected_pages, test_name): + allocation = cache.acquire_pages_for_tokens(tokens) + pages = allocation.pages + assert len(pages) == expected_pages, f"Failed for case: {test_name}" + allocation.release_pages() + + +@pytest.mark.parametrize( + "num_workers,pages_per_worker,expect_failure", + [ + (2, 1, False), # Basic concurrent access + (5, 1, False), # Higher concurrency, single page + (3, 2, False), # Multiple pages per worker + (2, 3, False), # More pages than workers, but within capacity + (TEST_POOL_CAPACITY, 1, False), # Max capacity single pages + (TEST_POOL_CAPACITY // 2, 2, False), # Max capacity multiple pages + (4, 3, True), # 12 pages needed, exceeds capacity + (TEST_POOL_CAPACITY + 1, 1, True), # More workers than capacity + (TEST_POOL_CAPACITY // 2, 3, True), # Exceeds capacity with multiple pages + ], +) +def test_concurrent_page_allocation( + cache, num_workers, pages_per_worker, expect_failure +): + allocated_pages = defaultdict(set) + errors = [] + allocations = [] + + def worker(worker_id: int): + try: + tokens = list(range(TEST_PAGE_SIZE * pages_per_worker)) + allocation = cache.acquire_pages_for_tokens(tokens) + allocations.append(allocation) + allocated_pages[worker_id] = {page.index for page in allocation.pages} + time.sleep(random.uniform(0.001, 0.01)) + except CacheAllocationFailure as e: + errors.append(e) + except Exception as e: + pytest.fail(f"Unexpected error: {e}") + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(num_workers)] for t in threads: t.start() for t in threads: t.join() - assert all(r == 1 for r in results) + if expect_failure: + assert len(errors) > 0, "Expected at least one CacheAllocationFailure" + else: + assert not errors, f"Workers encountered errors: {errors}" + for worker_id, pages in allocated_pages.items(): + assert ( + len(pages) == pages_per_worker + ), f"Worker {worker_id} got {len(pages)} pages, expected {pages_per_worker}" + + all_pages = set() + for pages in allocated_pages.values(): + assert not ( + pages & all_pages + ), f"Found duplicate page allocation: {pages & all_pages}" + all_pages.update(pages) + + for allocation in allocations: + allocation.release_pages() + + +@pytest.mark.parametrize( + "total_pages_needed", + [ + TEST_POOL_CAPACITY + 1, # Just over capacity + TEST_POOL_CAPACITY * 2, # Double capacity + ], +) +def test_allocation_failure_when_exhausted(cache, total_pages_needed): + successful_allocations = [] + + try: + tokens = list(range(TEST_PAGE_SIZE * total_pages_needed)) + allocation = cache.acquire_pages_for_tokens(tokens) + successful_allocations.append(allocation) + except CacheAllocationFailure as e: + pass + else: + pytest.fail("Expected CacheAllocationFailure was not raised") + finally: + for alloc in successful_allocations: + alloc.release_pages() From cd923b5371c7f4a790c7a1f75911816c4a125842 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 11:43:57 -0800 Subject: [PATCH 08/18] prettify the other part too --- .../kvcache/base_attention_cache_test.py | 58 +++++++++++-------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py index b679dbf6f..3fc749668 100644 --- a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py @@ -47,20 +47,22 @@ def cache(page_pool): return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE) +# fmt: off @pytest.mark.parametrize( - "tokens,expected_pages,test_name", - [ - ([], 0, "empty_token_list"), - (list(range(TEST_PAGE_SIZE // 2)), 1, "partial_page"), - (list(range(TEST_PAGE_SIZE)), 1, "exact_page"), - (list(range(TEST_PAGE_SIZE + 1)), 2, "just_over_one_page"), - (list(range(TEST_PAGE_SIZE * 2)), 2, "multiple_exact_pages"), - (list(range(TEST_PAGE_SIZE * 2 + 1)), 3, "multiple_pages_with_remainder"), - (list(range(TEST_PAGE_SIZE * 3)), 3, "three_exact_pages"), - (list(range(1)), 1, "single_token"), - (list(range(TEST_PAGE_SIZE - 1)), 1, "almost_full_page"), - ], + "tokens,expected_pages,test_name", + [ # Tokens Pages Name + ([], 0, "empty_token_list"), + (list(range(TEST_PAGE_SIZE // 2)), 1, "partial_page"), + (list(range(TEST_PAGE_SIZE)), 1, "exact_page"), + (list(range(TEST_PAGE_SIZE + 1)), 2, "just_over_one_page"), + (list(range(TEST_PAGE_SIZE * 2)), 2, "multiple_exact_pages"), + (list(range(TEST_PAGE_SIZE * 2 + 1)), 3, "multiple_pages_with_remainder"), + (list(range(TEST_PAGE_SIZE * 3)), 3, "three_exact_pages"), + (list(range(1)), 1, "single_token"), + (list(range(TEST_PAGE_SIZE - 1)), 1, "almost_full_page"), + ], ) +# fmt: on def test_allocation_sizes(cache, tokens, expected_pages, test_name): allocation = cache.acquire_pages_for_tokens(tokens) pages = allocation.pages @@ -68,22 +70,28 @@ def test_allocation_sizes(cache, tokens, expected_pages, test_name): allocation.release_pages() +# fmt: off @pytest.mark.parametrize( - "num_workers,pages_per_worker,expect_failure", - [ - (2, 1, False), # Basic concurrent access - (5, 1, False), # Higher concurrency, single page - (3, 2, False), # Multiple pages per worker - (2, 3, False), # More pages than workers, but within capacity - (TEST_POOL_CAPACITY, 1, False), # Max capacity single pages - (TEST_POOL_CAPACITY // 2, 2, False), # Max capacity multiple pages - (4, 3, True), # 12 pages needed, exceeds capacity - (TEST_POOL_CAPACITY + 1, 1, True), # More workers than capacity - (TEST_POOL_CAPACITY // 2, 3, True), # Exceeds capacity with multiple pages - ], + "num_workers,pages_per_worker,expect_failure,case_name", + [ # Workers Pages Failure Case name + (2, 1, False, "basic_concurrent"), # Basic concurrent access + (5, 1, False, "high_concurrency"), # Higher concurrency, single page + (3, 2, False, "multi_page"), # Multiple pages per worker + (2, 3, False, "more_pages"), # More pages than workers, within capacity + (TEST_POOL_CAPACITY, 1, False, "max_capacity"), # Max capacity single pages + (TEST_POOL_CAPACITY // 2, 2, False, "max_capacity_multi"), # Max capacity multiple pages + (4, 3, True , "exceeds_total"), # 12 pages needed, exceeds capacity + (TEST_POOL_CAPACITY + 1, 1, True , "exceeds_workers"), # More workers than capacity + (TEST_POOL_CAPACITY // 2, 3, True , "exceeds_with_multi"), # Exceeds capacity with multiple pages + ], ) +# fmt: on def test_concurrent_page_allocation( - cache, num_workers, pages_per_worker, expect_failure + cache, + num_workers, + pages_per_worker, + expect_failure, + case_name, ): allocated_pages = defaultdict(set) errors = [] From 09d83e61d7512481480027aaf9677e8a447aa169 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 11:44:37 -0800 Subject: [PATCH 09/18] consistent naming for case_name --- .../llm/components/kvcache/base_attention_cache_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py index 3fc749668..113da6912 100644 --- a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py @@ -49,8 +49,8 @@ def cache(page_pool): # fmt: off @pytest.mark.parametrize( - "tokens,expected_pages,test_name", - [ # Tokens Pages Name + "tokens,expected_pages,case_name", + [ # Tokens Pages Case Name ([], 0, "empty_token_list"), (list(range(TEST_PAGE_SIZE // 2)), 1, "partial_page"), (list(range(TEST_PAGE_SIZE)), 1, "exact_page"), @@ -63,10 +63,10 @@ def cache(page_pool): ], ) # fmt: on -def test_allocation_sizes(cache, tokens, expected_pages, test_name): +def test_allocation_sizes(cache, tokens, expected_pages, case_name): allocation = cache.acquire_pages_for_tokens(tokens) pages = allocation.pages - assert len(pages) == expected_pages, f"Failed for case: {test_name}" + assert len(pages) == expected_pages, f"Failed for case: {case_name}" allocation.release_pages() From 970dbc112e0ab194c5c0acb2b715a357e253bf95 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 15:14:08 -0800 Subject: [PATCH 10/18] add trie implementation --- .../kvcache/trie_attention_cache.py | 321 ++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py new file mode 100644 index 000000000..7280232b2 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -0,0 +1,321 @@ +from typing import Dict, Set, List, Tuple, Optional +from dataclasses import dataclass +import time +import math +import heapq +from .page_pool import PagePool, PageInfo +from .base_attention_cache import ( + BasePagedAttentionCache, + PageAllocation, + CacheAllocationFailure, +) + + +@dataclass +class TrieNode: + """Node of the block trie for paged attention cache. + + Each node represents a page of tokens in the cache, with edges representing + token sequences that can follow. This allows prefix sharing between sequences + that have common prefixes. + + Attributes: + tokens: Tuple of tokens stored in this node's page + page: PageInfo object containing the actual cache page + children: Dict mapping token sequences to child nodes + parent: Parent node in the trie (None for root) + ref_count: Number of active references to this node + access_time: Last access timestamp for LRU eviction + """ + + tokens: Tuple[int, ...] + page: PageInfo + children: Optional[Dict[Tuple[int, ...], "TrieNode"]] = None + parent: Optional["TrieNode"] = None + ref_count: int = 0 + access_time: float = 0.0 + + def __post_init__(self) -> None: + """Initialize children dict and access time if not provided.""" + if self.children is None: + self.children = {} + self.access_time = time.monotonic() + + def create_child(self, tokens: Tuple[int, ...], page: PageInfo) -> "TrieNode": + """Create a new child node with the given tokens and page. + + Args: + tokens: Sequence of tokens for the new node + page: PageInfo for the new node's cache page + + Returns: + The newly created child node + """ + new_node = TrieNode(tokens=tokens, page=page, parent=self) + self.children[tokens] = new_node + return new_node + + def unlink(self) -> None: + """Remove this node from its parent's children.""" + if self.parent is not None: + del self.parent.children[self.tokens] + self.parent = None + + def __hash__(self) -> int: + """Nodes are uniquely identified by their memory address.""" + return id(self) + + def __eq__(self, other: object) -> bool: + """Nodes are equal only if they are the same object.""" + return self is other + + +class TriePageAttentionCacheAllocation(PageAllocation): + """Represents a page allocation in the trie-based cache. + + Tracks both previously cached pages and newly allocated pages, + implementing the PageAllocation protocol for the trie cache. + + Attributes: + cache: The parent cache this allocation belongs to + tokens: Complete sequence of tokens this allocation represents + last_cached_node: Last matched node in the trie + cached_pages: List of pages already in cache + newly_acquired_pages: List of newly allocated pages + start_index: Index where cached tokens end and new tokens begin + """ + + def __init__( + self, + cache: "TriePagedAttentionCache", + tokens: List[int], + last_cached_node: TrieNode, + cached_pages: List[PageInfo], + newly_acquired_pages: List[PageInfo], + start_index: int, + ): + self.cache = cache + self.tokens = tokens + self.last_cached_node = last_cached_node + self.cached_pages = cached_pages + self.newly_acquired_pages = newly_acquired_pages + self.start_index = start_index + self._is_released = False + + @property + def pages(self) -> List[PageInfo]: + """List all pages in this allocation, both cached and new. + + Returns: + Combined list of cached and newly acquired pages + """ + return self.cached_pages + self.newly_acquired_pages + + def publish_pages(self, up_to_page_index: int) -> None: + """Make pages available in the cache up to the specified index. + + Args: + up_to_page_index: Number of pages to publish, starting from the beginning + """ + tokens_per_page = self.cache.tokens_per_page + + publish_token_count = min(len(self.tokens), up_to_page_index * tokens_per_page) + + cur_node = self.last_cached_node + first_uncached_page_index = self.start_index // tokens_per_page + + uncached_tokens = [ + tuple(self.tokens[i : i + tokens_per_page]) + for i in range( + first_uncached_page_index * tokens_per_page, + publish_token_count, + tokens_per_page, + ) + ] + + uncached_pages = self.newly_acquired_pages[: len(uncached_tokens)] + + for token_block, page in zip(uncached_tokens, uncached_pages): + new_node = cur_node.create_child(token_block, page) + cur_node = new_node + + if cur_node is not self.cache.root: + self.cache.leaves.add(cur_node) + + def release_pages(self) -> None: + """Release the allocation's reference to its pages. + + Decrements reference count of the last cached node. When count + reaches zero, the node becomes eligible for eviction. + """ + if self._is_released: + return + + self.last_cached_node.ref_count -= 1 + self._is_released = True + + +class TriePagedAttentionCache(BasePagedAttentionCache): + """Trie-based paged attention cache implementation. + + Implements prefix sharing through a trie structure where each node + represents a page of tokens. Common prefixes between sequences share + the same nodes/pages, reducing memory usage. + + Attributes: + root: Root node of the trie + leaves: Set of leaf nodes for efficient eviction + page_pool: Pool providing page allocations + tokens_per_page: Number of tokens that fit in each page + """ + + def __init__(self, page_pool: PagePool, tokens_per_page: int): + """Initialize the trie cache. + + Args: + page_pool: Pool to allocate pages from + tokens_per_page: Number of tokens per page + + Raises: + ValueError: If tokens_per_page <= 0 + """ + if tokens_per_page <= 0: + raise ValueError("tokens_per_page must be positive") + + super().__init__(page_pool, tokens_per_page) + + # Create root node with dummy page + dummy_page = PageInfo( + index=0, # Root uses reserved index 0 + pool=self.page_pool, + token_offset=0, + token_count=0, + ) + self.root = TrieNode(tokens=tuple(), page=dummy_page) + self.leaves: Set[TrieNode] = set() + + def _match(self, tokens: List[int]) -> Tuple[TrieNode, List[PageInfo]]: + """ + Find the longest prefix match in the trie. + + Walks the trie following the token sequence as far as possible, + collecting matched pages along the way. + + Args: + tokens: Sequence of tokens to match + + Returns: + Tuple of (last matched node, list of matched pages) + """ + tokens = tuple(tokens) + matched_pages = [] + cur = self.root + + for i in range(0, len(tokens), self.tokens_per_page): + token_block = tokens[i : i + self.tokens_per_page] + + if token_block not in cur.children: + break + cur = cur.children[token_block] + cur.access_time = time.monotonic() + matched_pages.append(cur.page) + + return cur, matched_pages + + def acquire_pages_for_tokens( + self, + tokens: List[int], + extra_token_slots: int = 1, + ) -> PageAllocation: + """Acquire pages for a sequence of tokens. + + Attempts to reuse existing cached pages where possible through + prefix matching, allocating new pages only for the uncached suffix. + + Args: + tokens: Sequence of tokens needing pages + extra_token_slots: Additional token slots to allocate beyond tokens + + Returns: + PageAllocation containing both cached and newly allocated pages + + Raises: + CacheAllocationFailure: If unable to allocate required pages + """ + tokens = tuple(tokens) + + cur_node, matched_pages = self._match(tokens) + cur_node.ref_count += 1 + + n_cached_tokens = len(matched_pages) * self.tokens_per_page + remaining_length = len(tokens) - n_cached_tokens + extra_token_slots + n_empty_pages = math.ceil(remaining_length / self.tokens_per_page) + + new_pages = self.page_pool.acquire_free_pages(n_empty_pages) + + if new_pages is not None: + return TriePageAttentionCacheAllocation( + cache=self, + tokens=tokens, + last_cached_node=cur_node, + cached_pages=matched_pages, + newly_acquired_pages=new_pages, + start_index=n_cached_tokens, + ) + + # Try eviction + self._evict_pages(n_empty_pages - len(self.page_pool.free_pages)) + new_pages = self.page_pool.acquire_free_pages(n_empty_pages) + + if new_pages is None: + raise CacheAllocationFailure( + "Failed to acquire pages even after attempting eviction from LRU leaves" + ) + + return TriePageAttentionCacheAllocation( + cache=self, + tokens=tokens, + last_cached_node=cur_node, + cached_pages=matched_pages, + newly_acquired_pages=new_pages, + start_index=n_cached_tokens, + ) + + def _evict_pages(self, max_pages: int) -> int: + """Evict up to max_pages pages using LRU strategy. + + Evicts from unreferenced leaf nodes first, working up the trie + as nodes become childless. + + Args: + max_pages: Maximum number of pages to evict + + Returns: + Number of pages actually evicted + """ + pages_to_evict = [] + + # Initialize heap with unreferenced leaves + unused_leaf_heap = [ + (leaf.access_time, leaf) for leaf in self.leaves if leaf.ref_count == 0 + ] + heapq.heapify(unused_leaf_heap) + + # Evict least recently used nodes + while unused_leaf_heap and len(pages_to_evict) < max_pages: + _, leaf = heapq.heappop(unused_leaf_heap) + pages_to_evict.append(leaf.page) + parent = leaf.parent + leaf.unlink() + self.leaves.remove(leaf) + + # If parent becomes childless, it becomes a leaf + if parent is not self.root and not parent.children: + self.leaves.add(parent) + if parent.ref_count == 0: + heapq.heappush(unused_leaf_heap, (parent.access_time, parent)) + + if pages_to_evict: + self.page_pool.release_pages(pages_to_evict) + + return len(pages_to_evict) From 05034492242b2412fc8585ddf65f2ce372489855 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 16:13:33 -0800 Subject: [PATCH 11/18] add unit tests and fix numerous small problems --- .../kvcache/base_attention_cache.py | 2 +- .../llm/components/kvcache/page_pool.py | 12 +- .../kvcache/trie_attention_cache.py | 10 +- .../kvcache/trie_attention_cache_test.py | 395 ++++++++++++++++++ 4 files changed, 409 insertions(+), 10 deletions(-) create mode 100644 shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 73134903c..c86379368 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -63,7 +63,7 @@ def release_pages(self) -> None: if self._is_released: logger.warning("Releasing already-released allocation") return - self._cache.page_pool.release_pages(self._pages) + self._cache.page_pool.free_pages(self._pages) self._is_released = True def __rerp__(self) -> str: diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 1686370c0..0c2cb3f67 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -86,7 +86,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig for i in range(self.config.alloc_page_count) ] - self.attn_page_free = list(self.attn_page_entries) + self.available_pages = list(self.attn_page_entries) # Initialize a page table on each device. page_table_shape = [ @@ -108,14 +108,14 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig def acquire_free_pages(self, count: int) -> list[PageInfo] | None: with self._lock: - available = len(self.attn_page_free) + available = len(self.available_pages) if count > available: return None - return [self.attn_page_free.pop() for _ in range(count)] + return [self.available_pages.pop() for _ in range(count)] - def release_pages(self, pages: list[PageInfo]): + def free_pages(self, pages: list[PageInfo]): with self._lock: - self.attn_page_free.extend(pages) + self.available_pages.extend(pages) def copy_page(self, src_page: PageInfo) -> PageInfo: """ @@ -148,7 +148,7 @@ def copy_page(self, src_page: PageInfo) -> PageInfo: def __repr__(self): # No need to lock for repr (list is internally synchronized). - free_pages = len(self.attn_page_free) + free_pages = len(self.available_pages) total_pages = len(self.attn_page_entries) return ( f"PagePool({total_pages - free_pages}/{total_pages} pages in use: " diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index 7280232b2..f4614c8d8 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -142,6 +142,10 @@ def publish_pages(self, up_to_page_index: int) -> None: if cur_node is not self.cache.root: self.cache.leaves.add(cur_node) + cur_node.ref_count += 1 + self.last_cached_node.ref_count -= 1 + self.last_cached_node = cur_node + def release_pages(self) -> None: """Release the allocation's reference to its pages. @@ -225,7 +229,7 @@ def _match(self, tokens: List[int]) -> Tuple[TrieNode, List[PageInfo]]: def acquire_pages_for_tokens( self, tokens: List[int], - extra_token_slots: int = 1, + extra_token_slots: int = 0, ) -> PageAllocation: """Acquire pages for a sequence of tokens. @@ -264,7 +268,7 @@ def acquire_pages_for_tokens( ) # Try eviction - self._evict_pages(n_empty_pages - len(self.page_pool.free_pages)) + self._evict_pages(n_empty_pages - len(self.page_pool.available_pages)) new_pages = self.page_pool.acquire_free_pages(n_empty_pages) if new_pages is None: @@ -316,6 +320,6 @@ def _evict_pages(self, max_pages: int) -> int: heapq.heappush(unused_leaf_heap, (parent.access_time, parent)) if pages_to_evict: - self.page_pool.release_pages(pages_to_evict) + self.page_pool.free_pages(pages_to_evict) return len(pages_to_evict) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py new file mode 100644 index 000000000..72f1d666e --- /dev/null +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -0,0 +1,395 @@ +import pytest +from typing import List, Tuple +import shortfin as sf +import shortfin.array as sfnp +from unittest.mock import Mock, MagicMock +import threading +import time +from dataclasses import dataclass + +from shortfin_apps.llm.components.kvcache.trie_attention_cache import ( + TriePagedAttentionCache, +) +from shortfin_apps.llm.components.kvcache.base_attention_cache import ( + CacheAllocationFailure, +) +from shortfin_apps.llm.components.kvcache.page_pool import ( + PagePool, + PageInfo, + PagePoolConfig, +) + + +# Test constants +TEST_PAGE_SIZE = 16 # Tokens per page +TEST_POOL_CAPACITY = 10 + + +@dataclass +class TokenSequence: + """Helper class for test parameterization""" + + tokens: List[int] + description: str + expected_pages: int + expected_cached: int = 0 + + def __str__(self): + return self.description + + +class MockScopedDevice: + """A proper mock for ScopedDevice that implements required interface""" + + def __init__(self): + self._mock = Mock(spec=sf.ScopedDevice) + # Add any necessary attributes/methods the real ScopedDevice has + self._mock.device_id = 0 + self._mock.device_type = "CPU" + + def __repr__(self): + return f"MockScopedDevice(device_id={self._mock.device_id})" + + +@pytest.fixture +def mock_device_array(): + """Create mock device array with proper interface implementation""" + + class MockDeviceArray: + def __init__(self): + self.shape = None + self.dtype = None + + def view(self, *args): + return Mock() + + def copy_from(self, src): + pass + + return MockDeviceArray() + + +@pytest.fixture +def mock_device(): + """Create properly structured mock device""" + return MockScopedDevice() + + +@pytest.fixture +def page_pool(mock_device, mock_device_array): + """Create PagePool with properly structured mock components""" + # Mock the device array creation + original_for_device = sf.array.device_array.for_device + + def mock_for_device(device, shape, dtype): + mock_array = mock_device_array + mock_array.shape = shape + mock_array.dtype = dtype + return mock_array + + sf.array.device_array.for_device = mock_for_device + + try: + config = PagePoolConfig( + dtype=sfnp.float16, + alloc_page_count=TEST_POOL_CAPACITY, + paged_kv_block_size_elements=128, + ) + + pool = PagePool(devices=[mock_device], config=config) + pool.page_tables = [mock_device_array] + return pool + finally: + # Restore original function + sf.array.device_array.for_device = original_for_device + + +@pytest.fixture +def trie_cache(page_pool): + """Create TriePagedAttentionCache instance""" + return TriePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE) + + +@pytest.fixture +def published_sequence(trie_cache): + """Helper fixture that returns a function to publish token sequences""" + + def _publish_sequence(tokens: List[int]) -> None: + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + alloc.publish_pages(len(alloc.pages)) + alloc.release_pages() + + return _publish_sequence + + +def print_tree_state(cache, prefix=""): + """Helper function to print current tree state in a readable format""" + if not hasattr(cache, "root"): + print(f"{prefix}Unable to access trie structure") + return + + def node_info(node): + token_str = f"tokens={list(node.tokens) if node.tokens else 'root'}" + return f"{token_str}, ref_count={node.ref_count}, page_index={node.page.index}" + + def print_node(node, depth=0): + indent = " " * depth + print(f"{prefix}{indent}- {node_info(node)}") + if node.children: + for child in node.children.values(): + print_node(child, depth + 1) + + print(f"{prefix}Tree state:") + print_node(cache.root) + + +# Test sequences for parameterization +basic_sequences = [ + TokenSequence(tokens=[], description="empty_sequence", expected_pages=0), + TokenSequence( + tokens=list(range(TEST_PAGE_SIZE // 2)), + description="partial_page", + expected_pages=1, + ), + TokenSequence( + tokens=list(range(TEST_PAGE_SIZE)), description="exact_page", expected_pages=1 + ), + TokenSequence( + tokens=list(range(TEST_PAGE_SIZE + 1)), + description="overflow_page", + expected_pages=2, + ), + TokenSequence( + tokens=list(range(TEST_PAGE_SIZE * 2)), + description="multiple_pages", + expected_pages=2, + ), +] + +reuse_sequences = [ + (list(range(TEST_PAGE_SIZE)), list(range(TEST_PAGE_SIZE)), "exact_match", 1, 1), + ( + list(range(TEST_PAGE_SIZE * 2)), + list(range(TEST_PAGE_SIZE * 2)), + "multi_page_match", + 2, + 2, + ), + ( + list(range(TEST_PAGE_SIZE * 2)), + list(range(TEST_PAGE_SIZE)) + list(range(100, 100 + TEST_PAGE_SIZE)), + "prefix_match", + 2, + 1, + ), + ( + list(range(TEST_PAGE_SIZE)), + list(range(50, 50 + TEST_PAGE_SIZE)), + "no_match", + 1, + 0, + ), +] + + +@pytest.mark.parametrize("seq", basic_sequences) +def test_basic_allocation(trie_cache, seq): + """Test basic page allocation without reuse""" + allocation = trie_cache.acquire_pages_for_tokens(seq.tokens, extra_token_slots=0) + assert len(allocation.pages) == seq.expected_pages + assert len(allocation.cached_pages) == 0 + assert len(allocation.newly_acquired_pages) == seq.expected_pages + allocation.release_pages() + + +@pytest.mark.parametrize( + "initial_tokens,reuse_tokens,description,total_pages,expected_cached", + reuse_sequences, +) +def test_page_reuse( + trie_cache, + published_sequence, + initial_tokens, + reuse_tokens, + description, + total_pages, + expected_cached, +): + """Test page reuse scenarios""" + # Publish initial sequence + published_sequence(initial_tokens) + + # Try to reuse + allocation = trie_cache.acquire_pages_for_tokens(reuse_tokens, extra_token_slots=0) + assert len(allocation.pages) == total_pages + assert len(allocation.cached_pages) == expected_cached + assert len(allocation.newly_acquired_pages) == total_pages - expected_cached + allocation.release_pages() + + +@pytest.fixture +def filled_cache(trie_cache, published_sequence): + """Fixture that fills cache with numbered sequences""" + sequences = [] + for i in range(TEST_POOL_CAPACITY): + tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) + published_sequence(tokens) + sequences.append(tokens) + return sequences + + +@pytest.mark.parametrize( + "access_count", [1, TEST_POOL_CAPACITY // 2, TEST_POOL_CAPACITY - 1] +) +def test_lru_eviction(trie_cache, access_count): + """Test LRU eviction with different access patterns""" + print(f"\nStarting test_lru_eviction with access_count={access_count}") + + # Create mix of published and unpublished sequences + keep_published = 3 # Number of sequences to keep published + sequences = [] + + # First add some sequences we'll keep published + print("\nPublishing sequences to keep active:") + for i in range(keep_published): + tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + alloc.publish_pages(1) # Don't release these - they should stay in cache + sequences.append(tokens) + print(f"Published sequence {i} (keeping active)") + print_tree_state(trie_cache, " ") + + # Then add sequences we'll publish but release (evictable) + print("\nAdding releasable sequences:") + for i in range(keep_published, TEST_POOL_CAPACITY): + tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + alloc.publish_pages(1) + alloc.release_pages() # These can be evicted + sequences.append(tokens) + print(f"Added releasable sequence {i}") + print_tree_state(trie_cache, " ") + + print("\nCache state before accessing sequences:") + print_tree_state(trie_cache, " ") + + # Access some sequences to update their LRU status + print(f"\nAccessing {access_count} sequences to update LRU order:") + for i in range(access_count): + print(f"\nAccessing sequence {i}:") + alloc = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0) + print_tree_state(trie_cache, " ") + alloc.release_pages() + print(f"After releasing allocation {i}:") + print_tree_state(trie_cache, " ") + + print("\nCache state before attempting new allocation:") + print_tree_state(trie_cache, " ") + print("\nAvailable pages in pool:", len(trie_cache.page_pool.available_pages)) + + # Try to allocate new sequence - should evict least recently used unpublished sequence + new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) + print(f"\nAttempting to allocate new sequence: {new_tokens}") + new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) + print("\nNew allocation succeeded:") + print(f"- Allocated {len(new_alloc.pages)} new pages") + print(f"- Cached pages: {len(new_alloc.cached_pages)}") + print(f"- Newly acquired pages: {len(new_alloc.newly_acquired_pages)}") + print("\nCache state after new allocation:") + print_tree_state(trie_cache, " ") + new_alloc.release_pages() + + # Verify recently accessed sequences AND published sequences weren't evicted + print("\nVerifying preserved sequences:") + for i in range(max(access_count, keep_published)): + print(f"\nChecking sequence {i}:") + recheck = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0) + cached_pages = len(recheck.cached_pages) + print(f"- Cached pages found: {cached_pages}") + assert ( + cached_pages == 1 + ), f"Sequence {i} was evicted but should have been preserved" + recheck.release_pages() + + +@pytest.mark.parametrize("publish_steps", [1, 2, 3]) +def test_progressive_publish(trie_cache, publish_steps): + """Test publishing pages progressively""" + tokens = list(range(TEST_PAGE_SIZE * 3)) # Three pages + alloc = trie_cache.acquire_pages_for_tokens(tokens) + + for step in range(publish_steps): + # Publish next page + alloc.publish_pages(step + 1) + + # Verify reuse up to published point + reuse_tokens = tokens[: (step + 1) * TEST_PAGE_SIZE] + reuse_alloc = trie_cache.acquire_pages_for_tokens(reuse_tokens) + assert len(reuse_alloc.cached_pages) == step + 1 + reuse_alloc.release_pages() + + alloc.release_pages() + + +@pytest.mark.parametrize("ref_count", [1, 2, 5]) +def test_reference_counting(trie_cache, ref_count): + """Test reference counting with different counts""" + tokens = list(range(TEST_PAGE_SIZE)) + allocations = [] + + # Create initial allocation and publish + first_alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + first_alloc.publish_pages(1) + allocations.append(first_alloc) + print("\nInitial allocation created") + print_tree_state(trie_cache, " ") + + # Create additional references + for i in range(ref_count - 1): + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + allocations.append(alloc) + print(f"\nCreated reference {i+1}") + print_tree_state(trie_cache, " ") + + # Fill remaining cache + remaining = TEST_POOL_CAPACITY - 1 + fill_allocations = [] + for i in range(remaining): + fill_tokens = list( + range(100 + i * TEST_PAGE_SIZE, 100 + (i + 1) * TEST_PAGE_SIZE) + ) + alloc = trie_cache.acquire_pages_for_tokens(fill_tokens, extra_token_slots=0) + alloc.publish_pages(1) + fill_allocations.append(alloc) + print(f"\nFilled cache slot {i+1}/{remaining}") + print_tree_state(trie_cache, " ") + + print("\nAttempting allocation that should fail...") + try: + new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) + new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) + print("ERROR: Allocation succeeded when it should have failed!") + print(f"- Allocated {len(new_alloc.pages)} new pages") + print(f"- Cached pages: {len(new_alloc.cached_pages)}") + print( + f"- Number of newly acquired pages: {len(new_alloc.newly_acquired_pages)}" + ) + print(f"- Newly acquired pages: {new_alloc.newly_acquired_pages}") + print("\nPost-allocation state:") + print_tree_state(trie_cache, " ") + new_alloc.release_pages() + pytest.fail("Expected CacheAllocationFailure was not raised") + except CacheAllocationFailure: + print("Success: CacheAllocationFailure raised as expected") + + # Cleanup + print("\nCleaning up allocations...") + for alloc in allocations + fill_allocations: + alloc.release_pages() + + +@pytest.mark.parametrize("tokens_per_page", [0, -1, -100]) +def test_invalid_init(page_pool, tokens_per_page): + """Test validation in __init__""" + with pytest.raises(ValueError): + TriePagedAttentionCache(page_pool=page_pool, tokens_per_page=tokens_per_page) From 558da42f94e64c339fca574b7c704759aa851aa5 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 16:26:31 -0800 Subject: [PATCH 12/18] all tests passing --- .../kvcache/trie_attention_cache.py | 5 +- .../kvcache/trie_attention_cache_test.py | 60 +++++++++++++++++-- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index f4614c8d8..ca009ed3a 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -122,7 +122,7 @@ def publish_pages(self, up_to_page_index: int) -> None: publish_token_count = min(len(self.tokens), up_to_page_index * tokens_per_page) cur_node = self.last_cached_node - first_uncached_page_index = self.start_index // tokens_per_page + first_uncached_page_index = len(self.cached_pages) uncached_tokens = [ tuple(self.tokens[i : i + tokens_per_page]) @@ -139,6 +139,9 @@ def publish_pages(self, up_to_page_index: int) -> None: new_node = cur_node.create_child(token_block, page) cur_node = new_node + self.cached_pages.extend(uncached_pages) + self.newly_acquired_pages = self.newly_acquired_pages[len(uncached_pages) :] + if cur_node is not self.cache.root: self.cache.leaves.add(cur_node) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index 72f1d666e..aa271dcec 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -315,20 +315,70 @@ def test_lru_eviction(trie_cache, access_count): @pytest.mark.parametrize("publish_steps", [1, 2, 3]) def test_progressive_publish(trie_cache, publish_steps): """Test publishing pages progressively""" - tokens = list(range(TEST_PAGE_SIZE * 3)) # Three pages + print(f"\nStarting test_progressive_publish with publish_steps={publish_steps}") + + tokens = tuple(range(TEST_PAGE_SIZE * 3)) # Three pages + print(f"\nInitial tokens: {tokens}") + print(f"Tokens per page: {TEST_PAGE_SIZE}") + print( + f"Expected total pages: {len(tokens) // TEST_PAGE_SIZE + (1 if len(tokens) % TEST_PAGE_SIZE else 0)}" + ) + + print("\nInitial cache state:") + print_tree_state(trie_cache) + + print("\nAcquiring initial allocation...") alloc = trie_cache.acquire_pages_for_tokens(tokens) + print(f"Initial allocation pages: {[p.index for p in alloc.pages]}") + print("\nCache state after initial allocation:") + print_tree_state(trie_cache) + + for step in range(1, publish_steps + 1): + print(f"\n--- Step {step} of {publish_steps} ---") - for step in range(publish_steps): # Publish next page - alloc.publish_pages(step + 1) + print(f"Publishing up to page {step}") + alloc.publish_pages(step) + print("\nCache state after publish:") + print_tree_state(trie_cache) # Verify reuse up to published point - reuse_tokens = tokens[: (step + 1) * TEST_PAGE_SIZE] + reuse_tokens = tokens[: (step) * TEST_PAGE_SIZE] + print(f"\nAttempting to reuse tokens: {reuse_tokens}") + print(f"Expected cached pages: {step}") + reuse_alloc = trie_cache.acquire_pages_for_tokens(reuse_tokens) - assert len(reuse_alloc.cached_pages) == step + 1 + print(f"Reuse allocation total pages: {len(reuse_alloc.pages)}") + print(f"Reuse allocation cached pages: {len(reuse_alloc.cached_pages)}") + print(f"Cached page indices: {[p.index for p in reuse_alloc.cached_pages]}") + print( + f"New page indices: {[p.index for p in reuse_alloc.newly_acquired_pages]}" + ) + + print("\nCache state after reuse attempt:") + print_tree_state(trie_cache) + + try: + assert len(reuse_alloc.cached_pages) == step + except AssertionError: + print("\nASSERTION FAILED!") + print( + f"Expected {step} cached pages but got {len(reuse_alloc.cached_pages)}" + ) + print("Cached pages details:") + for i, page in enumerate(reuse_alloc.cached_pages): + print( + f"Page {i}: index={page.index}, token_offset={page.token_offset}, token_count={page.token_count}" + ) + raise + reuse_alloc.release_pages() + print("\nCache state after releasing reuse allocation:") + print_tree_state(trie_cache) alloc.release_pages() + print("\nFinal cache state after releasing initial allocation:") + print_tree_state(trie_cache) @pytest.mark.parametrize("ref_count", [1, 2, 5]) From 99ff2c697196f08cfa3078c5a8a347cdf463175b Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 16:56:05 -0800 Subject: [PATCH 13/18] all but publishing working --- .../kvcache/trie_attention_cache.py | 66 ++++++++----------- .../shortfin_apps/llm/components/messages.py | 2 +- .../kvcache/trie_attention_cache_test.py | 48 ++++++-------- 3 files changed, 51 insertions(+), 65 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index ca009ed3a..bcfe0f04f 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -73,16 +73,15 @@ def __eq__(self, other: object) -> bool: class TriePageAttentionCacheAllocation(PageAllocation): """Represents a page allocation in the trie-based cache. - Tracks both previously cached pages and newly allocated pages, + Tracks sequence of pages and which ones are already published to the cache, implementing the PageAllocation protocol for the trie cache. Attributes: cache: The parent cache this allocation belongs to tokens: Complete sequence of tokens this allocation represents last_cached_node: Last matched node in the trie - cached_pages: List of pages already in cache - newly_acquired_pages: List of newly allocated pages - start_index: Index where cached tokens end and new tokens begin + pages: List of all pages in allocation + number_of_published_pages: Number of pages that are published to the cache """ def __init__( @@ -92,62 +91,57 @@ def __init__( last_cached_node: TrieNode, cached_pages: List[PageInfo], newly_acquired_pages: List[PageInfo], - start_index: int, ): self.cache = cache self.tokens = tokens self.last_cached_node = last_cached_node - self.cached_pages = cached_pages - self.newly_acquired_pages = newly_acquired_pages - self.start_index = start_index + self._pages = cached_pages + newly_acquired_pages + self.number_of_published_pages = len(cached_pages) self._is_released = False @property def pages(self) -> List[PageInfo]: - """List all pages in this allocation, both cached and new. + return self._pages - Returns: - Combined list of cached and newly acquired pages - """ - return self.cached_pages + self.newly_acquired_pages - - def publish_pages(self, up_to_page_index: int) -> None: - """Make pages available in the cache up to the specified index. + def publish_pages(self, up_to_page_index) -> None: + """Make pages available in the cache for the specified tokens. Args: - up_to_page_index: Number of pages to publish, starting from the beginning + tokens_to_publish: Tokens to publish to the cache + + Raises: + ValueError: If tokens don't match allocation or exceed available pages """ tokens_per_page = self.cache.tokens_per_page - publish_token_count = min(len(self.tokens), up_to_page_index * tokens_per_page) - - cur_node = self.last_cached_node - first_uncached_page_index = len(self.cached_pages) + # Create token blocks for unpublished pages + start_token = self.number_of_published_pages * tokens_per_page - uncached_tokens = [ + unpublished_tokens = [ tuple(self.tokens[i : i + tokens_per_page]) - for i in range( - first_uncached_page_index * tokens_per_page, - publish_token_count, - tokens_per_page, - ) + for i in range(start_token, tokens_per_page) ] - uncached_pages = self.newly_acquired_pages[: len(uncached_tokens)] + unpublished_pages = self._pages[ + self.number_of_published_pages : up_to_page_index + ] - for token_block, page in zip(uncached_tokens, uncached_pages): + # Add unpublished pages to trie + cur_node = self.last_cached_node + for token_block, page in zip(unpublished_tokens, unpublished_pages): new_node = cur_node.create_child(token_block, page) cur_node = new_node - self.cached_pages.extend(uncached_pages) - self.newly_acquired_pages = self.newly_acquired_pages[len(uncached_pages) :] - if cur_node is not self.cache.root: self.cache.leaves.add(cur_node) - cur_node.ref_count += 1 - self.last_cached_node.ref_count -= 1 - self.last_cached_node = cur_node + # Update reference counts + if unpublished_tokens: + cur_node.ref_count += 1 + self.last_cached_node.ref_count -= 1 + self.last_cached_node = cur_node + + self.number_of_published_pages = up_to_page_index def release_pages(self) -> None: """Release the allocation's reference to its pages. @@ -267,7 +261,6 @@ def acquire_pages_for_tokens( last_cached_node=cur_node, cached_pages=matched_pages, newly_acquired_pages=new_pages, - start_index=n_cached_tokens, ) # Try eviction @@ -285,7 +278,6 @@ def acquire_pages_for_tokens( last_cached_node=cur_node, cached_pages=matched_pages, newly_acquired_pages=new_pages, - start_index=n_cached_tokens, ) def _evict_pages(self, max_pages: int) -> int: diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index c03900782..bc1f851e2 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -58,7 +58,7 @@ def reset(self, phase: InferencePhase): def cache_page_indices(self, max_len: int) -> list[int]: if not self.allocation: return [] - indices = [p.index for p in self.allocation.pages[:max_len]] + indices = [p.index for p in self.allocation._pages[:max_len]] return indices def publish_allocated_pages(self, up_to_page_index: int): diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index aa271dcec..4ce71a9a7 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -197,8 +197,13 @@ def test_basic_allocation(trie_cache, seq): """Test basic page allocation without reuse""" allocation = trie_cache.acquire_pages_for_tokens(seq.tokens, extra_token_slots=0) assert len(allocation.pages) == seq.expected_pages - assert len(allocation.cached_pages) == 0 - assert len(allocation.newly_acquired_pages) == seq.expected_pages + assert allocation.number_of_published_pages == 0 + assert ( + len(allocation.pages) - allocation.number_of_published_pages + == seq.expected_pages + ) + # Replace publishing with tokens + allocation.publish_pages(len(allocation.pages)) allocation.release_pages() @@ -222,8 +227,13 @@ def test_page_reuse( # Try to reuse allocation = trie_cache.acquire_pages_for_tokens(reuse_tokens, extra_token_slots=0) assert len(allocation.pages) == total_pages - assert len(allocation.cached_pages) == expected_cached - assert len(allocation.newly_acquired_pages) == total_pages - expected_cached + assert allocation.number_of_published_pages == expected_cached + assert ( + len(allocation.pages) - allocation.number_of_published_pages + == total_pages - expected_cached + ) + # Replace publishing with tokens + allocation.publish_pages(len(allocation.pages)) allocation.release_pages() @@ -292,9 +302,6 @@ def test_lru_eviction(trie_cache, access_count): print(f"\nAttempting to allocate new sequence: {new_tokens}") new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) print("\nNew allocation succeeded:") - print(f"- Allocated {len(new_alloc.pages)} new pages") - print(f"- Cached pages: {len(new_alloc.cached_pages)}") - print(f"- Newly acquired pages: {len(new_alloc.newly_acquired_pages)}") print("\nCache state after new allocation:") print_tree_state(trie_cache, " ") new_alloc.release_pages() @@ -304,7 +311,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(max(access_count, keep_published)): print(f"\nChecking sequence {i}:") recheck = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0) - cached_pages = len(recheck.cached_pages) + cached_pages = recheck.number_of_published_pages print(f"- Cached pages found: {cached_pages}") assert ( cached_pages == 1 @@ -338,6 +345,7 @@ def test_progressive_publish(trie_cache, publish_steps): # Publish next page print(f"Publishing up to page {step}") + # Replace publishing with tokens alloc.publish_pages(step) print("\nCache state after publish:") print_tree_state(trie_cache) @@ -349,27 +357,18 @@ def test_progressive_publish(trie_cache, publish_steps): reuse_alloc = trie_cache.acquire_pages_for_tokens(reuse_tokens) print(f"Reuse allocation total pages: {len(reuse_alloc.pages)}") - print(f"Reuse allocation cached pages: {len(reuse_alloc.cached_pages)}") - print(f"Cached page indices: {[p.index for p in reuse_alloc.cached_pages]}") - print( - f"New page indices: {[p.index for p in reuse_alloc.newly_acquired_pages]}" - ) + print(f"Reuse allocation cached pages: {reuse_alloc.number_of_published_pages}") print("\nCache state after reuse attempt:") print_tree_state(trie_cache) try: - assert len(reuse_alloc.cached_pages) == step + assert reuse_alloc.number_of_published_pages == step except AssertionError: print("\nASSERTION FAILED!") print( - f"Expected {step} cached pages but got {len(reuse_alloc.cached_pages)}" + f"Expected {step} cached pages but got {reuse_alloc.number_of_published_pages}" ) - print("Cached pages details:") - for i, page in enumerate(reuse_alloc.cached_pages): - print( - f"Page {i}: index={page.index}, token_offset={page.token_offset}, token_count={page.token_count}" - ) raise reuse_alloc.release_pages() @@ -389,7 +388,8 @@ def test_reference_counting(trie_cache, ref_count): # Create initial allocation and publish first_alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - first_alloc.publish_pages(1) + # Replace publishing with tokens + first_alloc.publish_pages(len(first_alloc.pages)) allocations.append(first_alloc) print("\nInitial allocation created") print_tree_state(trie_cache, " ") @@ -419,12 +419,6 @@ def test_reference_counting(trie_cache, ref_count): new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) print("ERROR: Allocation succeeded when it should have failed!") - print(f"- Allocated {len(new_alloc.pages)} new pages") - print(f"- Cached pages: {len(new_alloc.cached_pages)}") - print( - f"- Number of newly acquired pages: {len(new_alloc.newly_acquired_pages)}" - ) - print(f"- Newly acquired pages: {new_alloc.newly_acquired_pages}") print("\nPost-allocation state:") print_tree_state(trie_cache, " ") new_alloc.release_pages() From 9e1f9c74202e80e85acb14d04107391334dfe79b Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 16:57:44 -0800 Subject: [PATCH 14/18] naming consistency Page -> Paged --- .../llm/components/kvcache/trie_attention_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index bcfe0f04f..2295f6455 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -70,7 +70,7 @@ def __eq__(self, other: object) -> bool: return self is other -class TriePageAttentionCacheAllocation(PageAllocation): +class TriePagedAttentionCacheAllocation(PageAllocation): """Represents a page allocation in the trie-based cache. Tracks sequence of pages and which ones are already published to the cache, @@ -255,7 +255,7 @@ def acquire_pages_for_tokens( new_pages = self.page_pool.acquire_free_pages(n_empty_pages) if new_pages is not None: - return TriePageAttentionCacheAllocation( + return TriePagedAttentionCacheAllocation( cache=self, tokens=tokens, last_cached_node=cur_node, @@ -272,7 +272,7 @@ def acquire_pages_for_tokens( "Failed to acquire pages even after attempting eviction from LRU leaves" ) - return TriePageAttentionCacheAllocation( + return TriePagedAttentionCacheAllocation( cache=self, tokens=tokens, last_cached_node=cur_node, From bf5c8ab4665859303e8994de2ae3ebcff54ec664 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 17:13:57 -0800 Subject: [PATCH 15/18] passing all tests now --- .../kvcache/base_attention_cache.py | 4 +- .../kvcache/trie_attention_cache.py | 38 ++++++++++++++++--- .../shortfin_apps/llm/components/messages.py | 4 +- .../kvcache/trie_attention_cache_test.py | 18 ++++----- 4 files changed, 45 insertions(+), 19 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index c86379368..f571907de 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -34,7 +34,7 @@ def pages(self) -> List[PageInfo]: pass @abstractmethod - def publish_pages(self, up_to_page_index) -> None: + def publish_pages(self, tokens, publish_incomplete_pages=False) -> None: """Makes pages[0:up_to_page_index] available to other requests.""" pass @@ -56,7 +56,7 @@ def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): def pages(self) -> List[PageInfo]: return list(self._pages) - def publish_pages(self, up_to_page_index) -> None: + def publish_pages(self, tokens, publish_incomplete_pages=False) -> None: pass def release_pages(self) -> None: diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index 2295f6455..283076c33 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -103,7 +103,7 @@ def __init__( def pages(self) -> List[PageInfo]: return self._pages - def publish_pages(self, up_to_page_index) -> None: + def publish_pages(self, tokens, publish_incomplete_page=False) -> None: """Make pages available in the cache for the specified tokens. Args: @@ -112,21 +112,47 @@ def publish_pages(self, up_to_page_index) -> None: Raises: ValueError: If tokens don't match allocation or exceed available pages """ + # If we have more tokens, publish pages up to the incoming tokens. + # If incoming has more tokens, replace our tokens with incoming tokens and publish pages up to the incoming tokens. + + def has_common_prefix(tokens1, tokens2): + for t1, t2 in zip(tokens1, tokens2): + if t1 != t2: + return False + return True + + if not has_common_prefix(self.tokens, tokens): + raise ValueError( + "Tokens provided in publish_pages do not match tokens in allocation" + ) + + if len(tokens) > len(self.tokens): + self.tokens = tokens + tokens_per_page = self.cache.tokens_per_page - # Create token blocks for unpublished pages - start_token = self.number_of_published_pages * tokens_per_page + number_of_pages_to_publish = len(tokens) / tokens_per_page + if publish_incomplete_page: + number_of_pages_to_publish = math.ceil(number_of_pages_to_publish) + else: + number_of_pages_to_publish = math.floor(number_of_pages_to_publish) + # Create token blocks for unpublished pages + start_token_index = self.number_of_published_pages * tokens_per_page unpublished_tokens = [ tuple(self.tokens[i : i + tokens_per_page]) - for i in range(start_token, tokens_per_page) + for i in range(start_token_index, len(self.tokens), tokens_per_page) ] unpublished_pages = self._pages[ - self.number_of_published_pages : up_to_page_index + self.number_of_published_pages : number_of_pages_to_publish ] # Add unpublished pages to trie + if publish_incomplete_page: + raise NotImplementedError( + "Additional work needed here to support publishing incomplete pages to ensure that we finish up a page before attaching child nodes to it." + ) cur_node = self.last_cached_node for token_block, page in zip(unpublished_tokens, unpublished_pages): new_node = cur_node.create_child(token_block, page) @@ -141,7 +167,7 @@ def publish_pages(self, up_to_page_index) -> None: self.last_cached_node.ref_count -= 1 self.last_cached_node = cur_node - self.number_of_published_pages = up_to_page_index + self.number_of_published_pages = number_of_pages_to_publish def release_pages(self) -> None: """Release the allocation's reference to its pages. diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index bc1f851e2..2f41c9834 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -63,7 +63,9 @@ def cache_page_indices(self, max_len: int) -> list[int]: def publish_allocated_pages(self, up_to_page_index: int): assert self.allocation - self.allocation.publish_pages(up_to_page_index) + self.allocation.publish_pages( + self.input_token_ids, publish_incomplete_pages=False + ) def free_cache_pages(self): if self.allocation: diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index 4ce71a9a7..b06ea5d4e 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -116,7 +116,7 @@ def published_sequence(trie_cache): def _publish_sequence(tokens: List[int]) -> None: alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(len(alloc.pages)) + alloc.publish_pages(alloc.tokens) alloc.release_pages() return _publish_sequence @@ -202,8 +202,7 @@ def test_basic_allocation(trie_cache, seq): len(allocation.pages) - allocation.number_of_published_pages == seq.expected_pages ) - # Replace publishing with tokens - allocation.publish_pages(len(allocation.pages)) + allocation.publish_pages(allocation.tokens) allocation.release_pages() @@ -232,8 +231,7 @@ def test_page_reuse( len(allocation.pages) - allocation.number_of_published_pages == total_pages - expected_cached ) - # Replace publishing with tokens - allocation.publish_pages(len(allocation.pages)) + allocation.publish_pages(allocation.tokens) allocation.release_pages() @@ -264,7 +262,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(keep_published): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(1) # Don't release these - they should stay in cache + alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) sequences.append(tokens) print(f"Published sequence {i} (keeping active)") print_tree_state(trie_cache, " ") @@ -274,7 +272,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(keep_published, TEST_POOL_CAPACITY): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(1) + alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) alloc.release_pages() # These can be evicted sequences.append(tokens) print(f"Added releasable sequence {i}") @@ -346,7 +344,7 @@ def test_progressive_publish(trie_cache, publish_steps): # Publish next page print(f"Publishing up to page {step}") # Replace publishing with tokens - alloc.publish_pages(step) + alloc.publish_pages(alloc.tokens[: (step) * TEST_PAGE_SIZE]) print("\nCache state after publish:") print_tree_state(trie_cache) @@ -389,7 +387,7 @@ def test_reference_counting(trie_cache, ref_count): # Create initial allocation and publish first_alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) # Replace publishing with tokens - first_alloc.publish_pages(len(first_alloc.pages)) + first_alloc.publish_pages(first_alloc.tokens) allocations.append(first_alloc) print("\nInitial allocation created") print_tree_state(trie_cache, " ") @@ -409,7 +407,7 @@ def test_reference_counting(trie_cache, ref_count): range(100 + i * TEST_PAGE_SIZE, 100 + (i + 1) * TEST_PAGE_SIZE) ) alloc = trie_cache.acquire_pages_for_tokens(fill_tokens, extra_token_slots=0) - alloc.publish_pages(1) + alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) fill_allocations.append(alloc) print(f"\nFilled cache slot {i+1}/{remaining}") print_tree_state(trie_cache, " ") From 895e4d5296337ee70606fab978b61d37417f23c8 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 17:18:38 -0800 Subject: [PATCH 16/18] name and documentation update publish_pages -> publish_pages_for_tokens --- .../components/kvcache/base_attention_cache.py | 8 +++++--- .../components/kvcache/trie_attention_cache.py | 2 +- .../shortfin_apps/llm/components/messages.py | 2 +- .../kvcache/trie_attention_cache_test.py | 16 ++++++++-------- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index f571907de..a7ee9c369 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -34,8 +34,10 @@ def pages(self) -> List[PageInfo]: pass @abstractmethod - def publish_pages(self, tokens, publish_incomplete_pages=False) -> None: - """Makes pages[0:up_to_page_index] available to other requests.""" + def publish_pages_for_tokens(self, tokens, publish_incomplete_pages=False) -> None: + """ + Makes pages available to other requests. For details, reference the derived class in trie_attention_cache.py. + """ pass @abstractmethod @@ -56,7 +58,7 @@ def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): def pages(self) -> List[PageInfo]: return list(self._pages) - def publish_pages(self, tokens, publish_incomplete_pages=False) -> None: + def publish_pages_for_tokens(self, tokens, publish_incomplete_pages=False) -> None: pass def release_pages(self) -> None: diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index 283076c33..1c9872cda 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -103,7 +103,7 @@ def __init__( def pages(self) -> List[PageInfo]: return self._pages - def publish_pages(self, tokens, publish_incomplete_page=False) -> None: + def publish_pages_for_tokens(self, tokens, publish_incomplete_page=False) -> None: """Make pages available in the cache for the specified tokens. Args: diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 2f41c9834..118ae2225 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -63,7 +63,7 @@ def cache_page_indices(self, max_len: int) -> list[int]: def publish_allocated_pages(self, up_to_page_index: int): assert self.allocation - self.allocation.publish_pages( + self.allocation.publish_pages_for_tokens( self.input_token_ids, publish_incomplete_pages=False ) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index b06ea5d4e..6d216c790 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -116,7 +116,7 @@ def published_sequence(trie_cache): def _publish_sequence(tokens: List[int]) -> None: alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(alloc.tokens) + alloc.publish_pages_for_tokens(alloc.tokens) alloc.release_pages() return _publish_sequence @@ -202,7 +202,7 @@ def test_basic_allocation(trie_cache, seq): len(allocation.pages) - allocation.number_of_published_pages == seq.expected_pages ) - allocation.publish_pages(allocation.tokens) + allocation.publish_pages_for_tokens(allocation.tokens) allocation.release_pages() @@ -231,7 +231,7 @@ def test_page_reuse( len(allocation.pages) - allocation.number_of_published_pages == total_pages - expected_cached ) - allocation.publish_pages(allocation.tokens) + allocation.publish_pages_for_tokens(allocation.tokens) allocation.release_pages() @@ -262,7 +262,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(keep_published): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) + alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE]) sequences.append(tokens) print(f"Published sequence {i} (keeping active)") print_tree_state(trie_cache, " ") @@ -272,7 +272,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(keep_published, TEST_POOL_CAPACITY): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) + alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE]) alloc.release_pages() # These can be evicted sequences.append(tokens) print(f"Added releasable sequence {i}") @@ -344,7 +344,7 @@ def test_progressive_publish(trie_cache, publish_steps): # Publish next page print(f"Publishing up to page {step}") # Replace publishing with tokens - alloc.publish_pages(alloc.tokens[: (step) * TEST_PAGE_SIZE]) + alloc.publish_pages_for_tokens(alloc.tokens[: (step) * TEST_PAGE_SIZE]) print("\nCache state after publish:") print_tree_state(trie_cache) @@ -387,7 +387,7 @@ def test_reference_counting(trie_cache, ref_count): # Create initial allocation and publish first_alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) # Replace publishing with tokens - first_alloc.publish_pages(first_alloc.tokens) + first_alloc.publish_pages_for_tokens(first_alloc.tokens) allocations.append(first_alloc) print("\nInitial allocation created") print_tree_state(trie_cache, " ") @@ -407,7 +407,7 @@ def test_reference_counting(trie_cache, ref_count): range(100 + i * TEST_PAGE_SIZE, 100 + (i + 1) * TEST_PAGE_SIZE) ) alloc = trie_cache.acquire_pages_for_tokens(fill_tokens, extra_token_slots=0) - alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) + alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE]) fill_allocations.append(alloc) print(f"\nFilled cache slot {i+1}/{remaining}") print_tree_state(trie_cache, " ") From 0caeb23473c809814036be407a95b0240020adbd Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 17:23:27 -0800 Subject: [PATCH 17/18] undo accidental edit of messages.py --- shortfin/python/shortfin_apps/llm/components/messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 118ae2225..a8d0f871b 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -58,7 +58,7 @@ def reset(self, phase: InferencePhase): def cache_page_indices(self, max_len: int) -> list[int]: if not self.allocation: return [] - indices = [p.index for p in self.allocation._pages[:max_len]] + indices = [p.index for p in self.allocation.pages[:max_len]] return indices def publish_allocated_pages(self, up_to_page_index: int): From 74adaab1b25ee5ab07179c310595f50888c8dc04 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 17:36:38 -0800 Subject: [PATCH 18/18] remove not-very-useful test case --- .../llm/components/kvcache/trie_attention_cache_test.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index 6d216c790..ce0025419 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -428,10 +428,3 @@ def test_reference_counting(trie_cache, ref_count): print("\nCleaning up allocations...") for alloc in allocations + fill_allocations: alloc.release_pages() - - -@pytest.mark.parametrize("tokens_per_page", [0, -1, -100]) -def test_invalid_init(page_pool, tokens_per_page): - """Test validation in __init__""" - with pytest.raises(ValueError): - TriePagedAttentionCache(page_pool=page_pool, tokens_per_page=tokens_per_page)