diff --git a/dfdx-core/src/tensor/numpy.rs b/dfdx-core/src/tensor/numpy.rs index b362be2a..0a22a97f 100644 --- a/dfdx-core/src/tensor/numpy.rs +++ b/dfdx-core/src/tensor/numpy.rs @@ -282,6 +282,166 @@ impl NumpyDtype for f64 { } } +impl NumpyDtype for u8 { + const NUMPY_DTYPE_STR: &'static str = "u1"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 1]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for u16 { + const NUMPY_DTYPE_STR: &'static str = "u2"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 2]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for u32 { + const NUMPY_DTYPE_STR: &'static str = "u4"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 4]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for u64 { + const NUMPY_DTYPE_STR: &'static str = "u8"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 8]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for i8 { + const NUMPY_DTYPE_STR: &'static str = "i1"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 1]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for i16 { + const NUMPY_DTYPE_STR: &'static str = "i2"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 2]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for i32 { + const NUMPY_DTYPE_STR: &'static str = "i4"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 4]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for i64 { + const NUMPY_DTYPE_STR: &'static str = "i8"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 8]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + #[derive(Debug)] pub enum NpyError { /// Magic number did not match the expected value. @@ -560,4 +720,95 @@ mod tests { .load_from_npy(file.path()) .expect_err(""); } + + #[test] + fn test_0d_u8_save() { + let dev: TestDevice = Default::default(); + + let x = dev.tensor(0u8); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + + x.save_to_npy(file.path()).expect("Saving failed"); + + let mut f = File::open(file.path()).expect("No file found"); + + let mut found = Vec::new(); + f.read_to_end(&mut found).expect("Reading failed"); + + assert_eq!( + &found, + &[ + 147, 78, 85, 77, 80, 89, 1, 0, 64, 0, 123, 39, 100, 101, 115, 99, 114, 39, 58, 32, + 39, 60, 117, 49, 39, 44, 32, 39, 102, 111, 114, 116, 114, 97, 110, 95, 111, 114, + 100, 101, 114, 39, 58, 32, 70, 97, 108, 115, 101, 44, 32, 39, 115, 104, 97, 112, + 101, 39, 58, 32, 40, 41, 44, 32, 125, 32, 32, 32, 32, 32, 32, 32, 32, 10, 0 + ] + ); + } + + #[test] + fn test_0d_u8_load() { + let dev: TestDevice = Default::default(); + let x = dev.tensor(2u8); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + + x.save_to_npy(file.path()).expect("Saving failed"); + + let mut v = dev.tensor(0u8); + v.load_from_npy(file.path()).expect("Loading failed"); + assert_eq!(v.array(), x.array()); + + dev.tensor(0u16).load_from_npy(file.path()).expect_err(""); + dev.tensor([0u8; 1]) + .load_from_npy(file.path()) + .expect_err(""); + } + + #[test] + fn test_0d_i8_save() { + let dev: TestDevice = Default::default(); + + let x = dev.tensor(0i8); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + + x.save_to_npy(file.path()).expect("Saving failed"); + x.save_to_npy("out.npy").expect("Saving failed"); + + let mut f = File::open(file.path()).expect("No file found"); + + let mut found = Vec::new(); + f.read_to_end(&mut found).expect("Reading failed"); + + assert_eq!( + &found, + &[ + 147, 78, 85, 77, 80, 89, 1, 0, 64, 0, 123, 39, 100, 101, 115, 99, 114, 39, 58, 32, + 39, 60, 105, 49, 39, 44, 32, 39, 102, 111, 114, 116, 114, 97, 110, 95, 111, 114, + 100, 101, 114, 39, 58, 32, 70, 97, 108, 115, 101, 44, 32, 39, 115, 104, 97, 112, + 101, 39, 58, 32, 40, 41, 44, 32, 125, 32, 32, 32, 32, 32, 32, 32, 32, 10, 0 + ] + ); + } + + #[test] + fn test_0d_i8_load() { + let dev: TestDevice = Default::default(); + let x = dev.tensor(2i8); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + + x.save_to_npy(file.path()).expect("Saving failed"); + + let mut v = dev.tensor(0i8); + v.load_from_npy(file.path()).expect("Loading failed"); + assert_eq!(v.array(), x.array()); + + dev.tensor(0i16).load_from_npy(file.path()).expect_err(""); + dev.tensor([0i8; 1]) + .load_from_npy(file.path()) + .expect_err(""); + } } diff --git a/dfdx-core/src/tensor_ops/concat_along/mod.rs b/dfdx-core/src/tensor_ops/concat_along/mod.rs index 92816692..e632aeb2 100644 --- a/dfdx-core/src/tensor_ops/concat_along/mod.rs +++ b/dfdx-core/src/tensor_ops/concat_along/mod.rs @@ -1,11 +1,6 @@ +pub use super::concat_tensor_along::ConcatAlongKernel; use crate::{shapes::*, tensor::*}; -mod cpu_kernel; -#[cfg(feature = "cuda")] -mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; - /// Concatenate two tensors along a given axis. /// /// **Pytorch equivalent** `torch.concat`. @@ -48,6 +43,7 @@ mod webgpu_kernel; /// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4)); /// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>).realize(); /// ``` +#[deprecated = "Use TryConcatTensorAlong or TryConcatShapeAlong instead"] pub trait TryConcatAlong: Sized { type Output; @@ -59,26 +55,7 @@ pub trait TryConcatAlong: Sized { fn try_concat_along(self, ax: Ax) -> Result; } -pub trait ConcatAlongKernel: Storage { - fn forward( - &self, - ax: usize, - a: &Tensor, - b: &Tensor, - c: &mut Tensor, - ) -> Result<(), Error>; - - fn backward( - &self, - ax: usize, - a: &GhostTensor, - grad_a: &mut Self::Vec, - b: &GhostTensor, - grad_b: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), Error>; -} - +#[allow(deprecated)] impl, R: Tape> TryConcatAlong for (Tensor, Tensor) where @@ -123,6 +100,7 @@ where macro_rules! impl_concat { ($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => { + #[allow(deprecated)] impl TryConcatAlong> for ( ($($Head, )* A, $($Tail, )*), @@ -183,6 +161,7 @@ impl_concat!(4, 6, [D0, D1, D2, D3], [D5]); impl_concat!(5, 6, [D0, D1, D2, D3, D4], []); #[cfg(test)] +#[allow(deprecated)] mod tests { use super::*; use crate::{tensor_ops::*, tests::*}; diff --git a/dfdx-core/src/tensor_ops/concat_shape_along/mod.rs b/dfdx-core/src/tensor_ops/concat_shape_along/mod.rs new file mode 100644 index 00000000..f22044ac --- /dev/null +++ b/dfdx-core/src/tensor_ops/concat_shape_along/mod.rs @@ -0,0 +1,153 @@ +use crate::{shapes::*, tensor::*}; + +/// Concatenate two shapes along a given axis. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Rank2<3, 4> = (Const, Const); +/// let b: Rank2<3, 4> = (Const, Const); +/// let _: Rank2<6, 4> = (a, b).concat_shape_along(Axis::<0>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Rank2<3, 4> = (Const, Const); +/// let b: Rank2<3, 4> = (Const, Const); +/// let _: Rank2<3, 8> = (a, b).concat_shape_along(Axis::<1>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: (usize, Const<3>) = (2, Const); +/// let b: (usize, Const<3>) = (4, Const); +/// let c: (usize, Const<3>) = (a, b).concat_shape_along(Axis::<0>); +/// assert_eq!(c, (6, Const::<3>)); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: (Const<2>, usize) = (Const, 2); +/// let b: (Const<2>, usize) = (Const, 4); +/// let c: (Const<2>, usize) = (a, b).concat_shape_along(Axis::<1>); +/// assert_eq!(c, (Const::<2>, 6)); +/// ``` +pub trait TryConcatShapeAlong: Sized { + type Output: Shape; + + /// Concatenates self along the given axis. + fn concat_shape_along(self, ax: Ax) -> Self::Output { + self.try_concat_shape_along(ax).unwrap() + } + /// Fallibly concatenates self along the given axis. + fn try_concat_shape_along(self, ax: Ax) -> Result; +} + +macro_rules! impl_concat { + ($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => { + impl TryConcatShapeAlong> + for ( + ($($Head, )* A, $($Tail, )*), + ($($Head, )* B, $($Tail, )*), + ) + where + A: std::ops::Add, + >::Output: Dim, + { + type Output = ( + $($Head, )* + >::Output, + $($Tail, )* + ); + + fn try_concat_shape_along(self, _: Axis<$Ax>) -> Result { + let (lhs, rhs) = self; + let lhs_dims = lhs.concrete(); + let rhs_dims = rhs.concrete(); + for i in 0..$NumDims { + if i != $Ax { + assert_eq!(lhs_dims[i], rhs_dims[i]); + } + } + let mut out_dims = lhs_dims; + out_dims[$Ax] += rhs_dims[$Ax]; + Ok(Self::Output::from_concrete(&out_dims).unwrap()) + } + } + }; +} + +impl_concat!(0, 1, [], []); +impl_concat!(0, 2, [], [D1]); +impl_concat!(0, 3, [], [D1, D2]); +impl_concat!(0, 4, [], [D1, D2, D3]); +impl_concat!(0, 5, [], [D1, D2, D3, D4]); +impl_concat!(0, 6, [], [D1, D2, D3, D4, D5]); + +impl_concat!(1, 2, [D0], []); +impl_concat!(1, 3, [D0], [D2]); +impl_concat!(1, 4, [D0], [D2, D3]); +impl_concat!(1, 5, [D0], [D2, D3, D4]); +impl_concat!(1, 6, [D0], [D2, D3, D4, D5]); + +impl_concat!(2, 3, [D0, D1], []); +impl_concat!(2, 4, [D0, D1], [D3]); +impl_concat!(2, 5, [D0, D1], [D3, D4]); +impl_concat!(2, 6, [D0, D1], [D3, D4, D5]); + +impl_concat!(3, 4, [D0, D1, D2], []); +impl_concat!(3, 5, [D0, D1, D2], [D4]); +impl_concat!(3, 6, [D0, D1, D2], [D4, D5]); + +impl_concat!(4, 5, [D0, D1, D2, D3], []); +impl_concat!(4, 6, [D0, D1, D2, D3], [D5]); + +impl_concat!(5, 6, [D0, D1, D2, D3, D4], []); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_concat_shape() { + let a: (usize, Const<5>) = (5, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!((a, b).concat_shape_along(Axis::<0>), (8, Const::<5>)); + + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!((a, b).concat_shape_along(Axis::<0>), (8, Const::<5>)); + + let a: (usize, Const<5>) = (5, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!((a, b).concat_shape_along(Axis::<0>), (8, Const::<5>)); + + #[cfg(feature = "nightly")] + { + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!( + (a, b).concat_shape_along(Axis::<0>), + (Const::<8>, Const::<5>) + ); + } + } + + #[test] + #[should_panic = "left: 10\n right: 7"] + fn test_concat_shape_fails() { + let a = (5, 10); + let b = (3, 7); + (a, b).concat_shape_along(Axis::<0>); + } +} diff --git a/dfdx-core/src/tensor_ops/concat_along/cpu_kernel.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs similarity index 100% rename from dfdx-core/src/tensor_ops/concat_along/cpu_kernel.rs rename to dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs diff --git a/dfdx-core/src/tensor_ops/concat_along/cuda_kernel.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/cuda_kernel.rs similarity index 100% rename from dfdx-core/src/tensor_ops/concat_along/cuda_kernel.rs rename to dfdx-core/src/tensor_ops/concat_tensor_along/cuda_kernel.rs diff --git a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs new file mode 100644 index 00000000..7462fd2b --- /dev/null +++ b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs @@ -0,0 +1,227 @@ +use super::concat_shape_along::TryConcatShapeAlong; +use crate::{shapes::*, tensor::*}; + +pub(crate) mod cpu_kernel; +#[cfg(feature = "cuda")] +pub(crate) mod cuda_kernel; +#[cfg(feature = "webgpu")] +mod webgpu_kernel; + +/// Concatenate two tensors along a given axis. +/// +/// **Pytorch equivalent** `torch.concat`. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Tensor, f32, _> = dev.zeros(); +/// let b: Tensor, f32, _> = dev.zeros(); +/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Tensor, f32, _> = dev.zeros(); +/// let b: Tensor, f32, _> = dev.zeros(); +/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const)); +/// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const)); +/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>).realize(); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2)); +/// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4)); +/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>).realize(); +/// ``` +pub trait TryConcatTensorAlong: Sized { + type Output; + + /// Concatenates self along the given axis. + fn concat_tensor_along(self, ax: Ax) -> Self::Output { + self.try_concat_tensor_along(ax).unwrap() + } + /// Fallibly concatenates self along the given axis. + fn try_concat_tensor_along(self, ax: Ax) -> Result; +} + +pub trait ConcatAlongKernel: Storage { + fn forward( + &self, + ax: usize, + a: &Tensor, + b: &Tensor, + c: &mut Tensor, + ) -> Result<(), Error>; + + fn backward( + &self, + ax: usize, + a: &GhostTensor, + grad_a: &mut Self::Vec, + b: &GhostTensor, + grad_b: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Error>; +} + +impl, R: Tape> TryConcatTensorAlong + for (Tensor, Tensor) +where + Ax: Axes, + D: ConcatAlongKernel + ZerosTensor, + A: Shape + HasAxes, + B: Shape + HasAxes, + (A, B): TryConcatShapeAlong, + T: Merge, +{ + type Output = Tensor<<(A, B) as TryConcatShapeAlong>::Output, E, D, T>; + + fn try_concat_tensor_along(self, ax: Ax) -> Result { + let (lhs, rhs) = self; + + let out_shape = (*lhs.shape(), *rhs.shape()).concat_shape_along(ax); + let ax = Ax::as_array()[0] as usize; + + let (lhs, tape) = lhs.split_tape(); + let (rhs, rtape) = rhs.split_tape(); + let mut tape = tape.merge(rtape); + + let mut out = lhs.device.try_zeros_like(&out_shape)?; + lhs.device.forward(ax, &lhs, &rhs, &mut out)?; + + let lhs_ghost = lhs.ghost(); + let rhs_ghost = rhs.ghost(); + let out_ghost = out.ghost(); + tape.add_backward_op(move |grads| { + grads.try_alloc_for(&lhs_ghost)?; + grads.try_alloc_for(&rhs_ghost)?; + grads.try_alloc_for(&out_ghost)?; + let (lhs_grad, rhs_grad, out_grad) = + grads.muts_and_ref(&lhs_ghost, &rhs_ghost, &out_ghost); + lhs.device + .backward(ax, &lhs_ghost, lhs_grad, &rhs_ghost, rhs_grad, out_grad) + }); + Ok(out.put_tape(tape)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{tensor_ops::*, tests::*}; + + #[test] + fn test_concat_ax_0() { + let dev: TestDevice = Default::default(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); + let a_dyn = a + .leaky_trace() + .try_realize::<(usize, Const<3>, Const<4>)>() + .unwrap(); + let b_dyn = b + .clone() + .try_realize::<(usize, Const<3>, Const<4>)>() + .unwrap(); + let c = (a_dyn, b_dyn).concat_tensor_along(Axis::<0>); + let c = c.try_realize::<(Const<5>, Const<3>, Const<4>)>().unwrap(); + let a_arr = a.array(); + let b_arr = b.array(); + let c_arr = c.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{c_arr:?}"); + assert_eq!(c_arr[0], a_arr[0]); + assert_eq!(c_arr[1], a_arr[1]); + assert_eq!(c_arr[2], b_arr[0]); + assert_eq!(c_arr[3], b_arr[1]); + assert_eq!(c_arr[4], b_arr[2]); + let concat_grads = c.exp().sum().backward(); + let a_grads = a.leaky_trace().exp().sum().backward(); + let b_grads = b.leaky_trace().exp().sum().backward(); + assert_close_to_tensor!(concat_grads.get(&a), a_grads.get(&a)); + assert_close_to_tensor!(concat_grads.get(&b), b_grads.get(&b)); + } + + #[test] + fn test_concat_ax_1() { + let dev: TestDevice = Default::default(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); + let a_dyn = a + .leaky_trace() + .try_realize::<(Const<2>, usize, Const<4>)>() + .unwrap(); + let b_dyn = b + .clone() + .try_realize::<(Const<2>, usize, Const<4>)>() + .unwrap(); + let c = (a_dyn, b_dyn).concat_tensor_along(Axis::<1>); + let c = c.try_realize::<(Const<2>, Const<5>, Const<4>)>().unwrap(); + let a_arr = a.array(); + let b_arr = b.array(); + let c_arr = c.array(); + for i in 0..2 { + assert_eq!(c_arr[i][0], a_arr[i][0]); + assert_eq!(c_arr[i][1], a_arr[i][1]); + assert_eq!(c_arr[i][2], b_arr[i][0]); + assert_eq!(c_arr[i][3], b_arr[i][1]); + assert_eq!(c_arr[i][4], b_arr[i][2]); + } + let concat_grads = c.exp().sum().backward(); + let a_grads = a.leaky_trace().exp().sum().backward(); + let b_grads = b.leaky_trace().exp().sum().backward(); + assert_close_to_tensor!(concat_grads.get(&a), a_grads.get(&a)); + assert_close_to_tensor!(concat_grads.get(&b), b_grads.get(&b)); + } + + #[test] + fn test_concat_ax_2() { + let dev: TestDevice = Default::default(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); + let a_dyn = a + .leaky_trace() + .try_realize::<(Const<2>, Const<3>, usize)>() + .unwrap(); + let b_dyn = b + .clone() + .try_realize::<(Const<2>, Const<3>, usize)>() + .unwrap(); + let c = (a_dyn, b_dyn).concat_tensor_along(Axis::<2>); + let c = c.try_realize::<(Const<2>, Const<3>, Const<5>)>().unwrap(); + let a_arr = a.array(); + let b_arr = b.array(); + let c_arr = c.array(); + for i in 0..2 { + for j in 0..3 { + assert_eq!(c_arr[i][j][0], a_arr[i][j][0]); + assert_eq!(c_arr[i][j][1], a_arr[i][j][1]); + assert_eq!(c_arr[i][j][2], b_arr[i][j][0]); + assert_eq!(c_arr[i][j][3], b_arr[i][j][1]); + assert_eq!(c_arr[i][j][4], b_arr[i][j][2]); + } + } + let concat_grads = c.exp().sum().backward(); + let a_grads = a.leaky_trace().exp().sum().backward(); + let b_grads = b.leaky_trace().exp().sum().backward(); + assert_close_to_tensor!(concat_grads.get(&a), a_grads.get(&a)); + assert_close_to_tensor!(concat_grads.get(&b), b_grads.get(&b)); + } +} diff --git a/dfdx-core/src/tensor_ops/concat_along/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/webgpu_kernel.rs similarity index 100% rename from dfdx-core/src/tensor_ops/concat_along/webgpu_kernel.rs rename to dfdx-core/src/tensor_ops/concat_tensor_along/webgpu_kernel.rs diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index c51040ee..d934b678 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -163,6 +163,8 @@ mod clamp; mod cmp; mod concat; mod concat_along; +mod concat_shape_along; +mod concat_tensor_along; mod cos; mod div; mod dropout; @@ -224,7 +226,10 @@ pub use clamp::clamp; pub use cmp::{eq, ge, gt, le, lt, ne, TryEq, TryGe, TryGt, TryLe, TryLt, TryNe}; #[allow(deprecated)] pub use concat::TryConcat; +#[allow(deprecated)] pub use concat_along::TryConcatAlong; +pub use concat_shape_along::TryConcatShapeAlong; +pub use concat_tensor_along::TryConcatTensorAlong; pub use cos::cos; pub use div::{div, TryDiv}; pub use dropout::dropout; diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index dd89a186..00fa9502 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -18,8 +18,8 @@ pub trait Device: + super::super::stack::StackKernel + super::super::concat::ConcatKernel + super::super::concat::ConcatKernel - + super::super::concat_along::ConcatAlongKernel - + super::super::concat_along::ConcatAlongKernel + + super::super::concat_tensor_along::ConcatAlongKernel + + super::super::concat_tensor_along::ConcatAlongKernel // optimizers + super::super::adam::AdamKernel diff --git a/dfdx/tests/issue_tests.rs b/dfdx/tests/issue_tests.rs new file mode 100644 index 00000000..23f53e71 --- /dev/null +++ b/dfdx/tests/issue_tests.rs @@ -0,0 +1,36 @@ +use dfdx::prelude::*; +use std::fmt::Debug; + +#[test] +fn test_issue_891() { + #[derive(Default, Debug, Clone, Copy, CustomModule)] + pub struct Id; + + impl Module for Id { + type Output = Input; + fn try_forward(&self, x: Input) -> Result { + Ok(x) + } + } + + #[derive(Default, Debug, Clone, Copy, dfdx_derives::CustomModule)] + struct ConcatTensorAlong + Debug>(pub Ax); + + impl Module for ConcatTensorAlong> + where + Input: TryConcatTensorAlong>, + { + type Output = >>::Output; + + fn try_forward(&self, x: Input) -> Result { + x.try_concat_tensor_along(Axis) + } + } + + type Arch = (SplitInto<(Id, Id)>, ConcatTensorAlong>); + + let dev = Cpu::default(); + let x = dev.tensor([1.]); + let m = dev.build_module::(Arch::default()); + let _y: Tensor, _, _, _> = m.forward(x); +}