From db9a8ea5f36a94a6d0b44b17d75a8c4da1ce8690 Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Tue, 2 Jul 2024 22:46:41 -0400 Subject: [PATCH 01/15] logging and writing to db --- src/balanceof_imbalances.py | 9 +-- src/config.py | 32 ++++++++++ src/daemon.py | 113 +++++++++++++++++++++++++----------- src/helper_functions.py | 23 ++++++++ src/imbalances_script.py | 30 +++++----- 5 files changed, 152 insertions(+), 55 deletions(-) create mode 100644 src/helper_functions.py diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index f317f8d..d91ea06 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -1,10 +1,3 @@ -# mypy: disable-error-code="call-overload, arg-type, operator" -import sys -import os - -# for debugging purposes -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - from web3 import Web3 from web3.types import TxReceipt from eth_typing import ChecksumAddress @@ -124,4 +117,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/config.py b/src/config.py index 3d2e71b..8ee40b3 100644 --- a/src/config.py +++ b/src/config.py @@ -1,10 +1,42 @@ import os +import psycopg2 +from sqlalchemy import create_engine from dotenv import load_dotenv +from src.helper_functions import get_logger load_dotenv() ETHEREUM_NODE_URL = os.getenv("ETHEREUM_NODE_URL") GNOSIS_NODE_URL = os.getenv("GNOSIS_NODE_URL") +ARBITRUM_NODE_URL = os.getenv("ARBITRUM_NODE_URL") +SOLVER_SLIPPAGE_DB_URL = os.getenv("SOLVER_SLIPPAGE_DB_URL") + CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL} +# sleep time can be configured here CHAIN_SLEEP_TIMES = {"Ethereum": 60, "Gnosis": 120} + + +def create_read_db_connection(chain_name: str): + """function that creates a connection to the CoW db.""" + if chain_name == "Ethereum": + read_db_url = os.getenv("ETHEREUM_DB_URL") + elif chain_name == "Gnosis": + read_db_url = os.getenv("GNOSIS_DB_URL") + + return create_engine(f"postgresql+psycopg2://{read_db_url}") + + +def create_write_db_connection(): + """Function that creates a connection to the write database.""" + write_db_connection = psycopg2.connect( + database="solver_slippage", + host=os.getenv("SOLVER_SLIPPAGE_HOST"), + user=os.getenv("SOLVER_SLIPPAGE_USER"), + password=os.getenv("SOLVER_SLIPPAGE_PASS"), + port=5432, + ) + return write_db_connection + + +logger = get_logger("raw_token_imbalances") diff --git a/src/daemon.py b/src/daemon.py index d1ff665..9199d29 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -1,14 +1,58 @@ -# mypy: disable-error-code="import, arg-type" -import os +""" +Running this daemon computes raw imbalances for finalized blocks by calling imbalances_script.py. +""" import time +import psycopg2 import pandas as pd from web3 import Web3 from typing import List from threading import Thread -from sqlalchemy import create_engine from sqlalchemy.engine import Engine from src.imbalances_script import RawTokenImbalances -from src.config import CHAIN_RPC_ENDPOINTS, CHAIN_SLEEP_TIMES +from src.config import ( + CHAIN_RPC_ENDPOINTS, + CHAIN_SLEEP_TIMES, + create_read_db_connection, + create_write_db_connection, + logger, +) + + +def write_token_imbalances_to_db( + chain_name: str, + write_db_connection, + auction_id: int, + tx_hash: str, + token_address: str, + imbalance, +): + try: + cursor = write_db_connection.cursor() + # Remove '0x' and then convert hex strings to bytes + tx_hash_bytes = bytes.fromhex(tx_hash[2:]) + token_address_bytes = bytes.fromhex(token_address[2:]) + + insert_sql = """ + INSERT INTO raw_token_imbalances (auction_id, chain_name, tx_hash, token_address, imbalance) + VALUES (%s, %s, %s, %s, %s) + """ + cursor.execute( + insert_sql, + ( + auction_id, + chain_name, + psycopg2.Binary(tx_hash_bytes), + psycopg2.Binary(token_address_bytes), + imbalance, + ), + ) + write_db_connection.commit() + + logger.info("Record inserted successfully.") + except psycopg2.Error as e: + logger.error(f"Error inserting record: {e}") + finally: + cursor.close() def get_web3_instance(chain_name: str) -> Web3: @@ -19,72 +63,75 @@ def get_finalized_block_number(web3: Web3) -> int: return web3.eth.block_number - 64 -def create_db_connection(chain_name: str): - """function that creates a connection to the CoW db.""" - if chain_name == "Ethereum": - db_url = os.getenv("ETHEREUM_DB_URL") - elif chain_name == "Gnosis": - db_url = os.getenv("GNOSIS_DB_URL") - - return create_engine(f"postgresql+psycopg2://{db_url}") - - def fetch_transaction_hashes( - db_connection: Engine, start_block: int, end_block: int + read_db_connection: Engine, start_block: int, end_block: int ) -> List[str]: """Fetch transaction hashes beginning start_block.""" query = f""" - SELECT tx_hash + SELECT tx_hash, auction_id FROM settlements WHERE block_number >= {start_block} AND block_number <= {end_block} """ - db_hashes = pd.read_sql(query, db_connection) + db_data = pd.read_sql(query, read_db_connection) # converts hashes at memory location to hex - db_hashes["tx_hash"] = db_hashes["tx_hash"].apply(lambda x: f"0x{x.hex()}") + db_data["tx_hash"] = db_data["tx_hash"].apply(lambda x: f"0x{x.hex()}") - return db_hashes["tx_hash"].tolist() + # return db_hashes['tx_hash'].tolist(), db_hashes['auction_id'].tolist() + tx_hashes_auction_ids = [ + (row["tx_hash"], row["auction_id"]) for index, row in db_data.iterrows() + ] + return tx_hashes_auction_ids def process_transactions(chain_name: str) -> None: web3 = get_web3_instance(chain_name) rt = RawTokenImbalances(web3, chain_name) sleep_time = CHAIN_SLEEP_TIMES.get(chain_name) - db_connection = create_db_connection(chain_name) - + read_db_connection = create_read_db_connection(chain_name) + write_db_connection = create_write_db_connection() previous_block = get_finalized_block_number(web3) - unprocessed_txs = [] # type: List + unprocessed_txs = [] - print(f"{chain_name} Daemon started.") + logger.info(f"{chain_name} Daemon started.") while True: try: latest_block = get_finalized_block_number(web3) new_txs = fetch_transaction_hashes( - db_connection, previous_block, latest_block + read_db_connection, previous_block, latest_block ) # add any unprocessed hashes for processing, then clear list of unprocessed all_txs = new_txs + unprocessed_txs unprocessed_txs.clear() - for tx in all_txs: - print(f"Processing transaction on {chain_name}: {tx}") + for tx, auction_id in all_txs: + logger.info(f"Processing transaction on {chain_name}: {tx}") try: imbalances = rt.compute_imbalances(tx) - print(f"Token Imbalances on {chain_name}:") + logger.info(f"Token Imbalances on {chain_name}:") for token_address, imbalance in imbalances.items(): - print(f"Token: {token_address}, Imbalance: {imbalance}") + write_token_imbalances_to_db( + chain_name, + write_db_connection, + auction_id, + tx, + token_address, + imbalance, + ) + logger.info(f"Token: {token_address}, Imbalance: {imbalance}") except ValueError as e: - print(e) + logger.error(e) unprocessed_txs.append(tx) - print("Done checks..") previous_block = latest_block + 1 except ConnectionError as e: - print(f"Connection error processing transactions on {chain_name}: {e}") + logger.error( + f"Connection error processing transactions on {chain_name}: {e}" + ) except Exception as e: - print(f"Error processing transactions on {chain_name}: {e}") + logger.error(f"Error processing transactions on {chain_name}: {e}") time.sleep(sleep_time) @@ -102,4 +149,4 @@ def main() -> None: if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/helper_functions.py b/src/helper_functions.py new file mode 100644 index 0000000..a211307 --- /dev/null +++ b/src/helper_functions.py @@ -0,0 +1,23 @@ +""" +This file contains some auxiliary functions +""" +from __future__ import annotations +import logging +from typing import Optional + + +def get_logger(filename: Optional[str] = None) -> logging.Logger: + """ + get_logger() returns a logger object that can write to a file, terminal or only file if needed. + """ + logging.basicConfig(format="%(levelname)s - %(message)s") + logger = logging.getLogger() + logger.setLevel(logging.INFO) + if filename: + file_handler = logging.FileHandler(filename + ".log", mode="w") + file_handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(levelname)s - %(message)s") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger diff --git a/src/imbalances_script.py b/src/imbalances_script.py index cfd5d5b..4e8854d 100644 --- a/src/imbalances_script.py +++ b/src/imbalances_script.py @@ -20,11 +20,11 @@ 9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI transfer involved which has special handling for its events. """ +from web3 import Web3 from web3.datastructures import AttributeDict from typing import Dict, List, Optional, Tuple -from web3 import Web3 from web3.types import TxReceipt -from src.config import CHAIN_RPC_ENDPOINTS +from src.config import CHAIN_RPC_ENDPOINTS, logger from src.constants import ( SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS, @@ -54,14 +54,14 @@ def find_chain_with_tx(tx_hash: str) -> Tuple[str, Web3]: for chain_name, url in CHAIN_RPC_ENDPOINTS.items(): web3 = Web3(Web3.HTTPProvider(url)) if not web3.is_connected(): - print(f"Could not connect to {chain_name}.") + logger.warning(f"Could not connect to {chain_name}.") continue try: web3.eth.get_transaction_receipt(tx_hash) - print(f"Transaction found on {chain_name}.") + logger.info(f"Transaction found on {chain_name}.") return chain_name, web3 except Exception as e: - print(f"Transaction not found on {chain_name}: {e}") + logger.debug(f"Transaction not found on {chain_name}: {e}") raise ValueError(f"Transaction hash {tx_hash} not found on any chain.") @@ -74,7 +74,8 @@ def _to_int(value: str | int) -> int: else int(value) ) except ValueError: - print(f"Error converting value {value} to integer.") + logger.error(f"Error converting value {value} to integer.") + class RawTokenImbalances: @@ -89,7 +90,7 @@ def get_transaction_receipt(self, tx_hash: str) -> Optional[TxReceipt]: try: return self.web3.eth.get_transaction_receipt(tx_hash) except Exception as e: - print(f"Error getting transaction receipt: {e}") + logger.error(f"Error getting transaction receipt: {e}") return None def get_transaction_trace(self, tx_hash: str) -> Optional[List[Dict]]: @@ -98,7 +99,7 @@ def get_transaction_trace(self, tx_hash: str) -> Optional[List[Dict]]: res = self.web3.tracing.trace_transaction(tx_hash) return res except Exception as err: - print(f"Error occurred while fetching transaction trace: {err}") + logger.error(f"Error occurred while fetching transaction trace: {err}") return None def extract_actions(self, traces: List[AttributeDict], address: str) -> List[Dict]: @@ -149,7 +150,7 @@ def extract_events(self, tx_receipt: Dict) -> Dict[str, List[Dict]]: k: v for k, v in event_topics.items() if k not in transfer_topics } - events = {name: [] for name in EVENT_TOPICS} # type: dict + events = {name: [] for name in EVENT_TOPICS} for log in tx_receipt["logs"]: log_topic = log["topics"][0].hex() if log_topic in transfer_topics.values(): @@ -187,7 +188,7 @@ def decode_event( else: # Withdrawal event return from_address, None, value except Exception as e: - print(f"Error decoding event: {str(e)}") + logger.error(f"Error decoding event: {str(e)}") return None, None, None def process_event( @@ -256,7 +257,7 @@ def decode_sdai_event(self, event: Dict) -> int | None: value = int(value_hex, 16) return value except Exception as e: - print(f"Error decoding sDAI event: {str(e)}") + logger.error(f"Error decoding sDAI event: {str(e)}") return None def process_sdai_event( @@ -324,11 +325,12 @@ def main() -> None: rt = RawTokenImbalances(web3, chain_name) try: imbalances = rt.compute_imbalances(tx_hash) - print(f"Token Imbalances on {chain_name}:") + logger.info(f"Token Imbalances on {chain_name}:") for token_address, imbalance in imbalances.items(): - print(f"Token: {token_address}, Imbalance: {imbalance}") + logger.info(f"Token: {token_address}, Imbalance: {imbalance}") except ValueError as e: - print(e) + logger.error(e) + if __name__ == "__main__": From 8b5abe0c227de6b068123483a51fffd9bf351f8c Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Wed, 3 Jul 2024 13:40:41 -0400 Subject: [PATCH 02/15] edit .envsample --- .env.sample | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.env.sample b/.env.sample index dd89db4..df489ff 100644 --- a/.env.sample +++ b/.env.sample @@ -8,5 +8,10 @@ GNOSIS_DB_URL= ETHEREUM_NODE_URL= GNOSIS_NODE_URL= +# credentials for writing to db connection +SOLVER_SLIPPAGE_HOST= +SOLVER_SLIPPAGE_USER= +SOLVER_SLIPPAGE_PASS= + # optional INFURA_KEY=infura_key_here From 4fde33be426b29efc70d203742bbb5dc2065ee32 Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Thu, 4 Jul 2024 13:59:27 -0400 Subject: [PATCH 03/15] pylint 1 --- contracts/erc20_abi.py | 3 +++ src/balanceof_imbalances.py | 2 +- src/constants.py | 1 + src/daemon.py | 39 +++++++++++++++++++++++++------------ src/imbalances_script.py | 38 ++++++++++++++++++++---------------- tests/basic_test.py | 4 ++++ 6 files changed, 57 insertions(+), 30 deletions(-) diff --git a/contracts/erc20_abi.py b/contracts/erc20_abi.py index 03a62f2..9a0c632 100644 --- a/contracts/erc20_abi.py +++ b/contracts/erc20_abi.py @@ -1,3 +1,6 @@ +""" +ERC20 ABI contract +""" erc20_abi = [ { "constant": True, diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index d91ea06..98cd700 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -117,4 +117,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/constants.py b/src/constants.py index bbe179e..63fa527 100644 --- a/src/constants.py +++ b/src/constants.py @@ -1,3 +1,4 @@ +""" Constants used for the token imbalances project """ from web3 import Web3 SETTLEMENT_CONTRACT_ADDRESS = Web3.to_checksum_address( diff --git a/src/daemon.py b/src/daemon.py index 9199d29..33f9f2e 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -2,11 +2,11 @@ Running this daemon computes raw imbalances for finalized blocks by calling imbalances_script.py. """ import time +from typing import List, Tuple +from threading import Thread import psycopg2 import pandas as pd from web3 import Web3 -from typing import List -from threading import Thread from sqlalchemy.engine import Engine from src.imbalances_script import RawTokenImbalances from src.config import ( @@ -26,6 +26,9 @@ def write_token_imbalances_to_db( token_address: str, imbalance, ): + """ + Write token imbalances to the database. + """ try: cursor = write_db_connection.cursor() # Remove '0x' and then convert hex strings to bytes @@ -50,23 +53,29 @@ def write_token_imbalances_to_db( logger.info("Record inserted successfully.") except psycopg2.Error as e: - logger.error(f"Error inserting record: {e}") + logger.error("Error inserting record: %s", e) finally: cursor.close() def get_web3_instance(chain_name: str) -> Web3: + """ + returns a Web3 instance for the given blockchain via chain name. + """ return Web3(Web3.HTTPProvider(CHAIN_RPC_ENDPOINTS[chain_name])) def get_finalized_block_number(web3: Web3) -> int: + """ + Get the number of the most recent finalized block. + """ return web3.eth.block_number - 64 def fetch_transaction_hashes( read_db_connection: Engine, start_block: int, end_block: int ) -> List[str]: - """Fetch transaction hashes beginning start_block.""" + """Fetch transaction hashes beginning from start_block to end_block. """ query = f""" SELECT tx_hash, auction_id FROM settlements @@ -86,6 +95,9 @@ def fetch_transaction_hashes( def process_transactions(chain_name: str) -> None: + """ + Process transactions to compute imbalances for a given blockchain via chain name. + """ web3 = get_web3_instance(chain_name) rt = RawTokenImbalances(web3, chain_name) sleep_time = CHAIN_SLEEP_TIMES.get(chain_name) @@ -94,7 +106,7 @@ def process_transactions(chain_name: str) -> None: previous_block = get_finalized_block_number(web3) unprocessed_txs = [] - logger.info(f"{chain_name} Daemon started.") + logger.info("%s Daemon started.", chain_name) while True: try: @@ -107,10 +119,10 @@ def process_transactions(chain_name: str) -> None: unprocessed_txs.clear() for tx, auction_id in all_txs: - logger.info(f"Processing transaction on {chain_name}: {tx}") + logger.info("Processing transaction on %s: %s", chain_name, tx) try: imbalances = rt.compute_imbalances(tx) - logger.info(f"Token Imbalances on {chain_name}:") + logger.info("Token Imbalances on %s:", chain_name) for token_address, imbalance in imbalances.items(): write_token_imbalances_to_db( chain_name, @@ -120,23 +132,26 @@ def process_transactions(chain_name: str) -> None: token_address, imbalance, ) - logger.info(f"Token: {token_address}, Imbalance: {imbalance}") + logger.info("Token: %s, Imbalance: %s", token_address, imbalance) except ValueError as e: - logger.error(e) + logger.error("ValueError: %s", e) unprocessed_txs.append(tx) previous_block = latest_block + 1 except ConnectionError as e: logger.error( - f"Connection error processing transactions on {chain_name}: {e}" + "Connection error processing transactions on %s: %s", chain_name, e ) except Exception as e: - logger.error(f"Error processing transactions on {chain_name}: {e}") + logger.error("Error processing transactions on %s: %s", chain_name, e) time.sleep(sleep_time) def main() -> None: + """ + Main function to start the daemon threads for each blockchain. + """ threads = [] for chain_name in CHAIN_RPC_ENDPOINTS.keys(): @@ -149,4 +164,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/imbalances_script.py b/src/imbalances_script.py index 4e8854d..1523f56 100644 --- a/src/imbalances_script.py +++ b/src/imbalances_script.py @@ -15,15 +15,18 @@ adding the transfer value to existing inflow/outflow for the token addresses. 7. Returning to calculate_imbalances(), which finds the imbalance for all token addresses using inflow-outflow. -8. If actions are not None, it denotes an ETH transfer event, which involves reducing WETH withdrawal - amount- > update_weth_imbalance(). The ETH imbalance is also calculated via -> update_native_eth_imbalance(). -9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI transfer - involved which has special handling for its events. +8. If actions are not None, it denotes an ETH transfer event, which involves reducing WETH + withdrawal amount- > update_weth_imbalance(). The ETH imbalance is also calculated + via -> update_native_eth_imbalance(). +9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI + transfer involved which has special handling for its events. """ +from typing import Dict, List, Optional, Tuple + from web3 import Web3 from web3.datastructures import AttributeDict -from typing import Dict, List, Optional, Tuple from web3.types import TxReceipt + from src.config import CHAIN_RPC_ENDPOINTS, logger from src.constants import ( SETTLEMENT_CONTRACT_ADDRESS, @@ -54,17 +57,18 @@ def find_chain_with_tx(tx_hash: str) -> Tuple[str, Web3]: for chain_name, url in CHAIN_RPC_ENDPOINTS.items(): web3 = Web3(Web3.HTTPProvider(url)) if not web3.is_connected(): - logger.warning(f"Could not connect to {chain_name}.") + logger.warning("Could not connect to %s.", chain_name) continue try: web3.eth.get_transaction_receipt(tx_hash) - logger.info(f"Transaction found on {chain_name}.") + logger.info("Transaction found on %s.", chain_name) return chain_name, web3 - except Exception as e: - logger.debug(f"Transaction not found on {chain_name}: {e}") + except Exception as ex: + logger.debug("Transaction not found on %s: %s", chain_name, ex) raise ValueError(f"Transaction hash {tx_hash} not found on any chain.") + def _to_int(value: str | int) -> int: """Convert hex string or integer to integer.""" try: @@ -74,11 +78,12 @@ def _to_int(value: str | int) -> int: else int(value) ) except ValueError: - logger.error(f"Error converting value {value} to integer.") - + logger.error("Error converting value %s to integer.", value) class RawTokenImbalances: + """Class for computing token imbalances.""" + def __init__(self, web3: Web3, chain_name: str): self.web3 = web3 self.chain_name = chain_name @@ -89,8 +94,8 @@ def get_transaction_receipt(self, tx_hash: str) -> Optional[TxReceipt]: """ try: return self.web3.eth.get_transaction_receipt(tx_hash) - except Exception as e: - logger.error(f"Error getting transaction receipt: {e}") + except Exception as ex: + logger.error("Error getting transaction receipt: %s", ex) return None def get_transaction_trace(self, tx_hash: str) -> Optional[List[Dict]]: @@ -99,7 +104,7 @@ def get_transaction_trace(self, tx_hash: str) -> Optional[List[Dict]]: res = self.web3.tracing.trace_transaction(tx_hash) return res except Exception as err: - logger.error(f"Error occurred while fetching transaction trace: {err}") + logger.error("Error occurred while fetching transaction trace: %s", err) return None def extract_actions(self, traces: List[AttributeDict], address: str) -> List[Dict]: @@ -188,7 +193,7 @@ def decode_event( else: # Withdrawal event return from_address, None, value except Exception as e: - logger.error(f"Error decoding event: {str(e)}") + logger.error("Error decoding event: %s", str(e)) return None, None, None def process_event( @@ -318,8 +323,8 @@ def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: return imbalances -# main method for finding imbalance for a single tx hash def main() -> None: + """ main function for finding imbalance for a single tx hash. """ tx_hash = input("Enter transaction hash: ") chain_name, web3 = find_chain_with_tx(tx_hash) rt = RawTokenImbalances(web3, chain_name) @@ -332,6 +337,5 @@ def main() -> None: logger.error(e) - if __name__ == "__main__": main() diff --git a/tests/basic_test.py b/tests/basic_test.py index 4598d5e..f83b857 100644 --- a/tests/basic_test.py +++ b/tests/basic_test.py @@ -1,3 +1,4 @@ +""" Runs a basic test for raw imbalance calculation edge-cases. """ import pytest from src.imbalances_script import RawTokenImbalances @@ -33,6 +34,9 @@ ], ) def test_imbalances(tx_hash, expected_imbalances): + """ + Asserts imbalances match for main script with test values provided. + """ rt = RawTokenImbalances() imbalances, _ = rt.compute_imbalances(tx_hash) for token_address, expected_imbalance in expected_imbalances.items(): From 25596c1d642468314ebacad670603c2863099be7 Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Thu, 4 Jul 2024 14:01:31 -0400 Subject: [PATCH 04/15] black --- src/daemon.py | 6 ++++-- src/imbalances_script.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/daemon.py b/src/daemon.py index 33f9f2e..ac863ff 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -75,7 +75,7 @@ def get_finalized_block_number(web3: Web3) -> int: def fetch_transaction_hashes( read_db_connection: Engine, start_block: int, end_block: int ) -> List[str]: - """Fetch transaction hashes beginning from start_block to end_block. """ + """Fetch transaction hashes beginning from start_block to end_block.""" query = f""" SELECT tx_hash, auction_id FROM settlements @@ -132,7 +132,9 @@ def process_transactions(chain_name: str) -> None: token_address, imbalance, ) - logger.info("Token: %s, Imbalance: %s", token_address, imbalance) + logger.info( + "Token: %s, Imbalance: %s", token_address, imbalance + ) except ValueError as e: logger.error("ValueError: %s", e) unprocessed_txs.append(tx) diff --git a/src/imbalances_script.py b/src/imbalances_script.py index 1523f56..4b8420c 100644 --- a/src/imbalances_script.py +++ b/src/imbalances_script.py @@ -68,7 +68,6 @@ def find_chain_with_tx(tx_hash: str) -> Tuple[str, Web3]: raise ValueError(f"Transaction hash {tx_hash} not found on any chain.") - def _to_int(value: str | int) -> int: """Convert hex string or integer to integer.""" try: @@ -324,7 +323,7 @@ def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: def main() -> None: - """ main function for finding imbalance for a single tx hash. """ + """main function for finding imbalance for a single tx hash.""" tx_hash = input("Enter transaction hash: ") chain_name, web3 = find_chain_with_tx(tx_hash) rt = RawTokenImbalances(web3, chain_name) From ab00ec47dbc394bbad71096349ff3ec0df1f535d Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Thu, 4 Jul 2024 14:25:42 -0400 Subject: [PATCH 05/15] mypy 1 --- src/balanceof_imbalances.py | 16 ++++++++-------- src/config.py | 2 +- src/daemon.py | 23 +++++++++++------------ src/imbalances_script.py | 2 +- 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index 98cd700..c74bf8d 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -1,7 +1,7 @@ from web3 import Web3 -from web3.types import TxReceipt +from web3.types import TxReceipt, HexStr from eth_typing import ChecksumAddress -from typing import Dict, Optional, Set +from typing import Dict, Optional, Set, Any from src.config import ETHEREUM_NODE_URL from src.constants import SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS from contracts.erc20_abi import erc20_abi @@ -36,7 +36,7 @@ def get_eth_balance(self, account: str, block_identifier: int) -> Optional[int]: def extract_token_addresses(self, tx_receipt: Dict) -> Set[str]: """Extract unique token addresses from 'Transfer' events in a transaction receipt.""" - token_addresses = set() + token_addresses: Set[ChecksumAddress] = set() transfer_topics = { self.web3.keccak(text="Transfer(address,address,uint256)").hex(), self.web3.keccak(text="ERC20Transfer(address,address,uint256)").hex(), @@ -44,7 +44,7 @@ def extract_token_addresses(self, tx_receipt: Dict) -> Set[str]: } for log in tx_receipt["logs"]: if log["topics"][0].hex() in transfer_topics: - token_addresses.add(log["address"]) + token_addresses.add(self.web3.to_checksum_address(log["address"])) return token_addresses def get_transaction_receipt(self, tx_hash: str) -> Optional[TxReceipt]: @@ -59,7 +59,7 @@ def get_balances( self, token_addresses: Set[ChecksumAddress], block_number: int ) -> Dict[ChecksumAddress, Optional[int]]: """Get balances for all tokens at the given block number.""" - balances = {} + balances: Dict[ChecksumAddress, Optional[int]] = {} balances[NATIVE_ETH_TOKEN_ADDRESS] = self.get_eth_balance( SETTLEMENT_CONTRACT_ADDRESS, block_number ) @@ -73,8 +73,8 @@ def get_balances( def calculate_imbalances( self, - prev_balances: Dict[str, Optional[int]], - final_balances: Dict[str, Optional[int]], + prev_balances: Dict[ChecksumAddress, Optional[int]], + final_balances: Dict[ChecksumAddress, Optional[int]], ) -> Dict[str, int]: """Calculate imbalances between previous and final balances.""" imbalances = {} @@ -87,7 +87,7 @@ def calculate_imbalances( imbalances[token_address] = imbalance return imbalances - def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: + def compute_imbalances(self, tx_hash: HexStr) -> Dict[ChecksumAddress, int]: """Compute token imbalances before and after a transaction.""" tx_receipt = self.get_transaction_receipt(tx_hash) if tx_receipt is None: diff --git a/src/config.py b/src/config.py index 8ee40b3..81378c8 100644 --- a/src/config.py +++ b/src/config.py @@ -14,7 +14,7 @@ CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL} # sleep time can be configured here -CHAIN_SLEEP_TIMES = {"Ethereum": 60, "Gnosis": 120} +CHAIN_SLEEP_TIMES = {"Ethereum": 60.0, "Gnosis": 120.0} def create_read_db_connection(chain_name: str): diff --git a/src/daemon.py b/src/daemon.py index ac863ff..373b292 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -2,11 +2,12 @@ Running this daemon computes raw imbalances for finalized blocks by calling imbalances_script.py. """ import time -from typing import List, Tuple +from typing import List, Tuple, Dict, Any from threading import Thread import psycopg2 import pandas as pd from web3 import Web3 +from web3.types import ChecksumAddress, HexStr, TxReceipt from sqlalchemy.engine import Engine from src.imbalances_script import RawTokenImbalances from src.config import ( @@ -20,12 +21,12 @@ def write_token_imbalances_to_db( chain_name: str, - write_db_connection, + write_db_connection: Any, auction_id: int, tx_hash: str, token_address: str, - imbalance, -): + imbalance: float, +) -> None: """ Write token imbalances to the database. """ @@ -74,7 +75,7 @@ def get_finalized_block_number(web3: Web3) -> int: def fetch_transaction_hashes( read_db_connection: Engine, start_block: int, end_block: int -) -> List[str]: +) -> List[Tuple[str, int]]: """Fetch transaction hashes beginning from start_block to end_block.""" query = f""" SELECT tx_hash, auction_id @@ -87,11 +88,9 @@ def fetch_transaction_hashes( # converts hashes at memory location to hex db_data["tx_hash"] = db_data["tx_hash"].apply(lambda x: f"0x{x.hex()}") - # return db_hashes['tx_hash'].tolist(), db_hashes['auction_id'].tolist() - tx_hashes_auction_ids = [ - (row["tx_hash"], row["auction_id"]) for index, row in db_data.iterrows() - ] - return tx_hashes_auction_ids + # return (tx hash, auction id) as tx_data + tx_data = [(row["tx_hash"], row["auction_id"]) for index, row in db_data.iterrows()] + return tx_data def process_transactions(chain_name: str) -> None: @@ -104,7 +103,7 @@ def process_transactions(chain_name: str) -> None: read_db_connection = create_read_db_connection(chain_name) write_db_connection = create_write_db_connection() previous_block = get_finalized_block_number(web3) - unprocessed_txs = [] + unprocessed_txs: List[Tuple[str, int]] = [] logger.info("%s Daemon started.", chain_name) @@ -137,7 +136,7 @@ def process_transactions(chain_name: str) -> None: ) except ValueError as e: logger.error("ValueError: %s", e) - unprocessed_txs.append(tx) + unprocessed_txs.append((tx, auction_id)) previous_block = latest_block + 1 except ConnectionError as e: diff --git a/src/imbalances_script.py b/src/imbalances_script.py index 4b8420c..5c1508d 100644 --- a/src/imbalances_script.py +++ b/src/imbalances_script.py @@ -154,7 +154,7 @@ def extract_events(self, tx_receipt: Dict) -> Dict[str, List[Dict]]: k: v for k, v in event_topics.items() if k not in transfer_topics } - events = {name: [] for name in EVENT_TOPICS} + events: Dict[str, List[Dict]] = {name: [] for name in EVENT_TOPICS} for log in tx_receipt["logs"]: log_topic = log["topics"][0].hex() if log_topic in transfer_topics.values(): From 692692aadd89b8ae73c2f8e5abc43154384a8fde Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Thu, 4 Jul 2024 14:37:58 -0400 Subject: [PATCH 06/15] mypy 2 --- requirements.txt | 4 +++- src/balanceof_imbalances.py | 29 ++++++++++++++++++++--------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index 247a178..ebf26cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,6 @@ black==23.3.0 mypy==1.4.1 pylint==3.2.5 pytest==7.4.0 -setuptools \ No newline at end of file +setuptools +pandas-stubs +types-psycopg2 diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index c74bf8d..32c1160 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -14,7 +14,10 @@ def __init__(self, ETHEREUM_NODE_URL: str): self.web3 = Web3(Web3.HTTPProvider(ETHEREUM_NODE_URL)) def get_token_balance( - self, token_address: str, account: str, block_identifier: int + self, + token_address: ChecksumAddress, + account: ChecksumAddress, + block_identifier: int, ) -> Optional[int]: """Retrieve the ERC-20 token balance of an account at a given block.""" token_contract = self.web3.eth.contract(address=token_address, abi=erc20_abi) @@ -26,7 +29,9 @@ def get_token_balance( print(f"Error fetching balance for token {token_address}: {e}") return None - def get_eth_balance(self, account: str, block_identifier: int) -> Optional[int]: + def get_eth_balance( + self, account: ChecksumAddress, block_identifier: int + ) -> Optional[int]: """Get the ETH balance for a given account and block number.""" try: return self.web3.eth.get_balance(account, block_identifier=block_identifier) @@ -34,7 +39,9 @@ def get_eth_balance(self, account: str, block_identifier: int) -> Optional[int]: print(f"Error fetching ETH balance: {e}") return None - def extract_token_addresses(self, tx_receipt: Dict) -> Set[str]: + def extract_token_addresses( + self, tx_receipt: Dict[Any, Any] + ) -> Set[ChecksumAddress]: """Extract unique token addresses from 'Transfer' events in a transaction receipt.""" token_addresses: Set[ChecksumAddress] = set() transfer_topics = { @@ -47,7 +54,7 @@ def extract_token_addresses(self, tx_receipt: Dict) -> Set[str]: token_addresses.add(self.web3.to_checksum_address(log["address"])) return token_addresses - def get_transaction_receipt(self, tx_hash: str) -> Optional[TxReceipt]: + def get_transaction_receipt(self, tx_hash: HexStr) -> Optional[TxReceipt]: """Fetch the transaction receipt for the given hash.""" try: return self.web3.eth.get_transaction_receipt(tx_hash) @@ -60,13 +67,17 @@ def get_balances( ) -> Dict[ChecksumAddress, Optional[int]]: """Get balances for all tokens at the given block number.""" balances: Dict[ChecksumAddress, Optional[int]] = {} - balances[NATIVE_ETH_TOKEN_ADDRESS] = self.get_eth_balance( - SETTLEMENT_CONTRACT_ADDRESS, block_number + balances[ + self.web3.to_checksum_address(NATIVE_ETH_TOKEN_ADDRESS) + ] = self.get_eth_balance( + self.web3.to_checksum_address(SETTLEMENT_CONTRACT_ADDRESS), block_number ) for token_address in token_addresses: balances[token_address] = self.get_token_balance( - token_address, SETTLEMENT_CONTRACT_ADDRESS, block_number + token_address, + self.web3.to_checksum_address(SETTLEMENT_CONTRACT_ADDRESS), + block_number, ) return balances @@ -75,9 +86,9 @@ def calculate_imbalances( self, prev_balances: Dict[ChecksumAddress, Optional[int]], final_balances: Dict[ChecksumAddress, Optional[int]], - ) -> Dict[str, int]: + ) -> Dict[ChecksumAddress, int]: """Calculate imbalances between previous and final balances.""" - imbalances = {} + imbalances: Dict[ChecksumAddress, int] = {} for token_address in prev_balances: if ( prev_balances[token_address] is not None From bb0b4b41b5b6ce9ad4e7eff66c8ebe3e0c555738 Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Thu, 4 Jul 2024 14:51:55 -0400 Subject: [PATCH 07/15] mypy fix --- src/balanceof_imbalances.py | 2 +- src/daemon.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index 32c1160..6c11cbb 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -40,7 +40,7 @@ def get_eth_balance( return None def extract_token_addresses( - self, tx_receipt: Dict[Any, Any] + self, tx_receipt: TxReceipt ) -> Set[ChecksumAddress]: """Extract unique token addresses from 'Transfer' events in a transaction receipt.""" token_addresses: Set[ChecksumAddress] = set() diff --git a/src/daemon.py b/src/daemon.py index 373b292..a7bbd33 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -145,8 +145,8 @@ def process_transactions(chain_name: str) -> None: ) except Exception as e: logger.error("Error processing transactions on %s: %s", chain_name, e) - - time.sleep(sleep_time) + if sleep_time is not None: + time.sleep(sleep_time) def main() -> None: From 11a586b2c25b6d55628d8fe4bcded623d5a43448 Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Thu, 4 Jul 2024 14:53:50 -0400 Subject: [PATCH 08/15] mypy fix --- src/balanceof_imbalances.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index 6c11cbb..66a9418 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -39,9 +39,7 @@ def get_eth_balance( print(f"Error fetching ETH balance: {e}") return None - def extract_token_addresses( - self, tx_receipt: TxReceipt - ) -> Set[ChecksumAddress]: + def extract_token_addresses(self, tx_receipt: TxReceipt) -> Set[ChecksumAddress]: """Extract unique token addresses from 'Transfer' events in a transaction receipt.""" token_addresses: Set[ChecksumAddress] = set() transfer_topics = { From 6e266e9f2c700374489afb157c1134b00333ab13 Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Thu, 4 Jul 2024 14:59:54 -0400 Subject: [PATCH 09/15] mypy fixes --- src/balanceof_imbalances.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index 66a9418..400be6e 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -92,7 +92,11 @@ def calculate_imbalances( prev_balances[token_address] is not None and final_balances[token_address] is not None ): - imbalance = final_balances[token_address] - prev_balances[token_address] + prev_balance = prev_balances[token_address] + assert prev_balance is not None # Ensure prev_balance is not None + final_balance = final_balances[token_address] + assert final_balance is not None # Ensure final_balance is not None + imbalance = final_balance - prev_balance imbalances[token_address] = imbalance return imbalances From 30c3a589722860aa73776f412f395dfc976067bc Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Sun, 7 Jul 2024 02:17:35 -0400 Subject: [PATCH 10/15] check if exists, small changes --- .env.sample | 10 +++--- src/config.py | 31 +++++++++-------- src/daemon.py | 92 ++++++++++++++++++++++++++++++++++++--------------- 3 files changed, 89 insertions(+), 44 deletions(-) diff --git a/.env.sample b/.env.sample index df489ff..a2922e8 100644 --- a/.env.sample +++ b/.env.sample @@ -8,10 +8,12 @@ GNOSIS_DB_URL= ETHEREUM_NODE_URL= GNOSIS_NODE_URL= -# credentials for writing to db connection -SOLVER_SLIPPAGE_HOST= -SOLVER_SLIPPAGE_USER= -SOLVER_SLIPPAGE_PASS= +# add credentials for connecting to solver slippage DB based on this format +SOLVER_SLIPPAGE_DB_URL=postgresql://username:password@hostname:port/database + +# configure chain sleep time +ETHEREUM_SLEEP_TIME= +GNOSIS_SLEEP_TIME= # optional INFURA_KEY=infura_key_here diff --git a/src/config.py b/src/config.py index 81378c8..184a829 100644 --- a/src/config.py +++ b/src/config.py @@ -1,23 +1,24 @@ import os import psycopg2 -from sqlalchemy import create_engine +from sqlalchemy import create_engine, Engine from dotenv import load_dotenv +from urllib.parse import urlparse +from psycopg2.extensions import connection as Psycopg2Connection from src.helper_functions import get_logger load_dotenv() ETHEREUM_NODE_URL = os.getenv("ETHEREUM_NODE_URL") GNOSIS_NODE_URL = os.getenv("GNOSIS_NODE_URL") -ARBITRUM_NODE_URL = os.getenv("ARBITRUM_NODE_URL") -SOLVER_SLIPPAGE_DB_URL = os.getenv("SOLVER_SLIPPAGE_DB_URL") - CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL} -# sleep time can be configured here -CHAIN_SLEEP_TIMES = {"Ethereum": 60.0, "Gnosis": 120.0} +CHAIN_SLEEP_TIMES = { + "Ethereum": float(os.getenv("ETHEREUM_SLEEP_TIME")), + "Gnosis": float(os.getenv("GNOSIS_SLEEP_TIME")), +} -def create_read_db_connection(chain_name: str): +def create_read_db_connection(chain_name: str) -> Engine: """function that creates a connection to the CoW db.""" if chain_name == "Ethereum": read_db_url = os.getenv("ETHEREUM_DB_URL") @@ -27,14 +28,18 @@ def create_read_db_connection(chain_name: str): return create_engine(f"postgresql+psycopg2://{read_db_url}") -def create_write_db_connection(): +def create_write_db_connection() -> Psycopg2Connection: """Function that creates a connection to the write database.""" + + parsed_url = urlparse(os.getenv("SOLVER_SLIPPAGE_DB_URL")) + + # Connect to the database write_db_connection = psycopg2.connect( - database="solver_slippage", - host=os.getenv("SOLVER_SLIPPAGE_HOST"), - user=os.getenv("SOLVER_SLIPPAGE_USER"), - password=os.getenv("SOLVER_SLIPPAGE_PASS"), - port=5432, + database=parsed_url.path[1:], + user=parsed_url.username, + password=parsed_url.password, + host=parsed_url.hostname, + port=parsed_url.port, ) return write_db_connection diff --git a/src/daemon.py b/src/daemon.py index a7bbd33..ac0401c 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -1,6 +1,7 @@ """ Running this daemon computes raw imbalances for finalized blocks by calling imbalances_script.py. """ + import time from typing import List, Tuple, Dict, Any from threading import Thread @@ -19,46 +20,83 @@ ) -def write_token_imbalances_to_db( - chain_name: str, +def record_exists( write_db_connection: Any, - auction_id: int, - tx_hash: str, - token_address: str, - imbalance: float, -) -> None: + tx_hash_bytes: bytes, + token_address_bytes: bytes, +) -> bool: """ - Write token imbalances to the database. + Check if a record with the given (tx_hash, token_address) already exists in the database. """ try: cursor = write_db_connection.cursor() - # Remove '0x' and then convert hex strings to bytes - tx_hash_bytes = bytes.fromhex(tx_hash[2:]) - token_address_bytes = bytes.fromhex(token_address[2:]) - insert_sql = """ - INSERT INTO raw_token_imbalances (auction_id, chain_name, tx_hash, token_address, imbalance) - VALUES (%s, %s, %s, %s, %s) + # Check if the record exists + check_sql = """ + SELECT 1 FROM raw_token_imbalances + WHERE tx_hash = %s AND token_address = %s """ cursor.execute( - insert_sql, - ( - auction_id, - chain_name, - psycopg2.Binary(tx_hash_bytes), - psycopg2.Binary(token_address_bytes), - imbalance, - ), + check_sql, + (psycopg2.Binary(tx_hash_bytes), psycopg2.Binary(token_address_bytes)), ) - write_db_connection.commit() + record_exists = cursor.fetchone() - logger.info("Record inserted successfully.") + return record_exists is not None except psycopg2.Error as e: - logger.error("Error inserting record: %s", e) + logger.error("Error checking record existence: %s", e) + return False finally: cursor.close() +def write_token_imbalances_to_db( + chain_name: str, + write_db_connection: Any, + auction_id: int, + tx_hash: str, + token_address: str, + imbalance: float, +) -> None: + """ + Write token imbalances to the database if the (tx_hash, token_address) combination does not already exist. + """ + tx_hash_bytes = bytes.fromhex(tx_hash[2:]) + token_address_bytes = bytes.fromhex(token_address[2:]) + + if not record_exists(write_db_connection, tx_hash_bytes, token_address_bytes): + try: + cursor = write_db_connection.cursor() + # Convert hex strings to bytes for database insertion + + insert_sql = """ + INSERT INTO raw_token_imbalances (auction_id, chain_name, tx_hash, token_address, imbalance) + VALUES (%s, %s, %s, %s, %s) + """ + cursor.execute( + insert_sql, + ( + auction_id, + chain_name, + psycopg2.Binary(tx_hash_bytes), + psycopg2.Binary(token_address_bytes), + imbalance, + ), + ) + write_db_connection.commit() + logger.info("Record inserted successfully.") + except psycopg2.Error as e: + logger.error("Error inserting record: %s", e) + finally: + cursor.close() + else: + logger.info( + "Record with tx_hash %s and token_address %s already exists.", + tx_hash, + token_address, + ) + + def get_web3_instance(chain_name: str) -> Web3: """ returns a Web3 instance for the given blockchain via chain name. @@ -80,8 +118,8 @@ def fetch_transaction_hashes( query = f""" SELECT tx_hash, auction_id FROM settlements - WHERE block_number >= {start_block} - AND block_number <= {end_block} + WHERE block_number >= 20201300 + AND block_number <= 20201400 """ db_data = pd.read_sql(query, read_db_connection) From 1b0afa91b456703813adfa104e3f25fc1780ffcd Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Sun, 7 Jul 2024 02:31:40 -0400 Subject: [PATCH 11/15] mypy --- src/config.py | 23 +++++++++++++++++++++-- src/daemon.py | 7 +++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/config.py b/src/config.py index 184a829..e777df4 100644 --- a/src/config.py +++ b/src/config.py @@ -12,9 +12,22 @@ CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL} + +# function for safe conversion to float (prevents None -> float conversion issues raised by mypy) +def get_env_float(var_name: str) -> float: + """Retrieve environment variable and convert to float. Raise an error if not set.""" + value = os.getenv(var_name) + if value is None: + raise ValueError(f"Environment variable {var_name} is not set.") + try: + return float(value) + except ValueError: + raise ValueError(f"Environment variable {var_name} must be a float.") + + CHAIN_SLEEP_TIMES = { - "Ethereum": float(os.getenv("ETHEREUM_SLEEP_TIME")), - "Gnosis": float(os.getenv("GNOSIS_SLEEP_TIME")), + "Ethereum": get_env_float("ETHEREUM_SLEEP_TIME"), + "Gnosis": get_env_float("GNOSIS_SLEEP_TIME"), } @@ -25,6 +38,9 @@ def create_read_db_connection(chain_name: str) -> Engine: elif chain_name == "Gnosis": read_db_url = os.getenv("GNOSIS_DB_URL") + if not read_db_url: + raise ValueError(f"No database URL found for chain: {chain_name}") + return create_engine(f"postgresql+psycopg2://{read_db_url}") @@ -33,6 +49,9 @@ def create_write_db_connection() -> Psycopg2Connection: parsed_url = urlparse(os.getenv("SOLVER_SLIPPAGE_DB_URL")) + if not parsed_url.hostname or not parsed_url.path: + raise ValueError("Invalid or missing write database URL") + # Connect to the database write_db_connection = psycopg2.connect( database=parsed_url.path[1:], diff --git a/src/daemon.py b/src/daemon.py index ac0401c..e4fc033 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -3,12 +3,11 @@ """ import time -from typing import List, Tuple, Dict, Any +from typing import List, Tuple, Any from threading import Thread import psycopg2 import pandas as pd from web3 import Web3 -from web3.types import ChecksumAddress, HexStr, TxReceipt from sqlalchemy.engine import Engine from src.imbalances_script import RawTokenImbalances from src.config import ( @@ -118,8 +117,8 @@ def fetch_transaction_hashes( query = f""" SELECT tx_hash, auction_id FROM settlements - WHERE block_number >= 20201300 - AND block_number <= 20201400 + WHERE block_number >= {start_block} + AND block_number <= {end_block} """ db_data = pd.read_sql(query, read_db_connection) From 6dd7ba93d18134d5b60140507085b49ee79a0f67 Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Mon, 8 Jul 2024 23:38:52 -0400 Subject: [PATCH 12/15] major changes --- src/config.py | 33 +++++++-- src/daemon.py | 191 +++++++++++++++++++++++++++++++++----------------- 2 files changed, 152 insertions(+), 72 deletions(-) diff --git a/src/config.py b/src/config.py index e777df4..78269cc 100644 --- a/src/config.py +++ b/src/config.py @@ -1,21 +1,27 @@ import os import psycopg2 +from typing import Any, Optional from sqlalchemy import create_engine, Engine from dotenv import load_dotenv from urllib.parse import urlparse from psycopg2.extensions import connection as Psycopg2Connection from src.helper_functions import get_logger + load_dotenv() ETHEREUM_NODE_URL = os.getenv("ETHEREUM_NODE_URL") GNOSIS_NODE_URL = os.getenv("GNOSIS_NODE_URL") CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL} +logger = get_logger("raw_token_imbalances") + -# function for safe conversion to float (prevents None -> float conversion issues raised by mypy) def get_env_float(var_name: str) -> float: - """Retrieve environment variable and convert to float. Raise an error if not set.""" + """ + Function for safe conversion to float (prevents None -> float conversion issues raised by mypy) + Retrieve environment variable and convert to float. Raise an error if not set. + """ value = os.getenv(var_name) if value is None: raise ValueError(f"Environment variable {var_name} is not set.") @@ -31,7 +37,7 @@ def get_env_float(var_name: str) -> float: } -def create_read_db_connection(chain_name: str) -> Engine: +def create_backend_db_connection(chain_name: str) -> Engine: """function that creates a connection to the CoW db.""" if chain_name == "Ethereum": read_db_url = os.getenv("ETHEREUM_DB_URL") @@ -44,7 +50,7 @@ def create_read_db_connection(chain_name: str) -> Engine: return create_engine(f"postgresql+psycopg2://{read_db_url}") -def create_write_db_connection() -> Psycopg2Connection: +def create_solver_slippage_db_connection() -> Psycopg2Connection: """Function that creates a connection to the write database.""" parsed_url = urlparse(os.getenv("SOLVER_SLIPPAGE_DB_URL")) @@ -53,14 +59,27 @@ def create_write_db_connection() -> Psycopg2Connection: raise ValueError("Invalid or missing write database URL") # Connect to the database - write_db_connection = psycopg2.connect( + solver_slippage_connection = psycopg2.connect( database=parsed_url.path[1:], user=parsed_url.username, password=parsed_url.password, host=parsed_url.hostname, port=parsed_url.port, ) - return write_db_connection + return solver_slippage_connection -logger = get_logger("raw_token_imbalances") +def check_db_connection(connection: Any, chain_name: Optional[str] = None) -> Any: + """ + Check if the database connection is still active. If not, create a new one. + """ + try: + if connection.closed: + raise psycopg2.OperationalError("Connection is closed") + except (psycopg2.OperationalError, AttributeError): + connection = ( + create_backend_db_connection(chain_name) + if chain_name + else create_solver_slippage_db_connection() + ) + return connection diff --git a/src/daemon.py b/src/daemon.py index e4fc033..748a68d 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -13,24 +13,62 @@ from src.config import ( CHAIN_RPC_ENDPOINTS, CHAIN_SLEEP_TIMES, - create_read_db_connection, - create_write_db_connection, + create_backend_db_connection, + create_solver_slippage_db_connection, + check_db_connection, logger, ) +def get_web3_instance(chain_name: str) -> Web3: + """ + returns a Web3 instance for the given blockchain via chain name. + """ + return Web3(Web3.HTTPProvider(CHAIN_RPC_ENDPOINTS[chain_name])) + + +def get_finalized_block_number(web3: Web3) -> int: + """ + Get the number of the most recent finalized block. + """ + return web3.eth.block_number - 67 + + +def fetch_tx_data( + backend_db_connection: Engine, chain_name: str, start_block: int, end_block: int +) -> List[Tuple[str, int, int]]: + """Fetch transaction hashes beginning from start_block to end_block.""" + backend_db_connection = check_db_connection(backend_db_connection, chain_name) + query = f""" + SELECT tx_hash, auction_id, block_number + FROM settlements + WHERE block_number >= {start_block} + AND block_number <= {end_block} + """ + db_data = pd.read_sql(query, backend_db_connection) + # converts hashes at memory location to hex + db_data["tx_hash"] = db_data["tx_hash"].apply(lambda x: f"0x{x.hex()}") + + # return (tx hash, auction id) as tx_data + tx_data = [ + (row["tx_hash"], row["auction_id"], row["block_number"]) + for index, row in db_data.iterrows() + ] + return tx_data + + def record_exists( - write_db_connection: Any, + solver_slippage_db_connection: Any, tx_hash_bytes: bytes, token_address_bytes: bytes, ) -> bool: """ Check if a record with the given (tx_hash, token_address) already exists in the database. """ + solver_slippage_db_connection = check_db_connection(solver_slippage_db_connection) try: - cursor = write_db_connection.cursor() - - # Check if the record exists + cursor = solver_slippage_db_connection.cursor() + # check if the record exists check_sql = """ SELECT 1 FROM raw_token_imbalances WHERE tx_hash = %s AND token_address = %s @@ -40,7 +78,6 @@ def record_exists( (psycopg2.Binary(tx_hash_bytes), psycopg2.Binary(token_address_bytes)), ) record_exists = cursor.fetchone() - return record_exists is not None except psycopg2.Error as e: logger.error("Error checking record existence: %s", e) @@ -51,8 +88,9 @@ def record_exists( def write_token_imbalances_to_db( chain_name: str, - write_db_connection: Any, + solver_slippage_db_connection: Any, auction_id: int, + block_number: int, tx_hash: str, token_address: str, imbalance: float, @@ -60,30 +98,31 @@ def write_token_imbalances_to_db( """ Write token imbalances to the database if the (tx_hash, token_address) combination does not already exist. """ + solver_slippage_db_connection = check_db_connection(solver_slippage_db_connection) tx_hash_bytes = bytes.fromhex(tx_hash[2:]) token_address_bytes = bytes.fromhex(token_address[2:]) - - if not record_exists(write_db_connection, tx_hash_bytes, token_address_bytes): + if not record_exists( + solver_slippage_db_connection, tx_hash_bytes, token_address_bytes + ): try: - cursor = write_db_connection.cursor() - # Convert hex strings to bytes for database insertion - + cursor = solver_slippage_db_connection.cursor() insert_sql = """ - INSERT INTO raw_token_imbalances (auction_id, chain_name, tx_hash, token_address, imbalance) - VALUES (%s, %s, %s, %s, %s) + INSERT INTO raw_token_imbalances (auction_id, chain_name, block_number, tx_hash, token_address, imbalance) + VALUES (%s, %s, %s, %s, %s, %s) """ cursor.execute( insert_sql, ( auction_id, chain_name, + block_number, psycopg2.Binary(tx_hash_bytes), psycopg2.Binary(token_address_bytes), imbalance, ), ) - write_db_connection.commit() - logger.info("Record inserted successfully.") + solver_slippage_db_connection.commit() + logger.debug("Record inserted successfully.") except psycopg2.Error as e: logger.error("Error inserting record: %s", e) finally: @@ -96,38 +135,55 @@ def write_token_imbalances_to_db( ) -def get_web3_instance(chain_name: str) -> Web3: +def get_start_block( + chain_name: str, solver_slippage_db_connection: Any, web3: Web3 +) -> int: """ - returns a Web3 instance for the given blockchain via chain name. + Retrieve the most recent block already present in raw_token_imbalances table, + delete entries for that block, and return this block number as start_block. + If no entries are present, fallback to get_finalized_block_number(). """ - return Web3(Web3.HTTPProvider(CHAIN_RPC_ENDPOINTS[chain_name])) - + try: + solver_slippage_db_connection = check_db_connection( + solver_slippage_db_connection + ) -def get_finalized_block_number(web3: Web3) -> int: - """ - Get the number of the most recent finalized block. - """ - return web3.eth.block_number - 64 + # query to get the maximum block number present in the table for the given chain_name + query_max_block = """ + SELECT MAX(block_number) FROM raw_token_imbalances + WHERE chain_name = %s + """ + cursor = solver_slippage_db_connection.cursor() + cursor.execute(query_max_block, (chain_name,)) + max_block = cursor.fetchone()[0] # Fetch the maximum block number + if max_block is not None: + logger.debug(f"Fetched max block number from database: {max_block}") + # If no entries present, fallback to get_finalized_block_number() + if max_block is None: + cursor.close() + return get_finalized_block_number(web3) -def fetch_transaction_hashes( - read_db_connection: Engine, start_block: int, end_block: int -) -> List[Tuple[str, int]]: - """Fetch transaction hashes beginning from start_block to end_block.""" - query = f""" - SELECT tx_hash, auction_id - FROM settlements - WHERE block_number >= {start_block} - AND block_number <= {end_block} - """ + # delete entries for the max block from the table + delete_sql = """ + DELETE FROM raw_token_imbalances WHERE chain_name = %s AND block_number = %s + """ + try: + cursor.execute(delete_sql, (chain_name, max_block)) + solver_slippage_db_connection.commit() + logger.debug(f"Successfully deleted entries for block number: {max_block}") + except Exception as e: + logger.debug(f"Failed to delete entries for block number {max_block}: {e}") - db_data = pd.read_sql(query, read_db_connection) - # converts hashes at memory location to hex - db_data["tx_hash"] = db_data["tx_hash"].apply(lambda x: f"0x{x.hex()}") + cursor.close() + return max_block - # return (tx hash, auction id) as tx_data - tx_data = [(row["tx_hash"], row["auction_id"]) for index, row in db_data.iterrows()] - return tx_data + except psycopg2.Error as e: + logger.error("Error accessing database: %s", e) + # Fallback to get_finalized_block_number() in case of any error + return get_finalized_block_number(web3) + finally: + solver_slippage_db_connection.close() # Close the database connection def process_transactions(chain_name: str) -> None: @@ -137,43 +193,48 @@ def process_transactions(chain_name: str) -> None: web3 = get_web3_instance(chain_name) rt = RawTokenImbalances(web3, chain_name) sleep_time = CHAIN_SLEEP_TIMES.get(chain_name) - read_db_connection = create_read_db_connection(chain_name) - write_db_connection = create_write_db_connection() - previous_block = get_finalized_block_number(web3) - unprocessed_txs: List[Tuple[str, int]] = [] - - logger.info("%s Daemon started.", chain_name) + backend_db_connection = create_backend_db_connection(chain_name) + solver_slippage_db_connection = create_solver_slippage_db_connection() + start_block = get_start_block(chain_name, solver_slippage_db_connection, web3) + previous_block = start_block + unprocessed_txs: List[Tuple[str, int, int]] = [] + logger.info("%s Daemon started. Start block: %d", chain_name, start_block) while True: try: latest_block = get_finalized_block_number(web3) - new_txs = fetch_transaction_hashes( - read_db_connection, previous_block, latest_block + new_txs = fetch_tx_data( + backend_db_connection, chain_name, previous_block, latest_block ) - # add any unprocessed hashes for processing, then clear list of unprocessed + # add any unprocessed txs for processing, then clear list of unprocessed all_txs = new_txs + unprocessed_txs unprocessed_txs.clear() - for tx, auction_id in all_txs: + for tx, auction_id, block_number in all_txs: logger.info("Processing transaction on %s: %s", chain_name, tx) try: imbalances = rt.compute_imbalances(tx) - logger.info("Token Imbalances on %s:", chain_name) + # append imbalances to a single log message + log_message = [f"Token Imbalances on {chain_name} for tx {tx}:"] for token_address, imbalance in imbalances.items(): - write_token_imbalances_to_db( - chain_name, - write_db_connection, - auction_id, - tx, - token_address, - imbalance, - ) - logger.info( - "Token: %s, Imbalance: %s", token_address, imbalance - ) + # ignore tokens that have null imbalances + if imbalance != 0: + write_token_imbalances_to_db( + chain_name, + solver_slippage_db_connection, + auction_id, + block_number, + tx, + token_address, + imbalance, + ) + log_message.append( + f"Token: {token_address}, Imbalance: {imbalance}" + ) + logger.info("\n".join(log_message)) except ValueError as e: logger.error("ValueError: %s", e) - unprocessed_txs.append((tx, auction_id)) + unprocessed_txs.append((tx, auction_id, block_number)) previous_block = latest_block + 1 except ConnectionError as e: From 01a0dc36026c4c2c11503d84ac9ccea50b221fda Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Mon, 8 Jul 2024 23:56:35 -0400 Subject: [PATCH 13/15] addressed comments --- src/balanceof_imbalances.py | 6 ++++-- src/config.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index 400be6e..22e5b9d 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -92,10 +92,12 @@ def calculate_imbalances( prev_balances[token_address] is not None and final_balances[token_address] is not None ): + # need to ensure prev_balance and final_balance contain values + # to prevent subtraction from None prev_balance = prev_balances[token_address] - assert prev_balance is not None # Ensure prev_balance is not None + assert prev_balance is not None final_balance = final_balances[token_address] - assert final_balance is not None # Ensure final_balance is not None + assert final_balance is not None imbalance = final_balance - prev_balance imbalances[token_address] = imbalance return imbalances diff --git a/src/config.py b/src/config.py index 78269cc..5b0fda7 100644 --- a/src/config.py +++ b/src/config.py @@ -17,23 +17,23 @@ logger = get_logger("raw_token_imbalances") -def get_env_float(var_name: str) -> float: +def get_env_int(var_name: str) -> int: """ - Function for safe conversion to float (prevents None -> float conversion issues raised by mypy) - Retrieve environment variable and convert to float. Raise an error if not set. + Function for safe conversion to int (prevents None -> int conversion issues raised by mypy) + Retrieve environment variable and convert to int. Raise an error if not set. """ value = os.getenv(var_name) if value is None: raise ValueError(f"Environment variable {var_name} is not set.") try: - return float(value) + return int(value) except ValueError: - raise ValueError(f"Environment variable {var_name} must be a float.") + raise ValueError(f"Environment variable {var_name} must be a int.") CHAIN_SLEEP_TIMES = { - "Ethereum": get_env_float("ETHEREUM_SLEEP_TIME"), - "Gnosis": get_env_float("GNOSIS_SLEEP_TIME"), + "Ethereum": get_env_int("ETHEREUM_SLEEP_TIME"), + "Gnosis": get_env_int("GNOSIS_SLEEP_TIME"), } From c12ab0c283c4fc565571c427a6da83dfeb97f94e Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Tue, 9 Jul 2024 21:04:14 -0400 Subject: [PATCH 14/15] both dbs to sqlalchemy --- src/config.py | 41 ++++++------- src/daemon.py | 155 +++++++++++++++++++++++++------------------------- 2 files changed, 95 insertions(+), 101 deletions(-) diff --git a/src/config.py b/src/config.py index 5b0fda7..6614e1d 100644 --- a/src/config.py +++ b/src/config.py @@ -1,10 +1,9 @@ import os -import psycopg2 -from typing import Any, Optional +from typing import Optional +from sqlalchemy import text +from sqlalchemy.exc import OperationalError from sqlalchemy import create_engine, Engine from dotenv import load_dotenv -from urllib.parse import urlparse -from psycopg2.extensions import connection as Psycopg2Connection from src.helper_functions import get_logger @@ -50,33 +49,27 @@ def create_backend_db_connection(chain_name: str) -> Engine: return create_engine(f"postgresql+psycopg2://{read_db_url}") -def create_solver_slippage_db_connection() -> Psycopg2Connection: - """Function that creates a connection to the write database.""" - - parsed_url = urlparse(os.getenv("SOLVER_SLIPPAGE_DB_URL")) - - if not parsed_url.hostname or not parsed_url.path: - raise ValueError("Invalid or missing write database URL") +def create_solver_slippage_db_connection() -> Engine: + """function that creates a connection to the CoW db.""" + solver_db_url = os.getenv("SOLVER_SLIPPAGE_DB_URL") + if not solver_db_url: + raise ValueError( + "Solver slippage database URL not found in environment variables." + ) - # Connect to the database - solver_slippage_connection = psycopg2.connect( - database=parsed_url.path[1:], - user=parsed_url.username, - password=parsed_url.password, - host=parsed_url.hostname, - port=parsed_url.port, - ) - return solver_slippage_connection + return create_engine(f"postgresql+psycopg2://{solver_db_url}") -def check_db_connection(connection: Any, chain_name: Optional[str] = None) -> Any: +def check_db_connection(connection: Engine, chain_name: Optional[str] = None) -> Engine: """ Check if the database connection is still active. If not, create a new one. """ try: - if connection.closed: - raise psycopg2.OperationalError("Connection is closed") - except (psycopg2.OperationalError, AttributeError): + if connection: + with connection.connect() as conn: # Use connection.connect() to get a Connection object + conn.execute(text("SELECT 1")) + except OperationalError: + # if connection is closed, create new one connection = ( create_backend_db_connection(chain_name) if chain_name diff --git a/src/daemon.py b/src/daemon.py index 748a68d..62f6a36 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -3,11 +3,11 @@ """ import time -from typing import List, Tuple, Any +from typing import List, Tuple from threading import Thread -import psycopg2 import pandas as pd from web3 import Web3 +from sqlalchemy import text from sqlalchemy.engine import Engine from src.imbalances_script import RawTokenImbalances from src.config import ( @@ -58,37 +58,35 @@ def fetch_tx_data( def record_exists( - solver_slippage_db_connection: Any, + solver_slippage_db_engine: Engine, tx_hash_bytes: bytes, token_address_bytes: bytes, ) -> bool: """ Check if a record with the given (tx_hash, token_address) already exists in the database. """ - solver_slippage_db_connection = check_db_connection(solver_slippage_db_connection) - try: - cursor = solver_slippage_db_connection.cursor() - # check if the record exists - check_sql = """ - SELECT 1 FROM raw_token_imbalances - WHERE tx_hash = %s AND token_address = %s + solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine) + query = text( """ - cursor.execute( - check_sql, - (psycopg2.Binary(tx_hash_bytes), psycopg2.Binary(token_address_bytes)), - ) - record_exists = cursor.fetchone() - return record_exists is not None - except psycopg2.Error as e: + SELECT 1 FROM raw_token_imbalances + WHERE tx_hash = :tx_hash AND token_address = :token_address + """ + ) + try: + with solver_slippage_db_engine.connect() as connection: + result = connection.execute( + query, {"tx_hash": tx_hash_bytes, "token_address": token_address_bytes} + ) + record_exists = result.fetchone() is not None + return record_exists + except Exception as e: logger.error("Error checking record existence: %s", e) return False - finally: - cursor.close() def write_token_imbalances_to_db( chain_name: str, - solver_slippage_db_connection: Any, + solver_slippage_db_engine: Engine, auction_id: int, block_number: int, tx_hash: str, @@ -98,35 +96,33 @@ def write_token_imbalances_to_db( """ Write token imbalances to the database if the (tx_hash, token_address) combination does not already exist. """ - solver_slippage_db_connection = check_db_connection(solver_slippage_db_connection) + solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine) tx_hash_bytes = bytes.fromhex(tx_hash[2:]) token_address_bytes = bytes.fromhex(token_address[2:]) - if not record_exists( - solver_slippage_db_connection, tx_hash_bytes, token_address_bytes - ): - try: - cursor = solver_slippage_db_connection.cursor() - insert_sql = """ - INSERT INTO raw_token_imbalances (auction_id, chain_name, block_number, tx_hash, token_address, imbalance) - VALUES (%s, %s, %s, %s, %s, %s) + if not record_exists(solver_slippage_db_engine, tx_hash_bytes, token_address_bytes): + insert_sql = text( """ - cursor.execute( - insert_sql, - ( - auction_id, - chain_name, - block_number, - psycopg2.Binary(tx_hash_bytes), - psycopg2.Binary(token_address_bytes), - imbalance, - ), - ) - solver_slippage_db_connection.commit() - logger.debug("Record inserted successfully.") - except psycopg2.Error as e: + INSERT INTO raw_token_imbalances (auction_id, chain_name, block_number, tx_hash, token_address, imbalance) + VALUES (:auction_id, :chain_name, :block_number, :tx_hash, :token_address, :imbalance) + """ + ) + try: + with solver_slippage_db_engine.connect() as connection: + connection.execute( + insert_sql, + { + "auction_id": auction_id, + "chain_name": chain_name, + "block_number": block_number, + "tx_hash": tx_hash_bytes, + "token_address": token_address_bytes, + "imbalance": imbalance, + }, + ) + connection.commit() + logger.debug("Record inserted successfully.") + except Exception as e: logger.error("Error inserting record: %s", e) - finally: - cursor.close() else: logger.info( "Record with tx_hash %s and token_address %s already exists.", @@ -136,7 +132,7 @@ def write_token_imbalances_to_db( def get_start_block( - chain_name: str, solver_slippage_db_connection: Any, web3: Web3 + chain_name: str, solver_slippage_db_engine: Engine, web3: Web3 ) -> int: """ Retrieve the most recent block already present in raw_token_imbalances table, @@ -144,46 +140,51 @@ def get_start_block( If no entries are present, fallback to get_finalized_block_number(). """ try: - solver_slippage_db_connection = check_db_connection( - solver_slippage_db_connection - ) + solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine) - # query to get the maximum block number present in the table for the given chain_name - query_max_block = """ + query_max_block = text( + """ SELECT MAX(block_number) FROM raw_token_imbalances - WHERE chain_name = %s + WHERE chain_name = :chain_name """ + ) - cursor = solver_slippage_db_connection.cursor() - cursor.execute(query_max_block, (chain_name,)) - max_block = cursor.fetchone()[0] # Fetch the maximum block number - if max_block is not None: - logger.debug(f"Fetched max block number from database: {max_block}") - # If no entries present, fallback to get_finalized_block_number() - if max_block is None: - cursor.close() - return get_finalized_block_number(web3) + with solver_slippage_db_engine.connect() as connection: + result = connection.execute(query_max_block, {"chain_name": chain_name}) + row = result.fetchone() + max_block = ( + row[0] if row is not None else None + ) # Fetch the maximum block number + if max_block is not None: + logger.debug("Fetched max block number from database: %d", max_block) - # delete entries for the max block from the table - delete_sql = """ - DELETE FROM raw_token_imbalances WHERE chain_name = %s AND block_number = %s - """ - try: - cursor.execute(delete_sql, (chain_name, max_block)) - solver_slippage_db_connection.commit() - logger.debug(f"Successfully deleted entries for block number: {max_block}") - except Exception as e: - logger.debug(f"Failed to delete entries for block number {max_block}: {e}") + # If no entries present, fallback to get_finalized_block_number() + if max_block is None: + return get_finalized_block_number(web3) - cursor.close() - return max_block + # delete entries for the max block from the table + delete_sql = text( + """ + DELETE FROM raw_token_imbalances WHERE chain_name = :chain_name AND block_number = :block_number + """ + ) + try: + connection.execute( + delete_sql, {"chain_name": chain_name, "block_number": max_block} + ) + connection.commit() + logger.debug( + "Successfully deleted entries for block number: %s", max_block + ) + except Exception as e: + logger.debug( + "Failed to delete entries for block number %s: %s", max_block, e + ) - except psycopg2.Error as e: + return max_block + except Exception as e: logger.error("Error accessing database: %s", e) - # Fallback to get_finalized_block_number() in case of any error return get_finalized_block_number(web3) - finally: - solver_slippage_db_connection.close() # Close the database connection def process_transactions(chain_name: str) -> None: From dc64020420826cf49e86d837b17d434d9a6ea08c Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Tue, 9 Jul 2024 23:00:58 -0400 Subject: [PATCH 15/15] stdout, no threads --- .env.sample | 24 ++++++++++++------------ src/balanceof_imbalances.py | 10 +++++----- src/config.py | 21 +++++++++------------ src/daemon.py | 34 ++++++++++++++-------------------- src/helper_functions.py | 28 ++++++++++++++++++++++++++-- 5 files changed, 66 insertions(+), 51 deletions(-) diff --git a/.env.sample b/.env.sample index a2922e8..3def4ce 100644 --- a/.env.sample +++ b/.env.sample @@ -1,19 +1,19 @@ # .env.sample -# URLs for DB connection -ETHEREUM_DB_URL= -GNOSIS_DB_URL= +# DB connection +DB_URL= -# URLs for Node provider connection -ETHEREUM_NODE_URL= -GNOSIS_NODE_URL= +# Node provider connection +NODE_URL= -# add credentials for connecting to solver slippage DB based on this format -SOLVER_SLIPPAGE_DB_URL=postgresql://username:password@hostname:port/database +# connecting to Solver Slippage DB +SOLVER_SLIPPAGE_DB_URL= -# configure chain sleep time -ETHEREUM_SLEEP_TIME= -GNOSIS_SLEEP_TIME= +# configure chain sleep time, e.g. CHAIN_SLEEP_TIME=60 +CHAIN_SLEEP_TIME= + +# add chain name, e.g. CHAIN_NAME=Ethereum +CHAIN_NAME= # optional -INFURA_KEY=infura_key_here +INFURA_KEY= diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index 22e5b9d..c78f965 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -1,8 +1,8 @@ from web3 import Web3 from web3.types import TxReceipt, HexStr from eth_typing import ChecksumAddress -from typing import Dict, Optional, Set, Any -from src.config import ETHEREUM_NODE_URL +from typing import Dict, Optional, Set +from src.config import NODE_URL from src.constants import SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS from contracts.erc20_abi import erc20_abi @@ -10,8 +10,8 @@ class BalanceOfImbalances: - def __init__(self, ETHEREUM_NODE_URL: str): - self.web3 = Web3(Web3.HTTPProvider(ETHEREUM_NODE_URL)) + def __init__(self, NODE_URL: str): + self.web3 = Web3(Web3.HTTPProvider(NODE_URL)) def get_token_balance( self, @@ -124,7 +124,7 @@ def compute_imbalances(self, tx_hash: HexStr) -> Dict[ChecksumAddress, int]: def main(): tx_hash = input("Enter transaction hash: ") - bo = BalanceOfImbalances(ETHEREUM_NODE_URL) + bo = BalanceOfImbalances(NODE_URL) imbalances = bo.compute_imbalances(tx_hash) print("Token Imbalances:") for token_address, imbalance in imbalances.items(): diff --git a/src/config.py b/src/config.py index 6614e1d..4cdf822 100644 --- a/src/config.py +++ b/src/config.py @@ -8,13 +8,16 @@ load_dotenv() -ETHEREUM_NODE_URL = os.getenv("ETHEREUM_NODE_URL") -GNOSIS_NODE_URL = os.getenv("GNOSIS_NODE_URL") - -CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL} +NODE_URL = os.getenv("NODE_URL") logger = get_logger("raw_token_imbalances") +# Utilized by imbalances_script for computing for single tx hash +CHAIN_RPC_ENDPOINTS = { + "Ethereum": os.getenv("ETHEREUM_NODE_URL"), + "Gnosis": os.getenv("GNOSIS_NODE_URL"), +} + def get_env_int(var_name: str) -> int: """ @@ -30,18 +33,12 @@ def get_env_int(var_name: str) -> int: raise ValueError(f"Environment variable {var_name} must be a int.") -CHAIN_SLEEP_TIMES = { - "Ethereum": get_env_int("ETHEREUM_SLEEP_TIME"), - "Gnosis": get_env_int("GNOSIS_SLEEP_TIME"), -} +CHAIN_SLEEP_TIME = get_env_int("CHAIN_SLEEP_TIME") def create_backend_db_connection(chain_name: str) -> Engine: """function that creates a connection to the CoW db.""" - if chain_name == "Ethereum": - read_db_url = os.getenv("ETHEREUM_DB_URL") - elif chain_name == "Gnosis": - read_db_url = os.getenv("GNOSIS_DB_URL") + read_db_url = os.getenv("DB_URL") if not read_db_url: raise ValueError(f"No database URL found for chain: {chain_name}") diff --git a/src/daemon.py b/src/daemon.py index 62f6a36..b3595b6 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -1,18 +1,17 @@ """ Running this daemon computes raw imbalances for finalized blocks by calling imbalances_script.py. """ - +import os import time from typing import List, Tuple -from threading import Thread import pandas as pd from web3 import Web3 from sqlalchemy import text from sqlalchemy.engine import Engine from src.imbalances_script import RawTokenImbalances from src.config import ( - CHAIN_RPC_ENDPOINTS, - CHAIN_SLEEP_TIMES, + CHAIN_SLEEP_TIME, + NODE_URL, create_backend_db_connection, create_solver_slippage_db_connection, check_db_connection, @@ -20,11 +19,11 @@ ) -def get_web3_instance(chain_name: str) -> Web3: +def get_web3_instance() -> Web3: """ returns a Web3 instance for the given blockchain via chain name. """ - return Web3(Web3.HTTPProvider(CHAIN_RPC_ENDPOINTS[chain_name])) + return Web3(Web3.HTTPProvider(NODE_URL)) def get_finalized_block_number(web3: Web3) -> int: @@ -191,9 +190,8 @@ def process_transactions(chain_name: str) -> None: """ Process transactions to compute imbalances for a given blockchain via chain name. """ - web3 = get_web3_instance(chain_name) + web3 = get_web3_instance() rt = RawTokenImbalances(web3, chain_name) - sleep_time = CHAIN_SLEEP_TIMES.get(chain_name) backend_db_connection = create_backend_db_connection(chain_name) solver_slippage_db_connection = create_solver_slippage_db_connection() start_block = get_start_block(chain_name, solver_slippage_db_connection, web3) @@ -244,23 +242,19 @@ def process_transactions(chain_name: str) -> None: ) except Exception as e: logger.error("Error processing transactions on %s: %s", chain_name, e) - if sleep_time is not None: - time.sleep(sleep_time) + if CHAIN_SLEEP_TIME is not None: + time.sleep(CHAIN_SLEEP_TIME) def main() -> None: """ - Main function to start the daemon threads for each blockchain. + Main function to start the daemon for a blockchain. """ - threads = [] - - for chain_name in CHAIN_RPC_ENDPOINTS.keys(): - thread = Thread(target=process_transactions, args=(chain_name,), daemon=True) - thread.start() - threads.append(thread) - - for thread in threads: - thread.join() + chain_name = os.getenv("CHAIN_NAME") + if chain_name is None: + logger.error("CHAIN_NAME environment variable is not set.") + return + process_transactions(chain_name) if __name__ == "__main__": diff --git a/src/helper_functions.py b/src/helper_functions.py index a211307..07cf062 100644 --- a/src/helper_functions.py +++ b/src/helper_functions.py @@ -2,6 +2,7 @@ This file contains some auxiliary functions """ from __future__ import annotations +import sys import logging from typing import Optional @@ -10,13 +11,36 @@ def get_logger(filename: Optional[str] = None) -> logging.Logger: """ get_logger() returns a logger object that can write to a file, terminal or only file if needed. """ - logging.basicConfig(format="%(levelname)s - %(message)s") logger = logging.getLogger() logger.setLevel(logging.INFO) + + # Clear any existing handlers to avoid duplicate logs + if logger.hasHandlers(): + logger.handlers.clear() + + # Create formatter + formatter = logging.Formatter("%(levelname)s - %(message)s") + + # Handler for stdout (INFO and lower) + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.INFO) + stdout_handler.setFormatter(formatter) + + # ERROR and above logs will not be logged to stdout + stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR) + + # Handler for stderr (ERROR and higher) + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(logging.ERROR) + stderr_handler.setFormatter(formatter) + + # Add handlers to the logger + logger.addHandler(stdout_handler) + logger.addHandler(stderr_handler) + if filename: file_handler = logging.FileHandler(filename + ".log", mode="w") file_handler.setLevel(logging.INFO) - formatter = logging.Formatter("%(levelname)s - %(message)s") file_handler.setFormatter(formatter) logger.addHandler(file_handler)