Skip to content

Commit

Permalink
run.sh modified
Browse files Browse the repository at this point in the history
astorfi committed Jul 30, 2017
1 parent 807cf79 commit a7f7249
Showing 4 changed files with 105 additions and 13 deletions.
36 changes: 33 additions & 3 deletions code/4-ROC_PR_curve/PlotHIST.py
Original file line number Diff line number Diff line change
@@ -10,8 +10,9 @@
import scipy.io as sio
from sklearn import *
import matplotlib.pyplot as plt
import os

def Plot_HIST_Fn(label,distance, phase, num_bins = 50):
def Plot_HIST_Fn(label,distance, save_path, num_bins = 50):

dissimilarity = distance[:]
gen_dissimilarity_original = []
@@ -27,6 +28,35 @@ def Plot_HIST_Fn(label,distance, phase, num_bins = 50):
plt.hist(gen_dissimilarity_original, bins, alpha=0.5, facecolor='blue', normed=False, label='gen_dist_original')
plt.hist(imp_dissimilarity_original, bins, alpha=0.5, facecolor='red', normed=False, label='imp_dist_original')
plt.legend(loc='upper right')
plt.title(phase + '_' + 'OriginalFeatures_Histogram.jpg')
plt.title('OriginalFeatures_Histogram.jpg')
plt.show()
fig.savefig(phase + '_' + 'OriginalFeatures_Histogram.jpg')
fig.savefig(save_path)

if __name__ == '__main__':

tf.app.flags.DEFINE_string(
'evaluation_dir', '../../results/SCORES',
'Directory where checkpoints and event logs are written to.')

tf.app.flags.DEFINE_string(
'plot_dir', '../../results/PLOTS',
'Directory where plots are saved to.')

tf.app.flags.DEFINE_integer(
'num_bins', '50',
'Number of bins for plotting histogram.')

# Store all elemnts in FLAG structure!
FLAGS = tf.app.flags.FLAGS

# Loading necessary data.
score = np.load(os.path.join(FLAGS.evaluation_dir,'score_vector.npy'))
label = np.load(os.path.join(FLAGS.evaluation_dir,'target_label_vector.npy'))
save_path = os.path.join(FLAGS.plot_dir,'Histogram.jpg')

# Creating the path
if not os.path.exists(FLAGS.plot_dir):
os.makedirs(FLAGS.plot_dir)

Plot_HIST_Fn(label,score, save_path, FLAGS.num_bins)

30 changes: 27 additions & 3 deletions code/4-ROC_PR_curve/PlotPR.py
Original file line number Diff line number Diff line change
@@ -10,11 +10,12 @@
import scipy.io as sio
from sklearn import *
import matplotlib.pyplot as plt
import os

def Plot_PR_Fn(label,distance,phase):

precision, recall, thresholds = metrics.precision_recall_curve(label, -distance, pos_label=1, sample_weight=None)
AP = metrics.average_precision_score(label, -distance, average='macro', sample_weight=None)
precision, recall, thresholds = metrics.precision_recall_curve(label, distance, pos_label=1, sample_weight=None)
AP = metrics.average_precision_score(label, distance, average='macro', sample_weight=None)

# AP(average precision) calculation.
# This score corresponds to the area under the precision-recall curve.
@@ -38,5 +39,28 @@ def Plot_PR_Fn(label,distance,phase):
# plt.text(0.5, 0.5, 'AP = ' + str(AP), fontdict=None)
plt.grid()
plt.show()
fig.savefig(phase + '_' + 'PR.jpg')
fig.savefig(save_path)

if __name__ == '__main__':

tf.app.flags.DEFINE_string(
'evaluation_dir', '../../results/SCORES',
'Directory where checkpoints and event logs are written to.')

tf.app.flags.DEFINE_string(
'plot_dir', '../../results/PLOTS',
'Directory where plots are saved to.')

# Store all elemnts in FLAG structure!
FLAGS = tf.app.flags.FLAGS

# Loading necessary data.
score = np.load(os.path.join(FLAGS.evaluation_dir,'score_vector.npy'))
label = np.load(os.path.join(FLAGS.evaluation_dir,'target_label_vector.npy'))
save_path = os.path.join(FLAGS.plot_dir,'PR.jpg')

# Creating the path
if not os.path.exists(FLAGS.plot_dir):
os.makedirs(FLAGS.plot_dir)

