Skip to content

Commit

Permalink
distributed tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Dec 18, 2024
1 parent 023ddc2 commit db010d5
Show file tree
Hide file tree
Showing 11 changed files with 791 additions and 263 deletions.
32 changes: 13 additions & 19 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,13 @@ async def handle_get_download_progress(self, request):

async def handle_post_chat_completions(self, request):
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
chat_request = parse_chat_request(data, self.default_model)
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
chat_request.model = self.default_model
if not chat_request.model or chat_request.model not in model_cards:
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
chat_request.model = self.default_model
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
if not shard:
Expand All @@ -331,7 +331,7 @@ async def handle_post_chat_completions(self, request):
)

tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")

prompt = build_prompt(tokenizer, chat_request.messages)
request_id = str(uuid.uuid4())
Expand All @@ -340,25 +340,13 @@ async def handle_post_chat_completions(self, request):
self.on_chat_completion_request(request_id, chat_request, prompt)
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
# request_id = None
# match = self.prompts.find_longest_prefix(prompt)
# if match and len(prompt) > len(match[1].prompt):
# if DEBUG >= 2:
# print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
# request_id = match[1].request_id
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
# # remove the matching prefix from the prompt
# prompt = prompt[len(match[1].prompt):]
# else:
# request_id = str(uuid.uuid4())
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))

if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")

if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")

try:
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)

if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")

if stream:
response = web.StreamResponse(
Expand All @@ -374,10 +362,12 @@ async def handle_post_chat_completions(self, request):
try:
# Stream tokens while waiting for inference to complete
while True:
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
token, is_finished = await asyncio.wait_for(
self.token_queues[request_id].get(),
timeout=self.response_timeout
)
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}")

finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
Expand Down Expand Up @@ -408,10 +398,13 @@ async def handle_post_chat_completions(self, request):
return response

except asyncio.TimeoutError:
if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
return web.json_response({"detail": "Response generation timed out"}, status=408)

except Exception as e:
if DEBUG >= 2: traceback.print_exc()
if DEBUG >= 2:
print(f"[ChatGPTAPI] Error processing prompt: {e}")
traceback.print_exc()
return web.json_response(
{"detail": f"Error processing prompt: {str(e)}"},
status=500
Expand All @@ -420,6 +413,7 @@ async def handle_post_chat_completions(self, request):
finally:
# Clean up the queue for this request
if request_id in self.token_queues:
if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
del self.token_queues[request_id]
else:
tokens = []
Expand Down
2 changes: 1 addition & 1 deletion exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
shard_specific_patterns.add(sorted_file_names[-1])
else:
shard_specific_patterns = set(["*.safetensors"])
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
if DEBUG >= 4: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

async def get_file_download_percentage(
Expand Down
56 changes: 56 additions & 0 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import socket
import resource
import psutil
import grpc

# Configure uvloop for maximum performance
def configure_uvloop():
Expand Down Expand Up @@ -308,6 +309,61 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
async def main():
loop = asyncio.get_running_loop()

# Set up OpenTelemetry
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource

# Check if Jaeger is available
def check_jaeger_connection():
try:
# Try to connect to the OTLP gRPC port
sock = socket.create_connection(("localhost", 4317), timeout=1)
sock.close()
return True
except (socket.timeout, socket.error):
return False

# Create and configure the tracer
resource = Resource.create({
"service.name": "exo-distributed",
"service.instance.id": args.node_id
})

tracer_provider = TracerProvider(resource=resource)

if check_jaeger_connection():
print("Jaeger connection successful, setting up tracing...")
# Configure the OTLP exporter with better defaults for high throughput
otlp_exporter = OTLPSpanExporter(
endpoint="http://localhost:4317",
# Increase timeout to handle larger batches
timeout=30.0,
)

# Configure the BatchSpanProcessor with appropriate batch settings
span_processor = BatchSpanProcessor(
otlp_exporter,
# Reduce export frequency
schedule_delay_millis=5000,
# Increase max batch size
max_export_batch_size=512,
# Limit queue size to prevent memory issues
max_queue_size=2048,
)

tracer_provider.add_span_processor(span_processor)
else:
print("Warning: Could not connect to Jaeger, tracing will be disabled")
# Use a no-op span processor if Jaeger is not available
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
tracer_provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter()))

# Set the tracer provider
trace.set_tracer_provider(tracer_provider)

# Check HuggingFace directory permissions
hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
Expand Down
137 changes: 111 additions & 26 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,34 +90,66 @@ async def health_check(self) -> bool:
traceback.print_exc()
return False

