From 47e5f44d6683090fd8d2f9c0d082d7253932f463 Mon Sep 17 00:00:00 2001 From: dan Date: Wed, 28 Aug 2024 13:14:45 -0700 Subject: [PATCH] some cleanup --- .../sharktank/layers/configs/llm_configs.py | 25 +---- sharktank/sharktank/layers/linear.py | 20 +--- sharktank/sharktank/layers/norm.py | 11 --- sharktank/sharktank/models/llama/llama.py | 94 ++++--------------- 4 files changed, 23 insertions(+), 127 deletions(-) diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index e02cf740e..5a829c6de 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -41,32 +41,11 @@ class LlamaHParams: attention_layer_norm_rms_epsilon: float attention_head_count_kv: int -# @staticmethod -# def from_hf_props(p: dict[str, Any]): - + # @staticmethod + # def from_hf_props(p: dict[str, Any]): @staticmethod def from_gguf_props(p: dict[str, Any]): - if "hparams" in p: - hp = p["hparams"] - print("num_attention_heads: ", _int_prop(hp, "num_attention_heads")) - attention_head_count=_int_prop(hp, "num_attention_heads") - print("head_count_kv: ", _optional_int_prop(hp, "num_key_value_heads", attention_head_count)) - attn_head_dim=int(_int_prop(hp, "hidden_size") // _int_prop(hp, "num_attention_heads")) - return LlamaHParams( - context_length=_int_prop(hp, "max_position_embeddings"), - embedding_length=_int_prop(hp, "hidden_size"), - block_count=_int_prop(hp, "num_hidden_layers"), - feed_forward_length=_int_prop(hp, "intermediate_size"), - attn_head_dim=attn_head_dim, - rope_dimension_count=attn_head_dim, - attention_head_count=attention_head_count, - attention_layer_norm_rms_epsilon=_float_prop(hp, "rms_norm_eps"), - attention_head_count_kv=_optional_int_prop(hp, "num_key_value_heads", attention_head_count), - ) - - - attention_head_count = _int_prop(p, "llama.attention.head_count") return LlamaHParams( context_length=_int_prop(p, "llama.context_length"), diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index 35e213b7c..d5f6d8a3b 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -41,7 +41,6 @@ def __init__( *, weight_name: str = "weight", bias_name: str = "bias", - debug_save_file=None, ): super().__init__(theta) self._simulate_native_quant = True @@ -57,7 +56,6 @@ def __init__( if self.q_input is not None and self.qdq_input is not None: raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input") self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output") - self.debug_save_file = debug_save_file def forward(self, x): weight = self.weight @@ -72,6 +70,7 @@ def forward(self, x): if q_input is not None: x = q_input.quantize(x) elif qdq_input is not None: + # TODO: probably need a way to only do q_input if exporting. print("qdq input") x = qdq_input.quantize(x).unpack().dequant() # from torch.nn import functional as F @@ -80,22 +79,9 @@ def forward(self, x): # Unconditionally dequantize. # TODO: Support a q_output specifier that signals the layer to let # the QuantizedTensor escape. - qdq_y = None if isinstance(y, QuantizedTensor): y = y.unpack().dequant() if qdq_output is not None: - qdq_y = qdq_output.quantize(y).unpack().dequant() - if self.debug_save_file != None: - print(f"debug save file: {self.debug_save_file}") - save_file( - { - "qdq_y": qdq_y, - "y": y, - "qdq_i": x, - "input": original_input, - "weight": weight.unpack().dequant(), - }, - self.debug_save_file, - ) - y = qdq_y if qdq_y is not None else y + # TODO: same as above. + y = qdq_output.quantize(y).unpack().dequant() return y diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py index ee20f318a..04eafaaad 100644 --- a/sharktank/sharktank/layers/norm.py +++ b/sharktank/sharktank/layers/norm.py @@ -32,7 +32,6 @@ def __init__( self.weight = self.theta_tensor(weight_name) self.epsilon = epsilon self.dtype = dtype - self.debug_save_file = debug_save_file def forward(self, x: torch.Tensor): orig_dtype = x.dtype @@ -44,14 +43,4 @@ def forward(self, x: torch.Tensor): # Will automatically upcast to the dtype of the weight, which is # often in higher precision. Downcast back to expected. norm = norm.to(orig_dtype) - if self.debug_save_file is not None: - save_file( - { - "input": x, - "variance": torch.tensor(self.epsilon), - "weight": self.weight.as_torch(), - "output": norm, - }, - self.debug_save_file, - ) return norm diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 817d3d51f..0d7b717ee 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -135,18 +135,17 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): use_hf=self.use_hf, ), ) - key = "output_norm" if "output_norm" in list(theta.keys) else "model.norm" self.add_module( "output_norm", - RMSNormLayer(theta(key), epsilon=self.hp.attention_layer_norm_rms_epsilon), + RMSNormLayer( + theta("output_norm"), epsilon=self.hp.attention_layer_norm_rms_epsilon + ), ) - key = "output" if "output" in list(theta.keys) else "lm_head" - self.add_module("output_lm_head", LinearLayer(theta(key))) - key = "blk" if "blk" in list(theta.keys) else "model.layers" + self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList( [ PagedLlamaAttentionBlock( - theta(key, n), + theta("blk", n), block_index=n, cache=self.cache, head_count=hp.attention_head_count, @@ -290,76 +289,19 @@ def __init__( use_hf: bool = False, ): super().__init__(theta) - if "input_layernorm" in list(theta.keys): - self.add_module( - "attn_norm", - RMSNormLayer( - theta("input_layernorm"), - epsilon=rms_epsilon, - debug_save_file=f"input_layernorm_{block_index}.safetensors", - ), - ) - self.add_module( - "attn_q", - LinearLayer( - theta("self_attn.q_proj"), - debug_save_file=f"attn_q_{block_index}.safetensors", - ), - ) - self.add_module( - "attn_k", - LinearLayer( - theta("self_attn.k_proj"), - debug_save_file=f"attn_k_{block_index}.safetensors", - ), - ) - self.add_module( - "attn_v", - LinearLayer( - theta("self_attn.v_proj"), - debug_save_file=f"attn_v_{block_index}.safetensors", - ), - ) - self.add_module( - "attn_output", - LinearLayer( - theta("self_attn.o_proj"), - debug_save_file=f"attn_output_{block_index}.safetensors", - ), - ) - self.add_module( - "ffn_norm", - RMSNormLayer(theta("post_attention_layernorm"), epsilon=rms_epsilon), - ) - self.add_module("ffn_gate", LinearLayer(theta("mlp.gate_proj"))) - self.add_module( - "ffn_up", - LinearLayer( - theta("mlp.up_proj"), - debug_save_file=f"ffn_up_{block_index}.safetensors", - ), - ) - self.add_module( - "ffn_down", - LinearLayer( - theta("mlp.down_proj"), - debug_save_file=f"ffn_down_{block_index}.safetensors", - ), - ) - else: - self.add_module( - "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) - ) - self.add_module("attn_q", LinearLayer(theta("attn_q"))) - self.add_module("attn_k", LinearLayer(theta("attn_k"))) - self.add_module("attn_v", LinearLayer(theta("attn_v"))) - self.add_module("attn_output", LinearLayer(theta("attn_output"))) - self.add_module( - "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) - ) - self.add_module("ffn_gate", LinearLayer(theta("ffn_gate"))) - self.add_module("ffn_up", LinearLayer(theta("ffn_up"))) - self.add_module("ffn_down", LinearLayer(theta("ffn_down"))) + self.add_module( + "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) + ) + self.add_module("attn_q", LinearLayer(theta("attn_q"))) + self.add_module("attn_k", LinearLayer(theta("attn_k"))) + self.add_module("attn_v", LinearLayer(theta("attn_v"))) + self.add_module("attn_output", LinearLayer(theta("attn_output"))) + self.add_module( + "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) + ) + self.add_module("ffn_gate", LinearLayer(theta("ffn_gate"))) + self.add_module("ffn_up", LinearLayer(theta("ffn_up"))) + self.add_module("ffn_down", LinearLayer(theta("ffn_down"))) self.block_index = block_index self.cache = cache