Skip to content

Commit

Permalink
Add tests for layer backpropagation.
Browse files Browse the repository at this point in the history
  • Loading branch information
LordSaumya committed Jul 8, 2024
1 parent bf77bd9 commit 69ddaae
Showing 1 changed file with 111 additions and 0 deletions.
111 changes: 111 additions & 0 deletions rusty_kan/src/tests/layer_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,114 @@ fn layer_forward_more_cols_fail() {
layer.forward(input);
}

#[test]
fn layer_backward_pass() {
// Node 1
let incoming_edge_11: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 0, BSpline::new(Vector::new(vec![1.0, 2.0, 3.0]), 2), 0)));
let incoming_edge_12: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(1, 0, BSpline::new(Vector::new(vec![1.5, 2.5, 3.5]), 2), 0)));

let outgoing_edge_11: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 0, BSpline::new(Vector::new(vec![0.5, 1.5, 2.5]), 2), 0)));

let node_1: Node = Node::new(vec![incoming_edge_11, incoming_edge_12], vec![outgoing_edge_11], 0);

// Node 2
let incoming_edge_21: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 1, BSpline::new(Vector::new(vec![0.0, 1.0, 2.0]), 2), 0)));
let incoming_edge_22: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(1, 1, BSpline::new(Vector::new(vec![0.5, 1.5, 2.5]), 2), 0)));

let outgoing_edge_21: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 1, BSpline::new(Vector::new(vec![0.0, 1.0, 2.0]), 2), 0)));

let node_2: Node = Node::new(vec![incoming_edge_21, incoming_edge_22], vec![outgoing_edge_21], 0);

// Layer
let nodes: Vec<Rc<RefCell<Node>>> = vec![node_1.clone(), node_2.clone()].iter().map(|node| Rc::new(RefCell::new(node.clone()))).collect();
let layer: Layer = Layer::new(nodes);

let inputs: Matrix = Matrix::new(vec![Vector::from(vec![0.1, 0.2]), Vector::from(vec![0.3, 0.4])]);
let upstream_gradient: Vector = Vector::from(vec![0.4, 0.8]);

layer.backward(inputs.clone(), upstream_gradient.clone()).unwrap();

// Node 1:
// Incoming edge 1
let incoming_edge_1: RefMut<Edge> = node_1.incoming[0].borrow_mut();
assert_is_close!(incoming_edge_1.gradient[0], incoming_edge_1.clone().spline.basis(0, incoming_edge_1.spline.degree, inputs[0][0]) * upstream_gradient[0], 1e-3);
assert_is_close!(incoming_edge_1.gradient[1], incoming_edge_1.clone().spline.basis(1, incoming_edge_1.spline.degree, inputs[0][0]) * upstream_gradient[0], 1e-3);
assert_is_close!(incoming_edge_1.gradient[2], incoming_edge_1.clone().spline.basis(2, incoming_edge_1.spline.degree, inputs[0][0]) * upstream_gradient[0], 1e-3);

// Incoming edge 2
let incoming_edge_2: RefMut<Edge> = node_1.incoming[1].borrow_mut();
assert_is_close!(incoming_edge_2.gradient[0], incoming_edge_2.clone().spline.basis(0, incoming_edge_2.spline.degree, inputs[0][1]) * upstream_gradient[0], 1e-3);
assert_is_close!(incoming_edge_2.gradient[1], incoming_edge_2.clone().spline.basis(1, incoming_edge_2.spline.degree, inputs[0][1]) * upstream_gradient[0], 1e-3);
assert_is_close!(incoming_edge_2.gradient[2], incoming_edge_2.clone().spline.basis(2, incoming_edge_2.spline.degree, inputs[0][1]) * upstream_gradient[0], 1e-3);

// Node 2:
// Incoming edge 1
let incoming_edge_1: RefMut<Edge> = node_2.incoming[0].borrow_mut();
assert_is_close!(incoming_edge_1.gradient[0], incoming_edge_1.clone().spline.basis(0, incoming_edge_1.spline.degree, inputs[1][0]) * upstream_gradient[1], 1e-3);
assert_is_close!(incoming_edge_1.gradient[1], incoming_edge_1.clone().spline.basis(1, incoming_edge_1.spline.degree, inputs[1][0]) * upstream_gradient[1], 1e-3);
assert_is_close!(incoming_edge_1.gradient[2], incoming_edge_1.clone().spline.basis(2, incoming_edge_1.spline.degree, inputs[1][0]) * upstream_gradient[1], 1e-3);

