Skip to content

Commit

Permalink
embeddings similarities
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecvet committed Sep 19, 2023
1 parent b68ba27 commit b3c917d
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 36 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
debug/
target/
.vscode
model.out

# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ opt-level = 3
clap = { version = "4.2.7", features = ["derive"] }
ndarray = {version = "0.15.6", features = ["serde"]}
ndarray-rand = "0.14.0"
ordered-float = "2.7"
rand = "0.8.5"
regex = "1.9.3"
serde = { version = "1.0", features = ["derive"] }
Expand Down
3 changes: 3 additions & 0 deletions sf_ba_oak.txt

Large diffs are not rendered by default.

107 changes: 75 additions & 32 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use std::{fs, io::Read};
use ndarray::{Array, Array1, Array2, ArrayView, Axis, Ix2};
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Uniform;
use ordered_float::OrderedFloat;
use rand::distributions::Standard;
use rand::prelude::*;
use rand::{thread_rng, Rng};
use regex::Regex;
use std::collections::{HashMap, HashSet};
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::cmp::Reverse;
use std::ops::{Mul, Sub, SubAssign};
use word2vec_rs::*;

Expand All @@ -22,8 +24,14 @@ const WINDOW_SIZE: i32 = 3;
const LEARNING_RATE: f64 = 0.0003;

