Skip to content

Commit

Permalink
Add a new llama_cpp engine (#14436)
Browse files Browse the repository at this point in the history
  • Loading branch information
maziyarpanahi authored Oct 18, 2024
1 parent 7eadf9e commit b314f62
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ final case object ONNX extends ModelEngine {
val decoderModel = "decoder_model.onnx"
val decoderWithPastModel = "decoder_with_past_model.onnx"
}

final case object Openvino extends ModelEngine {
val name = "openvino"
val ovModel = "openvino_model"
Expand All @@ -41,6 +42,10 @@ final case object Openvino extends ModelEngine {
val decoderModelWithPast = "openvino_decoder_with_past_model"
}

final case object LlamaCPP extends ModelEngine {
val name = "llama_cpp"
}

final case object Unknown extends ModelEngine {
val name = "unk"
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
package com.johnsnowlabs.nlp.annotators.seq2seq

import com.johnsnowlabs.ml.gguf.GGUFWrapper
import com.johnsnowlabs.ml.util.LlamaCPP
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.llama.LlamaModel
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -153,6 +153,8 @@ class AutoGGUFModel(override val uid: String)
this
}

private[johnsnowlabs] def setEngine(engineName: String): this.type = set(engine, engineName)

override def onWrite(path: String, spark: SparkSession): Unit = {
super.onWrite(path, spark)
getModelIfNotSet.saveToFile(path)
Expand Down Expand Up @@ -261,6 +263,7 @@ trait ReadAutoGGUFModel {
val annotatorModel = new AutoGGUFModel()
annotatorModel
.setModelIfNotSet(spark, GGUFWrapper.read(spark, localPath))
.setEngine(LlamaCPP.name)

val metadata = LlamaModel.getMetadataFromFile(localPath)
if (metadata.nonEmpty) annotatorModel.setMetadata(metadata)
Expand Down

0 comments on commit b314f62

Please sign in to comment.