Plot_PR_Fn(label,score,save_path)
37 changes: 32 additions & 5 deletions code/4-ROC_PR_curve/PlotROC.py
Original file line number Diff line number Diff line change
@@ -10,12 +10,14 @@
import scipy.io as sio
from sklearn import *
import matplotlib.pyplot as plt
import os


def Plot_ROC_Fn(label,distance,phase):

fpr, tpr, thresholds = metrics.roc_curve(label, -distance, pos_label=1)
AUC = metrics.roc_auc_score(label, -distance, average='macro', sample_weight=None)
def Plot_ROC_Fn(label,distance,save_path):

fpr, tpr, thresholds = metrics.roc_curve(label, distance, pos_label=1)
AUC = metrics.roc_auc_score(label, distance, average='macro', sample_weight=None)
# AP = metrics.average_precision_score(label, -distance, average='macro', sample_weight=None)

# Calculating EER
@@ -37,7 +39,7 @@ def Plot_ROC_Fn(label,distance,phase):
plt.setp(lines, linewidth=2, color='r')
ax.set_xticks(np.arange(0, 1.1, 0.1))
ax.set_yticks(np.arange(0, 1.1, 0.1))
plt.title(phase + '_' + 'ROC.jpg')
plt.title('ROC.jpg')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')

@@ -52,7 +54,32 @@ def Plot_ROC_Fn(label,distance,phase):
# plt.text(0.5, 0.4, 'EER = ' + str(EER), fontdict=None)
plt.grid()
plt.show()
fig.savefig(phase + '_' + 'ROC.jpg')
fig.savefig(save_path)

if __name__ == '__main__':

tf.app.flags.DEFINE_string(
'evaluation_dir', '../../results/SCORES',
'Directory where checkpoints and event logs are written to.')

tf.app.flags.DEFINE_string(
'plot_dir', '../../results/PLOTS',
'Directory where plots are saved to.')

# Store all elemnts in FLAG structure!
FLAGS = tf.app.flags.FLAGS

# Loading scores and labels
score = np.load(os.path.join(FLAGS.evaluation_dir,'score_vector.npy'))
label = np.load(os.path.join(FLAGS.evaluation_dir,'target_label_vector.npy'))
save_path = os.path.join(FLAGS.plot_dir,'ROC.jpg')

# Creating the path
if not os.path.exists(FLAGS.plot_dir):
os.makedirs(FLAGS.plot_dir)

Plot_ROC_Fn(label,score,save_path)




15 changes: 13 additions & 2 deletions run.sh
Original file line number Diff line number Diff line change
@@ -19,10 +19,21 @@ if [ $do_training = 'train' ]; then
python -u ./code/2-enrollment/enrollment.py --development_dataset_path=$development_dataset --enrollment_dataset_path=$enrollment_dataset --checkpoint_dir=results/TRAIN_CNN_3D/ --enrollment_dir=results/Model

# evaluation
python -u ./code/3-evaluation/evaluation.py --development_dataset_path=$development_dataset --evaluation_dataset_path=$evaluation_dataset --checkpoint_dir=results/TRAIN_CNN_3D/ --evaluation_dir=results/ROC --enrollment_dir=results/Model
python -u ./code/3-evaluation/evaluation.py --development_dataset_path=$development_dataset --evaluation_dataset_path=$evaluation_dataset --checkpoint_dir=results/TRAIN_CNN_3D/ --evaluation_dir=results/SCORES --enrollment_dir=results/Model

# ROC curve
python -u ./code/4-ROC_PR_curve/calculate_roc.py --evaluation_dir=results/ROC
python -u ./code/4-ROC_PR_curve/calculate_roc.py --evaluation_dir=results/SCORES

# Plot ROC
python -u ./code/4-ROC_PR_curve/PlotROC.py --evaluation_dir=results/SCORES --plot_dir=results/PLOTS

# Plot ROC
python -u ./code/4-ROC_PR_curve/PlotPR.py --evaluation_dir=results/SCORES --plot_dir=results/PLOTS

# Plot HIST
python -u ./code/4-ROC_PR_curve/PlotHIST.py --evaluation_dir=results/SCORES --plot_dir=results/PLOTS --num_bins=5



else

0 comments on commit a7f7249

Please sign in to comment.