Skip to content

Commit

Permalink
Add matrix addition
Browse files Browse the repository at this point in the history
  • Loading branch information
opixelum committed Feb 6, 2024
1 parent cf952c4 commit f35ab4c
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions rust/src/math/matrixes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use num_traits::cast::NumCast;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Div, Mul};

#[derive(Debug, PartialEq)]
pub struct Matrix<T> {
pub data: Vec<T>,
pub shape: Vec<usize>,
Expand Down Expand Up @@ -47,7 +48,24 @@ where
}
}

pub fn matrix_multiply<T>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>, &'static str>
pub fn matrix_addition<T>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>, &'static str>
where
T: Default + Clone + Add<Output = T> + Mul<Output = T> + Sum,
{
if a.shape != b.shape {
return Err("Dimension mismatch for matrix addition");
}

let mut result = Matrix::new(a.shape.clone());

for (i, (a_elem, b_elem)) in a.data.iter().zip(b.data.iter()).enumerate() {
result.data[i] = a_elem.clone() + b_elem.clone();
}

return Ok(result);
}

pub fn matrix_multiplication<T>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>, &'static str>
where
T: Default + Clone + Add<Output = T> + Mul<Output = T> + Sum,
{
Expand Down Expand Up @@ -203,7 +221,31 @@ mod test {
}

#[test]
fn test_matrix_multiply() {
fn test_matrix_addition() {
// Should work
let a = Matrix {
data: vec![1, 2, 3, 4, 5, 6],
shape: vec![2, 3],
};
let b = Matrix {
data: vec![7, 8, 9, 10, 11, 12],
shape: vec![2, 3],
};
let result = matrix_addition(&a, &b).unwrap();
assert_eq!(result.data, vec![8, 10, 12, 14, 16, 18]);
assert_eq!(result.shape, vec![2, 3]);

// Should fail due to dimension mismatch
let c = Matrix {
data: vec![1, 2, 3, 4],
shape: vec![2, 2],
};
let result = matrix_addition(&a, &c);
assert_eq!(result, Err("Dimension mismatch for matrix addition"));
}

#[test]
fn test_matrix_multiplication() {
let a = Matrix {
data: vec![1, 2, 3, 4, 5, 6],
shape: vec![2, 3],
Expand All @@ -212,7 +254,7 @@ mod test {
data: vec![7, 8, 9, 10, 11, 12],
shape: vec![3, 2],
};
let result = matrix_multiply(&a, &b).unwrap();
let result = matrix_multiplication(&a, &b).unwrap();
assert_eq!(result.data, vec![58, 64, 139, 154]);
assert_eq!(result.shape, vec![2, 2]);
}
Expand Down

0 comments on commit f35ab4c

Please sign in to comment.