Skip to content

Commit

Permalink
add round tripper for client
Browse files Browse the repository at this point in the history
  • Loading branch information
ElNiak committed Jan 4, 2024
1 parent 2969333 commit f13bd96
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 81 deletions.
34 changes: 24 additions & 10 deletions py-ssh3/client_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ async def main():
configuration.keylog_file = key_log

log.info(f"TLS configuration is {configuration}")

tls_config = None # TODO

round_tripper = RoundTripper(quic_config=configuration,tls_config=tls_config)

ssh_auth_sock = os.getenv('SSH_AUTH_SOCK')
log.debug(f"SSH_AUTH_SOCK is {ssh_auth_sock}")
Expand Down Expand Up @@ -340,8 +344,8 @@ async def establish_client_connection(client):
new_url = URL(url_from_param.replace("https","ssh3")) # TODO -> should replace Proto
log.info(f"New URL is {new_url}")
req = HttpRequest(method="CONNECT", url=new_url)
req.headers['user-agent'] = get_current_version()
req.headers['protocol'] = "ssh3" # TODO -> should replace Proto
# req.headers[b'user-agent'] = get_current_version()
req.headers[':protocol'] = "ssh3" # TODO -> should replace Proto
log.info(f"Request is {req}")
# TODO seems not totally correct and secure
log.info(f"Request is {req}")
Expand Down Expand Up @@ -453,7 +457,7 @@ async def establish_client_connection(client):
exit(-1)


