This repo contains simple gRPC server implementation for serving MNIST digit classification requests using Rust stack. Purpose of this project is personal learning and experimentation with tonic for rust implementation of gRPC; and candle, a ml/tensor framework in Rust.
One nice feature of using Rust in ml-inference is the lightweight deployments. While Python deployments with deep learning
frameworks like Pytorch often result in container sizes of multiple GBs, with candle
, the release binary size of the
server implementation is only ~6MB (without model weights).
This project supports two (rather trivial) neural network architectures for MNIST classification:
-
Multi-Layer Perceptron (MLP)
- 3 fully connected layers: 784 → 128 → 64 → 10
- Simple feedforward network
- Good baseline performance
-
Convolutional Neural Network (ConvNet)
- 2 convolutional layers (1→32→64 channels) with ReLU and max pooling
- 2 fully connected layers: 3136 → 128 → 10
- Better feature extraction and higher accuracy
These models are defined in the mnist
sub-crate.
- Rust (latest stable version)
- Python 3.8+ (for training)
- uv (Python package manager)
Model training is done in Python with Pytorch:
-
Navigate to the training directory:
cd training
-
Install Python dependencies:
uv sync
-
Train the model and save weights:
uv run python train.py --output ../models/mnist_convnet.safetensors
The training script will:
- Download the MNIST dataset automatically
- Train a ConvNet for 3 epochs
- Display training progress and final test accuracy
- Save the model weights in SafeTensors format
-
Build the server:
cargo build --release
-
Start the gRPC server:
cargo run --release --bin grpc-server -- --model-architecture conv --model-weights models/mnist_convnet.safetensors
The server will start on
[::1]:50051
by default.
You can check other available CLI args with --help
.
As the protocol expects the images to be sent as raw bytes, one can convert image to base64 and create a request in JSON format:
echo '{"data": "'$(base64 -w 0 -i ~/Desktop/four.png)'"}' > test_request.json
Using grpcurl, such requests can be sent to the server:
grpcurl -plaintext -proto ./proto/mnist.proto \
-d @ \
'[::1]:50051' mnist.Mnist.Predict \
< your_request.json
This should respond with something like:
{
"label": 4,
"probabilities": [0.001, 0.002, 0.003, 0.004, 0.985, 0.002, 0.001, 0.001, 0.001, 0.000]
}