Skip to content

Commit

Permalink
Add stand-alone payment mode to test app.
Browse files Browse the repository at this point in the history
  • Loading branch information
danopato committed Nov 18, 2024
1 parent acc996c commit 46c5720
Showing 1 changed file with 132 additions and 43 deletions.
175 changes: 132 additions & 43 deletions gai-backend/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
@dataclass
class InferenceConfig:
provider: str
funder: str
secret: str
chainid: int
currency: str
rpc: str
funder: Optional[str] = None
secret: Optional[str] = None
chainid: Optional[int] = None
currency: Optional[str] = None
rpc: Optional[str] = None

@dataclass
class ProviderLocation:
Expand Down Expand Up @@ -55,7 +55,6 @@ class TestConfig:
def from_dict(cls, data: Dict) -> 'TestConfig':
messages = []
if 'prompt' in data:
# Handle legacy config with single prompt
messages = [Message(role="user", content=data['prompt'])]
elif 'messages' in data:
messages = [Message(**msg) for msg in data['messages']]
Expand Down Expand Up @@ -85,39 +84,62 @@ class ClientConfig:
def from_file(cls, config_path: str) -> 'ClientConfig':
with open(config_path) as f:
data = json.load(f)

# Make inference config fields optional
inference_data = data.get('inference', {})
if not isinstance(inference_data, dict):
inference_data = {}

return cls(
inference=InferenceConfig(**data['inference']),
inference=InferenceConfig(**inference_data),
location=LocationConfig.from_dict(data['location']),
test=TestConfig.from_dict(data['test']),
logging=LoggingConfig(**data['logging'])
)

class OrchidLLMTestClient:
def __init__(self, config_path: str, prompt: Optional[str] = None):
def __init__(self, config_path: str, wallet_only: bool = False, inference_only: bool = False,
inference_url: Optional[str] = None, auth_key: Optional[str] = None,
prompt: Optional[str] = None):
self.config = ClientConfig.from_file(config_path)
if prompt:
self.config.test.messages = [Message(role="user", content=prompt)]

self._setup_logging()
self.logger = logging.getLogger(__name__)

self.web3 = Web3(Web3.HTTPProvider(self.config.inference.rpc))
self.lottery = Lottery(
self.web3,
chain_id=self.config.inference.chainid
)
self.account = OrchidAccount(
self.lottery,
self.config.inference.funder,
self.config.inference.secret
)
self.wallet_only = wallet_only
self.inference_only = inference_only
self.cli_inference_url = inference_url
self.cli_auth_key = auth_key

if not inference_only:
if not self.config.inference.rpc:
raise Exception("RPC URL required in config for wallet mode")
if not self.config.inference.chainid:
raise Exception("Chain ID required in config for wallet mode")
if not self.config.inference.funder:
raise Exception("Funder address required in config for wallet mode")
if not self.config.inference.secret:
raise Exception("Secret required in config for wallet mode")

self.web3 = Web3(Web3.HTTPProvider(self.config.inference.rpc))
self.lottery = Lottery(
self.web3,
chain_id=self.config.inference.chainid
)
self.account = OrchidAccount(
self.lottery,
self.config.inference.funder,
self.config.inference.secret
)

self.ws = None
self.session_id = None
self.inference_url = None
self.message_queue = asyncio.Queue()
self._handler_task = None

def _setup_logging(self):
logging.basicConfig(
level=getattr(logging, self.config.logging.level.upper()),
Expand All @@ -131,17 +153,21 @@ async def _handle_invoice(self, invoice_data: Dict) -> None:
recipient = invoice_data['recipient']
commit = invoice_data['commit']

self.logger.info(f"Received invoice for {amount/1e18} tokens")

# Create and send ticket immediately
ticket_str = self.account.create_ticket(
amount=amount,
recipient=recipient,
commitment=commit
)

await self.ws.send(json.dumps({
payment = {
'type': 'payment',
'tickets': [ticket_str]
}))
self.logger.info(f"Sent payment ticket for {amount/1e18} tokens")
}
await self.ws.send(json.dumps(payment))
self.logger.info(f"Sent payment ticket")

except Exception as e:
self.logger.error(f"Failed to handle invoice: {e}")
Expand All @@ -154,21 +180,44 @@ async def _billing_handler(self) -> None:
self.logger.debug(f"Received WS message: {msg['type']}")

if msg['type'] == 'invoice':
# Handle invoice immediately
await self._handle_invoice(msg)
elif msg['type'] == 'auth_token':
self.session_id = msg['session_id']
self.inference_url = msg['inference_url']
await self.message_queue.put(('auth_received', self.session_id))
if self.wallet_only:
print(f"\nAuth Token: {self.session_id}")
print(f"Inference URL: {self.inference_url}")
print("\nWallet is active and handling payments. Press Ctrl+C to exit.")
else:
await self.message_queue.put(('auth_received', self.session_id))
elif msg['type'] == 'error':
self.logger.error(f"Received error: {msg['code']}")
await self.message_queue.put(('error', msg['code']))

except websockets.exceptions.ConnectionClosed:
self.logger.info("Billing WebSocket closed")
except Exception as e:
self.logger.error(f"Billing handler error: {e}")
await self.message_queue.put(('error', str(e)))

async def connect(self) -> None:
if self.inference_only:
if self.cli_auth_key:
self.session_id = self.cli_auth_key
else:
self.session_id = self.config.test.params.get('session_id')
if not self.session_id:
raise Exception("session_id required either in config or via --key parameter")

if self.cli_inference_url:
self.inference_url = self.cli_inference_url
else:
self.inference_url = self.config.test.params.get('inference_url')
if not self.inference_url:
raise Exception("inference_url required either in config or via --url parameter")
return

try:
provider = self.config.inference.provider
provider_config = self.config.location.providers.get(provider)
Expand All @@ -185,16 +234,26 @@ async def connect(self) -> None:
'orchid_account': self.config.inference.funder
}))

