diff --git a/example/auth.py b/example/auth.py index 3c64987..6d14dba 100755 --- a/example/auth.py +++ b/example/auth.py @@ -6,9 +6,11 @@ import sys import pyrad.packet -srv = Client(server="localhost", secret=b"Kah3choteereethiejeimaeziecumi", dict=Dictionary("dictionary")) +srv = Client(server="localhost", secret=b"Kah3choteereethiejeimaeziecumi", + dict=Dictionary("dictionary")) -req = srv.CreateAuthPacket(code=pyrad.packet.AccessRequest, User_Name="wichert") +req = srv.CreateAuthPacket(code=pyrad.packet.AccessRequest, + User_Name="wichert") req["NAS-IP-Address"] = "192.168.1.10" req["NAS-Port"] = 0 diff --git a/example/auth_async.py b/example/auth_async.py index 9ce4a41..5ae6c24 100644 --- a/example/auth_async.py +++ b/example/auth_async.py @@ -12,7 +12,7 @@ format="%(asctime)s [%(levelname)-8s] %(message)s") client = ClientAsync(server="localhost", secret=b"Kah3choteereethiejeimaeziecumi", - timeout=4, + timeout=3, debug=True, dict=Dictionary("dictionary")) loop = asyncio.get_event_loop() @@ -50,8 +50,8 @@ def test_auth1(): loop.run_until_complete( asyncio.ensure_future( client.initialize_transports(enable_auth=True, - local_addr='127.0.0.1', - local_auth_port=8000, + #local_addr='127.0.0.1', + #local_auth_port=8000, enable_acct=True, enable_coa=True))) @@ -117,15 +117,16 @@ def test_multi_auth(): asyncio.ensure_future( client.initialize_transports(enable_auth=True, local_addr='127.0.0.1', - local_auth_port=8000, + #local_auth_port=8000, enable_acct=True, enable_coa=True))) reqs = [] - for i in range(255): + for i in range(150): req = create_request(client, "user%s" % i) + print('CREATE REQUEST with id %d' % req.id) future = client.SendPacket(req) reqs.append(future) @@ -145,6 +146,7 @@ def test_multi_auth(): reply = future.result() print_reply(reply) + print('INVALID RESPONSE:', client.protocol_auth.errors) # Close transports loop.run_until_complete(asyncio.ensure_future( client.deinitialize_transports())) @@ -160,5 +162,84 @@ def test_multi_auth(): loop.close() +def test_multi_client(): + + clients = [] + n_clients = 73 + n_req4client = 50 + reqs = [] + + global loop + + try: + for i in range(n_clients): + client = ClientAsync(server="localhost", + secret=b"Kah3choteereethiejeimaeziecumi", + timeout=4, debug=True, + dict=Dictionary("dictionary"), + loop=loop) + + clients.append(client) + + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + enable_acct=False, + enable_coa=False))) + + # Send + for i in range(n_req4client): + req = create_request(client, "user%s" % i) + print('CREATE REQUEST with id %d' % req.id) + future = client.SendPacket(req) + reqs.append(future) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + *reqs, + return_exceptions=True + ) + + )) + + for future in reqs: + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + print_reply(reply) + + client = clients.pop() + while client: + + print('INVALID RESPONSE:', client.protocol_auth.errors) + print('RETRIES:', client.protocol_auth.retries_counter) + + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + del client + if len(clients) > 0: + client = clients.pop() + else: + client = None + + print('END') + except Exception as exc: + + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + + for client in clients: + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + + #test_multi_auth() -test_auth1() +#test_auth1() +test_multi_client() diff --git a/example/server_async.py b/example/server_async.py index 3b893da..d1f63ac 100644 --- a/example/server_async.py +++ b/example/server_async.py @@ -10,12 +10,13 @@ from pyrad.server import RemoteHost try: + # If available i try to use uvloop import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: pass -logging.basicConfig(level="DEBUG", +logging.basicConfig(level="INFO", format="%(asctime)s [%(levelname)-8s] %(message)s") class FakeServer(ServerAsync): @@ -103,6 +104,8 @@ def handle_disconnect_packet(self, protocol, pkt, addr): except KeyboardInterrupt as k: pass + print('STATS', server.stats()) + # Close transports loop.run_until_complete(asyncio.ensure_future( server.deinitialize_transports())) diff --git a/pyrad/client_async.py b/pyrad/client_async.py index de08917..c7b4c02 100644 --- a/pyrad/client_async.py +++ b/pyrad/client_async.py @@ -9,6 +9,7 @@ import six import logging import random +import traceback from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket @@ -24,6 +25,8 @@ def __init__(self, server, port, logger, self.retries = retries self.timeout = timeout self.client = client + self.errors = 0 + self.retries_counter = 0 # Map of pending requests self.pending_requests = {} @@ -58,7 +61,14 @@ async def __timeout_handler__(self): # Send again packet req['send_date'] = now req['retries'] += 1 - self.logger.debug('[%s:%d] For request %d execute retry %d', self.server, self.port, id, req['retries']) + self.retries_counter += 1 + self.logger.debug( + '[%s:%d] For request %d execute retry %d.' % ( + self.server, self.port, id, + req['retries'] + ) + ) + self.transport.sendto(req['packet'].RequestPacket()) elif next_weak_up > secs: next_weak_up = secs @@ -94,9 +104,9 @@ def connection_made(self, transport): socket = transport.get_extra_info('socket') self.logger.info( '[%s:%d] Transport created with binding in %s:%d', - self.server, self.port, - socket.getsockname()[0], - socket.getsockname()[1] + self.server, self.port, + socket.getsockname()[0], + socket.getsockname()[1] ) pre_loop = asyncio.get_event_loop() @@ -119,13 +129,13 @@ def connection_lost(self, exc): # noinspection PyUnusedLocal def datagram_received(self, data, addr): try: + reply = Packet(packet=data, dict=self.client.dict) if reply and reply.id in self.pending_requests: req = self.pending_requests[reply.id] packet = req['packet'] - reply.dict = packet.dict reply.secret = packet.secret if packet.VerifyReply(reply, data): @@ -133,12 +143,32 @@ def datagram_received(self, data, addr): # Remove request for map del self.pending_requests[reply.id] else: - self.logger.warn('[%s:%d] Ignore invalid reply for id %d. %s', self.server, self.port, reply.id) + self.logger.warn( + '[%s:%d] Received invalid reply for id %d. %s' % ( + self.server, self.port, reply.id, + 'Ignoring it.' + ) + ) + self.errors += 1 else: - self.logger.warn('[%s:%d] Ignore invalid reply: %d', self.server, self.port, data) + self.logger.warn( + '[%s:%d] Received invalid reply with id %d: %s.\nIgnoring it.' % ( + self.server, self.port, + (-1, reply.id)[reply is not None], + data.hex(), + ) + ) + self.errors += 1 except Exception as exc: - self.logger.error('[%s:%d] Error on decode packet: %s', self.server, self.port, exc) + self.logger.error( + '[%s:%d] Error on decode packet: %s.' % ( + self.server, self.port, + (exc, '\n'.join(traceback.format_exc().splitlines()))[ + self.client.debug + ] + ) + ) async def close_transport(self): if self.transport: @@ -177,7 +207,7 @@ class ClientAsync: def __init__(self, server, auth_port=1812, acct_port=1813, coa_port=3799, secret=six.b(''), dict=None, loop=None, retries=3, timeout=30, - logger_name='pyrad'): + logger_name='pyrad', debug=False): """Constructor. @@ -217,6 +247,8 @@ def __init__(self, server, auth_port=1812, acct_port=1813, self.protocol_coa = None self.coa_port = coa_port + self.debug = debug + async def initialize_transports(self, enable_acct=False, enable_auth=False, enable_coa=False, local_addr=None, local_auth_port=None, diff --git a/pyrad/server_async.py b/pyrad/server_async.py index 070754d..381f285 100644 --- a/pyrad/server_async.py +++ b/pyrad/server_async.py @@ -36,6 +36,7 @@ def __init__(self, ip, port, logger, server, server_type, hosts, self.hosts = hosts self.server_type = server_type self.request_callback = request_callback + self.requests = 0 def connection_made(self, transport): self.transport = transport @@ -48,73 +49,123 @@ def connection_lost(self, exc): self.logger.info('[%s:%d] Transport closed', self.ip, self.port) def send_response(self, reply, addr): + if self.server.debug: + self.logger.info( + '[%s:%d] Sending Response to %s packet: %s' % ( + self.ip, self.port, addr, reply.ReplyPacket().hex() + ) + ) self.transport.sendto(reply.ReplyPacket(), addr) + def __get_remote_host__(self, addr): + ans = None + if addr in self.hosts.keys(): + ans = self.hosts[addr] + return ans + def datagram_received(self, data, addr): self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, self.port, len(data), addr) receive_date = datetime.utcnow() - if addr[0] in self.hosts: - remote_host = self.hosts[addr[0]] - elif '0.0.0.0' in self.hosts: - remote_host = self.hosts['0.0.0.0'].secret - else: - self.logger.warn('[%s:%d] Drop package from unknown source %s', self.ip, self.port, addr) - return + remote_host = self.__get_remote_host__(addr[0]) - try: - self.logger.debug('[%s:%d] Received from %s packet: %s', self.ip, self.port, addr, data.hex()) - req = Packet(packet=data, dict=self.server.dict) - except Exception as exc: - self.logger.error('[%s:%d] Error on decode packet: %s', self.ip, self.port, exc) - return + if remote_host: - try: - if req.code in (AccountingResponse, AccessAccept, AccessReject, CoANAK, CoAACK, DisconnectNAK, DisconnectACK): - raise ServerPacketError('Invalid response packet %d' % req.code) - - elif self.server_type == ServerType.Auth: - if req.code != AccessRequest: - raise ServerPacketError('Received non-auth packet on auth port') - req = AuthPacket(secret=remote_host.secret, - dict=self.server.dict, - packet=data) - if self.server.enable_pkt_verify: - if req.VerifyAuthRequest(): - raise PacketError('Packet verification failed') - - elif self.server_type == ServerType.Coa: - if req.code != DisconnectRequest and req.code != CoARequest: - raise ServerPacketError('Received non-coa packet on coa port') - req = CoAPacket(secret=remote_host.secret, - dict=self.server.dict, - packet=data) - if self.server.enable_pkt_verify: - if req.VerifyCoARequest(): - raise PacketError('Packet verification failed') - - elif self.server_type == ServerType.Acct: - - if req.code != AccountingRequest: - raise ServerPacketError('Received non-acct packet on acct port') - req = AcctPacket(secret=remote_host.secret, - dict=self.server.dict, - packet=data) - if self.server.enable_pkt_verify: - if req.VerifyAcctRequest(): - raise PacketError('Packet verification failed') - - # Call request callback - self.request_callback(self, req, addr) - except Exception as exc: - if self.server.debug: - self.logger.exception('[%s:%d] Error for packet from %s', self.ip, self.port, addr) - else: - self.logger.error('[%s:%d] Error for packet from %s: %s', self.ip, self.port, addr, exc) + try: + if self.server.debug: + self.logger.info( + '[%s:%d] Received from %s packet: %s.' % ( + self.ip, self.port, addr, data.hex() + ) + ) + req = Packet(packet=data, dict=self.server.dict) + + except Exception as exc: + self.logger.error( + '[%s:%d] Error on decode packet: %s. Ignore it.' % ( + self.ip, self.port, exc + ) + ) + req = None + + if not req: + return + + try: + if req.code in ( + AccountingResponse, + AccessAccept, + AccessReject, + CoANAK, + CoAACK, + DisconnectNAK, + DisconnectACK): + raise ServerPacketError('Invalid response packet %d' % + req.code) + + elif self.server_type == ServerType.Auth: + + if req.code != AccessRequest: + raise ServerPacketError( + 'Received not-authentication packet ' + 'on authentication port') + req = AuthPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + + elif self.server_type == ServerType.Coa: + + if req.code != DisconnectRequest and \ + req.code != CoARequest: + raise ServerPacketError( + 'Received not-coa packet on coa port' + ) + req = CoAPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + if self.server.enable_pkt_verify: + if not req.VerifyCoARequest(): + raise PacketError('Packet verification failed') + + elif self.server_type == ServerType.Acct: + + if req.code != AccountingRequest: + raise ServerPacketError( + 'Received not-accounting packet on ' + 'accounting port' + ) + req = AcctPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + + if self.server.enable_pkt_verify: + if not req.VerifyAcctRequest(): + raise PacketError('Packet verification failed') + + # Call request callback + self.request_callback(self, req, addr) + + self.requests += 1 + + except Exception as e: + self.logger.error( + '[%s:%d] Unexpected error for packet from %s: %s' % ( + self.ip, self.port, addr, + (e, '\n'.join(traceback.format_exc().splitlines()))[ + self.server.debug + ] + ) + ) + + else: + self.logger.error('[%s:%d] Drop package from unknown source %s', + self.ip, self.port, addr) process_date = datetime.utcnow() - self.logger.debug('[%s:%d] Request from %s processed in %d ms', self.ip, self.port, addr, (process_date-receive_date).microseconds/1000) + self.logger.debug('[%s:%d] Request from %s processed in %d ms', + self.ip, self.port, addr, + (process_date-receive_date).microseconds/1000) def error_received(self, exc): self.logger.error('[%s:%d] Error received: %s', self.ip, self.port, exc) @@ -295,6 +346,18 @@ async def initialize_transports(self, enable_acct=False, loop=self.loop ) + def stats(self): + ans = {} + + for proto in self.coa_protocols: + ans['%s-%s' % (proto.ip, proto.port)] = proto.requests + for proto in self.auth_protocols: + ans['%s-%s' % (proto.ip, proto.port)] = proto.requests + for proto in self.acct_protocols: + ans['%s-%s' % (proto.ip, proto.port)] = proto.requests + + return ans + # noinspection SpellCheckingInspection async def deinitialize_transports(self, deinit_coa=True, deinit_auth=True, deinit_acct=True):