diff --git a/dfdx-core/src/nn_traits/mod.rs b/dfdx-core/src/nn_traits/mod.rs index 20c55da2..52203373 100644 --- a/dfdx-core/src/nn_traits/mod.rs +++ b/dfdx-core/src/nn_traits/mod.rs @@ -116,12 +116,13 @@ pub trait ZeroGrads> { #[cfg(feature = "safetensors")] /// Something that can be saved to a .safetensors file. pub trait SaveSafeTensors { - fn save_safetensors>( + fn save_safetensors_with, F: FnMut(String) -> String>( &self, path: P, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { let mut tensors = Vec::new(); - self.write_safetensors("", &mut tensors); + self.write_safetensors_with("", &mut tensors, key_map); let data = tensors.iter().map(|(k, dtype, shape, data)| { ( k.clone(), @@ -131,53 +132,88 @@ pub trait SaveSafeTensors { safetensors::serialize_to_file(data, &None, path.as_ref()) } - fn write_safetensors( + fn save_safetensors>( + &self, + path: P, + ) -> Result<(), safetensors::SafeTensorError> { + self.save_safetensors_with(path, &mut core::convert::identity) + } + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ); + fn write_safetensors( + &self, + location: &str, + tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + ) { + self.write_safetensors_with(location, tensors, &mut core::convert::identity) + } } #[cfg(feature = "safetensors")] /// Something that can be loaded from a .safetensors file. pub trait LoadSafeTensors { - fn load_safetensors>( + fn load_safetensors_with, F: FnMut(String) -> String>( &mut self, path: P, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { let f = std::fs::File::open(path)?; let buffer = unsafe { memmap2::MmapOptions::new().map(&f)? }; let tensors = safetensors::SafeTensors::deserialize(&buffer)?; - self.read_safetensors("", &tensors) + self.read_safetensors_with("", &tensors, skip_missing, key_map) + } + fn load_safetensors>( + &mut self, + path: P, + ) -> Result<(), safetensors::SafeTensorError> { + self.load_safetensors_with(path, false, &mut core::convert::identity) } - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError>; + fn read_safetensors( + &mut self, + location: &str, + tensors: &safetensors::SafeTensors, + ) -> Result<(), safetensors::SafeTensorError> { + self.read_safetensors_with(location, tensors, false, &mut core::convert::identity) + } } #[cfg(feature = "safetensors")] impl, T> LoadSafeTensors for Tensor { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { - self.load_safetensor(tensors, location) + self.load_safetensor(tensors, location, skip_missing, key_map) } } #[cfg(feature = "safetensors")] impl, T> SaveSafeTensors for Tensor { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { + let location = key_map(location.to_string()); tensors.push(( - location.to_string(), + location, ::DTYPE, self.shape.concrete().into(), self.as_vec().iter().flat_map(|e| e.to_le_bytes()).collect(), @@ -189,15 +225,17 @@ macro_rules! unit_safetensors { ($Ty:ty) => { #[cfg(feature = "safetensors")] impl SaveSafeTensors for $Ty { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { + let location = key_map(location.to_string()); #[allow(unused_imports)] use crate::dtypes::ToLeBytes; tensors.push(( - location.to_string(), + location, <$Ty as crate::dtypes::SafeTensorsDtype>::DTYPE, Vec::new(), self.to_le_bytes().to_vec(), @@ -207,14 +245,23 @@ macro_rules! unit_safetensors { #[cfg(feature = "safetensors")] impl LoadSafeTensors for $Ty { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { + let location = key_map(location.to_string()); #[allow(unused_imports)] use crate::dtypes::FromLeBytes; - let view = tensors.tensor(location)?; + let view = match tensors.tensor(&location) { + Ok(ok) => ok, + Err(safetensors::SafeTensorError::TensorNotFound(_name)) if skip_missing => { + return Ok(()); + } + Err(e) => return Err(e), + }; *self = Self::from_le_bytes(view.data().try_into().unwrap()); Ok(()) } diff --git a/dfdx-core/src/nn_traits/tuples.rs b/dfdx-core/src/nn_traits/tuples.rs index 205c0419..7f267482 100644 --- a/dfdx-core/src/nn_traits/tuples.rs +++ b/dfdx-core/src/nn_traits/tuples.rs @@ -20,23 +20,32 @@ macro_rules! tuple_impls { #[cfg(feature = "safetensors")] impl<$($name: crate::nn_traits::SaveSafeTensors, )+> crate::nn_traits::SaveSafeTensors for ($($name,)+) { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { - $(self.$idx.write_safetensors(&format!("{location}.{}", $idx), tensors);)+ + $( + let name = &format!("{location}.{}", $idx); + self.$idx.write_safetensors_with(name, tensors, key_map); + )+ } } #[cfg(feature = "safetensors")] impl<$($name: crate::nn_traits::LoadSafeTensors, )+> crate::nn_traits::LoadSafeTensors for ($($name,)+) { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { - $(self.$idx.read_safetensors(&format!("{location}.{}", $idx), tensors)?;)+ + $( + let name = &format!("{location}.{}", $idx); + self.$idx.read_safetensors_with(name, tensors, skip_missing, key_map)?; + )+ Ok(()) } } diff --git a/dfdx-core/src/nn_traits/vecs.rs b/dfdx-core/src/nn_traits/vecs.rs index 593b1a55..201dd932 100644 --- a/dfdx-core/src/nn_traits/vecs.rs +++ b/dfdx-core/src/nn_traits/vecs.rs @@ -60,26 +60,31 @@ impl, T: crate::nn_traits::ZeroGrads> crate::nn_tra #[cfg(feature = "safetensors")] impl crate::nn_traits::SaveSafeTensors for Vec { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { for (i, t) in self.iter().enumerate() { - t.write_safetensors(&format!("{location}.{i}"), tensors); + let name = &format!("{location}.{i}"); + t.write_safetensors_with(name, tensors, key_map); } } } #[cfg(feature = "safetensors")] impl crate::nn_traits::LoadSafeTensors for Vec { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { for (i, t) in self.iter_mut().enumerate() { - t.read_safetensors(&format!("{location}.{i}"), tensors)?; + let name = &format!("{location}.{i}"); + t.read_safetensors_with(name, tensors, skip_missing, key_map)?; } Ok(()) } diff --git a/dfdx-core/src/tensor/safetensors.rs b/dfdx-core/src/tensor/safetensors.rs index c0566c40..626eaeaa 100644 --- a/dfdx-core/src/tensor/safetensors.rs +++ b/dfdx-core/src/tensor/safetensors.rs @@ -5,12 +5,21 @@ use std::vec::Vec; impl, T> Tensor { /// Loads data from the [SafeTensors] `Storage` with the given `key` - pub fn load_safetensor( + pub fn load_safetensor String>( &mut self, tensors: &SafeTensors, key: &str, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), SafeTensorError> { - let tensor_view = tensors.tensor(key)?; + let key = key_map(key.to_string()); + let tensor_view = match tensors.tensor(&key) { + Ok(ok) => ok, + Err(safetensors::SafeTensorError::TensorNotFound(_name)) if skip_missing => { + return Ok(()); + } + Err(e) => return Err(e), + }; let v = tensor_view.data(); let num_bytes = std::mem::size_of::(); assert_eq!( diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 7af885f9..3c68fcb3 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -196,18 +196,21 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream let safetensors_impls = if cfg!(feature = "safetensors") { quote! { impl #built_impl ::dfdx::nn_traits::SaveSafeTensors for #builder_name #built_ty #built_where { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec, Vec)>, + key_map: &mut KeyMap, ) {} } impl #built_impl ::dfdx::nn_traits::LoadSafeTensors for #builder_name #built_ty #built_where { - fn read_safetensors<'a>( + fn read_safetensors_with<'a, KeyMap: FnMut(String) -> String>( &mut self, location: &str, tensors: &::dfdx::safetensors::SafeTensors<'a>, + skip_missing: bool, + key_map: &mut KeyMap, ) -> Result<(), ::dfdx::safetensors::SafeTensorError> { Ok(()) } @@ -850,9 +853,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#name.write_safetensors( + quote_spanned!(f.span()=>self.#name.write_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), - tensors + tensors, + key_map );) } else { Default::default() @@ -869,9 +873,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#index.write_safetensors( + quote_spanned!(f.span()=>self.#index.write_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), - tensors + tensors, + key_map );) } else { Default::default() @@ -890,10 +895,11 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre proc_macro::TokenStream::from(quote! { // note: SaveSafeTensors definition is already gated by the safetensors feature impl #impl_generics ::dfdx::nn_traits::SaveSafeTensors for #name #ty_generics #where_clause { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec, Vec)>, + key_map: &mut KeyMap, ) { #save_fields } @@ -919,9 +925,11 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#name.read_safetensors( + quote_spanned!(f.span()=>self.#name.read_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), - tensors + tensors, + skip_missing, + key_map )?;) } else { Default::default() @@ -937,9 +945,11 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#index.read_safetensors( + quote_spanned!(f.span()=>self.#index.read_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), - tensors + tensors, + skip_missing, + key_map )?;) } else { Default::default() @@ -958,10 +968,12 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre proc_macro::TokenStream::from(quote! { // note: LoadSafeTensors definition is already gated by the safetensors feature impl #impl_generics ::dfdx::nn_traits::LoadSafeTensors for #name #ty_generics #where_clause { - fn read_safetensors<'a>( + fn read_safetensors_with<'a, KeyMap: FnMut(String) -> String>( &mut self, location: &str, tensors: &::dfdx::safetensors::SafeTensors<'a>, + skip_missing: bool, + key_map: &mut KeyMap, ) -> Result<(), ::dfdx::safetensors::SafeTensorError> { #load_fields Ok(())