Skip to content

Commit

Permalink
Adding A100 compute. (#2806)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Dec 6, 2024
1 parent 5df8059 commit d96dcb1
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ struct RawConfig {
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
#[serde(rename = "num_experts_per_tok")]
experts: Option<usize>,
num_experts_per_token: Option<usize>,
#[serde(rename = "n_shared_experts")]
num_shared_experts: Option<usize>,
}

#[derive(Deserialize)]
Expand All @@ -196,7 +198,8 @@ struct Config {
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
experts: Option<usize>,
num_experts_per_token: usize,
num_shared_experts: usize,
}

impl Config {
Expand All @@ -210,11 +213,9 @@ impl Config {
let num_kv_heads = self.num_kv_heads? as u64;
let head_dim = self.head_dim? as u64;
let hidden_size = self.hidden_size? as u64;
let intermediate_size = if let Some(experts) = self.experts {
(self.intermediate_size? * experts) as u64
} else {
self.intermediate_size? as u64
};
let intermediate_size = (self.intermediate_size?
* (self.num_experts_per_token + self.num_shared_experts))
as u64;
let num_layers = self.num_layers? as u64;

let q_flops = 2 * num_heads * head_dim * hidden_size;
Expand Down Expand Up @@ -257,7 +258,8 @@ impl From<RawConfig> for Config {
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
let experts = other.experts;
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
let num_shared_experts = other.num_shared_experts.unwrap_or(0);
Config {
max_position_embeddings,
quantize,
Expand All @@ -270,7 +272,8 @@ impl From<RawConfig> for Config {
num_kv_heads,
intermediate_size,
num_layers,
experts,
num_experts_per_token,
num_shared_experts,
}
}
}
Expand Down Expand Up @@ -1547,6 +1550,7 @@ impl ComputeType {
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
"nvidia-h100-80gb-hbm3" => Some(900 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
"nvidia-a100-sxm4-80gb" => Some(312 * 10u64.pow(12)),
"nvidia-a100" => Some(312 * 10u64.pow(12)),
card => {
tracing::warn!("Unkown compute for card {card}");
Expand Down

0 comments on commit d96dcb1

Please sign in to comment.