Skip to content

Commit

Permalink
[Breaking] Combining separate device errors into single `dfdx::tensor…
Browse files Browse the repository at this point in the history
…::Error` enum (#875)

* [Breaking] Adding single Error enum

* Fixing example

* Fixing cuda kernels

* Fixing no-std
  • Loading branch information
coreylowman authored Oct 25, 2023
1 parent 5e0c3dd commit 7df12c0
Show file tree
Hide file tree
Showing 188 changed files with 844 additions and 1,069 deletions.
56 changes: 14 additions & 42 deletions dfdx-core/src/nn_traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@ mod vecs;

use std::vec::Vec;

use crate::prelude::{Device, Dtype, Gradients, Shape, Tensor, UniqueId};
use crate::prelude::{Device, Dtype, Error, Gradients, Shape, Tensor, UniqueId};

/// Mutable & Immutable forward of `Input` that produces [Module::Output].
pub trait Module<X> {
/// The type that this unit produces given `Input`.
type Output;
type Error: std::fmt::Debug;

fn try_forward(&self, x: X) -> Result<Self::Output, Self::Error>;
fn try_forward(&self, x: X) -> Result<Self::Output, Error>;

fn try_forward_mut(&mut self, x: X) -> Result<Self::Output, Self::Error> {
fn try_forward_mut(&mut self, x: X) -> Result<Self::Output, Error> {
self.try_forward(x)
}

Expand All @@ -26,52 +25,25 @@ pub trait Module<X> {
}
}

/// An error indicating that a parameter was not used in gradient
/// computation, and was therefore not present in [Gradients]
/// during an update.
#[derive(Debug)]
pub enum OptimizerUpdateError<Err> {
UnusedTensors(Vec<UniqueId>),
DeviceError(Err),
}

impl<Err: std::fmt::Display> std::fmt::Display for OptimizerUpdateError<Err> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnusedTensors(unused) => write!(f, "Unused tensors: {unused:?}"),
Self::DeviceError(err) => write!(f, "{err}"),
}
}
}

#[cfg(feature = "std")]
impl<Err: std::fmt::Debug + std::fmt::Display> std::error::Error for OptimizerUpdateError<Err> {}

/// Something that can update both tensors and a [UpdateParams]. At minimum [Optimizer::update_tensor()] must be implemented.
pub trait Optimizer<M, E: Dtype, D: Device<E>>: Sized {
fn update_tensor<S: Shape>(
&mut self,
t: &mut Tensor<S, E, D>,
gradients: &Gradients<E, D>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), D::Err>;
) -> Result<(), Error>;

fn update(
&mut self,
module: &mut M,
gradients: &Gradients<E, D>,
) -> Result<(), OptimizerUpdateError<D::Err>>
fn update(&mut self, module: &mut M, gradients: &Gradients<E, D>) -> Result<(), Error>
where
M: UpdateParams<E, D>,
{
let mut missing_tensors = Vec::new();
module
.try_update_params(self, gradients, &mut missing_tensors)
.map_err(OptimizerUpdateError::DeviceError)?;
module.try_update_params(self, gradients, &mut missing_tensors)?;
if missing_tensors.is_empty() {
Ok(())
} else {
Err(OptimizerUpdateError::UnusedTensors(missing_tensors))
Err(Error::UnusedTensors(missing_tensors))
}
}
}
Expand All @@ -82,15 +54,15 @@ pub trait BuildOnDevice<E: Dtype, D: Device<E>>: Clone {
fn build_on_device(&self, device: &D) -> Self::Built {
self.try_build_on_device(device).unwrap()
}
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, D::Err>;
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, crate::tensor::Error>;
}

/// Something that can have all of its parameters reset to a specific state (may be random or not random).
pub trait ResetParams<E: Dtype, D: Device<E>> {
fn reset_params(&mut self) {
self.try_reset_params().unwrap()
}
fn try_reset_params(&mut self) -> Result<(), D::Err>;
fn try_reset_params(&mut self) -> Result<(), crate::tensor::Error>;
}

/// Something that can have it's params updated with an [Optimizer] and a set of [Gradients].
Expand All @@ -109,7 +81,7 @@ pub trait UpdateParams<E: Dtype, D: Device<E>> {
optimizer: &mut Optim,
gradients: &Gradients<E, D>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), D::Err>;
) -> Result<(), crate::tensor::Error>;
}

