-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
453 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright 2024 Hathor Labs | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import time | ||
|
||
from twisted.internet.interfaces import IAddress | ||
from twisted.internet.protocol import ServerFactory | ||
from twisted.protocols.basic import LineReceiver | ||
|
||
from hathor.reactor import initialize_global_reactor | ||
|
||
|
||
class HathorProtocol: | ||
def __init__(self, manager: HathorManager) -> None: | ||
self._manager = manager | ||
|
||
def do_something(self, data: bytes) -> None: | ||
print('printing HathorManager data from HathorProtocol: ', self._manager.read_storage()) | ||
time.sleep(5) | ||
self._manager.save_storage(data) | ||
self.send_line(b'some line ' + data) | ||
|
||
def send_line(self, data: bytes) -> None: | ||
raise NotImplementedError | ||
|
||
|
||
class MyLineReceiver(LineReceiver, HathorProtocol): | ||
def lineReceived(self, data: bytes) -> None: | ||
self.do_something(data) | ||
|
||
def send_line(self, data: bytes) -> None: | ||
self.sendLine(data) | ||
|
||
|
||
class MyFactory(ServerFactory): | ||
def __init__(self, manager: HathorManager) -> None: | ||
self._manager = manager | ||
|
||
def buildProtocol(self, addr: IAddress) -> MyLineReceiver: | ||
return MyLineReceiver(self._manager) | ||
|
||
|
||
class HathorManager: | ||
def __init__(self, *, storage: bytes): | ||
self._storage = storage | ||
|
||
def read_storage(self) -> bytes: | ||
return self._storage | ||
|
||
def save_storage(self, data: bytes) -> None: | ||
print('printing from HathorManager.save_storage: ', data) | ||
|
||
|
||
def main() -> None: | ||
port = 8080 | ||
reactor = initialize_global_reactor() | ||
manager = HathorManager(storage=b'manager storage') | ||
factory = MyFactory(manager) | ||
reactor.listenTCP(port, factory) | ||
print(f'Server running on port {port}') | ||
reactor.run() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright 2024 Hathor Labs | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import multiprocessing | ||
from abc import ABC, abstractmethod | ||
from multiprocessing import Pipe, Process | ||
from multiprocessing.connection import Connection | ||
from multiprocessing.sharedctypes import Synchronized | ||
from typing import Any, Callable, Coroutine, NamedTuple, TypeAlias, TypeVar, Union | ||
|
||
from twisted.internet.defer import Deferred | ||
from twisted.internet.task import LoopingCall | ||
from typing_extensions import Self | ||
|
||
from hathor.reactor import ReactorProtocol, initialize_global_reactor | ||
from hathor.transaction.util import bytes_to_int, int_to_bytes | ||
|
||
POLLING_INTERVAL: float = 0.001 | ||
MESSAGE_SEPARATOR: bytes = b' ' | ||
MAX_MESSAGE_ID: int = 2**64-1 | ||
|
||
ClientT = TypeVar('ClientT', bound='IpcClient') | ||
|
||
|
||
def connect( | ||
*, | ||
main_reactor: ReactorProtocol, | ||
main_client: IpcClient, | ||
main_server: IpcServer, | ||
subprocess_client_builder: Callable[[], ClientT], | ||
subprocess_server_builder: Callable[[ClientT], IpcServer], | ||
subprocess_name: str, | ||
) -> None: | ||
conn1: Connection | ||
conn2: Connection | ||
conn1, conn2 = Pipe() | ||
message_id = multiprocessing.Value('L', 0) | ||
|
||
main_ipc_conn = _IpcConnection( | ||
reactor=main_reactor, name='main', conn=conn1, message_id=message_id, server=main_server | ||
) | ||
main_client.set_ipc_conn(main_ipc_conn) | ||
main_ipc_conn.start_listening() | ||
|
||
subprocess = Process( | ||
name=subprocess_name, | ||
target=_run_subprocess, | ||
kwargs=dict( | ||
name=subprocess_name, | ||
conn=conn2, | ||
client_builder=subprocess_client_builder, | ||
server_builder=subprocess_server_builder, | ||
message_id=message_id, | ||
), | ||
) | ||
subprocess.start() | ||
|
||
|
||
def _run_subprocess( | ||
*, | ||
name: str, | ||
conn: Connection, | ||
client_builder: Callable[[], IpcClient], | ||
server_builder: Callable[[IpcClient], IpcServer], | ||
message_id: Synchronized, | ||
) -> None: | ||
subprocess_reactor = initialize_global_reactor() | ||
client = client_builder() | ||
server = server_builder(client) | ||
subprocess_ipc_conn = _IpcConnection( | ||
reactor=subprocess_reactor, name=name, conn=conn, server=server, message_id=message_id | ||
) | ||
client.set_ipc_conn(subprocess_ipc_conn) | ||
subprocess_ipc_conn.start_listening() | ||
subprocess_reactor.run() | ||
|
||
|
||
IpcCommand: TypeAlias = Union[ | ||
Callable[[bytes], Coroutine[Deferred[bytes], Any, bytes]], | ||
Callable[[bytes], Coroutine[Deferred[None], Any, None]], | ||
] | ||
|
||
|
||
class IpcServer(ABC): | ||
@abstractmethod | ||
def get_cmd_map(self) -> dict[bytes, IpcCommand]: | ||
raise NotImplementedError | ||
|
||
async def handle_request(self, request: bytes) -> bytes: | ||
cmd_name, _, data = request.partition(MESSAGE_SEPARATOR) | ||
cmd_map = self.get_cmd_map() | ||
cmd = cmd_map.get(cmd_name) | ||
assert cmd is not None, cmd_name | ||
result = await cmd(data) | ||
return result if result is not None else b'success' | ||
|
||
|
||
class IpcClient(ABC): | ||
__slots__ = ('_ipc_conn',) | ||
|
||
def __init__(self) -> None: | ||
self._ipc_conn: _IpcConnection | None = None | ||
|
||
def set_ipc_conn(self, ipc_conn: _IpcConnection) -> None: | ||
assert self._ipc_conn is None | ||
self._ipc_conn = ipc_conn | ||
|
||
def call(self, cmd: bytes, request: bytes | None = None) -> Deferred[bytes]: | ||
assert self._ipc_conn is not None | ||
return self._ipc_conn.call(cmd, request) | ||
|
||
|
||
class _Message(NamedTuple): | ||
id: int | ||
data: bytes | ||
|
||
def serialize(self) -> bytes: | ||
return int_to_bytes(self.id, size=8) + MESSAGE_SEPARATOR + self.data | ||
|
||
@classmethod | ||
def deserialize(cls, data: bytes) -> Self: | ||
id_, separator, data = data.partition(MESSAGE_SEPARATOR) | ||
assert separator == MESSAGE_SEPARATOR | ||
return cls( | ||
id=bytes_to_int(id_), | ||
data=data, | ||
) | ||
|
||
|
||
class _IpcConnection: | ||
__slots__ = ( | ||
'_name', | ||
'_conn', | ||
'_server', | ||
'_message_id', | ||
'_poll_lc', | ||
'_pending_calls', | ||
) | ||
|
||
def __init__( | ||
self, | ||
*, | ||
reactor: ReactorProtocol, | ||
name: str, | ||
conn: Connection, | ||
server: IpcServer, | ||
message_id: Synchronized, | ||
) -> None: | ||
self._name = name | ||
self._conn = conn | ||
self._server = server | ||
self._message_id = message_id | ||
self._poll_lc = LoopingCall(self._safe_poll) | ||
self._poll_lc.clock = reactor | ||
self._pending_calls: dict[int, Deferred[bytes]] = {} | ||
|
||
def start_listening(self) -> None: | ||
self._poll_lc.start(POLLING_INTERVAL, now=False) | ||
|
||
def call(self, cmd: bytes, request: bytes | None) -> Deferred[bytes]: | ||
data = cmd if request is None else cmd + MESSAGE_SEPARATOR + request | ||
message = self._send_message(data) | ||
deferred: Deferred[bytes] = Deferred() | ||
self._pending_calls[message.id] = deferred | ||
return deferred | ||
|
||
def _send_message(self, data: bytes, request_id: int | None = None) -> _Message: | ||
message_id = self._get_new_message_id() if request_id is None else request_id | ||
message = _Message(id=message_id, data=data) | ||
self._conn.send_bytes(message.serialize()) | ||
return message | ||
|
||
def _get_new_message_id(self) -> int: | ||
with self._message_id.get_lock(): | ||
message_id = self._message_id.value | ||
assert message_id < MAX_MESSAGE_ID | ||
self._message_id.value += 1 | ||
return message_id | ||
|
||
def _safe_poll(self) -> None: | ||
try: | ||
self._unsafe_poll() | ||
except Exception as e: | ||
print('error', e) | ||
|
||
def _unsafe_poll(self) -> None: | ||
if not self._conn.poll(): | ||
return | ||
|
||
message_bytes = self._conn.recv_bytes() | ||
message = _Message.deserialize(message_bytes) | ||
|
||
if pending_call := self._pending_calls.pop(message.id, None): | ||
# The received message is a response for one of our own requests | ||
# print(f'res({self._name}): {message_data}') | ||
pending_call.callback(message.data) | ||
return | ||
|
||
# The received message is a new request | ||
# print(f'req({self._name}): {message_data}') | ||
coro = self._server.handle_request(message.data) | ||
deferred = Deferred.fromCoroutine(coro) | ||
deferred.addCallback(lambda response: self._send_message(response, request_id=message.id)) |
Oops, something went wrong.