Skip to content

Commit

Permalink
refactor: store index to class and vocab after training
Browse files Browse the repository at this point in the history
  • Loading branch information
radandreicristian-cnx committed Feb 2, 2024
1 parent 5e58936 commit b5d5cbe
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 106 deletions.
1 change: 1 addition & 0 deletions data/index_to_class.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"mapping":{"1":"weather","0":"sports"}}
9 changes: 0 additions & 9 deletions data/mock_vocab.json

This file was deleted.

1 change: 1 addition & 0 deletions data/vocab.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"vocabulary":["midwest","device","over","medical","interrupted","competition","heatwave","sweeps","victory","hope","in","dominates","illness","unveils","to","win","wedding","by","looms","football","leaves","for","couples","surprise","across","discovery","promises","blizzard","team","giant","groundbreaking","south","leads","revolutionary","new","overtime","celebrity","player","fans","comeback","awe","game","tennis","warning","tech"]}
60 changes: 26 additions & 34 deletions src/common/encode.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::collections::HashMap;
use std::{collections::HashMap, fs::File, io::{Read, Write}};
use serde::{Deserialize, Serialize};
use anyhow::Error;
use super::MultiHotEncodeError;

/// Creates a mapping from vocabulary words to their corresponding indices.
Expand Down Expand Up @@ -53,39 +55,6 @@ pub fn map_to_indices(words: Vec<String>, mapping: &HashMap<String, u32>) -> Vec
.collect()
}

/// Creates mappings between class names and their corresponding indices.
///
/// This function takes a vector of class names `class_names` and creates two mappings:
/// - A mapping from class names (strings) to their corresponding indices (u32).
/// - A reverse mapping from indices to class names.
///
/// These mappings are often used in machine learning tasks where class labels need to be
/// represented as indices for model training and evaluation.
///
/// # Arguments
///
/// * `class_names`: A vector of strings representing the class names.
///
/// # Returns
///
/// A tuple containing two `HashMap` instances:
/// - The first `HashMap` maps class names (strings) to their corresponding indices (u32).
/// - The second `HashMap` maps indices to their corresponding class names.
pub fn create_class_mappings_from_class_names(
class_names: Vec<String>,
) -> (HashMap<String, u32>, HashMap<u32, String>) {
let mut index_to_class: HashMap<u32, String> = HashMap::new();
let mut class_to_index: HashMap<String, u32> = HashMap::new();

for (index, word) in class_names.iter().enumerate() {
let index_u32 = index as u32;
index_to_class.insert(index_u32, word.to_string());
class_to_index.insert(word.to_string(), index_u32);
}

(class_to_index, index_to_class)
}

/// Creates mappings between class labels and their corresponding indices.
///
/// This function takes a reference to a vector of class labels `labels` and creates two mappings:
Expand Down Expand Up @@ -186,3 +155,26 @@ pub fn multi_hot_encode(

Ok(all_encodings)
}


#[derive(Serialize, Deserialize, Debug)]
struct IndexToClassMapping {
mapping: HashMap<u32, String>
}

pub fn store_index_to_class_mapping(index_to_class: &HashMap<u32, String>, file_path: &str) -> Result<(), Error> {
let mut file = File::create(file_path)?;
file.write_all(serde_json::to_string(&IndexToClassMapping{mapping: index_to_class.to_owned()})?.as_bytes())?;
Ok(())
}

pub fn load_index_to_class_mapping(file_path: &str) -> Result<HashMap<u32, String>, Error> {
let mut file = File::open(file_path)?;
let mut json_data = String::new();

file.read_to_string(&mut json_data)?;

let mapping: IndexToClassMapping = serde_json::from_str(&json_data)?;

Ok(mapping.mapping)
}
14 changes: 11 additions & 3 deletions src/common/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::{collections::HashSet, fs::File, io::Read};
use std::{collections::HashSet, fs::File, io::{Read, Write}};
use regex::Regex;
use serde::Deserialize;
use serde::{Serialize, Deserialize};
use super::exception::VocabularyLoadError;

#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
struct Vocabulary {
vocabulary: Vec<String>
}
Expand Down Expand Up @@ -49,3 +49,11 @@ pub fn load_vocabulary(file_path: &str) -> Result<Vec<String>, VocabularyLoadErr

Ok(vocabulary.vocabulary)
}


pub fn store_vocabulary(vocabulary: &Vec<String>, file_path: &str) -> Result<(), anyhow::Error>{

let mut file = File::create(file_path)?;
file.write_all(serde_json::to_string(&Vocabulary{vocabulary: vocabulary.to_owned()})?.as_bytes())?;
Ok(())
}
7 changes: 3 additions & 4 deletions src/serving/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::sync::Arc;

