From d87a0d7d7c5fa709389f7c22890426bf64a25320 Mon Sep 17 00:00:00 2001 From: lykimq Date: Fri, 26 Jan 2024 11:55:17 +0100 Subject: [PATCH] Src: limit the number of calls per new users --- src/crud.py | 35 ++++++++++++++++++++--------------- src/models.py | 14 ++++---------- src/routes.py | 18 ++++++++++++++---- src/schemas.py | 5 ++--- 4 files changed, 40 insertions(+), 32 deletions(-) diff --git a/src/crud.py b/src/crud.py index d9d3cda..0ddf092 100644 --- a/src/crud.py +++ b/src/crud.py @@ -321,22 +321,24 @@ def check_calls_per_month(db, contract_id): def create_max_calls_per_sponsee_condition( db: Session, condition: schemas.CreateMaxCallsPerSponseeCondition ): - # If a condition still exists, do not create a new one + """ + Function to limit the number of calls for new users + rather than for a specific sponsee. + """ existing_condition = ( db.query(models.Condition) - .filter(models.Condition.sponsee_address == condition.sponsee_address) .filter(models.Condition.vault_id == condition.vault_id) .filter(models.Condition.current < models.Condition.max) .one_or_none() ) if existing_condition is not None: raise ConditionAlreadyExists( - "A condition with maximum calls per sponsee already exists and the maximum is not reached. Cannot create a new one." + "A condition with maximum calls per sponsee already exists for this vault. Cannot create a new one." ) + # Create a new condition for the max calls per sponsee db_condition = models.Condition( **{ "type": schemas.ConditionType.MAX_CALLS_PER_SPONSEE, - "sponsee_address": condition.sponsee_address, "vault_id": condition.vault_id, "max": condition.max, "current": 0, @@ -346,7 +348,6 @@ def create_max_calls_per_sponsee_condition( db.commit() db.refresh(db_condition) return schemas.MaxCallsPerSponseeCondition( - sponsee_address=db_condition.sponsee_address, vault_id=db_condition.vault_id, max=db_condition.max, current=db_condition.current, @@ -397,11 +398,15 @@ def create_max_calls_per_entrypoint_condition( ) -def check_max_calls_per_sponsee(db: Session, sponsee_address: str, vault_id: UUID4): +def check_max_calls_per_sponsee(db: Session, vault_id: UUID4): + """ + The functions retrieves the condition that sets the maximum limit + on the number of calls for any address or new user associated with the + specified vault_id. + """ return ( db.query(models.Condition) .filter(models.Condition.type == schemas.ConditionType.MAX_CALLS_PER_SPONSEE) - .filter(models.Condition.sponsee_address == sponsee_address) .filter(models.Condition.vault_id == vault_id) .one_or_none() ) @@ -422,20 +427,20 @@ def check_max_calls_per_entrypoint( def check_conditions(db: Session, datas: schemas.CheckConditions): print(datas) - sponsee_condition = check_max_calls_per_sponsee( - db, datas.sponsee_address, datas.vault_id - ) + + max_calls_condition = check_max_calls_per_sponsee(db, datas.vault_id) + entrypoint_condition = check_max_calls_per_entrypoint( db, datas.contract_id, datas.entrypoint_id, datas.vault_id ) # No condition registered - if sponsee_condition is None and entrypoint_condition is None: + if max_calls_condition is None and entrypoint_condition is None: return True # One of condition is excedeed if ( - sponsee_condition is not None - and (sponsee_condition.current >= sponsee_condition.max) + max_calls_condition is not None + and (max_calls_condition.current_calls_per_user >= max_calls_condition.max_calls_per_user) ) or ( entrypoint_condition is not None and (entrypoint_condition.current >= entrypoint_condition.max) @@ -445,8 +450,8 @@ def check_conditions(db: Session, datas: schemas.CheckConditions): # Update conditions # TODO - Rewrite with list - if sponsee_condition: - update_condition(db, sponsee_condition) + if max_calls_condition: + update_condition(db, max_calls_condition) if entrypoint_condition: update_condition(db, entrypoint_condition) return True diff --git a/src/models.py b/src/models.py index afa0d55..ec407cf 100644 --- a/src/models.py +++ b/src/models.py @@ -133,14 +133,6 @@ class Condition(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) type = Column(Enum(ConditionType)) - sponsee_address = Column( - String, - CheckConstraint( - "(type = 'MAX_CALLS_PER_SPONSEE') = (sponsee_address IS NOT NULL)", - name="sponsee_address_not_null_constraint", - ), - nullable=True, - ) contract_id = Column( UUID(as_uuid=True), ForeignKey("contracts.id"), @@ -153,8 +145,10 @@ class Condition(Base): ) vault_id = Column(UUID(as_uuid=True), ForeignKey( "credits.id"), nullable=False) - max = Column(Integer, nullable=False) - current = Column(Integer, nullable=False) + # New field to store the maximum allowed calls per user + max_calls_per_user = Column(Integer, nullable=False) + # New field to track current calls per user + current_calls_per_user = Column(Integer, nullable=False, default=0) created_at = Column( DateTime(timezone=True), default=datetime.datetime.utcnow(), nullable=False ) diff --git a/src/routes.py b/src/routes.py index 4835a08..6a70371 100644 --- a/src/routes.py +++ b/src/routes.py @@ -292,6 +292,8 @@ async def post_operation( detail=f"Target {contract_address} is not allowed", ) try: + # try to retrieve the contract associated with the destination address + # of each operation from the database contract = crud.get_contract_by_address(db, contract_address) except ContractNotFound: logging.warning(f"{contract_address} is not found") @@ -303,15 +305,19 @@ async def post_operation( entrypoint_name = operation["parameters"]["entrypoint"] try: + # check if the specified entrypoint for each operation exists + # and is enabled in the contract entrypoint = crud.get_entrypoint( db, str(contract.address), entrypoint_name) if not entrypoint.is_enabled: raise EntrypointDisabled() + # Check if certain conditions are met for each operation, + # if not, raises corresponding exceptions if not crud.check_conditions( db, schemas.CheckConditions( - sponsee_address=call_data.sender_address, + sponsee_address=None, # Remove sponsee address contract_id=contract.id, entrypoint_id=entrypoint.id, vault_id=contract.credit_id, @@ -338,7 +344,7 @@ async def post_operation( ) try: - # Simulate the operation alone without sending it + # Simulate the transaction to estimate fees without actually sending it # TODO: log the result op = tezos.simulate_transaction(call_data.operations) @@ -350,16 +356,20 @@ async def post_operation( logging.debug(f"Estimated fees: {estimated_fees}") + # Check if there are enough funds to pay the estimated fees. if not tezos.check_credits(db, estimated_fees): logging.warning(f"Not enough funds to pay estimated fees.") raise NotEnoughFunds( f"Estimated fees : {estimated_fees[str(contract.address)]} mutez" ) + # Check if there have been too many calls made for the contract in the current month if not crud.check_calls_per_month(db, contract.id): # type: ignore logging.warning( f"Too many calls made for this contract this month.") raise TooManyCallsForThisMonth() + # If everything succeeds, it queues the operation for execution, records the + # operation in the database, and returns the result result = await tezos.tezos_manager.queue_operation(call_data.sender_address, op) crud.create_operation( @@ -369,6 +379,7 @@ async def post_operation( user_address=call_data.sender_address, contract_id=str(contract.id), entrypoint_id=str(entrypoint.id), hash=result["transaction_hash"], status=result["result"] ), ) + # Handle exceptions except MichelsonError as e: print("Received failing operation, discarding") logging.error(f"Invalid operation {e}") @@ -453,12 +464,11 @@ async def create_condition( ) elif ( body.type == ConditionType.MAX_CALLS_PER_SPONSEE - and body.sponsee_address is not None ): + # Adjusted to create a condition for maximum calls per user return crud.create_max_calls_per_sponsee_condition( db, schemas.CreateMaxCallsPerSponseeCondition( - sponsee_address=body.sponsee_address, vault_id=body.vault_id, max=body.max, ), diff --git a/src/schemas.py b/src/schemas.py index c7db896..b4a3bcf 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -154,13 +154,12 @@ class CreateMaxCallsPerEntrypointCondition(BaseModel): class CreateMaxCallsPerSponseeCondition(BaseModel): - sponsee_address: str vault_id: UUID4 max: int class CheckConditions(BaseModel): - sponsee_address: str + sponsee_address: Optional[str] contract_id: UUID4 entrypoint_id: UUID4 vault_id: UUID4 @@ -181,4 +180,4 @@ class MaxCallsPerEntrypointCondition(ConditionBase): class MaxCallsPerSponseeCondition(ConditionBase): - sponsee_address: str + pass