Skip to content

Commit

Permalink
NDEV-3036: Create a new service for the Private RPC API
Browse files Browse the repository at this point in the history
  • Loading branch information
alfiedotwtf committed Jun 14, 2024
1 parent fbd3667 commit b2d28c8
Show file tree
Hide file tree
Showing 26 changed files with 494 additions and 164 deletions.
12 changes: 12 additions & 0 deletions common/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class Config:
# Statistic configuration
gather_stat_name: Final[str] = "GATHER_STATISTICS"
# Proxy configuration
rpc_private_ip_name: Final[str] = "RPC_PRIVATE_IP"
rpc_private_port_name: Final[str] = "RPC_PRIVATE_PORT"
rpc_public_port_name: Final[str] = "RPC_PUBLIC_PORT"
rpc_process_cnt_name: Final[str] = "RPC_PROCESS_COUNT"
rpc_worker_cnt_name: Final[str] = "RPC_WORKER_COUNT"
Expand Down Expand Up @@ -431,6 +433,14 @@ def gather_stat(self) -> bool:
#########################
# Proxy configuration

@cached_property
def rpc_private_ip(self) -> str:
return os.environ.get(self.rpc_private_ip_name, self.base_service_ip)

@cached_property
def rpc_private_port(self) -> int:
return self._env_num(self.rpc_private_port_name, self.rpc_public_port + 1, 8000, 25000)

