diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 010aaafbf..40adc05b2 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -25,6 +25,7 @@ def __init__(self, model: "Llama"): def llamacpp(model_path: str, device: Optional[str] = None, **model_kwargs) -> LlamaCpp: from llama_cpp import Llama + model_kwargs = model_kwargs.get("model_kwargs", {}) if device == "cuda": model_kwargs["n_gpu_layers"].setdefault(-1)