-
-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
160 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
use crate::tensor_ops::cpu_kernels::UnaryDerivative; | ||
|
||
impl<F: num_traits::Float> UnaryDerivative<F> for super::SiLUKernelOp { | ||
const DF_USES_FX: bool = false; | ||
const HAS_CONST_DF: bool = false; | ||
|
||
// x / (1 + e^-x) | ||
#[inline(always)] | ||
fn f(&self, x: &F) -> F { | ||
*x / (F::one() + x.neg().exp()) | ||
} | ||
|
||
// (1 + e^-x + x * e^-x) / (1 + e^-x)^2 | ||
// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2 | ||
#[inline(always)] | ||
fn df(&self, x: &F) -> F { | ||
let exp_nx = x.neg().exp(); | ||
F::one() + exp_nx + *x * exp_nx / (F::one() + exp_nx).powi(2) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
use super::SiLUKernelOp; | ||
#[allow(unused_imports)] | ||
use crate::dtypes::*; | ||
use crate::tensor_ops::cuda_kernels::cuda_unary; | ||
|
||
unsafe impl cudarc::driver::DeviceRepr for SiLUKernelOp {} | ||
|
||
const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/silu.ptx")); | ||
|
||
#[cfg(feature = "f16")] | ||
cuda_unary!(SiLUKernelOp, f16, PTX, "silu_fwd_f16", "silu_bwd_f16"); | ||
#[cfg(feature = "f16")] | ||
cuda_unary!(SiLUKernelOp, AMP<f16>, PTX, "silu_fwd_f16", "silu_bwd_f16"); | ||
cuda_unary!(SiLUKernelOp, f32, PTX, "silu_fwd_f32", "silu_bwd_f32"); | ||
cuda_unary!(SiLUKernelOp, f64, PTX, "silu_fwd_f64", "silu_bwd_f64"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
mod cpu_kernel; | ||
|
||
#[cfg(feature = "cuda")] | ||
mod cuda_kernel; | ||
|
||
#[cfg(feature = "webgpu")] | ||
mod webgpu_kernel; | ||
|
||
use super::ops::{try_unary_op, UnaryKernel}; | ||
use crate::{shapes::*, tensor::*}; | ||
|
||
#[repr(C)] | ||
#[derive(Debug, Default, Copy, Clone)] | ||
pub struct SiLUKernelOp; | ||
|
||
/// [Sigmoid-Weighted Linear Unit (SiLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)). `x * x.sigmoid()` | ||
/// | ||
/// The derivative is `x * sigmoid'(x) + sigmoid(x)`. | ||
/// | ||
/// Examples: | ||
/// ```rust | ||
/// # use dfdx_core::prelude::*; | ||
/// # let dev: Cpu = Default::default(); | ||
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]); | ||
/// let r = t.silu(); | ||
/// ``` | ||
pub fn silu<S: Shape, E: Dtype, D: UnaryKernel<SiLUKernelOp, E>, T: Tape<E, D>>( | ||
t: Tensor<S, E, D, T>, | ||
) -> Tensor<S, E, D, T> { | ||
t.silu() | ||
} | ||
|
||
impl<S: Shape, E: Dtype, D: UnaryKernel<SiLUKernelOp, E>, T: Tape<E, D>> Tensor<S, E, D, T> { | ||
/// See [silu] | ||
pub fn silu(self) -> Self { | ||
self.try_silu().unwrap() | ||
} | ||
/// See [silu] | ||
pub fn try_silu(self) -> Result<Self, crate::tensor::Error> { | ||
try_unary_op(SiLUKernelOp, self) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::{tensor::*, tensor_ops::*, tests::*}; | ||
|
||
#[test] | ||
fn test_silu() { | ||
let dev: TestDevice = Default::default(); | ||
let x = dev | ||
.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) | ||
.to_dtype::<TestDtype>(); | ||
let r = x.leaky_trace().silu(); | ||
assert_close_to_literal!(r, [-0.23840584, -0.26894143, 0.0, 0.7310586, 1.761594]); | ||
let g = r.mean().backward(); | ||
assert_close_to_literal!( | ||
g.get(&x), | ||
[1.635814, 0.70433396, 0.4, 0.31289828, 0.26906452] | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#include "unary_op_macros.cuh" | ||
|
||
struct SiLUKernelOp {}; | ||
|
||
// x / (1 + e^-x) | ||
template<typename T> | ||
__device__ __forceinline__ T silu_fwd(T x) { | ||
T one = 1.0; | ||
return x / (one + expg(-x)); | ||
} | ||
|
||
// (1 + e^-x + x * e^-x) / (1 + e^-x)^2 | ||
// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2 | ||
template<typename T> | ||
__device__ __forceinline__ T silu_bwd(T x) { | ||
T one = 1.0; | ||
T exp_nx = expg(-x); | ||
T denom_sqrt = (one + exp_nx); | ||
return (one + exp_nx + x * exp_nx) / (denom_sqrt * denom_sqrt); | ||
} | ||
|
||
UNARY_OP(__half, silu_fwd_f16, silu_bwd_f16, SiLUKernelOp, | ||
silu_fwd(x), | ||
silu_bwd(x)) | ||
|
||
UNARY_OP(float, silu_fwd_f32, silu_bwd_f32, SiLUKernelOp, | ||
silu_fwd(x), | ||
silu_bwd(x)) | ||
|
||
UNARY_OP(double, silu_fwd_f64, silu_bwd_f64, SiLUKernelOp, | ||
silu_fwd(x), | ||
silu_bwd(x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
use std::borrow::Cow; | ||
|
||
use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; | ||
|
||
impl<E: Dtype> UnaryKernel<super::SiLUKernelOp, E> for Webgpu { | ||
const BACKWARD_WITHOUT_INP: bool = false; | ||
|
||
const BACKWARD_WITHOUT_DATA: bool = false; | ||
|
||
fn forward<S: crate::prelude::Shape>( | ||
&self, | ||
op: super::SiLUKernelOp, | ||
inp: Cow<crate::prelude::Tensor<S, E, Self>>, | ||
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> { | ||
todo!() | ||
} | ||
|
||
fn backward<S: crate::prelude::Shape>( | ||
&self, | ||
op: super::SiLUKernelOp, | ||
inp: &impl crate::prelude::Tensorlike<S, E, Self>, | ||
grad_inp: &mut Self::Vec, | ||
out: &impl crate::prelude::Tensorlike<S, E, Self>, | ||
grad_out: &Self::Vec, | ||
) -> Result<(), crate::prelude::Error> { | ||
todo!() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters