From db010d51fb1d119c02b1f55966034c62296e9eba Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Wed, 18 Dec 2024 00:01:14 +0000 Subject: [PATCH] distributed tracing --- exo/api/chatgpt_api.py | 32 +- exo/download/hf/hf_helpers.py | 2 +- exo/main.py | 56 ++++ exo/networking/grpc/grpc_peer_handle.py | 137 ++++++-- exo/networking/grpc/grpc_server.py | 147 +++++++-- exo/networking/grpc/node_service.proto | 84 +++-- exo/networking/grpc/node_service_pb2.py | 80 +++-- exo/networking/grpc/node_service_pb2_grpc.py | 30 +- exo/orchestration/node.py | 316 +++++++++++++------ exo/orchestration/tracing.py | 166 ++++++++++ setup.py | 4 + 11 files changed, 791 insertions(+), 263 deletions(-) create mode 100644 exo/orchestration/tracing.py diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 5bc9fb963..4a4fb9e6b 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -314,13 +314,13 @@ async def handle_get_download_progress(self, request): async def handle_post_chat_completions(self, request): data = await request.json() - if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}") + if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}") stream = data.get("stream", False) chat_request = parse_chat_request(data, self.default_model) if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model chat_request.model = self.default_model if not chat_request.model or chat_request.model not in model_cards: - if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}") + if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}") chat_request.model = self.default_model shard = build_base_shard(chat_request.model, self.inference_engine_classname) if not shard: @@ -331,7 +331,7 @@ async def handle_post_chat_completions(self, request): ) tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname)) - if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}") + if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}") prompt = build_prompt(tokenizer, chat_request.messages) request_id = str(uuid.uuid4()) @@ -340,25 +340,13 @@ async def handle_post_chat_completions(self, request): self.on_chat_completion_request(request_id, chat_request, prompt) except Exception as e: if DEBUG >= 2: traceback.print_exc() - # request_id = None - # match = self.prompts.find_longest_prefix(prompt) - # if match and len(prompt) > len(match[1].prompt): - # if DEBUG >= 2: - # print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}") - # request_id = match[1].request_id - # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) - # # remove the matching prefix from the prompt - # prompt = prompt[len(match[1].prompt):] - # else: - # request_id = str(uuid.uuid4()) - # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) - - if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}") + + if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}") try: await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout) - if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s") + if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s") if stream: response = web.StreamResponse( @@ -374,10 +362,12 @@ async def handle_post_chat_completions(self, request): try: # Stream tokens while waiting for inference to complete while True: + if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}") token, is_finished = await asyncio.wait_for( self.token_queues[request_id].get(), timeout=self.response_timeout ) + if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}") finish_reason = None eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None) @@ -408,10 +398,13 @@ async def handle_post_chat_completions(self, request): return response except asyncio.TimeoutError: + if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}") return web.json_response({"detail": "Response generation timed out"}, status=408) except Exception as e: - if DEBUG >= 2: traceback.print_exc() + if DEBUG >= 2: + print(f"[ChatGPTAPI] Error processing prompt: {e}") + traceback.print_exc() return web.json_response( {"detail": f"Error processing prompt: {str(e)}"}, status=500 @@ -420,6 +413,7 @@ async def handle_post_chat_completions(self, request): finally: # Clean up the queue for this request if request_id in self.token_queues: + if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}") del self.token_queues[request_id] else: tokens = [] diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index d248dd373..b73cea616 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -437,7 +437,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: shard_specific_patterns.add(sorted_file_names[-1]) else: shard_specific_patterns = set(["*.safetensors"]) - if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}") + if DEBUG >= 4: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}") return list(default_patterns | shard_specific_patterns) async def get_file_download_percentage( diff --git a/exo/main.py b/exo/main.py index 677fb294c..9d4110b6a 100644 --- a/exo/main.py +++ b/exo/main.py @@ -38,6 +38,7 @@ import socket import resource import psutil +import grpc # Configure uvloop for maximum performance def configure_uvloop(): @@ -308,6 +309,61 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n async def main(): loop = asyncio.get_running_loop() + # Set up OpenTelemetry + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.resources import Resource + + # Check if Jaeger is available + def check_jaeger_connection(): + try: + # Try to connect to the OTLP gRPC port + sock = socket.create_connection(("localhost", 4317), timeout=1) + sock.close() + return True + except (socket.timeout, socket.error): + return False + + # Create and configure the tracer + resource = Resource.create({ + "service.name": "exo-distributed", + "service.instance.id": args.node_id + }) + + tracer_provider = TracerProvider(resource=resource) + + if check_jaeger_connection(): + print("Jaeger connection successful, setting up tracing...") + # Configure the OTLP exporter with better defaults for high throughput + otlp_exporter = OTLPSpanExporter( + endpoint="http://localhost:4317", + # Increase timeout to handle larger batches + timeout=30.0, + ) + + # Configure the BatchSpanProcessor with appropriate batch settings + span_processor = BatchSpanProcessor( + otlp_exporter, + # Reduce export frequency + schedule_delay_millis=5000, + # Increase max batch size + max_export_batch_size=512, + # Limit queue size to prevent memory issues + max_queue_size=2048, + ) + + tracer_provider.add_span_processor(span_processor) + else: + print("Warning: Could not connect to Jaeger, tracing will be disabled") + # Use a no-op span processor if Jaeger is not available + from opentelemetry.sdk.trace.export import ConsoleSpanExporter + tracer_provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) + + # Set the tracer provider + trace.set_tracer_provider(tracer_provider) + # Check HuggingFace directory permissions hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access() if DEBUG >= 1: print(f"Model storage directory: {hf_home}") diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index f0ef31db3..e94a25eb6 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -90,34 +90,66 @@ async def health_check(self) -> bool: traceback.print_exc() return False - async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> None: - request = node_service_pb2.PromptRequest( - prompt=prompt, + async def send_prompt( + self, + shard: Shard, + prompt: str, + request_id: Optional[str] = None, + sequence_number: Optional[int] = None, + trace_parent: Optional[str] = None + ) -> None: + request = node_service_pb2.SendPromptRequest( shard=node_service_pb2.Shard( model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers, ), + prompt=prompt, request_id=request_id, + sequence_number=sequence_number, + trace_parent=trace_parent ) await self.stub.SendPrompt(request) - async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> None: - request = node_service_pb2.TensorRequest( + async def send_tensor( + self, + shard: Shard, + tensor: np.ndarray, + request_id: Optional[str] = None, + sequence_number: Optional[int] = None, + trace_parent: Optional[str] = None + ) -> None: + request = node_service_pb2.SendTensorRequest( shard=node_service_pb2.Shard( model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers, ), - tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)), + tensor=node_service_pb2.Tensor( + tensor_data=tensor.tobytes(), + shape=tensor.shape, + dtype=str(tensor.dtype) + ), request_id=request_id, + sequence_number=sequence_number, + trace_parent=trace_parent ) await self.stub.SendTensor(request) - - async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]: - request = node_service_pb2.ExampleRequest( + + async def send_example( + self, + shard: Shard, + example: np.ndarray, + target: np.ndarray, + length: np.ndarray, + train: bool, + request_id: Optional[str] = None, + sequence_number: Optional[int] = None, + trace_parent: Optional[str] = None + ) -> Optional[np.array]: + request = node_service_pb2.SendExampleRequest( shard=node_service_pb2.Shard( model_id=shard.model_id, start_layer=shard.start_layer, @@ -129,6 +161,8 @@ async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarr length=node_service_pb2.Tensor(tensor_data=length.tobytes(), shape=length.shape, dtype=str(length.dtype)), train=train, request_id=request_id, + sequence_number=sequence_number, + trace_parent=trace_parent ) response = await self.stub.SendExample(request) loss = response.loss @@ -137,7 +171,7 @@ async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarr return loss, grads else: return loss - + async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]: request = node_service_pb2.TensorRequest( shard=node_service_pb2.Shard( @@ -156,27 +190,78 @@ async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape) - async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: - request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth) + async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology: + if DEBUG >= 2: print(f"[GRPCPeerHandle] Collecting topology from {self.id()} with {visited=} {max_depth=}") + + # Convert set to list for GRPC request + request = node_service_pb2.CollectTopologyRequest( + visited=list(visited), + max_depth=max_depth + ) + + # Make GRPC call response = await self.stub.CollectTopology(request) + if DEBUG >= 2: print(f"[GRPCPeerHandle] Got topology response from {self.id()}") + + # Convert proto topology to Topology object topology = Topology() - for node_id, capabilities in response.nodes.items(): - device_capabilities = DeviceCapabilities( - model=capabilities.model, - chip=capabilities.chip, - memory=capabilities.memory, - flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8) + proto_topology = response.topology + + # Convert nodes and their capabilities + for node in proto_topology.nodes: + # Convert DeviceCapabilities + flops = DeviceFlops( + fp32=node.capabilities.flops.fp32, + fp16=node.capabilities.flops.fp16, + int8=node.capabilities.flops.int8 + ) + capabilities = DeviceCapabilities( + model=node.capabilities.model, + chip=node.capabilities.chip, + memory=node.capabilities.memory, + flops=flops ) - topology.update_node(node_id, device_capabilities) - for node_id, peer_connections in response.peer_graph.items(): - for conn in peer_connections.connections: - topology.add_edge(node_id, conn.to_id, conn.description) + + # Add node to topology + topology.update_node(node.id, capabilities) + + # Add connections + for conn in node.connections: + topology.add_edge(node.id, conn.to_id, conn.description if conn.HasField("description") else None) + + # Set active node + if proto_topology.HasField("active_node_id"): + topology.active_node_id = proto_topology.active_node_id + + if DEBUG >= 2: print(f"[GRPCPeerHandle] Converted topology from {self.id()} with {len(topology.nodes)} nodes") return topology - async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None: - request = node_service_pb2.SendNewTokenRequest(request_id=request_id, token=token, is_finished=is_finished) + async def send_new_token( + self, + request_id: str, + token: int, + is_finished: bool, + sequence_number: Optional[int] = None, + trace_parent: Optional[str] = None + ) -> None: + request = node_service_pb2.SendNewTokenRequest( + request_id=request_id, + token=token, + is_finished=is_finished, + sequence_number=sequence_number, + trace_parent=trace_parent + ) await self.stub.SendNewToken(request) - async def send_opaque_status(self, request_id: str, status: str) -> None: - request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status) + async def send_opaque_status( + self, + request_id: str, + status: str, + trace_parent: Optional[str] = None + ) -> None: + request = node_service_pb2.SendOpaqueStatusRequest( + request_id=request_id, + status=status, + trace_parent=trace_parent + ) await self.stub.SendOpaqueStatus(request) diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index ec37d768d..72e24542b 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -58,8 +58,21 @@ async def SendPrompt(self, request, context): ) prompt = request.prompt request_id = request.request_id + sequence_number = request.sequence_number if hasattr(request, 'sequence_number') else None + trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None + + # Update trace context if sequence number or trace parent is provided + if sequence_number is not None or trace_parent is not None: + from exo.orchestration.tracing import tracer, TraceContext + context = TraceContext( + request_id=request_id, + sequence_number=sequence_number or 0, + trace_parent=trace_parent + ) + tracer.set_context(request_id, context) + await self.node.process_prompt(shard, prompt, request_id) - if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=}") + if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} {sequence_number=}") return node_service_pb2.Empty() async def SendTensor(self, request, context): @@ -71,8 +84,21 @@ async def SendTensor(self, request, context): ) tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape) request_id = request.request_id + sequence_number = request.sequence_number if hasattr(request, 'sequence_number') else None + trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None + + # Update trace context if sequence number or trace parent is provided + if sequence_number is not None or trace_parent is not None: + from exo.orchestration.tracing import tracer, TraceContext + context = TraceContext( + request_id=request_id, + sequence_number=sequence_number or 0, + trace_parent=trace_parent + ) + tracer.set_context(request_id, context) + await self.node.process_tensor(shard, tensor, request_id) - if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=}") + if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} {sequence_number=}") return node_service_pb2.Empty() async def SendExample(self, request, context): @@ -87,6 +113,18 @@ async def SendExample(self, request, context): length = np.frombuffer(request.length.tensor_data, dtype=np.dtype(request.length.dtype)).reshape(request.length.shape) train = request.train request_id = request.request_id + sequence_number = request.sequence_number if hasattr(request, 'sequence_number') else None + trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None + + # Update trace context if sequence number or trace parent is provided + if sequence_number is not None or trace_parent is not None: + from exo.orchestration.tracing import tracer, TraceContext + context = TraceContext( + request_id=request_id, + sequence_number=sequence_number or 0, + trace_parent=trace_parent + ) + tracer.set_context(request_id, context) if train and not shard.is_first_layer(): loss, grad = await self.node.process_example(shard, example, target, length, train, request_id) @@ -97,43 +135,100 @@ async def SendExample(self, request, context): loss = await self.node.process_example(shard, example, target, length, train, request_id) return node_service_pb2.Loss(loss=loss, grads=None) - async def CollectTopology(self, request, context): - max_depth = request.max_depth + async def CollectTopology( + self, + request: node_service_pb2.CollectTopologyRequest, + context: grpc.aio.ServicerContext, + ) -> node_service_pb2.CollectTopologyResponse: + # Convert visited list back to set visited = set(request.visited) - topology = self.node.current_topology - nodes = { - node_id: - node_service_pb2.DeviceCapabilities( - model=cap.model, - chip=cap.chip, - memory=cap.memory, - flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8), - ) - for node_id, cap in topology.nodes.items() - } - peer_graph = { - node_id: node_service_pb2.PeerConnections( - connections=[ - node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) - for conn in connections - ] + if DEBUG >= 2: print(f"[GRPCServer] CollectTopology request with {visited=} {request.max_depth=}") + + # Get topology from node + topology = await self.node.collect_topology(visited, request.max_depth) + if DEBUG >= 2: print(f"[GRPCServer] Got topology: {topology}") + + # Convert Topology to proto message + proto_topology = node_service_pb2.CollectTopologyResponse.Topology() + + # Convert nodes and their capabilities + for node_id, capabilities in topology.nodes.items(): + # Create DeviceFlops + flops = node_service_pb2.CollectTopologyResponse.DeviceFlops( + fp32=capabilities.flops.fp32, + fp16=capabilities.flops.fp16, + int8=capabilities.flops.int8 ) - for node_id, connections in topology.peer_graph.items() - } - if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}") - return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph) + + # Create DeviceCapabilities + device_caps = node_service_pb2.CollectTopologyResponse.DeviceCapabilities( + model=capabilities.model, + chip=capabilities.chip, + memory=capabilities.memory, + flops=flops + ) + + # Get connections for this node + connections = [] + if node_id in topology.peer_graph: + for conn in topology.peer_graph[node_id]: + proto_conn = node_service_pb2.CollectTopologyResponse.PeerConnection( + to_id=conn.to_id, + description=conn.description if conn.description else None + ) + connections.append(proto_conn) + + # Create Node with its connections + node = node_service_pb2.CollectTopologyResponse.Node( + id=node_id, + capabilities=device_caps, + connections=connections + ) + proto_topology.nodes.append(node) + + # Set active node if present + if topology.active_node_id: + proto_topology.active_node_id = topology.active_node_id + + if DEBUG >= 2: print(f"[GRPCServer] Sending topology response with {len(proto_topology.nodes)} nodes") + return node_service_pb2.CollectTopologyResponse(topology=proto_topology) async def SendNewToken(self, request, context): request_id = request.request_id token = request.token is_finished = request.is_finished - if DEBUG >= 5: print(f"Received SendNewToken request: {request_id=} {token=} {is_finished=}") + sequence_number = request.sequence_number if hasattr(request, 'sequence_number') else None + trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None + + # Update trace context if sequence number or trace parent is provided + if sequence_number is not None or trace_parent is not None: + from exo.orchestration.tracing import tracer, TraceContext + context = TraceContext( + request_id=request_id, + sequence_number=sequence_number or 0, + trace_parent=trace_parent + ) + tracer.set_context(request_id, context) + + if DEBUG >= 5: print(f"Received SendNewToken request: {request_id=} {token=} {is_finished=} {sequence_number=}") self.node.on_token.trigger_all(request_id, token, is_finished) return node_service_pb2.Empty() async def SendOpaqueStatus(self, request, context): request_id = request.request_id status = request.status + trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None + + # Update trace context if trace parent is provided + if trace_parent is not None: + from exo.orchestration.tracing import tracer, TraceContext + context = TraceContext( + request_id=request_id, + sequence_number=0, + trace_parent=trace_parent + ) + tracer.set_context(request_id, context) + if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}") self.node.on_opaque_status.trigger_all(request_id, status) return node_service_pb2.Empty() diff --git a/exo/networking/grpc/node_service.proto b/exo/networking/grpc/node_service.proto index b99f5e665..8c5a8207e 100644 --- a/exo/networking/grpc/node_service.proto +++ b/exo/networking/grpc/node_service.proto @@ -3,10 +3,10 @@ syntax = "proto3"; package node_service; service NodeService { - rpc SendPrompt (PromptRequest) returns (Empty) {} - rpc SendTensor (TensorRequest) returns (Empty) {} - rpc SendExample (ExampleRequest) returns (Loss) {} - rpc CollectTopology (CollectTopologyRequest) returns (Topology) {} + rpc SendPrompt (SendPromptRequest) returns (Empty) {} + rpc SendTensor (SendTensorRequest) returns (Empty) {} + rpc SendExample (SendExampleRequest) returns (Empty) {} + rpc CollectTopology (CollectTopologyRequest) returns (CollectTopologyResponse) {} rpc SendNewToken (SendNewTokenRequest) returns (Empty) {} rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {} rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {} @@ -19,25 +19,30 @@ message Shard { int32 n_layers = 4; } -message PromptRequest { +message SendPromptRequest { Shard shard = 1; string prompt = 2; - optional string request_id = 3; + string request_id = 3; + int32 sequence_number = 4; + string trace_parent = 5; } -message TensorRequest { +message SendTensorRequest { Shard shard = 1; Tensor tensor = 2; - optional string request_id = 3; + string request_id = 3; + int32 sequence_number = 4; + string trace_parent = 5; } -message ExampleRequest { +message SendExampleRequest { Shard shard = 1; - Tensor example = 2; - Tensor target = 3; - Tensor length = 4; - bool train = 5; - optional string request_id = 6; + bytes example = 2; + bytes target = 3; + bytes length = 4; + string request_id = 5; + bool train = 6; + string trace_parent = 7; } message Loss { @@ -56,42 +61,51 @@ message CollectTopologyRequest { int32 max_depth = 2; } -message Topology { - map nodes = 1; - map peer_graph = 2; -} +message CollectTopologyResponse { + message DeviceFlops { + double fp32 = 1; + double fp16 = 2; + double int8 = 3; + } -message PeerConnection { - string to_id = 1; - optional string description = 2; -} + message DeviceCapabilities { + string model = 1; + string chip = 2; + int32 memory = 3; + DeviceFlops flops = 4; + } -message PeerConnections { - repeated PeerConnection connections = 1; -} + message PeerConnection { + string to_id = 1; + optional string description = 2; + } -message DeviceFlops { - double fp32 = 1; - double fp16 = 2; - double int8 = 3; -} + message Node { + string id = 1; + DeviceCapabilities capabilities = 2; + repeated PeerConnection connections = 3; + } + + message Topology { + repeated Node nodes = 1; + optional string active_node_id = 2; + } -message DeviceCapabilities { - string model = 1; - string chip = 2; - int32 memory = 3; - DeviceFlops flops = 4; + Topology topology = 1; } message SendNewTokenRequest { string request_id = 1; int32 token = 2; bool is_finished = 3; + int32 sequence_number = 4; + string trace_parent = 5; } message SendOpaqueStatusRequest { string request_id = 1; string status = 2; + string trace_parent = 3; } message HealthCheckRequest {} diff --git a/exo/networking/grpc/node_service_pb2.py b/exo/networking/grpc/node_service_pb2.py index 7379eb69c..6d2c7d1a3 100644 --- a/exo/networking/grpc/node_service_pb2.py +++ b/exo/networking/grpc/node_service_pb2.py @@ -24,55 +24,49 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"M\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\r\n\x05token\x18\x02 \x01(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\x8a\x01\n\x11SendPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12\x17\n\x0fsequence_number\x18\x04 \x01(\x05\x12\x14\n\x0ctrace_parent\x18\x05 \x01(\t\"\xa0\x01\n\x11SendTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12\x17\n\x0fsequence_number\x18\x04 \x01(\x05\x12\x14\n\x0ctrace_parent\x18\x05 \x01(\t\"\xa2\x01\n\x12SendExampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0f\n\x07\x65xample\x18\x02 \x01(\x0c\x12\x0e\n\x06target\x18\x03 \x01(\x0c\x12\x0e\n\x06length\x18\x04 \x01(\x0c\x12\x12\n\nrequest_id\x18\x05 \x01(\t\x12\r\n\x05train\x18\x06 \x01(\x08\x12\x14\n\x0ctrace_parent\x18\x07 \x01(\t\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8c\x05\n\x17\x43ollectTopologyResponse\x12@\n\x08topology\x18\x01 \x01(\x0b\x32..node_service.CollectTopologyResponse.Topology\x1a\x37\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\x1a\x83\x01\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12@\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x31.node_service.CollectTopologyResponse.DeviceFlops\x1aI\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\x1a\xad\x01\n\x04Node\x12\n\n\x02id\x18\x01 \x01(\t\x12N\n\x0c\x63\x61pabilities\x18\x02 \x01(\x0b\x32\x38.node_service.CollectTopologyResponse.DeviceCapabilities\x12I\n\x0b\x63onnections\x18\x03 \x03(\x0b\x32\x34.node_service.CollectTopologyResponse.PeerConnection\x1au\n\x08Topology\x12\x39\n\x05nodes\x18\x01 \x03(\x0b\x32*.node_service.CollectTopologyResponse.Node\x12\x1b\n\x0e\x61\x63tive_node_id\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x11\n\x0f_active_node_id\"|\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\r\n\x05token\x18\x02 \x01(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\x12\x17\n\x0fsequence_number\x18\x04 \x01(\x05\x12\x14\n\x0ctrace_parent\x18\x05 \x01(\t\"S\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\x12\x14\n\x0ctrace_parent\x18\x03 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\xb5\x04\n\x0bNodeService\x12\x44\n\nSendPrompt\x12\x1f.node_service.SendPromptRequest\x1a\x13.node_service.Empty\"\x00\x12\x44\n\nSendTensor\x12\x1f.node_service.SendTensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x46\n\x0bSendExample\x12 .node_service.SendExampleRequest\x1a\x13.node_service.Empty\"\x00\x12`\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a%.node_service.CollectTopologyResponse\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_TOPOLOGY_NODESENTRY']._loaded_options = None - _globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001' - _globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None - _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001' _globals['_SHARD']._serialized_start=36 _globals['_SHARD']._serialized_end=119 - _globals['_PROMPTREQUEST']._serialized_start=121 - _globals['_PROMPTREQUEST']._serialized_end=228 - _globals['_TENSORREQUEST']._serialized_start=231 - _globals['_TENSORREQUEST']._serialized_end=360 - _globals['_EXAMPLEREQUEST']._serialized_start=363 - _globals['_EXAMPLEREQUEST']._serialized_end=585 - _globals['_LOSS']._serialized_start=587 - _globals['_LOSS']._serialized_end=659 - _globals['_TENSOR']._serialized_start=661 - _globals['_TENSOR']._serialized_end=720 - _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=722 - _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=782 - _globals['_TOPOLOGY']._serialized_start=785 - _globals['_TOPOLOGY']._serialized_end=1065 - _globals['_TOPOLOGY_NODESENTRY']._serialized_start=906 - _globals['_TOPOLOGY_NODESENTRY']._serialized_end=984 - _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=986 - _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1065 - _globals['_PEERCONNECTION']._serialized_start=1067 - _globals['_PEERCONNECTION']._serialized_end=1140 - _globals['_PEERCONNECTIONS']._serialized_start=1142 - _globals['_PEERCONNECTIONS']._serialized_end=1210 - _globals['_DEVICEFLOPS']._serialized_start=1212 - _globals['_DEVICEFLOPS']._serialized_end=1267 - _globals['_DEVICECAPABILITIES']._serialized_start=1269 - _globals['_DEVICECAPABILITIES']._serialized_end=1376 - _globals['_SENDNEWTOKENREQUEST']._serialized_start=1378 - _globals['_SENDNEWTOKENREQUEST']._serialized_end=1455 - _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1457 - _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1518 - _globals['_HEALTHCHECKREQUEST']._serialized_start=1520 - _globals['_HEALTHCHECKREQUEST']._serialized_end=1540 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=1542 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=1583 - _globals['_EMPTY']._serialized_start=1585 - _globals['_EMPTY']._serialized_end=1592 - _globals['_NODESERVICE']._serialized_start=1595 - _globals['_NODESERVICE']._serialized_end=2132 + _globals['_SENDPROMPTREQUEST']._serialized_start=122 + _globals['_SENDPROMPTREQUEST']._serialized_end=260 + _globals['_SENDTENSORREQUEST']._serialized_start=263 + _globals['_SENDTENSORREQUEST']._serialized_end=423 + _globals['_SENDEXAMPLEREQUEST']._serialized_start=426 + _globals['_SENDEXAMPLEREQUEST']._serialized_end=588 + _globals['_LOSS']._serialized_start=590 + _globals['_LOSS']._serialized_end=662 + _globals['_TENSOR']._serialized_start=664 + _globals['_TENSOR']._serialized_end=723 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=725 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=785 + _globals['_COLLECTTOPOLOGYRESPONSE']._serialized_start=788 + _globals['_COLLECTTOPOLOGYRESPONSE']._serialized_end=1440 + _globals['_COLLECTTOPOLOGYRESPONSE_DEVICEFLOPS']._serialized_start=881 + _globals['_COLLECTTOPOLOGYRESPONSE_DEVICEFLOPS']._serialized_end=936 + _globals['_COLLECTTOPOLOGYRESPONSE_DEVICECAPABILITIES']._serialized_start=939 + _globals['_COLLECTTOPOLOGYRESPONSE_DEVICECAPABILITIES']._serialized_end=1070 + _globals['_COLLECTTOPOLOGYRESPONSE_PEERCONNECTION']._serialized_start=1072 + _globals['_COLLECTTOPOLOGYRESPONSE_PEERCONNECTION']._serialized_end=1145 + _globals['_COLLECTTOPOLOGYRESPONSE_NODE']._serialized_start=1148 + _globals['_COLLECTTOPOLOGYRESPONSE_NODE']._serialized_end=1321 + _globals['_COLLECTTOPOLOGYRESPONSE_TOPOLOGY']._serialized_start=1323 + _globals['_COLLECTTOPOLOGYRESPONSE_TOPOLOGY']._serialized_end=1440 + _globals['_SENDNEWTOKENREQUEST']._serialized_start=1442 + _globals['_SENDNEWTOKENREQUEST']._serialized_end=1566 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1568 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1651 + _globals['_HEALTHCHECKREQUEST']._serialized_start=1653 + _globals['_HEALTHCHECKREQUEST']._serialized_end=1673 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=1675 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=1716 + _globals['_EMPTY']._serialized_start=1718 + _globals['_EMPTY']._serialized_end=1725 + _globals['_NODESERVICE']._serialized_start=1728 + _globals['_NODESERVICE']._serialized_end=2293 # @@protoc_insertion_point(module_scope) diff --git a/exo/networking/grpc/node_service_pb2_grpc.py b/exo/networking/grpc/node_service_pb2_grpc.py index 306e9cae4..9ece9a9cb 100644 --- a/exo/networking/grpc/node_service_pb2_grpc.py +++ b/exo/networking/grpc/node_service_pb2_grpc.py @@ -36,23 +36,23 @@ def __init__(self, channel): """ self.SendPrompt = channel.unary_unary( '/node_service.NodeService/SendPrompt', - request_serializer=node__service__pb2.PromptRequest.SerializeToString, + request_serializer=node__service__pb2.SendPromptRequest.SerializeToString, response_deserializer=node__service__pb2.Empty.FromString, _registered_method=True) self.SendTensor = channel.unary_unary( '/node_service.NodeService/SendTensor', - request_serializer=node__service__pb2.TensorRequest.SerializeToString, + request_serializer=node__service__pb2.SendTensorRequest.SerializeToString, response_deserializer=node__service__pb2.Empty.FromString, _registered_method=True) self.SendExample = channel.unary_unary( '/node_service.NodeService/SendExample', - request_serializer=node__service__pb2.ExampleRequest.SerializeToString, - response_deserializer=node__service__pb2.Loss.FromString, + request_serializer=node__service__pb2.SendExampleRequest.SerializeToString, + response_deserializer=node__service__pb2.Empty.FromString, _registered_method=True) self.CollectTopology = channel.unary_unary( '/node_service.NodeService/CollectTopology', request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString, - response_deserializer=node__service__pb2.Topology.FromString, + response_deserializer=node__service__pb2.CollectTopologyResponse.FromString, _registered_method=True) self.SendNewToken = channel.unary_unary( '/node_service.NodeService/SendNewToken', @@ -121,23 +121,23 @@ def add_NodeServiceServicer_to_server(servicer, server): rpc_method_handlers = { 'SendPrompt': grpc.unary_unary_rpc_method_handler( servicer.SendPrompt, - request_deserializer=node__service__pb2.PromptRequest.FromString, + request_deserializer=node__service__pb2.SendPromptRequest.FromString, response_serializer=node__service__pb2.Empty.SerializeToString, ), 'SendTensor': grpc.unary_unary_rpc_method_handler( servicer.SendTensor, - request_deserializer=node__service__pb2.TensorRequest.FromString, + request_deserializer=node__service__pb2.SendTensorRequest.FromString, response_serializer=node__service__pb2.Empty.SerializeToString, ), 'SendExample': grpc.unary_unary_rpc_method_handler( servicer.SendExample, - request_deserializer=node__service__pb2.ExampleRequest.FromString, - response_serializer=node__service__pb2.Loss.SerializeToString, + request_deserializer=node__service__pb2.SendExampleRequest.FromString, + response_serializer=node__service__pb2.Empty.SerializeToString, ), 'CollectTopology': grpc.unary_unary_rpc_method_handler( servicer.CollectTopology, request_deserializer=node__service__pb2.CollectTopologyRequest.FromString, - response_serializer=node__service__pb2.Topology.SerializeToString, + response_serializer=node__service__pb2.CollectTopologyResponse.SerializeToString, ), 'SendNewToken': grpc.unary_unary_rpc_method_handler( servicer.SendNewToken, @@ -180,7 +180,7 @@ def SendPrompt(request, request, target, '/node_service.NodeService/SendPrompt', - node__service__pb2.PromptRequest.SerializeToString, + node__service__pb2.SendPromptRequest.SerializeToString, node__service__pb2.Empty.FromString, options, channel_credentials, @@ -207,7 +207,7 @@ def SendTensor(request, request, target, '/node_service.NodeService/SendTensor', - node__service__pb2.TensorRequest.SerializeToString, + node__service__pb2.SendTensorRequest.SerializeToString, node__service__pb2.Empty.FromString, options, channel_credentials, @@ -234,8 +234,8 @@ def SendExample(request, request, target, '/node_service.NodeService/SendExample', - node__service__pb2.ExampleRequest.SerializeToString, - node__service__pb2.Loss.FromString, + node__service__pb2.SendExampleRequest.SerializeToString, + node__service__pb2.Empty.FromString, options, channel_credentials, insecure, @@ -262,7 +262,7 @@ def CollectTopology(request, target, '/node_service.NodeService/CollectTopology', node__service__pb2.CollectTopologyRequest.SerializeToString, - node__service__pb2.Topology.FromString, + node__service__pb2.CollectTopologyResponse.FromString, options, channel_credentials, insecure, diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index ebf9b6734..a9dd2d3cf 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -16,6 +16,7 @@ from exo.download.hf.hf_helpers import RepoProgressEvent from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.orchestration.tracing import tracer, TraceContext class Node: def __init__( @@ -111,44 +112,79 @@ async def broadcast_supported_engines(self, supported_engines_names: List[str]): def get_topology_inference_engines(self) -> List[List[str]]: return self.topology_inference_engines_pool - token_count = 0 - first_token_time = 0 async def process_inference_result( self, shard, result: np.ndarray, request_id: Optional[str] = None, ): - if request_id not in self.buffered_token_output: - self.buffered_token_output[request_id] = ([], False) - is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens - - if shard.is_last_layer() and not is_finished: - self.token_count += 1 - if self.token_count == 1: - self.first_token_time = time.perf_counter_ns() - if self.token_count % 20 == 0: - print(f"[{request_id}] TPS: {self.token_count / ((time.perf_counter_ns() - self.first_token_time) / 1e9)}") - - token = await self.inference_engine.sample(result, temp=self.default_sample_temperature) - await self.inference_engine.ensure_shard(shard) - self.buffered_token_output[request_id][0].append(token.item()) - is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens - if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") - forward = token.reshape(1, -1) - self.trigger_on_token_callbacks(request_id, token.item(), is_finished) - asyncio.create_task(self.broadcast_new_token(request_id, token.item(), is_finished)) - else: - forward = result - - if is_finished: - self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) - self.outstanding_requests.pop(request_id) - else: - self.outstanding_requests[request_id] = "waiting" - asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1))) + context = tracer.get_context(request_id) + if not context: + context = TraceContext(request_id=request_id or str(uuid.uuid4()), sequence_number=0) + tracer.set_context(request_id, context) - return np.array(self.buffered_token_output[request_id][0]) + try: + with tracer.start_span( + f"process_inference_result.{self.get_partition_index()}", + context, + extra_attributes={ + "partition_index": self.get_partition_index(), + "node_id": self.id, + "start_layer": shard.start_layer, + "end_layer": shard.end_layer + } + ): + if request_id not in self.buffered_token_output: + self.buffered_token_output[request_id] = ([], False) + is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens + + if shard.is_last_layer() and not is_finished: + token = await self.inference_engine.sample(result, temp=self.default_sample_temperature) + forward = token.reshape(1, -1) + + # Increment sequence number for next forward pass + next_sequence = context.sequence_number + 1 + # Create new context but preserve request span + next_context = TraceContext( + request_id=context.request_id, + sequence_number=next_sequence, + request_span=context.request_span # Preserve request span + ) + tracer.set_context(request_id, next_context) + + self.buffered_token_output[request_id][0].append(token.item()) + is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished + self.trigger_on_token_callbacks(request_id, token.item(), is_finished) + await self.broadcast_new_token(request_id, token.item(), is_finished) + + if not is_finished: + self.outstanding_requests[request_id] = "waiting" + asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1))) + else: + forward = result + if not is_finished: + self.outstanding_requests[request_id] = "waiting" + asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1))) + + if is_finished: + # End the request span when generation is complete + if context.request_span: + context.request_span.set_attribute("total_tokens", len(self.buffered_token_output[request_id][0])) + context.request_span.end() + context.request_span = None + self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) + self.outstanding_requests.pop(request_id) + + return np.array(self.buffered_token_output[request_id][0]) + except Exception as e: + if request_id in self.outstanding_requests: + self.outstanding_requests.pop(request_id) + # End request span on error + if context and context.request_span: + context.request_span.set_status(Status(StatusCode.ERROR, str(e))) + context.request_span.end() + context.request_span = None + raise async def process_prompt( self, @@ -195,18 +231,46 @@ async def process_prompt( async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]: if request_id is None: request_id = str(uuid.uuid4()) + + # Create or get trace context + context = tracer.get_context(request_id) + if not context: + # Create new context with request span + request_span = tracer.tracer.start_span( + "request", + attributes={ + "request_id": request_id, + "prompt": prompt, + "node_id": self.id, + "request_type": "process_prompt" + } + ) + context = TraceContext( + request_id=request_id, + sequence_number=0, + request_span=request_span, + current_span=request_span, + trace_parent=tracer.inject_context(request_span) + ) + tracer.set_context(request_id, context) + shard = self.get_current_shard(base_shard) if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}") - if not shard.is_first_layer(): - if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}") - self.outstanding_requests[request_id] = "waiting" - await self.forward_prompt(shard, prompt, request_id, 0) - return None + try: + if not shard.is_first_layer(): + if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}") + self.outstanding_requests[request_id] = "waiting" + await self.forward_prompt(shard, prompt, request_id, 0) + return None - self.outstanding_requests[request_id] = "processing" - result = await self.inference_engine.infer_prompt(request_id, shard, prompt) - await self.process_inference_result(shard, result, request_id) + self.outstanding_requests[request_id] = "processing" + result = await self.inference_engine.infer_prompt(request_id, shard, prompt) + await self.process_inference_result(shard, result, request_id) + except Exception as e: + if context.request_span: + context.request_span.set_status(Status(StatusCode.ERROR, str(e))) + raise async def enqueue_example( self, @@ -350,33 +414,36 @@ async def process_tensor( base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, - ) -> None: - shard = self.get_current_shard(base_shard) - start_time = time.perf_counter_ns() - await self._process_tensor(shard, tensor, request_id) - end_time = time.perf_counter_ns() - elapsed_time_ns = end_time - start_time - if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}") - - async def _process_tensor( - self, - base_shard: Shard, - tensor: np.ndarray, - request_id: Optional[str] = None, - ) -> None: - if request_id is None: - request_id = str(uuid.uuid4()) - shard = self.get_current_shard(base_shard) + ): + context = tracer.get_context(request_id) + if not context: + context = TraceContext(request_id=request_id or str(uuid.uuid4()), sequence_number=0) + tracer.set_context(request_id, context) try: self.outstanding_requests[request_id] = "processing" - result = await self.inference_engine.infer_tensor(request_id, shard, tensor) - await self.process_inference_result(shard, result, request_id) + with tracer.start_span( + f"process_tensor.{self.get_partition_index()}", + context, + extra_attributes={ + "partition_index": self.get_partition_index(), + "node_id": self.id, + "start_layer": base_shard.start_layer, + "end_layer": base_shard.end_layer, + "tensor_shape": str(tensor.shape) + } + ): + result = await self.inference_engine.infer_tensor(request_id, base_shard, tensor) + await self.process_inference_result(base_shard, result, request_id) except Exception as e: - self.outstanding_requests.pop(request_id) - print(f"Error processing tensor for shard {shard}: {e}") + if request_id in self.outstanding_requests: + self.outstanding_requests.pop(request_id) + if context and context.request_span: + context.request_span.set_status(Status(StatusCode.ERROR, str(e))) + print(f"Error processing tensor for shard {base_shard}: {e}") traceback.print_exc() - + raise + async def forward_example( self, base_shard: Shard, @@ -405,18 +472,39 @@ async def forward_prompt( request_id: str, target_index: int, ) -> None: - if DEBUG >= 1: print(f"target partition index: {target_index}") - target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id - next_shard = self.get_current_shard(base_shard, target_index) - if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}") - if target_id == self.id: - await self.process_prompt(next_shard, prompt, request_id) - else: - target_peer = next((p for p in self.peers if p.id() == target_id), None) - if not target_peer: - raise ValueError(f"Peer for {target_index} not found") - if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}") - await target_peer.send_prompt(next_shard, prompt, request_id=request_id) + context = tracer.get_context(request_id) + if not context: + context = TraceContext(request_id=request_id, sequence_number=0) + tracer.set_context(request_id, context) + + with tracer.start_span( + "forward_prompt", + context, + extra_attributes={ + "source_node": self.id, + "target_index": target_index, + "prompt": prompt + } + ) as span: + if DEBUG >= 1: print(f"target partition index: {target_index}") + target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id + next_shard = self.get_current_shard(base_shard, target_index) + span.set_attribute("target_node", target_id) + + # Get trace context for propagation + trace_parent = tracer.inject_context(span) + + if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}") + if target_id == self.id: + # Update local context with trace parent + context.trace_parent = trace_parent + await self.process_prompt(next_shard, prompt, request_id) + else: + target_peer = next((p for p in self.peers if p.id() == target_id), None) + if not target_peer: + raise ValueError(f"Peer for {target_index} not found") + if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}") + await target_peer.send_prompt(next_shard, prompt, request_id=request_id, sequence_number=context.sequence_number, trace_parent=trace_parent) async def forward_tensor( self, @@ -424,19 +512,39 @@ async def forward_tensor( tensor: np.ndarray, request_id: str, target_index: int, - ) -> None: - if DEBUG >= 1: print(f"target partition index: {target_index}") - target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id - next_shard = self.get_current_shard(base_shard, target_index) - if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}") - if target_id == self.id: - await self.process_tensor(next_shard, tensor, request_id) - else: - target_peer = next((p for p in self.peers if p.id() == target_id), None) - if not target_peer: - raise ValueError(f"Peer for {target_index} not found") - if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}") - await target_peer.send_tensor(next_shard, tensor, request_id=request_id) + ): + context = tracer.get_context(request_id) + if not context: + context = TraceContext(request_id=request_id, sequence_number=0) + tracer.set_context(request_id, context) + + with tracer.start_span( + "forward_tensor", + context, + extra_attributes={ + "source_node": self.id, + "target_index": target_index, + "tensor_shape": str(tensor.shape) + } + ) as span: + target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id + next_shard = self.get_current_shard(base_shard, target_index) + span.set_attribute("target_node", target_id) + + # Get trace context for propagation + trace_parent = tracer.inject_context(context.request_span or span) + + if target_id == self.id: + # Update local context with trace parent + context.trace_parent = trace_parent + await self.process_tensor(next_shard, tensor, request_id) + else: + target_peer = next((p for p in self.peers if p.id() == target_id), None) + if not target_peer: + raise ValueError(f"Peer for {target_index} not found") + + if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}") + await target_peer.send_tensor(next_shard, tensor, request_id=request_id, sequence_number=context.sequence_number, trace_parent=trace_parent) def get_partition_index(self, offset: int = 0): if not self.partitioning_strategy: @@ -570,20 +678,32 @@ def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]: return self._on_opaque_status def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: bool) -> None: - if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {token=} {is_finished=}") + if DEBUG >= 2: print(f"[Node] Triggering token callbacks: {request_id=} {token=} {is_finished=}") self.on_token.trigger_all(request_id, token, is_finished) - async def broadcast_new_token(self, request_id: str, token: int, is_finished: bool) -> None: - async def send_new_token_to_peer(peer): - try: - await asyncio.wait_for(peer.send_new_token(request_id, token, is_finished), timeout=15.0) - except asyncio.TimeoutError: - print(f"Timeout broadcasting new token to {peer.id()}") - except Exception as e: - print(f"Error broadcasting new token to {peer.id()}: {e}") - traceback.print_exc() - - await asyncio.gather(*[send_new_token_to_peer(peer) for peer in self.peers], return_exceptions=True) + async def broadcast_new_token(self, request_id: str, token: int, is_finished: bool): + """Broadcast a new token to all peers.""" + context = tracer.get_context(request_id) + if context: + # Handle token in tracer for grouping + tracer.handle_token(context, token, is_finished) + # Get current trace context for propagation + trace_parent = "" + if context.current_span: + trace_parent = tracer.inject_context(context.current_span) + + tasks = [] + for peer in self.peers: + tasks.append( + peer.send_new_token( + request_id, + token, + is_finished, + context.sequence_number if context else 0, + trace_parent if context else "" + ) + ) + await asyncio.gather(*tasks) async def broadcast_opaque_status(self, request_id: str, status: str) -> None: if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}") diff --git a/exo/orchestration/tracing.py b/exo/orchestration/tracing.py new file mode 100644 index 000000000..4466fc7d0 --- /dev/null +++ b/exo/orchestration/tracing.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Any +from opentelemetry import trace, context +from opentelemetry.trace import Status, StatusCode, SpanContext +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from contextlib import contextmanager +import time +from threading import Lock + +@dataclass +class TraceContext: + request_id: str + sequence_number: int + current_span: Optional[trace.Span] = None + trace_parent: Optional[str] = None + token_group_span: Optional[trace.Span] = None + token_count: int = 0 + token_group_size: int = 10 # Default group size + request_span: Optional[trace.Span] = None # Track the main request span + +class Tracer: + def __init__(self): + self.tracer = trace.get_tracer("exo") + self.contexts: Dict[str, TraceContext] = {} + self._lock = Lock() + self.propagator = TraceContextTextMapPropagator() + + def get_context(self, request_id: str) -> Optional[TraceContext]: + with self._lock: + return self.contexts.get(request_id) + + def set_context(self, request_id: str, context: TraceContext): + with self._lock: + self.contexts[request_id] = context + + def inject_context(self, span: trace.Span) -> str: + """Inject current span context into carrier for propagation""" + carrier = {} + ctx = trace.set_span_in_context(span) + self.propagator.inject(carrier, context=ctx) + return carrier.get("traceparent", "") + + def extract_context(self, trace_parent: str) -> Optional[context.Context]: + """Extract span context from carrier""" + if not trace_parent: + return None + carrier = {"traceparent": trace_parent} + return self.propagator.extract(carrier) + + def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext: + """Create a new context with the given trace parent""" + parent_ctx = self.extract_context(trace_parent) + if parent_ctx: + # Create a new request span that links to the parent context + request_span = self.tracer.start_span( + "request", + context=parent_ctx, + attributes={ + "request_id": request_id, + "sequence_number": sequence_number + } + ) + return TraceContext( + request_id=request_id, + sequence_number=sequence_number, + request_span=request_span, + current_span=request_span, + trace_parent=trace_parent + ) + return TraceContext(request_id=request_id, sequence_number=sequence_number) + + def handle_token(self, context: TraceContext, token: int, is_finished: bool = False): + """Handle token generation and manage token group spans""" + context.token_count += 1 + + # Start a new token group span if needed + if not context.token_group_span and context.request_span: + group_number = (context.token_count - 1) // context.token_group_size + 1 + + # Create token group span as child of request span + parent_ctx = trace.set_span_in_context(context.request_span) + context.token_group_span = self.tracer.start_span( + f"token_group_{group_number}", + context=parent_ctx, + attributes={ + "request_id": context.request_id, + "group.number": group_number, + "group.start_token": context.token_count, + "group.max_tokens": context.token_group_size + } + ) + + # Add token to current group span + if context.token_group_span: + relative_pos = ((context.token_count - 1) % context.token_group_size) + 1 + context.token_group_span.set_attribute(f"token.{relative_pos}", token) + context.token_group_span.set_attribute("token.count", relative_pos) + + # End current group span if we've reached the group size or if generation is finished + if context.token_count % context.token_group_size == 0 or is_finished: + context.token_group_span.set_attribute("token.final_count", relative_pos) + context.token_group_span.end() + context.token_group_span = None + + @contextmanager + def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None): + """Start a new span with proper parent context""" + attributes = { + "request_id": context.request_id, + "sequence_number": context.sequence_number + } + if extra_attributes: + attributes.update(extra_attributes) + + # Use request span as parent if available + parent_ctx = None + if context.request_span: + parent_ctx = trace.set_span_in_context(context.request_span) + elif context.trace_parent: + parent_ctx = self.extract_context(context.trace_parent) + if parent_ctx and not context.request_span: + # Create a new request span that links to the parent context + context.request_span = self.tracer.start_span( + "request", + context=parent_ctx, + attributes={ + "request_id": context.request_id, + "sequence_number": context.sequence_number + } + ) + parent_ctx = trace.set_span_in_context(context.request_span) + elif context.current_span: + parent_ctx = trace.set_span_in_context(context.current_span) + + # Create span with parent context if it exists + if parent_ctx: + span = self.tracer.start_span( + name, + context=parent_ctx, + attributes=attributes + ) + else: + span = self.tracer.start_span( + name, + attributes=attributes + ) + + # Update context with current span + prev_span = context.current_span + context.current_span = span + + try: + start_time = time.perf_counter() + yield span + duration = time.perf_counter() - start_time + span.set_attribute("duration_s", duration) + span.set_status(Status(StatusCode.OK)) + except Exception as e: + span.set_status(Status(StatusCode.ERROR, str(e))) + raise + finally: + span.end() + context.current_span = prev_span + +# Global tracer instance +tracer = Tracer() \ No newline at end of file diff --git a/setup.py b/setup.py index 4b3720a28..2e18c44a4 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,10 @@ "nuitka==2.5.1", "nvidia-ml-py==12.560.30", "opencv-python==4.10.0.84", + "opentelemetry-api==1.29.0", + "opentelemetry-sdk==1.29.0", + "opentelemetry-exporter-otlp==1.29.0", + "opentelemetry-instrumentation==0.50b0", "pillow==10.4.0", "prometheus-client==0.20.0", "protobuf==5.28.1",