Skip to content

Commit

Permalink
Add comprehensive debugging messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lihao Ran committed Feb 6, 2025
1 parent 7e66063 commit f3db75d
Showing 1 changed file with 114 additions and 59 deletions.
173 changes: 114 additions & 59 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@
root.addHandler(handler)


def ThreadDebugLog(thread_name: str, message: str) -> None:
logging.debug(f"[{thread_name}] {message}")


@dataclasses.dataclass
class ActiveRequestMetadata:
"""Inference request metadata."""
Expand Down Expand Up @@ -246,19 +250,24 @@ def __init__(
is_ray_backend: bool = False,
):
if prefill_engines is None:
logging.warning("No prefill engines provided.")
prefill_engines = []
if generate_engines is None:
logging.warning("No generate engines provided.")
generate_engines = []
if prefill_params is None:
logging.warning("No prefill parameters provided.")
prefill_params = []
if generate_params is None:
logging.warning("No generate parameters provided.")
generate_params = []

logging.warning(
"Initialising driver with %d prefill engines and %d generate engines.",
logging.info(
"Initializing the driver with %d prefill engines and %d generate engines in %s mode",
len(prefill_engines),
len(generate_engines),
)
"interleaved" if interleaved_mode else "disaggregated")

self._prefill_engines = prefill_engines
self._generate_engines = generate_engines
self._prefill_params = prefill_params
Expand Down Expand Up @@ -353,6 +362,13 @@ def __init__(
for idx, engine in enumerate(self._generate_engines)
]

logging.debug(
"Initializing the driver with 1 prefill backlogs, %d transfer backlogs, \n"
"%d generate backlogs and %d detokenize backlogs.",
len(self._transfer_backlogs),
len(self._generate_backlogs),
len(self._detokenize_backlogs))

self._jax_padding = jax_padding

