Skip to content

Commit

Permalink
remove some device calls, add some comments
Browse files Browse the repository at this point in the history
undo some uneccessary changes
  • Loading branch information
dan-garvey committed Aug 28, 2024
1 parent 47e5f44 commit 7542ac4
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 57 deletions.
5 changes: 2 additions & 3 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from sharktank.types import *

# TODO: Should be using a base class with the protocol supported.
from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1


def main():
from sharktank.utils import cli
from ..utils import cli

parser = cli.create_parser()
cli.add_input_dataset_options(parser)
Expand Down Expand Up @@ -49,7 +49,6 @@ def main():
dataset_type = cli.get_input_data_files(args)
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
dataset = cli.get_input_dataset(args)
print('\n'.join([x for x in dataset.root_theta.flatten() if x.endswith(".weight")]))

hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
llama_config = LlamaModelConfig(hp)
Expand Down
32 changes: 12 additions & 20 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@

import torch

from sharktank.layers import *
from sharktank.types import *
from ..layers import *
from ..types import *

# TODO: Should be using a base class with the protocol supported.
from sharktank.models.llama.llama import *
from sharktank.utils.debugging import trace_tensor
from sharktank.utils.tokenizer import InferenceTokenizer, load_tokenizer
from sharktank.utils.patching import SaveModuleResultTensorsPatch
from sharktank.models.punet.tools.sample_data import get_random_inputs, load_inputs, save_outputs
from ..models.llama.llama import *
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer


class TorchGenerator:
"""Generator that runs directly on the Torch model."""
Expand All @@ -50,15 +49,12 @@ def block_seq_stride(self) -> int:
return self.model.cache.block_seq_stride

def begin_batch(self, prompts: list[str]):
#token_ids, seq_lens = self.tokenizer.encode(
# prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
#)

#token_ids = torch.tensor(token_ids, device=self.model.device)
#seq_lens = torch.tensor(seq_lens, device=self.model.device)
with safe_open("/home/nod/batch.safetensors", framework="pt", device="cpu") as st:
token_ids=st.get_tensor("batch").to(device=self.model.device)
seq_lens = torch.tensor([2048]).to(device=self.model.device)
token_ids, seq_lens = self.tokenizer.encode(
prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
)

token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
if self.shared_cache_state is not None:
cache_state = self.shared_cache_state
else:
Expand Down Expand Up @@ -264,23 +260,19 @@ def main():
intermediates_saver.patch_child_modules(model)
generator = TorchGenerator(model, tokenizer)


print(f":: Prompting:")
for prompt in prompts:
print(f" {prompt.encode()}")


batch = generator.begin_batch(prompts)
print(f":: Prompt tokens: {batch.token_ids}")
batch.prefill()
intermediates_saver.save_file("/home/nod/stank.safetensors")
print(batch.detokenize())

if args.save_intermediates_path:
intermediates_saver.save_file(
args.save_intermediates_path + "_prefill.safetensors"
)
exit()
counter = 0
while not batch.done:
batch.decode()
Expand Down
3 changes: 0 additions & 3 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class LlamaHParams:
attention_layer_norm_rms_epsilon: float
attention_head_count_kv: int

# @staticmethod
# def from_hf_props(p: dict[str, Any]):

