diff --git a/cirtorch/utils/evaluate.py b/cirtorch/utils/evaluate.py index 9b3ae60..431a962 100644 --- a/cirtorch/utils/evaluate.py +++ b/cirtorch/utils/evaluate.py @@ -102,7 +102,7 @@ def compute_map(ranks, gnd, kappas=[]): pos += 1 # get it to 1-based for j in np.arange(len(kappas)): kq = min(max(pos), kappas[j]); - prs[i, j] = (pos <= kq).sum() / kq + prs[i, j] = (pos <= kq).astype(float).sum() / kq pr = pr + prs[i, :] map = map / (nq - nempty) @@ -146,4 +146,4 @@ def compute_map_and_print(dataset, ranks, gnd, kappas=[1, 5, 10]): mapH, apsH, mprH, prsH = compute_map(ranks, gnd_t, kappas) print('>> {}: mAP E: {}, M: {}, H: {}'.format(dataset, np.around(mapE*100, decimals=2), np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2))) - print('>> {}: mP@k{} E: {}, M: {}, H: {}'.format(dataset, kappas, np.around(mprE*100, decimals=2), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2))) \ No newline at end of file + print('>> {}: mP@k{} E: {}, M: {}, H: {}'.format(dataset, kappas, np.around(mprE*100, decimals=2), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2)))