diff --git a/prediction_market_agent/agents/microchain_agent/functions.py b/prediction_market_agent/agents/microchain_agent/functions.py index dd74e9fe..07ca21d9 100644 --- a/prediction_market_agent/agents/microchain_agent/functions.py +++ b/prediction_market_agent/agents/microchain_agent/functions.py @@ -2,17 +2,24 @@ import typing as t from decimal import Decimal +from eth_utils import to_checksum_address from microchain import Function from prediction_market_agent_tooling.markets.data_models import BetAmount, Currency from prediction_market_agent_tooling.markets.omen.data_models import ( OMEN_FALSE_OUTCOME, OMEN_TRUE_OUTCOME, + OmenUserPosition, get_boolean_outcome, ) from prediction_market_agent_tooling.markets.omen.omen import OmenAgentMarket +from prediction_market_agent_tooling.markets.omen.omen_subgraph_handler import ( + OmenSubgraphHandler, +) +from prediction_market_agent_tooling.tools.balances import get_balances from prediction_market_agent.agents.microchain_agent.utils import ( MicroMarket, + fetch_public_key_from_env, get_omen_binary_market_from_question, get_omen_binary_markets, get_omen_market_token_balance, @@ -100,6 +107,7 @@ def __call__(self) -> float: class BuyTokens(Function): def __init__(self, outcome: str): + self.user_address = fetch_public_key_from_env() self.outcome = outcome super().__init__() @@ -115,14 +123,23 @@ def __call__(self, market: str, amount: float) -> str: outcome_bool = get_boolean_outcome(self.outcome) market_obj: OmenAgentMarket = get_omen_binary_market_from_question(market) + outcome_index = market_obj.get_outcome_index(self.outcome) + market_index_set = outcome_index + 1 + before_balance = get_omen_market_token_balance( - market=market_obj, outcome=outcome_bool + user_address=self.user_address, + market_condition_id=market_obj.condition.id, + market_index_set=market_index_set, ) market_obj.place_bet( outcome_bool, BetAmount(amount=Decimal(amount), currency=Currency.xDai) ) tokens = ( - get_omen_market_token_balance(market=market_obj, outcome=outcome_bool) + get_omen_market_token_balance( + user_address=self.user_address, + market_condition_id=market_obj.condition.id, + market_index_set=market_index_set, + ) - before_balance ) return f"Bought {tokens} {self.outcome} outcome tokens of: {market}" @@ -209,6 +226,39 @@ def __call__(self, summary: str) -> str: return summary +class GetWalletBalance(Function): + @property + def description(self) -> str: + return "Use this function to fetch your balance, given in xDAI units." + + @property + def example_args(self) -> list[str]: + return [] + + def __call__(self, user_address: str) -> Decimal: + # We focus solely on xDAI balance for now to avoid the agent having to wrap/unwrap xDAI. + user_address_checksummed = to_checksum_address(user_address) + balance = get_balances(user_address_checksummed) + return balance.xdai + + +class GetUserPositions(Function): + @property + def description(self) -> str: + return ( + "Use this function to fetch the markets where the user has previously bet." + ) + + @property + def example_args(self) -> list[str]: + return ["0x2DD9f5678484C1F59F97eD334725858b938B4102"] + + def __call__(self, user_address: str) -> list[OmenUserPosition]: + return OmenSubgraphHandler().get_user_positions( + better_address=to_checksum_address(user_address) + ) + + ALL_FUNCTIONS = [ Sum, Product, @@ -221,4 +271,6 @@ def __call__(self, summary: str) -> str: SellNo, # BalanceToOutcomes, SummarizeLearning, + GetWalletBalance, + GetUserPositions, ] diff --git a/prediction_market_agent/agents/microchain_agent/utils.py b/prediction_market_agent/agents/microchain_agent/utils.py index f30530cd..dabea3e7 100644 --- a/prediction_market_agent/agents/microchain_agent/utils.py +++ b/prediction_market_agent/agents/microchain_agent/utils.py @@ -1,12 +1,23 @@ +import os from typing import List, cast +from eth_typing import ChecksumAddress from prediction_market_agent_tooling.markets.agent_market import ( AgentMarket, FilterBy, SortBy, ) from prediction_market_agent_tooling.markets.omen.omen import OmenAgentMarket -from pydantic import BaseModel +from prediction_market_agent_tooling.markets.omen.omen_contracts import ( + OmenConditionalTokenContract, +) +from prediction_market_agent_tooling.markets.omen.omen_subgraph_handler import ( + OmenSubgraphHandler, +) +from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes +from prediction_market_agent_tooling.tools.web3_utils import private_key_to_public_key +from pydantic import BaseModel, SecretStr +from web3.types import Wei class MicroMarket(BaseModel): @@ -42,6 +53,22 @@ def get_omen_binary_market_from_question(market: str) -> OmenAgentMarket: raise ValueError(f"Market '{market}' not found") -def get_omen_market_token_balance(market: OmenAgentMarket, outcome: bool) -> float: - # TODO implement this - return 7.3 +def get_omen_market_token_balance( + user_address: ChecksumAddress, market_condition_id: HexBytes, market_index_set: int +) -> Wei: + # We get the multiple positions for each market + positions = OmenSubgraphHandler().get_positions(market_condition_id) + # Find position matching market_outcome + position_for_index_set = next( + p for p in positions if market_index_set in p.indexSets + ) + position_as_int = int(position_for_index_set.id.hex(), 16) + balance = OmenConditionalTokenContract().balanceOf(user_address, position_as_int) + return balance + + +def fetch_public_key_from_env() -> ChecksumAddress: + private_key = os.environ.get("BET_FROM_PRIVATE_KEY") + if private_key is None: + raise EnvironmentError("Could not load private key using env var") + return private_key_to_public_key(SecretStr(private_key)) diff --git a/tests/agents/test_microchain.py b/tests/agents/test_microchain.py index cc024abb..90b7dfd9 100644 --- a/tests/agents/test_microchain.py +++ b/tests/agents/test_microchain.py @@ -1,15 +1,37 @@ +import typing as t + import pytest +from dotenv import load_dotenv +from eth_typing import HexAddress, HexStr +from prediction_market_agent_tooling.markets.omen.omen import OmenAgentMarket +from prediction_market_agent_tooling.markets.omen.omen_subgraph_handler import ( + OmenSubgraphHandler, +) +from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes +from web3 import Web3 from prediction_market_agent.agents.microchain_agent.functions import ( BuyNo, BuyYes, GetMarkets, + GetUserPositions, + GetWalletBalance, ) from prediction_market_agent.agents.microchain_agent.utils import ( get_omen_binary_markets, + get_omen_market_token_balance, ) from tests.utils import RUN_PAID_TESTS +REPLICATOR_ADDRESS = "0x993DFcE14768e4dE4c366654bE57C21D9ba54748" +AGENT_0_ADDRESS = "0x2DD9f5678484C1F59F97eD334725858b938B4102" + + +@pytest.fixture(scope="session", autouse=True) +def before_all_tests() -> t.Generator[None, None, None]: + load_dotenv() + yield None + def test_get_markets() -> None: get_markets = GetMarkets() @@ -28,3 +50,44 @@ def test_buy_no() -> None: market = get_omen_binary_markets()[0] buy_yes = BuyNo() print(buy_yes(market.question, 0.0001)) + + +def test_replicator_has_balance_gt_0() -> None: + balance = GetWalletBalance()(REPLICATOR_ADDRESS) + assert balance > 0 + + +def test_agent_0_has_bet_on_market() -> None: + user_positions = GetUserPositions()(AGENT_0_ADDRESS) + # Assert 3 conditionIds are included + expected_condition_ids = [ + HexBytes("0x9c7711bee0902cc8e6838179058726a7ba769cc97d4d0ea47b31370d2d7a117b"), + HexBytes("0xe2bf80af2a936cdabeef4f511620a2eec46f1caf8e75eb5dc189372367a9154c"), + HexBytes("0x3f8153364001b26b983dd92191a084de8230f199b5ad0b045e9e1df61089b30d"), + ] + unique_condition_ids = sum([u.position.conditionIds for u in user_positions], []) + assert set(expected_condition_ids).issubset(unique_condition_ids) + + +def test_balance_for_user_in_market() -> None: + user_address = AGENT_0_ADDRESS + subgraph_handler = OmenSubgraphHandler() + market_id = HexAddress( + HexStr("0x59975b067b0716fef6f561e1e30e44f606b08803") + ) # yes/no + market = subgraph_handler.get_omen_market(market_id) + omen_agent_market = OmenAgentMarket.from_data_model(market) + balance_yes = get_omen_market_token_balance( + user_address=Web3.to_checksum_address(user_address), + market_condition_id=omen_agent_market.condition.id, + market_index_set=market.condition.index_sets[0], + ) + + assert balance_yes == 1959903969410997 + + balance_no = get_omen_market_token_balance( + user_address=Web3.to_checksum_address(user_address), + market_condition_id=omen_agent_market.condition.id, + market_index_set=market.condition.index_sets[1], + ) + assert balance_no == 0