diff --git a/moshi_mlx/moshi_mlx/models/__init__.py b/moshi_mlx/moshi_mlx/models/__init__.py index ce8fc8f..0046162 100644 --- a/moshi_mlx/moshi_mlx/models/__init__.py +++ b/moshi_mlx/moshi_mlx/models/__init__.py @@ -6,5 +6,5 @@ Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. """ -from .lm import Lm, LmConfig, config_v0_1, config1b_202412 +from .lm import Lm, LmConfig, config_v0_1, config1b_202412, config_helium_1_preview_2b from .generate import LmGen diff --git a/moshi_mlx/moshi_mlx/models/lm.py b/moshi_mlx/moshi_mlx/models/lm.py index 0c23197..6941c4c 100644 --- a/moshi_mlx/moshi_mlx/models/lm.py +++ b/moshi_mlx/moshi_mlx/models/lm.py @@ -125,9 +125,13 @@ def __init__(self, cfg: LmConfig): self.transformer_cache: list[RotatingKVCache] = ( self.transformer.make_rot_cache() ) - self.depformer_cache: list[KVCache] = self.depformer.slices[ - 0 - ].transformer.make_cache() + + if len(self.depformer.slices) > 0: + self.depformer_cache: list[KVCache] = self.depformer.slices[ + 0 + ].transformer.make_cache() + else: + self.depformer_cache = [] def __call__( self, @@ -300,3 +304,42 @@ def config_v0_1() -> LmConfig: audio_codebooks=16, audio_delays=([0] + [1] * 7) * 2, ) + + +def config_helium_1_preview_2b() -> LmConfig: + transformer = TransformerConfig( + d_model=2560, + num_heads=20, + num_layers=24, + dim_feedforward=2560 * 4, # dim * hidden_scale + causal=True, + norm_first=True, + bias_ff=False, + bias_attn=False, + layer_scale=None, + context=4096, + max_period=100000, + use_conv_block=False, + use_conv_bias=True, + cross_attention=False, + gating=True, + norm="rms_norm", + positional_embedding="rope", + conv_layout=False, + conv_kernel_size=3, + kv_repeat=1, + max_seq_len=4096, + ) + depformer = DepFormerConfig( + transformer=transformer, + num_slices=0, + ) + return LmConfig( + transformer=transformer, + depformer=depformer, + audio_vocab_size=2049, + text_in_vocab_size=48000, + text_out_vocab_size=48000, + audio_codebooks=0, + audio_delays=[], + ) diff --git a/moshi_mlx/moshi_mlx/run_helium.py b/moshi_mlx/moshi_mlx/run_helium.py new file mode 100644 index 0000000..c0a560f --- /dev/null +++ b/moshi_mlx/moshi_mlx/run_helium.py @@ -0,0 +1,62 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sentencepiece +import huggingface_hub +import mlx.core as mx +from moshi_mlx import models, utils + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--tokenizer", type=str) + parser.add_argument("--weights", type=str) + parser.add_argument("--nsteps", type=int, default=20) + parser.add_argument("--hf-repo", type=str, default="kyutai/helium-1-preview-2b-mlx") + parser.add_argument("--prompt", type=str, default="Aujourd'hui, il est temps") + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + weights = args.weights + if weights is None: + weights = huggingface_hub.hf_hub_download( + args.hf_repo, "helium-1-preview-2b-bf16.safetensors" + ) + tokenizer = args.tokenizer + if tokenizer is None: + tokenizer = huggingface_hub.hf_hub_download( + args.hf_repo, "tokenizer_spm_48k_multi6_2.model" + ) + + mx.random.seed(299792458) + lm_config = models.config_helium_1_preview_2b() + model = models.Lm(lm_config) + model.set_dtype(mx.bfloat16) + model.load_weights(weights, strict=True) + sampler = utils.Sampler() + tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore + if args.verbose: + print("prompt", args.prompt) + else: + print(args.prompt, end="", flush=True) + prompt_tokens = tokenizer.encode(args.prompt) # type: ignore + token = mx.array([[1] + prompt_tokens]) + for step_idx in range(args.nsteps): + logits = model(token) + token, _ = sampler(logits[:, -1]) + text_token = token.item() + _text = tokenizer.id_to_piece(text_token) # type: ignore + _text = _text.replace("▁", " ") + _text = _text.replace("<0x0A>", "\n") + if args.verbose: + print(step_idx, token, _text) + else: + print(_text, end="", flush=True) + token = token[None] + print() + + +if __name__ == "__main__": + main() diff --git a/rust/moshi-core/src/lm.rs b/rust/moshi-core/src/lm.rs index 5ee1e20..37bf384 100644 --- a/rust/moshi-core/src/lm.rs +++ b/rust/moshi-core/src/lm.rs @@ -588,7 +588,7 @@ pub fn load_streaming>( dev: &Device, ) -> Result { let cfg = Config::v0_1_streaming(8); - let is_gguf = model_file.as_ref().extension().map_or(false, |v| v == "gguf"); + let is_gguf = model_file.as_ref().extension().is_some_and(|v| v == "gguf"); let lm = if is_gguf { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(model_file, dev)?; @@ -609,7 +609,7 @@ pub fn load_streaming_both_ways>( dev: &Device, ) -> Result { let cfg = Config::v0_1_streaming(16); - let is_gguf = model_file.as_ref().extension().map_or(false, |v| v == "gguf"); + let is_gguf = model_file.as_ref().extension().is_some_and(|v| v == "gguf"); let lm = if is_gguf { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(model_file, dev)?; diff --git a/scripts/import_helium_mlx.py b/scripts/import_helium_mlx.py new file mode 100644 index 0000000..d8add32 --- /dev/null +++ b/scripts/import_helium_mlx.py @@ -0,0 +1,73 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import torch +from pathlib import Path +from safetensors import safe_open +from safetensors.torch import save_file +from huggingface_hub import hf_hub_download + + +def import_model(in_path: Path, out_path: Path, silent: bool = False) -> None: + with safe_open(in_path, framework="pt", device="cpu") as f: + tensors = { key: f.get_tensor(key) for key in f.keys() } + model = { + "text_emb.weight": tensors["model.embed_tokens.weight"], + "text_linear.weight": tensors["lm_head.weight"], + "out_norm.weight": tensors["model.norm.weight"], + } + n_layers = -1 + for key in tensors.keys(): + if key.startswith("model.layers."): + layer_idx = int(key.split(".")[2]) + n_layers = max(layer_idx, n_layers) + n_layers += 1 + if not silent: + print(f"found {n_layers} layers") + for layer_idx in range(n_layers): + dst_prefix = f"transformer.layers.{layer_idx}." + src_prefix = f"model.layers.{layer_idx}." + _model = { + "norm1.weight": "input_layernorm.weight", + "norm2.weight": "post_attention_layernorm.weight", + "self_attn.out_proj.weight": "self_attn.o_proj.weight", + "gating.linear_out.weight": "mlp.down_proj.weight", + } + for dst, src in _model.items(): + model[dst_prefix + dst] = tensors[src_prefix + src] + gate_proj = tensors[src_prefix + "mlp.gate_proj.weight"] + up_proj = tensors[src_prefix + "mlp.up_proj.weight"] + linear_in = torch.cat([gate_proj, up_proj], dim=0) + model[dst_prefix + "gating.linear_in.weight"] = linear_in + q = tensors[src_prefix + "self_attn.q_proj.weight"] + k = tensors[src_prefix + "self_attn.k_proj.weight"] + v = tensors[src_prefix + "self_attn.v_proj.weight"] + in_proj = torch.cat([q, k, v], dim=0) + model[dst_prefix + "self_attn.in_proj.weight"] = in_proj + + save_file(model, out_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", type=str, default="kyutai/helium-1-preview-2b", help="the transformers checkpoint to import") + parser.add_argument("--out", type=str, help="the mlx safetensors file to generate") + parser.add_argument( + "-s", "--silent", action="store_true", help="Only prints the checkpoint name" + ) + args = parser.parse_args() + + ckpt_path = Path(args.checkpoint) + if not ckpt_path.exists(): + ckpt_path = hf_hub_download(repo_id=args.checkpoint, filename="model.safetensors") + out_path = Path(args.out) + if not out_path.exists(): + import_model(ckpt_path, out_path, silent=args.silent) + print(out_path) + + +if __name__ == "__main__": + main() +