Skip to content

Commit

Permalink
fixapiregistrationpolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
Dacksus committed Jan 9, 2025
1 parent 08c29d5 commit ed49852
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
6 changes: 5 additions & 1 deletion python/src/uagents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,11 @@ async def register(self):
assert self._registration_policy is not None, "Agent has no registration policy"

await self._registration_policy.register(
self.address, list(self.protocols.keys()), self._endpoints, self._metadata
self.address,
self._identity,
list(self.protocols.keys()),
self._endpoints,
self._metadata,
)

async def _schedule_registration(self):
Expand Down
25 changes: 12 additions & 13 deletions python/src/uagents/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class AgentRegistrationPolicy(ABC):
async def register(
self,
agent_address: str,
identity: Identity,
protocols: List[str],
endpoints: List[AgentEndpoint],
metadata: Optional[Dict[str, Any]] = None,
Expand All @@ -159,20 +160,19 @@ def add_agent(self, agent_info: AgentInfo, identity: Identity):
class AlmanacApiRegistrationPolicy(AgentRegistrationPolicy):
def __init__(
self,
identity: Identity,
*,
almanac_api: Optional[str] = None,
max_retries: int = ALMANAC_API_MAX_RETRIES,
logger: Optional[logging.Logger] = None,
):
self._almanac_api = almanac_api or ALMANAC_API_URL
self._max_retries = max_retries
self._identity = identity
self._logger = logger or logging.getLogger(__name__)

async def register(
self,
agent_address: str,
identity: Identity,
protocols: List[str],
endpoints: List[AgentEndpoint],
metadata: Optional[Dict[str, Any]] = None,
Expand All @@ -186,7 +186,7 @@ async def register(
)

# sign the attestation
attestation.sign(self._identity)
attestation.sign(identity)

success = await almanac_api_post(
f"{self._almanac_api}/agents", attestation, retries=self._max_retries
Expand Down Expand Up @@ -233,15 +233,13 @@ async def register(self):
class LedgerBasedRegistrationPolicy(AgentRegistrationPolicy):
def __init__(
self,
identity: Identity,
ledger: LedgerClient,
wallet: LocalWallet,
almanac_contract: AlmanacContract,
testnet: bool,
*,
logger: Optional[logging.Logger] = None,
):
self._identity = identity
self._wallet = wallet
self._ledger = ledger
self._testnet = testnet
Expand All @@ -265,6 +263,7 @@ def check_contract_version(self):
async def register(
self,
agent_address: str,
identity: Identity,
protocols: List[str],
endpoints: List[AgentEndpoint],
metadata: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -301,7 +300,7 @@ async def register(

current_time = int(time.time()) - ALMANAC_REGISTRATION_WAIT

signature = self._sign_registration(current_time)
signature = self._sign_registration(identity, current_time)
await self._almanac_contract.register(
self._ledger,
self._wallet,
Expand All @@ -318,7 +317,7 @@ async def register(
def _get_balance(self) -> int:
return self._ledger.query_bank_balance(Address(self._wallet.address()))

def _sign_registration(self, timestamp: int) -> str:
def _sign_registration(self, identity: Identity, timestamp: int) -> str:
"""
Sign the registration data for Almanac contract.
Expand All @@ -333,7 +332,7 @@ def _sign_registration(self, timestamp: int) -> str:
"""
assert self._almanac_contract.address is not None
return self._identity.sign_registration(
return identity.sign_registration(
str(self._almanac_contract.address),
timestamp,
str(self._wallet.address()),
Expand Down Expand Up @@ -400,7 +399,6 @@ async def register(self):
class DefaultRegistrationPolicy(AgentRegistrationPolicy):
def __init__(
self,
identity: Identity,
ledger: LedgerClient,
wallet: LocalWallet,
almanac_contract: Optional[AlmanacContract],
Expand All @@ -411,26 +409,27 @@ def __init__(
):
self._logger = logger or logging.getLogger(__name__)
self._api_policy = AlmanacApiRegistrationPolicy(
identity, almanac_api=almanac_api, logger=logger
almanac_api=almanac_api, logger=logger
)
if almanac_contract is None:
self._ledger_policy = None
else:
self._ledger_policy = LedgerBasedRegistrationPolicy(
identity, ledger, wallet, almanac_contract, testnet, logger=logger
ledger, wallet, almanac_contract, testnet, logger=logger
)

async def register(
self,
agent_address: str,
identity: Identity,
protocols: List[str],
endpoints: List[AgentEndpoint],
metadata: Optional[Dict[str, Any]] = None,
):
# prefer the API registration policy as it is faster
try:
await self._api_policy.register(
agent_address, protocols, endpoints, metadata
agent_address, identity, protocols, endpoints, metadata
)
except Exception as e:
self._logger.warning(
Expand All @@ -443,7 +442,7 @@ async def register(
# schedule the ledger registration
try:
await self._ledger_policy.register(
agent_address, protocols, endpoints, metadata
agent_address, identity, protocols, endpoints, metadata
)
except InsufficientFundsError:
self._logger.warning(
Expand Down

0 comments on commit ed49852

Please sign in to comment.