diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 08faa51ec..d81e88f99 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -15,11 +15,11 @@ from sharktank.types import * # TODO: Should be using a base class with the protocol supported. -from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 +from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 def main(): - from sharktank.utils import cli + from ..utils import cli parser = cli.create_parser() cli.add_input_dataset_options(parser) @@ -49,7 +49,6 @@ def main(): dataset_type = cli.get_input_data_files(args) dataset_type = "irpa" if "irpa" in dataset_type else "gguf" dataset = cli.get_input_dataset(args) - print('\n'.join([x for x in dataset.root_theta.flatten() if x.endswith(".weight")])) hp = configs.LlamaHParams.from_gguf_props(dataset.properties) llama_config = LlamaModelConfig(hp) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index ba4c8adc6..a9d14195c 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -15,15 +15,14 @@ import torch -from sharktank.layers import * -from sharktank.types import * +from ..layers import * +from ..types import * # TODO: Should be using a base class with the protocol supported. -from sharktank.models.llama.llama import * -from sharktank.utils.debugging import trace_tensor -from sharktank.utils.tokenizer import InferenceTokenizer, load_tokenizer -from sharktank.utils.patching import SaveModuleResultTensorsPatch -from sharktank.models.punet.tools.sample_data import get_random_inputs, load_inputs, save_outputs +from ..models.llama.llama import * +from ..utils.debugging import trace_tensor +from ..utils.tokenizer import InferenceTokenizer + class TorchGenerator: """Generator that runs directly on the Torch model.""" @@ -50,15 +49,12 @@ def block_seq_stride(self) -> int: return self.model.cache.block_seq_stride def begin_batch(self, prompts: list[str]): - #token_ids, seq_lens = self.tokenizer.encode( - # prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride - #) - - #token_ids = torch.tensor(token_ids, device=self.model.device) - #seq_lens = torch.tensor(seq_lens, device=self.model.device) - with safe_open("/home/nod/batch.safetensors", framework="pt", device="cpu") as st: - token_ids=st.get_tensor("batch").to(device=self.model.device) - seq_lens = torch.tensor([2048]).to(device=self.model.device) + token_ids, seq_lens = self.tokenizer.encode( + prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride + ) + + token_ids = torch.tensor(token_ids, device=self.model.device) + seq_lens = torch.tensor(seq_lens, device=self.model.device) if self.shared_cache_state is not None: cache_state = self.shared_cache_state else: @@ -264,23 +260,19 @@ def main(): intermediates_saver.patch_child_modules(model) generator = TorchGenerator(model, tokenizer) - print(f":: Prompting:") for prompt in prompts: print(f" {prompt.encode()}") - batch = generator.begin_batch(prompts) print(f":: Prompt tokens: {batch.token_ids}") batch.prefill() - intermediates_saver.save_file("/home/nod/stank.safetensors") print(batch.detokenize()) if args.save_intermediates_path: intermediates_saver.save_file( args.save_intermediates_path + "_prefill.safetensors" ) - exit() counter = 0 while not batch.done: batch.decode() diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 5a829c6de..ab548e286 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -41,9 +41,6 @@ class LlamaHParams: attention_layer_norm_rms_epsilon: float attention_head_count_kv: int - # @staticmethod - # def from_hf_props(p: dict[str, Any]): - @staticmethod def from_gguf_props(p: dict[str, Any]): attention_head_count = _int_prop(p, "llama.attention.head_count") diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 2bf1a0725..319f2353c 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -100,7 +100,6 @@ def __init__( self.seq_length = seq_length self.device = device self.dtype = dtype - print("cache dtype = ", dtype) @property def pad_sequence_stride(self) -> int: diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index d5f6d8a3b..926664d18 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -7,11 +7,8 @@ from typing import Optional import torch -from safetensors.torch import save_file -from torch.nn import functional as F from .. import ops from .base import Theta, ThetaLayer -from ..types.layout_utils import saturate_cast from ..types import ( DynamicScaledQuantizer, QuantizedTensor, @@ -44,7 +41,7 @@ def __init__( ): super().__init__(theta) self._simulate_native_quant = True - self.weight = self.theta_tensor(weight_name) # .to(device="cuda:0") + self.weight = self.theta_tensor(weight_name) self.bias = None if bias_name in self.theta.keys: self.bias = self.theta_tensor(bias_name) @@ -73,9 +70,9 @@ def forward(self, x): # TODO: probably need a way to only do q_input if exporting. print("qdq input") x = qdq_input.quantize(x).unpack().dequant() - # from torch.nn import functional as F y = ops.linear(x, weight, bias) + # Unconditionally dequantize. # TODO: Support a q_output specifier that signals the layer to let # the QuantizedTensor escape. diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py index 04eafaaad..68818a2d7 100644 --- a/sharktank/sharktank/layers/norm.py +++ b/sharktank/sharktank/layers/norm.py @@ -8,7 +8,6 @@ from .. import ops from .base import Theta, ThetaLayer -from safetensors.torch import save_file class RMSNormLayer(ThetaLayer): @@ -35,9 +34,6 @@ def __init__( def forward(self, x: torch.Tensor): orig_dtype = x.dtype - print("norm dtype: ", self.dtype) - print("orgi_dtype: ", orig_dtype) - x = x.to(self.dtype) norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon) # Will automatically upcast to the dtype of the weight, which is diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 0d7b717ee..ff3d9ed1a 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -8,7 +8,6 @@ from dataclasses import dataclass import math -from safetensors.torch import safe_open, save_file import torch import torch.nn as nn import torch.nn.functional as F @@ -36,7 +35,7 @@ class LlamaModelConfig: block_seq_stride: int = 16 # Either "paged" or "direct". - kv_cache_type: str = "direct" + kv_cache_type: str = "paged" # The device on which to place intermediate state. device: Optional[torch.device] = None @@ -114,17 +113,15 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype, ) - self.hf = False self.config = config self.hp = hp self.cache = config.create_kv_cache() self.activation_dtype = config.activation_dtype self.use_hf = config.use_hf - key = "token_embd" self.add_module( "token_embedding", - TokenEmbeddingLayer(theta(key), dtype=config.activation_dtype), + TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), ) self.add_module( "attention_embedding", diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py index 3ef763779..fcf38571b 100644 --- a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py +++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py @@ -4,20 +4,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Imports Brevitas pre-processed weights and quantization config into a -Dataset. +"""Imports quark pre-processed weights and quantization config into a +Dataset of the gguf format. Usage: - python -m sharktank.models.punet.import_hf_dataset \ - --output-irpa-file ~/models/punet/punet_fp16.irpa \ - --config-json ~/models/stable-diffusion-xl-base-1.0/unet/config.json + python -m sharktank.models.llama.tools.import_quark_dataset \ + --params=llama2-7b-fp8.safetensors --output-irpa-file=new.irpa \ + --config-json=../llama2/config.json -The resulting dataset has all tensors as nested in the original model. -Properties are separated into a "meta" dict (for "_" prefixed props) and an -"hparams" dict. - -Default flag values assume that there is a quant_param.json and -params.safetensors adjacent to the HF config.json file. """ from typing import Optional diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index d8a582035..a61f35591 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -182,10 +182,11 @@ def scaled_dot_product_attention(q, k, v, a) -> Tensor: @rms_norm.override(Tensor, Tensor) def rms_norm_default(x, weight, *, epsilon: float) -> Tensor: x = unbox_tensor(x) - weight = unbox_tensor(weight).to(device=x.device) + weight = unbox_tensor(weight) variance = x.pow(2).mean(-1, keepdim=True) output = x * torch.rsqrt(variance + epsilon) - output = weight * output.to(torch.float16) + # The cast here is to match the hf implementation, affects numerics + output = weight * output.to(weight.dtype) return output diff --git a/sharktank/sharktank/types/quantizers.py b/sharktank/sharktank/types/quantizers.py index bd976c7dc..6dcc35140 100644 --- a/sharktank/sharktank/types/quantizers.py +++ b/sharktank/sharktank/types/quantizers.py @@ -139,8 +139,7 @@ def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor if axis is None: # Per tensor. if offset is None: - print(self._scale) - print(self.dtype) + # Changed to t/reciprocal because narrow float types are garbage qs = saturate_cast( t / self._reciprocal_scale, dtype=self.dtype, @@ -148,7 +147,7 @@ def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor ) else: qs = saturate_cast( - t * self._scale + offset, + t / self._reciprocal_scale + offset, dtype=self.dtype, disable_saturate=self._disable_saturate, )