-
Notifications
You must be signed in to change notification settings - Fork 0
/
gather.py
77 lines (63 loc) · 2.54 KB
/
gather.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- encoding:utf-8 -*-
"""
Author: Zhaopeng Qiu
Date: create at 2020-10-02
training script
CUDA_VISIBLE_DEVICES=0,1,2 python training.py training.gpus=3
"""
import os
import argparse
from tqdm import tqdm
import json
import scipy.stats as ss
import numpy as np
import pandas as pd
import math
import torch
from utils.log_util import NEPTUNE_API_TOKEN
from utils.log_util import convert_omegaconf_to_dict
from utils.train_util import set_seed
from utils.train_util import save_checkpoint_by_epoch
from utils.eval_util import group_labels
from utils.eval_util import cal_metric
def gather(output_path, filenum, validate=False, save=True):
preds = []
labels = []
imp_indexes = []
for i in range(filenum):
with open(output_path + 'tmp_{}.json'.format(i), 'r', encoding='utf-8') as f:
cur_result = json.load(f)
imp_indexes += cur_result['imp']
labels += cur_result['labels']
preds += cur_result['preds']
print(len(preds))
all_keys = list(set(imp_indexes))
group_labels = {k: [] for k in all_keys}
group_preds = {k: [] for k in all_keys}
for l, p, k in zip(labels, preds, imp_indexes):
group_labels[k].append(l)
group_preds[k].append(p)
if validate:
all_labels = []
all_preds = []
for k in all_keys:
all_labels.append(group_labels[k])
all_preds.append(group_preds[k])
metric_list = [x.strip() for x in "group_auc || mean_mrr || ndcg@5;10".split("||")]
ret = cal_metric(all_labels, all_preds, metric_list)
for metric, val in ret.items():
print("Epoch: {}, {}: {}".format(1, metric, val))
if save:
final_arr = []
for k in group_preds.keys():
new_row = []
new_row.append(k)
new_row.append(','.join(list(map(str, np.array(group_labels[k]).astype(int)))))
new_row.append(','.join(list(map(str, np.array(group_preds[k]).astype(float)))))
rank = ss.rankdata(-np.array(group_preds[k])).astype(int).tolist()
new_row.append('[' + ','.join(list(map(str, rank))) + ']')
assert(len(rank) == len(group_labels[k]))
final_arr.append(new_row)
fdf = pd.DataFrame(final_arr, columns=['impression', 'labels', 'preds', 'ranks'])
fdf.drop(columns=['labels', 'ranks']).to_csv(output_path + 'score.txt', sep=' ', index=False)
fdf.drop(columns=['labels', 'preds']).to_csv(output_path + 'result.txt', header=None, sep=' ', index=False)