-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dev: VRF Listener Refactoring (#199)
* dev(refacto_vrf_listener): * dev(refacto_vrf_listener):
- Loading branch information
Showing
5 changed files
with
226 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import logging | ||
|
||
from typing import Optional | ||
|
||
from grpc import ssl_channel_credentials | ||
from grpc.aio import secure_channel | ||
from apibara.protocol import StreamAddress, StreamService, credentials_with_auth_token | ||
from apibara.protocol.proto.stream_pb2 import DataFinality | ||
from apibara.starknet import Block, EventFilter, Filter, felt, starknet_cursor | ||
|
||
from pragma_sdk.onchain.types import RandomnessRequest | ||
from pragma_sdk.onchain.constants import RANDOMNESS_REQUEST_EVENT_SELECTOR | ||
from pragma_sdk.onchain.client import PragmaOnChainClient | ||
|
||
from vrf_listener.safe_queue import ThreadSafeQueue | ||
|
||
EVENT_INDEX_TO_SPLIT = 7 | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Indexer: | ||
stream: StreamService | ||
requests_queue: ThreadSafeQueue | ||
|
||
def __init__( | ||
self, | ||
stream: StreamService, | ||
requests_queue: ThreadSafeQueue, | ||
) -> None: | ||
self.stream = stream | ||
self.requests_queue = requests_queue | ||
|
||
@classmethod | ||
async def from_client( | ||
cls, | ||
pragma_client: PragmaOnChainClient, | ||
apibara_api_key: Optional[str], | ||
requests_queue: ThreadSafeQueue, | ||
from_block: Optional[int] = None, | ||
): | ||
""" | ||
Creates an Indexer from a PragmaOnChainClient. | ||
""" | ||
if apibara_api_key is None: | ||
raise ValueError("--apibara-api-key not provided.") | ||
|
||
channel = secure_channel( | ||
StreamAddress.StarkNet.Sepolia | ||
if pragma_client.network == "sepolia" | ||
else StreamAddress.StarkNet.Mainnet, | ||
credentials_with_auth_token(apibara_api_key, ssl_channel_credentials()), | ||
) | ||
filter = ( | ||
Filter() | ||
.with_header(weak=False) | ||
.add_event( | ||
EventFilter() | ||
.with_from_address(felt.from_int(pragma_client.randomness.address)) | ||
.with_keys([felt.from_hex(RANDOMNESS_REQUEST_EVENT_SELECTOR)]) | ||
.with_include_receipt(False) | ||
.with_include_transaction(False) | ||
) | ||
.encode() | ||
) | ||
if from_block: | ||
current_block = from_block | ||
else: | ||
current_block = await pragma_client.get_block_number() | ||
|
||
stream = StreamService(channel).stream_data_immutable( | ||
filter=filter, | ||
finality=DataFinality.DATA_STATUS_PENDING, | ||
batch_size=1, | ||
cursor=starknet_cursor(current_block), | ||
) | ||
return cls(stream=stream, requests_queue=requests_queue) | ||
|
||
async def run_forever(self) -> None: | ||
""" | ||
Index forever using Apibara and fill the requests_queue when encountering a | ||
VRF request. | ||
""" | ||
logger.info("👩💻 Indexing VRF requests using apibara...") | ||
block = Block() | ||
async for message in self.stream: | ||
if message.data is None: | ||
continue | ||
for batch in message.data.data: | ||
block.ParseFromString(batch) | ||
if len(block.events) == 0: | ||
continue | ||
events = block.events | ||
for event in events: | ||
data = list(map(felt.to_int, event.event.data)) | ||
data = data[:EVENT_INDEX_TO_SPLIT] + [data[EVENT_INDEX_TO_SPLIT + 1 :]] | ||
await self.requests_queue.put(RandomnessRequest(*data)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import asyncio | ||
import logging | ||
|
||
from typing import Optional, List, Set | ||
|
||
|
||
from pragma_sdk.onchain.types import RandomnessRequest | ||
from pragma_sdk.onchain.client import PragmaOnChainClient | ||
|
||
from vrf_listener.indexer import Indexer | ||
from vrf_listener.safe_queue import ThreadSafeQueue | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Listener: | ||
pragma_client: PragmaOnChainClient | ||
private_key: int | ||
requests_queue: ThreadSafeQueue | ||
check_requests_interval: int | ||
processed_requests: Set[RandomnessRequest] | ||
indexer: Optional[Indexer] = None | ||
|
||
def __init__( | ||
self, | ||
pragma_client: PragmaOnChainClient, | ||
private_key: int | str, | ||
requests_queue: ThreadSafeQueue, | ||
check_requests_interval: int, | ||
ignore_request_threshold: int, | ||
indexer: Optional[Indexer] = None, | ||
) -> None: | ||
self.pragma_client = pragma_client | ||
if isinstance(private_key, str): | ||
private_key = int(private_key, 16) | ||
self.private_key = private_key | ||
self.requests_queue = requests_queue | ||
self.check_requests_interval = check_requests_interval | ||
self.ignore_request_threshold = ignore_request_threshold | ||
self.processed_requests = set() | ||
self.indexer = indexer | ||
|
||
async def run_forever(self) -> None: | ||
""" | ||
Handle VRF requests forever. | ||
""" | ||
logger.info("👂 Listening for VRF requests!") | ||
while True: | ||
if self.indexer is not None: | ||
events = await self._consume_requests_queue() | ||
try: | ||
await self.pragma_client.handle_random( | ||
private_key=self.private_key, | ||
ignore_request_threshold=self.ignore_request_threshold, | ||
requests_events=events if self.indexer is not None else None, | ||
) | ||
except Exception as e: | ||
logger.error(f"⛔ Error while handling randomness request: {e}") | ||
self.processed_requests.clear() | ||
await asyncio.sleep(self.check_requests_interval) | ||
|
||
async def _consume_requests_queue(self) -> List[RandomnessRequest]: | ||
""" | ||
Consumes the whole requests_queue and return the requests. | ||
""" | ||
events = [] | ||
while not self.requests_queue.empty(): | ||
try: | ||
request: RandomnessRequest = await self.requests_queue.get() | ||
if request not in self.processed_requests: | ||
events.append(request) | ||
self.processed_requests.add(request) | ||
except asyncio.QueueEmpty: | ||
break | ||
return events |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import asyncio | ||
|
||
|
||
class ThreadSafeQueue: | ||
def __init__(self): | ||
self.queue = asyncio.Queue() | ||
self.lock = asyncio.Lock() | ||
|
||
async def put(self, item): | ||
async with self.lock: | ||
await self.queue.put(item) | ||
|
||
async def get(self): | ||
async with self.lock: | ||
return await self.queue.get() | ||
|
||
def empty(self): | ||
return self.queue.empty() |