-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
36 lines (32 loc) · 1004 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import matplotlib.pyplot as plt
# Plotting activation functions and their derivatives
def plot_fn(x, fn, fn_name="", save=False):
y = fn.forward(x)
dy = fn.backward(x)
plt.plot(x, y, label="f(x)")
plt.plot(x, dy, label="d/dx f(x)")
plt.title(fn_name)
plt.legend()
if save:
plt.savefig(fn_name + ".png")
plt.show()
# Plotting loss
def plot_loss(train_loss_arr, valid_loss_arr, name, save=False):
plt.plot(train_loss_arr)
plt.plot(valid_loss_arr)
plt.title("Loss using " + name)
plt.xlabel("epochs")
plt.ylabel("loss")
plt.legend(["train", "valid"])
if save:
plt.savefig(name + "loss.png")
plt.show()
# Plotting predictions
def plot_predictions(X_test, y_test, y_preds, name, save=False):
plt.scatter(X_test, y_test, label="true", s=0.5)
plt.scatter(X_test, y_preds, label="pred", s=0.5)
plt.title("Predictions using " + name)
plt.legend()
if save:
plt.savefig(name + "pred.png")
plt.show()