Skip to content

Commit

Permalink
refactor: project structure
Browse files Browse the repository at this point in the history
  • Loading branch information
radandreicristian-cnx committed Feb 2, 2024
1 parent dea83f6 commit 5e58936
Show file tree
Hide file tree
Showing 24 changed files with 736 additions and 268 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: docker build -t headline-predictor-rs
- run: docker build -t headline-predictor-rs .
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2021"

[[bin]]
name = "inference"
path = "src/main.rs"
path = "src/serving/serve.rs"

[[bin]]
name = "training"
Expand All @@ -27,4 +27,5 @@ candle-optimisers = "0.3.2"
log = "0.4.20"
env_logger = "0.11.1"
anyhow = "1.0.0"
regex = "1.10.3"
polars ={ version = "0.37.0", features=["lazy"] }
9 changes: 9 additions & 0 deletions data/mock_vocab.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"vocabulary": [
"this",
"is",
"an",
"example",
"vocabulary"
]
}
20 changes: 0 additions & 20 deletions src/common/convert.rs

This file was deleted.

188 changes: 188 additions & 0 deletions src/common/encode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
use std::collections::HashMap;
use super::MultiHotEncodeError;

/// Creates a mapping from vocabulary words to their corresponding indices.
///
/// This function takes a reference to a vector of strings `vocabulary` and creates a `HashMap`
/// where each unique word in the vocabulary is associated with its index (a `u32` value). The
/// resulting mapping is used for tasks such as converting text data into sequences of indices.
///
/// The special token "<UNK>" is inserted at index 0 in the mapping to represent any unknown words.
/// When using this mapping to convert words to indices, if a word is not found in the vocabulary,
/// it is mapped to the index of "<UNK>" (index 0).
///
/// # Arguments
///
/// * `vocabulary`: A reference to a vector of strings representing the vocabulary.
///
/// # Returns
///
/// A `HashMap` mapping words to their corresponding indices.
pub fn create_vocabulary_to_index_mapping(vocabulary: &Vec<String>) -> HashMap<String, u32> {
let mut vocab_to_index = HashMap::<String, u32>::new();

// Special token for any unknown words. In map_to_indices, unwrap_or(0) is used to map any unknown words to this token index.
vocab_to_index.insert("<UNK>".to_string(), 0);

for (index, word) in vocabulary.iter().enumerate() {
vocab_to_index.insert(word.clone(), index as u32 + 1);
}

vocab_to_index
}

/// Maps a vector of words to their corresponding indices using a mapping.
///
/// This function takes a vector of strings `words` and a reference to a mapping `mapping`, which
/// associates words with their corresponding indices. It maps each word in the input vector to its
/// index using the provided mapping. If a word is not found in the mapping, it is mapped to 0 by
/// default.
///
/// # Arguments
///
/// * `words`: A vector of strings representing the words to be mapped to indices.
/// * `mapping`: A reference to a `HashMap` mapping words to their corresponding indices.
///
/// # Returns
///
/// A vector of `u32` values representing the indices of the input words based on the provided mapping.
pub fn map_to_indices(words: Vec<String>, mapping: &HashMap<String, u32>) -> Vec<u32> {
words
.iter()
.map(|word| mapping.get(word).copied().unwrap_or(0))
.collect()
}

/// Creates mappings between class names and their corresponding indices.
///
/// This function takes a vector of class names `class_names` and creates two mappings:
/// - A mapping from class names (strings) to their corresponding indices (u32).
/// - A reverse mapping from indices to class names.
///
/// These mappings are often used in machine learning tasks where class labels need to be
/// represented as indices for model training and evaluation.
///
/// # Arguments
///
/// * `class_names`: A vector of strings representing the class names.
///
/// # Returns
///
/// A tuple containing two `HashMap` instances:
/// - The first `HashMap` maps class names (strings) to their corresponding indices (u32).
/// - The second `HashMap` maps indices to their corresponding class names.
pub fn create_class_mappings_from_class_names(
class_names: Vec<String>,
) -> (HashMap<String, u32>, HashMap<u32, String>) {
let mut index_to_class: HashMap<u32, String> = HashMap::new();
let mut class_to_index: HashMap<String, u32> = HashMap::new();

for (index, word) in class_names.iter().enumerate() {
let index_u32 = index as u32;
index_to_class.insert(index_u32, word.to_string());
class_to_index.insert(word.to_string(), index_u32);
}

(class_to_index, index_to_class)
}

