Skip to content

JNapoli/aether

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

72 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Aether

This repository contains my Aether Biomachines take-home assignment. In this repo, we implement a neural network training pipeline for MNIST handwritten digit classification using PyTorch.

Overview

This repo contains scripts for training a convolutional neural network on the MNIST data set and for launching an inference server that serves the resulting model. The aether package in this directory contains classes that simplify the training pipeline.

Installation

The following steps have been verified to be reproducible on MacOS Catalina. The code requires Python v3.7.5 and can be reproduced according to the following steps:

  1. Clone this repository:
git clone https://github.com/JNapoli/aether.git
cd aether/
  1. Create and activate a Python environment (Miniconda is recommended for this), providing the dependencies in requirements-conda.txt:
conda create --name aether python=3.7.5 --file requirements-conda.txt
source activate aether
  1. Source the env.sh file in the root directory. This will add the aether package to your PYTHONPATH environment variable:
source env.sh
  1. Confirm that the PYTHONPATH environment variable was properly modified:
echo $PYTHONPATH

You should be good to go!

Usage

Re-fitting the model

A pre-trained model can be found in ./models/pretrained/. To fit a new model, run train_model.py in the bin directory:

python train_model.py /path/to/output/model.pt /path/for/storing/datasets/ \
       /path/to/directory/for/saving/output/data/ -epochs 10

For more details, you can run:

python train_model.py -h

The above will train the model over 10 epochs on the MNIST dataset and output results / the model to the specified locations. Please use full paths for specifying files and directories. After training, the script will print out the accuracy of the resulting model evaluated over the test set.

The pre-trained model included here was trained for 50 epochs on the torchvision MNIST training dataset and scores 98% test accuracy on the MNIST test set. This robust performance far exceeds the random classification accuracy of 10% that would be expected if the model did not learn anything, and indicates the model is not pathologically overfitting the training data.

Serving the model

The resulting model can be served using the Flask app in app.py, for example:

python app.py /repo_root_dir/models/pretrained/pretrained.pt

Running the above will launch the app and indicate where it is running (e.g. on my machine, it indicates Running on http://127.0.0.1:5000/).

Inference

Inference may now be performed on additional data by making requests to the address specified above. For example, the short python script:

import requests

response = requests.post(
    'http://127.0.0.1:5000/predict',
    files={'file': open('./mnist_image.jpg', 'rb')}
)
print('Result:')
print(response.json()['class_name'])
print(response.json()['class_prob'])

Would run inference on the image mnist_image.jpg and print the result. response.json() returns a dictionary containing both the predicted class of the number in the image as well as the probability associated with the prediction.

In this repo, run_inference.py is a script that runs inference over a collection of jpg images contained in a directory. Running:

python run_inference.py /path/to/MNIST_jpeg_images/ http://127.0.0.1:5000/predict \
       /path/to/result.csv

will run inference for each jpg image contained in the provided directory and output a csv file containing the image name, the predicted class, and the probability. Inference was benchmarked on my machine and the rate was 96 images per second. Further performance gains may be achieved by migrating to a more scalable production server other than Flask.

For convenience, a zipped directory is included in this repo which contains MNIST images that can be used to test inference. Inference has only been tested for jpg images, though support for other file types may be added later.

Contributing

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

License

MIT

About

Take-home assignment for Aetherbio

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published