diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index c5d2d0b7813d1..918a90e584b57 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -23,7 +23,7 @@ sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf -from convert import HfVocab +from convert import LlamaHfVocab ###### MODEL DEFINITIONS ###### @@ -230,7 +230,7 @@ def _get_part_names(self): def _set_vocab_gpt2(self): dir_model = self.dir_model hparams = self.hparams - tokens: list[bytearray] = [] + tokens: list[str] = [] toktypes: list[int] = [] from transformers import AutoTokenizer @@ -243,8 +243,7 @@ def _set_vocab_gpt2(self): for i in range(vocab_size): if i not in reverse_vocab: - pad_token = f"[PAD{i}]".encode('utf-8') - tokens.append(bytearray(pad_token)) + tokens.append(f"[PAD{i}]") toktypes.append(gguf.TokenType.USER_DEFINED) elif reverse_vocab[i] in added_vocab: tokens.append(reverse_vocab[i]) @@ -266,7 +265,7 @@ def _set_vocab_gpt2(self): def _set_vocab_qwen(self): dir_model = self.dir_model hparams = self.hparams - tokens: list[bytearray] = [] + tokens: list[str] = [] toktypes: list[int] = [] from transformers import AutoTokenizer @@ -291,8 +290,7 @@ def _set_vocab_qwen(self): for i in range(vocab_size): if i not in reverse_vocab: - pad_token = f"[PAD{i}]".encode("utf-8") - tokens.append(bytearray(pad_token)) + tokens.append(f"[PAD{i}]") toktypes.append(gguf.TokenType.USER_DEFINED) elif reverse_vocab[i] in added_vocab: tokens.append(reverse_vocab[i]) @@ -372,12 +370,8 @@ def _set_vocab_sentencepiece(self): special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab.add_to_gguf(self.gguf_writer) - def _set_vocab_hf(self): - path = self.dir_model - added_tokens_path = self.dir_model - vocab = HfVocab( - path, added_tokens_path if added_tokens_path.exists() else None - ) + def _set_vocab_llama_hf(self): + vocab = LlamaHfVocab(self.dir_model) tokens = [] scores = [] toktypes = [] @@ -1099,7 +1093,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) def set_vocab(self): - self._set_vocab_hf() + self._set_vocab_llama_hf() def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: @@ -1700,11 +1694,8 @@ def set_gguf_parameters(self): self.gguf_writer.add_pooling_type(pooling_type) def set_vocab(self): - path = self.dir_model - added_tokens_path = self.dir_model if self.dir_model.exists() else None - # use huggingface vocab to get all tokens - vocab = HfVocab(path, added_tokens_path) + vocab = LlamaHfVocab(self.dir_model, ignore_nonllama=True) tokens, scores, toktypes = zip(*vocab.all_tokens()) assert len(tokens) == vocab.vocab_size self.vocab_size = vocab.vocab_size diff --git a/convert-persimmon-to-gguf.py b/convert-persimmon-to-gguf.py index def210531e27b..ccb99279e20a8 100755 --- a/convert-persimmon-to-gguf.py +++ b/convert-persimmon-to-gguf.py @@ -106,12 +106,12 @@ def main(): tensor_map = gguf.get_tensor_name_map(arch, block_count) print(tensor_map) for name in tensors.keys(): - data = tensors[name] + data_torch = tensors[name] if name.endswith(".self_attention.rotary_emb.inv_freq"): continue - old_dtype = data.dtype + old_dtype = data_torch.dtype # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?) - data = data.to(torch.float32).squeeze().numpy() + data = data_torch.to(torch.float32).squeeze().numpy() new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) if new_name is None: print("Can not map tensor '" + name + "'") diff --git a/convert.py b/convert.py index 817cb66123a8f..d3a9ccaf21e61 100755 --- a/convert.py +++ b/convert.py @@ -16,13 +16,14 @@ import signal import struct import sys +import textwrap import time import zipfile -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable import numpy as np from sentencepiece import SentencePieceProcessor @@ -43,6 +44,9 @@ DEFAULT_CONCURRENCY = 8 +ADDED_TOKENS_FILE = 'added_tokens.json' +FAST_TOKENIZER_FILE = 'tokenizer.json' + # # data types # @@ -188,8 +192,10 @@ def guessed(model: LazyModel) -> Params: n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) if n_layer < 1: - raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + msg = """\ + failed to guess 'n_layer'. This model is unknown or unsupported. + Suggestion: provide 'config.json' of the model in the same directory containing model files.""" + raise KeyError(textwrap.dedent(msg)) n_head = n_embd // 128 # guessed n_mult = 256 # guessed @@ -211,7 +217,8 @@ def guessed(model: LazyModel) -> Params: @staticmethod def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: - config = json.load(open(config_path)) + with open(config_path) as f: + config = json.load(f) rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None rope_scaling = config.get("rope_scaling") @@ -233,8 +240,10 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: elif "max_position_embeddings" in config: n_ctx = config["max_position_embeddings"] else: - raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + msg = """\ + failed to guess 'n_ctx'. This model is unknown or unsupported. + Suggestion: provide 'config.json' of the model in the same directory containing model files.""" + raise KeyError(textwrap.dedent(msg)) n_experts = None n_experts_used = None @@ -265,7 +274,8 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} @staticmethod def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: - config = json.load(open(config_path)) + with open(config_path) as f: + config = json.load(f) n_experts = None n_experts_used = None @@ -331,47 +341,86 @@ def load(model_plus: ModelPlus) -> Params: # vocab # -class BpeVocab: +@runtime_checkable +class BaseVocab(Protocol): + tokenizer_model: ClassVar[str] + name: ClassVar[str] + + +class NoVocab(BaseVocab): + tokenizer_model = "no_vocab" + name = "no_vocab" + + def __repr__(self) -> str: + return "" + + +@runtime_checkable +class Vocab(BaseVocab, Protocol): + vocab_size: int + added_tokens_dict: dict[str, int] + added_tokens_list: list[str] + fname_tokenizer: Path + + def __init__(self, base_path: Path): ... + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ... + + +class BpeVocab(Vocab): tokenizer_model = "gpt2" name = "bpe" - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: - self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) - if isinstance(self.bpe_tokenizer.get('model'), dict): - self.vocab = self.bpe_tokenizer["model"]["vocab"] - else: - self.vocab = self.bpe_tokenizer - added_tokens: dict[str, int] - if fname_added_tokens is not None: - # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. - added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) + def __init__(self, base_path: Path): + added_tokens: dict[str, int] = {} + + if (fname_tokenizer := base_path / 'vocab.json').exists(): + # "slow" tokenizer + with open(fname_tokenizer, encoding="utf-8") as f: + self.vocab = json.load(f) + + try: + # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. + with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f: + added_tokens = json.load(f) + except FileNotFoundError: + pass else: - # Fall back to trying to find the added tokens in tokenizer.json - tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json' - if not tokenizer_json_file.is_file(): - added_tokens = {} - else: - tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8")) - added_tokens = dict( - (item['content'], item['id']) - for item in tokenizer_json.get('added_tokens', []) - # Added tokens here can be duplicates of the main vocabulary. - if item['content'] not in self.bpe_tokenizer) - - vocab_size: int = len(self.vocab) - expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) - actual_ids = sorted(added_tokens.values()) + # "fast" tokenizer + fname_tokenizer = base_path / FAST_TOKENIZER_FILE + + # if this fails, FileNotFoundError propagates to caller + with open(fname_tokenizer, encoding="utf-8") as f: + tokenizer_json = json.load(f) + + tokenizer_model: dict[str, Any] = tokenizer_json['model'] + if ( + tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False) + or tokenizer_json['decoder']['type'] != 'ByteLevel' + ): + raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer') + + self.vocab = tokenizer_model["vocab"] + + if (added := tokenizer_json.get('added_tokens')) is not None: + # Added tokens here can be duplicates of the main vocabulary. + added_tokens = {item['content']: item['id'] + for item in added + if item['content'] not in self.vocab} + + vocab_size = len(self.vocab) + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) if expected_ids != actual_ids: expected_end_id = vocab_size + len(actual_ids) - 1 - raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}") + raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range " + f"{vocab_size} - {expected_end_id}; got {actual_ids}") items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) self.added_tokens_dict = added_tokens self.added_tokens_list = [text for (text, idx) in items] - self.vocab_size_base: int = vocab_size - self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) self.fname_tokenizer = fname_tokenizer - self.fname_added_tokens = fname_added_tokens def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()} @@ -392,19 +441,25 @@ def __repr__(self) -> str: return f"" -class SentencePieceVocab: +class SentencePieceVocab(Vocab): tokenizer_model = "llama" name = "spm" - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: - self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) - added_tokens: dict[str, int] - if fname_added_tokens is not None: - added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) - else: - added_tokens = {} + def __init__(self, base_path: Path): + added_tokens: dict[str, int] = {} + if (fname_tokenizer := base_path / 'tokenizer.model').exists(): + # normal location + try: + with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f: + added_tokens = json.load(f) + except FileNotFoundError: + pass + elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists(): + # not found in alternate location either + raise FileNotFoundError('Cannot find tokenizer.model') - vocab_size: int = self.sentencepiece_tokenizer.vocab_size() + self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) + vocab_size = self.sentencepiece_tokenizer.vocab_size() new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) @@ -414,18 +469,17 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> No raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") # Token pieces that were added to the base vocabulary. - self.added_tokens_dict = added_tokens + self.added_tokens_dict = added_tokens self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] self.vocab_size_base = vocab_size self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) self.fname_tokenizer = fname_tokenizer - self.fname_added_tokens = fname_added_tokens def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: tokenizer = self.sentencepiece_tokenizer for i in range(tokenizer.vocab_size()): piece = tokenizer.id_to_piece(i) - text: bytes = piece.encode("utf-8") + text = piece.encode("utf-8") score: float = tokenizer.get_score(i) toktype = gguf.TokenType.NORMAL @@ -458,27 +512,42 @@ def __repr__(self) -> str: return f"" -class HfVocab: +class LlamaHfVocab(Vocab): tokenizer_model = "llama" name = "hfft" - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None: + def __init__(self, base_path: Path, ignore_nonllama: bool = False): + fname_tokenizer = base_path / FAST_TOKENIZER_FILE + # if this fails, FileNotFoundError propagates to caller + with open(fname_tokenizer, encoding='utf-8') as f: + tokenizer_json = json.load(f) + + # pre-check so we know if we need transformers + tokenizer_model: dict[str, Any] = tokenizer_json['model'] + if ignore_nonllama: + pass # workaround incorrect use of this class for WordPiece + elif ( + tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False) + or tokenizer_json['decoder']['type'] != 'Sequence' + ): + raise FileNotFoundError('Cannot find Llama BPE tokenizer') + try: from transformers import AutoTokenizer except ImportError as e: raise ImportError( - "To use HfVocab, please install the `transformers` package. " + "To use LlamaHfVocab, please install the `transformers` package. " "You can install it with `pip install transformers`." ) from e - print("fname_tokenizer:", fname_tokenizer) # Allow the tokenizer to default to slow or fast versions. # Explicitly set tokenizer to use local paths. self.tokenizer = AutoTokenizer.from_pretrained( - fname_tokenizer, - cache_dir=fname_tokenizer, + base_path, + cache_dir=base_path, local_files_only=True, ) + assert self.tokenizer.is_fast # assume tokenizer.json is used # Initialize lists and dictionaries for added tokens self.added_tokens_list = [] @@ -506,8 +575,7 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None self.vocab_size_base = self.tokenizer.vocab_size self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer - self.fname_added_tokens = fname_added_tokens + self.fname_tokenizer = fname_tokenizer def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: reverse_vocab = { @@ -559,18 +627,7 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.added_tokens() def __repr__(self) -> str: - return f"" - - -class NoVocab: - tokenizer_model = "no_vocab" - name = "no_vocab" - - def __repr__(self) -> str: - return "" - - -Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab" + return f"" # @@ -588,7 +645,7 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: .reshape(weights.shape)) -class Tensor(metaclass=ABCMeta): +class Tensor(ABC): data_type: DataType @abstractmethod @@ -610,7 +667,7 @@ def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: class UnquantizedTensor(Tensor): - def __init__(self, ndarray: NDArray) -> None: + def __init__(self, ndarray: NDArray): assert isinstance(ndarray, np.ndarray) self.ndarray = ndarray self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] @@ -689,7 +746,7 @@ class ModelPlus: model: LazyModel paths: list[Path] # Where this was read from. format: Literal['ggml', 'torch', 'safetensors', 'none'] - vocab: Vocab | None # For GGML models (which have vocab built in), the vocab. + vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab. def merge_sharded(models: list[LazyModel]) -> LazyModel: @@ -698,7 +755,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel: names = {name: None for model in models for name in model} def convert(name: str) -> LazyTensor: - lazy_tensors: list[LazyTensor] = [model[name] for model in models] + lazy_tensors = [model[name] for model in models] if len(lazy_tensors) == 1: # only one file; don't go through this procedure since there might # be quantized tensors @@ -719,7 +776,7 @@ def convert(name: str) -> LazyTensor: def load() -> UnquantizedTensor: ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors] - concatenated: NDArray = np.concatenate(ndarrays, axis=axis) + concatenated = np.concatenate(ndarrays, axis=axis) return UnquantizedTensor(concatenated) description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]' return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description) @@ -807,10 +864,10 @@ def persistent_load(self, pid: Any) -> Any: def load(offset: int, elm_count: int) -> NDArray: dtype = data_type.dtype - fp = self.zip_file.open(info) - fp.seek(offset * dtype.itemsize) - size = elm_count * dtype.itemsize - data = fp.read(size) + with self.zip_file.open(info) as fp: + fp.seek(offset * dtype.itemsize) + size = elm_count * dtype.itemsize + data = fp.read(size) assert len(data) == size return np.frombuffer(data, dtype) description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' @@ -831,7 +888,7 @@ def load() -> UnquantizedTensor: def rebuild_from_type_v2(func, new_type, args, state): return func(*args) - CLASSES: dict[tuple[str, str], Any] = { + CLASSES = { # getattr used here as a workaround for mypy not being smart enough to determine # the staticmethods have a __func__ attribute. ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), @@ -890,7 +947,7 @@ def load() -> UnquantizedTensor: def must_read(fp: IO[bytes], length: int) -> bytes: ret = fp.read(length) if len(ret) < length: - raise Exception("unexpectedly reached end of file") + raise EOFError("unexpectedly reached end of file") return ret @@ -948,13 +1005,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc yield result -def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: +def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None: # Handle special case where the model's vocab size is not set if params.n_vocab == -1: raise ValueError( - f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}" + "The model's vocab size is set to -1 in params.json. Please update it manually." + + (f" Maybe {vocab.vocab_size}?" if isinstance(vocab, Vocab) else ""), ) - if isinstance(vocab, NoVocab): + if not isinstance(vocab, Vocab): return # model has no vocab # Check for a vocab size mismatch @@ -979,11 +1037,11 @@ def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> N if vocab.vocab_size < params.n_vocab: msg += " Add the --pad-vocab option and try again." - raise Exception(msg) + raise ValueError(msg) class OutputFile: - def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: + def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) def add_meta_arch(self, params: Params) -> None: @@ -1034,8 +1092,6 @@ def add_meta_arch(self, params: Params) -> None: self.gguf.add_file_type(params.ftype) def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: - assert not isinstance(vocab, NoVocab) - tokens = [] scores = [] toktypes = [] @@ -1135,7 +1191,7 @@ def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: @staticmethod def write_all( - fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, + fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: @@ -1145,11 +1201,11 @@ def write_all( # meta data of.add_meta_arch(params) - if isinstance(vocab, NoVocab): - of.gguf.add_tokenizer_model(vocab.tokenizer_model) - else: + if isinstance(vocab, Vocab): of.add_meta_vocab(vocab) of.add_meta_special_vocab(svocab) + else: # NoVocab + of.gguf.add_tokenizer_model(vocab.tokenizer_model) # tensor info for name, lazy_tensor in model.items(): @@ -1176,7 +1232,7 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} - raise Exception(f"Unexpected combination of types: {name_to_type}") + raise ValueError(f"Unexpected combination of types: {name_to_type}") def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: @@ -1186,7 +1242,7 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel: tmap = gguf.TensorNameMap(ARCH, params.n_layer) - should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) + should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) tmp = model @@ -1213,8 +1269,7 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> if skip_unknown: print(f"Unexpected tensor name: {name} - skipping") continue - else: - raise Exception(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)") + raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)") if tensor_type in should_skip: print(f"skipping tensor {name_new}") @@ -1231,7 +1286,7 @@ def nth_multifile_path(path: Path, n: int) -> Path | None: the nth path in the model. ''' # Support the following patterns: - patterns: list[tuple[str, str]] = [ + patterns = [ # - x.00.pth, x.01.pth, etc. (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. @@ -1277,9 +1332,9 @@ def load_some_model(path: Path) -> ModelPlus: globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] files = [file for glob in globs for file in path.glob(glob)] if not files: - raise Exception(f"Can't find model in directory {path}") + raise FileNotFoundError(f"Can't find model in directory {path}") if len(files) > 1: - raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}") + raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}") path = files[0] paths = find_multifile_paths(path) @@ -1293,36 +1348,14 @@ def load_some_model(path: Path) -> ModelPlus: class VocabFactory: - _FILES = {"spm": "tokenizer.model", "bpe": "vocab.json", "hfft": "tokenizer.json"} + _VOCAB_CLASSES: list[type[Vocab]] = [SentencePieceVocab, BpeVocab, LlamaHfVocab] def __init__(self, path: Path): self.path = path - self.file_paths = self._detect_files() - print(f"Found vocab files: {self.file_paths}") - - def _detect_files(self) -> dict[str, Path | None]: - def locate(file: str) -> Path | None: - if (path := self.path / file).exists(): - return path - if (path := self.path.parent / file).exists(): - return path - return None - - return {vt: locate(f) for vt, f in self._FILES.items()} - - def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]: - for vtype in vocab_types: - try: - path = self.file_paths[vtype] - except KeyError: - raise ValueError(f"Unsupported vocabulary type {vtype}") from None - if path is not None: - return vtype, path - raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}") - def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab: + def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab: load_merges = vocab.name == "bpe" - n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None + n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None return gguf.SpecialVocab( model_parent_path, load_merges=load_merges, @@ -1331,27 +1364,29 @@ def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.S ) def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab: - vocab_type, path = self._select_file(vocab_types) - print(f"Loading vocab file {path!r}, type {vocab_type!r}") + vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES} + selected_vocabs: dict[str, type[Vocab]] = {} + for vtype in vocab_types: + try: + selected_vocabs[vtype] = vocab_classes[vtype] + except KeyError: + raise ValueError(f"Unsupported vocabulary type {vtype}") from None - added_tokens_path = path.parent / "added_tokens.json" - if vocab_type == "bpe": - return BpeVocab( - path, added_tokens_path if added_tokens_path.exists() else None - ) - if vocab_type == "spm": - return SentencePieceVocab( - path, added_tokens_path if added_tokens_path.exists() else None - ) - if vocab_type == "hfft": - return HfVocab( - path.parent, added_tokens_path if added_tokens_path.exists() else None - ) - raise ValueError(vocab_type) + for vtype, cls in selected_vocabs.items(): + try: + vocab = cls(self.path) + break + except FileNotFoundError: + pass # ignore unavailable tokenizers + else: + raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}") + + print(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}") + return vocab - def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: - vocab: Vocab - if len(vocab_types) == 1 and "no_vocab" in vocab_types: + def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]: + vocab: BaseVocab + if vocab_types is None: vocab = NoVocab() else: vocab = self._create_vocab_by_path(vocab_types) @@ -1408,10 +1443,8 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing") args = parser.parse_args(args_in) - if args.no_vocab: - if args.vocab_only: - raise ValueError("no need to specify --vocab-only if using --no-vocab") - args.vocab_type = "no_vocab" + if args.no_vocab and args.vocab_only: + raise ValueError("--vocab-only does not make sense with --no-vocab") if args.dump_single: model_plus = lazy_load_file(args.model) @@ -1433,10 +1466,12 @@ def main(args_in: list[str] | None = None) -> None: params = Params.load(model_plus) if params.n_ctx == -1: if args.ctx is None: - raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" - "Please specify one with --ctx:\n" - " - LLaMA v1: --ctx 2048\n" - " - LLaMA v2: --ctx 4096\n") + msg = """\ + The model doesn't have a context size, and you didn't specify one with --ctx + Please specify one with --ctx: + - LLaMA v1: --ctx 2048 + - LLaMA v2: --ctx 4096""" + parser.error(textwrap.dedent(msg)) params.n_ctx = args.ctx if args.outtype: @@ -1451,9 +1486,11 @@ def main(args_in: list[str] | None = None) -> None: model_parent_path = model_plus.paths[0].parent vocab_path = Path(args.vocab_dir or args.model or model_parent_path) vocab_factory = VocabFactory(vocab_path) - vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path) + vocab_types = None if args.no_vocab else args.vocab_type.split(",") + vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path) if args.vocab_only: + assert isinstance(vocab, Vocab) if not args.outfile: raise ValueError("need --outfile if using --vocab-only") outfile = args.outfile diff --git a/llama.h b/llama.h index 1fe4af495820f..f061d014ca8eb 100644 --- a/llama.h +++ b/llama.h @@ -60,9 +60,9 @@ extern "C" { enum llama_vocab_type { LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab - LLAMA_VOCAB_TYPE_SPM = 1, // SentencePiece - LLAMA_VOCAB_TYPE_BPE = 2, // Byte Pair Encoding - LLAMA_VOCAB_TYPE_WPM = 3, // WordPiece + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece }; // note: these values should be synchronized with ggml_rope