Skip to content

Commit

Permalink
allowed interface types
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Dec 8, 2024
1 parent b216819 commit 571b26c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/bench_job.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
echo "Starting exo daemon..."
DEBUG=6 DEBUG_DISCOVERY=6 exo --node-id="${MY_NODE_ID}" --node-id-filter="${ALL_NODE_IDS}" --chatgpt-api-port 52415 > output1.log 2>&1 &
DEBUG=6 DEBUG_DISCOVERY=6 exo --node-id="${MY_NODE_ID}" --node-id-filter="${ALL_NODE_IDS}" --interface-type-filter="Ethernet" --chatgpt-api-port 52415 > output1.log 2>&1 &
PID1=$!
echo "Exo process started with PID: $PID1"
tail -f output1.log &
Expand Down
7 changes: 5 additions & 2 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")

Expand Down Expand Up @@ -90,8 +91,9 @@
for chatgpt_api_endpoint in chatgpt_api_endpoints:
print(f" - {terminal_link(chatgpt_api_endpoint)}")

# Convert node-id-filter to list if provided
# Convert node-id-filter and interface-type-filter to lists if provided
allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None

if args.discovery_module == "udp":
discovery = UDPDiscovery(
Expand All @@ -101,7 +103,8 @@
args.broadcast_port,
lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
discovery_timeout=args.discovery_timeout,
allowed_node_ids=allowed_node_ids
allowed_node_ids=allowed_node_ids,
allowed_interface_types=allowed_interface_types
)
elif args.discovery_module == "tailscale":
discovery = TailscaleDiscovery(
Expand Down
12 changes: 10 additions & 2 deletions exo/networking/udp/udp_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import socket
import time
import traceback
from typing import List, Dict, Callable, Tuple, Coroutine
from typing import List, Dict, Callable, Tuple, Coroutine, Optional
from exo.networking.discovery import Discovery
from exo.networking.peer_handle import PeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
Expand Down Expand Up @@ -45,7 +45,8 @@ def __init__(
broadcast_interval: int = 2.5,
discovery_timeout: int = 30,
device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
allowed_node_ids: List[str] = None,
allowed_node_ids: Optional[List[str]] = None,
allowed_interface_types: Optional[List[str]] = None,
):
self.node_id = node_id
self.node_port = node_port
Expand All @@ -56,6 +57,7 @@ def __init__(
self.discovery_timeout = discovery_timeout
self.device_capabilities = device_capabilities
self.allowed_node_ids = allowed_node_ids
self.allowed_interface_types = allowed_interface_types
self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
self.broadcast_task = None
self.listen_task = None
Expand Down Expand Up @@ -147,6 +149,12 @@ async def on_listen_message(self, data, addr):
peer_prio = message["priority"]
peer_interface_name = message["interface_name"]
peer_interface_type = message["interface_type"]

# Skip if interface type is not in allowed list
if self.allowed_interface_types and peer_interface_type not in self.allowed_interface_types:
if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as its interface type {peer_interface_type} is not in the allowed interface types list")
return

device_capabilities = DeviceCapabilities(**message["device_capabilities"])

if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
Expand Down

0 comments on commit 571b26c

Please sign in to comment.