-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelpfunctions.py
29 lines (23 loc) · 940 Bytes
/
helpfunctions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score
from sklearn.metrics import accuracy_score
import torch
import json
import os
def accuracy(output,target):
'''
:param output:
:param target:
:return:
'''
return accuracy_score(output, target)
def evaluation_metrics(preds, target) -> tuple:
eval_metrics = classification_report(preds, target)
precision = precision_score(preds, target, average='macro')
recall = recall_score(preds, target, average='macro')
F1_score = f1_score(preds, target, average='macro')
return ({'precision':precision, 'recall':recall, 'F1_score':F1_score}, eval_metrics)
def save_checkpoint(config, filename, model_name, args):
model = model_name
torch.save(model.state_dict(), filename)
with open(os.path.join(args.model_path, '{}.txt'.format(config['model_name'])), 'w') as f:
f.write(json.dumps(config))