From 62442254d09f88e4bf3d6dacb458e92d4f6f6c1f Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 14 Dec 2023 08:07:34 -0600 Subject: [PATCH] Tweaks. --- candle-nn/src/var_builder.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 0b266a41ee..9d245f12ab 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -535,14 +535,15 @@ impl Backend for ShardedSafeTensors { fn get( &self, - target_shape: Shape, // The size is not checked for ShardedTensors + target_shape: Shape, // The size is only checked when the world size is 1. path: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result { - if h == Default::default() { - // no sharding + if h.world_size == 1 { + // There is no sharding to be applied here so we use the default backend to speed + // things up. return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev); }