Skip to content

Commit

Permalink
Merge pull request #17 from marigold-dev/branches-cleanup
Browse files Browse the repository at this point in the history
API routes update
  • Loading branch information
aguillon authored Nov 21, 2023
2 parents 5d19c3f + cfe7447 commit 7dc4cf0
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 73 deletions.
28 changes: 22 additions & 6 deletions src/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_contracts_by_credit(db: Session, credit_id: str):
models.Contract.credit_id == credit_id).all()


def get_contract(db: Session, address: str):
def get_contract_by_address(db: Session, address: str):
"""
Return a models.Contract or raise ContractNotFound exception
"""
Expand All @@ -65,21 +65,37 @@ def get_contract(db: Session, address: str):
raise ContractNotFound() from e


def get_entrypoints(db: Session,
contract_address: str) -> List[models.Entrypoint]:
def get_contract(db: Session, contract_id: str):
"""
Return a models.Contract or raise ContractNotFound exception
"""
try:
return db.query(models.Contract).get(contract_id)
except NoResultFound as e:
raise ContractNotFound() from e


def get_entrypoints(
db: Session,
contract_address_or_id: str
) -> List[models.Entrypoint]:
"""
Return a list of models.Contract or raise ContractNotFound exception
"""
contract = get_contract(db, contract_address)
if contract_address_or_id.startswith("KT"):
contract = get_contract_by_address(db, contract_address_or_id)
else:
contract = get_contract(db, contract_address_or_id)
return contract.entrypoints


def get_entrypoint(db: Session, contract_address: str,
def get_entrypoint(db: Session,
contract_address_or_id: str,
name: str) -> Optional[models.Entrypoint]:
"""
Return a models.Entrypoint or raise EntrypointNotFound exception
"""
entrypoints = get_entrypoints(db, contract_address)
entrypoints = get_entrypoints(db, contract_address_or_id)
entrypoint = [e for e in entrypoints if e.name == name] # type: ignore
if len(entrypoint) == 0:
raise EntrypointNotFound()
Expand Down
141 changes: 87 additions & 54 deletions src/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.orm import Session
from . import tezos
from pytezos.rpc.errors import MichelsonError
from pytezos.crypto.encoding import is_address
from .utils import ContractNotFound, CreditNotFound, EntrypointNotFound, UserNotFound


Expand Down Expand Up @@ -61,7 +62,7 @@ async def update_credits(
amount)
if not is_confirmed:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Could not find confirmation for {amount} with {op_hash}"
)
return crud.update_credits(db, credits)
Expand Down Expand Up @@ -95,14 +96,14 @@ async def withdraw_credits(
)
if credits.amount < withdraw.amount:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_400_BAD_REQUEST,
detail="Not enough funds to withdraw."
)

expected_counter = credits.owner.withdraw_counter or 0
if expected_counter != withdraw.withdraw_counter:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_400_BAD_REQUEST,
detail="Bad withdraw counter."
)

Expand All @@ -114,7 +115,7 @@ async def withdraw_credits(
public_key)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid signature."
)
# We increment the counter even if the withdraw fails to prevent
Expand All @@ -139,39 +140,31 @@ async def withdraw_credits(


# Users and credits getters
@router.get("/users/{user_address}", response_model=schemas.User)
async def get_user(user_address: str, db: Session = Depends(database.get_db)):
@router.get("/users/{address_or_id}", response_model=schemas.User)
async def get_user(address_or_id: str, db: Session = Depends(database.get_db)):
try:
return crud.get_user_by_address(db, user_address)
if is_address(address_or_id) and address_or_id.startswith("tz"):
return crud.get_user_by_address(db, address_or_id)
else:
return crud.get_user(db, address_or_id)
except UserNotFound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User not found.",
)


@router.get("/withdraw_counter/{user_address}",
response_model=schemas.WithdrawCounter)
async def get_withdraw_counter(user_address: str,
db: Session = Depends(database.get_db)):
try:
counter = crud.get_user_by_address(db, user_address).withdraw_counter
if counter is None:
counter = 0
return schemas.WithdrawCounter(counter=counter)
except UserNotFound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User not found.",
)


