Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: close runner up sockets in the event there are multiple winners #143

Merged
merged 5 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import asyncio
import collections
import contextlib
import functools
import itertools
import socket
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Set, Union

from . import _staggered
from .types import AddrInfoType
Expand Down Expand Up @@ -75,15 +76,36 @@ async def start_connection(
except (RuntimeError, OSError):
continue
else: # using happy eyeballs
sock, _, _ = await _staggered.staggered_race(
(
functools.partial(
_connect_sock, current_loop, exceptions, addrinfo, local_addr_infos
)
for addrinfo in addr_infos
),
happy_eyeballs_delay,
)
open_sockets: Set[socket.socket] = set()
try:
sock, _, _ = await _staggered.staggered_race(
(
functools.partial(
_connect_sock,
current_loop,
exceptions,
addrinfo,
local_addr_infos,
open_sockets,
)
for addrinfo in addr_infos
),
happy_eyeballs_delay,
)
finally:
# If we have a winner, staggered_race will
# cancel the other tasks, however there is a
# small race window where any of the other tasks
# can be done before they are cancelled which
# will leave the socket open. To avoid this problem
# we pass a set to _connect_sock to keep track of
# the open sockets and close them here if there
# are any "runner up" sockets.
for s in open_sockets:
if s is not sock:
with contextlib.suppress(OSError):
s.close()
open_sockets = None # type: ignore[assignment]

if sock is None:
all_exceptions = [exc for sub in exceptions for exc in sub]
Expand Down Expand Up @@ -130,14 +152,26 @@ async def _connect_sock(
exceptions: List[List[Union[OSError, RuntimeError]]],
addr_info: AddrInfoType,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
open_sockets: Optional[Set[socket.socket]] = None,
) -> socket.socket:
"""Create, bind and connect one socket."""
"""
Create, bind and connect one socket.

If open_sockets is passed, add the socket to the set of open sockets.
Any failure caught here will remove the socket from the set and close it.

Callers can use this set to close any sockets that are not the winner
of all staggered tasks in the result there are runner up sockets aka
multiple winners.
"""
my_exceptions: List[Union[OSError, RuntimeError]] = []
exceptions.append(my_exceptions)
family, type_, proto, _, address = addr_info
sock = None
try:
sock = socket.socket(family=family, type=type_, proto=proto)
if open_sockets is not None:
open_sockets.add(sock)
sock.setblocking(False)
if local_addr_infos is not None:
for lfamily, _, _, _, laddr in local_addr_infos:
Expand Down Expand Up @@ -165,6 +199,8 @@ async def _connect_sock(
except (RuntimeError, OSError) as exc:
my_exceptions.append(exc)
if sock is not None:
if open_sockets is not None:
open_sockets.remove(sock)
try:
sock.close()
except OSError as e:
Expand All @@ -173,6 +209,8 @@ async def _connect_sock(
raise
except:
if sock is not None:
if open_sockets is not None:
open_sockets.remove(sock)
try:
sock.close()
except OSError as e:
Expand Down
73 changes: 71 additions & 2 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
import socket
from types import ModuleType
from typing import Tuple
from typing import List, Optional, Sequence, Set, Tuple, Union
from unittest import mock

import pytest

from aiohappyeyeballs import _staggered, start_connection
from aiohappyeyeballs import AddrInfoType, _staggered, impl, start_connection


def mock_socket_module():
Expand Down Expand Up @@ -179,6 +179,75 @@ def _socket(*args, **kw):
assert await start_connection(addr_info) == mock_socket


@pytest.mark.asyncio
@patch_socket
async def test_multiple_winners_cleaned_up(
m_socket: ModuleType,
) -> None:
loop = asyncio.get_running_loop()
finish = loop.create_future()

def _socket(*args, **kw):
return mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)

async def _connect_sock(
loop: asyncio.AbstractEventLoop,
exceptions: List[List[Union[OSError, RuntimeError]]],
addr_info: AddrInfoType,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
sockets: Optional[Set[socket.socket]] = None,
) -> socket.socket:
await finish
sock = _socket()
assert sockets is not None
sockets.add(sock)
return sock

m_socket.socket = _socket # type: ignore
addr_info = [
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.82", 80),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.84", 80),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.85", 80),
),
]
with mock.patch.object(impl, "_connect_sock", _connect_sock):
task = loop.create_task(
start_connection(addr_info, happy_eyeballs_delay=0.0001, interleave=0)
)
await asyncio.sleep(0.1)
loop.call_soon(finish.set_result, None)
await task


@pytest.mark.asyncio
@patch_socket
async def test_multiple_addr_success_second_one_happy_eyeballs(
Expand Down
Loading