Skip to content

Commit

Permalink
dev: VRF Listener Refactoring (#199)
Browse files Browse the repository at this point in the history
* dev(refacto_vrf_listener):

* dev(refacto_vrf_listener):
  • Loading branch information
akhercha authored Aug 29, 2024
1 parent 0c9d92e commit 84c989d
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 108 deletions.
10 changes: 9 additions & 1 deletion pragma-sdk/pragma_sdk/onchain/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,16 @@ class RandomnessRequest:
num_words: int
calldata: List[int]

def __hash__(self):
return hash(
(
self.request_id,
self.caller_address,
)
)

def __repr__(self) -> str:
return (
f"Request(caller_address={self.caller_address},request_id={self.request_id},"
f"minimum_block_number={self.minimum_block_number}"
f"minimum_block_number={self.minimum_block_number})"
)
97 changes: 97 additions & 0 deletions vrf-listener/vrf_listener/indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging

from typing import Optional

from grpc import ssl_channel_credentials
from grpc.aio import secure_channel
from apibara.protocol import StreamAddress, StreamService, credentials_with_auth_token
from apibara.protocol.proto.stream_pb2 import DataFinality
from apibara.starknet import Block, EventFilter, Filter, felt, starknet_cursor

from pragma_sdk.onchain.types import RandomnessRequest
from pragma_sdk.onchain.constants import RANDOMNESS_REQUEST_EVENT_SELECTOR
from pragma_sdk.onchain.client import PragmaOnChainClient

from vrf_listener.safe_queue import ThreadSafeQueue

EVENT_INDEX_TO_SPLIT = 7

logger = logging.getLogger(__name__)


class Indexer:
stream: StreamService
requests_queue: ThreadSafeQueue

def __init__(
self,
stream: StreamService,
requests_queue: ThreadSafeQueue,
) -> None:
self.stream = stream
self.requests_queue = requests_queue

@classmethod
async def from_client(
cls,
pragma_client: PragmaOnChainClient,
apibara_api_key: Optional[str],
requests_queue: ThreadSafeQueue,
from_block: Optional[int] = None,
):
"""
Creates an Indexer from a PragmaOnChainClient.
"""
if apibara_api_key is None:
raise ValueError("--apibara-api-key not provided.")

channel = secure_channel(
StreamAddress.StarkNet.Sepolia
if pragma_client.network == "sepolia"
else StreamAddress.StarkNet.Mainnet,
credentials_with_auth_token(apibara_api_key, ssl_channel_credentials()),
)
filter = (
Filter()
.with_header(weak=False)
.add_event(
EventFilter()
.with_from_address(felt.from_int(pragma_client.randomness.address))
.with_keys([felt.from_hex(RANDOMNESS_REQUEST_EVENT_SELECTOR)])
.with_include_receipt(False)
.with_include_transaction(False)
)
.encode()
)
if from_block:
current_block = from_block
else:
current_block = await pragma_client.get_block_number()

stream = StreamService(channel).stream_data_immutable(
filter=filter,
finality=DataFinality.DATA_STATUS_PENDING,
batch_size=1,
cursor=starknet_cursor(current_block),
)
return cls(stream=stream, requests_queue=requests_queue)

async def run_forever(self) -> None:
"""
Index forever using Apibara and fill the requests_queue when encountering a
VRF request.
"""
logger.info("👩‍💻 Indexing VRF requests using apibara...")
block = Block()
async for message in self.stream:
if message.data is None:
continue
for batch in message.data.data:
block.ParseFromString(batch)
if len(block.events) == 0:
continue
events = block.events
for event in events:
data = list(map(felt.to_int, event.event.data))
data = data[:EVENT_INDEX_TO_SPLIT] + [data[EVENT_INDEX_TO_SPLIT + 1 :]]
await self.requests_queue.put(RandomnessRequest(*data))
75 changes: 75 additions & 0 deletions vrf-listener/vrf_listener/listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import asyncio
import logging

from typing import Optional, List, Set


from pragma_sdk.onchain.types import RandomnessRequest
from pragma_sdk.onchain.client import PragmaOnChainClient

from vrf_listener.indexer import Indexer
from vrf_listener.safe_queue import ThreadSafeQueue

logger = logging.getLogger(__name__)


class Listener:
pragma_client: PragmaOnChainClient
private_key: int
requests_queue: ThreadSafeQueue
check_requests_interval: int
processed_requests: Set[RandomnessRequest]
indexer: Optional[Indexer] = None

def __init__(
self,
pragma_client: PragmaOnChainClient,
private_key: int | str,
requests_queue: ThreadSafeQueue,
check_requests_interval: int,
ignore_request_threshold: int,
indexer: Optional[Indexer] = None,
) -> None:
self.pragma_client = pragma_client
if isinstance(private_key, str):
private_key = int(private_key, 16)
self.private_key = private_key
self.requests_queue = requests_queue
self.check_requests_interval = check_requests_interval
self.ignore_request_threshold = ignore_request_threshold
self.processed_requests = set()
self.indexer = indexer

async def run_forever(self) -> None:
"""
Handle VRF requests forever.
"""
logger.info("👂 Listening for VRF requests!")
while True:
if self.indexer is not None:
events = await self._consume_requests_queue()
try:
await self.pragma_client.handle_random(
private_key=self.private_key,
ignore_request_threshold=self.ignore_request_threshold,
requests_events=events if self.indexer is not None else None,
)
except Exception as e:
logger.error(f"⛔ Error while handling randomness request: {e}")
self.processed_requests.clear()
await asyncio.sleep(self.check_requests_interval)

async def _consume_requests_queue(self) -> List[RandomnessRequest]:
"""
Consumes the whole requests_queue and return the requests.
"""
events = []
while not self.requests_queue.empty():
try:
request: RandomnessRequest = await self.requests_queue.get()
if request not in self.processed_requests:
events.append(request)
self.processed_requests.add(request)
except asyncio.QueueEmpty:
break
return events
134 changes: 27 additions & 107 deletions vrf-listener/vrf_listener/main.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import asyncio
import click
import os
import logging

from pydantic import HttpUrl
from typing import Optional, Literal, List

from grpc import ssl_channel_credentials
from grpc.aio import secure_channel
from apibara.protocol import StreamAddress, StreamService, credentials_with_auth_token
from apibara.protocol.proto.stream_pb2 import DataFinality
from apibara.starknet import Block, EventFilter, Filter, felt, starknet_cursor
from typing import Optional, Literal

from pragma_utils.logger import setup_logging
from pragma_utils.cli import load_private_key_from_cli_arg
from pragma_sdk.onchain.types import ContractAddresses, RandomnessRequest
from pragma_sdk.onchain.constants import RANDOMNESS_REQUEST_EVENT_SELECTOR
from pragma_sdk.onchain.types import ContractAddresses
from pragma_sdk.onchain.client import PragmaOnChainClient

from vrf_listener.safe_queue import ThreadSafeQueue
from vrf_listener.indexer import Indexer
from vrf_listener.listener import Listener

logger = logging.getLogger(__name__)

EVENT_INDEX_TO_SPLIT = 7

# Related to Apibara GRPC - need to disable fork support because of async.
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "0"


async def main(
Expand All @@ -34,6 +34,7 @@ async def main(
apibara_api_key: Optional[str] = None,
) -> None:
logger.info("🧩 Starting VRF listener...")

client = _create_pragma_client(
network=network,
vrf_address=vrf_address,
Expand All @@ -42,31 +43,26 @@ async def main(
rpc_url=rpc_url,
)

requests_queue = ThreadSafeQueue()

if index_with_apibara:
logger.info("👩‍💻 Indexing VRF requests using apibara...")
stream = await _create_apibara_stream(
client=client,
network=network,
indexer = await Indexer.from_client(
pragma_client=client,
apibara_api_key=apibara_api_key,
vrf_address=vrf_address,
requests_queue=requests_queue,
)
requests_queue: asyncio.Queue = asyncio.Queue()
asyncio.create_task(_index_using_apibara(stream, requests_queue))
asyncio.create_task(indexer.run_forever())

listener = Listener(
pragma_client=client,
private_key=private_key,
requests_queue=requests_queue,
check_requests_interval=check_requests_interval,
ignore_request_threshold=ignore_request_threshold,
indexer=indexer if index_with_apibara else None,
)

logger.info("👂 Listening for VRF requests!")
while True:
if index_with_apibara:
events = await _consume_full_queue(requests_queue)
try:
await client.handle_random(
private_key=int(private_key, 16),
ignore_request_threshold=ignore_request_threshold,
requests_events=events if index_with_apibara else None,
)
except Exception as e:
logger.error(f"⛔ Error while handling randomness request: {e}")
raise e
await asyncio.sleep(check_requests_interval)
await listener.run_forever()


def _create_pragma_client(
Expand Down Expand Up @@ -94,82 +90,6 @@ def _create_pragma_client(
return client


async def _create_apibara_stream(
client: PragmaOnChainClient,
network: Literal["mainnet", "sepolia"],
apibara_api_key: Optional[str],
vrf_address: str,
) -> StreamService:
"""
Creates an apibara stream filter that will index VRF requests events.
"""
if apibara_api_key is None:
raise ValueError("--apibara-api-key not provided.")

channel = secure_channel(
StreamAddress.StarkNet.Sepolia if network == "sepolia" else StreamAddress.StarkNet.Mainnet,
credentials_with_auth_token(apibara_api_key, ssl_channel_credentials()),
)
filter = (
Filter()
.with_header(weak=False)
.add_event(
EventFilter()
.with_from_address(felt.from_hex(vrf_address))
.with_keys([felt.from_hex(RANDOMNESS_REQUEST_EVENT_SELECTOR)])
.with_include_receipt(False)
.with_include_transaction(False)
)
.encode()
)
current_block = await client.get_block_number()
return StreamService(channel).stream_data_immutable(
filter=filter,
finality=DataFinality.DATA_STATUS_PENDING,
batch_size=1,
cursor=starknet_cursor(current_block),
)


async def _index_using_apibara(
stream: StreamService,
requests_queue: asyncio.Queue,
) -> None:
"""
Consumes the apibara stream until empty and extract found vrf requests.
"""
block = Block()
async for message in stream:
if message.data is None:
continue
else:
for batch in message.data.data:
block.ParseFromString(batch)
if len(block.events) == 0:
continue
events = block.events
for event in events:
from_data = list(map(felt.to_int, event.event.data))
from_data = from_data[:EVENT_INDEX_TO_SPLIT] + [
from_data[EVENT_INDEX_TO_SPLIT + 1 :]
]
await requests_queue.put(RandomnessRequest(*from_data))


async def _consume_full_queue(requests_queue: asyncio.Queue) -> List[RandomnessRequest]:
"""
Consume the whole requests_queue and return the requests.
"""
events = []
while not requests_queue.empty():
try:
e = requests_queue.get_nowait()
events.append(e)
except asyncio.QueueEmpty:
break
return events


@click.command()
@click.option(
"--log-level",
Expand Down
18 changes: 18 additions & 0 deletions vrf-listener/vrf_listener/safe_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import asyncio


class ThreadSafeQueue:
def __init__(self):
self.queue = asyncio.Queue()
self.lock = asyncio.Lock()

async def put(self, item):
async with self.lock:
await self.queue.put(item)

async def get(self):
async with self.lock:
return await self.queue.get()

def empty(self):
return self.queue.empty()

0 comments on commit 84c989d

Please sign in to comment.