Skip to content

Commit

Permalink
Make MEMORY_WIGGLE_ROOM a custom arg (#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Dec 3, 2024
1 parent c96ff88 commit 5e0b003
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,9 @@ struct Args {

#[clap(long, env)]
disable_sgmv: bool,

#[clap(default_value = "0.8", long, env)]
memory_wiggle_room: f32,
}

#[derive(Debug)]
Expand Down Expand Up @@ -680,6 +683,7 @@ fn shard_manager(
_shutdown_sender: mpsc::Sender<()>,
embedding_dim: Option<usize>,
disable_sgmv: bool,
memory_wiggle_room: f32,
) {
// Enter shard-manager tracing span
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
Expand Down Expand Up @@ -843,6 +847,12 @@ fn shard_manager(
envs.push(("DISABLE_SGMV".into(), "1".into()))
}

// Memory wiggle room
envs.push((
"MEMORY_WIGGLE_ROOM".into(),
memory_wiggle_room.to_string().into(),
));

// Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));

Expand Down Expand Up @@ -1303,6 +1313,7 @@ fn spawn_shards(
let backend = args.backend;
let embedding_dim = args.embedding_dim;
let disable_sgmv = args.disable_sgmv;
let memory_wiggle_room = args.memory_wiggle_room;
thread::spawn(move || {
shard_manager(
model_id,
Expand Down Expand Up @@ -1343,6 +1354,7 @@ fn spawn_shards(
shutdown_sender,
embedding_dim,
disable_sgmv,
memory_wiggle_room,
)
});
}
Expand Down
4 changes: 3 additions & 1 deletion router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ itertools = "0.12.1"
async-trait = "0.1.80"
minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
image = "0.25.1"
image = "=0.25.5"
rustls = "0.22.4"
webpki = "0.22.2"
base64 = "0.22.0"
wasm-bindgen = "=0.2.95"
wasm-bindgen-macro = "=0.2.95"

[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# CUDA memory fraction
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))

MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.95"))
MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.8"))


class FakeBarrier:
Expand Down

0 comments on commit 5e0b003

Please sign in to comment.