forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Activation.cpp
831 lines (738 loc) · 27.3 KB
/
Activation.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/Activation.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorIterator.h>
#include <ATen/TensorOperators.h>
#include <ATen/OpMathType.h>
#include <ATen/Parallel.h>
#include <ATen/ScalarOps.h>
#if defined(C10_MOBILE) && defined(USE_XNNPACK)
#include <ATen/native/xnnpack/Engine.h>
#endif
#include <ATen/core/DistributionsHelper.h>
#include <c10/util/irange.h>
#include <c10/core/ScalarType.h>
#if AT_MKLDNN_ENABLED()
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/celu_native.h>
#include <ATen/ops/clamp.h>
#include <ATen/ops/clamp_min.h>
#include <ATen/ops/elu.h>
#include <ATen/ops/elu_backward_native.h>
#include <ATen/ops/elu_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/gelu_backward_native.h>
#include <ATen/ops/gelu_native.h>
#include <ATen/ops/hardshrink_backward_native.h>
#include <ATen/ops/hardshrink_native.h>
#include <ATen/ops/hardsigmoid_backward_native.h>
#include <ATen/ops/hardsigmoid_native.h>
#include <ATen/ops/hardswish_backward_native.h>
#include <ATen/ops/hardswish_native.h>
#include <ATen/ops/hardtanh.h>
#include <ATen/ops/hardtanh_backward_native.h>
#include <ATen/ops/hardtanh_native.h>
#include <ATen/ops/infinitely_differentiable_gelu_backward_native.h>
#include <ATen/ops/leaky_relu.h>
#include <ATen/ops/leaky_relu_backward.h>
#include <ATen/ops/leaky_relu_backward_native.h>
#include <ATen/ops/leaky_relu_native.h>
#include <ATen/ops/log_sigmoid_backward_native.h>
#include <ATen/ops/log_sigmoid_forward.h>
#include <ATen/ops/log_sigmoid_forward_native.h>
#include <ATen/ops/log_sigmoid_native.h>
#include <ATen/ops/mish_backward_native.h>
#include <ATen/ops/mish_native.h>
#include <ATen/ops/prelu_native.h>
#include <ATen/ops/_prelu_kernel.h>
#include <ATen/ops/_prelu_kernel_native.h>
#include <ATen/ops/_prelu_kernel_backward_native.h>
#include <ATen/ops/relu6_native.h>
#include <ATen/ops/relu_native.h>
#include <ATen/ops/rrelu_native.h>
#include <ATen/ops/rrelu_with_noise.h>
#include <ATen/ops/rrelu_with_noise_backward_native.h>
#include <ATen/ops/rrelu_with_noise_native.h>
#include <ATen/ops/selu_native.h>
#include <ATen/ops/sigmoid.h>
#include <ATen/ops/silu_backward_native.h>
#include <ATen/ops/silu_native.h>
#include <ATen/ops/softplus.h>
#include <ATen/ops/softplus_backward_native.h>
#include <ATen/ops/softplus_native.h>
#include <ATen/ops/softshrink_backward_native.h>
#include <ATen/ops/softshrink_native.h>
#include <ATen/ops/tanh.h>
#include <ATen/ops/threshold_backward_native.h>
#include <ATen/ops/threshold_native.h>
#include <ATen/ops/zeros_like.h>
#include <utility>
#endif
namespace at::meta {
// computes `result = self <= threshold ? value : other`
// other is `self` in threshold() and `grad` in threshold_backward()
TORCH_META_FUNC(threshold)(const Tensor& self, const Scalar& threshold, const Scalar& value) {
const Tensor& result = maybe_get_output();
build(TensorIteratorConfig()
.set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
.add_output(result)
.add_input(self)
.add_input(self) // other
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true));
}
// computes `result = self <= threshold ? value : other`
// other is `self` in threshold() and `grad` in threshold_backward()
TORCH_META_FUNC(threshold_backward)(const Tensor& grad, const Tensor& self, const Scalar& threshold) {
const Tensor& gradInput = maybe_get_output();
build(TensorIteratorConfig()
.set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
.add_output(gradInput)
.add_input(self)
.add_input(grad) // other
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true));
}
TORCH_META_FUNC(elu) (
const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale
) {
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(elu_backward) (
const Tensor& grad_output,
const Scalar& alpha,
const Scalar& scale,
const Scalar& input_scale,
bool is_result,
const Tensor& self_or_result
) {
TORCH_CHECK(
!is_result || alpha.to<double>() >= 0.0,
"In-place elu backward calculation is triggered with a negative slope which is not supported. "
"This is caused by calling in-place forward function with a negative slope, "
"please call out-of-place version instead.");
build_borrowing_binary_op(maybe_get_output(), grad_output, self_or_result);
}
TORCH_META_FUNC(silu) (const Tensor& self) {
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(silu_backward) (
const Tensor& grad_output, const Tensor& input
) {
build_borrowing_binary_op(maybe_get_output(), grad_output, input);
}
TORCH_META_FUNC(mish) (const Tensor& self) {
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(softplus) (
const Tensor& self, const Scalar& beta, const Scalar& threshold
) {
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(softplus_backward) (
const Tensor& grad_output,
const Tensor& self,
const Scalar& beta,
const Scalar& threshold
) {
build_borrowing_binary_op(maybe_get_output(), grad_output, self);
}
TORCH_META_FUNC(leaky_relu) (
const Tensor& self, const Scalar& negval
) {
build_unary_op(maybe_get_output(), self);
}
// Note: leakyReLu backward calculation doesn't support in-place call with negative slope.
// The reason is that for in-place forward call, the forward result will be saved into autograd
// node instead of the input itself, when calculating backward gradient, there is no way to know
// whether the original input for current node is positive or not if the input slope is
// negative. eg. forward is 2, slope is -0.2, the original input for this node could be
// either 2, or -10, so no way to get a correct backward gradient in this case.
TORCH_META_FUNC(leaky_relu_backward) (
const Tensor& grad_output,
const Tensor& self_or_result,
const Scalar& negval,
bool is_result
) {
TORCH_CHECK(
!is_result || negval.to<double>() >= 0.0,
"In-place leakyReLu backward calculation is triggered with a negative slope which is not supported. "
"This is caused by calling in-place forward function with a negative slope, "
"please call out-of-place version instead. File an issue at https://github.com/pytorch/pytorch if you do "
"require supporting in-place leakRelu backward calculation with negative slope");
build_borrowing_binary_op(maybe_get_output(), self_or_result, grad_output);
}
TORCH_META_FUNC(hardsigmoid) (const Tensor& self) {
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(hardsigmoid_backward) (const Tensor& grad_output, const Tensor& self) {
build_borrowing_binary_op(maybe_get_output(), grad_output, self);
}
TORCH_META_FUNC(hardshrink) (const Tensor & self, const Scalar& lambd) {
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(hardshrink_backward) (
const Tensor & grad, const Tensor & self, const Scalar& lambd
) {
build_borrowing_binary_op(maybe_get_output(), grad, self);
}
static inline void softshrink_check(const Scalar& lambd) {
double lamb = lambd.to<double>();
TORCH_CHECK(lamb >= 0, "lambda must be greater or equal to 0, but found to be ", lamb, ".");
}
TORCH_META_FUNC(softshrink) (
const Tensor & self, const Scalar& lambd
) {
softshrink_check(lambd);
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(softshrink_backward) (
const Tensor & grad, const Tensor & self, const Scalar& lambd
) {
build_borrowing_binary_op(maybe_get_output(), grad, self);
}
TORCH_META_FUNC(gelu) (const Tensor & self, c10::string_view approximate) {
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(gelu_backward) (
const Tensor& grad, const Tensor& self, c10::string_view approximate
) {
build_borrowing_binary_op(maybe_get_output(), grad, self);
}
} // namespace at::meta
namespace at::native {
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946;
DEFINE_DISPATCH(elu_stub);
DEFINE_DISPATCH(elu_backward_stub);
DEFINE_DISPATCH(softplus_stub);
DEFINE_DISPATCH(softplus_backward_stub);
DEFINE_DISPATCH(log_sigmoid_cpu_stub);
DEFINE_DISPATCH(log_sigmoid_backward_stub);
DEFINE_DISPATCH(threshold_stub);
DEFINE_DISPATCH(hardtanh_backward_stub);
DEFINE_DISPATCH(hardsigmoid_stub);
DEFINE_DISPATCH(hardsigmoid_backward_stub);
DEFINE_DISPATCH(hardswish_stub);
DEFINE_DISPATCH(hardswish_backward_stub);
DEFINE_DISPATCH(hardshrink_stub);
DEFINE_DISPATCH(softshrink_stub);
DEFINE_DISPATCH(shrink_backward_stub);
DEFINE_DISPATCH(leaky_relu_stub);
DEFINE_DISPATCH(leaky_relu_backward_stub);
DEFINE_DISPATCH(silu_stub);
DEFINE_DISPATCH(silu_backward_stub);
DEFINE_DISPATCH(mish_stub);
DEFINE_DISPATCH(mish_backward_stub);
DEFINE_DISPATCH(prelu_stub);
DEFINE_DISPATCH(prelu_backward_stub);
TORCH_IMPL_FUNC(elu_out) (
const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result
) {
elu_stub(device_type(), *this, alpha, scale, input_scale);
}
TORCH_IMPL_FUNC(elu_backward_out) (
const Tensor& grad_output,
const Scalar& alpha,
const Scalar& scale,
const Scalar& input_scale,
bool is_result,
const Tensor& self_or_result,
const Tensor& grad_input
) {
elu_backward_stub(device_type(), *this, alpha, scale, input_scale, is_result);
}
TORCH_IMPL_FUNC(silu_out) (
const Tensor& self, const Tensor& result
) {
silu_stub(device_type(), *this);
}
TORCH_IMPL_FUNC(silu_backward_out) (
const Tensor& grad_output, const Tensor& input, const Tensor& grad_input
) {
silu_backward_stub(device_type(), *this);
}
TORCH_IMPL_FUNC(mish_out) (
const Tensor& self, const Tensor& result
) {
mish_stub(device_type(), *this);
}
TORCH_IMPL_FUNC(softplus_out) (
const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result
) {
softplus_stub(device_type(), *this, beta, threshold);
}
TORCH_IMPL_FUNC(softplus_backward_out) (
const Tensor& grad_output,
const Tensor& self,
const Scalar& beta,
const Scalar& threshold,
const Tensor& grad_input
) {
softplus_backward_stub(device_type(), *this, beta, threshold);
}
TORCH_IMPL_FUNC(leaky_relu_out) (
const Tensor& self, const Scalar& negval, const Tensor& result
) {
leaky_relu_stub(device_type(), *this, negval);
}
TORCH_IMPL_FUNC(leaky_relu_backward_out) (
const Tensor& grad_output,
const Tensor& self_or_result,
const Scalar& negval,
bool is_result,
const Tensor& grad_input
) {
leaky_relu_backward_stub(device_type(), *this, negval);
}
TORCH_IMPL_FUNC(hardsigmoid_out) (
const Tensor& self, const Tensor& result
) {
hardsigmoid_stub(device_type(), *this);
}
TORCH_IMPL_FUNC(hardsigmoid_backward_out) (
const Tensor& grad_output, const Tensor& self, const Tensor& grad_input
) {
hardsigmoid_backward_stub(device_type(), *this);
}
TORCH_IMPL_FUNC(hardshrink_out) (
const Tensor & self, const Scalar& lambd, const Tensor& result
) {
hardshrink_stub(device_type(), *this, lambd);
}
TORCH_IMPL_FUNC(hardshrink_backward_out) (
const Tensor & grad, const Tensor & self, const Scalar& lambd, const Tensor& grad_input
) {
shrink_backward_stub(device_type(), *this, lambd);
}
TORCH_IMPL_FUNC(softshrink_out) (
const Tensor & self, const Scalar& lambd, const Tensor& result
) {
softshrink_stub(device_type(), *this, lambd);
}
TORCH_IMPL_FUNC(softshrink_backward_out) (
const Tensor & grad, const Tensor & self, const Scalar& lambd, const Tensor& grad_input
) {
shrink_backward_stub(device_type(), *this, lambd);
}
#if AT_MKLDNN_ENABLED()
static bool use_mkldnn(const Tensor& input) {
if (!at::globalContext().userEnabledMkldnn()) {
return false;
}
if (!input.is_contiguous() || input.numel() <= 1) {
return false;
}
return (input.is_mkldnn()) || // input is mkldnn Tensor
(input.device().is_cpu() &&
(((input.scalar_type() == kBFloat16) && mkldnn_bf16_device_check()) ||
(input.scalar_type() == kFloat))); // input is dense layout and bfloat16/float32
}
#endif
TORCH_IMPL_FUNC(gelu_out_cpu) (
const Tensor& self, c10::string_view approximate, const Tensor& result
) {
auto approximate_type = get_gelutype_enum(approximate);
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor y = itensor_from_tensor(result);
ideep::eltwise_forward::compute(
x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
} else {
GeluKernel(kCPU, *this, approximate_type);
}
#else
GeluKernel(kCPU, *this, approximate_type);
#endif
}
TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input
) {
auto approximate_type = get_gelutype_enum(approximate);
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor grady = itensor_from_tensor(grad);
ideep::tensor gradx = itensor_from_tensor(grad_input);
ideep::eltwise_backward::compute(x, grady, gradx,
ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0);
} else {
GeluBackwardKernel(kCPU, *this, approximate_type);
}
#else
GeluBackwardKernel(kCPU, *this, approximate_type);
#endif
}
Tensor hardtanh(const Tensor& self, const Scalar& min, const Scalar& max) {
Tensor result = at::empty_like(self);
return at::hardtanh_out(result, self, min, max);
}
Tensor& hardtanh_out(const Tensor& self, const Scalar& min, const Scalar& max, Tensor& result) {
TORCH_CHECK(self.scalar_type() != at::kBool,
"Bool inputs not supported for hardtanh");
//preserve legacy behavior of boundaries not causing type promotion
Scalar min_, max_;
if (at::isIntegralType(self.scalar_type(), /*include_bool*/false)) {
int64_t minval = min.toLong();
int64_t maxval = max.toLong();
TORCH_CHECK(self.dtype() != at::kByte || (minval >= 0 &&
maxval >=0), "cannot do hardtanh on an unsigned type with negative limits");
min_ = minval;
max_ = maxval;
} else {
min_ = min;
max_ = max;
}
return at::clamp_out(result, self, min_, max_);
}
Tensor& hardtanh_(Tensor& self, const Scalar& min, const Scalar& max) {
return at::hardtanh_out(self, self, min, max);
}
Tensor& hardtanh_backward_out(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max, Tensor& grad_input) {
auto iter = TensorIterator::borrowing_binary_op(grad_input, grad_output, self);
hardtanh_backward_stub(iter.device_type(), iter, min, max);
return grad_input;
}
Tensor hardtanh_backward(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max) {
Tensor result;
auto iter = TensorIterator::borrowing_binary_op(result, grad_output, self);
hardtanh_backward_stub(iter.device_type(), iter, min, max);
return iter.output();
}
Tensor hardswish(const Tensor& self) {
#if defined(C10_MOBILE) && defined(USE_XNNPACK)
if (xnnpack::use_hardswish(self)) {
return xnnpack::hardswish(self);
}
#endif
Tensor result;
auto iter = TensorIterator::unary_op(result, self);
hardswish_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& hardswish_out(const Tensor& self, Tensor& result) {
auto iter = TensorIterator::unary_op(result, self);
hardswish_stub(iter.device_type(), iter);
return result;
}
Tensor& hardswish_(Tensor& self) {
#if defined(C10_MOBILE) && defined(USE_XNNPACK)
if (xnnpack::use_hardswish(self)) {
xnnpack::hardswish_(self);
return self;
}
#endif
auto iter = TensorIterator::unary_op(self, self);
hardswish_stub(iter.device_type(), iter);
return self;
}
Tensor hardswish_backward(const Tensor& grad_output, const Tensor& self) {
Tensor grad_input;
auto iter = TensorIterator::borrowing_binary_op(grad_input, grad_output, self);
hardswish_backward_stub(iter.device_type(), iter);
return iter.output();
}
Tensor relu(const Tensor & self) {
TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
return at::clamp_min(self, 0);
}
Tensor & relu_(Tensor & self) {
TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
return at::clamp_min_(self, 0);
}
Tensor selu(const Tensor & self) {
return at::elu(self, SELU_ALPHA, SELU_SCALE);
}
Tensor relu6(const Tensor & self) {
return at::hardtanh(self, /*min_val=*/0, /*max_val=*/6);
}
Tensor & selu_(Tensor & self) {
return at::elu_(self, SELU_ALPHA, SELU_SCALE);
}
Tensor & relu6_(Tensor & self) {
return at::hardtanh_(self, /*min_val=*/0, /*max_val=*/6);
}
Tensor celu(const Tensor & self, const Scalar& alpha) {
TORCH_CHECK(alpha.to<double>() != 0,
"ZeroDivisionError: alpha cannot be 0 for CELU");
double inv_alpha = 1. / alpha.to<double>();
return at::elu(self, alpha, Scalar(1.0), Scalar(inv_alpha));
}
Tensor & celu_(Tensor & self, const Scalar& alpha) {
TORCH_CHECK(alpha.to<double>() != 0,
"ZeroDivisionError: alpha cannot be 0 for CELU");
double inv_alpha = 1. / alpha.to<double>();
return at::elu_(self, alpha, Scalar(1.0), Scalar(inv_alpha));
}
Tensor math_silu_backward(
const Tensor& grad_output,
const Tensor& input) {
auto input_sigmoid = at::sigmoid(input);
return grad_output * (input_sigmoid * (1 + input * (1 - input_sigmoid)));
}
Tensor mish_backward(
const Tensor& grad_output,
const Tensor& input) {
Tensor grad_input = at::empty({0}, input.options());
auto iter = TensorIterator::binary_op(grad_input, grad_output, input);
mish_backward_stub(iter.device_type(), iter);
return grad_input;
}
Tensor math_mish_backward(
const Tensor& grad_output,
const Tensor& input) {
auto input_tanh_softplus = at::tanh(at::softplus(input));
auto input_sigmoid = at::sigmoid(input);
return grad_output * (input_tanh_softplus + (input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)));
}
template <typename scalar_t>
inline void _rrelu_with_noise_train(
Tensor& output,
const Tensor& input,
const Tensor& noise,
const Scalar& lower_,
const Scalar& upper_,
c10::optional<Generator> generator) {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t lower = lower_.to<opmath_t>();
opmath_t upper = upper_.to<opmath_t>();
Tensor tmp_tensor = output.contiguous();
scalar_t* output_data = tmp_tensor.data_ptr<scalar_t>();
scalar_t* input_data = input.data_ptr<scalar_t>();
scalar_t* noise_data = noise.data_ptr<scalar_t>();
auto gen = at::get_generator_or_default<CPUGeneratorImpl>(generator, detail::getDefaultCPUGenerator());
std::lock_guard<std::mutex> lock(gen->mutex_);
for (const auto i : c10::irange(input.numel())) {
if (input_data[i] <= 0) {
at::uniform_real_distribution<double> uniform(lower, upper);
const opmath_t r = (opmath_t)uniform(gen);
output_data[i] = input_data[i] * r;
noise_data[i] = r;
} else {
noise_data[i] = 1;
output_data[i] = input_data[i];
}
}
if (!output.is_contiguous()) {
output.copy_(tmp_tensor);
}
}
Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
const Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
c10::optional<Generator> generator,
Tensor& output) {
if (training) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "rrelu_with_noise_out_cpu", [&] {
_rrelu_with_noise_train<scalar_t>(output, self.contiguous(), noise, lower, upper, generator);
});
return output;
} else {
auto lower_tensor = scalar_to_tensor(lower);
auto upper_tensor = scalar_to_tensor(upper);
auto negative = (lower_tensor + upper_tensor) / 2;
Scalar negative_slope = negative.item();
return at::leaky_relu_out(output, self, negative_slope);
}
}
Tensor rrelu_with_noise_cpu(
const Tensor& self,
const Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
c10::optional<Generator> generator) {
auto output = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return at::native::rrelu_with_noise_out_cpu(
self, noise, lower, upper, training, std::move(generator), output);
}
Tensor& rrelu_with_noise_cpu_(
Tensor& self,
const Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
c10::optional<Generator> generator) {
return at::native::rrelu_with_noise_out_cpu(
self, noise, lower, upper, training, std::move(generator), self);
}
Tensor rrelu_with_noise_backward(
const Tensor& grad_output,
const Tensor& self_or_result,
const Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
bool is_result) {
if (training) {
return noise * grad_output;
} else {
auto l = lower.toDouble();
auto u = upper.toDouble();
auto mid = (l + u) / 2.;
return at::leaky_relu_backward(grad_output, self_or_result, mid, is_result);
}
}
Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, c10::optional<Generator> generator) {
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
return at::rrelu_with_noise(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
}
Tensor & rrelu_(Tensor & self, const Scalar& lower, const Scalar& upper, bool training, c10::optional<Generator> generator) {
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
}
TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) {
threshold_stub(device_type(), *this, threshold, value);
}
TORCH_IMPL_FUNC(threshold_backward_out)(const Tensor& grad, const Tensor& self, const Scalar& threshold, const Tensor& gradInput) {
threshold_stub(device_type(), *this, threshold, 0);
}
Tensor prelu(const Tensor& self, const Tensor& weight_) {
TORCH_INTERNAL_ASSERT(weight_.defined());
auto self_dim = self.dim();
TORCH_CHECK(self.scalar_type() == weight_.scalar_type(),
"prelu: Type promoting not supported. Got ",
self.scalar_type(), " and ", weight_.scalar_type());
if (weight_.sym_numel() != 1) {
TORCH_CHECK(self_dim > 0, "Not allow zero-dim input tensor.");
auto channel_size = self_dim > 1 ? self.sym_size(1) : 1; // channel_size default to 1
TORCH_CHECK(channel_size == weight_.sym_numel(),
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_.numel(),
" and channel size = ", channel_size, ".");
}
TORCH_CHECK(
weight_.dim() <= 1,
"prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ", weight_.dim());
// Adjust weight to broadcast over self and have weight.ndim == self.ndim
auto weight = weight_;
if (self_dim != weight.dim()) {
SymDimVector dim_w(self_dim, 1);
if (self_dim > 1) {
dim_w[1] = weight_.sym_numel();
}
// This will always be a view in CPU/CUDA, but some backends
// like MKLDNN do not support views
weight = weight.reshape_symint(dim_w);
}
return at::_prelu_kernel(self, weight);
}
Tensor _prelu_kernel(const Tensor& self, const Tensor& weight) {
// Weight broadcasts over self and they have the same dtype
auto result = at::empty_like(self);
auto iter = TensorIteratorConfig()
.add_output(result)
.add_input(self)
.add_input(weight)
.build();
prelu_stub(iter.device_type(), iter);
return result;
}
std::tuple<Tensor, Tensor> _prelu_kernel_backward(const Tensor& grad_out, const Tensor& self, const Tensor& weight) {
Tensor grad_self = at::empty({0}, self.options());
Tensor grad_weight = at::empty({0}, weight.options());
auto iter = TensorIteratorConfig()
.add_output(grad_self)
.add_output(grad_weight)
.add_input(self)
.add_input(weight)
.add_input(grad_out)
.build();
prelu_backward_stub(iter.device_type(), iter);
return {grad_self, grad_weight};
}
Tensor infinitely_differentiable_gelu_backward(
const Tensor& grad,
const Tensor& self) {
constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
Tensor cdf = (1.0 + (self * M_SQRT1_2).erf_()).mul_(0.5);
Tensor pdf = (-0.5 * self * self).exp_();
return cdf.addcmul_(self, pdf, kAlpha).mul_(grad);
}
std::tuple<Tensor, Tensor> log_sigmoid_forward_cpu(const Tensor& input) {
// FIXME: do these actually need to be zeros_like or can they be empty_like?
auto result = at::zeros_like(input, at::MemoryFormat::Contiguous);
auto buffer = at::zeros_like(input, at::MemoryFormat::Contiguous);
log_sigmoid_cpu_stub(kCPU, result, buffer, input.contiguous());
return std::make_tuple(result, buffer);
}
std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_cpu(const Tensor& input, Tensor& result, Tensor& buffer) {
result.resize_as_(input);
buffer.resize_as_(input, at::MemoryFormat::Contiguous);
TORCH_CHECK(buffer.is_contiguous(), "Contiguous buffer required for log_sigmoid with out parameter");
Tensor result_tmp = result.is_contiguous() ? result : at::empty_like(result, at::MemoryFormat::Contiguous);
log_sigmoid_cpu_stub(kCPU, result_tmp, buffer, input.contiguous());
if (!result.is_contiguous()) {
result.copy_(result_tmp);
}
return std::forward_as_tuple(result, buffer);
}
Tensor & log_sigmoid_out(const Tensor & self, Tensor & output) {
Tensor buffer = at::empty({0}, self.options());
return std::get<0>(at::log_sigmoid_forward_out(output, buffer, self));
}
Tensor log_sigmoid(const Tensor & self) {
return std::get<0>(at::log_sigmoid_forward(self));
}
Tensor log_sigmoid_backward_cuda(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) {
auto grad_input = at::empty_like(grad_output);
// NOTE: buffer is only used by CPU dispatch, we just ignore it here
auto iter = at::TensorIteratorConfig()
.add_output(grad_input)
.add_input(input)
.add_input(grad_output)
.build();
log_sigmoid_backward_stub(kCUDA, iter);
return iter.output();
}
Tensor log_sigmoid_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) {
auto grad_input = at::empty_like(grad_output);
auto iter = at::TensorIteratorConfig()
.add_output(grad_input)
.add_input(input)
.add_input(buffer)
.add_input(grad_output)
.build();
log_sigmoid_backward_stub(kCPU, iter);
return iter.output();
}
Tensor& log_sigmoid_backward_cuda_out(const Tensor& grad_output, const Tensor& input,
const Tensor& buffer, Tensor& grad_input) {
auto iter = TensorIteratorConfig()
.add_output(grad_input)
.add_input(input)
.add_input(grad_output)
.build();
log_sigmoid_backward_stub(kCUDA, iter);
return grad_input;
}
Tensor& log_sigmoid_backward_cpu_out(const Tensor& grad_output,
const Tensor& input,
const Tensor& buffer,
Tensor& grad_input) {
auto iter = TensorIteratorConfig()
.add_output(grad_input)
.add_input(input)
.add_input(buffer)
.add_input(grad_output)
.build();
log_sigmoid_backward_stub(kCPU, iter);
return grad_input;
}
DEFINE_DISPATCH(GeluKernel);
DEFINE_DISPATCH(GeluBackwardKernel);
} // namespace at::native