Skip to content

Commit

Permalink
Merge pull request #11 from cowprotocol/db-fix
Browse files Browse the repository at this point in the history
addressed issue #7 & trace handling
  • Loading branch information
harisang authored Jul 11, 2024
2 parents bd08e86 + d00a751 commit 72e101d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 90 deletions.
41 changes: 16 additions & 25 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from typing import Optional
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
from sqlalchemy import create_engine, Engine
Expand All @@ -18,6 +17,11 @@
"Gnosis": os.getenv("GNOSIS_NODE_URL"),
}

CREATE_DB_URLS = {
"backend": os.getenv("DB_URL"),
"solver_slippage": os.getenv("SOLVER_SLIPPAGE_DB_URL"),
}


def get_env_int(var_name: str) -> int:
"""
Expand All @@ -36,28 +40,19 @@ def get_env_int(var_name: str) -> int:
CHAIN_SLEEP_TIME = get_env_int("CHAIN_SLEEP_TIME")


def create_backend_db_connection(chain_name: str) -> Engine:
"""function that creates a connection to the CoW db."""
read_db_url = os.getenv("DB_URL")

if not read_db_url:
raise ValueError(f"No database URL found for chain: {chain_name}")

return create_engine(f"postgresql+psycopg2://{read_db_url}")


def create_solver_slippage_db_connection() -> Engine:
"""function that creates a connection to the CoW db."""
solver_db_url = os.getenv("SOLVER_SLIPPAGE_DB_URL")
if not solver_db_url:
raise ValueError(
"Solver slippage database URL not found in environment variables."
)
def create_db_connection(db_type: str) -> Engine:
"""
Function that creates a connection to the specified database.
db_type should be either "backend" or "solver_slippage".
"""
db_url = CREATE_DB_URLS.get(db_type)
if not db_url:
raise ValueError(f"{db_type} database URL not found in environment variables.")

return create_engine(f"postgresql+psycopg2://{solver_db_url}")
return create_engine(f"postgresql+psycopg2://{db_url}")


def check_db_connection(connection: Engine, chain_name: Optional[str] = None) -> Engine:
def check_db_connection(connection: Engine, db_type: str) -> Engine:
"""
Check if the database connection is still active. If not, create a new one.
"""
Expand All @@ -67,9 +62,5 @@ def check_db_connection(connection: Engine, chain_name: Optional[str] = None) ->
conn.execute(text("SELECT 1"))
except OperationalError:
# if connection is closed, create new one
connection = (
create_backend_db_connection(chain_name)
if chain_name
else create_solver_slippage_db_connection()
)
connection = create_db_connection(db_type)
return connection
92 changes: 52 additions & 40 deletions src/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from src.config import (
CHAIN_SLEEP_TIME,
NODE_URL,
create_backend_db_connection,
create_solver_slippage_db_connection,
create_db_connection,
check_db_connection,
logger,
)
Expand All @@ -34,10 +33,12 @@ def get_finalized_block_number(web3: Web3) -> int:


def fetch_tx_data(
backend_db_connection: Engine, chain_name: str, start_block: int, end_block: int
backend_db_connection: Engine, start_block: int, end_block: int
) -> List[Tuple[str, int, int]]:
"""Fetch transaction hashes beginning from start_block to end_block."""
backend_db_connection = check_db_connection(backend_db_connection, chain_name)

backend_db_connection = check_db_connection(backend_db_connection, "backend")

query = f"""
SELECT tx_hash, auction_id, block_number
FROM settlements
Expand All @@ -57,22 +58,25 @@ def fetch_tx_data(


def record_exists(
solver_slippage_db_engine: Engine,
solver_slippage_connection: Engine,
tx_hash_bytes: bytes,
token_address_bytes: bytes,
) -> bool:
"""
Check if a record with the given (tx_hash, token_address) already exists in the database.
"""
solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine)
solver_slippage_connection = check_db_connection(
solver_slippage_connection, "solver_slippage"
)

