Skip to content

boncheolgu/tflite-rs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Rust

Rust bindings for TensorFlow Lite

This crates provides TensorFlow Lite APIs. Please read the API documentation on docs.rs

Using the interpreter from a model file

The following example shows how to use the TensorFlow Lite interpreter when provided a TensorFlow Lite FlatBuffer file. The example also demonstrates how to run inference on input data.

use std::fs::{self, File};
use std::io::Read;

use tflite::ops::builtin::BuiltinOpResolver;
use tflite::{FlatBufferModel, InterpreterBuilder, Result};

fn test_mnist(model: &FlatBufferModel) -> Result<()> {
    let resolver = BuiltinOpResolver::default();

    let builder = InterpreterBuilder::new(model, &resolver)?;
    let mut interpreter = builder.build()?;

    interpreter.allocate_tensors()?;

    let inputs = interpreter.inputs().to_vec();
    assert_eq!(inputs.len(), 1);

    let input_index = inputs[0];

    let outputs = interpreter.outputs().to_vec();
    assert_eq!(outputs.len(), 1);

    let output_index = outputs[0];

    let input_tensor = interpreter.tensor_info(input_index).unwrap();
    assert_eq!(input_tensor.dims, vec![1, 28, 28, 1]);

    let output_tensor = interpreter.tensor_info(output_index).unwrap();
    assert_eq!(output_tensor.dims, vec![1, 10]);

    let mut input_file = File::open("data/mnist10.bin")?;
    for i in 0..10 {
        input_file.read_exact(interpreter.tensor_data_mut(input_index)?)?;

        interpreter.invoke()?;

        let output: &[u8] = interpreter.tensor_data(output_index)?;
        let guess = output.iter().enumerate().max_by(|x, y| x.1.cmp(y.1)).unwrap().0;

        println!("{}: {:?}", i, output);
        assert_eq!(i, guess);
    }
    Ok(())
}

#[test]
fn mobilenetv1_mnist() -> Result<()> {
    test_mnist(&FlatBufferModel::build_from_file("data/MNISTnet_uint8_quant.tflite")?)?;

    let buf = fs::read("data/MNISTnet_uint8_quant.tflite")?;
    test_mnist(&FlatBufferModel::build_from_buffer(buf)?)
}

#[test]
fn mobilenetv2_mnist() -> Result<()> {
    test_mnist(&FlatBufferModel::build_from_file("data/MNISTnet_v2_uint8_quant.tflite")?)?;

    let buf = fs::read("data/MNISTnet_v2_uint8_quant.tflite")?;
    test_mnist(&FlatBufferModel::build_from_buffer(buf)?)
}

Using the FlatBuffers model APIs

This crate also provides a limited set of FlatBuffers model APIs.

use tflite::model::stl::vector::{VectorInsert, VectorErase, VectorSlice};
use tflite::model::{BuiltinOperator, BuiltinOptions, Model, SoftmaxOptionsT};

#[test]
fn flatbuffer_model_apis_inspect() {
    let model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap();
    assert_eq!(model.version, 3);
    assert_eq!(model.operator_codes.size(), 5);
    assert_eq!(model.subgraphs.size(), 1);
    assert_eq!(model.buffers.size(), 24);
    assert_eq!(
        model.description.c_str().to_string_lossy(),
        "TOCO Converted."
    );

    assert_eq!(
        model.operator_codes[0].builtin_code,
        BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D
    );

    assert_eq!(
        model
            .operator_codes
            .iter()
            .map(|oc| oc.builtin_code)
            .collect::<Vec<_>>(),
        vec![
            BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D,
            BuiltinOperator::BuiltinOperator_CONV_2D,
            BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D,
            BuiltinOperator::BuiltinOperator_SOFTMAX,
            BuiltinOperator::BuiltinOperator_RESHAPE
        ]
    );

    let subgraph = &model.subgraphs[0];
    assert_eq!(subgraph.tensors.size(), 23);
    assert_eq!(subgraph.operators.size(), 9);
    assert_eq!(subgraph.inputs.as_slice(), &[22]);
    assert_eq!(subgraph.outputs.as_slice(), &[21]);

    let softmax = subgraph
        .operators
        .iter()
        .position(|op| {
            model.operator_codes[op.opcode_index as usize].builtin_code
                == BuiltinOperator::BuiltinOperator_SOFTMAX
        })
        .unwrap();

    assert_eq!(subgraph.operators[softmax].inputs.as_slice(), &[4]);
    assert_eq!(subgraph.operators[softmax].outputs.as_slice(), &[21]);
    assert_eq!(
        subgraph.operators[softmax].builtin_options.type_,
        BuiltinOptions::BuiltinOptions_SoftmaxOptions
    );

    let softmax_options: &SoftmaxOptionsT = subgraph.operators[softmax].builtin_options.as_ref();
    assert_eq!(softmax_options.beta, 1.);
}

#[test]
fn flatbuffer_model_apis_mutate() {
    let mut model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap();
    model.version = 2;
    model.operator_codes.erase(4);
    model.buffers.erase(22);
    model.buffers.erase(23);
    model
        .description
        .assign(CString::new("flatbuffer").unwrap());

    {
        let subgraph = &mut model.subgraphs[0];
        subgraph.inputs.erase(0);
        subgraph.outputs.assign(vec![1, 2, 3, 4]);
    }

    let model_buffer = model.to_buffer();
    let model = Model::from_buffer(&model_buffer);
    assert_eq!(model.version, 2);
    assert_eq!(model.operator_codes.size(), 4);
    assert_eq!(model.subgraphs.size(), 1);
    assert_eq!(model.buffers.size(), 22);
    assert_eq!(model.description.c_str().to_string_lossy(), "flatbuffer");

    let subgraph = &model.subgraphs[0];
    assert_eq!(subgraph.tensors.size(), 23);
    assert_eq!(subgraph.operators.size(), 9);
    assert!(subgraph.inputs.as_slice().is_empty());
    assert_eq!(subgraph.outputs.as_slice(), &[1, 2, 3, 4]);
}

About

No description, website, or topics provided.

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Packages

No packages published

Languages