From b314f626de99e60bace99a6bbb2c3237b16d2d9c Mon Sep 17 00:00:00 2001 From: Maziyar Panahi Date: Fri, 18 Oct 2024 18:16:34 +0200 Subject: [PATCH] Add a new llama_cpp engine (#14436) --- src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala | 5 +++++ .../johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala | 5 ++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala b/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala index 02ecbc1d626082..e75a3ce29c61a9 100644 --- a/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala +++ b/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala @@ -33,6 +33,7 @@ final case object ONNX extends ModelEngine { val decoderModel = "decoder_model.onnx" val decoderWithPastModel = "decoder_with_past_model.onnx" } + final case object Openvino extends ModelEngine { val name = "openvino" val ovModel = "openvino_model" @@ -41,6 +42,10 @@ final case object Openvino extends ModelEngine { val decoderModelWithPast = "openvino_decoder_with_past_model" } +final case object LlamaCPP extends ModelEngine { + val name = "llama_cpp" +} + final case object Unknown extends ModelEngine { val name = "unk" } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala index e681ce99888010..8049ba6b642473 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala @@ -16,9 +16,9 @@ package com.johnsnowlabs.nlp.annotators.seq2seq import com.johnsnowlabs.ml.gguf.GGUFWrapper +import com.johnsnowlabs.ml.util.LlamaCPP import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.util.io.ResourceHelper -import com.johnsnowlabs.nlp.llama.LlamaModel import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.SparkSession @@ -153,6 +153,8 @@ class AutoGGUFModel(override val uid: String) this } + private[johnsnowlabs] def setEngine(engineName: String): this.type = set(engine, engineName) + override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) getModelIfNotSet.saveToFile(path) @@ -261,6 +263,7 @@ trait ReadAutoGGUFModel { val annotatorModel = new AutoGGUFModel() annotatorModel .setModelIfNotSet(spark, GGUFWrapper.read(spark, localPath)) + .setEngine(LlamaCPP.name) val metadata = LlamaModel.getMetadataFromFile(localPath) if (metadata.nonEmpty) annotatorModel.setMetadata(metadata)