diff --git a/moshi_mlx/moshi_mlx/run_helium.py b/moshi_mlx/moshi_mlx/run_helium.py index c0a560f..684cc48 100644 --- a/moshi_mlx/moshi_mlx/run_helium.py +++ b/moshi_mlx/moshi_mlx/run_helium.py @@ -6,6 +6,7 @@ import sentencepiece import huggingface_hub import mlx.core as mx +import mlx.nn as nn from moshi_mlx import models, utils @@ -17,6 +18,9 @@ def main(): parser.add_argument("--hf-repo", type=str, default="kyutai/helium-1-preview-2b-mlx") parser.add_argument("--prompt", type=str, default="Aujourd'hui, il est temps") parser.add_argument("--verbose", action="store_true") + parser.add_argument("--quantize-bits", type=int) + parser.add_argument("--save-quantized", type=str) + parser.add_argument("--quantize-group-size", type=int, default=64) args = parser.parse_args() weights = args.weights @@ -35,6 +39,11 @@ def main(): model = models.Lm(lm_config) model.set_dtype(mx.bfloat16) model.load_weights(weights, strict=True) + if args.quantize_bits is not None: + nn.quantize(model, bits=args.quantize_bits, group_size=args.quantize_group_size) + if args.save_quantized is not None: + print(f"saving quantized weights in {args.save_quantized}") + model.save_weights(args.save_quantized) sampler = utils.Sampler() tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore if args.verbose: