Skip to content

Commit

Permalink
Added {load/read/save/write}_safetensor_with methods
Browse files Browse the repository at this point in the history
This alternative method:
- Requires load/read to decide whether it should skip missing tensors;
- Requires load/read/save/write to decide how should keys be mapped.
  • Loading branch information
swfsql committed Mar 1, 2024
1 parent d971e90 commit 4b9824e
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 36 deletions.
75 changes: 61 additions & 14 deletions dfdx-core/src/nn_traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,13 @@ pub trait ZeroGrads<E: Dtype, D: Device<E>> {
#[cfg(feature = "safetensors")]
/// Something that can be saved to a .safetensors file.
pub trait SaveSafeTensors {
fn save_safetensors<P: AsRef<std::path::Path>>(
fn save_safetensors_with<P: AsRef<std::path::Path>, F: FnMut(String) -> String>(
&self,
path: P,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
let mut tensors = Vec::new();
self.write_safetensors("", &mut tensors);
self.write_safetensors_with("", &mut tensors, key_map);
let data = tensors.iter().map(|(k, dtype, shape, data)| {
(
k.clone(),
Expand All @@ -131,53 +132,88 @@ pub trait SaveSafeTensors {

safetensors::serialize_to_file(data, &None, path.as_ref())
}
fn write_safetensors(
fn save_safetensors<P: AsRef<std::path::Path>>(
&self,
path: P,
) -> Result<(), safetensors::SafeTensorError> {
self.save_safetensors_with(path, &mut core::convert::identity)
}
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
);
fn write_safetensors(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
) {
self.write_safetensors_with(location, tensors, &mut core::convert::identity)
}
}

#[cfg(feature = "safetensors")]
/// Something that can be loaded from a .safetensors file.
pub trait LoadSafeTensors {
fn load_safetensors<P: AsRef<std::path::Path>>(
fn load_safetensors_with<P: AsRef<std::path::Path>, F: FnMut(String) -> String>(
&mut self,
path: P,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
let f = std::fs::File::open(path)?;
let buffer = unsafe { memmap2::MmapOptions::new().map(&f)? };
let tensors = safetensors::SafeTensors::deserialize(&buffer)?;
self.read_safetensors("", &tensors)
self.read_safetensors_with("", &tensors, skip_missing, key_map)
}
fn load_safetensors<P: AsRef<std::path::Path>>(
&mut self,
path: P,
) -> Result<(), safetensors::SafeTensorError> {
self.load_safetensors_with(path, false, &mut core::convert::identity)
}

fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError>;
fn read_safetensors(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
) -> Result<(), safetensors::SafeTensorError> {
self.read_safetensors_with(location, tensors, false, &mut core::convert::identity)
}
}

#[cfg(feature = "safetensors")]
impl<S: Shape, E: Dtype, D: Device<E>, T> LoadSafeTensors for Tensor<S, E, D, T> {
fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
self.load_safetensor(tensors, location)
self.load_safetensor(tensors, location, skip_missing, key_map)
}
}

#[cfg(feature = "safetensors")]
impl<S: Shape, E: Dtype, D: Device<E>, T> SaveSafeTensors for Tensor<S, E, D, T> {
fn write_safetensors(
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
) {
let location = key_map(location.to_string());
tensors.push((
location.to_string(),
location,
<E as crate::dtypes::SafeTensorsDtype>::DTYPE,
self.shape.concrete().into(),
self.as_vec().iter().flat_map(|e| e.to_le_bytes()).collect(),
Expand All @@ -189,15 +225,17 @@ macro_rules! unit_safetensors {
($Ty:ty) => {
#[cfg(feature = "safetensors")]
impl SaveSafeTensors for $Ty {
fn write_safetensors(
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
) {
let location = key_map(location.to_string());
#[allow(unused_imports)]
use crate::dtypes::ToLeBytes;
tensors.push((
location.to_string(),
location,
<$Ty as crate::dtypes::SafeTensorsDtype>::DTYPE,
Vec::new(),
self.to_le_bytes().to_vec(),
Expand All @@ -207,14 +245,23 @@ macro_rules! unit_safetensors {

#[cfg(feature = "safetensors")]
impl LoadSafeTensors for $Ty {
fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
let location = key_map(location.to_string());
#[allow(unused_imports)]
use crate::dtypes::FromLeBytes;
let view = tensors.tensor(location)?;
let view = match tensors.tensor(&location) {
Ok(ok) => ok,
Err(safetensors::SafeTensorError::TensorNotFound(_name)) if skip_missing => {
return Ok(());
}
Err(e) => return Err(e),
};
*self = Self::from_le_bytes(view.data().try_into().unwrap());
Ok(())
}
Expand Down
17 changes: 13 additions & 4 deletions dfdx-core/src/nn_traits/tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,32 @@ macro_rules! tuple_impls {

#[cfg(feature = "safetensors")]
impl<$($name: crate::nn_traits::SaveSafeTensors, )+> crate::nn_traits::SaveSafeTensors for ($($name,)+) {
fn write_safetensors(
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
) {
$(self.$idx.write_safetensors(&format!("{location}.{}", $idx), tensors);)+
$(
let name = &format!("{location}.{}", $idx);
self.$idx.write_safetensors_with(name, tensors, key_map);
)+
}
}

#[cfg(feature = "safetensors")]
impl<$($name: crate::nn_traits::LoadSafeTensors, )+> crate::nn_traits::LoadSafeTensors for ($($name,)+) {
fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
$(self.$idx.read_safetensors(&format!("{location}.{}", $idx), tensors)?;)+
$(
let name = &format!("{location}.{}", $idx);
self.$idx.read_safetensors_with(name, tensors, skip_missing, key_map)?;
)+
Ok(())
}
}
Expand Down
13 changes: 9 additions & 4 deletions dfdx-core/src/nn_traits/vecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,31 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::ZeroGrads<E, D>> crate::nn_tra

#[cfg(feature = "safetensors")]
impl<T: crate::nn_traits::SaveSafeTensors> crate::nn_traits::SaveSafeTensors for Vec<T> {
fn write_safetensors(
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
) {
for (i, t) in self.iter().enumerate() {
t.write_safetensors(&format!("{location}.{i}"), tensors);
let name = &format!("{location}.{i}");
t.write_safetensors_with(name, tensors, key_map);
}
}
}

#[cfg(feature = "safetensors")]
impl<T: crate::nn_traits::LoadSafeTensors> crate::nn_traits::LoadSafeTensors for Vec<T> {
fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
for (i, t) in self.iter_mut().enumerate() {
t.read_safetensors(&format!("{location}.{i}"), tensors)?;
let name = &format!("{location}.{i}");
t.read_safetensors_with(name, tensors, skip_missing, key_map)?;
}
Ok(())
}
Expand Down
13 changes: 11 additions & 2 deletions dfdx-core/src/tensor/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@ use std::vec::Vec;

impl<S: Shape, E: Dtype, D: CopySlice<E>, T> Tensor<S, E, D, T> {
/// Loads data from the [SafeTensors] `Storage<E>` with the given `key`
pub fn load_safetensor(
pub fn load_safetensor<F: FnMut(String) -> String>(
&mut self,
tensors: &SafeTensors,
key: &str,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), SafeTensorError> {
let tensor_view = tensors.tensor(key)?;
let key = key_map(key.to_string());
let tensor_view = match tensors.tensor(&key) {
Ok(ok) => ok,
Err(safetensors::SafeTensorError::TensorNotFound(_name)) if skip_missing => {
return Ok(());
}
Err(e) => return Err(e),
};
let v = tensor_view.data();
let num_bytes = std::mem::size_of::<E>();
assert_eq!(
Expand Down
36 changes: 24 additions & 12 deletions dfdx-derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,21 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream
let safetensors_impls = if cfg!(feature = "safetensors") {
quote! {
impl #built_impl ::dfdx::nn_traits::SaveSafeTensors for #builder_name #built_ty #built_where {
fn write_safetensors(
fn write_safetensors_with<KeyMap: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut KeyMap,
) {}
}

impl #built_impl ::dfdx::nn_traits::LoadSafeTensors for #builder_name #built_ty #built_where {
fn read_safetensors<'a>(
fn read_safetensors_with<'a, KeyMap: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &::dfdx::safetensors::SafeTensors<'a>,
skip_missing: bool,
key_map: &mut KeyMap,
) -> Result<(), ::dfdx::safetensors::SafeTensorError> {
Ok(())
}
Expand Down Expand Up @@ -850,9 +853,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
where_clause
.predicates
.push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors));
quote_spanned!(f.span()=>self.#name.write_safetensors(
quote_spanned!(f.span()=>self.#name.write_safetensors_with(
&format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str),
tensors
tensors,
key_map
);)
} else {
Default::default()
Expand All @@ -869,9 +873,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
where_clause
.predicates
.push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors));
quote_spanned!(f.span()=>self.#index.write_safetensors(
quote_spanned!(f.span()=>self.#index.write_safetensors_with(
&format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index),
tensors
tensors,
key_map
);)
} else {
Default::default()
Expand All @@ -890,10 +895,11 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
proc_macro::TokenStream::from(quote! {
// note: SaveSafeTensors definition is already gated by the safetensors feature
impl #impl_generics ::dfdx::nn_traits::SaveSafeTensors for #name #ty_generics #where_clause {
fn write_safetensors(
fn write_safetensors_with<KeyMap: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut KeyMap,
) {
#save_fields
}
Expand All @@ -919,9 +925,11 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
where_clause
.predicates
.push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors));
quote_spanned!(f.span()=>self.#name.read_safetensors(
quote_spanned!(f.span()=>self.#name.read_safetensors_with(
&format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str),
tensors
tensors,
skip_missing,
key_map
)?;)
} else {
Default::default()
Expand All @@ -937,9 +945,11 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
where_clause
.predicates
.push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors));
quote_spanned!(f.span()=>self.#index.read_safetensors(
quote_spanned!(f.span()=>self.#index.read_safetensors_with(
&format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index),
tensors
tensors,
skip_missing,
key_map
)?;)
} else {
Default::default()
Expand All @@ -958,10 +968,12 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
proc_macro::TokenStream::from(quote! {
// note: LoadSafeTensors definition is already gated by the safetensors feature
impl #impl_generics ::dfdx::nn_traits::LoadSafeTensors for #name #ty_generics #where_clause {
fn read_safetensors<'a>(
fn read_safetensors_with<'a, KeyMap: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &::dfdx::safetensors::SafeTensors<'a>,
skip_missing: bool,
key_map: &mut KeyMap,
) -> Result<(), ::dfdx::safetensors::SafeTensorError> {
#load_fields
Ok(())
Expand Down

0 comments on commit 4b9824e

Please sign in to comment.