diff --git a/.github/workflows/rust-app.yml b/.github/workflows/rust-app.yml index 67cd3e1..bf2e3bb 100644 --- a/.github/workflows/rust-app.yml +++ b/.github/workflows/rust-app.yml @@ -15,4 +15,4 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - run: docker build -t headline-predictor-rs \ No newline at end of file + - run: docker build -t headline-predictor-rs . \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index a36dd69..3aa9631 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [[bin]] name = "inference" -path = "src/main.rs" +path = "src/serving/serve.rs" [[bin]] name = "training" @@ -27,4 +27,5 @@ candle-optimisers = "0.3.2" log = "0.4.20" env_logger = "0.11.1" anyhow = "1.0.0" +regex = "1.10.3" polars ={ version = "0.37.0", features=["lazy"] } \ No newline at end of file diff --git a/data/mock_vocab.json b/data/mock_vocab.json new file mode 100644 index 0000000..451a7d9 --- /dev/null +++ b/data/mock_vocab.json @@ -0,0 +1,9 @@ +{ + "vocabulary": [ + "this", + "is", + "an", + "example", + "vocabulary" + ] +} \ No newline at end of file diff --git a/src/common/convert.rs b/src/common/convert.rs deleted file mode 100644 index 6b45a74..0000000 --- a/src/common/convert.rs +++ /dev/null @@ -1,20 +0,0 @@ -use super::exception::InferenceError; - -pub fn pad_vector(mut vector: Vec, max_padding: usize, pad_value: T) -> Vec { - match vector.len() { - len if len > max_padding => vector.truncate(max_padding), - len if len < max_padding => { - vector.extend(std::iter::repeat(pad_value).take(max_padding - len)) - } - _ => (), - } - vector -} - -pub fn convert_to_array(vec: Vec) -> Result<[T; N], InferenceError> -where - T: Default + Clone, -{ - vec.try_into() - .map_err(|_| InferenceError::ArrayConversionError("Could not convert to array.")) -} diff --git a/src/common/encode.rs b/src/common/encode.rs new file mode 100644 index 0000000..8505588 --- /dev/null +++ b/src/common/encode.rs @@ -0,0 +1,188 @@ +use std::collections::HashMap; +use super::MultiHotEncodeError; + +/// Creates a mapping from vocabulary words to their corresponding indices. +/// +/// This function takes a reference to a vector of strings `vocabulary` and creates a `HashMap` +/// where each unique word in the vocabulary is associated with its index (a `u32` value). The +/// resulting mapping is used for tasks such as converting text data into sequences of indices. +/// +/// The special token "" is inserted at index 0 in the mapping to represent any unknown words. +/// When using this mapping to convert words to indices, if a word is not found in the vocabulary, +/// it is mapped to the index of "" (index 0). +/// +/// # Arguments +/// +/// * `vocabulary`: A reference to a vector of strings representing the vocabulary. +/// +/// # Returns +/// +/// A `HashMap` mapping words to their corresponding indices. +pub fn create_vocabulary_to_index_mapping(vocabulary: &Vec) -> HashMap { + let mut vocab_to_index = HashMap::::new(); + + // Special token for any unknown words. In map_to_indices, unwrap_or(0) is used to map any unknown words to this token index. + vocab_to_index.insert("".to_string(), 0); + + for (index, word) in vocabulary.iter().enumerate() { + vocab_to_index.insert(word.clone(), index as u32 + 1); + } + + vocab_to_index +} + +/// Maps a vector of words to their corresponding indices using a mapping. +/// +/// This function takes a vector of strings `words` and a reference to a mapping `mapping`, which +/// associates words with their corresponding indices. It maps each word in the input vector to its +/// index using the provided mapping. If a word is not found in the mapping, it is mapped to 0 by +/// default. +/// +/// # Arguments +/// +/// * `words`: A vector of strings representing the words to be mapped to indices. +/// * `mapping`: A reference to a `HashMap` mapping words to their corresponding indices. +/// +/// # Returns +/// +/// A vector of `u32` values representing the indices of the input words based on the provided mapping. +pub fn map_to_indices(words: Vec, mapping: &HashMap) -> Vec { + words + .iter() + .map(|word| mapping.get(word).copied().unwrap_or(0)) + .collect() +} + +/// Creates mappings between class names and their corresponding indices. +/// +/// This function takes a vector of class names `class_names` and creates two mappings: +/// - A mapping from class names (strings) to their corresponding indices (u32). +/// - A reverse mapping from indices to class names. +/// +/// These mappings are often used in machine learning tasks where class labels need to be +/// represented as indices for model training and evaluation. +/// +/// # Arguments +/// +/// * `class_names`: A vector of strings representing the class names. +/// +/// # Returns +/// +/// A tuple containing two `HashMap` instances: +/// - The first `HashMap` maps class names (strings) to their corresponding indices (u32). +/// - The second `HashMap` maps indices to their corresponding class names. +pub fn create_class_mappings_from_class_names( + class_names: Vec, +) -> (HashMap, HashMap) { + let mut index_to_class: HashMap = HashMap::new(); + let mut class_to_index: HashMap = HashMap::new(); + + for (index, word) in class_names.iter().enumerate() { + let index_u32 = index as u32; + index_to_class.insert(index_u32, word.to_string()); + class_to_index.insert(word.to_string(), index_u32); + } + + (class_to_index, index_to_class) +} + +/// Creates mappings between class labels and their corresponding indices. +/// +/// This function takes a reference to a vector of class labels `labels` and creates two mappings: +/// - A mapping from class labels (strings) to their corresponding indices (u32). +/// - A reverse mapping from indices to class labels. +/// +/// Class labels are often represented as strings separated by '|' characters, allowing multiple +/// labels to be associated with a single data point. This function processes the labels and assigns +/// unique indices to each unique label encountered. +/// +/// # Arguments +/// +/// * `labels`: A reference to a vector of strings representing class labels. +/// +/// # Returns +/// +/// A tuple containing two `HashMap` instances: +/// - The first `HashMap` maps class labels (strings) to their corresponding indices (u32). +/// - The second `HashMap` maps indices to their corresponding class labels. +pub fn create_class_mapping_from_labels( + labels: &Vec, +) -> (HashMap, HashMap) { + let mut class_to_index: HashMap = HashMap::new(); + let mut index_to_class: HashMap = HashMap::new(); + + let mut n_classes = 0; + + for word in labels.iter() { + if word == "" { + continue; + } + let labels: Vec<&str> = word.split('|').collect(); + + for label in labels { + if !class_to_index.contains_key(label) { + class_to_index.insert(label.to_string(), n_classes); + index_to_class.insert(n_classes, label.to_string()); + n_classes += 1; + } + } + } + + (class_to_index, index_to_class) +} + +/// Converts a list of labels into a multi-hot encoding using the provided mapping. +/// +/// Given a list of labels and a mapping of class names to their respective indices, +/// this function generates a multi-hot encoding where each label corresponds to a binary vector. +/// If a label is not found in the mapping, an error is returned. +/// +/// # Arguments +/// +/// * `labels` - A vector of labels to be encoded. +/// * `class_to_index` - A reference to a `HashMap` containing class names as keys and their +/// corresponding indices as values. +/// +/// # Errors +/// +/// If a label in `labels` is not found in `class_to_index`, an `Err` variant is returned +/// with an associated error message indicating which label was not found. +/// +/// # Returns +/// +/// Returns a `Result` where: +/// - `Ok(encodings)` contains the multi-hot encodings as a vector of u32 values. +/// - `Err(err)` contains a `MultiHotEncodeError` with a description of the error. +pub fn multi_hot_encode( + labels: Vec, + class_to_index: &HashMap, +) -> Result, MultiHotEncodeError> { + let n_classes = class_to_index.len(); + let mut all_encodings: Vec = Vec::new(); // Initialize a single encodings vector + + for label in labels { + let mut label_encodings = vec![0u32; n_classes]; + log::debug!("Label: {:?}", label); + if label == "" { + log::debug!("Encoding: {:?}", label_encodings); + all_encodings.append(&mut label_encodings); + // Skip empty labels + continue; + } + + let label_classes: Vec<&str> = label.split('|').collect(); + for label_class in label_classes { + if let Some(&index) = class_to_index.get(&label_class.to_string()) { + label_encodings[index as usize] = 1u32; + } else { + return Err(MultiHotEncodeError::new(&format!("Label not found: {}", label_class))); + } + } + log::debug!("Encoding: {:?}", label_encodings); + + all_encodings.append(&mut label_encodings); + } + log::debug!("All encodings {:?}", all_encodings); + + Ok(all_encodings) +} diff --git a/src/common/exception.rs b/src/common/exception.rs index 263017f..fbc4072 100644 --- a/src/common/exception.rs +++ b/src/common/exception.rs @@ -1,18 +1,52 @@ use std::{error::Error, fmt}; + #[derive(Debug)] -pub enum InferenceError { - ArrayConversionError(&'static str), +pub struct MultiHotEncodeError { + pub message: String, } -impl fmt::Display for InferenceError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - InferenceError::ArrayConversionError(error_message) => { - write!(f, "Inference Error - {}", error_message) - } +impl MultiHotEncodeError { + pub fn new(message: &str) -> MultiHotEncodeError { + MultiHotEncodeError { + message: message.to_string(), } } } -impl Error for InferenceError {} +impl fmt::Display for MultiHotEncodeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.message) + } +} + +impl Error for MultiHotEncodeError {} + +#[derive(Debug)] +pub enum VocabularyLoadError { + Io(std::io::Error), + Json(serde_json::Error), +} + +impl From for VocabularyLoadError { + fn from(err: std::io::Error) -> Self { + VocabularyLoadError::Io(err) + } +} + +impl From for VocabularyLoadError { + fn from(err: serde_json::Error) -> Self { + VocabularyLoadError::Json(err) + } +} + +impl std::error::Error for VocabularyLoadError {} + +impl std::fmt::Display for VocabularyLoadError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VocabularyLoadError::Io(err) => write!(f, "IO error: {}", err), + VocabularyLoadError::Json(err) => write!(f, "JSON error: {}", err), + } + } +} \ No newline at end of file diff --git a/src/common/mapping.rs b/src/common/mapping.rs deleted file mode 100644 index 0ed2931..0000000 --- a/src/common/mapping.rs +++ /dev/null @@ -1,97 +0,0 @@ -use std::{collections::HashMap, hash::Hash, ops::Index}; - -pub fn create_vocabulary_to_index_mapping(vocabulary: &Vec) -> HashMap { - let mut vocab_to_index = HashMap::::new(); - - vocab_to_index.insert("".to_string(), 0); - - for (index, word) in vocabulary.iter().enumerate() { - vocab_to_index.insert(word.clone(), index as u32 + 1); - } - - vocab_to_index -} - -pub fn map_words_to_indices(words: Vec, mapping: &HashMap) -> Vec { - words - .iter() - .map(|word| mapping.get(word).copied().unwrap_or(0)) - .collect() -} - -pub fn create_class_mappings_from_class_names( - class_names: Vec, -) -> (HashMap, HashMap) { - let mut index_to_class: HashMap = HashMap::new(); - let mut class_to_index: HashMap = HashMap::new(); - - for (index, word) in class_names.iter().enumerate() { - index_to_class.insert(index, word.to_string()); - class_to_index.insert(word.to_string(), index); - } - - (index_to_class, class_to_index) -} - -/// Create a class mapping from the labels of the trainig set. -/// -/// This function assumes that labels is a vector of strings, where each string can represent either no classes (empty string), a single class, or multiple classes (delimited by comma). -/// -/// -pub fn create_class_mapping_from_labels( - labels: &Vec, -) -> (HashMap, HashMap) { - let mut class_to_index: HashMap = HashMap::new(); - let mut index_to_class: HashMap = HashMap::new(); - - let mut n_classes = 0; - - for word in labels.iter() { - if word == "" { - continue; - } - let labels: Vec<&str> = word.split('|').collect(); - - for label in labels { - if !class_to_index.contains_key(label) { - class_to_index.insert(label.to_string(), n_classes); - index_to_class.insert(n_classes, label.to_string()); - n_classes += 1; - } - } - } - - (class_to_index, index_to_class) -} - -pub fn multi_hot_encode(labels: Vec, class_to_index: &HashMap) -> Vec { - let n_classes = class_to_index.len(); - let mut all_encodings: Vec = Vec::new(); // Initialize a single encodings vector - - for label in labels { - let mut label_encodings = vec![0u32; n_classes]; - log::debug!("Label: {:?}", label); - if label == "" { - log::debug!("Encoding: {:?}", label_encodings); - all_encodings.append(&mut label_encodings); - continue; // Skip empty labels - } - - let label_classes: Vec<&str> = label.split('|').collect(); - let label_class_indices: Vec = label_classes - .iter() - .map(|label| { - *class_to_index.get(&label.to_string()).unwrap() // Default to 0 if label not found - }) - .collect(); - - for &index in &label_class_indices { - label_encodings[index as usize] = 1u32; - } - log::debug!("Encoding: {:?}", label_encodings); - - all_encodings.append(&mut label_encodings); - } - log::debug!("All encodings {:?}", all_encodings); - all_encodings -} diff --git a/src/common/mod.rs b/src/common/mod.rs index e26f899..83d810f 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,14 +1,11 @@ -pub mod convert; -pub mod exception; -pub mod mapping; +pub mod encode; pub mod model; +pub mod preprocess; pub mod vocabulary; +mod exception; -pub use convert::{convert_to_array, pad_vector}; -pub use exception::InferenceError; -pub use mapping::{ - create_class_mapping_from_labels, create_class_mappings_from_class_names, - create_vocabulary_to_index_mapping, map_words_to_indices, multi_hot_encode, -}; -pub use model::{CategoriesPredictorModel, ModelConfig}; -pub use vocabulary::{make_mock_vocabulary, make_vocabulary}; +use exception::*; +pub use encode::*; +pub use model::*; +pub use preprocess::*; +pub use vocabulary::*; diff --git a/src/common/model.rs b/src/common/model.rs index 70c5b98..f50d797 100644 --- a/src/common/model.rs +++ b/src/common/model.rs @@ -1,9 +1,8 @@ use candle_core::{Result, Tensor}; -use candle_nn::ops::sigmoid; + use candle_nn::{embedding, linear, Embedding, Linear, Module, VarBuilder}; use candle_core::Device; -use candle_optimisers::Model; pub struct ModelConfig { pub device: Device, @@ -14,6 +13,8 @@ pub struct ModelConfig { pub max_seq_len: usize, } +pub const MAX_SEQ_LEN: usize = 128; + impl Default for ModelConfig { fn default() -> Self { Self { @@ -22,7 +23,7 @@ impl Default for ModelConfig { embedding_size: 40, hidden_size: 20, n_classes: 2, - max_seq_len: 128, + max_seq_len: MAX_SEQ_LEN, } } } diff --git a/src/common/preprocess.rs b/src/common/preprocess.rs new file mode 100644 index 0000000..6d983a3 --- /dev/null +++ b/src/common/preprocess.rs @@ -0,0 +1,27 @@ +/// Pads a vector with a specified padding value to reach a maximum length. +/// +/// This function takes an input vector, a maximum padding length, and a padding value. It then +/// modifies the input vector to ensure that it has the maximum length, either by truncating it +/// if it exceeds the maximum length or by adding elements with the padding value if it falls +/// short of the maximum length. +/// +/// # Arguments +/// +/// * `vector`: The input vector that you want to pad. +/// * `max_padding`: The maximum length that the input vector should have after padding. +/// * `pad_value`: The value to use for padding when extending the vector. +/// +/// # Returns +/// +/// A new vector that has been padded to reach the specified maximum length. +/// +pub fn pad_vector(mut vector: Vec, max_padding: usize, pad_value: T) -> Vec { + match vector.len() { + len if len > max_padding => vector.truncate(max_padding), + len if len < max_padding => { + vector.extend(std::iter::repeat(pad_value).take(max_padding - len)) + } + _ => (), + } + vector +} diff --git a/src/common/transform.rs b/src/common/transform.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/common/vocabulary.rs b/src/common/vocabulary.rs index 0a504c8..40fe291 100644 --- a/src/common/vocabulary.rs +++ b/src/common/vocabulary.rs @@ -1,10 +1,36 @@ -use std::collections::HashSet; +use std::{collections::HashSet, fs::File, io::Read}; +use regex::Regex; +use serde::Deserialize; +use super::exception::VocabularyLoadError; +#[derive(Deserialize, Debug)] +struct Vocabulary { + vocabulary: Vec +} + +/// Creates a vocabulary from a given corpus of sentences. +/// +/// The function takes a reference to a vector of strings representing a corpus +/// of sentences and returns a vector of unique words found in the corpus. Words +/// are separated by whitespace in each sentence. +/// +/// # Arguments +/// +/// * `corpus` - A reference to a vector of strings containing sentences. +/// +/// # Returns +/// +/// A vector of unique words found in the corpus. pub fn make_vocabulary(corpus: &Vec) -> Vec { let mut vocabulary: HashSet = HashSet::new(); + let punctuation_regex = Regex::new(r"[[:punct:]]").unwrap(); + for sentence in corpus { - let words: Vec<&str> = sentence.split_whitespace().collect(); + + let sentence_without_punctuation = punctuation_regex.replace_all(sentence, ""); + + let words: Vec<&str> = sentence_without_punctuation.split_whitespace().collect(); for word in words { vocabulary.insert(word.to_string()); @@ -13,20 +39,13 @@ pub fn make_vocabulary(corpus: &Vec) -> Vec { vocabulary.into_iter().collect::>() } -pub fn make_mock_vocabulary() -> Vec { - let vocabulary: Vec = vec![ - "this", - "is", - "an", - "example", - "vocabulary", - "just", - "for", - "show", - ] - .iter() - .map(|word| word.to_string()) - .collect(); - - vocabulary +pub fn load_vocabulary(file_path: &str) -> Result, VocabularyLoadError> { + let mut file = File::open(file_path)?; + let mut json_data = String::new(); + + file.read_to_string(&mut json_data)?; + + let vocabulary: Vocabulary = serde_json::from_str(&json_data)?; + + Ok(vocabulary.vocabulary) } diff --git a/src/lib.rs b/src/lib.rs index f62b4c6..07441e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,3 @@ mod common; -pub use common::{ - convert_to_array, create_class_mapping_from_labels, create_class_mappings_from_class_names, - create_vocabulary_to_index_mapping, make_vocabulary, map_words_to_indices, multi_hot_encode, - pad_vector, CategoriesPredictorModel, ModelConfig, -}; +pub use common::*; diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 170057d..0000000 --- a/src/main.rs +++ /dev/null @@ -1,59 +0,0 @@ -mod common; -mod serving; - -use std::sync::Arc; - -use common::{ - create_class_mappings_from_class_names, create_vocabulary_to_index_mapping, - make_mock_vocabulary, CategoriesPredictorModel, ModelConfig, -}; -use serving::processing::{get_predictions, map_to_class_names_with_scores}; -use serving::request_model::{PredictRequest, PredictResponse}; -use warp::Filter; - -#[tokio::main] -async fn main() { - env_logger::init(); - - let class_names: Vec = vec!["sport".to_string(), "weather".to_string()]; - let vocabulary = make_mock_vocabulary(); - - let mapping = Arc::new(create_vocabulary_to_index_mapping(&vocabulary)); - let (class_to_index, index_to_class) = create_class_mappings_from_class_names(class_names); - - let class_to_index = Arc::new(class_to_index); - let index_to_class = Arc::new(index_to_class); - - let model_config = ModelConfig::default(); - - let model = Arc::new(CategoriesPredictorModel::random(&model_config).unwrap()); - - let health_check_route = warp::get() - .and(warp::path("hc")) - .map(|| warp::reply::json(&serde_json::json!({"status": "healthy"}))); - - let predict_route = warp::post() - .and(warp::path("predict")) - .and(warp::body::json()) - .map({ - let mapping = Arc::clone(&mapping); // Clone Arc for the closure - let class_to_index = Arc::clone(&class_to_index); - let model = Arc::clone(&model); - move |body: PredictRequest| match get_predictions(&body.text, &mapping, &model) { - Ok(predictions) => { - let predicted_categories = - map_to_class_names_with_scores(predictions, &class_to_index, -1000.0); - let response = PredictResponse { - predictions: predicted_categories, - }; - // Todo: There's a missing sigmoid - warp::reply::json(&response) - } - Err(error) => warp::reply::json(&serde_json::json!({"error": error.to_string()})), - } - }); - - let routes = health_check_route.or(predict_route); - - warp::serve(routes).run(([127, 0, 0, 1], 3030)).await; -} diff --git a/src/serving/processing.rs b/src/serving/inference.rs similarity index 55% rename from src/serving/processing.rs rename to src/serving/inference.rs index 441638f..3c0750d 100644 --- a/src/serving/processing.rs +++ b/src/serving/inference.rs @@ -1,47 +1,47 @@ -pub use crate::common::{ - convert_to_array, map_words_to_indices, pad_vector, CategoriesPredictorModel, -}; - +use anyhow::{anyhow, Error}; use candle_core::{Device, Tensor}; +use candle_nn::ops::sigmoid; +use common::{map_to_indices, pad_vector, CategoriesPredictorModel, MAX_SEQ_LEN}; use std::collections::HashMap; -use std::error::Error; -use std::sync::Arc; pub fn get_predictions( text: &str, - mapping: &Arc>, - model: &Arc, -) -> Result, Box> { + word_to_index: &HashMap, + model: &CategoriesPredictorModel, +) -> Result, Error> { let words = text .split_whitespace() .map(|s| s.to_string()) .collect::>(); - let indices = map_words_to_indices(words, mapping); + let indices = map_to_indices(words, word_to_index); - let padded_indices = pad_vector::(indices, 256, 0); + let padded_indices: Vec = pad_vector(indices, MAX_SEQ_LEN, 0); - let padded_indices_array = convert_to_array::(padded_indices)?; + let padded_indices_array: [u32; MAX_SEQ_LEN] = match padded_indices.try_into() { + Ok(array) => Ok(array), + Err(_) => Err(anyhow!("Failed to convert Vec to [u32; _]")), + }?; let tensor_indices = Tensor::new(&padded_indices_array, &Device::Cpu)?; let predictions = model.forward(&tensor_indices)?; - let predictions_vec = predictions.flatten(0, 1)?.to_vec1()?; + let predictions_vec = sigmoid(&predictions)?.flatten(0, 1)?.to_vec1()?; Ok(predictions_vec) } pub fn map_to_class_names_with_scores( logits: Vec, - class_name_mapping: &HashMap, + index_to_class: &HashMap, threshold: f32, ) -> Vec> { let mut class_names_with_logits: Vec> = Vec::new(); for (index, &logit) in logits.iter().enumerate() { if logit > threshold { - if let Some(class_name) = class_name_mapping.get(&index) { + if let Some(class_name) = index_to_class.get(&(index as u32)) { let mapping = vec![(class_name, logit)] .into_iter() .map(|(c, l)| (c.to_string(), l)) diff --git a/src/serving/serve.rs b/src/serving/serve.rs new file mode 100644 index 0000000..cd35708 --- /dev/null +++ b/src/serving/serve.rs @@ -0,0 +1,82 @@ +use common; +mod inference; +mod types; + +use std::collections::HashMap; +use std::sync::Arc; + +use common::{ + create_class_mappings_from_class_names, create_vocabulary_to_index_mapping, + load_vocabulary, CategoriesPredictorModel, ModelConfig, +}; +use inference::{get_predictions, map_to_class_names_with_scores}; +use types::{PredictRequest, PredictResponse}; +use warp::Filter; + +fn with_shared_data( + shared_data: SharedData, +) -> impl Filter + Clone { + warp::any().map(move || shared_data.clone()) +} + +#[derive(Clone)] +struct SharedData { + word_to_index: Arc>, + index_to_class: Arc>, + model: Arc +} + +#[tokio::main] +async fn main() -> anyhow::Result<()>{ + + env_logger::init(); + + let class_names: Vec = vec!["sport".to_string(), "weather".to_string()]; + let vocabulary = load_vocabulary("data/mock_vocab.json")?; + + let word_to_index = Arc::new(create_vocabulary_to_index_mapping(&vocabulary)); + let (_, index_to_class) = create_class_mappings_from_class_names(class_names); + + let index_to_class = Arc::new(index_to_class); + + let model_config = Arc::new(ModelConfig::default()); + + let model = Arc::new(CategoriesPredictorModel::random(&Arc::clone(&model_config))?); + + let shared_data = SharedData { + word_to_index: Arc::clone(&word_to_index), + index_to_class: Arc::clone(&index_to_class), + model: Arc::clone(&model), + }; + + let health_check_route = warp::get() + .and(warp::path("hc")) + .map(|| warp::reply::json(&serde_json::json!({"status": "healthy"}))); + + let predict_route = warp::post() + .and(warp::path("predict")) + .and(warp::body::json()) + .and(with_shared_data(shared_data)) + .and_then(|body: PredictRequest, data: SharedData| async move { + match get_predictions(&body.text, &data.word_to_index, &data.model) { + Ok(predictions) => { + let predicted_categories = + map_to_class_names_with_scores(predictions, &data.index_to_class, 0.3); + let response = PredictResponse { + predictions: predicted_categories, + }; + // Todo: There's a missing sigmoid + Ok::<_, warp::Rejection>(warp::reply::json(&response)) + } + Err(error) => Ok(warp::reply::json( + &serde_json::json!({"error": error.to_string()}), + )), + } + }); + + let routes = health_check_route.or(predict_route); + + warp::serve(routes).run(([127, 0, 0, 1], 3030)).await; + + Ok(()) +} diff --git a/src/serving/request_model.rs b/src/serving/types.rs similarity index 99% rename from src/serving/request_model.rs rename to src/serving/types.rs index eaba74b..49957b0 100644 --- a/src/serving/request_model.rs +++ b/src/serving/types.rs @@ -1,6 +1,5 @@ -use std::collections::HashMap; - use serde::{Deserialize, Serialize}; +use std::collections::HashMap; #[derive(Deserialize)] pub struct PredictRequest { diff --git a/src/training/dataset.rs b/src/training/dataset.rs index 05aba11..bad4a86 100644 --- a/src/training/dataset.rs +++ b/src/training/dataset.rs @@ -1,5 +1,3 @@ -use std::vec; - use anyhow::Error; use candle_core::Tensor; use polars::io::{csv::CsvReader, SerReader}; diff --git a/src/training/dummy.rs b/src/training/dummy.rs deleted file mode 100644 index b6204bf..0000000 --- a/src/training/dummy.rs +++ /dev/null @@ -1,20 +0,0 @@ -use anyhow::{Ok, Error}; -use polars::{error::PolarsError, io::SerReader, prelude::CsvReader}; - -pub fn main() -> Result<(), Error> { - let df = CsvReader::from_path("./data/train.csv")?.has_header(true).finish()?; - - let series = df.column("text").unwrap().to_owned(); - - let vec_of_strings: Vec<&str> = series.str().unwrap().into_iter().map(|optional| { - let x = match optional { - Some(val) => val, None => "fk" - }; - x - }).collect(); - - println!("{:?}", vec_of_strings); - - - Ok(()) -} \ No newline at end of file diff --git a/src/training/train.rs b/src/training/train.rs index 567e1b9..665c80d 100644 --- a/src/training/train.rs +++ b/src/training/train.rs @@ -112,8 +112,8 @@ pub fn main() -> Result<()> { log::debug!("Class to index {:?}", class_to_index); // Multi-hot encode the labels - let train_labels_encoded = multi_hot_encode(train_labels, &class_to_index); - let test_labels_encoded = multi_hot_encode(test_labels, &class_to_index); + let train_labels_encoded = multi_hot_encode(train_labels, &class_to_index)?; + let test_labels_encoded = multi_hot_encode(test_labels, &class_to_index)?; // Make the vocabulary and the vocabulary to index from the training data let vocabulary = make_vocabulary(&train_data); diff --git a/src/training/transform.rs b/src/training/transform.rs index 2017834..4bf299e 100644 --- a/src/training/transform.rs +++ b/src/training/transform.rs @@ -1,6 +1,6 @@ use anyhow::Error; -use candle_core::{DType, Device, Tensor}; -use common::{map_words_to_indices, pad_vector}; +use candle_core::{Device, Tensor}; +use common::{map_to_indices, pad_vector}; use std::collections::HashMap; pub fn encode( @@ -13,7 +13,7 @@ pub fn encode( .iter() .flat_map(|sentence| { let words: Vec = sentence.split_whitespace().map(|s| s.to_string()).collect(); - let indices = map_words_to_indices(words, vocabulary_index_mapping); + let indices = map_to_indices(words, vocabulary_index_mapping); pad_vector(indices, max_seq_len, 0) }) .collect(); diff --git a/tests/encode.rs b/tests/encode.rs new file mode 100644 index 0000000..5ab1497 --- /dev/null +++ b/tests/encode.rs @@ -0,0 +1,238 @@ +#[cfg(test)] +mod encode { + + use std::collections::HashMap; + + use common::encode::*; + + #[test] + fn test_create_vocabulary_to_index_mapping_multiple_words() { + let vocabulary = vec![ + "apple".to_string(), + "banana".to_string(), + "cherry".to_string(), + ]; + let result = create_vocabulary_to_index_mapping(&vocabulary); + let expected: HashMap = [ + ("".to_string(), 0), + ("apple".to_string(), 1), + ("banana".to_string(), 2), + ("cherry".to_string(), 3), + ] + .iter() + .cloned() + .collect(); + + assert_eq!(result, expected); + } + + #[test] + fn test_map_to_indices_word_not_in_mapping() { + let words = vec!["apple".to_string(), "banana".to_string(), "cherry".to_string()]; + let mapping: HashMap = [("apple".to_string(), 1), ("cherry".to_string(), 3)] + .iter() + .cloned() + .collect(); + + let result = map_to_indices(words, &mapping); + let expected: Vec = vec![1, 0, 3]; + + assert_eq!(result, expected); + } + + #[test] + fn test_map_to_indices_empty_words() { + let words: Vec = Vec::new(); + let mapping: HashMap = HashMap::new(); + + let result = map_to_indices(words, &mapping); + let expected: Vec = Vec::new(); + + assert_eq!(result, expected); + } + + #[test] + fn test_map_to_indices_word_in_mapping() { + let words = vec!["apple".to_string(), "banana".to_string(), "cherry".to_string()]; + let mapping: HashMap = [ + ("apple".to_string(), 1), + ("banana".to_string(), 2), + ("cherry".to_string(), 3), + ] + .iter() + .cloned() + .collect(); + + let result = map_to_indices(words, &mapping); + let expected: Vec = vec![1, 2, 3]; + + assert_eq!(result, expected); + } + + + #[test] + fn test_create_class_mappings_empty() { + let class_names: Vec = Vec::new(); + + let (class_to_index, index_to_class) = create_class_mappings_from_class_names(class_names); + + let expected_class_to_index: HashMap = HashMap::new(); + let expected_index_to_class: HashMap = HashMap::new(); + + assert_eq!(class_to_index, expected_class_to_index); + assert_eq!(index_to_class, expected_index_to_class); + } + + #[test] + fn test_create_class_mappings_single_class() { + let class_names = vec!["ClassA".to_string()]; + + let (class_to_index, index_to_class) = create_class_mappings_from_class_names(class_names); + + let mut expected_class_to_index: HashMap = HashMap::new(); + let mut expected_index_to_class: HashMap = HashMap::new(); + + expected_class_to_index.insert("ClassA".to_string(), 0); + expected_index_to_class.insert(0, "ClassA".to_string()); + + assert_eq!(class_to_index, expected_class_to_index); + assert_eq!(index_to_class, expected_index_to_class); + } + + #[test] + fn test_create_class_mappings_multiple_classes() { + let class_names = vec!["ClassA".to_string(), "ClassB".to_string(), "ClassC".to_string()]; + + let (class_to_index, index_to_class) = create_class_mappings_from_class_names(class_names); + + let mut expected_class_to_index: HashMap = HashMap::new(); + let mut expected_index_to_class: HashMap = HashMap::new(); + + expected_class_to_index.insert("ClassA".to_string(), 0); + expected_class_to_index.insert("ClassB".to_string(), 1); + expected_class_to_index.insert("ClassC".to_string(), 2); + + expected_index_to_class.insert(0, "ClassA".to_string()); + expected_index_to_class.insert(1, "ClassB".to_string()); + expected_index_to_class.insert(2, "ClassC".to_string()); + + assert_eq!(class_to_index, expected_class_to_index); + assert_eq!(index_to_class, expected_index_to_class); + } + + #[test] + fn test_create_class_mapping_empty_labels() { + let labels: Vec = Vec::new(); + + let (class_to_index, index_to_class) = create_class_mapping_from_labels(&labels); + + let expected_class_to_index: HashMap = HashMap::new(); + let expected_index_to_class: HashMap = HashMap::new(); + + assert_eq!(class_to_index, expected_class_to_index); + assert_eq!(index_to_class, expected_index_to_class); + } + + #[test] + fn test_create_class_mapping_single_label() { + let labels = vec!["ClassA".to_string()]; + + let (class_to_index, index_to_class) = create_class_mapping_from_labels(&labels); + + let mut expected_class_to_index: HashMap = HashMap::new(); + let mut expected_index_to_class: HashMap = HashMap::new(); + + expected_class_to_index.insert("ClassA".to_string(), 0); + expected_index_to_class.insert(0, "ClassA".to_string()); + + assert_eq!(class_to_index, expected_class_to_index); + assert_eq!(index_to_class, expected_index_to_class); + } + + #[test] + fn test_create_class_mapping_multiple_labels() { + let labels = vec!["ClassA|ClassB".to_string(), "ClassC".to_string(), "".to_string()]; + + let (class_to_index, index_to_class) = create_class_mapping_from_labels(&labels); + + let mut expected_class_to_index: HashMap = HashMap::new(); + let mut expected_index_to_class: HashMap = HashMap::new(); + + expected_class_to_index.insert("ClassA".to_string(), 0); + expected_class_to_index.insert("ClassB".to_string(), 1); + expected_class_to_index.insert("ClassC".to_string(), 2); + + expected_index_to_class.insert(0, "ClassA".to_string()); + expected_index_to_class.insert(1, "ClassB".to_string()); + expected_index_to_class.insert(2, "ClassC".to_string()); + + assert_eq!(class_to_index, expected_class_to_index); + assert_eq!(index_to_class, expected_index_to_class); + } + #[test] + fn test_multi_hot_encode_empty_labels() { + let labels: Vec = Vec::new(); + let class_to_index: HashMap = HashMap::new(); + + let result = multi_hot_encode(labels, &class_to_index); + + match result { + Ok(encodings) => assert_eq!(encodings, Vec::::new()), + _ => panic!("Expected Ok(Vec::new())"), + } + } + + #[test] + fn test_multi_hot_encode_single_label_not_found() { + let labels = vec!["ClassA".to_string()]; + let class_to_index: HashMap = HashMap::new(); + + let result = multi_hot_encode(labels, &class_to_index); + + match result { + Err(err) => assert_eq!( + err.to_string(), + "Label not found: ClassA".to_string() + ), + _ => panic!("Expected Err(\"Label not found: ClassA\")"), + } + } + + #[test] + fn test_multi_hot_encode_single_label_found() { + let labels = vec!["ClassA".to_string()]; + let mut class_to_index: HashMap = HashMap::new(); + class_to_index.insert("ClassA".to_string(), 0); + + let result = multi_hot_encode(labels, &class_to_index); + + match result { + Ok(encodings) => assert_eq!(encodings, vec![1]), + _ => panic!("Expected Ok(vec![1])"), + } + } + + #[test] + fn test_multi_hot_encode_multiple_labels() { + let labels = vec!["ClassA|ClassB".to_string(), "ClassC".to_string(), "".to_string()]; + let mut class_to_index: HashMap = HashMap::new(); + class_to_index.insert("ClassA".to_string(), 0); + class_to_index.insert("ClassB".to_string(), 1); + class_to_index.insert("ClassC".to_string(), 2); + + let result = multi_hot_encode(labels, &class_to_index); + + match result { + Ok(encodings) => assert_eq!( + encodings, + vec![ + 1, 1, 0, // ClassA, ClassB, empty label + 0, 0, 1, // ClassC, empty label, ClassC + 0, 0, 0, // empty label, empty label, empty label + ] + ), + _ => panic!("Expected Ok(encodings)"), + } + } + +} diff --git a/tests/preproces.rs b/tests/preproces.rs new file mode 100644 index 0000000..ed951a0 --- /dev/null +++ b/tests/preproces.rs @@ -0,0 +1,43 @@ +mod encode; + +#[cfg(test)] +mod test_convert { + + use common::preprocess::*; + + #[test] + fn test_pad_vector_smaller() { + let vector: Vec = Vec::new(); + let max_padding = 5_usize; + let pad_value = 1_u32; + + let expected_result: Vec = vec![1; 5]; + let actual_result = pad_vector(vector, max_padding, pad_value); + + assert_eq!(expected_result, actual_result); + } + + #[test] + fn test_pad_vector_larger() { + let vector: Vec = vec![1; 5]; + let max_padding = 4; + let pad_value = 1; + + let expected_result: Vec = vec![1; 4]; + let actual_result = pad_vector(vector, max_padding, pad_value); + + assert_eq!(expected_result, actual_result); + } + + #[test] + fn test_pad_vector_equal() { + let vector: Vec = vec![1; 5]; + let max_padding = 5; + let pad_value = 1; + let expected_result: Vec = vec![1; 5]; + + let actual_result = pad_vector(vector, max_padding, pad_value); + + assert_eq!(expected_result, actual_result); + } +} diff --git a/tests/vocabulary.rs b/tests/vocabulary.rs new file mode 100644 index 0000000..a68602b --- /dev/null +++ b/tests/vocabulary.rs @@ -0,0 +1,32 @@ +#[cfg(test)] +mod encode { + + use std::collections::HashSet; + + use common::vocabulary::*; + + #[test] + fn test_make_vocabulary_single_sentence() { + let corpus = vec!["Hello, world!".to_string()]; + let vocabulary = make_vocabulary(&corpus); + + let expected_words_in_vocabulary = vec!["Hello".to_string(), "world".to_string()]; + println!("{:?}", vocabulary); + for word in expected_words_in_vocabulary { + assert!(vocabulary.contains(&word)); + } + } + + #[test] + fn test_make_vocabulary_multiple_sentences() { + let corpus = vec![ + "This is a test.".to_string(), + "Another test.".to_string(), + ]; + let vocabulary = make_vocabulary(&corpus); + let expected_words_in_vocabulary: Vec = vec!["This".to_string(), "is".to_string(), "a".to_string(), "test".to_string(), "Another".to_string()]; + for word in expected_words_in_vocabulary { + assert!(vocabulary.contains(&word)); + } + } +} \ No newline at end of file