From dea137c1d87157cd95ed9ab37042fa14e476ec99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Magalh=C3=A3es?= Date: Mon, 22 Apr 2024 15:56:24 +0100 Subject: [PATCH] refactor: re-formats code according to black rules Also removed extra comments in the file headers. Should make merging of existing PRs messy. --- examples/basic/future.py | 28 +- examples/basic/future_neo.py | 28 +- examples/basic/future_old.py | 31 +- examples/basic/loop.py | 14 +- examples/basic/loop_asyncio.py | 14 +- examples/basic/loop_neo.py | 14 +- examples/basic/loop_old.py | 20 +- examples/basic/sum.py | 13 +- examples/basic/sum_gen.py | 14 +- examples/basic/sum_mix.py | 16 +- examples/basic/sum_neo.py | 13 +- examples/echo/echoc_udp.py | 16 +- examples/echo/echos_udp.py | 16 +- examples/http/http_aiohttp.py | 15 +- examples/http/http_aiohttp_neo.py | 16 +- examples/http/http_asyncio.py | 25 +- examples/http/http_players.py | 45 +- setup.py | 58 +- src/netius/__init__.py | 9 - src/netius/adapters/__init__.py | 9 - src/netius/adapters/base.py | 54 +- src/netius/adapters/fs.py | 53 +- src/netius/adapters/memory.py | 25 +- src/netius/adapters/mongo.py | 12 +- src/netius/adapters/null.py | 10 +- src/netius/auth/__init__.py | 9 - src/netius/auth/address.py | 30 +- src/netius/auth/allow.py | 10 +- src/netius/auth/base.py | 78 +- src/netius/auth/deny.py | 10 +- src/netius/auth/dummy.py | 20 +- src/netius/auth/memory.py | 34 +- src/netius/auth/passwd.py | 41 +- src/netius/auth/simple.py | 29 +- src/netius/base/__init__.py | 106 +- src/netius/base/agent.py | 20 +- src/netius/base/async_neo.py | 87 +- src/netius/base/async_old.py | 168 +-- src/netius/base/asynchronous.py | 14 +- src/netius/base/client.py | 464 ++++--- src/netius/base/common.py | 1605 +++++++++++++----------- src/netius/base/compat.py | 320 +++-- src/netius/base/config.py | 168 +-- src/netius/base/conn.py | 345 +++-- src/netius/base/container.py | 57 +- src/netius/base/diag.py | 42 +- src/netius/base/errors.py | 25 +- src/netius/base/legacy.py | 553 +++++--- src/netius/base/log.py | 65 +- src/netius/base/observer.py | 47 +- src/netius/base/poll.py | 268 ++-- src/netius/base/protocol.py | 158 ++- src/netius/base/request.py | 17 +- src/netius/base/server.py | 558 ++++---- src/netius/base/stream.py | 24 +- src/netius/base/tls.py | 84 +- src/netius/base/transport.py | 111 +- src/netius/base/util.py | 18 +- src/netius/clients/__init__.py | 9 - src/netius/clients/apn.py | 94 +- src/netius/clients/dht.py | 80 +- src/netius/clients/dns.py | 220 ++-- src/netius/clients/http.py | 982 +++++++-------- src/netius/clients/mjpg.py | 42 +- src/netius/clients/raw.py | 28 +- src/netius/clients/smtp.py | 277 ++-- src/netius/clients/ssdp.py | 106 +- src/netius/clients/torrent.py | 84 +- src/netius/clients/ws.py | 86 +- src/netius/common/__init__.py | 157 ++- src/netius/common/asn.py | 83 +- src/netius/common/calc.py | 53 +- src/netius/common/dhcp.py | 93 +- src/netius/common/dkim.py | 78 +- src/netius/common/ftp.py | 32 +- src/netius/common/geo.py | 107 +- src/netius/common/http.py | 290 +++-- src/netius/common/http2.py | 743 +++++------ src/netius/common/mime.py | 84 +- src/netius/common/parser.py | 10 +- src/netius/common/pop.py | 24 +- src/netius/common/rsa.py | 251 ++-- src/netius/common/setup.py | 44 +- src/netius/common/smtp.py | 32 +- src/netius/common/socks.py | 95 +- src/netius/common/stream.py | 42 +- src/netius/common/structures.py | 19 +- src/netius/common/style.py | 9 - src/netius/common/tftp.py | 19 +- src/netius/common/tls.py | 26 +- src/netius/common/torrent.py | 84 +- src/netius/common/util.py | 192 +-- src/netius/common/ws.py | 33 +- src/netius/examples/__init__.py | 9 - src/netius/examples/http.py | 16 +- src/netius/examples/upnp.py | 33 +- src/netius/extra/__init__.py | 9 - src/netius/extra/desktop.py | 22 +- src/netius/extra/dhcp_s.py | 70 +- src/netius/extra/file.py | 537 ++++---- src/netius/extra/filea.py | 37 +- src/netius/extra/hello.py | 45 +- src/netius/extra/hello_w.py | 17 +- src/netius/extra/proxy_d.py | 87 +- src/netius/extra/proxy_f.py | 66 +- src/netius/extra/proxy_r.py | 506 ++++---- src/netius/extra/smtp_r.py | 107 +- src/netius/middleware/__init__.py | 9 - src/netius/middleware/annoyer.py | 22 +- src/netius/middleware/base.py | 10 +- src/netius/middleware/blacklist.py | 22 +- src/netius/middleware/dummy.py | 10 +- src/netius/middleware/flood.py | 28 +- src/netius/middleware/proxy.py | 98 +- src/netius/mock/__init__.py | 9 - src/netius/mock/appier.py | 11 +- src/netius/pool/__init__.py | 18 +- src/netius/pool/common.py | 154 ++- src/netius/pool/file.py | 43 +- src/netius/pool/notify.py | 12 +- src/netius/pool/task.py | 25 +- src/netius/servers/__init__.py | 9 - src/netius/servers/dhcp.py | 76 +- src/netius/servers/echo.py | 14 +- src/netius/servers/echo_ws.py | 13 +- src/netius/servers/ftp.py | 318 +++-- src/netius/servers/http.py | 634 ++++------ src/netius/servers/http2.py | 884 ++++++------- src/netius/servers/mjpg.py | 43 +- src/netius/servers/pop.py | 142 ++- src/netius/servers/proxy.py | 318 +++-- src/netius/servers/smtp.py | 176 ++- src/netius/servers/socks.py | 112 +- src/netius/servers/tftp.py | 89 +- src/netius/servers/torrent.py | 325 ++--- src/netius/servers/ws.py | 72 +- src/netius/servers/wsgi.py | 162 +-- src/netius/sh/__init__.py | 9 - src/netius/sh/auth.py | 15 +- src/netius/sh/base.py | 12 +- src/netius/sh/dkim.py | 29 +- src/netius/sh/rsa.py | 13 +- src/netius/sh/smtp.py | 38 +- src/netius/test/__init__.py | 9 - src/netius/test/auth/__init__.py | 9 - src/netius/test/auth/allow.py | 10 +- src/netius/test/auth/deny.py | 10 +- src/netius/test/auth/simple.py | 10 +- src/netius/test/base/__init__.py | 9 - src/netius/test/base/asynchronous.py | 18 +- src/netius/test/base/common.py | 10 +- src/netius/test/base/config.py | 24 +- src/netius/test/base/tls.py | 22 +- src/netius/test/base/transport.py | 10 +- src/netius/test/clients/__init__.py | 9 - src/netius/test/clients/http.py | 44 +- src/netius/test/common/__init__.py | 9 - src/netius/test/common/calc.py | 10 +- src/netius/test/common/dkim.py | 18 +- src/netius/test/common/http.py | 102 +- src/netius/test/common/mime.py | 10 +- src/netius/test/common/rsa.py | 10 +- src/netius/test/common/setup.py | 12 +- src/netius/test/common/util.py | 70 +- src/netius/test/extra/__init__.py | 9 - src/netius/test/extra/proxy_r.py | 19 +- src/netius/test/extra/smtp_r.py | 20 +- src/netius/test/middleware/__init__.py | 9 - src/netius/test/middleware/proxy.py | 51 +- src/netius/test/pool/__init__.py | 9 - src/netius/test/pool/common.py | 10 +- src/netius/test/servers/__init__.py | 9 - src/netius/test/servers/http.py | 58 +- 173 files changed, 8472 insertions(+), 8453 deletions(-) diff --git a/examples/basic/future.py b/examples/basic/future.py index f7f00fe39..ba9b790ed 100644 --- a/examples/basic/future.py +++ b/examples/basic/future.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,24 +32,25 @@ import netius -def set(future, raise_e = True): - if raise_e: future.set_exception(Exception("Awaiting error")) - else: future.set_result(42) + +def set(future, raise_e=True): + if raise_e: + future.set_exception(Exception("Awaiting error")) + else: + future.set_result(42) + @netius.coroutine def await_forever(): print("Awaiting forever") future = netius.build_future() - thread = threading.Thread( - target = set, - args = (future,), - kwargs = dict(raise_e = True) - ) + thread = threading.Thread(target=set, args=(future,), kwargs=dict(raise_e=True)) thread.start() result = yield from future return result -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) result = loop.run_until_complete(await_forever()) loop.close() diff --git a/examples/basic/future_neo.py b/examples/basic/future_neo.py index 84373e208..9e680704d 100644 --- a/examples/basic/future_neo.py +++ b/examples/basic/future_neo.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,22 +32,23 @@ import netius -def set(future, raise_e = True): - if raise_e: future.set_exception(Exception("Awaiting error")) - else: future.set_result(42) + +def set(future, raise_e=True): + if raise_e: + future.set_exception(Exception("Awaiting error")) + else: + future.set_result(42) + async def await_forever(): print("Awaiting forever") future = netius.build_future() - thread = threading.Thread( - target = set, - args = (future,), - kwargs = dict(raise_e = True) - ) + thread = threading.Thread(target=set, args=(future,), kwargs=dict(raise_e=True)) thread.start() return await future -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) result = loop.run_until_complete(await_forever()) loop.close() diff --git a/examples/basic/future_old.py b/examples/basic/future_old.py index c07040636..c63263d35 100644 --- a/examples/basic/future_old.py +++ b/examples/basic/future_old.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,23 +32,25 @@ import netius -def set(future, raise_e = True): - if raise_e: future.set_exception(Exception("Awaiting error")) - else: future.set_result(42) + +def set(future, raise_e=True): + if raise_e: + future.set_exception(Exception("Awaiting error")) + else: + future.set_result(42) + @netius.coroutine def await_forever(): print("Awaiting forever") future = netius.build_future() - thread = threading.Thread( - target = set, - args = (future,), - kwargs = dict(raise_e = True) - ) + thread = threading.Thread(target=set, args=(future,), kwargs=dict(raise_e=True)) thread.start() - for value in future: yield value + for value in future: + yield value + -loop = netius.get_loop(_compat = True) +loop = netius.get_loop(_compat=True) result = loop.run_until_complete(await_forever()) loop.close() diff --git a/examples/basic/loop.py b/examples/basic/loop.py index 7227fdfd0..99901a680 100644 --- a/examples/basic/loop.py +++ b/examples/basic/loop.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,17 +30,20 @@ import netius + @netius.coroutine def compute(x, y): print("Compute %s + %s ..." % (x, y)) yield from netius.sleep(1.0) return x + y + @netius.coroutine def print_sum(x, y): result = yield from compute(x, y) print("%s + %s = %s" % (x, y, result)) -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) loop.run_until_complete(print_sum(1, 2)) loop.close() diff --git a/examples/basic/loop_asyncio.py b/examples/basic/loop_asyncio.py index 82eb7f39a..7632d5e67 100644 --- a/examples/basic/loop_asyncio.py +++ b/examples/basic/loop_asyncio.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,17 +32,20 @@ import netius + @asyncio.coroutine def compute(x, y): print("Compute %s + %s ..." % (x, y)) yield from asyncio.sleep(1.0) return x + y + @asyncio.coroutine def print_sum(x, y): result = yield from compute(x, y) print("%s + %s = %s" % (x, y, result)) -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) loop.run_until_complete(print_sum(1, 2)) loop.close() diff --git a/examples/basic/loop_neo.py b/examples/basic/loop_neo.py index 04e447413..cf2c9a35f 100644 --- a/examples/basic/loop_neo.py +++ b/examples/basic/loop_neo.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,15 +30,18 @@ import netius + async def compute(x, y): print("Compute %s + %s ..." % (x, y)) await netius.sleep(1.0) return x + y + async def print_sum(x, y): result = await compute(x, y) print("%s + %s = %s" % (x, y, result)) -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) loop.run_until_complete(print_sum(1, 2)) loop.close() diff --git a/examples/basic/loop_old.py b/examples/basic/loop_old.py index af4e64b4d..d72bce7ba 100644 --- a/examples/basic/loop_old.py +++ b/examples/basic/loop_old.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,19 +30,24 @@ import netius + @netius.coroutine def compute(future, x, y): print("Compute %s + %s ..." % (x, y)) - for value in netius.sleep(1.0): yield value + for value in netius.sleep(1.0): + yield value future.set_result(x + y) + @netius.coroutine def print_sum(x, y): future = netius.build_future() - for value in compute(future, x, y): yield value + for value in compute(future, x, y): + yield value result = future.result() print("%s + %s = %s" % (x, y, result)) -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) loop.run_until_complete(print_sum(1, 2)) loop.close() diff --git a/examples/basic/sum.py b/examples/basic/sum.py index c2e45243a..f3e98c8f7 100644 --- a/examples/basic/sum.py +++ b/examples/basic/sum.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,13 +30,15 @@ import netius + @netius.coroutine def compute(x, y): print("Compute %s + %s ..." % (x, y)) yield from netius.sleep(1.0) return x + y -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) result = loop.run_until_complete(compute(1, 2)) loop.close() diff --git a/examples/basic/sum_gen.py b/examples/basic/sum_gen.py index efdd8b172..32b98f048 100644 --- a/examples/basic/sum_gen.py +++ b/examples/basic/sum_gen.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,18 +30,21 @@ import netius + async def compute(x, y): result = None async for value in _compute(x, y): result = value return result + async def _compute(x, y): print("Compute %s + %s ..." % (x, y)) await netius.sleep(1.0) yield x + y -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) result = loop.run_until_complete(compute(1, 2)) loop.close() diff --git a/examples/basic/sum_mix.py b/examples/basic/sum_mix.py index c157e7bca..bd7f9c597 100644 --- a/examples/basic/sum_mix.py +++ b/examples/basic/sum_mix.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,8 +30,10 @@ import netius + async def compute(x, y): - return (await _compute(x, y)) + return await _compute(x, y) + @netius.coroutine def _compute(x, y): @@ -48,7 +41,8 @@ def _compute(x, y): yield from netius.sleep(1.0) return x + y -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) result = loop.run_until_complete(compute(1, 2)) loop.close() diff --git a/examples/basic/sum_neo.py b/examples/basic/sum_neo.py index 786861330..cb36f6caa 100644 --- a/examples/basic/sum_neo.py +++ b/examples/basic/sum_neo.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,12 +30,14 @@ import netius + async def compute(x, y): print("Compute %s + %s ..." % (x, y)) await netius.sleep(1.0) return x + y -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) result = loop.run_until_complete(compute(1, 2)) loop.close() diff --git a/examples/echo/echoc_udp.py b/examples/echo/echoc_udp.py index a5551f2ac..495d3750c 100644 --- a/examples/echo/echoc_udp.py +++ b/examples/echo/echoc_udp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius + class EchoClientProtocol(object): def __init__(self, message, loop): @@ -66,12 +58,12 @@ def connection_lost(self, exc): loop = asyncio.get_event_loop() loop.stop() + message = "Hello World!" -loop = netius.get_loop(_compat = True) +loop = netius.get_loop(_compat=True) connect = loop.create_datagram_endpoint( - lambda: EchoClientProtocol(message, loop), - remote_addr = ("127.0.0.1", 9999) + lambda: EchoClientProtocol(message, loop), remote_addr=("127.0.0.1", 9999) ) transport, protocol = loop.run_until_complete(connect) loop.run_forever() diff --git a/examples/echo/echos_udp.py b/examples/echo/echos_udp.py index 41ccefc7f..1dcba7da2 100644 --- a/examples/echo/echos_udp.py +++ b/examples/echo/echos_udp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ import netius + class EchoServerProtocol(object): def connection_made(self, transport): @@ -51,12 +43,12 @@ def datagram_received(self, data, addr): print("Send %r to %s" % (message, addr)) self.transport.sendto(data, addr) + print("Starting UDP server") -loop = netius.get_loop(_compat = True) +loop = netius.get_loop(_compat=True) listen = loop.create_datagram_endpoint( - EchoServerProtocol, - local_addr = ("127.0.0.1", 9999) + EchoServerProtocol, local_addr=("127.0.0.1", 9999) ) transport, protocol = loop.run_until_complete(listen) diff --git a/examples/http/http_aiohttp.py b/examples/http/http_aiohttp.py index 9028464a5..1ecd36e6c 100644 --- a/examples/http/http_aiohttp.py +++ b/examples/http/http_aiohttp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -44,9 +35,10 @@ netius.verify( int(aiohttp.__version__[0]) < 3, - message = "Requires legacy (2.x.x or older) version of aiohttp" + message="Requires legacy (2.x.x or older) version of aiohttp", ) + @asyncio.coroutine def print_http(url): response = yield from aiohttp.request("GET", url) @@ -54,6 +46,7 @@ def print_http(url): data = yield from response.read() print(data) -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) loop.run_until_complete(print_http("https://www.flickr.com/")) loop.close() diff --git a/examples/http/http_aiohttp_neo.py b/examples/http/http_aiohttp_neo.py index a4fc83502..428399883 100644 --- a/examples/http/http_aiohttp_neo.py +++ b/examples/http/http_aiohttp_neo.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,16 +32,19 @@ import netius + async def print_http(session, url): async with session.get(url) as response: print(response.status) data = await response.read() print(data) + async def go(loop, url): - async with aiohttp.ClientSession(loop = loop) as session: + async with aiohttp.ClientSession(loop=loop) as session: await print_http(session, url) -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) loop.run_until_complete(go(loop, "https://www.flickr.com/")) loop.close() diff --git a/examples/http/http_asyncio.py b/examples/http/http_asyncio.py index 7197125dc..f8967603a 100644 --- a/examples/http/http_asyncio.py +++ b/examples/http/http_asyncio.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,29 +34,33 @@ import netius + @asyncio.coroutine -def print_http_headers(url, encoding = "utf-8"): +def print_http_headers(url, encoding="utf-8"): url = urllib.parse.urlsplit(url) if url.scheme == "https": - connect = asyncio.open_connection(url.hostname, 443, ssl = True) + connect = asyncio.open_connection(url.hostname, 443, ssl=True) else: connect = asyncio.open_connection(url.hostname, 80) reader, writer = yield from connect query = "HEAD {path} HTTP/1.0\r\n" + "Host: {hostname}\r\n" + "\r\n" - query = query.format(path = url.path or "/", hostname = url.hostname) + query = query.format(path=url.path or "/", hostname=url.hostname) writer.write(query.encode(encoding)) while True: line = yield from reader.readline() - if not line: break + if not line: + break line = line.decode(encoding).rstrip() - if line: print(line) + if line: + print(line) writer.close() -loop = netius.get_loop(_compat = True) + +loop = netius.get_loop(_compat=True) task = asyncio.ensure_future(print_http_headers("https://www.flickr.com/")) loop.run_until_complete(task) loop.close() diff --git a/examples/http/http_players.py b/examples/http/http_players.py index a3307348d..dc39ac879 100644 --- a/examples/http/http_players.py +++ b/examples/http/http_players.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -52,54 +43,48 @@ "AppleWebKit/537.36 (KHTML, like Gecko) " "Chrome/45.0.2454.101 Safari/537.36" ), - "connection": ("keep-alive") + "connection": ("keep-alive"), } -async def get_players(player_args, season = "2016-17"): + +async def get_players(player_args, season="2016-17"): endpoint = "/commonallplayers" - params = dict( - season = season, - leagueid = "00", - isonlycurrentseason = "1" - ) + params = dict(season=season, leagueid="00", isonlycurrentseason="1") url = BASE_URL + endpoint print("Getting all players for season %s ..." % season) async with aiohttp.ClientSession() as session: - async with session.get(url, headers = HEADERS, params = params) as resp: + async with session.get(url, headers=HEADERS, params=params) as resp: data = await resp.json() - player_args.extend( - [(item[0], item[2]) for item in data["resultSets"][0]["rowSet"]]) + player_args.extend([(item[0], item[2]) for item in data["resultSets"][0]["rowSet"]]) + async def get_player(player_id, player_name): endpoint = "/commonplayerinfo" - params = dict(playerid = player_id) + params = dict(playerid=player_id) url = BASE_URL + endpoint print("Getting player %s" % player_name) async with aiohttp.ClientSession() as session: - async with session.get(url, headers = HEADERS, params = params) as resp: + async with session.get(url, headers=HEADERS, params=params) as resp: data = await resp.text() print(data) async with aiofiles.open( - "players/%s.json" % player_name.replace(" ", "_"), "w" - ) as file: + "players/%s.json" % player_name.replace(" ", "_"), "w" + ) as file: await file.write(data) -loop = netius.get_loop(_compat = True) -os.makedirs("players", exist_ok = True) +loop = netius.get_loop(_compat=True) + +os.makedirs("players", exist_ok=True) player_args = [] loop.run_until_complete(get_players(player_args)) -loop.run_until_complete( - asyncio.gather( - *(get_player(*args) for args in player_args) - ) -) +loop.run_until_complete(asyncio.gather(*(get_player(*args) for args in player_args))) loop.close() diff --git a/setup.py b/setup.py index 9bff2e092..ed23af63c 100644 --- a/setup.py +++ b/setup.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -50,24 +41,29 @@ import netius.common + def read_file(path): - if not os.path.exists(path): return None + if not os.path.exists(path): + return None file = open(path, "r") - try: return file.read() - finally: file.close() + try: + return file.read() + finally: + file.close() + netius.common.ensure_setup() setuptools.setup( - name = "netius", - version = "1.19.3", - author = "Hive Solutions Lda.", - author_email = "development@hive.pt", - description = "Netius System", - license = "Apache License, Version 2.0", - keywords = "netius net infrastructure wsgi", - url = "http://netius.hive.pt", - zip_safe = False, - packages = [ + name="netius", + version="1.19.3", + author="Hive Solutions Lda.", + author_email="development@hive.pt", + description="Netius System", + license="Apache License, Version 2.0", + keywords="netius net infrastructure wsgi", + url="http://netius.hive.pt", + zip_safe=False, + packages=[ "netius", "netius.adapters", "netius.auth", @@ -81,16 +77,12 @@ def read_file(path): "netius.pool", "netius.servers", "netius.sh", - "netius.test" + "netius.test", ], - test_suite = "netius.test", - package_dir = { - "" : os.path.normpath("src") - }, - package_data = { - "netius" : ["base/extras/*", "extra/extras/*", "servers/extras/*"] - }, - classifiers = [ + test_suite="netius.test", + package_dir={"": os.path.normpath("src")}, + package_data={"netius": ["base/extras/*", "extra/extras/*", "servers/extras/*"]}, + classifiers=[ "Development Status :: 5 - Production/Stable", "Topic :: Utilities", "License :: OSI Approved :: Apache Software License", @@ -110,7 +102,7 @@ def read_file(path): "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12" + "Programming Language :: Python :: 3.12", ], - long_description = read_file("README.rst") + long_description=read_file("README.rst"), ) diff --git a/src/netius/__init__.py b/src/netius/__init__.py index 725547ba5..877573b33 100644 --- a/src/netius/__init__.py +++ b/src/netius/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/adapters/__init__.py b/src/netius/adapters/__init__.py index 5bbd19532..4a9c70194 100644 --- a/src/netius/adapters/__init__.py +++ b/src/netius/adapters/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/adapters/base.py b/src/netius/adapters/base.py index 0629b12d9..9650845e4 100644 --- a/src/netius/adapters/base.py +++ b/src/netius/adapters/base.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,6 +34,7 @@ import netius + class BaseAdapter(object): """ Top level abstract representation of a netius adapter. @@ -52,29 +44,34 @@ class BaseAdapter(object): storage, torrent hash table storage, sessions, etc.) """ - def set(self, value, owner = "nobody"): + def set(self, value, owner="nobody"): pass def get(self, key): file = self.get_file(key) - if not file: return file - try: value = file.read() - finally: file.close() + if not file: + return file + try: + value = file.read() + finally: + file.close() return value - def get_file(self, key, mode = "rb"): + def get_file(self, key, mode="rb"): return netius.legacy.StringIO() - def delete(self, key, owner = "nobody"): + def delete(self, key, owner="nobody"): pass def append(self, key, value): - file = self.get_file(key, mode = "ab") - try: file.write(value) - finally: file.close() + file = self.get_file(key, mode="ab") + try: + file.write(value) + finally: + file.close() def truncate(self, key, count): - file = self.get_file(key, mode = "rb+") + file = self.get_file(key, mode="rb+") try: offset = count * -1 file.seek(offset, os.SEEK_END) @@ -85,24 +82,25 @@ def truncate(self, key, count): def size(self, key): return 0 - def sizes(self, owner = None): - list = self.list(owner = owner) + def sizes(self, owner=None): + list = self.list(owner=owner) sizes = [self.size(key) for key in list] return sizes - def total(self, owner = None): + def total(self, owner=None): total = 0 - list = self.list(owner = owner) - for key in list: total += self.size(key) + list = self.list(owner=owner) + for key in list: + total += self.size(key) return total - def reserve(self, owner = "nobody"): - return self.set("", owner = owner) + def reserve(self, owner="nobody"): + return self.set("", owner=owner) - def count(self, owner = None): + def count(self, owner=None): return 0 - def list(self, owner = None): + def list(self, owner=None): return () def generate(self): diff --git a/src/netius/adapters/fs.py b/src/netius/adapters/fs.py index c71271fb5..48ff22d57 100644 --- a/src/netius/adapters/fs.py +++ b/src/netius/adapters/fs.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -44,9 +35,10 @@ from . import base + class FsAdapter(base.BaseAdapter): - def __init__(self, base_path = None): + def __init__(self, base_path=None): base.BaseAdapter.__init__(self) self.base_path = base_path or "fs.data" self.base_path = os.path.abspath(self.base_path) @@ -54,24 +46,26 @@ def __init__(self, base_path = None): if not os.path.exists(self.base_path): os.makedirs(self.base_path) - def set(self, value, owner = "nobody"): + def set(self, value, owner="nobody"): key = self.generate() owner_path = self._ensure(owner) file_path = os.path.join(self.base_path, key) link_path = os.path.join(owner_path, key) value = netius.legacy.bytes(value) file = open(file_path, "wb") - try: file.write(value) - finally: file.close() + try: + file.write(value) + finally: + file.close() self._symlink(file_path, link_path) return key - def get_file(self, key, mode = "rb"): + def get_file(self, key, mode="rb"): file_path = os.path.join(self.base_path, key) file = open(file_path, mode) return file - def delete(self, key, owner = "nobody"): + def delete(self, key, owner="nobody"): owner_path = self._ensure(owner) file_path = os.path.join(self.base_path, key) link_path = os.path.join(owner_path, key) @@ -82,37 +76,36 @@ def size(self, key): file_path = os.path.join(self.base_path, key) return os.path.getsize(file_path) - def count(self, owner = None): - list = self.list(owner = owner) + def count(self, owner=None): + list = self.list(owner=owner) return len(list) - def list(self, owner = None): - path = self._path(owner = owner) + def list(self, owner=None): + path = self._path(owner=owner) exists = os.path.exists(path) files = os.listdir(path) if exists else [] return files - def _path(self, owner = None): - if not owner: return self.base_path + def _path(self, owner=None): + if not owner: + return self.base_path return os.path.join(self.base_path, owner) def _ensure(self, owner): owner_path = os.path.join(self.base_path, owner) - if os.path.exists(owner_path): return owner_path + if os.path.exists(owner_path): + return owner_path os.makedirs(owner_path) return owner_path def _symlink(self, source, target): if os.name == "nt": - symlink = ctypes.windll.kernel32.CreateSymbolicLinkW #@UndefinedVariable - symlink.argtypes = ( - ctypes.c_wchar_p, - ctypes.c_wchar_p, - ctypes.c_uint32 - ) + symlink = ctypes.windll.kernel32.CreateSymbolicLinkW # @UndefinedVariable + symlink.argtypes = (ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_uint32) symlink.restype = ctypes.c_ubyte flags = 1 if os.path.isdir(source) else 0 result = symlink(target, source, flags) - if result == 0: raise ctypes.WinError() + if result == 0: + raise ctypes.WinError() else: - os.symlink(source, target) #@UndefinedVariable + os.symlink(source, target) # @UndefinedVariable diff --git a/src/netius/adapters/memory.py b/src/netius/adapters/memory.py index 0c1cd944e..3fc368f23 100644 --- a/src/netius/adapters/memory.py +++ b/src/netius/adapters/memory.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ from . import base + class MemoryAdapter(base.BaseAdapter): def __init__(self): @@ -48,16 +40,17 @@ def __init__(self): self.map = dict() self.owners = dict() - def set(self, value, owner = "nobody"): + def set(self, value, owner="nobody"): map_o = self._ensure(owner) key = self.generate() - item = dict(value = value, owner = owner) + item = dict(value=value, owner=owner) self.map[key] = item map_o[key] = item return key - def get_file(self, key, mode = "rb"): - if not key in self.map: netius.NetiusError("Key not found") + def get_file(self, key, mode="rb"): + if not key in self.map: + netius.NetiusError("Key not found") item = self.map[key] value = item["value"] file = netius.legacy.StringIO(value) @@ -66,7 +59,7 @@ def get_file(self, key, mode = "rb"): file.close = close return file - def delete(self, key, owner = "nobody"): + def delete(self, key, owner="nobody"): item = self.map[key] owner = item["owner"] map_o = self._ensure(owner) @@ -91,11 +84,11 @@ def size(self, key): _value = item["value"] return len(_value) - def count(self, owner = None): + def count(self, owner=None): map = self._ensure(owner) if owner else self.map return len(map) - def list(self, owner = None): + def list(self, owner=None): map = self._ensure(owner) if owner else self.map return map.keys() diff --git a/src/netius/adapters/mongo.py b/src/netius/adapters/mongo.py index e69be3d07..448d3067b 100644 --- a/src/netius/adapters/mongo.py +++ b/src/netius/adapters/mongo.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,9 +30,10 @@ from . import base + class MongoAdapter(base.BaseAdapter): - def set(self, value, owner = "nobody"): + def set(self, value, owner="nobody"): pass def get(self, key): diff --git a/src/netius/adapters/null.py b/src/netius/adapters/null.py index a0c3e7186..3a9991b41 100644 --- a/src/netius/adapters/null.py +++ b/src/netius/adapters/null.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,5 +30,6 @@ from . import base + class NullAdapter(base.BaseAdapter): pass diff --git a/src/netius/auth/__init__.py b/src/netius/auth/__init__.py index 0435e6836..4b86032aa 100644 --- a/src/netius/auth/__init__.py +++ b/src/netius/auth/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/auth/address.py b/src/netius/auth/address.py index faf8f488e..be3acd930 100644 --- a/src/netius/auth/address.py +++ b/src/netius/auth/address.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,35 +30,30 @@ from . import base + class AddressAuth(base.Auth): - def __init__(self, allowed = [], *args, **kwargs): + def __init__(self, allowed=[], *args, **kwargs): base.Auth.__init__(self, *args, **kwargs) self.allowed = allowed @classmethod - def auth(cls, allowed = [], *args, **kwargs): + def auth(cls, allowed=[], *args, **kwargs): import netius.common + host = kwargs.get("host", None) headers = kwargs.get("headers", {}) - if not host and not headers: return False + if not host and not headers: + return False address = headers.get("X-Forwarded-For", host) address = headers.get("X-Client-IP", address) address = headers.get("X-Real-IP", address) address = address.split(",", 1)[0].strip() - return netius.common.assert_ip4( - address, - allowed, - default = False - ) + return netius.common.assert_ip4(address, allowed, default=False) @classmethod def is_simple(cls): return True def auth_i(self, *args, **kwargs): - return self.__class__.auth( - allowed = self.allowed, - *args, - **kwargs - ) + return self.__class__.auth(allowed=self.allowed, *args, **kwargs) diff --git a/src/netius/auth/allow.py b/src/netius/auth/allow.py index dd21a99a4..094bbc088 100644 --- a/src/netius/auth/allow.py +++ b/src/netius/auth/allow.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ from . import base + class AllowAuth(base.Auth): @classmethod diff --git a/src/netius/auth/base.py b/src/netius/auth/base.py index 8a7b4ca85..493f7b4dc 100644 --- a/src/netius/auth/base.py +++ b/src/netius/auth/base.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,6 +34,7 @@ import netius + class Auth(object): """ The top level base authentication handler, should define @@ -68,13 +60,16 @@ def meta(cls, *args, **kwargs): @classmethod def auth_assert(cls, *args, **kwargs): result = cls.auth(*args, **kwargs) - if not result: raise netius.SecurityError("Invalid authentication") + if not result: + raise netius.SecurityError("Invalid authentication") @classmethod def verify(cls, encoded, decoded): type, salt, digest, plain = cls.unpack(encoded) - if plain: return encoded == decoded - if salt: decoded += salt + if plain: + return encoded == decoded + if salt: + decoded += salt type = type.lower() decoded = netius.legacy.bytes(decoded) hash = hashlib.new(type, decoded) @@ -82,13 +77,16 @@ def verify(cls, encoded, decoded): return _digest == digest @classmethod - def generate(cls, password, type = "sha256", salt = "netius"): - if type == "plain" : return password - if salt: password += salt + def generate(cls, password, type="sha256", salt="netius"): + if type == "plain": + return password + if salt: + password += salt password = netius.legacy.bytes(password) hash = hashlib.new(type, password) digest = hash.hexdigest() - if not salt: return "%s:%s" % (type, digest) + if not salt: + return "%s:%s" % (type, digest) salt = netius.legacy.bytes(salt) salt = binascii.hexlify(salt) salt = netius.legacy.str(salt) @@ -97,17 +95,28 @@ def generate(cls, password, type = "sha256", salt = "netius"): @classmethod def unpack(cls, password): count = password.count(":") - if count == 2: type, salt, digest = password.split(":") - elif count == 1: type, digest = password.split(":"); salt = None - else: plain = password; type = "plain"; salt = None; digest = None - if not type == "plain": plain = None - if salt: salt = netius.legacy.bytes(salt) - if salt: salt = binascii.unhexlify(salt) - if salt: salt = netius.legacy.str(salt) + if count == 2: + type, salt, digest = password.split(":") + elif count == 1: + type, digest = password.split(":") + salt = None + else: + plain = password + type = "plain" + salt = None + digest = None + if not type == "plain": + plain = None + if salt: + salt = netius.legacy.bytes(salt) + if salt: + salt = binascii.unhexlify(salt) + if salt: + salt = netius.legacy.str(salt) return (type, salt, digest, plain) @classmethod - def get_file(cls, path, cache = False, encoding = None): + def get_file(cls, path, cache=False, encoding=None): """ Retrieves the (file) contents for the file located "under" the provided path, these contents are returned as a normal @@ -142,26 +151,32 @@ def get_file(cls, path, cache = False, encoding = None): # verifies if the cache attribute already exists under the current class # and in case it does not creates the initial cache dictionary - if not hasattr(cls, "_cache"): cls._cache = dict() + if not hasattr(cls, "_cache"): + cls._cache = dict() # tries to retrieve the contents of the file using a caches approach # and returns such value in case the cache flag is enabled result = cls._cache.get(path, None) - if cache and not result == None: return result + if cache and not result == None: + return result # as the cache retrieval has not been successful there's a need to # load the file from the secondary storage (file system) file = open(path, "rb") - try: contents = file.read() - finally: file.close() + try: + contents = file.read() + finally: + file.close() # in case an encoding value has been passed the contents must be properly # decoded so that the "final" contents string is defined - if encoding: contents = contents.decode(encoding) + if encoding: + contents = contents.decode(encoding) # verifies if the cache mode/flag is enabled and if that's the case # store the complete file contents in memory under the dictionary - if cache: cls._cache[path] = contents + if cache: + cls._cache[path] = contents return contents @classmethod @@ -173,7 +188,8 @@ def auth_i(self, *args, **kwargs): def auth_assert_i(self, *args, **kwargs): result = self.auth_i(*args, **kwargs) - if not result: raise netius.SecurityError("Invalid authentication") + if not result: + raise netius.SecurityError("Invalid authentication") def is_simple_i(self): return self.__class__.is_simple() diff --git a/src/netius/auth/deny.py b/src/netius/auth/deny.py index 568692cbd..e5ec97785 100644 --- a/src/netius/auth/deny.py +++ b/src/netius/auth/deny.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ from . import base + class DenyAuth(base.Auth): @classmethod diff --git a/src/netius/auth/dummy.py b/src/netius/auth/dummy.py index b7e2f7272..b26329acb 100644 --- a/src/netius/auth/dummy.py +++ b/src/netius/auth/dummy.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,14 +30,15 @@ from . import base + class DummyAuth(base.Auth): - def __init__(self, value = True, *args, **kwargs): + def __init__(self, value=True, *args, **kwargs): base.Auth.__init__(self, *args, **kwargs) self.value = value @classmethod - def auth(cls, value = True, *args, **kwargs): + def auth(cls, value=True, *args, **kwargs): return value @classmethod @@ -54,8 +46,4 @@ def is_simple(cls): return True def auth_i(self, *args, **kwargs): - return self.__class__.auth( - value = self.value, - *args, - **kwargs - ) + return self.__class__.auth(value=self.value, *args, **kwargs) diff --git a/src/netius/auth/memory.py b/src/netius/auth/memory.py index add856c43..d15be1f6c 100644 --- a/src/netius/auth/memory.py +++ b/src/netius/auth/memory.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -38,31 +29,36 @@ from . import base + class MemoryAuth(base.Auth): - def __init__(self, registry = None, *args, **kwargs): + def __init__(self, registry=None, *args, **kwargs): base.Auth.__init__(self, *args, **kwargs) self.registry = registry @classmethod - def auth(cls, username, password, registry = None, *args, **kwargs): + def auth(cls, username, password, registry=None, *args, **kwargs): registry = registry or cls.get_registry() - if not registry: return False + if not registry: + return False register = registry.get(username, None) - if not register: return False + if not register: + return False _password = register.get("password") return cls.verify(_password, password) @classmethod - def meta(cls, username, registry = None, *args, **kwargs): + def meta(cls, username, registry=None, *args, **kwargs): registry = registry or cls.get_registry() - if not registry: return {} + if not registry: + return {} register = registry.get(username, {}) return register @classmethod def get_registry(cls): - if hasattr(cls, "registry"): return cls.registry + if hasattr(cls, "registry"): + return cls.registry cls.registry = cls.load_registry() return cls.registry @@ -72,9 +68,5 @@ def load_registry(cls): def auth_i(self, username, password, *args, **kwargs): return self.__class__.auth( - username, - password, - registry = self.registry, - *args, - **kwargs + username, password, registry=self.registry, *args, **kwargs ) diff --git a/src/netius/auth/passwd.py b/src/netius/auth/passwd.py index 76700be2c..73d5d3cad 100644 --- a/src/netius/auth/passwd.py +++ b/src/netius/auth/passwd.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,46 +32,46 @@ from . import base + class PasswdAuth(base.Auth): - def __init__(self, path = None, *args, **kwargs): + def __init__(self, path=None, *args, **kwargs): base.Auth.__init__(self, *args, **kwargs) self.path = path @classmethod - def auth(cls, username, password, path = "passwd", *args, **kwargs): + def auth(cls, username, password, path="passwd", *args, **kwargs): passwd = cls.get_passwd(path) _password = passwd.get(username, None) - if not _password: return False + if not _password: + return False return cls.verify(_password, password) @classmethod - def get_passwd(cls, path, cache = True): + def get_passwd(cls, path, cache=True): path = os.path.expanduser(path) path = os.path.abspath(path) path = os.path.normpath(path) - if not hasattr(cls, "_pwcache"): cls._pwcache = dict() + if not hasattr(cls, "_pwcache"): + cls._pwcache = dict() result = cls._pwcache.get(path, None) if hasattr(cls, "_pwcache") else None - if cache and not result == None: return result + if cache and not result == None: + return result htpasswd = dict() - contents = cls.get_file(path, cache = cache, encoding = "utf-8") + contents = cls.get_file(path, cache=cache, encoding="utf-8") for line in contents.split("\n"): line = line.strip() - if not line: continue + if not line: + continue username, password = line.split(":", 1) htpasswd[username] = password - if cache: cls._pwcache[path] = htpasswd + if cache: + cls._pwcache[path] = htpasswd return htpasswd def auth_i(self, username, password, *args, **kwargs): - return self.__class__.auth( - username, - password, - path = self.path, - *args, - **kwargs - ) + return self.__class__.auth(username, password, path=self.path, *args, **kwargs) diff --git a/src/netius/auth/simple.py b/src/netius/auth/simple.py index 1ebc4e0ee..e12d5deb2 100644 --- a/src/netius/auth/simple.py +++ b/src/netius/auth/simple.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -36,26 +27,26 @@ from . import base + class SimpleAuth(base.Auth): - def __init__(self, username = None, password = None, *args, **kwargs): + def __init__(self, username=None, password=None, *args, **kwargs): base.Auth.__init__(self, *args, **kwargs) self.username = username self.password = password @classmethod - def auth(cls, username, password, target = None, *args, **kwargs): - if not target: return False + def auth(cls, username, password, target=None, *args, **kwargs): + if not target: + return False _username, _password = target - if _username and not username == _username: return False - if not password == _password: return False + if _username and not username == _username: + return False + if not password == _password: + return False return True def auth_i(self, username, password, *args, **kwargs): return self.__class__.auth( - username, - password, - target = (self.username, self.password), - *args, - **kwargs + username, password, target=(self.username, self.password), *args, **kwargs ) diff --git a/src/netius/base/__init__.py b/src/netius/base/__init__.py index a2291fc48..a63168285 100644 --- a/src/netius/base/__init__.py +++ b/src/netius/base/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -56,23 +47,85 @@ from . import util from .agent import Agent, ClientAgent, ServerAgent -from .asynchronous import Future, Task, Handle, Executor, ThreadPoolExecutor, coroutine,\ - async_test_all, async_test, ensure_generator, get_asyncio, is_coroutine,\ - is_coroutine_object, is_coroutine_native, is_future, is_neo, is_asynclib, is_await,\ - wakeup, sleep, wait, notify, coroutine_return +from .asynchronous import ( + Future, + Task, + Handle, + Executor, + ThreadPoolExecutor, + coroutine, + async_test_all, + async_test, + ensure_generator, + get_asyncio, + is_coroutine, + is_coroutine_object, + is_coroutine_native, + is_future, + is_neo, + is_asynclib, + is_await, + wakeup, + sleep, + wait, + notify, + coroutine_return, +) from .client import Client, DatagramClient, StreamClient -from .common import NAME, VERSION, IDENTIFIER_SHORT, IDENTIFIER_LONG,\ - IDENTIFIER, TCP_TYPE, UDP_TYPE, SSL_KEY_PATH, SSL_CER_PATH, SSL_CA_PATH,\ - SSL_DH_PATH, Base, BaseThread, new_loop_main, new_loop_asyncio, new_loop,\ - ensure_main, ensure_asyncio, ensure_loop, get_main, get_loop, get_event_loop,\ - stop_loop, compat_loop, get_poll, build_future, ensure, ensure_pool -from .compat import BaseLoop, CompatLoop, is_compat, is_asyncio, build_datagram,\ - connect_stream +from .common import ( + NAME, + VERSION, + IDENTIFIER_SHORT, + IDENTIFIER_LONG, + IDENTIFIER, + TCP_TYPE, + UDP_TYPE, + SSL_KEY_PATH, + SSL_CER_PATH, + SSL_CA_PATH, + SSL_DH_PATH, + Base, + BaseThread, + new_loop_main, + new_loop_asyncio, + new_loop, + ensure_main, + ensure_asyncio, + ensure_loop, + get_main, + get_loop, + get_event_loop, + stop_loop, + compat_loop, + get_poll, + build_future, + ensure, + ensure_pool, +) +from .compat import ( + BaseLoop, + CompatLoop, + is_compat, + is_asyncio, + build_datagram, + connect_stream, +) from .config import conf, conf_prefix, conf_suffix, conf_s, conf_r, conf_d, conf_ctx from .conn import OPEN, CLOSED, PENDING, CHUNK_SIZE, Connection from .container import Container, ContainerServer -from .errors import NetiusError, RuntimeError, StopError, PauseError, WakeupError,\ - DataError, ParserError, GeneratorError, SecurityError, NotImplemented, AssertionError +from .errors import ( + NetiusError, + RuntimeError, + StopError, + PauseError, + WakeupError, + DataError, + ParserError, + GeneratorError, + SecurityError, + NotImplemented, + AssertionError, +) from .log import SILENT, rotating_handler, smtp_handler from .observer import Observable from .poll import Poll, EpollPoll, KqueuePoll, PollPoll, SelectPoll @@ -80,7 +133,12 @@ from .request import Request, Response from .server import Server, DatagramServer, StreamServer from .stream import Stream -from .tls import fingerprint, match_fingerprint, match_hostname, dnsname_match,\ - dump_certificate +from .tls import ( + fingerprint, + match_fingerprint, + match_hostname, + dnsname_match, + dump_certificate, +) from .transport import Transport, TransportDatagram, TransportStream from .util import camel_to_underscore, verify diff --git a/src/netius/base/agent.py b/src/netius/base/agent.py index e264adc50..2f0357b44 100644 --- a/src/netius/base/agent.py +++ b/src/netius/base/agent.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,6 +33,7 @@ from . import legacy from . import observer + class Agent(observer.Observable): """ Top level class for the entry point classes of the multiple @@ -63,12 +55,14 @@ class Agent(observer.Observable): def cleanup_s(cls): pass - def cleanup(self, destroy = True): - if destroy: self.destroy() + def cleanup(self, destroy=True): + if destroy: + self.destroy() def destroy(self): observer.Observable.destroy(self) + class ClientAgent(Agent): _clients = dict() @@ -89,10 +83,12 @@ def cleanup_s(cls): def get_client_s(cls, *args, **kwargs): tid = threading.current_thread().ident client = cls._clients.get(tid, None) - if client: return client + if client: + return client client = cls(*args, **kwargs) cls._clients[tid] = client return client + class ServerAgent(Agent): pass diff --git a/src/netius/base/async_neo.py b/src/netius/base/async_neo.py index 7df01b693..7afa1a912 100644 --- a/src/netius/base/async_neo.py +++ b/src/netius/base/async_neo.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -44,8 +35,11 @@ from . import legacy from . import async_old -try: import asyncio -except ImportError: asyncio = None +try: + import asyncio +except ImportError: + asyncio = None + class Future(async_old.Future): """ @@ -77,6 +71,7 @@ def __await__(self): raise self.exception() return self.result() + class AwaitWrapper(object): """ Wrapper class meant to be used to encapsulate "old" @@ -92,14 +87,17 @@ class AwaitWrapper(object): infra-structure to know that this type is considered to be generator compliant. """ - def __init__(self, generator, generate = False): - if generate: generator = self.generate(generator) + def __init__(self, generator, generate=False): + if generate: + generator = self.generate(generator) self.generator = generator self.is_generator = legacy.is_generator(generator) def __await__(self): - if self.is_generator: return self._await_generator() - else: return self._await_basic() + if self.is_generator: + return self._await_generator() + else: + return self._await_basic() def __iter__(self): return self @@ -121,6 +119,7 @@ def _await_basic(self): return self.generator yield + class CoroutineWrapper(object): """ Wrapper class meant to encapsulate a coroutine object @@ -139,21 +138,25 @@ def __iter__(self): return self def __next__(self): - if self._buffer: return self._buffer.pop(0) + if self._buffer: + return self._buffer.pop(0) return self.coroutine.send(None) def next(self): return self.__next__() def restore(self, value): - if self._buffer == None: self._buffer = [] + if self._buffer == None: + self._buffer = [] self._buffer.append(value) + def coroutine(function): if inspect.isgeneratorfunction(function): routine = function else: + @functools.wraps(function) def routine(*args, **kwargs): # calls the underlying function with the expected arguments @@ -184,77 +187,92 @@ def wrapper(*args, **kwargs): wrapper._is_coroutine = True return wrapper + def ensure_generator(value): if legacy.is_generator(value): return True, value - if hasattr(inspect, "iscoroutine") and\ - inspect.iscoroutine(value): #@UndefinedVariable + if hasattr(inspect, "iscoroutine") and inspect.iscoroutine( + value + ): # @UndefinedVariable return True, CoroutineWrapper(value) return False, value + def get_asyncio(): return asyncio + def is_coroutine(callable): if hasattr(callable, "_is_coroutine"): return True - if hasattr(inspect, "iscoroutinefunction") and\ - inspect.iscoroutinefunction(callable): #@UndefinedVariable + if hasattr(inspect, "iscoroutinefunction") and inspect.iscoroutinefunction( + callable + ): # @UndefinedVariable return True return False + def is_coroutine_object(generator): if legacy.is_generator(generator): return True - if hasattr(inspect, "iscoroutine") and\ - inspect.iscoroutine(generator): #@UndefinedVariable + if hasattr(inspect, "iscoroutine") and inspect.iscoroutine( + generator + ): # @UndefinedVariable return True return False + def is_coroutine_native(generator): - if hasattr(inspect, "iscoroutine") and\ - inspect.iscoroutine(generator): #@UndefinedVariable + if hasattr(inspect, "iscoroutine") and inspect.iscoroutine( + generator + ): # @UndefinedVariable return True return False + def is_future(future): - if isinstance(future, async_old.Future): return True - if asyncio and isinstance(future, asyncio.Future): return True + if isinstance(future, async_old.Future): + return True + if asyncio and isinstance(future, asyncio.Future): + return True return False -def _sleep(timeout, compat = True): + +def _sleep(timeout, compat=True): from .common import get_loop + loop = get_loop() compat &= hasattr(loop, "_sleep") sleep = loop._sleep if compat else loop.sleep result = yield from sleep(timeout) return result -def _wait(event, timeout = None, future = None): + +def _wait(event, timeout=None, future=None): from .common import get_loop + loop = get_loop() - result = yield from loop.wait( - event, - timeout = timeout, - future = future - ) + result = yield from loop.wait(event, timeout=timeout, future=future) return result + def sleep(*args, **kwargs): generator = _sleep(*args, **kwargs) return AwaitWrapper(generator) + def wait(*args, **kwargs): generator = _wait(*args, **kwargs) return AwaitWrapper(generator) + def coroutine_return(coroutine): """ Allows for the abstraction of the return value of a coroutine @@ -272,6 +290,7 @@ def coroutine_return(coroutine): generator = _coroutine_return(coroutine) return AwaitWrapper(generator) + def _coroutine_return(coroutine): for value in coroutine: yield value diff --git a/src/netius/base/async_old.py b/src/netius/base/async_old.py index 76e781442..75be5929f 100644 --- a/src/netius/base/async_old.py +++ b/src/netius/base/async_old.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -45,6 +36,7 @@ from . import errors from . import legacy + class Future(object): """ The base future object that represents a promise that a @@ -59,7 +51,7 @@ class Future(object): :see: https://en.wikipedia.org/wiki/Futures_and_promises """ - def __init__(self, loop = None): + def __init__(self, loop=None): self.status = 0 self._loop = loop self._blocking = False @@ -68,7 +60,8 @@ def __init__(self, loop = None): self.cleanup() def __iter__(self): - while not self.done(): yield self + while not self.done(): + yield self if self.cancelled(): raise errors.RuntimeError("Future canceled") if self.exception(): @@ -97,7 +90,7 @@ def result(self): raise errors.RuntimeError("Already canceled") return self._result - def exception(self, timeout = None): + def exception(self, timeout=None): return self._exception def partial(self, value): @@ -105,7 +98,8 @@ def partial(self, value): def add_done_callback(self, function): self.done_callbacks.append(function) - if not self.finished(): return + if not self.finished(): + return self._done_callbacks() def add_partial_callback(self, function): @@ -117,59 +111,63 @@ def add_ready_callback(self, function): def add_closed_callback(self, function): self.closed_callbacks.append(function) - def approve(self, cleanup = True): - self.set_result(None, cleanup = cleanup) + def approve(self, cleanup=True): + self.set_result(None, cleanup=cleanup) - def cancel(self, cleanup = True, force = False): - if not force and not self.running(): return False + def cancel(self, cleanup=True, force=False): + if not force and not self.running(): + return False self.status = 2 - self._done_callbacks(cleanup = cleanup) + self._done_callbacks(cleanup=cleanup) return True - def set_result(self, result, cleanup = True, force = False): + def set_result(self, result, cleanup=True, force=False): if not force and not self.running(): raise errors.AssertionError("Future not running") self.status = 1 self._result = result - self._done_callbacks(cleanup = cleanup) + self._done_callbacks(cleanup=cleanup) - def set_exception(self, exception, cleanup = True, force = False): + def set_exception(self, exception, cleanup=True, force=False): if not force and not self.running(): raise errors.AssertionError("Future not running") self.status = 1 self._exception = exception - self._done_callbacks(cleanup = cleanup) + self._done_callbacks(cleanup=cleanup) @property def ready(self): ready = True - for callback in self.ready_callbacks: ready &= callback() + for callback in self.ready_callbacks: + ready &= callback() return ready @property def closed(self): closed = False - for callback in self.closed_callbacks: closed |= callback() + for callback in self.closed_callbacks: + closed |= callback() return closed - def _done_callbacks(self, cleanup = True, delayed = True): - if not self.done_callbacks: return + def _done_callbacks(self, cleanup=True, delayed=True): + if not self.done_callbacks: + return if delayed and self._loop: return self._delay( - lambda: self._done_callbacks( - cleanup = cleanup, - delayed = False - ) + lambda: self._done_callbacks(cleanup=cleanup, delayed=False) ) - for callback in self.done_callbacks: callback(self) - if cleanup: self.cleanup() - - def _partial_callbacks(self, value, delayed = True): - if not self.partial_callbacks: return - if delayed and self._loop: return self._delay( - lambda: self._partial_callbacks(value, delayed = False) - ) - for callback in self.partial_callbacks: callback(self, value) + for callback in self.done_callbacks: + callback(self) + if cleanup: + self.cleanup() + + def _partial_callbacks(self, value, delayed=True): + if not self.partial_callbacks: + return + if delayed and self._loop: + return self._delay(lambda: self._partial_callbacks(value, delayed=False)) + for callback in self.partial_callbacks: + callback(self, value) def _wrap(self, future): self.status = future.status @@ -184,32 +182,39 @@ def _wrap(self, future): def _delay(self, callable): has_delay = hasattr(self._loop, "delay") - if has_delay: return self._loop.delay(callable, immediately = True) + if has_delay: + return self._loop.delay(callable, immediately=True) return self._loop.call_soon(callable) + class Task(Future): - def __init__(self, future = None): + def __init__(self, future=None): Future.__init__(self) self._future = future self._source_traceback = None - if future: self._wrap(future) + if future: + self._wrap(future) + class Handle(object): - def __init__(self, callable_t = None): + def __init__(self, callable_t=None): self._callable_t = callable_t def cancel(self): - if not self._callable_t: return + if not self._callable_t: + return options = self._callable_t[4] options[0] = False + class Executor(object): def submit(self, callable, *args, **kwargs): raise errors.NotImplemented("Missing implementation") + class ThreadPoolExecutor(Executor): def __init__(self, owner): @@ -217,22 +222,17 @@ def __init__(self, owner): def submit(self, callable, *args, **kwargs): future = self.owner.build_future() - callback = lambda result: self.owner.delay_s( - lambda: future.set_result(result) - ) - self.owner.texecute( - callable, - args = args, - kwargs = kwargs, - callback = callback - ) + callback = lambda result: self.owner.delay_s(lambda: future.set_result(result)) + self.owner.texecute(callable, args=args, kwargs=kwargs, callback=callback) return future + def coroutine(function): if inspect.isgeneratorfunction(function): routine = function else: + @functools.wraps(function) def routine(*args, **kwargs): # calls the underlying function with the expected arguments @@ -248,7 +248,8 @@ def routine(*args, **kwargs): # the complete set of values is properly yield # to the caller method as expected if is_future_ or is_generator: - for value in result: yield value + for value in result: + yield value # otherwise, the single resulting value is yield # to the caller method (simple propagation) @@ -258,7 +259,8 @@ def routine(*args, **kwargs): routine._is_coroutine = True return routine -def async_test_all(factory = None, close = True): + +def async_test_all(factory=None, close=True): def decorator(function): @@ -269,69 +271,94 @@ def decorator(function): def wrapper(*args, **kwargs): function_c = asynchronous.coroutine(function) future = function_c(*args, **kwargs) - loop = common.get_main(factory = factory) - return loop.run_coroutine(future, close = close) + loop = common.get_main(factory=factory) + return loop.run_coroutine(future, close=close) return wrapper return decorator + def async_test(function): decorator = async_test_all() return decorator(function) + def ensure_generator(value): - if legacy.is_generator(value): return True, value + if legacy.is_generator(value): + return True, value return False, value + def get_asyncio(): return None + def is_coroutine(callable): - if hasattr(callable, "_is_coroutine"): return True + if hasattr(callable, "_is_coroutine"): + return True return False + def is_coroutine_object(generator): - if legacy.is_generator(generator): return True + if legacy.is_generator(generator): + return True return False + def is_coroutine_native(generator): return False + def is_future(future): - if isinstance(future, Future): return True + if isinstance(future, Future): + return True return False + def is_neo(): return sys.version_info[0] >= 3 and sys.version_info[1] >= 3 + def is_asynclib(): return sys.version_info[0] >= 3 and sys.version_info[1] >= 4 + def is_await(): return sys.version_info[0] >= 3 and sys.version_info[1] >= 6 -def wakeup(force = False): + +def wakeup(force=False): from .common import get_loop + loop = get_loop() - return loop.wakeup(force = force) + return loop.wakeup(force=force) + -def sleep(timeout, compat = True, future = None): +def sleep(timeout, compat=True, future=None): from .common import get_loop + loop = get_loop() compat &= hasattr(loop, "_sleep") sleep = loop._sleep if compat else loop.sleep - for value in sleep(timeout, future = future): yield value + for value in sleep(timeout, future=future): + yield value + -def wait(event, timeout = None, future = None): +def wait(event, timeout=None, future=None): from .common import get_loop + loop = get_loop() - for value in loop.wait(event, timeout = timeout, future = future): yield value + for value in loop.wait(event, timeout=timeout, future=future): + yield value + -def notify(event, data = None): +def notify(event, data=None): from .common import get_loop + loop = get_loop() - return loop.notify(event, data = data) + return loop.notify(event, data=data) + def coroutine_return(coroutine): """ @@ -347,4 +374,5 @@ def coroutine_return(coroutine): and have its last future result returned from the generator. """ - for value in coroutine: yield value + for value in coroutine: + yield value diff --git a/src/netius/base/asynchronous.py b/src/netius/base/asynchronous.py index e11590a1f..a925774f5 100644 --- a/src/netius/base/asynchronous.py +++ b/src/netius/base/asynchronous.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -40,10 +31,11 @@ # imports the base (old) version of the async implementation # that should be compatible with all the available python # interpreters, base collection of async library -from .async_old import * #@UnusedWildImport +from .async_old import * # @UnusedWildImport # verifies if the current python interpreter version supports # the new version of the async implementation and if that's the # case runs the additional import of symbols, this should override # most of the symbols that have just been created -if is_neo(): from .async_neo import * #@UnusedWildImport +if is_neo(): + from .async_neo import * # @UnusedWildImport diff --git a/src/netius/base/client.py b/src/netius/base/client.py index 9d1ec86c9..d3e1312ef 100644 --- a/src/netius/base/client.py +++ b/src/netius/base/client.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,8 +30,8 @@ from . import request -from .conn import * #@UnusedWildImport -from .common import * #@UnusedWildImport +from .conn import * # @UnusedWildImport +from .common import * # @UnusedWildImport BUFFER_SIZE = None """ The size of the buffer that is going to be used in the @@ -52,6 +43,7 @@ collector of pending request in a datagram client, this value will be used n the delay operation of the action """ + class Client(Base): """ Abstract client implementation, should provide the required @@ -67,7 +59,7 @@ class Client(Base): system from exiting correctly, in order to prevent that the cleanup method should be called """ - def __init__(self, thread = True, daemon = False, *args, **kwargs): + def __init__(self, thread=True, daemon=False, *args, **kwargs): Base.__init__(self, *args, **kwargs) self.receive_buffer = kwargs.get("receive_buffer", BUFFER_SIZE) self.send_buffer = kwargs.get("send_buffer", BUFFER_SIZE) @@ -77,16 +69,18 @@ def __init__(self, thread = True, daemon = False, *args, **kwargs): @classmethod def get_client_s(cls, *args, **kwargs): - if cls._client: return cls._client + if cls._client: + return cls._client cls._client = cls(*args, **kwargs) return cls._client @classmethod def cleanup_s(cls): - if not cls._client: return + if not cls._client: + return cls._client.close() - def ensure_loop(self, env = True): + def ensure_loop(self, env=True): """ Ensures that the proper main loop thread requested in the building of the entity is started if that was requested. @@ -106,21 +100,26 @@ def ensure_loop(self, env = True): # verifies if the (run in) thread flag is set and that the there's # not thread currently created for the client in case any of these # conditions fails the control flow is returned immediately to caller - if not self.thread: return - if self._thread: return + if not self.thread: + return + if self._thread: + return # runs the various extra variable initialization taking into # account if the environment variable is currently set or not # please note that some side effects may arise from this set - if env: self.level = self.get_env("LEVEL", self.level) - if env: self.diag = self.get_env("CLIENT_DIAG", self.diag, cast = bool) - if env: self.logging = self.get_env("LOGGING", self.logging) - if env: self.poll_name = self.get_env("POLL", self.poll_name) - if env: self.poll_timeout = self.get_env( - "POLL_TIMEOUT", - self.poll_timeout, - cast = float - ) + if env: + self.level = self.get_env("LEVEL", self.level) + if env: + self.diag = self.get_env("CLIENT_DIAG", self.diag, cast=bool) + if env: + self.logging = self.get_env("LOGGING", self.logging) + if env: + self.poll_name = self.get_env("POLL", self.poll_name) + if env: + self.poll_timeout = self.get_env( + "POLL_TIMEOUT", self.poll_timeout, cast=float + ) # prints a debug message about the new thread to be created for # the client infra-structure (required for execution) @@ -129,17 +128,14 @@ def ensure_loop(self, env = True): # in case the thread flag is set a new thread must be constructed # for the running of the client's main loop then, these thread # may or may not be constructed using a daemon approach - self._thread = BaseThread( - owner = self, - daemon = self.daemon, - name = "Loop" - ) + self._thread = BaseThread(owner=self, daemon=self.daemon, name="Loop") self._thread.start() - def join(self, timeout = None): + def join(self, timeout=None): # runs the join operation in the thread associated with the client # so that the current thread blocks until the other ends execution - self._thread.join(timeout = timeout) + self._thread.join(timeout=timeout) + class DatagramClient(Client): @@ -156,7 +152,7 @@ def __init__(self, *args, **kwargs): def boot(self): Client.boot(self) - self.keep_gc(timeout = GC_TIMEOUT, run = False) + self.keep_gc(timeout=GC_TIMEOUT, run=False) def cleanup(self): Client.cleanup(self) @@ -166,12 +162,14 @@ def cleanup(self): def on_read(self, _socket): callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("read", _socket) + for callback in callbacks: + callback("read", _socket) # verifies if the provided socket for reading is the same # as the one registered in the client if that's not the case # return immediately to avoid unwanted operations - if not _socket == self.socket: return + if not _socket == self.socket: + return try: # iterates continuously trying to read as much data as possible @@ -185,8 +183,7 @@ def on_read(self, _socket): error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error) except socket.error as error: error_v = error.args[0] if error.args else None @@ -202,12 +199,14 @@ def on_read(self, _socket): def on_write(self, _socket): callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("write", _socket) + for callback in callbacks: + callback("write", _socket) # verifies if the provided socket for writing is the same # as the one registered in the client if that's not the case # return immediately to avoid unwanted operations - if not _socket == self.socket: return + if not _socket == self.socket: + return try: self._send(_socket) @@ -216,8 +215,7 @@ def on_write(self, _socket): error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error) except socket.error as error: error_v = error.args[0] if error.args else None @@ -233,12 +231,14 @@ def on_write(self, _socket): def on_error(self, _socket): callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("error", _socket) + for callback in callbacks: + callback("error", _socket) # verifies if the provided socket for error is the same # as the one registered in the client if that's not the case # return immediately to avoid unwanted operations - if not _socket == self.socket: return + if not _socket == self.socket: + return def on_exception(self, exception): self.warning(exception) @@ -250,15 +250,17 @@ def on_expected(self, exception): def on_data(self, connection, data): pass - def keep_gc(self, timeout = GC_TIMEOUT, run = True): - if run: self.gc() + def keep_gc(self, timeout=GC_TIMEOUT, run=True): + if run: + self.gc() self.delay(self.keep_gc, timeout) - def gc(self, callbacks = True): + def gc(self, callbacks=True): # in case there're no requests pending in the current client # there's no need to start the garbage collection logic, as # this would required some (minimal) resources - if not self.requests: return + if not self.requests: + return # prints a message (for debug) about the garbage collection # operation that is going to be run @@ -272,20 +274,23 @@ def gc(self, callbacks = True): # verifies if the requests structure (list) is empty and # if that's the case break the loop, nothing more remains # to be processed for the current garbage collection operation - if not self.requests: break + if not self.requests: + break # retrieves the top level request (peek operation) and # verifies if the timeout value of it has exceed the # current time if that's the case removes it as it # should no longer be handled (time out) request = self.requests[0] - if request.timeout > current: break + if request.timeout > current: + break self.remove_request(request) # in case the (call) callbacks flag is not set continues # the current loop, meaning that the associated callbacks # are not going to be called (as expected) - if not callbacks: continue + if not callbacks: + continue # extracts the callback method from the request and in # case it is defined and valid calls it with an invalid @@ -293,7 +298,7 @@ def gc(self, callbacks = True): # the call is encapsulated in a safe manner meaning that # any exception raised will be gracefully caught callback = request.callback - callback and self.call_safe(callback, args = [None]) + callback and self.call_safe(callback, args=[None]) def add_request(self, request): # adds the current request object to the list of requests @@ -309,13 +314,15 @@ def remove_request(self, request): def get_request(self, id): is_response = isinstance(id, request.Response) - if is_response: id = id.get_id() + if is_response: + id = id.get_id() return self.requests_m.get(id, None) def ensure_socket(self): # in case the socket is already created and valid returns immediately # as nothing else remain to be done in the current method - if self.socket: return + if self.socket: + return # prints a small debug message about the UDP socket that is going # to be created for the client's connection @@ -350,7 +357,8 @@ def ensure_write(self): # safe and so it must be delayed to be executed in the # next loop of the thread cycle, must return immediately # to avoid extra subscription operations - if not is_safe: return self.delay(self.ensure_write, safe = True) + if not is_safe: + return self.delay(self.ensure_write, safe=True) # adds the current socket to the list of write operations # so that it's going to be available for writing as soon @@ -361,29 +369,26 @@ def remove_write(self): self.unsub_write(self.socket) def enable_read(self): - if not self.renable == False: return + if not self.renable == False: + return self.renable = True self.sub_read(self.socket) def disable_read(self): - if not self.renable == True: return + if not self.renable == True: + return self.renable = False self.unsub_read(self.socket) - def send( - self, - data, - address, - delay = True, - ensure_loop = True, - callback = None - ): - if ensure_loop: self.ensure_loop() + def send(self, data, address, delay=True, ensure_loop=True, callback=None): + if ensure_loop: + self.ensure_loop() data = legacy.bytes(data) data_l = len(data) - if callback: data = (data, callback) + if callback: + data = (data, callback) data = (data, address) cthread = threading.current_thread() @@ -391,19 +396,18 @@ def send( is_safe = tid == self.tid self.pending_lock.acquire() - try: self.pending.appendleft(data) - finally: self.pending_lock.release() + try: + self.pending.appendleft(data) + finally: + self.pending_lock.release() self.pending_s += data_l if self.wready: - if is_safe and not delay: self._flush_write() - else: self.delay( - self._flush_write, - immediately = True, - verify = True, - safe = True - ) + if is_safe and not delay: + self._flush_write() + else: + self.delay(self._flush_write, immediately=True, verify=True, safe=True) else: self.ensure_write() @@ -414,7 +418,8 @@ def _send(self, _socket): while True: # in case there's no pending data to be sent to the # server side breaks the current loop (queue empty) - if not self.pending: break + if not self.pending: + break # retrieves the current data from the pending list # of data to be sent and then saves the original data @@ -428,7 +433,8 @@ def _send(self, _socket): # verifies if the data type of the data is a tuple and # if that's the case unpacks it as data and callback is_tuple = type(data) == tuple - if is_tuple: data, callback = data + if is_tuple: + data, callback = data # retrieves the length (in bytes) of the data that is # going to be sent to the server @@ -440,8 +446,10 @@ def _send(self, _socket): # sent through the socket, this number may not be # the same as the size of the data in case only # part of the data has been sent - if data: count = _socket.sendto(data, address) - else: count = 0 + if data: + count = _socket.sendto(data, address) + else: + count = 0 # verifies if the current situation is that of a non # closed socket and valid data, and if that's the case @@ -449,7 +457,8 @@ def _send(self, _socket): # be in a would block situation and and such an error # is raised indicating the issue (is going to be caught # as a normal would block exception) - if data and count == 0: raise socket.error(errno.EWOULDBLOCK) + if data and count == 0: + raise socket.error(errno.EWOULDBLOCK) except: # sets the current connection write ready flag to false # so that a new level notification must be received @@ -478,7 +487,8 @@ def _send(self, _socket): # sent latter (only then the callback is called) is_valid = count == data_l if is_valid: - if callback: callback(self) + if callback: + callback(self) else: data_o = ((data[count:], callback), address) self.pending.append(data_o) @@ -495,7 +505,8 @@ def _flush_write(self): """ self.ensure_socket() - self.writes((self.socket,), state = False) + self.writes((self.socket,), state=False) + class StreamClient(Client): @@ -514,26 +525,28 @@ def cleanup(self): def ticks(self): self.set_state(STATE_TICK) self._lid = (self._lid + 1) % 2147483647 - if self.pendings: self._connects() + if self.pendings: + self._connects() self._delays() - def info_dict(self, full = False): - info = Client.info_dict(self, full = full) - if full: info.update( - pendings = len(self.pendings), - free_conn = sum([len(value) for value in legacy.values(self.free_map)]) - ) + def info_dict(self, full=False): + info = Client.info_dict(self, full=full) + if full: + info.update( + pendings=len(self.pendings), + free_conn=sum([len(value) for value in legacy.values(self.free_map)]), + ) return info def acquire_c( self, host, port, - ssl = False, - key_file = None, - cer_file = None, - validate = False, - callback = None + ssl=False, + key_file=None, + cer_file=None, + validate=False, + callback=None, ): # sets the initial value of the connection instance variable # to invalid, this is going to be populated with a valid @@ -555,13 +568,15 @@ def acquire_c( # is still valid (open and ready), if that's not the case # unsets the connection variable connection = connection_l.pop() - if validate and not self.validate_c(connection): connection = None + if validate and not self.validate_c(connection): + connection = None # in case the connection has been invalidated (possible # disconnect) the current loop iteration is skipped and # a new connection from the list of connections in pool # is going to be searched and validated - if not connection: continue + if not connection: + continue # runs the connection acquire operation that should take # care of the proper acquisition notification process and @@ -575,11 +590,7 @@ def acquire_c( # the next execution cycle (delayed execution) if not connection: connection = self.connect( - host, - port, - ssl = ssl, - key_file = key_file, - cer_file = cer_file + host, port, ssl=ssl, key_file=key_file, cer_file=cer_file ) connection.tuple = connection_t @@ -588,7 +599,8 @@ def acquire_c( return connection def release_c(self, connection): - if not hasattr(connection, "tuple"): return + if not hasattr(connection, "tuple"): + return connection_t = connection.tuple connection_l = self.free_map.get(connection_t, []) connection_l.append(connection) @@ -596,12 +608,14 @@ def release_c(self, connection): self.on_release(connection) def remove_c(self, connection): - if not hasattr(connection, "tuple"): return + if not hasattr(connection, "tuple"): + return connection_t = connection.tuple connection_l = self.free_map.get(connection_t, []) - if connection in connection_l: connection_l.remove(connection) + if connection in connection_l: + connection_l.remove(connection) - def validate_c(self, connection, close = True): + def validate_c(self, connection, close=True): # sets the original valid flag value as true so that the # basic/default assumption on the connection is that it's # valid (per basis a connection is valid) @@ -610,7 +624,8 @@ def validate_c(self, connection, close = True): # tries to retrieve the value of the error options value of # the socket in case it's currently set unsets the valid flag error = connection.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if error: valid = False + if error: + valid = False # iterates continuously trying to read any pending data from # the connection, some of this data may indicate that the @@ -618,7 +633,8 @@ def validate_c(self, connection, close = True): while True: # verifies if the value of the valid flag is false and # if that's the case breaks the current loop immediately - if not valid: break + if not valid: + break # tries to read/receive any set of pending data from # the connection in case there's an exception and it's @@ -627,22 +643,28 @@ def validate_c(self, connection, close = True): # connection and sets it as invalid try: data = connection.recv() - if not data: raise errors.NetiusError("EOF received") + if not data: + raise errors.NetiusError("EOF received") connection.send(b"") except ssl.SSLError as error: error_v = error.args[0] if error.args else None - if error_v in SSL_VALID_ERRORS: break - if close: connection.close() + if error_v in SSL_VALID_ERRORS: + break + if close: + connection.close() valid = False except socket.error as error: error_v = error.args[0] if error.args else None - if error_v in VALID_ERRORS: break - if close: connection.close() + if error_v in VALID_ERRORS: + break + if close: + connection.close() valid = False except (KeyboardInterrupt, SystemExit): raise except: - if close: connection.close() + if close: + connection.close() valid = False # returns the final value on the connection validity test @@ -653,21 +675,23 @@ def connect( self, host, port, - ssl = False, - key_file = None, - cer_file = None, - ca_file = None, - ca_root = True, - ssl_verify = False, - family = socket.AF_INET, - type = socket.SOCK_STREAM, - ensure_loop = True, - env = True + ssl=False, + key_file=None, + cer_file=None, + ca_file=None, + ca_root=True, + ssl_verify=False, + family=socket.AF_INET, + type=socket.SOCK_STREAM, + ensure_loop=True, + env=True, ): # runs a series of pre-validations on the provided parameters, raising # exceptions in case they do not comply with expected values - if not host: raise errors.NetiusError("Invalid host for connect operation") - if not port: raise errors.NetiusError("Invalid port for connect operation") + if not host: + raise errors.NetiusError("Invalid host for connect operation") + if not port: + raise errors.NetiusError("Invalid port for connect operation") # tries to retrieve some of the environment variable related values # so that some of these values are accessible via an external environment @@ -675,16 +699,19 @@ def connect( key_file = self.get_env("KEY_FILE", key_file) if env else key_file cer_file = self.get_env("CER_FILE", cer_file) if env else cer_file ca_file = self.get_env("CA_FILE", ca_file) if env else ca_file - ca_root = self.get_env("CA_ROOT", ca_root, cast = bool) if env else ca_root - ssl_verify = self.get_env("SSL_VERIFY", ssl_verify, cast = bool) if env else ssl_verify - key_file = self.get_env("KEY_DATA", key_file, expand = True) if env else key_file - cer_file = self.get_env("CER_DATA", cer_file, expand = True) if env else cer_file - ca_file = self.get_env("CA_DATA", ca_file, expand = True) if env else ca_file + ca_root = self.get_env("CA_ROOT", ca_root, cast=bool) if env else ca_root + ssl_verify = ( + self.get_env("SSL_VERIFY", ssl_verify, cast=bool) if env else ssl_verify + ) + key_file = self.get_env("KEY_DATA", key_file, expand=True) if env else key_file + cer_file = self.get_env("CER_DATA", cer_file, expand=True) if env else cer_file + ca_file = self.get_env("CA_DATA", ca_file, expand=True) if env else ca_file # ensures that a proper loop cycle is available for the current # client, otherwise the connection operation would become stalled # because there's no listening of events for it - if ensure_loop: self.ensure_loop() + if ensure_loop: + self.ensure_loop() # ensures that the proper socket family is defined in case the # requested host value is unix socket oriented, this step greatly @@ -715,36 +742,28 @@ def connect( # in case the SSL flag is set re-creates the socket by wrapping it into # an SSL based one with the provided set of keys and certificates - if ssl: _socket = self._ssl_wrap( - _socket, - key_file = key_file, - cer_file = cer_file, - ca_file = ca_file, - ca_root = ca_root, - server = False, - ssl_verify = ssl_verify, - server_hostname = host - ) + if ssl: + _socket = self._ssl_wrap( + _socket, + key_file=key_file, + cer_file=cer_file, + ca_file=ca_file, + ca_root=ca_root, + server=False, + ssl_verify=ssl_verify, + server_hostname=host, + ) # sets the appropriate socket options enable it for port re-usage and # for proper keep alive notification, among others _socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) _socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if is_inet: _socket.setsockopt( - socket.IPPROTO_TCP, - socket.TCP_NODELAY, - 1 - ) - if self.receive_buffer: _socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_RCVBUF, - self.receive_buffer - ) - if self.send_buffer: _socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_SNDBUF, - self.send_buffer - ) + if is_inet: + _socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if self.receive_buffer: + _socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.receive_buffer) + if self.send_buffer: + _socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.send_buffer) self._socket_keepalive(_socket) # constructs the address tuple taking into account if the @@ -755,15 +774,18 @@ def connect( # creates the connection object using the typical constructor # and then sets the SSL host (for verification) if the verify # SSL option is defined (secured and verified connection) - connection = self.build_connection(_socket, address, ssl = ssl) - if ssl_verify: connection.ssl_host = host + connection = self.build_connection(_socket, address, ssl=ssl) + if ssl_verify: + connection.ssl_host = host # acquires the pending lock so that it's safe to add an element # to the list of pending connection for connect, this lock is # then released in the final part of the operation self._pending_lock.acquire() - try: self.pendings.append(connection) - finally: self._pending_lock.release() + try: + self.pendings.append(connection) + finally: + self._pending_lock.release() # returns the "final" connection, that is now scheduled for connect # to the caller method, it may now be used for operations @@ -779,47 +801,58 @@ def on_read(self, _socket): # to the execution of the read operation in the socket callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("read", _socket) + for callback in callbacks: + callback("read", _socket) # retrieves the connection object associated with the # current socket that is going to be read in case there's # no connection available or the status is not open # must return the control flow immediately to the caller connection = self.connections_m.get(_socket, None) - if not connection: return - if not connection.status == OPEN: return - if not connection.renable == True: return + if not connection: + return + if not connection.status == OPEN: + return + if not connection.renable == True: + return try: # in case the connection is under the connecting state # the socket must be verified for errors and in case # there's none the connection must proceed, for example # the SSL connection handshake must be performed/retried - if connection.connecting: self._connectf(connection) + if connection.connecting: + self._connectf(connection) # verifies if there's any pending operations in the # connection (eg: SSL handshaking) and performs it trying # to finish them, if they are still pending at the current # state returns immediately (waits for next loop) - if self._pending(connection): return + if self._pending(connection): + return # iterates continuously trying to read as much data as possible # when there's a failure to read more data it should raise an # exception that should be handled properly while True: data = connection.recv(CHUNK_SIZE) - if data: self.on_data(connection, data) - else: connection.close(); break - if not connection.status == OPEN: break - if not connection.renable == True: break - if not connection.socket == _socket: break + if data: + self.on_data(connection, data) + else: + connection.close() + break + if not connection.status == OPEN: + break + if not connection.renable == True: + break + if not connection.socket == _socket: + break except ssl.SSLError as error: error_v = error.args[0] if error.args else None error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error, connection) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error, connection) except socket.error as error: error_v = error.args[0] if error.args else None @@ -838,20 +871,24 @@ def on_write(self, _socket): # to the execution of the read operation in the socket callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("write", _socket) + for callback in callbacks: + callback("write", _socket) # retrieves the connection associated with the socket that # is ready for the write operation and verifies that it # exists and the current status of it is open (required) connection = self.connections_m.get(_socket, None) - if not connection: return - if not connection.status == OPEN: return + if not connection: + return + if not connection.status == OPEN: + return # in case the connection is under the connecting state # the socket must be verified for errors and in case # there's none the connection must proceed, for example # the SSL connection handshake must be performed/retried - if connection.connecting: self._connectf(connection) + if connection.connecting: + self._connectf(connection) try: connection._send() @@ -860,8 +897,7 @@ def on_write(self, _socket): error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error, connection) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error, connection) except socket.error as error: error_v = error.args[0] if error.args else None @@ -877,11 +913,14 @@ def on_write(self, _socket): def on_error(self, _socket): callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("error", _socket) + for callback in callbacks: + callback("error", _socket) connection = self.connections_m.get(_socket, None) - if not connection: return - if not connection.status == OPEN: return + if not connection: + return + if not connection.status == OPEN: + return connection.close() @@ -897,8 +936,8 @@ def on_expected(self, exception, connection): def on_connect(self, connection): self.debug( - "Connection '%s' %s from '%s' connected" % - (connection.id, connection.address, connection.owner.name) + "Connection '%s' %s from '%s' connected" + % (connection.id, connection.address, connection.owner.name) ) connection.set_connected() if hasattr(connection, "tuple"): @@ -906,8 +945,8 @@ def on_connect(self, connection): def on_upgrade(self, connection): self.debug( - "Connection '%s' %s from '%s' upgraded" % - (connection.id, connection.address, connection.owner.name) + "Connection '%s' %s from '%s' upgraded" + % (connection.id, connection.address, connection.owner.name) ) connection.set_upgraded() @@ -926,8 +965,10 @@ def on_ssl(self, connection): # and calls the proper event handler for each event, this is # required because the connection workflow is probably dependent # on the calling of these event handlers to proceed - if connection.connecting: self.on_connect(connection) - elif connection.upgrading: self.on_upgrade(connection) + if connection.connecting: + self.on_connect(connection) + elif connection.upgrading: + self.on_upgrade(connection) def on_acquire(self, connection): pass @@ -954,19 +995,24 @@ def _connectf(self, connection): # in case the SSL connection is still undergoing the handshaking # procedures (marked as connecting) ignores the call as this must # be a duplicated call to this method (to be ignored) - if connection.ssl_connecting: return + if connection.ssl_connecting: + return # verifies if there was an error in the middle of the connection # operation and if that's the case calls the proper callback and # returns the control flow to the caller method error = connection.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if error: self.on_error(connection.socket); return + if error: + self.on_error(connection.socket) + return # checks if the current connection is SSL based and if that's the # case starts the handshaking process (async non blocking) otherwise # calls the on connect callback with the newly created connection - if connection.ssl: connection.add_starter(self._ssl_handshake) - else: self.on_connect(connection) + if connection.ssl: + connection.add_starter(self._ssl_handshake) + else: + self.on_connect(connection) # runs the starter process (initial kick-off) so that all the starters # registered for the connection may start to be executed, note that if @@ -987,18 +1033,20 @@ def _connect(self, connection): # in case the current connection has been closed meanwhile # the current connection is meant to be avoided and so the # method must return immediately to the caller method - if connection.status == CLOSED: return + if connection.status == CLOSED: + return # retrieves the socket associated with the connection # and calls the open method of the connection to proceed # with the correct operations for the connection _socket = connection.socket - connection.open(connect = True) + connection.open(connect=True) # tries to run the non blocking connection it should # fail and the connection should only be considered as # open when a write event is raised for the connection - try: _socket.connect(connection.address) + try: + _socket.connect(connection.address) except ssl.SSLError as error: error_v = error.args[0] if error.args else None if not error_v in SSL_VALID_ERRORS: @@ -1033,7 +1081,8 @@ def _connect(self, connection): # in case the connection is not of type SSL the method # may return as there's nothing left to be done, as the # rest of the method is dedicated to SSL tricks - if not connection.ssl: return + if not connection.ssl: + return # verifies if the current SSL object is a context oriented one # (newest versions) or a legacy oriented one, that does not uses @@ -1046,12 +1095,11 @@ def _connect(self, connection): # destroyed by the underlying SSL library (as an error) because # the socket is of type non blocking and raises an error, note # that the creation of the socket varies between SSL versions - if _socket._sslobj: return + if _socket._sslobj: + return if has_context: _socket._sslobj = _socket.context._wrap_socket( - _socket, - _socket.server_side, - _socket.server_hostname + _socket, _socket.server_side, _socket.server_hostname ) else: _socket._sslobj = ssl._ssl.sslwrap( @@ -1061,7 +1109,7 @@ def _connect(self, connection): _socket.certfile, _socket.cert_reqs, _socket.ssl_version, - _socket.ca_certs + _socket.ca_certs, ) # verifies if the SSL object class is defined in the SSL module @@ -1069,9 +1117,12 @@ def _connect(self, connection): # in order to comply with new indirection/abstraction method, under # some circumstances this operations fails with an exception because # the wrapping operation is not allowed for every Python environment - if not hasattr(ssl, "SSLObject"): return - try: _socket._sslobj = ssl.SSLObject(_socket._sslobj, owner = _socket) - except TypeError: pass + if not hasattr(ssl, "SSLObject"): + return + try: + _socket._sslobj = ssl.SSLObject(_socket._sslobj, owner=_socket) + except TypeError: + pass def _ssl_handshake(self, connection): Client._ssl_handshake(self, connection) @@ -1079,7 +1130,8 @@ def _ssl_handshake(self, connection): # verifies if the socket still has finished the SSL handshaking # process (by verifying the appropriate flag) and then if that's # not the case returns immediately (nothing done) - if not connection.ssl_handshake: return + if not connection.ssl_handshake: + return # prints a debug information notifying the developer about # the finishing of the handshaking process for the connection diff --git a/src/netius/base/common.py b/src/netius/base/common.py index 216cba57f..a495d4b0e 100644 --- a/src/netius/base/common.py +++ b/src/netius/base/common.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -57,9 +48,9 @@ from .. import middleware -from .conn import * #@UnusedWildImport -from .poll import * #@UnusedWildImport -from .asynchronous import * #@UnusedWildImport +from .conn import * # @UnusedWildImport +from .poll import * # @UnusedWildImport +from .asynchronous import * # @UnusedWildImport NAME = "netius" """ The global infra-structure name to be used in the @@ -77,7 +68,7 @@ sys.version_info[1], sys.version_info[2], sys.version_info[3], - sys.platform + sys.platform, ) """ Extra system information containing some of the details of the technical platform that is running the system, this @@ -99,8 +90,11 @@ development like environment as it shows critical information about the system internals that may expose the system """ -IDENTIFIER = IDENTIFIER_LONG if config._is_devel() else\ - IDENTIFIER_TINY if config._is_secure() else IDENTIFIER_SHORT +IDENTIFIER = ( + IDENTIFIER_LONG + if config._is_devel() + else IDENTIFIER_TINY if config._is_secure() else IDENTIFIER_SHORT +) """ The identifier that may be used to identify an user agent or service running under the current platform, this string should comply with the typical structure for such values, @@ -130,12 +124,7 @@ already present in the hash table, this error may be safely ignored as it does not represent a threat """ -POLL_ORDER = ( - EpollPoll, - KqueuePoll, - PollPoll, - SelectPoll -) +POLL_ORDER = (EpollPoll, KqueuePoll, PollPoll, SelectPoll) """ The order from which the poll methods are going to be selected from the fastest to the slowest, in case no explicit poll method is defined for a base service they are selected @@ -146,7 +135,7 @@ errno.ECONNRESET, errno.EPIPE, WSAECONNABORTED, - WSAECONNRESET + WSAECONNRESET, ) """ List that contain the various connection error states that should not raise any extra logging information because even though @@ -158,15 +147,12 @@ errno.EPERM, errno.ENOENT, errno.EINPROGRESS, - WSAEWOULDBLOCK + WSAEWOULDBLOCK, ) """ List containing the complete set of error that represent non ready operations in a non blocking socket """ -SSL_SILENT_ERRORS = ( - ssl.SSL_ERROR_EOF, - ssl.SSL_ERROR_ZERO_RETURN -) +SSL_SILENT_ERRORS = (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN) """ The list containing the errors that should be silenced while still making the connection dropped as they are expected to occur and should not be considered an exception """ @@ -174,22 +160,20 @@ SSL_VALID_ERRORS = ( ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, - SSL_ERROR_CERT_ALREADY_IN_HASH_TABLE + SSL_ERROR_CERT_ALREADY_IN_HASH_TABLE, ) """ The list containing the valid errors for the handshake operation of the SSL connection establishment """ SSL_ERROR_NAMES = { - ssl.SSL_ERROR_WANT_READ : "SSL_ERROR_WANT_READ", - ssl.SSL_ERROR_WANT_WRITE : "SSL_ERROR_WANT_WRITE", - SSL_ERROR_CERT_ALREADY_IN_HASH_TABLE : "SSL_ERROR_CERT_ALREADY_IN_HASH_TABLE" + ssl.SSL_ERROR_WANT_READ: "SSL_ERROR_WANT_READ", + ssl.SSL_ERROR_WANT_WRITE: "SSL_ERROR_WANT_WRITE", + SSL_ERROR_CERT_ALREADY_IN_HASH_TABLE: "SSL_ERROR_CERT_ALREADY_IN_HASH_TABLE", } """ The dictionary containing the association between the various SSL errors and their string representation """ -SSL_VALID_REASONS = ( - "CERT_ALREADY_IN_HASH_TABLE", -) +SSL_VALID_REASONS = ("CERT_ALREADY_IN_HASH_TABLE",) """ The list containing the valid reasons for the handshake operation of the SSL connection establishment """ @@ -254,7 +238,7 @@ "TICK", "READ", "WRITE", - "ERROR" + "ERROR", ) """ Sequence that contains the various strings associated with the various states for the base service, this may be used to @@ -294,8 +278,11 @@ SSL_CER_PATH = os.path.join(EXTRAS_PATH, "net.cer") SSL_CA_PATH = os.path.join(EXTRAS_PATH, "net.ca") SSL_DH_PATH = os.path.join(EXTRAS_PATH, "dh.pem") -if not os.path.exists(SSL_CA_PATH): SSL_CA_PATH = None -if not os.path.exists(SSL_DH_PATH): SSL_DH_PATH = None +if not os.path.exists(SSL_CA_PATH): + SSL_CA_PATH = None +if not os.path.exists(SSL_DH_PATH): + SSL_DH_PATH = None + class AbstractBase(observer.Observable): """ @@ -317,7 +304,7 @@ class AbstractBase(observer.Observable): should be used to provide compatibility with protocol and transports used by the new API """ - def __init__(self, name = None, handlers = None, *args, **kwargs): + def __init__(self, name=None, handlers=None, *args, **kwargs): observer.Observable.__init__(self, *args, **kwargs) cls = self.__class__ poll = cls.test_poll() @@ -374,7 +361,7 @@ def __init__(self, name = None, handlers = None, *args, **kwargs): self.set_state(STATE_STOP) @classmethod - def test_poll(cls, preferred = None): + def test_poll(cls, preferred=None): # sets the initial selected variable with the unselected # (invalid) value so that at lease one selection must be # done in order for this method to succeed @@ -385,50 +372,57 @@ def test_poll(cls, preferred = None): # the current situation, either the preferred poll method or # the most performant one in case it's not possible for poll in POLL_ORDER: - if not poll.test(): continue - if not selected: selected = poll - if not preferred: break + if not poll.test(): + continue + if not selected: + selected = poll + if not preferred: + break name = poll.name() - if not name == preferred: continue + if not name == preferred: + continue selected = poll break # in case no polling method was selected must raise an exception # indicating that no valid polling mechanism is available - if not selected: raise errors.NetiusError( - "No valid poll mechanism available" - ) + if not selected: + raise errors.NetiusError("No valid poll mechanism available") # returns the selected polling mechanism class to the caller method # as expected by the current method return selected @classmethod - def get_loop(cls, compat = False, asyncio = False): + def get_loop(cls, compat=False, asyncio=False): loop = cls.get_asyncio() if asyncio else None - loop = loop or cls.get_main(compat = compat) + loop = loop or cls.get_main(compat=compat) return loop @classmethod - def get_main(cls, compat = False): + def get_main(cls, compat=False): return cls._MAIN_C if compat else cls._MAIN @classmethod def get_asyncio(cls): asyncio = asynchronous.get_asyncio() - if not asyncio: return None + if not asyncio: + return None policy = asyncio.get_event_loop_policy() - if not policy._local._loop: return None + if not policy._local._loop: + return None return asyncio.get_event_loop() @classmethod - def set_main(cls, instance, set_compat = True): + def set_main(cls, instance, set_compat=True): compat = compat_loop(instance) cls._MAIN = instance cls._MAIN_C = compat - if not set_compat: return + if not set_compat: + return asyncio = asynchronous.get_asyncio() - if not asyncio: return + if not asyncio: + return # runs a series of patches in the current asyncio # infra-structure to make sure that it's ready to @@ -441,17 +435,19 @@ def set_main(cls, instance, set_compat = True): asyncio.set_event_loop(compat) @classmethod - def unset_main(cls, set_compat = True): - cls.set_main(None, set_compat = set_compat) + def unset_main(cls, set_compat=True): + cls.set_main(None, set_compat=set_compat) @classmethod def patch_asyncio(cls): asyncio = asynchronous.get_asyncio() - if not asyncio: return - if hasattr(asyncio, "_patched"): return + if not asyncio: + return + if hasattr(asyncio, "_patched"): + return if hasattr(asyncio.tasks, "_PyTask"): - asyncio.Task = asyncio.tasks._PyTask #@UndefinedVariable - asyncio.tasks.Task = asyncio.tasks._PyTask #@UndefinedVariable + asyncio.Task = asyncio.tasks._PyTask # @UndefinedVariable + asyncio.tasks.Task = asyncio.tasks._PyTask # @UndefinedVariable asyncio._patched = True @classmethod @@ -464,7 +460,8 @@ def waitpid(cls, pid): # needs to verify if an os error is raised with # the value 3 (interrupted system call) as python # does not handle these errors correctly - if error.errno == 4: continue + if error.errno == 4: + continue raise def destroy(self): @@ -472,14 +469,15 @@ def destroy(self): # iterates over the complete set of sockets in the connections # map to properly close them (avoids any leak of resources) - for _socket in self.connections_m: _socket.close() + for _socket in self.connections_m: + _socket.close() # clears some of the internal structure so that they don't # get called any longer (as expected) self.connections_m.clear() self.callbacks_m.clear() - def call_safe(self, callable, args = [], kwargs = {}): + def call_safe(self, callable, args=[], kwargs={}): """ Calls the provided callable object using a safe strategy meaning that in case there's an exception raised in the @@ -511,13 +509,14 @@ def call_safe(self, callable, args = [], kwargs = {}): self.warning(exception) self.log_stack() - def wait_event(self, callable, name = None): + def wait_event(self, callable, name=None): # tries to retrieve the list of binds for the event # to be "waited" for, this list should contain the # complete list of callables to be called upon the # event notification/trigger binds = self._events.get(name, []) - if callable in binds: return + if callable in binds: + return # adds the callable to the list of binds for the event # the complete set of callables will be called whenever @@ -525,12 +524,13 @@ def wait_event(self, callable, name = None): binds.append(callable) self._events[name] = binds - def unwait_event(self, callable, name = None): + def unwait_event(self, callable, name=None): # tries to retrieve the list of binds for the event # and verifies that the callable is present on them # and if that's not the case ignores the operation binds = self._events.get(name, None) - if not binds or not callable in binds: return + if not binds or not callable in binds: + return # removes the callable from the binds list so that # it's no longer going to be called @@ -538,39 +538,35 @@ def unwait_event(self, callable, name = None): # verifies if the binds list is still valid deleting # it from the map of events otherwise - if binds: self._events[name] = binds - else: del self._events[name] + if binds: + self._events[name] = binds + else: + del self._events[name] def delay( - self, - callable, - timeout = None, - immediately = False, - verify = False, - safe = False + self, callable, timeout=None, immediately=False, verify=False, safe=False ): # in case the safe flag is set and the thread trying to add # delayed elements is not the main the proper (safe) method # is used meaning a safe execution is targeted if safe and not self.is_main(): return self.delay_s( - callable, - timeout = timeout, - immediately = immediately, - verify = verify + callable, timeout=timeout, immediately=immediately, verify=verify ) # in case the legacy module is no longer defined (probably # at exit execution) then returns immediately as it's not # possible to proceed with this execution - if not legacy: return + if not legacy: + return # creates the original target value with a zero value (forced # execution in next tick) in case the timeout value is set the # value is incremented to the current time, then created the # callable original tuple with the target (time) and the callable target = -1 if immediately else 0 - if timeout: target = time.time() + timeout + if timeout: + target = time.time() + timeout callable_o = (target, callable) callable_o = legacy.orderable(callable_o) @@ -578,7 +574,8 @@ def delay( # is already inserted in the list of delayed operations in # case it does returns immediately to avoid duplicated values is_duplicate = verify and callable_o in self._delayed_o - if is_duplicate: return + if is_duplicate: + return # creates the list that is going to be used to populate the # options to be used by the calling tuple @@ -602,12 +599,7 @@ def delay( return callable_t def delay_s( - self, - callable, - timeout = None, - immediately = True, - verify = False, - wakeup = True + self, callable, timeout=None, immediately=True, verify=False, wakeup=True ): """ Safe version of the delay operation to be used to insert a callable @@ -644,13 +636,16 @@ def delay_s( # the delayed (next) list is only going to be joined/merged with delay # operations and list on the next tick (through the merge operation) self._delayed_l.acquire() - try: self._delayed_n.append(next) - finally: self._delayed_l.release() + try: + self._delayed_n.append(next) + finally: + self._delayed_l.release() # in case the wakeup flag is set this delay operation should have # been called from a different thread and the event loop should # awaken as soon as possible to handle the event - if wakeup: self.wakeup() + if wakeup: + self.wakeup() def delay_m(self): """ @@ -661,17 +656,15 @@ def delay_m(self): # verifies if the delay next list is not valid or empty and if that's # the case returns immediately as there's nothing to be merged - if not self._delayed_n: return + if not self._delayed_n: + return # iterates over the complete set of next elements in the delay next list # and schedules them as delay for the next tick execution for next in self._delayed_n: callable, timeout, immediately, verify = next self.delay( - callable, - timeout = timeout, - immediately = immediately, - verify = verify + callable, timeout=timeout, immediately=immediately, verify=verify ) # deletes the complete set of elements present in the delay next list, this @@ -679,13 +672,7 @@ def delay_m(self): del self._delayed_n[:] def ensure( - self, - coroutine, - args = [], - kwargs = {}, - thread = None, - future = None, - immediately = True + self, coroutine, args=[], kwargs={}, thread=None, future=None, immediately=True ): """ Main method for the queuing/startup of an asynchronous coroutine @@ -736,7 +723,8 @@ def ensure( is_coroutine = asynchronous.is_coroutine(coroutine) is_coroutine_object = asynchronous.is_coroutine_object(coroutine) is_defined = is_coroutine or is_coroutine_object - if thread == None: thread = False if is_defined else True + if thread == None: + thread = False if is_defined else True # verifies if a future variable is meant to be re-used # or if instead a new one should be created for the new @@ -764,9 +752,12 @@ def coroutine(future, *args, **kwargs): # operation from the "parent" future to the child one, this # should also close the associated generator def cleanup(future): - if not future.cancelled(): return - if not hasattr(future, "child"): return - if not future.child: return + if not future.cancelled(): + return + if not hasattr(future, "child"): + return + if not future.child: + return future.child.cancel() # adds the cleanup function as a done callback so that whenever @@ -795,22 +786,27 @@ def cleanup(future): # will be used for the control of the execution, notice that # the future is only passed in case the coroutine has been # determined to be receiving the future as first argument - if is_future: sequence = coroutine(future, *args, **kwargs) - else: sequence = coroutine(*args, **kwargs) + if is_future: + sequence = coroutine(future, *args, **kwargs) + else: + sequence = coroutine(*args, **kwargs) # calls the ensure generator method so that the provided sequence # gets properly "normalized" into the expected generator structure # in case the normalization is not possible a proper exception is # raised indicating the "critical" problem is_generator, sequence = asynchronous.ensure_generator(sequence) - if not is_generator: raise errors.AssertionError("Expected generator") + if not is_generator: + raise errors.AssertionError("Expected generator") # creates the callable that is going to be used to call # the coroutine with the proper future variable as argument # note that in case the thread mode execution is enabled the # callable is going to be executed on a different thread - if thread: callable = lambda f = future: self.texecute(step, [f]) - else: callable = lambda f = future: step(f) + if thread: + callable = lambda f=future: self.texecute(step, [f]) + else: + callable = lambda f=future: step(f) # creates the function that will be used to step through the # various elements in the sequence created from the calling of @@ -864,10 +860,12 @@ def step(_future): # and then breaks the loop, notice that if there's an # exception raised in the middle of the generator iteration # it's set on the future (indirect notification) - try: value = next(sequence) + try: + value = next(sequence) except StopIteration as exception: result = exception.args[0] if exception.args else None - if future.running(): future.set_result(result) + if future.running(): + future.set_result(result) break except BaseException as exception: future.set_exception(exception) @@ -894,6 +892,7 @@ def step(_future): # so that the event loop on the main thread gets unblocked and # the proper partial value handling is performed (always on main thread) if thread: + def handler(): future.partial(value) callable() @@ -908,10 +907,10 @@ def handler(): # delays the execution of the callable so that it is executed # immediately if possible (event on the same iteration) - self.delay(callable, immediately = immediately) + self.delay(callable, immediately=immediately) return future - def resolve_hostname(self, hostname, type = "a"): + def resolve_hostname(self, hostname, type="a"): """ Resolve the provided hostname according to the provided type resolution. The resolution process itself is asynchronous and @@ -941,27 +940,16 @@ def handler(response): future.set_result(address) - netius.clients.DNSClient.query_s( - hostname, - type = type, - callback = handler - ) + netius.clients.DNSClient.query_s(hostname, type=type, callback=handler) return future def run_forever(self): # starts the current event loop, this is a blocking operation until # the the stop method is called to unblock the loop - self.forever(env = False) + self.forever(env=False) - def run_coroutine( - self, - coroutine, - args = [], - kwargs = {}, - thread = None, - close = None - ): + def run_coroutine(self, coroutine, args=[], kwargs={}, thread=None, close=None): # creates the callback function that is going to be called when # the future associated with the provided ensure context gets # finished (on done callback) @@ -978,12 +966,11 @@ def cleanup(future): # ensures that the provided coroutine get executed under a new # context and retrieves the resulting future - future = self.ensure( - coroutine, - args = args, - kwargs = kwargs, - thread = thread - ) if is_coroutine else coroutine + future = ( + self.ensure(coroutine, args=args, kwargs=kwargs, thread=thread) + if is_coroutine + else coroutine + ) # defines the cleanup operation (loop stop) as the target for the # done operation on the future (allows cleanup) @@ -998,27 +985,31 @@ def cleanup(future): # execution and returns the control flow immediately with # the future's result, to be used by the caller exception = future.exception() - if not exception: return future.result() + if not exception: + return future.result() # raises the exception to the upper layers so that it's properly # handled by them, this is the expected behaviour by this sync # execution mode of the coroutine inside an event loop raise exception - def wakeup(self, force = False): + def wakeup(self, force=False): # verifies if this is the main thread and if that's not the case # and the force flag is not set ignore the wakeup operation, avoiding # extra usage of resources (not required) - if self.is_main() and not force: return + if self.is_main() and not force: + return # makes sure that the the notify pool is started (required for proper # event notification) and then runs the notification process, should # "wake" the main event loop as soon as possible - if force: self.nensure() - if not self.npool: return + if force: + self.nensure() + if not self.npool: + return self.npool.notify() - def sleep(self, timeout, future = None): + def sleep(self, timeout, future=None): # verifies if a future variable is meant to be re-used # or if instead a new one should be created for the new # sleep operation to be executed @@ -1031,10 +1022,10 @@ def sleep(self, timeout, future = None): # delays the execution of the callable so that it is executed # after the requested amount of timeout, note that the resolution # of the event loop will condition the precision of the timeout - self.delay(callable, timeout = timeout) + self.delay(callable, timeout=timeout) return future - def wait(self, event, timeout = None, future = None): + def wait(self, event, timeout=None, future=None): # verifies if a future variable is meant to be re-used # or if instead a new one should be created for the new # sleep operation to be executed @@ -1044,21 +1035,23 @@ def wait(self, event, timeout = None, future = None): # the final value of the future variable, the result # set in the future represents the payload of the event def callable(data): - if future.cancelled(): return + if future.cancelled(): + return future.set_result(data) # creates the callable that is going to be called in case # the timeout has been reached, this avoids constant waiting # for an event to happen (dead lock) def canceler(): - if future.done(): return + if future.done(): + return future.cancel() # creates the callback function that is going to be called # whenever the future is completed (either error or success) # this should run the series of cleanup operations def cleanup(future): - self.unwait_event(callable, name = event) + self.unwait_event(callable, name=event) # registers the cleanup function for the done operation so that # the waiting for the event is canceled whenever the future is @@ -1068,30 +1061,32 @@ def cleanup(future): # waits the execution of the callable until the event with the # provided name is notified/triggered, the execution should be # triggered on the same event loop tick as the notification - self.wait_event(callable, name = event) + self.wait_event(callable, name=event) # in case a valid timeout is set schedules the canceler operation # to be performed (to unblock the waiting element) - if timeout: self.delay(canceler, timeout = timeout) + if timeout: + self.delay(canceler, timeout=timeout) # returns the provided future or a new one in case none has been # provided, this will be used for proper event registration return future - def notify(self, event, data = None): + def notify(self, event, data=None): # adds the event with the provided name to the list of notifications # that are going to be processed in the current tick operation self._notified.append((event, data)) # in case this is considered to be the main thread there no need to # proceed with the task pool notification process (expensive) - if self.is_main(): return + if self.is_main(): + return # runs the wakeup operation making sure that as soon as possible the # main event loop gets unblocked for event processing self.wakeup() - def load(self, full = False): + def load(self, full=False): """ Starts the loading process for the current engine, this should be a singleton (run once) operation to be executed once per instance. @@ -1109,7 +1104,8 @@ def load(self, full = False): # in case the current structure is considered/marked as already loaded # there's no need to continue with the loading execution (returns immediately) - if self._loaded: return + if self._loaded: + return # calls the boot hook responsible for the initialization of the various # structures of the base system, note that is going to be called once @@ -1142,7 +1138,7 @@ def load(self, full = False): # will be done after this first call to the loading (no duplicates) self._loaded = True - def unload(self, full = True): + def unload(self, full=True): """ Unloads the structures associated with the current engine, so that the state of the current engine is reversed to the original one. @@ -1160,11 +1156,13 @@ def unload(self, full = True): # verifies if the current structure is considered/marked as already # "unloaded", if that's the case returns the control flow immediately # as there's nothing pending to be (undone) - if not self._loaded: return + if not self._loaded: + return # triggers the operation that will start the unloading process of the # logging infra-structure of the current system - if full: self.unload_logging() + if full: + self.unload_logging() # runs the unbind operation for the signals so that no side effects # occur while the unloading is going to take place @@ -1184,10 +1182,11 @@ def boot(self): def welcome(self): pass - def load_logging(self, level = logging.DEBUG, format = LOG_FORMAT, unique = False): + def load_logging(self, level=logging.DEBUG, format=LOG_FORMAT, unique=False): # verifies if there's a logger already set in the current service # if that's the case ignores the call no double reloading allowed - if self.logger: return + if self.logger: + return # normalizes the provided level value so that it represents # a proper and understandable value, then starts the formatter @@ -1195,7 +1194,7 @@ def load_logging(self, level = logging.DEBUG, format = LOG_FORMAT, unique = Fals # identifier to be used in the logger retrieval/identification level = self._level(level) formatter = logging.Formatter(format) - identifier = self.get_id(unique = unique) + identifier = self.get_id(unique=unique) # retrieves the logger that is going to be according to the # decided identifier and then verifies that the counter value @@ -1206,7 +1205,8 @@ def load_logging(self, level = logging.DEBUG, format = LOG_FORMAT, unique = Fals counter = self.logger._counter if hasattr(self.logger, "_counter") else 0 is_new = counter == 0 self.logger._counter = counter + 1 - if not is_new: return + if not is_new: + return # start the extra logging infrastructure (extra handlers) # and initializes the stream handlers with the proper level @@ -1221,14 +1221,16 @@ def load_logging(self, level = logging.DEBUG, format = LOG_FORMAT, unique = Fals self.logger.parent = None self.logger.setLevel(level) for handler in self.handlers: - if not handler: continue + if not handler: + continue self.logger.addHandler(handler) - def unload_logging(self, safe = True): + def unload_logging(self, safe=True): # verifies if there's a valid logger instance set in the # current service, in case there's not returns immediately # as there's nothing remaining to be done here - if not self.logger: return + if not self.logger: + return # updates the counter value for the logger and validates # that no more "clients" are using the logger so that it @@ -1236,12 +1238,14 @@ def unload_logging(self, safe = True): counter = self.logger._counter is_old = counter == 1 self.logger._counter = counter - 1 - if not is_old: return + if not is_old: + return # iterates over the complete set of handlers in the current # base element and removes them from the current logger for handler in self.handlers: - if not handler: continue + if not handler: + continue self.logger.removeHandler(handler) # in case the safe flag is set, iterates over the complete @@ -1249,7 +1253,8 @@ def unload_logging(self, safe = True): # from the current logger, this is required so that proper # handler unregistration is ensured even for complex scenarios for handler in self.logger.handlers if safe else (): - if not handler: continue + if not handler: + continue self.logger.removeHandler(handler) # closes the base stream handler as it's no longer going to @@ -1259,7 +1264,8 @@ def unload_logging(self, safe = True): # iterates over the complete set of (built) extra handlers # and runs the close operation for each of them, as they are # no longer considered required for logging purposes - for handler in self._extra_handlers: handler.close() + for handler in self._extra_handlers: + handler.close() # unset the logger reference in the current service so that # it's not possible to use it any longer @@ -1289,7 +1295,8 @@ def extra_logging(self, level, formatter): # defined and in case it's not returns immediately, otherwise # starts by converting the currently defined set of handlers into # a list so that it may be correctly manipulated (add handlers) - if not self.logging: return + if not self.logging: + return self.handlers = list(self.handlers) # iterates over the complete set of handler configuration in the @@ -1304,13 +1311,16 @@ def extra_logging(self, level, formatter): # "clones" the configuration dictionary and then removes the base # values so that they do not interfere with the building config = dict(config) - if "level" in config: del config["level"] - if "name" in config: del config["name"] + if "level" in config: + del config["level"] + if "name" in config: + del config["name"] # retrieves the proper building, skipping the current loop in case # it does not exits and then builds the new handler instance, setting # the proper level and formatter and then adding it to the set - if not hasattr(log, name + "_handler"): continue + if not hasattr(log, name + "_handler"): + continue builder = getattr(log, name + "_handler") handler = builder(**config) handler.setLevel(_level) @@ -1347,13 +1357,15 @@ def level_logging(self, level): # iterates over the complete set of attached handlers to # update their respective logging level - for handler in self.handlers: handler.setLevel(level) + for handler in self.handlers: + handler.setLevel(level) - def load_diag(self, env = True): + def load_diag(self, env=True): # verifies if the diagnostics "feature" has been requested # for the current infra-structure and if that's not the case # returns the control flow immediately to the caller - if not self.diag: return + if not self.diag: + return # runs the import operations for the diag module, note that # this must be performed locally no avoid any unwanted behavior @@ -1371,7 +1383,7 @@ def load_diag(self, env = True): # server, taking into account if the env flag is set server = self.get_env("DIAG_SERVER", "netius") if env else "netius" host = self.get_env("DIAG_HOST", "127.0.0.1") if env else "127.0.0.1" - port = self.get_env("DIAG_PORT", 5050, cast = int) if env else 5050 + port = self.get_env("DIAG_PORT", 5050, cast=int) if env else 5050 # creates the application object that is going to be # used for serving the diagnostics app @@ -1384,15 +1396,10 @@ def load_diag(self, env = True): # starts the "serving" procedure of it under a new thread # to avoid blocking the current context of execution self.diag_app.serve( - server = server, - host = host, - port = port, - diag = False, - threaded = True, - conf = False + server=server, host=host, port=port, diag=False, threaded=True, conf=False ) - def load_middleware(self, suffix = "Middleware"): + def load_middleware(self, suffix="Middleware"): # iterates over the complete set of string that define the middleware # that is going to be loaded and executes the loading process for name in self.middleware: @@ -1414,7 +1421,8 @@ def unload_middleware(self): # iterates over the complete set of middleware instance to stop # them (close internal structures) and then removes the middleware # list so that they don't get used any longer - for middleware_i in self.middleware_l: middleware_i.stop() + for middleware_i in self.middleware_l: + middleware_i.stop() del self.middleware_l[:] def register_middleware(self, middleware_c, *args, **kwargs): @@ -1441,33 +1449,42 @@ def call_middleware(self, name, *args, **kwargs): def bind_signals( self, - signals = ( + signals=( signal.SIGINT, signal.SIGTERM, - signal.SIGHUP if hasattr(signal, "SIGHUP") else None, #@UndefinedVariable - signal.SIGQUIT if hasattr(signal, "SIGQUIT") else None #@UndefinedVariable + signal.SIGHUP if hasattr(signal, "SIGHUP") else None, # @UndefinedVariable + ( + signal.SIGQUIT if hasattr(signal, "SIGQUIT") else None + ), # @UndefinedVariable ), - handler = None + handler=None, ): # creates the signal handler function that propagates the raising # of the system exit exception (proper logic is executed) and then # registers such handler for the (typical) sigterm signal - def base_handler(signum = None, frame = None): raise SystemExit() + def base_handler(signum=None, frame=None): + raise SystemExit() + for signum in signals: - if signum == None: continue - try: signal.signal(signum, handler or base_handler) - except Exception: self.debug("Failed to register %d handler" % signum) + if signum == None: + continue + try: + signal.signal(signum, handler or base_handler) + except Exception: + self.debug("Failed to register %d handler" % signum) def unbind_signals( self, - signals = ( + signals=( signal.SIGINT, signal.SIGTERM, - signal.SIGHUP if hasattr(signal, "SIGHUP") else None, #@UndefinedVariable - signal.SIGQUIT if hasattr(signal, "SIGQUIT") else None #@UndefinedVariable - ) + signal.SIGHUP if hasattr(signal, "SIGHUP") else None, # @UndefinedVariable + ( + signal.SIGQUIT if hasattr(signal, "SIGQUIT") else None + ), # @UndefinedVariable + ), ): - self.bind_signals(signals = signals, handler = signal.SIG_IGN) + self.bind_signals(signals=signals, handler=signal.SIG_IGN) def bind_env(self): """ @@ -1476,26 +1493,29 @@ def bind_env(self): """ self.level = self.get_env("LEVEL", self.level) - self.diag = self.get_env("DIAG", self.diag, cast = bool) - self.middleware = self.get_env("MIDDLEWARE", self.middleware, cast = list) - self.children = self.get_env("CHILD", self.children, cast = int) - self.children = self.get_env("CHILDREN", self.children, cast = int) + self.diag = self.get_env("DIAG", self.diag, cast=bool) + self.middleware = self.get_env("MIDDLEWARE", self.middleware, cast=list) + self.children = self.get_env("CHILD", self.children, cast=int) + self.children = self.get_env("CHILDREN", self.children, cast=int) self.logging = self.get_env("LOGGING", self.logging) self.poll_name = self.get_env("POLL", self.poll_name) - def forever(self, env = True): - if env: self.bind_env() + def forever(self, env=True): + if env: + self.bind_env() return self.start() def start(self): # in case the current instance is currently paused runs the # resume operation instead as that's the expected operation - if self.is_paused(): return self.resume() + if self.is_paused(): + return self.resume() # in case the event loop is already running then a new sub- # context based loop should be created in order to block the # current execution stack (as expected) - if self.is_running(): return self.block() + if self.is_running(): + return self.block() # re-builds the polling structure with the new name this # is required so that it's possible to change the polling @@ -1515,7 +1535,7 @@ def start(self): # opens the polling mechanism so that its internal structures # become ready for the polling cycle, the inverse operation # (close) should be performed as part of the cleanup - self.poll.open(timeout = self.poll_timeout) + self.poll.open(timeout=self.poll_timeout) # makes sure that the notify pool is created so that the event # notification (required for multi threaded environments) is created @@ -1539,7 +1559,10 @@ def start(self): # enters the main loop operation by printing a message # to the logger indicating this start, this stage # should block the thread until a stop call is made - self.debug("Starting '%s' service main loop (%.2fs) ..." % (self.name, self.poll_timeout)) + self.debug( + "Starting '%s' service main loop (%.2fs) ..." + % (self.name, self.poll_timeout) + ) self.debug("Using thread '%s' with TID '%d'" % (self.tname, self.tid)) self.debug("Using '%s' as polling mechanism" % poll_name) @@ -1550,26 +1573,33 @@ def start(self): def stop(self): # in case the current process is neither running nor # paused there's nothing pending to be done on stop - if not self.is_running() and not self.is_paused(): return + if not self.is_running() and not self.is_paused(): + return # in case the current loop is in pause state calls only # the finish operation otherwise sets the running flag # to false meaning that on the next event loop tick the # unloading process will be triggered - if self.is_paused(): self.finish() - else: self._running = False + if self.is_paused(): + self.finish() + else: + self._running = False # in case the current process is the parent in a pre-fork # environment raises the stop error to wakeup the process # from its current infinite loop for stop handling - if self.is_parent: raise errors.StopError() + if self.is_parent: + raise errors.StopError() def pause(self): self._running = False self._pausing = True def resume(self): - self.debug("Resuming '%s' service main loop (%.2fs) ..." % (self.name, self.poll_timeout)) + self.debug( + "Resuming '%s' service main loop (%.2fs) ..." + % (self.name, self.poll_timeout) + ) self.on_resume() self.main() @@ -1606,24 +1636,27 @@ def main(self): # external HTTP client) then this exception must be re-raised # to the upper layer (main event loop) so that it can be # properly handled to be able to exit the environment - if not self == Base.get_main(): raise + if not self == Base.get_main(): + raise except errors.PauseError: self.debug("Pausing '%s' service main loop" % self.name) self.set_state(STATE_PAUSE) self.on_pause() except BaseException as exception: self.error(exception) - self.log_stack(method = self.warning) + self.log_stack(method=self.warning) except: self.critical("Critical level loop exception raised") - self.log_stack(method = self.error) + self.log_stack(method=self.error) finally: - if self.is_paused(): return + if self.is_paused(): + return self.stop() self.finish() def is_main(self): - if not self.tid: return True + if not self.tid: + return True return threading.current_thread().ident == self.tid def is_running(self): @@ -1654,19 +1687,19 @@ def is_sub_error(self, socket): return self.poll.is_sub_error(socket) def sub_all(self, socket): - return self.poll.sub_all(socket, owner = self) + return self.poll.sub_all(socket, owner=self) def unsub_all(self, socket): return self.poll.unsub_all(socket) def sub_read(self, socket): - return self.poll.sub_read(socket, owner = self) + return self.poll.sub_read(socket, owner=self) def sub_write(self, socket): - return self.poll.sub_write(socket, owner = self) + return self.poll.sub_write(socket, owner=self) def sub_error(self, socket): - return self.poll.sub_error(socket, owner = self) + return self.poll.sub_error(socket, owner=self) def unsub_read(self, socket): return self.poll.unsub_read(socket) @@ -1677,7 +1710,7 @@ def unsub_write(self, socket): def unsub_error(self, socket): return self.poll.unsub_error(socket) - def cleanup(self, destroy = True): + def cleanup(self, destroy=True): # runs the unload operation for the current base container this should # unset/unload some of the components for this base infra-structure self.unload() @@ -1706,17 +1739,20 @@ def cleanup(self, destroy = True): # verifies if there's a valid (and open) notify pool, if that's # the case starts the stop process for it so that there's no # leaking of task descriptors and other structures - if self.npool: self.nstop() + if self.npool: + self.nstop() # verifies if there's a valid (and open) task pool, if that's # the case starts the stop process for it so that there's no # leaking of task descriptors and other structures - if self.tpool: self.tstop() + if self.tpool: + self.tstop() # verifies if there's a valid (and open) file pool, if that's # the case starts the stop process for it so that there's no # leaking of file descriptors and other structures - if self.fpool: self.fstop() + if self.fpool: + self.fstop() # creates a copy of the connections list because this structure # is going to be changed in the closing of the connection object @@ -1725,11 +1761,13 @@ def cleanup(self, destroy = True): # iterates over the complete set of connections currently # registered in the base structure and closes them so that # can no longer be used and are gracefully disconnected - for connection in connections: connection.close() + for connection in connections: + connection.close() # iterates over the complete set of sockets in the connections # map to properly close them (avoids any leak of resources) - for _socket in self.connections_m: _socket.close() + for _socket in self.connections_m: + _socket.close() # in case the current thread is the main one then in case the # instance set as global main is this one unsets the value @@ -1741,7 +1779,8 @@ def cleanup(self, destroy = True): # from an open poll system (memory leaks, etc.), note that this is # only performed in case the current base instance is the owner of # the poll that is going to be closed (works with containers) - if self.poll_owner: self.poll.close() + if self.poll_owner: + self.poll.close() # deletes some of the internal data structures created for the instance # and that are considered as they are considered to be no longer required @@ -1752,7 +1791,8 @@ def cleanup(self, destroy = True): # runs the destroy operation for the current instance, this should remove # the most obscure parts of the current instance - if destroy: self.destroy() + if destroy: + self.destroy() def loop(self): # iterates continuously while the running flag is set, once @@ -1767,7 +1807,8 @@ def loop(self): # in case running flag is disabled it's time to break the # cycle (just before the possible block) as it would imply # extra time before we could stop the event loop - if not self._running: break + if not self._running: + break # updates the current state to poll to indicate # that the base service is selecting the connections @@ -1804,10 +1845,12 @@ def block(self): # saves the current running state and then runs the loop again # restoring the same state at the end of the execution _running = self._running - try: self.loop() - finally: self._running = _running + try: + self.loop() + finally: + self._running = _running - def fork(self, timeout = 10): + def fork(self, timeout=10): # retrieves the reference to the parent class object # to be used for the class level operations cls = self.__class__ @@ -1818,17 +1861,21 @@ def fork(self, timeout = 10): # runs a series of validations to be able to verify # if the fork operation should really be performed - if not self.children: return True - if not self.children > 0: return True - if not hasattr(os, "fork"): return True - if self._forked: return True + if not self.children: + return True + if not self.children > 0: + return True + if not hasattr(os, "fork"): + return True + if self._forked: + return True # makes sure that no signal handlers exist for the parent # process, this is relevant to avoid immediate destruction # of the current process on premature signal self.unbind_signals() if hasattr(signal, "SIGUSR1"): - self.unbind_signals(signals = (signal.SIGUSR1,)) + self.unbind_signals(signals=(signal.SIGUSR1,)) # sets the initial PID value to the value of the current # master process as this is going to be used for child @@ -1846,10 +1893,11 @@ def fork(self, timeout = 10): # builds the inline function that takes the message to be # sent to the output pipe, normalizes it and sends it def pipe_send(message): - if not hasattr(signal, "SIGUSR1"): return + if not hasattr(signal, "SIGUSR1"): + return frame = legacy.bytes(message) + b"\n" os.write(pipeout, frame) - os.kill(ppid, signal.SIGUSR1) #@UndefinedVariable + os.kill(ppid, signal.SIGUSR1) # @UndefinedVariable # prints a debug operation about the operation that is # going to be performed for the forking @@ -1862,10 +1910,12 @@ def pipe_send(message): # iterates of the requested (number of children) to run # the concrete fork operation and fork the logic for _index in range(self.children): - pid = os.fork() #@UndefinedVariable + pid = os.fork() # @UndefinedVariable self._child = pid == 0 - if self._child: self.on_child(pipe = pipe_send) - if self._child: break + if self._child: + self.on_child(pipe=pipe_send) + if self._child: + break self._childs.append(pid) # sets the forked flag, meaning that the current process @@ -1874,7 +1924,8 @@ def pipe_send(message): # in case the current process is a child one an immediate # valid value should be returned (force logic continuation) - if self._child: return True + if self._child: + return True # prints a debug operation the finished forking operation self.debug("Finished forking children") @@ -1886,8 +1937,10 @@ def pipe_send(message): # registers for some of the common signals to be able to start # the process of stopping and joining with the child processes # in case there's a request to do so - def handler(signum = None, frame = None): self.stop() - self.bind_signals(handler = handler) + def handler(signum=None, frame=None): + self.stop() + + self.bind_signals(handler=handler) def callback(): # reads a line from the input pipe considering it to be a @@ -1898,17 +1951,18 @@ def callback(): # creates the pipe signal handler that is responsible for the # reading of the pipe information from the child process to # the parent process (as expected) - def pipe_handler(signum = None, frame = None): + def pipe_handler(signum=None, frame=None): # in case the current process is considered to be not # running then returns the control flow immediately # not possible to handle any command - if not self._running: return + if not self._running: + return # schedules the current clojure to be executed as soon as # possible and then forces the wakeup, because although we're # running on the main thread we're possible under a blocking # statement and so we need to wakeup the parent loop - self.delay_s(callback, immediately = True) + self.delay_s(callback, immediately=True) if hasattr(self, "_awaken") and not self._awaken: self._awaken = True raise errors.WakeupError() @@ -1916,7 +1970,7 @@ def pipe_handler(signum = None, frame = None): # in case the user signal is defined registers for it so that it's # possible to establish a communication between child and parent if hasattr(signal, "SIGUSR1"): - signal.signal(signal.SIGUSR1, pipe_handler) #@UndefinedVariable + signal.signal(signal.SIGUSR1, pipe_handler) # @UndefinedVariable # prints a debug operation the finished forking operation self.debug("Entering wait forever loop") @@ -1953,8 +2007,10 @@ def pipe_handler(signum = None, frame = None): # creates the catcher for the alarm signal so that a wakeup # can happen that kills the (possibly) stuck children - def catcher(signal, frame): raise errors.WakeupError() - signal.signal(signal.SIGALRM, catcher) #@UndefinedVariable + def catcher(signal, frame): + raise errors.WakeupError() + + signal.signal(signal.SIGALRM, catcher) # @UndefinedVariable # iterates over the complete set of child processes to join # them (master process responsibility) @@ -1965,7 +2021,9 @@ def catcher(signal, frame): raise errors.WakeupError() # registers the alarm for the remaining time until # the child process should be forcibly killed - signal.setitimer(signal.ITIMER_REAL, max(timeout, 0.15)) #@UndefinedVariable + signal.setitimer( + signal.ITIMER_REAL, max(timeout, 0.15) + ) # @UndefinedVariable try: # runs the waiting for the children PID (process to finish) @@ -1974,7 +2032,7 @@ def catcher(signal, frame): raise errors.WakeupError() cls.waitpid(pid) except errors.WakeupError: self.warning("Timeout reached killing PID '%d' with SIGKILL ..." % pid) - os.kill(pid, signal.SIGKILL) #@UndefinedVariable + os.kill(pid, signal.SIGKILL) # @UndefinedVariable cls.waitpid(pid) # decrements the timeout value by the time that was @@ -1983,7 +2041,7 @@ def catcher(signal, frame): raise errors.WakeupError() # resets the alarm as we've finished waiting for all of the # children processes, some may have been killed forcibly - signal.setitimer(signal.ITIMER_REAL, 0) #@UndefinedVariable + signal.setitimer(signal.ITIMER_REAL, 0) # @UndefinedVariable # calls the final (on) join method indicating that the complete # set of child processes have been join and that now only the @@ -2005,8 +2063,10 @@ def catcher(signal, frame): raise errors.WakeupError() def finalize(self): # verifies a series of conditions and raises a proper error in case # any of them is verified under the current state - if self._pausing: raise errors.PauseError("Pause state expected") - if self._running: raise errors.AssertionError("Not expected running") + if self._pausing: + raise errors.PauseError("Pause state expected") + if self._running: + raise errors.AssertionError("Not expected running") def ticks(self): # updates the current state value to the tick state indicating @@ -2027,57 +2087,66 @@ def ticks(self): # calls are called if the correct time has been reached self._delays() - def reads(self, reads, state = True): + def reads(self, reads, state=True): # in case the update state is requested updates the current loop # instance into the read state (debugging purposes) - if state: self.set_state(STATE_READ) + if state: + self.set_state(STATE_READ) # in case the concrete flag is set return immediately as the # concrete instance (eg: client, server) should implement the # concrete handling specifics for this event - if self._concrete: return + if self._concrete: + return # iterates over all of the read events and calls the proper on # read method handler to properly handle each event - for read in reads: self.on_read(read) + for read in reads: + self.on_read(read) - def writes(self, writes, state = True): + def writes(self, writes, state=True): # in case the update state is requested updates the current loop # instance into the write state (debugging purposes) - if state: self.set_state(STATE_WRITE) + if state: + self.set_state(STATE_WRITE) # in case the concrete flag is set return immediately as the # concrete instance (eg: client, server) should implement the # concrete handling specifics for this event - if self._concrete: return + if self._concrete: + return # iterates over all of the write events and calls the proper on # write method handler to properly handle each event - for write in writes: self.on_write(write) + for write in writes: + self.on_write(write) - def errors(self, errors, state = True): + def errors(self, errors, state=True): # in case the update state is requested updates the current loop # instance into the error state (debugging purposes) - if state: self.set_state(STATE_ERRROR) + if state: + self.set_state(STATE_ERRROR) # in case the concrete flag is set return immediately as the # concrete instance (eg: client, server) should implement the # concrete handling specifics for this event - if self._concrete: return + if self._concrete: + return # iterates over all of the error events and calls the proper on # error method handler to properly handle each event - for error in errors: self.on_error(error) + for error in errors: + self.on_error(error) def datagram( self, - family = socket.AF_INET, - type = socket.SOCK_DGRAM, - local_host = None, - local_port = None, - remote_host = None, - remote_port = None, - callback = None + family=socket.AF_INET, + type=socket.SOCK_DGRAM, + local_host=None, + local_port=None, + remote_host=None, + remote_port=None, + callback=None, ): """ Builds a datagram based connection for the provided family and @@ -2128,23 +2197,26 @@ def datagram( # in case both the local host and port are defined runs the bind # operation so that the current socket is st to listen for new # datagrams on the associated host and port - if local_host and local_port: _socket.bind((local_host, local_port)) + if local_host and local_port: + _socket.bind((local_host, local_port)) # verifies if both the host and the port are set and if that's the # case runs the connect (send bind) operation in the datagram socket # notice that this is not a "real" remote connection - if remote_host and remote_port: _socket.connect((remote_host, remote_port)) + if remote_host and remote_port: + _socket.connect((remote_host, remote_port)) # creates a new connection object representing the datagram socket # that has just been created to be used for upper level operations # and then immediately sets it as connected - connection = self.base_connection(_socket, datagram = True) + connection = self.base_connection(_socket, datagram=True) connection.open() connection.set_connected() # in case a callback is defined schedules its execution for the next # tick to avoid possible issues with same tick registration - if callback: self.delay(lambda: callback(connection, True), immediately = True) + if callback: + self.delay(lambda: callback(connection, True), immediately=True) # returns the connection to the caller method so that it may be used # for operation from now on (latter usage) @@ -2154,23 +2226,25 @@ def connect( self, host, port, - receive_buffer = None, - send_buffer = None, - ssl = False, - key_file = None, - cer_file = None, - ca_file = None, - ca_root = True, - ssl_verify = False, - family = socket.AF_INET, - type = socket.SOCK_STREAM, - callback = None, - env = True + receive_buffer=None, + send_buffer=None, + ssl=False, + key_file=None, + cer_file=None, + ca_file=None, + ca_root=True, + ssl_verify=False, + family=socket.AF_INET, + type=socket.SOCK_STREAM, + callback=None, + env=True, ): # runs a series of pre-validations on the provided parameters, raising # exceptions in case they do not comply with expected values - if not host: raise errors.NetiusError("Invalid host for connect operation") - if not port: raise errors.NetiusError("Invalid port for connect operation") + if not host: + raise errors.NetiusError("Invalid host for connect operation") + if not port: + raise errors.NetiusError("Invalid port for connect operation") # tries to retrieve some of the environment variable related values # so that some of these values are accessible via an external environment @@ -2178,11 +2252,13 @@ def connect( key_file = self.get_env("KEY_FILE", key_file) if env else key_file cer_file = self.get_env("CER_FILE", cer_file) if env else cer_file ca_file = self.get_env("CA_FILE", ca_file) if env else ca_file - ca_root = self.get_env("CA_ROOT", ca_root, cast = bool) if env else ca_root - ssl_verify = self.get_env("SSL_VERIFY", ssl_verify, cast = bool) if env else ssl_verify - key_file = self.get_env("KEY_DATA", key_file, expand = True) if env else key_file - cer_file = self.get_env("CER_DATA", cer_file, expand = True) if env else cer_file - ca_file = self.get_env("CA_DATA", ca_file, expand = True) if env else ca_file + ca_root = self.get_env("CA_ROOT", ca_root, cast=bool) if env else ca_root + ssl_verify = ( + self.get_env("SSL_VERIFY", ssl_verify, cast=bool) if env else ssl_verify + ) + key_file = self.get_env("KEY_DATA", key_file, expand=True) if env else key_file + cer_file = self.get_env("CER_DATA", cer_file, expand=True) if env else cer_file + ca_file = self.get_env("CA_DATA", ca_file, expand=True) if env else ca_file # ensures that the proper socket family is defined in case the # requested host value is unix socket oriented, this step greatly @@ -2213,36 +2289,28 @@ def connect( # in case the SSL option is enabled the socket should be wrapped into # a proper SSL socket interface so that it may be operated accordingly - if ssl: _socket = self._ssl_wrap( - _socket, - key_file = key_file, - cer_file = cer_file, - ca_file = ca_file, - ca_root = ca_root, - server = False, - ssl_verify = ssl_verify, - server_hostname = host - ) + if ssl: + _socket = self._ssl_wrap( + _socket, + key_file=key_file, + cer_file=cer_file, + ca_file=ca_file, + ca_root=ca_root, + server=False, + ssl_verify=ssl_verify, + server_hostname=host, + ) # sets a series of options in the socket to ensure that it's # prepared for the client operations to be performed _socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) _socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if is_inet: _socket.setsockopt( - socket.IPPROTO_TCP, - socket.TCP_NODELAY, - 1 - ) - if receive_buffer: _socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_RCVBUF, - receive_buffer - ) - if send_buffer: _socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_SNDBUF, - send_buffer - ) + if is_inet: + _socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if receive_buffer: + _socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, receive_buffer) + if send_buffer: + _socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, send_buffer) self._socket_keepalive(_socket) # constructs the address tuple taking into account if the @@ -2253,16 +2321,14 @@ def connect( # creates the connection object using the typical constructor # and then sets the SSL host (for verification) if the verify # SSL option is defined (secured and verified connection) - connection = self.base_connection(_socket, address, ssl = ssl) - if ssl_verify: connection.ssl_host = host + connection = self.base_connection(_socket, address, ssl=ssl) + if ssl_verify: + connection.ssl_host = host # schedules the underlying non blocking connect operation to # be executed as soon as possible to start the process of # connecting for the current connection - self.delay( - lambda: self._connect(connection), - immediately = True - ) + self.delay(lambda: self._connect(connection), immediately=True) def on_close(conection): callback and callback(connection, False) @@ -2273,8 +2339,10 @@ def on_connect(conection): # in case there's a callback defined for the connection establishment # then registers such callback for the connect event in the connection - if callback: connection.bind("connect", on_connect, oneshot = True) - if callback: connection.bind("close", on_close, oneshot = True) + if callback: + connection.bind("connect", on_connect, oneshot=True) + if callback: + connection.bind("close", on_close, oneshot=True) # returns the "final" connection, that is now scheduled for connect # to the caller method, it may now be used for operations @@ -2293,9 +2361,12 @@ def pregister(self, pool): # object that is notified for each operation associated with # the pool, (primary communication mechanism) eventfd = pool.eventfd() - if not eventfd: self.warning("Starting pool without eventfd") - if not eventfd: return - if not self.poll: return + if not eventfd: + self.warning("Starting pool without eventfd") + if not eventfd: + return + if not self.poll: + return self.sub_read(eventfd) # creates the callback clojure around the current context @@ -2326,14 +2397,18 @@ def punregister(self, pool): # the pool an in case it exists unsubscribes # from it under the current polling system eventfd = pool.eventfd() - if not eventfd: self.warning("Stopping pool without eventfd") - if not eventfd: return - if not self.poll: return + if not eventfd: + self.warning("Stopping pool without eventfd") + if not eventfd: + return + if not self.poll: + return self.unsub_read(eventfd) # verifies if the callback operation in the event fd is defined # for the pool and if that's not the case returns immediately - if not hasattr(pool, "_callback"): return + if not hasattr(pool, "_callback"): + return # unregisters from a callback operation in the event fd so that # no more events are handled by the notifier, this is expected @@ -2351,58 +2426,63 @@ def punregister(self, pool): def pcallback(self, event, socket, pool): # runs a series of pre-validations on the callback so that # no operations is performed for such conditions - if not pool: return - if not event == "read": return + if not pool: + return + if not event == "read": + return # runs the de-notify operation clearing the pool from any # possible extra notification (avoid extra counter) pool.denotify() def nensure(self): - if self.npool: return + if self.npool: + return self.nstart() def nstart(self): - if self.npool: return + if self.npool: + return self.npool = netius.pool.NotifyPool() self.npool.start() self.pregister(self.npool) def nstop(self): - if not self.npool: return + if not self.npool: + return self.punregister(self.npool) self.npool.stop() def tensure(self): - if self.tpool: return + if self.tpool: + return self.tstart() def tstart(self): - if self.tpool: return + if self.tpool: + return self.tpool = netius.pool.TaskPool() self.tpool.start() self.pregister(self.tpool) def tstop(self): - if not self.tpool: return + if not self.tpool: + return self.punregister(self.tpool) self.tpool.stop() - def texecute(self, callable, args = [], kwargs = {}, callback = None): + def texecute(self, callable, args=[], kwargs={}, callback=None): self.tensure() - self.tpool.execute( - callable, - args = args, - kwargs = kwargs, - callback = callback - ) + self.tpool.execute(callable, args=args, kwargs=kwargs, callback=callback) def files(self): - if not self.fpool: return + if not self.fpool: + return events = self.fpool.pop_all() for event in events: callback = event[-1] - if not callback: continue + if not callback: + continue callback(*event[1:-1]) def fopen(self, *args, **kwargs): @@ -2422,14 +2502,16 @@ def fwrite(self, *args, **kwargs): return self.fpool.write(*args, **kwargs) def fensure(self): - if self.fpool: return + if self.fpool: + return self.fstart() def fstart(self): # verifies if there's an already open file pool for # the current system and if that's not the case creates # a new one and starts it's thread cycle - if self.fpool: return + if self.fpool: + return self.fpool = netius.pool.FilePool() self.fpool.start() self.pregister(self.fpool) @@ -2439,7 +2521,8 @@ def fstop(self): # if that's the case initializes the stopping of # such system, note that this is blocking call as # all of the thread will be joined under it - if not self.fpool: return + if not self.fpool: + return self.punregister(self.fpool) self.fpool.stop() @@ -2447,12 +2530,12 @@ def on_connection_c(self, connection): # prints some debug information about the connection that has # just been created (for possible debugging purposes) self.debug( - "Connection '%s' %s from '%s' created" % - (connection.id, connection.address, connection.owner.name) + "Connection '%s' %s from '%s' created" + % (connection.id, connection.address, connection.owner.name) ) self.debug( - "There are %d connections for '%s'" % - (len(connection.owner.connections), connection.owner.name) + "There are %d connections for '%s'" + % (len(connection.owner.connections), connection.owner.name) ) # triggers the event notifying any listener about the new connection @@ -2463,12 +2546,12 @@ def on_connection_d(self, connection): # prints some debug information about the connection # that has just been scheduled for destruction self.debug( - "Connection '%s' %s from '%s' deleted" % - (connection.id, connection.address, connection.owner.name) + "Connection '%s' %s from '%s' deleted" + % (connection.id, connection.address, connection.owner.name) ) self.debug( - "There are %d connections for '%s'" % - (len(connection.owner.connections), connection.owner.name) + "There are %d connections for '%s'" + % (len(connection.owner.connections), connection.owner.name) ) # triggers the event notifying any listener about the @@ -2483,8 +2566,7 @@ def on_stream_c(self, stream): # prints some debug information on the stream that has just been # created (may be used for debugging purposes) self.debug( - "Stream '%s' from '%s' created" % - (stream.identifier, connection.owner.name) + "Stream '%s' from '%s' created" % (stream.identifier, connection.owner.name) ) # notifies any listener of the stream created event about the @@ -2499,8 +2581,7 @@ def on_stream_d(self, stream): # prints some debug information on the stream that has just been # deleted (may be used for debugging purposes) self.debug( - "Stream '%s' from '%s' deleted" % - (stream.identifier, connection.owner.name) + "Stream '%s' from '%s' deleted" % (stream.identifier, connection.owner.name) ) # notifies any listener of the stream deleted event about the @@ -2513,10 +2594,10 @@ def on_fork(self): def on_join(self): self.trigger("join", self) - def on_child(self, pipe = None): + def on_child(self, pipe=None): # triggers the child event indicating that a new child has been # created and than any callback operation may now be performed - self.trigger("child", self, pipe = pipe) + self.trigger("child", self, pipe=pipe) # creates a new seed value from a pseudo random value and # then adds this new value as the base for randomness in the @@ -2528,8 +2609,8 @@ def on_child(self, pipe = None): # ignores the complete set of signals (avoids signal duplication) # and registers for the exit on the term signal that should be # sent from the parent process (proper exit/termination) - self.bind_signals(handler = signal.SIG_IGN) - self.bind_signals(signals = (signal.SIGTERM,)) + self.bind_signals(handler=signal.SIG_IGN) + self.bind_signals(signals=(signal.SIGTERM,)) def on_command(self, command): self.trigger("command", self, command) @@ -2555,47 +2636,58 @@ def on_read(self, _socket): # to the execution of the read operation in the socket callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("read", _socket) + for callback in callbacks: + callback("read", _socket) # retrieves the connection object associated with the # current socket that is going to be read in case there's # no connection available or the status is not open # must return the control flow immediately to the caller connection = self.connections_m.get(_socket, None) - if not connection: return - if not connection.status == OPEN: return - if not connection.renable == True: return + if not connection: + return + if not connection.status == OPEN: + return + if not connection.renable == True: + return try: # in case the connection is under the connecting state # the socket must be verified for errors and in case # there's none the connection must proceed, for example # the SSL connection handshake must be performed/retried - if connection.connecting: self._connectf(connection) + if connection.connecting: + self._connectf(connection) # verifies if there's any pending operations in the # connection (eg: SSL handshaking) and performs it trying # to finish them, if they are still pending at the current # state returns immediately (waits for next loop) - if self._pending(connection): return + if self._pending(connection): + return # iterates continuously trying to read as much data as possible # when there's a failure to read more data it should raise an # exception that should be handled properly while True: data = connection.recv(CHUNK_SIZE) - if data: self.on_data_base(connection, data) - else: connection.close(); break - if not connection.status == OPEN: break - if not connection.renable == True: break - if not connection.socket == _socket: break + if data: + self.on_data_base(connection, data) + else: + connection.close() + break + if not connection.status == OPEN: + break + if not connection.renable == True: + break + if not connection.socket == _socket: + break except ssl.SSLError as error: error_v = error.args[0] if error.args else None error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error, connection) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error, connection) except socket.error as error: error_v = error.args[0] if error.args else None @@ -2614,20 +2706,24 @@ def on_write(self, _socket): # to the execution of the read operation in the socket callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("write", _socket) + for callback in callbacks: + callback("write", _socket) # retrieves the connection associated with the socket that # is ready for the write operation and verifies that it # exists and the current status of it is open (required) connection = self.connections_m.get(_socket, None) - if not connection: return - if not connection.status == OPEN: return + if not connection: + return + if not connection.status == OPEN: + return # in case the connection is under the connecting state # the socket must be verified for errors and in case # there's none the connection must proceed, for example # the SSL connection handshake must be performed/retried - if connection.connecting: self._connectf(connection) + if connection.connecting: + self._connectf(connection) try: connection._send() @@ -2636,8 +2732,7 @@ def on_write(self, _socket): error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error, connection) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error, connection) except socket.error as error: error_v = error.args[0] if error.args else None @@ -2653,11 +2748,14 @@ def on_write(self, _socket): def on_error(self, _socket): callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("error", _socket) + for callback in callbacks: + callback("error", _socket) connection = self.connections_m.get(_socket, None) - if not connection: return - if not connection.status == OPEN: return + if not connection: + return + if not connection.status == OPEN: + return connection.close() @@ -2672,7 +2770,8 @@ def on_expected(self, exception, connection): def on_connect(self, connection): connection.set_connected() - if not hasattr(connection, "tuple"): return + if not hasattr(connection, "tuple"): + return self.on_acquire_base(connection) def on_upgrade(self, connection): @@ -2693,8 +2792,10 @@ def on_client_ssl(self, connection): # and calls the proper event handler for each event, this is # required because the connection workflow is probably dependent # on the calling of these event handlers to proceed - if connection.connecting: self.on_connect(connection) - elif connection.upgrading: self.on_upgrade(connection) + if connection.connecting: + self.on_connect(connection) + elif connection.upgrading: + self.on_upgrade(connection) def on_acquire(self, connection): pass @@ -2714,54 +2815,46 @@ def on_data(self, connection, data): def on_data_base(self, connection, data): connection.set_data(data) - def info_dict(self, full = False): + def info_dict(self, full=False): info = dict( - loaded = self._loaded, - connections = len(self.connections), - state = self.get_state_s(), - poll = self.get_poll_name() - ) - if full: info.update( - name = self.name, - _lid = self._lid + loaded=self._loaded, + connections=len(self.connections), + state=self.get_state_s(), + poll=self.get_poll_name(), ) + if full: + info.update(name=self.name, _lid=self._lid) return info - def info_string(self, full = False, safe = True): - try: info = self.info_dict(full = full) - except Exception: info = dict() + def info_string(self, full=False, safe=True): + try: + info = self.info_dict(full=full) + except Exception: + info = dict() info_s = json.dumps( - info, - ensure_ascii = False, - indent = 4, - separators = (",", " : "), - sort_keys = True + info, ensure_ascii=False, indent=4, separators=(",", " : "), sort_keys=True ) return info_s - def connections_dict(self, full = False): + def connections_dict(self, full=False): connections = [] for connection in self.connections: - info = connection.info_dict(full = full) + info = connection.info_dict(full=full) connections.append(info) return connections - def connection_dict(self, id, full = False): + def connection_dict(self, id, full=False): connection = None for _connection in self.connections: - if not _connection.id == id: continue + if not _connection.id == id: + continue connection = _connection break - if not connection: return None - return connection.info_dict(full = full) + if not connection: + return None + return connection.info_dict(full=full) - def build_connection( - self, - socket, - address = None, - datagram = False, - ssl = False - ): + def build_connection(self, socket, address=None, datagram=False, ssl=False): """ Creates a new connection for the provided socket object and string based address, the returned @@ -2785,11 +2878,7 @@ def build_connection( """ return Connection( - owner = self, - socket = socket, - address = address, - datagram = datagram, - ssl = ssl + owner=self, socket=socket, address=address, datagram=datagram, ssl=ssl ) def base_connection(self, *args, **kwargs): @@ -2809,31 +2898,37 @@ def del_connection(self, connection): def add_callback(self, socket, callback): callbacks = self.callbacks_m.get(socket, []) - if callback in callbacks: return + if callback in callbacks: + return callbacks.append(callback) self.callbacks_m[socket] = callbacks def remove_callback(self, socket, callback): callbacks = self.callbacks_m.get(socket, []) - if not callback in callbacks: return + if not callback in callbacks: + return callbacks.remove(callback) - if callbacks: return + if callbacks: + return del self.callbacks_m[socket] - def load_config(self, path = "config.json", **kwargs): + def load_config(self, path="config.json", **kwargs): kwargs = self.apply_config(path, kwargs) for key, value in legacy.iteritems(kwargs): setattr(self, key, value) def apply_config(self, path, kwargs): - if not os.path.exists(path): return kwargs + if not os.path.exists(path): + return kwargs self.info("Applying configuration file '%s' ..." % path) kwargs = copy.copy(kwargs) file = open(path, "rb") - try: contents = json.load(file) - finally: file.close() + try: + contents = json.load(file) + finally: + file.close() for key, value in legacy.iteritems(contents): kwargs[key] = value @@ -2848,8 +2943,7 @@ def exec_safe(self, connection, callable, *args, **kwargs): error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error, connection) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error, connection) except socket.error as error: error_v = error.args[0] if error.args else None @@ -2884,72 +2978,95 @@ def is_devel(self): return self.is_debug() def is_debug(self): - if not self.logger: return False + if not self.logger: + return False return self.logger.isEnabledFor(logging.DEBUG) def is_info(self): - if not self.logger: return False + if not self.logger: + return False return self.logger.isEnabledFor(logging.INFO) def is_warning(self): - if not self.logger: return False + if not self.logger: + return False return self.logger.isEnabledFor(logging.WARNING) def is_error(self): - if not self.logger: return False + if not self.logger: + return False return self.logger.isEnabledFor(logging.ERROR) def is_critical(self): - if not self.logger: return False + if not self.logger: + return False return self.logger.isEnabledFor(logging.CRITICAL) def debug(self, object): - if not logging: return - self.log(object, level = logging.DEBUG) + if not logging: + return + self.log(object, level=logging.DEBUG) def info(self, object): - if not logging: return - self.log(object, level = logging.INFO) + if not logging: + return + self.log(object, level=logging.INFO) def warning(self, object): - if not logging: return - self.log(object, level = logging.WARNING) + if not logging: + return + self.log(object, level=logging.WARNING) def error(self, object): - if not logging: return - self.log(object, level = logging.ERROR) + if not logging: + return + self.log(object, level=logging.ERROR) def critical(self, object): - if not logging: return - self.log(object, level = logging.CRITICAL) + if not logging: + return + self.log(object, level=logging.CRITICAL) - def log_stack(self, method = None, info = True): - if not method: method = self.info + def log_stack(self, method=None, info=True): + if not method: + method = self.info lines = traceback.format_exc().splitlines() - for line in lines: method(line) - if info: self.log_info(method = method) - - def log_info(self, method = None): - if not method: method = self.info - info_string = self.info_string(full = True) - for line in info_string.split("\n"): method(line) + for line in lines: + method(line) + if info: + self.log_info(method=method) + + def log_info(self, method=None): + if not method: + method = self.info + info_string = self.info_string(full=True) + for line in info_string.split("\n"): + method(line) def log(self, *args, **kwargs): - if legacy.PYTHON_3: return self.log_python_3(*args, **kwargs) - else: return self.log_python_2(*args, **kwargs) + if legacy.PYTHON_3: + return self.log_python_3(*args, **kwargs) + else: + return self.log_python_2(*args, **kwargs) - def log_python_3(self, object, level = logging.INFO): + def log_python_3(self, object, level=logging.INFO): is_str = isinstance(object, legacy.STRINGS) - try: message = str(object) if not is_str else object - except Exception: message = str(object) - if not self.logger: return + try: + message = str(object) if not is_str else object + except Exception: + message = str(object) + if not self.logger: + return self.logger.log(level, message) - def log_python_2(self, object, level = logging.INFO): + def log_python_2(self, object, level=logging.INFO): is_str = isinstance(object, legacy.STRINGS) - try: message = unicode(object) if not is_str else object #@UndefinedVariable - except Exception: message = str(object).decode("utf-8", "ignore") - if not self.logger: return + try: + message = unicode(object) if not is_str else object # @UndefinedVariable + except Exception: + message = str(object).decode("utf-8", "ignore") + if not self.logger: + return self.logger.log(level, message) def build_poll(self): @@ -2961,13 +3078,15 @@ def build_poll(self): # case it's ther's no need to re-build the polling mechanism # otherwise rebuilds the polling mechanism with the current # name and returns the new poll object to the caller method - if self.poll and self.poll.is_open(): return self.poll + if self.poll and self.poll.is_open(): + return self.poll # runs the testing of the poll again and verifies if the polling # class has changed in case it did not returns the current poll # instance as expected by the current infra-structure - poll_c = cls.test_poll(preferred = self.poll_name) - if poll_c == self.poll_c: return self.poll + poll_c = cls.test_poll(preferred=self.poll_name) + if poll_c == self.poll_c: + return self.poll # updates the polling class with the new value and re-creates # the polling instance with the new polling class returning this @@ -2976,7 +3095,7 @@ def build_poll(self): self.poll = self.poll_c() return self.poll - def build_future(self, compat = True, asyncio = True): + def build_future(self, compat=True, asyncio=True): """ Creates a future object that is bound to the current event loop context, this allows for latter access to the owning loop. @@ -2994,13 +3113,14 @@ def build_future(self, compat = True, asyncio = True): # creates a normal future object, setting the current loop (global) as # the loop, then returns the future to the caller method - loop = self.get_loop(compat = compat, asyncio = asyncio) - future = asynchronous.Future(loop = loop) + loop = self.get_loop(compat=compat, asyncio=asyncio) + future = asynchronous.Future(loop=loop) return future - def get_id(self, unique = True): + def get_id(self, unique=True): base = NAME + "-" + util.camel_to_underscore(self.name) - if not unique: return base + if not unique: + return base return base + "-" + str(self._uuid) def get_poll(self): @@ -3017,7 +3137,7 @@ def get_state(self): def set_state(self, state): self._state = state - def get_state_s(self, lower = True): + def get_state_s(self, lower=True): """ Retrieves a string describing the current state of the system, this string should be as descriptive @@ -3038,7 +3158,7 @@ def get_state_s(self, lower = True): state_s = state_s.lower() if lower else state_s return state_s - def get_env(self, name, default = None, cast = None, expand = False): + def get_env(self, name, default=None, cast=None, expand=False): """ Retrieves the value of the environment variable with the requested name, defaulting to the provided value in case @@ -3073,14 +3193,17 @@ def get_env(self, name, default = None, cast = None, expand = False): properly casted into the target value. """ - if not name in config.CONFIGS: return default + if not name in config.CONFIGS: + return default value = config.CONFIGS.get(name, default) - if expand: value = self.expand(value) + if expand: + value = self.expand(value) cast = config.CASTS.get(cast, cast) - if cast and not value == None: value = cast(value) + if cast and not value == None: + value = cast(value) return value - def expand(self, value, encoding = "utf-8", force = False): + def expand(self, value, encoding="utf-8", force=False): """ Expands the provided string/bytes value into a file in the current file system so that it may be correctly used by interfaces @@ -3107,15 +3230,19 @@ def expand(self, value, encoding = "utf-8", force = False): for the expansion of the provided value. """ - if not value and not force: return value + if not value and not force: + return value is_bytes = legacy.is_bytes(value) - if not is_bytes: value = value.encode(encoding) + if not is_bytes: + value = value.encode(encoding) value = value.replace(b"\\n", b"\n") fd, file_path = tempfile.mkstemp() os.close(fd) file = open(file_path, "wb") - try: file.write(value) - finally: file.close() + try: + file.write(value) + finally: + file.close() self._expanded.append(file_path) return file_path @@ -3132,7 +3259,7 @@ def get_protocols(self): return None - def get_adapter(self, name = "memory", *args, **kwargs): + def get_adapter(self, name="memory", *args, **kwargs): """ Retrieves an instance of a storage adapter described by the provided name, note that the dynamic (extra) @@ -3152,7 +3279,7 @@ def get_adapter(self, name = "memory", *args, **kwargs): adapter = adapter_c(*args, **kwargs) return adapter - def get_auth(self, name = "memory", *args, **kwargs): + def get_auth(self, name="memory", *args, **kwargs): """ Gathers the proper authentication handler that is being requested with the provided name. The retrieved auth @@ -3252,7 +3379,8 @@ def _notifies(self): while self._notified: event, data = self._notified.pop(0) binds = self._events.pop(event, []) - for callable in binds: callable(data) + for callable in binds: + callable(data) count += 1 # returns the number of processed notifications to the @@ -3284,7 +3412,8 @@ def _delays(self): # in case there's no delayed items to be called returns the control # flow immediately, note that the notified elements (pending process) # are also going to be verified for presence - if not self._delayed and not self._notified: return + if not self._delayed and not self._notified: + return # retrieves the value for the current timestamp, to be used in # comparisons against the target timestamps of the callables @@ -3304,7 +3433,8 @@ def _delays(self): # runs the notifies verification cycle and if there's at # least one processed event continues the loop meaning that # the if test evaluations must be re-processed - if self._notifies(): continue + if self._notifies(): + continue # "pops" the current item from the delayed list to be used # in the execution of the current iteration cycle @@ -3320,7 +3450,8 @@ def _delays(self): # for the comparison against the current time reference # this is performed by defaulting the value against negative # ensuring immediate execution of the associated callable - if target == None: target = -1 + if target == None: + target = -1 # tests if the current target is valid (less than or # equals to the current time value) and in case it's @@ -3344,11 +3475,12 @@ def _delays(self): # unpacks the multiple options so that it's possible to determine # the way the delayed operation is going to be executed - run, = options + (run,) = options # in case the method is not meant to be run, probably canceled # the execution of it should be properly ignored - if not run: continue + if not run: + continue # calls the callback method as the delayed operation is # now meant to be run, this is an operation that may change @@ -3356,12 +3488,13 @@ def _delays(self): # must be implemented with the proper precautions, note that # proper exception is set so that proper top level handling # is defined and logging is performed - try: method() + try: + method() except (KeyboardInterrupt, SystemExit, errors.StopError): raise except BaseException as exception: self.error(exception) - self.log_stack(method = self.warning) + self.log_stack(method=self.warning) # iterates over all the pending callable tuple values and adds # them back to the delayed heap list so that they are called @@ -3372,9 +3505,10 @@ def _delays(self): # in case the delayed list is empty resets the delay id so that # it never gets into a very large number, would break performance - if not self._delayed: self._did = 0 + if not self._delayed: + self._did = 0 - def _generate(self, hashed = True): + def _generate(self, hashed=True): """ Generates a random unique identifier that may be used to uniquely identify a certain object or operation. @@ -3392,7 +3526,8 @@ def _generate(self, hashed = True): identifier = str(uuid.uuid4()) identifier = identifier.upper() - if not hashed: return identifier + if not hashed: + return identifier identifier = legacy.bytes(identifier) hash = hashlib.sha256(identifier) identifier = hash.hexdigest() @@ -3403,18 +3538,20 @@ def _connect(self, connection): # in case the current connection has been closed meanwhile # the current connection is meant to be avoided and so the # method must return immediately to the caller method - if connection.status == CLOSED: return + if connection.status == CLOSED: + return # retrieves the socket associated with the connection # and calls the open method of the connection to proceed # with the correct operations for the connection _socket = connection.socket - connection.open(connect = True) + connection.open(connect=True) # tries to run the non blocking connection it should # fail and the connection should only be considered as # open when a write event is raised for the connection - try: _socket.connect(connection.address) + try: + _socket.connect(connection.address) except ssl.SSLError as error: error_v = error.args[0] if error.args else None if not error_v in SSL_VALID_ERRORS: @@ -3449,7 +3586,8 @@ def _connect(self, connection): # in case the connection is not of type SSL the method # may return as there's nothing left to be done, as the # rest of the method is dedicated to SSL tricks - if not connection.ssl: return + if not connection.ssl: + return # verifies if the current SSL object is a context oriented one # (newest versions) or a legacy oriented one, that does not uses @@ -3462,12 +3600,11 @@ def _connect(self, connection): # destroyed by the underlying SSL library (as an error) because # the socket is of type non blocking and raises an error, note # that the creation of the socket varies between SSL versions - if _socket._sslobj: return + if _socket._sslobj: + return if has_context: _socket._sslobj = _socket.context._wrap_socket( - _socket, - _socket.server_side, - _socket.server_hostname + _socket, _socket.server_side, _socket.server_hostname ) else: _socket._sslobj = ssl._ssl.sslwrap( @@ -3477,7 +3614,7 @@ def _connect(self, connection): _socket.certfile, _socket.cert_reqs, _socket.ssl_version, - _socket.ca_certs + _socket.ca_certs, ) # verifies if the SSL object class is defined in the SSL module @@ -3485,9 +3622,12 @@ def _connect(self, connection): # in order to comply with new indirection/abstraction method, under # some circumstances this operations fails with an exception because # the wrapping operation is not allowed for every Python environment - if not hasattr(ssl, "SSLObject"): return - try: _socket._sslobj = ssl.SSLObject(_socket._sslobj, owner = _socket) - except TypeError: pass + if not hasattr(ssl, "SSLObject"): + return + try: + _socket._sslobj = ssl.SSLObject(_socket._sslobj, owner=_socket) + except TypeError: + pass def _connectf(self, connection): """ @@ -3505,19 +3645,24 @@ def _connectf(self, connection): # in case the SSL connection is still undergoing the handshaking # procedures (marked as connecting) ignores the call as this must # be a duplicated call to this method (to be ignored) - if connection.ssl_connecting: return + if connection.ssl_connecting: + return # verifies if there was an error in the middle of the connection # operation and if that's the case calls the proper callback and # returns the control flow to the caller method error = connection.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if error: self.on_error(connection.socket); return + if error: + self.on_error(connection.socket) + return # checks if the current connection is SSL based and if that's the # case starts the handshaking process (async non blocking) otherwise # calls the on connect callback with the newly created connection - if connection.ssl: connection.add_starter(self._ssl_client_handshake) - else: self.on_connect(connection) + if connection.ssl: + connection.add_starter(self._ssl_client_handshake) + else: + self.on_connect(connection) # runs the starter process (initial kick-off) so that all the starters # registered for the connection may start to be executed, note that if @@ -3525,43 +3670,28 @@ def _connectf(self, connection): # going to be triggered by this call connection.run_starter() - def _socket_keepalive( - self, - _socket, - timeout = None, - interval = None, - count = None - ): - if timeout == None: timeout = self.keepalive_timeout - if interval == None: interval = self.keepalive_interval - if count == None: count = self.keepalive_count + def _socket_keepalive(self, _socket, timeout=None, interval=None, count=None): + if timeout == None: + timeout = self.keepalive_timeout + if interval == None: + interval = self.keepalive_interval + if count == None: + count = self.keepalive_count is_inet = _socket.family in (socket.AF_INET, socket.AF_INET6) - is_inet and hasattr(_socket, "TCP_KEEPIDLE") and\ - self.socket.setsockopt( - socket.IPPROTO_TCP, - socket.TCP_KEEPIDLE, #@UndefinedVariable - timeout - ) - is_inet and hasattr(_socket, "TCP_KEEPINTVL") and\ - self.socket.setsockopt( - socket.IPPROTO_TCP, - socket.TCP_KEEPINTVL, #@UndefinedVariable - interval - ) - is_inet and hasattr(_socket, "TCP_KEEPCNT") and\ - self.socket.setsockopt( - socket.IPPROTO_TCP, - socket.TCP_KEEPCNT, #@UndefinedVariable - count - ) - hasattr(_socket, "SO_REUSEPORT") and\ - self.socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_REUSEPORT, #@UndefinedVariable - 1 - ) + is_inet and hasattr(_socket, "TCP_KEEPIDLE") and self.socket.setsockopt( + socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, timeout # @UndefinedVariable + ) + is_inet and hasattr(_socket, "TCP_KEEPINTVL") and self.socket.setsockopt( + socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval # @UndefinedVariable + ) + is_inet and hasattr(_socket, "TCP_KEEPCNT") and self.socket.setsockopt( + socket.IPPROTO_TCP, socket.TCP_KEEPCNT, count # @UndefinedVariable + ) + hasattr(_socket, "SO_REUSEPORT") and self.socket.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT, 1 # @UndefinedVariable + ) - def _ssl_init(self, strict = True, env = True): + def _ssl_init(self, strict=True, env=True): # initializes the values of both the "main" context for SSL # and the map that associated an hostname and a context, both # are going to be used (if possible) at runtime for proper @@ -3574,15 +3704,18 @@ def _ssl_init(self, strict = True, env = True): # returned to the caller method as it's not possible to created # any kind of context information for SSL has_context = hasattr(ssl, "SSLContext") - if not has_context: return + if not has_context: + return # retrieves the reference to the environment variables that are going # to be used in the construction of the various SSL contexts, note that # the secure variable is extremely important to ensure that a proper and # secure SSL connection is established with the peer - secure = self.get_env("SSL_SECURE", 1, cast = int) if env else 0 - context_options = self.get_env("SSL_CONTEXT_OPTIONS", [], cast = list) if env else [] - contexts = self.get_env("SSL_CONTEXTS", {}, cast = dict) if env else {} + secure = self.get_env("SSL_SECURE", 1, cast=int) if env else 0 + context_options = ( + self.get_env("SSL_CONTEXT_OPTIONS", [], cast=list) if env else [] + ) + contexts = self.get_env("SSL_CONTEXTS", {}, cast=dict) if env else {} # creates the main/default SSL context setting the default key # and certificate information in such context, then verifies @@ -3593,24 +3726,22 @@ def _ssl_init(self, strict = True, env = True): # is unset for situation where no callback registration is possible self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) self._ssl_ctx_base( - self._ssl_context, - secure = secure, - context_options = context_options + self._ssl_context, secure=secure, context_options=context_options ) self._ssl_ctx_protocols(self._ssl_context) self._ssl_certs(self._ssl_context) has_callback = hasattr(self._ssl_context, "set_servername_callback") - if has_callback: self._ssl_context.set_servername_callback(self._ssl_callback) - elif strict: self._ssl_context = None + if has_callback: + self._ssl_context.set_servername_callback(self._ssl_callback) + elif strict: + self._ssl_context = None # retrieves the reference to the map containing the various key # and certificate paths for the various defined host names and # uses it to create the complete set of SSL context objects for hostname, values in legacy.iteritems(contexts): context = self._ssl_ctx( - values, - secure = secure, - context_options = context_options + values, secure=secure, context_options=context_options ) self._ssl_contexts[hostname] = (context, values) @@ -3622,18 +3753,21 @@ def _ssl_callback(self, socket, hostname, context): context, values = self._ssl_contexts.get(hostname, (context, None)) self._ssl_ctx_protocols(context) socket.context = context - if not values: return + if not values: + return ssl_host = values.get("ssl_host", None) ssl_fingerprint = values.get("ssl_fingerprint", None) - if not ssl_host and not ssl_fingerprint: return + if not ssl_host and not ssl_fingerprint: + return connection = self.connections_m.get(socket, None) - if not connection: return + if not connection: + return connection.ssl_host = ssl_host connection.ssl_fingerprint = ssl_fingerprint - def _ssl_ctx(self, values, context = None, secure = 1, context_options = []): + def _ssl_ctx(self, values, context=None, secure=1, context_options=[]): context = context or ssl.SSLContext(ssl.PROTOCOL_SSLv23) - self._ssl_ctx_base(context, secure = secure, context_options = context_options) + self._ssl_ctx_base(context, secure=secure, context_options=context_options) self._ssl_ctx_protocols(context) key_file = values.get("key_file", None) cer_file = values.get("cer_file", None) @@ -3643,15 +3777,15 @@ def _ssl_ctx(self, values, context = None, secure = 1, context_options = []): cert_reqs = ssl.CERT_REQUIRED if ssl_verify else ssl.CERT_NONE self._ssl_certs( context, - key_file = key_file, - cer_file = cer_file, - ca_file = ca_file, - ca_root = ca_root, - verify_mode = cert_reqs + key_file=key_file, + cer_file=cer_file, + ca_file=ca_file, + ca_root=ca_root, + verify_mode=cert_reqs, ) return context - def _ssl_ctx_base(self, context, secure = 1, context_options = []): + def _ssl_ctx_base(self, context, secure=1, context_options=[]): if secure >= 1 and hasattr(ssl, "OP_NO_SSLv2"): context.options |= ssl.OP_NO_SSLv2 if secure >= 1 and hasattr(ssl, "OP_NO_SSLv3"): @@ -3667,7 +3801,8 @@ def _ssl_ctx_base(self, context, secure = 1, context_options = []): if secure >= 1 and hasattr(ssl, "OP_CIPHER_SERVER_PREFERENCE"): context.options |= ssl.OP_CIPHER_SERVER_PREFERENCE for context_option in context_options: - if not hasattr(ssl, context_option): continue + if not hasattr(ssl, context_option): + continue context.options |= getattr(ssl, context_option) if secure >= 2 and hasattr(context, "set_ecdh_curve"): context.set_ecdh_curve("prime256v1") @@ -3679,28 +3814,34 @@ def _ssl_ctx_protocols(self, context): self._ssl_ctx_npn(context) def _ssl_ctx_alpn(self, context): - if not hasattr(ssl, "HAS_ALPN"): return - if not ssl.HAS_ALPN: return + if not hasattr(ssl, "HAS_ALPN"): + return + if not ssl.HAS_ALPN: + return if hasattr(context, "set_alpn_protocols"): protocols = self.get_protocols() - if protocols: context.set_alpn_protocols(protocols) + if protocols: + context.set_alpn_protocols(protocols) def _ssl_ctx_npn(self, context): - if not hasattr(ssl, "HAS_NPN"): return - if not ssl.HAS_NPN: return + if not hasattr(ssl, "HAS_NPN"): + return + if not ssl.HAS_NPN: + return if hasattr(context, "set_npn_protocols"): protocols = self.get_protocols() - if protocols: context.set_npn_protocols(protocols) + if protocols: + context.set_npn_protocols(protocols) def _ssl_certs( self, context, - key_file = None, - cer_file = None, - ca_file = None, - ca_root = False, - verify_mode = ssl.CERT_NONE, - check_hostname = False + key_file=None, + cer_file=None, + ca_file=None, + ca_root=False, + verify_mode=ssl.CERT_NONE, + check_hostname=False, ): dir_path = os.path.dirname(__file__) root_path = os.path.join(dir_path, "../") @@ -3709,49 +3850,50 @@ def _ssl_certs( extras_path = os.path.join(base_path, "extras") key_file = key_file or os.path.join(extras_path, "net.key") cer_file = cer_file or os.path.join(extras_path, "net.cer") - context.load_cert_chain(cer_file, keyfile = key_file) + context.load_cert_chain(cer_file, keyfile=key_file) context.verify_mode = verify_mode - if hasattr(context, "check_hostname"): context.check_hostname = check_hostname + if hasattr(context, "check_hostname"): + context.check_hostname = check_hostname if ca_file: - context.load_verify_locations(cafile = ca_file) + context.load_verify_locations(cafile=ca_file) if ca_root and hasattr(context, "load_default_certs"): - context.load_default_certs(purpose = ssl.Purpose.SERVER_AUTH) + context.load_default_certs(purpose=ssl.Purpose.SERVER_AUTH) if ca_root and SSL_CA_PATH: - context.load_verify_locations(cafile = SSL_CA_PATH) + context.load_verify_locations(cafile=SSL_CA_PATH) def _ssl_upgrade( self, _socket, - key_file = None, - cer_file = None, - ca_file = None, - ca_root = True, - server = True, - ssl_verify = False, - server_hostname = None + key_file=None, + cer_file=None, + ca_file=None, + ca_root=True, + server=True, + ssl_verify=False, + server_hostname=None, ): socket_ssl = self._ssl_wrap( _socket, - key_file = key_file, - cer_file = cer_file, - ca_file = ca_file, - ca_root = ca_root, - server = server, - ssl_verify = ssl_verify, - server_hostname = server_hostname + key_file=key_file, + cer_file=cer_file, + ca_file=ca_file, + ca_root=ca_root, + server=server, + ssl_verify=ssl_verify, + server_hostname=server_hostname, ) return socket_ssl def _ssl_wrap( self, _socket, - key_file = None, - cer_file = None, - ca_file = None, - ca_root = True, - server = True, - ssl_verify = False, - server_hostname = None + key_file=None, + cer_file=None, + ca_file=None, + ca_root=True, + server=True, + ssl_verify=False, + server_hostname=None, ): # tries to determine the value for the check hostname flag to be # passed to the wrap function by ensuring that both the SSL verify @@ -3777,30 +3919,30 @@ def _ssl_wrap( if not self._ssl_context: return ssl.wrap_socket( _socket, - keyfile = key_file, - certfile = cer_file, - server_side = server, - cert_reqs = cert_reqs, - ca_certs = ca_file, - ssl_version = ssl.PROTOCOL_SSLv23, - do_handshake_on_connect = False + keyfile=key_file, + certfile=cer_file, + server_side=server, + cert_reqs=cert_reqs, + ca_certs=ca_file, + ssl_version=ssl.PROTOCOL_SSLv23, + do_handshake_on_connect=False, ) self._ssl_certs( self._ssl_context, - key_file = key_file, - cer_file = cer_file, - ca_file = ca_file, - ca_root = ca_root, - verify_mode = cert_reqs, - check_hostname = check_hostname + key_file=key_file, + cer_file=cer_file, + ca_file=ca_file, + ca_root=ca_root, + verify_mode=cert_reqs, + check_hostname=check_hostname, ) return self._ssl_context.wrap_socket( _socket, - server_side = server, - do_handshake_on_connect = False, - server_hostname = server_hostname + server_side=server, + do_handshake_on_connect=False, + server_hostname=server_hostname, ) def _ssl_handshake(self, connection): @@ -3852,12 +3994,14 @@ def _ssl_handshake(self, connection): # or read operation is available (retry process) error_v = error.args[0] if error.args else None if error_v in SSL_VALID_ERRORS: - if error_v == ssl.SSL_ERROR_WANT_WRITE and\ - not self.is_sub_write(_socket): + if error_v == ssl.SSL_ERROR_WANT_WRITE and not self.is_sub_write( + _socket + ): self.sub_write(_socket) elif self.is_sub_write(_socket): self.unsub_write(_socket) - else: raise + else: + raise def _ssl_client_handshake(self, connection): """ @@ -3920,12 +4064,14 @@ def _ssl_client_handshake(self, connection): # or read operation is available (retry process) error_v = error.args[0] if error.args else None if error_v in SSL_VALID_ERRORS: - if error_v == ssl.SSL_ERROR_WANT_WRITE and\ - not self.is_sub_write(_socket): + if error_v == ssl.SSL_ERROR_WANT_WRITE and not self.is_sub_write( + _socket + ): self.sub_write(_socket) elif self.is_sub_write(_socket): self.unsub_write(_socket) - else: raise + else: + raise def _expand_destroy(self): """ @@ -3939,8 +4085,10 @@ def _expand_destroy(self): # iterates over the complete list of expanded file paths to remove # their corresponding files (graceful error handling) for expanded in self._expanded: - try: os.remove(expanded) - except OSError: pass + try: + os.remove(expanded) + except OSError: + pass # deletes the complete set of path references from the expanded # list so that it is not going to be used any longer @@ -3965,14 +4113,17 @@ def _level(self, level): """ level_t = type(level) - if level_t == int: return level - if level == None: return level - if level == "SILENT": return log.SILENT + if level_t == int: + return level + if level == None: + return level + if level == "SILENT": + return log.SILENT if hasattr(logging, "_checkLevel"): return logging._checkLevel(level) return logging.getLevelName(level) - def _format_delta(self, time_delta, count = 2): + def _format_delta(self, time_delta, count=2): days = time_delta.days hours, remainder = divmod(time_delta.seconds, 3600) minutes, seconds = divmod(remainder, 60) @@ -3980,19 +4131,22 @@ def _format_delta(self, time_delta, count = 2): if days > 0: delta_s += "%dd " % days count -= 1 - if count == 0: return delta_s.strip() + if count == 0: + return delta_s.strip() if hours > 0: delta_s += "%dh " % hours count -= 1 - if count == 0: return delta_s.strip() + if count == 0: + return delta_s.strip() if minutes > 0: delta_s += "%dm " % minutes count -= 1 - if count == 0: return delta_s.strip() + if count == 0: + return delta_s.strip() delta_s += "%ds" % seconds return delta_s.strip() - def _wait_forever(self, sleep = 60): + def _wait_forever(self, sleep=60): """ Runs a simple event loop that sleeps for a certain amount of time and then processes a series of pending events. @@ -4020,18 +4174,27 @@ def _wait_forever(self, sleep = 60): while self._running: try: self._awaken = True - try: self._delays() - finally: self._awaken = False + try: + self._delays() + finally: + self._awaken = False time.sleep(sleep) except errors.WakeupError: continue - except (KeyboardInterrupt, SystemExit, errors.StopError, errors.PauseError): + except ( + KeyboardInterrupt, + SystemExit, + errors.StopError, + errors.PauseError, + ): raise except BaseException as exception: self.error(exception) - self.log_stack(method = self.warning) + self.log_stack(method=self.warning) finally: - if hasattr(self, "_awaken"): del self._awaken + if hasattr(self, "_awaken"): + del self._awaken + class DiagBase(AbstractBase): @@ -4053,15 +4216,14 @@ def errors(self, *args, **kwargs): AbstractBase.errors(self, *args, **kwargs) self.errors_c += 1 - def info_dict(self, full = False): - info = AbstractBase.info_dict(self, full = full) + def info_dict(self, full=False): + info = AbstractBase.info_dict(self, full=full) info.update( - reads_c = self.reads_c, - writes_c = self.writes_c, - errors_c = self.errors_c + reads_c=self.reads_c, writes_c=self.writes_c, errors_c=self.errors_c ) return info + class BaseThread(threading.Thread): """ The top level thread class that is meant to encapsulate @@ -4071,14 +4233,15 @@ class BaseThread(threading.Thread): a main thread to continue with execution logic. """ - def __init__(self, owner = None, daemon = False, *args, **kwargs): + def __init__(self, owner=None, daemon=False, *args, **kwargs): threading.Thread.__init__(self, *args, **kwargs) self.owner = owner self.daemon = daemon def run(self): threading.Thread.run(self) - if not self.owner: return + if not self.owner: + return self.owner._thread = self try: self.owner.start() @@ -4086,57 +4249,69 @@ def run(self): self.owner._thread = None self.owner = None -def new_loop_main(factory = None, _compat = None, **kwargs): + +def new_loop_main(factory=None, _compat=None, **kwargs): factory = factory or Base kwargs["_slave"] = kwargs.pop("_slave", True) instance = factory(**kwargs) return compat_loop(instance) if _compat else instance + def new_loop_asyncio(**kwargs): asyncio = asynchronous.get_asyncio() - if not asyncio: return None + if not asyncio: + return None return asyncio.new_event_loop() -def new_loop(factory = None, _compat = None, asyncio = None, **kwargs): + +def new_loop(factory=None, _compat=None, asyncio=None, **kwargs): _compat = compat.is_compat() if _compat == None else _compat asyncio = compat.is_asyncio() if asyncio == None else asyncio - if asyncio: return new_loop_asyncio(**kwargs) - else: return new_loop_main(factory = factory, _compat = _compat, **kwargs) + if asyncio: + return new_loop_asyncio(**kwargs) + else: + return new_loop_main(factory=factory, _compat=_compat, **kwargs) -def ensure_main(factory = None, **kwargs): - if Base.get_main(): return + +def ensure_main(factory=None, **kwargs): + if Base.get_main(): + return factory = factory or Base instance = factory(**kwargs) Base.set_main(instance) + def ensure_asyncio(**kwargs): asyncio = asynchronous.get_asyncio() - if not asyncio: return None + if not asyncio: + return None return asyncio.get_event_loop() -def ensure_loop(factory = None, asyncio = None, **kwargs): + +def ensure_loop(factory=None, asyncio=None, **kwargs): asyncio = compat.is_asyncio() if asyncio == None else asyncio - if asyncio: ensure_asyncio() - else: ensure_main(factory = factory, **kwargs) + if asyncio: + ensure_asyncio() + else: + ensure_main(factory=factory, **kwargs) + -def get_main(factory = None, ensure = True, **kwargs): - if ensure: ensure_main(factory = factory, **kwargs) +def get_main(factory=None, ensure=True, **kwargs): + if ensure: + ensure_main(factory=factory, **kwargs) return Base.get_main() -def get_loop( - factory = None, - ensure = True, - _compat = None, - asyncio = None, - **kwargs -): + +def get_loop(factory=None, ensure=True, _compat=None, asyncio=None, **kwargs): _compat = compat.is_compat() if _compat == None else _compat asyncio = compat.is_asyncio() if asyncio == None else asyncio - if ensure: ensure_loop(factory = factory, asyncio = asyncio) - loop = Base.get_loop(compat = _compat, asyncio = asyncio) - loop = loop or get_main(factory = factory, **kwargs) + if ensure: + ensure_loop(factory=factory, asyncio=asyncio) + loop = Base.get_loop(compat=_compat, asyncio=asyncio) + loop = loop or get_main(factory=factory, **kwargs) return loop + def get_event_loop(*args, **kwargs): """ Compatibility alias function with the ``get_loop()`` function @@ -4149,11 +4324,14 @@ def get_event_loop(*args, **kwargs): return get_loop(*args, **kwargs) -def stop_loop(compat = True, asyncio = True): - loop = get_loop(ensure = False, _compat = compat, asyncio = asyncio) - if not loop: return + +def stop_loop(compat=True, asyncio=True): + loop = get_loop(ensure=False, _compat=compat, asyncio=asyncio) + if not loop: + return loop.stop() + def compat_loop(loop): """ Retrieves the asyncio API compatible version of the provided @@ -4169,33 +4347,32 @@ def compat_loop(loop): return loop._compat if hasattr(loop, "_compat") else loop + def get_poll(): main = get_main() - if not main: return None + if not main: + return None return main.poll -def build_future(compat = True, asyncio = True): + +def build_future(compat=True, asyncio=True): main = get_main() - if not main: return None - return main.build_future(compat = compat, asyncio = asyncio) + if not main: + return None + return main.build_future(compat=compat, asyncio=asyncio) + -def ensure(coroutine, args = [], kwargs = {}, thread = None): +def ensure(coroutine, args=[], kwargs={}, thread=None): loop = get_loop() - return loop.ensure( - coroutine, - args = args, - kwargs = kwargs, - thread = thread - ) - -def ensure_pool(coroutine, args = [], kwargs = {}): - return ensure( - coroutine, - args = args, - kwargs = kwargs, - thread = True - ) - -is_diag = config.conf("DIAG", False, cast = bool) -if is_diag: Base = DiagBase -else: Base = AbstractBase + return loop.ensure(coroutine, args=args, kwargs=kwargs, thread=thread) + + +def ensure_pool(coroutine, args=[], kwargs={}): + return ensure(coroutine, args=args, kwargs=kwargs, thread=True) + + +is_diag = config.conf("DIAG", False, cast=bool) +if is_diag: + Base = DiagBase +else: + Base = AbstractBase diff --git a/src/netius/base/compat.py b/src/netius/base/compat.py index e933204b3..458d667fd 100644 --- a/src/netius/base/compat.py +++ b/src/netius/base/compat.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -52,6 +43,7 @@ asyncio = asynchronous.get_asyncio() if asynchronous.is_neo() else None BaseLoop = asyncio.AbstractEventLoop if asyncio else object + class CompatLoop(BaseLoop): """ Top level compatibility class that adds compatibility support @@ -79,14 +71,14 @@ def time(self): return time.time() def call_soon(self, callback, *args): - return self._call_delay(callback, args, immediately = True) + return self._call_delay(callback, args, immediately=True) def call_soon_threadsafe(self, callback, *args): - return self._call_delay(callback, args, immediately = True, safe = True) + return self._call_delay(callback, args, immediately=True, safe=True) def call_at(self, when, callback, *args): delay = when - self.time() - return self._call_delay(callback, args, timeout = delay) + return self._call_delay(callback, args, timeout=delay) def call_later(self, delay, callback, *args): """ @@ -102,7 +94,7 @@ def call_later(self, delay, callback, *args): :return: The handle object to the operation, that may be used to cancel it. """ - return self._call_delay(callback, args, timeout = delay) + return self._call_delay(callback, args, timeout=delay) def create_future(self): return self._loop.build_future() @@ -134,8 +126,10 @@ def getnameinfo(self, *args, **kwargs): def run_until_complete(self, future): self._set_current_task(future) - try: return self._loop.run_coroutine(future) - finally: self._unset_current_task() + try: + return self._loop.run_coroutine(future) + finally: + self._unset_current_task() def run_forever(self): return self._loop.run_forever() @@ -160,7 +154,8 @@ def default_exception_handler(self, context): return self._default_handler(context) def call_exception_handler(self, context): - if not self._handler: return + if not self._handler: + return return self._handler(context) def get_debug(self): @@ -184,28 +179,13 @@ def is_running(self): def is_closed(self): return self._loop.is_stopped() - def _getaddrinfo( - self, - host, - port, - family = 0, - type = 0, - proto = 0, - flags = 0 - ): + def _getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): future = self.create_future() - result = socket.getaddrinfo( - host, - port, - family, - type, - proto, - flags = flags - ) - self._loop.delay(lambda: future.set_result(result), immediately = True) + result = socket.getaddrinfo(host, port, family, type, proto, flags=flags) + self._loop.delay(lambda: future.set_result(result), immediately=True) yield future - def _getnameinfo(self, sockaddr, flags = 0): + def _getnameinfo(self, sockaddr, flags=0): raise errors.NotImplemented("Missing implementation") def _run_in_executor(self, executor, func, *args): @@ -216,15 +196,15 @@ def _run_in_executor(self, executor, func, *args): def _create_connection( self, protocol_factory, - host = None, - port = None, - ssl = None, - family = 0, - proto = 0, - flags = 0, - sock = None, - local_addr = None, - server_hostname = None, + host=None, + port=None, + ssl=None, + family=0, + proto=0, + flags=0, + sock=None, + local_addr=None, + server_hostname=None, *args, **kwargs ): @@ -234,8 +214,10 @@ def _create_connection( future = self.create_future() def on_complete(connection, success): - if success: on_connect(connection) - else: on_error(connection) + if success: + on_connect(connection) + else: + on_error(connection) def on_connect(connection): protocol = protocol_factory() @@ -244,32 +226,24 @@ def on_connect(connection): future.set_result((_transport, protocol)) def on_error(connection): - future.set_exception( - errors.RuntimeError("Connection issue") - ) + future.set_exception(errors.RuntimeError("Connection issue")) - self._loop.connect( - host, - port, - ssl = ssl, - family = family, - callback = on_complete - ) + self._loop.connect(host, port, ssl=ssl, family=family, callback=on_complete) yield future def _create_datagram_endpoint( self, protocol_factory, - local_addr = None, - remote_addr = None, - family = 0, - proto = 0, - flags = 0, - reuse_address = None, - reuse_port = None, - allow_broadcast = None, - sock = None, + local_addr=None, + remote_addr=None, + family=0, + proto=0, + flags=0, + reuse_address=None, + reuse_port=None, + allow_broadcast=None, + sock=None, *args, **kwargs ): @@ -279,8 +253,10 @@ def _create_datagram_endpoint( future = self.create_future() def on_complete(connection, success): - if success: on_connect(connection) - else: on_error(connection) + if success: + on_connect(connection) + else: + on_error(connection) def on_connect(connection): protocol = protocol_factory() @@ -289,17 +265,15 @@ def on_connect(connection): future.set_result((_transport, protocol)) def on_error(connection): - future.set_exception( - errors.RuntimeError("Connection issue") - ) + future.set_exception(errors.RuntimeError("Connection issue")) connection = self._loop.datagram( - family = family, - type = proto, - local_host = local_addr[0] if local_addr else None, - local_port = local_addr[1] if local_addr else None, - remote_host = remote_addr[0] if remote_addr else None, - remote_port = remote_addr[1] if remote_addr else None + family=family, + type=proto, + local_host=local_addr[0] if local_addr else None, + local_port=local_addr[1] if local_addr else None, + remote_host=remote_addr[0] if remote_addr else None, + remote_port=remote_addr[1] if remote_addr else None, ) self._loop.delay(lambda: on_complete(connection, True)) @@ -307,22 +281,18 @@ def on_error(connection): def _set_current_task(self, task): asyncio = asynchronous.get_asyncio() - if not asyncio: return + if not asyncio: + return asyncio.Task._current_tasks[self] = task def _unset_current_task(self): asyncio = asynchronous.get_asyncio() - if not asyncio: return + if not asyncio: + return asyncio.Task._current_tasks.pop(self, None) def _call_delay( - self, - callback, - args, - timeout = None, - immediately = False, - verify = False, - safe = False + self, callback, args, timeout=None, immediately=False, verify=False, safe=False ): # creates the callable to be called after the timeout, note the # clojure around the "normal" arguments (allows proper propagation) @@ -332,19 +302,15 @@ def _call_delay( # the provided set of options expected by the delay operation the # callback tuple is returned so that a proper handle may be created callable_t = self._loop.delay( - callable, - timeout = timeout, - immediately = immediately, - verify = verify, - safe = safe + callable, timeout=timeout, immediately=immediately, verify=verify, safe=safe ) # creates the handle to control the operation and then returns the # object to the caller method, allowing operation cancellation - handle = asynchronous.Handle(callable_t = callable_t) + handle = asynchronous.Handle(callable_t=callable_t) return handle - def _sleep(self, timeout, future = None): + def _sleep(self, timeout, future=None): # verifies if a future variable is meant to be re-used # or if instead a new one should be created for the new # sleep operation to be executed @@ -370,6 +336,7 @@ def _default_handler(self, context): def _thread_id(self): return self._loop.tid + def is_compat(): """ Determines if the compatibility mode for the netius @@ -385,10 +352,11 @@ def is_compat(): the compatibility mode. """ - compat = config.conf("COMPAT", False, cast = bool) + compat = config.conf("COMPAT", False, cast=bool) compat |= is_asyncio() return compat and asynchronous.is_neo() + def is_asyncio(): """ Checks if the asyncio mode of execution (external event @@ -402,25 +370,32 @@ def is_asyncio(): proper library support available. """ - asyncio = config.conf("ASYNCIO", False, cast = bool) + asyncio = config.conf("ASYNCIO", False, cast=bool) return asyncio and asynchronous.is_asynclib() + def build_datagram(*args, **kwargs): - if is_compat(): return _build_datagram_compat(*args, **kwargs) - else: return _build_datagram_native(*args, **kwargs) + if is_compat(): + return _build_datagram_compat(*args, **kwargs) + else: + return _build_datagram_native(*args, **kwargs) + def connect_stream(*args, **kwargs): - if is_compat(): return _connect_stream_compat(*args, **kwargs) - else: return _connect_stream_native(*args, **kwargs) + if is_compat(): + return _connect_stream_compat(*args, **kwargs) + else: + return _connect_stream_native(*args, **kwargs) + def _build_datagram_native( protocol_factory, - family = socket.AF_INET, - type = socket.SOCK_DGRAM, - remote_host = None, - remote_port = None, - callback = None, - loop = None, + family=socket.AF_INET, + type=socket.SOCK_DGRAM, + remote_host=None, + remote_port=None, + callback=None, + loop=None, *args, **kwargs ): @@ -430,25 +405,29 @@ def _build_datagram_native( protocol = protocol_factory() has_loop_set = hasattr(protocol, "loop_set") - if has_loop_set: protocol.loop_set(loop) + if has_loop_set: + protocol.loop_set(loop) def on_ready(): loop.datagram( - family = family, - type = type, - remote_host = remote_host, - remote_port = remote_port, - callback = on_complete + family=family, + type=type, + remote_host=remote_host, + remote_port=remote_port, + callback=on_complete, ) def on_complete(connection, success): - if success: on_connect(connection) - else: on_error(connection) + if success: + on_connect(connection) + else: + on_error(connection) def on_connect(connection): _transport = transport.TransportDatagram(loop, connection) _transport._set_compat(protocol) - if not callback: return + if not callback: + return callback((_transport, protocol)) def on_error(connection): @@ -458,14 +437,15 @@ def on_error(connection): return loop + def _build_datagram_compat( protocol_factory, - family = socket.AF_INET, - type = socket.SOCK_DGRAM, - remote_host = None, - remote_port = None, - callback = None, - loop = None, + family=socket.AF_INET, + type=socket.SOCK_DGRAM, + remote_host=None, + remote_port=None, + callback=None, + loop=None, *args, **kwargs ): @@ -475,7 +455,8 @@ def _build_datagram_compat( protocol = protocol_factory() has_loop_set = hasattr(protocol, "loop_set") - if has_loop_set: protocol.loop_set(loop) + if has_loop_set: + protocol.loop_set(loop) def build_protocol(): return protocol @@ -487,15 +468,14 @@ def on_connect(future): result = future.result() callback and callback(result) - remote_addr = (remote_host, remote_port) if\ - remote_host and remote_port else kwargs.pop("remote_addr", None) + remote_addr = ( + (remote_host, remote_port) + if remote_host and remote_port + else kwargs.pop("remote_addr", None) + ) connect = loop.create_datagram_endpoint( - build_protocol, - family = family, - remote_addr = remote_addr, - *args, - **kwargs + build_protocol, family=family, remote_addr=remote_addr, *args, **kwargs ) future = loop.create_task(connect) @@ -503,20 +483,21 @@ def on_connect(future): return loop + def _connect_stream_native( protocol_factory, host, port, - ssl = False, - key_file = None, - cer_file = None, - ca_file = None, - ca_root = True, - ssl_verify = False, - family = socket.AF_INET, - type = socket.SOCK_STREAM, - callback = None, - loop = None, + ssl=False, + key_file=None, + cer_file=None, + ca_file=None, + ca_root=True, + ssl_verify=False, + family=socket.AF_INET, + type=socket.SOCK_STREAM, + callback=None, + loop=None, *args, **kwargs ): @@ -526,31 +507,35 @@ def _connect_stream_native( protocol = protocol_factory() has_loop_set = hasattr(protocol, "loop_set") - if has_loop_set: protocol.loop_set(loop) + if has_loop_set: + protocol.loop_set(loop) def on_ready(): loop.connect( host, port, - ssl = ssl, - key_file = key_file, - cer_file = cer_file, - ca_file = ca_file, - ca_root = ca_root, - ssl_verify = ssl_verify, - family = family, - type = type, - callback = on_complete + ssl=ssl, + key_file=key_file, + cer_file=cer_file, + ca_file=ca_file, + ca_root=ca_root, + ssl_verify=ssl_verify, + family=family, + type=type, + callback=on_complete, ) def on_complete(connection, success): - if success: on_connect(connection) - else: on_error(connection) + if success: + on_connect(connection) + else: + on_error(connection) def on_connect(connection): _transport = transport.TransportStream(loop, connection) _transport._set_compat(protocol) - if not callback: return + if not callback: + return callback((_transport, protocol)) def on_error(connection): @@ -560,20 +545,21 @@ def on_error(connection): return loop + def _connect_stream_compat( protocol_factory, host, port, - ssl = False, - key_file = None, - cer_file = None, - ca_file = None, - ca_root = True, - ssl_verify = False, - family = socket.AF_INET, - type = socket.SOCK_STREAM, - callback = None, - loop = None, + ssl=False, + key_file=None, + cer_file=None, + ca_file=None, + ca_root=True, + ssl_verify=False, + family=socket.AF_INET, + type=socket.SOCK_STREAM, + callback=None, + loop=None, *args, **kwargs ): @@ -583,7 +569,8 @@ def _connect_stream_compat( protocol = protocol_factory() has_loop_set = hasattr(protocol, "loop_set") - if has_loop_set: protocol.loop_set(loop) + if has_loop_set: + protocol.loop_set(loop) def build_protocol(): return protocol @@ -597,18 +584,13 @@ def on_connect(future): if ssl and cer_file and key_file: import ssl as _ssl + ssl_context = _ssl.SSLContext() - ssl_context.load_cert_chain(cer_file, keyfile = key_file) + ssl_context.load_cert_chain(cer_file, keyfile=key_file) ssl = ssl_context connect = loop.create_connection( - build_protocol, - host = host, - port = port, - ssl = ssl, - family = family, - *args, - **kwargs + build_protocol, host=host, port=port, ssl=ssl, family=family, *args, **kwargs ) future = loop.create_task(connect) diff --git a/src/netius/base/config.py b/src/netius/base/config.py index 5a795959e..443aaceba 100644 --- a/src/netius/base/config.py +++ b/src/netius/base/config.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -60,19 +51,15 @@ name that references a list of include files to be loaded """ CASTS = { - bool : lambda v: v if isinstance(v, bool) else v in ("1", "true", "True"), - list : lambda v: v if isinstance(v, list) else v.split(";") if v else [], - tuple : lambda v: v if isinstance(v, tuple) else tuple(v.split(";") if v else []) + bool: lambda v: v if isinstance(v, bool) else v in ("1", "true", "True"), + list: lambda v: v if isinstance(v, list) else v.split(";") if v else [], + tuple: lambda v: v if isinstance(v, tuple) else tuple(v.split(";") if v else []), } """ The map containing the various cast method operation associated with the various data types, they provide a different type of casting strategy """ -ENV_ENCODINGS = ( - "utf-8", - sys.getdefaultencoding(), - sys.getfilesystemencoding() -) +ENV_ENCODINGS = ("utf-8", sys.getdefaultencoding(), sys.getfilesystemencoding()) """ The sequence of encodings that are going to be used to try to decode possible byte based strings for the various environment variable values """ @@ -92,12 +79,12 @@ to be the home on in terms of configuration, this value should be set on the initial loading of the ".home" file """ -__builtins__ = __builtins__ if isinstance(__builtins__, dict) else\ - __builtins__.__dict__ +__builtins__ = __builtins__ if isinstance(__builtins__, dict) else __builtins__.__dict__ """ The global builtins reference created by the proper redefinition of the variable if that's required by python implementation """ -def conf(name, default = None, cast = None, ctx = None): + +def conf(name, default=None, cast=None, ctx=None): """ Retrieves the configuration value for the provided value defaulting to the provided default value in case no value @@ -127,42 +114,53 @@ def conf(name, default = None, cast = None, ctx = None): configs = ctx["configs"] if ctx else CONFIGS cast = _cast_r(cast) value = configs.get(name, default) - if cast and not value == None: value = cast(value) + if cast and not value == None: + value = cast(value) return value -def conf_prefix(prefix, ctx = None): + +def conf_prefix(prefix, ctx=None): configs = ctx["configs"] if ctx else CONFIGS configs_prefix = dict() for name, value in configs.items(): - if not name.startswith(prefix): continue + if not name.startswith(prefix): + continue configs_prefix[name] = value return configs_prefix -def conf_suffix(suffix, ctx = None): + +def conf_suffix(suffix, ctx=None): configs = ctx["configs"] if ctx else CONFIGS configs_suffix = dict() for name, value in configs.items(): - if not name.endswith(suffix): continue + if not name.endswith(suffix): + continue configs_suffix[name] = value return configs_suffix -def conf_s(name, value, ctx = None): + +def conf_s(name, value, ctx=None): configs = ctx["configs"] if ctx else CONFIGS configs[name] = value -def conf_r(name, ctx = None): + +def conf_r(name, ctx=None): configs = ctx["configs"] if ctx else CONFIGS - if not name in configs: return + if not name in configs: + return del configs[name] -def conf_d(ctx = None): + +def conf_d(ctx=None): configs = ctx["configs"] if ctx else CONFIGS return configs + def conf_ctx(): - return dict(configs = dict(), config_f = dict()) + return dict(configs=dict(), config_f=dict()) + -def load(names = (FILE_NAME,), path = None, encoding = "utf-8", ctx = None): +def load(names=(FILE_NAME,), path=None, encoding="utf-8", ctx=None): paths = [] homes = get_homes() for home in homes: @@ -174,43 +172,54 @@ def load(names = (FILE_NAME,), path = None, encoding = "utf-8", ctx = None): paths.append(path) for path in paths: for name in names: - load_file(name = name, path = path, encoding = encoding, ctx = ctx) - load_env(ctx = ctx) + load_file(name=name, path=path, encoding=encoding, ctx=ctx) + load_env(ctx=ctx) -def load_file(name = FILE_NAME, path = None, encoding = "utf-8", ctx = None): + +def load_file(name=FILE_NAME, path=None, encoding="utf-8", ctx=None): configs = ctx["configs"] if ctx else CONFIGS config_f = ctx["config_f"] if ctx else CONFIG_F - if path: path = os.path.normpath(path) - if path: file_path = os.path.join(path, name) - else: file_path = name + if path: + path = os.path.normpath(path) + if path: + file_path = os.path.join(path, name) + else: + file_path = name file_path = os.path.abspath(file_path) file_path = os.path.normpath(file_path) base_path = os.path.dirname(file_path) exists = os.path.exists(file_path) - if not exists: return + if not exists: + return exists = file_path in config_f - if exists: config_f.remove(file_path) + if exists: + config_f.remove(file_path) config_f.append(file_path) file = open(file_path, "rb") - try: data = file.read() - finally: file.close() - if not data: return + try: + data = file.read() + finally: + file.close() + if not data: + return data = data.decode(encoding) data_j = json.loads(data) - _load_includes(base_path, data_j, encoding = encoding) + _load_includes(base_path, data_j, encoding=encoding) for key, value in data_j.items(): - if not _is_valid(key): continue + if not _is_valid(key): + continue configs[key] = value -def load_env(ctx = None): + +def load_env(ctx=None): configs = ctx["configs"] if ctx else CONFIGS config = dict(os.environ) @@ -220,28 +229,31 @@ def load_env(ctx = None): _load_includes(home, config) for key, value in legacy.iteritems(config): - if not _is_valid(key): continue + if not _is_valid(key): + continue configs[key] = value is_bytes = legacy.is_bytes(value) - if not is_bytes: continue + if not is_bytes: + continue for encoding in ENV_ENCODINGS: - try: value = value.decode(encoding) - except UnicodeDecodeError: pass - else: break + try: + value = value.decode(encoding) + except UnicodeDecodeError: + pass + else: + break configs[key] = value -def get_homes( - file_path = HOME_FILE, - default = "~", - encoding = "utf-8", - force_default = False -): + +def get_homes(file_path=HOME_FILE, default="~", encoding="utf-8", force_default=False): global HOMES - if HOMES: return HOMES + if HOMES: + return HOMES HOMES = os.environ.get("HOMES", None) HOMES = HOMES.split(";") if HOMES else HOMES - if not HOMES == None: return HOMES + if not HOMES == None: + return HOMES default = os.path.expanduser(default) default = os.path.abspath(default) @@ -251,13 +263,17 @@ def get_homes( file_path = os.path.expanduser(file_path) file_path = os.path.normpath(file_path) exists = os.path.exists(file_path) - if not exists: return HOMES + if not exists: + return HOMES - if not force_default: del HOMES[:] + if not force_default: + del HOMES[:] file = open(file_path, "rb") - try: data = file.read() - finally: file.close() + try: + data = file.read() + finally: + file.close() data = data.decode("utf-8") data = data.strip() @@ -266,7 +282,8 @@ def get_homes( for path in paths: path = path.strip() - if not path: continue + if not path: + continue path = os.path.expanduser(path) path = os.path.abspath(path) path = os.path.normpath(path) @@ -274,13 +291,17 @@ def get_homes( return HOMES + def _cast_r(cast): is_string = type(cast) in legacy.STRINGS - if is_string: cast = __builtins__.get(cast, None) - if not cast: return None + if is_string: + cast = __builtins__.get(cast, None) + if not cast: + return None return CASTS.get(cast, cast) -def _load_includes(base_path, config, encoding = "utf-8"): + +def _load_includes(base_path, config, encoding="utf-8"): includes = () for alias in IMPORT_NAMES: @@ -290,16 +311,15 @@ def _load_includes(base_path, config, encoding = "utf-8"): includes = includes.split(";") for include in includes: - load_file( - name = include, - path = base_path, - encoding = encoding - ) + load_file(name=include, path=base_path, encoding=encoding) + def _is_valid(key): - if key in IMPORT_NAMES: return False + if key in IMPORT_NAMES: + return False return True + def _is_devel(): """ Simple debug/development level detection mechanism to be @@ -317,6 +337,7 @@ def _is_devel(): return conf("LEVEL", "INFO") in ("DEBUG",) + def _is_secure(): """ Simple secure variable that should be overridden only under @@ -329,6 +350,7 @@ def _is_secure(): secured type level of traceability. """ - return conf("SECURE", True, cast = bool) + return conf("SECURE", True, cast=bool) + load() diff --git a/src/netius/base/conn.py b/src/netius/base/conn.py index 7a6ea2bda..dbd6d1072 100644 --- a/src/netius/base/conn.py +++ b/src/netius/base/conn.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -67,6 +58,7 @@ """ The size of the chunk to be used while received data from the service socket """ + class BaseConnection(observer.Observable): """ Abstract connection object that should encapsulate @@ -81,13 +73,13 @@ class BaseConnection(observer.Observable): def __init__( self, - owner = None, - socket = None, - address = None, - datagram = False, - ssl = False, - max_pending = -1, - min_pending = -1 + owner=None, + socket=None, + address=None, + datagram=False, + ssl=False, + max_pending=-1, + min_pending=-1, ): observer.Observable.__init__(self) self.status = PENDING @@ -124,11 +116,12 @@ def destroy(self): self.pending.clear() self.restored.clear() - def open(self, connect = False): + def open(self, connect=False): # in case the current status of the connection is already open # it does not make sense to proceed with the opening of the # connection as the connection is already open - if self.status == OPEN: return + if self.status == OPEN: + return # retrieves the reference to the owner object from the # current instance to be used to add the socket to the @@ -154,7 +147,8 @@ def open(self, connect = False): # in case the connect flag is set, must set the current # connection as connecting, indicating that some extra # steps are still required to complete the connection - if connect: self.set_connecting() + if connect: + self.set_connecting() # calls the top level of connection creation handler so that the owner # object gets notified about the creation of the connection (open) @@ -165,11 +159,12 @@ def open(self, connect = False): # the current netius specification and strategy self.trigger("open", self) - def close(self, flush = False, destroy = True): + def close(self, flush=False, destroy=True): # in case the current status of the connection is closed it does # nor make sense to proceed with the closing as the connection # is already in the closed state (nothing to be done) - if self.status == CLOSED: return + if self.status == CLOSED: + return # in case the flush flag is set, a different approach is taken # where all the pending data is flushed (as possible) before @@ -220,14 +215,18 @@ def close(self, flush = False, destroy = True): # removes the current connection from the list of connections in the # owner and also from the map that associates the socket with the # proper connection (also in the owner) - if self in owner.connections: owner.connections.remove(self) - if self.socket in owner.connections_m: del owner.connections_m[self.socket] + if self in owner.connections: + owner.connections.remove(self) + if self.socket in owner.connections_m: + del owner.connections_m[self.socket] # closes the socket, using the proper grace way so that # operations are no longer allowed in the socket in case there are # an error in the operation fails silently (on purpose) - try: self.socket.close() - except Exception: pass + try: + self.socket.close() + except Exception: + pass # calls the top level of the connection delete handler so that the owner # object gets notified about the deletion of the connection (closed) @@ -242,15 +241,17 @@ def close(self, flush = False, destroy = True): # for instance, the current event registered handlers will no longer be available # this is important to avoid any memory leak from circular references from # this moment on the connection is considered disabled (not ready for usage) - if destroy: self.destroy() + if destroy: + self.destroy() def close_flush(self): - self.send(None, callback = self._close_callback) + self.send(None, callback=self._close_callback) - def upgrade(self, key_file = None, cer_file = None, ca_file = None, server = True): + def upgrade(self, key_file=None, cer_file=None, ca_file=None, server=True): # in case the current connection is already an SSL-oriented one there's # nothing to be done here, and the method returns immediately to the caller - if self.ssl: return + if self.ssl: + return # prints a debug message about the upgrading of the connection that is # going to be performed for the current connection @@ -266,15 +267,21 @@ def upgrade(self, key_file = None, cer_file = None, ca_file = None, server = Tru # determines if the arguments-based certificate and key values should be used # of if instead the owner values should be used as a fallback process, these # values are going to be used as part of the SSL upgrade process - if hasattr(self.owner, "key_file"): key_file = key_file or self.owner.key_file - if hasattr(self.owner, "cer_file"): cer_file = cer_file or self.owner.cer_file - if hasattr(self.owner, "ca_file"): ca_file = ca_file or self.owner.ca_file + if hasattr(self.owner, "key_file"): + key_file = key_file or self.owner.key_file + if hasattr(self.owner, "cer_file"): + cer_file = cer_file or self.owner.cer_file + if hasattr(self.owner, "ca_file"): + ca_file = ca_file or self.owner.ca_file # prints some debug information about the files that are going to be used for # the SSL based connection upgrade, this is mainly for debugging purposes - if key_file: self.owner.debug("Using '%s' as key file" % key_file) - if cer_file: self.owner.debug("Using '%s' as certificate file" % cer_file) - if ca_file: self.owner.debug("Using '%s' as certificate authority file" % ca_file) + if key_file: + self.owner.debug("Using '%s' as key file" % key_file) + if cer_file: + self.owner.debug("Using '%s' as certificate file" % cer_file) + if ca_file: + self.owner.debug("Using '%s' as certificate authority file" % ca_file) # removes the "old" association socket association for the connection and # unsubscribes the "old" socket from the complete set of events, this should @@ -288,10 +295,10 @@ def upgrade(self, key_file = None, cer_file = None, ca_file = None, server = Tru # encapsulated SSL socket is then set as the current connection's socket self.socket = self.owner._ssl_upgrade( self.socket, - key_file = key_file, - cer_file = cer_file, - ca_file = ca_file, - server = server + key_file=key_file, + cer_file=cer_file, + ca_file=ca_file, + server=server, ) # updates the current socket in the connection resolution map with the new SSL one @@ -319,14 +326,16 @@ def set_upgraded(self): self.upgrading = False self.trigger("upgrade", self) - def set_data(self, data, address = None): - if address: self.trigger("data", self, data, address) - else: self.trigger("data", self, data) + def set_data(self, data, address=None): + if address: + self.trigger("data", self, data, address) + else: + self.trigger("data", self, data) def set_exception(self, exception): self.trigger("exception", self, exception) - def ensure_write(self, flush = True): + def ensure_write(self, flush=True): # retrieves the identifier of the current thread and # checks if it's the same as the one defined in the # owner in case it's not then the operation is not @@ -340,24 +349,28 @@ def ensure_write(self, flush = True): # safe and so it must be delayed to be executed in the # next loop of the thread cycle, must return immediately # to avoid extra subscription operations - if not is_safe: return self.owner.delay(self.ensure_write, safe = True) + if not is_safe: + return self.owner.delay(self.ensure_write, safe=True) # verifies if the status of the connection is open and # in case it's not returned immediately as there's no reason # to so it for writing - if not self.status == OPEN: return + if not self.status == OPEN: + return # in case the write ready flag is enabled (writes allowed) # and the flush parameter is set to the ensure operation performs # the flush of the write operations (instead of subscription) # this may be done because it's safe to flush the write operations # when the write ready flag for the connection is set - if self.wready and flush: return self._flush_write() + if self.wready and flush: + return self._flush_write() # verifies if the owner object is already subscribed for the # write operation in case it is returned immediately in order # avoid any extra subscription operation - if self.owner.is_sub_write(self.socket): return + if self.owner.is_sub_write(self.socket): + return # adds the current socket to the list of write operations # so that it's going to be available for writing as soon @@ -365,7 +378,8 @@ def ensure_write(self, flush = True): self.owner.sub_write(self.socket) def remove_write(self): - if not self.status == OPEN: return + if not self.status == OPEN: + return self.owner.unsub_write(self.socket) def enable_read(self): @@ -379,8 +393,10 @@ def enable_read(self): to stall if misused. """ - if not self.status == OPEN: return - if not self.renable == False: return + if not self.status == OPEN: + return + if not self.renable == False: + return self.renable = True self.owner.sub_read(self.socket) @@ -395,13 +411,15 @@ def disable_read(self): of the event poll is required to avoid stalling. """ - if not self.status == OPEN: return - if not self.renable == True: return + if not self.status == OPEN: + return + if not self.renable == True: + return self.renable = False self.owner.unsub_read(self.socket) - def send(self, data, address = None, delay = True, force = False, callback = None): + def send(self, data, address=None, delay=True, force=False, callback=None): """ The main send call is to be used by a proxy connection and from different threads. @@ -457,7 +475,8 @@ def send(self, data, address = None, delay = True, force = False, callback = Non # verifies that the connection is currently in the open # state and then verifies if that's not the case returns # immediately, not possible to send data - if not self.status == OPEN and not force: return 0 + if not self.status == OPEN and not force: + return 0 # creates the tuple that is going to represent the data # to be sent, this tuple should contain the data itself @@ -486,33 +505,36 @@ def send(self, data, address = None, delay = True, force = False, callback = Non # the next tick operation (delayed execution), note that # running the flush operation immediately may lead to # typical stack overflow errors (due to recursion limit) - if is_safe and not delay: self._flush_write() - else: self.owner.delay( - self._flush_write, - immediately = True, - verify = True, - safe = True - ) + if is_safe and not delay: + self._flush_write() + else: + self.owner.delay( + self._flush_write, immediately=True, verify=True, safe=True + ) # otherwise the write stream is not ready and so the # connection must be ensured to write ready, should # subscribe to the write events as soon as possible - else: self.ensure_write() + else: + self.ensure_write() # returns the final number of bytes (length of data) # that has been submitted to be sent (as soon as possible) return data_l - def recv(self, size = CHUNK_SIZE, force = False): - if not self.status == OPEN and not force: return b"" - return self._recv(size = size) + def recv(self, size=CHUNK_SIZE, force=False): + if not self.status == OPEN and not force: + return b"" + return self._recv(size=size) - def pend(self, data, back = True): + def pend(self, data, back=True): # verifies if the provided data is a tuple and if that's # the case unpacks the callback value from it, required is_tuple = type(data) == tuple - if is_tuple: data_b, _address, _callback = data - else: data_b = data + if is_tuple: + data_b, _address, _callback = data + else: + data_b = data # calculates the size in bytes of the provided data so # that it may be used later for the incrementing of @@ -526,8 +548,10 @@ def pend(self, data, back = True): # that the FIFO strategy is maintained self.pending_lock.acquire() try: - if back: self.pending.appendleft(data) - else: self.pending.append(data) + if back: + self.pending.appendleft(data) + else: + self.pending.append(data) finally: self.pending_lock.release() @@ -540,7 +564,7 @@ def pend(self, data, back = True): # some of the flow controlling operation may have to be performed self.trigger("pend", self) - def restore(self, data, back = True): + def restore(self, data, back=True): """ Restore data to the pending (to receive) so that they are going to be "received" in the next receive operation. @@ -564,8 +588,10 @@ def restore(self, data, back = True): # going to be used in the next receive operation self.restored_lock.acquire() try: - if back: self.restored.appendleft(data) - else: self.restored.append(data) + if back: + self.restored.appendleft(data) + else: + self.restored.append(data) finally: self.restored_lock.release() @@ -584,14 +610,18 @@ def run_starter(self): # used in the current iteration, either from the currently # pending operation, next in line or as fallback an invalid # one, that is going to invalidate the iteration - if self._starter: starter = self._starter - elif self.starters: starter = self.starters.pop() - else: starter = None + if self._starter: + starter = self._starter + elif self.starters: + starter = self.starters.pop() + else: + starter = None # in case there's no starter pending and no other is set # in the queue for the connection (to be executed next) # breaks the current loop (nothing left to be done) - if not starter: break + if not starter: + break # sets the current starter as the starter currently selected # by the loop set of operations @@ -602,7 +632,8 @@ def run_starter(self): # that's note the case breaks the loop, as the stater should # finish on the next loop tick finished = self._starter == None - if not finished: return True + if not finished: + return True # returns the default invalid value, meaning that no more starter # operations are pending for the the current connection @@ -611,62 +642,72 @@ def run_starter(self): def end_starter(self): self._starter = None - def add_starter(self, starter, back = True): - if back: self.starters.appendleft(starter) - else: self.starters.append(starter) + def add_starter(self, starter, back=True): + if back: + self.starters.appendleft(starter) + else: + self.starters.append(starter) def remove_starter(self, starter): self.starters.remove(starter) - def info_dict(self, full = False): + def info_dict(self, full=False): info = dict( - status = self.status, - id = self.id, - connecting = self.connecting, - upgrading = self.upgrading, - address = self.address, - ssl = self.ssl, - renable = self.renable, - wready = self.wready, - pending_s = self.pending_s, - restored_s = self.restored_s + status=self.status, + id=self.id, + connecting=self.connecting, + upgrading=self.upgrading, + address=self.address, + ssl=self.ssl, + renable=self.renable, + wready=self.wready, + pending_s=self.pending_s, + restored_s=self.restored_s, ) return info - def ssl_certificate(self, binary = False): - if not self.ssl: return None - return self.socket.getpeercert(binary_form = binary) + def ssl_certificate(self, binary=False): + if not self.ssl: + return None + return self.socket.getpeercert(binary_form=binary) - def ssl_verify_host(self, host = None): + def ssl_verify_host(self, host=None): host = host or self.ssl_host - if not host: return + if not host: + return certificate = self.ssl_certificate() tls.match_hostname(certificate, host) - def ssl_verify_fingerprint(self, fingerprint = None): + def ssl_verify_fingerprint(self, fingerprint=None): fingerprint = fingerprint or self.ssl_fingerprint - if not fingerprint: return - certificate = self.ssl_certificate(binary = True) + if not fingerprint: + return + certificate = self.ssl_certificate(binary=True) tls.match_fingerprint(certificate, fingerprint) - def ssl_dump_certificate(self, dump = False): + def ssl_dump_certificate(self, dump=False): dump = dump or self.ssl_dump - if not dump: return + if not dump: + return certificate = self.ssl_certificate() - certificate_binary = self.ssl_certificate(binary = True) + certificate_binary = self.ssl_certificate(binary=True) tls.dump_certificate(certificate, certificate_binary) def ssl_protocol(self): return self.ssl_alpn_protocol() or self.ssl_npn_protocol() def ssl_alpn_protocol(self): - if not self.socket: return None - if not hasattr(self.socket, "selected_alpn_protocol"): return None + if not self.socket: + return None + if not hasattr(self.socket, "selected_alpn_protocol"): + return None return self.socket.selected_alpn_protocol() def ssl_npn_protocol(self): - if not self.socket: return None - if not hasattr(self.socket, "selected_npn_protocol"): return None + if not self.socket: + return None + if not hasattr(self.socket, "selected_npn_protocol"): + return None return self.socket.selected_npn_protocol() def is_open(self): @@ -708,7 +749,8 @@ def _send(self): # it's not possible to perform a sending operation and the # send operation is ignored, note that the write-ready flag # is still set as it may be used later for flushing operations - if self.connecting: return + if self.connecting: + return # acquires the pending lock so that no other access to the # the pending structure is made from a different thread @@ -721,7 +763,8 @@ def _send(self): # verifies if there's data pending to be sent in case # there's not returns immediately, because there is # nothing pending to be done for such a case - if not self.pending: break + if not self.pending: + break # retrieves the current data chunk to be sent from the # list of pending things and then saves the data chunk @@ -741,10 +784,15 @@ def _send(self): # part of the data has been sent, note that if no # data is provided the shutdown operation is performed # instead of closing the stream between both sockets - if is_close: self._shutdown(); count = 0 - elif address: count = self.socket.sendto(data, address) - elif data: count = self.socket.send(data) - else: count = 0 + if is_close: + self._shutdown() + count = 0 + elif address: + count = self.socket.sendto(data, address) + elif data: + count = self.socket.send(data) + else: + count = 0 # verifies if the current situation is that of a non # closed socket and valid data, and if that's the case @@ -796,7 +844,8 @@ def _send(self): # in case the is valid flag is set (all of the data for # the current write operation has been sent) calls the # the associated callback (in case it exists) - if is_valid and callback: callback(self) + if is_valid and callback: + callback(self) finally: # releases the pending access lock so that no leaks # exists and no access to the pending is prevented @@ -808,31 +857,40 @@ def _send(self): def _recv(self, size): data = self._recv_restored(size) - if data: return data - if self.datagram: return self.socket.recvfrom(size) - else: return self.socket.recv(size) + if data: + return data + if self.datagram: + return self.socket.recvfrom(size) + else: + return self.socket.recv(size) def _recv_ssl(self, size): data = self._recv_restored(size) - if data: return data + if data: + return data has_socket = hasattr(self.socket, "_sock") - if has_socket: return self.socket._sock.recv(size) - else: return socket.socket.recv(self.socket, size) + if has_socket: + return self.socket._sock.recv(size) + else: + return socket.socket.recv(self.socket, size) def _recv_restored(self, size): - if not self.restored_s: return b"" + if not self.restored_s: + return b"" data = self.restored.pop() data = data[:size] remaining = data[size:] self.restored_s -= len(data) - if remaining: self.restore(remaining, back = False) + if remaining: + self.restore(remaining, back=False) return data - def _shutdown(self, close = False, force = False, ignore = True): + def _shutdown(self, close=False, force=False, ignore=True): # in case the status of the current connection is # already closed returns immediately as it's not # possible to shutdown a closed connection - if self.status == CLOSED: return + if self.status == CLOSED: + return try: # verifies the type of connection and takes that @@ -843,19 +901,22 @@ def _shutdown(self, close = False, force = False, ignore = True): # normal shutdown operation for the socket if self.ssl and hasattr(self.socket._sslobj, "shutdown"): self.socket._sslobj.shutdown() - if force: self.socket.shutdown(socket.SHUT_RDWR) + if force: + self.socket.shutdown(socket.SHUT_RDWR) except (IOError, socket.error, ssl.SSLError): # ignores the IO/SSL error that has just been raised, this # assumes that the problem that has just occurred is not # relevant as the socket is shutting down and if a problem # occurs that must be related to the socket being closed # on the other side of the connection - if not ignore: raise + if not ignore: + raise # in case the close (connection) flag is active the # current connection should be closed immediately # following the successful shutdown of the socket - if close: self.close() + if close: + self.close() def _close_callback(self, connection): """ @@ -877,7 +938,8 @@ def _flush_write(self): pending for the current connection's socket. """ - self.owner.writes((self.socket,), state = False) + self.owner.writes((self.socket,), state=False) + class DiagConnection(BaseConnection): @@ -901,17 +963,18 @@ def send(self, data, *args, **kwargs): self.sends += 1 return result - def info_dict(self, full = False): - info = BaseConnection.info_dict(self, full = full) + def info_dict(self, full=False): + info = BaseConnection.info_dict(self, full=full) info.update( - uptime = self._uptime(), - recvs = self.recvs, - sends = self.sends, - in_bytes = self.in_bytes, - out_bytes = self.out_bytes + uptime=self._uptime(), + recvs=self.recvs, + sends=self.sends, + in_bytes=self.in_bytes, + out_bytes=self.out_bytes, ) geo = self._resolve(self.address) - if geo: info["geo"] = geo + if geo: + info["geo"] = geo return info def _uptime(self): @@ -922,9 +985,13 @@ def _uptime(self): def _resolve(self, address): import netius.common + ip, _port = address return netius.common.GeoResolver.resolve(ip) -is_diag = config.conf("DIAG", False, cast = bool) -if is_diag: Connection = DiagConnection -else: Connection = BaseConnection + +is_diag = config.conf("DIAG", False, cast=bool) +if is_diag: + Connection = DiagConnection +else: + Connection = BaseConnection diff --git a/src/netius/base/container.py b/src/netius/base/container.py index 12f8b4a7e..d3d3338af 100644 --- a/src/netius/base/container.py +++ b/src/netius/base/container.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,7 +30,8 @@ from . import server -from .common import * #@UnusedWildImport +from .common import * # @UnusedWildImport + class Container(Base): @@ -84,7 +76,8 @@ def cleanup(self): # iterates over all the bases registered and propagates the cleanup operation # over them, deleting the list of bases afterwards (no more usage for them) - for base in self.bases: base.cleanup() + for base in self.bases: + base.cleanup() del self.bases[:] # unbinds the start operation from the on start event, as this is no longer @@ -119,29 +112,33 @@ def loop(self): def ticks(self): self.set_state(STATE_TICK) self._lid = (self._lid + 1) % 2147483647 - for base in self.bases: base.ticks() + for base in self.bases: + base.ticks() - def connections_dict(self, full = False): + def connections_dict(self, full=False): all = dict() for base in self.bases: is_owner = base == self.owner - if is_owner: connections = base.connections_dict( - full = full, parent = True - ) - else: connections = base.connections_dict(full = full) + if is_owner: + connections = base.connections_dict(full=full, parent=True) + else: + connections = base.connections_dict(full=full) all[base.name] = connections return all - def connection_dict(self, id, full = False): + def connection_dict(self, id, full=False): connection = None for base in self.bases: for _connection in base.connections: - if not _connection.id == id: continue + if not _connection.id == id: + continue connection = _connection break - if connection: break - if not connection: return None - return connection.info_dict(full = full) + if connection: + break + if not connection: + return None + return connection.info_dict(full=full) def on_start(self): Base.on_start(self) @@ -165,10 +162,12 @@ def start_base(self, base): base.load() def start_all(self): - for base in self.bases: self.start_base(base) + for base in self.bases: + self.start_base(base) def apply_all(self): - for base in self.bases: self.apply_base(base) + for base in self.bases: + self.apply_base(base) def apply_base(self, base): base.tid = self.tid @@ -183,7 +182,9 @@ def call_all(self, name, *args, **kwargs): method(*args, **kwargs) def trigger_all(self, name, *args, **kwargs): - for base in self.bases: base.trigger(name, base, *args, **kwargs) + for base in self.bases: + base.trigger(name, base, *args, **kwargs) + class ContainerServer(server.StreamServer): @@ -202,7 +203,8 @@ def stop(self): # verifies if there's a container object currently defined in # the object and in case it does exist propagates the stop call # to the container so that the proper stop operation is performed - if not self.container: return + if not self.container: + return self.container.stop() def cleanup(self): @@ -215,7 +217,8 @@ def cleanup(self): # verifies if the container is valid and if that's not the case # returns the control flow immediately (as expected) - if not container: return + if not container: + return # runs the cleanup operation on the cleanup, this should properly # propagate the operation to the owner container (as expected) diff --git a/src/netius/base/diag.py b/src/netius/base/diag.py index 560b09315..9b146e3f7 100644 --- a/src/netius/base/diag.py +++ b/src/netius/base/diag.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,29 +33,26 @@ try: import appier + loaded = True except ImportError: import netius.mock + appier = netius.mock.appier loaded = False + class DiagApp(appier.APIApp): def __init__(self, system, *args, **kwargs): - appier.APIApp.__init__( - self, - name = "diag", - *args, **kwargs - ) + appier.APIApp.__init__(self, name="diag", *args, **kwargs) self.system = system @appier.route("/logger", "GET") def show_logger(self): level = self.system.logger.level level = logging.getLevelName(level) - return dict( - level = level - ) + return dict(level=level) @appier.route("/logger/set", ("GET", "POST")) def set_logger(self): @@ -74,22 +62,22 @@ def set_logger(self): @appier.route("/environ", "GET") def show_environ(self): - return self.json(dict(os.environ), sort_keys = True) + return self.json(dict(os.environ), sort_keys=True) @appier.route("/info", "GET") def system_info(self): - full = self.field("full", True, cast = bool) - info = self.system.info_dict(full = full) - return self.json(info, sort_keys = True) + full = self.field("full", True, cast=bool) + info = self.system.info_dict(full=full) + return self.json(info, sort_keys=True) @appier.route("/connections", "GET") def list_connections(self): - full = self.field("full", True, cast = bool) - info = self.system.connections_dict(full = full) - return self.json(info, sort_keys = True) + full = self.field("full", True, cast=bool) + info = self.system.connections_dict(full=full) + return self.json(info, sort_keys=True) @appier.route("/connections/", "GET") def show_connection(self, id): - full = self.field("full", True, cast = bool) - info = self.system.connection_dict(id, full = full) - return self.json(info, sort_keys = True) + full = self.field("full", True, cast=bool) + info = self.system.connection_dict(id, full=full) + return self.json(info, sort_keys=True) diff --git a/src/netius/base/errors.py b/src/netius/base/errors.py index 553f9b2ed..e9dab4e73 100644 --- a/src/netius/base/errors.py +++ b/src/netius/base/errors.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ import uuid + class NetiusError(Exception): """ The top level base error to be used in the @@ -61,15 +53,17 @@ def __init__(self, *args, **kwargs): self.details = kwargs["details"] self._uid = None - def get_kwarg(self, name, default = None): + def get_kwarg(self, name, default=None): return self.kwargs.get(name, default) @property def uid(self): - if self._uid: return self._uid + if self._uid: + return self._uid self._uid = uuid.uuid4() return self._uid + class RuntimeError(NetiusError): """ Error to be used for situations where an exception @@ -81,6 +75,7 @@ class in every exception raised during normal execution. pass + class StopError(RuntimeError): """ Error to be used for situations where a stop @@ -92,6 +87,7 @@ class StopError(RuntimeError): pass + class PauseError(RuntimeError): """ Error to be used for situations where a pause @@ -103,6 +99,7 @@ class PauseError(RuntimeError): pass + class WakeupError(RuntimeError): """ Error used to send a wakeup intent from one context @@ -114,6 +111,7 @@ class WakeupError(RuntimeError): pass + class DataError(RuntimeError): """ Error to be used for situations where the @@ -126,6 +124,7 @@ class DataError(RuntimeError): pass + class ParserError(RuntimeError): """ Error caused by a malformed data that invalidated @@ -140,6 +139,7 @@ def __init__(self, *args, **kwargs): kwargs["code"] = kwargs.get("code", 400) RuntimeError.__init__(self, *args, **kwargs) + class GeneratorError(RuntimeError): """ Error generated by a problem in the generation of @@ -151,6 +151,7 @@ class GeneratorError(RuntimeError): pass + class SecurityError(RuntimeError): """ Error caused by a failed security verification this @@ -163,6 +164,7 @@ class SecurityError(RuntimeError): pass + class NotImplemented(RuntimeError): """ Error caused by the non implementation of a certain @@ -176,6 +178,7 @@ class NotImplemented(RuntimeError): pass + class AssertionError(RuntimeError): """ Error raised for failure to meet any pre-condition or diff --git a/src/netius/base/legacy.py b/src/netius/base/legacy.py index 6140206da..bf15ee4dc 100644 --- a/src/netius/base/legacy.py +++ b/src/netius/base/legacy.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -47,71 +38,116 @@ import contextlib import collections -import urllib #@UnusedImport +import urllib # @UnusedImport + +ArgSpec = collections.namedtuple("ArgSpec", ["args", "varargs", "keywords", "defaults"]) -ArgSpec = collections.namedtuple( - "ArgSpec", - ["args", "varargs", "keywords", "defaults"] -) @contextlib.contextmanager def ctx_absolute(): root = sys.path.pop(0) - try: yield - finally: sys.path.insert(0, root) + try: + yield + finally: + sys.path.insert(0, root) + with ctx_absolute(): - try: import urllib2 - except ImportError: urllib2 = None + try: + import urllib2 + except ImportError: + urllib2 = None with ctx_absolute(): - try: import httplib - except ImportError: httplib = None + try: + import httplib + except ImportError: + httplib = None with ctx_absolute(): - try: import http - except ImportError: http = None + try: + import http + except ImportError: + http = None with ctx_absolute(): - try: import types - except ImportError: types = None + try: + import types + except ImportError: + types = None with ctx_absolute(): - try: import urllib.error - except ImportError: pass + try: + import urllib.error + except ImportError: + pass with ctx_absolute(): - try: import urllib.request - except ImportError: pass + try: + import urllib.request + except ImportError: + pass with ctx_absolute(): - try: import http.client - except ImportError: pass + try: + import http.client + except ImportError: + pass with ctx_absolute(): - try: import importlib.util - except ImportError: pass + try: + import importlib.util + except ImportError: + pass + +try: + import HTMLParser +except ImportError: + import html.parser + + HTMLParser = html.parser + +try: + import cPickle +except ImportError: + import pickle -try: import HTMLParser -except ImportError: import html.parser; HTMLParser = html.parser + cPickle = pickle -try: import cPickle -except ImportError: import pickle; cPickle = pickle +try: + import imp +except ImportError: + import importlib -try: import imp -except ImportError: import importlib; imp = importlib + imp = importlib -try: import importlib -except ImportError: import imp; importlib = imp +try: + import importlib +except ImportError: + import imp -try: import cStringIO -except ImportError: import io; cStringIO = io + importlib = imp -try: import StringIO as _StringIO -except ImportError: import io; _StringIO = io +try: + import cStringIO +except ImportError: + import io -try: import urlparse as _urlparse -except ImportError: import urllib.parse; _urlparse = urllib.parse + cStringIO = io + +try: + import StringIO as _StringIO +except ImportError: + import io + + _StringIO = io + +try: + import urlparse as _urlparse +except ImportError: + import urllib.parse + + _urlparse = urllib.parse PYTHON_3 = sys.version_info[0] >= 3 """ Global variable that defines if the current Python @@ -148,26 +184,40 @@ def ctx_absolute(): """ The Python version integer describing the version of a the interpreter as a set of three integer digits """ -if PYTHON_3: LONG = int -else: LONG = long #@UndefinedVariable - -if PYTHON_3: BYTES = bytes -else: BYTES = str #@UndefinedVariable - -if PYTHON_3: UNICODE = str -else: UNICODE = unicode #@UndefinedVariable - -if PYTHON_3: OLD_UNICODE = None -else: OLD_UNICODE = unicode #@UndefinedVariable - -if PYTHON_3: STRINGS = (str,) -else: STRINGS = (str, unicode) #@UndefinedVariable - -if PYTHON_3: ALL_STRINGS = (bytes, str) -else: ALL_STRINGS = (bytes, str, unicode) #@UndefinedVariable - -if PYTHON_3: INTEGERS = (int,) -else: INTEGERS = (int, long) #@UndefinedVariable +if PYTHON_3: + LONG = int +else: + LONG = long # @UndefinedVariable + +if PYTHON_3: + BYTES = bytes +else: + BYTES = str # @UndefinedVariable + +if PYTHON_3: + UNICODE = str +else: + UNICODE = unicode # @UndefinedVariable + +if PYTHON_3: + OLD_UNICODE = None +else: + OLD_UNICODE = unicode # @UndefinedVariable + +if PYTHON_3: + STRINGS = (str,) +else: + STRINGS = (str, unicode) # @UndefinedVariable + +if PYTHON_3: + ALL_STRINGS = (bytes, str) +else: + ALL_STRINGS = (bytes, str, unicode) # @UndefinedVariable + +if PYTHON_3: + INTEGERS = (int,) +else: + INTEGERS = (int, long) # @UndefinedVariable # saves a series of global symbols that are going to be # used latter for some of the legacy operations @@ -177,161 +227,247 @@ def ctx_absolute(): _bytes = bytes _range = range -try: _xrange = xrange #@UndefinedVariable -except Exception: _xrange = None - -if PYTHON_3: Request = urllib.request.Request -else: Request = urllib2.Request - -if PYTHON_3: HTTPHandler = urllib.request.HTTPHandler -else: HTTPHandler = urllib2.HTTPHandler - -if PYTHON_3: HTTPError = urllib.error.HTTPError -else: HTTPError = urllib2.HTTPError - -if PYTHON_3: HTTPConnection = http.client.HTTPConnection #@UndefinedVariable -else: HTTPConnection = httplib.HTTPConnection - -if PYTHON_3: HTTPSConnection = http.client.HTTPSConnection #@UndefinedVariable -else: HTTPSConnection = httplib.HTTPSConnection - -try: _execfile = execfile #@UndefinedVariable -except Exception: _execfile = None +try: + _xrange = xrange # @UndefinedVariable +except Exception: + _xrange = None + +if PYTHON_3: + Request = urllib.request.Request +else: + Request = urllib2.Request + +if PYTHON_3: + HTTPHandler = urllib.request.HTTPHandler +else: + HTTPHandler = urllib2.HTTPHandler + +if PYTHON_3: + HTTPError = urllib.error.HTTPError +else: + HTTPError = urllib2.HTTPError + +if PYTHON_3: + HTTPConnection = http.client.HTTPConnection # @UndefinedVariable +else: + HTTPConnection = httplib.HTTPConnection + +if PYTHON_3: + HTTPSConnection = http.client.HTTPSConnection # @UndefinedVariable +else: + HTTPSConnection = httplib.HTTPSConnection + +try: + _execfile = execfile # @UndefinedVariable +except Exception: + _execfile = None + +try: + _reduce = reduce # @UndefinedVariable +except Exception: + _reduce = None + +try: + _reload = reload # @UndefinedVariable +except Exception: + _reload = None + +try: + _unichr = unichr # @UndefinedVariable +except Exception: + _unichr = None -try: _reduce = reduce #@UndefinedVariable -except Exception: _reduce = None - -try: _reload = reload #@UndefinedVariable -except Exception: _reload = None - -try: _unichr = unichr #@UndefinedVariable -except Exception: _unichr = None def with_meta(meta, *bases): return meta("Class", bases, {}) + def eager(iterable): - if PYTHON_3: return list(iterable) + if PYTHON_3: + return list(iterable) return iterable + def iteritems(associative): - if PYTHON_3: return associative.items() + if PYTHON_3: + return associative.items() return associative.iteritems() + def iterkeys(associative): - if PYTHON_3: return associative.keys() + if PYTHON_3: + return associative.keys() return associative.iterkeys() + def itervalues(associative): - if PYTHON_3: return associative.values() + if PYTHON_3: + return associative.values() return associative.itervalues() + def items(associative): - if PYTHON_3: return eager(associative.items()) + if PYTHON_3: + return eager(associative.items()) return associative.items() + def keys(associative): - if PYTHON_3: return eager(associative.keys()) + if PYTHON_3: + return eager(associative.keys()) return associative.keys() + def values(associative): - if PYTHON_3: return eager(associative.values()) + if PYTHON_3: + return eager(associative.values()) return associative.values() -def xrange(start, stop = None, step = 1): - if PYTHON_3: return _range(start, stop, step) if stop else _range(start) + +def xrange(start, stop=None, step=1): + if PYTHON_3: + return _range(start, stop, step) if stop else _range(start) return _xrange(start, stop, step) if stop else _range(start) -def range(start, stop = None, step = None): - if PYTHON_3: return eager(_range(start, stop, step)) if stop else eager(_range(start)) + +def range(start, stop=None, step=None): + if PYTHON_3: + return eager(_range(start, stop, step)) if stop else eager(_range(start)) return _range(start, stop, step) if stop else _range(start) + def ord(value): - if PYTHON_3 and type(value) == int: return value + if PYTHON_3 and type(value) == int: + return value return _ord(value) + def chr(value): - if PYTHON_3: return _bytes([value]) - if type(value) in INTEGERS: return _chr(value) + if PYTHON_3: + return _bytes([value]) + if type(value) in INTEGERS: + return _chr(value) return value + def chri(value): - if PYTHON_3: return value - if type(value) in INTEGERS: return _chr(value) + if PYTHON_3: + return value + if type(value) in INTEGERS: + return _chr(value) return value -def bytes(value, encoding = "latin-1", errors = "strict", force = False): - if not PYTHON_3 and not force: return value - if value == None: return value - if type(value) == _bytes: return value + +def bytes(value, encoding="latin-1", errors="strict", force=False): + if not PYTHON_3 and not force: + return value + if value == None: + return value + if type(value) == _bytes: + return value return value.encode(encoding, errors) -def str(value, encoding = "latin-1", errors = "strict", force = False): - if not PYTHON_3 and not force: return value - if value == None: return value - if type(value) in STRINGS: return value + +def str(value, encoding="latin-1", errors="strict", force=False): + if not PYTHON_3 and not force: + return value + if value == None: + return value + if type(value) in STRINGS: + return value return value.decode(encoding, errors) -def u(value, encoding = "utf-8", errors = "strict", force = False): - if PYTHON_3 and not force: return value - if value == None: return value - if type(value) == UNICODE: return value + +def u(value, encoding="utf-8", errors="strict", force=False): + if PYTHON_3 and not force: + return value + if value == None: + return value + if type(value) == UNICODE: + return value return value.decode(encoding, errors) -def ascii(value, encoding = "utf-8", errors = "replace"): - if is_bytes(value): value = value.decode(encoding, errors) - else: value = UNICODE(value) + +def ascii(value, encoding="utf-8", errors="replace"): + if is_bytes(value): + value = value.decode(encoding, errors) + else: + value = UNICODE(value) value = value.encode("ascii", errors) value = str(value) return value + def orderable(value): - if not PYTHON_3: return value + if not PYTHON_3: + return value return Orderable(value) + def is_str(value): return type(value) == _str + def is_unicode(value): - if PYTHON_3: return type(value) == _str - else: return type(value) == unicode #@UndefinedVariable + if PYTHON_3: + return type(value) == _str + else: + return type(value) == unicode # @UndefinedVariable + def is_bytes(value): - if PYTHON_3: return type(value) == _bytes - else: return type(value) == _str #@UndefinedVariable + if PYTHON_3: + return type(value) == _bytes + else: + return type(value) == _str # @UndefinedVariable + -def is_string(value, all = False): +def is_string(value, all=False): target = ALL_STRINGS if all else STRINGS return type(value) in target + def is_generator(value): - if inspect.isgenerator(value): return True - if type(value) in (itertools.chain,): return True - if hasattr(value, "_is_generator"): return True + if inspect.isgenerator(value): + return True + if type(value) in (itertools.chain,): + return True + if hasattr(value, "_is_generator"): + return True return False + def is_async_generator(value): - if not hasattr(inspect, "isasyncgen"): return False + if not hasattr(inspect, "isasyncgen"): + return False return inspect.isasyncgen(value) -def is_unittest(name = "unittest"): + +def is_unittest(name="unittest"): current_stack = inspect.stack() for stack_frame in current_stack: for program_line in stack_frame[4]: is_unittest = not name in program_line - if is_unittest: continue + if is_unittest: + continue return True return False -def execfile(path, global_vars, local_vars = None, encoding = "utf-8"): - if local_vars == None: local_vars = global_vars - if not PYTHON_3: return _execfile(path, global_vars, local_vars) + +def execfile(path, global_vars, local_vars=None, encoding="utf-8"): + if local_vars == None: + local_vars = global_vars + if not PYTHON_3: + return _execfile(path, global_vars, local_vars) file = open(path, "rb") - try: data = file.read() - finally: file.close() + try: + data = file.read() + finally: + file.close() data = data.decode(encoding) code = compile(data, path, "exec") - exec(code, global_vars, local_vars) #@UndefinedVariable + exec(code, global_vars, local_vars) # @UndefinedVariable + def walk(path, visit, arg): for root, dirs, _files in os.walk(path): @@ -341,22 +477,33 @@ def walk(path, visit, arg): exists = dir in names not exists and dirs.remove(dir) + def getargspec(func): has_full = hasattr(inspect, "getfullargspec") - if has_full: return ArgSpec(*inspect.getfullargspec(func)[:4]) - else: return inspect.getargspec(func) + if has_full: + return ArgSpec(*inspect.getfullargspec(func)[:4]) + else: + return inspect.getargspec(func) + def has_module(name): if PYTHON_3: - try: spec = importlib.util.find_spec(name) - except ImportError: return False - if spec == None: return False + try: + spec = importlib.util.find_spec(name) + except ImportError: + return False + if spec == None: + return False return True - try: file, _path, _description = imp.find_module(name) - except ImportError: return False - if file: file.close() + try: + file, _path, _description = imp.find_module(name) + except ImportError: + return False + if file: + file.close() return True + def new_module(name): if hasattr(types, "ModuleType"): return types.ModuleType(name) @@ -364,25 +511,38 @@ def new_module(name): return imp.new_module(name) raise ValueError("No module build method available") + def reduce(*args, **kwargs): - if PYTHON_3: return functools.reduce(*args, **kwargs) + if PYTHON_3: + return functools.reduce(*args, **kwargs) return _reduce(*args, **kwargs) + def reload(*args, **kwargs): - if PYTHON_3: return importlib.reload(*args, **kwargs) + if PYTHON_3: + return importlib.reload(*args, **kwargs) return _reload(*args, **kwargs) + def unichr(*args, **kwargs): - if PYTHON_3: return _chr(*args, **kwargs) + if PYTHON_3: + return _chr(*args, **kwargs) return _unichr(*args, **kwargs) + def urlopen(*args, **kwargs): - if PYTHON_3: return urllib.request.urlopen(*args, **kwargs) - else: return urllib2.urlopen(*args, **kwargs) #@UndefinedVariable + if PYTHON_3: + return urllib.request.urlopen(*args, **kwargs) + else: + return urllib2.urlopen(*args, **kwargs) # @UndefinedVariable + def build_opener(*args, **kwargs): - if PYTHON_3: return urllib.request.build_opener(*args, **kwargs) - else: return urllib2.build_opener(*args, **kwargs) #@UndefinedVariable + if PYTHON_3: + return urllib.request.build_opener(*args, **kwargs) + else: + return urllib2.build_opener(*args, **kwargs) # @UndefinedVariable + def to_timestamp(date_time): if PYTHON_33: @@ -390,6 +550,7 @@ def to_timestamp(date_time): else: return calendar.timegm(date_time.utctimetuple()) + def to_datetime(timestamp): if PYTHON_33: return datetime.datetime.fromtimestamp( @@ -398,63 +559,99 @@ def to_datetime(timestamp): else: return datetime.datetime.utcfromtimestamp(timestamp) + def utcfromtimestamp(timestamp): return to_datetime(timestamp) + def utc_now(): if PYTHON_33: return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) else: return datetime.datetime.utcnow() + def urlparse(*args, **kwargs): return _urlparse.urlparse(*args, **kwargs) + def urlunparse(*args, **kwargs): return _urlparse.urlunparse(*args, **kwargs) + def parse_qs(*args, **kwargs): return _urlparse.parse_qs(*args, **kwargs) + def urlencode(*args, **kwargs): - if PYTHON_3: return urllib.parse.urlencode(*args, **kwargs) - else: return urllib.urlencode(*args, **kwargs) #@UndefinedVariable + if PYTHON_3: + return urllib.parse.urlencode(*args, **kwargs) + else: + return urllib.urlencode(*args, **kwargs) # @UndefinedVariable + def quote(*args, **kwargs): - if PYTHON_3: return urllib.parse.quote(*args, **kwargs) - else: return urllib.quote(*args, **kwargs) #@UndefinedVariable + if PYTHON_3: + return urllib.parse.quote(*args, **kwargs) + else: + return urllib.quote(*args, **kwargs) # @UndefinedVariable + def quote_plus(*args, **kwargs): - if PYTHON_3: return urllib.parse.quote_plus(*args, **kwargs) - else: return urllib.quote_plus(*args, **kwargs) #@UndefinedVariable + if PYTHON_3: + return urllib.parse.quote_plus(*args, **kwargs) + else: + return urllib.quote_plus(*args, **kwargs) # @UndefinedVariable + def unquote(*args, **kwargs): - if PYTHON_3: return urllib.parse.unquote(*args, **kwargs) - else: return urllib.unquote(*args, **kwargs) #@UndefinedVariable + if PYTHON_3: + return urllib.parse.unquote(*args, **kwargs) + else: + return urllib.unquote(*args, **kwargs) # @UndefinedVariable + def unquote_plus(*args, **kwargs): - if PYTHON_3: return urllib.parse.unquote_plus(*args, **kwargs) - else: return urllib.unquote_plus(*args, **kwargs) #@UndefinedVariable + if PYTHON_3: + return urllib.parse.unquote_plus(*args, **kwargs) + else: + return urllib.unquote_plus(*args, **kwargs) # @UndefinedVariable + def cmp_to_key(*args, **kwargs): - if PYTHON_3: return dict(key = functools.cmp_to_key(*args, **kwargs)) #@UndefinedVariable - else: return dict(cmp = args[0]) + if PYTHON_3: + return dict(key=functools.cmp_to_key(*args, **kwargs)) # @UndefinedVariable + else: + return dict(cmp=args[0]) + def tobytes(self, *args, **kwargs): - if PYTHON_3: return self.tobytes(*args, **kwargs) - else: return self.tostring(*args, **kwargs) + if PYTHON_3: + return self.tobytes(*args, **kwargs) + else: + return self.tostring(*args, **kwargs) + def tostring(self, *args, **kwargs): - if PYTHON_3: return self.tobytes(*args, **kwargs) - else: return self.tostring(*args, **kwargs) + if PYTHON_3: + return self.tobytes(*args, **kwargs) + else: + return self.tostring(*args, **kwargs) + def StringIO(*args, **kwargs): - if PYTHON_3: return cStringIO.StringIO(*args, **kwargs) - else: return _StringIO.StringIO(*args, **kwargs) + if PYTHON_3: + return cStringIO.StringIO(*args, **kwargs) + else: + return _StringIO.StringIO(*args, **kwargs) + def BytesIO(*args, **kwargs): - if PYTHON_3: return cStringIO.BytesIO(*args, **kwargs) - else: return cStringIO.StringIO(*args, **kwargs) + if PYTHON_3: + return cStringIO.BytesIO(*args, **kwargs) + else: + return cStringIO.StringIO(*args, **kwargs) + class Orderable(tuple): """ diff --git a/src/netius/base/log.py b/src/netius/base/log.py index b1c67f71f..5b276fb3a 100644 --- a/src/netius/base/log.py +++ b/src/netius/base/log.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -46,49 +37,45 @@ or an handler, this is used as an utility for debugging purposes more that a real feature for production systems """ + def rotating_handler( - path = "netius.log", - max_bytes = 1048576, - max_log = 5, - encoding = None, - delay = False + path="netius.log", max_bytes=1048576, max_log=5, encoding=None, delay=False ): return logging.handlers.RotatingFileHandler( - path, - maxBytes = max_bytes, - backupCount = max_log, - encoding = encoding, - delay = delay + path, maxBytes=max_bytes, backupCount=max_log, encoding=encoding, delay=delay ) + def smtp_handler( - host = "localhost", - port = 25, - sender = "no-reply@netius.com", - receivers = [], - subject = "Netius logging", - username = None, - password = None, - stls = False + host="localhost", + port=25, + sender="no-reply@netius.com", + receivers=[], + subject="Netius logging", + username=None, + password=None, + stls=False, ): address = (host, port) - if username and password: credentials = (username, password) - else: credentials = None + if username and password: + credentials = (username, password) + else: + credentials = None has_secure = in_signature(logging.handlers.SMTPHandler.__init__, "secure") - if has_secure: kwargs = dict(secure = () if stls else None) - else: kwargs = dict() + if has_secure: + kwargs = dict(secure=() if stls else None) + else: + kwargs = dict() return logging.handlers.SMTPHandler( - address, - sender, - receivers, - subject, - credentials = credentials, - **kwargs + address, sender, receivers, subject, credentials=credentials, **kwargs ) + def in_signature(callable, name): has_full = hasattr(inspect, "getfullargspec") - if has_full: spec = inspect.getfullargspec(callable) - else: spec = inspect.getargspec(callable) + if has_full: + spec = inspect.getfullargspec(callable) + else: + spec = inspect.getargspec(callable) args, _varargs, kwargs = spec[:3] return (args and name in args) or (kwargs and "secure" in kwargs) diff --git a/src/netius/base/observer.py b/src/netius/base/observer.py index 54523e23f..be51a83b9 100644 --- a/src/netius/base/observer.py +++ b/src/netius/base/observer.py @@ -22,21 +22,13 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ __license__ = "Apache License, Version 2.0" """ The license for the module """ + class Observable(object): """ The base class that implements the observable @@ -56,32 +48,43 @@ def build(self): def destroy(self): self.unbind_all() - def bind(self, name, method, oneshot = False): - if oneshot: method.oneshot = oneshot + def bind(self, name, method, oneshot=False): + if oneshot: + method.oneshot = oneshot methods = self.events.get(name, []) methods.append(method) self.events[name] = methods - def unbind(self, name, method = None): + def unbind(self, name, method=None): methods = self.events.get(name, None) - if not methods: return - if method: methods.remove(method) - else: del methods[:] + if not methods: + return + if method: + methods.remove(method) + else: + del methods[:] def unbind_all(self): - if not hasattr(self, "events"): return - for methods in self.events.values(): del methods[:] + if not hasattr(self, "events"): + return + for methods in self.events.values(): + del methods[:] self.events.clear() def trigger(self, name, *args, **kwargs): methods = self.events.get(name, None) - if not methods: return + if not methods: + return oneshots = None for method in methods: method(*args, **kwargs) - if not hasattr(method, "oneshot"): continue - if not method.oneshot: continue + if not hasattr(method, "oneshot"): + continue + if not method.oneshot: + continue oneshots = [] if oneshots == None else oneshots oneshots.append(method) - if not oneshots: return - for oneshot in oneshots: self.unbind(name, oneshot) + if not oneshots: + return + for oneshot in oneshots: + self.unbind(name, oneshot) diff --git a/src/netius/base/poll.py b/src/netius/base/poll.py index cd7885a9c..7ad3ac3eb 100644 --- a/src/netius/base/poll.py +++ b/src/netius/base/poll.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -45,6 +36,7 @@ this should be considered the maximum amount of time a thread waits for a poll request """ + class Poll(object): """ The top level abstract implementation of a poll object @@ -70,8 +62,9 @@ def name(cls): def test(cls): return True - def open(self, timeout = POLL_TIMEOUT): - if self._open: return + def open(self, timeout=POLL_TIMEOUT): + if self._open: + return self._open = True self.timeout = timeout @@ -80,7 +73,8 @@ def open(self, timeout = POLL_TIMEOUT): self.error_o.clear() def close(self): - if not self._open: return + if not self._open: + return self._open = False self.read_o.clear() @@ -130,10 +124,10 @@ def is_edge(self): def is_empty(self): return not self.read_o and not self.write_o and not self.error_o - def sub_all(self, socket, owner = None): - self.sub_read(socket, owner = owner) - self.sub_write(socket, owner = owner) - self.sub_error(socket, owner = owner) + def sub_all(self, socket, owner=None): + self.sub_read(socket, owner=owner) + self.sub_write(socket, owner=owner) + self.sub_error(socket, owner=owner) def unsub_all(self, socket): self.unsub_error(socket) @@ -149,30 +143,37 @@ def is_sub_write(self, socket): def is_sub_error(self, socket): return socket in self.error_o - def sub_read(self, socket, owner = None): - if socket in self.read_o: return + def sub_read(self, socket, owner=None): + if socket in self.read_o: + return self.read_o[socket] = owner - def sub_write(self, socket, owner = None): - if socket in self.write_o: return + def sub_write(self, socket, owner=None): + if socket in self.write_o: + return self.write_o[socket] = owner - def sub_error(self, socket, owner = None): - if socket in self.error_o: return + def sub_error(self, socket, owner=None): + if socket in self.error_o: + return self.error_o[socket] = owner def unsub_read(self, socket): - if not socket in self.read_o: return + if not socket in self.read_o: + return del self.read_o[socket] def unsub_write(self, socket): - if not socket in self.write_o: return + if not socket in self.write_o: + return del self.write_o[socket] def unsub_error(self, socket): - if not socket in self.error_o: return + if not socket in self.error_o: + return del self.error_o[socket] + class EpollPoll(Poll): def __init__(self, *args, **kwargs): @@ -183,12 +184,13 @@ def __init__(self, *args, **kwargs): def test(cls): return hasattr(select, "epoll") - def open(self, timeout = POLL_TIMEOUT): - if self._open: return + def open(self, timeout=POLL_TIMEOUT): + if self._open: + return self._open = True self.timeout = timeout - self.epoll = select.epoll() #@UndefinedVariable + self.epoll = select.epoll() # @UndefinedVariable self.fd_m = {} @@ -197,10 +199,12 @@ def open(self, timeout = POLL_TIMEOUT): self.error_o = {} def close(self): - if not self._open: return + if not self._open: + return self._open = False - for fd in self.fd_m: self.epoll.unregister(fd) + for fd in self.fd_m: + self.epoll.unregister(fd) self.epoll.close() self.epoll = None @@ -215,13 +219,13 @@ def poll(self): events = self.epoll.poll(self.timeout) for fd, event in events: - if event & select.EPOLLIN: #@UndefinedVariable + if event & select.EPOLLIN: # @UndefinedVariable socket = self.fd_m.get(fd, None) socket and result[0].append(socket) - if event & select.EPOLLOUT: #@UndefinedVariable + if event & select.EPOLLOUT: # @UndefinedVariable socket = self.fd_m.get(fd, None) socket and result[1].append(socket) - if event & select.EPOLLERR or event & select.EPOLLHUP: #@UndefinedVariable + if event & select.EPOLLERR or event & select.EPOLLHUP: # @UndefinedVariable socket = self.fd_m.get(fd, None) socket and result[2].append(socket) @@ -230,30 +234,34 @@ def poll(self): def is_edge(self): return True - def sub_read(self, socket, owner = None): - if socket in self.read_o: return + def sub_read(self, socket, owner=None): + if socket in self.read_o: + return socket_fd = socket.fileno() self.fd_m[socket_fd] = socket self.read_o[socket] = owner self.write_o[socket] = owner self.error_o[socket] = owner - self.epoll.register( #@UndefinedVariable + self.epoll.register( # @UndefinedVariable socket_fd, - select.EPOLLIN | select.EPOLLOUT | select.EPOLLERR | select.EPOLLHUP | select.EPOLLET #@UndefinedVariable + select.EPOLLIN + | select.EPOLLOUT + | select.EPOLLERR + | select.EPOLLHUP + | select.EPOLLET, # @UndefinedVariable ) - def sub_write(self, socket, owner = None): + def sub_write(self, socket, owner=None): pass - def sub_error(self, socket, owner = None): + def sub_error(self, socket, owner=None): pass def unsub_read(self, socket): - if not socket in self.read_o: return + if not socket in self.read_o: + return socket_fd = socket.fileno() - self.epoll.unregister( #@UndefinedVariable - socket_fd - ) + self.epoll.unregister(socket_fd) # @UndefinedVariable del self.fd_m[socket_fd] del self.read_o[socket] del self.write_o[socket] @@ -265,6 +273,7 @@ def unsub_write(self, socket): def unsub_error(self, socket): pass + class KqueuePoll(Poll): def __init__(self, *args, **kwargs): @@ -275,13 +284,15 @@ def __init__(self, *args, **kwargs): def test(cls): return hasattr(select, "kqueue") - def open(self, timeout = POLL_TIMEOUT): - if self._open: return + def open(self, timeout=POLL_TIMEOUT): + if self._open: + return self._open = True self.timeout = timeout - if self.timeout < 0: self.timeout = None + if self.timeout < 0: + self.timeout = None - self.kqueue = select.kqueue() #@UndefinedVariable + self.kqueue = select.kqueue() # @UndefinedVariable self.fd_m = {} @@ -290,7 +301,8 @@ def open(self, timeout = POLL_TIMEOUT): self.error_o = {} def close(self): - if not self._open: return + if not self._open: + return self._open = False self.kqueue.close() @@ -307,16 +319,16 @@ def poll(self): events = self.kqueue.control(None, 32, self.timeout) for event in events: - if event.flags & select.KQ_EV_ERROR: #@UndefinedVariable + if event.flags & select.KQ_EV_ERROR: # @UndefinedVariable socket = self.fd_m.get(event.udata, None) socket and result[2].append(socket) - elif event.filter == select.KQ_FILTER_READ: #@UndefinedVariable + elif event.filter == select.KQ_FILTER_READ: # @UndefinedVariable socket = self.fd_m.get(event.udata, None) - index = 2 if event.flags & select.KQ_EV_EOF else 0 #@UndefinedVariable + index = 2 if event.flags & select.KQ_EV_EOF else 0 # @UndefinedVariable socket and result[index].append(socket) - elif event.filter == select.KQ_FILTER_WRITE: #@UndefinedVariable + elif event.filter == select.KQ_FILTER_WRITE: # @UndefinedVariable socket = self.fd_m.get(event.udata, None) - index = 2 if event.flags & select.KQ_EV_EOF else 1 #@UndefinedVariable + index = 2 if event.flags & select.KQ_EV_EOF else 1 # @UndefinedVariable socket and result[index].append(socket) return result @@ -324,47 +336,49 @@ def poll(self): def is_edge(self): return True - def sub_read(self, socket, owner = None): - if socket in self.read_o: return + def sub_read(self, socket, owner=None): + if socket in self.read_o: + return socket_fd = socket.fileno() self.fd_m[socket_fd] = socket self.read_o[socket] = owner self.write_o[socket] = owner self.error_o[socket] = owner - event = select.kevent( #@UndefinedVariable + event = select.kevent( # @UndefinedVariable socket_fd, - filter = select.KQ_FILTER_READ, #@UndefinedVariable - flags = select.KQ_EV_ADD | select.KQ_EV_CLEAR, #@UndefinedVariable - udata = socket_fd + filter=select.KQ_FILTER_READ, # @UndefinedVariable + flags=select.KQ_EV_ADD | select.KQ_EV_CLEAR, # @UndefinedVariable + udata=socket_fd, ) self.kqueue.control([event], 0) - event = select.kevent( #@UndefinedVariable + event = select.kevent( # @UndefinedVariable socket_fd, - filter = select.KQ_FILTER_WRITE, #@UndefinedVariable - flags = select.KQ_EV_ADD | select.KQ_EV_CLEAR, #@UndefinedVariable - udata = socket_fd + filter=select.KQ_FILTER_WRITE, # @UndefinedVariable + flags=select.KQ_EV_ADD | select.KQ_EV_CLEAR, # @UndefinedVariable + udata=socket_fd, ) self.kqueue.control([event], 0) - def sub_write(self, socket, owner = None): + def sub_write(self, socket, owner=None): pass - def sub_error(self, socket, owner = None): + def sub_error(self, socket, owner=None): pass def unsub_read(self, socket): - if not socket in self.read_o: return + if not socket in self.read_o: + return socket_fd = socket.fileno() - event = select.kevent( #@UndefinedVariable + event = select.kevent( # @UndefinedVariable socket_fd, - filter = select.KQ_FILTER_READ, #@UndefinedVariable - flags = select.KQ_EV_DELETE #@UndefinedVariable + filter=select.KQ_FILTER_READ, # @UndefinedVariable + flags=select.KQ_EV_DELETE, # @UndefinedVariable ) self.kqueue.control([event], 0) - event = select.kevent( #@UndefinedVariable + event = select.kevent( # @UndefinedVariable socket_fd, - filter = select.KQ_FILTER_WRITE, #@UndefinedVariable - flags = select.KQ_EV_DELETE #@UndefinedVariable + filter=select.KQ_FILTER_WRITE, # @UndefinedVariable + flags=select.KQ_EV_DELETE, # @UndefinedVariable ) self.kqueue.control([event], 0) del self.fd_m[socket_fd] @@ -378,6 +392,7 @@ def unsub_write(self, socket): def unsub_error(self, socket): pass + class PollPoll(Poll): def __init__(self, *args, **kwargs): @@ -388,12 +403,13 @@ def __init__(self, *args, **kwargs): def test(cls): return hasattr(select, "poll") - def open(self, timeout = POLL_TIMEOUT): - if self._open: return + def open(self, timeout=POLL_TIMEOUT): + if self._open: + return self._open = True self.timeout = timeout - self._poll = select.poll() #@UndefinedVariable + self._poll = select.poll() # @UndefinedVariable self.read_fd = {} self.write_fd = {} @@ -403,10 +419,12 @@ def open(self, timeout = POLL_TIMEOUT): self.error_o = {} def close(self): - if not self._open: return + if not self._open: + return self._open = False - for fd in self.read_fd: self._poll.unregister(fd) + for fd in self.read_fd: + self._poll.unregister(fd) self._poll = None self.read_fd.clear() @@ -421,13 +439,13 @@ def poll(self): events = self._poll.poll(self.timeout * 1000) for fd, event in events: - if event & select.POLLIN: #@UndefinedVariable + if event & select.POLLIN: # @UndefinedVariable socket = self.read_fd.get(fd, None) socket and result[0].append(socket) - if event & select.POLLOUT: #@UndefinedVariable + if event & select.POLLOUT: # @UndefinedVariable socket = self.write_fd.get(fd, None) socket and result[1].append(socket) - if event & select.POLLERR or event & select.POLLHUP: #@UndefinedVariable + if event & select.POLLERR or event & select.POLLHUP: # @UndefinedVariable socket = self.read_fd.get(fd, None) socket and result[2].append(socket) @@ -436,64 +454,68 @@ def poll(self): def is_edge(self): return False - def sub_read(self, socket, owner = None): - if socket in self.read_o: return + def sub_read(self, socket, owner=None): + if socket in self.read_o: + return socket_fd = socket.fileno() self.read_fd[socket_fd] = socket self.read_o[socket] = owner - self._poll.register( #@UndefinedVariable - socket_fd, - select.POLLIN #@UndefinedVariable + self._poll.register( # @UndefinedVariable + socket_fd, select.POLLIN # @UndefinedVariable ) - def sub_write(self, socket, owner = None): - if socket in self.write_o: return + def sub_write(self, socket, owner=None): + if socket in self.write_o: + return socket_fd = socket.fileno() self.write_fd[socket_fd] = socket self.write_o[socket] = owner - self._poll.modify( #@UndefinedVariable - socket_fd, - select.POLLIN | select.POLLOUT #@UndefinedVariable + self._poll.modify( # @UndefinedVariable + socket_fd, select.POLLIN | select.POLLOUT # @UndefinedVariable ) - def sub_error(self, socket, owner = None): - if socket in self.error_o: return + def sub_error(self, socket, owner=None): + if socket in self.error_o: + return self.error_o[socket] = owner def unsub_read(self, socket): - if not socket in self.read_o: return + if not socket in self.read_o: + return socket_fd = socket.fileno() - self._poll.unregister( #@UndefinedVariable - socket_fd - ) + self._poll.unregister(socket_fd) # @UndefinedVariable del self.read_fd[socket_fd] del self.read_o[socket] def unsub_write(self, socket): - if not socket in self.write_o: return + if not socket in self.write_o: + return socket_fd = socket.fileno() - self._poll.modify( #@UndefinedVariable - socket_fd, - select.POLLIN #@UndefinedVariable + self._poll.modify( # @UndefinedVariable + socket_fd, select.POLLIN # @UndefinedVariable ) del self.write_fd[socket_fd] del self.write_o[socket] def unsub_error(self, socket): - if not socket in self.error_o: return + if not socket in self.error_o: + return del self.error_o[socket] + class SelectPoll(Poll): def __init__(self, *args, **kwargs): Poll.__init__(self, *args, **kwargs) self._open = False - def open(self, timeout = POLL_TIMEOUT): - if self._open: return + def open(self, timeout=POLL_TIMEOUT): + if self._open: + return self._open = True self.timeout = timeout - if self.timeout < 0: self.timeout = None + if self.timeout < 0: + self.timeout = None self.read_l = [] self.write_l = [] @@ -504,7 +526,8 @@ def open(self, timeout = POLL_TIMEOUT): self.error_o = {} def close(self): - if not self._open: return + if not self._open: + return self._open = False # removes the contents of all of the loop related structures @@ -529,47 +552,50 @@ def poll(self): # in case it's sleeps for a while and then continues # the loop (this avoids error in empty selection) is_empty = self.is_empty() - if is_empty: time.sleep(sleep_timeout); return ([], [], []) + if is_empty: + time.sleep(sleep_timeout) + return ([], [], []) # runs the proper select statement waiting for the desired # amount of time as timeout at the end a tuple with three # list for the different operations should be returned - return select.select( - self.read_l, - self.write_l, - self.error_l, - self.timeout - ) + return select.select(self.read_l, self.write_l, self.error_l, self.timeout) def is_edge(self): return False - def sub_read(self, socket, owner = None): - if socket in self.read_o: return + def sub_read(self, socket, owner=None): + if socket in self.read_o: + return self.read_o[socket] = owner self.read_l.append(socket) - def sub_write(self, socket, owner = None): - if socket in self.write_o: return + def sub_write(self, socket, owner=None): + if socket in self.write_o: + return self.write_o[socket] = owner self.write_l.append(socket) - def sub_error(self, socket, owner = None): - if socket in self.error_o: return + def sub_error(self, socket, owner=None): + if socket in self.error_o: + return self.error_o[socket] = owner self.error_l.append(socket) def unsub_read(self, socket): - if not socket in self.read_o: return + if not socket in self.read_o: + return self.read_l.remove(socket) del self.read_o[socket] def unsub_write(self, socket): - if not socket in self.write_o: return + if not socket in self.write_o: + return self.write_l.remove(socket) del self.write_o[socket] def unsub_error(self, socket): - if not socket in self.error_o: return + if not socket in self.error_o: + return self.error_l.remove(socket) del self.error_o[socket] diff --git a/src/netius/base/protocol.py b/src/netius/base/protocol.py index 08777785a..ca26f98eb 100644 --- a/src/netius/base/protocol.py +++ b/src/netius/base/protocol.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ from . import request from . import observer + class Protocol(observer.Observable): """ Abstract class from which concrete implementation of @@ -51,7 +43,7 @@ class Protocol(observer.Observable): of processed data (send). """ - def __init__(self, owner = None): + def __init__(self, owner=None): observer.Observable.__init__(self) self.owner = owner self._transport = None @@ -66,7 +58,8 @@ def __init__(self, owner = None): def open(self): # in case the protocol is already open, ignores the current # call as it's considered a double opening - if self.is_open(): return + if self.is_open(): + return # calls the concrete implementation of the open operation # allowing an extra level of indirection @@ -77,7 +70,8 @@ def open(self): def close(self): # in case the protocol is already closed, ignores the current # call considering it a double closing operation - if self.is_closed() or self.is_closing(): return + if self.is_closed() or self.is_closing(): + return # calls the concrete implementation of the close operation # allowing an extra level of indirection @@ -88,8 +82,10 @@ def close(self): def finish(self): # in case the current protocol is already (completely) closed # or is not in the state of closing, nothing should be done - if self.is_closed(): return - if not self.is_closing(): return + if self.is_closed(): + return + if not self.is_closing(): + return # calls the concrete implementation of the finish operation # allowing an extra level of indirection @@ -137,9 +133,10 @@ def finish_c(self): self._closed = True self._closing = False - def info_dict(self, full = False): - if not self._transport: return dict() - info = self._transport.info_dict(full = full) + def info_dict(self, full=False): + if not self._transport: + return dict() + info = self._transport.info_dict(full=full) return info def connection_made(self, transport): @@ -176,51 +173,60 @@ def resume_writing(self): self._flush_callbacks() self._flush_send() - def delay(self, callable, timeout = None): + def delay(self, callable, timeout=None): # in case there's no event loop defined for the protocol # it's not possible to delay this execution so the # callable is called immediately - if not self._loop: return callable() + if not self._loop: + return callable() # verifies if the assigned loop contains the non-standard # delay method and if that's the case calls it instead of # the base asyncio API ones (compatibility) if hasattr(self._loop, "delay"): immediately = timeout == None - return self._loop.delay( - callable, - timeout = timeout, - immediately = immediately - ) + return self._loop.delay(callable, timeout=timeout, immediately=immediately) # calls the proper call method taking into account if a timeout # value exists or not (soon against later) - if timeout: return self._loop.call_later(timeout, callable) - else: return self._loop.call_soon(callable) + if timeout: + return self._loop.call_later(timeout, callable) + else: + return self._loop.call_soon(callable) def debug(self, object): - if not self._loop: return - if not hasattr(self._loop, "debug"): return + if not self._loop: + return + if not hasattr(self._loop, "debug"): + return self._loop.debug(object) def info(self, object): - if not self._loop: return - if not hasattr(self._loop, "info"): return + if not self._loop: + return + if not hasattr(self._loop, "info"): + return self._loop.info(object) def warning(self, object): - if not self._loop: return - if not hasattr(self._loop, "warning"): return + if not self._loop: + return + if not hasattr(self._loop, "warning"): + return self._loop.warning(object) def error(self, object): - if not self._loop: return - if not hasattr(self._loop, "error"): return + if not self._loop: + return + if not hasattr(self._loop, "error"): + return self._loop.error(object) def critical(self, object): - if not self._loop: return - if not hasattr(self._loop, "critical"): return + if not self._loop: + return + if not hasattr(self._loop, "critical"): + return self._loop.critical(object) def is_pending(self): @@ -239,15 +245,18 @@ def is_closed_or_closing(self): return self._closed or self._closing def is_devel(self): - if not self._loop: return False - if not hasattr(self._loop, "is_devel"): return False + if not self._loop: + return False + if not hasattr(self._loop, "is_devel"): + return False return self._loop.is_devel() - def _close_transport(self, force = False): - if not self._transport: return + def _close_transport(self, force=False): + if not self._transport: + return self._transport.abort() - def _delay_send(self, data, address = None, callback = None): + def _delay_send(self, data, address=None, callback=None): item = (data, address, callback) self._delayed.append(item) return len(data) @@ -259,11 +268,16 @@ def _flush_callbacks(self): def _flush_send(self): while True: - if not self._delayed: break - if not self._writing: break + if not self._delayed: + break + if not self._writing: + break data, address, callback = self._delayed.pop(0) - if address: self.send(data, address, callback = callback) - else: self.send(data, callback = callback) + if address: + self.send(data, address, callback=callback) + else: + self.send(data, callback=callback) + class DatagramProtocol(Protocol): @@ -281,30 +295,10 @@ def error_received(self, exception): def on_data(self, address, data): self.trigger("data", self, data) - def send( - self, - data, - address, - delay = True, - force = False, - callback = None - ): - return self.send_to( - data, - address, - delay = delay, - force = force, - callback = callback - ) - - def send_to( - self, - data, - address, - delay = True, - force = False, - callback = None - ): + def send(self, data, address, delay=True, force=False, callback=None): + return self.send_to(data, address, delay=delay, force=force, callback=callback) + + def send_to(self, data, address, delay=True, force=False, callback=None): # ensures that the provided data value is a bytes sequence # so that its format is compliant with what's expected by # the underlying transport send to operation @@ -314,11 +308,7 @@ def send_to( # (paused mode) the writing of the data is delayed until the # writing is again enabled (resume writing) if not self._writing: - return self._delay_send( - data, - address = address, - callback = callback - ) + return self._delay_send(data, address=address, callback=callback) # pushes the write data down to the transport layer immediately # as writing is still allowed for the current protocol @@ -330,8 +320,10 @@ def send_to( # to be called on the next tick, otherwise adds it to the # callbacks to be called upon the next write resume operation if callback: - if self._writing: self.delay(lambda: callback(self._transport)) - else: self._callbacks.append(callback) + if self._writing: + self.delay(lambda: callback(self._transport)) + else: + self._callbacks.append(callback) # returns the size (in bytes) of the data that has just been # explicitly sent through the associated transport @@ -351,9 +343,11 @@ def remove_request(self, request): def get_request(self, id): is_response = isinstance(id, request.Response) - if is_response: id = id.get_id() + if is_response: + id = id.get_id() return self.requests_m.get(id, None) + class StreamProtocol(Protocol): def data_received(self, data): @@ -365,7 +359,7 @@ def eof_received(self): def on_data(self, data): self.trigger("data", self, data) - def send(self, data, delay = True, force = False, callback = None): + def send(self, data, delay=True, force=False, callback=None): # ensures that the provided data value is a bytes sequence # so that its format is compliant with what's expected by # the underlying transport write operation @@ -375,7 +369,7 @@ def send(self, data, delay = True, force = False, callback = None): # (paused mode) the writing of the data is delayed until the # writing is again enabled (resume writing) if not self._writing: - return self._delay_send(data, callback = callback) + return self._delay_send(data, callback=callback) # pushes the write data down to the transport layer immediately # as writing is still allowed for the current protocol @@ -387,8 +381,10 @@ def send(self, data, delay = True, force = False, callback = None): # to be called on the next tick otherwise adds it to the # callbacks to be called upon the next write resume operation if callback: - if self._writing: self.delay(lambda: callback(self._transport)) - else: self._callbacks.append(callback) + if self._writing: + self.delay(lambda: callback(self._transport)) + else: + self._callbacks.append(callback) # returns the size (in bytes) of the data that has just been # explicitly sent through the associated transport diff --git a/src/netius/base/request.py b/src/netius/base/request.py index 97e8db5d9..9726f0f0d 100644 --- a/src/netius/base/request.py +++ b/src/netius/base/request.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -46,6 +37,7 @@ expired and is discarded from the request related structures, this is crucial to avoid memory leaks """ + class Request(object): """ Abstract request structure used to represent @@ -57,16 +49,17 @@ class Request(object): """ The global class identifier value that is going to be used when assigning new values to the request """ - def __init__(self, timeout = REQUEST_TIMEOUT, callback = None): + def __init__(self, timeout=REQUEST_TIMEOUT, callback=None): self.id = self.__class__._generate_id() self.timeout = time.time() + timeout self.callback = callback @classmethod def _generate_id(cls): - cls.IDENTIFIER = (cls.IDENTIFIER + 1) & 0xffff + cls.IDENTIFIER = (cls.IDENTIFIER + 1) & 0xFFFF return cls.IDENTIFIER + class Response(object): """ Top level abstract representation of a response to @@ -79,7 +72,7 @@ class Response(object): generated identifier. """ - def __init__(self, data, request = None): + def __init__(self, data, request=None): self.data = data self.request = request diff --git a/src/netius/base/server.py b/src/netius/base/server.py index e77e1924c..2fc39976c 100644 --- a/src/netius/base/server.py +++ b/src/netius/base/server.py @@ -22,23 +22,14 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ __license__ = "Apache License, Version 2.0" """ The license for the module """ -from .conn import * #@UnusedWildImport -from .common import * #@UnusedWildImport +from .conn import * # @UnusedWildImport +from .common import * # @UnusedWildImport BUFFER_SIZE_S = None """ The size of both the send and receive buffers for @@ -51,6 +42,7 @@ the server (client sockets), this is critical for a good performance of the server (large value) """ + class Server(Base): def __init__(self, *args, **kwargs): @@ -85,98 +77,103 @@ def cleanup(self): # tries to close the service socket, as this is the one that # has no connection associated and is independent - try: self.socket and self.socket.close() - except Exception: pass + try: + self.socket and self.socket.close() + except Exception: + pass # unsets the socket attribute as the socket should now be closed # and not able to be used for any kind of communication self.socket = None - def info_dict(self, full = False): - info = Base.info_dict(self, full = full) - info.update( - host = self.host, - port = self.port, - type = self.type, - ssl = self.ssl - ) + def info_dict(self, full=False): + info = Base.info_dict(self, full=full) + info.update(host=self.host, port=self.port, type=self.type, ssl=self.ssl) return info def serve( self, - host = None, - port = 9090, - type = TCP_TYPE, - ipv6 = False, - ssl = False, - key_file = None, - cer_file = None, - ca_file = None, - ca_root = True, - ssl_verify = False, - ssl_host = None, - ssl_fingerprint = None, - ssl_dump = False, - setuid = None, - backlog = socket.SOMAXCONN, - load = True, - start = True, - env = False + host=None, + port=9090, + type=TCP_TYPE, + ipv6=False, + ssl=False, + key_file=None, + cer_file=None, + ca_file=None, + ca_root=True, + ssl_verify=False, + ssl_host=None, + ssl_fingerprint=None, + ssl_dump=False, + setuid=None, + backlog=socket.SOMAXCONN, + load=True, + start=True, + env=False, ): # processes the various default values taking into account if # the environment variables are meant to be processed for the # current context (default values are processed accordingly) host = self.get_env("HOST", host) if env else host - port = self.get_env("PORT", port, cast = int) if env else port - type = self.get_env("TYPE", type, cast = int) if env else type - ipv6 = self.get_env("IPV6", ipv6, cast = bool) if env else ipv6 - ssl = self.get_env("SSL", ssl, cast = bool) if env else ssl + port = self.get_env("PORT", port, cast=int) if env else port + type = self.get_env("TYPE", type, cast=int) if env else type + ipv6 = self.get_env("IPV6", ipv6, cast=bool) if env else ipv6 + ssl = self.get_env("SSL", ssl, cast=bool) if env else ssl port = self.get_env("UNIX_PATH", port) if env else port key_file = self.get_env("KEY_FILE", key_file) if env else key_file cer_file = self.get_env("CER_FILE", cer_file) if env else cer_file ca_file = self.get_env("CA_FILE", ca_file) if env else ca_file - ca_root = self.get_env("CA_ROOT", ca_root, cast = bool) if env else ca_root - ssl_verify = self.get_env("SSL_VERIFY", ssl_verify, cast = bool) if env else ssl_verify + ca_root = self.get_env("CA_ROOT", ca_root, cast=bool) if env else ca_root + ssl_verify = ( + self.get_env("SSL_VERIFY", ssl_verify, cast=bool) if env else ssl_verify + ) ssl_host = self.get_env("SSL_HOST", ssl_host) if env else ssl_host - ssl_fingerprint = self.get_env("SSL_FINGERPRINT", ssl_fingerprint) if env else ssl_fingerprint + ssl_fingerprint = ( + self.get_env("SSL_FINGERPRINT", ssl_fingerprint) if env else ssl_fingerprint + ) ssl_dump = self.get_env("SSL_DUMP", ssl_dump) if env else ssl_dump - key_file = self.get_env("KEY_DATA", key_file, expand = True) if env else key_file - cer_file = self.get_env("CER_DATA", cer_file, expand = True) if env else cer_file - ca_file = self.get_env("CA_DATA", ca_file, expand = True) if env else ca_file - setuid = self.get_env("SETUID", setuid, cast = int) if env else setuid - backlog = self.get_env("BACKLOG", backlog, cast = int) if env else backlog + key_file = self.get_env("KEY_DATA", key_file, expand=True) if env else key_file + cer_file = self.get_env("CER_DATA", cer_file, expand=True) if env else cer_file + ca_file = self.get_env("CA_DATA", ca_file, expand=True) if env else ca_file + setuid = self.get_env("SETUID", setuid, cast=int) if env else setuid + backlog = self.get_env("BACKLOG", backlog, cast=int) if env else backlog # runs the various extra variable initialization taking into # account if the environment variable is currently set or not # please note that some side effects may arise from this set - if env: self.level = self.get_env("LEVEL", self.level) - if env: self.diag = self.get_env("DIAG", self.diag, cast = bool) - if env: self.middleware = self.get_env("MIDDLEWARE", self.middleware, cast = list) - if env: self.children = self.get_env("CHILD", self.children, cast = int) - if env: self.children = self.get_env("CHILDREN", self.children, cast = int) - if env: self.logging = self.get_env("LOGGING", self.logging) - if env: self.poll_name = self.get_env("POLL", self.poll_name) - if env: self.poll_timeout = self.get_env( - "POLL_TIMEOUT", - self.poll_timeout, - cast = float - ) - if env: self.keepalive_timeout = self.get_env( - "KEEPALIVE_TIMEOUT", - self.keepalive_timeout, - cast = int - ) - if env: self.keepalive_interval = self.get_env( - "KEEPALIVE_INTERVAL", - self.keepalive_interval, - cast = int - ) - if env: self.keepalive_count = self.get_env( - "KEEPALIVE_COUNT", - self.keepalive_count, - cast = int - ) - if env: self.allowed = self.get_env("ALLOWED", self.allowed, cast = list) + if env: + self.level = self.get_env("LEVEL", self.level) + if env: + self.diag = self.get_env("DIAG", self.diag, cast=bool) + if env: + self.middleware = self.get_env("MIDDLEWARE", self.middleware, cast=list) + if env: + self.children = self.get_env("CHILD", self.children, cast=int) + if env: + self.children = self.get_env("CHILDREN", self.children, cast=int) + if env: + self.logging = self.get_env("LOGGING", self.logging) + if env: + self.poll_name = self.get_env("POLL", self.poll_name) + if env: + self.poll_timeout = self.get_env( + "POLL_TIMEOUT", self.poll_timeout, cast=float + ) + if env: + self.keepalive_timeout = self.get_env( + "KEEPALIVE_TIMEOUT", self.keepalive_timeout, cast=int + ) + if env: + self.keepalive_interval = self.get_env( + "KEEPALIVE_INTERVAL", self.keepalive_interval, cast=int + ) + if env: + self.keepalive_count = self.get_env( + "KEEPALIVE_COUNT", self.keepalive_count, cast=int + ) + if env: + self.allowed = self.get_env("ALLOWED", self.allowed, cast=list) # updates the current service status to the configuration # stage as the next steps is to configure the service socket @@ -184,12 +181,14 @@ def serve( # starts the loading process of the base system so that the system should # be able to log some information that is going to be output - if load: self.load() + if load: + self.load() # ensures the proper default address value, taking into account # the type of connection that is currently being used, this avoids # problems with multiple stack based servers (ipv4 and ipv6) - if host == None: host = "::1" if ipv6 else "127.0.0.1" + if host == None: + host = "::1" if ipv6 else "127.0.0.1" # defaults the provided SSL key and certificate paths to the # ones statically defined (dummy certificates), please beware @@ -231,46 +230,56 @@ def serve( # creates a service socket according to the defined service family = socket.AF_INET6 if ipv6 else socket.AF_INET family = socket.AF_UNIX if is_unix else family - if type == TCP_TYPE: self.socket = self.socket_tcp( - ssl, - key_file = key_file, - cer_file = cer_file, - ca_file = ca_file, - ca_root = ca_root, - ssl_verify = ssl_verify, - family = family - ) - elif type == UDP_TYPE: self.socket = self.socket_udp() - else: raise errors.NetiusError("Invalid server type provided '%d'" % type) + if type == TCP_TYPE: + self.socket = self.socket_tcp( + ssl, + key_file=key_file, + cer_file=cer_file, + ca_file=ca_file, + ca_root=ca_root, + ssl_verify=ssl_verify, + family=family, + ) + elif type == UDP_TYPE: + self.socket = self.socket_udp() + else: + raise errors.NetiusError("Invalid server type provided '%d'" % type) # "calculates" the address "bind target", taking into account that this # server may be running under a unix based socket infra-structure and # if that's the case the target (file path) is also removed, avoiding # a duplicated usage of the socket (required for address re-usage) address = port if is_unix else (host, port) - if is_unix and os.path.exists(address): os.remove(address) + if is_unix and os.path.exists(address): + os.remove(address) # binds the socket to the provided address value (per spec) and then # starts the listening in the socket with the provided backlog value # defaulting to the typical maximum backlog as possible if not provided self.socket.bind(address) - if type == TCP_TYPE: self.socket.listen(backlog) + if type == TCP_TYPE: + self.socket.listen(backlog) # in case the set user id value the user of the current process should # be changed so that it represents the new (possibly unprivileged user) - if setuid: os.setuid(setuid) + if setuid: + os.setuid(setuid) # in case the selected port is zero based, meaning that a randomly selected # port has been assigned by the bind operation the new port must be retrieved # and set for the current server instance as the new port (for future reference) - if self.port == 0: self.port = self.socket.getsockname()[1] + if self.port == 0: + self.port = self.socket.getsockname()[1] # creates the string that identifies it the current service connection # is using a secure channel (SSL) and then prints an info message about # the service that is going to be started ipv6_s = " on IPv6" if ipv6 else "" ssl_s = " using SSL" if ssl else "" - self.info("Serving '%s' service on %s:%s%s%s ..." % (self.name, host, port, ipv6_s, ssl_s)) + self.info( + "Serving '%s' service on %s:%s%s%s ..." + % (self.name, host, port, ipv6_s, ssl_s) + ) # runs the fork operation responsible for the forking of the # current process into the various child processes for multiple @@ -279,13 +288,14 @@ def serve( # in case the result is not valid an immediate return is performed # as this represents a master based process (not meant to serve) result = self.fork() - if not result: return + if not result: + return # ensures that the current polling mechanism is correctly open as the # service socket is going to be added to it next, this overrides the # default behavior of the common infra-structure (on start) self.poll = self.build_poll() - self.poll.open(timeout = self.poll_timeout) + self.poll.open(timeout=self.poll_timeout) # adds the socket to all of the pool lists so that it's ready to read # write and handle error, this is the expected behavior of a service @@ -299,18 +309,19 @@ def serve( # starts the base system so that the event loop gets started and the # the servers get ready to accept new connections (starts service) - if start: self.start() + if start: + self.start() def socket_tcp( self, - ssl = False, - key_file = None, - cer_file = None, - ca_file = None, - ca_root = True, - ssl_verify = False, - family = socket.AF_INET, - type = socket.SOCK_STREAM + ssl=False, + key_file=None, + cer_file=None, + ca_file=None, + ca_root=True, + ssl_verify=False, + family=socket.AF_INET, + type=socket.SOCK_STREAM, ): # verifies if the provided family is of type internet and if that's # the case the associated flag is set to valid for usage @@ -320,10 +331,14 @@ def socket_tcp( # and the prints a series of log message about the socket to be created type_s = " SSL" if ssl else "" self.debug("Creating server's TCP%s socket ..." % type_s) - if ssl: self.debug("Loading '%s' as key file" % key_file) - if ssl: self.debug("Loading '%s' as certificate file" % cer_file) - if ssl and ca_file: self.debug("Loading '%s' as certificate authority file" % ca_file) - if ssl and ssl_verify: self.debug("Loading with client SSL verification") + if ssl: + self.debug("Loading '%s' as key file" % key_file) + if ssl: + self.debug("Loading '%s' as certificate file" % cer_file) + if ssl and ca_file: + self.debug("Loading '%s' as certificate authority file" % ca_file) + if ssl and ssl_verify: + self.debug("Loading with client SSL verification") # creates the socket that it's going to be used for the listening # of new connections (server socket) and sets it as non blocking @@ -332,15 +347,16 @@ def socket_tcp( # in case the server is meant to be used as SSL wraps the socket # in suck fashion so that it becomes "secured" - if ssl: _socket = self._ssl_wrap( - _socket, - key_file = key_file, - cer_file = cer_file, - ca_file = ca_file, - ca_root = ca_root, - server = True, - ssl_verify = ssl_verify - ) + if ssl: + _socket = self._ssl_wrap( + _socket, + key_file=key_file, + cer_file=cer_file, + ca_file=ca_file, + ca_root=ca_root, + server=True, + ssl_verify=ssl_verify, + ) # sets the various options in the service socket so that it becomes # ready for the operation with the highest possible performance, these @@ -349,28 +365,21 @@ def socket_tcp( # avoiding the leak of connections (operative system managed) _socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) _socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if is_inet: _socket.setsockopt( - socket.IPPROTO_TCP, - socket.TCP_NODELAY, - 1 - ) - if self.receive_buffer_s: _socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_RCVBUF, - self.receive_buffer_s - ) - if self.send_buffer_s: _socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_SNDBUF, - self.send_buffer_s - ) + if is_inet: + _socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if self.receive_buffer_s: + _socket.setsockopt( + socket.SOL_SOCKET, socket.SO_RCVBUF, self.receive_buffer_s + ) + if self.send_buffer_s: + _socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.send_buffer_s) self._socket_keepalive(_socket) # returns the created TCP socket to the calling method so that it # may be used from this point on return _socket - def socket_udp(self, family = socket.AF_INET, type = socket.SOCK_DGRAM): + def socket_udp(self, family=socket.AF_INET, type=socket.SOCK_DGRAM): # prints a small debug message about the udp socket that is going # to be created for the server's connection self.debug("Creating server's udp socket ...") @@ -392,6 +401,7 @@ def socket_udp(self, family = socket.AF_INET, type = socket.SOCK_DGRAM): def on_serve(self): pass + class DatagramServer(Server): def __init__(self, *args, **kwargs): @@ -402,23 +412,23 @@ def __init__(self, *args, **kwargs): self.pending = collections.deque() self.pending_lock = threading.RLock() - def reads(self, reads, state = True): - Server.reads(self, reads, state = state) + def reads(self, reads, state=True): + Server.reads(self, reads, state=state) for read in reads: self.on_read(read) - def writes(self, writes, state = True): - Server.writes(self, writes, state = state) + def writes(self, writes, state=True): + Server.writes(self, writes, state=state) for write in writes: self.on_write(write) - def errors(self, errors, state = True): - Server.errors(self, errors, state = state) + def errors(self, errors, state=True): + Server.errors(self, errors, state=state) for error in errors: self.on_error(error) - def serve(self, type = UDP_TYPE, *args, **kwargs): - Server.serve(self, type = type, *args, **kwargs) + def serve(self, type=UDP_TYPE, *args, **kwargs): + Server.serve(self, type=type, *args, **kwargs) def on_read(self, _socket): # tries to retrieve a proper callback for the socket @@ -426,17 +436,20 @@ def on_read(self, _socket): # proper callback as expected callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("read", _socket) + for callback in callbacks: + callback("read", _socket) # in case the read enabled flag is not currently set # must return immediately because the read operation # is not currently being allowed - if not self.renable == True: return + if not self.renable == True: + return # verifies if the provided socket for reading is the same # as the one registered in the client if that's not the case # return immediately to avoid unwanted operations - if not _socket == self.socket: return + if not _socket == self.socket: + return try: # iterates continuously trying to read as much data as possible @@ -447,14 +460,14 @@ def on_read(self, _socket): while True: data, address = _socket.recvfrom(CHUNK_SIZE) self.on_data(address, data) - if not self.renable == True: break + if not self.renable == True: + break except ssl.SSLError as error: error_v = error.args[0] if error.args else None error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error) except socket.error as error: error_v = error.args[0] if error.args else None @@ -470,12 +483,14 @@ def on_read(self, _socket): def on_write(self, _socket): callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("write", _socket) + for callback in callbacks: + callback("write", _socket) # verifies if the provided socket for writing is the same # as the one registered in the client if that's not the case # return immediately to avoid unwanted operations - if not _socket == self.socket: return + if not _socket == self.socket: + return try: self._send(_socket) @@ -484,8 +499,7 @@ def on_write(self, _socket): error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error) except socket.error as error: error_v = error.args[0] if error.args else None @@ -501,12 +515,14 @@ def on_write(self, _socket): def on_error(self, _socket): callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("error", _socket) + for callback in callbacks: + callback("error", _socket) # verifies if the provided socket for error is the same # as the one registered in the client if that's not the case # return immediately to avoid unwanted operations - if not _socket == self.socket: return + if not _socket == self.socket: + return def on_exception(self, exception): self.warning(exception) @@ -532,7 +548,8 @@ def ensure_write(self): # safe and so it must be delayed to be executed in the # next loop of the thread cycle, must return immediately # to avoid extra subscription operations - if not is_safe: return self.delay(self.ensure_write, safe = True) + if not is_safe: + return self.delay(self.ensure_write, safe=True) # adds the current socket to the list of write operations # so that it's going to be available for writing as soon @@ -543,20 +560,23 @@ def remove_write(self): self.unsub_write(self.socket) def enable_read(self): - if not self.renable == False: return + if not self.renable == False: + return self.renable = True self.sub_read(self.socket) def disable_read(self): - if not self.renable == True: return + if not self.renable == True: + return self.renable = False self.unsub_read(self.socket) - def send(self, data, address, delay = True, callback = None): + def send(self, data, address, delay=True, callback=None): data = legacy.bytes(data) data_l = len(data) - if callback: data = (data, callback) + if callback: + data = (data, callback) data = (data, address) cthread = threading.current_thread() @@ -564,19 +584,18 @@ def send(self, data, address, delay = True, callback = None): is_safe = tid == self.tid self.pending_lock.acquire() - try: self.pending.appendleft(data) - finally: self.pending_lock.release() + try: + self.pending.appendleft(data) + finally: + self.pending_lock.release() self.pending_s += data_l if self.wready: - if is_safe and not delay: self._flush_write() - else: self.delay( - self._flush_write, - immediately = True, - verify = True, - safe = True - ) + if is_safe and not delay: + self._flush_write() + else: + self.delay(self._flush_write, immediately=True, verify=True, safe=True) else: self.ensure_write() @@ -587,7 +606,8 @@ def _send(self, _socket): while True: # in case there's no pending data to be sent to the # client side breaks the current loop (queue empty) - if not self.pending: break + if not self.pending: + break # retrieves the current data from the pending list # of data to be sent and then saves the original data @@ -601,7 +621,8 @@ def _send(self, _socket): # verifies if the data type of the data is a tuple and # if that's the case unpacks it as data and callback is_tuple = type(data) == tuple - if is_tuple: data, callback = data + if is_tuple: + data, callback = data # retrieves the length (in bytes) of the data that is # going to be sent to the client @@ -613,8 +634,10 @@ def _send(self, _socket): # sent through the socket, this number may not be # the same as the size of the data in case only # part of the data has been sent - if data: count = _socket.sendto(data, address) - else: count = 0 + if data: + count = _socket.sendto(data, address) + else: + count = 0 # verifies if the current situation is that of a non # closed socket and valid data, and if that's the case @@ -622,7 +645,8 @@ def _send(self, _socket): # be in a would block situation and and such an error # is raised indicating the issue (is going to be caught # as a normal would block exception) - if data and count == 0: raise socket.error(errno.EWOULDBLOCK) + if data and count == 0: + raise socket.error(errno.EWOULDBLOCK) except: # sets the current connection write ready flag to false # so that a new level notification must be received @@ -651,7 +675,8 @@ def _send(self, _socket): # sent latter (only then the callback is called) is_valid = count == data_l if is_valid: - if callback: callback(self) + if callback: + callback(self) else: data_o = ((data[count:], callback), address) self.pending.append(data_o) @@ -667,44 +692,53 @@ def _flush_write(self): pending for the current connection's socket. """ - self.writes((self.socket,), state = False) + self.writes((self.socket,), state=False) + class StreamServer(Server): - def reads(self, reads, state = True): - Server.reads(self, reads, state = state) + def reads(self, reads, state=True): + Server.reads(self, reads, state=state) for read in reads: - if read == self.socket: self.on_read_s(read) - else: self.on_read(read) + if read == self.socket: + self.on_read_s(read) + else: + self.on_read(read) - def writes(self, writes, state = True): - Server.writes(self, writes, state = state) + def writes(self, writes, state=True): + Server.writes(self, writes, state=state) for write in writes: - if write == self.socket: self.on_write_s(write) - else: self.on_write(write) + if write == self.socket: + self.on_write_s(write) + else: + self.on_write(write) - def errors(self, errors, state = True): - Server.errors(self, errors, state = state) + def errors(self, errors, state=True): + Server.errors(self, errors, state=state) for error in errors: - if error == self.socket: self.on_error_s(error) - else: self.on_error(error) + if error == self.socket: + self.on_error_s(error) + else: + self.on_error(error) - def serve(self, type = TCP_TYPE, *args, **kwargs): - Server.serve(self, type = type, *args, **kwargs) + def serve(self, type=TCP_TYPE, *args, **kwargs): + Server.serve(self, type=type, *args, **kwargs) def on_read_s(self, _socket): try: while True: socket_c, address = _socket.accept() - try: self.on_socket_c(socket_c, address) - except Exception: socket_c.close(); raise + try: + self.on_socket_c(socket_c, address) + except Exception: + socket_c.close() + raise except ssl.SSLError as error: error_v = error.args[0] if error.args else None error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected_s(error) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception_s(error) except socket.error as error: error_v = error.args[0] if error.args else None @@ -740,39 +774,49 @@ def on_read(self, _socket): # to the execution of the read operation in the socket callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("read", _socket) + for callback in callbacks: + callback("read", _socket) # run a series of validations, in case there no connection # or the connection is not ready for return the control flow is # returned to the caller method (nothing to be done) - if not connection: return - if not connection.status == OPEN: return - if not connection.renable == True: return + if not connection: + return + if not connection.status == OPEN: + return + if not connection.renable == True: + return try: # verifies if there's any pending operations in the # connection (eg: SSL handshaking) and performs it trying # to finish them, if they are still pending at the current # state returns immediately (waits for next loop) - if self._pending(connection): return + if self._pending(connection): + return # iterates continuously trying to read as much data as possible # when there's a failure to read more data it should raise an # exception that should be handled properly while True: data = connection.recv(CHUNK_SIZE) - if data: self.on_data(connection, data) - else: connection.close(); break - if not connection.status == OPEN: break - if not connection.renable == True: break - if not connection.socket == _socket: break + if data: + self.on_data(connection, data) + else: + connection.close() + break + if not connection.status == OPEN: + break + if not connection.renable == True: + break + if not connection.socket == _socket: + break except ssl.SSLError as error: error_v = error.args[0] if error.args else None error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error, connection) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error, connection) except socket.error as error: error_v = error.args[0] if error.args else None @@ -796,10 +840,13 @@ def on_write(self, _socket): callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("write", _socket) + for callback in callbacks: + callback("write", _socket) - if not connection: return - if not connection.status == OPEN: return + if not connection: + return + if not connection.status == OPEN: + return try: connection._send() @@ -808,8 +855,7 @@ def on_write(self, _socket): error_m = error.reason if hasattr(error, "reason") else None if error_v in SSL_SILENT_ERRORS: self.on_expected(error, connection) - elif not error_v in SSL_VALID_ERRORS and\ - not error_m in SSL_VALID_REASONS: + elif not error_v in SSL_VALID_ERRORS and not error_m in SSL_VALID_REASONS: self.on_exception(error, connection) except socket.error as error: error_v = error.args[0] if error.args else None @@ -833,10 +879,13 @@ def on_error(self, _socket): callbacks = self.callbacks_m.get(_socket, None) if callbacks: - for callback in callbacks: callback("error", _socket) + for callback in callbacks: + callback("error", _socket) - if not connection: return - if not connection.status == OPEN: return + if not connection: + return + if not connection.status == OPEN: + return connection.close() @@ -866,26 +915,33 @@ def on_ssl(self, connection): # as a fallback the SSL verification process is performed with no # value defined, meaning that a possible (SSL) host value set in the # connection is going to be used instead for the verification - if self.ssl_host: connection.ssl_verify_host(self.ssl_host) - else: connection.ssl_verify_host() + if self.ssl_host: + connection.ssl_verify_host(self.ssl_host) + else: + connection.ssl_verify_host() # in case the SSL fingerprint verification process is enabled for the # current server the client certificates are going to be verified for # their integrity using this technique, otherwise the default verification # process is going to be run instead - if self.ssl_fingerprint: connection.ssl_verify_fingerprint(self.ssl_fingerprint) - else: connection.ssl_verify_fingerprint() + if self.ssl_fingerprint: + connection.ssl_verify_fingerprint(self.ssl_fingerprint) + else: + connection.ssl_verify_fingerprint() # in case the SSL dump flag is set the dump operation is performed according # to that flag, otherwise the default operation is performed, that in most # of the cases should prevent the dump of the information - if self.ssl_dump: connection.ssl_dump_certificate(self.ssl_dump) - else: connection.ssl_dump_certificate() + if self.ssl_dump: + connection.ssl_dump_certificate(self.ssl_dump) + else: + connection.ssl_dump_certificate() # in case the current connection is under the upgrade # status calls the proper event handler so that the # connection workflow may proceed accordingly - if connection.upgrading: self.on_upgrade(connection) + if connection.upgrading: + self.on_upgrade(connection) def on_data(self, connection, data): connection.set_data(data) @@ -896,19 +952,21 @@ def on_socket_c(self, socket_c, address): # the case raises an exception indicating the issue host = address[0] if address else "" result = netius.common.assert_ip4(host, self.allowed) - if not result: raise errors.NetiusError( - "Address '%s' not present in allowed list" % host - ) + if not result: + raise errors.NetiusError("Address '%s' not present in allowed list" % host) # verifies a series of pre-conditions on the socket so # that it's ensured to be in a valid state before it's # set as a new connection for the server (validation) - if self.ssl and not socket_c._sslobj: socket_c.close(); return + if self.ssl and not socket_c._sslobj: + socket_c.close() + return # in case the SSL mode is enabled, "patches" the socket # object with an extra pending reference, that is going # to be to store pending callable operations in it - if self.ssl: socket_c.pending = None + if self.ssl: + socket_c.pending = None # verifies if the socket is of type internet (either ipv4 # of ipv6), this is going to be used for conditional setting @@ -920,48 +978,47 @@ def on_socket_c(self, socket_c, address): # socket if of type internet (timeout values) socket_c.setblocking(0) socket_c.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if is_inet: socket_c.setsockopt( - socket.IPPROTO_TCP, - socket.TCP_NODELAY, - 1 - ) - if self.receive_buffer_c: socket_c.setsockopt( - socket.SOL_SOCKET, - socket.SO_RCVBUF, - self.receive_buffer_c - ) - if self.send_buffer_c: socket_c.setsockopt( - socket.SOL_SOCKET, - socket.SO_SNDBUF, - self.send_buffer_c - ) + if is_inet: + socket_c.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if self.receive_buffer_c: + socket_c.setsockopt( + socket.SOL_SOCKET, socket.SO_RCVBUF, self.receive_buffer_c + ) + if self.send_buffer_c: + socket_c.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.send_buffer_c) # the process creation is considered completed and a new # connection is created for it and opened, from this time # on a new connection is considered accepted/created for server - connection = self.build_connection(socket_c, address, ssl = self.ssl) + connection = self.build_connection(socket_c, address, ssl=self.ssl) connection.open() # registers the SSL handshake method as a starter method # for the connection, so that the handshake is properly # performed on the initial stage of the connection (as expected) - if self.ssl: connection.add_starter(self._ssl_handshake) + if self.ssl: + connection.add_starter(self._ssl_handshake) # runs the initial try for the handshaking process, note that # this is an async process and further tries to the handshake # may come after this one (async operation) in case an exception # is raises the connection is closed (avoids possible errors) - try: connection.run_starter() - except Exception: connection.close(); raise + try: + connection.run_starter() + except Exception: + connection.close() + raise # in case there's extraneous data pending to be read from the # current connection's internal receive buffer it must be properly # handled on the risk of blocking the newly created connection - if connection.is_pending_data(): self.on_read(connection.socket) + if connection.is_pending_data(): + self.on_read(connection.socket) def on_socket_d(self, socket_c): connection = self.connections_m.get(socket_c, None) - if not connection: return + if not connection: + return def _ssl_handshake(self, connection): Server._ssl_handshake(self, connection) @@ -969,7 +1026,8 @@ def _ssl_handshake(self, connection): # verifies if the socket still has finished the SSL handshaking # process (by verifying the appropriate flag) and then if that's # not the case returns immediately (nothing done) - if not connection.ssl_handshake: return + if not connection.ssl_handshake: + return # prints a debug information notifying the developer about # the finishing of the handshaking process for the connection diff --git a/src/netius/base/stream.py b/src/netius/base/stream.py index 48ae3e847..c0d01f4b7 100644 --- a/src/netius/base/stream.py +++ b/src/netius/base/stream.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -51,6 +42,7 @@ """ The pending status used for transient states (eg: created) connections under this state must be used carefully """ + class Stream(observer.Observable): """ Abstract stream class responsible for the representation of @@ -65,7 +57,7 @@ class Stream(observer.Observable): allowing huge performance improvements. """ - def __init__(self, owner = None): + def __init__(self, owner=None): observer.Observable.__init__(self) self.status = PENDING self.owner = owner @@ -75,19 +67,19 @@ def reset(self): pass def open(self): - if self.status == OPEN: return + if self.status == OPEN: + return self.status = OPEN self.connection.owner.on_stream_c(self) def close(self): - if self.status == CLOSED: return + if self.status == CLOSED: + return self.status = CLOSED self.connection.owner.on_stream_d(self) - def info_dict(self, full = False): - info = dict( - status = self.status - ) + def info_dict(self, full=False): + info = dict(status=self.status) return info def is_open(self): diff --git a/src/netius/base/tls.py b/src/netius/base/tls.py index f98b3f155..3f6e38f57 100644 --- a/src/netius/base/tls.py +++ b/src/netius/base/tls.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -45,21 +36,22 @@ from . import config from . import errors -def fingerprint(certificate, hash = "sha1"): + +def fingerprint(certificate, hash="sha1"): digest = hashlib.new(hash, certificate) return digest.hexdigest() -def match_fingerprint(certificate, exp_fingerprint, hash = "sha1"): - cert_fingerprint = fingerprint(certificate, hash = hash) - if cert_fingerprint == exp_fingerprint: return + +def match_fingerprint(certificate, exp_fingerprint, hash="sha1"): + cert_fingerprint = fingerprint(certificate, hash=hash) + if cert_fingerprint == exp_fingerprint: + return if config._is_devel(): - extra = ", expected '%s' got '%s'" %\ - (exp_fingerprint, cert_fingerprint) + extra = ", expected '%s' got '%s'" % (exp_fingerprint, cert_fingerprint) else: extra = "" - raise errors.SecurityError( - "Missmatch in certificate fingerprint" + extra - ) + raise errors.SecurityError("Missmatch in certificate fingerprint" + extra) + def match_hostname(certificate, hostname): if hasattr(ssl, "match_hostname"): @@ -69,40 +61,45 @@ def match_hostname(certificate, hostname): subject_alt_name = certificate.get("subjectAltName", ()) for key, value in subject_alt_name: - if not key == "DNS": continue - if dnsname_match(value, hostname): return + if not key == "DNS": + continue + if dnsname_match(value, hostname): + return dns_names.append(value) if not dns_names: for subject in certificate.get("subject", ()): for key, value in subject: - if not key == "commonName": continue - if dnsname_match(value, hostname): return + if not key == "commonName": + continue + if dnsname_match(value, hostname): + return dns_names.append(value) if len(dns_names) > 1: raise errors.SecurityError( - "Hostname %s doesn't match either of %s" %\ - (hostname, ", ".join(map(str, dns_names))) + "Hostname %s doesn't match either of %s" + % (hostname, ", ".join(map(str, dns_names))) ) elif len(dns_names) == 1: raise errors.SecurityError( - "Hostname %s doesn't match %s" %\ - (hostname, dns_names[0]) + "Hostname %s doesn't match %s" % (hostname, dns_names[0]) ) else: raise errors.SecurityError( "No appropriate commonName or subjectAltName fields were found" ) -def dnsname_match(domain, hostname, max_wildcards = 1): + +def dnsname_match(domain, hostname, max_wildcards=1): # creates the initial list of pats that are going to be used in # the final match operation for wildcard matching pats = [] # in case no valid domain is passed an invalid result is returned # immediately indicating that no match was possible - if not domain: return False + if not domain: + return False # splits the provided domain value around its components, taking # into account the typical dot separator @@ -114,9 +111,10 @@ def dnsname_match(domain, hostname, max_wildcards = 1): # base value for discovery in case this value overflow the maximum # number of wildcards allowed raises an error wildcards = base.count("*") - if wildcards > max_wildcards: raise errors.SecurityError( - "Too many wildcards in certificate DNS name: " + str(domain) - ) + if wildcards > max_wildcards: + raise errors.SecurityError( + "Too many wildcards in certificate DNS name: " + str(domain) + ) # in case there are no wildcards in the domain name runs the # "normal" hostname validation process against the domain name @@ -136,18 +134,23 @@ def dnsname_match(domain, hostname, max_wildcards = 1): pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE) return True if pat.match(hostname) else False -def dump_certificate(certificate, certificate_binary, name = None): + +def dump_certificate(certificate, certificate_binary, name=None): # runs some pre-validation operations so that the dump parameters # that are required are considered valid - if not certificate: return - if not certificate_binary: return + if not certificate: + return + if not certificate_binary: + return # tries to retrieve the main subject name from the subject # alternative names, there may be no value and if that's the # case a default value is used instead subject_alt_name = certificate.get("subjectAltName", ()) - if subject_alt_name: subject_name = subject_alt_name[0][1] - else: subject_name = "certificate" + if subject_alt_name: + subject_name = subject_alt_name[0][1] + else: + subject_name = "certificate" # "calculates" the final name for the certificate, taking # into account the provided parameter and subject name, the @@ -159,10 +162,13 @@ def dump_certificate(certificate, certificate_binary, name = None): # stored) and creates such directory (if required) ssl_path = config.conf("SSL_PATH", "/tmp/ssl") file_path = os.path.join(ssl_path, file_name) - if not os.path.exists(ssl_path): os.makedirs(ssl_path) + if not os.path.exists(ssl_path): + os.makedirs(ssl_path) # opens the file for writing and then dumps the certificate # binary information into the file, closing it afterwards file = open(file_path, "wb") - try: file.write(certificate_binary) - finally: file.close() + try: + file.write(certificate_binary) + finally: + file.close() diff --git a/src/netius/base/transport.py b/src/netius/base/transport.py index 187c7fecd..8a3aaa5c7 100644 --- a/src/netius/base/transport.py +++ b/src/netius/base/transport.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2017 Hive Solutions Lda." """ The copyright for the module """ @@ -40,6 +31,7 @@ from . import errors from . import observer + class Transport(observer.Observable): """ Decorator class to be used to add the functionality of a @@ -53,13 +45,14 @@ class Transport(observer.Observable): compatible interface. """ - def __init__(self, loop, connection, open = True): + def __init__(self, loop, connection, open=True): self._loop = loop self._connection = connection self._protocol = None self._extra_dict = None self._exhausted = False - if open: self.open() + if open: + self.open() def open(self): self.set_handlers() @@ -69,12 +62,15 @@ def open(self): def close(self): # in case the current transport is already closed or in the # process of closing returns immediately (avoids duplication) - if self.is_closing(): return + if self.is_closing(): + return # in case there's a connection object set schedules its closing # otherwise unsets the protocol object immediately - if self._connection: self._connection.close(flush = True) - else: self._protocol = None + if self._connection: + self._connection.close(flush=True) + else: + self._protocol = None # removes the reference to the underlying connection object # and unsets the exhausted flag (reset of values) @@ -84,14 +80,17 @@ def close(self): def abort(self): # in case the current transport is already closed or in the # process of closing returns immediately (avoids duplication) - if self.is_closing(): return + if self.is_closing(): + return # in case there's a connection set runs the (forced) close # operation so that no more interaction exists otherwise # unsets the protocol (notice that if the connection exists # the close operation will trigger the close protocol callback) - if self._connection: self._connection.close() - else: self._protocol = None + if self._connection: + self._connection.close() + else: + self._protocol = None # unsets the connection object (as it's no longer eligible # to be used) and unsets the current transport for exhausted @@ -102,44 +101,45 @@ def abort(self): def write(self, data): # verifies if the current connection is closing or in the process # of closing and if that's the case returns immediately (graceful) - if self.is_closing(): return + if self.is_closing(): + return # runs the send operation on the underlying (and concrete) # connection object, notice that the delay flag is unset so # that the send flushing operation runs immediately (to provide # behaviour level compatibility with the asyncio library) - self._connection.send(data, delay = False) + self._connection.send(data, delay=False) - def sendto(self, data, addr = None): + def sendto(self, data, addr=None): # verifies if the current connection is closing or in the process # of closing and if that's the case returns immediately (graceful) - if self.is_closing(): return + if self.is_closing(): + return # runs the send operation on the underlying (and concrete) # connection object, notice that the delay flag is unset so # that the send flushing operation runs immediately (to provide # behaviour level compatibility with the asyncio library) - self._connection.send(data, address = addr, delay = False) + self._connection.send(data, address=addr, delay=False) - def get_extra_info(self, name, default = None): + def get_extra_info(self, name, default=None): callable = self._extra_dict.get(name, None) - if callable: return callable() - else: return default + if callable: + return callable() + else: + return default def get_write_buffer_size(self): return self._connection.pending_s def get_write_buffer_limits(self): - return ( - self._connection.min_pending, - self._connection.max_pending - ) + return (self._connection.min_pending, self._connection.max_pending) def set_handlers(self): self._connection.bind("pend", self._buffer_touched) self._connection.bind("unpend", self._buffer_touched) - def set_write_buffer_limits(self, high = None, low = None): + def set_write_buffer_limits(self, high=None, low=None): """ Sets the write buffer limits in the underlying connection object using the provided values. @@ -157,9 +157,12 @@ def set_write_buffer_limits(self, high = None, low = None): """ if high is None: - if low == None: high = 65536 - else: high = 4 * low - if low == None: low = high // 4 + if low == None: + high = 65536 + else: + high = 4 * low + if low == None: + low = high // 4 if not high >= low >= 0: raise errors.RuntimeError("High must be larger than low") @@ -168,21 +171,21 @@ def set_write_buffer_limits(self, high = None, low = None): def set_extra_dict(self): self._extra_dict = dict( - socket = lambda: self._connection.socket, - peername = lambda: self._connection.socket.getpeername(), - sockname = lambda: self._connection.socket.getsockname(), - compression = lambda: self._connection.socket.compression(), - cipher = lambda: self._connection.socket.cipher(), - peercert = lambda: self._connection.socket.getpeercert(), - sslcontext = lambda: self._connection.socket.context, - ssl_object = lambda: self._connection.socket + socket=lambda: self._connection.socket, + peername=lambda: self._connection.socket.getpeername(), + sockname=lambda: self._connection.socket.getsockname(), + compression=lambda: self._connection.socket.compression(), + cipher=lambda: self._connection.socket.cipher(), + peercert=lambda: self._connection.socket.getpeercert(), + sslcontext=lambda: self._connection.socket.context, + ssl_object=lambda: self._connection.socket, ) def get_protocol(self): return self._protocol def set_protocol(self, protocol): - self._set_protocol(protocol, mark = False) + self._set_protocol(protocol, mark=False) def is_closing(self): """ @@ -198,8 +201,10 @@ def is_closing(self): it's also considered to be closing. """ - if not self._connection: return True - if self._connection.is_closed(): return True + if not self._connection: + return True + if self._connection.is_closed(): + return True return False def _on_data(self, connection, data): @@ -216,24 +221,28 @@ def _set_binds(self): self._connection.bind("data", self._on_data) self._connection.bind("close", self._on_close) - def _set_protocol(self, protocol, mark = True): + def _set_protocol(self, protocol, mark=True): self._protocol = protocol - if mark: self._protocol.connection_made(self) + if mark: + self._protocol.connection_made(self) def _buffer_touched(self, connection): self._handle_flow() def _handle_flow(self): - if not self._connection: return + if not self._connection: + return if self._exhausted: is_restored = self._connection.is_restored() - if not is_restored: return + if not is_restored: + return self._exhausted = False self._protocol.resume_writing() else: is_exhausted = self._connection.is_exhausted() - if not is_exhausted: return + if not is_exhausted: + return self._exhausted = True self._protocol.pause_writing() @@ -279,7 +288,8 @@ def _call_soon(self, callback, *args): self._loop.call_soon(callback, *args) else: callable = lambda: callback(*args) - self._loop.delay(callable, immediately = True) + self._loop.delay(callable, immediately=True) + class TransportDatagram(Transport): """ @@ -297,6 +307,7 @@ def _on_data(self, connection, data): def _on_close(self, connection): self._cleanup() + class TransportStream(Transport): """ Abstract class to be used when creating a stream based diff --git a/src/netius/base/util.py b/src/netius/base/util.py index 8dc28d401..2a1900b16 100644 --- a/src/netius/base/util.py +++ b/src/netius/base/util.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -50,7 +41,8 @@ upper case letter regex that will provide a way of putting the underscore in the middle of the transition """ -def camel_to_underscore(camel, separator = "_"): + +def camel_to_underscore(camel, separator="_"): """ Converts the provided camel cased based value into a normalized underscore based string. @@ -74,7 +66,8 @@ def camel_to_underscore(camel, separator = "_"): value = value.lower() return value -def verify(condition, message = None, exception = None): + +def verify(condition, message=None, exception=None): """ Ensures that the requested condition returns a valid value and if that's no the case an exception raised breaking the @@ -92,6 +85,7 @@ def verify(condition, message = None, exception = None): verification operation fails. """ - if condition: return + if condition: + return exception = exception or errors.AssertionError raise exception(message or "Assertion Error") diff --git a/src/netius/clients/__init__.py b/src/netius/clients/__init__.py index 3e8510e3c..066f9fbf0 100644 --- a/src/netius/clients/__init__.py +++ b/src/netius/clients/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/clients/apn.py b/src/netius/clients/apn.py index 88e0ec3aa..aecbd4ddf 100644 --- a/src/netius/clients/apn.py +++ b/src/netius/clients/apn.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,6 +34,7 @@ import netius + class APNProtocol(netius.StreamProtocol): """ Protocol class that defines the interface to operate @@ -86,24 +78,18 @@ def connection_made(self, transport): self.send_notification( self.token, self.message, - sound = self.sound, - badge = self.badge, - close = self._close + sound=self.sound, + badge=self.badge, + close=self._close, ) - def send_notification( - self, - token, - message, - sound = "default", - badge = 0, - close = False - ): + def send_notification(self, token, message, sound="default", badge=0, close=False): # creates the callback handler that closes the current # client infra-structure after sending, this will close # the connection using a graceful approach to avoid any # of the typical problems with the connection shutdown - def callback(transport): self.close() + def callback(transport): + self.close() # converts the current token (in hexadecimal) to a set # of binary string elements and uses that value to get @@ -114,20 +100,15 @@ def callback(transport): self.close() # creates the message structure using with the # message (string) as the alert and then converts # it into a JSON format (payload) - message_s = dict( - aps = dict( - alert = message, - sound = sound, - badge = badge - ) - ) + message_s = dict(aps=dict(alert=message, sound=sound, badge=badge)) payload = json.dumps(message_s) # verifies if the resulting payload object is unicode based # and in case it is encodes it into a string representation # so that it may be used for the packing of structure is_unicode = netius.legacy.is_unicode(payload) - if is_unicode: payload = payload.encode("utf-8") + if is_unicode: + payload = payload.encode("utf-8") # sets the command with the zero value (simplified) # then calculates the token and payload lengths @@ -140,20 +121,22 @@ def callback(transport): self.close() # applies the various components of the message and packs # them according to the generated template template = "!BH%dsH%ds" % (token_length, payload_length) - message = struct.pack(template, command, token_length, token, payload_length, payload) + message = struct.pack( + template, command, token_length, token, payload_length, payload + ) callback = callback if close else None - self.send(message, callback = callback) + self.send(message, callback=callback) def set( self, token, message, - sound = "default", - badge = 0, - sandbox = True, - key_file = None, - cer_file = None, - _close = True + sound="default", + badge=0, + sandbox=True, + key_file=None, + cer_file=None, + _close=True, ): self.token = token self.message = message @@ -164,7 +147,7 @@ def set( self.cer_file = cer_file self._close = _close - def notify(self, token, loop = None, **kwargs): + def notify(self, token, loop=None, **kwargs): # retrieves the intance's parent class object to be # used to global class operations cls = self.__class__ @@ -193,40 +176,43 @@ def notify(self, token, loop = None, **kwargs): self.set( token, message, - sound = sound, - badge = badge, - sandbox = sandbox, - key_file = key_file, - cer_file = cer_file, - _close = _close + sound=sound, + badge=badge, + sandbox=sandbox, + key_file=key_file, + cer_file=cer_file, + _close=_close, ) # establishes the connection to the target host and port # and using the provided key and certificate files loop = netius.connect_stream( lambda: self, - host = self.host, - port = self.port, - ssl = True, - key_file = key_file, - cer_file = cer_file, - loop = loop + host=self.host, + port=self.port, + ssl=True, + key_file=key_file, + cer_file=cer_file, + loop=loop, ) # returns both the current associated loop and the current # instance to the protocol defined by the current instance return loop, self + class APNClient(netius.ClientAgent): protocol = APNProtocol @classmethod - def notify_s(cls, token, loop = None, **kwargs): + def notify_s(cls, token, loop=None, **kwargs): protocol = cls.protocol() - return protocol.notify(token, loop = loop, **kwargs) + return protocol.notify(token, loop=loop, **kwargs) + if __name__ == "__main__": + def on_finish(protocol): netius.compat_loop(loop).stop() @@ -234,9 +220,7 @@ def on_finish(protocol): key_file = netius.conf("APN_KEY_FILE", None) cer_file = netius.conf("APN_CER_FILE", None) - loop, protocol = APNClient.notify_s( - token, key_file = key_file, cer_file = cer_file - ) + loop, protocol = APNClient.notify_s(token, key_file=key_file, cer_file=cer_file) protocol.bind("finish", on_finish) diff --git a/src/netius/clients/dht.py b/src/netius/clients/dht.py index ac55a7eaf..6048a76c7 100644 --- a/src/netius/clients/dht.py +++ b/src/netius/clients/dht.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,24 +32,25 @@ import netius.common + class DHTRequest(netius.Request): def __init__( self, peer_id, - host = "127.0.0.1", - port = 9090, - type = "ping", - callback = None, + host="127.0.0.1", + port=9090, + type="ping", + callback=None, *args, **kwargs ): - netius.Request.__init__(self, callback = callback) + netius.Request.__init__(self, callback=callback) self.peer_id = peer_id self.host = host self.port = port self.type = type - self.args = args, + self.args = (args,) self.kwargs = kwargs self._peer_id = self._get_peer_id() @@ -72,36 +64,25 @@ def request(self): raise netius.ParserError("Invalid type '%s'" % self.type) method = getattr(self, self.type) query = method() - request = dict( - t = str(self.id), - y = "q", - q = self.type, - a = query - ) + request = dict(t=str(self.id), y="q", q=self.type, a=query) return netius.common.bencode(request) def ping(self): - return dict(id = self._peer_id) + return dict(id=self._peer_id) def find_node(self): - return dict( - id = self._peer_id, - target = self.kwargs["target"] - ) + return dict(id=self._peer_id, target=self.kwargs["target"]) def get_peers(self): - return dict( - id = self._peer_id, - info_hash = self.kwargs["info_hash"] - ) + return dict(id=self._peer_id, info_hash=self.kwargs["info_hash"]) def announce_peer(self): return dict( - id = self._peer_id, - implied_port = self.kwargs["implied_port"], - info_hash = self.kwargs["info_hash"], - port = self.kwargs["port"], - token = self.kwargs["token"] + id=self._peer_id, + implied_port=self.kwargs["implied_port"], + info_hash=self.kwargs["info_hash"], + port=self.kwargs["port"], + token=self.kwargs["token"], ) def _get_peer_id(self): @@ -109,6 +90,7 @@ def _get_peer_id(self): peer_id = netius.legacy.bytes(self.peer_id) return peer_id + contact + class DHTResponse(netius.Response): def __init__(self, data): @@ -131,6 +113,7 @@ def is_error(self): def is_response(self): return self.info("r", True) + class DHTClient(netius.DatagramClient): """ Implementation of the DHT (Distributed hash table) for the torrent @@ -143,32 +126,26 @@ class DHTClient(netius.DatagramClient): """ def ping(self, host, port, peer_id, *args, **kwargs): - return self.query(type = "ping", *args, **kwargs) + return self.query(type="ping", *args, **kwargs) def find_node(self, *args, **kwargs): - return self.query(type = "find_node", *args, **kwargs) + return self.query(type="find_node", *args, **kwargs) def get_peers(self, *args, **kwargs): - return self.query(type = "get_peers", *args, **kwargs) + return self.query(type="get_peers", *args, **kwargs) def query( self, - host = "127.0.0.1", - port = 9090, - peer_id = None, - type = "ping", - callback = None, + host="127.0.0.1", + port=9090, + peer_id=None, + type="ping", + callback=None, *args, **kwargs ): request = DHTRequest( - peer_id, - host = host, - port = port, - type = type, - callback = callback, - *args, - **kwargs + peer_id, host=host, port=port, type=type, callback=callback, *args, **kwargs ) data = request.request() @@ -190,5 +167,6 @@ def on_data_dht(self, address, response): response.request = request self.remove_request(request) - if not request.callback: return + if not request.callback: + return request.callback(response) diff --git a/src/netius/clients/dns.py b/src/netius/clients/dns.py index ca275f946..d17a146d5 100644 --- a/src/netius/clients/dns.py +++ b/src/netius/clients/dns.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -59,36 +50,35 @@ DNS_RD = 0x01 DNS_TYPES = dict( - A = 0x01, - NS = 0x02, - MD = 0x03, - MF = 0x04, - CNAME = 0x05, - SOA = 0x06, - MB = 0x07, - MG = 0x08, - MR = 0x09, - NULL = 0x0a, - WKS = 0x0b, - PTR = 0x0c, - HINFO = 0x0d, - MINFO = 0x0e, - MX = 0x0f, - TXT = 0x10, - AAAA = 0x1c + A=0x01, + NS=0x02, + MD=0x03, + MF=0x04, + CNAME=0x05, + SOA=0x06, + MB=0x07, + MG=0x08, + MR=0x09, + NULL=0x0A, + WKS=0x0B, + PTR=0x0C, + HINFO=0x0D, + MINFO=0x0E, + MX=0x0F, + TXT=0x10, + AAAA=0x1C, ) -DNS_CLASSES = dict( - IN = 0x01 -) +DNS_CLASSES = dict(IN=0x01) DNS_TYPES_R = dict(zip(DNS_TYPES.values(), DNS_TYPES.keys())) DNS_CLASSES_R = dict(zip(DNS_CLASSES.values(), DNS_CLASSES.keys())) + class DNSRequest(netius.Request): - def __init__(self, name, type = "a", cls = "in", callback = None): - netius.Request.__init__(self, callback = callback) + def __init__(self, name, type="a", cls="in", callback=None): + netius.Request.__init__(self, callback=callback) self.name = name self.type = type self.cls = cls @@ -119,18 +109,14 @@ def request(self): data = struct.pack(format, *result) buffer.append(data) - query = self._query( - self.name, - type = self.type, - cls = self.cls - ) + query = self._query(self.name, type=self.type, cls=self.cls) buffer.append(query) data = b"".join(buffer) return data - def _query(self, name, type = "a", cls = "in"): + def _query(self, name, type="a", cls="in"): type_i = DNS_TYPES.get(type.upper(), 0x00) clsi = DNS_CLASSES.get(cls.upper(), 0x00) @@ -157,6 +143,7 @@ def _label(self, value): data = b"".join(buffer) return data + class DNSResponse(netius.Response): def parse(self): @@ -247,8 +234,10 @@ def parse_label(self, data, index): initial = data[index] initial_i = netius.legacy.ord(initial) - if initial_i == 0: index += 1; break - is_pointer = initial_i & 0xc0 + if initial_i == 0: + index += 1 + break + is_pointer = initial_i & 0xC0 if is_pointer: index, _data = self.parse_pointer(data, index) @@ -256,7 +245,7 @@ def parse_label(self, data, index): data = b".".join(buffer) return (index, data) - _data = data[index + 1:index + initial_i + 1] + _data = data[index + 1 : index + initial_i + 1] buffer.append(_data) index += initial_i + 1 @@ -265,10 +254,10 @@ def parse_label(self, data, index): return (index, data) def parse_pointer(self, data, index): - slice = data[index:index + 2] + slice = data[index : index + 2] - offset, = struct.unpack("!H", slice) - offset &= 0x3fff + (offset,) = struct.unpack("!H", slice) + offset &= 0x3FFF _index, label = self.parse_label(data, offset) @@ -287,59 +276,68 @@ def parse_ip6(self, data, index): return (index, address) def parse_byte(self, data, index): - _data = data[index:index + 1] - short, = struct.unpack("!B", _data) + _data = data[index : index + 1] + (short,) = struct.unpack("!B", _data) return (index + 1, short) def parse_short(self, data, index): - _data = data[index:index + 2] - short, = struct.unpack("!H", _data) + _data = data[index : index + 2] + (short,) = struct.unpack("!H", _data) return (index + 2, short) def parse_long(self, data, index): - _data = data[index:index + 4] - long, = struct.unpack("!L", _data) + _data = data[index : index + 4] + (long,) = struct.unpack("!L", _data) return (index + 4, long) def parse_long_long(self, data, index): - _data = data[index:index + 8] - long_long, = struct.unpack("!Q", _data) + _data = data[index : index + 8] + (long_long,) = struct.unpack("!Q", _data) return (index + 8, long_long) + class DNSProtocol(netius.DatagramProtocol): ns_file_l = None @classmethod - def ns_system(cls, type = "ip4"): - ns = cls.ns_conf(type = type) - if ns: return ns[0] - ns = cls.ns_file(type = type) - if ns: return ns[0] - ns = cls.ns_google(type = type) - if ns: return ns[0] - ns = cls.ns_cloudfare(type = type) - if ns: return ns[0] + def ns_system(cls, type="ip4"): + ns = cls.ns_conf(type=type) + if ns: + return ns[0] + ns = cls.ns_file(type=type) + if ns: + return ns[0] + ns = cls.ns_google(type=type) + if ns: + return ns[0] + ns = cls.ns_cloudfare(type=type) + if ns: + return ns[0] return None @classmethod - def ns_conf(cls, type = "ip4", force = False): - ns = netius.conf("NAMESERVERS_%s" % type.upper(), [], cast = list) - if ns: return ns - ns = netius.conf("NAMESERVERS", [], cast = list) - if ns: return ns + def ns_conf(cls, type="ip4", force=False): + ns = netius.conf("NAMESERVERS_%s" % type.upper(), [], cast=list) + if ns: + return ns + ns = netius.conf("NAMESERVERS", [], cast=list) + if ns: + return ns return [] @classmethod - def ns_file(cls, type = "ip4", force = False): + def ns_file(cls, type="ip4", force=False): # verifies if the list value for the file based name server # retrieval value is defined and if that's the case and the # force flag is not set returns it immediately - if not cls.ns_file_l == None and not force: return cls.ns_file_l + if not cls.ns_file_l == None and not force: + return cls.ns_file_l # verifies if the resolve file exists and if it does not returns # immediately indicating the impossibility to resolve the value - if not os.path.exists("/etc/resolv.conf"): return None + if not os.path.exists("/etc/resolv.conf"): + return None # retrieves the reference to the function that is going to validate # if the provided name server complies with the proper (address) type @@ -348,8 +346,10 @@ def ns_file(cls, type = "ip4", force = False): # opens the resolve file and reads the complete set of contents # from it, closing the file afterwards file = open("/etc/resolv.conf", "rb") - try: data = file.read() - finally: file.close() + try: + data = file.read() + finally: + file.close() # starts the list that is going to store the various name server # values, this is going to be populated with the file contents @@ -359,12 +359,14 @@ def ns_file(cls, type = "ip4", force = False): # to find the name servers defined in it to be added to the list for line in data.split(b"\n"): line = line.strip() - if not line.startswith(b"nameserver"): continue + if not line.startswith(b"nameserver"): + continue _header, ns = line.split(b" ", 1) ns = ns.strip() ns = netius.legacy.str(ns) is_valid = validator(ns) - if not is_valid: continue + if not is_valid: + continue cls.ns_file_l.append(ns) # returns the final value of the list of name servers loaded from @@ -372,24 +374,22 @@ def ns_file(cls, type = "ip4", force = False): return cls.ns_file_l @classmethod - def ns_google(cls, type = "ip4"): - if type == "ip4": return ["8.8.8.8", "8.8.4.4"] - if type == "ip6": return [ - "2001:4860:4860::8888", - "2001:4860:4860::8844" - ] + def ns_google(cls, type="ip4"): + if type == "ip4": + return ["8.8.8.8", "8.8.4.4"] + if type == "ip6": + return ["2001:4860:4860::8888", "2001:4860:4860::8844"] return [] @classmethod - def ns_cloudfare(cls, type = "ip4"): - if type == "ip4": return ["1.1.1.1", "1.0.0.1"] - if type == "ip6": return [ - "2606:4700:4700::1111", - "2606:4700:4700::1001" - ] + def ns_cloudfare(cls, type="ip4"): + if type == "ip4": + return ["1.1.1.1", "1.0.0.1"] + if type == "ip6": + return ["2606:4700:4700::1111", "2606:4700:4700::1001"] return [] - def query(self, name, type = "a", cls = "in", ns = None, callback = None): + def query(self, name, type="a", cls="in", ns=None, callback=None): # retrieves the reference to the class associated with the # current instance to be used to access class operations _cls = self.__class__ @@ -402,12 +402,7 @@ def query(self, name, type = "a", cls = "in", ns = None, callback = None): # creates a new DNS request object describing the query that was # just sent and then generates the request stream code that is # going to be used for sending the request through network - request = DNSRequest( - name, - type = type, - cls = cls, - callback = callback - ) + request = DNSRequest(name, type=type, cls=cls, callback=callback) data = request.request() # prints some debug information about the DNS query that is going @@ -442,7 +437,8 @@ def on_data_dns(self, address, response): # response and in case none is found returns immediately as # there's nothing remaining to be done request = self.get_request(response) - if not request: return + if not request: + return # removes the request being handled from the current request # structures so that a callback is no longer answered @@ -451,47 +447,34 @@ def on_data_dns(self, address, response): # in case no callback is not defined for the request returns # immediately as there's nothing else remaining to be done, # otherwise calls the proper callback with the response - if not request.callback: return + if not request.callback: + return request.callback(response) + class DNSClient(netius.ClientAgent): protocol = DNSProtocol @classmethod - def query_s( - cls, - name, - type = "a", - cls_ = "in", - ns = None, - callback = None, - loop = None - ): + def query_s(cls, name, type="a", cls_="in", ns=None, callback=None, loop=None): ns = ns or cls.protocol.ns_system() address = (ns, 53) protocol = cls.protocol() def on_connect(result): _transport, protocol = result - protocol.query( - name, - type = type, - cls = cls_, - ns = ns, - callback = callback - ) + protocol.query(name, type=type, cls=cls_, ns=ns, callback=callback) loop = netius.build_datagram( - lambda: protocol, - callback = on_connect, - loop = loop, - remote_addr = address + lambda: protocol, callback=on_connect, loop=loop, remote_addr=address ) return loop, protocol + if __name__ == "__main__": + def handler(response): # closes the current protocol to correctly close # all of the underlying structures @@ -505,7 +488,9 @@ def handler(response): # in case the provided response is not valid # a timeout message is printed to indicate the # problem with the resolution - if not response: print("Timeout in resolution"); return + if not response: + print("Timeout in resolution") + return # unpacks the complete set of contents from # the various answers so that only the address @@ -529,12 +514,7 @@ def handler(response): # runs the static version of a DNS query, note that # the daemon flag is unset so that the global client # runs in foreground avoiding the exit of the process - loop, protocol = DNSClient.query_s( - name, - type = type, - ns = ns, - callback = handler - ) + loop, protocol = DNSClient.query_s(name, type=type, ns=ns, callback=handler) loop.run_forever() loop.close() else: diff --git a/src/netius/clients/http.py b/src/netius/clients/http.py index 69690beef..80a147e7a 100644 --- a/src/netius/clients/http.py +++ b/src/netius/clients/http.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -44,14 +35,19 @@ import netius.common -from netius.common import PLAIN_ENCODING, CHUNKED_ENCODING,\ - GZIP_ENCODING, DEFLATE_ENCODING +from netius.common import ( + PLAIN_ENCODING, + CHUNKED_ENCODING, + GZIP_ENCODING, + DEFLATE_ENCODING, +) Z_PARTIAL_FLUSH = 1 """ The zlib constant value representing the partial flush of the current zlib stream, this value has to be defined locally as it is not defines under the zlib module """ + class HTTPProtocol(netius.StreamProtocol): """ Implementation of the HTTP protocol to be used by a client @@ -59,9 +55,7 @@ class HTTPProtocol(netius.StreamProtocol): responses. """ - BASE_HEADERS = { - "user-agent" : netius.IDENTIFIER - } + BASE_HEADERS = {"user-agent": netius.IDENTIFIER} """ The map containing the complete set of headers that are meant to be applied to all the requests """ @@ -69,24 +63,24 @@ def __init__( self, method, url, - params = None, - headers = None, - data = None, - version = "HTTP/1.1", - encoding = PLAIN_ENCODING, - encodings = "gzip, deflate", - safe = False, - request = False, - asynchronous = True, - timeout = None, - use_file = False, - callback = None, - on_init = None, - on_open = None, - on_close = None, - on_headers = None, - on_data = None, - on_result = None, + params=None, + headers=None, + data=None, + version="HTTP/1.1", + encoding=PLAIN_ENCODING, + encodings="gzip, deflate", + safe=False, + request=False, + asynchronous=True, + timeout=None, + use_file=False, + callback=None, + on_init=None, + on_open=None, + on_close=None, + on_headers=None, + on_data=None, + on_result=None, *args, **kwargs ): @@ -95,24 +89,24 @@ def __init__( self.set( method, url, - params = params, - headers = headers, - data = data, - version = version, - encoding = encoding, - encodings = encodings, - safe = safe, - request = request, - asynchronous = asynchronous, - timeout = timeout, - use_file = use_file, - callback = callback, - on_init = on_init, - on_open = on_open, - on_close = on_close, - on_headers = on_headers, - on_data = on_data, - on_result = on_result + params=params, + headers=headers, + data=data, + version=version, + encoding=encoding, + encodings=encodings, + safe=safe, + request=request, + asynchronous=asynchronous, + timeout=timeout, + use_file=use_file, + callback=callback, + on_init=on_init, + on_open=on_open, + on_close=on_close, + on_headers=on_headers, + on_data=on_data, + on_result=on_result, ) @classmethod @@ -125,27 +119,28 @@ def key_g(cls, url): @classmethod def decode_gzip(cls, data): - if not data: return data + if not data: + return data return zlib.decompress(data, zlib.MAX_WBITS | 16) @classmethod def decode_deflate(cls, data): - if not data: return data - try: return zlib.decompress(data) - except Exception: return zlib.decompress(data, -zlib.MAX_WBITS) + if not data: + return data + try: + return zlib.decompress(data) + except Exception: + return zlib.decompress(data, -zlib.MAX_WBITS) @classmethod def decode_zlib_file( - cls, - input, - output, - buffer_size = 16384, - wbits = zlib.MAX_WBITS | 16 + cls, input, output, buffer_size=16384, wbits=zlib.MAX_WBITS | 16 ): decompressor = zlib.decompressobj(wbits) while True: data = input.read(buffer_size) - if not data: break + if not data: + break raw_data = decompressor.decompress(data) output.write(raw_data) raw_data = decompressor.flush() @@ -154,42 +149,26 @@ def decode_zlib_file( @classmethod def decode_gzip_file( - cls, - input, - output, - buffer_size = 16384, - wbits = zlib.MAX_WBITS | 16 + cls, input, output, buffer_size=16384, wbits=zlib.MAX_WBITS | 16 ): - return cls.decode_zlib_file( - input, - output, - buffer_size = buffer_size, - wbits = wbits - ) + return cls.decode_zlib_file(input, output, buffer_size=buffer_size, wbits=wbits) @classmethod def decode_deflate_file( - cls, - input, - output, - buffer_size = 16384, - wbits = -zlib.MAX_WBITS + cls, input, output, buffer_size=16384, wbits=-zlib.MAX_WBITS ): - return cls.decode_zlib_file( - input, - output, - buffer_size = buffer_size, - wbits = wbits - ) + return cls.decode_zlib_file(input, output, buffer_size=buffer_size, wbits=wbits) @classmethod - def set_request(cls, parser, buffer, request = None): - if request == None: request = dict() + def set_request(cls, parser, buffer, request=None): + if request == None: + request = dict() headers = parser.get_headers() data = b"".join(buffer) encoding = headers.get("Content-Encoding", None) decoder = getattr(cls, "decode_%s" % encoding) if encoding else None - if decoder and data: data = decoder(data) + if decoder and data: + data = decoder(data) request["code"] = parser.code request["status"] = parser.status request["headers"] = headers @@ -198,16 +177,12 @@ def set_request(cls, parser, buffer, request = None): @classmethod def set_request_file( - cls, - parser, - input, - request = None, - output = None, - buffer_size = 16384 + cls, parser, input, request=None, output=None, buffer_size=16384 ): # verifies if a request object has been passes to the current # method and if that's not the case creates a new one (as a map) - if request == None: request = dict() + if request == None: + request = dict() # retrieves the complete set of headers and tries discover the # encoding of it and the associated decoder (if any) @@ -218,14 +193,11 @@ def set_request_file( # in case there's a decoder and an input (file) then runs the decoding # process setting the data as the resulting (decoded object) if decoder and input: - if output == None: output = tempfile.NamedTemporaryFile(mode = "w+b") + if output == None: + output = tempfile.NamedTemporaryFile(mode="w+b") input.seek(0) try: - data = decoder( - input, - output, - buffer_size = buffer_size - ) + data = decoder(input, output, buffer_size=buffer_size) finally: input.close() @@ -251,9 +223,11 @@ def set_request_file( return request @classmethod - def set_error(cls, error, message = None, request = None, force = False): - if request == None: request = dict() - if "error" in request and not force: return + def set_error(cls, error, message=None, request=None, force=False): + if request == None: + request = dict() + if "error" in request and not force: + return request["error"] = error request["message"] = message return request @@ -263,7 +237,7 @@ def open_c(self, *args, **kwargs): # creates a new HTTP parser instance and set the correct event # handlers so that the data parsing is properly handled - self.parser = netius.common.HTTPParser(self, type = netius.common.RESPONSE) + self.parser = netius.common.HTTPParser(self, type=netius.common.RESPONSE) self.parser.bind("on_data", self._on_data) self.parser.bind("on_partial", self.on_partial) self.parser.bind("on_headers", self.on_headers) @@ -272,27 +246,30 @@ def open_c(self, *args, **kwargs): def close_c(self, *args, **kwargs): netius.StreamProtocol.close_c(self, *args, **kwargs) - if self.parser: self.parser.destroy() - if self.parsed: self.parsed = None - if self.gzip: self._close_gzip(safe = True) - if self.gzip_c: self.gzip_c = None - - def info_dict(self, full = False): - info = netius.StreamProtocol.info_dict(self, full = full) + if self.parser: + self.parser.destroy() + if self.parsed: + self.parsed = None + if self.gzip: + self._close_gzip(safe=True) + if self.gzip_c: + self.gzip_c = None + + def info_dict(self, full=False): + info = netius.StreamProtocol.info_dict(self, full=full) info.update( - version = self.version, - method = self.method, - encoding = self.encodings, - url = self.url, - parsed = self.parsed, - host = self.host, - port = self.port, - path = self.path, - headers = self.headers - ) - if full: info.update( - parser = self.parser.info_dict() + version=self.version, + method=self.method, + encoding=self.encodings, + url=self.url, + parsed=self.parsed, + host=self.host, + port=self.port, + path=self.path, + headers=self.headers, ) + if full: + info.update(parser=self.parser.info_dict()) return info def connection_made(self, transport): @@ -306,101 +283,79 @@ def loop_set(self, loop): netius.StreamProtocol.loop_set(self, loop) self.set_dynamic() - def flush(self, force = False, callback = None): + def flush(self, force=False, callback=None): if self.current == DEFLATE_ENCODING: - self._flush_gzip(force = force, callback = callback) + self._flush_gzip(force=force, callback=callback) elif self.current == GZIP_ENCODING: - self._flush_gzip(force = force, callback = callback) + self._flush_gzip(force=force, callback=callback) elif self.current == CHUNKED_ENCODING: - self._flush_chunked(force = force, callback = callback) + self._flush_chunked(force=force, callback=callback) elif self.current == PLAIN_ENCODING: - self._flush_plain(force = force, callback = callback) + self._flush_plain(force=force, callback=callback) self.current = self.encoding def send_base( - self, - data, - stream = None, - final = True, - delay = True, - force = False, - callback = None + self, data, stream=None, final=True, delay=True, force=False, callback=None ): data = netius.legacy.bytes(data) if data else data if self.current == PLAIN_ENCODING: return self.send_plain( data, - stream = stream, - final = final, - delay = delay, - force = force, - callback = callback + stream=stream, + final=final, + delay=delay, + force=force, + callback=callback, ) elif self.current == CHUNKED_ENCODING: return self.send_chunked( data, - stream = stream, - final = final, - delay = delay, - force = force, - callback = callback + stream=stream, + final=final, + delay=delay, + force=force, + callback=callback, ) elif self.current == GZIP_ENCODING: return self.send_gzip( data, - stream = stream, - final = final, - delay = delay, - force = force, - callback = callback + stream=stream, + final=final, + delay=delay, + force=force, + callback=callback, ) elif self.current == DEFLATE_ENCODING: return self.send_gzip( data, - stream = stream, - final = final, - delay = delay, - force = force, - callback = callback + stream=stream, + final=final, + delay=delay, + force=force, + callback=callback, ) def send_plain( - self, - data, - stream = None, - final = True, - delay = True, - force = False, - callback = None + self, data, stream=None, final=True, delay=True, force=False, callback=None ): - return self.send( - data, - delay = delay, - force = force, - callback = callback - ) + return self.send(data, delay=delay, force=force, callback=callback) def send_chunked( - self, - data, - stream = None, - final = True, - delay = True, - force = False, - callback = None + self, data, stream=None, final=True, delay=True, force=False, callback=None ): # in case there's no valid data to be sent uses the plain # send method to send the empty string and returns immediately # to the caller method, to avoid any problems - if not data: return self.send_plain( - data, - stream = stream, - final = final, - delay = delay, - force = force, - callback = callback - ) + if not data: + return self.send_plain( + data, + stream=stream, + final=final, + delay=delay, + force=force, + callback=callback, + ) # creates the new list that is going to be used to store # the various parts of the chunk and then calculates the @@ -420,34 +375,35 @@ def send_chunked( buffer_s = b"".join(buffer) return self.send_plain( buffer_s, - stream = stream, - final = final, - delay = delay, - force = force, - callback = callback + stream=stream, + final=final, + delay=delay, + force=force, + callback=callback, ) def send_gzip( self, data, - stream = None, - final = True, - delay = True, - force = False, - callback = None, - level = 6 + stream=None, + final=True, + delay=True, + force=False, + callback=None, + level=6, ): # verifies if the provided data buffer is valid and in # in case it's not propagates the sending to the upper # layer (chunked sending) for proper processing - if not data: return self.send_chunked( - data, - stream = stream, - final = final, - delay = delay, - force = force, - callback = callback - ) + if not data: + return self.send_chunked( + data, + stream=stream, + final=final, + delay=delay, + force=force, + callback=callback, + ) # "calculates" if the current sending of gzip data is # the first one by verifying if the gzip object is set @@ -467,41 +423,42 @@ def send_gzip( # that in case the resulting of the compress operation # is not valid a sync flush operation is performed data_c = self.gzip.compress(data) - if not data_c: data_c = self.gzip.flush(Z_PARTIAL_FLUSH) + if not data_c: + data_c = self.gzip.flush(Z_PARTIAL_FLUSH) # sends the compressed data to the client endpoint setting # the correct callback values as requested return self.send_chunked( data_c, - stream = stream, - final = final, - delay = delay, - force = force, - callback = callback + stream=stream, + final=final, + delay=delay, + force=force, + callback=callback, ) def set( self, method, url, - params = None, - headers = None, - data = None, - version = "HTTP/1.1", - encoding = PLAIN_ENCODING, - encodings = "gzip, deflate", - safe = False, - request = False, - asynchronous = True, - timeout = None, - use_file = False, - callback = None, - on_init = None, - on_open = None, - on_close = None, - on_headers = None, - on_data = None, - on_result = None, + params=None, + headers=None, + data=None, + version="HTTP/1.1", + encoding=PLAIN_ENCODING, + encodings="gzip, deflate", + safe=False, + request=False, + asynchronous=True, + timeout=None, + use_file=False, + callback=None, + on_init=None, + on_open=None, + on_close=None, + on_headers=None, + on_data=None, + on_result=None, ): cls = self.__class__ @@ -553,7 +510,8 @@ def set( # in case there's an HTTP parser already set for the protocol runs # the reset operation so that its state is guaranteed to be clean - if self.parser: self.parser.clear() + if self.parser: + self.parser.clear() # tries to determine if the protocol response should be request # wrapped, meaning that a map based object is going to be populated @@ -565,22 +523,28 @@ def set( # are met) the protocol is called to run the wrapping operation if wrap_request: _request, on_close, on_data, callback = self.wrap_request( - use_file = use_file, - callback = callback, - on_close = on_close, - on_data = on_data, - on_result = on_result + use_file=use_file, + callback=callback, + on_close=on_close, + on_data=on_data, + on_result=on_result, ) # registers for the proper event handlers according to the # provided parameters, note that these are considered to be # the lower level infra-structure of the event handling - if on_init: self.bind("loop_set", on_init) - if on_open: self.bind("open", on_open) - if on_close: self.bind("close", on_close) - if on_headers: self.bind("headers", on_headers) - if on_data: self.bind("partial", on_data) - if callback: self.bind("message", callback) + if on_init: + self.bind("loop_set", on_init) + if on_open: + self.bind("open", on_open) + if on_close: + self.bind("close", on_close) + if on_headers: + self.bind("headers", on_headers) + if on_data: + self.bind("partial", on_data) + if callback: + self.bind("message", callback) # sets the static part of the protocol internal (no loop is required) # so that the required initials fields are properly populated @@ -610,7 +574,8 @@ def set_static(self): # adds these parameters to the end of the provided URL, these # values are commonly named get parameters query = netius.legacy.urlencode(self.params) - if query: self.url = self.url + "?" + query + if query: + self.url = self.url + "?" + query # parses the provided URL and retrieves the various parts of the # URL that are going to be used in the creation of the connection @@ -646,18 +611,17 @@ def set_dynamic(self): # the connection operation is exceeded an error is set int # the connection and the connection is properly closed def connect_timeout(): - if self.is_open(): return + if self.is_open(): + return self.request and cls.set_error( - "timeout", - message = "Timeout on connect", - request = self.request + "timeout", message="Timeout on connect", request=self.request ) self.close() # schedules a delay operation to run the timeout handler for # both connect operation (this is considered the initial # triggers for the such verifiers) - self.delay(connect_timeout, timeout = self.timeout) + self.delay(connect_timeout, timeout=self.timeout) def run_request(self): # retrieves the reference to the top level class to be used @@ -675,10 +639,14 @@ def receive_timeout(): # try to validate if the requirements for proper request # validations are defined, if any of them is not the control # full is returned immediately avoiding re-schedule of handler - if not self.request: return - if not self.is_open(): return - if self.request["code"]: return - if not id(request) == id(self.request): return + if not self.request: + return + if not self.is_open(): + return + if self.request["code"]: + return + if not id(request) == id(self.request): + return # retrieves the current time and the time of the last data # receive operation and using that calculates the delta @@ -696,24 +664,25 @@ def receive_timeout(): # receive operations is valid or there's data still pending # to be sent to the server side, and if that's the case delays # the timeout verification according to the timeout value - if not self.is_open() or delta < self.timeout or\ - not self.transport().get_write_buffer_size() == 0: - self.delay(receive_timeout, timeout = self.timeout) + if ( + not self.is_open() + or delta < self.timeout + or not self.transport().get_write_buffer_size() == 0 + ): + self.delay(receive_timeout, timeout=self.timeout) return # tries to determine the proper message that is going to be # set in the request error, this value should take into account # the current development mode flag value - if self.is_devel(): message = "Timeout on receive (received %d bytes)" % received - else: message = "Timeout on receive" + if self.is_devel(): + message = "Timeout on receive (received %d bytes)" % received + else: + message = "Timeout on receive" # sets the error information in the request so that the # request handler is properly "notified" about the error - cls.set_error( - "timeout", - message = message, - request = self.request - ) + cls.set_error("timeout", message=message, request=self.request) # closes the protocol (it's no longer considered valid) # and then verifies the various auto closing values @@ -721,11 +690,11 @@ def receive_timeout(): # sends the request effectively triggering a chain of event # that should end with the complete receiving of the response - self.send_request(callback = lambda c: self.delay( - receive_timeout, timeout = self.timeout - )) + self.send_request( + callback=lambda c: self.delay(receive_timeout, timeout=self.timeout) + ) - def send_request(self, callback = None): + def send_request(self, callback=None): method = self.method path = self.path version = self.version @@ -734,51 +703,53 @@ def send_request(self, callback = None): parsed = self.parsed safe = self.safe - if parsed.query: path += "?" + parsed.query + if parsed.query: + path += "?" + parsed.query headers = dict(headers) self._apply_base(headers) self._apply_dynamic(headers) self._apply_connection(headers) - if safe: self._headers_normalize(headers) + if safe: + self._headers_normalize(headers) buffer = [] buffer.append("%s %s %s\r\n" % (method, path, version)) for key, value in netius.legacy.iteritems(headers): key = netius.common.header_up(key) - if not isinstance(value, list): value = (value,) + if not isinstance(value, list): + value = (value,) for _value in value: _value = netius.legacy.ascii(_value) buffer.append("%s: %s\r\n" % (key, _value)) buffer.append("\r\n") buffer_data = "".join(buffer) - if data: count = self.send_plain(buffer_data, force = True) - else: count = self.send_plain(buffer_data, force = True, callback = callback) + if data: + count = self.send_plain(buffer_data, force=True) + else: + count = self.send_plain(buffer_data, force=True, callback=callback) - if not data: return count + if not data: + return count - def send_part(transport = None): + def send_part(transport=None): try: _data = next(data) except StopIteration: - if hasattr(data, "close"): data.close() + if hasattr(data, "close"): + data.close() callback and callback(transport) return - self.send_base(_data, force = True, callback = send_part) + self.send_base(_data, force=True, callback=send_part) send_part() return count def wrap_request( - self, - use_file = False, - callback = None, - on_close = None, - on_data = None, - on_result = None + self, use_file=False, callback=None, on_close=None, on_data=None, on_result=None ): """ Wraps the current set of operations for the protocol so that @@ -823,40 +794,46 @@ def wrap_request( # they may be used for the correct construction of the request # structure that is going to be send in the callback, then sets # the identifier (memory address) of the request in the connection - buffer = tempfile.NamedTemporaryFile(mode = "w+b") if use_file else [] - self.request = dict(code = None, data = None) + buffer = tempfile.NamedTemporaryFile(mode="w+b") if use_file else [] + self.request = dict(code=None, data=None) def on_close(protocol): - if _on_close: _on_close(protocol) + if _on_close: + _on_close(protocol) protocol._request = None - if self.request["code"]: return - cls.set_error( - "closed", - message = "Connection closed", - request = self.request - ) + if self.request["code"]: + return + cls.set_error("closed", message="Connection closed", request=self.request) def on_data(protocol, parser, data): - if _on_data: _on_data(protocol, parser, data) - if use_file: buffer.write(data) - else: buffer.append(data) + if _on_data: + _on_data(protocol, parser, data) + if use_file: + buffer.write(data) + else: + buffer.append(data) received = self.request.get("received", 0) self.request["received"] = received + len(data) self.request["last"] = time.time() def callback(protocol, parser, message): - if _callback: _callback(protocol, parser, message) - if use_file: cls.set_request_file(parser, buffer, request = self.request) - else: cls.set_request(parser, buffer, request = self.request) - if on_result: on_result(protocol, parser, self.request) + if _callback: + _callback(protocol, parser, message) + if use_file: + cls.set_request_file(parser, buffer, request=self.request) + else: + cls.set_request(parser, buffer, request=self.request) + if on_result: + on_result(protocol, parser, self.request) # returns the request object that is going to be properly # populated over the life-cycle of the protocol return self.request, on_close, on_data, callback - def set_headers(self, headers, normalize = True): + def set_headers(self, headers, normalize=True): self.headers = headers - if normalize: self.normalize_headers() + if normalize: + self.normalize_headers() def normalize_headers(self): for key, value in netius.legacy.items(self.headers): @@ -881,7 +858,8 @@ def raw_data(self, data): """ encoding = self.parser.headers.get("content-encoding", None) - if not encoding: return data + if not encoding: + return data if not self.gzip_c: is_deflate = encoding == "deflate" wbits = zlib.MAX_WBITS if is_deflate else zlib.MAX_WBITS | 16 @@ -909,9 +887,11 @@ def is_uncompressed(self): def is_flushed(self): return self.current > PLAIN_ENCODING - def is_measurable(self, strict = True): - if self.is_compressed(): return False - if strict and self.is_chunked(): return False + def is_measurable(self, strict=True): + if self.is_compressed(): + return False + if strict and self.is_chunked(): + return False return True def on_data(self, data): @@ -933,26 +913,27 @@ def on_headers(self): def on_chunk(self, range): self.trigger("chunk", self, self.parser, range) - def _flush_plain(self, force = False, callback = None): - if not callback: return - self.send_plain(b"", force = force, callback = callback) + def _flush_plain(self, force=False, callback=None): + if not callback: + return + self.send_plain(b"", force=force, callback=callback) - def _flush_chunked(self, force = False, callback = None): - self.send_plain(b"0\r\n\r\n", force = force, callback = callback) + def _flush_chunked(self, force=False, callback=None): + self.send_plain(b"0\r\n\r\n", force=force, callback=callback) - def _flush_gzip(self, force = False, callback = None): + def _flush_gzip(self, force=False, callback=None): # in case the gzip structure has not been initialized # (no data sent) no need to run the flushing of the # gzip data, so only the chunked part is flushed if not self.gzip: - self._flush_chunked(force = force, callback = callback) + self._flush_chunked(force=force, callback=callback) return # flushes the internal zlib buffers to be able to retrieve # the data pending to be sent to the client and then sends # it using the chunked encoding strategy data_c = self.gzip.flush(zlib.Z_FINISH) - self.send_chunked(data_c, force = force, final = False) + self.send_chunked(data_c, force=force, final=False) # resets the gzip values to the original ones so that new # requests will starts the information from the beginning @@ -961,12 +942,13 @@ def _flush_gzip(self, force = False, callback = None): # runs the flush operation for the underlying chunked encoding # layer so that the client is correctly notified about the # end of the current request (normal operation) - self._flush_chunked(force = force, callback = callback) + self._flush_chunked(force=force, callback=callback) - def _close_gzip(self, safe = True): + def _close_gzip(self, safe=True): # in case the gzip object is not defined returns the control # to the caller method immediately (nothing to be done) - if not self.gzip: return + if not self.gzip: + return try: # runs the flush operation for the the final finish stage @@ -977,12 +959,14 @@ def _close_gzip(self, safe = True): except Exception: # in case the safe flag is not set re-raises the exception # to the caller stack (as expected by the callers) - if not safe: raise + if not safe: + raise - def _apply_base(self, headers, replace = False): + def _apply_base(self, headers, replace=False): cls = self.__class__ for key, value in netius.legacy.iteritems(cls.BASE_HEADERS): - if not replace and key in headers: continue + if not replace and key in headers: + continue headers[key] = value def _apply_dynamic(self, headers): @@ -994,20 +978,25 @@ def _apply_dynamic(self, headers): # determines the proper strategy for data payload length, taking into # account if there's a payload and if it exists if it's a byte stream # or instead an iterator/generator - if not data: length = 0 - elif netius.legacy.is_bytes(data): length = len(data) - else: length = next(data) + if not data: + length = 0 + elif netius.legacy.is_bytes(data): + length = len(data) + else: + length = next(data) # ensures that if the content encoding is plain the content length # for the payload is defined otherwise it would be impossible to the # server side to determine when the content sending is finished netius.verify( not is_plain or not length == -1, - message = "The content length must be defined for plain HTTP encoding" + message="The content length must be defined for plain HTTP encoding", ) - if port in (80, 443): host_s = host - else: host_s = "%s:%d" % (host, port) + if port in (80, 443): + host_s = host + else: + host_s = "%s:%d" % (host, port) if not "connection" in headers: headers["connection"] = "keep-alive" @@ -1018,27 +1007,34 @@ def _apply_dynamic(self, headers): if not "accept-encoding" in headers and self.encodings: headers["accept-encoding"] = self.encodings - def _apply_connection(self, headers, strict = True): + def _apply_connection(self, headers, strict=True): is_chunked = self.is_chunked() is_gzip = self.is_gzip() is_deflate = self.is_deflate() is_compressed = self.is_compressed() - is_measurable = self.is_measurable(strict = strict) + is_measurable = self.is_measurable(strict=strict) has_length = "content-length" in headers has_ranges = "accept-ranges" in headers - if is_chunked: headers["transfer-encoding"] = "chunked" - if is_gzip: headers["content-encoding"] = "gzip" - if is_deflate: headers["content-encoding"] = "deflate" + if is_chunked: + headers["transfer-encoding"] = "chunked" + if is_gzip: + headers["content-encoding"] = "gzip" + if is_deflate: + headers["content-encoding"] = "deflate" - if not is_measurable and has_length: del headers["content-length"] - if is_compressed and has_ranges: del headers["accept-ranges"] + if not is_measurable and has_length: + del headers["content-length"] + if is_compressed and has_ranges: + del headers["accept-ranges"] def _headers_normalize(self, headers): for key, value in netius.legacy.items(headers): - if not type(value) in (list, tuple): continue + if not type(value) in (list, tuple): + continue headers[key] = ";".join(value) + class HTTPClient(netius.ClientAgent): """ Simple test of an HTTP client, supports a series of basic @@ -1051,118 +1047,62 @@ class HTTPClient(netius.ClientAgent): protocol = HTTPProtocol - def __init__( - self, - auto_release = True, - *args, - **kwargs - ): + def __init__(self, auto_release=True, *args, **kwargs): netius.ClientAgent.__init__(self, *args, **kwargs) self.auto_release = auto_release self.available = dict() self._loop = None @classmethod - def get_s( - cls, - url, - params = {}, - headers = {}, - **kwargs - ): - return cls.method_s( - "GET", - url, - params = params, - headers = headers, - **kwargs - ) + def get_s(cls, url, params={}, headers={}, **kwargs): + return cls.method_s("GET", url, params=params, headers=headers, **kwargs) @classmethod - def post_s( - cls, - url, - params = {}, - headers = {}, - data = None, - **kwargs - ): + def post_s(cls, url, params={}, headers={}, data=None, **kwargs): return cls.method_s( - "POST", - url, - params = params, - headers = headers, - data = data, - **kwargs + "POST", url, params=params, headers=headers, data=data, **kwargs ) @classmethod - def put_s( - cls, - url, - params = {}, - headers = {}, - data = None, - **kwargs - ): + def put_s(cls, url, params={}, headers={}, data=None, **kwargs): return cls.method_s( - "PUT", - url, - params = params, - headers = headers, - data = data, - **kwargs + "PUT", url, params=params, headers=headers, data=data, **kwargs ) @classmethod - def delete_s( - cls, - url, - params = {}, - headers = {}, - **kwargs - ): - return cls.method_s( - "DELETE", - url, - params = params, - headers = headers, - **kwargs - ) + def delete_s(cls, url, params={}, headers={}, **kwargs): + return cls.method_s("DELETE", url, params=params, headers=headers, **kwargs) @classmethod def method_s( cls, method, url, - params = {}, - headers = {}, - data = None, - version = "HTTP/1.1", - safe = False, - asynchronous = True, - daemon = True, - timeout = None, - ssl_verify = False, - use_file = False, - callback = None, - on_init = None, - on_open = None, - on_close = None, - on_headers = None, - on_data = None, - on_result = None, - http_client = None, + params={}, + headers={}, + data=None, + version="HTTP/1.1", + safe=False, + asynchronous=True, + daemon=True, + timeout=None, + ssl_verify=False, + use_file=False, + callback=None, + on_init=None, + on_open=None, + on_close=None, + on_headers=None, + on_data=None, + on_result=None, + http_client=None, **kwargs ): # in case no HTTP client instance is provided tries to # retrieve a static global one (singleton) to be used # for the current request operation if not http_client: - http_client = cls.get_client_s( - daemon = daemon, - **kwargs - ) + http_client = cls.get_client_s(daemon=daemon, **kwargs) # calls the underlying method on the current HTTP client # propagating most of the arguments, and retrieves the resulting @@ -1170,22 +1110,22 @@ def method_s( result = http_client.method( method, url, - params = params, - headers = headers, - data = data, - version = version, - safe = safe, - asynchronous = asynchronous, - timeout = timeout, - ssl_verify = ssl_verify, - use_file = use_file, - callback = callback, - on_init = on_init, - on_open = on_open, - on_close = on_close, - on_headers = on_headers, - on_data = on_data, - on_result = on_result, + params=params, + headers=headers, + data=data, + version=version, + safe=safe, + asynchronous=asynchronous, + timeout=timeout, + ssl_verify=ssl_verify, + use_file=use_file, + callback=callback, + on_init=on_init, + on_open=on_open, + on_close=on_close, + on_headers=on_headers, + on_data=on_data, + on_result=on_result, **kwargs ) @@ -1194,7 +1134,7 @@ def method_s( return result @classmethod - def to_response(cls, map, raise_e = True): + def to_response(cls, map, raise_e=True): """ Simple utility method that takes the classic dictionary based request and converts it into a simple HTTP response @@ -1214,14 +1154,16 @@ def to_response(cls, map, raise_e = True): message = map.get("message", None) exception = map.get("exception", None) is_error = True if error and raise_e else False - if not is_error: return netius.common.HTTPResponse( - data = map.get("data", None), - code = map.get("code", 500), - status = map.get("status", None), - headers = map.get("headers", None) - ) + if not is_error: + return netius.common.HTTPResponse( + data=map.get("data", None), + code=map.get("code", 500), + status=map.get("status", None), + headers=map.get("headers", None), + ) message = message or "Undefined error (%s)" % error - if exception: raise exception + if exception: + raise exception raise netius.NetiusError(message) def cleanup(self): @@ -1238,95 +1180,47 @@ def cleanup(self): # not allowing any further re-usage of it (as expected) self._close_loop() - def get( - self, - url, - params = {}, - headers = {}, - **kwargs - ): - return self.method( - "GET", - url, - params = params, - headers = headers, - **kwargs - ) + def get(self, url, params={}, headers={}, **kwargs): + return self.method("GET", url, params=params, headers=headers, **kwargs) - def post( - self, - url, - params = {}, - headers = {}, - data = None, - **kwargs - ): + def post(self, url, params={}, headers={}, data=None, **kwargs): return self.method( - "POST", - url, - params = params, - headers = headers, - data = data, - **kwargs + "POST", url, params=params, headers=headers, data=data, **kwargs ) - def put( - self, - url, - params = {}, - headers = {}, - data = None, - **kwargs - ): + def put(self, url, params={}, headers={}, data=None, **kwargs): return self.method( - "PUT", - url, - params = params, - headers = headers, - data = data, - **kwargs + "PUT", url, params=params, headers=headers, data=data, **kwargs ) - def delete( - self, - url, - params = {}, - headers = {}, - **kwargs - ): - return self.method( - "DELETE", - url, - params = params, - headers = headers, - **kwargs - ) + def delete(self, url, params={}, headers={}, **kwargs): + return self.method("DELETE", url, params=params, headers=headers, **kwargs) def method( self, method, url, - params = None, - headers = None, - data = None, - version = "HTTP/1.1", - encoding = PLAIN_ENCODING, - encodings = "gzip, deflate", - safe = False, - request = False, - close = True, - asynchronous = True, - timeout = None, - ssl_verify = False, - use_file = False, - callback = None, - on_init = None, - on_open = None, - on_close = None, - on_headers = None, - on_data = None, - on_result = None, - loop = None, + params=None, + headers=None, + data=None, + version="HTTP/1.1", + encoding=PLAIN_ENCODING, + encodings="gzip, deflate", + safe=False, + request=False, + close=True, + asynchronous=True, + timeout=None, + ssl_verify=False, + use_file=False, + callback=None, + on_init=None, + on_open=None, + on_close=None, + on_headers=None, + on_data=None, + on_result=None, + loop=None, **kwargs ): # extracts the reference to the upper class element associated @@ -1339,9 +1233,10 @@ def method( # notice that the event loop is also re-used accordingly key = cls.protocol.key_g(url) protocol = self.available.pop(key, None) - if protocol and (not protocol.is_open() or\ - protocol.transport().is_closing()): protocol = None - if protocol: loop = loop or protocol.loop() + if protocol and (not protocol.is_open() or protocol.transport().is_closing()): + protocol = None + if protocol: + loop = loop or protocol.loop() # determines if the loop instance was provided by the user so # that latter on we can determine if it should be closed (garbage @@ -1351,7 +1246,8 @@ def method( # in case the current execution model is not asynchronous a new # loop context must exist otherwise it may collide with the global # event loop execution creating unwanted behaviour - if not asynchronous: loop = loop or self._get_loop(**kwargs) + if not asynchronous: + loop = loop or self._get_loop(**kwargs) # creates the new protocol instance that is going to be used to # handle this new request, a new protocol represents also a new @@ -1360,24 +1256,24 @@ def method( protocol = callable( method, url, - params = params, - headers = headers, - data = data, - version = version, - encoding = encoding, - encodings = encodings, - safe = safe, - request = request, - asynchronous = asynchronous, - timeout = timeout, - use_file = use_file, - callback = callback, - on_init = on_init, - on_open = on_open, - on_close = on_close, - on_headers = on_headers, - on_data = on_data, - on_result = on_result + params=params, + headers=headers, + data=data, + version=version, + encoding=encoding, + encodings=encodings, + safe=safe, + request=request, + asynchronous=asynchronous, + timeout=timeout, + use_file=use_file, + callback=callback, + on_init=on_init, + on_open=on_open, + on_close=on_close, + on_headers=on_headers, + on_data=on_data, + on_result=on_result, ) # verifies if the current protocol is already open and if that's the @@ -1393,14 +1289,15 @@ def method( lambda: protocol, protocol.host, protocol.port, - ssl = protocol.ssl, - ssl_verify = ssl_verify, - loop = loop + ssl=protocol.ssl, + ssl_verify=ssl_verify, + loop=loop, ) # in case the asynchronous mode is enabled returns the loop and the protocol # immediately so that it can be properly used by the caller - if asynchronous: return loop, protocol + if asynchronous: + return loop, protocol def on_message(protocol, parser, message): # in case the auto release (no connection re-usage) mode is @@ -1435,7 +1332,8 @@ def on_close(protocol): # in case the protocol that is being closed is not the one # in usage returns immediately (no need to stop the event # loop for a protocol from the available pool) - if from_pool: return + if from_pool: + return # tries to retrieve the loop compatible value and if it's # successful runs the stop operation on the loop @@ -1451,21 +1349,25 @@ def on_close(protocol): # used is not the HTTP client's static loop and also not a user's # provided loop it's closed immediately (garbage collection) loop.run_forever() - if not loop == self._loop and not user_loop: loop.close() + if not loop == self._loop and not user_loop: + loop.close() # returns the final request object (that should be populated by this # time) to the called method, so that a simple interface is provided return protocol.request def _get_loop(self, **kwargs): - if not self._loop: self._loop = netius.new_loop(**kwargs) + if not self._loop: + self._loop = netius.new_loop(**kwargs) return self._loop def _close_loop(self): - if not self._loop: return + if not self._loop: + return self._loop.close() self._loop = None + if __name__ == "__main__": buffer = [] diff --git a/src/netius/clients/mjpg.py b/src/netius/clients/mjpg.py index 3ad172211..e1aa7dddc 100644 --- a/src/netius/clients/mjpg.py +++ b/src/netius/clients/mjpg.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,6 +34,7 @@ from . import http + class MJPGProtocol(http.HTTPProtocol): MAGIC_JPEG = b"\xff\xd8\xff\xe0" @@ -62,10 +54,12 @@ def __init__(self, *args, **kwargs): def add_buffer(self, data): self.buffer_l.append(data) - def get_buffer(self, delete = True): - if not self.buffer_l: return b"" + def get_buffer(self, delete=True): + if not self.buffer_l: + return b"" buffer = b"".join(self.buffer_l) - if delete: del self.buffer_l[:] + if delete: + del self.buffer_l[:] return buffer def on_partial(self, data): @@ -79,7 +73,9 @@ def on_partial(self, data): # received data, and in case it's not found add the (partial) # data to the current buffer, to be latter processed eoi_index = data.find(cls.EOI_JPEG) - if eoi_index == -1: self.buffer_l.append(data); return + if eoi_index == -1: + self.buffer_l.append(data) + return # calculates the size of the end of image (EOI) token so that # this value will be used for the calculus of the image data @@ -88,15 +84,15 @@ def on_partial(self, data): # adds the partial valid data of the current chunk to the buffer # and then joins the current buffer as the frame data, removing # the multipart header from it (to become a valid image) - self.buffer_l.append(data[:eoi_index + eoi_size]) + self.buffer_l.append(data[: eoi_index + eoi_size]) frame = b"".join(self.buffer_l) multipart_index = frame.find(b"\r\n\r\n") - frame = frame[multipart_index + 4:] + frame = frame[multipart_index + 4 :] # clears the current buffer and adds the remaining part of the # current chunk, that may be already part of a new image del self.buffer_l[:] - self.buffer_l.append(data[eoi_index + eoi_size:]) + self.buffer_l.append(data[eoi_index + eoi_size :]) # calls the proper event handler for the new frame data that has # just been received, triggering the processing of the frame @@ -105,10 +101,12 @@ def on_partial(self, data): def on_frame_mjpg(self, data): self.trigger("frame", self, data) + class MJPGClient(http.HTTPClient): protocol = MJPGProtocol + if __name__ == "__main__": index = 0 limit = 30 @@ -116,15 +114,19 @@ class MJPGClient(http.HTTPClient): def on_frame(protocol, data): global index index += 1 - if index >= limit: return protocol.close() + if index >= limit: + return protocol.close() base_path = netius.conf("IMAGES_PATH", "images") base_path = os.path.abspath(base_path) base_path = os.path.normpath(base_path) - if not os.path.exists(base_path): os.makedirs(base_path) + if not os.path.exists(base_path): + os.makedirs(base_path) path = os.path.join(base_path, "%08d.jpg" % index) file = open(path, "wb") - try: file.write(data) - finally: file.close() + try: + file.write(data) + finally: + file.close() print("Saved frame %08d of %d bytes" % (index, len(data))) def on_finish(protocol): diff --git a/src/netius/clients/raw.py b/src/netius/clients/raw.py index 2bbebec09..59f3be192 100644 --- a/src/netius/clients/raw.py +++ b/src/netius/clients/raw.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ import netius + class RawProtocol(netius.StreamProtocol): def send_basic(self): @@ -49,32 +41,22 @@ def send_basic(self): self.send("GET / HTTP/1.0\r\n\r\n") + class RawClient(netius.ClientAgent): protocol = RawProtocol @classmethod - def run_s( - cls, - host, - port = 8080, - loop = None, - *args, - **kwargs - ): + def run_s(cls, host, port=8080, loop=None, *args, **kwargs): protocol = cls.protocol() loop = netius.connect_stream( - lambda: protocol, - host = host, - port = port, - loop = loop, - *args, - **kwargs + lambda: protocol, host=host, port=port, loop=loop, *args, **kwargs ) return loop, protocol + if __name__ == "__main__": def on_open(protocol): diff --git a/src/netius/clients/smtp.py b/src/netius/clients/smtp.py index ad688426a..c1d14f35e 100644 --- a/src/netius/clients/smtp.py +++ b/src/netius/clients/smtp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -74,16 +65,14 @@ FINAL_STATE = 14 + class SMTPConnection(netius.Connection): - AUTH_METHODS = ( - "plain", - "login" - ) + AUTH_METHODS = ("plain", "login") """ The sequence that defined the multiple allowed methods for this SMTP protocol implementation """ - def __init__(self, host = "smtp.localhost", *args, **kwargs): + def __init__(self, host="smtp.localhost", *args, **kwargs): netius.Connection.__init__(self, *args, **kwargs) self.parser = None self.host = host @@ -103,15 +92,18 @@ def __init__(self, host = "smtp.localhost", *args, **kwargs): def open(self, *args, **kwargs): netius.Connection.open(self, *args, **kwargs) - if not self.is_open(): return + if not self.is_open(): + return self.parser = netius.common.SMTPParser(self) self.parser.bind("on_line", self.on_line) self.build() def close(self, *args, **kwargs): netius.Connection.close(self, *args, **kwargs) - if not self.is_closed(): return - if self.parser: self.parser.destroy() + if not self.is_closed(): + return + if self.parser: + self.parser.destroy() self.destroy() def build(self): @@ -136,7 +128,7 @@ def build(self): self.data_t, self.contents_t, self.quit_t, - self.close_t + self.close_t, ) self.state_l = len(self.states) @@ -151,20 +143,21 @@ def destroy(self): self.states = () self.state_l = 0 - def set_smtp(self, froms, tos, contents, username = None, password = None): + def set_smtp(self, froms, tos, contents, username=None, password=None): self.froms = froms self.tos = tos self.contents = contents self.username = username self.password = password - def set_sequence(self, sequence, safe = True): - if safe and self.sequence == sequence: return + def set_sequence(self, sequence, safe=True): + if safe and self.sequence == sequence: + return self.sindex = 0 self.sequence = sequence self.state = sequence[0] - def set_message_seq(self, ehlo = True): + def set_message_seq(self, ehlo=True): sequence = ( EHLO_STATE if ehlo else HELO_STATE, CAPA_STATE, @@ -176,11 +169,11 @@ def set_message_seq(self, ehlo = True): DATA_STATE, CONTENTS_STATE, QUIT_STATE, - FINAL_STATE + FINAL_STATE, ) self.set_sequence(sequence) - def set_message_stls_seq(self, ehlo = True): + def set_message_stls_seq(self, ehlo=True): sequence = ( EHLO_STATE if ehlo else HELO_STATE, CAPA_STATE, @@ -196,12 +189,13 @@ def set_message_stls_seq(self, ehlo = True): DATA_STATE, CONTENTS_STATE, QUIT_STATE, - FINAL_STATE + FINAL_STATE, ) self.set_sequence(sequence) - def set_capabilities(self, capabilities, force = True): - if not force and self.capabilities: return + def set_capabilities(self, capabilities, force=True): + if not force and self.capabilities: + return capabilities = [value.strip().lower() for value in capabilities] self.capabilities = capabilities @@ -214,14 +208,14 @@ def next_sequence(self): def parse(self, data): return self.parser.parse(data) - def send_smtp(self, code, message = "", delay = True, callback = None): + def send_smtp(self, code, message="", delay=True, callback=None): base = "%s %s" % (code, message) data = base + "\r\n" - count = self.send(data, delay = delay, callback = callback) + count = self.send(data, delay=delay, callback=callback) self.owner.debug(base) return count - def on_line(self, code, message, is_final = True): + def on_line(self, code, message, is_final=True): # creates the base string from the provided code value and the # message associated with it, then logs the values into the # current debug logger support (for traceability) @@ -237,7 +231,8 @@ def on_line(self, code, message, is_final = True): # immediately to continue the processing of information for the # current response, the various message should be accumulated under # the message buffer to avoid any problem - if not is_final: return + if not is_final: + return # runs the code based assertion so that if there's an expected # value set for the current connection it's correctly validated @@ -247,7 +242,7 @@ def on_line(self, code, message, is_final = True): # according to the ones that have "generate" handling methods, otherwise # raises a parser error indicating the problem if self.state > self.state_l: - raise netius.ParserError("Invalid state", details = self.messages) + raise netius.ParserError("Invalid state", details=self.messages) # runs the calling of the next state based method according to the # currently defined state, this is the increments in calling @@ -289,7 +284,8 @@ def stls_t(self): def upgrade_t(self): def callback(connection): - connection.upgrade(server = False) + connection.upgrade(server=False) + self.next_sequence() callback(self) @@ -301,7 +297,7 @@ def auth_t(self): return method = self.best_auth() - self.auth(self.username, self.password, method = method) + self.auth(self.username, self.password, method=method) self.next_sequence() def username_t(self): @@ -320,7 +316,8 @@ def rcpt_t(self): is_final = self.to_index == len(self.tos) - 1 self.rcpt(self.tos[self.to_index]) self.to_index += 1 - if is_final: self.next_sequence() + if is_final: + self.next_sequence() def data_t(self): self.data() @@ -338,7 +335,7 @@ def quit_t(self): self.next_sequence() def close_t(self): - self.close(flush = True) + self.close(flush=True) def pass_t(self): pass @@ -360,11 +357,12 @@ def starttls(self): self.send_smtp("starttls") self.set_expected(220) - def auth(self, username, password, method = "plain"): + def auth(self, username, password, method="plain"): self.assert_s(AUTH_STATE) method_name = "auth_%s" % method has_method = hasattr(self, method_name) - if not has_method: raise netius.NotImplemented("Method not implemented") + if not has_method: + raise netius.NotImplemented("Method not implemented") method = getattr(self, method_name) method(username, password) @@ -423,141 +421,129 @@ def set_expected(self, expected): self.expected = expected def assert_c(self, code): - if not self.expected: return + if not self.expected: + return expected = self.expected code_i = int(code) self.expected = None valid = expected == code_i - if valid: return + if valid: + return raise netius.ParserError( - "Invalid response code expected '%d' received '%d'" % - (expected, code_i), - details = self.messages + "Invalid response code expected '%d' received '%d'" % (expected, code_i), + details=self.messages, ) def assert_s(self, expected): - if self.state == expected: return - raise netius.ParserError("Invalid state", details = self.messages) + if self.state == expected: + return + raise netius.ParserError("Invalid state", details=self.messages) def best_auth(self): cls = self.__class__ methods = [] for capability in self.capabilities: is_auth = capability.startswith("auth ") - if not is_auth: continue + if not is_auth: + continue parts = capability.split(" ") parts = [part.strip() for part in parts] methods.extend(parts[1:]) usable = [method for method in methods if method in cls.AUTH_METHODS] return usable[0] if usable else "plain" + class SMTPClient(netius.StreamClient): - def __init__( - self, - host = None, - auto_close = False, - *args, - **kwargs - ): + def __init__(self, host=None, auto_close=False, *args, **kwargs): netius.StreamClient.__init__(self, *args, **kwargs) self.host = host if host else "[" + netius.common.host() + "]" self.auto_close = auto_close @classmethod def message_s( - cls, - froms, - tos, - contents, - daemon = True, - host = None, - mark = True, - callback = None + cls, froms, tos, contents, daemon=True, host=None, mark=True, callback=None ): - smtp_client = cls.get_client_s(thread = True, daemon = daemon, host = host) - smtp_client.message( - froms, - tos, - contents, - mark = mark, - callback = callback - ) + smtp_client = cls.get_client_s(thread=True, daemon=daemon, host=host) + smtp_client.message(froms, tos, contents, mark=mark, callback=callback) def message( self, froms, tos, contents, - message_id = None, - host = None, - port = 25, - username = None, - password = None, - ehlo = True, - stls = False, - mark = True, - comply = False, - ensure_loop = True, - callback = None, - callback_error = None + message_id=None, + host=None, + port=25, + username=None, + password=None, + ehlo=True, + stls=False, + mark=True, + comply=False, + ensure_loop=True, + callback=None, + callback_error=None, ): # in case the comply flag is set then ensure that a series # of mandatory fields are present in the contents if comply: contents = self.comply( - contents, - froms = froms, - tos = tos, - message_id = message_id + contents, froms=froms, tos=tos, message_id=message_id ) # in case the mark flag is set the contents data is modified # and "marked" with the pre-defined header values of the client # this should provide some extra information on the agent - if mark: contents = self.mark(contents) + if mark: + contents = self.mark(contents) # creates the method that is able to generate handler for a # certain sequence of to based (email) addresses - def build_handler(tos, domain = None, tos_map = None): + def build_handler(tos, domain=None, tos_map=None): # creates the context object that will be used to pass # contextual information to the callbacks context = dict( - froms = froms, - tos = tos, - contents = contents, - mark = mark, - comply = comply, - ensure_loop = ensure_loop, - domain = domain, - tos_map = tos_map + froms=froms, + tos=tos, + contents=contents, + mark=mark, + comply=comply, + ensure_loop=ensure_loop, + domain=domain, + tos_map=tos_map, ) - def on_close(connection = None): + def on_close(connection=None): # verifies if the current handler has been build with a # domain based clojure and if that's the case removes the # reference of it from the map of tos, then verifies if the # map is still valid and if that's the case returns and this # is not considered the last remaining SMTP session for the # current send operation (still some open) - if domain: del tos_map[domain] - if tos_map: return + if domain: + del tos_map[domain] + if tos_map: + return # verifies if the callback method is defined and if that's # the case calls the callback indicating the end of the send # operation (note that this may represent multiple SMTP sessions) - if callback: callback(self) + if callback: + callback(self) - def on_exception(connection = None, exception = None): - if callback_error: callback_error(self, context, exception) + def on_exception(connection=None, exception=None): + if callback_error: + callback_error(self, context, exception) - def handler(response = None): + def handler(response=None): # in case the provided response value is invalid returns # immediately, as this should represent a resolution error, # this is only done in case the host is also not defined # as for such situations an address is not retrievable - if response == None and host == None: return + if response == None and host == None: + return # in case there's a valid response provided must parse it # to try to "recover" the final address that is going to be @@ -569,11 +555,13 @@ def handler(response = None): # fallback for this connections is handled if not response.answers: on_close() - if self.auto_close: self.close() + if self.auto_close: + self.close() exception = netius.NetiusError( "Not possible to resolve MX for '%s'" % domain ) - if callback_error: callback_error(self, context, exception) + if callback_error: + callback_error(self, context, exception) raise exception # retrieves the first answer (probably the most accurate) @@ -584,7 +572,8 @@ def handler(response = None): # otherwise the host should have been provided and as such the # address value is set with the provided host - else: address = host + else: + address = host # sets the proper address (host) and port values that are # going to be used to establish the connection, notice that @@ -602,14 +591,12 @@ def handler(response = None): # sets the SMTP information in the current connection, after # the connections is completed the SMTP session should start connection = self.connect(_host, _port) - if stls: connection.set_message_stls_seq(ehlo = ehlo) - else: connection.set_message_seq(ehlo = ehlo) + if stls: + connection.set_message_stls_seq(ehlo=ehlo) + else: + connection.set_message_seq(ehlo=ehlo) connection.set_smtp( - froms, - tos, - contents, - username = username, - password = password + froms, tos, contents, username=username, password=password ) connection.bind("close", on_close) connection.bind("exception", on_exception) @@ -631,7 +618,8 @@ def handler(response = None): # SMTP client does not become orphan as no connection has been # established as of this moment (as expected) and the dns client # is going to be run as a daemon (avoids process exit) - if ensure_loop: self.ensure_loop() + if ensure_loop: + self.ensure_loop() # creates the map that is going to be used to associate each of # the domains with the proper to (email) addresses, this is going @@ -649,7 +637,7 @@ def handler(response = None): for domain, tos in netius.legacy.items(tos_map): # creates a new handler method bound to the to addresses # associated with the current domain in iteration - handler = build_handler(tos, domain = domain, tos_map = tos_map) + handler = build_handler(tos, domain=domain, tos_map=tos_map) # prints a small debug message about the resolution of the # domain for the current message (debugging purposes) @@ -658,7 +646,7 @@ def handler(response = None): # runs the dns query to be able to retrieve the proper # mail exchange host for the target email address and then # sets the proper callback for sending - dns.DNSClient.query_s(domain, type = "mx", callback = handler) + dns.DNSClient.query_s(domain, type="mx", callback=handler) def on_connect(self, connection): netius.StreamClient.on_connect(self, connection) @@ -673,25 +661,26 @@ def on_data(self, connection, data): def on_connection_d(self, connection): netius.StreamClient.on_connection_d(self, connection) - if not self.auto_close: return - if self.connections: return + if not self.auto_close: + return + if self.connections: + return self.close() - def build_connection(self, socket, address, ssl = False): + def build_connection(self, socket, address, ssl=False): return SMTPConnection( - owner = self, - socket = socket, - address = address, - ssl = ssl, - host = self.host + owner=self, socket=socket, address=address, ssl=ssl, host=self.host ) - def comply(self, contents, froms = None, tos = None, message_id = None): + def comply(self, contents, froms=None, tos=None, message_id=None): parser = email.parser.Parser() message = parser.parsestr(contents) - if froms: self.from_(message, froms[0]) - if tos: self.to(message, ",".join(tos)) - if message_id: self.message_id(message, message_id) + if froms: + self.from_(message, froms[0]) + if tos: + self.to(message, ",".join(tos)) + if message_id: + self.message_id(message, message_id) return message.as_string() def mark(self, contents): @@ -703,41 +692,47 @@ def mark(self, contents): def from_(self, message, value): from_ = message.get("From", None) - if from_: return + if from_: + return message["From"] = value def to(self, message, value): to = message.get("To", None) - if to: return + if to: + return message["To"] = value def message_id(self, message, value): message_id = message.get("Message-Id", None) message_id = message.get("Message-ID", message_id) - if message_id: return + if message_id: + return message["Message-ID"] = value def date(self, message): date = message.get("Date", None) - if date: return + if date: + return date_time = datetime.datetime.utcnow() message["Date"] = date_time.strftime("%a, %d %b %Y %H:%M:%S +0000") def user_agent(self, message): user_agent = message.get("User-Agent", None) - if user_agent: return + if user_agent: + return message["User-Agent"] = netius.IDENTIFIER + if __name__ == "__main__": import email.mime.text sender = netius.conf("SMTP_SENDER", "hello@bemisc.com") receiver = netius.conf("SMTP_RECEIVER", "hello@bemisc.com") host = netius.conf("SMTP_HOST", None) - port = netius.conf("SMTP_PORT", 25, cast = int) + port = netius.conf("SMTP_PORT", 25, cast=int) username = netius.conf("SMTP_USER", None) password = netius.conf("SMTP_PASSWORD", None) - stls = netius.conf("SMTP_STARTTLS", False, cast = bool) + stls = netius.conf("SMTP_STARTTLS", False, cast=bool) mime = email.mime.text.MIMEText("Hello World") mime["Subject"] = "Hello World" @@ -745,16 +740,16 @@ def user_agent(self, message): mime["To"] = receiver contents = mime.as_string() - client = SMTPClient(auto_close = True) + client = SMTPClient(auto_close=True) client.message( [sender], [receiver], contents, - host = host, - port = port, - username = username, - password = password, - stls = stls + host=host, + port=port, + username=username, + password=password, + stls=stls, ) else: __path__ = [] diff --git a/src/netius/clients/ssdp.py b/src/netius/clients/ssdp.py index 3f2f4c79d..136e8ec50 100644 --- a/src/netius/clients/ssdp.py +++ b/src/netius/clients/ssdp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ import netius.common + class SSDPProtocol(netius.DatagramProtocol): """ Protocol implementation of the SSDP protocol meant to be @@ -53,7 +45,7 @@ class SSDPProtocol(netius.DatagramProtocol): def on_data(self, address, data): netius.DatagramProtocol.on_data(self, address, data) - self.parser = netius.common.HTTPParser(self, type = netius.common.RESPONSE) + self.parser = netius.common.HTTPParser(self, type=netius.common.RESPONSE) self.parser.bind("on_headers", self.on_headers_parser) self.parser.parse(data) self.parser.destroy() @@ -63,34 +55,28 @@ def on_headers_parser(self): self.trigger("headers", self, self.parser, headers) def discover(self, target, *args, **kwargs): - return self.method( - "M-SEARCH", - target, - "ssdp:discover", - *args, - **kwargs - ) + return self.method("M-SEARCH", target, "ssdp:discover", *args, **kwargs) def method( self, method, target, namespace, - mx = 3, - path = "*", - params = None, - headers = None, - data = None, - host = "239.255.255.250", - port = 1900, - version = "HTTP/1.1", - callback = None + mx=3, + path="*", + params=None, + headers=None, + data=None, + host="239.255.255.250", + port=1900, + version="HTTP/1.1", + callback=None, ): address = (host, port) headers = headers or dict() headers["ST"] = target - headers["Man"] = "\"" + namespace + "\"" + headers["Man"] = '"' + namespace + '"' headers["MX"] = str(mx) headers["Host"] = "%s:%d" % address @@ -98,13 +84,16 @@ def method( buffer.append("%s %s %s\r\n" % (method, path, version)) for key, value in netius.legacy.iteritems(headers): key = netius.common.header_up(key) - if not isinstance(value, list): value = (value,) - for _value in value: buffer.append("%s: %s\r\n" % (key, _value)) + if not isinstance(value, list): + value = (value,) + for _value in value: + buffer.append("%s: %s\r\n" % (key, _value)) buffer.append("\r\n") buffer_data = "".join(buffer) self.send(buffer_data, address) - data and self.send(data, address, callback = callback) + data and self.send(data, address, callback=callback) + class SSDPClient(netius.ClientAgent): @@ -112,13 +101,7 @@ class SSDPClient(netius.ClientAgent): @classmethod def discover_s(cls, target, *args, **kwargs): - return cls.method_s( - "M-SEARCH", - target, - "ssdp:discover", - *args, - **kwargs - ) + return cls.method_s("M-SEARCH", target, "ssdp:discover", *args, **kwargs) @classmethod def method_s( @@ -126,16 +109,16 @@ def method_s( method, target, namespace, - mx = 3, - path = "*", - params = None, - headers = None, - data = None, - host = "239.255.255.250", - port = 1900, - version = "HTTP/1.1", - callback = None, - loop = None + mx=3, + path="*", + params=None, + headers=None, + data=None, + host="239.255.255.250", + port=1900, + version="HTTP/1.1", + callback=None, + loop=None, ): address = (host, port) protocol = cls.protocol() @@ -146,27 +129,26 @@ def on_connect(result): method, target, namespace, - mx = mx, - path = path, - params = params, - headers = headers, - data = data, - host = host, - port = port, - version = version, - callback = callback + mx=mx, + path=path, + params=params, + headers=headers, + data=data, + host=host, + port=port, + version=version, + callback=callback, ) loop = netius.build_datagram( - lambda: protocol, - callback = on_connect, - loop = loop, - remote_addr = address + lambda: protocol, callback=on_connect, loop=loop, remote_addr=address ) return loop, protocol + if __name__ == "__main__": + def on_headers(client, parser, headers): print(headers) @@ -175,7 +157,9 @@ def on_headers(client, parser, headers): # stop operation on the next tick end netius.compat_loop(loop).stop() - target = netius.conf("SSDP_TARGET", "urn:schemas-upnp-org:device:InternetGatewayDevice:1") + target = netius.conf( + "SSDP_TARGET", "urn:schemas-upnp-org:device:InternetGatewayDevice:1" + ) loop, protocol = SSDPClient.discover_s(target) diff --git a/src/netius/clients/torrent.py b/src/netius/clients/torrent.py index 49f493863..e87cf2c01 100644 --- a/src/netius/clients/torrent.py +++ b/src/netius/clients/torrent.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -66,9 +57,10 @@ using the current torrent infra-structure, this value conditions most of the torrent operations and should be defined carefully """ + class TorrentConnection(netius.Connection): - def __init__(self, max_requests = 50, *args, **kwargs): + def __init__(self, max_requests=50, *args, **kwargs): netius.Connection.__init__(self, *args, **kwargs) self.parser = None self.max_requests = max_requests @@ -85,17 +77,20 @@ def __init__(self, max_requests = 50, *args, **kwargs): def open(self, *args, **kwargs): netius.Connection.open(self, *args, **kwargs) - if not self.is_open(): return + if not self.is_open(): + return self.parser = netius.common.TorrentParser(self) self.bind("close", self.on_close) self.parser.bind("on_handshake", self.on_handshake) self.parser.bind("on_message", self.on_message) - self.is_alive(timeout = ALIVE_TIMEOUT, schedule = True) + self.is_alive(timeout=ALIVE_TIMEOUT, schedule=True) def close(self, *args, **kwargs): netius.Connection.close(self, *args, **kwargs) - if not self.is_closed(): return - if self.parser: self.parser.destroy() + if not self.is_closed(): + return + if self.parser: + self.parser.destroy() def on_close(self, connection): self.release() @@ -118,7 +113,8 @@ def handle(self, type, data): # for the handle of the message from the provided type and verifies # if the method exists under the current instance method_name = "%s_t" % type - if not hasattr(self, method_name): return + if not hasattr(self, method_name): + return # tries to retrieve the method for the current state in iteration # and then calls the retrieve method with (handler method) @@ -130,13 +126,15 @@ def bitfield_t(self, data): self.bitfield = [True if value == "1" else False for value in bitfield] def choke_t(self, data): - if self.choked == CHOKED: return + if self.choked == CHOKED: + return self.choked = CHOKED self.release() self.trigger("choked", self) def unchoke_t(self, data): - if self.choked == UNCHOKED: return + if self.choked == UNCHOKED: + return self.choked = UNCHOKED self.reset() self.next() @@ -153,17 +151,20 @@ def piece_t(self, data): self.trigger("piece", self, data, index, begin) def port_t(self, data): - port, = struct.unpack("!H", data[:8]) + (port,) = struct.unpack("!H", data[:8]) self.task.set_dht(self.address, port) - def next(self, count = None): - if not self.choked == UNCHOKED: return - if count == None: count = self.max_requests - self.pend_requests + def next(self, count=None): + if not self.choked == UNCHOKED: + return + if count == None: + count = self.max_requests - self.pend_requests for _index in range(count): block = self.task.pop_block(self.bitfield) - if not block: return + if not block: + return index, begin, length = block - self.request(index, begin = begin, length = length) + self.request(index, begin=begin, length=length) block_t = (index, begin) self.add_request(block_t) @@ -172,7 +173,8 @@ def add_request(self, block): self.pend_requests += 1 def remove_request(self, block): - if not block in self.requests: return + if not block in self.requests: + return self.requests.remove(block) self.pend_requests -= 1 @@ -192,7 +194,7 @@ def handshake(self): b"BitTorrent protocol", 1, self.task.info_hash, - netius.legacy.bytes(self.task.owner.peer_id) + netius.legacy.bytes(self.task.owner.peer_id), ) data and self.send(data) @@ -220,26 +222,33 @@ def have(self, index): data = struct.pack("!LBL", 5, 4, index) data and self.send(data) - def request(self, index, begin = 0, length = BLOCK_SIZE): + def request(self, index, begin=0, length=BLOCK_SIZE): data = struct.pack("!LBLLL", 13, 6, index, begin, length) data and self.send(data) - def is_alive(self, timeout = ALIVE_TIMEOUT, schedule = False): + def is_alive(self, timeout=ALIVE_TIMEOUT, schedule=False): messages = self.messages downloaded = self.downloaded def clojure(): - if not self.is_open(): return + if not self.is_open(): + return delta = self.downloaded - downloaded rate = float(delta) / float(timeout) - if self.messages == messages: self.close(flush = True); return - if rate < SPEED_LIMIT: self.close(flush = True); return + if self.messages == messages: + self.close(flush=True) + return + if rate < SPEED_LIMIT: + self.close(flush=True) + return callable = self.is_alive() self.owner.delay(callable, timeout) - if schedule: self.owner.delay(clojure, timeout) + if schedule: + self.owner.delay(clojure, timeout) return clojure + class TorrentClient(netius.StreamClient): """ Implementation of the torrent protocol, able to download @@ -255,8 +264,8 @@ class TorrentClient(netius.StreamClient): :see: http://www.bittorrent.org/beps/bep_0003.html """ - def peer(self, task, host, port, ssl = False, connection = None): - connection = connection or self.acquire_c(host, port, ssl = ssl) + def peer(self, task, host, port, ssl=False, connection=None): + connection = connection or self.acquire_c(host, port, ssl=ssl) connection.task = task return connection @@ -273,10 +282,5 @@ def on_data(self, connection, data): netius.StreamClient.on_data(self, connection, data) connection.parse(data) - def build_connection(self, socket, address, ssl = False): - return TorrentConnection( - owner = self, - socket = socket, - address = address, - ssl = ssl - ) + def build_connection(self, socket, address, ssl=False): + return TorrentConnection(owner=self, socket=socket, address=address, ssl=ssl) diff --git a/src/netius/clients/ws.py b/src/netius/clients/ws.py index 10745706d..e9841c2ce 100644 --- a/src/netius/clients/ws.py +++ b/src/netius/clients/ws.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,6 +34,7 @@ import netius.common + class WSProtocol(netius.StreamProtocol): """ Abstract WebSockets protocol to be used for real-time bidirectional @@ -56,7 +48,7 @@ class WSProtocol(netius.StreamProtocol): of the key generation process in the handshake """ @classmethod - def _key(cls, size = 16): + def _key(cls, size=16): seed = str(uuid.uuid4()) seed = netius.legacy.bytes(seed)[:size] seed = base64.b64encode(seed) @@ -78,13 +70,15 @@ def __init__(self, *args, **kwargs): def connection_made(self, transport): netius.StreamProtocol.connection_made(self, transport) - data = "GET %s HTTP/1.1\r\n" % self.path +\ - "Upgrade: websocket\r\n" +\ - "Connection: Upgrade\r\n" +\ - "Host: %s\r\n" % self.host +\ - "Origin: http://%s\r\n" % self.host +\ - "Sec-WebSocket-Key: %s\r\n" % self.key +\ - "Sec-WebSocket-Version: 13\r\n\r\n" + data = ( + "GET %s HTTP/1.1\r\n" % self.path + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Host: %s\r\n" % self.host + + "Origin: http://%s\r\n" % self.host + + "Sec-WebSocket-Key: %s\r\n" % self.key + + "Sec-WebSocket-Version: 13\r\n\r\n" + ) self.send(data) def on_data(self, data): @@ -100,8 +94,11 @@ def on_data(self, data): # a problem the (pending) data is added to the buffer buffer = self.get_buffer() data = buffer + data - try: decoded, data = netius.common.decode_ws(data) - except netius.DataError: self.add_buffer(data); break + try: + decoded, data = netius.common.decode_ws(data) + except netius.DataError: + self.add_buffer(data) + break # calls the callback method in the protocol notifying # it about the new (decoded) data that has been received @@ -120,8 +117,10 @@ def on_data(self, data): # current protocol in case it fails due to an # handshake error must delay the execution to the # next iteration (not enough data) - try: self.do_handshake() - except netius.DataError: return + try: + self.do_handshake() + except netius.DataError: + return # validates (and computes) the accept key value according # to the provided value, in case there's an error an exception @@ -143,7 +142,7 @@ def on_data_ws(self, data): def on_handshake(self): self.trigger("handshake", self) - def connect_ws(self, url, callback = None, loop = None): + def connect_ws(self, url, callback=None, loop=None): cls = self.__class__ parsed = netius.legacy.urlparse(url) @@ -153,22 +152,19 @@ def connect_ws(self, url, callback = None, loop = None): self.path = parsed.path or "/" loop = netius.connect_stream( - lambda: self, - host = self.host, - port = self.port, - ssl = self.ssl, - loop = loop + lambda: self, host=self.host, port=self.port, ssl=self.ssl, loop=loop ) self.key = cls._key() - if callback: self.bind("handshake", callback, oneshot = True) + if callback: + self.bind("handshake", callback, oneshot=True) return loop, self - def send_ws(self, data, callback = None): - encoded = netius.common.encode_ws(data, mask = True) - return self.send(encoded, callback = callback) + def send_ws(self, data, callback=None): + encoded = netius.common.encode_ws(data, mask=True) + return self.send(encoded, callback=callback) def receive_ws(self, decoded): pass @@ -176,10 +172,12 @@ def receive_ws(self, decoded): def add_buffer(self, data): self.buffer_l.append(data) - def get_buffer(self, delete = True): - if not self.buffer_l: return b"" + def get_buffer(self, delete=True): + if not self.buffer_l: + return b"" buffer = b"".join(self.buffer_l) - if delete: del self.buffer_l[:] + if delete: + del self.buffer_l[:] return buffer def do_handshake(self): @@ -188,16 +186,18 @@ def do_handshake(self): buffer = b"".join(self.buffer_l) end_index = buffer.find(b"\r\n\r\n") - if end_index == -1: raise netius.DataError("Missing data for handshake") + if end_index == -1: + raise netius.DataError("Missing data for handshake") - data = buffer[:end_index + 4] - remaining = buffer[end_index + 4:] + data = buffer[: end_index + 4] + remaining = buffer[end_index + 4 :] lines = data.split(b"\r\n") for line in lines[1:]: values = line.split(b":", 1) values_l = len(values) - if not values_l == 2: continue + if not values_l == 2: + continue key, value = values key = key.strip() @@ -213,7 +213,8 @@ def do_handshake(self): del self.buffer_l[:] self.handshake = True - if remaining: self.add_buffer(remaining) + if remaining: + self.add_buffer(remaining) def validate_key(self): accept_key = self.headers.get("Sec-WebSocket-Accept", None) @@ -229,16 +230,19 @@ def validate_key(self): if not _accept_key == accept_key: raise netius.SecurityError("Invalid accept key provided") + class WSClient(netius.ClientAgent): protocol = WSProtocol @classmethod - def connect_ws_s(cls, url, callback = None, loop = None): + def connect_ws_s(cls, url, callback=None, loop=None): protocol = cls.protocol() - return protocol.connect_ws(url, callback = callback, loop = loop) + return protocol.connect_ws(url, callback=callback, loop=loop) + if __name__ == "__main__": + def on_handshake(protocol): protocol.send_ws("Hello World") diff --git a/src/netius/common/__init__.py b/src/netius/common/__init__.py index ec598a661..d9b6fea47 100644 --- a/src/netius/common/__init__.py +++ b/src/netius/common/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -59,30 +50,113 @@ from . import ws from .asn import asn1_parse, asn1_length, asn1_gen, asn1_build -from .calc import prime, is_prime, relatively_prime, gcd, egcd, modinv,\ - random_integer_interval, random_primality, jacobi_witness, jacobi, ceil_integer -from .dhcp import SUBNET_DHCP, ROUTER_DHCP, DNS_DHCP, NAME_DHCP, BROADCAST_DHCP,\ - REQUESTED_DHCP, LEASE_DHCP, DISCOVER_DHCP, OFFER_DHCP, REQUEST_DHCP, DECLINE_DHCP,\ - ACK_DHCP, NAK_DHCP, IDENTIFIER_DHCP, RENEWAL_DHCP, REBIND_DHCP, PROXY_DHCP,\ - END_DHCP, OPTIONS_DHCP, TYPES_DHCP, VERBS_DHCP, AddressPool +from .calc import ( + prime, + is_prime, + relatively_prime, + gcd, + egcd, + modinv, + random_integer_interval, + random_primality, + jacobi_witness, + jacobi, + ceil_integer, +) +from .dhcp import ( + SUBNET_DHCP, + ROUTER_DHCP, + DNS_DHCP, + NAME_DHCP, + BROADCAST_DHCP, + REQUESTED_DHCP, + LEASE_DHCP, + DISCOVER_DHCP, + OFFER_DHCP, + REQUEST_DHCP, + DECLINE_DHCP, + ACK_DHCP, + NAK_DHCP, + IDENTIFIER_DHCP, + RENEWAL_DHCP, + REBIND_DHCP, + PROXY_DHCP, + END_DHCP, + OPTIONS_DHCP, + TYPES_DHCP, + VERBS_DHCP, + AddressPool, +) from .dkim import dkim_sign, dkim_headers, dkim_body, dkim_fold, dkim_generate from .ftp import FTPParser from .geo import GeoResolver -from .http import REQUEST, RESPONSE, PLAIN_ENCODING, CHUNKED_ENCODING, GZIP_ENCODING,\ - DEFLATE_ENCODING, HTTP_09, HTTP_10, HTTP_11, VERSIONS_MAP, CODE_STRINGS, HTTPParser,\ - HTTPResponse -from .http2 import DATA, HEADERS, PRIORITY, RST_STREAM, SETTINGS, PUSH_PROMISE,\ - PING, GOAWAY, WINDOW_UPDATE, CONTINUATION, HTTP2_WINDOW, HTTP2_PREFACE,\ - HTTP2_TUPLES, HTTP2_NAMES, HTTP2_SETTINGS, HTTP2_SETTINGS_OPTIMAL, HTTP2_SETTINGS_T,\ - HTTP2_SETTINGS_OPTIMAL_T, HTTP2Parser, HTTP2Stream +from .http import ( + REQUEST, + RESPONSE, + PLAIN_ENCODING, + CHUNKED_ENCODING, + GZIP_ENCODING, + DEFLATE_ENCODING, + HTTP_09, + HTTP_10, + HTTP_11, + VERSIONS_MAP, + CODE_STRINGS, + HTTPParser, + HTTPResponse, +) +from .http2 import ( + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, + CONTINUATION, + HTTP2_WINDOW, + HTTP2_PREFACE, + HTTP2_TUPLES, + HTTP2_NAMES, + HTTP2_SETTINGS, + HTTP2_SETTINGS_OPTIMAL, + HTTP2_SETTINGS_T, + HTTP2_SETTINGS_OPTIMAL_T, + HTTP2Parser, + HTTP2Stream, +) from .mime import rfc822_parse, rfc822_join, mime_register from .parser import Parser from .pop import POPParser -from .rsa import open_pem_key, open_pem_data, write_pem_key, open_private_key, open_private_key_b64,\ - open_private_key_data, open_public_key, open_public_key_b64, open_public_key_data,\ - write_private_key, write_public_key, asn_private_key, asn_public_key, pem_to_der,\ - pem_limiters, private_to_public, assert_private, rsa_private, rsa_primes, rsa_exponents,\ - rsa_bits, rsa_sign, rsa_verify, rsa_crypt_s, rsa_crypt +from .rsa import ( + open_pem_key, + open_pem_data, + write_pem_key, + open_private_key, + open_private_key_b64, + open_private_key_data, + open_public_key, + open_public_key_b64, + open_public_key_data, + write_private_key, + write_public_key, + asn_private_key, + asn_public_key, + pem_to_der, + pem_limiters, + private_to_public, + assert_private, + rsa_private, + rsa_primes, + rsa_exponents, + rsa_bits, + rsa_sign, + rsa_verify, + rsa_crypt_s, + rsa_crypt, +) from .setup import ensure_setup, ensure_ca from .smtp import SMTPParser from .socks import SOCKSParser @@ -92,8 +166,29 @@ from .style import BASE_STYLE from .tftp import RRQ_TFTP, WRQ_TFTP, DATA_TFTP, ACK_TFTP, ERROR_TFTP, TYPES_TFTP from .torrent import info_hash, bencode, bdecode, chunk, dechunk, TorrentParser -from .util import cstring, chunks, header_down, header_up, is_ip4, is_ip6, assert_ip4,\ - in_subnet_ip4, addr_to_ip4, addr_to_ip6, ip4_to_addr, string_to_bits, integer_to_bytes,\ - bytes_to_integer, random_integer, host, hostname, size_round_unit, verify, verify_equal,\ - verify_not_equal, verify_type, verify_many +from .util import ( + cstring, + chunks, + header_down, + header_up, + is_ip4, + is_ip6, + assert_ip4, + in_subnet_ip4, + addr_to_ip4, + addr_to_ip6, + ip4_to_addr, + string_to_bits, + integer_to_bytes, + bytes_to_integer, + random_integer, + host, + hostname, + size_round_unit, + verify, + verify_equal, + verify_not_equal, + verify_type, + verify_many, +) from .ws import encode_ws, decode_ws, assert_ws diff --git a/src/netius/common/asn.py b/src/netius/common/asn.py index 43e1e5c87..8d7dad0d5 100644 --- a/src/netius/common/asn.py +++ b/src/netius/common/asn.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -48,41 +39,32 @@ OBJECT_IDENTIFIER = 0x06 SEQUENCE = 0x30 -ASN1_OBJECT = [ - (SEQUENCE, [ - (SEQUENCE, [ - OBJECT_IDENTIFIER, - NULL - ]), - BIT_STRING - ]) -] +ASN1_OBJECT = [(SEQUENCE, [(SEQUENCE, [OBJECT_IDENTIFIER, NULL]), BIT_STRING])] -ASN1_RSA_PUBLIC_KEY = [ - (SEQUENCE, [ - INTEGER, - INTEGER - ]) -] +ASN1_RSA_PUBLIC_KEY = [(SEQUENCE, [INTEGER, INTEGER])] ASN1_RSA_PRIVATE_KEY = [ - (SEQUENCE, [ - INTEGER, - INTEGER, - INTEGER, - INTEGER, - INTEGER, - INTEGER, - INTEGER, - INTEGER, - INTEGER - ]) + ( + SEQUENCE, + [ + INTEGER, + INTEGER, + INTEGER, + INTEGER, + INTEGER, + INTEGER, + INTEGER, + INTEGER, + INTEGER, + ], + ) ] RSAID_PKCS1 = b"\x2a\x86\x48\x86\xf7\x0d\x01\x01\x01" HASHID_SHA1 = b"\x2b\x0e\x03\x02\x1a" HASHID_SHA256 = b"\x60\x86\x48\x01\x65\x03\x04\x02\x01" + def asn1_parse(template, data): """ Parse a data structure according to asn.1 template, @@ -111,8 +93,11 @@ def asn1_parse(template, data): # item to be parser is tuple and based on that defined # the current expected data type and children values is_tuple = type(item) == tuple - if is_tuple: dtype, children = item - else: dtype = item; children = None + if is_tuple: + dtype, children = item + else: + dtype = item + children = None # retrieves the value (as an ordinal) for the current # byte and increments the index for the parser @@ -123,7 +108,9 @@ def asn1_parse(template, data): # must raise an exception indicating the problem to # the top level layers (should be properly handled) if not tag == dtype: - raise netius.ParserError("Unexpected tag (got 0x%02x, expecting 0x%02x)" % (tag, dtype)) + raise netius.ParserError( + "Unexpected tag (got 0x%02x, expecting 0x%02x)" % (tag, dtype) + ) # retrieves the ordinal value of the current byte as # the length of the value to be parsed and then increments @@ -135,20 +122,20 @@ def asn1_parse(template, data): # the byte designates the length of the byte sequence that # defines the length of the current value to be read instead if length & 0x80: - number = length & 0x7f - length = util.bytes_to_integer(data[index:index + number]) + number = length & 0x7F + length = util.bytes_to_integer(data[index : index + number]) index += number if tag == BIT_STRING: - result.append(data[index:index + length]) + result.append(data[index : index + length]) index += length elif tag == OCTET_STRING: - result.append(data[index:index + length]) + result.append(data[index : index + length]) index += length elif tag == INTEGER: - number = util.bytes_to_integer(data[index:index + length]) + number = util.bytes_to_integer(data[index : index + length]) index += length result.append(number) @@ -157,11 +144,11 @@ def asn1_parse(template, data): result.append(None) elif tag == OBJECT_IDENTIFIER: - result.append(data[index:index + length]) + result.append(data[index : index + length]) index += length elif tag == SEQUENCE: - part = asn1_parse(children, data[index:index + length]) + part = asn1_parse(children, data[index : index + length]) result.append(part) index += length @@ -170,6 +157,7 @@ def asn1_parse(template, data): return result + def asn1_length(length): """ Returns a string representing a field length in asn.1 format. @@ -185,17 +173,20 @@ def asn1_length(length): """ netius.verify(length >= 0) - if length < 0x7f: return netius.legacy.chr(length) + if length < 0x7F: + return netius.legacy.chr(length) result = util.integer_to_bytes(length) number = len(result) result = netius.legacy.chr(number | 0x80) + result return result + def asn1_gen(node): generator = asn1_build(node) return b"".join(generator) + def asn1_build(node): """ Builds an asn.1 data structure based on pairs of (type, data), diff --git a/src/netius/common/calc.py b/src/netius/common/calc.py index 6c08212a1..ca29f2069 100644 --- a/src/netius/common/calc.py +++ b/src/netius/common/calc.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -44,6 +35,7 @@ from . import util + def prime(number_bits): """ Generates a prime number with the given number of bits @@ -70,15 +62,18 @@ def prime(number_bits): # using the primality testing strategy, and in case # it's breaks the current loop as a prime has been # found with the pre-defined number of bits - if is_prime(integer): break + if is_prime(integer): + break # returns the (generated) and verified prime integer # to the caller method, may be used for exponent return integer + def is_prime(number): return random_primality(number, 6) + def relatively_prime(first, second): # retrieves the greatest common divisor between the # two values and verifies if the value is one, for @@ -86,6 +81,7 @@ def relatively_prime(first, second): divisor = gcd(first, second) return divisor == 1 + def gcd(first, second): """ Calculates the greatest common divisor of p value and q value. @@ -104,7 +100,8 @@ def gcd(first, second): # in case the p value is smaller than the q value # reverses the order of the arguments and re-computes - if first < second: return gcd(second, first) + if first < second: + return gcd(second, first) # in case the q value is zero if second == 0: @@ -118,6 +115,7 @@ def gcd(first, second): next = abs(first % second) return gcd(second, next) + def egcd(first, second): """ Extended version of the greatest common divisor created @@ -136,7 +134,8 @@ def egcd(first, second): :see: http://en.wikipedia.org/wiki/Extended_Euclidean_algorithm """ - if second == 0: return (first, 1, 0) + if second == 0: + return (first, 1, 0) q = abs(first % second) r = first // second @@ -144,6 +143,7 @@ def egcd(first, second): return (d, l, k - l * r) + def modinv(first, second): """ Uses the extended greatest common divisor algorithm to compute @@ -164,8 +164,11 @@ def modinv(first, second): """ d, l, _e = egcd(first, second) - if d != 1: raise netius.DataError("Modular inverse does not exist") - else: return l % second + if d != 1: + raise netius.DataError("Modular inverse does not exist") + else: + return l % second + def random_integer_interval(min_value, max_value): # sets the default minimum number of bits, even if the @@ -195,6 +198,7 @@ def random_integer_interval(min_value, max_value): random_base_value = util.random_integer(number_bits) % range return random_base_value + min_value + def random_primality(number, k): """ Uses a probabilistic approach to the testing of primality @@ -225,12 +229,14 @@ def random_primality(number, k): # is going to be verified random_number = random_integer_interval(1, number - 1) is_witness = jacobi_witness(random_number, number) - if is_witness: return False + if is_witness: + return False # returns valid as no jacobi witness has been found # for the current number that is being verified return True + def jacobi_witness(x, n): """ Checks if the given x value is witness to n value @@ -251,8 +257,11 @@ def jacobi_witness(x, n): j = jacobi(x, n) % n f = pow(x, (n - 1) // 2, n) - if j == f: return False - else: return True + if j == f: + return False + else: + return True + def jacobi(a, b): """ @@ -272,20 +281,24 @@ def jacobi(a, b): :see: http://en.wikipedia.org/wiki/Jacobi_symbol """ - if a % b == 0: return 0 + if a % b == 0: + return 0 result = 1 while a > 1: if a & 1: - if ((a - 1) * (b - 1) >> 2) & 1: result = -result + if ((a - 1) * (b - 1) >> 2) & 1: + result = -result b, a = a, b % a else: - if ((b ** 2 - 1) >> 3) & 1: result = -result + if ((b**2 - 1) >> 3) & 1: + result = -result a >>= 1 return result + def ceil_integer(value): """ Retrieves the ceil of a value and then converts it diff --git a/src/netius/common/dhcp.py b/src/netius/common/dhcp.py index 220805190..302ea8977 100644 --- a/src/netius/common/dhcp.py +++ b/src/netius/common/dhcp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -62,46 +53,47 @@ END_DHCP = 18 OPTIONS_DHCP = dict( - subnet = SUBNET_DHCP, - router = ROUTER_DHCP, - dns = DNS_DHCP, - name = NAME_DHCP, - broadcast = BROADCAST_DHCP, - lease = LEASE_DHCP, - discover = DISCOVER_DHCP, - offer = OFFER_DHCP, - request = REQUEST_DHCP, - decline = DECLINE_DHCP, - ack = ACK_DHCP, - nak = NAK_DHCP, - identifier = IDENTIFIER_DHCP, - renewal = RENEWAL_DHCP, - rebind = REBIND_DHCP, - proxy = PROXY_DHCP, - end = END_DHCP + subnet=SUBNET_DHCP, + router=ROUTER_DHCP, + dns=DNS_DHCP, + name=NAME_DHCP, + broadcast=BROADCAST_DHCP, + lease=LEASE_DHCP, + discover=DISCOVER_DHCP, + offer=OFFER_DHCP, + request=REQUEST_DHCP, + decline=DECLINE_DHCP, + ack=ACK_DHCP, + nak=NAK_DHCP, + identifier=IDENTIFIER_DHCP, + renewal=RENEWAL_DHCP, + rebind=REBIND_DHCP, + proxy=PROXY_DHCP, + end=END_DHCP, ) """ The map of option names that associates a string based name with the integer based counter-part for resolution """ TYPES_DHCP = { - 0x01 : "discover", - 0x02 : "offer", - 0x03 : "request", - 0x04 : "decline", - 0x05 : "ack", - 0x06 : "nak" + 0x01: "discover", + 0x02: "offer", + 0x03: "request", + 0x04: "decline", + 0x05: "ack", + 0x06: "nak", } VERBS_DHCP = { - 0x01 : "discovering", - 0x02 : "offering", - 0x03 : "requesting", - 0x04 : "declining", - 0x05 : "acknowledging", - 0x06 : "not acknowledging" + 0x01: "discovering", + 0x02: "offering", + 0x03: "requesting", + 0x04: "declining", + 0x05: "acknowledging", + 0x06: "not acknowledging", } + class AddressPool(object): def __init__(self, start_addr, end_addr): @@ -121,8 +113,11 @@ def get_next(cls, current): current_l = [int(value) for value in current_l] for index, value in enumerate(current_l): - if value == 255: current_l[index] = 0 - else: current_l[index] = value + 1; break + if value == 255: + current_l[index] = 0 + else: + current_l[index] = value + 1 + break current_l.reverse() @@ -137,7 +132,8 @@ def peek(self): target, addr = heapq.heappop(self.addrs) _target = self.map.get(addr, 0) - if not target == _target: continue + if not target == _target: + continue if target > current: heapq.heappush(self.addrs, (target, addr)) @@ -147,7 +143,7 @@ def peek(self): return addr - def reserve(self, owner = None, lease = 3600): + def reserve(self, owner=None, lease=3600): current = time.time() target = int(current + lease) addr = self.peek() @@ -157,11 +153,10 @@ def reserve(self, owner = None, lease = 3600): heapq.heappush(self.addrs, (target, addr)) return addr - def touch(self, addr, lease = 3600): + def touch(self, addr, lease=3600): is_valid = self.is_valid(addr) - if not is_valid: raise netius.NetiusError( - "Not possible to touch address" - ) + if not is_valid: + raise netius.NetiusError("Not possible to touch address") current = time.time() target = int(current + lease) @@ -181,7 +176,8 @@ def is_valid(self, addr): def is_owner(self, owner, addr): is_valid = self.is_valid(addr) - if not is_valid: return False + if not is_valid: + return False _owner = self.owners.get(addr, None) return owner == _owner @@ -192,5 +188,6 @@ def _populate(self): self.map[addr] = 0 self.owners[addr] = None heapq.heappush(self.addrs, (0, addr)) - if addr == self.end_addr: break + if addr == self.end_addr: + break addr = AddressPool.get_next(addr) diff --git a/src/netius/common/dkim.py b/src/netius/common/dkim.py index f23cdd153..8974b7149 100644 --- a/src/netius/common/dkim.py +++ b/src/netius/common/dkim.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -50,19 +41,14 @@ from . import util from . import mime + def dkim_sign( - message, - selector, - domain, - private_key, - identity = None, - separator = ":", - creation = None + message, selector, domain, private_key, identity=None, separator=":", creation=None ): separator = netius.legacy.bytes(separator) identity = identity or "@" + domain - headers, body = mime.rfc822_parse(message, strip = False) + headers, body = mime.rfc822_parse(message, strip=False) if not identity.endswith(domain): raise netius.GeneratorError("Identity must end with domain") @@ -71,7 +57,9 @@ def dkim_sign( body = dkim_body(body) include_headers = [name.lower() for name, _value in headers] - sign_headers = [header for header in headers if header[0].lower() in include_headers] + sign_headers = [ + header for header in headers if header[0].lower() in include_headers + ] sign_names = [name for name, _value in sign_headers] hash = hashlib.sha256() @@ -118,13 +106,19 @@ def dkim_sign( digest = hash.digest() digest_info = asn.asn1_gen( - (asn.SEQUENCE, [ - (asn.SEQUENCE, [ - (asn.OBJECT_IDENTIFIER, asn.HASHID_SHA256), - (asn.NULL, None), - ]), - (asn.OCTET_STRING, digest), - ]) + ( + asn.SEQUENCE, + [ + ( + asn.SEQUENCE, + [ + (asn.OBJECT_IDENTIFIER, asn.HASHID_SHA256), + (asn.NULL, None), + ], + ), + (asn.OCTET_STRING, digest), + ], + ) ) modulus = private_key["modulus"] @@ -143,22 +137,25 @@ def dkim_sign( base_i = util.bytes_to_integer(base) signature_i = rsa.rsa_crypt(base_i, exponent, modulus) - signature_s = util.integer_to_bytes(signature_i, length = modulus_l) + signature_s = util.integer_to_bytes(signature_i, length=modulus_l) signature += base64.b64encode(signature_s) + b"\r\n" return signature + def dkim_headers(headers): # returns the headers exactly the way they were parsed # as this is the simple strategy approach return headers + def dkim_body(body): # remove the complete set of empty lines in the body # and adds only one line to the end of it as requested return re.sub(b"(\r\n)*$", b"\r\n", body) -def dkim_fold(header, length = 72): + +def dkim_fold(header, length=72): """ Folds a header line into multiple line feed separated lines at column length defined (defaults to 72). @@ -178,7 +175,8 @@ def dkim_fold(header, length = 72): """ index = header.rfind(b"\r\n ") - if index == -1: pre = b"" + if index == -1: + pre = b"" else: index += 3 pre = header[:index] @@ -186,23 +184,27 @@ def dkim_fold(header, length = 72): while len(header) > length: index = header[:length].rfind(b" ") - if index == -1: _index = index - else: _index = index + 1 + if index == -1: + _index = index + else: + _index = index + 1 pre += header[:index] + b"\r\n " header = header[_index:] return pre + header -def dkim_generate(domain, suffix = None, number_bits = 1024): + +def dkim_generate(domain, suffix=None, number_bits=1024): date_time = datetime.datetime.utcnow() selector = date_time.strftime("%Y%m%d%H%M%S") - if suffix: selector += "." + suffix + if suffix: + selector += "." + suffix selector_full = "%s._domainkey.%s." % (selector, domain) private_key = rsa.rsa_private(number_bits) - rsa.assert_private(private_key, number_bits = number_bits) + rsa.assert_private(private_key, number_bits=number_bits) public_key = rsa.private_to_public(private_key) buffer = netius.legacy.BytesIO() @@ -217,11 +219,11 @@ def dkim_generate(domain, suffix = None, number_bits = 1024): public_b64 = base64.b64encode(public_data) public_b64 = netius.legacy.str(public_b64) - dns_txt = "%s IN TXT \"k=rsa; p=%s\"" % (selector_full, public_b64) + dns_txt = '%s IN TXT "k=rsa; p=%s"' % (selector_full, public_b64) return dict( - selector = selector, - selector_full = selector_full, - private_pem = private_pem, - dns_txt = dns_txt + selector=selector, + selector_full=selector_full, + private_pem=private_pem, + dns_txt=dns_txt, ) diff --git a/src/netius/common/ftp.py b/src/netius/common/ftp.py index f849b10e5..c0efd97e7 100644 --- a/src/netius/common/ftp.py +++ b/src/netius/common/ftp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -48,9 +39,10 @@ the various response lines between the code and the message part, it should handle both normal and continuation lines """ + class FTPParser(parser.Parser): - def __init__(self, owner, store = False): + def __init__(self, owner, store=False): parser.Parser.__init__(self, owner) self.buffer = [] @@ -85,7 +77,8 @@ def parse(self, data): # zero the parsing iteration is broken method = self._parse_line count = method(data) - if count == 0: break + if count == 0: + break # decrements the size of the data buffer by the # size of the parsed bytes and then retrieves the @@ -96,7 +89,8 @@ def parse(self, data): # in case not all of the data has been processed # must add it to the buffer so that it may be used # latter in the next parsing of the message - if size > 0: self.buffer.append(data) + if size > 0: + self.buffer.append(data) # returns the number of read (processed) bytes of the # data that has been sent to the parser @@ -106,7 +100,8 @@ def _parse_line(self, data): # tries to find the new line character in the currently received # data in case it's not found returns immediately with no data processed index = data.find(b"\n") - if index == -1: return 0 + if index == -1: + return 0 # adds the partial data (until new line) to the current buffer and # then joins it retrieving the current line, then deletes the buffer @@ -120,7 +115,8 @@ def _parse_line(self, data): # the split is not successful (not enough information) then an extra # value is added to the sequence of values for compatibility values = SEPARATOR_REGEX.split(self.line_s, 1) - if not len(values) > 1: values.append("") + if not len(values) > 1: + values.append("") # unpacks the set of values that have just been parsed into the code # and the message items as expected by the ftp specification @@ -133,13 +129,15 @@ def _parse_line(self, data): line_l = len(self.line_s) space_index = self.line_s.find(" ") token_index = self.line_s.find("-") - if space_index == -1: space_index = line_l - if token_index == -1: token_index = line_l + if space_index == -1: + space_index = line_l + if token_index == -1: + token_index = line_l is_continuation = token_index < space_index is_final = not is_continuation # triggers the on line event so that the listeners are notified # about the end of the parsing of the ftp line and then # returns the count of the parsed bytes of the message - self.trigger("on_line", code, message, is_final = is_final) + self.trigger("on_line", code, message, is_final=is_final) return index + 1 diff --git a/src/netius/common/geo.py b/src/netius/common/geo.py index 2e323ac49..c40299cb3 100644 --- a/src/netius/common/geo.py +++ b/src/netius/common/geo.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,6 +34,7 @@ import netius + class GeoResolver(object): DB_NAME = "GeoLite2-City.mmdb" @@ -67,19 +59,24 @@ class GeoResolver(object): that is going to be used in the GeoIP resolution """ @classmethod - def resolve(cls, address, simplified = True): + def resolve(cls, address, simplified=True): db = cls._get_db() - if not db: return None + if not db: + return None result = db.get(address) - if simplified: result = cls._simplify(result) + if simplified: + result = cls._simplify(result) return result @classmethod - def _simplify(cls, result, locale = "en", valid = VALID): - if not result: return result + def _simplify(cls, result, locale="en", valid=VALID): + if not result: + return result for name, value in netius.legacy.items(result): - if not name in valid: del result[name] - if not "names" in value: continue + if not name in valid: + del result[name] + if not "names" in value: + continue names = value["names"] value["name"] = names.get(locale, None) del value["names"] @@ -87,69 +84,83 @@ def _simplify(cls, result, locale = "en", valid = VALID): @classmethod def _get_db(cls): - if cls._db: return cls._db - try: import maxminddb - except ImportError: return None + if cls._db: + return cls._db + try: + import maxminddb + except ImportError: + return None path = cls._try_all() - if not path: return None + if not path: + return None cls._db = maxminddb.open_database(path) return cls._db @classmethod - def _try_all(cls, prefixes = PREFIXES): + def _try_all(cls, prefixes=PREFIXES): for prefix in cls.PREFIXES: - path = cls._try_db(path = prefix + cls.DB_NAME) - if path: return path - path = cls._try_db(path = cls.DB_NAME, download = True) - if path: return path + path = cls._try_db(path=prefix + cls.DB_NAME) + if path: + return path + path = cls._try_db(path=cls.DB_NAME, download=True) + if path: + return path return None @classmethod - def _try_db(cls, path = DB_NAME, download = False): + def _try_db(cls, path=DB_NAME, download=False): path = os.path.expanduser(path) path = os.path.normpath(path) exists = os.path.exists(path) - if exists: return path - if not download: return None - cls._download_db(path = path) + if exists: + return path + if not download: + return None + cls._download_db(path=path) exists = not os.path.exists(path) - if not exists: return None + if not exists: + return None return path @classmethod - def _download_db(cls, path = DB_NAME): + def _download_db(cls, path=DB_NAME): import netius.clients + result = netius.clients.HTTPClient.method_s( - "GET", - cls.DOWNLOAD_URL, - asynchronous = False + "GET", cls.DOWNLOAD_URL, asynchronous=False ) response = netius.clients.HTTPClient.to_response(result) contents = response.read() - cls._store_db(contents, path = path) + cls._store_db(contents, path=path) @classmethod - def _store_db(cls, contents, path = DB_NAME): + def _store_db(cls, contents, path=DB_NAME): path_gz = path + ".gz" file = open(path_gz, "wb") - try: file.write(contents) - finally: file.close() + try: + file.write(contents) + finally: + file.close() file = gzip.open(path_gz, "rb") - try: contents = file.read() - finally: file.close() + try: + contents = file.read() + finally: + file.close() file = open(path, "wb") - try: file.write(contents) - finally: file.close() + try: + file.write(contents) + finally: + file.close() os.remove(path_gz) return path + if __name__ == "__main__": prefix = "~/" - if len(sys.argv) > 1: prefix = sys.argv[1] - if not prefix.endswith("/"): prefix += "/" - GeoResolver._try_db( - path = prefix + GeoResolver.DB_NAME, - download = True - ) + if len(sys.argv) > 1: + prefix = sys.argv[1] + if not prefix.endswith("/"): + prefix += "/" + GeoResolver._try_db(path=prefix + GeoResolver.DB_NAME, download=True) else: __path__ = [] diff --git a/src/netius/common/http.py b/src/netius/common/http.py index 9a0f949c9..50167ade7 100644 --- a/src/netius/common/http.py +++ b/src/netius/common/http.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -114,56 +105,52 @@ most commonly used nowadays, connection running under this version of the protocol should keep connections open """ -VERSIONS_MAP = { - "HTTP/0.9" : HTTP_09, - "HTTP/1.0" : HTTP_10, - "HTTP/1.1" : HTTP_11 -} +VERSIONS_MAP = {"HTTP/0.9": HTTP_09, "HTTP/1.0": HTTP_10, "HTTP/1.1": HTTP_11} """ Maps associating the standard HTTP version string with the corresponding enumeration based values for each of them """ CODE_STRINGS = { - 100 : "Continue", - 101 : "Switching Protocols", - 200 : "OK", - 201 : "Created", - 202 : "Accepted", - 203 : "Non-Authoritative Information", - 204 : "No Content", - 205 : "Reset Content", - 206 : "Partial Content", - 207 : "Multi-Status", - 301 : "Moved permanently", - 302 : "Found", - 303 : "See Other", - 304 : "Not Modified", - 305 : "Use Proxy", - 306 : "(Unused)", - 307 : "Temporary Redirect", - 400 : "Bad Request", - 401 : "Unauthorized", - 402 : "Payment Required", - 403 : "Forbidden", - 404 : "Not Found", - 405 : "Method Not Allowed", - 406 : "Not Acceptable", - 407 : "Proxy Authentication Required", - 408 : "Request Timeout", - 409 : "Conflict", - 410 : "Gone", - 411 : "Length Required", - 412 : "Precondition Failed", - 413 : "Request Entity Too Large", - 414 : "Request-URI Too Long", - 415 : "Unsupported Media Type", - 416 : "Requested Range Not Satisfiable", - 417 : "Expectation Failed", - 500 : "Internal Server Error", - 501 : "Not Implemented", - 502 : "Bad Gateway", - 503 : "Service Unavailable", - 504 : "Gateway Timeout", - 505 : "HTTP Version Not Supported" + 100: "Continue", + 101: "Switching Protocols", + 200: "OK", + 201: "Created", + 202: "Accepted", + 203: "Non-Authoritative Information", + 204: "No Content", + 205: "Reset Content", + 206: "Partial Content", + 207: "Multi-Status", + 301: "Moved permanently", + 302: "Found", + 303: "See Other", + 304: "Not Modified", + 305: "Use Proxy", + 306: "(Unused)", + 307: "Temporary Redirect", + 400: "Bad Request", + 401: "Unauthorized", + 402: "Payment Required", + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 406: "Not Acceptable", + 407: "Proxy Authentication Required", + 408: "Request Timeout", + 409: "Conflict", + 410: "Gone", + 411: "Length Required", + 412: "Precondition Failed", + 413: "Request Entity Too Large", + 414: "Request-URI Too Long", + 415: "Unsupported Media Type", + 416: "Requested Range Not Satisfiable", + 417: "Expectation Failed", + 500: "Internal Server Error", + 501: "Not Implemented", + 502: "Bad Gateway", + 503: "Service Unavailable", + 504: "Gateway Timeout", + 505: "HTTP Version Not Supported", } """ Dictionary associating the error code as integers with the official descriptive message for it """ @@ -173,6 +160,7 @@ header naming tokens, so that only the valid names are captured avoiding possible security issues, should be compliant with RFC 7230 """ + class HTTPParser(parser.Parser): """ Parser object for the HTTP format, should be able to @@ -214,20 +202,14 @@ class HTTPParser(parser.Parser): "chunk_d", "chunk_l", "chunk_s", - "chunk_e" + "chunk_e", ) - def __init__( - self, - owner, - type = REQUEST, - store = False, - file_limit = FILE_LIMIT - ): + def __init__(self, owner, type=REQUEST, store=False, file_limit=FILE_LIMIT): parser.Parser.__init__(self, owner) self.build() - self.reset(type = type, store = store, file_limit = file_limit) + self.reset(type=type, store=store, file_limit=file_limit) def build(self): """ @@ -240,11 +222,7 @@ def build(self): self.connection = self.owner - self.states = ( - self._parse_line, - self._parse_headers, - self._parse_message - ) + self.states = (self._parse_line, self._parse_headers, self._parse_message) self.state_l = len(self.states) def destroy(self): @@ -263,7 +241,7 @@ def destroy(self): self.states = () self.state_l = 0 - def reset(self, type = REQUEST, store = False, file_limit = FILE_LIMIT): + def reset(self, type=REQUEST, store=False, file_limit=FILE_LIMIT): """ Initializes the state of the parser setting the values for the various internal structures to the original value. @@ -314,19 +292,18 @@ def reset(self, type = REQUEST, store = False, file_limit = FILE_LIMIT): self.chunk_s = 0 self.chunk_e = 0 - def clear(self, force = False): - if not force and self.state == LINE_STATE: return - self.reset( - type = self.type, - store = self.store, - file_limit = self.file_limit - ) + def clear(self, force=False): + if not force and self.state == LINE_STATE: + return + self.reset(type=self.type, store=self.store, file_limit=self.file_limit) def close(self): - if hasattr(self, "message") and self.message: self.message = [] - if hasattr(self, "message_f") and self.message_f: self.message_f.close() + if hasattr(self, "message") and self.message: + self.message = [] + if hasattr(self, "message_f") and self.message_f: + self.message_f.close() - def get_path(self, normalize = False): + def get_path(self, normalize=False): """ Retrieves the path associated with the request, this value should be interpreted from the HTTP status line. @@ -344,8 +321,10 @@ def get_path(self, normalize = False): split = self.path_s.split("?", 1) path = split[0] - if not normalize: return path - if not path.startswith(("http://", "https://")): return path + if not normalize: + return path + if not path.startswith(("http://", "https://")): + return path return netius.legacy.urlparse(path).path def get_query(self): @@ -362,8 +341,10 @@ def get_query(self): """ split = self.path_s.split("?", 1) - if len(split) == 1: return "" - else: return split[1] + if len(split) == 1: + return "" + else: + return split[1] def get_message(self): """ @@ -382,16 +363,19 @@ def get_message(self): string value that may be used as a simple buffer. """ - if self.message_s: return self.message_s - if self.message_f: self.message_s = self.get_message_f() - else: self.message_s = b"".join(self.message) + if self.message_s: + return self.message_s + if self.message_f: + self.message_s = self.get_message_f() + else: + self.message_s = b"".join(self.message) return self.message_s def get_message_f(self): self.message_f.seek(0) return self.message_f.read() - def get_message_b(self, copy = False, size = 40960): + def get_message_b(self, copy=False, size=40960): """ Retrieves a new buffer associated with the currently loaded message, the first time this method is called a @@ -424,26 +408,31 @@ def get_message_b(self, copy = False, size = 40960): # and writes the value of the message into it if not self.message_f: self.message_f = netius.legacy.BytesIO() - for value in self.message: self.message_f.write(value) + for value in self.message: + self.message_f.write(value) # restores the message file to the original/initial position and # then in case there's no copy required returns it immediately self.message_f.seek(0) - if not copy: return self.message_f + if not copy: + return self.message_f # determines if the file limit for a temporary file has been # surpassed and if that's the case creates a named temporary # file, otherwise created a memory based buffer use_file = self.store and self.content_l >= self.file_limit - if use_file: message_f = tempfile.NamedTemporaryFile(mode = "w+b") - else: message_f = netius.legacy.BytesIO() + if use_file: + message_f = tempfile.NamedTemporaryFile(mode="w+b") + else: + message_f = netius.legacy.BytesIO() try: # iterates continuously reading the contents from the message # file and writing them back to the output (copy) file while True: data = self.message_f.read(size) - if not data: break + if not data: + break message_f.write(data) finally: # resets both of the message file (output and input) to the @@ -464,7 +453,8 @@ def get_headers(self): return headers def get_encodings(self): - if not self.encodings == None: return self.encodings + if not self.encodings == None: + return self.encodings accept_encoding_s = self.headers.get("accept-encoding", "") self.encodings = [value.strip() for value in accept_encoding_s.split(",")] return self.encodings @@ -488,7 +478,8 @@ def parse(self, data): # in case the current state of the parser is finished, must # reset the state to the start position as the parser is # re-starting (probably a new data sequence) - if self.state == FINISH_STATE: self.clear() + if self.state == FINISH_STATE: + self.clear() # retrieves the size of the data that has been sent for parsing # and saves it under the size original variable @@ -510,7 +501,8 @@ def parse(self, data): # zero the parsing iteration is broken method = self.states[self.state - 1] count = method(data) - if count == 0: break + if count == 0: + break # decrements the size of the data buffer by the # size of the parsed bytes and then retrieves the @@ -537,7 +529,8 @@ def parse(self, data): # in case not all of the data has been processed # must add it to the buffer so that it may be used # latter in the next parsing of the message - if size > 0: self.buffer.append(data) + if size > 0: + self.buffer.append(data) # returns the number of read (processed) bytes of the # data that has been sent to the parser @@ -548,7 +541,8 @@ def _parse_line(self, data): # data in case there's one it's considered that the the # initial line must have been found index = data.find(b"\n") - if index == -1: return 0 + if index == -1: + return 0 # adds the partial data (until line ending) to the buffer # and then joins the buffer as the initial line, this value @@ -570,10 +564,10 @@ def _parse_line(self, data): # that for responses the parsing is relaxed as the status string # can be an empty string (no message to be presented) values = self.line_s.split(" ", 2) - if self.type == RESPONSE and len(values) == 2: values.append("") - if not len(values) == 3: raise netius.ParserError( - "Invalid status line '%s'" % self.line_s - ) + if self.type == RESPONSE and len(values) == 2: + values.append("") + if not len(values) == 3: + raise netius.ParserError("Invalid status line '%s'" % self.line_s) # determines if the current type of parsing is request based # and if that's the case unpacks the status line as a request @@ -613,7 +607,8 @@ def _parse_headers(self, data): # it's not found returns the zero value meaning that # the no bytes have been processed (delays parsing) index = buffer_s.find(b"\r\n\r\n") - if index == -1: return 0 + if index == -1: + return 0 # retrieves the partial headers string from the buffer # string and then deletes the current buffer so that @@ -636,7 +631,8 @@ def _parse_headers(self, data): # verifies if the line contains any information if # that's not the case the current cycle must be # skipped as this may be an extra empty line - if not line: continue + if not line: + continue # tries to split the line around the key to value # separator in case there's no valid split (two @@ -661,7 +657,7 @@ def _parse_headers(self, data): # both the beginning and the end of it, then makes sure that # no extra "space like" character exist in it value = value.strip(b" ") - value = netius.legacy.str(value, errors = "replace") + value = netius.legacy.str(value, errors="replace") if not value == value.strip(): raise netius.ParserError("Invalid header value") @@ -672,7 +668,8 @@ def _parse_headers(self, data): if exists: sequence = self.headers[key] is_list = type(sequence) == list - if not is_list: sequence = [sequence] + if not is_list: + sequence = [sequence] sequence.append(value) value = sequence @@ -690,7 +687,8 @@ def _parse_headers(self, data): # the file contents, this is done by checking the store flag # and verifying that the file limit value has been reached use_file = self.store and self.content_l >= self.file_limit - if use_file: self.message_f = tempfile.NamedTemporaryFile(mode = "w+b") + if use_file: + self.message_f = tempfile.NamedTemporaryFile(mode="w+b") # retrieves the type of transfer encoding that is going to be # used in the processing of this request in case it's of type @@ -712,22 +710,23 @@ def _parse_headers(self, data): # in case the current response in parsing has the no content # code (no payload present) the content length is set to the # zero value in case it has not already been populated - if self.type == RESPONSE and self.code in (204, 304) and\ - self.content_l == -1: self.content_l = 0 + if self.type == RESPONSE and self.code in (204, 304) and self.content_l == -1: + self.content_l = 0 # in case the current request is not chunked and the content length # header is not defined the content length is set to zero because # for normal requests with payload the content length is required # and if it's omitted it means there's no payload present - if self.type == REQUEST and not self.chunked and\ - self.content_l == -1: self.content_l = 0 + if self.type == REQUEST and not self.chunked and self.content_l == -1: + self.content_l = 0 # verifies if the connection is meant to be kept alive by # verifying the current value of the connection header against # the expected keep alive string value, note that the verification # takes into account a possible list value in connection self.connection_s = self.headers.get("connection", None) - if type(self.connection_s) == list: self.connection_s = self.connection_s[0] + if type(self.connection_s) == list: + self.connection_s = self.connection_s[0] self.connection_s = self.connection_s and self.connection_s.lower() self.keep_alive = self.connection_s == "keep-alive" self.keep_alive |= self.connection_s == None and self.version >= HTTP_11 @@ -738,19 +737,24 @@ def _parse_headers(self, data): # updates the current state of parsing to the message state # as that the headers are followed by the message - if has_finished: self.state = FINISH_STATE - else: self.state = MESSAGE_STATE + if has_finished: + self.state = FINISH_STATE + else: + self.state = MESSAGE_STATE # triggers the on headers event so that the listener object # is notified about the parsing of the headers and than returns # the parsed amount of information (bytes) to the caller self.trigger("on_headers") - if has_finished: self.trigger("on_data") + if has_finished: + self.trigger("on_data") return base_index + 4 def _parse_message(self, data): - if self.chunked: return self._parse_chunked(data) - else: return self._parse_normal(data) + if self.chunked: + return self._parse_chunked(data) + else: + return self._parse_normal(data) def _parse_normal(self, data): # retrieves the size of the data that has just been @@ -758,21 +762,22 @@ def _parse_normal(self, data): # stores the data in the proper buffer and increments # the message length counter with the size of the data data_l = len(data) - if self.store: self._store_data(data) + if self.store: + self._store_data(data) self.message_l += data_l # verifies if the complete message has already been # received, that occurs if the content length is # defined and the value is the is the same as the # currently defined message length - has_finished = not self.content_l == -1 and\ - self.message_l == self.content_l + has_finished = not self.content_l == -1 and self.message_l == self.content_l # triggers the partial data received event and then # in case the complete message has not been received # returns immediately the length of processed data self.trigger("on_partial", data) - if not has_finished: return data_l + if not has_finished: + return data_l # updates the current state to the finish state and then # triggers the on data event (indicating the end of the @@ -832,7 +837,8 @@ def _parse_chunked(self, data): # case the file storage mode is active (spares memory), # deletes the contents of the message buffer as they're # not going to be used to access request's data as a whole - if not self.store or self.message_f: del self.message[:] + if not self.store or self.message_f: + del self.message[:] # returns the number of bytes that have been parsed by # the current end of chunk operation to the caller method @@ -846,7 +852,8 @@ def _parse_chunked(self, data): # tries to find the separator of the initial value for # the chunk in case it's not found returns immediately index = data.find(b"\n") - if index == -1: return 0 + if index == -1: + return 0 # some of the current data to the buffer and then re-joins # it as the header value, then removes the complete set of @@ -857,14 +864,14 @@ def _parse_chunked(self, data): # sets the new data buffer as the partial buffer of the data # except the extra newline character (not required) - data = data[index + 1:] + data = data[index + 1 :] # splits the header value so that additional chunk information # is removed and then parsed the value as the original chunk # size (dimension) adding the two extra bytes to the length header_s = header.split(b";", 1) size = header_s[0] - self.chunk_d = int(size.strip(), base = 16) + self.chunk_d = int(size.strip(), base=16) self.chunk_l = self.chunk_d + 2 self.chunk_s = len(self.message) @@ -876,7 +883,7 @@ def _parse_chunked(self, data): # retrieves the partial data that is valid according to the # calculated chunk length and then calculates the size of # "that" partial data string value - data = data[:self.chunk_l - 2] + data = data[: self.chunk_l - 2] data_s = len(data) # adds the partial data to the message list and runs the store operation @@ -885,28 +892,34 @@ def _parse_chunked(self, data): # the message buffer is used even if the store flag is not set, so that # it's possible to refer the chunk as a tuple of start and end indexes when # triggering the chunk parsed (on chunk) event (performance gains) - if data: self.message.append(data) - if data and self.store: self._store_data(data, memory = False) + if data: + self.message.append(data) + if data and self.store: + self._store_data(data, memory=False) self.chunk_l -= data_s # in case there's data parsed the partial data event # is triggered to notify handlers about the new data - if data: self.trigger("on_partial", data) + if data: + self.trigger("on_partial", data) # increments the byte counter value by the size of the data # and then returns the same counter to the caller method count += data_s return count - def _store_data(self, data, memory = True): - if not self.store: raise netius.ParserError("Store is not possible") - if self.message_f: self.message_f.write(data) - elif memory: self.message.append(data) + def _store_data(self, data, memory=True): + if not self.store: + raise netius.ParserError("Store is not possible") + if self.message_f: + self.message_f.write(data) + elif memory: + self.message.append(data) def _parse_query(self, query): # runs the "default" parsing of the query string from the system # and then decodes the complete set of parameters properly - params = netius.legacy.parse_qs(query, keep_blank_values = True) + params = netius.legacy.parse_qs(query, keep_blank_values=True) return self._decode_params(params) def _decode_params(self, params): @@ -916,17 +929,20 @@ def _decode_params(self, params): items = [] for item in value: is_bytes = netius.legacy.is_bytes(item) - if is_bytes: item = item.decode("utf-8") + if is_bytes: + item = item.decode("utf-8") items.append(item) is_bytes = netius.legacy.is_bytes(key) - if is_bytes: key = key.decode("utf-8") + if is_bytes: + key = key.decode("utf-8") _params[key] = items return _params + class HTTPResponse(object): - def __init__(self, data = None, code = 200, status = None, headers = None): + def __init__(self, data=None, code=200, status=None, headers=None): self.data = data self.code = code self.status = status diff --git a/src/netius/common/http2.py b/src/netius/common/http2.py index 9995bceb9..179bc1b78 100644 --- a/src/netius/common/http2.py +++ b/src/netius/common/http2.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -70,10 +61,10 @@ REFUSED_STREAM = 0x07 CANCEL = 0x08 COMPRESSION_ERROR = 0x09 -CONNECT_ERROR = 0x0a -ENHANCE_YOUR_CALM = 0x0b -INADEQUATE_SECURITY = 0x0c -HTTP_1_1_REQUIRED = 0x0d +CONNECT_ERROR = 0x0A +ENHANCE_YOUR_CALM = 0x0B +INADEQUATE_SECURITY = 0x0C +HTTP_1_1_REQUIRED = 0x0D SETTINGS_HEADER_TABLE_SIZE = 0x01 SETTINGS_ENABLE_PUSH = 0x02 @@ -121,43 +112,43 @@ (SETTINGS_MAX_CONCURRENT_STREAMS, "SETTINGS_MAX_CONCURRENT_STREAMS"), (SETTINGS_INITIAL_WINDOW_SIZE, "SETTINGS_INITIAL_WINDOW_SIZE"), (SETTINGS_MAX_FRAME_SIZE, "SETTINGS_MAX_FRAME_SIZE"), - (SETTINGS_MAX_HEADER_LIST_SIZE, "SETTINGS_MAX_HEADER_LIST_SIZE") + (SETTINGS_MAX_HEADER_LIST_SIZE, "SETTINGS_MAX_HEADER_LIST_SIZE"), ) """ The sequence of tuple that associate the constant value of the setting with the proper string representation for it """ HTTP2_NAMES = { - DATA : "DATA", - HEADERS : "HEADERS", - PRIORITY : "PRIORITY", - RST_STREAM : "RST_STREAM", - SETTINGS : "SETTINGS", - PUSH_PROMISE : "PUSH_PROMISE", - PING : "PING", - GOAWAY : "GOAWAY", - WINDOW_UPDATE : "WINDOW_UPDATE", - CONTINUATION : "CONTINUATION" + DATA: "DATA", + HEADERS: "HEADERS", + PRIORITY: "PRIORITY", + RST_STREAM: "RST_STREAM", + SETTINGS: "SETTINGS", + PUSH_PROMISE: "PUSH_PROMISE", + PING: "PING", + GOAWAY: "GOAWAY", + WINDOW_UPDATE: "WINDOW_UPDATE", + CONTINUATION: "CONTINUATION", } """ The association between the various types of frames described as integers and their representation as strings """ HTTP2_SETTINGS = { - SETTINGS_HEADER_TABLE_SIZE : 4096, - SETTINGS_ENABLE_PUSH : 1, - SETTINGS_MAX_CONCURRENT_STREAMS : 128, - SETTINGS_INITIAL_WINDOW_SIZE : 65535, - SETTINGS_MAX_FRAME_SIZE : 16384, - SETTINGS_MAX_HEADER_LIST_SIZE : 16384 + SETTINGS_HEADER_TABLE_SIZE: 4096, + SETTINGS_ENABLE_PUSH: 1, + SETTINGS_MAX_CONCURRENT_STREAMS: 128, + SETTINGS_INITIAL_WINDOW_SIZE: 65535, + SETTINGS_MAX_FRAME_SIZE: 16384, + SETTINGS_MAX_HEADER_LIST_SIZE: 16384, } """ The default values to be used for settings of a newly created connection, this should be defined according to specification """ HTTP2_SETTINGS_OPTIMAL = { - SETTINGS_HEADER_TABLE_SIZE : 4096, - SETTINGS_MAX_CONCURRENT_STREAMS : 512, - SETTINGS_INITIAL_WINDOW_SIZE : 1048576, - SETTINGS_MAX_FRAME_SIZE : 131072, - SETTINGS_MAX_HEADER_LIST_SIZE : 16384 + SETTINGS_HEADER_TABLE_SIZE: 4096, + SETTINGS_MAX_CONCURRENT_STREAMS: 512, + SETTINGS_INITIAL_WINDOW_SIZE: 1048576, + SETTINGS_MAX_FRAME_SIZE: 131072, + SETTINGS_MAX_HEADER_LIST_SIZE: 16384, } """ The optimal settings meant to be used by an infra-structure deployed in a production environment """ @@ -168,6 +159,7 @@ HTTP2_SETTINGS_OPTIMAL_T = netius.legacy.items(HTTP2_SETTINGS_OPTIMAL) """ The tuple sequence version of the settings optimal """ + class HTTP2Parser(parser.Parser): FIELDS = ( @@ -183,22 +175,14 @@ class HTTP2Parser(parser.Parser): "end_headers", "last_type", "last_stream", - "last_end_headers" + "last_end_headers", ) - def __init__( - self, - owner, - store = False, - file_limit = http.FILE_LIMIT - ): + def __init__(self, owner, store=False, file_limit=http.FILE_LIMIT): parser.Parser.__init__(self, owner) self.build() - self.reset( - store = store, - file_limit = file_limit - ) + self.reset(store=store, file_limit=file_limit) def build(self): """ @@ -211,10 +195,7 @@ def build(self): self.connection = self.owner - self.states = ( - self._parse_header, - self._parse_payload - ) + self.states = (self._parse_header, self._parse_payload) self.state_l = len(self.states) self.parsers = ( @@ -227,7 +208,7 @@ def build(self): self._parse_ping, self._parse_goaway, self._parse_window_update, - self._parse_continuation + self._parse_continuation, ) self.streams = {} @@ -248,7 +229,8 @@ def destroy(self): # them as the parser is now going to be destroyed and they cannot # be reached any longer (invalidated state) streams = netius.legacy.values(self.streams) - for stream in streams: stream.close() + for stream in streams: + stream.close() self.connection = None self.states = () @@ -261,9 +243,7 @@ def destroy(self): def info_dict(self): info = parser.Parser.info_dict(self) - info.update( - streams = self.info_streams() - ) + info.update(streams=self.info_streams()) return info def info_streams(self): @@ -276,11 +256,7 @@ def info_streams(self): info.append(item) return info - def reset( - self, - store = False, - file_limit = http.FILE_LIMIT - ): + def reset(self, store=False, file_limit=http.FILE_LIMIT): self.store = store self.file_limit = file_limit self.state = HEADER_STATE @@ -297,16 +273,15 @@ def reset( self.last_stream = 0 self.last_end_headers = False - def clear(self, force = False, save = True): - if not force and self.state == HEADER_STATE: return + def clear(self, force=False, save=True): + if not force and self.state == HEADER_STATE: + return type = self.type stream = self.stream end_headers = self.end_headers - self.reset( - store = self.store, - file_limit = self.file_limit - ) - if not save: return + self.reset(store=self.store, file_limit=self.file_limit) + if not save: + return self.last_type = type self.last_stream = stream self.last_end_headers = end_headers @@ -333,7 +308,8 @@ def parse(self, data): # in case the current state of the parser is finished, must # reset the state to the start position as the parser is # re-starting (probably a new data sequence) - if self.state == FINISH_STATE: self.clear() + if self.state == FINISH_STATE: + self.clear() # retrieves the size of the data that has been sent for parsing # and saves it under the size original variable @@ -347,8 +323,10 @@ def parse(self, data): if self.state <= self.state_l: method = self.states[self.state - 1] count = method(data) - if count == -1: break - if count == 0: continue + if count == -1: + break + if count == 0: + continue size -= count data = data[count:] @@ -366,7 +344,8 @@ def parse(self, data): # in case not all of the data has been processed # must add it to the buffer so that it may be used # latter in the next parsing of the message - if size > 0: self.buffer.append(data) + if size > 0: + self.buffer.append(data) # returns the number of read (processed) bytes of the # data that has been sent to the parser @@ -396,188 +375,174 @@ def assert_header(self): if self.length > self.owner.settings[SETTINGS_MAX_FRAME_SIZE]: raise netius.ParserError( "Headers are greater than SETTINGS_MAX_FRAME_SIZE", - stream = self.stream, - error_code = FRAME_SIZE_ERROR + stream=self.stream, + error_code=FRAME_SIZE_ERROR, ) - if self.last_type in (HEADERS, CONTINUATION) and not\ - self.last_end_headers and not self.last_stream == self.stream: + if ( + self.last_type in (HEADERS, CONTINUATION) + and not self.last_end_headers + and not self.last_stream == self.stream + ): raise netius.ParserError( "Cannot send frame from a different stream in middle of headers", - error_code = PROTOCOL_ERROR + error_code=PROTOCOL_ERROR, ) def assert_stream(self, stream): if not stream.identifier % 2 == 1: raise netius.ParserError( - "Stream identifiers must be odd", - error_code = PROTOCOL_ERROR + "Stream identifiers must be odd", error_code=PROTOCOL_ERROR ) if stream.dependency == stream.identifier: raise netius.ParserError( - "Stream cannot depend on itself", - error_code = PROTOCOL_ERROR + "Stream cannot depend on itself", error_code=PROTOCOL_ERROR ) if len(self.streams) >= self.owner.settings[SETTINGS_MAX_CONCURRENT_STREAMS]: raise netius.ParserError( "Too many streams (greater than SETTINGS_MAX_CONCURRENT_STREAMS)", - stream = self.stream, - error_code = PROTOCOL_ERROR + stream=self.stream, + error_code=PROTOCOL_ERROR, ) def assert_data(self, stream, end_stream): if self.stream == 0x00: raise netius.ParserError( - "Stream cannot be set to 0x00 for DATA", - error_code = PROTOCOL_ERROR + "Stream cannot be set to 0x00 for DATA", error_code=PROTOCOL_ERROR ) if not stream.end_headers: raise netius.ParserError( "Not ready to receive DATA open", - stream = self.stream, - error_code = PROTOCOL_ERROR + stream=self.stream, + error_code=PROTOCOL_ERROR, ) if stream.end_stream and stream.end_headers: raise netius.ParserError( "Not ready to receive DATA half closed (remote)", - stream = self.stream, - error_code = STREAM_CLOSED + stream=self.stream, + error_code=STREAM_CLOSED, ) def assert_headers(self, stream, end_stream): if stream.end_stream and stream.end_headers: raise netius.ParserError( "Not ready to receive HEADERS half closed (remote)", - stream = self.stream, - error_code = STREAM_CLOSED + stream=self.stream, + error_code=STREAM_CLOSED, ) if not end_stream: raise netius.ParserError( "Second HEADERS without END_STREAM flag", - stream = self.stream, - error_code = PROTOCOL_ERROR + stream=self.stream, + error_code=PROTOCOL_ERROR, ) def assert_priority(self, stream, dependency): if self.stream == 0x00: raise netius.ParserError( - "Stream cannot be set to 0x00 for PRIORITY", - error_code = PROTOCOL_ERROR + "Stream cannot be set to 0x00 for PRIORITY", error_code=PROTOCOL_ERROR ) if dependency == self.stream: raise netius.ParserError( - "Stream cannot depend on current stream", - error_code = PROTOCOL_ERROR + "Stream cannot depend on current stream", error_code=PROTOCOL_ERROR ) if stream and dependency == stream.identifier: raise netius.ParserError( - "Stream cannot depend on itself", - error_code = PROTOCOL_ERROR + "Stream cannot depend on itself", error_code=PROTOCOL_ERROR ) def assert_rst_stream(self, stream): if self.stream == 0x00: raise netius.ParserError( - "Stream cannot be set to 0x00 for RST_STREAM", - error_code = PROTOCOL_ERROR + "Stream cannot be set to 0x00 for RST_STREAM", error_code=PROTOCOL_ERROR ) if self.stream > self._max_stream: raise netius.ParserError( - "Stream has not been created for RST_STREAM", - error_code = PROTOCOL_ERROR + "Stream has not been created for RST_STREAM", error_code=PROTOCOL_ERROR ) - def assert_settings(self, settings, ack, extended = True): + def assert_settings(self, settings, ack, extended=True): if not self.stream == 0x00: raise netius.ParserError( - "Stream must be set to 0x00 for SETTINGS", - error_code = PROTOCOL_ERROR + "Stream must be set to 0x00 for SETTINGS", error_code=PROTOCOL_ERROR ) if ack and not self.length == 0: raise netius.ParserError( - "SETTINGS with ACK must be zero length", - error_code = FRAME_SIZE_ERROR + "SETTINGS with ACK must be zero length", error_code=FRAME_SIZE_ERROR ) if not self.length % 6 == 0: raise netius.ParserError( "Size of SETTINGS frame must be a multiple of 6", - error_code = FRAME_SIZE_ERROR + error_code=FRAME_SIZE_ERROR, ) - if not extended: return + if not extended: + return settings = dict(settings) if not settings.get(SETTINGS_ENABLE_PUSH, 0) in (0, 1): raise netius.ParserError( "Value of SETTINGS_ENABLE_PUSH different from 0 or 1", - error_code = PROTOCOL_ERROR + error_code=PROTOCOL_ERROR, ) if settings.get(SETTINGS_INITIAL_WINDOW_SIZE, 0) > 2147483647: raise netius.ParserError( "Value of SETTINGS_INITIAL_WINDOW_SIZE too large", - error_code = FLOW_CONTROL_ERROR + error_code=FLOW_CONTROL_ERROR, ) if settings.get(SETTINGS_MAX_FRAME_SIZE, 16384) < 16384: raise netius.ParserError( - "Value of SETTINGS_MAX_FRAME_SIZE too small", - error_code = PROTOCOL_ERROR + "Value of SETTINGS_MAX_FRAME_SIZE too small", error_code=PROTOCOL_ERROR ) if settings.get(SETTINGS_MAX_FRAME_SIZE, 16384) > 16777215: raise netius.ParserError( - "Value of SETTINGS_MAX_FRAME_SIZE too large", - error_code = PROTOCOL_ERROR + "Value of SETTINGS_MAX_FRAME_SIZE too large", error_code=PROTOCOL_ERROR ) def assert_push_promise(self, promised_stream): raise netius.ParserError( - "PUSH_PROMISE not allowed for server", - error_code = PROTOCOL_ERROR + "PUSH_PROMISE not allowed for server", error_code=PROTOCOL_ERROR ) def assert_ping(self): if not self.stream == 0x00: raise netius.ParserError( - "Stream must be set to 0x00 for PING", - error_code = PROTOCOL_ERROR + "Stream must be set to 0x00 for PING", error_code=PROTOCOL_ERROR ) if not self.length == 8: raise netius.ParserError( - "Size of PING frame must be 8", - error_code = FRAME_SIZE_ERROR + "Size of PING frame must be 8", error_code=FRAME_SIZE_ERROR ) def assert_goaway(self): if not self.stream == 0x00: raise netius.ParserError( - "Stream must be set to 0x00 for GOAWAY", - error_code = PROTOCOL_ERROR + "Stream must be set to 0x00 for GOAWAY", error_code=PROTOCOL_ERROR ) def assert_window_update(self, stream, increment): if increment == 0: raise netius.ParserError( - "WINDOW_UPDATE increment must not be zero", - error_code = PROTOCOL_ERROR + "WINDOW_UPDATE increment must not be zero", error_code=PROTOCOL_ERROR ) if self.owner.window + increment > 2147483647: raise netius.ParserError( "Window value for the connection too large", - error_code = FLOW_CONTROL_ERROR + error_code=FLOW_CONTROL_ERROR, ) if stream and stream.window + increment > 2147483647: raise netius.ParserError( - "Window value for the stream too large", - error_code = FLOW_CONTROL_ERROR + "Window value for the stream too large", error_code=FLOW_CONTROL_ERROR ) def assert_continuation(self, stream): if stream.end_stream and stream.end_headers: raise netius.ParserError( "Not ready to receive CONTINUATION half closed (remote)", - stream = self.stream, - error_code = PROTOCOL_ERROR + stream=self.stream, + error_code=PROTOCOL_ERROR, ) if not self.last_type in (HEADERS, PUSH_PROMISE, CONTINUATION): raise netius.ParserError( "CONTINUATION without HEADERS, PUSH_PROMISE or CONTINUATION before", - error_code = PROTOCOL_ERROR + error_code=PROTOCOL_ERROR, ) @property @@ -585,7 +550,8 @@ def type_s(self): return self.get_type_s(self.type) def _parse_header(self, data): - if len(data) + self.buffer_size < HEADER_SIZE: return -1 + if len(data) + self.buffer_size < HEADER_SIZE: + return -1 size = HEADER_SIZE - self.buffer_size data = self.buffer_data + data[:size] @@ -602,13 +568,15 @@ def _parse_header(self, data): return size def _parse_payload(self, data): - if len(data) + self.buffer_size < self.length: return -1 + if len(data) + self.buffer_size < self.length: + return -1 size = self.length - self.buffer_size data = self.buffer_data + data[:size] valid_type = self.type < len(self.parsers) - if not valid_type: self._invalid_type() + if not valid_type: + self._invalid_type() self.payload = data self.trigger("on_payload") @@ -631,10 +599,10 @@ def _parse_data(self, data): padded_l = 0 if padded: - padded_l, = struct.unpack("!B", data[index:index + 1]) + (padded_l,) = struct.unpack("!B", data[index : index + 1]) index += 1 - contents = data[index:data_l - padded_l] + contents = data[index : data_l - padded_l] stream = self._get_stream(self.stream) self.assert_data(stream, end_stream) @@ -645,7 +613,8 @@ def _parse_data(self, data): self.trigger("on_data_h2", stream, contents) self.trigger("on_partial", contents) - if stream.is_ready: self.trigger("on_data") + if stream.is_ready: + self.trigger("on_data") def _parse_headers(self, data): data_l = len(data) @@ -662,18 +631,18 @@ def _parse_headers(self, data): exclusive = 0 if padded: - padded_l, = struct.unpack("!B", data[index:index + 1]) + (padded_l,) = struct.unpack("!B", data[index : index + 1]) index += 1 if priority: - dependency, weight = struct.unpack("!IB", data[index:index + 5]) + dependency, weight = struct.unpack("!IB", data[index : index + 5]) exclusive = True if dependency & 0x80000000 else False - dependency = dependency & 0x7fffffff + dependency = dependency & 0x7FFFFFFF index += 5 # retrieves the (headers) fragment part of the payload, this is # going to be used as the basis for the header decoding - fragment = data[index:data_l - padded_l] + fragment = data[index : data_l - padded_l] # retrieves the value of the window initial size from the owner # connection this is the value to be set in the new stream and @@ -685,35 +654,40 @@ def _parse_headers(self, data): # tries to retrieve a previously opened stream and, this may be # the case it has been opened by a previous frame operation - stream = self._get_stream(self.stream, strict = False, closed_s = True) + stream = self._get_stream(self.stream, strict=False, closed_s=True) if stream: # runs the headers assertion operation and then updated the # various elements in the currently opened stream accordingly self.assert_headers(stream, end_stream) stream.extend_headers(fragment) - if dependency: stream.dependency = dependency - if weight: stream.weight = weight - if exclusive: stream.exclusive = exclusive - if end_headers: stream.end_headers = end_headers - if end_stream: stream.end_stream = end_stream + if dependency: + stream.dependency = dependency + if weight: + stream.weight = weight + if exclusive: + stream.exclusive = exclusive + if end_headers: + stream.end_headers = end_headers + if end_stream: + stream.end_stream = end_stream else: # constructs the stream structure for the current stream that # is being open/created using the current owner, headers and # other information as the basis for such construction stream = HTTP2Stream( - owner = self, - identifier = self.stream, - header_b = fragment, - dependency = dependency, - weight = weight, - exclusive = exclusive, - end_headers = end_headers, - end_stream = end_stream, - store = self.store, - file_limit = self.file_limit, - window = window, - frame_size = frame_size + owner=self, + identifier=self.stream, + header_b=fragment, + dependency=dependency, + weight=weight, + exclusive=exclusive, + end_headers=end_headers, + end_stream=end_stream, + store=self.store, + file_limit=self.file_limit, + window=window, + frame_size=frame_size, ) # ensures that the stream object is properly open, this should @@ -735,13 +709,16 @@ def _parse_headers(self, data): self.trigger("on_headers_h2", stream) - if stream.end_headers: stream._calculate() - if stream.end_headers: self.trigger("on_headers") - if stream.is_ready: self.trigger("on_data") + if stream.end_headers: + stream._calculate() + if stream.end_headers: + self.trigger("on_headers") + if stream.is_ready: + self.trigger("on_data") def _parse_priority(self, data): dependency, weight = struct.unpack("!IB", data) - stream = self._get_stream(self.stream, strict = False) + stream = self._get_stream(self.stream, strict=False) if stream: stream.dependency = dependency stream.weight = weight @@ -749,8 +726,8 @@ def _parse_priority(self, data): self.trigger("on_priority", stream, dependency, weight) def _parse_rst_stream(self, data): - error_code, = struct.unpack("!I", data) - stream = self._get_stream(self.stream, strict = False) + (error_code,) = struct.unpack("!I", data) + stream = self._get_stream(self.stream, strict=False) self.assert_rst_stream(stream) self.trigger("on_rst_stream", stream, error_code) @@ -762,7 +739,7 @@ def _parse_settings(self, data): for index in netius.legacy.xrange(count): base = index * SETTING_SIZE - part = data[base:base + SETTING_SIZE] + part = data[base : base + SETTING_SIZE] setting = struct.unpack("!HI", part) settings.append(setting) @@ -780,12 +757,12 @@ def _parse_push_promise(self, data): padded_l = 0 if padded: - padded_l, = struct.unpack("!B", data[index:index + 1]) + (padded_l,) = struct.unpack("!B", data[index : index + 1]) index += 1 - promised_stream, = struct.unpack("!I", data[index:index + 4]) + (promised_stream,) = struct.unpack("!I", data[index : index + 4]) - fragment = data[index:data_l - padded_l] + fragment = data[index : data_l - padded_l] self.assert_push_promise(promised_stream) @@ -803,14 +780,11 @@ def _parse_goaway(self, data): self.trigger("on_goaway", last_stream, error_code, extra) def _parse_window_update(self, data): - increment, = struct.unpack("!I", data) - stream = self._get_stream( - self.stream, - strict = False, - unopened_s = True - ) + (increment,) = struct.unpack("!I", data) + stream = self._get_stream(self.stream, strict=False, unopened_s=True) self.assert_window_update(stream, increment) - if self.stream and not stream: return + if self.stream and not stream: + return self.trigger("on_window_update", stream, increment) def _parse_continuation(self, data): @@ -827,8 +801,10 @@ def _parse_continuation(self, data): self.trigger("on_continuation", stream) - if stream.end_headers: stream._calculate() - if stream.end_headers: self.trigger("on_headers") + if stream.end_headers: + stream._calculate() + if stream.end_headers: + self.trigger("on_headers") if stream.end_headers and stream.end_stream: self.trigger("on_data") @@ -837,34 +813,39 @@ def _has_stream(self, stream): def _get_stream( self, - stream = None, - default = None, - strict = True, - closed_s = False, - unopened_s = False, - exists_s = False + stream=None, + default=None, + strict=True, + closed_s=False, + unopened_s=False, + exists_s=False, ): - if stream == None: stream = self.stream - if stream == 0: return default - if strict: closed_s = True; unopened_s = True; exists_s = True + if stream == None: + stream = self.stream + if stream == 0: + return default + if strict: + closed_s = True + unopened_s = True + exists_s = True exists = stream in self.streams if closed_s and not exists and stream <= self._max_stream: raise netius.ParserError( "Invalid or closed stream '%d'" % stream, - stream = self.stream, - error_code = STREAM_CLOSED + stream=self.stream, + error_code=STREAM_CLOSED, ) if unopened_s and not exists and stream > self._max_stream: raise netius.ParserError( "Invalid or unopened stream '%d'" % stream, - stream = self.stream, - error_code = PROTOCOL_ERROR + stream=self.stream, + error_code=PROTOCOL_ERROR, ) if exists_s and not exists: raise netius.ParserError( "Invalid stream '%d'" % stream, - stream = self.stream, - error_code = PROTOCOL_ERROR + stream=self.stream, + error_code=PROTOCOL_ERROR, ) self.stream_o = self.streams.get(stream, default) return self.stream_o @@ -875,39 +856,47 @@ def _set_stream(self, stream): self._max_stream = max(self._max_stream, stream.identifier) def _del_stream(self, stream): - if not stream in self.streams: return + if not stream in self.streams: + return del self.streams[stream] self.stream_o = None def _invalid_type(self): ignore = False if self.last_type == HEADERS else True - if ignore: raise netius.ParserError("Invalid frame type", ignore = True) - raise netius.ParserError("Invalid frame type", error_code = PROTOCOL_ERROR) + if ignore: + raise netius.ParserError("Invalid frame type", ignore=True) + raise netius.ParserError("Invalid frame type", error_code=PROTOCOL_ERROR) @property def buffer_size(self): return sum(len(data) for data in self.buffer) @property - def buffer_data(self, empty = True): + def buffer_data(self, empty=True): data = b"".join(self.buffer) - if empty: del self.buffer[:] + if empty: + del self.buffer[:] return data @property def encoder(self): - if self._encoder: return self._encoder + if self._encoder: + return self._encoder import hpack + self._encoder = hpack.hpack.Encoder() return self._encoder @property def decoder(self): - if self._decoder: return self._decoder + if self._decoder: + return self._decoder import hpack + self._decoder = hpack.hpack.Decoder() return self._decoder + class HTTP2Stream(netius.Stream): """ Object representing a stream of data interchanged between two @@ -924,18 +913,18 @@ class HTTP2Stream(netius.Stream): def __init__( self, - identifier = None, - header_b = None, - dependency = 0x00, - weight = 1, - exclusive = False, - end_headers = False, - end_stream = False, - end_stream_l = False, - store = False, - file_limit = http.FILE_LIMIT, - window = HTTP2_WINDOW, - frame_size = HTTP2_FRAME_SIZE, + identifier=None, + header_b=None, + dependency=0x00, + weight=1, + exclusive=False, + end_headers=False, + end_stream=False, + end_stream_l=False, + store=False, + file_limit=http.FILE_LIMIT, + window=HTTP2_WINDOW, + frame_size=HTTP2_FRAME_SIZE, *args, **kwargs ): @@ -949,10 +938,7 @@ def __init__( self.end_stream = end_stream self.end_stream_l = end_stream_l self.reset( - store = store, - file_limit = file_limit, - window = window, - frame_size = frame_size + store=store, file_limit=file_limit, window=window, frame_size=frame_size ) def __getattr__(self, name): @@ -962,10 +948,10 @@ def __getattr__(self, name): def reset( self, - store = False, - file_limit = http.FILE_LIMIT, - window = HTTP2_WINDOW, - frame_size = HTTP2_FRAME_SIZE + store=False, + file_limit=http.FILE_LIMIT, + window=HTTP2_WINDOW, + frame_size=HTTP2_FRAME_SIZE, ): netius.Stream.reset(self) self.store = store @@ -994,7 +980,8 @@ def reset( def open(self): # check if the current stream is currently in (already) in # the open state and if that's the case returns immediately - if self.status == netius.OPEN: return + if self.status == netius.OPEN: + return # calls the parent open operation for upper operations, this # should take care of some callback calling @@ -1005,10 +992,11 @@ def open(self): # data is not currently available (continuation frames pending) self.decode_headers() - def close(self, flush = False, destroy = True, reset = True): + def close(self, flush=False, destroy=True, reset=True): # verifies if the current stream is already closed and # if that's the case returns immediately, avoiding duplicate - if self.status == netius.CLOSED: return + if self.status == netius.CLOSED: + return # in case the reset flag is set sends the final, tries to determine # the way of resetting the stream, in case the flush flag is set @@ -1018,8 +1006,10 @@ def close(self, flush = False, destroy = True, reset = True): # graceful approach is requested) the reset operation is performed if reset: graceful = flush and self.is_ready - if graceful: self.send_part(b"") - else: self.send_reset() + if graceful: + self.send_part(b"") + else: + self.send_reset() # calls the parent close method so that the upper layer # instructions are correctly processed/handled @@ -1028,45 +1018,46 @@ def close(self, flush = False, destroy = True, reset = True): # verifies if a stream structure exists in the parser for # the provided identifier and if that's not the case returns # immediately otherwise removes it from the parent - if not self.owner._has_stream(self.identifier): return + if not self.owner._has_stream(self.identifier): + return self.owner._del_stream(self.identifier) # runs the reset operation in the stream clearing all of its # internal structures may avoid some memory leaks self.reset() - def info_dict(self, full = False): - info = netius.Stream.info_dict(self, full = full) + def info_dict(self, full=False): + info = netius.Stream.info_dict(self, full=full) info.update( - identifier = self.identifier, - dependency = self.dependency, - weight = self.weight, - exclusive = self.exclusive, - end_headers = self.end_headers, - end_stream = self.end_stream, - end_stream_l = self.end_stream_l, - store = self.store, - file_limit = self.file_limit, - window = self.window, - window_m = self.window_m, - window_o = self.window_o, - window_l = self.window_l, - window_t = self.window_t, - pending_s = self.pending_s, - headers = self.headers, - method = self.method, - path_s = self.path_s, - version = self.version, - version_s = self.version_s, - encodings = self.encodings, - chunked = self.chunked, - keep_alive = self.keep_alive, - content_l = self.content_l, - frames = self.frames, - available = self.connection.available_stream(self.identifier, 1), - exhausted = self.is_exhausted(), - restored = self.is_restored(), - _available = self._available + identifier=self.identifier, + dependency=self.dependency, + weight=self.weight, + exclusive=self.exclusive, + end_headers=self.end_headers, + end_stream=self.end_stream, + end_stream_l=self.end_stream_l, + store=self.store, + file_limit=self.file_limit, + window=self.window, + window_m=self.window_m, + window_o=self.window_o, + window_l=self.window_l, + window_t=self.window_t, + pending_s=self.pending_s, + headers=self.headers, + method=self.method, + path_s=self.path_s, + version=self.version, + version_s=self.version_s, + encodings=self.encodings, + chunked=self.chunked, + keep_alive=self.keep_alive, + content_l=self.content_l, + frames=self.frames, + available=self.connection.available_stream(self.identifier, 1), + exhausted=self.is_exhausted(), + restored=self.is_restored(), + _available=self._available, ) return info @@ -1096,7 +1087,8 @@ def set_encoding(self, encoding): def set_uncompressed(self): if self.current >= http.CHUNKED_ENCODING: self.current = http.CHUNKED_ENCODING - else: self.current = http.PLAIN_ENCODING + else: + self.current = http.PLAIN_ENCODING def set_plain(self): self.set_encoding(http.PLAIN_ENCODING) @@ -1131,29 +1123,38 @@ def is_uncompressed(self): def is_flushed(self): return self.current > http.PLAIN_ENCODING - def is_measurable(self, strict = True): - if self.is_compressed(): return False + def is_measurable(self, strict=True): + if self.is_compressed(): + return False return True def is_exhausted(self): - if self.pending_s > self.connection.max_pending: return True - if not self._available: return True + if self.pending_s > self.connection.max_pending: + return True + if not self._available: + return True return False def is_restored(self): - if self.pending_s > self.connection.min_pending: return False - if not self._available: return False + if self.pending_s > self.connection.min_pending: + return False + if not self._available: + return False return True - def decode_headers(self, force = False, assert_h = True): - if not self.end_headers and not force: return - if self.headers_l and not force: return - if not self.header_b: return + def decode_headers(self, force=False, assert_h=True): + if not self.end_headers and not force: + return + if self.headers_l and not force: + return + if not self.header_b: + return is_joinable = len(self.header_b) > 1 block = b"".join(self.header_b) if is_joinable else self.header_b[0] self.headers_l = self.owner.decoder.decode(block) self.header_b = [] - if assert_h: self.assert_headers() + if assert_h: + self.assert_headers() def extend_headers(self, fragment): """ @@ -1183,7 +1184,8 @@ def extend_data(self, data): """ self._data_l += len(data) - if not self.store: return + if not self.store: + return self._data_b.write(data) def remote_update(self, increment): @@ -1218,14 +1220,14 @@ def local_update(self, increment): """ self.window_l += increment - if self.window_l >= self.window_t: return + if self.window_l >= self.window_t: + return self.connection.send_window_update( - increment = self.window_o - self.window_l, - stream = self.identifier + increment=self.window_o - self.window_l, stream=self.identifier ) self.window_l = self.window_o - def get_path(self, normalize = False): + def get_path(self, normalize=False): """ Retrieves the path associated with the request, this value should be interpreted from the HTTP status line. @@ -1243,8 +1245,10 @@ def get_path(self, normalize = False): split = self.path_s.split("?", 1) path = split[0] - if not normalize: return path - if not path.startswith(("http://", "https://")): return path + if not normalize: + return path + if not path.startswith(("http://", "https://")): + return path return netius.legacy.urlparse(path).path def get_query(self): @@ -1261,10 +1265,12 @@ def get_query(self): """ split = self.path_s.split("?", 1) - if len(split) == 1: return "" - else: return split[1] + if len(split) == 1: + return "" + else: + return split[1] - def get_message_b(self, copy = False, size = 40960): + def get_message_b(self, copy=False, size=40960): """ Retrieves a new buffer associated with the currently loaded message. @@ -1295,21 +1301,25 @@ def get_message_b(self, copy = False, size = 40960): # restores the message file to the original/initial position and # then in case there's no copy required returns it immediately self._data_b.seek(0) - if not copy: return self._data_b + if not copy: + return self._data_b # determines if the file limit for a temporary file has been # surpassed and if that's the case creates a named temporary # file, otherwise created a memory based buffer use_file = self.store and self.content_l >= self.file_limit - if use_file: message_f = tempfile.NamedTemporaryFile(mode = "w+b") - else: message_f = netius.legacy.BytesIO() + if use_file: + message_f = tempfile.NamedTemporaryFile(mode="w+b") + else: + message_f = netius.legacy.BytesIO() try: # iterates continuously reading the contents from the message # file and writing them back to the output (copy) file while True: data = self._data_b.read(size) - if not data: break + if not data: + break message_f.write(data) finally: # resets both of the message file (output and input) to the @@ -1322,57 +1332,62 @@ def get_message_b(self, copy = False, size = 40960): return message_f def get_encodings(self): - if not self.encodings == None: return self.encodings + if not self.encodings == None: + return self.encodings accept_encoding_s = self.headers.get("accept-encoding", "") self.encodings = [value.strip() for value in accept_encoding_s.split(",")] return self.encodings def fragment(self, data): - reference = min( - self.connection.window, - self.window, - self.window_m - ) + reference = min(self.connection.window, self.window, self.window_m) yield data[:reference] data = data[reference:] while data: - yield data[:self.window_m] - data = data[self.window_m:] + yield data[: self.window_m] + data = data[self.window_m :] def fragmentable(self, data): - if not data: return False - if self.window_m == 0: return False - if len(data) <= self.window_m and\ - len(data) <= self.window: return False + if not data: + return False + if self.window_m == 0: + return False + if len(data) <= self.window_m and len(data) <= self.window: + return False return True def flush(self, *args, **kwargs): - if not self.is_open(): return 0 + if not self.is_open(): + return 0 with self.ctx_request(args, kwargs): return self.connection.flush(*args, **kwargs) def flush_s(self, *args, **kwargs): - if not self.is_open(): return 0 + if not self.is_open(): + return 0 with self.ctx_request(args, kwargs): return self.connection.flush_s(*args, **kwargs) def send_response(self, *args, **kwargs): - if not self.is_open(): return 0 + if not self.is_open(): + return 0 with self.ctx_request(args, kwargs): return self.connection.send_response(*args, **kwargs) def send_header(self, *args, **kwargs): - if not self.is_open(): return 0 + if not self.is_open(): + return 0 with self.ctx_request(args, kwargs): return self.connection.send_header(*args, **kwargs) def send_part(self, *args, **kwargs): - if not self.is_open(): return 0 + if not self.is_open(): + return 0 with self.ctx_request(args, kwargs): return self.connection.send_part(*args, **kwargs) def send_reset(self, *args, **kwargs): - if not self.is_open(): return 0 + if not self.is_open(): + return 0 with self.ctx_request(args, kwargs): return self.connection.send_rst_stream(*args, **kwargs) @@ -1381,73 +1396,79 @@ def assert_headers(self): pseudos = dict() for name, value in self.headers_l: is_pseudo = name.startswith(":") - if not is_pseudo: pseudo = False + if not is_pseudo: + pseudo = False if not name.lower() == name: raise netius.ParserError( "Headers must be lower cased", - stream = self.identifier, - error_code = PROTOCOL_ERROR + stream=self.identifier, + error_code=PROTOCOL_ERROR, ) if name in (":status",): raise netius.ParserError( "Response pseudo-header present", - stream = self.identifier, - error_code = PROTOCOL_ERROR + stream=self.identifier, + error_code=PROTOCOL_ERROR, ) if name in ("connection",): raise netius.ParserError( "Invalid header present", - stream = self.identifier, - error_code = PROTOCOL_ERROR + stream=self.identifier, + error_code=PROTOCOL_ERROR, ) if name == "te" and not value == "trailers": raise netius.ParserError( "Invalid value for TE header", - stream = self.identifier, - error_code = PROTOCOL_ERROR + stream=self.identifier, + error_code=PROTOCOL_ERROR, ) if is_pseudo and name in pseudos: raise netius.ParserError( "Duplicated pseudo-header value", - stream = self.identifier, - error_code = PROTOCOL_ERROR + stream=self.identifier, + error_code=PROTOCOL_ERROR, ) if pseudo and not name in HTTP2_PSEUDO: raise netius.ParserError( "Invalid pseudo-header", - stream = self.identifier, - error_code = PROTOCOL_ERROR + stream=self.identifier, + error_code=PROTOCOL_ERROR, ) if not pseudo and is_pseudo: raise netius.ParserError( "Pseudo-header positioned after normal header", - stream = self.identifier, - error_code = PROTOCOL_ERROR + stream=self.identifier, + error_code=PROTOCOL_ERROR, ) - if is_pseudo: pseudos[name] = True + if is_pseudo: + pseudos[name] = True for name in (":method", ":scheme", ":path"): if not name in pseudos: raise netius.ParserError( "Missing pseudo-header in request", - stream = self.identifier, - error_code = PROTOCOL_ERROR + stream=self.identifier, + error_code=PROTOCOL_ERROR, ) def assert_ready(self): - if not self.content_l == -1 and not self._data_l == 0 and\ - not self._data_l == self.content_l: + if ( + not self.content_l == -1 + and not self._data_l == 0 + and not self._data_l == self.content_l + ): raise netius.ParserError( "Invalid content-length header value (missmatch)", - stream = self.identifier, - error_code = PROTOCOL_ERROR + stream=self.identifier, + error_code=PROTOCOL_ERROR, ) @contextlib.contextmanager - def ctx_request(self, args = None, kwargs = None): + def ctx_request(self, args=None, kwargs=None): # in case there's no valid set of keyword arguments # a valid and empty one must be created (avoids error) - if kwargs == None: kwargs = dict() + if kwargs == None: + kwargs = dict() # sets the stream keyword argument with the current # stream's identifier (provides identification support) @@ -1457,7 +1478,8 @@ def ctx_request(self, args = None, kwargs = None): # and in case it exits uses it to create a new one that # calls this one at the end (connection to stream clojure) callback = kwargs.get("callback", None) - if callback: kwargs["callback"] = self._build_c(callback) + if callback: + kwargs["callback"] = self._build_c(callback) # retrieves the references to the "original" # values of the current and stream objects @@ -1485,7 +1507,7 @@ def parser(self): return self @property - def is_ready(self, calculate = True, assert_r = True): + def is_ready(self, calculate=True, assert_r=True): """ Determines if the stream is ready, meaning that the complete set of headers and data have been passed to peer and the request @@ -1502,11 +1524,16 @@ def is_ready(self, calculate = True, assert_r = True): :return: The final value on the is ready (for processing). """ - if not self.is_open(): return False - if calculate: self._calculate() - if not self.end_headers: return False - if not self.end_stream: return False - if assert_r: self.assert_ready() + if not self.is_open(): + return False + if calculate: + self._calculate() + if not self.end_headers: + return False + if not self.end_stream: + return False + if assert_r: + self.assert_ready() return True @property @@ -1514,9 +1541,12 @@ def is_headers(self): return self.end_headers def _calculate(self): - if not self._data_b == None: return - if not self._data_l == -1: return - if not self.is_headers: return + if not self._data_b == None: + return + if not self._data_l == -1: + return + if not self.is_headers: + return self._calculate_headers() self.content_l = self.headers.get("content-length", -1) self.content_l = self.content_l and int(self.content_l) @@ -1532,27 +1562,35 @@ def _calculate_headers(self): for header in self.headers_l: key, value = header - if not type(key) == str: key = str(key) - if not type(value) == str: value = str(value) + if not type(key) == str: + key = str(key) + if not type(value) == str: + value = str(value) is_special = key.startswith(":") exists = key in headers_m if exists: sequence = headers_m[key] is_list = type(sequence) == list - if not is_list: sequence = [sequence] + if not is_list: + sequence = [sequence] sequence.append(value) value = sequence - if is_special: headers_s[key] = value - else: headers_m[key] = value + if is_special: + headers_s[key] = value + else: + headers_m[key] = value host = headers_s.get(":authority", None) - if host: headers_m["host"] = host + if host: + headers_m["host"] = host self.headers = headers_m self.method = headers_s.get(":method", None) self.path_s = headers_s.get(":path", None) - if self.method: self.method = str(self.method) - if self.path_s: self.path_s = str(self.path_s) + if self.method: + self.method = str(self.method) + if self.path_s: + self.path_s = str(self.path_s) def _build_b(self): """ @@ -1569,10 +1607,12 @@ def _build_b(self): """ use_file = self.store and self.content_l >= self.file_limit - if use_file: return tempfile.NamedTemporaryFile(mode = "w+b") - else: return netius.legacy.BytesIO() + if use_file: + return tempfile.NamedTemporaryFile(mode="w+b") + else: + return netius.legacy.BytesIO() - def _build_c(self, callback, validate = True): + def _build_c(self, callback, validate=True): """ Builds the final callback function to be used with a clojure around the current stream for proper validation and passing @@ -1591,7 +1631,8 @@ def _build_c(self, callback, validate = True): """ def inner(connection): - if validate and not self.is_open(): return + if validate and not self.is_open(): + return callback(self) return inner @@ -1599,7 +1640,7 @@ def inner(connection): def _parse_query(self, query): # runs the "default" parsing of the query string from the system # and then decodes the complete set of parameters properly - params = netius.legacy.parse_qs(query, keep_blank_values = True) + params = netius.legacy.parse_qs(query, keep_blank_values=True) return self._decode_params(params) def _decode_params(self, params): @@ -1609,10 +1650,12 @@ def _decode_params(self, params): items = [] for item in value: is_bytes = netius.legacy.is_bytes(item) - if is_bytes: item = item.decode("utf-8") + if is_bytes: + item = item.decode("utf-8") items.append(item) is_bytes = netius.legacy.is_bytes(key) - if is_bytes: key = key.decode("utf-8") + if is_bytes: + key = key.decode("utf-8") _params[key] = items return _params diff --git a/src/netius/common/mime.py b/src/netius/common/mime.py index 33b0bdf00..86b0a51e4 100644 --- a/src/netius/common/mime.py +++ b/src/netius/common/mime.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -64,7 +55,7 @@ (".log", "text/plain"), (".mka", "audio/x-matroska"), (".mkv", "video/x-matroska"), - (".woff", "application/font-woff") + (".woff", "application/font-woff"), ) """ The sequence containing tuple associating the extension with the mime type or content type string """ @@ -74,6 +65,7 @@ been performed, avoiding possible duplicated registration that would spend unnecessary resources """ + class Headers(list): """ Mutable structure that allow the access to header tuples @@ -86,9 +78,11 @@ class Headers(list): def __getitem__(self, key): is_integer = isinstance(key, int) - if is_integer: return list.__getitem__(self, key) + if is_integer: + return list.__getitem__(self, key) for _key, value in self: - if not _key == key: continue + if not _key == key: + continue return value raise KeyError("not found") @@ -96,56 +90,68 @@ def __setitem__(self, key, value): key = self._normalize(key) value = self._normalize(value) is_integer = isinstance(key, int) - if is_integer: return list.__setitem__(self, key, value) + if is_integer: + return list.__setitem__(self, key, value) self.append([key, value]) def __delitem__(self, key): is_integer = isinstance(key, int) - if is_integer: return list.__delitem__(self, key) + if is_integer: + return list.__delitem__(self, key) value = self.__getitem__(key) self.remove([key, value]) def __contains__(self, item): is_string = isinstance(item, netius.legacy.ALL_STRINGS) - if not is_string: return list.__contains__(self, item) + if not is_string: + return list.__contains__(self, item) for key, _value in self: - if not key == item: continue + if not key == item: + continue return True return False def item(self, key): for item in self: - if not item[0] == key: continue + if not item[0] == key: + continue return item raise KeyError("not found") - def get(self, key, default = None): - if not key in self: return default + def get(self, key, default=None): + if not key in self: + return default return self[key] - def set(self, key, value, append = False): + def set(self, key, value, append=False): key = self._normalize(key) value = self._normalize(value) - if key in self and not append: self.item(key)[1] = value - else: self[key] = value + if key in self and not append: + self.item(key)[1] = value + else: + self[key] = value - def pop(self, key, default = None): - if not key in self: return default + def pop(self, key, default=None): + if not key in self: + return default value = self[key] del self[key] return value - def join(self, separator = "\r\n"): + def join(self, separator="\r\n"): separator = netius.legacy.bytes(separator) return separator.join([key + b": " + value for key, value in self]) def _normalize(self, value): value_t = type(value) - if value_t == netius.legacy.BYTES: return value - if value_t == netius.legacy.UNICODE: return value.encode("utf-8") + if value_t == netius.legacy.BYTES: + return value + if value_t == netius.legacy.UNICODE: + return value.encode("utf-8") return netius.legacy.bytes(str(value)) -def rfc822_parse(message, strip = True): + +def rfc822_parse(message, strip=True): """ Parse a message in rfc822 format. This format is similar to the mime one with only some small changes. The returning value @@ -185,7 +191,8 @@ def rfc822_parse(message, strip = True): for line in lines: # in case an empty/invalid line has been reached the # end of headers have been found (must break the loop) - if not line: break + if not line: + break # retrieves the value for the current byte so that it's # possible to try to match it against the various regular @@ -211,17 +218,20 @@ def rfc822_parse(message, strip = True): # creating the proper header tuple adding it to the list of headers if match: name = match.group(1) - value = line[match.end(0):] - if strip: value = value.lstrip() + value = line[match.end(0) :] + if strip: + value = value.lstrip() headers.append([name, value]) # otherwise in case the line is a from line formatted # using an old fashion strategy tolerates it (compatibility) - elif line.startswith(b"From "): pass + elif line.startswith(b"From "): + pass # as a fallback raises a parser error as no parsing of header # was possible for the message (major problem) - else: raise netius.ParserError("Unexpected header value") + else: + raise netius.ParserError("Unexpected header value") # increments the current line index counter, as one more # line has been processed by the parser @@ -230,19 +240,23 @@ def rfc822_parse(message, strip = True): # joins the complete set of "remaining" body lines creating the string # representing the body, and uses it to create the headers and body # tuple that is going to be returned to the caller method - body_lines = lines[index + 1:] + body_lines = lines[index + 1 :] body = b"\r\n".join(body_lines) return (headers, body) + def rfc822_join(headers, body): headers_s = headers.join() return headers_s + b"\r\n\r\n" + body + def mime_register(): global MIME_REGISTERED - if MIME_REGISTERED: return + if MIME_REGISTERED: + return for extension, mime_type in MIME_TYPES: mimetypes.add_type(mime_type, extension) MIME_REGISTERED = True + mime_register() diff --git a/src/netius/common/parser.py b/src/netius/common/parser.py index 5e62861d0..c95df3cf5 100644 --- a/src/netius/common/parser.py +++ b/src/netius/common/parser.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ import netius + class Parser(netius.Observable): FIELDS = ("_pid",) diff --git a/src/netius/common/pop.py b/src/netius/common/pop.py index 7ed2f0a9b..2225a0f33 100644 --- a/src/netius/common/pop.py +++ b/src/netius/common/pop.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,9 +32,10 @@ from . import parser + class POPParser(parser.Parser): - def __init__(self, owner, store = False): + def __init__(self, owner, store=False): parser.Parser.__init__(self, owner) self.buffer = [] @@ -78,7 +70,8 @@ def parse(self, data): # zero the parsing iteration is broken method = self._parse_line count = method(data) - if count == 0: break + if count == 0: + break # decrements the size of the data buffer by the # size of the parsed bytes and then retrieves the @@ -89,7 +82,8 @@ def parse(self, data): # in case not all of the data has been processed # must add it to the buffer so that it may be used # latter in the next parsing of the message - if size > 0: self.buffer.append(data) + if size > 0: + self.buffer.append(data) # returns the number of read (processed) bytes of the # data that has been sent to the parser @@ -99,7 +93,8 @@ def _parse_line(self, data): # tries to find the new line character in the currently received # data in case it's not found returns immediately with no data processed index = data.find(b"\n") - if index == -1: return 0 + if index == -1: + return 0 # adds the partial data (until new line) to the current buffer and # then joins it retrieving the current line, then deletes the buffer @@ -113,7 +108,8 @@ def _parse_line(self, data): # the split is not successful (not enough information) then an extra # value is added to the sequence of values for compatibility values = self.line_s.split(" ", 1) - if not len(values) > 1: values.append(b"") + if not len(values) > 1: + values.append(b"") # unpacks the set of values that have just been parsed into the code # and the message items as expected by the pop specification diff --git a/src/netius/common/rsa.py b/src/netius/common/rsa.py index bea254ed1..3e3f866a3 100644 --- a/src/netius/common/rsa.py +++ b/src/netius/common/rsa.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -49,24 +40,31 @@ PRIVATE_TOKEN = "RSA PRIVATE KEY" PUBLIC_TOKEN = "PUBLIC KEY" -def open_pem_key(path, token = PRIVATE_TOKEN): + +def open_pem_key(path, token=PRIVATE_TOKEN): is_file = not isinstance(path, netius.legacy.STRINGS) - if is_file: file = path - else: file = open(path, "rb") + if is_file: + file = path + else: + file = open(path, "rb") try: data = file.read() finally: - if not is_file: file.close() - return open_pem_data(data, token = token) + if not is_file: + file.close() + return open_pem_data(data, token=token) + -def open_pem_data(data, token = PRIVATE_TOKEN): +def open_pem_data(data, token=PRIVATE_TOKEN): begin, end = pem_limiters(token) begin_index = data.find(begin) end_index = data.find(end) - if begin_index == -1: raise netius.ParserError("Invalid key format") - if end_index == -1: raise netius.ParserError("Invalid key format") + if begin_index == -1: + raise netius.ParserError("Invalid key format") + if end_index == -1: + raise netius.ParserError("Invalid key format") begin_index += len(begin) @@ -74,12 +72,8 @@ def open_pem_data(data, token = PRIVATE_TOKEN): data = data.strip() return base64.b64decode(data) -def write_pem_key( - path, - data, - token = PRIVATE_TOKEN, - width = 64 -): + +def write_pem_key(path, data, token=PRIVATE_TOKEN, width=64): begin, end = pem_limiters(token) data = base64.b64encode(data) @@ -97,123 +91,130 @@ def write_pem_key( file.write(end) file.write(b"\n") finally: - if not is_file: file.close() + if not is_file: + file.close() + def open_private_key(path): - data = open_pem_key( - path, - token = PRIVATE_TOKEN - ) + data = open_pem_key(path, token=PRIVATE_TOKEN) return open_private_key_data(data) + def open_private_key_b64(data_b64): data = base64.b64decode(data_b64) return open_private_key_data(data) + def open_private_key_data(data): asn1 = asn.asn1_parse(asn.ASN1_RSA_PRIVATE_KEY, data)[0] private_key = dict( - version = asn1[0], - modulus = asn1[1], - public_exponent = asn1[2], - private_exponent = asn1[3], - prime_1 = asn1[4], - prime_2 = asn1[5], - exponent_1 = asn1[6], - exponent_2 = asn1[7], - coefficient = asn1[8], - bits = rsa_bits(asn1[1]) + version=asn1[0], + modulus=asn1[1], + public_exponent=asn1[2], + private_exponent=asn1[3], + prime_1=asn1[4], + prime_2=asn1[5], + exponent_1=asn1[6], + exponent_2=asn1[7], + coefficient=asn1[8], + bits=rsa_bits(asn1[1]), ) return private_key + def open_public_key(path): - data = open_pem_key( - path, - token = PUBLIC_TOKEN - ) + data = open_pem_key(path, token=PUBLIC_TOKEN) return open_public_key_data(data) + def open_public_key_b64(data_b64): data = base64.b64decode(data_b64) return open_public_key_data(data) + def open_public_key_data(data): asn1 = asn.asn1_parse(asn.ASN1_OBJECT, data)[0] asn1 = asn.asn1_parse(asn.ASN1_RSA_PUBLIC_KEY, asn1[1][1:])[0] - public_key = dict( - modulus = asn1[0], - public_exponent = asn1[1], - bits = rsa_bits(asn1[0]) - ) + public_key = dict(modulus=asn1[0], public_exponent=asn1[1], bits=rsa_bits(asn1[0])) return public_key + def write_private_key(path, private_key): data = asn_private_key(private_key) - write_pem_key( - path, - data, - token = PRIVATE_TOKEN - ) + write_pem_key(path, data, token=PRIVATE_TOKEN) + def write_public_key(path, public_key): data = asn_public_key(public_key) - write_pem_key( - path, - data, - token = PUBLIC_TOKEN - ) + write_pem_key(path, data, token=PUBLIC_TOKEN) + def asn_private_key(private_key): return asn.asn1_gen( - (asn.SEQUENCE, [ - (asn.INTEGER, private_key["version"]), - (asn.INTEGER, private_key["modulus"]), - (asn.INTEGER, private_key["public_exponent"]), - (asn.INTEGER, private_key["private_exponent"]), - (asn.INTEGER, private_key["prime_1"]), - (asn.INTEGER, private_key["prime_2"]), - (asn.INTEGER, private_key["exponent_1"]), - (asn.INTEGER, private_key["exponent_2"]), - (asn.INTEGER, private_key["coefficient"]) - ]) + ( + asn.SEQUENCE, + [ + (asn.INTEGER, private_key["version"]), + (asn.INTEGER, private_key["modulus"]), + (asn.INTEGER, private_key["public_exponent"]), + (asn.INTEGER, private_key["private_exponent"]), + (asn.INTEGER, private_key["prime_1"]), + (asn.INTEGER, private_key["prime_2"]), + (asn.INTEGER, private_key["exponent_1"]), + (asn.INTEGER, private_key["exponent_2"]), + (asn.INTEGER, private_key["coefficient"]), + ], + ) ) + def asn_public_key(public_key): data = b"\x00" + asn.asn1_gen( - (asn.SEQUENCE, [ - (asn.INTEGER, public_key["modulus"]), - (asn.INTEGER, public_key["public_exponent"]) - ]) + ( + asn.SEQUENCE, + [ + (asn.INTEGER, public_key["modulus"]), + (asn.INTEGER, public_key["public_exponent"]), + ], + ) ) return asn.asn1_gen( - (asn.SEQUENCE, [ - (asn.SEQUENCE, [ - (asn.OBJECT_IDENTIFIER, asn.RSAID_PKCS1), - (asn.NULL, None) - ]), - (asn.BIT_STRING, data) - ]) + ( + asn.SEQUENCE, + [ + ( + asn.SEQUENCE, + [(asn.OBJECT_IDENTIFIER, asn.RSAID_PKCS1), (asn.NULL, None)], + ), + (asn.BIT_STRING, data), + ], + ) ) -def pem_to_der(in_path, out_path, token = PRIVATE_TOKEN): - data = open_pem_key(in_path, token = token) + +def pem_to_der(in_path, out_path, token=PRIVATE_TOKEN): + data = open_pem_key(in_path, token=token) file = open(out_path, "wb") - try: file.write(data) - finally: file.close() + try: + file.write(data) + finally: + file.close() + def pem_limiters(token): begin = netius.legacy.bytes("-----BEGIN " + token + "-----") end = netius.legacy.bytes("-----END " + token + "-----") return (begin, end) + def private_to_public(private_key): public_key = dict( - modulus = private_key["modulus"], - public_exponent = private_key["public_exponent"] + modulus=private_key["modulus"], public_exponent=private_key["public_exponent"] ) return public_key -def assert_private(private_key, number_bits = None): + +def assert_private(private_key, number_bits=None): prime_1 = private_key["prime_1"] prime_2 = private_key["prime_2"] private_exponent = private_key["private_exponent"] @@ -240,6 +241,7 @@ def assert_private(private_key, number_bits = None): netius.verify(result == message) + def rsa_private(number_bits): """ Generates a new "random" private with the requested number @@ -258,8 +260,11 @@ def rsa_private(number_bits): while True: prime_1, prime_2 = rsa_primes(number_bits // 2) - public_exponent, private_exponent = rsa_exponents(prime_1, prime_2, number_bits // 2) - if private_exponent > 0: break + public_exponent, private_exponent = rsa_exponents( + prime_1, prime_2, number_bits // 2 + ) + if private_exponent > 0: + break modulus = prime_1 * prime_2 exponent_1 = private_exponent % (prime_1 - 1) @@ -268,20 +273,21 @@ def rsa_private(number_bits): bits = rsa_bits(modulus) private_key = dict( - version = 0, - modulus = modulus, - public_exponent = public_exponent, - private_exponent = private_exponent, - prime_1 = prime_1, - prime_2 = prime_2, - exponent_1 = exponent_1, - exponent_2 = exponent_2, - coefficient = coefficient, - bits = bits + version=0, + modulus=modulus, + public_exponent=public_exponent, + private_exponent=private_exponent, + prime_1=prime_1, + prime_2=prime_2, + exponent_1=exponent_1, + exponent_2=exponent_2, + coefficient=coefficient, + bits=bits, ) return private_key + def rsa_primes(number_bits): """ Generates two different prime numbers (p and q values) @@ -307,7 +313,8 @@ def rsa_primes(number_bits): # used for a trial an error based approach for the generation # of the primes to be used in the private key def rsa_acceptable(prime_1, prime_2): - if prime_1 == prime_2: return False + if prime_1 == prime_2: + return False modulus_bits = rsa_bits(prime_1 * prime_2) return modulus_bits == total_bits @@ -321,15 +328,19 @@ def rsa_acceptable(prime_1, prime_2): # iterates continuously trying to find a combination # of prime numbers that is acceptable and valid while True: - if rsa_acceptable(prime_1, prime_2): break - if is_odd: prime_1 = calc.prime(number_bits) - else: prime_2 = calc.prime(number_bits) + if rsa_acceptable(prime_1, prime_2): + break + if is_odd: + prime_1 = calc.prime(number_bits) + else: + prime_2 = calc.prime(number_bits) # returns a tuple containing both of the generated # primes and returns it to the caller method return (prime_1, prime_2) -def rsa_exponents(prime_1, prime_2, number_bits, basic = True): + +def rsa_exponents(prime_1, prime_2, number_bits, basic=True): """ Generates both the public and the private exponents for the RSA cryptography system taking as base the provided @@ -368,7 +379,9 @@ def rsa_exponents(prime_1, prime_2, number_bits, basic = True): # to create a public exponent and the basic mode is active the # number chosen is the "magic" number (compatibility) public_exponent = calc.prime(max(8, number_bits // 2)) - if is_first and basic: public_exponent = 65537; is_first = False + if is_first and basic: + public_exponent = 65537 + is_first = False # checks if the exponent and the modulus are relative primes # and also checks if the exponent and the phi modulus are relative @@ -376,7 +389,8 @@ def rsa_exponents(prime_1, prime_2, number_bits, basic = True): # and the cycle may be broken is_relative = calc.relatively_prime(public_exponent, modulus) is_relative_phi = calc.relatively_prime(public_exponent, phi_modulus) - if is_relative and is_relative_phi: break + if is_relative and is_relative_phi: + break # retrieves the result of the extended euclid greatest common divisor, # this value is going to be used as the basis for the calculus of the @@ -386,39 +400,45 @@ def rsa_exponents(prime_1, prime_2, number_bits, basic = True): # in case the greatest common divisor between both is not one, the values # are not relative primes and an exception must be raised - if not d == 1: raise netius.GeneratorError( - "The public exponent '%d' and the phi modulus '%d' are not relative primes" % - (public_exponent, phi_modulus) - ) + if not d == 1: + raise netius.GeneratorError( + "The public exponent '%d' and the phi modulus '%d' are not relative primes" + % (public_exponent, phi_modulus) + ) # calculates the inverse modulus for both exponent and in case it's not one # an exception is raised about the problem inverse_modulus = (public_exponent * private_exponent) % phi_modulus - if not inverse_modulus == 1: netius.GeneratorError( - "The public exponent '%d' and private exponent '%d' are not multiplicative inverse modulus of phi modulus '%d'" % - (public_exponent, private_exponent, phi_modulus) - ) + if not inverse_modulus == 1: + netius.GeneratorError( + "The public exponent '%d' and private exponent '%d' are not multiplicative inverse modulus of phi modulus '%d'" + % (public_exponent, private_exponent, phi_modulus) + ) # creates the tuple that contains both the public and the private # exponent values that may be used for RSA based cryptography return (public_exponent, private_exponent) + def rsa_bits(modulus): bits = math.log(modulus, 2) return calc.ceil_integer(bits) + def rsa_sign(message, private_key): message = netius.legacy.bytes(message) modulus = private_key["modulus"] private_exponent = private_key["private_exponent"] return rsa_crypt_s(message, private_exponent, modulus) + def rsa_verify(signature, public_key): signature = netius.legacy.bytes(signature) modulus = public_key["modulus"] public_exponent = public_key["public_exponent"] return rsa_crypt_s(signature, public_exponent, modulus) + def rsa_crypt_s(message, exponent, modulus): modulus_l = calc.ceil_integer(math.log(modulus, 256)) @@ -428,11 +448,14 @@ def rsa_crypt_s(message, exponent, modulus): return message_crypt_s + def rsa_crypt(number, exponent, modulus): if not isinstance(number, netius.legacy.INTEGERS): raise TypeError("you must pass a long or an int") - if number > 0 and math.floor(math.log(number, 2)) > math.floor(math.log(modulus, 2)): + if number > 0 and math.floor(math.log(number, 2)) > math.floor( + math.log(modulus, 2) + ): raise OverflowError("the message is too long") return pow(number, exponent, modulus) diff --git a/src/netius/common/setup.py b/src/netius/common/setup.py index 639460e05..9b5daedd0 100644 --- a/src/netius/common/setup.py +++ b/src/netius/common/setup.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -46,36 +37,43 @@ EXTRAS_PATH = os.path.join(BASE_PATH, "extras") SSL_CA_PATH = os.path.join(EXTRAS_PATH, "net.ca") + def ensure_setup(): ensure_ca() -def ensure_ca(path = SSL_CA_PATH): - if os.path.exists(path): return - _download_ca(path = path) -def _download_ca(path = SSL_CA_PATH, raise_e = True): +def ensure_ca(path=SSL_CA_PATH): + if os.path.exists(path): + return + _download_ca(path=path) + + +def _download_ca(path=SSL_CA_PATH, raise_e=True): import netius.clients + ca_url = CA_URL while True: - result = netius.clients.HTTPClient.method_s( - "GET", - ca_url, - asynchronous = False - ) - if not result["code"] in (301, 302, 303): break + result = netius.clients.HTTPClient.method_s("GET", ca_url, asynchronous=False) + if not result["code"] in (301, 302, 303): + break headers = result.get("headers", {}) location = headers.get("Location", None) - if not location: break + if not location: + break ca_url = location if not result["code"] == 200: - if not raise_e: return + if not raise_e: + return raise Exception("Error while downloading CA file from '%s'" % CA_URL) response = netius.clients.HTTPClient.to_response(result) contents = response.read() _store_contents(contents, path) + def _store_contents(contents, path): file = open(path, "wb") - try: file.write(contents) - finally: file.close() + try: + file.write(contents) + finally: + file.close() return path diff --git a/src/netius/common/smtp.py b/src/netius/common/smtp.py index 2f72d2e03..82c9bd910 100644 --- a/src/netius/common/smtp.py +++ b/src/netius/common/smtp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -48,9 +39,10 @@ the various response lines between the code and the message part, it should handle both normal and continuation lines """ + class SMTPParser(parser.Parser): - def __init__(self, owner, store = False): + def __init__(self, owner, store=False): parser.Parser.__init__(self, owner) self.buffer = [] @@ -85,7 +77,8 @@ def parse(self, data): # zero the parsing iteration is broken method = self._parse_line count = method(data) - if count == 0: break + if count == 0: + break # decrements the size of the data buffer by the # size of the parsed bytes and then retrieves the @@ -96,7 +89,8 @@ def parse(self, data): # in case not all of the data has been processed # must add it to the buffer so that it may be used # latter in the next parsing of the message - if size > 0: self.buffer.append(data) + if size > 0: + self.buffer.append(data) # returns the number of read (processed) bytes of the # data that has been sent to the parser @@ -106,7 +100,8 @@ def _parse_line(self, data): # tries to find the new line character in the currently received # data in case it's not found returns immediately with no data processed index = data.find(b"\n") - if index == -1: return 0 + if index == -1: + return 0 # adds the partial data (until new line) to the current buffer and # then joins it retrieving the current line, then deletes the buffer @@ -120,7 +115,8 @@ def _parse_line(self, data): # the split is not successful (not enough information) then an extra # value is added to the sequence of values for compatibility values = SEPARATOR_REGEX.split(self.line_s, 1) - if not len(values) > 1: values.append("") + if not len(values) > 1: + values.append("") # unpacks the set of values that have just been parsed into the code # and the message items as expected by the smtp specification @@ -133,13 +129,15 @@ def _parse_line(self, data): line_l = len(self.line_s) space_index = self.line_s.find(" ") token_index = self.line_s.find("-") - if space_index == -1: space_index = line_l - if token_index == -1: token_index = line_l + if space_index == -1: + space_index = line_l + if token_index == -1: + token_index = line_l is_continuation = token_index < space_index is_final = not is_continuation # triggers the on line event so that the listeners are notified # about the end of the parsing of the smtp line and then # returns the count of the parsed bytes of the message - self.trigger("on_line", code, message, is_final = is_final) + self.trigger("on_line", code, message, is_final=is_final) return index + 1 diff --git a/src/netius/common/socks.py b/src/netius/common/socks.py index 2520057fe..747a0e24b 100644 --- a/src/netius/common/socks.py +++ b/src/netius/common/socks.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -72,6 +63,7 @@ FINISH_STATE = 11 + class SOCKSParser(parser.Parser): def __init__(self, owner): @@ -99,7 +91,7 @@ def build(self): self._parse_header_extra, self._parse_size, self._parse_address, - self._parse_port + self._parse_port, ) self.state_l = len(self.states) @@ -131,8 +123,9 @@ def reset(self): self.auth_count = 0 self.auth_methods = None - def clear(self, force = False): - if not force and self.state == VERSION_STATE: return + def clear(self, force=False): + if not force and self.state == VERSION_STATE: + return self.reset() def parse(self, data): @@ -154,7 +147,8 @@ def parse(self, data): # in case the current state of the parser is finished, must # reset the state to the start position as the parser is # re-starting (probably a new data sequence) - if self.state == FINISH_STATE: self.clear() + if self.state == FINISH_STATE: + self.clear() # retrieves the size of the data that has been sent for parsing # and saves it under the size original variable @@ -168,7 +162,8 @@ def parse(self, data): if self.state <= self.state_l: method = self.states[self.state - 1] count = method(data) - if count == 0: break + if count == 0: + break size -= count data = data[count:] @@ -184,7 +179,8 @@ def parse(self, data): # in case not all of the data has been processed # must add it to the buffer so that it may be used # latter in the next parsing of the message - if size > 0: self.buffer.append(data) + if size > 0: + self.buffer.append(data) # returns the number of read (processed) bytes of the # data that has been sent to the parser @@ -194,11 +190,15 @@ def get_host(self): return self.domain or self.address_s def get_address(self): - if self.type == None: return None + if self.type == None: + return None - if self.type == IPV4: address = struct.pack("!I", self.address) - elif self.type == IPV6: address = struct.pack("!QQ", self.address) - else: address = struct.pack("!B", self.size) + self.address + if self.type == IPV4: + address = struct.pack("!I", self.address) + elif self.type == IPV6: + address = struct.pack("!QQ", self.address) + else: + address = struct.pack("!B", self.size) + self.address return address @@ -207,11 +207,14 @@ def _parse_version(self, data): raise netius.ParserError("Invalid request (too short)") request = data[:1] - self.version, = struct.unpack("!B", request) + (self.version,) = struct.unpack("!B", request) - if self.version == 4: self.state = HEADER_STATE - elif self.version == 5: self.state = AUTH_COUNT_STATE - else: raise netius.ParserError("Invalid version '%d'" % self.version) + if self.version == 4: + self.state = HEADER_STATE + elif self.version == 5: + self.state = AUTH_COUNT_STATE + else: + raise netius.ParserError("Invalid version '%d'" % self.version) return 1 @@ -231,22 +234,27 @@ def _parse_header(self, data): def _parse_user_id(self, data): index = data.find(b"\0") - if index == -1: return 0 + if index == -1: + return 0 self.buffer.append(data[:index]) self.user_id = b"".join(self.buffer) self.user_id = netius.legacy.str(self.user_id) del self.buffer[:] - if self.is_extended: self.state = DOMAIN_STATE - else: self.state = FINISH_STATE + if self.is_extended: + self.state = DOMAIN_STATE + else: + self.state = FINISH_STATE - if not self.is_extended: self.trigger("on_data") + if not self.is_extended: + self.trigger("on_data") return index + 1 def _parse_domain(self, data): index = data.find(b"\0") - if index == -1: return 0 + if index == -1: + return 0 self.buffer.append(data[:index]) self.domain = b"".join(self.buffer) @@ -263,7 +271,7 @@ def _parse_auth_count(self, data): raise netius.ParserError("Invalid request (too short)") request = data[:1] - self.auth_count, = struct.unpack("!B", request) + (self.auth_count,) = struct.unpack("!B", request) self.state = AUTH_METHODS_STATE @@ -271,7 +279,8 @@ def _parse_auth_count(self, data): def _parse_auth_methods(self, data): is_ready = len(data) + len(self.buffer) >= self.auth_count - if not is_ready: return 0 + if not is_ready: + return 0 remaining = self.auth_count - len(self.buffer) self.buffer.append(data[:remaining]) @@ -291,14 +300,19 @@ def _parse_header_extra(self, data): raise netius.ParserError("Invalid request (too short)") request = data[:4] - self.version, self.command, _reserved, self.type =\ - struct.unpack("!BBBB", request) + self.version, self.command, _reserved, self.type = struct.unpack( + "!BBBB", request + ) - if self.type == IPV4: self.size = 4 - elif self.type == IPV6: self.size = 16 + if self.type == IPV4: + self.size = 4 + elif self.type == IPV6: + self.size = 16 - if self.type == DOMAIN: self.state = SIZE_STATE - else: self.state = ADDRESS_STATE + if self.type == DOMAIN: + self.state = SIZE_STATE + else: + self.state = ADDRESS_STATE return 4 @@ -307,7 +321,7 @@ def _parse_size(self, data): raise netius.ParserError("Invalid request (too short)") request = data[:1] - self.size, = struct.unpack("!B", request) + (self.size,) = struct.unpack("!B", request) self.state = ADDRESS_STATE @@ -315,14 +329,15 @@ def _parse_size(self, data): def _parse_address(self, data): is_ready = len(data) + len(self.buffer) >= self.size - if not is_ready: return 0 + if not is_ready: + return 0 remaining = self.size - len(self.buffer) self.buffer.append(data[:remaining]) data = b"".join(self.buffer) if self.type == IPV4: - self.address, = struct.unpack("!I", data) + (self.address,) = struct.unpack("!I", data) self.address_s = util.addr_to_ip4(self.address) elif self.type == IPV6: address_t = struct.unpack("!QQ", data) @@ -341,7 +356,7 @@ def _parse_port(self, data): raise netius.ParserError("Invalid request (too short)") request = data[:2] - self.port, = struct.unpack("!H", request) + (self.port,) = struct.unpack("!H", request) self.state = FINISH_STATE diff --git a/src/netius/common/stream.py b/src/netius/common/stream.py index 2dde8a550..c20e6099d 100644 --- a/src/netius/common/stream.py +++ b/src/netius/common/stream.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,9 +32,10 @@ import netius + class Stream(object): - def open(self, mode = "r+b"): + def open(self, mode="r+b"): raise netius.NotImplemented("Missing implementation") def close(self): @@ -61,6 +53,7 @@ def write(self, data): def flish(self): raise netius.NotImplemented("Missing implementation") + class FileStream(Stream): def __init__(self, path, size): @@ -69,15 +62,17 @@ def __init__(self, path, size): self.size = size self.file = None - def open(self, mode = "w+b", allocate = True): + def open(self, mode="w+b", allocate=True): self.file = open(self.path, mode) - if not allocate: return + if not allocate: + return self.file.seek(self.size - 1) self.file.write(b"\0") self.file.flush() def close(self): - if not self.file: return + if not self.file: + return self.file.close() self.file = None @@ -93,6 +88,7 @@ def write(self, data): def flush(self): self.file.flush() + class FilesStream(Stream): def __init__(self, dir_path, size, files_m): @@ -103,7 +99,7 @@ def __init__(self, dir_path, size, files_m): self.files = [] self._offset = 0 - def open(self, mode = "w+b", allocate = True): + def open(self, mode="w+b", allocate=True): for file_m in self.files_m: file_path = file_m["path"] file_size = file_m["length"] @@ -111,13 +107,15 @@ def open(self, mode = "w+b", allocate = True): file = open(file_path, mode) file_t = (file, file_m) self.files.append(file_t) - if not allocate: continue + if not allocate: + continue file.seek(file_size - 1) file.write(b"\0") file.flush() def close(self): - if not self.files: return + if not self.files: + return for file_t in self.files: file, _file_m = file_t file.close() @@ -151,7 +149,8 @@ def read(self, size): # iteration, must go further start = offset - file_offset file_offset += file_size - if start >= file_size: continue + if start >= file_size: + continue # calculates the end internal offset value as the # minimum value between the file size and the start @@ -176,7 +175,8 @@ def read(self, size): # verifies if there's no more data pending and if # that's the case break the current loop as no more # files are going to be affected - if pending == 0: break + if pending == 0: + break # updates the current offset of the (virtual) file stream # with length of the data that has been read, then avoids @@ -213,7 +213,8 @@ def write(self, data): # iteration, must go further start = offset - file_offset file_offset += file_size - if start >= file_size: continue + if start >= file_size: + continue # calculates the end internal offset value as the # minimum value between the file size and the start @@ -240,7 +241,8 @@ def write(self, data): # verifies if there's no more data pending and if # that's the case break the current loop as no more # files are going to be affected - if pending == 0: break + if pending == 0: + break # updates the current offset of the (virtual) file stream # with length of the data that has just been written, then diff --git a/src/netius/common/structures.py b/src/netius/common/structures.py index 4997f8d20..01b63d9d1 100644 --- a/src/netius/common/structures.py +++ b/src/netius/common/structures.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -40,6 +31,7 @@ import os import heapq + class PriorityDict(dict): def __init__(self, *args, **kwargs): @@ -81,18 +73,21 @@ def update(self, *args, **kwargs): self._rebuild_heap() def sorted_iter(self): - while self: yield self.pop_smallest() + while self: + yield self.pop_smallest() def _rebuild_heap(self): self._heap = [(v, k) for k, v in self.items()] heapq.heapify(self._heap) -def file_iterator(file_object, chunk_size = 40960): + +def file_iterator(file_object, chunk_size=40960): file_object.seek(0, os.SEEK_END) size = file_object.tell() file_object.seek(0, os.SEEK_SET) yield size while True: data = file_object.read(chunk_size) - if not data: break + if not data: + break yield data diff --git a/src/netius/common/style.py b/src/netius/common/style.py index 4b49ad53f..b8c5e6c70 100644 --- a/src/netius/common/style.py +++ b/src/netius/common/style.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/common/tftp.py b/src/netius/common/tftp.py index 26ed359a8..8c45654b5 100644 --- a/src/netius/common/tftp.py +++ b/src/netius/common/tftp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -44,9 +35,9 @@ ERROR_TFTP = 0x05 TYPES_TFTP = { - RRQ_TFTP : "rrq", - WRQ_TFTP : "wrq", - DATA_TFTP : "data", - ACK_TFTP : "ack", - ERROR_TFTP : "error" + RRQ_TFTP: "rrq", + WRQ_TFTP: "wrq", + DATA_TFTP: "data", + ACK_TFTP: "ack", + ERROR_TFTP: "error", } diff --git a/src/netius/common/tls.py b/src/netius/common/tls.py index bf4bfcd25..11643d930 100644 --- a/src/netius/common/tls.py +++ b/src/netius/common/tls.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius + class TLSContextDict(dict): def __init__(self, owner, domains, *args, **kwargs): @@ -49,20 +41,23 @@ def __init__(self, owner, domains, *args, **kwargs): self.load(domains) def load(self, domains): - secure = self.owner.get_env("SSL_SECURE", 1, cast = int) + secure = self.owner.get_env("SSL_SECURE", 1, cast=int) for domain in domains: - if not self.has_definition(domain): continue + if not self.has_definition(domain): + continue cer_path = self.cer_path(domain) key_path = self.key_path(domain) - values = dict(cer_file = cer_path, key_file = key_path) - context = self.owner._ssl_ctx(values, secure = secure) + values = dict(cer_file=cer_path, key_file=key_path) + context = self.owner._ssl_ctx(values, secure=secure) self[domain] = (context, values) def has_definition(self, domain): cer_path = self.cer_path(domain) key_path = self.key_path(domain) - if not os.path.exists(cer_path): return False - if not os.path.exists(key_path): return False + if not os.path.exists(cer_path): + return False + if not os.path.exists(key_path): + return False return True def cer_path(self, domain): @@ -71,6 +66,7 @@ def cer_path(self, domain): def key_path(self, domain): raise netius.NotImplemented("Missing implementation") + class LetsEncryptDict(TLSContextDict): def __init__(self, owner, domains, *args, **kwargs): diff --git a/src/netius/common/torrent.py b/src/netius/common/torrent.py index 07653b561..c36fb86b4 100644 --- a/src/netius/common/torrent.py +++ b/src/netius/common/torrent.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -55,27 +46,29 @@ character and expression with decimal digits """ TORRENT_TYPES = { - -1 : "keep-alive", - 0 : "choke", - 1 : "unchoke", - 2 : "interested", - 3 : "not interested", - 4 : "have", - 5 : "bitfield", - 6 : "request", - 7 : "piece", - 8 : "cancel", - 9 : "port" + -1: "keep-alive", + 0: "choke", + 1: "unchoke", + 2: "interested", + 3: "not interested", + 4: "have", + 5: "bitfield", + 6: "request", + 7: "piece", + 8: "cancel", + 9: "port", } """ The map that associates the various message type identifiers with their internal string representations """ + def info_hash(root): info = root["info"] data = bencode(info) info_hash = hashlib.sha1(data) return info_hash.digest() + def bencode(root): # joins the complete set of values created by # generator that has been returned from the chunk @@ -83,6 +76,7 @@ def bencode(root): data = b"".join([value for value in chunk(root)]) return data + def bdecode(data): # converts the provide (string) data into a list # of chunks (characters) reversing it so that the @@ -98,6 +92,7 @@ def bdecode(data): root = dechunk(chunks) return root + def chunk(item): chunk_t = type(item) @@ -112,14 +107,17 @@ def chunk(item): keys.sort() for key in keys: value = item[key] - for part in chunk(key): yield part - for part in chunk(value): yield part + for part in chunk(key): + yield part + for part in chunk(value): + yield part yield b"e" elif chunk_t == list: yield b"l" for value in item: - for part in chunk(value): yield part + for part in chunk(value): + yield part yield b"e" elif chunk_t in netius.legacy.INTEGERS: @@ -131,6 +129,7 @@ def chunk(item): else: raise netius.ParserError("Not possible to encode") + def dechunk(chunks): item = chunks.pop() @@ -187,9 +186,10 @@ def dechunk(chunks): raise netius.ParserError("Invalid input: '%s'" % item) + class TorrentParser(parser.Parser): - def __init__(self, owner, store = False): + def __init__(self, owner, store=False): parser.Parser.__init__(self, owner) self.length = None @@ -207,10 +207,7 @@ def build(self): parser.Parser.build(self) - self.states = ( - self._parse_handshake, - self._parse_message - ) + self.states = (self._parse_handshake, self._parse_message) self.state_l = len(self.states) def destroy(self): @@ -252,7 +249,8 @@ def parse(self, data): # in case there's no owner associated with the # current parser must break the loop because # there's no way to continue with parsing - if not self.owner: break + if not self.owner: + break # retrieves the parsing method for the current # state and then runs it retrieving the number @@ -260,7 +258,8 @@ def parse(self, data): # zero the parsing iteration is broken method = self.states[self.owner.state - 1] count = method(data) - if count == 0: break + if count == 0: + break # decrements the size of the data buffer by the # size of the parsed bytes and then retrieves the @@ -271,7 +270,9 @@ def parse(self, data): # in case not all of the data has been processed # must add it to the buffer so that it may be used # latter in the next parsing of the message - if size > 0: self.buffer.append(data); self.buffer_l += size + if size > 0: + self.buffer.append(data) + self.buffer_l += size # returns the number of read (processed) bytes of the # data that has been sent to the parser @@ -287,12 +288,15 @@ def _join(self, data): def _parse_handshake(self, data): total = len(data) + self.buffer_l - if total < HANDSHAKE_SIZE: return 0 + if total < HANDSHAKE_SIZE: + return 0 diff = HANDSHAKE_SIZE - self.buffer_l if self.buffer_l < HANDSHAKE_SIZE else 0 result = self._join(data[:diff]) - _length, protocol, reserved, info_hash, peer_id = struct.unpack("!B19sQ20s20s", result) + _length, protocol, reserved, info_hash, peer_id = struct.unpack( + "!B19sQ20s20s", result + ) self.trigger("on_handshake", protocol, reserved, info_hash, peer_id) self.length = None @@ -312,11 +316,12 @@ def _parse_message(self, data): # inside the current message so that it may than be re-used # to check if the complete message has been received if self.length == None: - if total < 4: return 0 + if total < 4: + return 0 diff = 4 - self.buffer_l if self.buffer_l < 4 else 0 result = self._join(data[:diff]) data = data[diff:] - self.length, = struct.unpack("!L", result[:4]) + (self.length,) = struct.unpack("!L", result[:4]) count += diff # calculates the "target" total message length and verifies @@ -324,7 +329,8 @@ def _parse_message(self, data): # have already reached that value otherwise returns count # to the caller method (delayed execution) message_length = self.length + 4 - if total < message_length: return count + if total < message_length: + return count # calculates the difference between meaning the amount of # data from the current chunk that is going to be processed @@ -336,8 +342,10 @@ def _parse_message(self, data): # greater than zero) for such situations the type must # be loaded, otherwise the type is assumed to be the keep # alive one (only message with no payload available) - if self.length > 0: type, = struct.unpack("!B", result[4:5]) - else: type = -1 + if self.length > 0: + (type,) = struct.unpack("!B", result[4:5]) + else: + type = -1 # resolves the current type integer based type into the proper # string based type values so that it may be used from now on diff --git a/src/netius/common/util.py b/src/netius/common/util.py index 887c58341..09c5b1a8d 100644 --- a/src/netius/common/util.py +++ b/src/netius/common/util.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -44,15 +35,11 @@ import netius -SIZE_UNITS_LIST = ( - "B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB" -) +SIZE_UNITS_LIST = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") """ The size units list that contains the complete set of units indexed by the depth they represent """ -SIZE_UNITS_LIST_S = ( - "B", "K", "M", "G", "T", "P", "E", "Z", "Y" -) +SIZE_UNITS_LIST_S = ("B", "K", "M", "G", "T", "P", "E", "Z", "Y") """ The simplified size units list that contains the complete set of units indexed by the depth they represent """ @@ -75,49 +62,68 @@ this value is used to avoid an excessive blocking in the get host by name call, as it is a blocking call """ + def cstring(value): index = value.index("\0") - if index == -1: return value + if index == -1: + return value return value[:index] + def chunks(sequence, count): for index in range(0, len(sequence), count): - yield sequence[index:index + count] + yield sequence[index : index + count] + def header_down(name): values = name.split("-") values = [value.lower() for value in values] return "-".join(values) + def header_up(name): values = name.split("-") values = [value.title() for value in values] return "-".join(values) + def is_ip4(address): address_p = address.split(".", 4) - if not len(address_p) == 4: return False + if not len(address_p) == 4: + return False for part in address_p: - try: part_i = int(part) - except ValueError: return False - if part_i < 0: return False - if part_i > 255: return False + try: + part_i = int(part) + except ValueError: + return False + if part_i < 0: + return False + if part_i > 255: + return False return True + def is_ip6(address): - if is_ip4(address): return False + if is_ip4(address): + return False return True -def assert_ip4(address, allowed, default = True): - if not allowed: return default + +def assert_ip4(address, allowed, default=True): + if not allowed: + return default for item in allowed: is_subnet = "/" in item - if is_subnet: valid = in_subnet_ip4(address, item) - else: valid = address == item - if not valid: continue + if is_subnet: + valid = in_subnet_ip4(address, item) + else: + valid = address == item + if not valid: + continue return True return False + def in_subnet_ip4(address, subnet): subnet, length = subnet.split("/", 1) size_i = 32 - int(length) @@ -128,6 +134,7 @@ def in_subnet_ip4(address, subnet): in_subnet &= address_a < limit_a return in_subnet + def addr_to_ip4(number): first = int(number / 16777216) % 256 second = int(number / 65536) % 256 @@ -135,15 +142,17 @@ def addr_to_ip4(number): fourth = int(number) % 256 return "%s.%s.%s.%s" % (first, second, third, fourth) + def addr_to_ip6(number): buffer = collections.deque() for index in range(8): offset = index * 2 - first = number >> (8 * offset) & 0xff - second = number >> (8 * (offset + 1)) & 0xff + first = number >> (8 * offset) & 0xFF + second = number >> (8 * (offset + 1)) & 0xFF buffer.appendleft("%02x%02x" % (second, first)) return ":".join(buffer) + def ip4_to_addr(value): first, second, third, fourth = value.split(".", 3) first_a = int(first) * 16777216 @@ -152,10 +161,16 @@ def ip4_to_addr(value): fourth_a = int(fourth) return first_a + second_a + third_a + fourth_a + def string_to_bits(value): - return bin(netius.legacy.reduce(lambda x, y : (x << 8) + y, (netius.legacy.ord(c) for c in value), 1))[3:] + return bin( + netius.legacy.reduce( + lambda x, y: (x << 8) + y, (netius.legacy.ord(c) for c in value), 1 + ) + )[3:] -def integer_to_bytes(number, length = 0): + +def integer_to_bytes(number, length=0): if not isinstance(number, netius.legacy.INTEGERS): raise netius.DataError("Invalid data type") @@ -163,12 +178,13 @@ def integer_to_bytes(number, length = 0): number = abs(number) while number > 0: - bytes.append(chr(number & 0xff)) + bytes.append(chr(number & 0xFF)) number >>= 8 remaining = length - len(bytes) remaining = 0 if remaining < 0 else remaining - for _index in range(remaining): bytes.append("\x00") + for _index in range(remaining): + bytes.append("\x00") bytes = reversed(bytes) bytes_s = "".join(bytes) @@ -176,14 +192,17 @@ def integer_to_bytes(number, length = 0): return bytes_s + def bytes_to_integer(bytes): if not type(bytes) == netius.legacy.BYTES: raise netius.DataError("Invalid data type") number = 0 - for byte in bytes: number = (number << 8) | netius.legacy.ord(byte) + for byte in bytes: + number = (number << 8) | netius.legacy.ord(byte) return number + def random_integer(number_bits): """ Generates a random integer of approximately the @@ -216,7 +235,8 @@ def random_integer(number_bits): random_integer |= 1 << (number_bits - 1) return random_integer -def host(default = "127.0.0.1"): + +def host(default="127.0.0.1"): """ Retrieves the host for the current machine, typically this would be the ipv4 address of @@ -238,14 +258,19 @@ def host(default = "127.0.0.1"): """ global _HOST - if _HOST: return _HOST + if _HOST: + return _HOST hostname = socket.gethostname() - try: _HOST = socket.gethostbyname(hostname) - except socket.gaierror: _HOST = default + try: + _HOST = socket.gethostbyname(hostname) + except socket.gaierror: + _HOST = default is_unicode = type(_HOST) == netius.legacy.OLD_UNICODE - if is_unicode: _HOST = _HOST.encode("utf-8") + if is_unicode: + _HOST = _HOST.encode("utf-8") return _HOST + def hostname(): """ The name as a simple string o the name of the current @@ -262,15 +287,16 @@ def hostname(): return socket.gethostname() + def size_round_unit( size_value, - minimum = DEFAULT_MINIMUM, - places = DEFAULT_PLACES, - reduce = True, - space = False, - justify = False, - simplified = False, - depth = 0 + minimum=DEFAULT_MINIMUM, + places=DEFAULT_PLACES, + reduce=True, + space=False, + justify=False, + simplified=False, + depth=0, ): """ Rounds the size unit, returning a string representation @@ -344,24 +370,30 @@ def size_round_unit( # in case the dot value is not present in the size value # string adds it to the end otherwise an issue may occur # while removing extra padding characters for reduce - if reduce and not "." in size_value_s: size_value_s += "." + if reduce and not "." in size_value_s: + size_value_s += "." # strips the value from zero appended to the right and # then strips the value also from a possible decimal # point value that may be included in it, this is only # performed in case the reduce flag is enabled - if reduce: size_value_s = size_value_s.rstrip("0") - if reduce: size_value_s = size_value_s.rstrip(".") + if reduce: + size_value_s = size_value_s.rstrip("0") + if reduce: + size_value_s = size_value_s.rstrip(".") # in case the justify flag is set runs the justification # process on the size value taking into account the maximum # size of the associated size string - if justify: size_value_s = size_value_s.rjust(size_s) + if justify: + size_value_s = size_value_s.rjust(size_s) # retrieves the size unit (string mode) for the current # depth according to the provided map - if simplified: size_unit = SIZE_UNITS_LIST_S[depth] - else: size_unit = SIZE_UNITS_LIST[depth] + if simplified: + size_unit = SIZE_UNITS_LIST_S[depth] + else: + size_unit = SIZE_UNITS_LIST[depth] # retrieves the appropriate separator based # on the value of the space flag @@ -382,16 +414,17 @@ def size_round_unit( new_depth = depth + 1 return size_round_unit( new_size_value, - minimum = minimum, - places = places, - reduce = reduce, - space = space, - justify = justify, - simplified = simplified, - depth = new_depth + minimum=minimum, + places=places, + reduce=reduce, + space=space, + justify=justify, + simplified=simplified, + depth=new_depth, ) -def verify(condition, message = None, exception = None): + +def verify(condition, message=None, exception=None): """ Ensures that the requested condition returns a valid value and if that's no the case an exception raised breaking the @@ -409,39 +442,32 @@ def verify(condition, message = None, exception = None): verification operation fails. """ - if condition: return + if condition: + return exception = exception or netius.AssertionError raise exception(message or "Assertion Error") -def verify_equal(first, second, message = None, exception = None): + +def verify_equal(first, second, message=None, exception=None): message = message or "Expected %s got %s" % (repr(second), repr(first)) - return verify( - first == second, - message = message, - exception = exception - ) + return verify(first == second, message=message, exception=exception) -def verify_not_equal(first, second, message = None, exception = None): + +def verify_not_equal(first, second, message=None, exception=None): message = message or "Expected %s not equal to %s" % (repr(first), repr(second)) - return verify( - not first == second, - message = message, - exception = exception - ) + return verify(not first == second, message=message, exception=exception) + -def verify_type(value, types, null = True, message = None, exception = None, **kwargs): +def verify_type(value, types, null=True, message=None, exception=None, **kwargs): message = message or "Expected %s to have type %s" % (repr(value), repr(types)) return verify( (null and value == None) or isinstance(value, types), - message = message, - exception = exception, + message=message, + exception=exception, **kwargs ) -def verify_many(sequence, message = None, exception = None): + +def verify_many(sequence, message=None, exception=None): for condition in sequence: - verify( - condition, - message = message, - exception = exception - ) + verify(condition, message=message, exception=exception) diff --git a/src/netius/common/ws.py b/src/netius/common/ws.py index 24a742787..1befa9bba 100644 --- a/src/netius/common/ws.py +++ b/src/netius/common/ws.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,7 +33,8 @@ import netius -def encode_ws(data, final = True, opcode = 0x01, mask = True): + +def encode_ws(data, final=True, opcode=0x01, mask=True): # converts the boolean based values of the frame into the # bit based partials that are going to be used in the build # of the final frame container element (as expected) @@ -81,13 +73,16 @@ def encode_ws(data, final = True, opcode = 0x01, mask = True): encoded_l.append(mask_bytes) encoded_a = bytearray(data_l) for i in range(data_l): - encoded_a[i] = netius.legacy.chri(netius.legacy.ord(data[i]) ^ netius.legacy.ord(mask_bytes[i % 4])) + encoded_a[i] = netius.legacy.chri( + netius.legacy.ord(data[i]) ^ netius.legacy.ord(mask_bytes[i % 4]) + ) data = bytes(encoded_a) encoded_l.append(data) encoded = b"".join(encoded_l) return encoded + def decode_ws(data): # calculates the length of the data and runs the initial # verification ensuring that such data is larger than the @@ -141,16 +136,18 @@ def decode_ws(data): # immediately indicating that there's not enough data to complete # the decoding of the data (should be re-trying again latter) raw_size = data_l - index_mask_f - mask_bytes - if raw_size < length: raise netius.DataError("Not enough data") + if raw_size < length: + raise netius.DataError("Not enough data") # in case the frame data is not masked the complete set of contents # may be returned immediately to the caller as there's no issue with # avoiding the unmasking operation (as the data is not masked) - if not has_mask: return data[index_mask_f:], b"" + if not has_mask: + return data[index_mask_f:], b"" # retrieves the mask part of the data that are going to be # used in the decoding part of the process - mask = data[index_mask_f:index_mask_f + mask_bytes] + mask = data[index_mask_f : index_mask_f + mask_bytes] # allocates the array that is going to be used # for the decoding of the data with the length @@ -162,7 +159,9 @@ def decode_ws(data): # (decoding it consequently) to the created decoded array i = index_mask_f + 4 for j in range(length): - decoded_a[j] = netius.legacy.chri(netius.legacy.ord(data[i]) ^ netius.legacy.ord(mask[j % 4])) + decoded_a[j] = netius.legacy.chri( + netius.legacy.ord(data[i]) ^ netius.legacy.ord(mask[j % 4]) + ) i += 1 # converts the decoded array of data into a string and @@ -171,5 +170,7 @@ def decode_ws(data): decoded = bytes(decoded_a) return decoded, data[i:] + def assert_ws(data_l, size): - if data_l < size: raise netius.DataError("Not enough data") + if data_l < size: + raise netius.DataError("Not enough data") diff --git a/src/netius/examples/__init__.py b/src/netius/examples/__init__.py index 45c8d2ae1..03d10d641 100644 --- a/src/netius/examples/__init__.py +++ b/src/netius/examples/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/examples/http.py b/src/netius/examples/http.py index 4a3156ca4..dd3b681ff 100644 --- a/src/netius/examples/http.py +++ b/src/netius/examples/http.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,13 +30,14 @@ import netius.clients + def http_static(): request = netius.clients.HTTPClient.get_s( - "https://www.flickr.com/", - asynchronous = False + "https://www.flickr.com/", asynchronous=False ) print(request["data"]) + def http_callback(): def callback(connection, parser, request): print(request["data"]) @@ -53,4 +45,4 @@ def callback(connection, parser, request): client.close() client = netius.clients.HTTPClient() - client.get("https://www.flickr.com/", on_result = callback) + client.get("https://www.flickr.com/", on_result=callback) diff --git a/src/netius/examples/upnp.py b/src/netius/examples/upnp.py index 499c0a149..91dd7cb44 100644 --- a/src/netius/examples/upnp.py +++ b/src/netius/examples/upnp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,7 +32,8 @@ import netius.clients -def upnp_map(ext_port, int_port, host, protocol = "TCP", description = "netius"): + +def upnp_map(ext_port, int_port, host, protocol="TCP", description="netius"): """ Defines a router port forwarding rule using an UPnP based request that tries to find the first available router. @@ -67,7 +59,13 @@ def upnp_map(ext_port, int_port, host, protocol = "TCP", description = "netius") 0 - """ % (ext_port, protocol, int_port, host, description) + """ % ( + ext_port, + protocol, + int_port, + host, + description, + ) def on_location(connection, parser, request): data = request["data"] @@ -78,20 +76,21 @@ def on_location(connection, parser, request): url = base_url + path netius.clients.HTTPClient.post_s( url, - headers = dict( - SOAPACTION = "\"urn:schemas-upnp-org:service:WANIPConnection:1#AddPortMapping\"" + headers=dict( + SOAPACTION='"urn:schemas-upnp-org:service:WANIPConnection:1#AddPortMapping"' ), - data = message, - asynchronous = False + data=message, + asynchronous=False, ) client = connection.owner client.close() def on_headers(client, parser, headers): location = headers.get("Location", None) - if not location: raise netius.DataError("No location found") + if not location: + raise netius.DataError("No location found") http_client = netius.clients.HTTPClient() - http_client.get(location, on_result = on_location) + http_client.get(location, on_result=on_location) client.close() client = netius.clients.SSDPClient() diff --git a/src/netius/extra/__init__.py b/src/netius/extra/__init__.py index fbd209da3..c1a75eab8 100644 --- a/src/netius/extra/__init__.py +++ b/src/netius/extra/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/extra/desktop.py b/src/netius/extra/desktop.py index fb14a7c21..3a0028e6d 100644 --- a/src/netius/extra/desktop.py +++ b/src/netius/extra/desktop.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,8 +30,11 @@ import netius.servers -try: import PIL.ImageGrab -except ImportError: PIL = None +try: + import PIL.ImageGrab +except ImportError: + PIL = None + class DesktopServer(netius.servers.MJPGServer): @@ -48,7 +42,8 @@ def get_delay(self, connection): return 1 def get_image(self, connection): - if not PIL: return None + if not PIL: + return None image = PIL.ImageGrab.grab() buffer = netius.legacy.BytesIO() try: @@ -58,8 +53,9 @@ def get_image(self, connection): buffer.close() return data + if __name__ == "__main__": server = DesktopServer() - server.serve(env = True) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/extra/dhcp_s.py b/src/netius/extra/dhcp_s.py index c2f0039db..ba79f2f17 100644 --- a/src/netius/extra/dhcp_s.py +++ b/src/netius/extra/dhcp_s.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,9 +33,10 @@ import netius.common import netius.servers + class DHCPServerS(netius.servers.DHCPServer): - def __init__(self, pool = None, options = {}, *args, **kwargs): + def __init__(self, pool=None, options={}, *args, **kwargs): netius.servers.DHCPServer.__init__(self, *args, **kwargs) self.pool = pool or netius.common.AddressPool("192.168.0.61", "192.168.0.69") @@ -58,12 +50,15 @@ def get_type(self, request): requested = request.get_requested() mac = request.get_mac() - if type == 0x01: result = netius.common.OFFER_DHCP + if type == 0x01: + result = netius.common.OFFER_DHCP elif type == 0x03: current = self.pool.assigned(mac) or requested is_owner = self.pool.is_owner(mac, current) - if is_owner: result = netius.common.ACK_DHCP - else: result = netius.common.NAK_DHCP + if is_owner: + result = netius.common.ACK_DHCP + else: + result = netius.common.NAK_DHCP return result @@ -73,8 +68,10 @@ def get_options(self, request): def get_yiaddr(self, request): type = request.get_type() - if type == 0x01: yiaddr = self._reserve(request) - elif type == 0x03: yiaddr = self._confirm(request) + if type == 0x01: + yiaddr = self._reserve(request) + elif type == 0x03: + yiaddr = self._confirm(request) return yiaddr def _build(self, options): @@ -83,47 +80,42 @@ def _build(self, options): for key, value in netius.legacy.iteritems(options): key_i = netius.common.OPTIONS_DHCP.get(key, None) - if not key_i: continue + if not key_i: + continue self.options[key_i] = value def _reserve(self, request): mac = request.get_mac() - return self.pool.reserve( - owner = mac, - lease = self.lease - ) + return self.pool.reserve(owner=mac, lease=self.lease) def _confirm(self, request): requested = request.get_requested() mac = request.get_mac() current = self.pool.assigned(mac) or requested is_valid = self.pool.is_valid(current) - if is_valid: self.pool.touch(current, self.lease) + if is_valid: + self.pool.touch(current, self.lease) return current + if __name__ == "__main__": import logging + host = netius.common.host() pool = netius.common.AddressPool("172.16.0.80", "172.16.0.89") options = dict( - router = dict(routers = ["172.16.0.6"]), - subnet = dict(subnet = "255.255.0.0"), - dns = dict( - servers = ["172.16.0.11", "172.16.0.12"] - ), - identifier = dict(identifier = host), - broadcast = dict(broadcast = "172.16.255.255"), - name = dict(name = "hive"), - lease = dict(time = 3600), - renewal = dict(time = 1800), - rebind = dict(time = 2700), - proxy = dict(url = "http://172.16.0.25:8080/proxy.pac") - ) - server = DHCPServerS( - pool = pool, - options = options, - level = logging.INFO + router=dict(routers=["172.16.0.6"]), + subnet=dict(subnet="255.255.0.0"), + dns=dict(servers=["172.16.0.11", "172.16.0.12"]), + identifier=dict(identifier=host), + broadcast=dict(broadcast="172.16.255.255"), + name=dict(name="hive"), + lease=dict(time=3600), + renewal=dict(time=1800), + rebind=dict(time=2700), + proxy=dict(url="http://172.16.0.25:8080/proxy.pac"), ) - server.serve(env = True) + server = DHCPServerS(pool=pool, options=options, level=logging.INFO) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/extra/file.py b/src/netius/extra/file.py index bd4188f94..cfb4815db 100644 --- a/src/netius/extra/file.py +++ b/src/netius/extra/file.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -50,11 +41,11 @@ sending the file to the client, this should not be neither to big nor to small (as both situations would create problems) """ -FOLDER_SVG = "" +FOLDER_SVG = '' """ The vector code to be used for the icon that represents a folder under the directory listing """ -FILE_SVG = "" +FILE_SVG = '' """ The vector code to be used for the icon that represents a plain file under the directory listing """ @@ -62,6 +53,7 @@ """ Simple base 64 encoded empty gif to avoid possible image corruption while rendering empty images on browser """ + class FileServer(netius.servers.HTTP2Server): """ Simple implementation of a file server that is able to list files @@ -76,14 +68,14 @@ class FileServer(netius.servers.HTTP2Server): def __init__( self, - base_path = "", - style_urls = [], - index_files = [], - path_regex = [], - list_dirs = True, - list_engine = "base", - cors = False, - cache = 0, + base_path="", + style_urls=[], + index_files=[], + path_regex=[], + list_dirs=True, + list_engine="base", + cors=False, + cache=0, *args, **kwargs ): @@ -98,7 +90,7 @@ def __init__( self.cache = cache @classmethod - def _sorter_build(cls, name = None): + def _sorter_build(cls, name=None): def sorter(item): is_dir = item["is_dir"] @@ -106,32 +98,32 @@ def sorter(item): is_dir_v = 0 if is_dir else 1 is_dir_v = -1 if is_top else is_dir_v - if name == "name": return (item["name"], is_dir_v) - if name == "modified": return (item["modified"], is_dir_v) - if name == "size": return (item["size"], is_dir_v) - if name == "type": return (item["type"], is_dir_v) + if name == "name": + return (item["name"], is_dir_v) + if name == "modified": + return (item["modified"], is_dir_v) + if name == "size": + return (item["size"], is_dir_v) + if name == "type": + return (item["type"], is_dir_v) return (is_dir_v, item["name"]) return sorter @classmethod - def _items_normalize( - cls, - items, - path, - pad = False, - space = True, - simplified = False - ): + def _items_normalize(cls, items, path, pad=False, space=True, simplified=False): _items = [] for item in items: - if netius.legacy.PYTHON_3: item_s = item - else: item_s = item.encode("utf-8") + if netius.legacy.PYTHON_3: + item_s = item + else: + item_s = item.encode("utf-8") path_f = os.path.join(path, item) - if not os.path.exists(path_f): continue + if not os.path.exists(path_f): + continue is_dir = os.path.isdir(path_f) item_s = item_s + "/" if is_dir and pad else item_s @@ -143,30 +135,28 @@ def _items_normalize( size = 0 if is_dir else os.path.getsize(path_f) size_s = netius.common.size_round_unit( - size, - space = space, - simplified = simplified + size, space=space, simplified=simplified ) size_s = "-" if is_dir else size_s - type_s, _encoding = mimetypes.guess_type(path_f, strict = True) + type_s, _encoding = mimetypes.guess_type(path_f, strict=True) type_s = type_s or "-" type_s = "Directory" if is_dir else type_s icon = FOLDER_SVG if is_dir else FILE_SVG _item = dict( - name = item, - name_s = item_s, - name_q = item_q, - is_dir = is_dir, - path = path_f, - modified = time_s, - size = size, - size_s = size_s, - type = type_s, - type_s = type_s, - icon = icon + name=item, + name_s=item_s, + name_q=item_q, + is_dir=is_dir, + path=path_f, + modified=time_s, + size=size, + size_s=size_s, + type=type_s, + type_s=type_s, + icon=icon, ) _items.append(_item) @@ -174,19 +164,16 @@ def _items_normalize( return _items @classmethod - def _gen_dir(cls, engine, path, path_v, query_m, style = True, style_urls = [], **kwargs): + def _gen_dir( + cls, engine, path, path_v, query_m, style=True, style_urls=[], **kwargs + ): gen_dir_method = getattr(cls, "_gen_dir_" + engine) return gen_dir_method( - path, - path_v, - query_m, - style = style, - style_urls = style_urls, - **kwargs + path, path_v, query_m, style=style, style_urls=style_urls, **kwargs ) @classmethod - def _gen_dir_base(cls, path, path_v, query_m, style = True, style_urls = [], **kwargs): + def _gen_dir_base(cls, path, path_v, query_m, style=True, style_urls=[], **kwargs): sort = query_m.get("sort", []) direction = query_m.get("direction", []) @@ -199,14 +186,12 @@ def _gen_dir_base(cls, path, path_v, query_m, style = True, style_urls = [], **k items = os.listdir(path) is_root = path_v == "" or path_v == "/" - if not is_root: items.insert(0, "..") + if not is_root: + items.insert(0, "..") - items = cls._items_normalize(items, path, pad = not style) - items.sort(key = lambda v: v["name"]) - items.sort( - key = cls._sorter_build(name = sort), - reverse = reverse - ) + items = cls._items_normalize(items, path, pad=not style) + items.sort(key=lambda v: v["name"]) + items.sort(key=cls._sorter_build(name=sort), reverse=reverse) path_n = path_v.rstrip("/") @@ -216,8 +201,9 @@ def _gen_dir_base(cls, path, path_v, query_m, style = True, style_urls = [], **k for item in paths[:-1]: current += item + "/" - path_b.append(" %s " % (current, item or "/")) - if not item: continue + path_b.append(' %s ' % (current, item or "/")) + if not item: + continue path_b.append("/") path_b.append(" %s" % (paths[-1] or "/")) @@ -225,33 +211,39 @@ def _gen_dir_base(cls, path, path_v, query_m, style = True, style_urls = [], **k path_s = path_s.strip() for value in cls._gen_header( - "Index of %s" % (path_n or "/"), - style = style, - style_urls = style_urls + "Index of %s" % (path_n or "/"), style=style, style_urls=style_urls ): yield value yield "" - yield "

Index of %s

" % path_s + yield '

Index of %s

' % path_s yield "
" yield "" yield "" yield "" - yield "" - yield "" - yield "" %\ - (_direction, "selected" if sort == "size" else "") + yield '' % ( + _direction, + "selected" if sort == "size" else "", + ) yield "" - yield "" %\ - (_direction, "selected" if sort == "type" else "") + yield '' % ( + _direction, + "selected" if sort == "type" else "", + ) yield "" yield "" yield "" @@ -259,8 +251,9 @@ def _gen_dir_base(cls, path, path_v, query_m, style = True, style_urls = [], **k for item in items: yield "" yield "" yield "" % item["modified"] yield "" % item["size_s"] @@ -274,7 +267,8 @@ def _gen_dir_base(cls, path, path_v, query_m, style = True, style_urls = [], **k yield "" yield "" - for value in cls._gen_footer(): yield value + for value in cls._gen_footer(): + yield value @classmethod def _gen_dir_apache(cls, path, path_v, query_m, **kwargs): @@ -292,66 +286,78 @@ def _gen_dir_apache(cls, path, path_v, query_m, **kwargs): items.insert(0, "..") items = cls._items_normalize( - items, - path, - pad = True, - space = False, - simplified = True - ) - items.sort(key = lambda v: v["name"]) - items.sort( - key = cls._sorter_build(name = sort), - reverse = reverse + items, path, pad=True, space=False, simplified=True ) + items.sort(key=lambda v: v["name"]) + items.sort(key=cls._sorter_build(name=sort), reverse=reverse) path_n = path_v.rstrip("/") - for value in cls._gen_header("Index of %s" % (path_n or "/"), style = False, meta = False): + for value in cls._gen_header( + "Index of %s" % (path_n or "/"), style=False, meta=False + ): yield value yield "" - yield "

Index of %s

" % (path_n or "/") + yield '

Index of %s

' % (path_n or "/") yield "
" - yield "Name" %\ - (_direction, "selected" if sort == "name" else "") + yield '' + yield 'Name' % ( + _direction, + "selected" if sort == "name" else "", + ) yield "" - yield "Last Modified" %\ - (_direction, "selected" if sort == "modified" else "") + yield '' + yield 'Last Modified' % ( + _direction, + "selected" if sort == "modified" else "", + ) yield "" - yield "Size' + yield 'Size" - yield "Type' + yield 'Type
" - if style: yield item["icon"] - yield "%s" % (item["name_q"], item["name_s"]) + if style: + yield item["icon"] + yield '%s' % (item["name_q"], item["name_s"]) yield "%s%s
" yield "" - yield "" % EMPTY_GIF + yield '' % EMPTY_GIF yield "" yield "" yield "" %\ - (_direction, "selected" if sort == "size" else "") + yield 'Size' % ( + _direction, + "selected" if sort == "size" else "", + ) yield "" yield "" %\ - (_direction, "selected" if sort == "description" else "") + yield 'Description' % ( + _direction, + "selected" if sort == "description" else "", + ) yield "" yield "" - yield "" + yield '' for item in items: - if item["name_s"] == "../": type_s = "PARENTDIR" - elif item["is_dir"]: type_s = "DIR" - else: type_s = "ARC" - if item["name_s"] == "../": name_s = "Parent Directory" - else: name_s = item["name_s"] + if item["name_s"] == "../": + type_s = "PARENTDIR" + elif item["is_dir"]: + type_s = "DIR" + else: + type_s = "ARC" + if item["name_s"] == "../": + name_s = "Parent Directory" + else: + name_s = item["name_s"] yield "" - yield "" % (EMPTY_GIF, type_s) - yield "" % (item["name_q"], name_s) + yield '' % ( + EMPTY_GIF, + type_s, + ) + yield '' % (item["name_q"], name_s) yield "" % item["modified"] - yield "" % item["size_s"] + yield '' % item["size_s"] yield "" % item["type_s"] yield "" yield "\n" - yield "" + yield '' yield "
\"[ICO]\"[ICO]" - yield "Name" %\ - (_direction, "selected" if sort == "name" else "") + yield 'Name' % ( + _direction, + "selected" if sort == "name" else "", + ) yield "" - yield "Last modified" %\ - (_direction, "selected" if sort == "modified" else "") + yield 'Last modified' % ( + _direction, + "selected" if sort == "modified" else "", + ) yield "" - yield "Size" - yield "Description


\"[%s]\"%s[%s]%s%s%s%s%s


" yield "
%s
" % netius.IDENTIFIER yield "" - for value in cls._gen_footer(): yield value + for value in cls._gen_footer(): + yield value @classmethod def _gen_dir_legacy(cls, path, path_v, query_m, **kwargs): @@ -372,17 +378,10 @@ def _gen_dir_legacy(cls, path, path_v, query_m, **kwargs): items.insert(0, "..") items = cls._items_normalize( - items, - path, - pad = True, - space = False, - simplified = True - ) - items.sort(key = lambda v: v["name"]) - items.sort( - key = cls._sorter_build(name = sort), - reverse = reverse + items, path, pad=True, space=False, simplified=True ) + items.sort(key=lambda v: v["name"]) + items.sort(key=cls._sorter_build(name=sort), reverse=reverse) max_length = max([len(item["name_s"]) for item in items] + [max_length]) padding_s = (max_length + spacing - 4) * " " @@ -390,34 +389,47 @@ def _gen_dir_legacy(cls, path, path_v, query_m, **kwargs): path_n = path_v.rstrip("/") - for value in cls._gen_header("Index of %s" % (path_n or "/"), style = False, meta = False): + for value in cls._gen_header( + "Index of %s" % (path_n or "/"), style=False, meta=False + ): yield value yield "" - yield "

Index of %s

" % (path_n or "/") + yield '

Index of %s

' % (path_n or "/") yield "
" yield "
"
-        yield "\"Icon" % EMPTY_GIF
-        yield "Name" %\
-            (_direction, "selected" if sort == "name" else "")
+        yield 'Icon ' % EMPTY_GIF
+        yield 'Name' % (
+            _direction,
+            "selected" if sort == "name" else "",
+        )
         yield padding_s
-        yield "Last modified" %\
-            (_direction, "selected" if sort == "modified" else "")
+        yield 'Last modified' % (
+            _direction,
+            "selected" if sort == "modified" else "",
+        )
         yield "   "
         yield spacing_s
-        yield "Size" %\
-            (_direction, "selected" if sort == "size" else "")
+        yield 'Size' % (
+            _direction,
+            "selected" if sort == "size" else "",
+        )
         yield "
" for item in items: - if item["name_s"] == "../": type_s = "PARENTDIR" - elif item["is_dir"]: type_s = "DIR" - else: type_s = "ARC" - if item["name_s"] == "../": name_s = "Parent Directory" - else: name_s = item["name_s"] + if item["name_s"] == "../": + type_s = "PARENTDIR" + elif item["is_dir"]: + type_s = "DIR" + else: + type_s = "ARC" + if item["name_s"] == "../": + name_s = "Parent Directory" + else: + name_s = item["name_s"] name_s = name_s[:max_length] padding_r = max_length - len(name_s) - yield "\"[%s]\"" % (EMPTY_GIF, type_s) - yield "%s" % (item["name_q"], name_s) + yield '[%s]' % (EMPTY_GIF, type_s) + yield '%s' % (item["name_q"], name_s) yield " " * padding_r yield spacing_s yield "%s%s%s" % (item["modified"], spacing_s, item["size_s"].ljust(5)) @@ -428,13 +440,15 @@ def _gen_dir_legacy(cls, path, path_v, query_m, **kwargs): yield "
%s
" % netius.IDENTIFIER yield "" - for value in cls._gen_footer(): yield value + for value in cls._gen_footer(): + yield value def on_connection_d(self, connection): netius.servers.HTTP2Server.on_connection_d(self, connection) file = hasattr(connection, "file") and connection.file - if file: file.close() + if file: + file.close() setattr(connection, "file", None) setattr(connection, "range", None) setattr(connection, "bytes_p", None) @@ -442,7 +456,8 @@ def on_connection_d(self, connection): def on_stream_d(self, stream): file = hasattr(stream, "file") and stream.file - if file: file.close() + if file: + file.close() setattr(stream, "file", None) setattr(stream, "range", None) setattr(stream, "bytes_p", None) @@ -450,26 +465,39 @@ def on_stream_d(self, stream): def on_serve(self): netius.servers.HTTP2Server.on_serve(self) - if self.env: self.base_path = self.get_env("BASE_PATH", self.base_path) - if self.env: self.style_urls = self.get_env("STYLE_URLS", self.style_urls, cast = list) - if self.env: self.index_files = self.get_env("INDEX_FILES", self.index_files, cast = list) - if self.env: self.path_regex = self.get_env( - "PATH_REGEX", - self.path_regex, - cast = lambda v: [i.split(":") for i in v.split(";")] - ) - if self.env: self.list_dirs = self.get_env("LIST_DIRS", self.list_dirs, cast = bool) - if self.env: self.list_engine = self.get_env("LIST_ENGINE", self.list_engine) - if self.env: self.cors = self.get_env("CORS", self.cors, cast = bool) - if self.env: self.cache = self.get_env("CACHE", self.cache, cast = int) + if self.env: + self.base_path = self.get_env("BASE_PATH", self.base_path) + if self.env: + self.style_urls = self.get_env("STYLE_URLS", self.style_urls, cast=list) + if self.env: + self.index_files = self.get_env("INDEX_FILES", self.index_files, cast=list) + if self.env: + self.path_regex = self.get_env( + "PATH_REGEX", + self.path_regex, + cast=lambda v: [i.split(":") for i in v.split(";")], + ) + if self.env: + self.list_dirs = self.get_env("LIST_DIRS", self.list_dirs, cast=bool) + if self.env: + self.list_engine = self.get_env("LIST_ENGINE", self.list_engine) + if self.env: + self.cors = self.get_env("CORS", self.cors, cast=bool) + if self.env: + self.cache = self.get_env("CACHE", self.cache, cast=int) self._build_regex() self.base_path = os.path.abspath(self.base_path) - self.cache_d = datetime.timedelta(seconds = self.cache) - self.base_path = netius.legacy.u(self.base_path, force = True) - self.info("Defining '%s' as the root of the file server ..." % (self.base_path or ".")) - if self.list_dirs: self.info("Listing directories with '%s' engine ..." % self.list_engine) - if self.cors: self.info("Cross origin resource sharing is enabled") - if self.cache: self.info("Resource cache set with %d seconds" % self.cache) + self.cache_d = datetime.timedelta(seconds=self.cache) + self.base_path = netius.legacy.u(self.base_path, force=True) + self.info( + "Defining '%s' as the root of the file server ..." % (self.base_path or ".") + ) + if self.list_dirs: + self.info("Listing directories with '%s' engine ..." % self.list_engine) + if self.cors: + self.info("Cross origin resource sharing is enabled") + if self.cache: + self.info("Resource cache set with %d seconds" % self.cache) def on_data_http(self, connection, parser): netius.servers.HTTP2Server.on_data_http(self, connection, parser) @@ -479,7 +507,8 @@ def on_data_http(self, connection, parser): # handled by the connection and so the current data processing # must be delayed until the file is processed (inserted in queue) if hasattr(connection, "file") and connection.file: - if not hasattr(connection, "queue"): connection.queue = [] + if not hasattr(connection, "queue"): + connection.queue = [] state = parser.get_state() connection.queue.append(state) return @@ -488,10 +517,10 @@ def on_data_http(self, connection, parser): # retrieves the requested path from the parser and the constructs # the correct file name/path to be used in the reading from the # current file system, so that it's possible to handle the data - path = parser.get_path(normalize = True) + path = parser.get_path(normalize=True) path = netius.legacy.unquote(path) path = path.lstrip("/") - path = netius.legacy.u(path, force = True) + path = netius.legacy.u(path, force=True) path = self._resolve(path) path_f = os.path.join(self.base_path, path) path_f = os.path.abspath(path_f) @@ -501,31 +530,36 @@ def on_data_http(self, connection, parser): # it's required to decode the path into an unicode string, if that's # the case the normal decoding process is used using the currently # defined file system encoding as defined in the specification - path_f = netius.legacy.u(path_f, encoding = "utf-8", force = True) + path_f = netius.legacy.u(path_f, encoding="utf-8", force=True) # verifies if the provided path starts with the contents of the # base path in case it does not it's a security issue and a proper # exception must be raised indicating the issue is_sub = path_f.startswith(self.base_path) - if not is_sub: raise netius.SecurityError("Invalid path") + if not is_sub: + raise netius.SecurityError("Invalid path") # verifies if the requested file exists in case it does not # raises an error indicating the problem so that the user is # notified about the failure to find the appropriate file - if not os.path.exists(path_f): self.on_no_file(connection); return + if not os.path.exists(path_f): + self.on_no_file(connection) + return # verifies if the currently resolved path refers an directory or # instead a normal file and handles each of the cases properly by # redirecting the request to the proper handlers is_dir = os.path.isdir(path_f) - if is_dir: self.on_dir_file(connection, parser, path_f) - else: self.on_normal_file(connection, parser, path_f) + if is_dir: + self.on_dir_file(connection, parser, path_f) + else: + self.on_normal_file(connection, parser, path_f) except BaseException as exception: # handles the exception gracefully by sending the contents of # it to the client and identifying the problem correctly self.on_exception_file(connection, exception) - def on_dir_file(self, connection, parser, path, style = True): + def on_dir_file(self, connection, parser, path, style=True): cls = self.__class__ path_v = parser.get_path() @@ -538,43 +572,44 @@ def on_dir_file(self, connection, parser, path, style = True): if not is_valid: path_q = netius.legacy.quote(path_v) connection.send_response( - data = "Permanent redirect", - headers = dict( - location = path_q + "/" - ), - code = 301, - apply = True + data="Permanent redirect", + headers=dict(location=path_q + "/"), + code=301, + apply=True, ) return for index_file in self.index_files: index_path = os.path.join(path, index_file) - if not os.path.exists(index_path): continue + if not os.path.exists(index_path): + continue return self.on_normal_file(connection, parser, index_path) if not self.list_dirs: self.on_no_file(connection) return - data = "".join(cls._gen_dir( - self.list_engine, - path, - path_v, - query_m, - style = style, - style_urls = self.style_urls - )) - data = netius.legacy.bytes(data, encoding = "utf-8", force = True) + data = "".join( + cls._gen_dir( + self.list_engine, + path, + path_v, + query_m, + style=style, + style_urls=self.style_urls, + ) + ) + data = netius.legacy.bytes(data, encoding="utf-8", force=True) headers = dict() headers["content-type"] = "text/html" connection.send_response( - data = data, - headers = headers, - code = 200, - apply = True, - callback = self._file_check_close + data=data, + headers=headers, + code=200, + apply=True, + callback=self._file_check_close, ) def on_normal_file(self, connection, parser, path): @@ -605,13 +640,15 @@ def on_normal_file(self, connection, parser, path): # in case the file did not change in the mean time the not modified # callback must be called to correctly handled the file no change - if not_modified: self.on_not_modified(connection, path); return + if not_modified: + self.on_not_modified(connection, path) + return # tries to guess the mime type of the file present in the target # file path that is going to be returned, this may fail as it's not # always possible to determine the correct mime type for a file # for suck situations the default mime type is used - type, _encoding = mimetypes.guess_type(path, strict = True) + type, _encoding = mimetypes.guess_type(path, strict=True) type = type or "application/octet-stream" # retrieves the size of the file that has just be resolved using @@ -630,7 +667,8 @@ def on_normal_file(self, connection, parser, path): start = int(start_s) if start_s else 0 end = int(end_s) if end_s else file_size - 1 range = (start, end) - else: range = (0, file_size - 1) + else: + range = (0, file_size - 1) # calculates the real data size of the chunk that is going to be # sent to the client this must use the normal range approach @@ -655,10 +693,14 @@ def on_normal_file(self, connection, parser, path): headers = dict() headers["etag"] = etag headers["content-length"] = "%d" % data_size - if self.cors: headers["access-control-allow-origin"] = "*" - if type: headers["content-type"] = type - if is_partial: headers["content-range"] = content_range_s - if not is_partial: headers["accept-ranges"] = "bytes" + if self.cors: + headers["access-control-allow-origin"] = "*" + if type: + headers["content-type"] = type + if is_partial: + headers["content-range"] = content_range_s + if not is_partial: + headers["accept-ranges"] = "bytes" # in case there's a valid cache defined must populate the proper header # fields so that cache is applied to the request @@ -679,68 +721,61 @@ def on_normal_file(self, connection, parser, path): # operation is the initial sending of the file contents so that the # sending of the proper file contents starts with success connection.send_response( - headers = headers, - code = code, - apply = True, - final = False, - flush = False, - callback = self._file_send + headers=headers, + code=code, + apply=True, + final=False, + flush=False, + callback=self._file_send, ) def on_no_file(self, connection): cls = self.__class__ connection.send_response( - data = cls.build_text( - "File not found", - style_urls = self.style_urls - ), - headers = dict( - connection = "close" - ), - code = 404, - apply = True, - callback = self._file_close + data=cls.build_text("File not found", style_urls=self.style_urls), + headers=dict(connection="close"), + code=404, + apply=True, + callback=self._file_close, ) def on_exception_file(self, connection, exception): cls = self.__class__ connection.send_response( - data = cls.build_text( + data=cls.build_text( "Problem handling request - %s" % str(exception), - trace = self.is_devel(), - style_urls = self.style_urls - ), - headers = dict( - connection = "close" + trace=self.is_devel(), + style_urls=self.style_urls, ), - code = 500, - apply = True, - callback = self._file_close + headers=dict(connection="close"), + code=500, + apply=True, + callback=self._file_close, ) def on_not_modified(self, connection, path): connection.set_encoding(netius.common.PLAIN_ENCODING) - connection.send_response( - data = "", - code = 304, - apply = True - ) + connection.send_response(data="", code=304, apply=True) def _next_queue(self, connection): # verifies if the current connection already contains a reference to # the queue structure that handles the queuing/pipelining of requests # if it does not or the queue is empty returns immediately, as there's # nothing currently pending to be done/processed - if not hasattr(connection, "queue"): return - if not connection.queue: return + if not hasattr(connection, "queue"): + return + if not connection.queue: + return # retrieves the state (of the parser) as the payload of the next element # in the queue and then uses it to construct a mock parser object that is # going to be used to simulate an on data call to the file server state = connection.queue.pop(0) parser = netius.common.HTTPParser.mock(connection.parser.owner, state) - try: self.on_data_http(connection, parser) - finally: parser.destroy() + try: + self.on_data_http(connection, parser) + finally: + parser.destroy() def _file_send(self, connection): file = connection.file @@ -752,11 +787,7 @@ def _file_send(self, connection): connection.bytes_p -= data_l is_final = not data or connection.bytes_p == 0 callback = self._file_finish if is_final else self._file_send - connection.send_part( - data, - final = False, - callback = callback - ) + connection.send_part(data, final=False, callback=callback) def _file_finish(self, connection): connection.file.close() @@ -765,34 +796,40 @@ def _file_finish(self, connection): connection.bytes_p = None is_keep_alive = connection.parser.keep_alive callback = None if is_keep_alive else self._file_close - connection.flush_s(callback = callback) + connection.flush_s(callback=callback) self._next_queue(connection) def _file_close(self, connection): - connection.close(flush = True) + connection.close(flush=True) def _file_check_close(self, connection): - if connection.parser.keep_alive: return - connection.close(flush = True) + if connection.parser.keep_alive: + return + connection.close(flush=True) def _resolve(self, path): path, result = self._resolve_regex(path) - if result: return path + if result: + return path return path def _build_regex(self): - self.path_regex = [(re.compile(regex), value) for regex, value in self.path_regex] + self.path_regex = [ + (re.compile(regex), value) for regex, value in self.path_regex + ] def _resolve_regex(self, path): for regex, value in self.path_regex: - if not regex.match(path): continue + if not regex.match(path): + continue return (value, True) return (path, False) + if __name__ == "__main__": import logging - server = FileServer(level = logging.INFO) - server.serve(env = True) + server = FileServer(level=logging.INFO) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/extra/filea.py b/src/netius/extra/filea.py index 19eb92da8..b460ea7c6 100644 --- a/src/netius/extra/filea.py +++ b/src/netius/extra/filea.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -44,6 +35,7 @@ handles more data for each chunk, this is required to avoid extreme amounts of overhead in the file pool """ + class FileAsyncServer(_file.FileServer): """ Simple implementation of a file server that uses the async @@ -60,12 +52,16 @@ class FileAsyncServer(_file.FileServer): def on_connection_d(self, connection): file = hasattr(connection, "file") and connection.file - if file: self.fclose(file); connection.file = None + if file: + self.fclose(file) + connection.file = None _file.FileServer.on_connection_d(self, connection) def on_stream_d(self, stream): file = hasattr(stream, "file") and stream.file - if file: self.fclose(file); stream.file = None + if file: + self.fclose(file) + stream.file = None _file.FileServer.on_stream_d(self, stream) def _file_send(self, connection): @@ -75,24 +71,23 @@ def _file_send(self, connection): buffer_s = connection.bytes_p if is_larger else BUFFER_SIZE def callback(data, *args, **kwargs): - if connection.file == None: return - if isinstance(data, BaseException): return + if connection.file == None: + return + if isinstance(data, BaseException): + return data_l = len(data) if data else 0 connection.bytes_p -= data_l is_final = not data or connection.bytes_p == 0 callback = self._file_finish if is_final else self._file_send - connection.send_part( - data, - final = False, - callback = callback - ) + connection.send_part(data, final=False, callback=callback) + + self.fread(file, buffer_s, data=callback) - self.fread(file, buffer_s, data = callback) if __name__ == "__main__": import logging - server = FileAsyncServer(level = logging.INFO) - server.serve(env = True) + server = FileAsyncServer(level=logging.INFO) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/extra/hello.py b/src/netius/extra/hello.py index 500c62972..c39ece690 100644 --- a/src/netius/extra/hello.py +++ b/src/netius/extra/hello.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ import netius.servers + class HelloServer(netius.servers.HTTP2Server): """ Simple Hello (World) HTTP server meant to be used for benchmarks @@ -51,36 +43,33 @@ class HelloServer(netius.servers.HTTP2Server): or adding new features to this server implementation. """ - def __init__(self, message = "Hello World", *args, **kwargs): + def __init__(self, message="Hello World", *args, **kwargs): netius.servers.HTTP2Server.__init__(self, *args, **kwargs) self.message = message def on_serve(self): netius.servers.HTTP2Server.on_serve(self) - if self.env: self.message = self.get_env("MESSAGE", self.message, cast = str) - if self.env: self.keep_alive = self.get_env("KEEP_ALIVE", True, cast = bool) + if self.env: + self.message = self.get_env("MESSAGE", self.message, cast=str) + if self.env: + self.keep_alive = self.get_env("KEEP_ALIVE", True, cast=bool) self.info("Serving '%s' as welcome message ..." % self.message) def on_data_http(self, connection, parser): - netius.servers.HTTP2Server.on_data_http( - self, connection, parser - ) + netius.servers.HTTP2Server.on_data_http(self, connection, parser) keep_alive = self.keep_alive and parser.keep_alive callback = self._hello_keep if keep_alive else self._hello_close connection_s = "keep-alive" if keep_alive else "close" - headers = { - "Connection" : connection_s, - "Content-Type" : "text/plain" - } + headers = {"Connection": connection_s, "Content-Type": "text/plain"} connection.send_response( - data = self.message, - headers = headers, - code = 200, - code_s = "OK", - apply = True, - callback = callback + data=self.message, + headers=headers, + code=200, + code_s="OK", + apply=True, + callback=callback, ) def _hello_close(self, connection): @@ -89,9 +78,11 @@ def _hello_close(self, connection): def _hello_keep(self, connection): pass + if __name__ == "__main__": import logging - server = HelloServer(level = logging.INFO) - server.serve(env = True) + + server = HelloServer(level=logging.INFO) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/extra/hello_w.py b/src/netius/extra/hello_w.py index 04e2b9122..a111aa20e 100644 --- a/src/netius/extra/hello_w.py +++ b/src/netius/extra/hello_w.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ import netius.servers + def app(environ, start_response): status = "200 OK" contents = "Hello World" @@ -46,13 +38,14 @@ def app(environ, start_response): headers = ( ("Content-Length", content_l), ("Content-type", "text/plain"), - ("Connection", "keep-alive") + ("Connection", "keep-alive"), ) start_response(status, headers) yield contents + if __name__ == "__main__": - server = netius.servers.WSGIServer(app = app) - server.serve(env = True) + server = netius.servers.WSGIServer(app=app) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/extra/proxy_d.py b/src/netius/extra/proxy_d.py index e4bfc082d..27df91c3d 100644 --- a/src/netius/extra/proxy_d.py +++ b/src/netius/extra/proxy_d.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,6 +34,7 @@ from . import proxy_r + class DockerProxyServer(proxy_r.ReverseProxyServer): """ Specialized reverse proxy server that handles many of the @@ -55,18 +47,17 @@ class DockerProxyServer(proxy_r.ReverseProxyServer): of the Docker proxy servers is not limited to such use cases. """ - def __init__(self, host_suffixes = [], *args, **kwargs): + def __init__(self, host_suffixes=[], *args, **kwargs): proxy_r.ReverseProxyServer.__init__(self, *args, **kwargs) - self.load_config(host_suffixes = host_suffixes) + self.load_config(host_suffixes=host_suffixes) self._build_docker() def on_serve(self): proxy_r.ReverseProxyServer.on_serve(self) - if self.env: self.host_suffixes = self.get_env( - "HOST_SUFFIXES", - self.host_suffixes, - cast = list - ) + if self.env: + self.host_suffixes = self.get_env( + "HOST_SUFFIXES", self.host_suffixes, cast=list + ) self._build_suffixes() self._build_redirect_ssl() @@ -79,7 +70,7 @@ def _build_docker(self): self._build_error_urls() self._build_redirect_ssl() - def _build_regex(self, token = "$", sort = True): + def _build_regex(self, token="$", sort=True): # retrieves the complete set of configuration values with the # regex suffix so that they are going to be used for the creation # of the regex rules (as expected) @@ -88,20 +79,23 @@ def _build_regex(self, token = "$", sort = True): # retrieves the complete set of names from the linked items and then # in case the sort flag is set sorts their values (proper order) names = netius.legacy.keys(linked) - if sort: names.sort() + if sort: + names.sort() # iterates over the complete set of linked regex values splitting # the values around the proper token and adding them to the regex for name in names: value = linked[name] value_s = value.split(token, 1) - if not len(value_s) == 2: continue + if not len(value_s) == 2: + continue regex, target = value_s - if not self._valid_url(target): continue + if not self._valid_url(target): + continue rule = (re.compile(regex), target) self.regex.append(rule) - def _build_hosts(self, alias = True): + def _build_hosts(self, alias=True): # tries to retrieve the complete set of configuration # values associated with the port suffix, this represents # the possible linked container addresses @@ -122,19 +116,24 @@ def _build_hosts(self, alias = True): # in case this port value represent a service name_ref = base.upper() + "_NAME" name_value = netius.conf(name_ref, None) - if not name_value: continue + if not name_value: + continue # runs a series of validation on both the base and name # value to make sure that this value represents a valid # linked service/container # linked service/container (valid name reference found) - if name.endswith("_ENV_PORT"): continue - if not name.find("_ENV_") == -1: continue - if base[-1].isdigit() and name_value[-1].isdigit(): continue + if name.endswith("_ENV_PORT"): + continue + if not name.find("_ENV_") == -1: + continue + if base[-1].isdigit() and name_value[-1].isdigit(): + continue # validates that the provided host is a valid URL value and # if that's not the case continues the loop (ignores) - if not self._valid_url(host): continue + if not self._valid_url(host): + continue # replaces the prefix of the reference (assumes HTTP) and # then adds the base value to the registered hosts @@ -145,13 +144,16 @@ def _build_hosts(self, alias = True): # validates that the dashed version of the name is not the # same as the base one (at least one underscore) and if that's # not the case skips the current iteration - if base == base_dash: continue + if base == base_dash: + continue # checks if the alias based registration is enabled and adds # the dashed version as an alias for such case or as an host # otherwise (static registration) - if alias: self.alias[base_dash] = base - else: self.hosts[base_dash] = host + if alias: + self.alias[base_dash] = base + else: + self.hosts[base_dash] = host def _build_alias(self): linked = netius.conf_suffix("_ALIAS") @@ -166,7 +168,7 @@ def _build_passwords(self): for name, password in netius.legacy.iteritems(linked): base = name[:-9].lower() base_dash = base.replace("_", "-") - simple_auth = netius.SimpleAuth(password = password) + simple_auth = netius.SimpleAuth(password=password) self.auth[base] = simple_auth self.auth[base_dash] = simple_auth @@ -186,20 +188,22 @@ def _build_error_urls(self): self.error_urls[base] = error_url self.error_urls[base_dash] = error_url - def _build_redirect_ssl(self, alias = True): + def _build_redirect_ssl(self, alias=True): linked = netius.conf_suffix("_REDIRECT_SSL") for name, _force in netius.legacy.iteritems(linked): base = name[:-13].lower() base_dash = base.replace("_", "-") self.redirect[base] = (base, "https") self.redirect[base_dash] = (base_dash, "https") - if not alias: continue + if not alias: + continue for key, value in netius.legacy.iteritems(self.alias): is_match = value in (base, base_dash) - if not is_match: continue + if not is_match: + continue self.redirect[key] = (key, "https") - def _build_suffixes(self, alias = True, redirect = True): + def _build_suffixes(self, alias=True, redirect=True): for host_suffix in self.host_suffixes: self.info("Registering %s host suffix" % host_suffix) for alias, value in netius.legacy.items(self.alias): @@ -207,8 +211,10 @@ def _build_suffixes(self, alias = True, redirect = True): self.alias[fqn] = value for name, value in netius.legacy.items(self.hosts): fqn = name + "." + str(host_suffix) - if alias: self.alias[fqn] = name - else: self.hosts[fqn] = value + if alias: + self.alias[fqn] = name + else: + self.hosts[fqn] = value def _valid_url(self, value): """ @@ -227,12 +233,15 @@ def _valid_url(self, value): value = str(value) result = netius.legacy.urlparse(value) - if not result.scheme: return False - if not result.hostname: return False + if not result.scheme: + return False + if not result.hostname: + return False return True + if __name__ == "__main__": server = DockerProxyServer() - server.serve(env = True) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/extra/proxy_f.py b/src/netius/extra/proxy_f.py index f6d4c8810..03f06fecd 100644 --- a/src/netius/extra/proxy_f.py +++ b/src/netius/extra/proxy_f.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,11 +33,12 @@ import netius.common import netius.servers + class ForwardProxyServer(netius.servers.ProxyServer): - def __init__(self, config = "proxy.json", rules = {}, *args, **kwargs): + def __init__(self, config="proxy.json", rules={}, *args, **kwargs): netius.servers.ProxyServer.__init__(self, *args, **kwargs) - self.load_config(path = config, rules = rules) + self.load_config(path=config, rules=rules) self.compile() def on_headers(self, connection, parser): @@ -62,20 +54,19 @@ def on_headers(self, connection, parser): rejected = False for rule in self.rules.values(): rejected = rule.match(path) - if rejected: break + if rejected: + break if rejected: self.debug("This connection is not allowed") connection.send_response( - data = cls.build_text("This connection is not allowed"), - headers = dict( - connection = "close" - ), - version = version_s, - code = 403, - code_s = "Forbidden", - apply = True, - callback = self._prx_close + data=cls.build_text("This connection is not allowed"), + headers=dict(connection="close"), + version=version_s, + code=403, + code_s="Forbidden", + apply=True, + callback=self._prx_close, ) return @@ -91,21 +82,25 @@ def on_headers(self, connection, parser): proxy_c = hasattr(connection, "proxy_c") and connection.proxy_c proxy_c = proxy_c or None connection.proxy_c = None - if proxy_c in self.conn_map: del self.conn_map[proxy_c] + if proxy_c in self.conn_map: + del self.conn_map[proxy_c] encoding = headers.get("transfer-encoding", None) is_chunked = encoding == "chunked" - encoding = netius.common.CHUNKED_ENCODING if is_chunked else\ - netius.common.PLAIN_ENCODING + encoding = ( + netius.common.CHUNKED_ENCODING + if is_chunked + else netius.common.PLAIN_ENCODING + ) _connection = self.http_client.method( method, path, - headers = headers, - encoding = encoding, - encodings = None, - safe = True, - connection = proxy_c + headers=headers, + encoding=encoding, + encodings=None, + safe=True, + connection=proxy_c, ) self.debug("Setting connection as waiting, proxy connection loading ...") @@ -120,15 +115,12 @@ def compile(self): for key, rule in netius.legacy.items(self.rules): self.rules[key] = re.compile(rule) + if __name__ == "__main__": import logging - rules = dict( - facebook = ".*facebook.com.*" - ) - server = ForwardProxyServer( - rules = rules, - level = logging.INFO - ) - server.serve(env = True) + + rules = dict(facebook=".*facebook.com.*") + server = ForwardProxyServer(rules=rules, level=logging.INFO) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/extra/proxy_r.py b/src/netius/extra/proxy_r.py index fb214feaa..3a57cb13e 100644 --- a/src/netius/extra/proxy_r.py +++ b/src/netius/extra/proxy_r.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -49,6 +40,7 @@ to resolve host based map information, should not be changed as it represents the default way of setting values """ + class ReverseProxyServer(netius.servers.ProxyServer): """ Reverse HTTP proxy implementation based on the more generalized @@ -67,93 +59,107 @@ class ReverseProxyServer(netius.servers.ProxyServer): def __init__( self, - config = "proxy.json", - regex = {}, - hosts = {}, - alias = {}, - auth = {}, - auth_regex = {}, - redirect = {}, - redirect_regex = {}, - error_urls = {}, - forward = None, - strategy = "robin", - reuse = True, - sts = 0, - resolve = True, - resolve_t = 120.0, - host_f = False, - echo = False, + config="proxy.json", + regex={}, + hosts={}, + alias={}, + auth={}, + auth_regex={}, + redirect={}, + redirect_regex={}, + error_urls={}, + forward=None, + strategy="robin", + reuse=True, + sts=0, + resolve=True, + resolve_t=120.0, + host_f=False, + echo=False, *args, **kwargs ): netius.servers.ProxyServer.__init__(self, *args, **kwargs) - if isinstance(regex, dict): regex = netius.legacy.items(regex) - if isinstance(auth_regex, dict): auth_regex = netius.legacy.items(auth_regex) - if isinstance(redirect_regex, dict): redirect_regex = netius.legacy.items(redirect_regex) - if not isinstance(hosts, dict): hosts = dict(hosts) + if isinstance(regex, dict): + regex = netius.legacy.items(regex) + if isinstance(auth_regex, dict): + auth_regex = netius.legacy.items(auth_regex) + if isinstance(redirect_regex, dict): + redirect_regex = netius.legacy.items(redirect_regex) + if not isinstance(hosts, dict): + hosts = dict(hosts) self.load_config( - path = config, - regex = regex, - hosts = hosts, - alias = alias, - auth = auth, - auth_regex = auth_regex, - redirect = redirect, - redirect_regex = redirect_regex, - error_urls = error_urls, - forward = forward, - strategy = strategy, - reuse = reuse, - sts = sts, - resolve = resolve, - resolve_t = resolve_t, - host_f = host_f, - echo = echo, - robin = dict(), - smart = netius.common.PriorityDict() + path=config, + regex=regex, + hosts=hosts, + alias=alias, + auth=auth, + auth_regex=auth_regex, + redirect=redirect, + redirect_regex=redirect_regex, + error_urls=error_urls, + forward=forward, + strategy=strategy, + reuse=reuse, + sts=sts, + resolve=resolve, + resolve_t=resolve_t, + host_f=host_f, + echo=echo, + robin=dict(), + smart=netius.common.PriorityDict(), ) self.hosts_o = None self.busy_conn = 0 self._set_strategy() - def info_dict(self, full = False): - info = netius.servers.ProxyServer.info_dict(self, full = full) - info.update( - reuse = self.reuse, - strategy = self.strategy, - busy_conn = self.busy_conn - ) + def info_dict(self, full=False): + info = netius.servers.ProxyServer.info_dict(self, full=full) + info.update(reuse=self.reuse, strategy=self.strategy, busy_conn=self.busy_conn) return info def proxy_r_dict(self): - return dict( - hosts = self.hosts, - hosts_o = self.hosts_o - ) + return dict(hosts=self.hosts, hosts_o=self.hosts_o) def on_diag(self): self.diag_app.add_route("GET", "/proxy_r", self.proxy_r_dict) def on_start(self): netius.servers.ProxyServer.on_start(self) - if self.resolve: self.dns_start(timeout = self.resolve_t) + if self.resolve: + self.dns_start(timeout=self.resolve_t) def on_serve(self): netius.servers.ProxyServer.on_serve(self) - if self.env: self.sts = self.get_env("STS", self.sts, cast = int) - if self.env: self.echo = self.get_env("ECHO", self.echo, cast = bool) - if self.env: self.resolve = self.get_env("RESOLVE", self.resolve, cast = bool) - if self.env: self.resolve_t = self.get_env("RESOLVE_TIMEOUT", self.resolve_t, cast = float) - if self.env: self.host_f = self.get_env("HOST_FORWARD", self.host_f, cast = float) - if self.env: self.reuse = self.get_env("REUSE", self.reuse, cast = bool) - if self.env: self.strategy = self.get_env("STRATEGY", self.strategy) - if self.env: self.x_forwarded_port = self.get_env("X_FORWARDED_PORT", None) - if self.env: self.x_forwarded_proto = self.get_env("X_FORWARDED_PROTO", None) - if self.sts: self.info("Strict transport security set to %d seconds" % self.sts) - if self.resolve: self.info("DNS based resolution enabled in proxy with %.2fs timeout" % self.resolve_t) - if self.strategy: self.info("Using '%s' as load balancing strategy" % self.strategy) - if self.echo: self._echo() + if self.env: + self.sts = self.get_env("STS", self.sts, cast=int) + if self.env: + self.echo = self.get_env("ECHO", self.echo, cast=bool) + if self.env: + self.resolve = self.get_env("RESOLVE", self.resolve, cast=bool) + if self.env: + self.resolve_t = self.get_env("RESOLVE_TIMEOUT", self.resolve_t, cast=float) + if self.env: + self.host_f = self.get_env("HOST_FORWARD", self.host_f, cast=float) + if self.env: + self.reuse = self.get_env("REUSE", self.reuse, cast=bool) + if self.env: + self.strategy = self.get_env("STRATEGY", self.strategy) + if self.env: + self.x_forwarded_port = self.get_env("X_FORWARDED_PORT", None) + if self.env: + self.x_forwarded_proto = self.get_env("X_FORWARDED_PROTO", None) + if self.sts: + self.info("Strict transport security set to %d seconds" % self.sts) + if self.resolve: + self.info( + "DNS based resolution enabled in proxy with %.2fs timeout" + % self.resolve_t + ) + if self.strategy: + self.info("Using '%s' as load balancing strategy" % self.strategy) + if self.echo: + self._echo() self._set_strategy() def on_headers(self, connection, parser): @@ -194,13 +200,15 @@ def on_headers(self, connection, parser): # the port of the server bind port and the host based port and falling # back to the forwarded port in case the "origin" is considered "trustable" port = port_s or str(self.port) - if self.trust_origin: port = headers.get("x-forwarded-port", port) + if self.trust_origin: + port = headers.get("x-forwarded-port", port) # tries to discover the protocol representation of the current # connections, note that the forwarded for header is only used in case # the current "origin" is considered "trustable" protocol = "https" if is_secure else "http" - if self.trust_origin: protocol = headers.get("x-forwarded-proto", protocol) + if self.trust_origin: + protocol = headers.get("x-forwarded-proto", protocol) # constructs the URL that is going to be used by the rule engine and any # other internal resolution process as the canonical URL of the request @@ -213,12 +221,15 @@ def on_headers(self, connection, parser): redirect = self.redirect.get(host, redirect) redirect = self.redirect.get(host_s, redirect) redirect = self.redirect.get(host_o, redirect) - redirect, _match = self._resolve_regex(url, self.redirect_regex, default = redirect) + redirect, _match = self._resolve_regex( + url, self.redirect_regex, default=redirect + ) if redirect: # verifies if the redirect value is a sequence and if that's # not the case converts the value into a tuple value is_sequence = isinstance(redirect, (list, tuple)) - if not is_sequence: redirect = (redirect,) + if not is_sequence: + redirect = (redirect,) # converts the possible tuple value into a list so that it # may be changed (mutable sequence is required) @@ -226,8 +237,10 @@ def on_headers(self, connection, parser): # adds the proper trail values to the redirect sequence so that # both the protocol and the path are ensured to exist - if len(redirect) == 1: redirect += [protocol, path] - if len(redirect) == 2: redirect += [path] + if len(redirect) == 1: + redirect += [protocol, path] + if len(redirect) == 2: + redirect += [path] # unpacks the redirect sequence and builds the new location # value taking into account the proper values @@ -237,8 +250,9 @@ def on_headers(self, connection, parser): # verifies if the current request already matched the redirection # rule and if that't the case ignores the, note that the host value # is verified against all the possible combinations - is_match = host_o == redirect_t or host_s == redirect_t or\ - host == redirect_t + is_match = ( + host_o == redirect_t or host_s == redirect_t or host == redirect_t + ) is_match &= protocol == protocol_t is_match &= path == path_t redirect = not is_match @@ -248,13 +262,11 @@ def on_headers(self, connection, parser): # if the redirection is really required if redirect: return connection.send_response( - headers = dict( - location = location - ), - version = version_s, - code = 303, - code_s = "See Other", - apply = True + headers=dict(location=location), + version=version_s, + code=303, + code_s="See Other", + apply=True, ) # tries to extract the various attributes of the current connection @@ -272,7 +284,8 @@ def on_headers(self, connection, parser): # in order to obtain the possible prefix value for URL reconstruction, # a state value is also retrieved, this value will be used latter for # the acquiring and releasing parts of the balancing strategy operation - if not self.reuse or not reusable: prefix, state = self.rules(url, parser) + if not self.reuse or not reusable: + prefix, state = self.rules(url, parser) # in case no prefix is defined at this stage there's no matching # against the currently defined rules and so an error must be raised @@ -280,15 +293,13 @@ def on_headers(self, connection, parser): if not prefix: self.debug("No proxy endpoint found") return connection.send_response( - data = cls.build_text("No proxy endpoint found"), - headers = dict( - connection = "close" - ), - version = version_s, - code = 404, - code_s = "Not Found", - apply = True, - callback = self._prx_close + data=cls.build_text("No proxy endpoint found"), + headers=dict(connection="close"), + version=version_s, + code=404, + code_s="Not Found", + apply=True, + callback=self._prx_close, ) # verifies if the current host requires some kind of special authorization @@ -297,7 +308,7 @@ def on_headers(self, connection, parser): auth = self.auth.get(host, auth) auth = self.auth.get(host_s, auth) auth = self.auth.get(host_o, auth) - auth, _match = self._resolve_regex(url, self.auth_regex, default = auth) + auth, _match = self._resolve_regex(url, self.auth_regex, default=auth) if auth: # determines if the provided authentication method is a sequence # and if that't not the case casts it as one (iterative validation) @@ -310,8 +321,9 @@ def on_headers(self, connection, parser): # situation/request, note that this is considered an or operation # and should be used carefully to avoid unexpected behavior for auth in auths: - result = self.authorize(connection, parser, auth = auth) - if result: break + result = self.authorize(connection, parser, auth=auth) + if result: + break # in case the result of the authentication chain is not valid (none # of the authentication methods was successful) sends the invalid @@ -319,16 +331,16 @@ def on_headers(self, connection, parser): if not result: self.debug("Not authorized") return connection.send_response( - data = cls.build_text("Not authorized"), - headers = { - "connection" : "close", - "wWW-authenticate" : "Basic realm=\"default\"" + data=cls.build_text("Not authorized"), + headers={ + "connection": "close", + "wWW-authenticate": 'Basic realm="default"', }, - version = version_s, - code = 401, - code_s = "Not Authorized", - apply = True, - callback = self._prx_close + version=version_s, + code=401, + code_s="Not Authorized", + apply=True, + callback=self._prx_close, ) # tries to use all the possible strategies to retrieve the best possible @@ -362,14 +374,17 @@ def on_headers(self, connection, parser): if self.host_f: target = self.hosts.get(host_s, None) target = self.hosts.get(host, target) - if host_s in self.hosts_o: target = self.hosts_o[host_s][0] - if host in self.hosts_o: target = self.hosts_o[host][0] + if host_s in self.hosts_o: + target = self.hosts_o[host_s][0] + if host in self.hosts_o: + target = self.hosts_o[host][0] if target and len(target) == 1: domain = netius.legacy.urlparse(target[0]).netloc # in case the domain value was resolved then it's set in the host # header to simulate proper back-end HTTP connection - if domain: headers["host"] = domain + if domain: + headers["host"] = domain # updates the various headers that are related with the reverse # proxy operation this is required so that the request gets @@ -378,8 +393,12 @@ def on_headers(self, connection, parser): headers["x-real-ip"] = address headers["x-client-ip"] = address headers["x-forwarded-for"] = address - headers["x-forwarded-proto"] = self.x_forwarded_proto if self.x_forwarded_proto else protocol - headers["x-forwarded-port"] = self.x_forwarded_port if self.x_forwarded_port else port + headers["x-forwarded-proto"] = ( + self.x_forwarded_proto if self.x_forwarded_proto else protocol + ) + headers["x-forwarded-port"] = ( + self.x_forwarded_port if self.x_forwarded_port else port + ) headers["x-forwarded-host"] = host_o # verifies if the current connection already contains a valid @@ -389,14 +408,18 @@ def on_headers(self, connection, parser): proxy_c = hasattr(connection, "proxy_c") and connection.proxy_c proxy_c = proxy_c or None connection.proxy_c = None - if proxy_c in self.conn_map: del self.conn_map[proxy_c] + if proxy_c in self.conn_map: + del self.conn_map[proxy_c] # tries to determine the transfer encoding of the received request # and by using that determines the proper encoding to be applied encoding = headers.pop("transfer-encoding", None) is_chunked = encoding == "chunked" - encoding = netius.common.CHUNKED_ENCODING if is_chunked else\ - netius.common.PLAIN_ENCODING + encoding = ( + netius.common.CHUNKED_ENCODING + if is_chunked + else netius.common.PLAIN_ENCODING + ) # calls the proper (HTTP) method in the client this should acquire # a new connection and start the process of sending the request @@ -404,11 +427,11 @@ def on_headers(self, connection, parser): _connection = self.http_client.method( method, url, - headers = headers, - encoding = encoding, - encodings = None, - safe = True, - connection = proxy_c + headers=headers, + encoding=encoding, + encodings=None, + safe=True, + connection=proxy_c, ) # sets the state attribute in the connection so that it's possible @@ -445,11 +468,14 @@ def on_headers(self, connection, parser): def rules(self, url, parser): resolved = self.rules_regex(url, parser) - if resolved[0]: return resolved + if resolved[0]: + return resolved resolved = self.rules_host(url, parser) - if resolved[0]: return resolved + if resolved[0]: + return resolved resolved = self.rules_forward(url, parser) - if resolved[0]: return resolved + if resolved[0]: + return resolved return None, None def rules_regex(self, url, parser): @@ -462,7 +488,8 @@ def rules_regex(self, url, parser): # sequence of regex values, this is an iterative process in case # there's no match the default value is returned immediately _prefix, match = self._resolve_regex(url, self.regex) - if not _prefix: return prefix, state + if not _prefix: + return prefix, state # prints a debug message about the matching that has just occurred # so that proper debugging may take place if required @@ -472,7 +499,8 @@ def rules_regex(self, url, parser): # proper final prefix and its associated state _prefix, _state = self.balancer(_prefix) groups = match.groups() - if groups: _prefix = _prefix.format(*groups) + if groups: + _prefix = _prefix.format(*groups) prefix = _prefix state = _state @@ -523,7 +551,8 @@ def rules_forward(self, url, parser): def balancer(self, values): is_sequence = isinstance(values, (list, tuple)) - if not is_sequence: return values, None + if not is_sequence: + return values, None return self.balancer_m(values) def balancer_robin(self, values): @@ -537,7 +566,8 @@ def balancer_smart(self, values): queue = self.smart.get(values, None) if not queue: queue = netius.common.PriorityDict() - for value in values: queue[value] = [0, 0] + for value in values: + queue[value] = [0, 0] self.smart[values] = queue prefix = queue.smallest() @@ -551,7 +581,8 @@ def acquirer_robin(self, state): pass def acquirer_smart(self, state): - if not state: return + if not state: + return prefix, queue = state sorter = queue[prefix] sorter[0] += 1 @@ -564,14 +595,16 @@ def releaser_robin(self, state): pass def releaser_smart(self, state): - if not state: return + if not state: + return prefix, queue = state sorter = queue[prefix] sorter[0] -= 1 - if sorter[0] == 0: sorter[1] = time.time() * -1 + if sorter[0] == 0: + sorter[1] = time.time() * -1 queue[prefix] = sorter - def dns_start(self, timeout = 120.0): + def dns_start(self, timeout=120.0): # creates the dictionary that is going to hold the original # values for the hosts, this is going to be required for later # computation of the resolved values @@ -582,16 +615,17 @@ def dns_start(self, timeout = 120.0): # this is considered the bootstrap of the hosts for host, values in netius.legacy.items(self.hosts): is_sequence = isinstance(values, (list, tuple)) - if not is_sequence: values = (values,) + if not is_sequence: + values = (values,) resolved = [[value] for value in values] self.hosts[host] = tuple(values) self.hosts_o[host] = (values, resolved) # runs the initial tick for the DNS execution, this should start # the DNS resolution process - self.dns_tick(timeout = timeout) + self.dns_tick(timeout=timeout) - def dns_tick(self, timeout = 120.0): + def dns_tick(self, timeout=120.0): # iterates over the complete set of original hosts to run the tick # operation, this should perform DNS queries for all values using # the default netius DNS client @@ -614,46 +648,33 @@ def dns_tick(self, timeout = 120.0): # validates that the parsed hostname is valid and ready # to be queried as DNS value, if that's not the case # continues the loop skipping current iteration - if not hostname: continue + if not hostname: + continue # creates the callback function that is going to be called # after the DNS resolution for proper hosts setting callback = self.dns_callback( - host, - value, - parsed, - index = index, - resolved = resolved + host, value, parsed, index=index, resolved=resolved ) # runs the DNS query execution for the hostname associated # with the current load balancing URL, the callback of this # call should handle the addition of the value to hosts - netius.clients.DNSClient.query_s( - hostname, - type = "a", - callback = callback - ) + netius.clients.DNSClient.query_s(hostname, type="a", callback=callback) # verifies if the requested timeout is zero and if that's the # case only one execution of the DNS tick is pretended, returns # the control flow back to caller immediately - if timeout == 0: return + if timeout == 0: + return # schedules a delayed execution taking into account the timeout # that has been provided, this is going to update the various # servers that are defined for the registered domains - tick = lambda: self.dns_tick(timeout = timeout) - self.delay(tick, timeout = timeout) + tick = lambda: self.dns_tick(timeout=timeout) + self.delay(tick, timeout=timeout) - def dns_callback( - self, - host, - hostname, - parsed, - index = 0, - resolved = [] - ): + def dns_callback(self, host, hostname, parsed, index=0, resolved=[]): # constructs both the port string based value and extracts # the path of the base URL that gave origin to this resolution port_s = ":" + str(parsed.port) if parsed.port else "" @@ -662,8 +683,10 @@ def dns_callback( def callback(response): # in case there's no valid DNS response there's nothing to be # done, control flow is returned immediately - if not response: return - if not response.answers: return + if not response: + return + if not response.answers: + return # creates the list that is going to be used t store the complete # set of resolved URL for the current host value in resolution @@ -676,7 +699,8 @@ def callback(response): for answer in response.answers: type_s = answer[1] address = answer[4] - if not type_s in ("A", "AAAA"): continue + if not type_s in ("A", "AAAA"): + continue url = "%s://%s%s%s" % (parsed.scheme, address, port_s, path) target.append(url) @@ -696,156 +720,156 @@ def _on_prx_message(self, client, parser, message): busy = _connection.busy if hasattr(_connection, "busy") else 0 state = _connection.state if hasattr(_connection, "state") else None error_url = _connection.state if hasattr(_connection, "error_url") else None - if busy: self.busy_conn -= 1; _connection.busy -= 1 - if state: self.releaser(state); _connection.state = None - if error_url: _connection.error_url = None + if busy: + self.busy_conn -= 1 + _connection.busy -= 1 + if state: + self.releaser(state) + _connection.state = None + if error_url: + _connection.error_url = None netius.servers.ProxyServer._on_prx_message(self, client, parser, message) def _on_prx_close(self, client, _connection): busy = _connection.busy if hasattr(_connection, "busy") else 0 state = _connection.state if hasattr(_connection, "state") else None error_url = _connection.state if hasattr(_connection, "error_url") else None - if busy: self.busy_conn -= busy; _connection.busy -= busy - if state: self.releaser(state); _connection.state = None - if error_url: _connection.error_url = None + if busy: + self.busy_conn -= busy + _connection.busy -= busy + if state: + self.releaser(state) + _connection.state = None + if error_url: + _connection.error_url = None netius.servers.ProxyServer._on_prx_close(self, client, _connection) def _apply_all( - self, - parser, - connection, - headers, - upper = True, - normalize = False, - replace = False + self, parser, connection, headers, upper=True, normalize=False, replace=False ): netius.servers.ProxyServer._apply_all( self, parser, connection, headers, - upper = upper, - normalize = normalize, - replace = replace + upper=upper, + normalize=normalize, + replace=replace, ) # in case a strict transport security value (number) is defined it # is going to be used as the max age value to be applied for such # behavior, note that this is considered dangerous at it may corrupt # the serving of assets through non secure (no SSL) connections - if self.sts: headers["Strict-Transport-Security"] = "max-age=%d" % self.sts + if self.sts: + headers["Strict-Transport-Security"] = "max-age=%d" % self.sts - def _apply_headers(self, parser, connection, parser_prx, headers, upper = True): + def _apply_headers(self, parser, connection, parser_prx, headers, upper=True): netius.servers.ProxyServer._apply_headers( - self, - parser, - connection, - parser_prx, - headers, - upper = upper + self, parser, connection, parser_prx, headers, upper=upper ) # in case a strict transport security value (number) is defined it # is going to be used as the max age value to be applied for such # behavior, note that this is considered dangerous at it may corrupt # the serving of assets through non secure (no SSL) connections - if self.sts: headers["Strict-Transport-Security"] = "max-age=%d" % self.sts + if self.sts: + headers["Strict-Transport-Security"] = "max-age=%d" % self.sts # in case the parser has determined that the current connection is # meant to be kept alive the connection header is forced to be keep # alive this avoids issues where in HTTP 1.1 the connection header # is omitted and an ambiguous situation may be created raising the # level of incompatibility with user agents - if parser_prx.keep_alive: parser_prx.headers["Connection"] = "keep-alive" + if parser_prx.keep_alive: + parser_prx.headers["Connection"] = "keep-alive" def _set_strategy(self): self.balancer_m = getattr(self, "balancer_" + self.strategy) self.acquirer_m = getattr(self, "acquirer_" + self.strategy) self.releaser_m = getattr(self, "releaser_" + self.strategy) - def _resolve_regex(self, value, regexes, default = None): + def _resolve_regex(self, value, regexes, default=None): for regex, result in regexes: match = regex.match(value) - if not match: continue + if not match: + continue return result, match return default, None - def _echo(self, sort = True): - self._echo_regex(sort = sort) - self._echo_hosts(sort = sort) - self._echo_alias(sort = sort) - self._echo_redirect(sort = sort) - self._echo_error_urls(sort = sort) + def _echo(self, sort=True): + self._echo_regex(sort=sort) + self._echo_hosts(sort=sort) + self._echo_alias(sort=sort) + self._echo_redirect(sort=sort) + self._echo_error_urls(sort=sort) - def _echo_regex(self, sort = True): + def _echo_regex(self, sort=True): self.info("Regex registration information") - for key, value in self.regex: self.info("%s => %s" % (key, value)) + for key, value in self.regex: + self.info("%s => %s" % (key, value)) - def _echo_hosts(self, sort = True): + def _echo_hosts(self, sort=True): keys = netius.legacy.keys(self.hosts) - if sort: keys.sort() + if sort: + keys.sort() self.info("Host registration information") - for key in keys: self.info("%s => %s" % (key, self.hosts[key])) + for key in keys: + self.info("%s => %s" % (key, self.hosts[key])) - def _echo_alias(self, sort = True): + def _echo_alias(self, sort=True): keys = netius.legacy.keys(self.alias) - if sort: keys.sort() + if sort: + keys.sort() self.info("Alias registration information") - for key in keys: self.info("%s => %s" % (key, self.alias[key])) + for key in keys: + self.info("%s => %s" % (key, self.alias[key])) - def _echo_redirect(self, sort = True): + def _echo_redirect(self, sort=True): keys = netius.legacy.keys(self.redirect) - if sort: keys.sort() + if sort: + keys.sort() self.info("Redirect registration information") - for key in keys: self.info("%s => %s" % (key, self.redirect[key])) + for key in keys: + self.info("%s => %s" % (key, self.redirect[key])) - def _echo_error_urls(self, sort = True): + def _echo_error_urls(self, sort=True): keys = netius.legacy.keys(self.error_urls) - if sort: keys.sort() + if sort: + keys.sort() self.info("Error URLs registration information") - for key in keys: self.info("%s => %s" % (key, self.error_urls[key])) + for key in keys: + self.info("%s => %s" % (key, self.error_urls[key])) + if __name__ == "__main__": import logging + regex = ( (re.compile(r"https://host\.com"), "http://localhost"), - (re.compile(r"https://([a-zA-Z]*)\.host\.com"), "http://localhost/{0}") + (re.compile(r"https://([a-zA-Z]*)\.host\.com"), "http://localhost/{0}"), ) - hosts = { - "default" : "http://default.host.com", - "host.com" : "http://host.com" - } - alias = { - "alias.host.com" : "host.com" - } - auth = { - "host.com" : netius.SimpleAuth("root", "root") - } + hosts = {"default": "http://default.host.com", "host.com": "http://host.com"} + alias = {"alias.host.com": "host.com"} + auth = {"host.com": netius.SimpleAuth("root", "root")} auth_regex = ( ( re.compile(r"https://host\.com:9090"), - ( - netius.SimpleAuth("root", "root"), - netius.AddressAuth(["127.0.0.1"]) - ) + (netius.SimpleAuth("root", "root"), netius.AddressAuth(["127.0.0.1"])), ), ) - redirect = { - "host.com" : "other.host.com" - } - error_urls = { - "host.com" : "http://host.com/error" - } + redirect = {"host.com": "other.host.com"} + error_urls = {"host.com": "http://host.com/error"} server = ReverseProxyServer( - regex = regex, - hosts = hosts, - alias = alias, - auth = auth, - auth_regex = auth_regex, - redirect = redirect, - error_urls = error_urls, - level = logging.INFO + regex=regex, + hosts=hosts, + alias=alias, + auth=auth, + auth_regex=auth_regex, + redirect=redirect, + error_urls=error_urls, + level=logging.INFO, ) - server.serve(env = True) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/extra/smtp_r.py b/src/netius/extra/smtp_r.py index f32330b67..e0e0525b8 100644 --- a/src/netius/extra/smtp_r.py +++ b/src/netius/extra/smtp_r.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -45,6 +36,7 @@ import netius.clients import netius.servers + class RelaySMTPServer(netius.servers.SMTPServer): """ Relay version of the smtp server that relays messages @@ -54,7 +46,7 @@ class RelaySMTPServer(netius.servers.SMTPServer): to relay the messages. """ - def __init__(self, postmaster = None, *args, **kwargs): + def __init__(self, postmaster=None, *args, **kwargs): netius.servers.SMTPServer.__init__(self, *args, **kwargs) self.postmaster = postmaster self.dkim = {} @@ -65,7 +57,8 @@ def on_serve(self): self.dkim = self.get_env("DKIM", self.dkim) dkim_l = len(self.dkim) self.info("Starting Relay SMTP server with %d DKIM registers ..." % dkim_l) - if self.postmaster: self.info("Using '%s' as the Postmaster email sender ..." % self.postmaster) + if self.postmaster: + self.info("Using '%s' as the Postmaster email sender ..." % self.postmaster) def on_header_smtp(self, connection, from_l, to_l): netius.servers.SMTPServer.on_header_smtp(self, connection, from_l, to_l) @@ -87,7 +80,8 @@ def on_data_smtp(self, connection, data): # verifies if there're remote addresses in the current # connection's message and if there is adds the received # data to the current relay buffer that is used - if not connection.remotes: return + if not connection.remotes: + return connection.relay.append(data) def on_message_smtp(self, connection): @@ -95,17 +89,18 @@ def on_message_smtp(self, connection): # in case there's no remotes list in the current connection # there's no need to proceed as no relay is required - if not connection.remotes: return + if not connection.remotes: + return # joins the current relay buffer to create the full message # data and then removes the (non required) termination value # from it to avoid any possible problems with extra size data_s = b"".join(connection.relay) - data_s = data_s[:netius.servers.TERMINATION_SIZE * -1] + data_s = data_s[: netius.servers.TERMINATION_SIZE * -1] # retrieves the list of "froms" for the connection and then # sends the message for relay to all of the current remotes - froms = self._emails(connection.from_l, prefix = "from") + froms = self._emails(connection.from_l, prefix="from") self.relay(connection, froms, connection.remotes, data_s) def relay(self, connection, froms, tos, contents): @@ -134,7 +129,7 @@ def relay(self, connection, froms, tos, contents): # the one that is going to be used for message id generation # and then generates a new "temporary" message id first = froms[0] - message_id = self.message_id(connection = connection, email = first) + message_id = self.message_id(connection=connection, email=first) # the default reply to value is the first from value and it # should serve as a way to reply with errors in case they @@ -157,7 +152,7 @@ def relay(self, connection, froms, tos, contents): # search the current registry, trying to find a registry for the # domain of the sender and if it finds one signs the message using # the information provided by the registry - contents = self.dkim_contents(contents, email = first) + contents = self.dkim_contents(contents, email=first) # creates the callback that will close the client once the message # is sent to all the recipients (better auto close support), note @@ -169,20 +164,21 @@ def relay(self, connection, froms, tos, contents): # postmaster email to the reply to address found in the message, # note that this is only performed in case there's a valid email # address defined as postmaster for this SMTP server - callback_error = lambda smtp_client, context, exception:\ - self.relay_postmaster(reply_to, context, exception) + callback_error = lambda smtp_client, context, exception: self.relay_postmaster( + reply_to, context, exception + ) # generates a new smtp client for the sending of the message, # uses the current host for identification and then triggers # the message event to send the message to the target host - smtp_client = netius.clients.SMTPClient(host = self.host) + smtp_client = netius.clients.SMTPClient(host=self.host) smtp_client.message( froms, tos, contents, - mark = False, - callback = callback, - callback_error = callback_error + mark=False, + callback=callback, + callback_error=callback_error, ) def relay_postmaster(self, reply_to, context, exception): @@ -190,14 +186,22 @@ def relay_postmaster(self, reply_to, context, exception): # postmaster processing is available, meaning that # a reply to address is present and the postmaster # email was defined for the server - if not reply_to: return - if not self.postmaster: return + if not reply_to: + return + if not self.postmaster: + return # tries to extract both the main message and the details # from the information available from the exception message = exception.message if hasattr(exception, "message") else str(exception) - details = exception.details if (hasattr(exception, "details") and\ - isinstance(exception.details, (list, tuple))) else [] + details = ( + exception.details + if ( + hasattr(exception, "details") + and isinstance(exception.details, (list, tuple)) + ) + else [] + ) # builds the base sender and receiver information for # the postmaster email to be sent @@ -215,34 +219,28 @@ def relay_postmaster(self, reply_to, context, exception): contents_l.append("%s\r\n" % subject) contents_l.append("Message: %s\r\n" % message) contents_l.append("Details: %s\r\n\r\n" % ("\n".join(details) or "-")) - contents_l.append("----- Original message -----\r\n\r\n%s" % netius.legacy.str(contents_o)) + contents_l.append( + "----- Original message -----\r\n\r\n%s" % netius.legacy.str(contents_o) + ) contents = "".join(contents_l) # builds a new SMTP client that is going to be used # for the postmaster operation, ensures that the # correct contents are set in the message (including DKIM) - smtp_client = netius.clients.SMTPClient(host = self.host) + smtp_client = netius.clients.SMTPClient(host=self.host) contents = smtp_client.comply( - contents, - froms = froms, - tos = tos, - message_id = self.message_id(email = first) + contents, froms=froms, tos=tos, message_id=self.message_id(email=first) ) contents = smtp_client.mark(contents) contents = netius.legacy.bytes(contents) - contents = self.dkim_contents(contents, email = first) - smtp_client.message( - froms, - tos, - contents, - mark = False - ) + contents = self.dkim_contents(contents, email=first) + smtp_client.message(froms, tos, contents, mark=False) def date(self): date_time = datetime.datetime.utcnow() return date_time.strftime("%a, %d %b %Y %H:%M:%S +0000") - def message_id(self, connection = None, email = "user@localhost"): + def message_id(self, connection=None, email="user@localhost"): _user, domain = email.split("@", 1) domain = self.host or domain identifier = str(uuid.uuid4()) @@ -253,10 +251,11 @@ def message_id(self, connection = None, email = "user@localhost"): identifier = connection.identifier return "<%s@%s>" % (identifier, domain) - def dkim_contents(self, contents, email = "user@localhost", creation = None): + def dkim_contents(self, contents, email="user@localhost", creation=None): _user, domain = email.split("@", 1) register = self.dkim.get(domain, None) - if not register: return contents + if not register: + return contents key_path = register.get("key", None) key_b64 = register.get("key_b64", None) @@ -265,24 +264,24 @@ def dkim_contents(self, contents, email = "user@localhost", creation = None): contents = contents.lstrip() - if key_path: private_key = netius.common.open_private_key(key_path) - elif key_b64: private_key = netius.common.open_private_key_b64(key_b64) - else: raise netius.SecurityError("No private key provided") + if key_path: + private_key = netius.common.open_private_key(key_path) + elif key_b64: + private_key = netius.common.open_private_key_b64(key_b64) + else: + raise netius.SecurityError("No private key provided") signature = netius.common.dkim_sign( - contents, - selector, - domain, - private_key, - identity = email, - creation = creation + contents, selector, domain, private_key, identity=email, creation=creation ) return signature + contents + if __name__ == "__main__": import logging - server = RelaySMTPServer(level = logging.DEBUG) - server.serve(env = True) + + server = RelaySMTPServer(level=logging.DEBUG) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/middleware/__init__.py b/src/netius/middleware/__init__.py index 38dc71df5..3f175e1e6 100644 --- a/src/netius/middleware/__init__.py +++ b/src/netius/middleware/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/middleware/annoyer.py b/src/netius/middleware/annoyer.py index bca5c7512..dfd94e729 100644 --- a/src/netius/middleware/annoyer.py +++ b/src/netius/middleware/annoyer.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -46,6 +37,7 @@ from .base import Middleware + class AnnoyerMiddleware(Middleware): """ Simple middleware that prints an "annoying" status message @@ -53,7 +45,7 @@ class AnnoyerMiddleware(Middleware): a simple diagnostics strategy. """ - def __init__(self, owner, period = 10.0): + def __init__(self, owner, period=10.0): Middleware.__init__(self, owner) self.period = period self._initial = None @@ -62,8 +54,8 @@ def __init__(self, owner, period = 10.0): def start(self): Middleware.start(self) - self.period = netius.conf("ANNOYER_PERIOD", self.period, cast = float) - self._thread = threading.Thread(target = self._run) + self.period = netius.conf("ANNOYER_PERIOD", self.period, cast=float) + self._thread = threading.Thread(target=self._run) self._thread.start() def stop(self): @@ -79,8 +71,10 @@ def _run(self): while self._running: delta = datetime.datetime.utcnow() - self._initial delta_s = self.owner._format_delta(delta) - message = "Uptime => %s | Connections => %d\n" %\ - (delta_s, len(self.owner.connections)) + message = "Uptime => %s | Connections => %d\n" % ( + delta_s, + len(self.owner.connections), + ) sys.stdout.write(message) sys.stdout.flush() time.sleep(self.period) diff --git a/src/netius/middleware/base.py b/src/netius/middleware/base.py index ad5cb7edd..37963346c 100644 --- a/src/netius/middleware/base.py +++ b/src/netius/middleware/base.py @@ -22,21 +22,13 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ __license__ = "Apache License, Version 2.0" """ The license for the module """ + class Middleware(object): def __init__(self, owner): diff --git a/src/netius/middleware/blacklist.py b/src/netius/middleware/blacklist.py index 3e641aac4..523061af0 100644 --- a/src/netius/middleware/blacklist.py +++ b/src/netius/middleware/blacklist.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,21 +32,22 @@ from .base import Middleware + class BlacklistMiddleware(Middleware): """ Simple middleware implementation for blacklisting of IP addresses using a very minimalistic approach. """ - def __init__(self, owner, blacklist = None, whitelist = None): + def __init__(self, owner, blacklist=None, whitelist=None): Middleware.__init__(self, owner) self.blacklist = blacklist or [] self.whitelist = whitelist or [] def start(self): Middleware.start(self) - self.blacklist = netius.conf("BLACKLIST", self.blacklist, cast = list) - self.whitelist = netius.conf("WHITELIST", self.whitelist, cast = list) + self.blacklist = netius.conf("BLACKLIST", self.blacklist, cast=list) + self.whitelist = netius.conf("WHITELIST", self.whitelist, cast=list) self.owner.bind("connection_c", self.on_connection_c) def stop(self): @@ -64,7 +56,9 @@ def stop(self): def on_connection_c(self, owner, connection): host = connection.address[0] - if not host in self.blacklist and not "*" in self.blacklist: return - if host in self.whitelist: return + if not host in self.blacklist and not "*" in self.blacklist: + return + if host in self.whitelist: + return self.owner.warning("Connection from '%s' dropped (blacklisted)" % host) connection.close() diff --git a/src/netius/middleware/dummy.py b/src/netius/middleware/dummy.py index 6a2624d53..edde5a9f8 100644 --- a/src/netius/middleware/dummy.py +++ b/src/netius/middleware/dummy.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,6 +30,7 @@ from .base import Middleware + class DummyMiddleware(Middleware): """ Simple middleware implementation for testing/debugging diff --git a/src/netius/middleware/flood.py b/src/netius/middleware/flood.py index d1878530c..0f3d1ce32 100644 --- a/src/netius/middleware/flood.py +++ b/src/netius/middleware/flood.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,21 +34,22 @@ from .base import Middleware + class FloodMiddleware(Middleware): """ Simple middleware implementation for avoiding flooding of connection creation from a certain address. """ - def __init__(self, owner, conns_per_min = 600, whitelist = None): + def __init__(self, owner, conns_per_min=600, whitelist=None): Middleware.__init__(self, owner) self.blacklist = conns_per_min self.whitelist = whitelist or [] def start(self): Middleware.start(self) - self.conns_per_min = netius.conf("CONNS_PER_MIN", self.conns_per_min, cast = int) - self.whitelist = netius.conf("WHITELIST", self.whitelist, cast = list) + self.conns_per_min = netius.conf("CONNS_PER_MIN", self.conns_per_min, cast=int) + self.whitelist = netius.conf("WHITELIST", self.whitelist, cast=list) self.blacklist = [] self.conn_map = dict() self.minute = int(time.time() // 60) @@ -70,16 +62,20 @@ def stop(self): def on_connection_c(self, owner, connection): host = connection.address[0] self._update_flood(host) - if not host in self.blacklist and not "*" in self.blacklist: return - if host in self.whitelist: return + if not host in self.blacklist and not "*" in self.blacklist: + return + if host in self.whitelist: + return self.owner.warning("Connection from '%s' dropped (flooding avoidance)" % host) connection.close() def _update_flood(self, host): minute = int(time.time() // 60) - if minute == self.minute: self.conn_map.clear() + if minute == self.minute: + self.conn_map.clear() self.minute = minute count = self.conn_map.get(host, 0) count += 1 self.conn_map[host] = count - if count > self.conns_per_min: self.blacklist.append(host) + if count > self.conns_per_min: + self.blacklist.append(host) diff --git a/src/netius/middleware/proxy.py b/src/netius/middleware/proxy.py index f27684592..ddfe4f92a 100644 --- a/src/netius/middleware/proxy.py +++ b/src/netius/middleware/proxy.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,6 +34,7 @@ from .base import Middleware + class ProxyMiddleware(Middleware): """ Middleware that implements the PROXY protocol on creation @@ -78,13 +70,13 @@ class ProxyMiddleware(Middleware): PROTO_STREAM_v2 = 0x1 PROTO_DGRAM_v2 = 0x2 - def __init__(self, owner, version = 1): + def __init__(self, owner, version=1): Middleware.__init__(self, owner) self.version = version def start(self): Middleware.start(self) - self.version = netius.conf("PROXY_VERSION", self.version, cast = int) + self.version = netius.conf("PROXY_VERSION", self.version, cast=int) self.owner.bind("connection_c", self.on_connection_c) def stop(self): @@ -92,9 +84,12 @@ def stop(self): self.owner.unbind("connection_c", self.on_connection_c) def on_connection_c(self, owner, connection): - if self.version == 1: connection.add_starter(self._proxy_handshake_v1) - elif self.version == 2: connection.add_starter(self._proxy_handshake_v2) - else: raise netius.RuntimeError("Invalid PROXY version") + if self.version == 1: + connection.add_starter(self._proxy_handshake_v1) + elif self.version == 2: + connection.add_starter(self._proxy_handshake_v2) + else: + raise netius.RuntimeError("Invalid PROXY version") def _proxy_handshake_v1(self, connection): cls = self.__class__ @@ -118,8 +113,10 @@ def _proxy_handshake_v1(self, connection): # for the connection and if that's the case uses it otherwise # starts a new empty buffer from scratch has_buffer = hasattr(connection, "_proxy_buffer") - if has_buffer: buffer = connection._proxy_buffer - else: buffer = bytearray() + if has_buffer: + buffer = connection._proxy_buffer + else: + buffer = bytearray() # saves the "newly" created buffer as the PROXY buffer for the # current connection (may be used latter) @@ -134,11 +131,14 @@ def _proxy_handshake_v1(self, connection): # in case the received data represents that of a closed connection # the connection is closed and the control flow returned - if data == b"": connection.close(); return + if data == b"": + connection.close() + return # in case the received value is false, that indicates that the # execution has failed due to an exception (expected or unexpected) - if data == False: return + if data == False: + return # updates the "initial" buffer length taking into account # the current buffer and then appends the new data to it @@ -151,7 +151,8 @@ def _proxy_handshake_v1(self, connection): # in case the ready state has been reached, the complete set of # data is ready to be parsed and the loop is stopped - if is_ready: break + if is_ready: + break # removes the PROXY buffer reference from the connection as # its no longer going to be used @@ -171,7 +172,8 @@ def _proxy_handshake_v1(self, connection): # in case there's valid extra data to be restored to the connection # performs the operation, effectively restoring it for latter # receiving operations (just like adding it back to the socket) - if extra: connection.restore(extra) + if extra: + connection.restore(extra) # forces the "conversion" of the line into a string so that it may # be properly split into its components, note that first the value @@ -186,8 +188,8 @@ def _proxy_handshake_v1(self, connection): # prints a debug message about the PROXY header received, so that runtime # debugging is possible (and expected for this is a sensible part) self.owner.debug( - "Received header %s %s %s:%s => %s:%s" % - (header, protocol, source, source_p, destination, destination_p) + "Received header %s %s %s:%s => %s:%s" + % (header, protocol, source, source_p, destination, destination_p) ) # re-constructs the source address from the provided information, this is @@ -207,8 +209,10 @@ def _proxy_handshake_v2(self, connection): # for the connection and if that's the case uses it otherwise # starts a new empty buffer from scratch has_buffer = hasattr(connection, "_proxy_buffer") - if has_buffer: buffer = connection._proxy_buffer - else: buffer = bytearray() + if has_buffer: + buffer = connection._proxy_buffer + else: + buffer = bytearray() # saves the "newly" created buffer as the PROXY buffer for the # current connection (may be used latter) @@ -216,12 +220,15 @@ def _proxy_handshake_v2(self, connection): # verifies if a PROXY header was already parsed from the current connection # and if that was not the case runs its parsing - header = connection._proxy_header if hasattr(connection, "_proxy_header") else None + header = ( + connection._proxy_header if hasattr(connection, "_proxy_header") else None + ) if not header: # tries to read the PROXY v2 header bytes to be able to parse # the body parts taking that into account header = self._read_safe(connection, buffer, cls.HEADER_LENGTH_V2) - if not header: return + if not header: + return # updates the reference to the proxy header in the connection # and clears the buffer as it's now going to be used to load @@ -231,16 +238,18 @@ def _proxy_handshake_v2(self, connection): # unpacks the PROXY v2 header into its components, notice that some of them # contain multiple values on higher and lower bits - magic, version_type, address_protocol, body_size = struct.unpack("!12sBBH", header) + magic, version_type, address_protocol, body_size = struct.unpack( + "!12sBBH", header + ) # unpacks both the version (of the protocol) and the type (of message) by # unpacking the higher and the lower bits version = version_type >> 4 - type = version_type & 0x0f + type = version_type & 0x0F # unpacks the type of address to be communicated and the protocol family address = address_protocol >> 4 - protocol = address_protocol & 0x0f + protocol = address_protocol & 0x0F # runs a series of assertions on some of the basic promises of the protocol # (if they failed connection will be dropped) @@ -250,19 +259,22 @@ def _proxy_handshake_v2(self, connection): # reads the body part of the PROXY message taking into account the advertised # size of the body (from header component) body = self._read_safe(connection, buffer, body_size) - if not body: return + if not body: + return if address == cls.AF_INET_v2: source, destination, source_p, destination_p = struct.unpack("!IIHH", body) source = netius.common.addr_to_ip4(source) destination = netius.common.addr_to_ip4(destination) elif address == cls.AF_INET6_v2: - source_high,\ - source_low,\ - destination_high,\ - destination_low,\ - source_p,\ - destination_p = struct.unpack("!QQQQHH", body) + ( + source_high, + source_low, + destination_high, + destination_low, + source_p, + destination_p, + ) = struct.unpack("!QQQQHH", body) source = (source_high << 64) + source_low destination = (destination_high << 64) + destination_low source = netius.common.addr_to_ip6(source) @@ -278,8 +290,8 @@ def _proxy_handshake_v2(self, connection): # prints a debug message about the PROXY header received, so that runtime # debugging is possible (and expected for this is a sensible part) self.owner.debug( - "Received header v2 %d %s:%s => %s:%s" % - (protocol, source, source_p, destination, destination_p) + "Received header v2 %d %s:%s => %s:%s" + % (protocol, source, source_p, destination, destination_p) ) # re-constructs the source address from the provided information, this is @@ -328,7 +340,8 @@ def _read_safe(self, connection, buffer, count): # in the buffer and if that's less or equal to zero breaks the # current loop (nothing pending to be read) pending = count - len(buffer) - if pending <= 0: break + if pending <= 0: + break # tries to receive the maximum size of data that is required # for the handling of the PROXY information @@ -336,11 +349,14 @@ def _read_safe(self, connection, buffer, count): # in case the received data represents that of a closed connection # the connection is closed and the control flow returned - if data == b"": connection.close(); return None + if data == b"": + connection.close() + return None # in case the received value is false, that indicates that the # execution has failed due to an exception (expected or unexpected) - if data == False: return None + if data == False: + return None # adds the newly read data to the current buffer buffer += data diff --git a/src/netius/mock/__init__.py b/src/netius/mock/__init__.py index 135f4d941..ef68ebcfe 100644 --- a/src/netius/mock/__init__.py +++ b/src/netius/mock/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/mock/appier.py b/src/netius/mock/appier.py index 52abecaac..c2b14711f 100644 --- a/src/netius/mock/appier.py +++ b/src/netius/mock/appier.py @@ -22,24 +22,17 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ __license__ = "Apache License, Version 2.0" """ The license for the module """ + class APIApp(object): pass + def route(*args, **kwargs): def decorator(*args, **kwargs): diff --git a/src/netius/pool/__init__.py b/src/netius/pool/__init__.py index 06a711683..7eb25deb9 100644 --- a/src/netius/pool/__init__.py +++ b/src/netius/pool/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,7 +30,14 @@ from . import notify from . import task -from .common import Thread, ThreadPool, EventPool, EventFile, UnixEventFile, SocketEventFile +from .common import ( + Thread, + ThreadPool, + EventPool, + EventFile, + UnixEventFile, + SocketEventFile, +) from .file import FileThread, FilePool from .notify import NotifyPool from .task import TaskThread, TaskPool diff --git a/src/netius/pool/common.py b/src/netius/pool/common.py index 0f5f1806d..95b4238c8 100644 --- a/src/netius/pool/common.py +++ b/src/netius/pool/common.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -46,9 +37,10 @@ CALLABLE_WORK = 1 + class Thread(threading.Thread): - def __init__(self, identifier, owner = None, *args, **kwargs): + def __init__(self, identifier, owner=None, *args, **kwargs): threading.Thread.__init__(self, *args, **kwargs) self.identifier = identifier self.owner = owner @@ -64,14 +56,16 @@ def stop(self): def run(self): threading.Thread.run(self) self._run = True - while self._run: self.tick() + while self._run: + self.tick() def tick(self): self.owner.condition.acquire() while not self.owner.peek() and self._run: self.owner.condition.wait() try: - if not self._run: return + if not self._run: + return work = self.owner.pop() finally: self.owner.condition.release() @@ -79,12 +73,15 @@ def tick(self): def execute(self, work): type = work[0] - if type == CALLABLE_WORK: work[1]() - else: raise netius.NotImplemented("Cannot execute type '%d'" % type) + if type == CALLABLE_WORK: + work[1]() + else: + raise netius.NotImplemented("Cannot execute type '%d'" % type) + class ThreadPool(object): - def __init__(self, base = Thread, count = 32): + def __init__(self, base=Thread, count=32): self.base = base self.count = count self.instances = [] @@ -97,32 +94,41 @@ def start(self): for instance in self.instances: instance.start() - def stop(self, join = True): - for instance in self.instances: instance.stop() + def stop(self, join=True): + for instance in self.instances: + instance.stop() self.condition.acquire() - try: self.condition.notify_all() - finally: self.condition.release() - if not join: return - for instance in self.instances: instance.join() + try: + self.condition.notify_all() + finally: + self.condition.release() + if not join: + return + for instance in self.instances: + instance.join() def build(self): - if self._built: return + if self._built: + return for index in range(self.count): - instance = self.base(index, owner = self) + instance = self.base(index, owner=self) self.instances.append(instance) self._built = True def peek(self): - if not self.queue: return None + if not self.queue: + return None return self.queue[0] - def pop(self, lock = True): + def pop(self, lock=True): lock and self.condition.acquire() - try: value = self.queue.pop(0) - finally: lock and self.condition.release() + try: + value = self.queue.pop(0) + finally: + lock and self.condition.release() return value - def push(self, work, lock = True): + def push(self, work, lock=True): lock and self.condition.acquire() try: value = self.queue.append(work) @@ -135,52 +141,62 @@ def push_callable(self, callable): work = (CALLABLE_WORK, callable) self.push(work) + class EventPool(ThreadPool): - def __init__(self, base = Thread, count = 32): - ThreadPool.__init__(self, base = base, count = count) + def __init__(self, base=Thread, count=32): + ThreadPool.__init__(self, base=base, count=count) self.events = [] self.event_lock = threading.RLock() self._eventfd = None - def stop(self, join = True): - ThreadPool.stop(self, join = join) - if not self._eventfd: return + def stop(self, join=True): + ThreadPool.stop(self, join=join) + if not self._eventfd: + return self._eventfd.close() self._eventfd = None def push_event(self, event): self.event_lock.acquire() - try: self.events.append(event) - finally: self.event_lock.release() + try: + self.events.append(event) + finally: + self.event_lock.release() self.notify() def pop_event(self): self.event_lock.acquire() - try: event = self.events.pop(0) - finally: self.event_lock.release() + try: + event = self.events.pop(0) + finally: + self.event_lock.release() return event - def pop_all(self, denotify = False): + def pop_all(self, denotify=False): self.event_lock.acquire() try: events = list(self.events) del self.events[:] - if events and denotify: self.denotify() + if events and denotify: + self.denotify() finally: self.event_lock.release() return events def notify(self): - if not self._eventfd: return + if not self._eventfd: + return self._eventfd.notify() def denotify(self): - if not self._eventfd: return + if not self._eventfd: + return self._eventfd.denotify() def eventfd(self): - if self._eventfd: return self._eventfd + if self._eventfd: + return self._eventfd if UnixEventFile.available(): self._eventfd = UnixEventFile() elif PipeEventFile.available(): @@ -189,6 +205,7 @@ def eventfd(self): self._eventfd = SocketEventFile() return self._eventfd + class EventFile(object): def __init__(self, *args, **kwargs): @@ -214,6 +231,7 @@ def notify(self): def denotify(self): raise netius.NotImplemented("Missing implementation") + class UnixEventFile(EventFile): _LIBC = None @@ -229,14 +247,18 @@ def __init__(self, *args, **kwargs): @classmethod def available(cls): - if not os.name == "posix": return False + if not os.name == "posix": + return False return True if cls.libc() else False @classmethod def libc(cls): - if cls._LIBC: return cls._LIBC - try: cls._LIBC = ctypes.cdll.LoadLibrary("libc.so.6") - except Exception: return None + if cls._LIBC: + return cls._LIBC + try: + cls._LIBC = ctypes.cdll.LoadLibrary("libc.so.6") + except Exception: + return None return cls._LIBC def close(self): @@ -249,22 +271,26 @@ def notify(self): def denotify(self): self._read() - def _read(self, length = 8): - if self.closed: return None + def _read(self, length=8): + if self.closed: + return None return os.read(self._rfileno, length) def _write(self, value): - if self.closed: return + if self.closed: + return os.write(self._wfileno, ctypes.c_ulonglong(value)) + class PipeEventFile(EventFile): def __init__(self, *args, **kwargs): import fcntl + EventFile.__init__(self, *args, **kwargs) self._rfileno, self._wfileno = os.pipe() - fcntl.fcntl(self._rfileno, fcntl.F_SETFL, os.O_NONBLOCK) #@UndefinedVariable - fcntl.fcntl(self._wfileno, fcntl.F_SETFL, os.O_NONBLOCK) #@UndefinedVariable + fcntl.fcntl(self._rfileno, fcntl.F_SETFL, os.O_NONBLOCK) # @UndefinedVariable + fcntl.fcntl(self._wfileno, fcntl.F_SETFL, os.O_NONBLOCK) # @UndefinedVariable self._read_file = os.fdopen(self._rfileno, "rb", 0) self._write_file = os.fdopen(self._wfileno, "wb", 0) self._lock = threading.RLock() @@ -272,8 +298,10 @@ def __init__(self, *args, **kwargs): @classmethod def available(cls): - if not os.name == "posix": return False - if not hasattr(os, "pipe"): return False + if not os.name == "posix": + return False + if not hasattr(os, "pipe"): + return False return True def close(self): @@ -297,14 +325,17 @@ def denotify(self): finally: self._lock.release() - def _read(self, length = 4096): - if self.closed: return None + def _read(self, length=4096): + if self.closed: + return None return self._read_file.read(length) def _write(self, data): - if self.closed: return + if self.closed: + return self._write_file.write(data) + class SocketEventFile(EventFile): def __init__(self, *args, **kwargs): @@ -337,16 +368,19 @@ def notify(self): def denotify(self): self._lock.acquire() try: - if self._count == 0: return + if self._count == 0: + return data = self._read() self._count -= len(data) finally: self._lock.release() - def _read(self, length = 4096): - if self.closed: return None + def _read(self, length=4096): + if self.closed: + return None return self._read_socket.recv(length) def _write(self, data): - if self.closed: return + if self.closed: + return self._write_socket.send(data) diff --git a/src/netius/pool/file.py b/src/netius/pool/file.py index 5511a6b3d..22a6d27ad 100644 --- a/src/netius/pool/file.py +++ b/src/netius/pool/file.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -49,13 +40,13 @@ READ_ACTION = 3 WRITE_ACTION = 4 + class FileThread(common.Thread): def execute(self, work): type = work[0] - if not type == FILE_WORK: netius.NotImplemented( - "Cannot execute type '%d'" % type - ) + if not type == FILE_WORK: + netius.NotImplemented("Cannot execute type '%d'" % type) try: self._execute(work) @@ -80,29 +71,35 @@ def write(self, file, buffer, data): def _execute(self, work): action = work[1] - if action == OPEN_ACTION: self.open(*work[2:]) - elif action == CLOSE_ACTION: self.close(*work[2:]) - elif action == READ_ACTION: self.read(*work[2:]) - elif action == WRITE_ACTION: self.read(*work[2:]) - else: netius.NotImplemented("Undefined file action '%d'" % action) + if action == OPEN_ACTION: + self.open(*work[2:]) + elif action == CLOSE_ACTION: + self.close(*work[2:]) + elif action == READ_ACTION: + self.read(*work[2:]) + elif action == WRITE_ACTION: + self.read(*work[2:]) + else: + netius.NotImplemented("Undefined file action '%d'" % action) + class FilePool(common.EventPool): - def __init__(self, base = FileThread, count = 10): - common.EventPool.__init__(self, base = base, count = count) + def __init__(self, base=FileThread, count=10): + common.EventPool.__init__(self, base=base, count=count) - def open(self, path, mode = "r", data = None): + def open(self, path, mode="r", data=None): work = (FILE_WORK, OPEN_ACTION, path, mode, data) self.push(work) - def close(self, file, data = None): + def close(self, file, data=None): work = (FILE_WORK, CLOSE_ACTION, file, data) self.push(work) - def read(self, file, count = -1, data = None): + def read(self, file, count=-1, data=None): work = (FILE_WORK, READ_ACTION, file, count, data) self.push(work) - def write(self, file, buffer, data = None): + def write(self, file, buffer, data=None): work = (FILE_WORK, WRITE_ACTION, file, buffer, data) self.push(work) diff --git a/src/netius/pool/notify.py b/src/netius/pool/notify.py index f42976110..19a746809 100644 --- a/src/netius/pool/notify.py +++ b/src/netius/pool/notify.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,7 +30,8 @@ from . import common + class NotifyPool(common.EventPool): def __init__(self): - common.EventPool.__init__(self, count = 0) + common.EventPool.__init__(self, count=0) diff --git a/src/netius/pool/task.py b/src/netius/pool/task.py index cc9acb829..43f7f77f0 100644 --- a/src/netius/pool/task.py +++ b/src/netius/pool/task.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,23 +34,25 @@ TASK_WORK = 10 + class TaskThread(common.Thread): def execute(self, work): type = work[0] - if not type == TASK_WORK: netius.NotImplemented( - "Cannot execute type '%d'" % type - ) + if not type == TASK_WORK: + netius.NotImplemented("Cannot execute type '%d'" % type) callable, args, kwargs, callback = work[1:] result = callable(*args, **kwargs) - if callback: callback(result) + if callback: + callback(result) + class TaskPool(common.EventPool): - def __init__(self, base = TaskThread, count = 10): - common.EventPool.__init__(self, base = base, count = count) + def __init__(self, base=TaskThread, count=10): + common.EventPool.__init__(self, base=base, count=count) - def execute(self, callable, args = [], kwargs = {}, callback = None): + def execute(self, callable, args=[], kwargs={}, callback=None): work = (TASK_WORK, callable, args, kwargs, callback) self.push(work) diff --git a/src/netius/servers/__init__.py b/src/netius/servers/__init__.py index b38c6a5ca..38982d53f 100644 --- a/src/netius/servers/__init__.py +++ b/src/netius/servers/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/servers/dhcp.py b/src/netius/servers/dhcp.py index 07fa2c116..184a58a18 100644 --- a/src/netius/servers/dhcp.py +++ b/src/netius/servers/dhcp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,6 +33,7 @@ import netius.common + class DHCPRequest(object): options_m = None @@ -55,7 +47,8 @@ def __init__(self, data): @classmethod def generate(cls): - if cls.options_m: return + if cls.options_m: + return cls.options_m = ( cls._option_subnet, cls._option_router, @@ -74,7 +67,7 @@ def generate(cls): cls._option_renewal, cls._option_rebind, cls._option_proxy, - cls._option_end + cls._option_end, ) cls.options_l = len(cls.options_m) @@ -140,12 +133,13 @@ def unpack_options(self): index = 0 while True: byte = self.options[index] - if netius.legacy.ord(byte) == 0xff: break + if netius.legacy.ord(byte) == 0xFF: + break type = byte type_i = netius.legacy.ord(type) length = netius.legacy.ord(self.options[index + 1]) - payload = self.options[index + 2:index + length + 2] + payload = self.options[index + 2 : index + length + 2] self.options_p[type_i] = payload @@ -153,14 +147,16 @@ def unpack_options(self): def get_requested(self): payload = self.options_p.get(50, None) - if not payload: return "0.0.0.0" - value, = struct.unpack("!I", payload) + if not payload: + return "0.0.0.0" + (value,) = struct.unpack("!I", payload) requested = netius.common.addr_to_ip4(value) return requested def get_type(self): payload = self.options_p.get(53, None) - if not payload: return 0x00 + if not payload: + return 0x00 type = netius.legacy.ord(payload) return type @@ -177,7 +173,7 @@ def get_mac(self): mac_addr = ":".join(addr_l) return mac_addr - def response(self, yiaddr, options = {}): + def response(self, yiaddr, options={}): cls = self.__class__ host = netius.common.host() @@ -226,8 +222,10 @@ def response(self, yiaddr, options = {}): for option, values in netius.legacy.iteritems(options): method = cls.options_m[option - 1] - if values: option_s = method(**values) - else: option_s = method() + if values: + option_s = method(**values) + else: + option_s = method() buffer.append(option_s) buffer.append(end) @@ -252,47 +250,47 @@ def _pack_m(cls, sequence, format): return b"".join(result) @classmethod - def _option_subnet(cls, subnet = "255.255.255.0"): + def _option_subnet(cls, subnet="255.255.255.0"): subnet_a = netius.common.ip4_to_addr(subnet) subnet_s = struct.pack("!I", subnet_a) payload = cls._str(subnet_s) return b"\x01" + payload @classmethod - def _option_router(cls, routers = ["192.168.0.1"]): + def _option_router(cls, routers=["192.168.0.1"]): routers_a = [netius.common.ip4_to_addr(router) for router in routers] routers_s = cls._pack_m(routers_a, "!I") payload = cls._str(routers_s) return b"\x03" + payload @classmethod - def _option_dns(cls, servers = ["192.168.0.1", "192.168.0.2"]): + def _option_dns(cls, servers=["192.168.0.1", "192.168.0.2"]): servers_a = [netius.common.ip4_to_addr(server) for server in servers] servers_s = cls._pack_m(servers_a, "!I") payload = cls._str(servers_s) return b"\x06" + payload @classmethod - def _option_name(cls, name = "server.com"): + def _option_name(cls, name="server.com"): payload = cls._str(name) return b"\x0f" + payload @classmethod - def _option_broadcast(cls, broadcast = "192.168.0.255"): + def _option_broadcast(cls, broadcast="192.168.0.255"): subnet_a = netius.common.ip4_to_addr(broadcast) subnet_s = struct.pack("!I", subnet_a) payload = cls._str(subnet_s) return b"\x1c" + payload @classmethod - def _option_requested(cls, ip = "192.168.0.11"): + def _option_requested(cls, ip="192.168.0.11"): ip_a = netius.common.ip4_to_addr(ip) ip_s = struct.pack("!I", ip_a) payload = cls._str(ip_s) return b"\x32" + payload @classmethod - def _option_lease(cls, time = 3600): + def _option_lease(cls, time=3600): time_s = struct.pack("!I", time) payload = cls._str(time_s) return b"\x33" + payload @@ -322,26 +320,26 @@ def _option_nak(cls): return b"\x35\x01\x06" @classmethod - def _option_identifier(cls, identifier = "192.168.0.1"): + def _option_identifier(cls, identifier="192.168.0.1"): subnet_a = netius.common.ip4_to_addr(identifier) subnet_s = struct.pack("!I", subnet_a) payload = cls._str(subnet_s) return b"\x36" + payload @classmethod - def _option_renewal(cls, time = 3600): + def _option_renewal(cls, time=3600): time_s = struct.pack("!I", time) payload = cls._str(time_s) return b"\x3a" + payload @classmethod - def _option_rebind(cls, time = 3600): + def _option_rebind(cls, time=3600): time_s = struct.pack("!I", time) payload = cls._str(time_s) return b"\x3b" + payload @classmethod - def _option_proxy(cls, url = "http://localhost/proxy.pac"): + def _option_proxy(cls, url="http://localhost/proxy.pac"): length = len(url) length_o = netius.legacy.chr(length) return b"\xfc" + length_o + netius.legacy.bytes(url) @@ -350,10 +348,11 @@ def _option_proxy(cls, url = "http://localhost/proxy.pac"): def _option_end(cls): return b"\xff" + class DHCPServer(netius.DatagramServer): - def serve(self, port = 67, type = netius.UDP_TYPE, *args, **kwargs): - netius.DatagramServer.serve(self, port = port, type = type, *args, **kwargs) + def serve(self, port=67, type=netius.UDP_TYPE, *args, **kwargs): + netius.DatagramServer.serve(self, port=port, type=type, *args, **kwargs) def on_data(self, address, data): netius.DatagramServer.on_data(self, address, data) @@ -370,9 +369,8 @@ def on_data_dhcp(self, address, request): self.debug("Received %s message from '%s'" % (type_s, mac)) - if not type in (0x01, 0x03): raise netius.NetiusError( - "Invalid operation type '%d'", type - ) + if not type in (0x01, 0x03): + raise netius.NetiusError("Invalid operation type '%d'", type) type_r = self.get_type(request) options = self.get_options(request) @@ -384,7 +382,7 @@ def on_data_dhcp(self, address, request): self.debug("%s address '%s' ..." % (verb, yiaddr)) - response = request.response(yiaddr, options = options) + response = request.response(yiaddr, options=options) self.send_dhcp(response) def get_verb(self, type_r): @@ -403,9 +401,11 @@ def get_options(self, request): def get_yiaddr(self, request): raise netius.NetiusError("Not implemented") + if __name__ == "__main__": import logging - server = DHCPServer(level = logging.INFO) - server.serve(env = True) + + server = DHCPServer(level=logging.INFO) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/echo.py b/src/netius/servers/echo.py index 56ddbf125..b29f23037 100644 --- a/src/netius/servers/echo.py +++ b/src/netius/servers/echo.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,15 +30,18 @@ import netius + class EchoProtocol(netius.StreamProtocol): pass + class EchoServer(netius.ServerAgent): protocol = EchoProtocol + if __name__ == "__main__": server = EchoServer() - server.serve(env = True) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/echo_ws.py b/src/netius/servers/echo_ws.py index f40323edd..53c3193d1 100644 --- a/src/netius/servers/echo_ws.py +++ b/src/netius/servers/echo_ws.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -39,14 +30,16 @@ from . import ws + class EchoWSServer(ws.WSServer): def on_data_ws(self, connection, data): ws.WSServer.on_data_ws(self, connection, data) connection.send_ws(data) + if __name__ == "__main__": server = EchoWSServer() - server.serve(env = True) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/ftp.py b/src/netius/servers/ftp.py index c75e93775..4515546d8 100644 --- a/src/netius/servers/ftp.py +++ b/src/netius/servers/ftp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -48,35 +39,24 @@ sending the file to the client, this should not be neither to big nor to small (as both situations would create problems) """ -CAPABILITIES = ( - "PASV", - "UTF8" -) +CAPABILITIES = ("PASV", "UTF8") """ The sequence defining the complete set of capabilities that are available under the current ftp server implementation """ -PERMISSIONS = { - 7 : "rwx", - 6 : "rw-", - 5 : "r-x", - 4 : "r--", - 0 : "---" -} +PERMISSIONS = {7: "rwx", 6: "rw-", 5: "r-x", 4: "r--", 0: "---"} """ Map that defines the association between the octal based values for the permissions and the associated string values """ -TYPES = { - "A" : "ascii", - "E" : "ebcdic", - "I" : "binary", - "L" : "local" -} +TYPES = {"A": "ascii", "E": "ebcdic", "I": "binary", "L": "local"} """ The map that associated the various type command arguments with the more rich data mode transfer types """ + class FTPConnection(netius.Connection): - def __init__(self, base_path = "", host = "ftp.localhost", mode = "ascii", *args, **kwargs): + def __init__( + self, base_path="", host="ftp.localhost", mode="ascii", *args, **kwargs + ): netius.Connection.__init__(self, *args, **kwargs) self.parser = None self.base_path = os.path.abspath(base_path) @@ -90,56 +70,67 @@ def __init__(self, base_path = "", host = "ftp.localhost", mode = "ascii", *args def open(self, *args, **kwargs): netius.Connection.open(self, *args, **kwargs) - if not self.is_open(): return + if not self.is_open(): + return self.parser = netius.common.FTPParser(self) self.parser.bind("on_line", self.on_line) def close(self, *args, **kwargs): netius.Connection.close(self, *args, **kwargs) - if not self.is_closed(): return - if self.parser: self.parser.destroy() - if self.data_server: self.data_server.close_ftp() + if not self.is_closed(): + return + if self.parser: + self.parser.destroy() + if self.data_server: + self.data_server.close_ftp() file = hasattr(self, "file") and self.file - if file: file.close() + if file: + file.close() def parse(self, data): return self.parser.parse(data) - def send_ftp(self, code, message = "", lines = (), simple = False, delay = True, callback = None): - if lines: return self.send_ftp_lines( - code, - message = message, - lines = lines, - simple = simple, - delay = delay, - callback = callback - ) - else: return self.send_ftp_base( - code, - message, - delay, - callback - ) - - def send_ftp_base(self, code, message = "", delay = True, callback = None): + def send_ftp( + self, code, message="", lines=(), simple=False, delay=True, callback=None + ): + if lines: + return self.send_ftp_lines( + code, + message=message, + lines=lines, + simple=simple, + delay=delay, + callback=callback, + ) + else: + return self.send_ftp_base(code, message, delay, callback) + + def send_ftp_base(self, code, message="", delay=True, callback=None): base = "%d %s" % (code, message) data = base + "\r\n" - count = self.send(data, delay = delay, callback = callback) + count = self.send(data, delay=delay, callback=callback) self.owner.debug(base) return count - def send_ftp_lines(self, code, message = "", lines = (), simple = False, delay = True, callback = None): + def send_ftp_lines( + self, code, message="", lines=(), simple=False, delay=True, callback=None + ): lines = list(lines) - if not simple: lines.insert(0, message) + if not simple: + lines.insert(0, message) body = lines[:-1] tail = lines[-1] base = "%d-%s" % (code, message) if simple else "%d %s" % (code, message) - lines_s = [" %s" % line for line in body] if simple else\ - ["%d-%s" % (code, line) for line in body] + lines_s = ( + [" %s" % line for line in body] + if simple + else ["%d-%s" % (code, line) for line in body] + ) lines_s.append("%d %s" % (code, tail)) - if simple: lines_s.insert(0, base) + if simple: + lines_s.insert(0, base) data = "\r\n".join(lines_s) + "\r\n" - count = self.send(data, delay = delay, callback = callback) + count = self.send(data, delay=delay, callback=callback) self.owner.debug(base) return count @@ -156,17 +147,21 @@ def not_ok(self): self.send_ftp(500, message) def flush_ftp(self): - if not self.remaining: return + if not self.remaining: + return method = getattr(self, "flush_" + self.remaining) - try: method() - finally: self.remaining = None + try: + method() + finally: + self.remaining = None def data_ftp(self, data): self.file.write(data) def closed_ftp(self): has_file = hasattr(self, "file") and not self.file == None - if not has_file: return + if not has_file: + return self.file.close() self.file = None self.send_ftp(226, "file receive ok") @@ -174,7 +169,7 @@ def closed_ftp(self): def flush_list(self): self.send_ftp(150, "directory list sending") list_data = self._list() - self.data_server.send_ftp(list_data, callback = self.on_flush_list) + self.data_server.send_ftp(list_data, callback=self.on_flush_list) def flush_retr(self): self.send_ftp(150, "file sending") @@ -197,7 +192,7 @@ def on_flush_retr(self, connection): self._data_close() self.send_ftp(226, "file send ok") - def on_line(self, code, message, is_final = True): + def on_line(self, code, message, is_final=True): # "joins" the code and the message part of the message into the base # string and then uses this value to print some debug information base = "%s %s" % (code, message) @@ -217,7 +212,8 @@ def on_line(self, code, message, is_final = True): # does not raises an exception indicating the problem with the # code that has just been received (probably erroneous) exists = hasattr(self, method_n) - if not exists: raise netius.ParserError("Invalid code '%s'" % code) + if not exists: + raise netius.ParserError("Invalid code '%s'" % code) # retrieves the reference to the method that is going to be called # for the handling of the current line from the current instance and @@ -230,16 +226,16 @@ def on_user(self, message): self.ok() def on_syst(self, message): - self.send_ftp(215, message = "UNIX Type: L8 (%s)" % netius.VERSION) + self.send_ftp(215, message="UNIX Type: L8 (%s)" % netius.VERSION) def on_feat(self, message): - self.send_ftp(211, "features", lines = list(CAPABILITIES) + ["end"], simple = True) + self.send_ftp(211, "features", lines=list(CAPABILITIES) + ["end"], simple=True) def on_opts(self, message): self.ok() def on_pwd(self, message): - self.send_ftp(257, "\"%s\"" % self.cwd) + self.send_ftp(257, '"%s"' % self.cwd) def on_type(self, message): self.mode = TYPES.get("message", "ascii") @@ -247,8 +243,8 @@ def on_type(self, message): def on_pasv(self, message): data_server = self._data_open() - port_h = (data_server.port & 0xff00) >> 8 - port_l = data_server.port & 0x00ff + port_h = (data_server.port & 0xFF00) >> 8 + port_l = data_server.port & 0x00FF address = self.socket.getsockname()[0] address = address.replace(".", ",") address_s = "%s,%d,%d" % (address, port_h, port_l) @@ -258,61 +254,83 @@ def on_port(self, message): self.ok() def on_dele(self, message): - full_path = self._get_path(extra = message) - try: os.remove(full_path) - except Exception: self.not_ok() - else: self.ok() + full_path = self._get_path(extra=message) + try: + os.remove(full_path) + except Exception: + self.not_ok() + else: + self.ok() def on_mkd(self, message): - full_path = self._get_path(extra = message) - try: os.makedirs(full_path) - except Exception: self.not_ok() - else: self.ok() + full_path = self._get_path(extra=message) + try: + os.makedirs(full_path) + except Exception: + self.not_ok() + else: + self.ok() def on_rmd(self, message): - full_path = self._get_path(extra = message) - try: os.rmdir(full_path) - except Exception: self.not_ok() - else: self.ok() + full_path = self._get_path(extra=message) + try: + os.rmdir(full_path) + except Exception: + self.not_ok() + else: + self.ok() def on_rnfr(self, message): - self.source_path = self._get_path(extra = message) + self.source_path = self._get_path(extra=message) self.ok() def on_rnto(self, message): - self.target_path = self._get_path(extra = message) - try: os.rename(self.source_path, self.target_path) - except Exception: self.not_ok() - else: self.ok() - finally: self.source_path = self.target_path = None + self.target_path = self._get_path(extra=message) + try: + os.rename(self.source_path, self.target_path) + except Exception: + self.not_ok() + else: + self.ok() + finally: + self.source_path = self.target_path = None def on_cdup(self, message): self.cwd = self.cwd.rsplit("/", 1)[0] - if not self.cwd: self.cwd = "/" + if not self.cwd: + self.cwd = "/" self.ok() def on_cwd(self, message): is_absolute = message.startswith("/") - if is_absolute: cwd = message - else: cwd = self.cwd + (message if self.cwd.endswith("/") else "/" + message) + if is_absolute: + cwd = message + else: + cwd = self.cwd + (message if self.cwd.endswith("/") else "/" + message) - full_path = self._get_path(extra = message) + full_path = self._get_path(extra=message) is_dir = os.path.isdir(full_path) - if not is_dir: self.send_ftp(550, "failed to change directory"); return + if not is_dir: + self.send_ftp(550, "failed to change directory") + return self.cwd = cwd self.ok() def on_size(self, message): - full_path = self._get_path(extra = message) - if os.path.isdir(full_path): size = 0 - else: size = os.path.getsize(full_path) + full_path = self._get_path(extra=message) + if os.path.isdir(full_path): + size = 0 + else: + size = os.path.getsize(full_path) self.send_ftp(213, "%d" % size) def on_mdtm(self, message): - full_path = self._get_path(extra = message) - if os.path.isdir(full_path): modified = 0 - else: modified = os.path.getmtime(full_path) + full_path = self._get_path(extra=message) + if os.path.isdir(full_path): + modified = 0 + else: + modified = os.path.getmtime(full_path) modified_d = datetime.datetime.utcfromtimestamp(modified) modified_s = modified_d.strftime("%Y%m%d%H%M%S") self.send_ftp(213, modified_s) @@ -346,7 +364,7 @@ def _file_send(self, connection): self.bytes_p -= data_l is_final = not data or self.bytes_p == 0 callback = self._file_finish if is_final else self._file_send - self.data_server.send_ftp(data, callback = callback) + self.data_server.send_ftp(data, callback=callback) def _file_finish(self, connection): self.file.close() @@ -355,18 +373,15 @@ def _file_finish(self, connection): self.on_flush_retr(connection) def _data_open(self): - if self.data_server: self._data_close() + if self.data_server: + self._data_close() self.data_server = FTPDataServer(self, self.owner) - self.data_server.serve( - host = self.owner.host, - port = 0, - load = False, - start = False - ) + self.data_server.serve(host=self.owner.host, port=0, load=False, start=False) return self.data_server def _data_close(self): - if not self.data_server: return + if not self.data_server: + return self.data_server.close_ftp() self.data_server = None @@ -378,8 +393,10 @@ def _list(self): # lists the directory for the current relative path, this # should get a list of files contained in it, in case there's # an error in such listing an empty string is returned - try: entries = os.listdir(relative_path) - except Exception: return "" + try: + entries = os.listdir(relative_path) + except Exception: + return "" # allocates space for the list that will hold the various lines # for the complete set of tiles in the directory @@ -389,14 +406,22 @@ def _list(self): # working directory to create their respective listing line for entry in entries: file_path = os.path.join(relative_path, entry) - try: mode = os.stat(file_path) - except Exception: continue + try: + mode = os.stat(file_path) + except Exception: + continue permissions = self._to_unix(mode) timestamp = mode.st_mtime date_time = datetime.datetime.utcfromtimestamp(timestamp) date_s = date_time.strftime("%b %d %Y") - line = "%s 1 %-8s %-8s %8lu %s %s\r\n" %\ - (permissions, "ftp", "ftp", mode.st_size, date_s, entry) + line = "%s 1 %-8s %-8s %8lu %s %s\r\n" % ( + permissions, + "ftp", + "ftp", + mode.st_size, + date_s, + entry, + ) lines.append(line) # returns the final list string result as the joining of the @@ -406,9 +431,11 @@ def _list(self): def _to_unix(self, mode): is_dir = "d" if stat.S_ISDIR(mode.st_mode) else "-" permissions = str(oct(mode.st_mode)[-3:]) - return is_dir + "".join([PERMISSIONS.get(int(item), item) for item in permissions]) + return is_dir + "".join( + [PERMISSIONS.get(int(item), item) for item in permissions] + ) - def _get_path(self, extra = None): + def _get_path(self, extra=None): # tries to decide on own to resolve the base and extra parts # of the path taking into account a possible absolute extra # value, the current working directory is only used in case @@ -427,6 +454,7 @@ def _get_path(self, extra = None): relative_path = os.path.normpath(relative_path) return relative_path + class FTPDataServer(netius.StreamServer): def __init__(self, connection, container, *args, **kwargs): @@ -438,7 +466,9 @@ def __init__(self, connection, container, *args, **kwargs): def on_connection_c(self, connection): netius.StreamServer.on_connection_c(self, connection) - if self.accepted: connection.close(); return + if self.accepted: + connection.close() + return self.accepted = connection self.flush_ftp() @@ -450,18 +480,23 @@ def on_data(self, connection, data): netius.StreamServer.on_data(self, connection, data) self.connection.data_ftp(data) - def send_ftp(self, data, delay = True, force = False, callback = None): - if not self.accepted: raise netius.DataError("No connection accepted") - return self.accepted.send(data, delay = delay, force = force, callback = callback) + def send_ftp(self, data, delay=True, force=False, callback=None): + if not self.accepted: + raise netius.DataError("No connection accepted") + return self.accepted.send(data, delay=delay, force=force, callback=callback) def flush_ftp(self): - if not self.accepted: return + if not self.accepted: + return self.connection.flush_ftp() def close_ftp(self): - if self.accepted: self.accepted.close(); self.accepted = None + if self.accepted: + self.accepted.close() + self.accepted = None self.cleanup() + class FTPServer(netius.ContainerServer): """ Abstract ftp server implementation that handles authentication @@ -473,13 +508,13 @@ class FTPServer(netius.ContainerServer): :see: http://tools.ietf.org/html/rfc959 """ - def __init__(self, base_path = "", auth_s = "dummy", *args, **kwargs): + def __init__(self, base_path="", auth_s="dummy", *args, **kwargs): netius.ContainerServer.__init__(self, *args, **kwargs) self.base_path = base_path self.auth_s = auth_s - def serve(self, host = "ftp.localhost", port = 21, *args, **kwargs): - netius.ContainerServer.serve(self, port = port, *args, **kwargs) + def serve(self, host="ftp.localhost", port=21, *args, **kwargs): + netius.ContainerServer.serve(self, port=port, *args, **kwargs) self.host = host def on_connection_c(self, connection): @@ -492,29 +527,38 @@ def on_data(self, connection, data): def on_serve(self): netius.ContainerServer.on_serve(self) - if self.env: self.base_path = self.get_env("BASE_PATH", self.base_path) - if self.env: self.host = self.get_env("FTP_HOST", self.host) - if self.env: self.auth_s = self.get_env("FTP_AUTH", self.auth_s) + if self.env: + self.base_path = self.get_env("BASE_PATH", self.base_path) + if self.env: + self.host = self.get_env("FTP_HOST", self.host) + if self.env: + self.auth_s = self.get_env("FTP_AUTH", self.auth_s) self.auth = self.get_auth(self.auth_s) - self.info("Starting FTP server on '%s' using '%s' ..." % (self.host, self.auth_s)) - self.info("Defining '%s' as the root of the file server ..." % (self.base_path or ".")) + self.info( + "Starting FTP server on '%s' using '%s' ..." % (self.host, self.auth_s) + ) + self.info( + "Defining '%s' as the root of the file server ..." % (self.base_path or ".") + ) - def build_connection(self, socket, address, ssl = False): + def build_connection(self, socket, address, ssl=False): return FTPConnection( - owner = self, - socket = socket, - address = address, - ssl = ssl, - base_path = self.base_path, - host = self.host + owner=self, + socket=socket, + address=address, + ssl=ssl, + base_path=self.base_path, + host=self.host, ) def on_line_ftp(self, connection, code, message): pass + if __name__ == "__main__": import logging - server = FTPServer(level = logging.DEBUG) - server.serve(env = True) + + server = FTPServer(level=logging.DEBUG) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/http.py b/src/netius/servers/http.py index 0945a60a0..771b9bef8 100644 --- a/src/netius/servers/http.py +++ b/src/netius/servers/http.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -45,8 +36,12 @@ import netius.common -from netius.common import PLAIN_ENCODING, CHUNKED_ENCODING,\ - GZIP_ENCODING, DEFLATE_ENCODING +from netius.common import ( + PLAIN_ENCODING, + CHUNKED_ENCODING, + GZIP_ENCODING, + DEFLATE_ENCODING, +) Z_PARTIAL_FLUSH = 1 """ The zlib constant value representing the partial flush @@ -54,18 +49,19 @@ locally as it is not defines under the zlib module """ ENCODING_MAP = dict( - plain = PLAIN_ENCODING, - chunked = CHUNKED_ENCODING, - gzip = GZIP_ENCODING, - deflate = DEFLATE_ENCODING + plain=PLAIN_ENCODING, + chunked=CHUNKED_ENCODING, + gzip=GZIP_ENCODING, + deflate=DEFLATE_ENCODING, ) """ The map associating the various types of encoding with the corresponding integer value for each of them this is used in the initial construction of the server """ + class HTTPConnection(netius.Connection): - def __init__(self, encoding = PLAIN_ENCODING, *args, **kwargs): + def __init__(self, encoding=PLAIN_ENCODING, *args, **kwargs): netius.Connection.__init__(self, *args, **kwargs) self.encoding = encoding self.current = encoding @@ -75,132 +71,86 @@ def __init__(self, encoding = PLAIN_ENCODING, *args, **kwargs): def open(self, *args, **kwargs): netius.Connection.open(self, *args, **kwargs) - if not self.is_open(): return + if not self.is_open(): + return self.parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True + self, type=netius.common.REQUEST, store=True ) self.parser.bind("on_data", self.on_data) def close(self, *args, **kwargs): netius.Connection.close(self, *args, **kwargs) - if not self.is_closed(): return - if self.parser: self.parser.destroy() - if self.gzip_m: self._close_gzip(safe = True) - - def info_dict(self, full = False): - info = netius.Connection.info_dict(self, full = full) - info.update( - encoding = self.encoding, - current = self.current - ) - if full: info.update( - parser = self.parser.info_dict() - ) + if not self.is_closed(): + return + if self.parser: + self.parser.destroy() + if self.gzip_m: + self._close_gzip(safe=True) + + def info_dict(self, full=False): + info = netius.Connection.info_dict(self, full=full) + info.update(encoding=self.encoding, current=self.current) + if full: + info.update(parser=self.parser.info_dict()) return info - def flush(self, stream = None, callback = None): - encoding = min(self.current, CHUNKED_ENCODING) if\ - self.owner.dynamic else self.current + def flush(self, stream=None, callback=None): + encoding = ( + min(self.current, CHUNKED_ENCODING) if self.owner.dynamic else self.current + ) if encoding == DEFLATE_ENCODING: - self._flush_gzip(stream = stream, callback = callback) + self._flush_gzip(stream=stream, callback=callback) elif encoding == GZIP_ENCODING: - self._flush_gzip(stream = stream, callback = callback) + self._flush_gzip(stream=stream, callback=callback) elif encoding == CHUNKED_ENCODING: - self._flush_chunked(stream = stream, callback = callback) + self._flush_chunked(stream=stream, callback=callback) elif encoding == PLAIN_ENCODING: - self._flush_plain(stream = stream, callback = callback) + self._flush_plain(stream=stream, callback=callback) self.current = self.encoding self.owner.on_flush_http( - self.connection_ctx, - self.parser_ctx, - encoding = encoding + self.connection_ctx, self.parser_ctx, encoding=encoding ) - def flush_s(self, stream = None, callback = None): - return self.flush(stream = stream, callback = callback) + def flush_s(self, stream=None, callback=None): + return self.flush(stream=stream, callback=callback) - def send_base( - self, - data, - stream = None, - final = True, - delay = True, - callback = None - ): + def send_base(self, data, stream=None, final=True, delay=True, callback=None): data = netius.legacy.bytes(data) if data else data - encoding = min(self.current, CHUNKED_ENCODING) if\ - self.owner.dynamic else self.current + encoding = ( + min(self.current, CHUNKED_ENCODING) if self.owner.dynamic else self.current + ) if encoding == PLAIN_ENCODING: return self.send_plain( - data, - stream = stream, - final = final, - delay = delay, - callback = callback + data, stream=stream, final=final, delay=delay, callback=callback ) elif encoding == CHUNKED_ENCODING: return self.send_chunked( - data, - stream = stream, - final = final, - delay = delay, - callback = callback + data, stream=stream, final=final, delay=delay, callback=callback ) elif encoding == GZIP_ENCODING: return self.send_gzip( - data, - stream = stream, - final = final, - delay = delay, - callback = callback + data, stream=stream, final=final, delay=delay, callback=callback ) elif encoding == DEFLATE_ENCODING: return self.send_gzip( - data, - stream = stream, - final = final, - delay = delay, - callback = callback + data, stream=stream, final=final, delay=delay, callback=callback ) - def send_plain( - self, - data, - stream = None, - final = True, - delay = True, - callback = None - ): - return self.send( - data, - delay = delay, - callback = callback - ) + def send_plain(self, data, stream=None, final=True, delay=True, callback=None): + return self.send(data, delay=delay, callback=callback) - def send_chunked( - self, - data, - stream = None, - final = True, - delay = True, - callback = None - ): + def send_chunked(self, data, stream=None, final=True, delay=True, callback=None): # in case there's no valid data to be sent uses the plain # send method to send the empty string and returns immediately # to the caller method, to avoid any problems - if not data: return self.send_plain( - data, - stream = stream, - final = final, - delay = delay, - callback = callback - ) + if not data: + return self.send_plain( + data, stream=stream, final=final, delay=delay, callback=callback + ) # creates the new list that is going to be used to store # the various parts of the chunk and then calculates the @@ -219,37 +169,24 @@ def send_chunked( # sends it to the connection using the plain method buffer_s = b"".join(buffer) return self.send_plain( - buffer_s, - stream = stream, - final = final, - delay = delay, - callback = callback + buffer_s, stream=stream, final=final, delay=delay, callback=callback ) def send_gzip( - self, - data, - stream = None, - final = True, - delay = True, - callback = None, - level = 6 + self, data, stream=None, final=True, delay=True, callback=None, level=6 ): # verifies if the provided data buffer is valid and in # in case it's not propagates the sending to the upper # layer (chunked sending) for proper processing - if not data: return self.send_chunked( - data, - stream = stream, - final = final, - delay = delay, - callback = callback - ) + if not data: + return self.send_chunked( + data, stream=stream, final=final, delay=delay, callback=callback + ) # tries to retrieve the gzip object for the current stream # in case it's a new connection one will be created with the # appropriate compression level (as expected) - gzip = self._get_gzip(stream, level = level) + gzip = self._get_gzip(stream, level=level) # compresses the provided data string and removes the # initial data contents of the compressed data because @@ -257,31 +194,28 @@ def send_gzip( # that in case the resulting of the compress operation # is not valid a sync flush operation is performed data_c = gzip.compress(data) - if not data_c: data_c = gzip.flush(Z_PARTIAL_FLUSH) + if not data_c: + data_c = gzip.flush(Z_PARTIAL_FLUSH) # sends the compressed data to the client endpoint setting # the correct callback values as requested return self.send_chunked( - data_c, - stream = stream, - final = final, - delay = delay, - callback = callback + data_c, stream=stream, final=final, delay=delay, callback=callback ) def send_response( self, - data = None, - headers = None, - version = None, - code = 200, - code_s = None, - apply = False, - stream = None, - final = True, - flush = True, - delay = True, - callback = None + data=None, + headers=None, + version=None, + code=200, + code_s=None, + apply=False, + stream=None, + final=True, + flush=True, + delay=True, + callback=None, ): # retrieves the various parts that define the response # and runs a series of normalization processes to retrieve @@ -306,40 +240,37 @@ def send_response( # in case the apply flag is set the apply all operation is performed # so that a series of headers are applied to the current context # (things like the name of the server connection, etc) - if apply: self.owner._apply_all(self.parser, self, headers) + if apply: + self.owner._apply_all(self.parser, self, headers) # sends the initial headers data (including status line), this should # trigger the initial data sent to the peer/client count = self.send_header( - headers = headers, - version = version, - code = code, - code_s = code_s, - stream = stream + headers=headers, version=version, code=code, code_s=code_s, stream=stream ) # sends the part/payload information (data) to the client and optionally # flushes the current internal buffers to enforce sending of the value count += self.send_part( data, - stream = stream, - final = final, - flush = flush, - delay = delay, - callback = callback + stream=stream, + final=final, + flush=flush, + delay=delay, + callback=callback, ) return count def send_header( self, - headers = None, - version = None, - code = 200, - code_s = None, - stream = None, - final = False, - delay = True, - callback = None + headers=None, + version=None, + code=200, + code_s=None, + stream=None, + final=False, + delay=True, + callback=None, ): # retrieves the various parts that define the response # and runs a series of normalization processes to retrieve @@ -355,18 +286,17 @@ def send_header( buffer.append("%s %d %s\r\n" % (version, code, code_s)) for key, value in netius.legacy.iteritems(headers): key = netius.common.header_up(key) - if not isinstance(value, list): value = (value,) - for _value in value: buffer.append("%s: %s\r\n" % (key, _value)) + if not isinstance(value, list): + value = (value,) + for _value in value: + buffer.append("%s: %s\r\n" % (key, _value)) buffer.append("\r\n") buffer_data = "".join(buffer) # sends the buffer data to the connection peer so that it gets notified # about the headers for the current communication/message count = self.send_plain( - buffer_data, - stream = stream, - delay = delay, - callback = callback + buffer_data, stream=stream, delay=delay, callback=callback ) # "notifies" the owner of the connection that the headers have been @@ -374,10 +304,10 @@ def send_header( self.owner.on_send_http( self.connection_ctx, self.parser_ctx, - headers = headers, - version = version, - code = code, - code_s = code_s + headers=headers, + version=version, + code=code, + code_s=code_s, ) # returns the final number of bytes that have been sent during the current @@ -385,26 +315,20 @@ def send_header( return count def send_part( - self, - data, - stream = None, - final = True, - flush = False, - delay = True, - callback = None + self, data, stream=None, final=True, flush=False, delay=True, callback=None ): - if flush: count = self.send_base(data); self.flush(callback = callback) - else: count = self.send_base(data, delay = delay, callback = callback) + if flush: + count = self.send_base(data) + self.flush(callback=callback) + else: + count = self.send_base(data, delay=delay, callback=callback) return count def parse(self, data): try: return self.parser.parse(data) except netius.ParserError as error: - self.send_response( - code = error.code, - apply = True - ) + self.send_response(code=error.code, apply=True) def resolve_encoding(self, parser): # in case the "target" encoding is the plain one nothing @@ -450,7 +374,8 @@ def set_encoding(self, encoding): def set_uncompressed(self): if self.current >= CHUNKED_ENCODING: self.current = CHUNKED_ENCODING - else: self.current = PLAIN_ENCODING + else: + self.current = PLAIN_ENCODING def set_plain(self): self.set_encoding(PLAIN_ENCODING) @@ -485,16 +410,18 @@ def is_uncompressed(self): def is_flushed(self): return self.current > PLAIN_ENCODING - def is_measurable(self, strict = True): - if self.is_compressed(): return False - if strict and self.is_chunked(): return False + def is_measurable(self, strict=True): + if self.is_compressed(): + return False + if strict and self.is_chunked(): + return False return True def on_data(self): self.owner.on_data_http(self.connection_ctx, self.parser_ctx) @contextlib.contextmanager - def ctx_request(self, args = None, kwargs = None): + def ctx_request(self, args=None, kwargs=None): yield @property @@ -505,31 +432,32 @@ def connection_ctx(self): def parser_ctx(self): return self.parser - def _flush_plain(self, stream = None, callback = None): - if not callback: return - self.send_plain(b"", stream = stream, callback = callback) + def _flush_plain(self, stream=None, callback=None): + if not callback: + return + self.send_plain(b"", stream=stream, callback=callback) - def _flush_chunked(self, stream = None, callback = None): - self.send_plain(b"0\r\n\r\n", stream = stream, callback = callback) + def _flush_chunked(self, stream=None, callback=None): + self.send_plain(b"0\r\n\r\n", stream=stream, callback=callback) - def _flush_gzip(self, stream = None, callback = None): + def _flush_gzip(self, stream=None, callback=None): # tries to retrieve a possible gzip object for the current # stream, note that no gzip object is going to be created # in case there's none defined for the current stream - gzip = self._get_gzip(stream, ensure = False) + gzip = self._get_gzip(stream, ensure=False) # in case the gzip structure has not been initialized # (no data sent) no need to run the flushing of the # gzip data, so only the chunked part is flushed if not gzip: - self._flush_chunked(stream = stream, callback = callback) + self._flush_chunked(stream=stream, callback=callback) return # flushes the internal zlib buffers to be able to retrieve # the data pending to be sent to the client and then sends # it using the chunked encoding strategy data_c = gzip.flush(zlib.Z_FINISH) - self.send_chunked(data_c, stream = stream, final = False) + self.send_chunked(data_c, stream=stream, final=False) # resets the gzip values to the original ones so that new # requests will starts the information from the beginning @@ -538,14 +466,15 @@ def _flush_gzip(self, stream = None, callback = None): # runs the flush operation for the underlying chunked encoding # layer so that the client is correctly notified about the # end of the current request (normal operation) - self._flush_chunked(stream = stream, callback = callback) + self._flush_chunked(stream=stream, callback=callback) - def _get_gzip(self, stream, level = 6, ensure = True): + def _get_gzip(self, stream, level=6, ensure=True): # tries to retrieve the proper gzip object for the requested # stream and in case there's one or if the ensure flag is set # the retrieved value is returned to the caller method gzip = self.gzip_m.get(stream, None) - if gzip or not ensure: return gzip + if gzip or not ensure: + return gzip # in case this is the first sending a new compress object # is created with the requested compress level, notice that @@ -565,10 +494,11 @@ def _set_gzip(self, stream, gzip): def _unset_gzip(self, stream): del self.gzip_m[stream] - def _close_gzip(self, safe = True): + def _close_gzip(self, safe=True): # in case the gzip object is not defined returns the control # to the caller method immediately (nothing to be done) - if not self.gzip_m: return + if not self.gzip_m: + return # saves the current gzip object map locally (releasing reference) # and then recreates a new empty dictionary on the other object @@ -587,7 +517,9 @@ def _close_gzip(self, safe = True): except Exception: # in case the safe flag is not set re-raises the exception # to the caller stack (as expected by the callers) - if not safe: raise + if not safe: + raise + class HTTPServer(netius.StreamServer): """ @@ -596,13 +528,11 @@ class HTTPServer(netius.StreamServer): headers and read of data. """ - BASE_HEADERS = { - "Server" : netius.IDENTIFIER - } + BASE_HEADERS = {"Server": netius.IDENTIFIER} """ The map containing the complete set of headers that are meant to be applied to all the responses """ - def __init__(self, encoding = "plain", common_log = None, *args, **kwargs): + def __init__(self, encoding="plain", common_log=None, *args, **kwargs): netius.StreamServer.__init__(self, *args, **kwargs) self.encoding_s = encoding self.common_log = common_log @@ -613,81 +543,45 @@ def __init__(self, encoding = "plain", common_log = None, *args, **kwargs): def build_data( cls, text, - url = None, - trace = False, - style = True, - style_urls = [], - encode = True, - encoding = "utf-8" + url=None, + trace=False, + style=True, + style_urls=[], + encode=True, + encoding="utf-8", ): - if url: return cls.build_iframe( - text, - url, - style = style, - encode = encode, - encoding = encoding - ) - else: return cls.build_text( - text, - trace = trace, - style = style, - encode = encode, - encoding = encoding - ) + if url: + return cls.build_iframe( + text, url, style=style, encode=encode, encoding=encoding + ) + else: + return cls.build_text( + text, trace=trace, style=style, encode=encode, encoding=encoding + ) @classmethod def build_text( - cls, - text, - trace = False, - style = True, - style_urls = [], - encode = True, - encoding = "utf-8" + cls, text, trace=False, style=True, style_urls=[], encode=True, encoding="utf-8" ): data = "".join( - cls._gen_text( - text, - trace = trace, - style = style, - style_urls = style_urls - ) - ) - if encode: data = netius.legacy.bytes( - data, - encoding = encoding, - force = True + cls._gen_text(text, trace=trace, style=style, style_urls=style_urls) ) + if encode: + data = netius.legacy.bytes(data, encoding=encoding, force=True) return data @classmethod def build_iframe( - cls, - text, - url, - style = True, - style_urls = [], - encode = True, - encoding = "utf-8" + cls, text, url, style=True, style_urls=[], encode=True, encoding="utf-8" ): - data = "".join( - cls._gen_iframe( - text, - url, - style = style, - style_urls = style_urls - ) - ) - if encode: data = netius.legacy.bytes( - data, - encoding = encoding, - force = True - ) + data = "".join(cls._gen_iframe(text, url, style=style, style_urls=style_urls)) + if encode: + data = netius.legacy.bytes(data, encoding=encoding, force=True) return data @classmethod - def _gen_text(cls, text, trace = False, style = True, style_urls = []): - for value in cls._gen_header(text, style = style, style_urls = style_urls): + def _gen_text(cls, text, trace=False, style=True, style_urls=[]): + for value in cls._gen_header(text, style=style, style_urls=style_urls): yield value yield "" @@ -695,8 +589,9 @@ def _gen_text(cls, text, trace = False, style = True, style_urls = []): if trace: lines = traceback.format_exc().splitlines() yield "
" - yield "
" - for line in lines: yield "
%s
" % line + yield '
' + for line in lines: + yield "
%s
" % line yield "
" yield "
" yield "" @@ -704,32 +599,35 @@ def _gen_text(cls, text, trace = False, style = True, style_urls = []): yield "" yield "" - for value in cls._gen_footer(): yield value + for value in cls._gen_footer(): + yield value @classmethod - def _gen_iframe(cls, text, url, style = True, style_urls = []): - for value in cls._gen_header(text, style = style, style_urls = style_urls): + def _gen_iframe(cls, text, url, style=True, style_urls=[]): + for value in cls._gen_header(text, style=style, style_urls=style_urls): yield value yield "" - yield "" % url + yield '' % url yield "" - for value in cls._gen_footer(): yield value + for value in cls._gen_footer(): + yield value @classmethod - def _gen_header(cls, title, meta = True, style = True, style_urls = []): + def _gen_header(cls, title, meta=True, style=True, style_urls=[]): yield "" yield "" yield "" if meta: - yield "" - yield "" + yield '' + yield '' yield "%s" % title if style: - for value in cls._gen_style(): yield value + for value in cls._gen_style(): + yield value for style_url in style_urls: - yield "" % style_url + yield '' % style_url yield "" @classmethod @@ -743,11 +641,12 @@ def _gen_style(cls): def cleanup(self): netius.StreamServer.cleanup(self) - if self.common_file: self.common_file.close() + if self.common_file: + self.common_file.close() - def info_dict(self, full = False): - info = netius.StreamServer.info_dict(self, full = full) - info.update(encoding_s = self.encoding_s) + def info_dict(self, full=False): + info = netius.StreamServer.info_dict(self, full=full) + info.update(encoding_s=self.encoding_s) return info def on_data(self, connection, data): @@ -756,20 +655,20 @@ def on_data(self, connection, data): def on_serve(self): netius.StreamServer.on_serve(self) - if self.env: self.encoding_s = self.get_env("ENCODING", self.encoding_s) - if self.env: self.common_log = self.get_env("COMMON_LOG", self.common_log) - if self.common_log: self.common_file = open(self.common_log, "wb+") + if self.env: + self.encoding_s = self.get_env("ENCODING", self.encoding_s) + if self.env: + self.common_log = self.get_env("COMMON_LOG", self.common_log) + if self.common_log: + self.common_file = open(self.common_log, "wb+") self.encoding = ENCODING_MAP.get(self.encoding_s, PLAIN_ENCODING) self.info("Starting HTTP server with '%s' encoding ..." % self.encoding_s) - if self.common_log: self.info("Logging with Common Log Format to '%s' ..." % self.common_log) + if self.common_log: + self.info("Logging with Common Log Format to '%s' ..." % self.common_log) - def build_connection(self, socket, address, ssl = False): + def build_connection(self, socket, address, ssl=False): return HTTPConnection( - owner = self, - socket = socket, - address = address, - ssl = ssl, - encoding = self.encoding + owner=self, socket=socket, address=address, ssl=ssl, encoding=self.encoding ) def on_data_http(self, connection, parser): @@ -778,108 +677,113 @@ def on_data_http(self, connection, parser): connection.resolve_encoding(parser) def on_send_http( - self, - connection, - parser, - headers = None, - version = None, - code = 200, - code_s = None + self, connection, parser, headers=None, version=None, code=200, code_s=None ): self.common_file and self._log_request( connection, parser, - headers = headers, - version = version, - code = code, - code_s = code_s, - output = self._write_common, - mode = "common" + headers=headers, + version=version, + code=code, + code_s=code_s, + output=self._write_common, + mode="common", ) - def on_flush_http(self, connection, parser, encoding = None): + def on_flush_http(self, connection, parser, encoding=None): self.debug( - "Connection '%s' %s from '%s' flushed" %\ - (connection.id, connection.address, self.name) + "Connection '%s' %s from '%s' flushed" + % (connection.id, connection.address, self.name) ) - def authorize(self, connection, parser, auth = None, **kwargs): + def authorize(self, connection, parser, auth=None, **kwargs): # determines the proper authorization method to be used # taking into account either the provided method or the # default one in case none is provided auth = auth or netius.PasswdAuth - if hasattr(auth, "auth"): auth_method = auth.auth - else: auth_method = auth.auth_i - if hasattr(auth, "is_simple"): is_simple = auth.is_simple() - else: is_simple = auth.is_simple_i() + if hasattr(auth, "auth"): + auth_method = auth.auth + else: + auth_method = auth.auth_i + if hasattr(auth, "is_simple"): + is_simple = auth.is_simple() + else: + is_simple = auth.is_simple_i() # constructs a dictionary that contains extra information # about the current connection/request that may be used # to further determine if the request is authorized kwargs = dict( - connection = connection, - parser = parser, - host = connection.address[0], - port = connection.address[1], - headers = parser.headers + connection=connection, + parser=parser, + host=connection.address[0], + port=connection.address[1], + headers=parser.headers, ) # in case the current authentication method is considered # simples (no "classic" username and password) the named # arguments dictionary is provided as the only input - if is_simple: return auth_method(**kwargs) + if is_simple: + return auth_method(**kwargs) # retrieves the authorization tuple (username and password) # using the current parser and verifies if at least one of # them is defined in case it's not returns an invalid result username, password = self._authorization(parser) - if not username and not password: return False + if not username and not password: + return False # uses the provided username and password to run the authentication # process using the method associated with the authorization structure return auth_method(username, password, **kwargs) def _apply_all( - self, - parser, - connection, - headers, - upper = True, - normalize = False, - replace = False + self, parser, connection, headers, upper=True, normalize=False, replace=False ): - if upper: self._headers_upper(headers) - if normalize: self._headers_normalize(headers) - self._apply_base(headers, replace = replace) - self._apply_parser(parser, headers, replace = replace) + if upper: + self._headers_upper(headers) + if normalize: + self._headers_normalize(headers) + self._apply_base(headers, replace=replace) + self._apply_parser(parser, headers, replace=replace) self._apply_connection(connection, headers) - def _apply_base(self, headers, replace = False): + def _apply_base(self, headers, replace=False): cls = self.__class__ for key, value in netius.legacy.iteritems(cls.BASE_HEADERS): - if not replace and key in headers: continue + if not replace and key in headers: + continue headers[key] = value - def _apply_parser(self, parser, headers, replace = False): - if not replace and "Connection" in headers: return - if parser.keep_alive: headers["Connection"] = "keep-alive" - else: headers["Connection"] = "close" + def _apply_parser(self, parser, headers, replace=False): + if not replace and "Connection" in headers: + return + if parser.keep_alive: + headers["Connection"] = "keep-alive" + else: + headers["Connection"] = "close" - def _apply_connection(self, connection, headers, strict = True): + def _apply_connection(self, connection, headers, strict=True): is_chunked = connection.is_chunked() is_gzip = connection.is_gzip() is_deflate = connection.is_deflate() is_compressed = connection.is_compressed() - is_measurable = connection.is_measurable(strict = strict) + is_measurable = connection.is_measurable(strict=strict) has_length = "Content-Length" in headers has_ranges = "Accept-Ranges" in headers - if is_chunked: headers["Transfer-Encoding"] = "chunked" - if is_gzip: headers["Content-Encoding"] = "gzip" - if is_deflate: headers["Content-Encoding"] = "deflate" + if is_chunked: + headers["Transfer-Encoding"] = "chunked" + if is_gzip: + headers["Content-Encoding"] = "gzip" + if is_deflate: + headers["Content-Encoding"] = "deflate" - if not is_measurable and has_length: del headers["Content-Length"] - if is_compressed and has_ranges: del headers["Accept-Ranges"] + if not is_measurable and has_length: + del headers["Content-Length"] + if is_compressed and has_ranges: + del headers["Accept-Ranges"] def _headers_upper(self, headers): for key, value in netius.legacy.items(headers): @@ -889,7 +793,8 @@ def _headers_upper(self, headers): def _headers_normalize(self, headers): for key, value in netius.legacy.items(headers): - if not type(value) in (list, tuple): continue + if not type(value) in (list, tuple): + continue headers[key] = ";".join(value) def _authorization(self, parser): @@ -898,7 +803,8 @@ def _authorization(self, parser): # an invalid value in case no header is defined headers = parser.headers authorization = headers.get("authorization", None) - if not authorization: return None, None + if not authorization: + return None, None # splits the authorization token between the realm and the # token value (decoding it afterwards) and then unpacks the @@ -912,12 +818,8 @@ def _authorization(self, parser): # and the password associated with the authorization return username, password - def _write_common(self, message, encoding = "utf-8"): - message = netius.legacy.bytes( - message, - encoding = encoding, - force = True - ) + def _write_common(self, message, encoding="utf-8"): + message = netius.legacy.bytes(message, encoding=encoding, force=True) self.common_file.write(message + b"\n") self.common_file.flush() @@ -926,7 +828,7 @@ def _log_request(self, connection, parser, *args, **kwargs): method = getattr(self, "_log_request_" + mode) return method(connection, parser, *args, **kwargs) - def _log_request_basic(self, connection, parser, output = None): + def _log_request_basic(self, connection, parser, output=None): # runs the defaulting operation on the logger output # method so that the default logger output is used instead output = output or self.debug @@ -949,13 +851,13 @@ def _log_request_common( self, connection, parser, - headers = None, - version = None, - code = 200, - code_s = None, - size_s = None, - username = "frank", - output = None + headers=None, + version=None, + code=200, + code_s=None, + size_s=None, + username="frank", + output=None, ): # runs the defaulting operation on the logger output # method so that the default logger output is used instead @@ -977,7 +879,15 @@ def _log_request_common( # creates the complete message in the "Common Log Format" and then # prints a debug message with that same contents - message = "%s %s %s [%s] \"%s %s %s\" %d %s" % ( - ip_address, "-", username, date_s, method, path, version_s, code, size_s + message = '%s %s %s [%s] "%s %s %s" %d %s' % ( + ip_address, + "-", + username, + date_s, + method, + path, + version_s, + code, + size_s, ) output(message) diff --git a/src/netius/servers/http2.py b/src/netius/servers/http2.py index 255215603..60a9a6c5b 100644 --- a/src/netius/servers/http2.py +++ b/src/netius/servers/http2.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -44,14 +35,15 @@ from . import http + class HTTP2Connection(http.HTTPConnection): def __init__( self, - legacy = True, - window = netius.common.HTTP2_WINDOW, - settings = netius.common.HTTP2_SETTINGS_OPTIMAL, - settings_r = netius.common.HTTP2_SETTINGS, + legacy=True, + window=netius.common.HTTP2_WINDOW, + settings=netius.common.HTTP2_SETTINGS_OPTIMAL, + settings_r=netius.common.HTTP2_SETTINGS, *args, **kwargs ): @@ -70,34 +62,33 @@ def __init__( def open(self, *args, **kwargs): http.HTTPConnection.open(self, *args, **kwargs) - if not self.is_open(): return - if not self.legacy: self.set_h2() + if not self.is_open(): + return + if not self.legacy: + self.set_h2() - def info_dict(self, full = False): - info = http.HTTPConnection.info_dict(self, full = full) + def info_dict(self, full=False): + info = http.HTTPConnection.info_dict(self, full=full) info.update( - legacy = self.legacy, - window = self.window, - window_o = self.window_o, - window_l = self.window_l, - window_t = self.window_t, - frames = len(self.frames) + legacy=self.legacy, + window=self.window, + window_o=self.window_o, + window_l=self.window_l, + window_t=self.window_t, + frames=len(self.frames), ) return info - def flush_s(self, stream = None, callback = None): + def flush_s(self, stream=None, callback=None): return self.send_part( - b"", - stream = stream, - final = True, - flush = True, - callback = callback + b"", stream=stream, final=True, flush=True, callback=callback ) def set_h2(self): self.legacy = False - if self.parser: self.parser.destroy() - self.parser = netius.common.HTTP2Parser(self, store = True) + if self.parser: + self.parser.destroy() + self.parser = netius.common.HTTP2Parser(self, store=True) self.parser.bind("on_data", self.on_data) self.parser.bind("on_header", self.on_header) self.parser.bind("on_payload", self.on_payload) @@ -114,15 +105,14 @@ def set_h2(self): def parse(self, data): if not self.legacy and not self.preface: data = self.parse_preface(data) - if not data: return + if not data: + return try: return self.parser.parse(data) except netius.ParserError as error: - if not self.legacy: raise - self.send_response( - code = error.code, - apply = True - ) + if not self.legacy: + raise + self.send_response(code=error.code, apply=True) def parse_preface(self, data): """ @@ -150,7 +140,8 @@ def parse_preface(self, data): self.preface_b += data preface_l = len(netius.common.HTTP2_PREFACE) is_size = len(self.preface_b) >= preface_l - if not is_size: return None + if not is_size: + return None # retrieves the preface string from the buffer (according to size) # and runs the string based verification, raising an exception in @@ -176,33 +167,18 @@ def parse_preface(self, data): # parsed by any extra operation return data - def send_plain( - self, - data, - stream = None, - final = True, - delay = True, - callback = None - ): - if self.legacy: return http.HTTPConnection.send_plain( - self, - data, - stream = stream, - final = final, - delay = delay, - callback = callback - ) + def send_plain(self, data, stream=None, final=True, delay=True, callback=None): + if self.legacy: + return http.HTTPConnection.send_plain( + self, data, stream=stream, final=final, delay=delay, callback=callback + ) # verifies if the data should be fragmented for the provided # stream and if that's not required send the required data # straight away with any required splitting/fragmentation of it if not self.fragmentable_stream(stream, data): return self.send_data( - data, - stream = stream, - end_stream = final, - delay = delay, - callback = callback + data, stream=stream, end_stream=final, delay=delay, callback=callback ) # sends the same data but using a fragmented approach where the @@ -210,45 +186,19 @@ def send_plain( # frame size, this is required to overcome limitations in the connection # that has been established with the other peer return self.send_fragmented( - data, - stream = stream, - final = final, - delay = delay, - callback = callback + data, stream=stream, final=final, delay=delay, callback=callback ) - def send_chunked( - self, - data, - stream = None, - final = True, - delay = True, - callback = None - ): - if self.legacy: return http.HTTPConnection.send_chunked( - self, - data, - stream = stream, - final = final, - delay = delay, - callback = callback - ) + def send_chunked(self, data, stream=None, final=True, delay=True, callback=None): + if self.legacy: + return http.HTTPConnection.send_chunked( + self, data, stream=stream, final=final, delay=delay, callback=callback + ) return self.send_plain( - data, - stream = stream, - final = final, - delay = delay, - callback = callback + data, stream=stream, final=final, delay=delay, callback=callback ) - def send_fragmented( - self, - data, - stream = None, - final = True, - delay = True, - callback = None - ): + def send_fragmented(self, data, stream=None, final=True, delay=True, callback=None): count = 0 fragments = self.fragment_stream(stream, data) fragments = list(fragments) @@ -260,51 +210,49 @@ def send_fragmented( if is_last: count += self.send_data( fragment, - stream = stream, - end_stream = final, - delay = delay, - callback = callback + stream=stream, + end_stream=final, + delay=delay, + callback=callback, ) else: count += self.send_data( - fragment, - stream = stream, - end_stream = False, - delay = delay + fragment, stream=stream, end_stream=False, delay=delay ) return count def send_response( self, - data = None, - headers = None, - version = None, - code = 200, - code_s = None, - apply = False, - stream = None, - final = True, - flush = True, - delay = True, - callback = None + data=None, + headers=None, + version=None, + code=200, + code_s=None, + apply=False, + stream=None, + final=True, + flush=True, + delay=True, + callback=None, ): # in case the legacy mode is enabled the send response call is # forwarded to the upper layers so that it's handled properly - if self.legacy: return http.HTTPConnection.send_response( - self, - data = data, - headers = headers, - version = version, - code = code, - code_s = code_s, - apply = apply, - stream = stream, - final = final, - flush = flush, - delay = delay, - callback = callback - ) + if self.legacy: + return http.HTTPConnection.send_response( + self, + data=data, + headers=headers, + version=version, + code=code, + code_s=code_s, + apply=apply, + stream=stream, + final=final, + flush=flush, + delay=delay, + callback=callback, + ) # retrieves the various parts that define the response # and runs a series of normalization processes to retrieve @@ -329,53 +277,51 @@ def send_response( # in case the apply flag is set the apply all operation is performed # so that a series of headers are applied to the current context # (things like the name of the server connection, etc) - if apply: self.owner._apply_all(self.parser, self, headers) + if apply: + self.owner._apply_all(self.parser, self, headers) # sends the initial headers data (including status line), this should # trigger the initial data sent to the peer/client count = self.send_header( - headers = headers, - version = version, - code = code, - code_s = code_s, - stream = stream + headers=headers, version=version, code=code, code_s=code_s, stream=stream ) # sends the part/payload information (data) to the client and optionally # flushes the current internal buffers to enforce sending of the value count += self.send_part( data, - stream = stream, - final = final, - flush = flush, - delay = delay, - callback = callback + stream=stream, + final=final, + flush=flush, + delay=delay, + callback=callback, ) return count def send_header( self, - headers = None, - version = None, - code = 200, - code_s = None, - stream = None, - final = False, - delay = True, - callback = None + headers=None, + version=None, + code=200, + code_s=None, + stream=None, + final=False, + delay=True, + callback=None, ): # in case the legacy mode is enabled the send header call is # forwarded to the upper layers so that it's handled properly - if self.legacy: return http.HTTPConnection.send_header( - self, - headers = headers, - version = version, - code = code, - code_s = code_s, - stream = stream, - delay = delay, - callback = callback - ) + if self.legacy: + return http.HTTPConnection.send_header( + self, + headers=headers, + version=version, + code=code, + code_s=code_s, + stream=stream, + delay=delay, + callback=callback, + ) # verifies if the headers value has been provided and in case it # has not creates a new empty dictionary (runtime compatibility) @@ -394,9 +340,12 @@ def send_header( # them and add them to the currently defined base list for key, value in netius.legacy.iteritems(headers): key = netius.common.header_down(key) - if key in ("connection", "transfer-encoding"): continue - if not isinstance(value, list): value = (value,) - for _value in value: headers_b.append((key, _value)) + if key in ("connection", "transfer-encoding"): + continue + if not isinstance(value, list): + value = (value,) + for _value in value: + headers_b.append((key, _value)) # verifies if this is considered to be the final operation in the stream # and if that's the case creates a new callback for the closing of the @@ -405,17 +354,13 @@ def send_header( old_callback = callback def callback(connection): - self.close_stream(stream, final = final) + self.close_stream(stream, final=final) old_callback and old_callback(connection) # runs the send headers operations that should send the headers list # to the other peer and returns the number of bytes sent count = self.send_headers( - headers_b, - end_stream = final, - stream = stream, - delay = delay, - callback = callback + headers_b, end_stream=final, stream=stream, delay=delay, callback=callback ) # "notifies" the owner of the connection that the headers have been @@ -423,10 +368,10 @@ def callback(connection): self.owner.on_send_http( self.connection_ctx, self.parser_ctx, - headers = headers, - version = version, - code = code, - code_s = code_s + headers=headers, + version=version, + code=code, + code_s=code_s, ) # returns the final number of bytes that have been sent during the current @@ -434,23 +379,18 @@ def callback(connection): return count def send_part( - self, - data, - stream = None, - final = True, - flush = False, - delay = True, - callback = None + self, data, stream=None, final=True, flush=False, delay=True, callback=None ): - if self.legacy: return http.HTTPConnection.send_part( - self, - data, - stream = stream, - final = final, - flush = flush, - delay = delay, - callback = callback - ) + if self.legacy: + return http.HTTPConnection.send_part( + self, + data, + stream=stream, + final=final, + flush=flush, + delay=delay, + callback=callback, + ) # verifies if this is considered to be the final operation in the stream # and if that's the case creates a new callback for the closing of the @@ -459,7 +399,7 @@ def send_part( old_callback = callback def callback(connection): - self.close_stream(stream, final = final) + self.close_stream(stream, final=final) old_callback and old_callback(connection) # verifies if the current connection/stream is flushed meaning that it requires @@ -469,52 +409,34 @@ def callback(connection): flush = flush and self.is_flushed() if flush: - count = self.send_base( - data, - stream = stream, - final = False - ) - self.flush(stream = stream, callback = callback) + count = self.send_base(data, stream=stream, final=False) + self.flush(stream=stream, callback=callback) else: count = self.send_base( - data, - stream = stream, - final = final, - delay = delay, - callback = callback + data, stream=stream, final=final, delay=delay, callback=callback ) return count def send_frame( - self, - type = 0x01, - flags = 0x00, - payload = b"", - stream = 0x00, - delay = True, - callback = None + self, type=0x01, flags=0x00, payload=b"", stream=0x00, delay=True, callback=None ): size = len(payload) size_h = size >> 16 - size_l = size & 0xffff + size_l = size & 0xFFFF header = struct.pack("!BHBBI", size_h, size_l, type, flags, stream) message = header + payload self.owner.on_send_http2(self, self.parser, type, flags, payload, stream) - return self.send(message, delay = delay, callback = callback) + return self.send(message, delay=delay, callback=callback) def send_data( - self, - data = b"", - end_stream = True, - stream = None, - delay = True, - callback = None + self, data=b"", end_stream=True, stream=None, delay=True, callback=None ): # builds the flags byte taking into account the various # options that have been passed to the sending of data flags = 0x00 data_l = len(data) - if end_stream: flags |= 0x01 + if end_stream: + flags |= 0x01 # builds the callback clojure so that the connection state # is properly updated upon the sending of data @@ -525,29 +447,29 @@ def send_data( # the sending of the frame to when the stream becomes available if not self.available_stream(stream, data_l): count = self.delay_frame( - type = netius.common.DATA, - flags = flags, - payload = data, - stream = stream, - delay = delay, - callback = callback + type=netius.common.DATA, + flags=flags, + payload=data, + stream=stream, + delay=delay, + callback=callback, ) self.try_unavailable(stream) return count # runs the increments remove window value, decrementing the window # by the size of the data being sent - self.increment_remote(stream, data_l * -1, all = True) + self.increment_remote(stream, data_l * -1, all=True) # runs the "proper" sending of the data frame, registering the callback # with the expected clojure count = self.send_frame( - type = netius.common.DATA, - flags = flags, - payload = data, - stream = stream, - delay = delay, - callback = callback + type=netius.common.DATA, + flags=flags, + payload=data, + stream=stream, + delay=delay, + callback=callback, ) # runs the try unavailable method to verify if the stream did became @@ -560,89 +482,77 @@ def send_data( def send_headers( self, - headers = [], - end_stream = False, - end_headers = True, - stream = None, - delay = True, - callback = None + headers=[], + end_stream=False, + end_headers=True, + stream=None, + delay=True, + callback=None, ): flags = 0x00 - if end_stream: flags |= 0x01 - if end_headers: flags |= 0x04 + if end_stream: + flags |= 0x01 + if end_headers: + flags |= 0x04 payload = self.parser.encoder.encode(headers) return self.send_frame( - type = netius.common.HEADERS, - flags = flags, - payload = payload, - stream = stream, - delay = delay, - callback = callback + type=netius.common.HEADERS, + flags=flags, + payload=payload, + stream=stream, + delay=delay, + callback=callback, ) - def send_rst_stream( - self, - error_code = 0x00, - stream = None, - delay = True, - callback = None - ): + def send_rst_stream(self, error_code=0x00, stream=None, delay=True, callback=None): payload = struct.pack("!I", error_code) return self.send_frame( - type = netius.common.RST_STREAM, - payload = payload, - stream = stream, - delay = delay, - callback = callback + type=netius.common.RST_STREAM, + payload=payload, + stream=stream, + delay=delay, + callback=callback, ) - def send_settings( - self, - settings = (), - ack = False, - delay = True, - callback = None - ): + def send_settings(self, settings=(), ack=False, delay=True, callback=None): flags = 0x00 - if ack: flags |= 0x01 + if ack: + flags |= 0x01 buffer = [] for ident, value in settings: setting_s = struct.pack("!HI", ident, value) buffer.append(setting_s) payload = b"".join(buffer) return self.send_frame( - type = netius.common.SETTINGS, - flags = flags, - payload = payload, - delay = delay, - callback = callback + type=netius.common.SETTINGS, + flags=flags, + payload=payload, + delay=delay, + callback=callback, ) def send_ping( - self, - opaque = b"\0\0\0\0\0\0\0\0", - ack = False, - delay = True, - callback = None + self, opaque=b"\0\0\0\0\0\0\0\0", ack=False, delay=True, callback=None ): flags = 0x00 - if ack: flags |= 0x01 + if ack: + flags |= 0x01 return self.send_frame( - type = netius.common.PING, - flags = flags, - payload = opaque, - delay = delay, - callback = callback + type=netius.common.PING, + flags=flags, + payload=opaque, + delay=delay, + callback=callback, ) def send_goaway( self, - last_stream = 0x00, - error_code = 0x00, - message = "", - close = True, - delay = True, - callback = None + last_stream=0x00, + error_code=0x00, + message="", + close=True, + delay=True, + callback=None, ): if close: old_callback = callback @@ -655,33 +565,29 @@ def callback(connection): payload = struct.pack("!II", last_stream, error_code) payload += message return self.send_frame( - type = netius.common.GOAWAY, - payload = payload, - delay = delay, - callback = callback + type=netius.common.GOAWAY, payload=payload, delay=delay, callback=callback ) - def send_window_update( - self, - increment = 0, - stream = None, - delay = True, - callback = None - ): + def send_window_update(self, increment=0, stream=None, delay=True, callback=None): payload = struct.pack("!I", increment) return self.send_frame( - type = netius.common.WINDOW_UPDATE, - payload = payload, - stream = stream, - delay = delay, - callback = callback + type=netius.common.WINDOW_UPDATE, + payload=payload, + stream=stream, + delay=delay, + callback=callback, ) def send_delta(self): - delta = self.window_l -\ - netius.common.HTTP2_SETTINGS[netius.common.http2.SETTINGS_INITIAL_WINDOW_SIZE] - if delta == 0: return - self.send_window_update(increment = delta, stream = 0x00) + delta = ( + self.window_l + - netius.common.HTTP2_SETTINGS[ + netius.common.http2.SETTINGS_INITIAL_WINDOW_SIZE + ] + ) + if delta == 0: + return + self.send_window_update(increment=delta, stream=0x00) def delay_frame(self, *args, **kwargs): # retrieves the reference to the stream identifier for which @@ -700,7 +606,7 @@ def delay_frame(self, *args, **kwargs): # "immediately" by this method return 0 - def flush_frames(self, all = True): + def flush_frames(self, all=True): """ Runs the flush operation on the delayed/pending frames, meaning that the window/availability tests are going to be run, checking @@ -753,14 +659,15 @@ def flush_frames(self, all = True): # retrieves the reference to the stream object from the # identifier of the stream, this may an invalid/unset value - _stream = self.parser._get_stream(stream, strict = False) + _stream = self.parser._get_stream(stream, strict=False) # verifies if the current stream to be flushed is still # open and if that's not the case removes the frame from # the frames queue and skips the current iteration if not _stream or not _stream.is_open(): self.frames.pop(offset) - if _stream: _stream.frames -= 1 + if _stream: + _stream.frames -= 1 continue # makes sure that the stream is currently marked as not available @@ -776,8 +683,9 @@ def flush_frames(self, all = True): # all flush operation is enabled in which the stream is marked # as starved and the current iteration is skipped trying to # flush frames from different streams - available = self.available_stream(stream, payload_l, strict = False) - if not available and not all: return False + available = self.available_stream(stream, payload_l, strict=False) + if not available and not all: + return False if not available and all: starved[stream] = True offset += 1 @@ -790,7 +698,7 @@ def flush_frames(self, all = True): # decrements the current stream window by the size of the payload # and then runs the send frame operation for the pending frame - self.increment_remote(stream, payload_l * -1, all = True) + self.increment_remote(stream, payload_l * -1, all=True) self.send_frame(*args, **kwargs) # returns the final result with a valid value meaning that all of the @@ -816,21 +724,29 @@ def flush_available(self): def set_settings(self, settings): self.settings_r.update(settings) - def close_stream(self, stream, final = False, flush = False, reset = False): - if not self.parser._has_stream(stream): return + def close_stream(self, stream, final=False, flush=False, reset=False): + if not self.parser._has_stream(stream): + return stream = self.parser._get_stream(stream) - if not stream: return + if not stream: + return stream.end_stream_l = final - stream.close(flush = flush, reset = reset) + stream.close(flush=flush, reset=reset) - def available_stream(self, stream, length, strict = True): - if self.window == 0: return False - if self.window < length: return False + def available_stream(self, stream, length, strict=True): + if self.window == 0: + return False + if self.window < length: + return False stream = self.parser._get_stream(stream) - if not stream: return True - if stream.window == 0: return False - if stream.window < length: return False - if strict and stream.frames: return False + if not stream: + return True + if stream.window == 0: + return False + if stream.window < length: + return False + if strict and stream.frames: + return False return True def fragment_stream(self, stream, data): @@ -842,11 +758,12 @@ def fragmentable_stream(self, stream, data): return stream.fragmentable(data) def open_stream(self, stream): - stream = self.parser._get_stream(stream, strict = False) - if not stream : return False + stream = self.parser._get_stream(stream, strict=False) + if not stream: + return False return True if stream and stream.is_open() else False - def try_available(self, stream, strict = True): + def try_available(self, stream, strict=True): """ Tries to determine if the stream with the provided identifier has just became available (unblocked from blocked state), this @@ -864,12 +781,13 @@ def try_available(self, stream, strict = True): # verifies if the stream is currently present in the map of unavailable # or blocked streams and if that's the case returns immediately as # the connection is not blocked - if not stream in self.unavailable: return + if not stream in self.unavailable: + return # tries to retrieve the stream object reference from the identifier and # in case none is retrieved (probably stream closed) returns immediately # and removes the stream from the map of unavailability - _stream = self.parser._get_stream(stream, strict = False) + _stream = self.parser._get_stream(stream, strict=False) if not _stream: del self.unavailable[stream] return @@ -877,14 +795,15 @@ def try_available(self, stream, strict = True): # tries to determine if the stream is available for the sending of at # least one byte and if that's not the case returns immediately, not # setting the stream as available - if not self.available_stream(stream, 1, strict = strict): return + if not self.available_stream(stream, 1, strict=strict): + return # removes the stream from the map of unavailable stream and "notifies" # the stream about the state changing operation to available/unblocked del self.unavailable[stream] _stream.available() - def try_unavailable(self, stream, strict = True): + def try_unavailable(self, stream, strict=True): """ Runs the unavailability test on the stream with the provided identifier meaning that a series of validation will be performed to try to determine @@ -902,25 +821,28 @@ def try_unavailable(self, stream, strict = True): # in case the stream identifier is already present in the unavailable # map it cannot be marked as unavailable again - if stream in self.unavailable: return + if stream in self.unavailable: + return # tries to retrieve the reference to the stream object to be tested # an in case none is found (connection closed) returns immediately - _stream = self.parser._get_stream(stream, strict = False) - if not _stream: return + _stream = self.parser._get_stream(stream, strict=False) + if not _stream: + return # runs the proper availability verification by testing the capacity # of the stream to send one byte and in case there's capacity to send # that byte the stream is considered available or unblocked, so the # control flow must be returned (stream not marked) - if self.available_stream(stream, 1, strict = strict): return + if self.available_stream(stream, 1, strict=strict): + return # marks the stream as unavailable and "notifies" the stream object # about the changing to the unavailable/blocked state self.unavailable[stream] = True _stream.unavailable() - def increment_remote(self, stream, increment, all = False): + def increment_remote(self, stream, increment, all=False): """ Increments the size of the remove window associated with the stream passed by argument by the size defined in the @@ -941,10 +863,13 @@ def increment_remote(self, stream, increment, all = False): should be updated by the increment operation. """ - if not stream or all: self.window += increment - if not stream: return + if not stream or all: + self.window += increment + if not stream: + return stream = self.parser._get_stream(stream) - if not stream: return + if not stream: + return stream.remote_update(increment) def increment_local(self, stream, increment): @@ -956,8 +881,7 @@ def increment_local(self, stream, increment): self.window_l += increment if self.window_l < self.window_t: self.send_window_update( - increment = self.window_o - self.window_l, - stream = 0x00 + increment=self.window_o - self.window_l, stream=0x00 ) self.window_l = self.window_o @@ -965,44 +889,40 @@ def increment_local(self, stream, increment): # provided identifier and then runs the local update # operation in it (may trigger window update flushing) stream = self.parser._get_stream(stream) - if not stream: return + if not stream: + return stream.local_update(increment) def error_connection( - self, - last_stream = 0x00, - error_code = 0x00, - message = "", - close = True, - callback = None + self, last_stream=0x00, error_code=0x00, message="", close=True, callback=None ): self.send_goaway( - last_stream = last_stream, - error_code = error_code, - message = message, - close = close, - callback = callback + last_stream=last_stream, + error_code=error_code, + message=message, + close=close, + callback=callback, ) def error_stream( self, stream, - last_stream = 0x00, - error_code = 0x00, - message = "", - close = True, - callback = None + last_stream=0x00, + error_code=0x00, + message="", + close=True, + callback=None, ): self.send_rst_stream( - error_code = error_code, - stream = stream, - callback = lambda c: self.error_connection( - last_stream = last_stream, - error_code = error_code, - message = message, - close = close, - callback = callback - ) + error_code=error_code, + stream=stream, + callback=lambda c: self.error_connection( + last_stream=last_stream, + error_code=error_code, + message=message, + close=close, + callback=callback, + ), ) def on_header(self, header): @@ -1015,10 +935,7 @@ def on_frame(self): self.owner.on_frame_http2(self, self.parser) def on_data_h2(self, stream, contents): - self.increment_local( - stream and stream.identifier, - increment = len(contents) * -1 - ) + self.increment_local(stream and stream.identifier, increment=len(contents) * -1) self.owner.on_data_http2(self, self.parser, stream, contents) def on_headers_h2(self, stream): @@ -1046,55 +963,64 @@ def on_continuation(self, stream): self.owner.on_continuation_http2(self, self.parser, stream) def is_throttleable(self): - if self.legacy: return http.HTTPConnection.is_throttleable(self) + if self.legacy: + return http.HTTPConnection.is_throttleable(self) return False @property def connection_ctx(self): - if self.legacy: return super(HTTP2Connection, self).connection_ctx - if not self.parser: return self - if not self.parser.stream_o: return self + if self.legacy: + return super(HTTP2Connection, self).connection_ctx + if not self.parser: + return self + if not self.parser.stream_o: + return self return self.parser.stream_o @property def parser_ctx(self): - if self.legacy: return super(HTTP2Connection, self).parser_ctx - if not self.parser: return None - if not self.parser.stream_o: return self.parser + if self.legacy: + return super(HTTP2Connection, self).parser_ctx + if not self.parser: + return None + if not self.parser.stream_o: + return self.parser return self.parser.stream_o def _build_c(self, callback, stream, data_l): - stream = self.parser._get_stream(stream, strict = False) - if not stream: return callback + stream = self.parser._get_stream(stream, strict=False) + if not stream: + return callback stream.pending_s += data_l old_callback = callback def callback(connection): stream.pending_s -= data_l - if not old_callback: return + if not old_callback: + return return old_callback(connection) return callback - def _flush_plain(self, stream = None, callback = None): - self.send_part(b"", stream = stream, callback = callback) + def _flush_plain(self, stream=None, callback=None): + self.send_part(b"", stream=stream, callback=callback) + + def _flush_chunked(self, stream=None, callback=None): + if self.legacy: + return http.HTTPConnection._flush_chunked( + self, stream=stream, callback=callback + ) + self._flush_plain(stream=stream, callback=callback) - def _flush_chunked(self, stream = None, callback = None): - if self.legacy: return http.HTTPConnection._flush_chunked( - self, - stream = stream, - callback = callback - ) - self._flush_plain(stream = stream, callback = callback) class HTTP2Server(http.HTTPServer): def __init__( self, - legacy = True, - safe = False, - settings = netius.common.HTTP2_SETTINGS_OPTIMAL, + legacy=True, + safe=False, + settings=netius.common.HTTP2_SETTINGS_OPTIMAL, *args, **kwargs ): @@ -1105,13 +1031,15 @@ def __init__( self.has_h2 = self._has_h2() self.has_all_h2 = self._has_all_h2() self._protocols = [] - self.safe = self.get_env("SAFE", self.safe, cast = bool) + self.safe = self.get_env("SAFE", self.safe, cast=bool) http.HTTPServer.__init__(self, *args, **kwargs) @classmethod def _has_hpack(cls): - try: import hpack #@UnusedImport - except ImportError: return False + try: + import hpack # @UnusedImport + except ImportError: + return False return True @classmethod @@ -1122,31 +1050,34 @@ def _has_alpn(cls): def _has_npn(cls): return ssl.HAS_NPN - def info_dict(self, full = False): - info = http.HTTPServer.info_dict(self, full = full) + def info_dict(self, full=False): + info = http.HTTPServer.info_dict(self, full=full) info.update( - legacy = self.legacy, - safe = self.safe, - has_h2 = self.has_h2, - has_all_h2 = self.has_all_h2 + legacy=self.legacy, + safe=self.safe, + has_h2=self.has_h2, + has_all_h2=self.has_all_h2, ) return info def get_protocols(self): - if self._protocols: return self._protocols - if not self.safe and self.has_h2: self._protocols.extend(["h2"]) - if self.legacy: self._protocols.extend(["http/1.1", "http/1.0"]) + if self._protocols: + return self._protocols + if not self.safe and self.has_h2: + self._protocols.extend(["h2"]) + if self.legacy: + self._protocols.extend(["http/1.1", "http/1.0"]) return self._protocols - def build_connection(self, socket, address, ssl = False): + def build_connection(self, socket, address, ssl=False): return HTTP2Connection( - owner = self, - socket = socket, - address = address, - ssl = ssl, - encoding = self.encoding, - legacy = self.legacy, - settings = self.settings + owner=self, + socket=socket, + address=address, + ssl=ssl, + encoding=self.encoding, + legacy=self.legacy, + settings=self.settings, ) def on_exception(self, exception, connection): @@ -1154,32 +1085,40 @@ def on_exception(self, exception, connection): return http.HTTPServer.on_exception(self, exception, connection) if not isinstance(exception, netius.NetiusError): return http.HTTPServer.on_exception(self, exception, connection) - try: self._handle_exception(exception, connection) - except Exception: connection.close() + try: + self._handle_exception(exception, connection) + except Exception: + connection.close() def on_ssl(self, connection): http.HTTPServer.on_ssl(self, connection) - if self.safe or not self.has_h2: return + if self.safe or not self.has_h2: + return protocol = connection.ssl_protocol() - if not protocol == "h2": return + if not protocol == "h2": + return connection.set_h2() def on_serve(self): http.HTTPServer.on_serve(self) safe_s = "with" if self.safe else "without" self.info("Starting HTTP2 server %s safe mode ..." % safe_s) - if not self.has_h2: self.info("No support for HTTP2 is available ...") - elif not self.has_all_h2: self.info("Limited support for HTTP2 is available ...") + if not self.has_h2: + self.info("No support for HTTP2 is available ...") + elif not self.has_all_h2: + self.info("Limited support for HTTP2 is available ...") for setting, name in netius.common.HTTP2_TUPLES: - if not self.env: continue - value = self.get_env(name, None, cast = int) - if value == None: continue + if not self.env: + continue + value = self.get_env(name, None, cast=int) + if value == None: + continue self.settings[setting] = value self.info("Setting HTTP2 setting %s with value '%d' ..." % (name, value)) self.settings_t = netius.legacy.items(self.settings) def on_preface_http2(self, connection, parser): - connection.send_settings(settings = self.settings_t) + connection.send_settings(settings=self.settings_t) connection.send_delta() def on_header_http2(self, connection, parser, header): @@ -1199,23 +1138,27 @@ def on_headers_http2(self, connection, parser, stream): pass def on_rst_stream_http2(self, connection, parser, stream, error_code): - if not stream: return + if not stream: + return stream.end_stream = True stream.end_stream_l = True - stream.close(reset = False) + stream.close(reset=False) def on_settings_http2(self, connection, parser, settings, ack): - if ack: return + if ack: + return self.debug("Received settings %s for connection" % str(settings)) connection.set_settings(dict(settings)) - connection.send_settings(ack = True) + connection.send_settings(ack=True) def on_ping_http2(self, connection, parser, opaque, ack): - if ack: return - connection.send_ping(opaque = opaque, ack = True) + if ack: + return + connection.send_ping(opaque=opaque, ack=True) def on_goaway_http2(self, connection, parser, last_stream, error_code, extra): - if error_code == 0x00: return + if error_code == 0x00: + return self._log_error(error_code, extra) def on_window_update_http2(self, connection, parser, stream, increment): @@ -1230,14 +1173,18 @@ def on_send_http2(self, connection, parser, type, flags, payload, stream): def _has_h2(self): cls = self.__class__ - if not cls._has_hpack(): return False + if not cls._has_hpack(): + return False return True def _has_all_h2(self): cls = self.__class__ - if not cls._has_hpack(): return False - if not cls._has_alpn(): return False - if not cls._has_npn(): return False + if not cls._has_hpack(): + return False + if not cls._has_alpn(): + return False + if not cls._has_npn(): + return False return True def _handle_exception(self, exception, connection): @@ -1247,82 +1194,76 @@ def _handle_exception(self, exception, connection): ignore = exception.get_kwarg("ignore", False) self.warning(exception) self.log_stack() - if ignore: return connection.send_ping(ack = True) - if stream: return connection.error_stream( - stream, - error_code = error_code, - message = message - ) - return connection.error_connection( - error_code = error_code, - message = message - ) + if ignore: + return connection.send_ping(ack=True) + if stream: + return connection.error_stream( + stream, error_code=error_code, message=message + ) + return connection.error_connection(error_code=error_code, message=message) def _log_frame(self, connection, parser): self.debug( - "Received frame 0x%02x (%s) for stream %d with length %d bytes" %\ - (parser.type, parser.type_s, parser.stream, parser.length) + "Received frame 0x%02x (%s) for stream %d with length %d bytes" + % (parser.type, parser.type_s, parser.stream, parser.length) ) self._log_frame_details( - parser, - parser.type_s, - parser.flags, - parser.payload, - parser.stream, - False + parser, parser.type_s, parser.flags, parser.payload, parser.stream, False ) def _log_error(self, error_code, extra): message = netius.legacy.str(extra) - self.warning( - "Received error 0x%02x with message '%s'" %\ - (error_code, message) - ) + self.warning("Received error 0x%02x with message '%s'" % (error_code, message)) def _log_send(self, connection, parser, type, flags, payload, stream): length = len(payload) type_s = parser.get_type_s(type) self.debug( - "Sent frame 0x%02x (%s) for stream %d with length %d bytes" %\ - (type, type_s, stream, length) + "Sent frame 0x%02x (%s) for stream %d with length %d bytes" + % (type, type_s, stream, length) ) self._log_frame_details(parser, type_s, flags, payload, stream, True) - def _log_window(self, parser, stream, remote = False): + def _log_window(self, parser, stream, remote=False): name = "SEND" if remote else "RECV" connection = parser.connection window = connection.window if remote else connection.window_l self.debug("Connection %s window size is %d bytes" % (name, window)) - stream = parser._get_stream(stream, strict = False) - if not stream: return + stream = parser._get_stream(stream, strict=False) + if not stream: + return window = stream.window if remote else stream.window_l self.debug( - "Stream %d (dependency = %d, weight = %d) %s window size is %d bytes" %\ - (stream.identifier, stream.dependency, stream.weight, name, window) + "Stream %d (dependency = %d, weight = %d) %s window size is %d bytes" + % (stream.identifier, stream.dependency, stream.weight, name, window) ) def _log_frame_details(self, parser, type_s, flags, payload, stream, out): type_l = type_s.lower() method_s = "_log_frame_" + type_l - if not hasattr(self, method_s): return + if not hasattr(self, method_s): + return method = getattr(self, method_s) method(parser, flags, payload, stream, out) def _log_frame_flags(self, type_s, *args): flags = ", ".join(args) pluralized = "flags" if len(args) > 1 else "flag" - if flags: self.debug("%s with %s %s active" % (type_s, pluralized, flags)) - else: self.debug("Frame %s with no flags active" % type_s) + if flags: + self.debug("%s with %s %s active" % (type_s, pluralized, flags)) + else: + self.debug("Frame %s with no flags active" % type_s) def _log_frame_data(self, parser, flags, payload, stream, out): - _stream = parser._get_stream(stream, strict = False) + _stream = parser._get_stream(stream, strict=False) flags_l = self._flags_l(flags, (("END_STREAM", 0x01),)) self._log_frame_flags("DATA", *flags_l) - if _stream: self.debug("Frame DATA for path '%s'" % _stream.path_s) - self._log_window(parser, stream, remote = out) + if _stream: + self.debug("Frame DATA for path '%s'" % _stream.path_s) + self._log_window(parser, stream, remote=out) def _log_frame_headers(self, parser, flags, payload, stream, out): flags_l = self._flags_l( @@ -1331,32 +1272,33 @@ def _log_frame_headers(self, parser, flags, payload, stream, out): ("END_STREAM", 0x01), ("END_HEADERS", 0x04), ("PADDED", 0x08), - ("PRIORITY", 0x20) - ) + ("PRIORITY", 0x20), + ), ) self._log_frame_flags("HEADERS", *flags_l) def _log_frame_rst_stream(self, parser, flags, payload, stream, out): - error_code, = struct.unpack("!I", payload) + (error_code,) = struct.unpack("!I", payload) self.debug("Frame RST_STREAM with error code %d" % error_code) def _log_frame_goaway(self, parser, flags, payload, stream, out): last_stream, error_code = struct.unpack("!II", payload[:8]) extra = payload[8:] self.debug( - "Frame GOAWAY with last stream %d, error code %d and message %s" %\ - (last_stream, error_code, extra) + "Frame GOAWAY with last stream %d, error code %d and message %s" + % (last_stream, error_code, extra) ) def _log_frame_window_update(self, parser, flags, payload, stream, out): - increment, = struct.unpack("!I", payload) + (increment,) = struct.unpack("!I", payload) self.debug("Frame WINDOW_UPDATE with increment %d" % increment) - self._log_window(parser, stream, remote = not out) + self._log_window(parser, stream, remote=not out) def _flags_l(self, flags, definition): flags_l = [] for name, value in definition: valid = True if flags & value else False - if not valid: continue + if not valid: + continue flags_l.append(name) return flags_l diff --git a/src/netius/servers/mjpg.py b/src/netius/servers/mjpg.py index 9908518f2..ac7bece5c 100644 --- a/src/netius/servers/mjpg.py +++ b/src/netius/servers/mjpg.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -47,6 +38,7 @@ """ The defualt boundary string value to be used in case no boundary is provided to the app """ + class MJPGServer(http2.HTTP2Server): """ Server class for the creation of an HTTP server for @@ -56,7 +48,7 @@ class MJPGServer(http2.HTTP2Server): proper implementation should be made from this. """ - def __init__(self, boundary = BOUNDARY, *args, **kwargs): + def __init__(self, boundary=BOUNDARY, *args, **kwargs): http2.HTTP2Server.__init__(self, *args, **kwargs) self.boundary = boundary @@ -67,17 +59,14 @@ def on_data_http(self, connection, parser): ("Content-type", "multipart/x-mixed-replace; boundary=%s" % self.boundary), ("Cache-Control", "no-cache"), ("Connection", "close"), - ("Pragma", "no-cache") + ("Pragma", "no-cache"), ] version_s = parser.version_s headers = dict(headers) connection.send_header( - headers = headers, - version = version_s, - code = 200, - code_s = "OK" + headers=headers, version=version_s, code=200, code_s="OK" ) def send(connection): @@ -86,8 +75,10 @@ def send(connection): delay = self.get_delay(connection) data = self.get_image(connection) - if not data: self.warning("No image retrieved from provider") - if not data: data = b"" + if not data: + self.warning("No image retrieved from provider") + if not data: + data = b"" data_l = len(data) @@ -102,10 +93,12 @@ def send(connection): buffer_d = b"".join(buffer) def next(connection): - def callable(): send(connection) + def callable(): + send(connection) + self.delay(callable, delay) - connection.send_part(buffer_d, final = False, callback = next) + connection.send_part(buffer_d, final=False, callback=next) send(connection) @@ -117,7 +110,8 @@ def get_delay(self, connection): def get_image(self, connection): has_index = hasattr(connection, "index") - if not has_index: connection.index = 0 + if not has_index: + connection.index = 0 target = connection.index % 2 connection.index += 1 @@ -126,13 +120,16 @@ def get_image(self, connection): file_path = os.path.join(extras_path, "boy_%d.jpg" % target) file = open(file_path, "rb") - try: data = file.read() - finally: file.close() + try: + data = file.read() + finally: + file.close() return data + if __name__ == "__main__": server = MJPGServer() - server.serve(env = True) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/pop.py b/src/netius/servers/pop.py index b57173174..fd8f37639 100644 --- a/src/netius/servers/pop.py +++ b/src/netius/servers/pop.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -54,24 +45,19 @@ of the message to the client, this value will affect the memory used by the server and its network performance """ -CAPABILITIES = ( - "TOP", - "USER", - "STLS" -) +CAPABILITIES = ("TOP", "USER", "STLS") """ The capabilities that are going to be exposed to the client as the ones handled by the server, should only expose the ones that are properly handled by the server """ -AUTH_METHODS = ( - "PLAIN", -) +AUTH_METHODS = ("PLAIN",) """ Authentication methods that are available to be "used" by the client, should be mapped into the proper auth handlers """ + class POPConnection(netius.Connection): - def __init__(self, host = "pop.localhost", *args, **kwargs): + def __init__(self, host="pop.localhost", *args, **kwargs): netius.Connection.__init__(self, *args, **kwargs) self.parser = None self.host = host @@ -87,26 +73,34 @@ def __init__(self, host = "pop.localhost", *args, **kwargs): def open(self, *args, **kwargs): netius.Connection.open(self, *args, **kwargs) - if not self.is_open(): return + if not self.is_open(): + return self.parser = netius.common.POPParser(self) self.parser.bind("on_line", self.on_line) def close(self, *args, **kwargs): netius.Connection.close(self, *args, **kwargs) - if not self.is_closed(): return - if self.file: self.file.close(); self.file = None - if self.parser: self.parser.destroy() + if not self.is_closed(): + return + if self.file: + self.file.close() + self.file = None + if self.parser: + self.parser.destroy() def parse(self, data): - if self.state == AUTH_STATE: self.on_user(data) - else: return self.parser.parse(data) + if self.state == AUTH_STATE: + self.on_user(data) + else: + return self.parser.parse(data) - def send_pop(self, message = "", lines = (), status = "OK", delay = True, callback = None): + def send_pop(self, message="", lines=(), status="OK", delay=True, callback=None): status_s = "+" + status if status == "OK" else "-" + status base = "%s %s" % (status_s, message) data = base + "\r\n" - if lines: data += "\r\n".join(lines) + "\r\n.\r\n" - count = self.send(data, delay = delay, callback = callback) + if lines: + data += "\r\n".join(lines) + "\r\n.\r\n" + count = self.send(data, delay=delay, callback=callback) self.owner.debug(base) return count @@ -118,21 +112,22 @@ def ready(self): def starttls(self): def callback(connection): - connection.upgrade(server = True) + connection.upgrade(server=True) + message = "go ahead" - self.send_pop(message, callback = callback) + self.send_pop(message, callback=callback) self.state = HELO_STATE def capa(self): self.assert_s(HELO_STATE) message = "list follows" - self.send_pop(message, lines = CAPABILITIES) + self.send_pop(message, lines=CAPABILITIES) self.state = HELO_STATE def auth(self): self.assert_s(HELO_STATE) message = "list follows" - self.send_pop(message, lines = AUTH_METHODS) + self.send_pop(message, lines=AUTH_METHODS) self.state = HELO_STATE def accept(self): @@ -154,7 +149,7 @@ def list(self): size = self.sizes[index] line = "%d %d" % (index, size) lines.append(line) - self.send_pop(message, lines = lines) + self.send_pop(message, lines=lines) def uidl(self): self.owner.on_uidl_pop(self) @@ -164,18 +159,24 @@ def uidl(self): key = self.keys[index] line = "%d %s" % (index, key) lines.append(line) - self.send_pop(message, lines = lines) + self.send_pop(message, lines=lines) def retr(self, index): def callback(connection): - if not connection.file: return + if not connection.file: + return file = connection.file contents = file.read(CHUNK_SIZE) - if contents: self.send(contents, callback = callback) - else: self.send("\r\n.\r\n"); file.close(); connection.file = None + if contents: + self.send(contents, callback=callback) + else: + self.send("\r\n.\r\n") + file.close() + connection.file = None + self.owner.on_retr_pop(self, index) message = "%d octets" % self.size - self.send_pop(message, callback = callback) + self.send_pop(message, callback=callback) callback(self) def dele(self, index): @@ -193,7 +194,7 @@ def ok(self): def not_implemented(self): message = "not implemented" - self.send_pop(message, status = "ERR") + self.send_pop(message, status="ERR") def on_line(self, code, message): # "joins" the code and the message part of the message into the base @@ -215,7 +216,8 @@ def on_line(self, code, message): # does not raises an exception indicating the problem with the # code that has just been received (probably erroneous) exists = hasattr(self, method_n) - if not exists: raise netius.ParserError("Invalid code '%s'" % code) + if not exists: + raise netius.ParserError("Invalid code '%s'" % code) # retrieves the reference to the method that is going to be called # for the handling of the current line from the current instance and @@ -230,8 +232,10 @@ def on_capa(self, message): self.capa() def on_auth(self, message): - if message: self.accept() - else: self.auth() + if message: + self.accept() + else: + self.auth() def on_stat(self, message): self.stat() @@ -252,7 +256,7 @@ def on_dele(self, message): def on_quit(self, message): self.bye() - self.close(flush = True) + self.close(flush=True) def on_user(self, token): # adds the partial token value to the token buffer and @@ -260,7 +264,8 @@ def on_user(self, token): # case continues the parsing otherwise returns immediately self.token_buf.append(token) index = token.find(b"\n") - if index == -1: return + if index == -1: + return # removes the extra characters from the token so that no # extra value is considered to be part of the token @@ -287,18 +292,20 @@ def on_user(self, token): self.state = SESSION_STATE def assert_s(self, expected): - if self.state == expected: return + if self.state == expected: + return raise netius.ParserError("Invalid state") + class POPServer(netius.StreamServer): - def __init__(self, adapter_s = "memory", auth_s = "dummy", *args, **kwargs): + def __init__(self, adapter_s="memory", auth_s="dummy", *args, **kwargs): netius.StreamServer.__init__(self, *args, **kwargs) self.adapter_s = adapter_s self.auth_s = auth_s - def serve(self, host = "pop.localhost", port = 110, *args, **kwargs): - netius.StreamServer.serve(self, port = port, *args, **kwargs) + def serve(self, host="pop.localhost", port=110, *args, **kwargs): + netius.StreamServer.serve(self, port=port, *args, **kwargs) self.host = host def on_connection_c(self, connection): @@ -307,7 +314,9 @@ def on_connection_c(self, connection): def on_connection_d(self, connection): netius.StreamServer.on_connection_d(self, connection) - if connection.file: connection.file.close(); connection.file = None + if connection.file: + connection.file.close() + connection.file = None def on_data(self, connection, data): netius.StreamServer.on_data(self, connection, data) @@ -315,23 +324,22 @@ def on_data(self, connection, data): def on_serve(self): netius.StreamServer.on_serve(self) - if self.env: self.host = self.get_env("POP_HOST", self.host) - if self.env: self.adapter_s = self.get_env("POP_ADAPTER", self.adapter_s) - if self.env: self.auth_s = self.get_env("POP_AUTH", self.auth_s) + if self.env: + self.host = self.get_env("POP_HOST", self.host) + if self.env: + self.adapter_s = self.get_env("POP_ADAPTER", self.adapter_s) + if self.env: + self.auth_s = self.get_env("POP_AUTH", self.auth_s) self.adapter = self.get_adapter(self.adapter_s) self.auth = self.get_auth(self.auth_s) self.info( - "Starting POP server on '%s' using '%s' and '%s' ..." % - (self.host, self.adapter_s, self.auth_s) + "Starting POP server on '%s' using '%s' and '%s' ..." + % (self.host, self.adapter_s, self.auth_s) ) - def build_connection(self, socket, address, ssl = False): + def build_connection(self, socket, address, ssl=False): return POPConnection( - owner = self, - socket = socket, - address = address, - ssl = ssl, - host = self.host + owner=self, socket=socket, address=address, ssl=ssl, host=self.host ) def on_line_pop(self, connection, code, message): @@ -343,19 +351,19 @@ def on_auth_pop(self, connection, username, password): def on_stat_pop(self, connection): username = connection.username - count = self.adapter.count(owner = username) - total = self.adapter.total(owner = username) + count = self.adapter.count(owner=username) + total = self.adapter.total(owner=username) connection.count = count connection.byte_c = total def on_list_pop(self, connection): username = connection.username - sizes = self.adapter.sizes(owner = username) + sizes = self.adapter.sizes(owner=username) connection.sizes = sizes def on_uidl_pop(self, connection): username = connection.username - connection.keys = self.adapter.list(owner = username) + connection.keys = self.adapter.list(owner=username) def on_retr_pop(self, connection, index): key = connection.keys[index] @@ -365,11 +373,13 @@ def on_retr_pop(self, connection, index): def on_dele_pop(self, connection, index): username = connection.username key = connection.keys[index] - self.adapter.delete(key, owner = username) + self.adapter.delete(key, owner=username) + if __name__ == "__main__": import logging - server = POPServer(level = logging.DEBUG) - server.serve(env = True) + + server = POPServer(level=logging.DEBUG) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/proxy.py b/src/netius/servers/proxy.py index 8085d49f9..5a4cd288f 100644 --- a/src/netius/servers/proxy.py +++ b/src/netius/servers/proxy.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -58,11 +49,13 @@ avoids the starvation of the producer to consumer relation that could cause memory problems """ + class ProxyConnection(http2.HTTP2Connection): def open(self, *args, **kwargs): http2.HTTP2Connection.open(self, *args, **kwargs) - if not self.is_open(): return + if not self.is_open(): + return self.parser.store = False self.parser.bind("on_headers", self.on_headers) self.parser.bind("on_partial", self.on_partial) @@ -91,21 +84,22 @@ def on_available(self): def on_unavailable(self): self.owner.on_unavailable(self.connection_ctx, self.parser_ctx) + class ProxyServer(http2.HTTP2Server): def __init__( self, - dynamic = True, - throttle = True, - trust_origin = False, - max_pending = MAX_PENDING, + dynamic=True, + throttle=True, + trust_origin=False, + max_pending=MAX_PENDING, *args, **kwargs ): http2.HTTP2Server.__init__( self, - receive_buffer_c = int(max_pending * BUFFER_RATIO), - send_buffer_c = int(max_pending * BUFFER_RATIO), + receive_buffer_c=int(max_pending * BUFFER_RATIO), + send_buffer_c=int(max_pending * BUFFER_RATIO), *args, **kwargs ) @@ -117,10 +111,10 @@ def __init__( self.conn_map = {} self.http_client = netius.clients.HTTPClient( - thread = False, - auto_release = False, - receive_buffer = max_pending, - send_buffer = max_pending, + thread=False, + auto_release=False, + receive_buffer=max_pending, + send_buffer=max_pending, *args, **kwargs ) @@ -133,9 +127,9 @@ def __init__( self.http_client.bind("error", self._on_prx_error) self.raw_client = netius.clients.RawClient( - thread = False, - receive_buffer = int(max_pending * BUFFER_RATIO), - send_buffer = int(max_pending * BUFFER_RATIO), + thread=False, + receive_buffer=int(max_pending * BUFFER_RATIO), + send_buffer=int(max_pending * BUFFER_RATIO), *args, **kwargs ) @@ -158,7 +152,8 @@ def stop(self): # verifies if there's a container object currently defined in # the object and in case it does exist propagates the stop call # to the container so that the proper stop operation is performed - if not self.container: return + if not self.container: + return self.container.stop() def cleanup(self): @@ -171,7 +166,8 @@ def cleanup(self): # verifies if the container is valid and if that's not the case # returns the control flow immediately (as expected) - if not container: return + if not container: + return # runs the cleanup operation on the cleanup, this should properly # propagate the operation to the owner container (as expected) @@ -182,24 +178,25 @@ def cleanup(self): self.http_client = None self.raw_client = None - def info_dict(self, full = False): - info = http2.HTTP2Server.info_dict(self, full = full) + def info_dict(self, full=False): + info = http2.HTTP2Server.info_dict(self, full=full) info.update( - dynamic = self.dynamic, - throttle = self.throttle, - max_pending = self.max_pending, - min_pending = self.min_pending, - http_client = self.http_client.info_dict(full = full), - raw_client = self.raw_client.info_dict(full = full) + dynamic=self.dynamic, + throttle=self.throttle, + max_pending=self.max_pending, + min_pending=self.min_pending, + http_client=self.http_client.info_dict(full=full), + raw_client=self.raw_client.info_dict(full=full), ) return info - def connections_dict(self, full = False, parent = False): - if parent: return http2.HTTP2Server.connections_dict(self, full = full) - return self.container.connections_dict(full = full) + def connections_dict(self, full=False, parent=False): + if parent: + return http2.HTTP2Server.connections_dict(self, full=full) + return self.container.connections_dict(full=full) - def connection_dict(self, id, full = False): - return self.container.connection_dict(id, full = full) + def connection_dict(self, id, full=False): + return self.container.connection_dict(id, full=full) def on_data(self, connection, data): netius.StreamServer.on_data(self, connection, data) @@ -209,19 +206,22 @@ def on_data(self, connection, data): # (initial handshake or HTTP client proxy) runs the parse # step on the data and then returns immediately tunnel_c = hasattr(connection, "tunnel_c") and connection.tunnel_c - if not tunnel_c: connection.parse(data); return + if not tunnel_c: + connection.parse(data) + return # verifies that the current size of the pending buffer is greater # than the maximum size for the pending buffer the read operations # if that the case the read operations must be disabled should_throttle = self.throttle and connection.is_throttleable() should_disable = should_throttle and tunnel_c.is_exhausted() - if should_disable: connection.disable_read() + if should_disable: + connection.disable_read() # performs the sending operation on the data but uses the throttle # callback so that the connection read operations may be resumed if # the buffer has reached certain (minimum) levels - tunnel_c.send(data, callback = self._throttle) + tunnel_c.send(data, callback=self._throttle) def on_connection_d(self, connection): http2.HTTP2Server.on_connection_d(self, connection) @@ -229,8 +229,10 @@ def on_connection_d(self, connection): tunnel_c = hasattr(connection, "tunnel_c") and connection.tunnel_c proxy_c = hasattr(connection, "proxy_c") and connection.proxy_c - if tunnel_c: tunnel_c.close() - if proxy_c: proxy_c.close() + if tunnel_c: + tunnel_c.close() + if proxy_c: + proxy_c.close() setattr(connection, "tunnel_c", None) setattr(connection, "proxy_c", None) @@ -241,105 +243,132 @@ def on_stream_d(self, stream): tunnel_c = hasattr(stream, "tunnel_c") and stream.tunnel_c proxy_c = hasattr(stream, "proxy_c") and stream.proxy_c - if tunnel_c: tunnel_c.close() - if proxy_c: proxy_c.close() + if tunnel_c: + tunnel_c.close() + if proxy_c: + proxy_c.close() setattr(stream, "tunnel_c", None) setattr(stream, "proxy_c", None) def on_serve(self): http2.HTTP2Server.on_serve(self) - if self.env: self.dynamic = self.get_env("DYNAMIC", self.dynamic, cast = bool) - if self.env: self.throttle = self.get_env("THROTTLE", self.throttle, cast = bool) - if self.env: self.trust_origin = self.get_env("TRUST_ORIGIN", self.trust_origin, cast = bool) - if self.dynamic: self.info("Using dynamic encoding (no content re-encoding) in proxy ...") - if self.throttle: self.info("Throttling connections in proxy ...") - else: self.info("Not throttling connections in proxy ...") - if self.trust_origin: self.info("Origin is considered \"trustable\" by proxy") + if self.env: + self.dynamic = self.get_env("DYNAMIC", self.dynamic, cast=bool) + if self.env: + self.throttle = self.get_env("THROTTLE", self.throttle, cast=bool) + if self.env: + self.trust_origin = self.get_env( + "TRUST_ORIGIN", self.trust_origin, cast=bool + ) + if self.dynamic: + self.info("Using dynamic encoding (no content re-encoding) in proxy ...") + if self.throttle: + self.info("Throttling connections in proxy ...") + else: + self.info("Not throttling connections in proxy ...") + if self.trust_origin: + self.info('Origin is considered "trustable" by proxy') def on_data_http(self, connection, parser): http2.HTTP2Server.on_data_http(self, connection, parser) - if not hasattr(connection, "proxy_c"): return + if not hasattr(connection, "proxy_c"): + return proxy_c = connection.proxy_c should_throttle = self.throttle and connection.is_throttleable() should_disable = should_throttle and proxy_c.is_exhausted() - if should_disable: connection.disable_read() - proxy_c.flush(force = True, callback = self._throttle) + if should_disable: + connection.disable_read() + proxy_c.flush(force=True, callback=self._throttle) def on_headers(self, connection, parser): pass def on_partial(self, connection, parser, data): - if not hasattr(connection, "proxy_c"): return + if not hasattr(connection, "proxy_c"): + return proxy_c = connection.proxy_c should_throttle = self.throttle and connection.is_throttleable() should_disable = should_throttle and proxy_c.is_exhausted() - if should_disable: connection.disable_read() - proxy_c.send_base(data, force = True, callback = self._throttle) + if should_disable: + connection.disable_read() + proxy_c.send_base(data, force=True, callback=self._throttle) def on_available(self, connection, parser): proxy_c = connection.proxy_c - if not proxy_c.renable == False: return - if not connection.is_restored(): return + if not proxy_c.renable == False: + return + if not connection.is_restored(): + return proxy_c.enable_read() - self.reads((proxy_c.socket,), state = False) + self.reads((proxy_c.socket,), state=False) def on_unavailable(self, connection, parser): proxy_c = connection.proxy_c - if proxy_c.renable == False: return + if proxy_c.renable == False: + return should_throttle = self.throttle and proxy_c.is_throttleable() should_disable = should_throttle and connection.is_exhausted() - if not should_disable: return + if not should_disable: + return proxy_c.disable_read() - def build_connection(self, socket, address, ssl = False): + def build_connection(self, socket, address, ssl=False): return ProxyConnection( - owner = self, - socket = socket, - address = address, - ssl = ssl, - encoding = self.encoding, - max_pending = self.max_pending, - min_pending = self.min_pending + owner=self, + socket=socket, + address=address, + ssl=ssl, + encoding=self.encoding, + max_pending=self.max_pending, + min_pending=self.min_pending, ) def _throttle(self, _connection): - if not _connection.is_restored(): return + if not _connection.is_restored(): + return connection = self.conn_map[_connection] - if not connection.renable == False: return + if not connection.renable == False: + return connection.enable_read() - self.reads((connection.socket,), state = False) + self.reads((connection.socket,), state=False) def _prx_close(self, connection): - connection.close(flush = True) + connection.close(flush=True) def _prx_keep(self, connection): pass def _prx_throttle(self, connection): - if not connection.is_restored(): return + if not connection.is_restored(): + return proxy_c = hasattr(connection, "proxy_c") and connection.proxy_c - if not proxy_c: return - if not proxy_c.renable == False: return + if not proxy_c: + return + if not proxy_c.renable == False: + return proxy_c.enable_read() - self.http_client.reads((proxy_c.socket,), state = False) + self.http_client.reads((proxy_c.socket,), state=False) def _raw_throttle(self, connection): - if not connection.is_restored(): return + if not connection.is_restored(): + return tunnel_c = hasattr(connection, "tunnel_c") and connection.tunnel_c - if not tunnel_c: return - if not tunnel_c.renable == False: return + if not tunnel_c: + return + if not tunnel_c.renable == False: + return tunnel_c.enable_read() - self.raw_client.reads((tunnel_c.socket,), state = False) + self.raw_client.reads((tunnel_c.socket,), state=False) def _on_prx_headers(self, client, parser, headers): # retrieves the owner of the parser as the client connection @@ -365,16 +394,21 @@ def _on_prx_headers(self, client, parser, headers): # the length values of the connection are considered unreliable and # some extra operation must be defined, note that in case the dynamic # (no re-encoding) support is enabled the length is always reliable - unreliable_length = _connection.current > http.CHUNKED_ENCODING or\ - connection.current > http.CHUNKED_ENCODING or parser.content_l == -1 + unreliable_length = ( + _connection.current > http.CHUNKED_ENCODING + or connection.current > http.CHUNKED_ENCODING + or parser.content_l == -1 + ) unreliable_length &= not self.dynamic # in case the content length is unreliable some of the headers defined # must be removed so that no extra connection error occurs, as the size # of the content from one end point to the other may change if unreliable_length: - if "content-length" in headers: del headers["content-length"] - if "accept-ranges" in headers: del headers["accept-ranges"] + if "content-length" in headers: + del headers["content-length"] + if "accept-ranges" in headers: + del headers["accept-ranges"] # in case the length of the data is not reliable and the current connection # is plain encoded a proper set of operation must be properly handled including @@ -390,7 +424,9 @@ def _on_prx_headers(self, client, parser, headers): # that allows the content encoding to be kept (compatibility support), this # heuristic is only applied in case the dynamic option is enabled content_encoding_c = http.ENCODING_MAP.get(content_encoding, connection.current) - content_encoding_t = http.ENCODING_MAP.get(transfer_encoding, connection.current) + content_encoding_t = http.ENCODING_MAP.get( + transfer_encoding, connection.current + ) target_encoding = max(content_encoding_c, content_encoding_t) if self.dynamic and target_encoding > connection.current: connection.set_encoding(target_encoding) @@ -408,10 +444,7 @@ def _on_prx_headers(self, client, parser, headers): # semantics of the transmission should dependent on the version of # the protocol that is going to be used in the transmission connection.send_header( - headers = headers, - version = version_s, - code = int(code_s), - code_s = status_s + headers=headers, version=version_s, code=int(code_s), code_s=status_s ) def _on_prx_message(self, client, parser, message): @@ -428,7 +461,8 @@ def _on_prx_message(self, client, parser, message): # creates the clojure function that will be used to close the # current client connection and that may or may not close the # corresponding back-end connection (as defined in specification) - def close(connection): connection.close(flush = True) + def close(connection): + connection.close(flush=True) # verifies that the connection is meant to be kept alive, the # connection is meant to be kept alive when both the client and @@ -438,13 +472,15 @@ def close(connection): connection.close(flush = True) # defines the proper callback function to be called at the end # of the flushing of the connection according to the result of # the keep alive evaluation (as defined in specification) - if keep_alive: callback = None - else: callback = close + if keep_alive: + callback = None + else: + callback = close # runs the final flush operation in the connection making sure that # every data that is pending is properly flushed, this is especially # important for chunked or compressed connections - connection.flush_s(callback = callback) + connection.flush_s(callback=callback) def _on_prx_partial(self, client, parser, data): # retrieves the owner of the proxy parser as the proxy connection @@ -459,8 +495,9 @@ def _on_prx_partial(self, client, parser, data): connection = self.conn_map[_connection] should_throttle = self.throttle and _connection.is_throttleable() should_disable = should_throttle and connection.is_exhausted() - if should_disable: _connection.disable_read() - connection.send_part(data, final = False, callback = self._prx_throttle) + if should_disable: + _connection.disable_read() + connection.send_part(data, final=False, callback=self._prx_throttle) def _on_prx_connect(self, client, _connection): _connection.waiting = False @@ -478,26 +515,30 @@ def _on_prx_close(self, client, _connection): # no connection is retrieved returns the control flow # to the caller method immediately (nothing done) connection = self.conn_map.get(_connection, None) - if not connection: return + if not connection: + return # in case the connection is under the waiting state # the forbidden response is set to the client otherwise # the front-end connection is closed immediately - if _connection.waiting: connection.send_response( - data = cls.build_data( - "Forbidden", - url = _connection.error_url if\ - hasattr(_connection, "error_url") else None - ), - headers = dict( - connection = "close" - ), - code = 403, - code_s = "Forbidden", - apply = True, - callback = self._prx_close - ) - else: connection.close(flush = True) + if _connection.waiting: + connection.send_response( + data=cls.build_data( + "Forbidden", + url=( + _connection.error_url + if hasattr(_connection, "error_url") + else None + ), + ), + headers=dict(connection="close"), + code=403, + code_s="Forbidden", + apply=True, + callback=self._prx_close, + ) + else: + connection.close(flush=True) # removes the waiting state from the connection and # the removes the back-end to front-end connection @@ -514,7 +555,8 @@ def _on_prx_error(self, client, _connection, error): # the proxy connection, this value is going to be # if sending the message to the final client connection = self.conn_map.get(_connection, None) - if not connection: return + if not connection: + return # constructs the message string that is going to be # sent as part of the response from the proxy indicating @@ -525,15 +567,14 @@ def _on_prx_error(self, client, _connection, error): # be handled by the error manager, and that should imply # a closing operation on the original/proxy connection) error_m = str(error) or "Unknown proxy relay error" - if _connection.waiting: connection.send_response( - data = cls.build_text(error_m), - headers = dict( - connection = "close" - ), - code = 500, - code_s = "Internal Error", - apply = True - ) + if _connection.waiting: + connection.send_response( + data=cls.build_text(error_m), + headers=dict(connection="close"), + code=500, + code_s="Internal Error", + apply=True, + ) # sets the connection as not waiting, so that no more # messages are sent as part of the closing chain @@ -541,28 +582,26 @@ def _on_prx_error(self, client, _connection, error): def _on_raw_connect(self, client, _connection): connection = self.conn_map[_connection] - connection.send_response( - code = 200, - code_s = "Connection established", - apply = True - ) + connection.send_response(code=200, code_s="Connection established", apply=True) def _on_raw_data(self, client, _connection, data): connection = self.conn_map[_connection] should_throttle = self.throttle and _connection.is_throttleable() should_disable = should_throttle and connection.is_exhausted() - if should_disable: _connection.disable_read() - connection.send(data, callback = self._raw_throttle) + if should_disable: + _connection.disable_read() + connection.send(data, callback=self._raw_throttle) def _on_raw_close(self, client, _connection): connection = self.conn_map[_connection] - connection.close(flush = True) + connection.close(flush=True) del self.conn_map[_connection] - def _apply_headers(self, parser, connection, parser_prx, headers, upper = True): - if upper: self._headers_upper(headers) + def _apply_headers(self, parser, connection, parser_prx, headers, upper=True): + if upper: + self._headers_upper(headers) self._apply_via(parser_prx, headers) - self._apply_all(parser_prx, connection, headers, replace = True) + self._apply_all(parser_prx, connection, headers, replace=True) def _apply_via(self, parser_prx, headers): # retrieves the various elements of the parser that are going @@ -583,13 +622,16 @@ def _apply_via(self, parser_prx, headers): # creates the via string value taking into account if the server # part of the string exists or not (different template) - if server: via_s = "%s %s (%s)" % (version_s, host, server) - else: via_s = "%s %s" % (version_s, host) + if server: + via_s = "%s %s (%s)" % (version_s, host, server) + else: + via_s = "%s %s" % (version_s, host) # tries to retrieve the current via string (may already exits) # and appends the created string to the base string or creates # a new one (as defined in the HTTP specification) via = headers.get("Via", "") - if via: via += ", " + if via: + via += ", " via += via_s headers["Via"] = via diff --git a/src/netius/servers/smtp.py b/src/netius/servers/smtp.py index d0d197bc7..25fa204d1 100644 --- a/src/netius/servers/smtp.py +++ b/src/netius/servers/smtp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -77,17 +68,15 @@ this is going to be used in some parsing calculus, this value should be exposed so that it may be re-used by other modules """ -CAPABILITIES = ( - "AUTH PLAIN LOGIN", - "STARTTLS" -) +CAPABILITIES = ("AUTH PLAIN LOGIN", "STARTTLS") """ The sequence defining the various capabilities that are available under the current smtp server implementation, the description of these capabilities should conform with the rfp """ + class SMTPConnection(netius.Connection): - def __init__(self, host = "smtp.localhost", *args, **kwargs): + def __init__(self, host="smtp.localhost", *args, **kwargs): netius.Connection.__init__(self, *args, **kwargs) self.parser = None self.host = host @@ -102,44 +91,44 @@ def __init__(self, host = "smtp.localhost", *args, **kwargs): def open(self, *args, **kwargs): netius.Connection.open(self, *args, **kwargs) - if not self.is_open(): return + if not self.is_open(): + return self.parser = netius.common.SMTPParser(self) self.parser.bind("on_line", self.on_line) def close(self, *args, **kwargs): netius.Connection.close(self, *args, **kwargs) - if not self.is_closed(): return - if self.parser: self.parser.destroy() + if not self.is_closed(): + return + if self.parser: + self.parser.destroy() def parse(self, data): - if self.state == DATA_STATE: self.on_raw_data(data) - elif self.state == USERNAME_STATE: self.on_username(data) - elif self.state == PASSWORD_STATE: self.on_password(data) - else: return self.parser.parse(data) - - def send_smtp(self, code, message = "", lines = (), delay = True, callback = None): - if lines: return self.send_smtp_lines( - code, - message = message, - lines = lines, - delay = delay, - callback = callback - ) - else: return self.send_smtp_base( - code, - message, - delay, - callback - ) + if self.state == DATA_STATE: + self.on_raw_data(data) + elif self.state == USERNAME_STATE: + self.on_username(data) + elif self.state == PASSWORD_STATE: + self.on_password(data) + else: + return self.parser.parse(data) - def send_smtp_base(self, code, message = "", delay = True, callback = None): + def send_smtp(self, code, message="", lines=(), delay=True, callback=None): + if lines: + return self.send_smtp_lines( + code, message=message, lines=lines, delay=delay, callback=callback + ) + else: + return self.send_smtp_base(code, message, delay, callback) + + def send_smtp_base(self, code, message="", delay=True, callback=None): base = "%d %s" % (code, message) data = base + "\r\n" - count = self.send(data, delay = delay, callback = callback) + count = self.send(data, delay=delay, callback=callback) self.owner.debug(base) return count - def send_smtp_lines(self, code, message = "", lines = (), delay = True, callback = None): + def send_smtp_lines(self, code, message="", lines=(), delay=True, callback=None): lines = list(lines) lines.insert(0, message) body = lines[:-1] @@ -148,7 +137,7 @@ def send_smtp_lines(self, code, message = "", lines = (), delay = True, callback lines_s = ["%d-%s" % (code, line) for line in body] lines_s.append("%d %s" % (code, tail)) data = "\r\n".join(lines_s) + "\r\n" - count = self.send(data, delay = delay, callback = callback) + count = self.send(data, delay=delay, callback=callback) self.owner.debug(base) return count @@ -169,20 +158,22 @@ def ehlo(self, host): self.assert_s(HELO_STATE) self.chost = host message = "ehlo %s" % host - self.send_smtp(250, message, lines = CAPABILITIES) + self.send_smtp(250, message, lines=CAPABILITIES) self.state = HEADER_STATE def starttls(self): def callback(connection): - connection.upgrade(server = True) + connection.upgrade(server=True) + message = "go ahead" - self.send_smtp(220, message, callback = callback) + self.send_smtp(220, message, callback=callback) self.state = HELO_STATE def auth(self, method, data): method_name = "auth_%s" % method has_method = hasattr(self, method_name) - if not has_method: raise netius.NotImplemented("Method not implemented") + if not has_method: + raise netius.NotImplemented("Method not implemented") method = getattr(self, method_name) method(data) @@ -203,11 +194,11 @@ def auth_login(self, data): data_s = netius.legacy.str(data_s) self._username = data_s message = "UGFzc3dvcmQ6" - self.send_smtp(334 , message) + self.send_smtp(334, message) self.state = PASSWORD_STATE else: message = "VXNlcm5hbWU6" - self.send_smtp(334 , message) + self.send_smtp(334, message) self.state = USERNAME_STATE def data(self): @@ -218,7 +209,7 @@ def data(self): self.previous = bytes() self.state = DATA_STATE - def queued(self, index = -1): + def queued(self, index=-1): self.assert_s(DATA_STATE) self.owner.on_message_smtp(self) identifier = self.identifier or index @@ -243,7 +234,7 @@ def on_username(self, data): data_s = netius.legacy.str(data_s) self._username = data_s message = "UGFzc3dvcmQ6" - self.send_smtp(334 , message) + self.send_smtp(334, message) self.state = PASSWORD_STATE def on_password(self, data): @@ -269,8 +260,8 @@ def on_raw_data(self, data): # find the termination string in the final concatenated string data_l = len(data) remaining = TERMINATION_SIZE - data_l if TERMINATION_SIZE > data_l else 0 - previous_v = self.previous[remaining * -1:] if remaining > 0 else b"" - buffer = previous_v + data[TERMINATION_SIZE * -1:] + previous_v = self.previous[remaining * -1 :] if remaining > 0 else b"" + buffer = previous_v + data[TERMINATION_SIZE * -1 :] is_final = not buffer.find(b"\r\n.\r\n") == -1 # updates the previous value string with the current buffer used for finding @@ -280,14 +271,15 @@ def on_raw_data(self, data): # verifies if this is the final part of the message as # pre-defined before the data configuration, if that's not # the case must return the control flow immediately - if not is_final: return + if not is_final: + return # runs the queued command indicating that the message has # been queued for sending and that the connection may now # be closed if there's nothing remaining to be done self.queued() - def on_line(self, code, message, is_final = True): + def on_line(self, code, message, is_final=True): # "joins" the code and the message part of the message into the base # string and then uses this value to print some debug information base = "%s %s" % (code, message) @@ -307,7 +299,8 @@ def on_line(self, code, message, is_final = True): # does not raises an exception indicating the problem with the # code that has just been received (probably erroneous) exists = hasattr(self, method_n) - if not exists: raise netius.ParserError("Invalid code '%s'" % code) + if not exists: + raise netius.ParserError("Invalid code '%s'" % code) # retrieves the reference to the method that is going to be called # for the handling of the current line from the current instance and @@ -329,8 +322,11 @@ def on_starttls(self, message): def on_auth(self, message): message_s = message.split(" ", 1) is_tuple = len(message_s) == 2 - if is_tuple: method, data = message_s - else: method = message; data = "" + if is_tuple: + method, data = message_s + else: + method = message + data = "" method = method.lower() self.auth(method, data) @@ -347,41 +343,40 @@ def on_data(self, message): def on_quit(self, message): self.bye() - self.close(flush = True) + self.close(flush=True) def assert_s(self, expected): - if self.state == expected: return + if self.state == expected: + return raise netius.ParserError("Invalid state") def to_s(self): return ", ".join(["<%s>" % email[3:].strip()[1:-1] for email in self.to_l]) - def received_s(self, for_s = False): + def received_s(self, for_s=False): to_s = self.to_s() date_time = datetime.datetime.utcfromtimestamp(self.time) date_s = date_time.strftime("%a, %d %b %Y %H:%M:%S +0000") - return "from %s " % self.chost +\ - "by %s (netius) with ESMTP id %s" % (self.host, self.identifier) +\ - (" for %s" % to_s if for_s else "") +\ - "; %s" % date_s + return ( + "from %s " % self.chost + + "by %s (netius) with ESMTP id %s" % (self.host, self.identifier) + + (" for %s" % to_s if for_s else "") + + "; %s" % date_s + ) + class SMTPServer(netius.StreamServer): def __init__( - self, - adapter_s = "memory", - auth_s = "dummy", - locals = ("localhost",), - *args, - **kwargs + self, adapter_s="memory", auth_s="dummy", locals=("localhost",), *args, **kwargs ): netius.StreamServer.__init__(self, *args, **kwargs) self.adapter_s = adapter_s self.auth_s = auth_s self.locals = locals - def serve(self, host = "smtp.localhost", port = 25, *args, **kwargs): - netius.StreamServer.serve(self, port = port, *args, **kwargs) + def serve(self, host="smtp.localhost", port=25, *args, **kwargs): + netius.StreamServer.serve(self, port=port, *args, **kwargs) self.host = host def on_connection_c(self, connection): @@ -394,23 +389,22 @@ def on_data(self, connection, data): def on_serve(self): netius.StreamServer.on_serve(self) - if self.env: self.host = self.get_env("SMTP_HOST", self.host) - if self.env: self.adapter_s = self.get_env("SMTP_ADAPTER", self.adapter_s) - if self.env: self.auth_s = self.get_env("SMTP_AUTH", self.auth_s) + if self.env: + self.host = self.get_env("SMTP_HOST", self.host) + if self.env: + self.adapter_s = self.get_env("SMTP_ADAPTER", self.adapter_s) + if self.env: + self.auth_s = self.get_env("SMTP_AUTH", self.auth_s) self.adapter = self.get_adapter(self.adapter_s) self.auth = self.get_auth(self.auth_s) self.info( - "Starting SMTP server on '%s' using '%s' and '%s' ..." % - (self.host, self.adapter_s, self.auth_s) + "Starting SMTP server on '%s' using '%s' and '%s' ..." + % (self.host, self.adapter_s, self.auth_s) ) - def build_connection(self, socket, address, ssl = False): + def build_connection(self, socket, address, ssl=False): return SMTPConnection( - owner = self, - socket = socket, - address = address, - ssl = ssl, - host = self.host + owner=self, socket=socket, address=address, ssl=ssl, host=self.host ) def on_line_smtp(self, connection, code, message): @@ -435,14 +429,14 @@ def on_header_smtp(self, connection, from_l, to_l): # iterates over the complete set of users to reserve # new keys for the various items to be delivered for user in users: - key = self.adapter.reserve(owner = user) + key = self.adapter.reserve(owner=user) keys.append(key) # sets the list of reserved keys in the connection # and then generates a new identifier for the current # message that is going to be delivered/queued connection.keys = keys - connection.identifier = self._generate(hashed = False) + connection.identifier = self._generate(hashed=False) def on_data_smtp(self, connection, data): for key in connection.keys: @@ -452,17 +446,17 @@ def on_message_smtp(self, connection): for key in connection.keys: self.adapter.truncate(key, TERMINATION_SIZE) - def _locals(self, sequence, prefix = "to"): - emails = self._emails(sequence, prefix = prefix) + def _locals(self, sequence, prefix="to"): + emails = self._emails(sequence, prefix=prefix) emails = [email for email in emails if self._is_local(email)] return emails - def _remotes(self, sequence, prefix = "to"): - emails = self._emails(sequence, prefix = prefix) + def _remotes(self, sequence, prefix="to"): + emails = self._emails(sequence, prefix=prefix) emails = [email for email in emails if not self._is_local(email)] return emails - def _emails(self, sequence, prefix = "to"): + def _emails(self, sequence, prefix="to"): prefix_l = len(prefix) base = prefix_l + 1 emails = [item[base:].strip()[1:-1] for item in sequence] @@ -476,9 +470,11 @@ def _is_local(self, email): domain = email.split("@", 1)[1] return domain in self.locals + if __name__ == "__main__": import logging - server = SMTPServer(level = logging.DEBUG) - server.serve(env = True) + + server = SMTPServer(level=logging.DEBUG) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/socks.py b/src/netius/servers/socks.py index 7f7e0c087..9affa858a 100644 --- a/src/netius/servers/socks.py +++ b/src/netius/servers/socks.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,10 +33,10 @@ import netius.common import netius.clients -GRANTED = 0x5a -REJECTED = 0x5b -FAILED_CLIENT = 0x5c -FAILED_AUTH = 0x5d +GRANTED = 0x5A +REJECTED = 0x5B +FAILED_CLIENT = 0x5C +FAILED_AUTH = 0x5D GRANTED_EXTRA = 0x00 @@ -64,6 +55,7 @@ avoids the starvation of the producer to consumer relation that could cause memory problems """ + class SOCKSConnection(netius.Connection): def __init__(self, *args, **kwargs): @@ -72,21 +64,24 @@ def __init__(self, *args, **kwargs): def open(self, *args, **kwargs): netius.Connection.open(self, *args, **kwargs) - if not self.is_open(): return + if not self.is_open(): + return self.parser = netius.common.SOCKSParser(self) self.parser.bind("on_data", self.on_data) self.parser.bind("on_auth", self.on_auth) def close(self, *args, **kwargs): netius.Connection.close(self, *args, **kwargs) - if not self.is_closed(): return - if self.parser: self.parser.destroy() + if not self.is_closed(): + return + if self.parser: + self.parser.destroy() - def send_response(self, status = GRANTED): + def send_response(self, status=GRANTED): data = struct.pack("!BBHI", 0, status, 0, 0) return self.send(data) - def send_response_extra(self, status = GRANTED_EXTRA): + def send_response_extra(self, status=GRANTED_EXTRA): version = self.parser.version type = self.parser.type port = self.parser.port @@ -95,7 +90,7 @@ def send_response_extra(self, status = GRANTED_EXTRA): data = struct.pack(format, version, status, 0, type, address, port) return self.send(data) - def send_auth(self, version = None, method = 0x00): + def send_auth(self, version=None, method=0x00): version = version or self.parser.version data = struct.pack("!BB", version, method) return self.send(data) @@ -112,6 +107,7 @@ def on_data(self): def on_auth(self): self.owner.on_auth_socks(self, self.parser) + class SOCKSServer(netius.ServerAgent): """ SOCKS base server class to be used as an implementation of the @@ -121,14 +117,16 @@ class SOCKSServer(netius.ServerAgent): performant driven for readability purposes. """ - def __init__(self, rules = {}, throttle = True, max_pending = MAX_PENDING, *args, **kwargs): + def __init__( + self, rules={}, throttle=True, max_pending=MAX_PENDING, *args, **kwargs + ): netius.ContainerServer.__init__( self, - receive_buffer_c = int(max_pending * BUFFER_RATIO), - send_buffer_c = int(max_pending * BUFFER_RATIO), + receive_buffer_c=int(max_pending * BUFFER_RATIO), + send_buffer_c=int(max_pending * BUFFER_RATIO), *args, **kwargs - ) # @todo how is this going to work (receive buffer control) + ) # @todo how is this going to work (receive buffer control) self.rules = rules self.throttle = throttle self.max_pending = max_pending @@ -140,7 +138,7 @@ def __init__(self, rules = {}, throttle = True, max_pending = MAX_PENDING, *args self.raw_protocol.bind("data", self._on_raw_data) self.raw_protocol.bind("close", self._on_raw_close) - #@todo this does not make sense + # @todo this does not make sense self.add_base(self) self.add_base(self.raw_client) @@ -156,18 +154,21 @@ def on_data(self, connection, data): # (initial handshake) runs the parse step on the data and then # returns immediately (not going to send it back) tunnel_c = hasattr(connection, "tunnel_c") and connection.tunnel_c - if not tunnel_c: connection.parse(data); return + if not tunnel_c: + connection.parse(data) + return # verifies that the current size of the pending buffer is greater # than the maximum size for the pending buffer the read operations # if that the case the read operations must be disabled should_disable = self.throttle and tunnel_c.is_exhausted() - if should_disable: connection.disable_read() + if should_disable: + connection.disable_read() # performs the sending operation on the data but uses the throttle # callback so that the connection read operations may be resumed if # the buffer has reached certain (minimum) levels - tunnel_c.send(data, callback = self._throttle) + tunnel_c.send(data, callback=self._throttle) def on_data_socks(self, connection, parser): host = parser.get_host() @@ -185,64 +186,75 @@ def on_auth_socks(self, connection, parser): if not 0 in auth_methods: raise netius.ParserError("Authentication is not supported") - connection.send_auth(method = 0) + connection.send_auth(method=0) def on_connection_d(self, connection): netius.ContainerServer.on_connection_d(self, connection) tunnel_c = hasattr(connection, "tunnel_c") and connection.tunnel_c - if tunnel_c: tunnel_c.close() + if tunnel_c: + tunnel_c.close() setattr(connection, "tunnel_c", None) - def build_connection(self, socket, address, ssl = False): + def build_connection(self, socket, address, ssl=False): return SOCKSConnection( - owner = self, - socket = socket, - address = address, - ssl = ssl, - max_pending = self.max_pending, - min_pending = self.min_pending + owner=self, + socket=socket, + address=address, + ssl=ssl, + max_pending=self.max_pending, + min_pending=self.min_pending, ) def _throttle(self, _connection): - if not _connection.is_restored(): return + if not _connection.is_restored(): + return connection = self.conn_map[_connection] - if not connection.renable == False: return + if not connection.renable == False: + return connection.enable_read() - self.reads((connection.socket,), state = False) + self.reads((connection.socket,), state=False) def _raw_throttle(self, connection): - if not connection.is_restored(): return + if not connection.is_restored(): + return tunnel_c = hasattr(connection, "tunnel_c") and connection.tunnel_c - if not tunnel_c: return - if not tunnel_c.renable == False: return + if not tunnel_c: + return + if not tunnel_c.renable == False: + return tunnel_c.enable_read() - self.raw_client.reads((tunnel_c.socket,), state = False) + self.raw_client.reads((tunnel_c.socket,), state=False) def _on_raw_connect(self, client, _connection): connection = self.conn_map[_connection] version = connection.get_version() - if version == 0x04: connection.send_response(status = GRANTED) - elif version == 0x05: connection.send_response_extra(status = GRANTED_EXTRA) + if version == 0x04: + connection.send_response(status=GRANTED) + elif version == 0x05: + connection.send_response_extra(status=GRANTED_EXTRA) def _on_raw_data(self, client, _connection, data): connection = self.conn_map[_connection] should_disable = self.throttle and connection.is_exhausted() - if should_disable: _connection.disable_read() - connection.send(data, callback = self._raw_throttle) + if should_disable: + _connection.disable_read() + connection.send(data, callback=self._raw_throttle) def _on_raw_close(self, client, _connection): connection = self.conn_map[_connection] - connection.close(flush = True) + connection.close(flush=True) del self.conn_map[_connection] + if __name__ == "__main__": import logging - server = SOCKSServer(level = logging.INFO) - server.serve(env = True) + + server = SOCKSServer(level=logging.INFO) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/tftp.py b/src/netius/servers/tftp.py index 9799ffe96..6615b68b9 100644 --- a/src/netius/servers/tftp.py +++ b/src/netius/servers/tftp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,9 +33,10 @@ import netius.common + class TFTPSession(object): - def __init__(self, owner, name = None, mode = None): + def __init__(self, owner, name=None, mode=None): self.owner = owner self.name = name self.mode = mode @@ -56,25 +48,29 @@ def close(self): self.reset() def reset(self): - if self.file: self.file.close() + if self.file: + self.file.close() self.name = None self.mode = None self.file = None self.completed = False self.sequence = 0 - def next(self, size = 512, increment = True): - if self.completed: return None + def next(self, size=512, increment=True): + if self.completed: + return None file = self._get_file() data = file.read(size) self.completed = len(data) < size - if increment: self.increment() + if increment: + self.increment() header = struct.pack("!HH", netius.common.DATA_TFTP, self.sequence) return header + data - def ack(self, size = 512, increment = True): - if self.sequence == 0: return None - return self.next(size = size, increment = increment) + def ack(self, size=512, increment=True): + if self.sequence == 0: + return None + return self.next(size=size, increment=increment) def increment(self): self.sequence += 1 @@ -93,13 +89,16 @@ def print_info(self): info = self.get_info() print(info) - def _get_file(self, allow_absolute = False): - if self.file: return self.file - if not allow_absolute: name = self.name.lstrip("/") + def _get_file(self, allow_absolute=False): + if self.file: + return self.file + if not allow_absolute: + name = self.name.lstrip("/") path = os.path.join(self.owner.base_path, name) self.file = open(path, "rb") return self.file + class TFTPRequest(object): parsers_m = None @@ -114,13 +113,14 @@ def __init__(self, data, session): @classmethod def generate(cls): - if cls.parsers_m: return + if cls.parsers_m: + return cls.parsers_m = ( cls._parse_rrq, cls._parse_wrq, cls._parse_data, cls._parse_ack, - cls._parse_error + cls._parse_error, ) cls.parsers_l = len(cls.parsers_m) @@ -129,7 +129,8 @@ def get_info(self): buffer = netius.legacy.StringIO() buffer.write("op := %d\n" % self.op) buffer.write("payload := %s" % repr(self.payload)) - if session_info: buffer.write("\n" + session_info) + if session_info: + buffer.write("\n" + session_info) buffer.seek(0) info = buffer.read() return info @@ -159,8 +160,9 @@ def get_type_s(self): type_s = netius.common.TYPES_TFTP.get(type, None) return type_s - def response(self, options = {}): - if self.op == netius.common.ACK_TFTP: return self.session.ack() + def response(self, options={}): + if self.op == netius.common.ACK_TFTP: + return self.session.ack() return self.session.next() @classmethod @@ -189,10 +191,11 @@ def _parse_error(cls, self): @classmethod def _str(cls, data): index = data.index(b"\x00") - value, remaining = data[:index], data[index + 1:] + value, remaining = data[:index], data[index + 1 :] value = netius.legacy.str(value) return value, remaining + class TFTPServer(netius.DatagramServer): """ Abstract trivial ftp server implementation that handles simple @@ -201,25 +204,23 @@ class TFTPServer(netius.DatagramServer): :see: http://tools.ietf.org/html/rfc1350 """ - ALLOWED_OPERATIONS = ( - netius.common.RRQ_TFTP, - netius.common.ACK_TFTP - ) + ALLOWED_OPERATIONS = (netius.common.RRQ_TFTP, netius.common.ACK_TFTP) - def __init__(self, base_path = "", *args, **kwargs): + def __init__(self, base_path="", *args, **kwargs): netius.DatagramServer.__init__(self, *args, **kwargs) self.base_path = base_path self.sessions = dict() - def serve(self, port = 69, *args, **kwargs): - netius.DatagramServer.serve(self, port = port, *args, **kwargs) + def serve(self, port=69, *args, **kwargs): + netius.DatagramServer.serve(self, port=port, *args, **kwargs) def on_data(self, address, data): netius.DatagramServer.on_data(self, address, data) try: session = self.sessions.get(address, None) - if not session: session = TFTPSession(self) + if not session: + session = TFTPSession(self) self.sessions[address] = session request = TFTPRequest(data, session) @@ -231,9 +232,12 @@ def on_data(self, address, data): def on_serve(self): netius.DatagramServer.on_serve(self) - if self.env: self.base_path = self.get_env("BASE_PATH", self.base_path) + if self.env: + self.base_path = self.get_env("BASE_PATH", self.base_path) self.info("Starting TFTP server ...") - self.info("Defining '%s' as the root of the file server ..." % (self.base_path or ".")) + self.info( + "Defining '%s' as the root of the file server ..." % (self.base_path or ".") + ) def on_data_tftp(self, address, request): cls = self.__class__ @@ -244,12 +248,11 @@ def on_data_tftp(self, address, request): self.debug("Received %s message from '%s'" % (type_s, address)) if not type in cls.ALLOWED_OPERATIONS: - raise netius.NetiusError( - "Invalid operation type '%d'", type - ) + raise netius.NetiusError("Invalid operation type '%d'", type) response = request.response() - if not response: return + if not response: + return self.send(response, address) @@ -261,9 +264,11 @@ def on_error_tftp(self, address, exception): self.send(response, address) self.info("Sent error message '%s' to '%s'" % (message, address)) + if __name__ == "__main__": import logging - server = TFTPServer(level = logging.DEBUG) - server.serve(env = True) + + server = TFTPServer(level=logging.DEBUG) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/torrent.py b/src/netius/servers/torrent.py index 7a6e9e168..01e02e411 100644 --- a/src/netius/servers/torrent.py +++ b/src/netius/servers/torrent.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -78,6 +69,7 @@ """ The sequence defining the various paths that are going to be search trying to find the (static) peers file with format host:ip in each line """ + class Pieces(netius.Observable): """ Class that represents the logical structure of a file that is @@ -106,16 +98,19 @@ def piece(self, index): def piece_blocks(self, index): is_last = index == self.number_pieces - 1 - if not is_last: return self.number_blocks + if not is_last: + return self.number_blocks piece_size = self.piece_size(index) number_blocks = math.ceil(piece_size / float(BLOCK_SIZE)) return int(number_blocks) def piece_size(self, index): is_last = index == self.number_pieces - 1 - if not is_last: return self.number_blocks * BLOCK_SIZE + if not is_last: + return self.number_blocks * BLOCK_SIZE modulus = self.length % self.piece_length - if modulus == 0: return self.piece_length + if modulus == 0: + return self.piece_length return modulus def block(self, index, begin): @@ -126,31 +121,36 @@ def block(self, index, begin): def block_size(self, index, begin): block_index = begin // BLOCK_SIZE is_last_piece = index == self.number_pieces - 1 - if not is_last_piece: return BLOCK_SIZE + if not is_last_piece: + return BLOCK_SIZE is_last_block = block_index == self.final_blocks - 1 - if not is_last_block: return BLOCK_SIZE + if not is_last_block: + return BLOCK_SIZE piece_size = self.piece_size(index) modulus = piece_size % BLOCK_SIZE - if modulus == 0: return BLOCK_SIZE + if modulus == 0: + return BLOCK_SIZE return modulus - def pop_block(self, bitfield, mark = True): + def pop_block(self, bitfield, mark=True): index = 0 result = self._and(bitfield, self.bitfield) for bit in result: - if bit == True: break + if bit == True: + break index += 1 - if index == len(result): return None + if index == len(result): + return None - begin = self.update_block(index, mark = mark) + begin = self.update_block(index, mark=mark) length = self.block_size(index, begin) return (index, begin, length) def push_block(self, index, begin): - self.mark_block(index, begin, value = True) + self.mark_block(index, begin, value=True) - def mark_piece(self, index, value = False): + def mark_piece(self, index, value=False): base = index * self.number_blocks block_count = self.piece_blocks(index) @@ -159,23 +159,25 @@ def mark_piece(self, index, value = False): self.bitfield[index] = value - def mark_block(self, index, begin, value = False): + def mark_block(self, index, begin, value=False): base = index * self.number_blocks block_index = begin // BLOCK_SIZE self.mask[base + block_index] = value self.trigger("block", self, index, begin) self.update_piece(index) - def update_block(self, index, mark = True): + def update_block(self, index, mark=True): base = index * self.number_blocks block_count = self.piece_blocks(index) for block_index in netius.legacy.xrange(block_count): state = self.mask[base + block_index] - if state == True: break + if state == True: + break begin = block_index * BLOCK_SIZE - if mark: self.mark_block(index, begin) + if mark: + self.mark_block(index, begin) return begin def update_piece(self, index): @@ -192,7 +194,8 @@ def update_piece(self, index): # unmarked (all the blocks unmarked accordingly) for block_index in netius.legacy.xrange(block_count): state = self.mask[base + block_index] - if state == False: continue + if state == False: + continue piece_state = True break @@ -200,7 +203,8 @@ def update_piece(self, index): # note that the false value indicates that the piece has been # unmarked (and this is considered the objective) self.bitfield[index] = piece_state - if piece_state == True: return + if piece_state == True: + return # triggers the piece event indicating that a new piece has # been completely unmarked according to rules @@ -210,7 +214,8 @@ def update_piece(self, index): # bit field to verify if the file has been completely unmarked # in case it did not returns the control flow to caller for bit in self.bitfield: - if bit == True: return + if bit == True: + return # triggers the complete event to any of the handlers indicating # that the current torrent file has been completely unmarked @@ -225,18 +230,21 @@ def total_pieces(self): def marked_pieces(self): counter = 0 for bit in self.bitfield: - if bit == True: continue + if bit == True: + continue counter += 1 return counter @property - def missing_pieces(self, max_missing = MAX_MISSING): + def missing_pieces(self, max_missing=MAX_MISSING): missing_count = self.total_pieces - self.marked_pieces - if missing_count > max_missing: return [] + if missing_count > max_missing: + return [] missing = [] for index in netius.legacy.xrange(self.total_pieces): bit = self.bitfield[index] - if bit == False: continue + if bit == False: + continue missing.append(index) return missing @@ -249,29 +257,35 @@ def total_blocks(self): def marked_blocks(self): counter = 0 for bit in self.mask: - if bit == True: continue + if bit == True: + continue counter += 1 return counter @property - def missing_blocks(self, max_missing = MAX_MISSING): + def missing_blocks(self, max_missing=MAX_MISSING): missing_count = self.total_blocks - self.marked_blocks - if missing_count > max_missing: return [] + if missing_count > max_missing: + return [] missing = [] for index in netius.legacy.xrange(self.total_blocks): bit = self.mask[index] - if bit == False: continue + if bit == False: + continue missing.append(index) return missing def _and(self, first, second): result = [] for _first, _second in zip(first, second): - if _first and _second: value = True - else: value = False + if _first and _second: + value = True + else: + value = False result.append(value) return result + class TorrentTask(netius.Observable): """ Describes a task (operation) that is going to be performed @@ -285,7 +299,7 @@ class TorrentTask(netius.Observable): a proper easily described interface. """ - def __init__(self, owner, target_path, torrent_path = None, info_hash = None): + def __init__(self, owner, target_path, torrent_path=None, info_hash=None): netius.Observable.__init__(self) self.owner = owner @@ -302,8 +316,10 @@ def __init__(self, owner, target_path, torrent_path = None, info_hash = None): self.peers_m = {} def load(self): - if self.torrent_path: self.info = self.load_info(self.torrent_path) - else: self.info = dict(info_hash = self.info_hash) + if self.torrent_path: + self.info = self.load_info(self.torrent_path) + else: + self.info = dict(info_hash=self.info_hash) self.pieces_tracker() self.peers_dht() @@ -325,7 +341,8 @@ def on_close(self, connection): self.unchoked -= 1 if is_unchoked else 0 def ticks(self): - if time.time() < self.next_refresh: return + if time.time() < self.next_refresh: + return self.refresh() def refresh(self): @@ -345,7 +362,8 @@ def on_block(self, pieces, index, begin): self.trigger("block", self, index, begin) def on_piece(self, pieces, index): - try: self.verify_piece(index) + try: + self.verify_piece(index) except netius.DataError: self.refute_piece(index) else: @@ -358,7 +376,8 @@ def on_complete(self, pieces): def on_dht(self, response): # verifies if the response is valid and in case it's not # returns immediately to avoid any erroneous parsing - if not response: return + if not response: + return # retrieves the payload for the response and then uses it # to retrieves the nodes part of the response for parsing @@ -379,12 +398,13 @@ def on_dht(self, response): chunk = netius.legacy.bytes(chunk) peer_id, address, port = struct.unpack("!20sLH", chunk) ip = netius.common.addr_to_ip4(address) - peer = dict(id = peer_id, ip = ip, port = port) + peer = dict(id=peer_id, ip=ip, port=port) peers.append(peer) # in case no valid peers have been parsed there's no need # to continue with the processing, nothing to be done - if not peers: return + if not peers: + return # extends the currently defined peers list in the current # torrent task with the ones that have been discovered @@ -405,7 +425,8 @@ def on_tracker(self, client, parser, result): # there're none of them continues the loop as there's nothing to be # processed from this tracker response (invalid response) data = result["data"] - if not data: return + if not data: + return # tries to decode the provided data from the tracker using the bencoder # and extracts the peers part of the message to be processed @@ -415,7 +436,8 @@ def on_tracker(self, client, parser, result): # verifies if the provided peers part is not compact (already a dictionary) # if that's the case there's nothing remaining to be done, otherwise extra # processing must be done to - if isinstance(peers, dict): self.extend_peers(peers) + if isinstance(peers, dict): + self.extend_peers(peers) # need to normalize the peer structure by decoding the peers string into a # set of address port sub strings (as defined in torrent specification) @@ -425,12 +447,14 @@ def on_tracker(self, client, parser, result): peer = netius.legacy.bytes(peer) address, port = struct.unpack("!LH", peer) ip = netius.common.addr_to_ip4(address) - peer = dict(ip = ip, port = port) + peer = dict(ip=ip, port=port) self.add_peer(peer) # prints a debug message about the peer loading that has just occurred, this # may be used for the purpose of development (and traceability) - self.owner.debug("Received %d peers from '%s'" % (len(peers), parser.owner.base)) + self.owner.debug( + "Received %d peers from '%s'" % (len(peers), parser.owner.base) + ) # refreshes the connection with the peers because new peers have been added # to the current task and there may be new connections pending @@ -438,16 +462,20 @@ def on_tracker(self, client, parser, result): def load_info(self, torrent_path): file = open(torrent_path, "rb") - try: data = file.read() - finally: file.close() + try: + data = file.read() + finally: + file.close() struct = netius.common.bdecode(data) struct["info_hash"] = self.info_hash = netius.common.info_hash(struct) return struct def load_file(self): - if self._is_single(): return self.load_single() - else: return self.load_multiple() + if self._is_single(): + return self.load_single() + else: + return self.load_multiple() def load_single(self): # retrieves the length of the current (single file) and @@ -464,7 +492,8 @@ def load_single(self): # not the case creates the appropriate directories so that # they area available for the file stream creation is_dir = os.path.isdir(target_path) - if not is_dir: os.makedirs(target_path) + if not is_dir: + os.makedirs(target_path) # creates the "final" file path from the target path and the # name of the file and then constructs a file stream with the @@ -484,13 +513,15 @@ def load_multiple(self): dir_path = os.path.join(target_path, name) is_dir = os.path.isdir(dir_path) - if not is_dir: os.makedirs(dir_path) + if not is_dir: + os.makedirs(dir_path) self.file = netius.common.FilesStream(dir_path, size, files) self.file.open() def unload_file(self): - if not self.file: return + if not self.file: + return self.file.close() self.file = None @@ -505,8 +536,10 @@ def load_pieces(self): self.stored.bind("complete", self.on_complete) def unload_pieces(self): - if self.requested: self.requested.destroy() - if self.stored: self.stored.destroy() + if self.requested: + self.requested.destroy() + if self.stored: + self.stored.destroy() self.requested = None self.stored = None @@ -533,7 +566,8 @@ def set_data(self, data, index, begin): # immediately as this is a duplicated block setting, possible # in the last part of the file retrieval (end game) block = self.stored.block(index, begin) - if not block: return + if not block: + return # retrieves the size of a piece and uses that value together # with the block begin offset to calculate the final file offset @@ -559,21 +593,24 @@ def set_dht(self, peer_t, port): # and in case it succeeds sets the proper DHT (port) value in the peer # so that it may latter be used for DHT based operations peer = self.peers_m.get(peer_t, None) - if not peer: return + if not peer: + return peer["dht"] = port def peers_dht(self): - if not self.info_hash: return + if not self.info_hash: + return for peer in self.peers: port = peer.get("dht", None) - if not port: continue + if not port: + continue host = peer["ip"] self.owner.dht_client.get_peers( - host = host, - port = port, - peer_id = self.owner.peer_id, - info_hash = self.info_hash, - callback = self.on_dht + host=host, + port=port, + peer_id=self.owner.peer_id, + info_hash=self.info_hash, + callback=self.on_dht, ) self.owner.debug("Requested peers from DHT peer '%s'" % host) @@ -601,7 +638,8 @@ def peers_tracker(self): # URL of it and then verifies that it references an HTTP based # tracker (as that's the only one supported) is_http = tracker_url.startswith(("http://", "https://")) - if not is_http: continue + if not is_http: + continue # runs the get HTTP retrieval call (blocking call) so that it's # possible to retrieve the contents for the announce of the tracker @@ -609,20 +647,20 @@ def peers_tracker(self): # called at the end of the process with the message self.owner.http_client.get( tracker_url, - params = dict( - info_hash = self.info_hash, - peer_id = self.owner.peer_id, - port = 6881, - uploaded = self.uploaded, - downloaded = self.downloaded, - left = self.left(), - compact = 1, - no_peer_id = 0, - event = "started", - numwant = 50, - key = self.owner.get_id() + params=dict( + info_hash=self.info_hash, + peer_id=self.owner.peer_id, + port=6881, + uploaded=self.uploaded, + downloaded=self.downloaded, + left=self.left(), + compact=1, + no_peer_id=0, + event="started", + numwant=50, + key=self.owner.get_id(), ), - on_result = self.on_tracker + on_result=self.on_tracker, ) # prints a debug message about the request for peer that was just @@ -633,24 +671,28 @@ def peers_file(self): for path in PEER_PATHS: path = os.path.expanduser(path) path = os.path.normpath(path) - if not os.path.exists(path): continue + if not os.path.exists(path): + continue file = open(path, "r") for line in file: line = line.strip() host, port = line.split(":", 1) port = int(port) - peer = dict(ip = host, port = port) + peer = dict(ip=host, port=port) self.add_peer(peer) def connect_peers(self): - for peer in self.peers: self.connect_peer(peer) + for peer in self.peers: + self.connect_peer(peer) def disconnect_peers(self): connections = copy.copy(self.connections) - for connection in connections: connection.close(flush = True) + for connection in connections: + connection.close(flush=True) def connect_peer(self, peer): - if not peer["new"]: return + if not peer["new"]: + return peer["new"] = False self.owner.debug("Connecting to peer '%s:%d'" % (peer["ip"], peer["port"])) connection = self.owner.client.peer(self, peer["ip"], peer["port"]) @@ -660,18 +702,22 @@ def connect_peer(self, peer): connection.bind("unchoked", self.on_unchoked) def info_string(self): - return "==== STATUS ====\n" +\ - "peers := %d\n" % len(self.peers) +\ - "connections := %d\n" % len(self.connections) +\ - "choked := %d\n" % (len(self.connections) - self.unchoked) +\ - "unchoked := %d\n" % self.unchoked +\ - "pieces := %d/%d\n" % (self.stored.marked_pieces, self.stored.total_pieces) +\ - "blocks := %d/%d\n" % (self.stored.marked_blocks, self.stored.total_blocks) +\ - "pieces miss := %s\n" % self.stored.missing_pieces +\ - "blocks miss := %s\n" % self.stored.missing_blocks +\ - "percent := %.2f % %\n" % self.percent() +\ - "left := %d/%d bytes\n" % (self.left(), self.info["length"]) +\ - "speed := %s/s" % self.speed_s() + return ( + "==== STATUS ====\n" + + "peers := %d\n" % len(self.peers) + + "connections := %d\n" % len(self.connections) + + "choked := %d\n" % (len(self.connections) - self.unchoked) + + "unchoked := %d\n" % self.unchoked + + "pieces := %d/%d\n" + % (self.stored.marked_pieces, self.stored.total_pieces) + + "blocks := %d/%d\n" + % (self.stored.marked_blocks, self.stored.total_blocks) + + "pieces miss := %s\n" % self.stored.missing_pieces + + "blocks miss := %s\n" % self.stored.missing_blocks + + "percent := %.2f % %\n" % self.percent() + + "left := %d/%d bytes\n" % (self.left(), self.info["length"]) + + "speed := %s/s" % self.speed_s() + ) def left(self): size = self.info["length"] @@ -695,11 +741,7 @@ def speed(self): return bytes_second def speed_s(self): - return netius.common.size_round_unit( - self.speed(), - space = True, - reduce = False - ) + return netius.common.size_round_unit(self.speed(), space=True, reduce=False) def percent(self): size = self.info["length"] @@ -709,11 +751,13 @@ def pop_block(self, bitfield): left = self.left() is_end = left < THRESHOLD_END structure = self.stored if is_end else self.requested - if not structure: return None - return structure.pop_block(bitfield, mark = not is_end) + if not structure: + return None + return structure.pop_block(bitfield, mark=not is_end) def push_block(self, index, begin): - if not self.requested: return + if not self.requested: + return self.requested.push_block(index, begin) def verify_piece(self, index): @@ -724,16 +768,18 @@ def confirm_piece(self, index): self.downloaded += piece_size def refute_piece(self, index): - self.requested.mark_piece(index, value = True) - self.stored.mark_piece(index, value = True) + self.requested.mark_piece(index, value=True) + self.stored.mark_piece(index, value=True) self.owner.warning("Refuted piece '%d' (probably invalid)" % index) def extend_peers(self, peers): - for peer in peers: self.add_peer(peer) + for peer in peers: + self.add_peer(peer) def add_peer(self, peer): peer_t = (peer["ip"], peer["port"]) - if peer_t in self.peers_m: return + if peer_t in self.peers_m: + return peer["time"] = time.time() peer["new"] = True self.peers_m[peer_t] = peer @@ -741,7 +787,8 @@ def add_peer(self, peer): def remove_peer(self, peer): peer_t = (peer["ip"], peer["port"]) - if not peer_t in self.peers_m: return + if not peer_t in self.peers_m: + return del self.peers_m[peer_t] self.peers.remove(peer) @@ -756,36 +803,27 @@ def _verify_piece(self, index, file): pending = self.stored.piece_size(index) hash = hashlib.sha1() while True: - if pending == 0: break + if pending == 0: + break count = BLOCK_SIZE if pending > BLOCK_SIZE else pending data = file.read(count) hash.update(data) pending -= count digest = hash.digest() piece = netius.legacy.bytes(piece) - if digest == piece: return + if digest == piece: + return raise netius.DataError("Verifying piece index '%d'" % index) + class TorrentServer(netius.ContainerServer): def __init__(self, *args, **kwargs): netius.ContainerServer.__init__(self, *args, **kwargs) self.peer_id = self._generate_id() - self.client = netius.clients.TorrentClient( - thread = False, - *args, - **kwargs - ) - self.http_client = netius.clients.HTTPClient( - thread = False, - *args, - **kwargs - ) - self.dht_client = netius.clients.DHTClient( - thread = False, - *args, - **kwargs - ) + self.client = netius.clients.TorrentClient(thread=False, *args, **kwargs) + self.http_client = netius.clients.HTTPClient(thread=False, *args, **kwargs) + self.dht_client = netius.clients.DHTClient(thread=False, *args, **kwargs) self.tasks = [] self.add_base(self.client) self.add_base(self.http_client) @@ -798,9 +836,10 @@ def cleanup(self): def ticks(self): netius.ContainerServer.ticks(self) - for task in self.tasks: task.ticks() + for task in self.tasks: + task.ticks() - def download(self, target_path, torrent_path = None, info_hash = None, close = False): + def download(self, target_path, torrent_path=None, info_hash=None, close=False): """ Starts the "downloading" process of a torrent associated file using the defined peer to peer torrent strategy using either @@ -835,13 +874,11 @@ def download(self, target_path, torrent_path = None, info_hash = None, close = F def on_complete(task): owner = task.owner self.remove_task(task) - if close: owner.close() + if close: + owner.close() task = TorrentTask( - self, - target_path, - torrent_path = torrent_path, - info_hash = info_hash + self, target_path, torrent_path=torrent_path, info_hash=info_hash ) task.load() task.connect_peers() @@ -858,7 +895,8 @@ def remove_task(self, task): def cleanup_tasks(self): tasks = copy.copy(self.tasks) - for task in tasks: self.remove_task(task) + for task in tasks: + self.remove_task(task) def _generate_id(self): random = str(uuid.uuid4()) @@ -868,14 +906,17 @@ def _generate_id(self): id = "-%s-%s" % (ID_STRING, digest[:12]) return id + if __name__ == "__main__": import logging - if len(sys.argv) > 1: file_path = sys.argv[1] - else: file_path = "\\file.torrent" + if len(sys.argv) > 1: + file_path = sys.argv[1] + else: + file_path = "\\file.torrent" def on_start(server): - task = server.download("~/Downloads", file_path, close = True) + task = server.download("~/Downloads", file_path, close=True) task.bind("piece", on_piece) task.bind("complete", on_complete) @@ -890,8 +931,8 @@ def on_piece(task, index): def on_complete(task): print("Download completed") - server = TorrentServer(level = logging.DEBUG) + server = TorrentServer(level=logging.DEBUG) server.bind("start", on_start) - server.serve(env = True) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/servers/ws.py b/src/netius/servers/ws.py index bc0356bf8..60106ad2d 100644 --- a/src/netius/servers/ws.py +++ b/src/netius/servers/ws.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,6 +33,7 @@ import netius.common + class WSConnection(netius.Connection): """ Connection based class for the websockets connection, @@ -62,21 +54,23 @@ def __init__(self, *args, **kwargs): self.headers = {} def send_ws(self, data): - encoded = netius.common.encode_ws(data, mask = False) + encoded = netius.common.encode_ws(data, mask=False) return self.send(encoded) - def recv_ws(self, size = netius.CHUNK_SIZE): - data = self.recv(size = size) + def recv_ws(self, size=netius.CHUNK_SIZE): + data = self.recv(size=size) decoded = netius.common.decode_ws(data) return decoded def add_buffer(self, data): self.buffer_l.append(data) - def get_buffer(self, delete = True): - if not self.buffer_l: return b"" + def get_buffer(self, delete=True): + if not self.buffer_l: + return b"" buffer = b"".join(self.buffer_l) - if delete: del self.buffer_l[:] + if delete: + del self.buffer_l[:] return buffer def do_handshake(self): @@ -85,16 +79,18 @@ def do_handshake(self): buffer = b"".join(self.buffer_l) end_index = buffer.find(b"\r\n\r\n") - if end_index == -1: raise netius.DataError("Missing data for handshake") + if end_index == -1: + raise netius.DataError("Missing data for handshake") - data = buffer[:end_index + 4] - remaining = buffer[end_index + 4:] + data = buffer[: end_index + 4] + remaining = buffer[end_index + 4 :] lines = data.split(b"\r\n") for line in lines[1:]: values = line.split(b":", 1) values_l = len(values) - if not values_l == 2: continue + if not values_l == 2: + continue key, value = values key = key.strip() @@ -110,7 +106,8 @@ def do_handshake(self): del self.buffer_l[:] self.handshake = True - if remaining: self.add_buffer(remaining) + if remaining: + self.add_buffer(remaining) def accept_key(self): socket_key = self.headers.get("Sec-WebSocket-Key", None) @@ -124,6 +121,7 @@ def accept_key(self): accept_key = netius.legacy.str(accept_key) return accept_key + class WSServer(netius.StreamServer): """ Base class for the creation of websocket server, should @@ -149,8 +147,11 @@ def on_data(self, connection, data): # a problem the (pending) data is added to the buffer buffer = connection.get_buffer() data = buffer + data - try: decoded, data = netius.common.decode_ws(data) - except netius.DataError: connection.add_buffer(data); break + try: + decoded, data = netius.common.decode_ws(data) + except netius.DataError: + connection.add_buffer(data) + break self.on_data_ws(connection, decoded) else: @@ -162,8 +163,10 @@ def on_data(self, connection, data): # current connection in case it fails due to an # handshake error must delay the execution to the # next iteration (not enough data) - try: connection.do_handshake() - except netius.DataError: return + try: + connection.do_handshake() + except netius.DataError: + return # retrieves (and computes) the accept key value for # the current request and sends it as the handshake @@ -181,16 +184,11 @@ def on_data(self, connection, data): # required so that the complete client buffer is flushed data = connection.get_buffer() - def build_connection(self, socket, address, ssl = False): - return WSConnection( - owner = self, - socket = socket, - address = address, - ssl = ssl - ) + def build_connection(self, socket, address, ssl=False): + return WSConnection(owner=self, socket=socket, address=address, ssl=ssl) def send_ws(self, connection, data): - encoded = netius.common.encode_ws(data, mask = False) + encoded = netius.common.encode_ws(data, mask=False) return connection.send(encoded) def on_data_ws(self, connection, data): @@ -215,8 +213,10 @@ def _handshake_response(self, accept_key): the specification and the provided accept key. """ - data = "HTTP/1.1 101 Switching Protocols\r\n" +\ - "Upgrade: websocket\r\n" +\ - "Connection: Upgrade\r\n" +\ - "Sec-WebSocket-Accept: %s\r\n\r\n" % accept_key + data = ( + "HTTP/1.1 101 Switching Protocols\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Accept: %s\r\n\r\n" % accept_key + ) return data diff --git a/src/netius/servers/wsgi.py b/src/netius/servers/wsgi.py index 65d2247a6..aa50c87ab 100644 --- a/src/netius/servers/wsgi.py +++ b/src/netius/servers/wsgi.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -54,6 +45,7 @@ content, this should ensure proper resource usage avoiding extreme high levels of resource usage for compression of large files """ + class WSGIServer(http2.HTTP2Server): """ Base class for the creation of a wsgi compliant server @@ -64,9 +56,9 @@ class WSGIServer(http2.HTTP2Server): def __init__( self, app, - mount = "", - decode = True, - compressed_limit = COMPRESSED_LIMIT, + mount="", + decode=True, + compressed_limit=COMPRESSED_LIMIT, *args, **kwargs ): @@ -91,12 +83,13 @@ def on_connection_d(self, connection): def on_serve(self): http2.HTTP2Server.on_serve(self) - if self.env: self.compressed_limit = self.get_env( - "COMPRESSED_LIMIT", self.compressed_limit, cast = int - ) + if self.env: + self.compressed_limit = self.get_env( + "COMPRESSED_LIMIT", self.compressed_limit, cast=int + ) self.info( - "Starting WSGI server with %d bytes limit on compression ..." %\ - self.compressed_limit + "Starting WSGI server with %d bytes limit on compression ..." + % self.compressed_limit ) def on_data_http(self, connection, parser): @@ -105,16 +98,17 @@ def on_data_http(self, connection, parser): # retrieves the path for the current request and then retrieves # the query string part for it also, after that computes the # path info value as the substring of the path without the mount - path = parser.get_path(normalize = True) + path = parser.get_path(normalize=True) query = parser.get_query() - path_info = path[self.mount_l:] + path_info = path[self.mount_l :] # verifies if the path and query values should be encoded and if # that's the case the decoding process should unquote the received # path and then convert it into a valid string representation, this # is especially relevant for the python 3 infra-structure, this is # a tricky process but is required for the wsgi compliance - if self.decode: path_info = self._decode(path_info) + if self.decode: + path_info = self._decode(path_info) # retrieves a possible forwarded protocol value from the request # headers and calculates the appropriate (final scheme value) @@ -127,17 +121,17 @@ def on_data_http(self, connection, parser): # variables that should enable the application to handle the request # and respond to it in accordance environ = dict( - REQUEST_METHOD = parser.method.upper(), - SCRIPT_NAME = self.mount, - PATH_INFO = path_info, - QUERY_STRING = query, - CONTENT_TYPE = parser.headers.get("content-type", ""), - CONTENT_LENGTH = "" if parser.content_l == -1 else parser.content_l, - SERVER_NAME = self.host, - SERVER_PORT = str(self.port), - SERVER_PROTOCOL = parser.version_s, - SERVER_SOFTWARE = SERVER_SOFTWARE, - REMOTE_ADDR = connection.address[0] + REQUEST_METHOD=parser.method.upper(), + SCRIPT_NAME=self.mount, + PATH_INFO=path_info, + QUERY_STRING=query, + CONTENT_TYPE=parser.headers.get("content-type", ""), + CONTENT_LENGTH="" if parser.content_l == -1 else parser.content_l, + SERVER_NAME=self.host, + SERVER_PORT=str(self.port), + SERVER_PROTOCOL=parser.version_s, + SERVER_SOFTWARE=SERVER_SOFTWARE, + REMOTE_ADDR=connection.address[0], ) # updates the environment map with all the structures referring @@ -145,7 +139,7 @@ def on_data_http(self, connection, parser): # as a buffer to be able to handle the file specific operations environ["wsgi.version"] = (1, 0) environ["wsgi.url_scheme"] = scheme - environ["wsgi.input"] = parser.get_message_b(copy = True) + environ["wsgi.input"] = parser.get_message_b(copy=True) environ["wsgi.errors"] = sys.stderr environ["wsgi.multithread"] = False environ["wsgi.multiprocess"] = False @@ -159,7 +153,8 @@ def on_data_http(self, connection, parser): # in the standard specification for key, value in netius.legacy.iteritems(parser.headers): key = "HTTP_" + key.replace("-", "_").upper() - if isinstance(value, (list, tuple)): value = ";".join(value) + if isinstance(value, (list, tuple)): + value = ";".join(value) environ[key] = value # verifies if the connection already has an iterator associated with @@ -167,7 +162,8 @@ def on_data_http(self, connection, parser): # request processing must be delayed for future processing, this is # typically associated with HTTP pipelining if hasattr(connection, "iterator") and connection.iterator: - if not hasattr(connection, "queue"): connection.queue = [] + if not hasattr(connection, "queue"): + connection.queue = [] connection.queue.append(environ) return @@ -202,8 +198,10 @@ def _next_queue(self, connection): # the queue structure that handles the queueing/pipelining of requests # if it does not or the queue is empty returns immediately, as there's # nothing currently pending to be done/processed - if not hasattr(connection, "queue"): return - if not connection.queue: return + if not hasattr(connection, "queue"): + return + if not connection.queue: + return # retrieves the current/first element in the connection queue to for # the processing and then runs the proper callback for the environ @@ -233,25 +231,29 @@ def _start_response(self, connection, status, headers): length = headers.get("Content-Length", -1) length = int(length) length = 0 if status_c in (204, 304) else length - if length == 0: connection.set_encoding(http.PLAIN_ENCODING) + if length == 0: + connection.set_encoding(http.PLAIN_ENCODING) # verifies if the length value of the message payload overflow # the currently defined limit, if that's the case the connection # is set as uncompressed to avoid unnecessary encoding that would # consume a lot of resources (mostly processor) - if length > self.compressed_limit: connection.set_uncompressed() + if length > self.compressed_limit: + connection.set_uncompressed() # tries to determine if the accept ranges value is set and if # that's the case forces the uncompressed encoding to avoid possible # range missmatch due to re-encoding of the content ranges = headers.get("Accept-Ranges", None) - if ranges == "bytes": connection.set_uncompressed() + if ranges == "bytes": + connection.set_uncompressed() # determines if the content range header is set, meaning that # a partial chunk value is being sent if that's the case the # uncompressed encoding is forced to avoid re-encoding issues content_range = headers.get("Content-Range", None) - if content_range: connection.set_uncompressed() + if content_range: + connection.set_uncompressed() # verifies if the current connection is using a chunked based # stream as this will affect some of the decisions that are @@ -264,7 +266,8 @@ def _start_response(self, connection, status, headers): # requires the content length to be defined or the target # encoding type to be chunked has_length = not length == -1 - if not has_length: parser.keep_alive = is_chunked + if not has_length: + parser.keep_alive = is_chunked # applies the base (static) headers to the headers map and then # applies the parser based values to the headers map, these @@ -279,10 +282,7 @@ def _start_response(self, connection, status, headers): # should serialize the various headers and send them through the # current connection according to the currently associated protocol connection.send_header( - headers = headers, - version = version_s, - code = status_c, - code_s = status_m + headers=headers, version=version_s, code=status_c, code_s=status_m ) def _send_part(self, connection): @@ -301,12 +301,15 @@ def _send_part(self, connection): # case the stop iteration is received sets the is final flag # so that no more data is sent through the connection and # releases the iterator from the connection - try: data = next(iterator) if iterator else None + try: + data = next(iterator) if iterator else None except StopIteration as exception: # tries to extract possible data coming from the exception # return value and sets the is final flag otherwise - if exception.args: data = exception.args[0] - else: is_final = True + if exception.args: + data = exception.args[0] + else: + is_final = True # releases the iterator from the connection as it's no longer # considered to be valid for the current connection context @@ -317,20 +320,20 @@ def _send_part(self, connection): # for the handling on the end of the iteration is_future = netius.is_future(data) if is_future: + def on_partial(future, value): - if not value: return - if not self.is_main(): return self.delay_s( - lambda: on_partial(future, value) - ) - connection.send_part(value, final = False) + if not value: + return + if not self.is_main(): + return self.delay_s(lambda: on_partial(future, value)) + connection.send_part(value, final=False) def on_done(future): # in case the current threads is not the main one (running using # thread pool) delays the current callback to be called upon # the next main event loop ticks - if not self.is_main(): return self.delay_s( - lambda: on_done(future) - ) + if not self.is_main(): + return self.delay_s(lambda: on_done(future)) # unsets the future from the connection as it has been # completely processed, not going to be used anymore @@ -353,10 +356,8 @@ def on_done(future): # otherwise runs the send part operation on the next tick so # that it gets handled as fast as possible, this should continue # the iteration on the overall async generator - else: self.delay( - lambda: self._send_part(connection), - immediately = True - ) + else: + self.delay(lambda: self._send_part(connection), immediately=True) def on_ready(): return connection.wready @@ -383,19 +384,18 @@ def on_closed(): # ensures that the provided data is a byte sequence as expected # by the underlying server infra-structure - if data: data = netius.legacy.bytes(data) + if data: + data = netius.legacy.bytes(data) # in case the final flag is set runs the flush operation in the # connection setting the proper callback method for it so that # the connection state is defined in the proper way (closed or # kept untouched) otherwise sends the retrieved data setting the # callback to the current method so that more that is sent - if is_final: connection.flush_s(callback = self._final) - else: connection.send_part( - data, - final = False, - callback = self._send_part - ) + if is_final: + connection.flush_s(callback=self._final) + else: + connection.send_part(data, final=False, callback=self._send_part) def _final(self, connection): # retrieves the parser of the current connection and then determines @@ -405,7 +405,9 @@ def _final(self, connection): # in case the connection is not meant to be kept alive must # must call the proper underlying close operation (expected) - if not keep_alive: self._close(connection); return + if not keep_alive: + self._close(connection) + return # the map of environment must be destroyed properly, avoiding # any possible memory leak for the current handling and then the @@ -415,7 +417,7 @@ def _final(self, connection): self._next_queue(connection) def _close(self, connection): - connection.close(flush = True) + connection.close(flush=True) def _release(self, connection): self._release_future(connection) @@ -427,7 +429,8 @@ def _release_future(self, connection): # verifies if there's a future associated/running under the # current connection, if that's not the case returns immediately future = hasattr(connection, "future") and connection.future - if not future: return + if not future: + return # runs the cancel operation on the future, note that this # operation is only performed in case the future is still @@ -443,13 +446,15 @@ def _release_iterator(self, connection): # in the connection so that it may be close in case that's # required, this is mandatory to avoid any memory leak iterator = hasattr(connection, "iterator") and connection.iterator - if not iterator: return + if not iterator: + return # verifies if the close attributes is defined in the iterator # and if that's the case calls the close method in order to # avoid any memory leak caused by the generator has_close = hasattr(iterator, "close") - if has_close: iterator.close() + if has_close: + iterator.close() # unsets the iterator attribute in the connection object so that # it may no longer be used by any chunk of logic code @@ -459,7 +464,8 @@ def _release_environ(self, connection): # tries to retrieve the map of environment for the current # connection and in case it does not exists returns immediately environ = hasattr(connection, "environ") and connection.environ - if not environ: return + if not environ: + return # retrieves the input stream (buffer) and closes it as there's # not going to be any further operation in it (avoids leak) @@ -482,7 +488,8 @@ def _release_queue(self, connection): # connection in case it does not exist returns immediately as # there's no queue element to be release/cleared queue = hasattr(connection, "queue") and connection.queue - if not queue: return + if not queue: + return # iterates over the complete set of queue elements (environ # based maps) to clear their elements properly @@ -523,6 +530,7 @@ def _decode(self, value): value = netius.legacy.str(value) return value + if __name__ == "__main__": import logging @@ -533,12 +541,12 @@ def app(environ, start_response): headers = ( ("Content-Length", content_l), ("Content-type", "text/plain"), - ("Connection", "keep-alive") + ("Connection", "keep-alive"), ) start_response(status, headers) yield contents - server = WSGIServer(app = app, level = logging.INFO) - server.serve(env = True) + server = WSGIServer(app=app, level=logging.INFO) + server.serve(env=True) else: __path__ = [] diff --git a/src/netius/sh/__init__.py b/src/netius/sh/__init__.py index fae40397c..7a1625c0d 100644 --- a/src/netius/sh/__init__.py +++ b/src/netius/sh/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/sh/auth.py b/src/netius/sh/auth.py index b84188194..7aad16e07 100644 --- a/src/netius/sh/auth.py +++ b/src/netius/sh/auth.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -38,8 +29,10 @@ from . import base -def generate(password, type = "sha256", salt = "netius"): - print(netius.Auth.generate(password, type = type, salt = salt)) + +def generate(password, type="sha256", salt="netius"): + print(netius.Auth.generate(password, type=type, salt=salt)) + if __name__ == "__main__": base.sh_call(globals(), locals()) diff --git a/src/netius/sh/base.py b/src/netius/sh/base.py index cefe0acd5..adc2e832a 100644 --- a/src/netius/sh/base.py +++ b/src/netius/sh/base.py @@ -22,22 +22,14 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ __license__ = "Apache License, Version 2.0" """ The license for the module """ -def sh_call(globals = {}, locals = {}): + +def sh_call(globals={}, locals={}): import sys if not len(sys.argv) > 1: diff --git a/src/netius/sh/dkim.py b/src/netius/sh/dkim.py index 536485a9c..88cd75bfd 100644 --- a/src/netius/sh/dkim.py +++ b/src/netius/sh/dkim.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,24 +32,32 @@ from . import base -def generate(domain, suffix = None, number_bits = 1024): + +def generate(domain, suffix=None, number_bits=1024): number_bits = int(number_bits) - result = netius.common.dkim_generate(domain, suffix = suffix, number_bits = number_bits) + result = netius.common.dkim_generate(domain, suffix=suffix, number_bits=number_bits) print(result["dns_txt"]) print(result["private_pem"]) + def sign(email_path, key_path, selector, domain): file = open(email_path, "rb") - try: contents = file.read() - finally: file.close() + try: + contents = file.read() + finally: + file.close() contents = contents.lstrip() private_key = netius.common.open_private_key(key_path) signature = netius.common.dkim_sign(contents, selector, domain, private_key) file = open(email_path, "wb") - try: file.write(signature); file.write(contents) - finally: file.close() + try: + file.write(signature) + file.write(contents) + finally: + file.close() + if __name__ == "__main__": base.sh_call(globals(), locals()) diff --git a/src/netius/sh/rsa.py b/src/netius/sh/rsa.py index 19483828e..43049d93e 100644 --- a/src/netius/sh/rsa.py +++ b/src/netius/sh/rsa.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,19 +34,23 @@ from . import base + def read_private(path): private_key = netius.common.open_private_key(path) pprint.pprint(private_key) + def read_public(path): public_key = netius.common.open_public_key(path) pprint.pprint(public_key) + def private_to_public(private_path, public_path): private_key = netius.common.open_private_key(private_path) public_key = netius.common.private_to_public(private_key) netius.common.write_public_key(public_path, public_key) + if __name__ == "__main__": base.sh_call(globals(), locals()) else: diff --git a/src/netius/sh/smtp.py b/src/netius/sh/smtp.py index dc8c605cc..6290c250b 100644 --- a/src/netius/sh/smtp.py +++ b/src/netius/sh/smtp.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,31 +32,28 @@ import netius.clients + def send( - path, - sender, - receiver, - host = None, - port = 25, - username = None, - password = None, - stls = True + path, sender, receiver, host=None, port=25, username=None, password=None, stls=True ): file = open(path, "rb") - try: contents = file.read() - finally: file.close() - smtp_client = netius.clients.SMTPClient(auto_close = True) + try: + contents = file.read() + finally: + file.close() + smtp_client = netius.clients.SMTPClient(auto_close=True) smtp_client.message( [sender], [receiver], contents, - host = host, - port = port, - username = username, - password = password, - stls = stls + host=host, + port=port, + username=username, + password=password, + stls=stls, ) + if __name__ == "__main__": base.sh_call(globals(), locals()) else: diff --git a/src/netius/test/__init__.py b/src/netius/test/__init__.py index 52c2dd14c..81bce846f 100644 --- a/src/netius/test/__init__.py +++ b/src/netius/test/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/test/auth/__init__.py b/src/netius/test/auth/__init__.py index 52c2dd14c..81bce846f 100644 --- a/src/netius/test/auth/__init__.py +++ b/src/netius/test/auth/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/test/auth/allow.py b/src/netius/test/auth/allow.py index f15c0914f..7f27d6019 100644 --- a/src/netius/test/auth/allow.py +++ b/src/netius/test/auth/allow.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius.auth + class AllowAuthTest(unittest.TestCase): def test_simple(self): diff --git a/src/netius/test/auth/deny.py b/src/netius/test/auth/deny.py index 711a42ff9..bb7a56933 100644 --- a/src/netius/test/auth/deny.py +++ b/src/netius/test/auth/deny.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius.auth + class DenyAuthTest(unittest.TestCase): def test_simple(self): diff --git a/src/netius/test/auth/simple.py b/src/netius/test/auth/simple.py index 8857cb625..fc26ca7da 100644 --- a/src/netius/test/auth/simple.py +++ b/src/netius/test/auth/simple.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius.auth + class SimpleAuthTest(unittest.TestCase): def test_simple(self): diff --git a/src/netius/test/base/__init__.py b/src/netius/test/base/__init__.py index 52c2dd14c..81bce846f 100644 --- a/src/netius/test/base/__init__.py +++ b/src/netius/test/base/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/test/base/asynchronous.py b/src/netius/test/base/asynchronous.py index 2bb9108bc..082eda5c4 100644 --- a/src/netius/test/base/asynchronous.py +++ b/src/netius/test/base/asynchronous.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,15 +32,16 @@ import netius + class AsynchronousTest(unittest.TestCase): def test_basic(self): - loop = netius.get_loop(asyncio = False) + loop = netius.get_loop(asyncio=False) self.assertNotEqual(loop, None) self.assertEqual(isinstance(loop, netius.Base), True) - future = netius.build_future(compat = False, asyncio = False) + future = netius.build_future(compat=False, asyncio=False) self.assertNotEqual(future, None) self.assertEqual(isinstance(future, netius.Future), True) @@ -57,7 +49,7 @@ def test_basic(self): self.assertEqual(isinstance(future._loop, netius.Base), True) previous = loop - loop = netius.get_loop(_compat = True) + loop = netius.get_loop(_compat=True) self.assertNotEqual(loop, None) @@ -66,7 +58,7 @@ def test_basic(self): self.assertEqual(loop, previous._compat) self.assertEqual(loop._loop_ref(), previous) - loop = netius.get_loop(asyncio = True) + loop = netius.get_loop(asyncio=True) self.assertNotEqual(loop, None) diff --git a/src/netius/test/base/common.py b/src/netius/test/base/common.py index 4d6f8a6f2..81a30a4a7 100644 --- a/src/netius/test/base/common.py +++ b/src/netius/test/base/common.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius + class BaseTest(unittest.TestCase): def test_resolve_hostname(self): diff --git a/src/netius/test/base/config.py b/src/netius/test/base/config.py index c3b3cedd3..d5da83bda 100644 --- a/src/netius/test/base/config.py +++ b/src/netius/test/base/config.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius + class ConfigTest(unittest.TestCase): def test_basic(self): @@ -49,28 +41,28 @@ def test_basic(self): self.assertEqual(result, "name") - result = netius.conf("NAME", cast = str) + result = netius.conf("NAME", cast=str) self.assertEqual(result, "name") self.assertEqual(type(result), str) - result = netius.conf("NAME", cast = "str") + result = netius.conf("NAME", cast="str") self.assertEqual(result, "name") self.assertEqual(type(result), str) netius.conf_s("AGE", "10") - result = netius.conf("AGE", cast = int) + result = netius.conf("AGE", cast=int) self.assertEqual(result, 10) self.assertEqual(type(result), int) - result = netius.conf("AGE", cast = "int") + result = netius.conf("AGE", cast="int") self.assertEqual(result, 10) self.assertEqual(type(result), int) - result = netius.conf("AGE", cast = str) + result = netius.conf("AGE", cast=str) self.assertEqual(result, "10") self.assertEqual(type(result), str) @@ -81,10 +73,10 @@ def test_basic(self): def test_none(self): netius.conf_s("AGE", None) - result = netius.conf("AGE", cast = int) + result = netius.conf("AGE", cast=int) self.assertEqual(result, None) - result = netius.conf("HEIGHT", cast = int) + result = netius.conf("HEIGHT", cast=int) self.assertEqual(result, None) diff --git a/src/netius/test/base/tls.py b/src/netius/test/base/tls.py index e0cd92a95..7b9570797 100644 --- a/src/netius/test/base/tls.py +++ b/src/netius/test/base/tls.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius.common + class TLSTest(unittest.TestCase): def test_fingerprint(self): @@ -48,26 +40,26 @@ def test_fingerprint(self): result = netius.fingerprint(key_der) self.assertEqual(result, "5b4e55fa5ba652a9cb0c3be2dcfa303b5ae647d6") - cer_der = netius.common.open_pem_key(netius.SSL_CER_PATH, token = "CERTIFICATE") + cer_der = netius.common.open_pem_key(netius.SSL_CER_PATH, token="CERTIFICATE") result = netius.fingerprint(cer_der) self.assertEqual(result, "55ed3769f523281134d87393ffda7f78c9dff786") def test_match_hostname(self): certificate = dict( - subject = ((("commonName", "domain.com"),),), - subjectAltName = ( + subject=((("commonName", "domain.com"),),), + subjectAltName=( ("DNS", "api.domain.com"), ("DNS", "embed.domain.com"), ("DNS", "instore.domain.com"), ("DNS", "domain.com"), - ("DNS", "www.domain.com") + ("DNS", "www.domain.com"), ), - version = 3 + version=3, ) netius.match_hostname(certificate, "domain.com") self.assertRaises( BaseException, - lambda: netius.match_hostname(certificate, "other.domain.com") + lambda: netius.match_hostname(certificate, "other.domain.com"), ) def test_dnsname_match(self): diff --git a/src/netius/test/base/transport.py b/src/netius/test/base/transport.py index 393854bbd..138541fca 100644 --- a/src/netius/test/base/transport.py +++ b/src/netius/test/base/transport.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius + class TransportTest(unittest.TestCase): def test_write_closing(self): diff --git a/src/netius/test/clients/__init__.py b/src/netius/test/clients/__init__.py index 52c2dd14c..81bce846f 100644 --- a/src/netius/test/clients/__init__.py +++ b/src/netius/test/clients/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/test/clients/http.py b/src/netius/test/clients/http.py index 5388e0639..9b67d5d06 100644 --- a/src/netius/test/clients/http.py +++ b/src/netius/test/clients/http.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,6 +33,7 @@ import netius.clients + class HTTPClientTest(unittest.TestCase): def setUp(self): @@ -50,18 +42,14 @@ def setUp(self): def test_simple(self): result = netius.clients.HTTPClient.method_s( - "GET", - "http://%s/get" % self.httpbin, - asynchronous = False + "GET", "http://%s/get" % self.httpbin, asynchronous=False ) self.assertEqual(result["code"], 200) self.assertNotEqual(len(result["data"]), 0) self.assertNotEqual(json.loads(result["data"].decode("utf-8")), None) result = netius.clients.HTTPClient.method_s( - "GET", - "https://%s/get" % self.httpbin, - asynchronous = False + "GET", "https://%s/get" % self.httpbin, asynchronous=False ) self.assertEqual(result["code"], 200) self.assertNotEqual(len(result["data"]), 0) @@ -69,19 +57,13 @@ def test_simple(self): def test_timeout(self): result = netius.clients.HTTPClient.method_s( - "GET", - "http://%s/delay/3" % self.httpbin, - timeout = 1, - asynchronous = False + "GET", "http://%s/delay/3" % self.httpbin, timeout=1, asynchronous=False ) self.assertEqual(result["error"], "timeout") self.assertEqual(result["message"].startswith("Timeout on receive"), True) result = netius.clients.HTTPClient.method_s( - "GET", - "http://%s/delay/1" % self.httpbin, - timeout = 30, - asynchronous = False + "GET", "http://%s/delay/1" % self.httpbin, timeout=30, asynchronous=False ) self.assertEqual(result.get("error", None), None) self.assertEqual(result.get("message", None), None) @@ -91,18 +73,14 @@ def test_timeout(self): def test_compression(self): result = netius.clients.HTTPClient.method_s( - "GET", - "http://%s/gzip" % self.httpbin, - asynchronous = False + "GET", "http://%s/gzip" % self.httpbin, asynchronous=False ) self.assertEqual(result["code"], 200) self.assertNotEqual(len(result["data"]), 0) self.assertNotEqual(json.loads(result["data"].decode("utf-8")), None) result = netius.clients.HTTPClient.method_s( - "GET", - "http://%s/deflate" % self.httpbin, - asynchronous = False + "GET", "http://%s/deflate" % self.httpbin, asynchronous=False ) self.assertEqual(result["code"], 200) self.assertNotEqual(len(result["data"]), 0) @@ -110,9 +88,7 @@ def test_compression(self): def test_headers(self): result = netius.clients.HTTPClient.method_s( - "GET", - "http://%s/headers" % self.httpbin, - asynchronous = False + "GET", "http://%s/headers" % self.httpbin, asynchronous=False ) payload = json.loads(result["data"].decode("utf-8")) headers = payload["headers"] @@ -123,9 +99,7 @@ def test_headers(self): self.assertNotEqual(headers.get("User-Agent", ""), "") result = netius.clients.HTTPClient.method_s( - "GET", - "http://%s/image/png" % self.httpbin, - asynchronous = False + "GET", "http://%s/image/png" % self.httpbin, asynchronous=False ) self.assertEqual(result["code"], 200) self.assertNotEqual(len(result["data"]), 0) diff --git a/src/netius/test/common/__init__.py b/src/netius/test/common/__init__.py index 52c2dd14c..81bce846f 100644 --- a/src/netius/test/common/__init__.py +++ b/src/netius/test/common/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/test/common/calc.py b/src/netius/test/common/calc.py index b102a663f..4759e3fc7 100644 --- a/src/netius/test/common/calc.py +++ b/src/netius/test/common/calc.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius.common + class CalcTest(unittest.TestCase): def test_jacobi_witness(self): diff --git a/src/netius/test/common/dkim.py b/src/netius/test/common/dkim.py index c74b4cae8..47dc56f36 100644 --- a/src/netius/test/common/dkim.py +++ b/src/netius/test/common/dkim.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -60,10 +51,10 @@ 5NGFN3d6+ZMDzTl9aPqoSAiRfLXFmXMNTFFfNerMSd4YukU8+32kbybY/SdkjNnG\ N3qtMUEjP3bw9X6lAgMBAAE=" -DNS_LABEL = b"20160523113052._domainkey.netius.hive.pt. IN TXT\ -\"k=rsa; p=MIGeMA0GCSqGSIb3DQEBAQUAA4GMADCBiAKBgIEVkl9OywdFc6Q8teLWmjtW/o7kFYU\ +DNS_LABEL = b'20160523113052._domainkey.netius.hive.pt. IN TXT\ +"k=rsa; p=MIGeMA0GCSqGSIb3DQEBAQUAA4GMADCBiAKBgIEVkl9OywdFc6Q8teLWmjtW/o7kFYU\ Kcisv38Wi7qaIcuFAzsRyVCM+480egfJfdyl1qsRuVGStgnbi+kyiRkGfoxpE5NGFN3d6+ZMDzTl9a\ -PqoSAiRfLXFmXMNTFFfNerMSd4YukU8+32kbybY/SdkjNnGN3qtMUEjP3bw9X6lAgMBAAE=\"" +PqoSAiRfLXFmXMNTFFfNerMSd4YukU8+32kbybY/SdkjNnGN3qtMUEjP3bw9X6lAgMBAAE="' MESSAGE = b"Header: Value\r\n\r\nHello World" @@ -71,6 +62,7 @@ i=@netius.hive.pt; l=13; q=dns/txt; s=20160523113052; t=1464003802;\r\n\ h=Header; bh=sIAi0xXPHrEtJmW97Q5q9AZTwKC+l1Iy+0m8vQIc/DY=; b=TTDenBUdjKRjBAORnX2mhIZLVdeK2R4xfLPYERKthDvKsDvfdFgv4znf0BpyV/7gjSc7v2VAoeDxZSeYueZ8xtI2XEU2VoJFRy9Ccm0aFnFLy5H3yldK3xye4pKQ+8goRfjrlL/AMfaoDNJsEXXw1+ZPaRYeKnB1OwNTOC2a194=\r\n" + class DKIMTest(unittest.TestCase): def test_simple(self): @@ -80,6 +72,6 @@ def test_simple(self): "20160523113052", "netius.hive.pt", private_key, - creation = 1464003802 + creation=1464003802, ) self.assertEqual(result, RESULT) diff --git a/src/netius/test/common/http.py b/src/netius/test/common/http.py index 7e9df404c..ec36e1a77 100644 --- a/src/netius/test/common/http.py +++ b/src/netius/test/common/http.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -103,14 +94,11 @@ \r\n\ Hello World" + class HTTPParserTest(unittest.TestCase): def test_simple(self): - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) try: parser.parse(SIMPLE_REQUEST) message = parser.get_message() @@ -127,11 +115,7 @@ def test_simple(self): parser.clear() def test_chunked(self): - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) try: parser.parse(CHUNKED_REQUEST) message = parser.get_message() @@ -148,11 +132,7 @@ def test_chunked(self): parser.clear() def test_malformed(self): - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) try: parser.parse(EXTRA_SPACES_REQUEST) message = parser.get_message() @@ -168,11 +148,7 @@ def test_malformed(self): finally: parser.clear() - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) try: parser.parse(INVALID_HEADERS_REQUEST) message = parser.get_message() @@ -189,112 +165,87 @@ def test_malformed(self): finally: parser.clear() - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) try: if hasattr(self, "assertRaisesRegexp"): self.assertRaisesRegexp( netius.ParserError, "Invalid header value", - lambda: parser.parse(INVALID_HEADERS_TAB_REQUEST) + lambda: parser.parse(INVALID_HEADERS_TAB_REQUEST), ) else: self.assertRaises( netius.ParserError, - lambda: parser.parse(INVALID_HEADERS_TAB_REQUEST) + lambda: parser.parse(INVALID_HEADERS_TAB_REQUEST), ) finally: parser.clear() - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) try: if hasattr(self, "assertRaisesRegexp"): self.assertRaisesRegexp( netius.ParserError, "Invalid header value", - lambda: parser.parse(INVALID_HEADERS_NEWLINE_REQUEST) + lambda: parser.parse(INVALID_HEADERS_NEWLINE_REQUEST), ) else: self.assertRaises( netius.ParserError, - lambda: parser.parse(INVALID_HEADERS_NEWLINE_REQUEST) + lambda: parser.parse(INVALID_HEADERS_NEWLINE_REQUEST), ) finally: parser.clear() - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) try: if hasattr(self, "assertRaisesRegexp"): self.assertRaisesRegexp( netius.ParserError, "Chunked encoding with content length set", - lambda: parser.parse(INVALID_CHUNKED_REQUEST) + lambda: parser.parse(INVALID_CHUNKED_REQUEST), ) else: self.assertRaises( - netius.ParserError, - lambda: parser.parse(INVALID_CHUNKED_REQUEST) + netius.ParserError, lambda: parser.parse(INVALID_CHUNKED_REQUEST) ) finally: parser.clear() - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) try: if hasattr(self, "assertRaisesRegexp"): self.assertRaisesRegexp( netius.ParserError, "Invalid transfer encoding", - lambda: parser.parse(INVALID_TRANSFER_ENCODING_REQUEST) + lambda: parser.parse(INVALID_TRANSFER_ENCODING_REQUEST), ) else: self.assertRaises( netius.ParserError, - lambda: parser.parse(INVALID_TRANSFER_ENCODING_REQUEST) + lambda: parser.parse(INVALID_TRANSFER_ENCODING_REQUEST), ) finally: parser.clear() - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) try: if hasattr(self, "assertRaisesRegexp"): self.assertRaisesRegexp( netius.ParserError, "Invalid status line ", - lambda: parser.parse(INVALID_STATUS_REQUEST) + lambda: parser.parse(INVALID_STATUS_REQUEST), ) else: self.assertRaises( - netius.ParserError, - lambda: parser.parse(INVALID_STATUS_REQUEST) + netius.ParserError, lambda: parser.parse(INVALID_STATUS_REQUEST) ) finally: parser.clear() def test_file(self): parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True, - file_limit = -1 + self, type=netius.common.REQUEST, store=True, file_limit=-1 ) try: parser.parse(CHUNKED_REQUEST) @@ -312,10 +263,7 @@ def test_file(self): def test_no_store(self): parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = False, - file_limit = -1 + self, type=netius.common.REQUEST, store=False, file_limit=-1 ) try: parser.parse(CHUNKED_REQUEST) @@ -325,11 +273,7 @@ def test_no_store(self): parser.clear() def test_clear(self): - parser = netius.common.HTTPParser( - self, - type = netius.common.REQUEST, - store = True - ) + parser = netius.common.HTTPParser(self, type=netius.common.REQUEST, store=True) parser.parse(SIMPLE_REQUEST) parser.clear() self.assertEqual(parser.type, netius.common.REQUEST) diff --git a/src/netius/test/common/mime.py b/src/netius/test/common/mime.py index 48f41f8d7..7c257b70d 100644 --- a/src/netius/test/common/mime.py +++ b/src/netius/test/common/mime.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius.common + class MimeTest(unittest.TestCase): def test_headers(self): diff --git a/src/netius/test/common/rsa.py b/src/netius/test/common/rsa.py index 4f763d3b7..8cde2c21a 100644 --- a/src/netius/test/common/rsa.py +++ b/src/netius/test/common/rsa.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius.common + class RSATest(unittest.TestCase): def test_rsa_crypt(self): diff --git a/src/netius/test/common/setup.py b/src/netius/test/common/setup.py index 31a80ac78..460385289 100644 --- a/src/netius/test/common/setup.py +++ b/src/netius/test/common/setup.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,10 +33,11 @@ import netius.common + class CommonTest(unittest.TestCase): def test__download_ca(self): - netius.common.ensure_ca(path = "test.ca") + netius.common.ensure_ca(path="test.ca") file = open("test.ca", "rb") try: data = file.read() diff --git a/src/netius/test/common/util.py b/src/netius/test/common/util.py index 2957b1398..2c2153f04 100644 --- a/src/netius/test/common/util.py +++ b/src/netius/test/common/util.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius.common + class UtilTest(unittest.TestCase): def test_is_ip4(self): @@ -123,34 +115,34 @@ def test_integer_to_bytes(self): self.assertEqual(result, b"Hello World") def test_size_round_unit(self): - result = netius.common.size_round_unit(209715200, space = True) + result = netius.common.size_round_unit(209715200, space=True) self.assertEqual(result, "200 MB") - result = netius.common.size_round_unit(20480, space = True) + result = netius.common.size_round_unit(20480, space=True) self.assertEqual(result, "20 KB") - result = netius.common.size_round_unit(2048, reduce = False, space = True) + result = netius.common.size_round_unit(2048, reduce=False, space=True) self.assertEqual(result, "2.00 KB") - result = netius.common.size_round_unit(2500, space = True) + result = netius.common.size_round_unit(2500, space=True) self.assertEqual(result, "2.44 KB") - result = netius.common.size_round_unit(2500, reduce = False, space = True) + result = netius.common.size_round_unit(2500, reduce=False, space=True) self.assertEqual(result, "2.44 KB") result = netius.common.size_round_unit(1) self.assertEqual(result, "1B") - result = netius.common.size_round_unit(2048, minimum = 2049, reduce = False) + result = netius.common.size_round_unit(2048, minimum=2049, reduce=False) self.assertEqual(result, "2048B") - result = netius.common.size_round_unit(2049, places = 4, reduce = False) + result = netius.common.size_round_unit(2049, places=4, reduce=False) self.assertEqual(result, "2.001KB") - result = netius.common.size_round_unit(2048, places = 0, reduce = False) + result = netius.common.size_round_unit(2048, places=0, reduce=False) self.assertEqual(result, "2KB") - result = netius.common.size_round_unit(2049, places = 0, reduce = False) + result = netius.common.size_round_unit(2049, places=0, reduce=False) self.assertEqual(result, "2KB") def test_verify(self): @@ -164,7 +156,7 @@ def test_verify(self): self.assertRaises( netius.NetiusError, - lambda: netius.common.verify(1 == 2, exception = netius.NetiusError) + lambda: netius.common.verify(1 == 2, exception=netius.NetiusError), ) def test_verify_equal(self): @@ -174,11 +166,13 @@ def test_verify_equal(self): result = netius.common.verify_equal("hello", "hello") self.assertEqual(result, None) - self.assertRaises(netius.AssertionError, lambda: netius.common.verify_equal(1, 2)) + self.assertRaises( + netius.AssertionError, lambda: netius.common.verify_equal(1, 2) + ) self.assertRaises( netius.NetiusError, - lambda: netius.common.verify_equal(1, 2, exception = netius.NetiusError) + lambda: netius.common.verify_equal(1, 2, exception=netius.NetiusError), ) def test_verify_not_equal(self): @@ -188,11 +182,13 @@ def test_verify_not_equal(self): result = netius.common.verify_not_equal("hello", "world") self.assertEqual(result, None) - self.assertRaises(netius.AssertionError, lambda: netius.common.verify_not_equal(1, 1)) + self.assertRaises( + netius.AssertionError, lambda: netius.common.verify_not_equal(1, 1) + ) self.assertRaises( netius.NetiusError, - lambda: netius.common.verify_not_equal(1, 1, exception = netius.NetiusError) + lambda: netius.common.verify_not_equal(1, 1, exception=netius.NetiusError), ) def test_verify_type(self): @@ -205,18 +201,25 @@ def test_verify_type(self): result = netius.common.verify_type(None, int) self.assertEqual(result, None) - self.assertRaises(netius.AssertionError, lambda: netius.common.verify_type(1, str)) + self.assertRaises( + netius.AssertionError, lambda: netius.common.verify_type(1, str) + ) self.assertRaises( netius.NetiusError, - lambda: netius.common.verify_type(1, str, exception = netius.NetiusError) + lambda: netius.common.verify_type(1, str, exception=netius.NetiusError), ) - self.assertRaises(netius.AssertionError, lambda: netius.common.verify_type(None, str, null = False)) + self.assertRaises( + netius.AssertionError, + lambda: netius.common.verify_type(None, str, null=False), + ) self.assertRaises( netius.NetiusError, - lambda: netius.common.verify_type(None, str, null = False, exception = netius.NetiusError) + lambda: netius.common.verify_type( + None, str, null=False, exception=netius.NetiusError + ), ) def test_verify_many(self): @@ -226,14 +229,17 @@ def test_verify_many(self): result = netius.common.verify_many(("hello" == "hello",)) self.assertEqual(result, None) - self.assertRaises(netius.AssertionError, lambda: netius.common.verify_many((1 == 2,))) + self.assertRaises( + netius.AssertionError, lambda: netius.common.verify_many((1 == 2,)) + ) - self.assertRaises(netius.AssertionError, lambda: netius.common.verify_many((1 == 1, 1 == 2))) + self.assertRaises( + netius.AssertionError, lambda: netius.common.verify_many((1 == 1, 1 == 2)) + ) self.assertRaises( netius.NetiusError, lambda: netius.common.verify_many( - (1 == 1, 1 == 2), - exception = netius.NetiusError - ) + (1 == 1, 1 == 2), exception=netius.NetiusError + ), ) diff --git a/src/netius/test/extra/__init__.py b/src/netius/test/extra/__init__.py index 52c2dd14c..81bce846f 100644 --- a/src/netius/test/extra/__init__.py +++ b/src/netius/test/extra/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/test/extra/proxy_r.py b/src/netius/test/extra/proxy_r.py index 33499e4a7..8312a42e0 100644 --- a/src/netius/test/extra/proxy_r.py +++ b/src/netius/test/extra/proxy_r.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -42,17 +33,13 @@ import netius.extra + class ReverseProxyServerTest(unittest.TestCase): def setUp(self): unittest.TestCase.setUp(self) self.server = netius.extra.ReverseProxyServer( - hosts = { - "host.com" : "http://localhost" - }, - alias = { - "alias.host.com" : "host.com" - } + hosts={"host.com": "http://localhost"}, alias={"alias.host.com": "host.com"} ) def tearDown(self): @@ -61,6 +48,6 @@ def tearDown(self): def test_alias(self): Parser = collections.namedtuple("Parser", "headers") - parser = Parser(headers = dict(host = "alias.host.com")) + parser = Parser(headers=dict(host="alias.host.com")) result = self.server.rules_host(None, parser) self.assertEqual(result, ("http://localhost", None)) diff --git a/src/netius/test/extra/smtp_r.py b/src/netius/test/extra/smtp_r.py index b25609a1b..7e9b3cdb5 100644 --- a/src/netius/test/extra/smtp_r.py +++ b/src/netius/test/extra/smtp_r.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -63,13 +54,12 @@ bh=sIAi0xXPHrEtJmW97Q5q9AZTwKC+l1Iy+0m8vQIc/DY=; b=Pr7dVjQIX3ovG78v1X45seFwA/+uyIAofJbxn5iXTRBA5Mv+YVdiI9QMm/gU1ljoSGqqC+hvLS4iB2N1kC4fGuDxXOyNaApOLSA2hl/mBpzca6SNyu6CYvUDdhmfD+8TsYMe6Vy8UY9lWpPYNgfb9BhORqPvxiC8A8F9ScTVT/s=\r\nHeader: Value\r\n\r\nHello World" REGISTRY = { - "netius.hive.pt" : dict( - key_b64 = PRIVATE_KEY, - selector = "20160523113052", - domain = "netius.hive.pt" + "netius.hive.pt": dict( + key_b64=PRIVATE_KEY, selector="20160523113052", domain="netius.hive.pt" ) } + class RelaySMTPServerTest(unittest.TestCase): def setUp(self): @@ -83,9 +73,7 @@ def tearDown(self): def test_dkim(self): self.server.dkim = REGISTRY result = self.server.dkim_contents( - MESSAGE, - email = "email@netius.hive.pt", - creation = 1464003802 + MESSAGE, email="email@netius.hive.pt", creation=1464003802 ) self.assertEqual(result, RESULT) diff --git a/src/netius/test/middleware/__init__.py b/src/netius/test/middleware/__init__.py index 52c2dd14c..81bce846f 100644 --- a/src/netius/test/middleware/__init__.py +++ b/src/netius/test/middleware/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/test/middleware/proxy.py b/src/netius/test/middleware/proxy.py index d414b3998..e641b2ef4 100644 --- a/src/netius/test/middleware/proxy.py +++ b/src/netius/test/middleware/proxy.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -43,11 +34,12 @@ import netius.common import netius.middleware + class ProxyMiddlewareTest(unittest.TestCase): def setUp(self): unittest.TestCase.setUp(self) - self.server = netius.Server(poll = netius.Poll) + self.server = netius.Server(poll=netius.Poll) self.server.poll.open() def tearDown(self): @@ -55,11 +47,9 @@ def tearDown(self): self.server.cleanup() def test_ipv4_v1(self): - instance = self.server.register_middleware( - netius.middleware.ProxyMiddleware - ) + instance = self.server.register_middleware(netius.middleware.ProxyMiddleware) - connection = netius.Connection(owner = self.server) + connection = netius.Connection(owner=self.server) connection.open() connection.restore(b"PROXY TCP4 192.168.1.1 192.168.1.2 32598 8080\r\n") @@ -69,25 +59,23 @@ def test_ipv4_v1(self): self.assertEqual(len(connection.restored), 0) def test_ipv6_v1(self): - instance = self.server.register_middleware( - netius.middleware.ProxyMiddleware - ) + instance = self.server.register_middleware(netius.middleware.ProxyMiddleware) - connection = netius.Connection(owner = self.server) + connection = netius.Connection(owner=self.server) connection.open() - connection.restore(b"PROXY TCP4 fe80::787f:f63f:3176:d61b fe80::787f:f63f:3176:d61c 32598 8080\r\n") + connection.restore( + b"PROXY TCP4 fe80::787f:f63f:3176:d61b fe80::787f:f63f:3176:d61c 32598 8080\r\n" + ) instance._proxy_handshake_v1(connection) self.assertEqual(connection.address, ("fe80::787f:f63f:3176:d61b", 32598)) self.assertEqual(len(connection.restored), 0) def test_starter_v1(self): - self.server.register_middleware( - netius.middleware.ProxyMiddleware - ) + self.server.register_middleware(netius.middleware.ProxyMiddleware) - connection = netius.Connection(owner = self.server) + connection = netius.Connection(owner=self.server) connection.open() connection.restore(b"PROXY TCP4 192.168.1.1 192.168.1.2 32598 8080\r\n") @@ -97,7 +85,7 @@ def test_starter_v1(self): self.assertEqual(connection.restored_s, 0) self.assertEqual(len(connection.restored), 0) - connection = netius.Connection(owner = self.server) + connection = netius.Connection(owner=self.server) connection.open() connection.restore(b"PROXY TCP4 192.168.1.3 ") @@ -108,7 +96,7 @@ def test_starter_v1(self): self.assertEqual(connection.restored_s, 0) self.assertEqual(len(connection.restored), 0) - connection = netius.Connection(owner = self.server) + connection = netius.Connection(owner=self.server) connection.open() connection.restore(b"PROXY TCP4 192.168.1.3 ") @@ -121,11 +109,9 @@ def test_starter_v1(self): self.assertEqual(len(connection.restored), 2) def test_starter_v2(self): - self.server.register_middleware( - netius.middleware.ProxyMiddleware, version = 2 - ) + self.server.register_middleware(netius.middleware.ProxyMiddleware, version=2) - connection = netius.Connection(owner = self.server) + connection = netius.Connection(owner=self.server) connection.open() body = struct.pack( @@ -133,15 +119,16 @@ def test_starter_v2(self): netius.common.ip4_to_addr("192.168.1.1"), netius.common.ip4_to_addr("192.168.1.2"), 32598, - 8080 + 8080, ) header = struct.pack( "!12sBBH", netius.middleware.ProxyMiddleware.HEADER_MAGIC_V2, (2 << 4) + (netius.middleware.ProxyMiddleware.TYPE_PROXY_V2), - (netius.middleware.ProxyMiddleware.AF_INET_v2 << 4) + (netius.middleware.ProxyMiddleware.PROTO_STREAM_v2), - len(body) + (netius.middleware.ProxyMiddleware.AF_INET_v2 << 4) + + (netius.middleware.ProxyMiddleware.PROTO_STREAM_v2), + len(body), ) connection.restore(header) diff --git a/src/netius/test/pool/__init__.py b/src/netius/test/pool/__init__.py index 52c2dd14c..81bce846f 100644 --- a/src/netius/test/pool/__init__.py +++ b/src/netius/test/pool/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/test/pool/common.py b/src/netius/test/pool/common.py index 32900a246..f712646b7 100644 --- a/src/netius/test/pool/common.py +++ b/src/netius/test/pool/common.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,6 +32,7 @@ import netius.pool + class EventPoolTest(unittest.TestCase): def test_event(self): diff --git a/src/netius/test/servers/__init__.py b/src/netius/test/servers/__init__.py index 52c2dd14c..81bce846f 100644 --- a/src/netius/test/servers/__init__.py +++ b/src/netius/test/servers/__init__.py @@ -19,15 +19,6 @@ # You should have received a copy of the Apache License along with # Hive Netius System. If not, see . -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ diff --git a/src/netius/test/servers/http.py b/src/netius/test/servers/http.py index cc67db94a..0ff40b7cb 100644 --- a/src/netius/test/servers/http.py +++ b/src/netius/test/servers/http.py @@ -22,15 +22,6 @@ __author__ = "João Magalhães " """ The author(s) of the module """ -__version__ = "1.0.0" -""" The version of the module """ - -__revision__ = "$LastChangedRevision$" -""" The revision number of the module """ - -__date__ = "$LastChangedDate$" -""" The last change date of the module """ - __copyright__ = "Copyright (c) 2008-2020 Hive Solutions Lda." """ The copyright for the module """ @@ -41,52 +32,41 @@ import netius.servers + class HTTPServerTest(unittest.TestCase): def test__headers_upper(self): http_server = netius.servers.HTTPServer() - headers = { - "content-type" : "plain/text", - "content-length" : "12" - } + headers = {"content-type": "plain/text", "content-length": "12"} http_server._headers_upper(headers) - self.assertEqual(headers, { - "Content-Type" : "plain/text", - "Content-Length" : "12" - }) + self.assertEqual( + headers, {"Content-Type": "plain/text", "Content-Length": "12"} + ) - headers = { - "content-Type" : "plain/text", - "content-LEngtH" : "12" - } + headers = {"content-Type": "plain/text", "content-LEngtH": "12"} http_server._headers_upper(headers) - self.assertEqual(headers, { - "Content-Type" : "plain/text", - "Content-Length" : "12" - }) + self.assertEqual( + headers, {"Content-Type": "plain/text", "Content-Length": "12"} + ) def test__headers_normalize(self): http_server = netius.servers.HTTPServer() - headers = { - "Content-Type" : ["plain/text"], - "Content-Length" : ["12"] - } + headers = {"Content-Type": ["plain/text"], "Content-Length": ["12"]} http_server._headers_normalize(headers) - self.assertEqual(headers, { - "Content-Type" : "plain/text", - "Content-Length" : "12" - }) + self.assertEqual( + headers, {"Content-Type": "plain/text", "Content-Length": "12"} + ) headers = { - "Content-Type" : ["application/json", "charset=utf-8"], - "Content-Length" : "12" + "Content-Type": ["application/json", "charset=utf-8"], + "Content-Length": "12", } http_server._headers_normalize(headers) - self.assertEqual(headers, { - "Content-Type" : "application/json;charset=utf-8", - "Content-Length" : "12" - }) + self.assertEqual( + headers, + {"Content-Type": "application/json;charset=utf-8", "Content-Length": "12"}, + )