@router.get("/credits/{user_id}", response_model=list[schemas.Credit])
@router.get("/credits/{user_address_or_id}",
response_model=list[schemas.Credit])
async def credits_for_user(
user_id: str, db: Session = Depends(database.get_db)
user_address_or_id: str, db: Session = Depends(database.get_db)
):
try:
return crud.get_user(db, user_id).credits
if is_address(user_address_or_id) \
and user_address_or_id.startswith("tz"):
return crud.get_user_by_address(db, user_address_or_id).credits
else:
return crud.get_user(db, user_address_or_id).credits
except UserNotFound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -180,11 +173,13 @@ async def credits_for_user(


# Contracts
@router.get("/contracts/user/{user_address}", response_model=list[schemas.Contract])
async def get_user_contracts(user_address: str, db: Session = Depends(database.get_db)):
@router.get("/contracts/user/{user_address}",
response_model=list[schemas.Contract])
async def get_user_contracts(user_address: str,
db: Session = Depends(database.get_db)):
try:
return crud.get_contracts_by_user(db, user_address)
except UserNotFound as e:
except UserNotFound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"User not found."
)
Expand All @@ -195,15 +190,19 @@ async def get_user_contracts(user_address: str, db: Session = Depends(database.g
async def get_credit(credit_id: str, db: Session = Depends(database.get_db)):
try:
return crud.get_contracts_by_credit(db, credit_id)
except CreditNotFound as e:
except CreditNotFound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"Credit not found."
)


@router.get("/contracts/{address}", response_model=schemas.Contract)
async def get_contract(address: str, db: Session = Depends(database.get_db)):
contract = crud.get_contract(db, address)
@router.get("/contracts/{address_or_id}", response_model=schemas.Contract)
async def get_contract(address_or_id: str,
db: Session = Depends(database.get_db)):
if is_address(address_or_id) and address_or_id.startswith("KT"):
contract = crud.get_contract_by_address(db, address_or_id)
else:
contract = crud.get_contract(db, address_or_id)
if not contract:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"Contract not found."
Expand All @@ -212,24 +211,35 @@ async def get_contract(address: str, db: Session = Depends(database.get_db)):


# Entrypoints
@router.get("/entrypoints/{contract_address}", response_model=list[schemas.Entrypoint])
@router.get("/entrypoints/{contract_address_or_id}",
response_model=list[schemas.Entrypoint])
async def get_entrypoints(
contract_address: str, db: Session = Depends(database.get_db)
contract_address_or_id: str, db: Session = Depends(database.get_db)
):
try:
return crud.get_entrypoints(db, contract_address)
except ContractNotFound as e:
if contract_address_or_id.startswith("KT"):
assert is_address(contract_address_or_id)
return crud.get_entrypoints(db, contract_address_or_id)
except ContractNotFound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"Contract not found."
)
except AssertionError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid address."
)

@router.get("/entrypoints/{contract_address}/{name}", response_model=schemas.Entrypoint)

@router.get("/entrypoints/{contract_address_or_id}/{name}",
response_model=schemas.Entrypoint)
async def get_entrypoint(
contract_address: str, name: str, db: Session = Depends(database.get_db)
contract_address_or_id: str,
name: str,
db: Session = Depends(database.get_db)
):
try:
return crud.get_entrypoint(db, contract_address, name)
except EntrypointNotFound as e:
return crud.get_entrypoint(db, contract_address_or_id, name)
except EntrypointNotFound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"Entrypoint not found."
)
Expand All @@ -238,11 +248,11 @@ async def get_entrypoint(
# Operations
@router.post("/operation")
async def post_operation(
call_data: schemas.CallData, db: Session = Depends(database.get_db)
call_data: schemas.UnsignedCall, db: Session = Depends(database.get_db)
):
if len(call_data.operations) == 0:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Empty operations list",
)
# TODO: check that amount=0?
Expand All @@ -252,26 +262,24 @@ async def post_operation(
# Transfers to implicit accounts are always refused
if not contract_address.startswith("KT"):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Target {contract_address} is not allowed",
)
try:
contract = crud.get_contract(db, contract_address)
contract = crud.get_contract_by_address(db, contract_address)
except ContractNotFound:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Target {contract_address} is not allowed",
)