impl<S: Shape, E: Dtype, D: Device<E>> UpdateParams<E, D> for Tensor<S, E, D> {
Expand All @@ -118,7 +90,7 @@ impl<S: Shape, E: Dtype, D: Device<E>> UpdateParams<E, D> for Tensor<S, E, D> {
optimizer: &mut Optim,
gradients: &Gradients<E, D>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), <D>::Err> {
) -> Result<(), crate::tensor::Error> {
optimizer.update_tensor(self, gradients, missing_tensors)
}
}
Expand All @@ -128,12 +100,12 @@ pub trait ZeroGrads<E: Dtype, D: Device<E>> {
fn zero_grads(&self, grads: &mut Gradients<E, D>) {
self.try_zero_grads(grads).unwrap()
}
fn try_zero_grads(&self, grads: &mut Gradients<E, D>) -> Result<(), D::Err>;
fn try_zero_grads(&self, grads: &mut Gradients<E, D>) -> Result<(), crate::tensor::Error>;

fn alloc_grads(&self) -> Gradients<E, D> {
self.try_alloc_grads().unwrap()
}
fn try_alloc_grads(&self) -> Result<Gradients<E, D>, D::Err> {
fn try_alloc_grads(&self) -> Result<Gradients<E, D>, crate::tensor::Error> {
let mut grads = Gradients::leaky();
self.try_zero_grads(&mut grads)?;
grads.retain_current_grads_as_leafs();
Expand Down Expand Up @@ -275,7 +247,7 @@ pub trait BuildModuleExt<M>: Sized {
self.try_build_module(m).unwrap()
}

fn try_build_module<E: Dtype>(&self, m: M) -> Result<M::Built, Self::Err>
fn try_build_module<E: Dtype>(&self, m: M) -> Result<M::Built, Error>
where
M: BuildOnDevice<E, Self>,
M::Built: ResetParams<E, Self>,
Expand Down
21 changes: 12 additions & 9 deletions dfdx-core/src/nn_traits/tuples.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{dtypes::Dtype, tensor::UniqueId, tensor_ops::Device};
use crate::{
dtypes::Dtype,
tensor::{Error, UniqueId},
tensor_ops::Device,
};

use std::vec::Vec;

Expand All @@ -7,7 +11,7 @@ macro_rules! tuple_impls {

impl<Dev: Device<Elem>, Elem: Dtype, $($name: crate::nn_traits::BuildOnDevice<Elem, Dev>),+> crate::nn_traits::BuildOnDevice<Elem, Dev> for ($($name,)+) {
type Built = ($($name::Built, )+);
fn try_build_on_device(&self, device: &Dev) -> Result<Self::Built, Dev::Err> {
fn try_build_on_device(&self, device: &Dev) -> Result<Self::Built, Error> {
Ok(($(
self.$idx.try_build_on_device(device)?,
)+))
Expand Down Expand Up @@ -38,7 +42,7 @@ macro_rules! tuple_impls {
}

impl<Dev: Device<Elem>, Elem: Dtype, $($name: crate::nn_traits::ResetParams<Elem, Dev>),+> crate::nn_traits::ResetParams<Elem, Dev> for ($($name,)+) {
fn try_reset_params(&mut self) -> Result<(), Dev::Err> {
fn try_reset_params(&mut self) -> Result<(), Error> {
$(self.$idx.try_reset_params()?;)+
Ok(())
}
Expand All @@ -50,14 +54,14 @@ macro_rules! tuple_impls {
optimizer: &mut Optim,
gradients: &crate::prelude::Gradients<Elem, Dev>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), Dev::Err> {
) -> Result<(), Error> {
$(self.$idx.try_update_params(optimizer, gradients, missing_tensors)?;)+
Ok(())
}
}

impl<Dev: Device<Elem>, Elem: Dtype, $($name: crate::nn_traits::ZeroGrads<Elem, Dev>),+> crate::nn_traits::ZeroGrads<Elem, Dev> for ($($name,)+) {
fn try_zero_grads(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>) -> Result<(), Dev::Err> {
fn try_zero_grads(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>) -> Result<(), Error> {
$(self.$idx.try_zero_grads(grads)?;)+
Ok(())
}
Expand Down Expand Up @@ -91,20 +95,19 @@ macro_rules! tuple_impls {
impl<
Input,
$last:
$(crate::nn_traits::Module::<$rev_tail ::Output, Error=$rev_tail::Error>, $rev_tail: )*
$(crate::nn_traits::Module::<$rev_tail ::Output>, $rev_tail: )*
crate::nn_traits::Module<Input>
> crate::nn_traits::Module<Input> for ($($name,)+) {
type Output = $last ::Output;
type Error = $last ::Error;

/// Calls forward sequentially on each module in the tuple.
fn try_forward(&self, x: Input) -> Result<Self::Output, Self::Error> {
fn try_forward(&self, x: Input) -> Result<Self::Output, Error> {
$(let x = self.$idx.try_forward(x)?;)+
Ok(x)
}

/// Calls forward sequentially on each module in the tuple.
fn try_forward_mut(&mut self, x: Input) -> Result<Self::Output, Self::Error> {
fn try_forward_mut(&mut self, x: Input) -> Result<Self::Output, Error> {
$(let x = self.$idx.try_forward_mut(x)?;)+
Ok(x)
}
Expand Down
22 changes: 14 additions & 8 deletions dfdx-core/src/nn_traits/vecs.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use crate::{dtypes::Dtype, tensor::UniqueId, tensor_ops::Device};
use crate::{
dtypes::Dtype,
tensor::{Error, UniqueId},
tensor_ops::Device,
};

use std::vec::Vec;

impl<E: Dtype, D: Device<E>, T: crate::nn_traits::BuildOnDevice<E, D>>
crate::nn_traits::BuildOnDevice<E, D> for Vec<T>
{
type Built = Vec<T::Built>;
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, <D>::Err> {
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, crate::tensor::Error> {
self.iter()
.map(|m_i| m_i.try_build_on_device(device))
.collect()
Expand All @@ -16,7 +20,7 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::BuildOnDevice<E, D>>
impl<E: Dtype, D: Device<E>, T: crate::nn_traits::ResetParams<E, D>>
crate::nn_traits::ResetParams<E, D> for Vec<T>
{
fn try_reset_params(&mut self) -> Result<(), <D>::Err> {
fn try_reset_params(&mut self) -> Result<(), crate::tensor::Error> {
for m_i in self.iter_mut() {
m_i.try_reset_params()?;
}
Expand All @@ -32,7 +36,7 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::UpdateParams<E, D>>
optimizer: &mut Optim,
gradients: &crate::tensor::Gradients<E, D>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), D::Err> {
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter_mut() {
m_i.try_update_params(optimizer, gradients, missing_tensors)?;
}
Expand All @@ -43,7 +47,10 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::UpdateParams<E, D>>
impl<E: Dtype, D: Device<E>, T: crate::nn_traits::ZeroGrads<E, D>> crate::nn_traits::ZeroGrads<E, D>
for Vec<T>
{
fn try_zero_grads(&self, grads: &mut crate::tensor::Gradients<E, D>) -> Result<(), <D>::Err> {
fn try_zero_grads(
&self,
grads: &mut crate::tensor::Gradients<E, D>,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_zero_grads(grads)?;
}
Expand Down Expand Up @@ -82,15 +89,14 @@ impl<Input, T: crate::nn_traits::Module<Input, Output = Input>> crate::nn_traits
for Vec<T>
{
type Output = T::Output;
type Error = T::Error;

fn try_forward(&self, mut x: Input) -> Result<Self::Output, T::Error> {
fn try_forward(&self, mut x: Input) -> Result<Self::Output, Error> {
for m_i in self.iter() {
x = m_i.try_forward(x)?;
}
Ok(x)
}
fn try_forward_mut(&mut self, mut x: Input) -> Result<Self::Output, Self::Error> {
fn try_forward_mut(&mut self, mut x: Input) -> Result<Self::Output, Error> {
for m_i in self.iter_mut() {
x = m_i.try_forward_mut(x)?;
}
Expand Down
Loading

0 comments on commit 7df12c0

Please sign in to comment.