Skip to content

Commit

Permalink
Convert HF models with sparse threshold specified
Browse files Browse the repository at this point in the history
  • Loading branch information
Szy0127 authored and hodlen committed Jan 4, 2024
1 parent c2cfbf4 commit e4560a5
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions convert-hf-to-powerinfer-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
import torch.nn as tnn

from dataclasses import dataclass

if TYPE_CHECKING:
from torch import Tensor

Expand Down Expand Up @@ -333,7 +335,7 @@ class LlamaModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()

def set_gguf_parameters(self):
def set_gguf_parameters(self, params: PredictorParams):
self.gguf_writer.add_name("Llama")
self.gguf_writer.add_context_length(2048) # not in config.json
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
Expand All @@ -348,6 +350,9 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
self.gguf_writer.add_file_type(self.ftype)

if params.sparse_threshold is not None:
self.gguf_writer.add_sparse_threshold(params.sparse_threshold)

def write_tensors(self):
for name, data_torch in self.get_tensors():
# we don't need these
Expand Down Expand Up @@ -406,7 +411,7 @@ def write_tensors(self):


class FalconModel(Model):
def set_gguf_parameters(self):
def set_gguf_parameters(self, params: PredictorParams):
block_count = self.hparams.get("num_hidden_layers")
if block_count is None:
block_count = self.hparams["n_layer"] # old name
Expand All @@ -430,6 +435,9 @@ def set_gguf_parameters(self):
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)

if params.sparse_threshold is not None:
self.gguf_writer.add_sparse_threshold(params.sparse_threshold)

def write_tensors(self):
n_head = self.hparams.get("num_attention_heads")
if n_head is None:
Expand Down Expand Up @@ -506,6 +514,29 @@ def write_tensors(self):
self.gguf_writer.add_tensor(new_name, data)



@dataclass
class PredictorParams:
sparse_threshold: float | None = None

@staticmethod
def loadPredictorJson(config_path: Path) -> PredictorParams:
config = json.load(open(config_path))
return PredictorParams(
sparse_threshold = config.get("sparse_threshold"),
)

@staticmethod
def load(model_instance: Model) -> PredictorParams:
config_path = model_instance.dir_mlp_pred / "config.json"

if config_path.exists():
params = PredictorParams.loadPredictorJson(config_path)
else:
params = PredictorParams()

return params

###### CONVERSION LOGIC ######


Expand Down Expand Up @@ -581,7 +612,8 @@ def parse_args() -> argparse.Namespace:
)

print("Set model parameters")
model_instance.set_gguf_parameters()
params = PredictorParams.load(model_instance)
model_instance.set_gguf_parameters(params)

print("Set model tokenizer")
model_instance.set_vocab()
Expand Down

0 comments on commit e4560a5

Please sign in to comment.