diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 0b266a41ee..9d245f12ab 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -535,14 +535,15 @@ impl Backend for ShardedSafeTensors { fn get( &self, - target_shape: Shape, // The size is not checked for ShardedTensors + target_shape: Shape, // The size is only checked when the world size is 1. path: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result { - if h == Default::default() { - // no sharding + if h.world_size == 1 { + // There is no sharding to be applied here so we use the default backend to speed + // things up. return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev); }