Skip to content

Commit

Permalink
Merge branch 'main' into webgpu-abs
Browse files Browse the repository at this point in the history
  • Loading branch information
favilo authored Dec 27, 2023
2 parents 701cd7b + 4615ac1 commit 8edbaf0
Show file tree
Hide file tree
Showing 10 changed files with 679 additions and 28 deletions.
251 changes: 251 additions & 0 deletions dfdx-core/src/tensor/numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,166 @@ impl NumpyDtype for f64 {
}
}

impl NumpyDtype for u8 {
const NUMPY_DTYPE_STR: &'static str = "u1";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 1];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for u16 {
const NUMPY_DTYPE_STR: &'static str = "u2";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 2];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for u32 {
const NUMPY_DTYPE_STR: &'static str = "u4";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 4];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for u64 {
const NUMPY_DTYPE_STR: &'static str = "u8";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 8];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for i8 {
const NUMPY_DTYPE_STR: &'static str = "i1";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 1];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for i16 {
const NUMPY_DTYPE_STR: &'static str = "i2";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 2];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for i32 {
const NUMPY_DTYPE_STR: &'static str = "i4";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 4];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for i64 {
const NUMPY_DTYPE_STR: &'static str = "i8";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 8];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

#[derive(Debug)]
pub enum NpyError {
/// Magic number did not match the expected value.
Expand Down Expand Up @@ -560,4 +720,95 @@ mod tests {
.load_from_npy(file.path())
.expect_err("");
}

#[test]
fn test_0d_u8_save() {
let dev: TestDevice = Default::default();

let x = dev.tensor(0u8);

let file = NamedTempFile::new().expect("failed to create tempfile");

x.save_to_npy(file.path()).expect("Saving failed");

let mut f = File::open(file.path()).expect("No file found");

let mut found = Vec::new();
f.read_to_end(&mut found).expect("Reading failed");

assert_eq!(
&found,
&[
147, 78, 85, 77, 80, 89, 1, 0, 64, 0, 123, 39, 100, 101, 115, 99, 114, 39, 58, 32,
39, 60, 117, 49, 39, 44, 32, 39, 102, 111, 114, 116, 114, 97, 110, 95, 111, 114,
100, 101, 114, 39, 58, 32, 70, 97, 108, 115, 101, 44, 32, 39, 115, 104, 97, 112,
101, 39, 58, 32, 40, 41, 44, 32, 125, 32, 32, 32, 32, 32, 32, 32, 32, 10, 0
]
);
}

#[test]
fn test_0d_u8_load() {
let dev: TestDevice = Default::default();
let x = dev.tensor(2u8);

let file = NamedTempFile::new().expect("failed to create tempfile");

x.save_to_npy(file.path()).expect("Saving failed");

let mut v = dev.tensor(0u8);
v.load_from_npy(file.path()).expect("Loading failed");
assert_eq!(v.array(), x.array());

dev.tensor(0u16).load_from_npy(file.path()).expect_err("");
dev.tensor([0u8; 1])
.load_from_npy(file.path())
.expect_err("");
}

#[test]
fn test_0d_i8_save() {
let dev: TestDevice = Default::default();

let x = dev.tensor(0i8);

let file = NamedTempFile::new().expect("failed to create tempfile");

x.save_to_npy(file.path()).expect("Saving failed");
x.save_to_npy("out.npy").expect("Saving failed");

let mut f = File::open(file.path()).expect("No file found");

let mut found = Vec::new();
f.read_to_end(&mut found).expect("Reading failed");

assert_eq!(
&found,
&[
147, 78, 85, 77, 80, 89, 1, 0, 64, 0, 123, 39, 100, 101, 115, 99, 114, 39, 58, 32,
39, 60, 105, 49, 39, 44, 32, 39, 102, 111, 114, 116, 114, 97, 110, 95, 111, 114,
100, 101, 114, 39, 58, 32, 70, 97, 108, 115, 101, 44, 32, 39, 115, 104, 97, 112,
101, 39, 58, 32, 40, 41, 44, 32, 125, 32, 32, 32, 32, 32, 32, 32, 32, 10, 0
]
);
}

#[test]
fn test_0d_i8_load() {
let dev: TestDevice = Default::default();
let x = dev.tensor(2i8);

let file = NamedTempFile::new().expect("failed to create tempfile");

x.save_to_npy(file.path()).expect("Saving failed");

let mut v = dev.tensor(0i8);
v.load_from_npy(file.path()).expect("Loading failed");
assert_eq!(v.array(), x.array());

dev.tensor(0i16).load_from_npy(file.path()).expect_err("");
dev.tensor([0i8; 1])
.load_from_npy(file.path())
.expect_err("");
}
}
31 changes: 5 additions & 26 deletions dfdx-core/src/tensor_ops/concat_along/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
pub use super::concat_tensor_along::ConcatAlongKernel;
use crate::{shapes::*, tensor::*};

mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;
#[cfg(feature = "webgpu")]
mod webgpu_kernel;

/// Concatenate two tensors along a given axis.
///
/// **Pytorch equivalent** `torch.concat`.
Expand Down Expand Up @@ -48,6 +43,7 @@ mod webgpu_kernel;
/// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4));
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_along(Axis::<1>).realize();
/// ```
#[deprecated = "Use TryConcatTensorAlong or TryConcatShapeAlong instead"]
pub trait TryConcatAlong<Ax>: Sized {
type Output;

Expand All @@ -59,26 +55,7 @@ pub trait TryConcatAlong<Ax>: Sized {
fn try_concat_along(self, ax: Ax) -> Result<Self::Output, Error>;
}

pub trait ConcatAlongKernel<E: Dtype>: Storage<E> {
fn forward<A: Shape, B: Shape, C: Shape>(
&self,
ax: usize,
a: &Tensor<A, E, Self>,
b: &Tensor<B, E, Self>,
c: &mut Tensor<C, E, Self>,
) -> Result<(), Error>;

fn backward<A: Shape, B: Shape>(
&self,
ax: usize,
a: &GhostTensor<A, E, Self>,
grad_a: &mut Self::Vec,
b: &GhostTensor<B, E, Self>,
grad_b: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Error>;
}

#[allow(deprecated)]
impl<A, B, Ax, E: Dtype, D, T: Tape<E, D>, R: Tape<E, D>> TryConcatAlong<Ax>
for (Tensor<A, E, D, T>, Tensor<B, E, D, R>)
where
Expand Down Expand Up @@ -123,6 +100,7 @@ where

macro_rules! impl_concat {
($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => {
#[allow(deprecated)]
impl<A: Dim, B: Dim, $($Head: Dim, )* $($Tail: Dim, )*> TryConcatAlong<Axis<$Ax>>
for (
($($Head, )* A, $($Tail, )*),
Expand Down Expand Up @@ -183,6 +161,7 @@ impl_concat!(4, 6, [D0, D1, D2, D3], [D5]);
impl_concat!(5, 6, [D0, D1, D2, D3, D4], []);

#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::{tensor_ops::*, tests::*};
Expand Down
Loading

0 comments on commit 8edbaf0

Please sign in to comment.