-
Notifications
You must be signed in to change notification settings - Fork 0
/
figure.py
272 lines (269 loc) · 17.9 KB
/
figure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import numpy as np
import matplotlib.pyplot as plt
if __name__ == "__main__":
fig, ax1 = plt.subplots(figsize=(9.5,6.5), dpi=300)
ax2 = ax1.twinx()
domain = list(range(100))
# adam
ax1.plot(domain,
[0.4282, 0.5434, 0.5833, 0.6082, 0.6222, 0.6557, 0.6765, 0.7026, 0.742, 0.7549,
0.7803, 0.7794, 0.7823, 0.8047, 0.8191, 0.8199, 0.828, 0.9035, 0.9151, 0.9146,
0.9323, 0.9354, 0.9336, 0.9368, 0.9346, 0.9424, 0.9481, 0.9488, 0.9421, 0.9529,
0.9452, 0.9527, 0.9513, 0.9563, 0.9572, 0.9574, 0.9525, 0.9552, 0.9561, 0.9628,
0.9626, 0.9559, 0.943, 0.9644, 0.9625, 0.9581, 0.9604, 0.9624, 0.9639, 0.9629,
0.9631, 0.9645, 0.9618, 0.9654, 0.9661, 0.962, 0.9676, 0.9642, 0.9641, 0.9632,
0.9655, 0.9653, 0.9652, 0.9681, 0.9675, 0.9656, 0.963, 0.9673, 0.9662, 0.9631,
0.9665, 0.9672, 0.9654, 0.966, 0.9702, 0.9645, 0.9666, 0.9639, 0.9627, 0.9604,
0.9639, 0.96, 0.9627, 0.9688, 0.9669, 0.9682, 0.9707, 0.9702, 0.9676, 0.9677,
0.9669, 0.9616, 0.9679, 0.9676, 0.9666, 0.9575, 0.9661, 0.9668, 0.9634, 0.9683],
color="red", linestyle="-", label="Adam")
ax2.plot(domain,
[17.638753167353393, 1.4611125062563706, 1.2686002929983424, 1.1642130352450455,
1.1550783845499821, 1.1203496179266188, 1.0261231441068435, 0.949457245186089,
0.923810770968423, 0.842690442217491, 0.7568512851613662, 0.7016375557731317,
0.6627037191311477, 0.6335841216271795, 0.5856236032634335, 0.5568709074627255,
0.5104414835811666, 0.4758004404734112, 0.39792111750576187,
0.31823228109547785, 0.2956461605812125, 0.2908556377556769, 0.2805683706600124,
0.24699231101056315, 0.24610961354632566, 0.2019300276251271,
0.18123250858908782, 0.2161421921751923, 0.1881447746462724, 0.1844780072980682,
0.18392596155957452, 0.14933110109111516, 0.16294582351638498,
0.13063897684256415, 0.13844430032892474, 0.13496596731901916,
0.13600027163542064, 0.14116021718878785, 0.10839958031695968,
0.11324059779719013, 0.13150161439234404, 0.14779817677301077,
0.12504880987278708, 0.14731961783605677, 0.11139472451135855,
0.10019857272134514, 0.1139701861665705, 0.09419851242075421,
0.09749947281133646, 0.09866514941393349, 0.09690221518972084,
0.0960868565838637, 0.11466731106910394, 0.08726781768648123,
0.08231706973207431, 0.08874804814793678, 0.07207096816858456,
0.09443558117946665, 0.08247703765619459, 0.07174226401801671,
0.06342621039180317, 0.09202577674363216, 0.06064555472580168,
0.08049971520662447, 0.08052614847873334, 0.06590331391041575,
0.08902756837242065, 0.08518901728970035, 0.08009671497380648,
0.0665459784419944, 0.07645867101524711, 0.06619392585262274,
0.07319662435092372, 0.07819022119164427, 0.06135357899921504,
0.07601133991109248, 0.06387704764633366, 0.06177634120000859,
0.07753122463407253, 0.07372071304469666, 0.06637060713976908,
0.06210055983933071, 0.07570546242672052, 0.06355839833063788,
0.06630168548633161, 0.06119699554116554, 0.06357179699981327,
0.060724676797754784, 0.06710840717405289, 0.06700742729052761,
0.06600019527579638, 0.08041566842168277, 0.06137213053685028,
0.06607688599749306, 0.07513599347769509, 0.06905322847937205,
0.06052056025925136, 0.05364871534160272, 0.06564646861667474,
0.053048787866461125],
color="red", linestyle="--")
# adagrad
ax1.plot(domain,
[0.74, 0.7805, 0.8037, 0.8147, 0.8232, 0.8296, 0.8342, 0.8383, 0.8412, 0.844,
0.8466, 0.8491, 0.852, 0.8535, 0.8542, 0.8561, 0.8577, 0.8582, 0.8595, 0.8607,
0.862, 0.8624, 0.8631, 0.8636, 0.864, 0.8654, 0.8659, 0.8672, 0.8678, 0.8687,
0.8697, 0.8707, 0.8717, 0.8728, 0.8735, 0.874, 0.8746, 0.8754, 0.8756, 0.8759,
0.8756, 0.8762, 0.877, 0.8778, 0.8782, 0.8785, 0.8793, 0.8797, 0.8798, 0.8803,
0.8802, 0.8805, 0.8806, 0.8811, 0.8813, 0.882, 0.8822, 0.8826, 0.8827, 0.8829,
0.883, 0.8832, 0.8833, 0.8837, 0.884, 0.8844, 0.8847, 0.8851, 0.8854, 0.8859,
0.886, 0.8861, 0.8864, 0.8867, 0.8866, 0.8867, 0.8869, 0.8868, 0.8873, 0.8875,
0.8878, 0.8878, 0.888, 0.8886, 0.8889, 0.8889, 0.8893, 0.8892, 0.8893, 0.8895,
0.8898, 0.8898, 0.8902, 0.8905, 0.8906, 0.8909, 0.891, 0.8911, 0.8915, 0.8919],
color="forestgreen", linestyle="-", label="AdaGrad")
ax2.plot(domain,
[156.3404308338504, 77.97794988004262, 62.07197208081414, 53.564440556379886,
47.89076072752767, 43.777230782797396, 40.527691768893455, 37.891375251737735,
35.723210855799685, 33.89456878440748, 32.33287221882116, 30.96881866316372,
29.769614130359805, 28.700088936397346, 27.736258159438353, 26.85729343110033,
26.049353855960124, 25.308882452495812, 24.623028038111787, 23.985506934385665,
23.390110982412004, 22.83618847516832, 22.317530755414584, 21.83399140403265,
21.377184016853484, 20.941926370239702, 20.529005490306805, 20.138329890813285,
19.768331861008395, 19.419449085816883, 19.088527151497715, 18.774778640389048,
18.474202236440114, 18.18621559047326, 17.90904511781225, 17.642978237052105,
17.386622458485967, 17.14076435851403, 16.90320806012446, 16.67504446681901,
16.45395621305918, 16.239228199516138, 16.030309773677565, 15.828978089903176,
15.633057592689285, 15.443302801523942, 15.25921054714683, 15.080498230713188,
14.90608494030209, 14.735788628860957, 14.56959181644553, 14.407930256700853,
14.25119596617195, 14.098702671741531, 13.949977573139584, 13.80501703364188,
13.663300216820971, 13.525686187334118, 13.392180664358063, 13.262432052795093,
13.13561046914153, 13.011187638848073, 12.889469610130215, 12.770183174653596,
12.653697609176037, 12.539083227944444, 12.427596504739972, 12.318804841107939,
12.212797155744395, 12.109031429176543, 12.007619869550366, 11.907971681176226,
11.81028036783613, 11.714363792407747, 11.620701670728756, 11.529042451650415,
11.43924324631673, 11.351060735651625, 11.264811402944051, 11.18013407878843,
11.096856576323436, 11.014807471951789, 10.934312334061982, 10.855211327874217,
10.777353016497067, 10.701245734674487, 10.626441370317345, 10.552727926077832,
10.479860703994797, 10.40832461743649, 10.337991981181796, 10.268801502009962,
10.200494102906074, 10.133439098365487, 10.067507312132257, 10.002730520364436,
9.938801890500969, 9.87616386626179, 9.814640741272243, 9.75412727120142],
color="forestgreen", linestyle="--")
# momentum
ax1.plot(domain,
[0.7999, 0.8329, 0.8502, 0.8574, 0.8638, 0.8683, 0.8734, 0.8771, 0.8812, 0.8837,
0.8855, 0.885, 0.8857, 0.8883, 0.8895, 0.8906, 0.892, 0.8926, 0.8937, 0.8942,
0.8948, 0.8963, 0.897, 0.8968, 0.8974, 0.8975, 0.8971, 0.8973, 0.8983, 0.8995,
0.9005, 0.9009, 0.9006, 0.9014, 0.9022, 0.9025, 0.9027, 0.9029, 0.9022, 0.9028,
0.9031, 0.9036, 0.9037, 0.9035, 0.9044, 0.9045, 0.905, 0.906, 0.9059, 0.906,
0.9062, 0.9067, 0.907, 0.9077, 0.9073, 0.9075, 0.9074, 0.9076, 0.908, 0.9081,
0.9081, 0.9089, 0.9088, 0.9087, 0.9084, 0.909, 0.9092, 0.9094, 0.9098, 0.9099,
0.91, 0.9102, 0.9101, 0.9105, 0.9101, 0.9105, 0.911, 0.9107, 0.9108, 0.9112,
0.9114, 0.9114, 0.9118, 0.9118, 0.9118, 0.9119, 0.9116, 0.9119, 0.9119, 0.9117,
0.9119, 0.9123, 0.9121, 0.9122, 0.9126, 0.912, 0.9123, 0.9123, 0.9121, 0.9123],
color="mediumblue", linestyle="-", label="Momentum")
ax2.plot(domain,
[122.53869938911286, 39.10286923827787, 27.571173295727675, 21.765867829939733,
18.23124220197405, 15.809492194630343, 14.040185892380094, 12.677242054653457,
11.591455341619746, 10.70305130602964, 9.944308387921778, 9.288952902234925,
8.720667064790437, 8.223932372944375, 7.77419203116216, 7.3767315721075075,
7.0162154487787, 6.6881663283324695, 6.388004425443851, 6.112079687302442,
5.8543611131571875, 5.618625043409915, 5.394382338874478, 5.191302756532406,
4.997657703884758, 4.812677035725597, 4.640093709327104, 4.477416524806255,
4.3285707086622915, 4.1854619531659445, 4.050021187930386, 3.921434203210238,
3.798172537061354, 3.683662886182725, 3.5752667164749874, 3.4721277799205623,
3.376767307901181, 3.2816660750593583, 3.192398769886103, 3.10820160527708,
3.0252175054494037, 2.948376504859593, 2.8742872294542723, 2.8028065504448385,
2.733111915483158, 2.6676015685694674, 2.6030697000469, 2.5415058178188965,
2.4822919625489743, 2.424615350565343, 2.370343271476541, 2.317788157028338,
2.268225683440432, 2.21856526567576, 2.1724614627109573, 2.12620534849964,
2.081579303525062, 2.0381594680213695, 1.9969031479447255, 1.9562090637585359,
1.915890250891256, 1.8780360126298175, 1.8410106761148535, 1.8050528150832117,
1.7722319643126174, 1.738293826041444, 1.7059531454686478, 1.6745085917121396,
1.6435659665705622, 1.6149213873932262, 1.5854107227897014, 1.558670127696046,
1.5306107592432567, 1.5041282175596906, 1.4776029595835594, 1.451718159113183,
1.4276460188868887, 1.404090497923498, 1.3796632240039663, 1.357351495536279,
1.33409530608291, 1.3122586691929465, 1.2901763375007873, 1.27026858079657,
1.2492278072655993, 1.2294119664280154, 1.209661445460447, 1.190730560617601,
1.171863633566087, 1.1538520006664303, 1.1356033829033911, 1.1183012936128365,
1.1013755315559721, 1.0847690181659675, 1.0686785208955263, 1.0522954574542314,
1.036857606964357, 1.021280417385962, 1.0062389061594508, 0.991544976153664],
color="mediumblue", linestyle="--")
# sgd
ax1.plot(domain,
[0.8092, 0.8443, 0.8635, 0.8707, 0.878, 0.8823, 0.8871, 0.8919, 0.8928, 0.8952,
0.8979, 0.8994, 0.901, 0.9027, 0.9035, 0.9059, 0.9064, 0.9071, 0.9085, 0.9089,
0.9096, 0.911, 0.9119, 0.9125, 0.9129, 0.9136, 0.9143, 0.9157, 0.9163, 0.9168,
0.9167, 0.917, 0.9173, 0.9175, 0.9175, 0.9177, 0.9181, 0.9182, 0.9182, 0.9188,
0.9188, 0.919, 0.9196, 0.9201, 0.9198, 0.9201, 0.9205, 0.9204, 0.9204, 0.9202,
0.9205, 0.9216, 0.9217, 0.9221, 0.9221, 0.9221, 0.922, 0.9222, 0.9227, 0.9224,
0.9225, 0.9225, 0.9224, 0.9224, 0.9224, 0.9228, 0.9228, 0.9225, 0.9227, 0.9222,
0.9222, 0.9225, 0.9229, 0.9222, 0.9223, 0.9223, 0.9219, 0.9222, 0.923, 0.9229,
0.9233, 0.9232, 0.9231, 0.9229, 0.9231, 0.9229, 0.9233, 0.9234, 0.9228, 0.9232,
0.9233, 0.9232, 0.9235, 0.9233, 0.9241, 0.924, 0.9235, 0.9239, 0.9237, 0.9242],
color="goldenrod", linestyle="-", label="SGD")
ax2.plot(domain,
[111.6402179127881, 39.144721602953176, 28.657987308657383, 23.327051205850594,
19.929963780517674, 17.437365670004706, 15.54513384857947, 14.055855336636323,
12.821059206172686, 11.79802847992573, 10.941114403189296, 10.198989641486142,
9.556824671755175, 8.98924703590668, 8.489241485943847, 8.039599673275234,
7.624349095648669, 7.2691602703847495, 6.944317341483267, 6.635929343793328,
6.356975897920313, 6.096671785752297, 5.849362611406921, 5.623208271567161,
5.413150434024526, 5.210557732258254, 5.027159446970861, 4.8545894105348655,
4.688874494535888, 4.531462069037712, 4.382939410683336, 4.243666054823098,
4.107375183583125, 3.980343741022354, 3.856516171429907, 3.7432004181104532,
3.633533094506152, 3.5287606041246935, 3.4302107859107696, 3.3361860715265235,
3.2480310315199974, 3.160200675929584, 3.076617693927485, 2.997923845674127,
2.917037328922834, 2.8415934848938247, 2.7719046451141884, 2.701280033939557,
2.6341258810802435, 2.569583085856545, 2.5077507147965057, 2.448686779799762,
2.3879597813380773, 2.332982870296152, 2.276926957304362, 2.2247289540590174,
2.1735664428064787, 2.1245372980448467, 2.0770756478368186, 2.028637012325451,
1.9845724836295509, 1.9400806973183673, 1.8979583620981544, 1.85726290826594,
1.8171065101031778, 1.7788836680910858, 1.7417370116997237, 1.7051872450985455,
1.6707542441784247, 1.6376036961489078, 1.6048114604775652, 1.572532302570571,
1.540653308095243, 1.510294604658901, 1.4802230408379142, 1.4513205580876565,
1.422832514789876, 1.3946862147301884, 1.367922571107703, 1.3414382736977957,
1.3146849633488478, 1.2900919008066667, 1.264828769796831, 1.2408320150872167,
1.218776402820496, 1.1961258731739177, 1.1735076492566752, 1.1502482930411098,
1.1292866539126545, 1.1086099986924831, 1.088729770362471, 1.0680055977746892,
1.0475359552462193, 1.028750436984617, 1.009523014528036, 0.9905931057083488,
0.972612642530658, 0.9550423991604257, 0.9381794020302665, 0.9212813107407406],
color="goldenrod", linestyle="--")
# chart settings
ax1.set_title("Test Set Accuracy and Loss versus Epoch")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Test Set Accuracy")
ax2.set_ylabel("Softmax Loss")
ax1.legend(loc="center right")
plt.savefig("results.png", bbox_inches='tight')
# gradient descent stuff
fig, ax1 = plt.subplots(figsize=(9.5,6.5), dpi=300)
ax2 = ax1.twinx()
ax1.plot(domain,
[0.1671, 0.1894, 0.3255, 0.4054, 0.464, 0.5114, 0.5363, 0.5557, 0.5737, 0.5879,
0.5978, 0.6061, 0.6153, 0.6237, 0.6303, 0.6372, 0.6435, 0.6489, 0.654, 0.659,
0.6632, 0.6669, 0.6704, 0.6747, 0.6784, 0.6811, 0.6832, 0.6861, 0.6881, 0.6901,
0.6925, 0.6955, 0.6981, 0.6988, 0.7008, 0.7026, 0.7051, 0.7064, 0.7083, 0.7099,
0.7115, 0.7135, 0.7151, 0.7162, 0.7182, 0.7193, 0.7208, 0.723, 0.7236, 0.7249,
0.7257, 0.7268, 0.7272, 0.7279, 0.7299, 0.7309, 0.7318, 0.7329, 0.7341, 0.7347,
0.7361, 0.7366, 0.7379, 0.7391, 0.74, 0.7409, 0.742, 0.7432, 0.7439, 0.7449,
0.7461, 0.7475, 0.7483, 0.7493, 0.7506, 0.7511, 0.7531, 0.7545, 0.7563, 0.757,
0.7594, 0.7615, 0.7632, 0.7651, 0.7669, 0.7695, 0.7716, 0.7735, 0.7762, 0.7779,
0.7795, 0.781, 0.7823, 0.7838, 0.7856, 0.7873, 0.7884, 0.7896, 0.7896, 0.7915],
color="darkviolet", linestyle="-", label="GD w/ regularization")
ax2.plot(domain,
[1877.0367720186434, 1418.842929312803, 1061.5871486681501, 458.37080198057004,
192.84111755852675, 127.6644843253273, 100.06248743068268, 84.67107951446866,
73.71034411509697, 65.01070755163548, 57.84144578385049, 51.82709064909877,
46.66638528890774, 42.18997386297939, 38.27901768403263, 34.83017961257465,
31.773781123149497, 29.049288938129965, 26.612201350111818, 24.42488009709727,
22.454910808572677, 20.67418836774094, 19.059509602111856, 17.591277083029155,
16.25309208556401, 15.029909527036711, 13.910725162635016, 12.885441571806764,
11.94499319522279, 11.081260012346675, 10.28706254724624, 9.556264520074498,
8.88297307206089, 8.261841153812398, 7.688319648845783, 7.1583304303886734,
6.668181684961948, 6.2146527258697875, 5.794779059338173, 5.4056988010164755,
5.044999082019963, 4.710526264299936, 4.40025290315759, 4.112263139329426,
3.844889170333134, 3.5966874263705155, 3.3662102675275944, 3.152170616234196,
2.953384097464638, 2.7687416239026788, 2.5972302229940163, 2.437949981481326,
2.2900498661367448, 2.1527595250126303, 2.0253766675797515, 1.9072437834686178,
1.7977630476632689, 1.696376397735726, 1.6026007162756064, 1.5159603290237769,
1.436024567327941, 1.3624097345751334, 1.2947614122472677, 1.23274801608114,
1.1760719771504327, 1.1244590355531912, 1.077657044302921, 1.0354340413224201,
0.9975804717994191, 0.9639018681876692, 0.9342137942405851, 0.9083462631988327,
0.8861386806666476, 0.8674348781342986, 0.8520848374385952, 0.8399515554060907,
0.8308921027494354, 0.8247751208417256, 0.8214656075392349, 0.8208204848665908,
0.8227028448614414, 0.8269616576732391, 0.8334607852973361, 0.8420654612659392,
0.8526203036072546, 0.8650120478207002, 0.8790772960283533, 0.8946730222655197,
0.9116945456905604, 0.9299994973292277, 0.9494539649577037, 0.9699483076695159,
0.9913721557674188, 1.013621378470521, 1.036593085082425, 1.0601828574175964,
1.0842973720160294, 1.1088648720150422, 1.1338159956139073, 1.1590854893990317],
color="darkviolet", linestyle="--")
# gd
ax1.plot(domain,
[0.1627, 0.2185, 0.2371, 0.2723, 0.3382, 0.4267, 0.4982, 0.5079, 0.5255, 0.5342,
0.5441, 0.5497, 0.5575, 0.5606, 0.5631, 0.5671, 0.5709, 0.5757, 0.5763, 0.5775,
0.5779, 0.5781, 0.5787, 0.5792, 0.5794, 0.5775, 0.5756, 0.5737, 0.5699, 0.5672,
0.562, 0.5605, 0.5551, 0.549, 0.5457, 0.5395, 0.5298, 0.5233, 0.5161, 0.5108,
0.5056, 0.4979, 0.4935, 0.487, 0.4779, 0.4717, 0.464, 0.458, 0.4502, 0.4434,
0.4362, 0.4301, 0.422, 0.4156, 0.4111, 0.4045, 0.3987, 0.3914, 0.3853, 0.3797,
0.3738, 0.3693, 0.3648, 0.3612, 0.3548, 0.3505, 0.3475, 0.3433, 0.3378, 0.3327,
0.3292, 0.3271, 0.3232, 0.3195, 0.3162, 0.3134, 0.3147, 0.3111, 0.3078, 0.3053,
0.3016, 0.2995, 0.2972, 0.2947, 0.2919, 0.2908, 0.2897, 0.2862, 0.2852, 0.2825,
0.2805, 0.28, 0.2786, 0.2779, 0.2768, 0.2757, 0.2749, 0.2738, 0.2721, 0.2704],
color="indigo", linestyle="-", label="GD w/o regularization")
ax2.plot(domain,
[1654.6651816185495, 1180.3038635635266, 849.87095725426, 489.6313456670441,
280.5961094762227, 261.06906186821084, 148.96093607860485, 115.919106267023,
100.51183606391425, 92.25675094867951, 86.0492141819225, 80.89568305327379,
76.44641902268548, 72.52289061382882, 69.02417886807777, 65.86766912267656,
63.023277608844225, 60.43643730158372, 58.0686337438384, 55.868369315271025,
53.82464804426844, 51.93595032446823, 50.14913771873821, 48.44972960843327,
46.84839806998482, 45.327695867949686, 43.86216800563938, 42.45785645418152,
41.0939454957971, 39.76842774188204, 38.48436992380297, 37.2303019984631,
35.99683482577919, 34.777604955904614, 33.55330992428788, 32.343740468535714,
31.15055618899411, 29.96980576660437, 28.80737207844531, 27.666638355071456,
26.54886889095146, 25.457865578244967, 24.394416651439858, 23.38090185915531,
22.41380450497015, 21.485689561668575, 20.604241215122094, 19.761110005260498,
18.954966235288328, 18.195467685530833, 17.472835991018897, 16.795839324145934,
16.14127446546818, 15.516506662419285, 14.927813341812742, 14.37888299809189,
13.852775404647554, 13.355400176800996, 12.882532657805813, 12.437687762911,
12.024358116335728, 11.632340108584485, 11.259728183107411, 10.90702601476453,
10.571957905048967, 10.255884460687659, 9.956290105624566, 9.6714409613439,
9.405057188963509, 9.152105687750876, 8.911972405450157, 8.68438841241,
8.471179685463834, 8.268995119604442, 8.074214451152718, 7.887491568595448,
7.708578313514506, 7.5385502438415335, 7.375582256992636, 7.2201741169175255,
7.071642539049913, 6.930995682449079, 6.797058804539569, 6.669816364682467,
6.549325601201054, 6.433948679673967, 6.322830568323524, 6.213997713415428,
6.108894102127925, 6.0070580295424705, 5.908741686001253, 5.814616311153365,
5.724015790400395, 5.63620890820183, 5.553247900579436, 5.4741972163166635,
5.398129648643919, 5.325151390954768, 5.254021715469648, 5.185071589712148],
color="indigo", linestyle="--")
# chart settings
ax1.set_title("Gradient Descent and Regularization")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Test Set Accuracy")
ax2.set_ylabel("Softmax Loss")
ax1.legend(loc="center right")
plt.savefig("gdescent.png", bbox_inches='tight')