@cached_property
def rpc_public_port(self) -> int:
return self._env_num(self.rpc_public_port_name, 9090, 8000, 25000)
Expand Down Expand Up @@ -827,6 +837,8 @@ def to_string(self) -> str:
self.gather_stat_name: self.gather_stat,
self.debug_cmd_line_name: self.debug_cmd_line,
# Proxy configuration
self.rpc_private_ip_name: self.rpc_private_ip,
self.rpc_private_port_name: self.rpc_private_port,
self.rpc_public_port_name: self.rpc_public_port,
self.rpc_process_cnt_name: self.rpc_process_cnt,
self.rpc_worker_cnt_name: self.rpc_worker_cnt,
Expand Down
12 changes: 12 additions & 0 deletions common/http/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import hashlib
import inspect
import re
import time
Expand Down Expand Up @@ -66,6 +67,17 @@ def from_raw(

return self

@cached_property
def ctx_id(self) -> str:
if ctx_id := getattr(self, "_ctx_id", None):
return ctx_id

size = len(self.request.body)
raw_value = f"{self.ip_addr}:{size}:{self.start_time_nsec}"
ctx_id = hashlib.md5(bytes(raw_value, "utf-8")).hexdigest()[:8]
self.set_property_value("_ctx_id", ctx_id)
return ctx_id

@cached_property
def body(self) -> str:
value = self.request.body
Expand Down
7 changes: 7 additions & 0 deletions common/neon/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
from typing import Final, Annotated, Union

import eth_account
import eth_keys
import eth_utils
from pydantic.functional_serializers import PlainSerializer
Expand Down Expand Up @@ -145,6 +146,12 @@ def private_key(self) -> eth_keys.keys.PrivateKey:
assert self._private_key
return self._private_key

def sign_msg(self, data: bytes) -> eth_keys.datatypes.Signature:
return self.private_key.sign_msg(data)

def sign_transaction(self, tx: dict[str, str]) -> eth_account.datastructures.SignedTransaction:
return eth_account.Account.sign_transaction(tx, self.private_key)

def __str__(self) -> str:
return self.to_string()

Expand Down
4 changes: 2 additions & 2 deletions proxy/base/mp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,5 @@ class MpGetTxResp(BaseModel):


class MpTxPoolContentResp(BaseModel):
pending_list: tuple[NeonTxModel, ...]
queued_list: tuple[NeonTxModel, ...]
pending_list: list[NeonTxModel]
queued_list: list[NeonTxModel]
20 changes: 20 additions & 0 deletions proxy/base/op_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing_extensions import Self, ClassVar

from common.ethereum.hash import EthAddressField, EthAddress
from common.neon.transaction_model import NeonTxModel
from common.solana.pubkey import SolPubKey, SolPubKeyField
from common.solana.transaction_model import SolTxModel
from common.utils.cached import cached_method
Expand Down Expand Up @@ -79,6 +80,25 @@ class OpTokenSolAddressModel(BaseModel):
token_sol_address: SolPubKeyField


class OpSignEthMessageRequest(BaseModel):
ctx_id: str
eth_address: EthAddressField
data: str


class OpSignEthMessageResp(BaseModel):
signed_message: str


class OpSignEthTxRequest(BaseModel):
ctx_id: str
tx: NeonTxModel


class OpSignEthTxResp(BaseModel):
signed_tx: NeonTxModel


class OpSignSolTxListRequest(BaseModel):
req_id: dict
owner: SolPubKeyField
Expand Down
19 changes: 19 additions & 0 deletions proxy/base/op_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Sequence

from common.app_data.client import AppDataClient
from common.neon.transaction_model import NeonTxModel
from common.solana.pubkey import SolPubKey
from common.solana.transaction import SolTx
from common.solana.transaction_model import SolTxModel
Expand All @@ -14,6 +15,10 @@
OpResourceResp,
OpTokenSolAddressModel,
OpGetTokenSolAddressRequest,
OpSignEthMessageRequest,
OpSignEthMessageResp,
OpSignEthTxRequest,
OpSignEthTxResp,
OpSignSolTxListRequest,
OpSolTxListResp,
OpGetSignerKeyListRequest,
Expand Down Expand Up @@ -48,6 +53,14 @@ async def get_token_sol_address(self, req_id: dict, owner: SolPubKey, chain_id:
resp = await self._get_token_sol_address(req)
return resp.token_sol_address

async def sign_eth_message(self, ctx_id: str, eth_address: str, data: str) -> OpSignEthMessageResp:
req = OpSignEthMessageRequest(ctx_id=ctx_id, eth_address=eth_address, data=data)
return await self._sign_eth_message(req)

async def sign_eth_transaction(self, ctx_id: str, neon_tx: NeonTxModel) -> OpSignEthTxResp:
req = OpSignEthTxRequest(ctx_id=ctx_id, tx=neon_tx)
return await self._sign_eth_transaction(req)

async def sign_sol_tx_list(self, req_id: dict, owner: SolPubKey, tx_list: Sequence[SolTx]) -> tuple[SolTx, ...]:
model_list = [SolTxModel.from_raw(tx) for tx in tx_list]
req = OpSignSolTxListRequest(req_id=req_id, owner=owner, tx_list=model_list)
Expand Down Expand Up @@ -78,6 +91,12 @@ async def _free_resource(self, request: OpFreeResourceRequest) -> OpResourceResp
@AppDataClient.method(name="getOperatorTokenAddress")
async def _get_token_sol_address(self, request: OpGetTokenSolAddressRequest) -> OpTokenSolAddressModel: ...

@AppDataClient.method(name="signEthMessage")
async def _sign_eth_message(self, request: OpSignEthMessageRequest) -> OpSignEthMessageResp: ...

@AppDataClient.method(name="signEthTransaction")
async def _sign_eth_transaction(self, request: OpSignEthTxRequest) -> OpSignEthTxResp: ...

@AppDataClient.method(name="signSolanaTransactionList")
async def _sign_sol_tx_list(self, request: OpSignSolTxListRequest) -> OpSolTxListResp: ...

Expand Down
126 changes: 126 additions & 0 deletions proxy/base/rpc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import abc
import logging
from typing_extensions import Self

from common.config.config import Config
from common.http.errors import HttpRouteError
from common.http.utils import HttpRequestCtx
from common.jsonrpc.api import JsonRpcListRequest, JsonRpcListResp, JsonRpcRequest, JsonRpcResp
from common.jsonrpc.server import JsonRpcApi, JsonRpcServer
from common.neon.neon_program import NeonProg
from common.neon_rpc.api import EvmConfigModel
from common.stat.api import RpcCallData
from common.utils.cached import ttl_cached_method
from common.utils.json_logger import logging_context, log_msg
from typing import Callable
from ..base.mp_client import MempoolClient
from ..stat.client import StatClient

_LOG = logging.getLogger(__name__)


class RpcServer(JsonRpcServer):
def __init__(self, cfg: Config, mp_client: MempoolClient, stat_client: StatClient) -> None:
super().__init__(cfg)
self._mp_client = mp_client
self._stat_client = stat_client

@abc.abstractmethod
def endpoints(_cls) -> list[str]: ...

def _add_api(self, api: JsonRpcApi) -> Self:
_LOG.info(log_msg(f"Adding API {api.name}"))

for endpoint in self.endpoints():
_LOG.info(log_msg(f"Adding API {api.name} to endpoint {endpoint}"))
super().add_api(api, endpoint=endpoint)
return self

async def on_request_list(self, ctx: HttpRequestCtx, request: JsonRpcListRequest) -> None:
await self._validate_chain_id(ctx)
with logging_context(ctx=ctx.ctx_id):
_LOG.info(log_msg("handle BIG request <<< {IP} size={Size}", IP=ctx.ip_addr, Size=len(request.root)))

async def _validate_chain_id(self, ctx: HttpRequestCtx) -> None:
NeonProg.validate_protocol()

if not getattr(ctx, "chain_id", None):
await self._set_chain_id(ctx)

async def _set_chain_id(self, ctx: HttpRequestCtx) -> int:
evm_cfg = await self.get_evm_cfg()
if not (token_name := ctx.request.path_params.get("token", "").strip().upper()):
chain_id = evm_cfg.default_chain_id
ctx.set_property_value("is_default_chain_id", True)
elif token := evm_cfg.token_dict.get(token_name, None):
chain_id = token.chain_id
ctx.set_property_value("is_default_chain_id", token.is_default)
else:
raise HttpRouteError()

ctx.set_property_value("chain_id", chain_id)
return chain_id

@ttl_cached_method(ttl_sec=1)
async def get_evm_cfg(self) -> EvmConfigModel:
# forwarding request to mempool allows to limit the number of requests to Solana to maximum 1 time per second
# for details, see the mempool_server::get_evm_cfg() implementation
evm_cfg = await self._mp_client.get_evm_cfg()
NeonProg.init_prog(evm_cfg.treasury_pool_cnt, evm_cfg.treasury_pool_seed, evm_cfg.version)
return evm_cfg

def on_response_list(self, ctx: HttpRequestCtx, resp: JsonRpcListResp) -> None:
with logging_context(ctx=ctx.ctx_id, class_name=self.__class__.__name__):
msg = log_msg(
"done BIG request >>> {IP} size={Size} resp_time={TimeMS} msec",
IP=ctx.ip_addr,
Size=len(resp),
TimeMS=ctx.process_time_msec,
)
_LOG.info(msg)

stat = RpcCallData(service=self.stat_name, method="BIG", time_nsec=ctx.process_time_nsec, is_error=False)
self._stat_client.commit_rpc_call(stat)

def on_bad_request(self, ctx: HttpRequestCtx) -> None:
_LOG.warning(log_msg("BAD request from {IP} with size {Size}", IP=ctx.ip_addr, Size=len(ctx.request.body)))

stat = RpcCallData(service=self.stat_name, method="UNKNOWN", time_nsec=ctx.process_time_nsec, is_error=True)
self._stat_client.commit_rpc_call(stat)

async def handle_request(
self,
ctx: HttpRequestCtx,
request: JsonRpcRequest,
handler: Callable,
) -> JsonRpcResp:
await self._validate_chain_id(ctx)

info = dict(IP=ctx.ip_addr, ReqID=request.id, Method=request.method)
with logging_context(ctx=ctx.ctx_id, class_name=self.__class__.__name__):
_LOG.info(log_msg("handle request <<< {IP} req={ReqID} {Method} {Params}", Params=request.params, **info))

resp = await handler(ctx, request)
if resp.is_error:
msg = log_msg(
"error on request >>> {IP} req={ReqID} {Method} {Error} resp_time={TimeMS} msec",
Error=resp.error,
**info,
)
else:
msg = log_msg(
"done request >>> {IP} req={ReqID} {Method} {Result} resp_time={TimeMS} msec",
Result=resp.result,
**info,
)
_LOG.info(dict(**msg, TimeMS=ctx.process_time_msec))

stat = RpcCallData(
service=self.stat_name,
method=request.method,
time_nsec=ctx.process_time_nsec,
is_error=resp.is_error,
)
self._stat_client.commit_rpc_call(stat)

return resp
2 changes: 1 addition & 1 deletion proxy/mempool/transaction_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_content(self) -> MpTxPoolContentResp:
cont = tx_schedule.get_content()
pending_list.extend(cont.pending_list)
queued_list.extend(cont.queued_list)
return MpTxPoolContentResp(pending_list=tuple(pending_list), queued_list=tuple(queued_list))
return MpTxPoolContentResp(pending_list=pending_list, queued_list=queued_list)

async def _update_tx_order(self, tx: MpTxModel) -> MpTxResp | None:
if not tx.neon_tx.has_chain_id:
Expand Down
2 changes: 1 addition & 1 deletion proxy/mempool/transaction_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def get_content(self) -> MpTxPoolContentResp:
pending_list.extend(tx_list[:pending_stop_pos])
queued_list.extend(tx_list[pending_stop_pos:])

return MpTxPoolContentResp(pending_list=tuple(pending_list), queued_list=tuple(queued_list))
return MpTxPoolContentResp(pending_list=pending_list, queued_list=queued_list)

# protected:

Expand Down
18 changes: 17 additions & 1 deletion proxy/neon_proxy_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .executor.server import ExecutorServer
from .mempool.server import MempoolServer
from .operator_resource.server import OpResourceServer
from .private_rpc.server import PrivateRpcServer
from .rpc.server import NeonProxy
from .stat.client import StatClient
from .stat.server import StatServer
Expand All @@ -31,8 +32,8 @@ def __init__(self):
cfg = Config()
_LOG.info("running NeonProxy %s with the cfg: %s", NEON_PROXY_VER, cfg.to_string())

self._enable_private_rpc_server = cfg.enable_private_api
self._recv_sig_num = signal.SIG_DFL

self._msg_filter = LogMsgFilter(cfg)

# Init Solana client
Expand Down Expand Up @@ -85,6 +86,15 @@ def __init__(self):
# Init Prometheus stat
self._stat_server = StatServer(cfg=cfg)

# Init private RPC API
if self._enable_private_rpc_server:
self._private_rpc_server = PrivateRpcServer(
cfg=cfg,
mp_client=mp_client,
op_client=op_client,
stat_client=stat_client,
)

# Init external RPC API
self._proxy_server = NeonProxy(
cfg=cfg,
Expand All @@ -105,10 +115,16 @@ def start(self) -> int:
self._stat_server.start()
self._proxy_server.start()

if self._enable_private_rpc_server:
self._private_rpc_server.start()

self._register_term_signal_handler()
while self._recv_sig_num == signal.SIG_DFL:
time.sleep(1)

if self._enable_private_rpc_server:
self._private_rpc_server.stop()

self._proxy_server.stop()
self._stat_server.stop()
self._mp_server.stop()
Expand Down
Loading

0 comments on commit b2d28c8

Please sign in to comment.