Skip to content

Commit

Permalink
Merge pull request #184 from kyutai-labs/helium-mlx-import
Browse files Browse the repository at this point in the history
Helium inference for MLX
  • Loading branch information
LaurentMazare authored Jan 13, 2025
2 parents 6aba1ae + 29df729 commit 5a5d188
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 6 deletions.
2 changes: 1 addition & 1 deletion moshi_mlx/moshi_mlx/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
"""

from .lm import Lm, LmConfig, config_v0_1, config1b_202412
from .lm import Lm, LmConfig, config_v0_1, config1b_202412, config_helium_1_preview_2b
from .generate import LmGen
49 changes: 46 additions & 3 deletions moshi_mlx/moshi_mlx/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,13 @@ def __init__(self, cfg: LmConfig):
self.transformer_cache: list[RotatingKVCache] = (
self.transformer.make_rot_cache()
)
self.depformer_cache: list[KVCache] = self.depformer.slices[
0
].transformer.make_cache()

if len(self.depformer.slices) > 0:
self.depformer_cache: list[KVCache] = self.depformer.slices[
0
].transformer.make_cache()
else:
self.depformer_cache = []

def __call__(
self,
Expand Down Expand Up @@ -300,3 +304,42 @@ def config_v0_1() -> LmConfig:
audio_codebooks=16,
audio_delays=([0] + [1] * 7) * 2,
)


def config_helium_1_preview_2b() -> LmConfig:
transformer = TransformerConfig(
d_model=2560,
num_heads=20,
num_layers=24,
dim_feedforward=2560 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=4096,
max_period=100000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=False,
gating=True,
norm="rms_norm",
positional_embedding="rope",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
)
depformer = DepFormerConfig(
transformer=transformer,
num_slices=0,
)
return LmConfig(
transformer=transformer,
depformer=depformer,
audio_vocab_size=2049,
text_in_vocab_size=48000,
text_out_vocab_size=48000,
audio_codebooks=0,
audio_delays=[],
)
62 changes: 62 additions & 0 deletions moshi_mlx/moshi_mlx/run_helium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import sentencepiece
import huggingface_hub
import mlx.core as mx
from moshi_mlx import models, utils


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", type=str)
parser.add_argument("--weights", type=str)
parser.add_argument("--nsteps", type=int, default=20)
parser.add_argument("--hf-repo", type=str, default="kyutai/helium-1-preview-2b-mlx")
parser.add_argument("--prompt", type=str, default="Aujourd'hui, il est temps")
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()

weights = args.weights
if weights is None:
weights = huggingface_hub.hf_hub_download(
args.hf_repo, "helium-1-preview-2b-bf16.safetensors"
)
tokenizer = args.tokenizer
if tokenizer is None:
tokenizer = huggingface_hub.hf_hub_download(
args.hf_repo, "tokenizer_spm_48k_multi6_2.model"
)

mx.random.seed(299792458)
lm_config = models.config_helium_1_preview_2b()
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
model.load_weights(weights, strict=True)
sampler = utils.Sampler()
tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore
if args.verbose:
print("prompt", args.prompt)
else:
print(args.prompt, end="", flush=True)
prompt_tokens = tokenizer.encode(args.prompt) # type: ignore
token = mx.array([[1] + prompt_tokens])
for step_idx in range(args.nsteps):
logits = model(token)
token, _ = sampler(logits[:, -1])
text_token = token.item()
_text = tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
_text = _text.replace("<0x0A>", "\n")
if args.verbose:
print(step_idx, token, _text)
else:
print(_text, end="", flush=True)
token = token[None]
print()


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions rust/moshi-core/src/lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ pub fn load_streaming<P: AsRef<std::path::Path>>(
dev: &Device,
) -> Result<LmModel> {
let cfg = Config::v0_1_streaming(8);
let is_gguf = model_file.as_ref().extension().map_or(false, |v| v == "gguf");
let is_gguf = model_file.as_ref().extension().is_some_and(|v| v == "gguf");
let lm = if is_gguf {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(model_file, dev)?;
Expand All @@ -609,7 +609,7 @@ pub fn load_streaming_both_ways<P: AsRef<std::path::Path>>(
dev: &Device,
) -> Result<LmModel> {
let cfg = Config::v0_1_streaming(16);
let is_gguf = model_file.as_ref().extension().map_or(false, |v| v == "gguf");
let is_gguf = model_file.as_ref().extension().is_some_and(|v| v == "gguf");
let lm = if is_gguf {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(model_file, dev)?;
Expand Down
73 changes: 73 additions & 0 deletions scripts/import_helium_mlx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import torch
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
from huggingface_hub import hf_hub_download


def import_model(in_path: Path, out_path: Path, silent: bool = False) -> None:
with safe_open(in_path, framework="pt", device="cpu") as f:
tensors = { key: f.get_tensor(key) for key in f.keys() }
model = {
"text_emb.weight": tensors["model.embed_tokens.weight"],
"text_linear.weight": tensors["lm_head.weight"],
"out_norm.weight": tensors["model.norm.weight"],
}
n_layers = -1
for key in tensors.keys():
if key.startswith("model.layers."):
layer_idx = int(key.split(".")[2])
n_layers = max(layer_idx, n_layers)
n_layers += 1
if not silent:
print(f"found {n_layers} layers")
for layer_idx in range(n_layers):
dst_prefix = f"transformer.layers.{layer_idx}."
src_prefix = f"model.layers.{layer_idx}."
_model = {
"norm1.weight": "input_layernorm.weight",
"norm2.weight": "post_attention_layernorm.weight",
"self_attn.out_proj.weight": "self_attn.o_proj.weight",
"gating.linear_out.weight": "mlp.down_proj.weight",
}
for dst, src in _model.items():
model[dst_prefix + dst] = tensors[src_prefix + src]
gate_proj = tensors[src_prefix + "mlp.gate_proj.weight"]
up_proj = tensors[src_prefix + "mlp.up_proj.weight"]
linear_in = torch.cat([gate_proj, up_proj], dim=0)
model[dst_prefix + "gating.linear_in.weight"] = linear_in
q = tensors[src_prefix + "self_attn.q_proj.weight"]
k = tensors[src_prefix + "self_attn.k_proj.weight"]
v = tensors[src_prefix + "self_attn.v_proj.weight"]
in_proj = torch.cat([q, k, v], dim=0)
model[dst_prefix + "self_attn.in_proj.weight"] = in_proj

save_file(model, out_path)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default="kyutai/helium-1-preview-2b", help="the transformers checkpoint to import")
parser.add_argument("--out", type=str, help="the mlx safetensors file to generate")
parser.add_argument(
"-s", "--silent", action="store_true", help="Only prints the checkpoint name"
)
args = parser.parse_args()

ckpt_path = Path(args.checkpoint)
if not ckpt_path.exists():
ckpt_path = hf_hub_download(repo_id=args.checkpoint, filename="model.safetensors")
out_path = Path(args.out)
if not out_path.exists():
import_model(ckpt_path, out_path, silent=args.silent)
print(out_path)


if __name__ == "__main__":
main()

0 comments on commit 5a5d188

Please sign in to comment.