Skip to content

Commit

Permalink
ready for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Nov 22, 2024
1 parent ee5af50 commit 40f335b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,21 @@ def __init__(self, page_pool, tokens_per_page):
self.tokens_per_page = tokens_per_page


def acquire_pages_for_tokens(self, tokens: List[int]) -> tuple[List[PageInfo], int)]:
def acquire_pages_for_tokens(self, tokens: List[int], extra_token_slots: int = 1) -> tuple[List[PageInfo], int)]:
"""
Given a list of tokens, return a list of pages and a start position to continue generation from.
Parameters:
- tokens: all the known tokens for this generation request
- extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens.
In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable.
The pages are returned in order.
No token at idx < n_cached_token should be written to. TODO: consider enforcing this.
"""
pages_needed = math.ceil(len(tokens) / self.tokens_per_page)
pages_needed = math.ceil(len(tokens + extra_token_slots) / self.tokens_per_page)
pages = self.page_pool.acquire_free_pages(pages_needed)

n_cached_tokens = 0
Expand Down
12 changes: 10 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,11 @@ def board_prefills(self, cache: AttnPageCache):
needed_pages = math.ceil(
len(prefill_request.input_token_ids) / self.page_seq_stride
)
pages = cache.acquire_free_pages(needed_pages)
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(
prefill_request.input_token_ids,
extra_token_slots=0, # prefill needs no extra kvcache slots to write to
)
if pages is None:
logger.debug("Cannot fulfill request for %d pages", needed_pages)
continue
Expand Down Expand Up @@ -254,7 +258,11 @@ def board_decodes(self, cache: AttnPageCache):
/ self.page_seq_stride
)
if needed_pages > len(decode_request.locked_pages):
pages = cache.acquire_free_pages(needed_pages)
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(
decode_request.input_token_ids,
extra_token_slots=1, # need 1 extra slot to write result.
)
if pages is None:
logger.debug(
"Cannot fulfill decode request for %d pages", needed_pages
Expand Down

0 comments on commit 40f335b

Please sign in to comment.