diff --git a/Cargo.toml b/Cargo.toml index 68cc915c..97427663 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,13 @@ safetensors = { version = "0.4.0", default-features = false } memmap2 = { version = "0.9.0", default-features = false } rand = { version = "0.8.5", default-features = false, features = ["std_rng"] } rand_distr = { version = "0.4.3", default-features = false } -libm = "0.2.8" \ No newline at end of file +libm = "0.2.8" + +[patch.crates-io] +crossbeam = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-epoch = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-channel = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-deque = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-queue = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-skiplist = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-utils = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 5309ef7c..0f6cd5c6 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -35,7 +35,7 @@ num-traits = { workspace = true } safetensors = { workspace = true, optional = true } memmap2 = { workspace = true, optional = true } half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } +gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true } diff --git a/dfdx-core/src/data/collate.rs b/dfdx-core/src/data/collate.rs index d38a2a67..5f52d636 100644 --- a/dfdx-core/src/data/collate.rs +++ b/dfdx-core/src/data/collate.rs @@ -55,6 +55,7 @@ impl Collate for Vec<(A, B)> { impl<'a, A, B> Collate for Vec<&'a (A, B)> { type Collated = (Vec<&'a A>, Vec<&'a B>); fn collated(self) -> Self::Collated { + #[allow(clippy::map_identity)] self.into_iter().map(|(a, b)| (a, b)).unzip() } } diff --git a/dfdx-core/src/lib.rs b/dfdx-core/src/lib.rs index 31e61643..c126db2c 100644 --- a/dfdx-core/src/lib.rs +++ b/dfdx-core/src/lib.rs @@ -128,44 +128,6 @@ pub mod prelude { pub use crate::tensor_ops::*; } -/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn flush_denormals_to_zero() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } -} - -/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn keep_denormals() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } -} - #[cfg(test)] pub(crate) mod tests { pub use num_traits::{Float, NumCast, Zero}; diff --git a/dfdx-core/src/tensor/cache.rs b/dfdx-core/src/tensor/cache.rs index e785cb64..ad18e5c1 100644 --- a/dfdx-core/src/tensor/cache.rs +++ b/dfdx-core/src/tensor/cache.rs @@ -33,21 +33,35 @@ pub(crate) struct AllocationKey { /// valid allocation. When the last value is removed from the list, the key /// is removed. #[derive(Debug)] -pub(crate) struct TensorCache { +pub(crate) struct TensorCache, DeviceDev = ()> { pub(crate) allocations: RwLock>>, pub(crate) enabled: RwLock, + device_dev: DeviceDev, } -impl Default for TensorCache { +impl, DeviceDev: Default> Default for TensorCache { fn default() -> Self { Self { allocations: Default::default(), enabled: RwLock::new(false), + device_dev: DeviceDev::default(), } } } -impl TensorCache { +#[allow(dead_code)] +impl, DeviceDev> TensorCache { + /// Initiate an empty [TensorCache] with a given `device_dev`. + pub(crate) fn new(device_dev: DeviceDev) -> Self { + Self { + allocations: Default::default(), + enabled: RwLock::new(false), + device_dev, + } + } +} + +impl, DeviceDev> TensorCache { /// Returns the number of allocations in the cache. #[allow(unused)] pub(crate) fn len(&self) -> usize { @@ -183,6 +197,60 @@ impl TensorCache { } } +impl, DeviceDev> TensorCache { + /// Deallocates all cached memory on the device and empties the cache. + pub(crate) fn try_clear(&self) -> Result<(), crate::prelude::Error> { + let mut cache = { + #[cfg(not(feature = "no-std"))] + { + self.allocations.write().unwrap() + } + #[cfg(feature = "no-std")] + { + self.allocations.write() + } + }; + + for (&key, allocations) in cache.iter_mut() { + for alloc in allocations.drain(..) { + alloc.dealloc(&key, &self.device_dev); + } + } + cache.clear(); + Ok(()) + } +} + +impl, DeviceDev> Drop for TensorCache { + fn drop(&mut self) { + self.try_clear().unwrap(); + } +} + +/// Functionality internalized by the pointer. +pub(crate) trait CachePtr: Sized { + // by default no deallocation is made for any cache ptr + // ie. they leak + /// Deallocates the memory referred by this pointer. + fn dealloc(self, _key: &AllocationKey, _dev: &Dev) {} +} + +impl CachePtr for bool {} +impl CachePtr for u8 {} +impl CachePtr for u16 {} +impl CachePtr for u32 {} +impl CachePtr for u64 {} +impl CachePtr for u128 {} +impl CachePtr for usize {} +impl CachePtr for i8 {} +impl CachePtr for i16 {} +impl CachePtr for i32 {} +impl CachePtr for i64 {} +impl CachePtr for i128 {} +impl CachePtr for isize {} +impl CachePtr for f32 {} +impl CachePtr for f64 {} + #[cfg(test)] mod test { use super::*; diff --git a/dfdx-core/src/tensor/cpu/device.rs b/dfdx-core/src/tensor/cpu/device.rs index d3ce936f..1c6789fc 100644 --- a/dfdx-core/src/tensor/cpu/device.rs +++ b/dfdx-core/src/tensor/cpu/device.rs @@ -25,7 +25,7 @@ pub struct Cpu { /// A thread safe random number generator. pub(crate) rng: Arc>, /// A thread safe cache of memory allocations that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, } impl Default for Cpu { @@ -47,6 +47,45 @@ impl Cpu { } } +/// Unit struct to represent information needed for managing allocations on the Cpu. +#[derive(Clone, Debug, Default)] +pub(crate) struct CpuDevice; + +impl crate::tensor::cache::CachePtr for BytesPtr { + fn dealloc(self, key: &crate::tensor::cache::AllocationKey, _dev: &CpuDevice) { + assert!(key.num_bytes % key.size == 0); + assert!(key.num_bytes < isize::MAX as usize); + let len = key.num_bytes / key.size; + let cap = len; + // SAFETY: + // - "ptr must have been allocated using the global allocator, such as via the alloc::alloc function." + // - ✅ cpu uses global allocator + // - "T needs to have the same alignment as what ptr was allocated with." + // - ✅ we are matching on the alignment below + // - "The size of T times the capacity needs to be the same size as the pointer was allocated with." + // - ✅ covered by `key.num_bytes / key.size` and the `key.num_bytes % key.size == 0` assertion above + // - "length needs to be less than or equal to capacity." + // - ✅ they are equal + // - "The first length values must be properly initialized values of type T." + // - ✅ any bit pattern is valid for unsigned ints used below + // - "capacity needs to be the capacity that the pointer was allocated with." + // - ✅ handled by assertion above (key.num_bytes % key.size == 0) + // - "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset." + // - ✅ handled by assertion above + debug_assert_eq!(std::alloc::Layout::new::().align(), 1); + debug_assert_eq!(std::alloc::Layout::new::().align(), 2); + debug_assert_eq!(std::alloc::Layout::new::().align(), 4); + debug_assert_eq!(std::alloc::Layout::new::().align(), 8); + match key.alignment { + 1 => unsafe { drop(Vec::from_raw_parts(self.0, len, cap)) }, + 2 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u16, len, cap)) }, + 4 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u32, len, cap)) }, + 8 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u64, len, cap)) }, + _ => unreachable!(), + }; + } +} + /// A [Vec] that can be cloned without allocating new memory. /// When [Drop]ed it will insert it's data into the cache. #[derive(Debug)] @@ -54,7 +93,7 @@ pub struct CachableVec { /// The data stored in this vector. pub(crate) data: Vec, /// A cache of memory allocations that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, } impl Clone for CachableVec { @@ -166,45 +205,6 @@ impl Cache for Cpu { } fn try_empty_cache(&self) -> Result<(), Error> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&key, allocations) in cache.iter_mut() { - assert!(key.num_bytes % key.size == 0); - assert!(key.num_bytes < isize::MAX as usize); - let len = key.num_bytes / key.size; - let cap = len; - for alloc in allocations.drain(..) { - // SAFETY: - // - "ptr must have been allocated using the global allocator, such as via the alloc::alloc function." - // - ✅ cpu uses global allocator - // - "T needs to have the same alignment as what ptr was allocated with." - // - ✅ we are matching on the alignment below - // - "The size of T times the capacity needs to be the same size as the pointer was allocated with." - // - ✅ covered by `key.num_bytes / key.size` and the `key.num_bytes % key.size == 0` assertion above - // - "length needs to be less than or equal to capacity." - // - ✅ they are equal - // - "The first length values must be properly initialized values of type T." - // - ✅ any bit pattern is valid for unsigned ints used below - // - "capacity needs to be the capacity that the pointer was allocated with." - // - ✅ handled by assertion above (key.num_bytes % key.size == 0) - // - "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset." - // - ✅ handled by assertion above - debug_assert_eq!(std::alloc::Layout::new::().align(), 1); - debug_assert_eq!(std::alloc::Layout::new::().align(), 2); - debug_assert_eq!(std::alloc::Layout::new::().align(), 4); - debug_assert_eq!(std::alloc::Layout::new::().align(), 8); - match key.alignment { - 1 => unsafe { drop(Vec::from_raw_parts(alloc.0, len, cap)) }, - 2 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u16, len, cap)) }, - 4 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u32, len, cap)) }, - 8 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u64, len, cap)) }, - _ => unreachable!(), - }; - } - } - cache.clear(); - Ok(()) + self.cache.try_clear() } } diff --git a/dfdx-core/src/tensor/cuda/device.rs b/dfdx-core/src/tensor/cuda/device.rs index de6f7196..fc9c8225 100644 --- a/dfdx-core/src/tensor/cuda/device.rs +++ b/dfdx-core/src/tensor/cuda/device.rs @@ -29,7 +29,7 @@ pub struct Cuda { /// A second stream for kernels to optionally execute on. pub(crate) par_stream: Arc, pub(crate) workspace: Arc>>, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } impl From for Error { @@ -77,6 +77,7 @@ impl Cuda { let cudnn = cudarc::cudnn::Cudnn::new(dev.clone())?; let par_stream = Arc::new(dev.fork_default_stream()?); let workspace = Arc::new(Mutex::new(dev.alloc_zeros::(1)?)); + let cache = Arc::new(TensorCache::new(Arc::clone(&dev))); Ok(Self { cpu, dev, @@ -85,7 +86,7 @@ impl Cuda { cudnn, par_stream, workspace, - cache: Default::default(), + cache, }) } } @@ -100,7 +101,7 @@ impl Cuda { ) -> Result, Error> { let data = self.cache.try_pop::(len).map_or_else( || self.dev.alloc::(len), - |ptr| Ok(self.dev.upgrade_device_ptr(ptr, len)), + |ptr| Ok(self.dev.upgrade_device_ptr(ptr.0, len)), )?; Ok(data) } @@ -122,6 +123,18 @@ impl Cuda { } } +/// A pointer to a bytes on the Cuda device. Used in conjunction with [TensorCache]. +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct CudaBytesPtr(pub(crate) CUdeviceptr); + +impl crate::tensor::cache::CachePtr> for CudaBytesPtr { + fn dealloc(self, key: &crate::tensor::cache::AllocationKey, dev: &Arc) { + let data = unsafe { dev.upgrade_device_ptr::(self.0, key.num_bytes) }; + drop(data); + } +} + /// A [CudaSlice] that can be cloned without allocating new memory. /// When [Drop]ed it will insert it's data into the cache. #[derive(Debug)] @@ -129,7 +142,7 @@ pub struct CachableCudaSlice { /// The actual data. pub(crate) data: CudaSlice, /// A cache of device pointers that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } impl Clone for CachableCudaSlice { @@ -142,7 +155,7 @@ impl Clone for CachableCudaSlice { // SAFETY: // 1. we know that ptr is valid for `num_bytes` because it was registered for that. // 2. we are about to set the memory with dtod_copy - let mut slice = unsafe { dev.upgrade_device_ptr(ptr, len) }; + let mut slice = unsafe { dev.upgrade_device_ptr(ptr.0, len) }; dev.dtod_copy(&self.data, &mut slice).unwrap(); slice }, @@ -209,7 +222,7 @@ impl Drop for CachableCudaSlice { let numel = data.len(); // Get access to the raw pointer without freeing it. let ptr = data.leak(); - self.cache.insert::(numel, ptr); + self.cache.insert::(numel, CudaBytesPtr(ptr)); } } } @@ -232,18 +245,7 @@ impl Cache for Cuda { } fn try_empty_cache(&self) -> Result<(), Error> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&key, allocations) in cache.iter_mut() { - for alloc in allocations.drain(..) { - let data = unsafe { self.dev.upgrade_device_ptr::(alloc, key.num_bytes) }; - drop(data); - } - } - cache.clear(); - Ok(()) + self.cache.try_clear() } } diff --git a/dfdx-core/src/tensor/gradients.rs b/dfdx-core/src/tensor/gradients.rs index 86974ec6..d24e2e32 100644 --- a/dfdx-core/src/tensor/gradients.rs +++ b/dfdx-core/src/tensor/gradients.rs @@ -153,7 +153,7 @@ impl> Gradients { #[inline] pub(crate) fn many_and_ref( &mut self, - ls: &Vec>, + ls: &[impl Tensorlike], r: &impl Tensorlike, ) -> (Vec<&mut D::Vec>, &D::Vec) { for i in 0..ls.len() { diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 1c23989b..52acf5e5 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -109,7 +109,7 @@ pub struct Webgpu { pub(crate) dev: Arc, pub(crate) queue: Arc, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, pub(crate) cs_cache: Arc>>>, } @@ -297,12 +297,22 @@ impl Webgpu { // } } +/// Unit struct to represent information needed for managing allocations on the WebGpu. +#[derive(Clone, Debug, Default)] +pub(crate) struct WebGpuDevice; + +impl crate::tensor::cache::CachePtr for Buffer { + fn dealloc(self, _key: &crate::tensor::cache::AllocationKey, _dev: &WebGpuDevice) { + drop(self) + } +} + #[derive(Debug)] pub struct CachableBuffer { pub(crate) dev: Arc, pub(crate) queue: Arc, pub(crate) data: Buffer, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, pub(crate) _phantom: PhantomData, } @@ -397,17 +407,7 @@ impl Cache for Webgpu { } fn try_empty_cache(&self) -> Result<(), Error> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&_key, allocations) in cache.iter_mut() { - for alloc in allocations.drain(..) { - drop(alloc); - } - } - cache.clear(); - Ok(()) + self.cache.try_clear() } } diff --git a/dfdx-core/src/tensor_ops/conv2d/tests.rs b/dfdx-core/src/tensor_ops/conv2d/tests.rs index 85603de8..b7110a22 100644 --- a/dfdx-core/src/tensor_ops/conv2d/tests.rs +++ b/dfdx-core/src/tensor_ops/conv2d/tests.rs @@ -218,10 +218,10 @@ fn test_conv2d_s4p3k2() { #[test] fn test_batched_conv2d() { let dev: TestDevice = Default::default(); - let x: Tensor, TestDtype, _> = dev.sample_normal(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); let w: Tensor, TestDtype, _> = dev.sample_normal(); - let y: Tensor, _, _, _> = + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()).conv2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); let y0 = y.retaped::(); let grads0 = y.square().mean().backward(); @@ -229,11 +229,11 @@ fn test_batched_conv2d() { let w0 = grads0.get(&w); let x = x - .broadcast::, _>() - .reshape::>(); + .broadcast::, _>() + .reshape::>(); assert_eq!(x.strides, x.shape.strides()); - let y: Tensor, _, _, _> = + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()).conv2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); for i in 0..10 { assert_close_to_tensor!(y0, y.retaped::().select(dev.tensor(i))); @@ -245,7 +245,7 @@ fn test_batched_conv2d() { let x_grad = grads.get(&x) * 10.0; for i in 0..10 { - assert_close_to_tensor!(x0, x_grad.clone().select(dev.tensor(i))); + assert_close_to_tensor!(x0, x_grad.clone().select(dev.tensor(i)), 3e-6); } } @@ -405,7 +405,7 @@ fn test_conv2d_grouped() { fn test_conv2d_grouped_slices() { const NUM_GROUPS: usize = 3; let dev: TestDevice = Default::default(); - let x: Tensor, TestDtype, _> = dev.sample_normal(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); let w: Tensor, TestDtype, _> = dev.sample_normal(); let y = (x.leaky_trace(), w.clone()).conv2d( @@ -419,7 +419,7 @@ fn test_conv2d_grouped_slices() { let x_group = x .clone() .slice((.., 3 * i..3 * (i + 1), .., ..)) - .realize::<(Const<2>, Const<3>, Const<14>, Const<14>)>(); + .realize::<(Const<2>, Const<3>, Const<3>, Const<3>)>(); let w_group = w .clone() .slice((5 * i..5 * (i + 1), .., .., ..)) @@ -428,7 +428,7 @@ fn test_conv2d_grouped_slices() { let y_group_true = y .retaped::() .slice((.., 5 * i..5 * (i + 1), .., ..)) - .realize::<(Const<2>, Const<5>, Const<12>, Const<12>)>(); + .realize::<(Const<2>, Const<5>, Const<1>, Const<1>)>(); assert_close_to_tensor!(y_group, y_group_true); } @@ -440,7 +440,7 @@ fn test_conv2d_grouped_slices() { let x_group = x .clone() .slice((.., 3 * i..3 * (i + 1), .., ..)) - .realize::<(Const<2>, Const<3>, Const<14>, Const<14>)>(); + .realize::<(Const<2>, Const<3>, Const<3>, Const<3>)>(); let w_group = w .clone() .slice((5 * i..5 * (i + 1), .., .., ..)) @@ -452,7 +452,7 @@ fn test_conv2d_grouped_slices() { let x_grad_group_true = x_grad .clone() .slice((.., 3 * i..3 * (i + 1), .., ..)) - .realize::<(Const<2>, Const<3>, Const<14>, Const<14>)>(); + .realize::<(Const<2>, Const<3>, Const<3>, Const<3>)>(); let w_grad_group_true = w_grad .clone() .slice((5 * i..5 * (i + 1), .., .., ..)) diff --git a/dfdx-core/src/tensor_ops/convtrans2d/tests.rs b/dfdx-core/src/tensor_ops/convtrans2d/tests.rs index 3d64acbf..c3670294 100644 --- a/dfdx-core/src/tensor_ops/convtrans2d/tests.rs +++ b/dfdx-core/src/tensor_ops/convtrans2d/tests.rs @@ -280,10 +280,10 @@ fn test_convtrans2d_padded() { #[test] fn test_convtrans2d_batched() { let dev: TestDevice = Default::default(); - let x: Tensor, TestDtype, _> = dev.sample_normal(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); let w: Tensor, TestDtype, _> = dev.sample_normal(); - let y: Tensor, _, _, _> = + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); let y0 = y.retaped::(); let grads0 = y.square().mean().backward(); @@ -291,10 +291,10 @@ fn test_convtrans2d_batched() { let w0 = grads0.get(&w); let x = x - .broadcast::, _>() - .reshape::>(); + .broadcast::, _>() + .reshape::>(); - let y: Tensor, _, _, _> = + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); for i in 0..10 { assert_close_to_tensor!(y0, y.retaped::().select(dev.tensor(i)), 1e-5); diff --git a/dfdx-core/src/tensor_ops/log_softmax.rs b/dfdx-core/src/tensor_ops/log_softmax.rs index 487c33e5..d98bc330 100644 --- a/dfdx-core/src/tensor_ops/log_softmax.rs +++ b/dfdx-core/src/tensor_ops/log_softmax.rs @@ -81,7 +81,7 @@ mod tests { #[test] fn test_log_softmax_equivalence() { let dev: TestDevice = Default::default(); - let t: Tensor, TestDtype, _> = dev.sample_normal(); + let t: Tensor, TestDtype, _> = dev.sample_normal(); let p = t.leaky_trace().log_softmax::>(); let p_truth = t.leaky_trace() - t.leaky_trace().logsumexp::<_, Axis<3>>().broadcast(); // we can't create an array as it will overflow the stack diff --git a/dfdx-core/src/tensor_ops/matmul/cpu_kernel.rs b/dfdx-core/src/tensor_ops/matmul/cpu_kernel.rs index bf3e6ce0..e22af5a0 100644 --- a/dfdx-core/src/tensor_ops/matmul/cpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/matmul/cpu_kernel.rs @@ -90,7 +90,7 @@ impl MatMulImpl> for Cpu { false, false, false, - gemm::Parallelism::Rayon(rayon::current_num_threads()), + gemm::Parallelism::None, ) } } @@ -138,7 +138,7 @@ impl MatMulImpl for Cpu { false, false, false, - gemm::Parallelism::Rayon(rayon::current_num_threads()), + gemm::Parallelism::None, ) } } @@ -180,7 +180,7 @@ impl MatMulImpl for Cpu { false, false, false, - gemm::Parallelism::Rayon(rayon::current_num_threads()), + gemm::Parallelism::None, ) } } @@ -222,7 +222,7 @@ impl MatMulImpl for Cpu { false, false, false, - gemm::Parallelism::Rayon(rayon::current_num_threads()), + gemm::Parallelism::None, ) } } diff --git a/dfdx-core/src/tensor_ops/matmul/mod.rs b/dfdx-core/src/tensor_ops/matmul/mod.rs index 5e4d03b3..d133b9ab 100644 --- a/dfdx-core/src/tensor_ops/matmul/mod.rs +++ b/dfdx-core/src/tensor_ops/matmul/mod.rs @@ -346,21 +346,21 @@ mod tests { } { - let a: Tensor, TestDtype, _> = dev.zeros(); + let a: Tensor, TestDtype, _> = dev.zeros(); let b: Tensor, TestDtype, _> = dev.zeros(); - let _: Tensor, TestDtype, _> = a.matmul(b); + let _: Tensor, TestDtype, _> = a.matmul(b); } { - let a: Tensor, TestDtype, _> = dev.zeros(); - let b: Tensor, TestDtype, _> = dev.zeros(); - let _: Tensor, TestDtype, _> = a.matmul(b); + let a: Tensor, TestDtype, _> = dev.zeros(); + let b: Tensor, TestDtype, _> = dev.zeros(); + let _: Tensor, TestDtype, _> = a.matmul(b); } { - let a: Tensor, TestDtype, _> = dev.zeros(); - let b: Tensor, TestDtype, _> = dev.zeros(); - let _: Tensor, TestDtype, _> = a.matmul(b); + let a: Tensor, TestDtype, _> = dev.zeros(); + let b: Tensor, TestDtype, _> = dev.zeros(); + let _: Tensor, TestDtype, _> = a.matmul(b); } } @@ -427,7 +427,7 @@ mod tests { #[test] fn test_matmul_broadcast() { - const N: usize = 5; + const N: usize = 2; let dev: TestDevice = Default::default(); let a: Tensor, TestDtype, _> = dev.sample_normal(); let a_array = a.array(); @@ -458,7 +458,7 @@ mod tests { #[test] fn test_matmul_broadcast_actual() { - const N: usize = 5; + const N: usize = 2; let dev: TestDevice = Default::default(); let a: Tensor, TestDtype, _> = dev.sample_normal(); let b: Tensor, TestDtype, _> = dev.sample_normal(); @@ -476,9 +476,9 @@ mod tests { fn test_matmul_batched_3d() { let dev: TestDevice = Default::default(); - let a: Tensor, TestDtype, _> = dev.sample_normal(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); let a_array = a.array(); - let b: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); let b_array = b.array(); let c = a.leaky_trace().matmul(b.clone()); let c_array = c.array(); @@ -487,7 +487,7 @@ mod tests { let g_a = g.get(&a).array(); let g_b = g.get(&b).array(); - for i in 0..5 { + for i in 0..2 { let sub_a = dev.tensor(a_array[i]); let sub_b = dev.tensor(b_array[i]); let sub_c = sub_a.leaky_trace().matmul(sub_b.clone()); @@ -502,9 +502,9 @@ mod tests { fn test_matmul_batched_4d() { let dev: TestDevice = Default::default(); - let a: Tensor, TestDtype, _> = dev.sample_normal(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); let a_array = a.array(); - let b: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); let b_array = b.array(); let c = a.leaky_trace().matmul(b.clone()); let c_array = c.array(); @@ -513,8 +513,8 @@ mod tests { let g_a = g.get(&a).array(); let g_b = g.get(&b).array(); - for i in 0..7 { - for j in 0..5 { + for i in 0..2 { + for j in 0..3 { let sub_a = dev.tensor(a_array[i][j]); let sub_b = dev.tensor(b_array[i][j]); let sub_c = sub_a.leaky_trace().matmul(sub_b.clone()); diff --git a/dfdx-core/src/tensor_ops/softmax.rs b/dfdx-core/src/tensor_ops/softmax.rs index 0a6ec8aa..a45436c8 100644 --- a/dfdx-core/src/tensor_ops/softmax.rs +++ b/dfdx-core/src/tensor_ops/softmax.rs @@ -91,7 +91,7 @@ mod tests { #[test] fn test_softmax_equivalence() { let dev: TestDevice = Default::default(); - let t: Tensor, TestDtype, _> = dev.sample_normal(); + let t: Tensor, TestDtype, _> = dev.sample_normal(); let p = t.leaky_trace().softmax::>(); let p_truth = t.leaky_trace().log_softmax::>().exp(); // we can't create an array as it will overflow the stack diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..91f87cf6 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -114,25 +114,49 @@ pub trait Device: + crate::tensor_ops::axpy::AxpyKernel // conv1d - + super::super::conv1d::Conv1DKernel + + NonCudnnCuda +{ +} + +#[cfg(feature = "cudnn")] +pub trait NonCudnnCuda {} + +#[cfg(not(feature = "cudnn"))] +pub trait NonCudnnCuda: + // conv1d + super::super::conv1d::Conv1DKernel { } #[cfg(feature = "f16")] -impl Device for crate::tensor::Cpu {} -#[cfg(feature = "f16")] -impl Device> for crate::tensor::Cpu {} +mod f16_ { + use super::*; + impl Device for crate::tensor::Cpu {} + impl NonCudnnCuda for crate::tensor::Cpu {} + impl Device> for crate::tensor::Cpu {} + impl NonCudnnCuda> for crate::tensor::Cpu {} +} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} #[cfg(all(feature = "cuda", feature = "f16"))] -impl Device for crate::tensor::Cuda {} -#[cfg(all(feature = "cuda", feature = "f16"))] -impl Device> for crate::tensor::Cuda {} -#[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda_f16 { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device> for crate::tensor::Cuda {} + impl NonCudnnCuda> for crate::tensor::Cuda {} +} #[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} +} // TODO: How can we implement this for f16 when WGSL doesn't support f16 yet? // #[cfg(all(feature = "webgpu", feature = "f16"))] @@ -140,7 +164,11 @@ impl Device for crate::tensor::Cuda {} // #[cfg(all(feature = "webgpu", feature = "f16"))] // impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} +mod webgpu { + use super::*; + impl Device for crate::tensor::Webgpu {} + impl NonCudnnCuda for crate::tensor::Webgpu {} +} // TODO: How can we implement this for f64 when WGSL doesn't support f64 yet? // #[cfg(feature = "webgpu")] diff --git a/dfdx/examples/12-mnist.rs b/dfdx/examples/12-mnist.rs index 705d14c8..00d43452 100644 --- a/dfdx/examples/12-mnist.rs +++ b/dfdx/examples/12-mnist.rs @@ -62,9 +62,6 @@ type Mlp = ( const BATCH_SIZE: usize = 32; fn main() { - // ftz substantially improves performance - dfdx::flush_denormals_to_zero(); - let mnist_path = std::env::args() .nth(1) .unwrap_or_else(|| "./datasets/MNIST/raw".to_string()); diff --git a/dfdx/src/nn/layers/conv1d.rs b/dfdx/src/nn/layers/conv1d.rs index 5241b912..0986d1af 100644 --- a/dfdx/src/nn/layers/conv1d.rs +++ b/dfdx/src/nn/layers/conv1d.rs @@ -174,47 +174,50 @@ mod tests { fn test_grouped_forward_sizes() { let dev: TestDevice = Default::default(); - let x = dev.ones::>(); + let x = dev.ones::>(); - let m = dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let x = dev.ones::>(); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let x = dev.ones::>(); + + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let x = dev.ones::>(); let m = dev.build_module::(>::default()); let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x); + let _: Tensor, _, _> = m.forward(x); } #[rustfmt::skip] #[test] fn test_forward_4d_sizes() { let dev: TestDevice = Default::default(); - let x = dev.zeros::>(); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); } #[test] @@ -248,7 +251,7 @@ mod tests { let weight_init = m.weight.clone(); let mut opt = crate::nn::optim::Sgd::new(&m, Default::default()); - let out = m.forward(dev.sample_normal::>().leaky_trace()); + let out = m.forward(dev.sample_normal::>().leaky_trace()); let g = out.square().mean().backward(); assert_ne!(g.get(&m.weight).array(), [[[TestDtype::zero(); 3]; 2]; 4]); diff --git a/dfdx/src/nn/layers/conv2d.rs b/dfdx/src/nn/layers/conv2d.rs index c88ea821..a4cd0b5e 100644 --- a/dfdx/src/nn/layers/conv2d.rs +++ b/dfdx/src/nn/layers/conv2d.rs @@ -197,48 +197,53 @@ mod tests { fn test_grouped_forward_sizes() { let dev: TestDevice = Default::default(); - let x = dev.zeros::>(); + let x = dev.zeros::>(); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let x = dev.zeros::>(); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let x = dev.zeros::>(); + + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let x = dev.zeros::>(); let m = dev.build_module::(>::default()); let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x); + let _: Tensor, _, _> = m.forward(x); } #[rustfmt::skip] #[test] fn test_forward_4d_sizes() { let dev: TestDevice = Default::default(); - let x = dev.zeros::>(); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); } #[test] @@ -267,17 +272,17 @@ mod tests { fn test_conv_with_optimizer() { let dev: TestDevice = Default::default(); - let mut m = dev.build_module::(Conv2DConstConfig::<2, 4, 3>::default()); + let mut m = dev.build_module::(Conv2DConstConfig::<2, 3, 2>::default()); let weight_init = m.weight.clone(); let mut opt = crate::nn::optim::Sgd::new(&m, Default::default()); - let out = m.forward(dev.sample_normal::>().leaky_trace()); + let out = m.forward(dev.sample_normal::>().leaky_trace()); let g = out.square().mean().backward(); assert_ne!( g.get(&m.weight).array(), - [[[[TestDtype::zero(); 3]; 3]; 2]; 4] + [[[[TestDtype::zero(); 2]; 2]; 2]; 3] ); opt.update(&mut m, &g).expect("unused params"); diff --git a/dfdx/src/nn/layers/conv_trans2d.rs b/dfdx/src/nn/layers/conv_trans2d.rs index b7683676..943f3d85 100644 --- a/dfdx/src/nn/layers/conv_trans2d.rs +++ b/dfdx/src/nn/layers/conv_trans2d.rs @@ -180,16 +180,24 @@ mod tests { #[test] fn test_forward_4d_sizes() { let dev: TestDevice = Default::default(); - let x = dev.zeros::>(); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + + let x = dev.zeros::>(); + + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + + let x = dev.zeros::>(); + + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + + let x = dev.zeros::>(); + + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); } #[test] @@ -225,7 +233,7 @@ mod tests { let weight_init = m.weight.clone(); let mut opt = crate::nn::optim::Sgd::new(&m, Default::default()); - let out = m.forward(dev.sample_normal::>().leaky_trace()); + let out = m.forward(dev.sample_normal::>().leaky_trace()); let g = out.square().mean().backward(); assert_ne!( diff --git a/dfdx/src/nn/layers/multi_head_attention.rs b/dfdx/src/nn/layers/multi_head_attention.rs index 1232b433..fba2d08c 100644 --- a/dfdx/src/nn/layers/multi_head_attention.rs +++ b/dfdx/src/nn/layers/multi_head_attention.rs @@ -208,11 +208,11 @@ mod tests { fn test_mha_batched() { let dev = TestDevice::seed_from_u64(1); - const BATCH: usize = 5; - const M: usize = 8; + const BATCH: usize = 2; + const M: usize = 4; const NUM_HEADS: usize = 2; const S1: usize = 3; - const S2: usize = 4; + const S2: usize = 2; type Dtype = f32; @@ -224,42 +224,37 @@ mod tests { let k: Tensor, Dtype, _> = dev.sample_normal(); let v: Tensor, Dtype, _> = dev.sample_normal(); - let y = mha.forward((q, k, v)); + // uncomment to save for this specific test params and inputs + // + // mha.save_safetensors("mha.safetensor").unwrap(); + // q.save_safetensors("q.safetensor").unwrap(); + // k.save_safetensors("k.safetensor").unwrap(); + // v.save_safetensors("v.safetensor").unwrap(); + + let y = mha.forward((q.clone(), k.clone(), v.clone())); + + // uncomment to save for this specific test params and inputs + // + // y.save_safetensors("y.safetensor").unwrap(); // This expected y was generated by: // 1. saving `mha` parameters, `q`, `k`, `v` to a file // 2. Running pytorch with the same values // 3. printing out the output // See https://github.com/coreylowman/dfdx/wiki/Exporting-MultiHeadAttention-to-pytorch-for-unit-tests - #[rustfmt::skip] assert_close_to_literal!( y, [ [ - [-0.32666653, 0.23977730, 0.25563523,-0.46537930, 0.19651681,-0.37467819, 0.44978297, 0.04501118], - [-0.32847843, 0.22905068, 0.24268147,-0.49660331, 0.17547092,-0.41919118, 0.45197228,-0.01052883], - [-0.28976738, 0.26420441, 0.24134403,-0.41927847, 0.21895495,-0.35072452, 0.44843924, 0.07374063], + [-0.16630043, 0.01757687, 0.22978050, 0.50355506], + [-0.19439587, 0.02942148, 0.23266082, 0.48612449], + [-0.19675586, 0.06542480, 0.18101424, 0.43833256] ], [ - [-0.10029950, 0.15455982, 0.23578438,-0.36703593, 0.03778699,-0.41743413, 0.50207543, 0.11432818], - [-0.04076880, 0.24567264, 0.23325926,-0.19454414, 0.11575195,-0.22209120, 0.49752438, 0.30388331], - [-0.06600001, 0.20277922, 0.24651963,-0.24732135, 0.08645092,-0.28015324, 0.49499762, 0.23243824], - ], - [ - [-0.18352799, 0.15783942, 0.36657059,-0.24797240, 0.11065251,-0.22565264, 0.46300891, 0.18687661], - [-0.15986431, 0.26687002, 0.30500177,-0.22695602, 0.18453379,-0.21377291, 0.46498343, 0.30064404], - [-0.09165541, 0.31019136, 0.20057595,-0.29627919, 0.15811513,-0.33667034, 0.48559439, 0.32546705], - ], - [ - [-0.45827997, 0.08988418, 0.44279462,-0.45245945, 0.16884868,-0.26618001, 0.40024126, 0.01272556], - [-0.43258160, 0.11801003, 0.42784777,-0.41539627, 0.19628736,-0.23836099, 0.39999473, 0.05304383], - [-0.44729146, 0.09233949, 0.45179683,-0.41795415, 0.16631508,-0.22713992, 0.39473629, 0.04260518], - ], - [ - [-0.51776350, 0.05404706, 0.39951840,-0.61738086, 0.21067555,-0.51225299, 0.41040331,-0.25894681], - [-0.47914022, 0.09410305, 0.36355501,-0.59280866, 0.24956036,-0.50058168, 0.40235144,-0.16756263], - [-0.55189615,-0.06088167, 0.41224611,-0.76746291, 0.09680001,-0.70136547, 0.40278757,-0.45541200], - ], + [-0.23499183, -0.21414454, 0.32811928, 0.46780989], + [-0.25318044, -0.20085460, 0.37180322, 0.52941465], + [-0.22117066, -0.23581570, 0.36783585, 0.53560883] + ] ] ); } @@ -269,11 +264,11 @@ mod tests { let dev: TestDevice = Default::default(); let mut mha = dev - .build_module::(, Const<4>>>::default()); + .build_module::(, Const<2>>>::default()); - let q: Tensor, TestDtype, _> = dev.sample_normal(); - let k: Tensor, TestDtype, _> = dev.sample_normal(); - let v: Tensor, TestDtype, _> = dev.sample_normal(); + let q: Tensor, TestDtype, _> = dev.sample_normal(); + let k: Tensor, TestDtype, _> = dev.sample_normal(); + let v: Tensor, TestDtype, _> = dev.sample_normal(); let y = mha.forward((q.leaky_trace(), k, v)); let g = y.square().mean().backward(); diff --git a/dfdx/src/nn/layers/transformer.rs b/dfdx/src/nn/layers/transformer.rs index fa7ab76a..a0e50a30 100644 --- a/dfdx/src/nn/layers/transformer.rs +++ b/dfdx/src/nn/layers/transformer.rs @@ -204,38 +204,31 @@ mod tests { fn test_transformer_forward() { let dev = TestDevice::seed_from_u64(0); let mut t = dev.build_module::(TransformerConfig::new( - Const::<16>, - Const::<4>, - Const::<8>, - 3, - 3, + Const::<2>, Const::<2>, Const::<2>, 2, 2, )); // unbatched - let src = dev.sample_normal::>(); - let tgt = dev.sample_normal::>(); - let _: Tensor, _, _, _> = t.forward_mut((src, tgt)); + let src = dev.sample_normal::>(); + let tgt = dev.sample_normal::>(); + let _: Tensor, _, _, _> = t.forward_mut((src, tgt)); // batched - let src = dev.sample_normal::>(); - let tgt = dev.sample_normal::>(); - let _: Tensor, _, _, _> = t.forward_mut((src, tgt)); + let src = dev.sample_normal::>(); + let tgt = dev.sample_normal::>(); + let _: Tensor, _, _, _> = t.forward_mut((src, tgt)); } #[test] fn test_transformer_backward() { let dev = TestDevice::seed_from_u64(0); + let mut t = dev.build_module::(TransformerConfig::new( - Const::<16>, - Const::<4>, - Const::<8>, - 3, - 3, + Const::<2>, Const::<2>, Const::<2>, 2, 2, )); - let src = dev.sample_normal::>(); - let tgt = dev.sample_normal::>(); - let out: Tensor, _, _, _> = t.forward_mut((src.leaky_trace(), tgt)); + let src = dev.sample_normal::>(); + let tgt = dev.sample_normal::>(); + let out: Tensor, _, _, _> = t.forward_mut((src.leaky_trace(), tgt)); let g = out.mean().backward(); let mut opt = crate::nn::optim::Sgd::new(&t, Default::default()); @@ -246,11 +239,11 @@ mod tests { fn test_encoder_block_forward() { let dev = TestDevice::seed_from_u64(2); - const BATCH: usize = 3; - const SEQ_LEN: usize = 5; - const EMBED_DIM: usize = 9; - const NUM_HEADS: usize = 3; - const FF_DIM: usize = 16; + const BATCH: usize = 2; + const SEQ_LEN: usize = 3; + const EMBED_DIM: usize = 4; + const NUM_HEADS: usize = 2; + const FF_DIM: usize = 2; type Dtype = f32; @@ -261,38 +254,36 @@ mod tests { )); let x: Tensor, Dtype, _> = dev.sample_normal(); + + // uncomment to save for this specific test params and inputs + // + // encoder.save_safetensors("encoder.safetensor").unwrap(); + // x.save_safetensors("x.safetensor").unwrap(); + let y = encoder.forward(x); + // uncomment to save for this specific test params and inputs + // + // y.save_safetensors("y.safetensor").unwrap(); + // This expected y was generated by: // 1. saving `encoder` parameters, `x` and `y` to a npz files // 2. Running pytorch with the same values // 3. printing out the output // See https://github.com/coreylowman/dfdx/wiki/Exporting-MultiHeadAttention-to-pytorch-for-unit-tests - #[rustfmt::skip] assert_close_to_literal!( y, [ [ - [0.83316803, 0.85057360, 0.37431455, 1.48506296,-0.38405111,-1.89352179,-1.07049453,-0.50913972, 0.31408834], - [-0.57205188, 0.64078861,-0.56589824, 0.67155081, 0.65419787, 0.28409126,-1.75282931, 1.68111539,-1.04096484], - [-0.01414229, 1.34985816, 0.09684382, 0.13165890,-1.39875984,-1.61741352, 1.28747427, 0.75574619,-0.59126562], - [0.12542287, 2.60457349, 0.21064451,-0.81285846,-0.15861531,-0.87273139,-0.81707120,-0.17004849,-0.10931605], - [-1.54970682,-0.77183282, 1.37495196,-0.69562960,-0.66684282, 0.24720824, 1.38581741,-0.35962212, 1.03565681], + [-1.7209842, 0.6216407, 0.7037436, 0.39559996], + [0.53576326, -1.4666773, 1.2166189, -0.28570476], + [-1.3280064, 0.42387456, -0.45566577, 1.3597975] ], [ - [-0.15229249,-0.90768278,-0.85165489, 0.12768827, 1.61459768, 1.25826979,-0.46860829, 0.87496787,-1.49528503], - [-1.35595357, 1.13305736,-0.08542954, 1.01601434,-0.04678532,-1.69470263, 0.76144469,-0.68443829, 0.95679283], - [-1.49877191, 0.64559501, 0.33383703, 1.73698330,-0.14289393, 1.17869902,-1.01659226,-0.61038357,-0.62647283], - [0.78263682, 0.78481543,-0.16064386, 1.03396618, 1.49144781,-1.55002558,-1.11833119,-0.62120575,-0.64265978], - [-1.58957553, 1.75000548, 0.01272983, 0.11212827,-0.34744453,-1.45086825, 0.95842224, 0.50071126, 0.05389150], - ], - [ - [-1.13160479,-0.21202824, 0.25907388,-0.64313424,-0.76302397,-0.16797650,-0.75345570, 2.01765633, 1.39449334], - [-0.16463053,-0.73241645,-0.69120175, 0.13771832, 0.72443259,-2.06525135, 1.02475107, 1.40244913, 0.36414924], - [0.38766465,-0.19543301,-1.80767059, 1.11545098, 0.21692322,-1.22834778, 0.13580292, 1.63094711,-0.25533777], - [1.22877085, 0.05472810, 0.65142977, 0.73869365,-0.74706972,-1.29277837, 1.07350135, 0.06228387,-1.76955938], - [-0.01733636,-1.57447529, 0.79691470, 1.00687420, 1.65637493,-0.75668150,-0.54616517, 0.45799020,-1.02349579], - ], + [0.89139193, -1.2803736, 1.0577338, -0.668752], + [-0.41001588, 1.6245831, -1.084222, -0.13034514], + [0.9247901, -1.1639801, -0.8187512, 1.0579412] + ] ] ); } @@ -301,11 +292,11 @@ mod tests { fn test_decoder_block_forward() { let dev = TestDevice::seed_from_u64(2); - const BATCH: usize = 4; - const S1: usize = 8; - const S2: usize = 6; - const EMBED_DIM: usize = 12; - const NUM_HEADS: usize = 6; + const BATCH: usize = 2; + const S1: usize = 3; + const S2: usize = 2; + const EMBED_DIM: usize = 4; + const NUM_HEADS: usize = 2; const FF_DIM: usize = 2; type Dtype = f32; @@ -318,57 +309,39 @@ mod tests { let tgt: Tensor, Dtype, _> = dev.sample_normal(); let mem: Tensor, Dtype, _> = dev.sample_normal(); + + // uncomment to save for this specific test params and inputs + // + // decoder.save_safetensors("decoder.safetensor").unwrap(); + // tgt.save_safetensors("tgt.safetensor").unwrap(); + // mem.save_safetensors("mem.safetensor").unwrap(); + let y = decoder.forward((tgt, mem)); + // uncomment to save for this specific test params and inputs + // + // y.save_safetensors("y.safetensor").unwrap(); + + println!("{:?}", y.array()); + // This expected y was generated by: // 1. saving `decoder` parameters, `tgt`, `mem` and `y` to a npz files // 2. Running pytorch with the same values // 3. printing out the output // See https://github.com/coreylowman/dfdx/wiki/Exporting-MultiHeadAttention-to-pytorch-for-unit-tests - #[rustfmt::skip] assert_close_to_literal!( y, [ [ - [-1.87558722, 0.45965099, 0.20498508,-1.73645127, 1.19475269,-0.07198015, 1.87802076, 0.18534835, 0.09591459,-0.19824848,-0.35261178, 0.21620668], - [-1.65146410, 0.36979428, 2.44077325, 0.06124005,-1.35236311, 0.06834260, 0.15826070,-0.82507777, 0.37757808, 0.65084165,-0.26028851,-0.03763753], - [-0.30696073,-0.83636290, 1.20258296, 0.11318116, 2.23617601,-0.58318114, 0.66371393,-0.26198950,-0.46798199,-1.64899850, 0.63527161,-0.74545103], - [-0.23854624,-1.12693906, 1.16869855,-0.19282928, 1.83873713,-0.11721543, 1.00944722,-0.97332841,-0.75959450,-0.69980252, 1.23692346,-1.14555120], - [1.36781275,-1.00360036,-0.45941362, 1.16563404, 0.24138503, 0.51682448,-0.20305091,-0.68849629, 0.21949562,-2.32909155, 1.11119950, 0.06130134], - [-0.70381856, 1.24304760, 1.32746470, 0.43500248,-1.45963287,-0.33785006, 0.95192397,-0.72454590,-0.56011575,-1.33778274, 1.46311414,-0.29680732], - [-0.72720474,-1.29362297, 0.24656427, 0.25788289,-1.20061839, 0.20161679,-0.18183309,-0.28182927, 1.85331190,-0.41204709, 2.05122447,-0.51344484], - [-0.45356780, 1.31273413, 0.69735909,-1.96937740, 0.33488208,-0.99047261, 0.59060574,-0.65752614, 1.89437556,-0.41522720,-0.09553659,-0.24824893], - ], - [ - [0.92695564,-0.37954834, 0.74523187, 0.91893858, 0.26190025,-1.12540352, 0.87693417,-0.56255865, 0.20910029,-2.21528411, 1.21251309,-0.86877924], - [-0.94927889,-1.28225541, 1.38664925,-0.47819123, 1.60083365,-0.25243780, 1.21168947,-0.77403182, 0.60282439,-0.67139530, 0.72949010,-1.12389636], - [0.32318670, 0.44635653, 0.69037175,-2.00356507, 0.31796345,-1.09540510, 1.65720248, 0.18892130, 0.52996045,-0.80869401, 0.91539401,-1.16169262], - [-0.93624949, 0.90174866,-0.35485053, 0.28630549,-0.67549163,-1.74944031, 0.75101191, 0.73161471, 2.11734390,-0.91214812, 0.20135719,-0.36120197], - [-0.12938653,-0.65747797, 2.05397773,-1.01142454,-0.12065405,-2.02726126, 0.42845321, 0.56529117, 1.02239680, 0.41882706, 0.12460811,-0.66735017], - [1.61325872, 1.18383896, 0.58100909,-1.39098096,-0.86362296, 0.16341744,-0.44804084,-0.85499638,-0.94598162, 0.20620863, 1.56031752,-0.80442756], - [0.15400597, 0.30694833,-0.10923728,-1.54726267, 2.59482384,-0.72448921,-0.47337827, 0.94458705,-0.74652761, 0.43154043,-0.49556813,-0.33544219], - [0.06703589,-1.33028281, 1.29519308, 0.01789100, 1.73138475, 0.11349702, 0.98292470,-1.37452459,-0.57708341,-0.04158162, 0.54672015,-1.43117404], - ], - [ - [-1.13928354,-0.41951340, 1.02809525, 1.10831285,-0.37338197, 0.62760144,-0.49609870, 0.89603722, 0.28748062,-2.46635914, 0.32486960, 0.62223953], - [0.66343045, 0.17840990,-0.32520610,-0.91180247,-1.24669814, 0.98684084, 1.03520977,-0.66813290, 2.06043386,-1.47457957, 0.05163103,-0.34953672], - [0.70942575,-1.41629028, 0.57625329, 1.22837853, 0.26442787,-1.24242258,-0.38967255,-0.10485345, 1.34950197,-1.88799143, 0.64463151, 0.26861122], - [-0.90124643, 2.06094766, 0.20568365, 0.06078637, 1.68658400,-0.19301027,-0.56969130,-0.80906254,-1.20984066, 0.12565698, 0.62286967,-1.07967734], - [-0.58323914,-0.91550159, 2.76294446,-0.23104562, 1.03537095,-0.79180622,-0.30585235,-0.37028444, 0.06941666,-0.66646379, 0.61295509,-0.61649406], - [-0.69953281,-0.53587002, 0.10623999,-1.43030167,-1.28995168,-0.84757996,-0.18267554,-0.03703059, 1.55741370, 1.54363191, 0.52537125, 1.29028559], - [-0.70696884,-0.75943643, 1.45195222,-0.89612883,-0.74769866, 0.21710433,-0.64992350,-1.06435382,-0.16617794, 2.16994262, 1.05082333, 0.10086535], - [-0.37381354,-0.70111430, 1.83576059, 0.72364914,-1.35405958, 0.72988695, 0.52067578,-0.01720174,-0.46059695, 1.23575497,-0.43288255,-1.70605886], + [0.94532686, -0.46526614, 0.93781346, -1.4178741], + [1.6348482, -1.0348053, -0.49546495, -0.10457793], + [0.8033758, 1.1668185, -0.823479, -1.146715] ], [ - [-1.20804095, 0.38654494, 1.65309286,-1.20736289, 1.07261550, 0.46114275, 0.83086872,-0.01955486,-1.26059496,-0.11887560, 0.79357809,-1.38341355], - [-0.56300515,-0.59784967, 2.81054258,-0.37848800,-0.41372916,-0.90938121, 0.82510620, 0.12329611, 0.14460202, 0.12636989,-1.24349451, 0.07603064], - [-1.36658132,-1.11734688, 1.74118745, 0.56276298, 0.35426524, 0.82628661,-1.63426054,-0.80171925, 0.09229738, 0.71951282,-0.27681157, 0.90040714], - [-0.47256982,-0.39320827,-1.71228957, 0.24000385, 0.71217608, 1.75911832,-1.24219942,-0.00148612, 0.80727738,-1.04095078, 0.02052352, 1.32360506], - [-0.00462395, 0.10117173, 1.83498573,-0.69001645, 0.46190643,-1.00014806, 1.14456511, 0.55384815, 0.36776620,-0.55358148,-0.00812254,-2.20775104], - [-0.59229124,-1.63409364, 1.70002937, 0.40580338, 0.76335514,-0.50594056, 0.32149875, 1.17081654,-1.73462892, 0.50679129,-0.56456679, 0.16322602], - [-0.28135568, 0.12212670, 1.39109802,-1.15742660, 0.81334966, 0.21747869,-0.01345161, 0.15832950, 0.68586451,-1.60281539, 1.38292646,-1.71612430], - [0.52762824,-1.20023167, 1.34064293,-0.40414453, 0.61767668,-0.24842866, 0.06679908, 1.13988364,-0.66101944,-0.71850598, 1.43029106,-1.89059174], - ], + [1.2232355, -1.5628394, 0.2116476, 0.12795626], + [0.99152863, -0.98818815, 1.0083598, -1.0117002], + [-1.4775288, 0.47518563, -0.23662777, 1.2389709] + ] ] ); }