-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dea83f6
commit 5e58936
Showing
24 changed files
with
736 additions
and
268 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"vocabulary": [ | ||
"this", | ||
"is", | ||
"an", | ||
"example", | ||
"vocabulary" | ||
] | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
use std::collections::HashMap; | ||
use super::MultiHotEncodeError; | ||
|
||
/// Creates a mapping from vocabulary words to their corresponding indices. | ||
/// | ||
/// This function takes a reference to a vector of strings `vocabulary` and creates a `HashMap` | ||
/// where each unique word in the vocabulary is associated with its index (a `u32` value). The | ||
/// resulting mapping is used for tasks such as converting text data into sequences of indices. | ||
/// | ||
/// The special token "<UNK>" is inserted at index 0 in the mapping to represent any unknown words. | ||
/// When using this mapping to convert words to indices, if a word is not found in the vocabulary, | ||
/// it is mapped to the index of "<UNK>" (index 0). | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `vocabulary`: A reference to a vector of strings representing the vocabulary. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `HashMap` mapping words to their corresponding indices. | ||
pub fn create_vocabulary_to_index_mapping(vocabulary: &Vec<String>) -> HashMap<String, u32> { | ||
let mut vocab_to_index = HashMap::<String, u32>::new(); | ||
|
||
// Special token for any unknown words. In map_to_indices, unwrap_or(0) is used to map any unknown words to this token index. | ||
vocab_to_index.insert("<UNK>".to_string(), 0); | ||
|
||
for (index, word) in vocabulary.iter().enumerate() { | ||
vocab_to_index.insert(word.clone(), index as u32 + 1); | ||
} | ||
|
||
vocab_to_index | ||
} | ||
|
||
/// Maps a vector of words to their corresponding indices using a mapping. | ||
/// | ||
/// This function takes a vector of strings `words` and a reference to a mapping `mapping`, which | ||
/// associates words with their corresponding indices. It maps each word in the input vector to its | ||
/// index using the provided mapping. If a word is not found in the mapping, it is mapped to 0 by | ||
/// default. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `words`: A vector of strings representing the words to be mapped to indices. | ||
/// * `mapping`: A reference to a `HashMap` mapping words to their corresponding indices. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A vector of `u32` values representing the indices of the input words based on the provided mapping. | ||
pub fn map_to_indices(words: Vec<String>, mapping: &HashMap<String, u32>) -> Vec<u32> { | ||
words | ||
.iter() | ||
.map(|word| mapping.get(word).copied().unwrap_or(0)) | ||
.collect() | ||
} | ||
|
||
/// Creates mappings between class names and their corresponding indices. | ||
/// | ||
/// This function takes a vector of class names `class_names` and creates two mappings: | ||
/// - A mapping from class names (strings) to their corresponding indices (u32). | ||
/// - A reverse mapping from indices to class names. | ||
/// | ||
/// These mappings are often used in machine learning tasks where class labels need to be | ||
/// represented as indices for model training and evaluation. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `class_names`: A vector of strings representing the class names. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A tuple containing two `HashMap` instances: | ||
/// - The first `HashMap` maps class names (strings) to their corresponding indices (u32). | ||
/// - The second `HashMap` maps indices to their corresponding class names. | ||
pub fn create_class_mappings_from_class_names( | ||
class_names: Vec<String>, | ||
) -> (HashMap<String, u32>, HashMap<u32, String>) { | ||
let mut index_to_class: HashMap<u32, String> = HashMap::new(); | ||
let mut class_to_index: HashMap<String, u32> = HashMap::new(); | ||
|
||
for (index, word) in class_names.iter().enumerate() { | ||
let index_u32 = index as u32; | ||
index_to_class.insert(index_u32, word.to_string()); | ||
class_to_index.insert(word.to_string(), index_u32); | ||
} | ||
|
||
(class_to_index, index_to_class) | ||
} | ||
|
||
/// Creates mappings between class labels and their corresponding indices. | ||
/// | ||
/// This function takes a reference to a vector of class labels `labels` and creates two mappings: | ||
/// - A mapping from class labels (strings) to their corresponding indices (u32). | ||
/// - A reverse mapping from indices to class labels. | ||
/// | ||
/// Class labels are often represented as strings separated by '|' characters, allowing multiple | ||
/// labels to be associated with a single data point. This function processes the labels and assigns | ||
/// unique indices to each unique label encountered. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `labels`: A reference to a vector of strings representing class labels. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A tuple containing two `HashMap` instances: | ||
/// - The first `HashMap` maps class labels (strings) to their corresponding indices (u32). | ||
/// - The second `HashMap` maps indices to their corresponding class labels. | ||
pub fn create_class_mapping_from_labels( | ||
labels: &Vec<String>, | ||
) -> (HashMap<String, u32>, HashMap<u32, String>) { | ||
let mut class_to_index: HashMap<String, u32> = HashMap::new(); | ||
let mut index_to_class: HashMap<u32, String> = HashMap::new(); | ||
|
||
let mut n_classes = 0; | ||
|
||
for word in labels.iter() { | ||
if word == "" { | ||
continue; | ||
} | ||
let labels: Vec<&str> = word.split('|').collect(); | ||
|
||
for label in labels { | ||
if !class_to_index.contains_key(label) { | ||
class_to_index.insert(label.to_string(), n_classes); | ||
index_to_class.insert(n_classes, label.to_string()); | ||
n_classes += 1; | ||
} | ||
} | ||
} | ||
|
||
(class_to_index, index_to_class) | ||
} | ||
|
||
/// Converts a list of labels into a multi-hot encoding using the provided mapping. | ||
/// | ||
/// Given a list of labels and a mapping of class names to their respective indices, | ||
/// this function generates a multi-hot encoding where each label corresponds to a binary vector. | ||
/// If a label is not found in the mapping, an error is returned. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `labels` - A vector of labels to be encoded. | ||
/// * `class_to_index` - A reference to a `HashMap` containing class names as keys and their | ||
/// corresponding indices as values. | ||
/// | ||
/// # Errors | ||
/// | ||
/// If a label in `labels` is not found in `class_to_index`, an `Err` variant is returned | ||
/// with an associated error message indicating which label was not found. | ||
/// | ||
/// # Returns | ||
/// | ||
/// Returns a `Result` where: | ||
/// - `Ok(encodings)` contains the multi-hot encodings as a vector of u32 values. | ||
/// - `Err(err)` contains a `MultiHotEncodeError` with a description of the error. | ||
pub fn multi_hot_encode( | ||
labels: Vec<String>, | ||
class_to_index: &HashMap<String, u32>, | ||
) -> Result<Vec<u32>, MultiHotEncodeError> { | ||
let n_classes = class_to_index.len(); | ||
let mut all_encodings: Vec<u32> = Vec::new(); // Initialize a single encodings vector | ||
|
||
for label in labels { | ||
let mut label_encodings = vec![0u32; n_classes]; | ||
log::debug!("Label: {:?}", label); | ||
if label == "" { | ||
log::debug!("Encoding: {:?}", label_encodings); | ||
all_encodings.append(&mut label_encodings); | ||
// Skip empty labels | ||
continue; | ||
} | ||
|
||
let label_classes: Vec<&str> = label.split('|').collect(); | ||
for label_class in label_classes { | ||
if let Some(&index) = class_to_index.get(&label_class.to_string()) { | ||
label_encodings[index as usize] = 1u32; | ||
} else { | ||
return Err(MultiHotEncodeError::new(&format!("Label not found: {}", label_class))); | ||
} | ||
} | ||
log::debug!("Encoding: {:?}", label_encodings); | ||
|
||
all_encodings.append(&mut label_encodings); | ||
} | ||
log::debug!("All encodings {:?}", all_encodings); | ||
|
||
Ok(all_encodings) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,52 @@ | ||
use std::{error::Error, fmt}; | ||
|
||
|
||
#[derive(Debug)] | ||
pub enum InferenceError { | ||
ArrayConversionError(&'static str), | ||
pub struct MultiHotEncodeError { | ||
pub message: String, | ||
} | ||
|
||
impl fmt::Display for InferenceError { | ||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
match *self { | ||
InferenceError::ArrayConversionError(error_message) => { | ||
write!(f, "Inference Error - {}", error_message) | ||
} | ||
impl MultiHotEncodeError { | ||
pub fn new(message: &str) -> MultiHotEncodeError { | ||
MultiHotEncodeError { | ||
message: message.to_string(), | ||
} | ||
} | ||
} | ||
|
||
impl Error for InferenceError {} | ||
impl fmt::Display for MultiHotEncodeError { | ||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
write!(f, "{}", self.message) | ||
} | ||
} | ||
|
||
impl Error for MultiHotEncodeError {} | ||
|
||
#[derive(Debug)] | ||
pub enum VocabularyLoadError { | ||
Io(std::io::Error), | ||
Json(serde_json::Error), | ||
} | ||
|
||
impl From<std::io::Error> for VocabularyLoadError { | ||
fn from(err: std::io::Error) -> Self { | ||
VocabularyLoadError::Io(err) | ||
} | ||
} | ||
|
||
impl From<serde_json::Error> for VocabularyLoadError { | ||
fn from(err: serde_json::Error) -> Self { | ||
VocabularyLoadError::Json(err) | ||
} | ||
} | ||
|
||
impl std::error::Error for VocabularyLoadError {} | ||
|
||
impl std::fmt::Display for VocabularyLoadError { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
match self { | ||
VocabularyLoadError::Io(err) => write!(f, "IO error: {}", err), | ||
VocabularyLoadError::Json(err) => write!(f, "JSON error: {}", err), | ||
} | ||
} | ||
} |
Oops, something went wrong.