Skip to content

Commit

Permalink
Add --disable-sgmv flag (#639)
Browse files Browse the repository at this point in the history
  • Loading branch information
joseph-predibase authored Oct 16, 2024
1 parent 4c84b4f commit 3818e1a
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ struct Args {
/// The embedding dimension to use for the model.
#[clap(long, env)]
embedding_dim: Option<usize>,

#[clap(long, env)]
disable_sgmv: bool,
}

#[derive(Debug)]
Expand Down Expand Up @@ -500,6 +503,7 @@ fn shard_manager(
shutdown: Arc<AtomicBool>,
_shutdown_sender: mpsc::Sender<()>,
embedding_dim: Option<usize>,
disable_sgmv: bool,
) {
// Enter shard-manager tracing span
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
Expand Down Expand Up @@ -640,6 +644,10 @@ fn shard_manager(
envs.push(("FLASH_INFER".into(), "1".into()));
}

if disable_sgmv {
envs.push(("DISABLE_SGMV".into(), "1".into()))
}

// Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));

Expand Down Expand Up @@ -1088,6 +1096,7 @@ fn spawn_shards(
let merge_adapter_weights = args.merge_adapter_weights;
let backend = args.backend;
let embedding_dim = args.embedding_dim;
let disable_sgmv = args.disable_sgmv;
thread::spawn(move || {
shard_manager(
model_id,
Expand Down Expand Up @@ -1123,6 +1132,7 @@ fn spawn_shards(
shutdown,
shutdown_sender,
embedding_dim,
disable_sgmv,
)
});
}
Expand Down

0 comments on commit 3818e1a

Please sign in to comment.