async def dial_quic_host(hostname, port, quic_config, known_hosts_path, establish_client_connection):
async def dial_quic_host(hostname, port, quic_config, known_hosts_path):
try:
# Check if hostname is an IP address and format it appropriately
try:
Expand All @@ -474,7 +478,8 @@ async def dial_quic_host(hostname, port, quic_config, known_hosts_path, establis
# Connection established
client = cast(HttpClient, client)
log.info(f"Connected to {hostname}:{port} with client {client}")
await establish_client_connection(client)
return client
# await establish_client_connection(client)
# coros = [
# perform_http_request(
# client=client,
Expand All @@ -487,9 +492,9 @@ async def dial_quic_host(hostname, port, quic_config, known_hosts_path, establis
# ]
# await asyncio.gather(*coros)

log.info(f"Push HTTP event{client}")
process_http_pushes(client=client,include=False,output_dir=None)
client._quic.close(error_code=ErrorCode.H3_NO_ERROR)
# log.info(f"Push HTTP event{client}")
# process_http_pushes(client=client,include=False,output_dir=None)
# client._quic.close(error_code=ErrorCode.H3_NO_ERROR)

except ssl.SSLError as e:
logging.error("TLS error: %s", e)
Expand Down Expand Up @@ -535,14 +540,23 @@ async def dial_quic_host(hostname, port, quic_config, known_hosts_path, establis


log.info(f"Starting client to {url_from_param}")
await dial_quic_host(
client = await dial_quic_host(
hostname=hostname,
port=port,
quic_config=configuration,
known_hosts_path=known_hosts_path,
establish_client_connection=establish_client_connection
known_hosts_path=known_hosts_path
)

if client == -1 or client == 0:
return

# // TODO: could be nice ?? dirty hack: ensure only one QUIC connection is used
def dial(addr:str, tls_config, quic_config):
return client, None
round_tripper.dial = dial

await establish_client_connection(client)

try:
channel = conv.open_channel("session", 30000, 0)
except Exception as e:
Expand Down
104 changes: 99 additions & 5 deletions py-ssh3/http3/http3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ssl
import time
from collections import deque
from typing import BinaryIO, Callable, Deque, Dict, List, Optional, Union, cast
from typing import BinaryIO, Callable, Deque, Dict, List, Optional, Union, cast, Tuple
from urllib.parse import urlparse

import aioquic
Expand All @@ -21,6 +21,7 @@
H3Event,
HeadersReceived,
PushPromiseReceived,
DatagramReceived
)
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.events import QuicEvent
Expand Down Expand Up @@ -230,6 +231,8 @@ def http_event_received(self, event: H3Event) -> None:
def quic_event_received(self, event: QuicEvent) -> None:
#  pass event to the HTTP layer
logger.debug(f"HttpClient received QUIC event: {event}")
if isinstance(event, DatagramReceived):
self.handle_datagram(event.data)
if self._http is not None:
for http_event in self._http.handle_event(event):
self.http_event_received(http_event)
Expand All @@ -242,10 +245,10 @@ async def _request(self, request: HttpRequest) -> Deque[H3Event]:
(b":method", request.method.encode()),
(b":scheme", request.url.scheme.encode()),
(b":authority", request.url.authority.encode()),
(b":path", request.url.full_path.encode()),
(b"user-agent", USER_AGENT.encode()),
(b":path", request.url.full_path.encode())
]
+ [(k.encode(), v.encode()) for (k, v) in request.headers.items()],
+ [(k.encode(), v.encode()) for (k, v) in request.headers.items()]
+ [(b"user-agent", USER_AGENT.encode())],
end_stream=not request.content,
)
if request.content:
Expand All @@ -257,10 +260,101 @@ async def _request(self, request: HttpRequest) -> Deque[H3Event]:
self._request_events[stream_id] = deque()
self._request_waiter[stream_id] = waiter
self.transmit()
logger.debug(f"HttpClient _request called with request: {request}")
head = [ # For debug
(b":method", request.method.encode()),
(b":scheme", request.url.scheme.encode()),
(b":authority", request.url.authority.encode()),
(b":path", request.url.full_path.encode())
] + [(k.encode(), v.encode()) for (k, v) in request.headers.items()] + [(b"user-agent", USER_AGENT.encode())]
logger.debug(f"HttpClient _request called with request: {request} and header: {head}")
return await asyncio.shield(waiter)



class RoundTripper:
def __init__(self, quic_config: Optional[QuicConfiguration] = None,
tls_config: Optional[ssl.SSLContext] = None,
dial: Optional[Callable] = None,
save_session_ticket: Optional[Callable[[SessionTicket], None]] = None,
hijack_stream: Optional[Callable] = None):
self.quic_config = quic_config or QuicConfiguration(is_client=True, alpn_protocols=H3_ALPN)
self.tls_config = tls_config or ssl.create_default_context()
self.dial = dial
self.save_session_ticket = save_session_ticket or self._default_save_session_ticket
self.hijack_stream = hijack_stream or self._default_hijack_stream
self.connections: Dict[Tuple[str, int], QuicConnectionProtocol] = {}
self.last_used = {}

async def _get_or_create_client(self, host, port):
key = (host, port)
if key not in self.connections:
connection = await connect(
host=host,
port=port,
configuration=self.quic_config,
create_protocol=HttpClient,
session_ticket_handler=self._save_session_ticket,
local_port=0, # Replace with desired local port if needed
wait_connected=True,
)
self.connections[key] = connection
self.last_used[key] = time.time()
return self.connections[key]

def _cleanup_connections(self):
"""Close connections that have been idle for a certain threshold."""
idle_threshold = 60 # seconds
current_time = time.time()
for key, last_used_time in list(self.last_used.items()):
if current_time - last_used_time > idle_threshold:
self.connections[key]._quic.close()
del self.connections[key]
del self.last_used[key]

async def round_trip(self, request: HttpRequest) -> Deque[H3Event]:
self._cleanup_connections() # Clean up idle connections
url = request.url
parsed = urlparse(url)
assert parsed.scheme == "https", "Only https:// URLs are supported."
host = parsed.hostname
port = parsed.port or 443

# Get or create a QUIC client for the given host and port
client = await self._get_or_create_client(host, port)
client = cast(HttpClient, client)

# Use the client to perform an HTTP request
if request.method == "GET":
return await client.get(url)
elif request.method == "POST":
return await client.post(url, request.content, request.headers)
else:
raise ValueError("Unsupported HTTP method")

def _default_save_session_ticket(self, ticket: SessionTicket) -> None:
# Implement session ticket saving logic if needed
# TODO
logger.info("New session ticket received - TODO")
# if args.session_ticket:
# with open(args.session_ticket, "wb") as fp:
# pickle.dump(ticket, fp)

async def send_datagram(self, data: bytes, host: str, port: int):
client = await self._get_or_create_client(host, port)
client = cast(HttpClient, client)
client._quic.send_datagram_frame(data)

async def _default_hijack_stream(self, stream_id: int, host: str, port: int):
client = await self._get_or_create_client(host, port)
client = cast(HttpClient, client)

# Example: reading directly from a stream
stream = client._quic._get_or_create_stream_for_receive(stream_id)
data = await stream.receive_some()
# Process data...

# Similar methods can be implemented for writing to a stream.

async def perform_http_request(
client: HttpClient,
url: str,
Expand Down
5 changes: 3 additions & 2 deletions py-ssh3/http3/http3_hijacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class HTTPStreamer:
def __init__(self, stream_reader, stream_writer):
self.stream_reader = stream_reader
self.stream_writer = stream_writer
self.stream_id = stream_writer._transport.stream_id

async def read(self, size):
return await self.stream_reader.read(size)
Expand All @@ -31,11 +32,11 @@ def __init__(self, protocol: QuicConnectionProtocol):
self.protocol = protocol

async def open_stream(self) -> HTTPStreamer:
reader, writer = await self.protocol._quic.create_stream()
reader, writer = await self.protocol.create_stream()
return HTTPStreamer(reader, writer)

async def open_uni_stream(self) -> HTTPStreamer:
reader, writer = await self.protocol._quic.create_unidirectional_stream()
reader, writer = await self.protocol.create_stream(is_unidirectional=True)
return HTTPStreamer(reader, writer)

def local_addr(self):
Expand Down
14 changes: 10 additions & 4 deletions py-ssh3/http3/http3_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@

log = logging.getLogger(__name__)

from ssh3.version import get_current_version


AsgiApplication = Callable
HttpConnection = Union[H0Connection, H3Connection]

SERVER_NAME = "aioquic/" + aioquic.__version__
SERVER_NAME = get_current_version()


class HttpRequestHandler:
Expand All @@ -65,6 +67,7 @@ def __init__(
self.queue.put_nowait({"type": "http.request"})

def http_event_received(self, event: H3Event) -> None:
log.debug("HTTP event received: %s", event)
if isinstance(event, DataReceived):
self.queue.put_nowait(
{
Expand Down Expand Up @@ -332,7 +335,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._handlers: Dict[int, Handler] = {}
self._http: Optional[HttpConnection] = None
self.hijacker = Hijacker(self)


def http_event_received(self, event: H3Event) -> None:
log.debug("HTTP event received: %s", event)
Expand Down Expand Up @@ -361,7 +364,7 @@ def http_event_received(self, event: H3Event) -> None:
else:
path_bytes, query_string = raw_path, b""
path = path_bytes.decode()
self._quic._logger.info("HTTP request %s %s", method, path)
self._quic._logger.info("HTTP request %s %s %s", method, path, protocol)

# FIXME: add a public API to retrieve peer address
client_addr = self._http._quic._network_paths[0].addr
Expand Down Expand Up @@ -413,6 +416,7 @@ def http_event_received(self, event: H3Event) -> None:
transmit=self.transmit,
)
else:
scheme = protocol
extensions: Dict[str, Dict] = {}
if isinstance(self._http, H3Connection):
extensions["http.response.push"] = {}
Expand All @@ -426,8 +430,10 @@ def http_event_received(self, event: H3Event) -> None:
"query_string": query_string,
"raw_path": raw_path,
"root_path": "",
"scheme": "https",
"scheme": scheme,
"type": "http",
"stream_ended":event.stream_ended,
"stream_id":event.stream_id,
}
handler = HttpRequestHandler(
authority=authority,
Expand Down
45 changes: 32 additions & 13 deletions py-ssh3/linux_server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
WebTransportStreamDataReceived,
)
from aioquic.quic.events import DatagramFrameReceived, ProtocolNegotiated, QuicEvent
from aioquic.quic.connection import *
from ssh3.version import *
from starlette.responses import PlainTextResponse, Response
from aioquic.tls import *
Expand Down Expand Up @@ -57,11 +58,14 @@ async def handle_auths(
handler_func: callable,
quic_server: QuicServer
"""

logger.info(f"Auth - Received request {request}")
logger.info(f"Auth - Received request headers {request.headers}")
# Set response server header
content = ""
status = 200
header = [(b"Server", SERVER_NAME)]
header = {
b"Server": SERVER_NAME
}

# Check SSH3 version
user_agent = b""
Expand All @@ -82,22 +86,37 @@ async def handle_auths(
headers=header,
status_code=status)

# For the response
protocols_keys = list(glob.QUIC_SERVER._protocols.keys())
tls_state = glob.QUIC_SERVER._protocols[protocols_keys[-1]]._quic.tls.state # TODO should be more modular, if if there is multiple protocols
prot = glob.QUIC_SERVER._protocols[protocols_keys[-1]]
hijacker = prot.hijacker
if not hijacker:
logger.debug(f"failed to hijack")
status = 400
return Response(content=b"failed to hijack",
headers=header,
status_code=status)
stream_creator = hijacker.stream_creator()
tls_state = stream_creator.connection_state()
logger.info(f"TLS state is {tls_state}")
# Check if connection is complete
if not tls_state == State.SERVER_POST_HANDSHAKE:
status = 425
return Response(content="",
headers=header,
status_code=status)

if tls_state != QuicConnectionState.CONNECTED:
logger.debug(f"Too early connection")
status = 400
return Response(content=b"Too early connection",
headers=header,
status_code=status)
# Create a new conversation
# Implement NewServerConversation based on your protocol's specifics
# From the request TODO
stream = await stream_creator.open_stream()
logger.info(f"Received stream {stream}")
conv = await new_server_conversation(
max_packet_size=glob.DEFAULT_MAX_PACKET_SIZE,
queue_size=10,
tls_state= tls_state
tls_state= tls_state,
control_stream=stream,
stream_creator=stream_creator,
)
logger.info(f"Created new conversation {conv}")
# Handle authentication
Expand All @@ -116,14 +135,14 @@ async def handle_auths(
if glob.ENABLE_PASSWORD_LOGIN and authorization.startswith("Basic "):
logger.info("Handling basic auth")
return await handle_basic_auth(request=request, conv=conv)
elif authorization.startswith("Bearer "):
elif authorization.startswith("Bearer "): # TODO
logger.info("Handling bearer auth")
username = request.headers.get(b":path").decode().split("?", 1)[0].lstrip("/")
conv_id = base64.b64encode(conv.id).decode()
return await handle_bearer_auth(username, conv_id)
else:
logger.info("Handling no auth")
header.append((b"www-authenticate", b"Basic"))
header[b"www-authenticate"] = b"Basic"
status = 401
return Response(content=content,
headers=header,
Expand Down
Loading

0 comments on commit f13bd96

Please sign in to comment.