@staticmethod
def from_gguf_props(p: dict[str, Any]):
attention_head_count = _int_prop(p, "llama.attention.head_count")
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def __init__(
self.seq_length = seq_length
self.device = device
self.dtype = dtype
print("cache dtype = ", dtype)

@property
def pad_sequence_stride(self) -> int:
Expand Down
7 changes: 2 additions & 5 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
from typing import Optional

import torch
from safetensors.torch import save_file
from torch.nn import functional as F
from .. import ops
from .base import Theta, ThetaLayer
from ..types.layout_utils import saturate_cast
from ..types import (
DynamicScaledQuantizer,
QuantizedTensor,
Expand Down Expand Up @@ -44,7 +41,7 @@ def __init__(
):
super().__init__(theta)
self._simulate_native_quant = True
self.weight = self.theta_tensor(weight_name) # .to(device="cuda:0")
self.weight = self.theta_tensor(weight_name)
self.bias = None
if bias_name in self.theta.keys:
self.bias = self.theta_tensor(bias_name)
Expand Down Expand Up @@ -73,9 +70,9 @@ def forward(self, x):
# TODO: probably need a way to only do q_input if exporting.
print("qdq input")
x = qdq_input.quantize(x).unpack().dequant()
# from torch.nn import functional as F

y = ops.linear(x, weight, bias)

# Unconditionally dequantize.
# TODO: Support a q_output specifier that signals the layer to let
# the QuantizedTensor escape.
Expand Down
4 changes: 0 additions & 4 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from .. import ops
from .base import Theta, ThetaLayer
from safetensors.torch import save_file


class RMSNormLayer(ThetaLayer):
Expand All @@ -35,9 +34,6 @@ def __init__(

def forward(self, x: torch.Tensor):
orig_dtype = x.dtype
print("norm dtype: ", self.dtype)
print("orgi_dtype: ", orig_dtype)

x = x.to(self.dtype)
norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon)
# Will automatically upcast to the dtype of the weight, which is
Expand Down
7 changes: 2 additions & 5 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from dataclasses import dataclass
import math
from safetensors.torch import safe_open, save_file
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -36,7 +35,7 @@ class LlamaModelConfig:
block_seq_stride: int = 16

# Either "paged" or "direct".
kv_cache_type: str = "direct"
kv_cache_type: str = "paged"

# The device on which to place intermediate state.
device: Optional[torch.device] = None
Expand Down Expand Up @@ -114,17 +113,15 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
activation_dtype=config.activation_dtype,
attention_dtype=config.attention_dtype,
)
self.hf = False
self.config = config
self.hp = hp
self.cache = config.create_kv_cache()
self.activation_dtype = config.activation_dtype
self.use_hf = config.use_hf

key = "token_embd"
self.add_module(
"token_embedding",
TokenEmbeddingLayer(theta(key), dtype=config.activation_dtype),
TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
)
self.add_module(
"attention_embedding",
Expand Down
16 changes: 5 additions & 11 deletions sharktank/sharktank/models/llama/tools/import_quark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Imports Brevitas pre-processed weights and quantization config into a
Dataset.
"""Imports quark pre-processed weights and quantization config into a
Dataset of the gguf format.
Usage:
python -m sharktank.models.punet.import_hf_dataset \
--output-irpa-file ~/models/punet/punet_fp16.irpa \
--config-json ~/models/stable-diffusion-xl-base-1.0/unet/config.json
python -m sharktank.models.llama.tools.import_quark_dataset \
--params=llama2-7b-fp8.safetensors --output-irpa-file=new.irpa \
--config-json=../llama2/config.json
The resulting dataset has all tensors as nested in the original model.
Properties are separated into a "meta" dict (for "_" prefixed props) and an
"hparams" dict.
Default flag values assume that there is a quant_param.json and
params.safetensors adjacent to the HF config.json file.
"""
from typing import Optional

Expand Down
5 changes: 3 additions & 2 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,11 @@ def scaled_dot_product_attention(q, k, v, a) -> Tensor:
@rms_norm.override(Tensor, Tensor)
def rms_norm_default(x, weight, *, epsilon: float) -> Tensor:
x = unbox_tensor(x)
weight = unbox_tensor(weight).to(device=x.device)
weight = unbox_tensor(weight)
variance = x.pow(2).mean(-1, keepdim=True)
output = x * torch.rsqrt(variance + epsilon)
output = weight * output.to(torch.float16)
# The cast here is to match the hf implementation, affects numerics
output = weight * output.to(weight.dtype)
return output


Expand Down
5 changes: 2 additions & 3 deletions sharktank/sharktank/types/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,15 @@ def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor
if axis is None:
# Per tensor.
if offset is None:
print(self._scale)
print(self.dtype)
# Changed to t/reciprocal because narrow float types are garbage
qs = saturate_cast(
t / self._reciprocal_scale,
dtype=self.dtype,
disable_saturate=self._disable_saturate,
)
else:
qs = saturate_cast(
t * self._scale + offset,
t / self._reciprocal_scale + offset,
dtype=self.dtype,
disable_saturate=self._disable_saturate,
)
Expand Down

0 comments on commit 7542ac4

Please sign in to comment.