query = text(
"""
SELECT 1 FROM raw_token_imbalances
WHERE tx_hash = :tx_hash AND token_address = :token_address
"""
)
try:
with solver_slippage_db_engine.connect() as connection:
with solver_slippage_connection.connect() as connection:
result = connection.execute(
query, {"tx_hash": tx_hash_bytes, "token_address": token_address_bytes}
)
Expand All @@ -85,7 +89,7 @@ def record_exists(

def write_token_imbalances_to_db(
chain_name: str,
solver_slippage_db_engine: Engine,
solver_slippage_connection: Engine,
auction_id: int,
block_number: int,
tx_hash: str,
Expand All @@ -95,18 +99,23 @@ def write_token_imbalances_to_db(
"""
Write token imbalances to the database if the (tx_hash, token_address) combination does not already exist.
"""
solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine)
solver_slippage_connection = check_db_connection(
solver_slippage_connection, "solver_slippage"
)

tx_hash_bytes = bytes.fromhex(tx_hash[2:])
token_address_bytes = bytes.fromhex(token_address[2:])
if not record_exists(solver_slippage_db_engine, tx_hash_bytes, token_address_bytes):
if not record_exists(
solver_slippage_connection, tx_hash_bytes, token_address_bytes
):
insert_sql = text(
"""
INSERT INTO raw_token_imbalances (auction_id, chain_name, block_number, tx_hash, token_address, imbalance)
VALUES (:auction_id, :chain_name, :block_number, :tx_hash, :token_address, :imbalance)
"""
)
try:
with solver_slippage_db_engine.connect() as connection:
with solver_slippage_connection.connect() as connection:
connection.execute(
insert_sql,
{
Expand All @@ -131,15 +140,17 @@ def write_token_imbalances_to_db(


def get_start_block(
chain_name: str, solver_slippage_db_engine: Engine, web3: Web3
chain_name: str, solver_slippage_connection: Engine, web3: Web3
) -> int:
"""
Retrieve the most recent block already present in raw_token_imbalances table,
delete entries for that block, and return this block number as start_block.
If no entries are present, fallback to get_finalized_block_number().
"""
try:
solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine)
solver_slippage_connection = check_db_connection(
solver_slippage_connection, "solver_slippage"
)

query_max_block = text(
"""
Expand All @@ -148,7 +159,7 @@ def get_start_block(
"""
)

with solver_slippage_db_engine.connect() as connection:
with solver_slippage_connection.connect() as connection:
result = connection.execute(query_max_block, {"chain_name": chain_name})
row = result.fetchone()
max_block = (
Expand Down Expand Up @@ -176,7 +187,7 @@ def get_start_block(
"Successfully deleted entries for block number: %s", max_block
)
except Exception as e:
logger.debug(
logger.error(
"Failed to delete entries for block number %s: %s", max_block, e
)

Expand All @@ -192,8 +203,8 @@ def process_transactions(chain_name: str) -> None:
"""
web3 = get_web3_instance()
rt = RawTokenImbalances(web3, chain_name)
backend_db_connection = create_backend_db_connection(chain_name)
solver_slippage_db_connection = create_solver_slippage_db_connection()
backend_db_connection = create_db_connection("backend")
solver_slippage_db_connection = create_db_connection("solver_slippage")
start_block = get_start_block(chain_name, solver_slippage_db_connection, web3)
previous_block = start_block
unprocessed_txs: List[Tuple[str, int, int]] = []
Expand All @@ -202,46 +213,47 @@ def process_transactions(chain_name: str) -> None:
while True:
try:
latest_block = get_finalized_block_number(web3)
new_txs = fetch_tx_data(
backend_db_connection, chain_name, previous_block, latest_block
)
# add any unprocessed txs for processing, then clear list of unprocessed
new_txs = fetch_tx_data(backend_db_connection, previous_block, latest_block)
# Add any unprocessed txs for processing, then clear list of unprocessed
all_txs = new_txs + unprocessed_txs
unprocessed_txs.clear()

for tx, auction_id, block_number in all_txs:
logger.info("Processing transaction on %s: %s", chain_name, tx)
try:
imbalances = rt.compute_imbalances(tx)
# append imbalances to a single log message
log_message = [f"Token Imbalances on {chain_name} for tx {tx}:"]
for token_address, imbalance in imbalances.items():
# ignore tokens that have null imbalances
if imbalance != 0:
write_token_imbalances_to_db(
chain_name,
solver_slippage_db_connection,
auction_id,
block_number,
tx,
token_address,
imbalance,
)
log_message.append(
f"Token: {token_address}, Imbalance: {imbalance}"
)
logger.info("\n".join(log_message))
# Append imbalances to a single log message
if imbalances is not None:
log_message = [f"Token Imbalances on {chain_name} for tx {tx}:"]
for token_address, imbalance in imbalances.items():
# Ignore tokens that have null imbalances
if imbalance != 0:
write_token_imbalances_to_db(
chain_name,
solver_slippage_db_connection,
auction_id,
block_number,
tx,
token_address,
imbalance,
)
log_message.append(
f"Token: {token_address}, Imbalance: {imbalance}"
)
logger.info("\n".join(log_message))
else:
raise ValueError("Imbalances computation returned None.")
except ValueError as e:
logger.error("ValueError: %s", e)
unprocessed_txs.append((tx, auction_id, block_number))

previous_block = latest_block + 1
except ConnectionError as e:
logger.error(
"Connection error processing transactions on %s: %s", chain_name, e
)
except Exception as e:
logger.error("Error processing transactions on %s: %s", chain_name, e)

if CHAIN_SLEEP_TIME is not None:
time.sleep(CHAIN_SLEEP_TIME)

Expand Down
62 changes: 37 additions & 25 deletions src/imbalances_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI
transfer involved which has special handling for its events.
"""

from typing import Dict, List, Optional, Tuple

from web3 import Web3
Expand Down Expand Up @@ -291,35 +292,43 @@ def update_sdai_imbalance(
if event["address"] == SDAI_TOKEN_ADDRESS:
self.process_sdai_event(event, imbalances, is_deposit=False)

def compute_imbalances(self, tx_hash: str) -> Dict[str, int]:
"""Compute token imbalances for a given transaction hash."""
tx_receipt = self.get_transaction_receipt(tx_hash)
if tx_receipt is None:
raise ValueError(
f"Transaction hash {tx_hash} not found on chain {self.chain_name}."
)
# find trace and actions from trace to track native ETH events
traces = self.get_transaction_trace(tx_hash)
native_eth_imbalance = None
actions = []
if traces is not None:
def compute_imbalances(self, tx_hash: str) -> Optional[Dict[str, int]]:
try:
tx_receipt = self.get_transaction_receipt(tx_hash)
if not tx_receipt:
logger.error("No transaction receipt found for %s", tx_hash)
return None

traces = self.get_transaction_trace(tx_hash)
if traces is None:
logger.error(
"Error fetching transaction trace for %s. Marking transaction as unprocessed.",
tx_hash,
)
return None

events = self.extract_events(tx_receipt)
imbalances = self.calculate_imbalances(events, SETTLEMENT_CONTRACT_ADDRESS)

native_eth_imbalance = None
actions = []
actions = self.extract_actions(traces, SETTLEMENT_CONTRACT_ADDRESS)
native_eth_imbalance = self.calculate_native_eth_imbalance(
actions, SETTLEMENT_CONTRACT_ADDRESS
)

events = self.extract_events(tx_receipt)
imbalances = self.calculate_imbalances(events, SETTLEMENT_CONTRACT_ADDRESS)

if actions:
self.update_weth_imbalance(
events, actions, imbalances, SETTLEMENT_CONTRACT_ADDRESS
)
self.update_native_eth_imbalance(imbalances, native_eth_imbalance)
if actions:
self.update_weth_imbalance(
events, actions, imbalances, SETTLEMENT_CONTRACT_ADDRESS
)
self.update_native_eth_imbalance(imbalances, native_eth_imbalance)

self.update_sdai_imbalance(events, imbalances)
self.update_sdai_imbalance(events, imbalances)
return imbalances

return imbalances
except Exception as e:
logger.error("Error computing imbalances for %s: %s", tx_hash, e)
return None


def main() -> None:
Expand All @@ -329,9 +338,12 @@ def main() -> None:
rt = RawTokenImbalances(web3, chain_name)
try:
imbalances = rt.compute_imbalances(tx_hash)
logger.info(f"Token Imbalances on {chain_name}:")
for token_address, imbalance in imbalances.items():
logger.info(f"Token: {token_address}, Imbalance: {imbalance}")
if imbalances is not None:
logger.info(f"Token Imbalances on {chain_name}:")
for token_address, imbalance in imbalances.items():
logger.info(f"Token: {token_address}, Imbalance: {imbalance}")
else:
raise ValueError("Imbalances computation returned None.")
except ValueError as e:
logger.error(e)

Expand Down

0 comments on commit 72e101d

Please sign in to comment.