Skip to content

Commit

Permalink
update server side
Browse files Browse the repository at this point in the history
  • Loading branch information
ElNiak committed Dec 24, 2023
1 parent 513a9d0 commit 7f1d9a7
Show file tree
Hide file tree
Showing 12 changed files with 493 additions and 227 deletions.
8 changes: 8 additions & 0 deletions py-ssh3/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ def auth_hint(self) -> str:

def __str__(self):
return "raw-bearer-identity"

def build_jwt_bearer_token(key, username: str, conversation) -> str:
# Implement JWT token generation logic
pass

def get_config_for_host(host: str) -> Tuple[str, int, str, List]:
# Parse SSH config for the given host
pass
33 changes: 0 additions & 33 deletions py-ssh3/http3/quic_client.py

This file was deleted.

80 changes: 45 additions & 35 deletions py-ssh3/linux_server/auth.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,50 @@
import base64
import logging
from functools import wraps

def handle_auths(enable_password_login, default_max_packet_size, authenticated_handler_func):
def auth_decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# Version checking logic (placeholder)
user_agent = request.headers.get('User-Agent')
logging.debug(f"Received request from User-Agent: {user_agent}")

# Add more version checking and QUIC connection logic here

authorization = request.headers.get('Authorization')
if enable_password_login and authorization.startswith("Basic "):
return handle_basic_auth(authenticated_handler_func)(*args, **kwargs)
elif authorization.startswith("Bearer "):
# Additional logic for Bearer token
# Placeholder for bearer token handling
pass
else:
return Response(status=401) # Unauthorized

return decorated_function
return auth_decorator

def handle_basic_auth(handler_func):
def basic_auth_decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
auth = request.authorization
if not auth or not check_credentials(auth.username, auth.password):
return Response(status=401) # Unauthorized
return handler_func(auth.username, *args, **kwargs)
return decorated_function
return basic_auth_decorator
from aioquic.asyncio import serve
from aioquic.asyncio.protocol import QuicConnectionProtocol
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.events import HandshakeCompleted
from aioquic.asyncio.server import HttpRequestHandler, HttpServerProtocol, Route
from aioquic.quic.events import ProtocolNegotiated
from typing import Callable
import util.linux_util as linux_util

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


def handle_auths(enablePasswordLogin: bool, defaultMaxPacketSize: int) -> Callable:
async def handle_request(handler: HttpRequestHandler, event: ProtocolNegotiated):
request = handler._http_request_received
logger.debug(f"Received request from User-Agent {request.headers.get('user-agent')}")

# Add your version check and logic here

if not handler._quic._is_handshake_complete:
handler._quic.send_response(status_code=425) # 425 Too Early
return

# Process the request and perform authentication
authorization = request.headers.get('authorization')
if enablePasswordLogin and authorization.startswith('Basic '):
await handle_basic_auth(handler, request)
elif authorization.startswith('Bearer '):
# Handle bearer authentication
pass
else:
handler._quic.send_response(status_code=401) # 401 Unauthorized

return handle_request


def handle_basic_auth(handler: HttpRequestHandler, request):
auth = request.headers.get('authorization')
username, password = base64.b64decode(auth.split(' ')[1]).decode().split(':')
if not linux_util.UserPasswordAuthentication(username, password):
handler._quic.send_response(status_code=401) # 401 Unauthorized
return

# Continue with the authenticated request processing

def check_credentials(username, password):
# Placeholder for checking username and password
Expand Down
32 changes: 1 addition & 31 deletions py-ssh3/message/channel_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,7 @@
from message import message
from typing import Tuple
import ipaddress

class ChannelRequestMessage:
def __init__(self, want_reply, channel_request):
self.want_reply = want_reply
self.channel_request = channel_request

def length(self):
# msg type + request type + wantReply + request content
return len(util.var_int_len(message.SSH_MSG_CHANNEL_REQUEST)) + \
util.ssh_string_len(self.channel_request.request_type_str()) + 1 + \
self.channel_request.length()

def write(self, buf):
if len(buf) < self.length():
raise ValueError(f"Buffer too small to write message for channel request of type {type(self.channel_request)}: {len(buf)} < {self.length()}")

consumed = 0
msg_type_buf = util.append_var_int(None, message.SSH_MSG_CHANNEL_REQUEST)
buf[consumed:consumed+len(msg_type_buf)] = msg_type_buf
consumed += len(msg_type_buf)

n = util.write_ssh_string(buf[consumed:], self.channel_request.request_type_str())
consumed += n

buf[consumed] = 1 if self.want_reply else 0
consumed += 1

n = self.channel_request.write(buf[consumed:])
consumed += n

return consumed
from message.message import ChannelRequestMessage

def parse_request_message(buf):
request_type, err = util.parse_ssh_string(buf)
Expand Down
34 changes: 33 additions & 1 deletion py-ssh3/message/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import struct
from util import parse_ssh_string
from util.util import parse_ssh_string, ssh_string_len, var_int_len, append_var_int, write_ssh_string

# Constants for SSH message types
SSH_MSG_DISCONNECT = 1
Expand Down Expand Up @@ -42,6 +42,38 @@ def write(self, buf):
def length(self):
pass

class ChannelRequestMessage(Message):
def __init__(self, want_reply, channel_request):
self.want_reply = want_reply
self.channel_request = channel_request

def length(self):
# msg type + request type + wantReply + request content
return len(var_int_len(SSH_MSG_CHANNEL_REQUEST)) + \
ssh_string_len(self.channel_request.request_type_str()) + 1 + \
self.channel_request.length()

def write(self, buf):
if len(buf) < self.length():
raise ValueError(f"Buffer too small to write message for channel request of type {type(self.channel_request)}: {len(buf)} < {self.length()}")

consumed = 0
msg_type_buf = append_var_int(None, SSH_MSG_CHANNEL_REQUEST)
buf[consumed:consumed+len(msg_type_buf)] = msg_type_buf
consumed += len(msg_type_buf)

n = write_ssh_string(buf[consumed:], self.channel_request.request_type_str())
consumed += n

buf[consumed] = 1 if self.want_reply else 0
consumed += 1

n = self.channel_request.write(buf[consumed:])
consumed += n

return consumed


class ChannelOpenConfirmationMessage(Message):
def __init__(self, max_packet_size):
self.max_packet_size = max_packet_size
Expand Down
13 changes: 8 additions & 5 deletions py-ssh3/server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import asyncio
from aiohttp import web
from aioquic.asyncio import QuicConnectionProtocol
import logging


class Server:
def __init__(self, max_packet_size, default_datagram_queue_size, h3_server, conversation_handler):
class SSH3Server:
def __init__(self, max_packet_size,
default_datagram_queue_size,
h3_server,
conversation_handler):
self.max_packet_size = max_packet_size
self.h3_server = h3_server
self.conversations = {} # Map of StreamCreator to ConversationManager
self.conversation_handler = conversation_handler
self.lock = asyncio.Lock()



async def get_conversations_manager(self, stream_creator):
async with self.lock:
Expand All @@ -26,7 +29,7 @@ async def remove_connection(self, stream_creator):
async with self.lock:
self.conversations.pop(stream_creator, None)

def get_http_handler_func(self, context):
async def get_http_handler_func(self):
async def handler(request):
logging.info(f"Got request: method: {request.method}, URL: {request.url}")
# Handle the request logic here
Expand Down
Loading

0 comments on commit 7f1d9a7

Please sign in to comment.