// Incoming edge 2
let incoming_edge_2: RefMut<Edge> = node_2.incoming[1].borrow_mut();
assert_is_close!(incoming_edge_2.gradient[0], incoming_edge_2.clone().spline.basis(0, incoming_edge_2.spline.degree, inputs[1][1]) * upstream_gradient[1], 1e-3);
assert_is_close!(incoming_edge_2.gradient[1], incoming_edge_2.clone().spline.basis(1, incoming_edge_2.spline.degree, inputs[1][1]) * upstream_gradient[1], 1e-3);
assert_is_close!(incoming_edge_2.gradient[2], incoming_edge_2.clone().spline.basis(2, incoming_edge_2.spline.degree, inputs[1][1]) * upstream_gradient[1], 1e-3);
}

#[test]
#[should_panic]
fn layer_backward_wrong_input_dims_fail() {
// Node 1
let incoming_edge_11: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 0, BSpline::new(Vector::new(vec![1.0, 2.0, 3.0]), 2), 0)));
let incoming_edge_12: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(1, 0, BSpline::new(Vector::new(vec![1.5, 2.5, 3.5]), 2), 0)));

let outgoing_edge_11: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 0, BSpline::new(Vector::new(vec![0.5, 1.5, 2.5]), 2), 0)));

let node_1: Node = Node::new(vec![incoming_edge_11, incoming_edge_12], vec![outgoing_edge_11], 0);

// Node 2
let incoming_edge_21: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 1, BSpline::new(Vector::new(vec![0.0, 1.0, 2.0]), 2), 0)));
let incoming_edge_22: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(1, 1, BSpline::new(Vector::new(vec![0.5, 1.5, 2.5]), 2), 0)));

let outgoing_edge_21: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 1, BSpline::new(Vector::new(vec![0.0, 1.0, 2.0]), 2), 0)));

let node_2: Node = Node::new(vec![incoming_edge_21, incoming_edge_22], vec![outgoing_edge_21], 0);

// Layer
let nodes: Vec<Rc<RefCell<Node>>> = vec![node_1.clone(), node_2.clone()].iter().map(|node| Rc::new(RefCell::new(node.clone()))).collect();
let layer: Layer = Layer::new(nodes);

let inputs: Matrix = Matrix::new(vec![Vector::from(vec![0.1, 0.2])]);
let upstream_gradient: Vector = Vector::from(vec![0.4, 0.8]);

layer.backward(inputs.clone(), upstream_gradient.clone()).unwrap();
}

#[test]
#[should_panic]
fn layer_backward_wrong_gradient_dims_fail() {
// Node 1
let incoming_edge_11: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 0, BSpline::new(Vector::new(vec![1.0, 2.0, 3.0]), 2), 0)));
let incoming_edge_12: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(1, 0, BSpline::new(Vector::new(vec![1.5, 2.5, 3.5]), 2), 0)));

let outgoing_edge_11: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 0, BSpline::new(Vector::new(vec![0.5, 1.5, 2.5]), 2), 0)));

let node_1: Node = Node::new(vec![incoming_edge_11, incoming_edge_12], vec![outgoing_edge_11], 0);

// Node 2
let incoming_edge_21: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 1, BSpline::new(Vector::new(vec![0.0, 1.0, 2.0]), 2), 0)));
let incoming_edge_22: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(1, 1, BSpline::new(Vector::new(vec![0.5, 1.5, 2.5]), 2), 0)));

let outgoing_edge_21: Rc<RefCell<Edge>> = Rc::new(RefCell::new(Edge::new(0, 1, BSpline::new(Vector::new(vec![0.0, 1.0, 2.0]), 2), 0)));

let node_2: Node = Node::new(vec![incoming_edge_21, incoming_edge_22], vec![outgoing_edge_21], 0);

// Layer
let nodes: Vec<Rc<RefCell<Node>>> = vec![node_1.clone(), node_2.clone()].iter().map(|node| Rc::new(RefCell::new(node.clone()))).collect();
let layer: Layer = Layer::new(nodes);

let inputs: Matrix = Matrix::new(vec![Vector::from(vec![0.1, 0.2]), Vector::from(vec![0.3, 0.4])]);
let upstream_gradient: Vector = Vector::from(vec![0.4]);

layer.backward(inputs.clone(), upstream_gradient.clone()).unwrap();
}

0 comments on commit 69ddaae

Please sign in to comment.