Skip to content

Commit

Permalink
some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Aug 28, 2024
1 parent 7acf9d6 commit 47e5f44
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 127 deletions.
25 changes: 2 additions & 23 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,11 @@ class LlamaHParams:
attention_layer_norm_rms_epsilon: float
attention_head_count_kv: int

# @staticmethod
# def from_hf_props(p: dict[str, Any]):

# @staticmethod
# def from_hf_props(p: dict[str, Any]):

@staticmethod
def from_gguf_props(p: dict[str, Any]):
if "hparams" in p:
hp = p["hparams"]
print("num_attention_heads: ", _int_prop(hp, "num_attention_heads"))
attention_head_count=_int_prop(hp, "num_attention_heads")
print("head_count_kv: ", _optional_int_prop(hp, "num_key_value_heads", attention_head_count))
attn_head_dim=int(_int_prop(hp, "hidden_size") // _int_prop(hp, "num_attention_heads"))
return LlamaHParams(
context_length=_int_prop(hp, "max_position_embeddings"),
embedding_length=_int_prop(hp, "hidden_size"),
block_count=_int_prop(hp, "num_hidden_layers"),
feed_forward_length=_int_prop(hp, "intermediate_size"),
attn_head_dim=attn_head_dim,
rope_dimension_count=attn_head_dim,
attention_head_count=attention_head_count,
attention_layer_norm_rms_epsilon=_float_prop(hp, "rms_norm_eps"),
attention_head_count_kv=_optional_int_prop(hp, "num_key_value_heads", attention_head_count),
)



attention_head_count = _int_prop(p, "llama.attention.head_count")
return LlamaHParams(
context_length=_int_prop(p, "llama.context_length"),
Expand Down
20 changes: 3 additions & 17 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
*,
weight_name: str = "weight",
bias_name: str = "bias",
debug_save_file=None,
):
super().__init__(theta)
self._simulate_native_quant = True
Expand All @@ -57,7 +56,6 @@ def __init__(
if self.q_input is not None and self.qdq_input is not None:
raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input")
self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output")
self.debug_save_file = debug_save_file

def forward(self, x):
weight = self.weight
Expand All @@ -72,6 +70,7 @@ def forward(self, x):
if q_input is not None:
x = q_input.quantize(x)
elif qdq_input is not None:
# TODO: probably need a way to only do q_input if exporting.
print("qdq input")
x = qdq_input.quantize(x).unpack().dequant()
# from torch.nn import functional as F
Expand All @@ -80,22 +79,9 @@ def forward(self, x):
# Unconditionally dequantize.
# TODO: Support a q_output specifier that signals the layer to let
# the QuantizedTensor escape.
qdq_y = None
if isinstance(y, QuantizedTensor):
y = y.unpack().dequant()
if qdq_output is not None:
qdq_y = qdq_output.quantize(y).unpack().dequant()
if self.debug_save_file != None:
print(f"debug save file: {self.debug_save_file}")
save_file(
{
"qdq_y": qdq_y,
"y": y,
"qdq_i": x,
"input": original_input,
"weight": weight.unpack().dequant(),
},
self.debug_save_file,
)
y = qdq_y if qdq_y is not None else y
# TODO: same as above.
y = qdq_output.quantize(y).unpack().dequant()
return y
11 changes: 0 additions & 11 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(
self.weight = self.theta_tensor(weight_name)
self.epsilon = epsilon
self.dtype = dtype
self.debug_save_file = debug_save_file

def forward(self, x: torch.Tensor):
orig_dtype = x.dtype
Expand All @@ -44,14 +43,4 @@ def forward(self, x: torch.Tensor):
# Will automatically upcast to the dtype of the weight, which is
# often in higher precision. Downcast back to expected.
norm = norm.to(orig_dtype)
if self.debug_save_file is not None:
save_file(
{
"input": x,
"variance": torch.tensor(self.epsilon),
"weight": self.weight.as_torch(),
"output": norm,
},
self.debug_save_file,
)
return norm
94 changes: 18 additions & 76 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,17 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
use_hf=self.use_hf,
),
)
key = "output_norm" if "output_norm" in list(theta.keys) else "model.norm"
self.add_module(
"output_norm",
RMSNormLayer(theta(key), epsilon=self.hp.attention_layer_norm_rms_epsilon),
RMSNormLayer(
theta("output_norm"), epsilon=self.hp.attention_layer_norm_rms_epsilon
),
)
key = "output" if "output" in list(theta.keys) else "lm_head"
self.add_module("output_lm_head", LinearLayer(theta(key)))
key = "blk" if "blk" in list(theta.keys) else "model.layers"
self.add_module("output_lm_head", LinearLayer(theta("output")))
self.attn_blocks = nn.ModuleList(
[
PagedLlamaAttentionBlock(
theta(key, n),
theta("blk", n),
block_index=n,
cache=self.cache,
head_count=hp.attention_head_count,
Expand Down Expand Up @@ -290,76 +289,19 @@ def __init__(
use_hf: bool = False,
):
super().__init__(theta)
if "input_layernorm" in list(theta.keys):
self.add_module(
"attn_norm",
RMSNormLayer(
theta("input_layernorm"),
epsilon=rms_epsilon,
debug_save_file=f"input_layernorm_{block_index}.safetensors",
),
)
self.add_module(
"attn_q",
LinearLayer(
theta("self_attn.q_proj"),
debug_save_file=f"attn_q_{block_index}.safetensors",
),
)
self.add_module(
"attn_k",
LinearLayer(
theta("self_attn.k_proj"),
debug_save_file=f"attn_k_{block_index}.safetensors",
),
)
self.add_module(
"attn_v",
LinearLayer(
theta("self_attn.v_proj"),
debug_save_file=f"attn_v_{block_index}.safetensors",
),
)
self.add_module(
"attn_output",
LinearLayer(
theta("self_attn.o_proj"),
debug_save_file=f"attn_output_{block_index}.safetensors",
),
)
self.add_module(
"ffn_norm",
RMSNormLayer(theta("post_attention_layernorm"), epsilon=rms_epsilon),
)
self.add_module("ffn_gate", LinearLayer(theta("mlp.gate_proj")))
self.add_module(
"ffn_up",
LinearLayer(
theta("mlp.up_proj"),
debug_save_file=f"ffn_up_{block_index}.safetensors",
),
)
self.add_module(
"ffn_down",
LinearLayer(
theta("mlp.down_proj"),
debug_save_file=f"ffn_down_{block_index}.safetensors",
),
)
else:
self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
)
self.add_module("attn_q", LinearLayer(theta("attn_q")))
self.add_module("attn_k", LinearLayer(theta("attn_k")))
self.add_module("attn_v", LinearLayer(theta("attn_v")))
self.add_module("attn_output", LinearLayer(theta("attn_output")))
self.add_module(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))
self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
)
self.add_module("attn_q", LinearLayer(theta("attn_q")))
self.add_module("attn_k", LinearLayer(theta("attn_k")))
self.add_module("attn_v", LinearLayer(theta("attn_v")))
self.add_module("attn_output", LinearLayer(theta("attn_output")))
self.add_module(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))

self.block_index = block_index
self.cache = cache
Expand Down

0 comments on commit 47e5f44

Please sign in to comment.