2424from collections .abc import Iterator
2525from contextlib import contextmanager , suppress
2626from pathlib import Path
27- from typing import Optional
27+ from types import TracebackType
28+ from typing import Literal , Optional
2829
2930from lightning .fabric .utilities .file_lock import create_file_lock
3031from lightning .fabric .utilities .port_state import PortState
@@ -76,11 +77,16 @@ def _get_lock_dir() -> Path:
7677 except OSError as e :
7778 log .debug (f"Port manager probe file removal failed with { e } ; scheduling cleanup" )
7879
79- atexit .register (lambda p = test_file : p . unlink ( missing_ok = True ) )
80+ atexit .register (_cleanup_probe_file , test_file )
8081
8182 return lock_path
8283
8384
85+ def _cleanup_probe_file (path : Path ) -> None :
86+ """Best-effort removal of a temporary probe file at exit."""
87+ path .unlink (missing_ok = True )
88+
89+
8490def _get_lock_file () -> Path :
8591 """Get path to the port manager lock file.
8692
@@ -218,38 +224,13 @@ def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int
218224 with self ._lock : # Thread-safety
219225 try :
220226 with self ._file_lock : # Process-safety
221- # Read current state from file
222227 state = self ._read_state ()
223-
224- # Try preferred port if specified
225- if preferred_port is not None and self ._is_port_available (preferred_port , state ):
226- port = preferred_port
227- else :
228- # Find a free port
229- port = None
230- for _ in range (max_attempts ):
231- candidate = self ._find_free_port ()
232- if self ._is_port_available (candidate , state ):
233- port = candidate
234- break
235-
236- if port is None :
237- # Provide detailed diagnostics
238- allocated_count = len (state .allocated_ports )
239- queue_count = len (state .recently_released )
240- raise RuntimeError (
241- f"Failed to allocate a free port after { max_attempts } attempts. "
242- f"Diagnostics: allocated={ allocated_count } , recently_released={ queue_count } "
243- )
244-
245- # Allocate in shared state
228+ port = self ._select_port (state , preferred_port , max_attempts )
246229 state .allocate_port (port , pid = os .getpid ())
247230 self ._write_state (state )
248231
249- # Update in-memory cache
250232 self ._allocated_ports .add (port )
251233
252- # Log diagnostics if queue utilization is high
253234 queue_count = len (state .recently_released )
254235 if queue_count > 800 : # >78% of typical 1024 capacity
255236 log .warning (
@@ -261,7 +242,6 @@ def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int
261242 return port
262243
263244 except TimeoutError as e :
264- # File lock timeout - fail fast to prevent state divergence
265245 log .error (
266246 "Failed to acquire file lock for port allocation. "
267247 "Remediation: (1) Retry the operation after a short delay, "
@@ -274,6 +254,30 @@ def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int
274254 "Check if another process is holding the lock or if the lock file is inaccessible."
275255 ) from e
276256
257+ raise RuntimeError ("Unexpected error allocating port" )
258+
259+ def _select_port (
260+ self ,
261+ state : PortState ,
262+ preferred_port : Optional [int ],
263+ max_attempts : int ,
264+ ) -> int :
265+ """Choose an available port based on preference and state."""
266+ if preferred_port is not None and self ._is_port_available (preferred_port , state ):
267+ return preferred_port
268+
269+ for _ in range (max_attempts ):
270+ candidate = self ._find_free_port ()
271+ if self ._is_port_available (candidate , state ):
272+ return candidate
273+
274+ allocated_count = len (state .allocated_ports )
275+ queue_count = len (state .recently_released )
276+ raise RuntimeError (
277+ f"Failed to allocate a free port after { max_attempts } attempts. "
278+ f"Diagnostics: allocated={ allocated_count } , recently_released={ queue_count } "
279+ )
280+
277281 def _is_port_available (self , port : int , state : PortState ) -> bool :
278282 """Check if a port is available for allocation.
279283
@@ -480,7 +484,12 @@ def __enter__(self) -> "PortManager":
480484 """
481485 return self
482486
483- def __exit__ (self , exc_type , exc_val , exc_tb ) -> bool :
487+ def __exit__ (
488+ self ,
489+ exc_type : Optional [type [BaseException ]],
490+ exc_val : Optional [BaseException ],
491+ exc_tb : Optional [TracebackType ],
492+ ) -> Literal [False ]:
484493 """Exit context manager - cleanup ports from this process."""
485494 self .release_all ()
486495 return False # Don't suppress exceptions
0 commit comments