Skip to content

Commit

Permalink
plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecvet committed Sep 21, 2023
1 parent 0fd4fb3 commit f86fd48
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021"

[profile.release]
debug = true
lto = "thin"
opt-level = 3

[dependencies]
Expand Down
27 changes: 27 additions & 0 deletions plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import matplotlib.pyplot as plt
import numpy as np

# Read data from file
with open('entropy2.out', 'r') as f:
numbers = [float(line.strip()) for line in f.readlines()]

# Set up x-axis values
x_values = list(range(1, len(numbers) + 1))

# Compute polynomial fit (degree 3 as an example, but you can adjust this)
z = np.polyfit(x_values, numbers, 3)
p = np.poly1d(z)

# Plot the raw data points
plt.plot(x_values, numbers, marker='o', label='Data Points')

# Plot the smoothed trend line
plt.plot(x_values, p(x_values), 'r-', label='Trend Line')

plt.title("Visualization of cross-entropy error")
plt.xlabel("Epochs")
plt.ylabel("Entropy")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
12 changes: 6 additions & 6 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,12 @@ impl Model {
let std_w2 = (2.0 / self.w2.len() as f64).sqrt();
let distribution_w2 = Normal::new(0.0, std_w2).expect("Failed to create distribution");

println!("distriubtion 1 {:?} 2 {:?}", distribution_w1, distribution_w2);

for val in self.w2.iter_mut() {
*val = rng.sample(distribution_w2);
}

println!("initialized He weights: {}\n{}", self.w1, self.w2);
}

/// Runs forward propagation against this neural network.
/// Runs forward propagation against this neural network. Collects predicted output
pub fn
forward_propagation (&self, x: &Array2<f64>) -> (Array2<f64>, Array2<f64>)
{
Expand All @@ -127,10 +123,13 @@ impl Model {
(a1, probabilities)
}

/// Runs back propagation aginst this neural network.
/// Runs back propagation aginst this neural network. Compute the gradient of the
/// loss function with respect to the weights and biases. This gradient is then
/// used to update the weights and biases using gradient descent.
pub fn
back_propagation (&mut self, training_data: TrainingData, rate: f64) -> f64
{
// Run prediction
let (a, probabilities) = self.forward_propagation(&training_data.x);

// Compute the cross-entropy loss from the forward propagation step
Expand Down Expand Up @@ -173,6 +172,7 @@ impl Model {
Some(self.w1.row(*indx).to_owned().into_raw_vec())
},

// The token was not found in the map
None => None
}
}
Expand Down

0 comments on commit f86fd48

Please sign in to comment.