msg_type, session_id = await self.message_queue.get()
if msg_type != 'auth_received':
raise Exception(f"Authentication failed: {session_id}")

self.logger.info("Successfully authenticated")
if not self.wallet_only:
msg_type, session_id = await self.message_queue.get()
if msg_type != 'auth_received':
raise Exception(f"Authentication failed: {session_id}")

self.logger.info("Successfully authenticated")

except Exception as e:
self.logger.error(f"Connection failed: {e}")
raise

async def run_wallet(self) -> None:
"""Keep the wallet running and handling payments"""
try:
while True:
await asyncio.sleep(3600)
except asyncio.CancelledError:
self.logger.info("Wallet operation cancelled")
raise

async def send_inference_request(self, retry_count: int = 0) -> Dict:
if not self.session_id:
raise Exception("Not authenticated")
Expand Down Expand Up @@ -234,6 +293,8 @@ async def send_inference_request(self, retry_count: int = 0) -> Dict:
timeout=30
) as response:
if response.status == 402:
if self.inference_only:
raise Exception("Insufficient balance and unable to pay in inference-only mode")
retry_delay = self.config.test.retry_delay
self.logger.info(f"Insufficient balance, waiting {retry_delay}s for payment processing...")
await asyncio.sleep(retry_delay)
Expand Down Expand Up @@ -271,26 +332,35 @@ async def close(self) -> None:
pass

async def __aenter__(self):
"""Async context manager support"""
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Ensure cleanup on context exit"""
await self.close()

async def main(config_path: str, prompt: Optional[str] = None):
async with OrchidLLMTestClient(config_path, prompt) as client:
async def main(config_path: str, wallet_only: bool = False, inference_only: bool = False,
inference_url: Optional[str] = None, auth_key: Optional[str] = None,
prompt: Optional[str] = None):
async with OrchidLLMTestClient(
config_path,
wallet_only,
inference_only,
inference_url,
auth_key,
prompt
) as client:
try:
await client.connect()
result = await client.send_inference_request()

print("\nInference Results:")
messages = client.config.test.messages
print(f"Messages:")
for msg in messages:
print(f" {msg.role}: {msg.content}")
print(f"Response: {result['response']}")
print(f"Usage: {json.dumps(result['usage'], indent=2)}")
if wallet_only:
await client.run_wallet()
elif not wallet_only:
result = await client.send_inference_request()
print("\nInference Results:")
messages = client.config.test.messages
print(f"Messages:")
for msg in messages:
print(f" {msg.role}: {msg.content}")
print(f"Response: {result['response']}")
print(f"Usage: {json.dumps(result['usage'], indent=2)}")

except Exception as e:
print(f"Test failed: {e}")
Expand All @@ -300,8 +370,27 @@ async def main(config_path: str, prompt: Optional[str] = None):
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config", help="Path to config file")
parser.add_argument("--wallet", action="store_true", help="Run in wallet-only mode")
parser.add_argument("--inference", action="store_true", help="Run in inference-only mode")
parser.add_argument("--url", help="Override inference URL from config")
parser.add_argument("--key", help="Override auth key from config")
parser.add_argument("prompt", nargs="*", help="Optional prompt to override config")
args = parser.parse_args()

if args.wallet and args.inference:
print("Cannot specify both --wallet and --inference")
exit(1)

if (args.url and not args.key) or (args.key and not args.url):
print("Must specify both --url and --key together")
exit(1)

prompt = " ".join(args.prompt) if args.prompt else None
asyncio.run(main(args.config, prompt))
asyncio.run(main(
args.config,
args.wallet,
args.inference,
args.url,
args.key,
prompt
))

0 comments on commit 46c5720

Please sign in to comment.