entrypoint_name = operation["parameters"]["entrypoint"]
print(contract_address, entrypoint_name)
try:
entrypoint = crud.get_entrypoint(db,
str(contract.address),
entrypoint_name)
crud.get_entrypoint(db, str(contract.address), entrypoint_name)
except EntrypointNotFound:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Entrypoint {entrypoint_name} is not allowed",
)

Expand All @@ -284,17 +292,19 @@ async def post_operation(
estimated_fees = tezos.group_fees(op_estimated_fees)
if not tezos.check_credits(db, estimated_fees):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
status_code=status.HTTP_400_BAD_REQUEST,
detail="Not enough funds."
)
result = await tezos.tezos_manager.queue_operation(call_data.sender,
op)
result = await tezos.tezos_manager.queue_operation(
call_data.sender_address,
op
)
except MichelsonError as e:
print("Received failing operation, discarding")
print(e)
raise HTTPException(
# FIXME? Is this the best one?
status_code=status.HTTP_409_CONFLICT,
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Operation is invalid",
)
except Exception:
Expand All @@ -304,3 +314,26 @@ async def post_operation(
detail=f"Unknown exception raised.",
)
return result


@router.post("/signed_operation")
async def signed_operation(
call_data: schemas.SignedCall, db: Session = Depends(database.get_db)
):
# In order for the user to sign Micheline, we need to
# FIXME: this is a serious issue, we should sign the contract address too.
signed_data = [x["parameters"]["value"] for x in call_data.operations]
if not tezos.check_signature(
signed_data,
call_data.signature,
call_data.sender_key,
call_data.micheline_type
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid signature."
)
address = tezos.public_key_hash(call_data.sender_key)
call_data = schemas.UnsignedCall(sender_address=address,
operations=call_data.operations)
return await post_operation(call_data, db)
13 changes: 10 additions & 3 deletions src/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,15 @@ class ContractCreation(ContractBase):
credit_id: UUID4

# Operations
# TODO: right now the sender isn't checked, as we use permits anyway
class CallData(BaseModel):
sender: str
class UnsignedCall(BaseModel):
"""Data sent when posting an operation. The sender is mandatory."""
sender_address: str
operations: list[dict[str, Any]]


class SignedCall(BaseModel):
"""Data sent when posting an operation. The signature"""
sender_key: str
operations: list[dict[str, Any]]
signature: str
micheline_type: Any
27 changes: 17 additions & 10 deletions src/tezos.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,17 @@ def get_public_key(address):
return key


def check_signature(pair_data, signature, public_key):
# Type of a withdraw operation
pair_type = {
"prim": 'pair',
"args": [
{"prim": 'string'},
{"prim": "int"},
{"prim": 'mutez'}
]
}
def check_signature(pair_data, signature, public_key, pair_type=None):
if pair_type is None:
# Type of a withdraw operation
pair_type = {
"prim": 'pair',
"args": [
{"prim": 'string'},
{"prim": "int"},
{"prim": 'mutez'}
]
}
public_key = pytezos.Key.from_encoded_key(public_key)
matcher = MichelsonType.match(pair_type)
packed_pair = matcher.from_micheline_value(pair_data).pack()
Expand All @@ -143,6 +144,11 @@ def check_signature(pair_data, signature, public_key):
return False


def public_key_hash(public_key: str):
key = pytezos.Key.from_encoded_key(public_key)
return key.public_key_hash()


async def withdraw(tezos_manager, to, amount):
op = ptz.transaction(source=ptz.key.public_key_hash(),
destination=to,
Expand All @@ -160,6 +166,7 @@ def __init__(self, ptz):
# Receive an operation from sender and add it to the waiting queue;
# blocks until there is a result in self.results
async def queue_operation(self, sender, operation):
print(operation)
self.results[sender] = "waiting"
self.ops_queue[sender] = operation
while self.results[sender] == "waiting":
Expand Down

0 comments on commit 7dc4cf0

Please sign in to comment.