use common::{
create_class_mappings_from_class_names, create_vocabulary_to_index_mapping,
load_index_to_class_mapping, create_vocabulary_to_index_mapping,
load_vocabulary, CategoriesPredictorModel, ModelConfig,
};
use inference::{get_predictions, map_to_class_names_with_scores};
Expand All @@ -31,11 +31,10 @@ async fn main() -> anyhow::Result<()>{

env_logger::init();

let class_names: Vec<String> = vec!["sport".to_string(), "weather".to_string()];
let vocabulary = load_vocabulary("data/mock_vocab.json")?;
let vocabulary = load_vocabulary("data/vocab.json")?;

let word_to_index = Arc::new(create_vocabulary_to_index_mapping(&vocabulary));
let (_, index_to_class) = create_class_mappings_from_class_names(class_names);
let index_to_class = load_index_to_class_mapping("data/index_to_class.json")?;

let index_to_class = Arc::new(index_to_class);

Expand Down
11 changes: 9 additions & 2 deletions src/training/train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use candle_optimisers::adam;
use candle_optimisers::adam::ParamsAdam;
use common::{
create_class_mapping_from_labels, create_vocabulary_to_index_mapping, make_vocabulary,
multi_hot_encode,
multi_hot_encode, store_vocabulary, store_index_to_class_mapping
};
use common::{CategoriesPredictorModel, ModelConfig};
use config::TrainConfig;
Expand Down Expand Up @@ -107,7 +107,11 @@ pub fn main() -> Result<()> {
log::debug!("Train data sample: {:?}", train_data[0]);

// Create the class to index mapping
let (class_to_index, _) = create_class_mapping_from_labels(&train_labels);
let (class_to_index, index_to_class) = create_class_mapping_from_labels(&train_labels);


// Store the index to class mapping for inference
store_index_to_class_mapping(&index_to_class, "data/index_to_class.json")?;

log::debug!("Class to index {:?}", class_to_index);

Expand All @@ -118,6 +122,9 @@ pub fn main() -> Result<()> {
// Make the vocabulary and the vocabulary to index from the training data
let vocabulary = make_vocabulary(&train_data);

// Store the vocabulary to be loaded during inference
store_vocabulary(&vocabulary, "data/vocab.json")?;

let vocabulary_index_mapping = create_vocabulary_to_index_mapping(&vocabulary);

let max_seq_len = model_config.max_seq_len;
Expand Down
53 changes: 1 addition & 52 deletions tests/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,58 +68,7 @@ mod encode {

assert_eq!(result, expected);
}


#[test]
fn test_create_class_mappings_empty() {
let class_names: Vec<String> = Vec::new();

let (class_to_index, index_to_class) = create_class_mappings_from_class_names(class_names);

let expected_class_to_index: HashMap<String, u32> = HashMap::new();
let expected_index_to_class: HashMap<u32, String> = HashMap::new();

assert_eq!(class_to_index, expected_class_to_index);
assert_eq!(index_to_class, expected_index_to_class);
}

#[test]
fn test_create_class_mappings_single_class() {
let class_names = vec!["ClassA".to_string()];

let (class_to_index, index_to_class) = create_class_mappings_from_class_names(class_names);

let mut expected_class_to_index: HashMap<String, u32> = HashMap::new();
let mut expected_index_to_class: HashMap<u32, String> = HashMap::new();

expected_class_to_index.insert("ClassA".to_string(), 0);
expected_index_to_class.insert(0, "ClassA".to_string());

assert_eq!(class_to_index, expected_class_to_index);
assert_eq!(index_to_class, expected_index_to_class);
}

#[test]
fn test_create_class_mappings_multiple_classes() {
let class_names = vec!["ClassA".to_string(), "ClassB".to_string(), "ClassC".to_string()];

let (class_to_index, index_to_class) = create_class_mappings_from_class_names(class_names);

let mut expected_class_to_index: HashMap<String, u32> = HashMap::new();
let mut expected_index_to_class: HashMap<u32, String> = HashMap::new();

expected_class_to_index.insert("ClassA".to_string(), 0);
expected_class_to_index.insert("ClassB".to_string(), 1);
expected_class_to_index.insert("ClassC".to_string(), 2);

expected_index_to_class.insert(0, "ClassA".to_string());
expected_index_to_class.insert(1, "ClassB".to_string());
expected_index_to_class.insert(2, "ClassC".to_string());

assert_eq!(class_to_index, expected_class_to_index);
assert_eq!(index_to_class, expected_index_to_class);
}


#[test]
fn test_create_class_mapping_empty_labels() {
let labels: Vec<String> = Vec::new();
Expand Down
2 changes: 0 additions & 2 deletions tests/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#[cfg(test)]
mod encode {

use std::collections::HashSet;

use common::vocabulary::*;

#[test]
Expand Down

0 comments on commit b5d5cbe

Please sign in to comment.