Skip to content

Commit

Permalink
more granular tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Dec 18, 2024
1 parent db010d5 commit 165a9e1
Showing 1 changed file with 74 additions and 10 deletions.
84 changes: 74 additions & 10 deletions exo/orchestration/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,27 @@ async def process_inference_result(
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens

if shard.is_last_layer() and not is_finished:
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
forward = token.reshape(1, -1)
# Add span for sampling
with tracer.start_span(
"sample_token",
context,
extra_attributes={
"temperature": self.default_sample_temperature,
"result_shape": str(result.shape)
}
):
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)

# Add span for tensor reshaping
with tracer.start_span(
"reshape_token",
context,
extra_attributes={
"input_shape": str(token.shape),
"target_shape": "(1, -1)"
}
):
forward = token.reshape(1, -1)

# Increment sequence number for next forward pass
next_sequence = context.sequence_number + 1
Expand All @@ -152,10 +171,19 @@ async def process_inference_result(
)
tracer.set_context(request_id, next_context)

self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished
self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
await self.broadcast_new_token(request_id, token.item(), is_finished)
# Add span for token processing
with tracer.start_span(
"process_token",
context,
extra_attributes={
"token_value": token.item(),
"sequence_number": context.sequence_number
}
):
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished
self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
await self.broadcast_new_token(request_id, token.item(), is_finished)

if not is_finished:
self.outstanding_requests[request_id] = "waiting"
Expand Down Expand Up @@ -265,8 +293,26 @@ async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Opti
return None

self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
await self.process_inference_result(shard, result, request_id)
# Add span for prompt inference
with tracer.start_span(
"infer_prompt",
context,
extra_attributes={
"prompt_length": len(prompt),
"shard_layers": f"{shard.start_layer}-{shard.end_layer}"
}
):
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)

# Add span for prompt result processing
with tracer.start_span(
"process_prompt_result",
context,
extra_attributes={
"result_shape": str(result.shape)
}
):
await self.process_inference_result(shard, result, request_id)
except Exception as e:
if context.request_span:
context.request_span.set_status(Status(StatusCode.ERROR, str(e)))
Expand Down Expand Up @@ -433,8 +479,26 @@ async def process_tensor(
"tensor_shape": str(tensor.shape)
}
):
result = await self.inference_engine.infer_tensor(request_id, base_shard, tensor)
await self.process_inference_result(base_shard, result, request_id)
# Add span for tensor inference
with tracer.start_span(
"infer_tensor",
context,
extra_attributes={
"input_shape": str(tensor.shape),
"shard_layers": f"{base_shard.start_layer}-{base_shard.end_layer}"
}
):
result = await self.inference_engine.infer_tensor(request_id, base_shard, tensor)

# Add span for inference result processing
with tracer.start_span(
"process_result",
context,
extra_attributes={
"result_shape": str(result.shape)
}
):
await self.process_inference_result(base_shard, result, request_id)
except Exception as e:
if request_id in self.outstanding_requests:
self.outstanding_requests.pop(request_id)
Expand Down

0 comments on commit 165a9e1

Please sign in to comment.