-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_for_paper.py
165 lines (153 loc) · 10.4 KB
/
test_for_paper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from matplotlib import pyplot
dibert_train_loss_cls=[0.48430071738692204, 0.27730153397313245, 0.22541556088869832, 0.19596574675303066, 0.17347608814839732, 0.1533835646570595, 0.1350843780798866, 0.11693538050599822, 0.10026465734737253,0.08407624585322429]
dibert_train_loss_mlm=[5.491347128185911, 4.009450922119595, 3.6397895327396648, 3.4403899848028776, 3.301886940995798,3.1912466183986123, 3.0961167443516895, 3.0098005671304837, 2.9289797277814826, 2.8540517531392173 ]
dibert_train_loss_pp=[3.5331437302161115,1.6951116268025193, 1.4921578197263674, 1.3958533563132551, 1.3340558741198774, 1.2872319190962873, 1.2480977450088622, 1.2134283256167906, 1.1815208615803547, 1.152496820954521]
dibert_valid_loss_cls=[0.3494640473371897, 0.2610560226134765, 0.23463749027787112, 0.2275528297210351, 0.24554021456875863, 0.22320102334786684, 0.21512626419082667, 0.24130792909134657, 0.25086519342775526, 0.2650689552323176]
dibert_valid_loss_mlm=[4.085183953016232, 3.5909610772744203, 3.403571170415634, 3.3098633570548817, 3.24093768902314, 3.1814246422205215, 3.1446786623734693, 3.1120222128354587, 3.082875459622114, 3.0752777270781686]
dibert_valid_loss_pp=[1.6581118253561167, 1.3952396832979643, 1.300964830777584, 1.2636633762946496, 1.228268465323326, 1.2084958599163935, 1.1889893382023542, 1.1719615153777294, 1.1590523787033864, 1.1547587474187215]
bert_train_loss_cls=[0.49274437086129486, 0.2776497194654, 0.22143692823858196, 0.18880421397829808, 0.1638111817235851, 0.14140862471078494, 0.1208944835866969, 0.10145956052502882, 0.08360698709181565, 0.06753760461765418]
bert_train_loss_mlm=[5.494684397538439,3.942632295746839,3.577555536184287,3.378657437031348, 3.2338501089702874, 3.1162113070817963, 3.0131844703244166, 2.9196099015384727, 2.833291628005274, 2.755476903805583]
bert_valid_loss_cls=[0.324016851072128, 0.25400943694970546, 0.23388616318504016, 0.224633837529482, 0.23575505383121662, 0.21916134444375832, 0.22403994269955615, 0.24454388925089285, 0.25829780373053673, 0.27586218988379607]
bert_valid_loss_mlm=[4.040164391199748, 3.5502464905763285, 3.372568300442818, 3.2828326971102983, 3.216287793868627, 3.155730175360655, 3.126622574145977, 3.097288191624177, 3.0712033638587366, 3.063728133226052]
# epoch_vals = [i + 1 for i in range(10)]
#
# pyplot.subplot(111)
# pyplot.title("NSP loss")
# #pyplot.plot(epoch_vals, dibert_train_loss_cls, 'm--', label='nsp')
# pyplot.plot(epoch_vals, dibert_train_loss_cls,'c', label='dibert')
# pyplot.plot(epoch_vals, bert_train_loss_cls,'y', label='bert')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
# pyplot.show()
# pyplot.subplot(211)
# pyplot.title("bert-base (training loss)")
# pyplot.plot(epoch_vals, bert_train_loss_cls, 'm--', label='nsp')
# pyplot.plot(epoch_vals, bert_train_loss_mlm,'g', label='mlm')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
#
#
# pyplot.subplot(212)
# pyplot.title("bert-base (validation loss)")
# pyplot.plot(epoch_vals, bert_valid_loss_cls, 'm--',label='nsp')
# pyplot.plot(epoch_vals, bert_valid_loss_mlm,'g', label='mlm')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
# pyplot.show()
# pyplot.subplot(211)
# pyplot.title("dibert (training loss)")
# pyplot.plot(epoch_vals, dibert_train_loss_cls, 'm--', label='nsp')
# pyplot.plot(epoch_vals, dibert_train_loss_mlm,'g', label='mlm')
# pyplot.plot(epoch_vals, dibert_train_loss_pp,'r', label='pp')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
#
#
# pyplot.subplot(212)
# pyplot.title("dibert (validation loss)")
# pyplot.plot(epoch_vals, dibert_valid_loss_cls, 'm--',label='nsp')
# pyplot.plot(epoch_vals, dibert_valid_loss_mlm,'g', label='mlm')
# pyplot.plot(epoch_vals, dibert_valid_loss_pp,'r', label='pp')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
# pyplot.show()
#----------------------imdb----------------------------------
# imdb_f1_bert = [0.8408172063345226, 0.8572341512297569, 0.8564364904925332, 0.8641966873476424, 0.8674390128488643, 0.8646332996383781
# ,0.8526200634974098, 0.8685884027088417, 0.869756559112654, 0.8656944197888126, 0.8626968126851379, 0.8626519645435392,
# 0.8706458145352924, 0.869987089782801, 0.8709972953230705]
# imdb_f1_dibert = [0.8258576987994131, 0.847631289701772 , 0.8650821713927302, 0.8562090277953578, 0.865538673215382, 0.8738399006857795,
# 0.876353175036712, 0.8614125050477766, 0.8617090379690326, 0.8706414546198776, 0.8735929905875562, 0.8684409937753459,
# 0.8563257754589373, 0.8561276722394221, 0.8787846377214982]
#
# epoch_vals = [i + 1 for i in range(15)]
#
# pyplot.subplot(111)
# pyplot.title("IMDb validation f1")
# #pyplot.plot(epoch_vals, dibert_train_loss_cls, 'm--', label='nsp')
# pyplot.plot(epoch_vals, imdb_f1_dibert,'c', label='dibert')
# pyplot.plot(epoch_vals, imdb_f1_bert,'y', label='bert')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
# pyplot.show()
#--------------------------------------------------------LIAR-----------------------------------------------------------------------
# LIAR_f1_bert = [0.2141511610705922, 0.24549419694299876, 0.24967968834484128, 0.27142660504838734, 0.28863574178286155, 0.2868486956076384, 0.28653782227396596,
# 0.2848395978792812, 0.2844129952171855, 0.2779092366154981, 0.2698443784498295, 0.27450889110697424, 0.2702730591415788, 0.25135198559826466,
# 0.26346206901351504]
# LIAR_f1_dibert = [0.2058103037509316, 0.2234439347525856, 0.2404510649779266, 0.26582794925315395, 0.25956675181817024, 0.26999219272084024, 0.2805245320365466,
# 0.2953901962676256, 0.29326074256169493, 0.29919507672752943, 0.28452609646893334, 0.2930664754611869, 0.2761156124841313, 0.28838163266515515,
# 0.28737331811778577]
#
# epoch_vals = [i + 1 for i in range(15)]
#
# pyplot.subplot(111)
# pyplot.title("LIAR validation f1")
# #pyplot.plot(epoch_vals, dibert_train_loss_cls, 'm--', label='nsp')
# pyplot.plot(epoch_vals, LIAR_f1_dibert,'c', label='dibert')
# pyplot.plot(epoch_vals, LIAR_f1_bert,'y', label='bert')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
# pyplot.show()
#-------------------------------------MRPC-------------------------------------------------------------------------
# MRPC_f1_bert = [0.6392442836360923, 0.6358449174781146, 0.6490738945607722, 0.6691607521071038, 0.6848115097432161, 0.6771827598964113, 0.7037966628959276, 0.6979447288844097,
# 0.6918167973279339, 0.7249419504643964, 0.7045722021194409, 0.6918167973279339, 0.6751240212138732, 0.6799055361796579, 0.6932044832102882]
#
# MPRC_f1_dibert = [0.6347037734577972, 0.6341491386462349, 0.6426550079491257, 0.678974685849572, 0.6691607521071038, 0.7091569659577409, 0.7055000616598842, 0.720451841066178,
# 0.7298672588255077, 0.723790620816208, 0.720793465901212, 0.7184789019857505, 0.7131671831552928, 0.7058043732676318, 0.6982793429384235]
#
# epoch_vals = [i + 1 for i in range(15)]
#
# pyplot.subplot(111)
# pyplot.title("MRPC validation f1")
# #pyplot.plot(epoch_vals, dibert_train_loss_cls, 'm--', label='nsp')
# pyplot.plot(epoch_vals, MPRC_f1_dibert,'c', label='dibert')
# pyplot.plot(epoch_vals, MRPC_f1_bert,'y', label='bert')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
# pyplot.show()
#---------------------------------------QNLI-----------------------------------------------
# QNLI_f1_bert = [0.7249945802361246, 0.7321657964478685, 0.7408711892034531, 0.749497741236236, 0.7441692232284856, 0.7478064991571008, 0.748988216870349, 0.7469680694417994,
# 0.7437459878047123, 0.7419442323373805, 0.7494922470767863, 0.7480913173408031, 0.7437968400945972, 0.7458135178937179, 0.743540244731723]
# QNLI_f1_dibert = [0.7184768746697954, 0.7533232759331996, 0.7641757767252148, 0.7748529674129947, 0.7827192289389915, 0.7761366324800218, 0.7768448533949996, 0.7752496449909243,
# 0.7764310744988681, 0.7733141590996597, 0.7721707954170307, 0.7739281804521928, 0.7750686305448008, 0.7750091417076566, 0.7718832643910333]
#
# epoch_vals = [i + 1 for i in range(15)]
#
# pyplot.subplot(111)
# pyplot.title("QNLI validation f1")
# #pyplot.plot(epoch_vals, dibert_train_loss_cls, 'm--', label='nsp')
# pyplot.plot(epoch_vals, QNLI_f1_dibert,'c', label='dibert')
# pyplot.plot(epoch_vals, QNLI_f1_bert,'y', label='bert')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
# pyplot.show()
#----------------------------------SciTail--------------------------------------------------------------
# SCITAIL_f1_bert = [0.6450997604081222, 0.7162908595459659, 0.7356068811312184, 0.7829602695397339, 0.7736844037164956, 0.7961258496612589, 0.7738163841972476, 0.7835895724242559,
# 0.7958089287550607, 0.7982090244656413, 0.8073368370994426, 0.7918238559701313, 0.7899333639661358, 0.7936972173363549, 0.7891130339307687]
#
# SCITAIL_f1_dibert = [0.6945043625120062, 0.7584071407865165, 0.7486396019283442, 0.7129208417152461, 0.8094545436977798, 0.8350692614726485, 0.7951622538107308, 0.8041854766851629,
# 0.8166855531620486, 0.8220800302491077, 0.8310896156395375, 0.8257419840364942, 0.8165957755254736, 0.8134405828863074, 0.7932167642704645]
#
# epoch_vals = [i + 1 for i in range(15)]
#
# pyplot.subplot(111)
# pyplot.title("SciTail validation f1")
# #pyplot.plot(epoch_vals, dibert_train_loss_cls, 'm--', label='nsp')
# pyplot.plot(epoch_vals, SCITAIL_f1_dibert,'c', label='dibert')
# pyplot.plot(epoch_vals, SCITAIL_f1_bert,'y', label='bert')
# pyplot.legend()
# pyplot.xticks(epoch_vals)
# pyplot.show()
#--------------------------------SST-2--------------------------------------------------------------------------------
SST2_f1_bert = [0.8302457643368804, 0.8370839311854139, 0.8371619598216342, 0.8509017458856715, 0.842752284635203, 0.8246347286309611, 0.8440424403925512, 0.8394554548384772,
0.8383161508026007, 0.8463035974401484, 0.8463399776217456, 0.852020409513973, 0.8417356234792666, 0.8348728100174382, 0.8381832833704805]
SST2_f1_dibert = [0.8173101451064136, 0.8427290754126229, 0.849381770115715, 0.84745599039194, 0.8463359339161902, 0.8343738950816012, 0.8554893767922878, 0.8497119074877,
0.8497736057614712, 0.8388894697038612, 0.8472145995679639, 0.8497815103023197, 0.8554680738272461, 0.8459855806163166, 0.8415428954797012]
epoch_vals = [i + 1 for i in range(15)]
pyplot.subplot(111)
pyplot.title("SST2 validation f1")
#pyplot.plot(epoch_vals, dibert_train_loss_cls, 'm--', label='nsp')
pyplot.plot(epoch_vals, SST2_f1_dibert,'c', label='dibert')
pyplot.plot(epoch_vals, SST2_f1_bert,'y', label='bert')
pyplot.legend()
pyplot.xticks(epoch_vals)
pyplot.show()