forked from chenzomi12/AISystem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path06.srt
1442 lines (1090 loc) · 28 KB
/
06.srt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1
00:00:00,900 --> 00:00:04,425
字幕组:赵含霖 谢鑫鑫
2
00:00:05,075 --> 00:00:07,520
Hello 大家好,我是ZOMI酱
3
00:00:07,520 --> 00:00:09,938
又来到了没什么观看
4
00:00:09,938 --> 00:00:16,080
但是我依然在坚持的一节自动微分的系列的课堂当中
5
00:00:16,080 --> 00:00:21,760
这一节主要是讲反向操作符重载去实现自动微分
6
00:00:21,960 --> 00:00:26,915
那这个自动微分的方式更类似于PyTorch这个AI框架
7
00:00:26,915 --> 00:00:30,760
就是使用反向操作符重载的自动微分
8
00:00:30,760 --> 00:00:34,240
那一起来回顾一下什么叫做操作符重载
9
00:00:34,240 --> 00:00:37,304
下面这个操作符重载的这句话
10
00:00:37,304 --> 00:00:40,706
其实是我在Wiki或者百度上面粘过来的
11
00:00:40,706 --> 00:00:43,160
具体在哪粘我已经忘了
12
00:00:43,160 --> 00:00:43,898
简单的来说
13
00:00:43,898 --> 00:00:46,793
其实它只是利用了语言的多态性
14
00:00:46,793 --> 00:00:50,040
然后进行了一个重载
15
00:00:50,040 --> 00:00:53,420
下面这一段反倒是没什么用
16
00:00:53,420 --> 00:00:55,592
但是依旧在那放着的一句话
17
00:00:55,592 --> 00:00:58,570
就是讲操作符自动重载的微分的方式
18
00:00:58,570 --> 00:01:01,160
的一些过去的AI框架
19
00:01:01,160 --> 00:01:05,520
那最典型的一个代表就是经常用到的Pytorch
20
00:01:05,520 --> 00:01:09,281
其中最重要的就是使用数据结构Tap
21
00:01:09,281 --> 00:01:11,430
来记录整个计算流程
22
00:01:11,430 --> 00:01:13,960
也就是理解的计算图
23
00:01:13,960 --> 00:01:17,680
但是在Pytorch里面,它没有一个现实的计算图
24
00:01:18,000 --> 00:01:21,188
然后在反向求解梯度的时候去replay
25
00:01:21,538 --> 00:01:23,106
去重放我的操作
26
00:01:23,106 --> 00:01:24,456
这么一种方式
27
00:01:24,456 --> 00:01:29,055
现在来简单的去回顾一下操作符重载的基本流程
28
00:01:29,055 --> 00:01:34,025
首先就是需要用语言的多态性对操作符进行重载
29
00:01:34,025 --> 00:01:36,437
定一个特殊的数据结构
30
00:01:36,437 --> 00:01:40,120
并且对每个计算进行重载的操作
31
00:01:40,795 --> 00:01:43,164
第二个就是有一个Tap的一个数据结构
32
00:01:43,164 --> 00:01:48,080
对数据输出和计算进行记录
33
00:01:48,080 --> 00:01:51,218
接着记录了每一次操作之后
34
00:01:51,218 --> 00:01:53,701
需要对每次操作进行遍历
35
00:01:53,701 --> 00:01:56,440
然后计算它的微分方式
36
00:01:56,440 --> 00:01:59,128
最后就是使用链式求导法则
37
00:01:59,128 --> 00:02:05,720
把刚才遍历得到的微分的结果进行累积
38
00:02:05,720 --> 00:02:09,400
这个就完成了整个操作符重载的流程了
39
00:02:10,120 --> 00:02:13,100
操作符重载其实已经多次讲到了
40
00:02:13,100 --> 00:02:17,800
它的优点就是实现起来只需要语言去提供多态的性能
41
00:02:17,800 --> 00:02:19,952
第二个就是它的应用性非常高
42
00:02:19,952 --> 00:02:24,189
操作符重载之后跟原生语言的编程方式是类似的
43
00:02:24,680 --> 00:02:25,872
所以大家都会说
44
00:02:25,872 --> 00:02:29,882
极度的模仿操作符重载的PyTorch的方式
45
00:02:29,882 --> 00:02:33,776
非常方便去理解和使用
46
00:02:33,776 --> 00:02:35,451
跟理解python代码的方式一样
47
00:02:35,451 --> 00:02:37,080
这个就是PyTorch的优点
48
00:02:37,400 --> 00:02:39,374
它的缺点也是非常明显
49
00:02:39,374 --> 00:02:43,474
上面用了一个Tape去记录大量的操作
50
00:02:44,200 --> 00:02:48,589
这个时候就需要对特殊的数据结构进行大量的读和写了
51
00:02:48,589 --> 00:02:50,044
遍历等操作了
52
00:02:50,044 --> 00:02:52,840
非常不利于高阶的微分实现
53
00:02:52,840 --> 00:02:54,790
高阶微分可能会在
54
00:02:54,790 --> 00:02:59,700
动力学、生物分子建模、物理方程模拟等
55
00:02:59,700 --> 00:03:02,920
非常常见的一些科学计算场景经常用到
56
00:03:02,920 --> 00:03:07,120
这个时候这种自动微分的方式非常不利于求解
57
00:03:07,440 --> 00:03:10,808
第二个就是类似于While,If else这些控制表达
58
00:03:10,808 --> 00:03:14,133
其实很难通过操作符去重载的
59
00:03:15,080 --> 00:03:17,480
下面来看看反向模式
60
00:03:18,800 --> 00:03:22,360
反向模式一般来说是比较简单好理解的
61
00:03:22,600 --> 00:03:25,166
又回到熟悉的图里面
62
00:03:25,166 --> 00:03:28,920
正向模式,假设我现在有一个X的输入
63
00:03:28,920 --> 00:03:33,052
然后我正向的就是每一次去计算每一个节点
64
00:03:33,052 --> 00:03:36,000
然后去计算中间变量的导数
65
00:03:36,000 --> 00:03:37,581
最后一个个计算
66
00:03:37,581 --> 00:03:41,796
然后得到最终输出的f(x1,x2)这个输出
67
00:03:41,796 --> 00:03:43,653
对于X的导数
68
00:03:43,653 --> 00:03:46,120
这个就是每次正向计算的
69
00:03:46,120 --> 00:03:49,014
那反向计算就是我从最后一个
70
00:03:49,014 --> 00:03:52,480
每个中间变量关于最初的一个导数
71
00:03:52,480 --> 00:03:55,085
那从反向开始就是从后面开始
72
00:03:55,085 --> 00:03:59,160
计算每一条路径关于逆向输入的一个导数
73
00:03:59,160 --> 00:04:06,160
最后我就求得了δf关于x2和δf关于x1的所有的导数形式
74
00:04:06,160 --> 00:04:08,293
那在机器学习里面
75
00:04:08,293 --> 00:04:11,143
因为我的输入神经元非常的大量
76
00:04:11,143 --> 00:04:13,600
而我的输出类别有限
77
00:04:13,600 --> 00:04:14,980
在机器学习里面
78
00:04:14,980 --> 00:04:20,280
所以一般都会用到反向模式的自动微分的方式去实现
79
00:04:20,280 --> 00:04:25,720
那这个也是反向传播的一个最原始的idea或者数学原理
80
00:04:25,960 --> 00:04:27,673
后面了解了这一点
81
00:04:27,673 --> 00:04:30,857
看反向传播这个算法可能会更有感觉
82
00:04:31,960 --> 00:04:35,273
下面我想通过简单的几分钟的了解
83
00:04:35,273 --> 00:04:38,728
去跟着大家一起去回顾或者学习一下
84
00:04:38,728 --> 00:04:41,960
Pytorch的AutoDiff是怎么去实现的
85
00:04:41,960 --> 00:04:43,729
这里面的所有的操作方式
86
00:04:43,729 --> 00:04:49,809
都是根据Pytorch的最核心的框架的一个原始理念
87
00:04:49,809 --> 00:04:51,809
然后去复现的
88
00:04:51,809 --> 00:04:58,273
首先需要from typing Import List
NameTuple,Callable,Dict,Operational
89
00:04:58,273 --> 00:05:02,200
这一些简单的操作方便下面去一个加载的
90
00:05:02,200 --> 00:05:04,426
那这个fresh_name有什么用呢?
91
00:05:04,426 --> 00:05:04,431
fresh_name这个函数是用来打印跟Tape相关的变量
92
00:05:04,431 --> 00:05:08,520
fresh_name这个函数是用来打印跟Tape相关的变量
93
00:05:08,520 --> 00:05:12,153
这是我这个f'v{name}
94
00:05:12,153 --> 00:05:15,480
这个Name就是记录下面的每一条Tape
95
00:05:15,480 --> 00:05:19,736
假设我x1等于V-1,x2等于V0
96
00:05:19,736 --> 00:05:22,073
V-1又通过一个计算
97
00:05:22,073 --> 00:05:27,000
每一行每一次计算都有一个Tape去记录的
98
00:05:27,000 --> 00:05:29,629
所以我这里面通过fresh_name
99
00:05:29,629 --> 00:05:33,160
去记录我每一次Tape到底是第几个
100
00:05:33,160 --> 00:05:36,277
然后_Name第一个就是1
101
00:05:36,277 --> 00:05:39,216
从1开始不断的去累积
102
00:05:39,216 --> 00:05:42,840
然后返回V等于多少个
103
00:05:42,840 --> 00:05:45,315
为了更加好的理解Pytorch里面的
104
00:05:45,315 --> 00:05:47,413
反向模式自动微分的实现
105
00:05:47,413 --> 00:05:49,481
实现的代码过程当中
106
00:05:49,481 --> 00:05:53,080
完全不依赖于Pytorch的AutoGrid的方式
107
00:05:53,080 --> 00:05:56,912
反倒是引入了一个新的类
108
00:05:56,912 --> 00:05:58,912
这个类叫做Variable
109
00:05:58,912 --> 00:06:01,720
也就是类似于Pytorch里面的Tensor
110
00:06:01,720 --> 00:06:03,337
在计算的时候
111
00:06:03,337 --> 00:06:05,976
实际上是从最后的损失函数
112
00:06:05,976 --> 00:06:09,080
或者l来去进行一个计算的
113
00:06:09,080 --> 00:06:12,030
程序当中每算一个张量X的值
114
00:06:12,030 --> 00:06:13,334
就是它的梯度的时候
115
00:06:13,334 --> 00:06:16,680
都会去计算dl到dx的一个导数
116
00:06:16,680 --> 00:06:21,028
然后反向模式就是从dl对dl自身的导数开始
117
00:06:21,028 --> 00:06:23,455
也就是dl对dl的导数等于1
118
00:06:23,960 --> 00:06:26,079
回头看看上面这条公式
119
00:06:26,079 --> 00:06:30,280
V5就是我的y对y对自身的导数是1
120
00:06:30,280 --> 00:06:31,833
从这个讲式开始
121
00:06:31,833 --> 00:06:35,342
然后使用偏导数和链式法则进行传播
122
00:06:35,342 --> 00:06:37,450
也就是下面这条公式
123
00:06:37,450 --> 00:06:39,320
然后一步步的去算的
124
00:06:39,320 --> 00:06:42,878
下面代码实现可能还是比较简单
125
00:06:42,878 --> 00:06:46,440
我的Variable可以理解为简单的张量
126
00:06:46,440 --> 00:06:50,242
对于张量,一开始会初始化一个值叫做Value
127
00:06:50,242 --> 00:06:53,800
通过这个值变成张量的成员变量
128
00:06:53,800 --> 00:06:56,908
然后self.name就是刚才的中间变量
129
00:06:56,908 --> 00:06:59,529
如果一开始没有输入name
130
00:06:59,529 --> 00:07:01,617
它可能就直接使用fresh_name()
131
00:07:01,617 --> 00:07:04,762
就是刚才上面的一个函数
132
00:07:04,762 --> 00:07:07,490
fresh_name(),然后不断的去累加1
133
00:07:08,760 --> 00:07:11,586
接着下面这几个就比较有意思了
134
00:07:11,596 --> 00:07:13,315
Constant其实是比较方便
135
00:07:13,315 --> 00:07:16,365
去打印查看的一个过程
136
00:07:16,365 --> 00:07:18,493
会通过Constant然后上下文
137
00:07:18,493 --> 00:07:21,216
去把当前的一个值打印出来
138
00:07:21,216 --> 00:07:23,160
还有当前的Name打印出来
139
00:07:23,160 --> 00:07:26,917
下面这几个就是回到一开始去实现
140
00:07:26,917 --> 00:07:30,376
或者上两节分享内容里面的一个实现
141
00:07:30,376 --> 00:07:32,356
只有一条简单的公式
142
00:07:32,356 --> 00:07:32,376
这里面有5个操作
143
00:07:32,376 --> 00:07:33,856
这里面有5个操作
144
00:07:33,856 --> 00:07:38,062
第一个就是*、+、-、sin和log
145
00:07:38,062 --> 00:07:42,646
一开始并没有去实现这几个函数
146
00:07:42,646 --> 00:07:47,480
而是返回了ops_mul、ops_add、ops_sub
147
00:07:47,480 --> 00:07:49,550
在反向自动微分的时候
148
00:07:49,550 --> 00:07:51,778
其实最核心的就是一个Tape
149
00:07:51,778 --> 00:07:55,798
用来跟踪Variable的所有的计算
150
00:07:55,798 --> 00:07:58,680
以便于后面用链式求导法则的
151
00:07:58,680 --> 00:08:01,426
这里面就出现了一个Tape的类
152
00:08:01,426 --> 00:08:03,426
Tape的类的数就是NameTuple
153
00:08:03,426 --> 00:08:05,400
它是一个String
154
00:08:05,400 --> 00:08:08,914
我的输入或者我的记录的类有两个
155
00:08:08,914 --> 00:08:10,914
第一个是Input,第二个是Output
156
00:08:10,914 --> 00:08:14,840
那Propagation就是应用链式求导法则的
157
00:08:14,840 --> 00:08:17,225
告诉我的输入是什么,输出是什么
158
00:08:17,225 --> 00:08:23,285
值得注意的是这里面的输入是我的dl到doutput
159
00:08:23,285 --> 00:08:27,000
输出是dl除以dInput
160
00:08:27,000 --> 00:08:31,174
Tape把所有原始计算的累积的List列表
161
00:08:31,174 --> 00:08:35,716
就是我要把所有的计算逆向的过程记录下来
162
00:08:35,716 --> 00:08:41,880
最终通过遍历的方式求得每一次反向的自动微分的操作
163
00:08:41,880 --> 00:08:46,480
下面有另外一个函数,叫做reset_tape
164
00:08:46,480 --> 00:08:53,256
这个函数很简单,就是重新初始化整个gradient_tape
165
00:08:53,256 --> 00:08:56,310
把gradient_tape重新初始化一遍
166
00:08:57,160 --> 00:09:02,090
下面来看看具体的每个原子操作怎么去实现
167
00:09:02,090 --> 00:09:05,325
刚才在Variable或者Tensor里面
168
00:09:05,325 --> 00:09:09,785
重载mul、add、sub这些原始操作的时候
169
00:09:09,785 --> 00:09:13,235
返回的是一个ops_sub这个原子操作
170
00:09:13,235 --> 00:09:17,160
看看现在这个原子操作具体实现了哪些功能
171
00:09:17,160 --> 00:09:20,300
正向的时候的计算比较简单
172
00:09:20,300 --> 00:09:24,624
首第一个传进来的other它也是一个Variable或者一个张量
173
00:09:24,624 --> 00:09:28,189
这里面自身其实它是一个张量
174
00:09:28,189 --> 00:09:32,716
所以两个张量相乘,需要通过Variable把它们包起来
175
00:09:32,716 --> 00:09:34,625
最后返回一个X
176
00:09:34,625 --> 00:09:38,045
这个X正向计算的时候直接返回出去
177
00:09:38,045 --> 00:09:42,905
中间的这一坨就是为了在反向的时候去计算的
178
00:09:42,905 --> 00:09:46,710
反向的时候先不要去看反向的计算
179
00:09:46,710 --> 00:09:49,841
而是去看一下Tape具体做了哪些工作
180
00:09:49,841 --> 00:09:51,886
这个就是Tape
181
00:09:51,886 --> 00:09:58,272
Tape就是记录输入输出还有反向操作的一个闭包函数
182
00:09:58,272 --> 00:10:01,664
Tape刚才也是重新声明了
183
00:10:01,664 --> 00:10:05,990
只是记录输入输出还有对应的操作
184
00:10:05,990 --> 00:10:08,276
对应的反向操作就是这个
185
00:10:08,276 --> 00:10:12,880
然后通过Gradient的Tape把当前的Tape Append进去
186
00:10:12,880 --> 00:10:17,160
就通过一个列表List来记录我所有的操作
187
00:10:17,160 --> 00:10:19,741
然后就会去遍历这个Gradient的Tape
188
00:10:19,741 --> 00:10:23,578
去把每一次的操作逆向的求出来
189
00:10:23,578 --> 00:10:26,772
就把所有的正向的计算操作求一遍
190
00:10:26,772 --> 00:10:28,948
把反向的计算操作求一遍
191
00:10:28,948 --> 00:10:33,354
就求得了最终的dl对dx1、dx2、dx3
192
00:10:33,354 --> 00:10:38,266
反向的微分的时候有一个函数叫做propagate
193
00:10:38,266 --> 00:10:41,374
它的输入是dl对doutput的一个值
194
00:10:41,374 --> 00:10:44,828
这个反向就是我的损失函数对输出的一个导数
195
00:10:44,828 --> 00:10:47,347
那dl对dx就是我当前的一个值
196
00:10:47,347 --> 00:10:51,494
接着就是dx对dself的一个值就是other
197
00:10:51,494 --> 00:10:54,544
dx对dother的值就是我当前的数
198
00:10:54,544 --> 00:10:59,282
乘法里面可以看到根据乘数的求导法则
199
00:10:59,282 --> 00:11:00,288
就是这两个
200
00:11:00,288 --> 00:11:04,624
然后再求dl对dself还有dl对dother的一个值
201
00:11:04,624 --> 00:11:07,669
最后就把我的输出扔出来
202
00:11:07,669 --> 00:11:09,085
因为这里面有两个输出
203
00:11:09,085 --> 00:11:12,988
所以会把两个输出都同时返回出去
204
00:11:12,988 --> 00:11:17,290
那同样的,add操作也是相同的方式去处理
205
00:11:17,290 --> 00:11:19,160
我的sub操作也是相同的
206
00:11:19,160 --> 00:11:23,722
那加减乘除里面可能会简单一点的,就是加和减
207
00:11:23,722 --> 00:11:25,753
加和减无论你怎么算
208
00:11:25,753 --> 00:11:30,410
它里面就是对自身的数进行求导等于1
209
00:11:30,410 --> 00:11:32,748
如果你对另外一个数进行求导
210
00:11:32,748 --> 00:11:35,362
那就是减号,保留减号就等于-1
211
00:11:35,362 --> 00:11:39,581
sin还有log这两个也是比较简单的
212
00:11:39,581 --> 00:11:42,772
log就是1除以self.value就可以了
213
00:11:42,772 --> 00:11:45,460
然后如果你对另外一个数进行求导
214
00:11:45,460 --> 00:11:48,056
就是对dx乘以dself
215
00:11:51,160 --> 00:11:54,392
在Pytorch,TensorFlow或者MindSpore里面
216
00:11:54,392 --> 00:11:57,828
如果不显式的去设置self.autogrid
217
00:11:57,828 --> 00:12:00,104
或者实现一个自动微分的时候
218
00:12:00,104 --> 00:12:02,104
其实只是做了一个正向的计算
219
00:12:02,160 --> 00:12:05,158
在实际上需要反向计算的时候
220
00:12:05,158 --> 00:12:08,160
就需要去声明我这个函数需要进行反向
221
00:12:08,160 --> 00:12:11,160
那这里面的反向模式也是一样的
222
00:12:11,160 --> 00:12:14,242
首先通过一个函数grad
223
00:12:14,242 --> 00:12:17,800
然后去声明我需要进行一个反向梯度的求解
224
00:12:17,800 --> 00:12:21,840
那输入有两个,第一个是l,第二个是results
225
00:12:21,840 --> 00:12:25,022
输出results它是一个x
226
00:12:25,022 --> 00:12:26,174
代表它是一个List
227
00:12:26,174 --> 00:12:32,217
里面就对应于需要求导的所有的函数
228
00:12:32,217 --> 00:12:34,713
l就是我最后的V0
229
00:12:34,713 --> 00:12:37,009
可以看到这里面的公式
230
00:12:37,009 --> 00:12:42,001
对应的是这个results,从下往上求
231
00:12:42,001 --> 00:12:44,305
我的l就是V5
232
00:12:44,305 --> 00:12:50,321
我的results就是x1和x2,对应的V-1和V0
233
00:12:52,160 --> 00:12:54,855
回到最核心的这个里面
234
00:12:54,855 --> 00:12:58,294
首先创建一个字典dl_d
235
00:12:58,294 --> 00:13:00,954
dl_d它是一个字典
236
00:13:00,954 --> 00:13:03,402
里面就记录了每个dl对dx
237
00:13:03,402 --> 00:13:07,160
或者d中间的变量的所有的名字和数值
238
00:13:07,160 --> 00:13:11,196
然后最后一个variable等于1
239
00:13:11,196 --> 00:13:13,577
所以把最后一个l的name拿出来
240
00:13:13,577 --> 00:13:15,566
然后丢给它作为1
241
00:13:15,566 --> 00:13:19,406
可以看到这里面最后一个假设是1
242
00:13:19,406 --> 00:13:21,093
这个是所有的前提
243
00:13:22,821 --> 00:13:24,869
gather_grad这个内联函数呢
244
00:13:24,869 --> 00:13:27,058
主要是去把所有的entry
245
00:13:27,058 --> 00:13:30,946
就是我的grad里面的所有的Tape的数值都记录下来
246
00:13:30,946 --> 00:13:32,794
丢给我的dl_d
247
00:13:32,794 --> 00:13:37,978
也就是把所有的数值或者我的计算的过程放在我的dl_d里面
248
00:13:37,978 --> 00:13:40,970
为的就是方便我进行打印的时候操作
249
00:13:41,160 --> 00:13:44,487
这个时候可以看到dl_d
250
00:13:44,487 --> 00:13:48,842
主要是去记录所有的dl,d0
251
00:13:48,842 --> 00:13:51,160
这个具体的计算公式
252
00:13:51,160 --> 00:13:53,727