fn
run (epochs: usize, print_entropy: bool, query: Option<&str>, model_path: Option<&String>, text: &str)
{
run (
epochs: usize,
print_entropy: bool,
query: Option<&str>,
model_path: Option<&String>,
save: bool,
text: &str
) {
let metadata = Metadata::init(text);

let mut model = Model::new(metadata.vocab_size, EMBEDDINGS_SIZE);
Expand All @@ -39,35 +47,38 @@ run (epochs: usize, print_entropy: bool, query: Option<&str>, model_path: Option

println!("model trained");

match model.save_to_file("./model.out") {
Err(error) => println!("error: {}", error),
_ => ()
if save {
match model.save_to_file("./model.out") {
Err(error) => println!("error: {}", error),
_ => ()
}
}

match query {
Some(q) => {
predict(q, &model, &metadata.token_to_index, &metadata.index_to_token, metadata.vocab_size)
nn_forward_propagation(q, &model, &metadata.token_to_index, &metadata.index_to_token, metadata.vocab_size);

let mut word_embeddings: HashMap<String, Vec<f64>> = HashMap::new();
for entry in metadata.token_to_index.iter() {
word_embeddings.insert(
entry.0.clone(),
get_embedding(&model, entry.0, &metadata.token_to_index).unwrap()
);
}

closest_embeddings (&word_embeddings, q);

match word_analogy(&word_embeddings, q, "oakland", "clara") {
Some(analogy) => println!("best word analogy: {}", analogy),
_ => println!("could not find an analogy for {}", ""),
}
},
_ => ()
}

println!("query done");

let mut word_embeddings: HashMap<String, Vec<f64>> = HashMap::new();
for entry in metadata.token_to_index.iter() {
let embedding_matrix = get_embedding(&model, entry.0, &metadata.token_to_index).unwrap();
let v = embedding_matrix.into_raw_vec();
word_embeddings.insert(entry.0.clone(), v);
}

match word_analogy(&word_embeddings, "alameda", "oakland", "jose") {
Some(analogy) => println!("best word analogy: {}", analogy),
_ => println!("could not find an analogy for {}", ""),
}
}

fn
predict (
nn_forward_propagation (
query: &str,
model: &Model,
token_to_index: &HashMap<String, usize>,
Expand All @@ -86,22 +97,43 @@ predict (
indices.sort_by(|&a, &b| probabilities[b].partial_cmp(&probabilities[a]).unwrap());
let sorted_values: Vec<f64> = indices.iter().map(|&i| probabilities[i]).collect();

println!("Most similar nearby tokens to [{}]:\n", query);
println!("Most similar nearby tokens to [{}] via nn forward propagation:\n", query);

let mut i = 0;
for iter in indices.iter().zip(sorted_values.iter()) {
println!("[{}]: {}\t| probability: {}", i, index_to_token[iter.0], iter.1);
i += 1;

if i >= 20 {
if i >= 10 {
break;
}
}

print_embedding(
query,
&get_embedding(&model, query, &token_to_index).unwrap()
);
// print_embedding(
// query,
// &get_embedding(&model, query, &token_to_index).unwrap()
// );
}

fn
closest_embeddings (embeddings: &HashMap<String, Vec<f64>>, query: &str)
{
let mut heap: BinaryHeap<(OrderedFloat<f64>, String)> = BinaryHeap::new();
heap.push((OrderedFloat(0.4 as f64), "a".to_string()));

let query_vector = embeddings.get(query).unwrap();

for (word, vector) in embeddings.iter() {
let similarity = cosine_similarity(&query_vector, vector);
heap.push((OrderedFloat(similarity), word.to_string()));
}

println!("\nMost similar nearby tokens to [{}] via embeddings cosine similarity:\n", query);

for i in 0..10 {
let tmp = heap.pop().unwrap();
println!("[{}]: {}\t| probability: {}", i, tmp.1, tmp.0.0);
}
}

fn
Expand All @@ -127,7 +159,7 @@ cosine_similarity (v1: &[f64], v2: &[f64]) -> f64
fn
word_analogy (embeddings: &HashMap<String, Vec<f64>>, a: &str, b: &str, c: &str) -> Option<String>
{
println!("computing analogy for {} => {}, {}?", a, b, c);
println!("\nComputing analogy for {} - {} + {} = ?", a, b, c);
let v_a = embeddings.get(a)?;
let v_b = embeddings.get(b)?;
let v_c = embeddings.get(c)?;
Expand All @@ -144,6 +176,10 @@ word_analogy (embeddings: &HashMap<String, Vec<f64>>, a: &str, b: &str, c: &str)
if word != a && word != b && word != c {
let similarity = cosine_similarity(&target_vector, vector);

if word.eq("clara") {
println!(">> clara: {}", similarity);
}

if similarity > max_similarity {
max_similarity = similarity;
best_word = Some(word.clone());
Expand All @@ -166,13 +202,15 @@ main ()
.arg(arg!(--epochs <VALUE>).required(false))
.arg(arg!(--predict <VALUE>).required(false))
.arg(arg!(--load <VALUE>).required(false))
.arg(arg!(--save).required(false))
.get_matches();

let entropy_opt = matches.get_one::<bool>("entropy");
let input_opt = matches.get_one::<String>("input");
let epochs_opt = matches.get_one::<String>("epochs");
let predict_opt = matches.get_one::<String>("predict");
let load_opt = matches.get_one::<String>("load");
let save_opt = matches.get_one::<bool>("save");

let epochs = match epochs_opt.as_deref() {
Some(epoch_string) => epoch_string.parse::<usize>().unwrap(),
Expand All @@ -189,15 +227,20 @@ main ()
}
};

let save = match save_opt {
Some(s) => s,
_ => &true
};

match (entropy_opt, predict_opt) {
(Some(true), None) => {
run(epochs, true, None, load_opt, &text);
run(epochs, true, None, load_opt, *save, &text);
},
(Some(true), Some(query)) => {
run(epochs, true, Some(query), load_opt, &text)
run(epochs, true, Some(query), load_opt, *save, &text)
},
(Some(false), Some(query)) => {
run(epochs, false, Some(query), load_opt, &text)
run(epochs, false, Some(query), load_opt, *save, &text)
},
_ => {
println!("no options provided");
Expand Down
6 changes: 2 additions & 4 deletions src/train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,11 @@ cross_entropy (z: &Array2<f64>, y: &Array2<f64>) -> f64
}

pub fn
get_embedding (model: &Model, token: &str, token_to_index: &HashMap<String, usize>) -> Option<Array2<f64>>
get_embedding (model: &Model, token: &str, token_to_index: &HashMap<String, usize>) -> Option<Vec<f64>>
{
match token_to_index.get(token) {
Some(indx) => {
Some(
forward_propagation(model, &encode(*indx, token_to_index.len())).0
)
Some(model.w1.row(*indx).to_owned().into_raw_vec())
},

None => None
Expand Down

0 comments on commit b3c917d

Please sign in to comment.