Skip to content

Commit

Permalink
Create helpers.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KOSASIH authored Aug 4, 2024
1 parent 840da5b commit dfedde9
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions blockchain_integration/pi_network/daictd/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import logging
import torch
import numpy as np
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

def create_logger(name, level=logging.INFO):
logger = logging.getLogger(name)
logger.setLevel(level)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger

def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path)

def save_checkpoint(model, optimizer, epoch, path):
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}, path)

def load_checkpoint(model, optimizer, path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']

def calculate_metrics(y_true, y_pred):
accuracy = accuracy_score(y_true, y_pred)
report = classification_report(y_true, y_pred)
matrix = confusion_matrix(y_true, y_pred)
return accuracy, report, matrix

def plot_confusion_matrix(matrix, classes, path):
import matplotlib.pyplot as plt
plt.imshow(matrix, interpolation='nearest', cmap='Blues')
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.savefig(path)

def plot_loss_curve(losses, path):
import matplotlib.pyplot as plt
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.savefig(path)

def plot_accuracy_curve(accuracies, path):
import matplotlib.pyplot as plt
plt.plot(accuracies)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Curve')
plt.savefig(path)

0 comments on commit dfedde9

Please sign in to comment.