Skip to content

Commit

Permalink
package cli
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinosKokos committed Jan 7, 2025
1 parent 4bde1da commit 4d34b7a
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 68 deletions.
27 changes: 0 additions & 27 deletions api/query.py

This file was deleted.

41 changes: 0 additions & 41 deletions api/serve.py

This file was deleted.

5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,9 @@
'requests'
],
python_requires='>=3.11',
entry_points={
'console_scripts': [
'agda-quill=quill.api.cli:main'
]
}
)
Empty file added src/quill/api/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions src/quill/api/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def main(
files: list[str],
host: str,
port: int
):
import requests
from .schemas import CachePayload, read_json

url = f'http://{host}:{port}/cache'
payload = CachePayload(file_jsons=list(map(read_json, files)))
response = requests.post(url=url, json=payload.model_dump())
if response.status_code == 200:
print(response.json()['message'])
else:
print(f'Received status code {response.status_code}')
print(response.json()['error'])
53 changes: 53 additions & 0 deletions src/quill/api/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from argparse import ArgumentParser

def main():
parser = ArgumentParser(prog='quill')
subparsers = parser.add_subparsers(dest='command', required=True)

serve_parser = subparsers.add_parser('serve', help='Start the model server')
serve_parser.add_argument('-config', type=str, required=True, help="Path to the model config file")
serve_parser.add_argument('-weights', type=str, required=True, help="Path to the model weights file")
serve_parser.add_argument('--device', type=str, choices=('cuda', 'cpu'), default='cpu', help="Device to run inference on")
serve_parser.add_argument('--host', type=str, default='127.0.0.1', help="Server host address")
serve_parser.add_argument('--port', type=int, default=5000, help="Server port")

query_parser = subparsers.add_parser('query', help='Query the model with an Agda export file')
query_parser.add_argument('-file', type=str, required=True, help='Path to an Agda json export')
query_parser.add_argument('--host', type=str, default='127.0.0.1', help='Server host address')
query_parser.add_argument('--port', type=int, default=5000, help='Server port')
query_parser.add_argument('--use_cache', action='store_true', help='Suggest lemmas outside the current scope')

cache_parser = subparsers.add_parser('cache', help='Cache lemmas from a collection of files')
cache_parser.add_argument('-files', type=str, nargs='*', required=True, help='Paths to Agda export files')
cache_parser.add_argument('--host', type=str, default='127.0.0.1', help='Server host address')
cache_parser.add_argument('--port', type=int, default=5000, help='Server port')

args = parser.parse_args()

match args.command:
case 'serve':
from .serve import main
main(
config_path=args.config,
weight_path=args.weights,
device=args.device,
host=args.host,
port=args.port
)
case 'query':
from .query import main
main(
file=args.file,
host=args.host,
port=args.port,
use_cache=args.use_cache
)
case 'cache':
from .cache import main
main(
files=args.files,
host=args.host,
port=args.port,
)
case _:
raise ValueError(f'Unrecognized command {args.command}')
18 changes: 18 additions & 0 deletions src/quill/api/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
def main(
file: str,
host: str,
port: int,
use_cache: bool
):
import requests
from json import load
from .schemas import PredictPayload, read_json

url = f'http://{host}:{port}/predict'
payload = PredictPayload(file_json=read_json(file), use_cache=use_cache)
response = requests.post(url=url, json=payload.model_dump())
if response.status_code == 200:
print(response.json())
else:
print(f'Received status code {response.status_code}')
print(response.content)
15 changes: 15 additions & 0 deletions src/quill/api/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from json import load
from pydantic import BaseModel, ValidationError

class CachePayload(BaseModel):
file_jsons: list[dict]


class PredictPayload(BaseModel):
file_json: dict
use_cache: bool


def read_json(file):
with open(file, 'r') as f:
return load(f)
44 changes: 44 additions & 0 deletions src/quill/api/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
def main(
config_path: str,
weight_path: str,
device: str,
host: str,
port: int,
):
from flask import Flask, request, jsonify

from .schemas import PredictPayload, CachePayload, ValidationError, read_json
from ..nn.inference import Inferer
from ..data.agda.reader import parse_data

print(f'Initializing model from {config_path}...')
inferer = Inferer(model_config=read_json(config_path), cast_to=device).eval()
print(f'Loading weights from {weight_path}...')
inferer.load(path=weight_path, strict=True, map_location=device)
print('Done.')

app = Flask(__name__)
@app.route('/cache', methods=['POST'])
def precompute():
try:
payload = CachePayload(**request.json)
except ValidationError as e:
return jsonify({'error': e.errors()}), 400
if len(payload.file_jsons) > 0:
inferer.precompute(files=[parse_data(f, validate=True) for f in payload.file_jsons])
return jsonify({'message': f'Cache updated with {len(inferer.cache)} lemmas'}), 200
inferer.cache = []
return jsonify({'message': f'Cache emptied'}), 200


@app.route('/predict', methods=['POST'])
def predict():
try:
payload = PredictPayload(**request.json)
except ValidationError as e:
return jsonify({'error': e.errors()}), 400
output = inferer.select_premises(file=parse_data(payload.file_json, validate=True), use_cache=payload.use_cache)
return jsonify(output)

app.run(debug=False, host=host, port=port)
print(f'Serving on {host}:{port}')

0 comments on commit 4d34b7a

Please sign in to comment.