-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4bde1da
commit 4d34b7a
Showing
9 changed files
with
151 additions
and
68 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
def main( | ||
files: list[str], | ||
host: str, | ||
port: int | ||
): | ||
import requests | ||
from .schemas import CachePayload, read_json | ||
|
||
url = f'http://{host}:{port}/cache' | ||
payload = CachePayload(file_jsons=list(map(read_json, files))) | ||
response = requests.post(url=url, json=payload.model_dump()) | ||
if response.status_code == 200: | ||
print(response.json()['message']) | ||
else: | ||
print(f'Received status code {response.status_code}') | ||
print(response.json()['error']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from argparse import ArgumentParser | ||
|
||
def main(): | ||
parser = ArgumentParser(prog='quill') | ||
subparsers = parser.add_subparsers(dest='command', required=True) | ||
|
||
serve_parser = subparsers.add_parser('serve', help='Start the model server') | ||
serve_parser.add_argument('-config', type=str, required=True, help="Path to the model config file") | ||
serve_parser.add_argument('-weights', type=str, required=True, help="Path to the model weights file") | ||
serve_parser.add_argument('--device', type=str, choices=('cuda', 'cpu'), default='cpu', help="Device to run inference on") | ||
serve_parser.add_argument('--host', type=str, default='127.0.0.1', help="Server host address") | ||
serve_parser.add_argument('--port', type=int, default=5000, help="Server port") | ||
|
||
query_parser = subparsers.add_parser('query', help='Query the model with an Agda export file') | ||
query_parser.add_argument('-file', type=str, required=True, help='Path to an Agda json export') | ||
query_parser.add_argument('--host', type=str, default='127.0.0.1', help='Server host address') | ||
query_parser.add_argument('--port', type=int, default=5000, help='Server port') | ||
query_parser.add_argument('--use_cache', action='store_true', help='Suggest lemmas outside the current scope') | ||
|
||
cache_parser = subparsers.add_parser('cache', help='Cache lemmas from a collection of files') | ||
cache_parser.add_argument('-files', type=str, nargs='*', required=True, help='Paths to Agda export files') | ||
cache_parser.add_argument('--host', type=str, default='127.0.0.1', help='Server host address') | ||
cache_parser.add_argument('--port', type=int, default=5000, help='Server port') | ||
|
||
args = parser.parse_args() | ||
|
||
match args.command: | ||
case 'serve': | ||
from .serve import main | ||
main( | ||
config_path=args.config, | ||
weight_path=args.weights, | ||
device=args.device, | ||
host=args.host, | ||
port=args.port | ||
) | ||
case 'query': | ||
from .query import main | ||
main( | ||
file=args.file, | ||
host=args.host, | ||
port=args.port, | ||
use_cache=args.use_cache | ||
) | ||
case 'cache': | ||
from .cache import main | ||
main( | ||
files=args.files, | ||
host=args.host, | ||
port=args.port, | ||
) | ||
case _: | ||
raise ValueError(f'Unrecognized command {args.command}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
def main( | ||
file: str, | ||
host: str, | ||
port: int, | ||
use_cache: bool | ||
): | ||
import requests | ||
from json import load | ||
from .schemas import PredictPayload, read_json | ||
|
||
url = f'http://{host}:{port}/predict' | ||
payload = PredictPayload(file_json=read_json(file), use_cache=use_cache) | ||
response = requests.post(url=url, json=payload.model_dump()) | ||
if response.status_code == 200: | ||
print(response.json()) | ||
else: | ||
print(f'Received status code {response.status_code}') | ||
print(response.content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from json import load | ||
from pydantic import BaseModel, ValidationError | ||
|
||
class CachePayload(BaseModel): | ||
file_jsons: list[dict] | ||
|
||
|
||
class PredictPayload(BaseModel): | ||
file_json: dict | ||
use_cache: bool | ||
|
||
|
||
def read_json(file): | ||
with open(file, 'r') as f: | ||
return load(f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
def main( | ||
config_path: str, | ||
weight_path: str, | ||
device: str, | ||
host: str, | ||
port: int, | ||
): | ||
from flask import Flask, request, jsonify | ||
|
||
from .schemas import PredictPayload, CachePayload, ValidationError, read_json | ||
from ..nn.inference import Inferer | ||
from ..data.agda.reader import parse_data | ||
|
||
print(f'Initializing model from {config_path}...') | ||
inferer = Inferer(model_config=read_json(config_path), cast_to=device).eval() | ||
print(f'Loading weights from {weight_path}...') | ||
inferer.load(path=weight_path, strict=True, map_location=device) | ||
print('Done.') | ||
|
||
app = Flask(__name__) | ||
@app.route('/cache', methods=['POST']) | ||
def precompute(): | ||
try: | ||
payload = CachePayload(**request.json) | ||
except ValidationError as e: | ||
return jsonify({'error': e.errors()}), 400 | ||
if len(payload.file_jsons) > 0: | ||
inferer.precompute(files=[parse_data(f, validate=True) for f in payload.file_jsons]) | ||
return jsonify({'message': f'Cache updated with {len(inferer.cache)} lemmas'}), 200 | ||
inferer.cache = [] | ||
return jsonify({'message': f'Cache emptied'}), 200 | ||
|
||
|
||
@app.route('/predict', methods=['POST']) | ||
def predict(): | ||
try: | ||
payload = PredictPayload(**request.json) | ||
except ValidationError as e: | ||
return jsonify({'error': e.errors()}), 400 | ||
output = inferer.select_premises(file=parse_data(payload.file_json, validate=True), use_cache=payload.use_cache) | ||
return jsonify(output) | ||
|
||
app.run(debug=False, host=host, port=port) | ||
print(f'Serving on {host}:{port}') |