Skip to content

Commit

Permalink
remove some default values
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Aug 28, 2024
1 parent 7542ac4 commit ed1195a
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions sharktank/sharktank/models/llama/tools/import_quark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,6 @@
_int_prop,
)

# It is possible to import quant params from stock unet weights for testing.
# Quality won't be great but needs SMOOTHQUANT prescaling disabled to work
# at all.
IMPORT_SMOOTHQUANT_PRESCALE = False

# Quantizing the bias can produce better fusions but puts more pressure on
# datatype ranges.
QUANTIZE_BIAS = True


def _load_json(p: Path):
print(f"Loading {p}")
Expand Down Expand Up @@ -106,9 +97,9 @@ def apply_per_layer_quant(
root_theta: Theta,
layer_name: str,
updated_tensors: dict[str, InferenceTensor],
n_head=32,
split_sizes=[4096, 4096, 4096],
):
n_head: int,
split_sizes: list[int],
) -> dict[str, InferenceTensor]:

layer_theta = root_theta(layer_name)

Expand Down Expand Up @@ -234,7 +225,6 @@ def quantize_weight(

# Remove the updated tensor from the original tree.
root_theta.pop(layer_name)
return updated_tensors


def convert_hf_hparams_to_gguf(hf_hparams: dict[str, any]) -> dict[str, any]:
Expand Down Expand Up @@ -275,7 +265,6 @@ def update_norm_layer(
name=new_name + ".kv_cache_scaling_factor", data=kv_cache_scale
)
updated_tensors[new_name] = kv_cache_scale
return updated_tensors


def single_replace(
Expand All @@ -286,7 +275,6 @@ def single_replace(
):
data = quant_theta(layer_name).tensor("weight").as_torch()
updated_tensors[gguf_name] = DefaultPrimitiveTensor(name=gguf_name, data=data)
return updated_tensors


def main(argv):
Expand Down Expand Up @@ -329,6 +317,7 @@ def main(argv):

updated_tensors: dict[str, InferenceTensor] = {}
model_layers = [f"model.layers.{i}" for i in range(num_layers)]

sub_layers = [
"mlp.gate_proj",
"mlp.down_proj",
Expand All @@ -339,13 +328,19 @@ def main(argv):
for layer in model_layers:
for sub in sub_layers:
layer_name = layer + "." + sub
updated_tensors = apply_per_layer_quant(
apply_per_layer_quant(
quant_theta, layer_name, updated_tensors, split_sizes=split_sizes
)

# Update the non quantized weights (norm layers)
for layer_idx in model_layers:
updated_tensors = update_norm_layer(quant_theta, layer_idx, updated_tensors)
update_norm_layer(
quant_theta,
layer_idx,
updated_tensors,
head_count=updated_properties["llama.attention.head_count"],
split_sizes=split_sizes,
)

# The stragglers
stragglers = [
Expand All @@ -354,7 +349,7 @@ def main(argv):
("lm_head", "output.weight"),
]
for layer, new_name in stragglers:
updated_tensors = single_replace(quant_theta, layer, new_name, updated_tensors)
single_replace(quant_theta, layer, new_name, updated_tensors)

new_theta = Theta(updated_tensors)
# Make a new Dataset from the updated properties and tensors.
Expand Down

0 comments on commit ed1195a

Please sign in to comment.