# Create all threads
Expand Down Expand Up @@ -410,8 +426,19 @@ def __init__(
for t in self._all_threads:
t.start()

logging.debug(
"Started %d prefill threads, %d transfer threads, \n"
"%d generate threads, and %d detokenize threads.",
len(self._prefill_threads),
len(self._transfer_threads),
len(self._generate_threads),
len(self.detokenize_threads))

logging.info("Driver initialized.")

def stop(self):
"""Stops the driver and all background threads."""
logging.info("Stopping the driver and all background threads...")
# Signal to all threads that they should stop.
self.live = False

Expand Down Expand Up @@ -452,6 +479,8 @@ def stop(self):
for t in self._all_threads:
t.join()

logging.info("Driver stopped.")

def get_total_concurrent_requests(self) -> int:
"""Gets the total number of concurrent requests the driver can handle."""
# We don't support filling all backlogs at once because it can cause GIL
Expand Down Expand Up @@ -498,12 +527,13 @@ def _process_prefill_content(

def _prefill_thread(self, idx: int):
"""Thread which runs in the background performing prefills."""
logging.info("---------Spinning up prefill thread %d.---------", idx)
logging.info("Spinning up prefill thread %d.", idx)
prefill_engine = self._prefill_engines[idx]
prefill_params = self._prefill_params[idx]
metadata = prefill_engine.get_tokenizer()
tokenizer = prefill_engine.build_tokenizer(metadata)
logging.info("---------Prefill params %d loaded.---------", idx)
thread_name = "Prefill thread %d" % idx
ThreadDebugLog(thread_name, "Prefill params %d loaded." % idx)

while self.live:
my_transfer_backlog = self._transfer_backlogs[idx]
Expand All @@ -514,13 +544,11 @@ def _prefill_thread(self, idx: int):
break
request.metadata.prefill_dequeue_time = time.perf_counter()
is_bos = True
logging.info(
"Prefilling on prefill engine %d : prefill queue size, %d,"
" is_bos: %s",
idx,
self._prefill_backlog.qsize(),
is_bos,
)
ThreadDebugLog(
thread_name,
"Executing prefilling for one ActiveRequest. Current prefill backlog size: %d,"
" is_bos: %s", idx, self._prefill_backlog.qsize(),
is_bos)
# Tokenize and padding the text or token input.
padded_tokens, true_length = self._process_prefill_content(
request, tokenizer, is_bos, prefill_engine.max_prefill_length
Expand All @@ -543,14 +571,16 @@ def _prefill_thread(self, idx: int):
block=True,
)

# Once prefill is complete, place it on the generation queue and block if
ThreadDebugLog(
thread_name,
"Completed prefilling for one ActiveRequest.")
# Once prefill is complete, place it on the transfer queue and block if
# full.
my_transfer_backlog.put(request, block=True)
logging.info(
"Placed request on transfer queue %d, %d queued requests.",
idx,
my_transfer_backlog.qsize(),
)
ThreadDebugLog(
thread_name,
"Placed request on transfer backlog %d. Current transfer backlog size: %d.",
idx, my_transfer_backlog.qsize())
if self._metrics_collector:
self._metrics_collector.get_request_input_length().observe(true_length)

Expand All @@ -566,6 +596,8 @@ def _prefill_thread(self, idx: int):
del prefill_result
del request

logging.info("Prefill thread %d stopped.", idx)

def _jax_transfer_prefill_result(
self, new_request: ActiveRequest, target_idx: int
):
Expand All @@ -592,6 +624,8 @@ def _transfer_prefill_result(
def _transfer_thread(self, idx: int):
"""Transfers the kv cache on an active request to the least full
generate backlog."""
logging.info("Spinning up transfer thread %d.", idx)
thread_name = "Transfer thread %d" % idx
transfer_backlog = self._transfer_backlogs[idx]

while self.live:
Expand All @@ -606,29 +640,30 @@ def _transfer_thread(self, idx: int):
# Only transfer the KVCache for the disaggregated serving.
# TODO: Remove the conditional after fixing the compatibility.
if not self._interleaved_mode:
logging.info(
"Transferring prefill from prefill engine %d "
ThreadDebugLog(
thread_name,
"Transferring prefill result from prefill engine %d "
"to generate engine %d.",
idx,
target_idx,
)
target_idx)
# Transfer the info to the relevant generate slice.
self._transfer_prefill_result(new_request, target_idx)
# Place the request on the correct generate backlog and block if full.
new_request.metadata.generate_enqueue_time = time.perf_counter()
self._generate_backlogs[target_idx].put(new_request, block=True)
logging.info(
"Successfully transferred prefill "
"from prefill engine %d to generate engine %d "
"(%d requests now in backlog).",
ThreadDebugLog(
thread_name,
"Transferred ActiveRequest from prefill engine %d to generate backlog %d. "
"Current generate backlog size: %d.",
idx,
target_idx,
self._generate_backlogs[target_idx].qsize(),
)
self._generate_backlogs[target_idx].qsize())

logging.info("Transfer thread %d stopped.", idx)

def _generate_thread(self, idx: int):
"""Step token generation and insert prefills from backlog."""
logging.info("---------Spinning up generate thread %d.---------", idx)
logging.info("Spinning up generate thread %d.", idx)
generate_engine = self._generate_engines[idx]
my_slots = self._generate_slots[idx]
my_generate_backlog = self._generate_backlogs[idx]
Expand All @@ -640,18 +675,18 @@ def _generate_thread(self, idx: int):
decode_state = generate_engine.init_decode_state()

generate_params = self._generate_params[idx]
logging.info("---------Generate params %d loaded.---------", idx)
thread_name = "Generate thread %d" % idx
ThreadDebugLog(thread_name, "Generate params %d loaded." % idx)
time_of_last_generate = time.time()
time_of_last_print = time.time()
while self.live:
if (time.time() - time_of_last_print) > 1:
logging.info(
"Generate thread making a decision with:"
" prefill_backlog=%d"
" generate_free_slots=%d",
self._prefill_backlog.qsize(),
my_slots.qsize(),
)
ThreadDebugLog(thread_name,
"Generate thread making a decision with:"
" prefill_backlog=%d"
" generate_free_slots=%d",
self._prefill_backlog.qsize(),
my_slots.qsize())
time_of_last_print = time.time()

max_concurrent_decodes = generate_engine.max_concurrent_decodes
Expand All @@ -674,8 +709,11 @@ def _generate_thread(self, idx: int):
# Found a slot, now see if we can fill it.
except queue.Empty:
# Exit this while loop as we have no free slots to insert into.
ThreadDebugLog(thread_name, "All slots are occupied.")
break

ThreadDebugLog(thread_name, "Got an available slot.")

# We block when the decode slots are all free since we need to get a
# prefilled request to insert. We add timeout for the block to handle
# the case when the prefill backlog is cancelled and we end up with no
Expand All @@ -691,6 +729,10 @@ def _generate_thread(self, idx: int):
new_request = my_generate_backlog.get(block=block, timeout=1.0)
if new_request is None:
break
ThreadDebugLog(
thread_name,
"Got a new ActiveRequest from generate backlog %d." %
idx)
new_request.metadata.generate_dequeue_time = time.perf_counter()
if (
self._metrics_collector
Expand All @@ -711,34 +753,38 @@ def _generate_thread(self, idx: int):
except queue.Empty:
# No new requests, we can't insert, so put back slot.
my_slots.put(slot, block=False)
ThreadDebugLog(
thread_name,
"No new ActiveRequest from generate backlog %d. Put back the slot." % idx)
# If we were blocking and hit the timeout, then retry the loop.
# Otherwise, we can exit and proceed to generation.
if block:
continue
else:
break

# Signal to kill the thread.
if new_request is None:
return

logging.info(
"Generate slice %d filling slot %d at step %d.",
idx,
slot,
generate_timestep,
)

decode_state = generate_engine.insert(
new_request.prefill_result, decode_state, slot=slot
)
ThreadDebugLog(
thread_name,
"Generate slice %d filled slot %d at step %d.",
idx,
slot,
generate_timestep)

del new_request.prefill_result
new_request.generate_timestep_added = generate_timestep
new_request.complete = np.zeros(
(generate_engine.samples_per_slot,), dtype=np.bool_
)
# Respond to detokenization backpressure.

my_detokenize_backlog.put((slot, new_request), block=True)
ThreadDebugLog(
thread_name,
"Put the ActiveRequest into detokenize backlog %d. Current detokenize backlog size: %d.",
idx, my_detokenize_backlog.qsize())

# At this point, we know that we have at least some slots filled.
assert (
Expand All @@ -753,21 +799,23 @@ def _generate_thread(self, idx: int):
# Respond to detokenization backpressure.
my_detokenize_backlog.put((generate_timestep, sampled_tokens), block=True)
generate_timestep += 1
logging.info(
"Generate engine %d step %d - slots free : %d / %d, took %.2fms",
idx,
ThreadDebugLog(
thread_name,
"Step %d - slots free : %d / %d, took %.2fms",
generate_timestep,
my_slots_size,
max_concurrent_decodes,
(time.time() - time_of_last_generate) * 10**3,
)
(time.time() - time_of_last_generate) * 10**3)
time_of_last_generate = time.time()

logging.info("Generate thread %d stopped.", idx)

def _detokenize_thread(self, idx: int):
"""Detokenize sampled tokens and returns them to the user."""
# One of these per generate engine.
# For all filled my_slots, pop the sampled token onto the relevant
# requests return channel. If it done, place it back onto free slots.
logging.info("Spinning up detokenize thread %d.", idx)
my_detokenize_backlog = self._detokenize_backlogs[idx]
my_generate_engine = self._generate_engines[idx]
my_slots = self._generate_slots[idx]
Expand All @@ -777,7 +825,9 @@ def _detokenize_thread(self, idx: int):
my_live_requests = {
i: None for i in range(my_generate_engine.max_concurrent_decodes)
}
thread_name = "Detokenize thread %d" % idx
while self.live:
ThreadDebugLog(thread_name, "Waiting for a detokenization task.")
data = my_detokenize_backlog.get(block=True)
if data is None:
break
Expand All @@ -787,6 +837,8 @@ def _detokenize_thread(self, idx: int):
request_first_token, request, _ = data
request_first_token = request_first_token.convert_to_numpy()

ThreadDebugLog(
thread_name, "Detokenizing the first token of a sequence.")
results, complete = token_utils.process_result_tokens(
tokenizer=tokenizer,
slot=0, # always 0 as prefill only run 1 sample
Expand All @@ -804,11 +856,12 @@ def _detokenize_thread(self, idx: int):
self._metrics_collector.get_time_to_first_token().observe(
first_token_return_time - request.metadata.prefill_dequeue_time
)
logging.info(

ThreadDebugLog(
thread_name,
"TTFT duration: %fms",
(first_token_return_time - request.metadata.prefill_dequeue_time)
* 1000,
)
* 1000)
# generate step tokens
elif isinstance(data[1], engine_api.ResultTokens):
# We want to detokenize them.
Expand Down Expand Up @@ -870,16 +923,18 @@ def _detokenize_thread(self, idx: int):
my_live_requests[slot] = None
my_slots.put(slot, block=False) # This should always have space.
my_generate_engine.free_resource(slot)
logging.info(
ThreadDebugLog(
thread_name,
"Detokenizing generate step %d took %.2fms",
generate_timestep_added,
(time.time() - start_detokenize_time) * 10**3,
)
(time.time() - start_detokenize_time) * 10**3)
else:
# We want to update a slot with the new channel.
slot, active_request = data
my_live_requests[slot] = active_request

logging.info("Detokenize thread %d stopped.", idx)


class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer):
"""Coordinates a set of prefill and generate slices for LLM decoding."""
Expand Down

0 comments on commit f3db75d

Please sign in to comment.