Skip to content

Commit

Permalink
add a bunch of stuff and now were passing pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Nov 24, 2024
1 parent 60f51ac commit ed0aa30
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Dict, Set, List, Tuple, TYPE_CHECKING

if TYPE_CHECKING:
from .page_pool import PageInfo
from typing import Dict, Set, List, Tuple
from .page_pool import PageInfo
import heapq
from dataclasses import dataclass
import time
Expand Down Expand Up @@ -36,8 +34,12 @@ def parent(self, val: "TrieNode"):
val.children[hash(tuple(self.tokens))] = self
self._parent = val

def __lt__(self, other):
return True
# nodes are uniquely identified by their memory address
def __hash__(self):
return id(self)

def __eq__(self, other):
return self is other


class TriePagedAttentionCache(BasePagedAttentionCache):
Expand Down
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def board_flights(self):
self.pending_prefills.clear()
logger.debug("Post boarding cache state: %r", cache)

def board_prefills(self, cache: BasePagedAttentionCache):
def board_prefills(self, cache: TriePagedAttentionCache):
# Fill prefill flights.
pending_prefills = self.pending_prefills
if len(pending_prefills) == 0:
Expand Down Expand Up @@ -252,7 +252,7 @@ def board_prefills(self, cache: BasePagedAttentionCache):
# And takeoff.
exec_process.launch()

def board_decodes(self, cache: BasePagedAttentionCache):
def board_decodes(self, cache: TriePagedAttentionCache):
# Fill decode flights.
pending_decodes = self.pending_decodes
if len(pending_decodes) == 0:
Expand Down

0 comments on commit ed0aa30

Please sign in to comment.