/// Creates mappings between class labels and their corresponding indices.
///
/// This function takes a reference to a vector of class labels `labels` and creates two mappings:
/// - A mapping from class labels (strings) to their corresponding indices (u32).
/// - A reverse mapping from indices to class labels.
///
/// Class labels are often represented as strings separated by '|' characters, allowing multiple
/// labels to be associated with a single data point. This function processes the labels and assigns
/// unique indices to each unique label encountered.
///
/// # Arguments
///
/// * `labels`: A reference to a vector of strings representing class labels.
///
/// # Returns
///
/// A tuple containing two `HashMap` instances:
/// - The first `HashMap` maps class labels (strings) to their corresponding indices (u32).
/// - The second `HashMap` maps indices to their corresponding class labels.
pub fn create_class_mapping_from_labels(
labels: &Vec<String>,
) -> (HashMap<String, u32>, HashMap<u32, String>) {
let mut class_to_index: HashMap<String, u32> = HashMap::new();
let mut index_to_class: HashMap<u32, String> = HashMap::new();

let mut n_classes = 0;

for word in labels.iter() {
if word == "" {
continue;
}
let labels: Vec<&str> = word.split('|').collect();

for label in labels {
if !class_to_index.contains_key(label) {
class_to_index.insert(label.to_string(), n_classes);
index_to_class.insert(n_classes, label.to_string());
n_classes += 1;
}
}
}

(class_to_index, index_to_class)
}

/// Converts a list of labels into a multi-hot encoding using the provided mapping.
///
/// Given a list of labels and a mapping of class names to their respective indices,
/// this function generates a multi-hot encoding where each label corresponds to a binary vector.
/// If a label is not found in the mapping, an error is returned.
///
/// # Arguments
///
/// * `labels` - A vector of labels to be encoded.
/// * `class_to_index` - A reference to a `HashMap` containing class names as keys and their
/// corresponding indices as values.
///
/// # Errors
///
/// If a label in `labels` is not found in `class_to_index`, an `Err` variant is returned
/// with an associated error message indicating which label was not found.
///
/// # Returns
///
/// Returns a `Result` where:
/// - `Ok(encodings)` contains the multi-hot encodings as a vector of u32 values.
/// - `Err(err)` contains a `MultiHotEncodeError` with a description of the error.
pub fn multi_hot_encode(
labels: Vec<String>,
class_to_index: &HashMap<String, u32>,
) -> Result<Vec<u32>, MultiHotEncodeError> {
let n_classes = class_to_index.len();
let mut all_encodings: Vec<u32> = Vec::new(); // Initialize a single encodings vector

for label in labels {
let mut label_encodings = vec![0u32; n_classes];
log::debug!("Label: {:?}", label);
if label == "" {
log::debug!("Encoding: {:?}", label_encodings);
all_encodings.append(&mut label_encodings);
// Skip empty labels
continue;
}

let label_classes: Vec<&str> = label.split('|').collect();
for label_class in label_classes {
if let Some(&index) = class_to_index.get(&label_class.to_string()) {
label_encodings[index as usize] = 1u32;
} else {
return Err(MultiHotEncodeError::new(&format!("Label not found: {}", label_class)));
}
}
log::debug!("Encoding: {:?}", label_encodings);

all_encodings.append(&mut label_encodings);
}
log::debug!("All encodings {:?}", all_encodings);

Ok(all_encodings)
}
52 changes: 43 additions & 9 deletions src/common/exception.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,52 @@
use std::{error::Error, fmt};


#[derive(Debug)]
pub enum InferenceError {
ArrayConversionError(&'static str),
pub struct MultiHotEncodeError {
pub message: String,
}

impl fmt::Display for InferenceError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
InferenceError::ArrayConversionError(error_message) => {
write!(f, "Inference Error - {}", error_message)
}
impl MultiHotEncodeError {
pub fn new(message: &str) -> MultiHotEncodeError {
MultiHotEncodeError {
message: message.to_string(),
}
}
}

impl Error for InferenceError {}
impl fmt::Display for MultiHotEncodeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.message)
}
}

impl Error for MultiHotEncodeError {}

#[derive(Debug)]
pub enum VocabularyLoadError {
Io(std::io::Error),
Json(serde_json::Error),
}

impl From<std::io::Error> for VocabularyLoadError {
fn from(err: std::io::Error) -> Self {
VocabularyLoadError::Io(err)
}
}

impl From<serde_json::Error> for VocabularyLoadError {
fn from(err: serde_json::Error) -> Self {
VocabularyLoadError::Json(err)
}
}

impl std::error::Error for VocabularyLoadError {}

impl std::fmt::Display for VocabularyLoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VocabularyLoadError::Io(err) => write!(f, "IO error: {}", err),
VocabularyLoadError::Json(err) => write!(f, "JSON error: {}", err),
}
}
}
Loading

0 comments on commit 5e58936

Please sign in to comment.