async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> None:
request = node_service_pb2.PromptRequest(
prompt=prompt,
async def send_prompt(
self,
shard: Shard,
prompt: str,
request_id: Optional[str] = None,
sequence_number: Optional[int] = None,
trace_parent: Optional[str] = None
) -> None:
request = node_service_pb2.SendPromptRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
start_layer=shard.start_layer,
end_layer=shard.end_layer,
n_layers=shard.n_layers,
),
prompt=prompt,
request_id=request_id,
sequence_number=sequence_number,
trace_parent=trace_parent
)
await self.stub.SendPrompt(request)

async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> None:
request = node_service_pb2.TensorRequest(
async def send_tensor(
self,
shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
sequence_number: Optional[int] = None,
trace_parent: Optional[str] = None
) -> None:
request = node_service_pb2.SendTensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
start_layer=shard.start_layer,
end_layer=shard.end_layer,
n_layers=shard.n_layers,
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
tensor=node_service_pb2.Tensor(
tensor_data=tensor.tobytes(),
shape=tensor.shape,
dtype=str(tensor.dtype)
),
request_id=request_id,
sequence_number=sequence_number,
trace_parent=trace_parent
)
await self.stub.SendTensor(request)

async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.ExampleRequest(

async def send_example(
self,
shard: Shard,
example: np.ndarray,
target: np.ndarray,
length: np.ndarray,
train: bool,
request_id: Optional[str] = None,
sequence_number: Optional[int] = None,
trace_parent: Optional[str] = None
) -> Optional[np.array]:
request = node_service_pb2.SendExampleRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
start_layer=shard.start_layer,
Expand All @@ -129,6 +161,8 @@ async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarr
length=node_service_pb2.Tensor(tensor_data=length.tobytes(), shape=length.shape, dtype=str(length.dtype)),
train=train,
request_id=request_id,
sequence_number=sequence_number,
trace_parent=trace_parent
)
response = await self.stub.SendExample(request)
loss = response.loss
Expand All @@ -137,7 +171,7 @@ async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarr
return loss, grads
else:
return loss

async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
Expand All @@ -156,27 +190,78 @@ async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional

return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)

async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
if DEBUG >= 2: print(f"[GRPCPeerHandle] Collecting topology from {self.id()} with {visited=} {max_depth=}")

# Convert set to list for GRPC request
request = node_service_pb2.CollectTopologyRequest(
visited=list(visited),
max_depth=max_depth
)

# Make GRPC call
response = await self.stub.CollectTopology(request)
if DEBUG >= 2: print(f"[GRPCPeerHandle] Got topology response from {self.id()}")

# Convert proto topology to Topology object
topology = Topology()
for node_id, capabilities in response.nodes.items():
device_capabilities = DeviceCapabilities(
model=capabilities.model,
chip=capabilities.chip,
memory=capabilities.memory,
flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
proto_topology = response.topology

# Convert nodes and their capabilities
for node in proto_topology.nodes:
# Convert DeviceCapabilities
flops = DeviceFlops(
fp32=node.capabilities.flops.fp32,
fp16=node.capabilities.flops.fp16,
int8=node.capabilities.flops.int8
)
capabilities = DeviceCapabilities(
model=node.capabilities.model,
chip=node.capabilities.chip,
memory=node.capabilities.memory,
flops=flops
)
topology.update_node(node_id, device_capabilities)
for node_id, peer_connections in response.peer_graph.items():
for conn in peer_connections.connections:
topology.add_edge(node_id, conn.to_id, conn.description)

# Add node to topology
topology.update_node(node.id, capabilities)

# Add connections
for conn in node.connections:
topology.add_edge(node.id, conn.to_id, conn.description if conn.HasField("description") else None)

# Set active node
if proto_topology.HasField("active_node_id"):
topology.active_node_id = proto_topology.active_node_id

if DEBUG >= 2: print(f"[GRPCPeerHandle] Converted topology from {self.id()} with {len(topology.nodes)} nodes")
return topology

async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
request = node_service_pb2.SendNewTokenRequest(request_id=request_id, token=token, is_finished=is_finished)
async def send_new_token(
self,
request_id: str,
token: int,
is_finished: bool,
sequence_number: Optional[int] = None,
trace_parent: Optional[str] = None
) -> None:
request = node_service_pb2.SendNewTokenRequest(
request_id=request_id,
token=token,
is_finished=is_finished,
sequence_number=sequence_number,
trace_parent=trace_parent
)
await self.stub.SendNewToken(request)

async def send_opaque_status(self, request_id: str, status: str) -> None:
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
async def send_opaque_status(
self,
request_id: str,
status: str,
trace_parent: Optional[str] = None
) -> None:
request = node_service_pb2.SendOpaqueStatusRequest(
request_id=request_id,
status=status,
trace_parent=trace_parent
)
await self.stub.SendOpaqueStatus(request)
Loading

0 comments on commit db010d5

Please sign in to comment.