From 46c57209a21ab69edd035665fc5d0f730c9ad941 Mon Sep 17 00:00:00 2001 From: Dan Montgomery Date: Mon, 18 Nov 2024 14:20:24 -0500 Subject: [PATCH] Add stand-alone payment mode to test app. --- gai-backend/test.py | 175 +++++++++++++++++++++++++++++++++----------- 1 file changed, 132 insertions(+), 43 deletions(-) diff --git a/gai-backend/test.py b/gai-backend/test.py index 469f75084..69d6633ee 100644 --- a/gai-backend/test.py +++ b/gai-backend/test.py @@ -17,11 +17,11 @@ @dataclass class InferenceConfig: provider: str - funder: str - secret: str - chainid: int - currency: str - rpc: str + funder: Optional[str] = None + secret: Optional[str] = None + chainid: Optional[int] = None + currency: Optional[str] = None + rpc: Optional[str] = None @dataclass class ProviderLocation: @@ -55,7 +55,6 @@ class TestConfig: def from_dict(cls, data: Dict) -> 'TestConfig': messages = [] if 'prompt' in data: - # Handle legacy config with single prompt messages = [Message(role="user", content=data['prompt'])] elif 'messages' in data: messages = [Message(**msg) for msg in data['messages']] @@ -85,15 +84,23 @@ class ClientConfig: def from_file(cls, config_path: str) -> 'ClientConfig': with open(config_path) as f: data = json.load(f) + + # Make inference config fields optional + inference_data = data.get('inference', {}) + if not isinstance(inference_data, dict): + inference_data = {} + return cls( - inference=InferenceConfig(**data['inference']), + inference=InferenceConfig(**inference_data), location=LocationConfig.from_dict(data['location']), test=TestConfig.from_dict(data['test']), logging=LoggingConfig(**data['logging']) ) class OrchidLLMTestClient: - def __init__(self, config_path: str, prompt: Optional[str] = None): + def __init__(self, config_path: str, wallet_only: bool = False, inference_only: bool = False, + inference_url: Optional[str] = None, auth_key: Optional[str] = None, + prompt: Optional[str] = None): self.config = ClientConfig.from_file(config_path) if prompt: self.config.test.messages = [Message(role="user", content=prompt)] @@ -101,23 +108,38 @@ def __init__(self, config_path: str, prompt: Optional[str] = None): self._setup_logging() self.logger = logging.getLogger(__name__) - self.web3 = Web3(Web3.HTTPProvider(self.config.inference.rpc)) - self.lottery = Lottery( - self.web3, - chain_id=self.config.inference.chainid - ) - self.account = OrchidAccount( - self.lottery, - self.config.inference.funder, - self.config.inference.secret - ) + self.wallet_only = wallet_only + self.inference_only = inference_only + self.cli_inference_url = inference_url + self.cli_auth_key = auth_key + + if not inference_only: + if not self.config.inference.rpc: + raise Exception("RPC URL required in config for wallet mode") + if not self.config.inference.chainid: + raise Exception("Chain ID required in config for wallet mode") + if not self.config.inference.funder: + raise Exception("Funder address required in config for wallet mode") + if not self.config.inference.secret: + raise Exception("Secret required in config for wallet mode") + + self.web3 = Web3(Web3.HTTPProvider(self.config.inference.rpc)) + self.lottery = Lottery( + self.web3, + chain_id=self.config.inference.chainid + ) + self.account = OrchidAccount( + self.lottery, + self.config.inference.funder, + self.config.inference.secret + ) self.ws = None self.session_id = None self.inference_url = None self.message_queue = asyncio.Queue() self._handler_task = None - + def _setup_logging(self): logging.basicConfig( level=getattr(logging, self.config.logging.level.upper()), @@ -131,17 +153,21 @@ async def _handle_invoice(self, invoice_data: Dict) -> None: recipient = invoice_data['recipient'] commit = invoice_data['commit'] + self.logger.info(f"Received invoice for {amount/1e18} tokens") + + # Create and send ticket immediately ticket_str = self.account.create_ticket( amount=amount, recipient=recipient, commitment=commit ) - await self.ws.send(json.dumps({ + payment = { 'type': 'payment', 'tickets': [ticket_str] - })) - self.logger.info(f"Sent payment ticket for {amount/1e18} tokens") + } + await self.ws.send(json.dumps(payment)) + self.logger.info(f"Sent payment ticket") except Exception as e: self.logger.error(f"Failed to handle invoice: {e}") @@ -154,14 +180,21 @@ async def _billing_handler(self) -> None: self.logger.debug(f"Received WS message: {msg['type']}") if msg['type'] == 'invoice': + # Handle invoice immediately await self._handle_invoice(msg) elif msg['type'] == 'auth_token': self.session_id = msg['session_id'] self.inference_url = msg['inference_url'] - await self.message_queue.put(('auth_received', self.session_id)) + if self.wallet_only: + print(f"\nAuth Token: {self.session_id}") + print(f"Inference URL: {self.inference_url}") + print("\nWallet is active and handling payments. Press Ctrl+C to exit.") + else: + await self.message_queue.put(('auth_received', self.session_id)) elif msg['type'] == 'error': + self.logger.error(f"Received error: {msg['code']}") await self.message_queue.put(('error', msg['code'])) - + except websockets.exceptions.ConnectionClosed: self.logger.info("Billing WebSocket closed") except Exception as e: @@ -169,6 +202,22 @@ async def _billing_handler(self) -> None: await self.message_queue.put(('error', str(e))) async def connect(self) -> None: + if self.inference_only: + if self.cli_auth_key: + self.session_id = self.cli_auth_key + else: + self.session_id = self.config.test.params.get('session_id') + if not self.session_id: + raise Exception("session_id required either in config or via --key parameter") + + if self.cli_inference_url: + self.inference_url = self.cli_inference_url + else: + self.inference_url = self.config.test.params.get('inference_url') + if not self.inference_url: + raise Exception("inference_url required either in config or via --url parameter") + return + try: provider = self.config.inference.provider provider_config = self.config.location.providers.get(provider) @@ -185,16 +234,26 @@ async def connect(self) -> None: 'orchid_account': self.config.inference.funder })) - msg_type, session_id = await self.message_queue.get() - if msg_type != 'auth_received': - raise Exception(f"Authentication failed: {session_id}") - - self.logger.info("Successfully authenticated") + if not self.wallet_only: + msg_type, session_id = await self.message_queue.get() + if msg_type != 'auth_received': + raise Exception(f"Authentication failed: {session_id}") + + self.logger.info("Successfully authenticated") except Exception as e: self.logger.error(f"Connection failed: {e}") raise + async def run_wallet(self) -> None: + """Keep the wallet running and handling payments""" + try: + while True: + await asyncio.sleep(3600) + except asyncio.CancelledError: + self.logger.info("Wallet operation cancelled") + raise + async def send_inference_request(self, retry_count: int = 0) -> Dict: if not self.session_id: raise Exception("Not authenticated") @@ -234,6 +293,8 @@ async def send_inference_request(self, retry_count: int = 0) -> Dict: timeout=30 ) as response: if response.status == 402: + if self.inference_only: + raise Exception("Insufficient balance and unable to pay in inference-only mode") retry_delay = self.config.test.retry_delay self.logger.info(f"Insufficient balance, waiting {retry_delay}s for payment processing...") await asyncio.sleep(retry_delay) @@ -271,26 +332,35 @@ async def close(self) -> None: pass async def __aenter__(self): - """Async context manager support""" return self async def __aexit__(self, exc_type, exc_val, exc_tb): - """Ensure cleanup on context exit""" await self.close() -async def main(config_path: str, prompt: Optional[str] = None): - async with OrchidLLMTestClient(config_path, prompt) as client: +async def main(config_path: str, wallet_only: bool = False, inference_only: bool = False, + inference_url: Optional[str] = None, auth_key: Optional[str] = None, + prompt: Optional[str] = None): + async with OrchidLLMTestClient( + config_path, + wallet_only, + inference_only, + inference_url, + auth_key, + prompt + ) as client: try: await client.connect() - result = await client.send_inference_request() - - print("\nInference Results:") - messages = client.config.test.messages - print(f"Messages:") - for msg in messages: - print(f" {msg.role}: {msg.content}") - print(f"Response: {result['response']}") - print(f"Usage: {json.dumps(result['usage'], indent=2)}") + if wallet_only: + await client.run_wallet() + elif not wallet_only: + result = await client.send_inference_request() + print("\nInference Results:") + messages = client.config.test.messages + print(f"Messages:") + for msg in messages: + print(f" {msg.role}: {msg.content}") + print(f"Response: {result['response']}") + print(f"Usage: {json.dumps(result['usage'], indent=2)}") except Exception as e: print(f"Test failed: {e}") @@ -300,8 +370,27 @@ async def main(config_path: str, prompt: Optional[str] = None): import argparse parser = argparse.ArgumentParser() parser.add_argument("config", help="Path to config file") + parser.add_argument("--wallet", action="store_true", help="Run in wallet-only mode") + parser.add_argument("--inference", action="store_true", help="Run in inference-only mode") + parser.add_argument("--url", help="Override inference URL from config") + parser.add_argument("--key", help="Override auth key from config") parser.add_argument("prompt", nargs="*", help="Optional prompt to override config") args = parser.parse_args() + if args.wallet and args.inference: + print("Cannot specify both --wallet and --inference") + exit(1) + + if (args.url and not args.key) or (args.key and not args.url): + print("Must specify both --url and --key together") + exit(1) + prompt = " ".join(args.prompt) if args.prompt else None - asyncio.run(main(args.config, prompt)) + asyncio.run(main( + args.config, + args.wallet, + args.inference, + args.url, + args.key, + prompt + ))