Skip to content

Commit

Permalink
Update CamemBertForZeroShotClassification.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedlone127 committed Sep 10, 2024
1 parent 10f1fe3 commit 35d5d80
Showing 1 changed file with 39 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.ml.ai.CamemBertClassification
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel}
import com.johnsnowlabs.ml.tensorflow.{
ReadTensorflowModel,
TensorflowWrapper,
Expand All @@ -34,7 +35,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common.{SentenceSplit, TokenizedWithSentence}
import com.johnsnowlabs.nlp.serialization.MapFeature
import com.johnsnowlabs.nlp.{
Expand All @@ -59,6 +60,7 @@ class CamemBertForZeroShotClassification(override val uid: String)
with HasBatchedAnnotate[CamemBertForZeroShotClassification]
with WriteTensorflowModel
with WriteOnnxModel
with WriteOpenvinoModel
with WriteSentencePieceModel
with HasCaseSensitiveProperties
with HasClassifierActivationProperties
Expand Down Expand Up @@ -178,13 +180,15 @@ class CamemBertForZeroShotClassification(override val uid: String)
spark: SparkSession,
tensorflowWrapper: Option[TensorflowWrapper],
onnxWrapper: Option[OnnxWrapper],
openvinoWrapper: Option[OpenvinoWrapper],
spp: SentencePieceWrapper): CamemBertForZeroShotClassification = {
if (_model.isEmpty) {
_model = Some(
spark.sparkContext.broadcast(
new CamemBertClassification(
tensorflowWrapper,
onnxWrapper,
openvinoWrapper,
spp,
configProtoBytes = None,
tags = $$(labels),
Expand Down Expand Up @@ -269,6 +273,15 @@ class CamemBertForZeroShotClassification(override val uid: String)
getModelIfNotSet.onnxWrapper.get,
suffix,
CamemBertForSequenceClassification.onnxFile)

case Openvino.name =>
writeOpenvinoModel(
path,
spark,
getModelIfNotSet.openvinoWrapper.get,
"openvino_model.xml",
CamemBertForSequenceClassification.openvinoFile)

}

writeSentencePieceModel(
Expand Down Expand Up @@ -305,11 +318,13 @@ trait ReadPretrainedCamemBertForZeroShotClassification
trait ReadCamemBertForZeroShotClassification
extends ReadTensorflowModel
with ReadOnnxModel
with ReadSentencePieceModel {
with ReadSentencePieceModel
with ReadOpenvinoModel {
this: ParamsAndFeaturesReadable[CamemBertForZeroShotClassification] =>

override val tfFile: String = "camembert_classification_tensorflow"
override val onnxFile: String = "camembert_classification_onnx"
override val openvinoFile: String = "camembert_classification_openvino"
override val sppFile: String = "camembert_spp"

def readModel(
Expand All @@ -322,7 +337,7 @@ trait ReadCamemBertForZeroShotClassification
instance.getEngine match {
case TensorFlow.name =>
val tfWrapper = readTensorflowModel(path, spark, "_camembert_classification_tf")
instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp)
instance.setModelIfNotSet(spark, Some(tfWrapper), None, None, spp)
case ONNX.name =>
val onnxWrapper =
readOnnxModel(
Expand All @@ -332,11 +347,16 @@ trait ReadCamemBertForZeroShotClassification
zipped = true,
useBundle = false,
None)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper), None, spp)

case Openvino.name =>
val openvinoWrapper = readOpenvinoModel(path, spark, "_camembert_classification_ov")
instance.setModelIfNotSet(spark, None, None, Some(openvinoWrapper), spp)

case _ =>
throw new Exception(notSupportedEngineError)
}

}
}

addReader(readModel)
Expand Down Expand Up @@ -392,11 +412,23 @@ trait ReadCamemBertForZeroShotClassification
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, Some(wrapper), None, spModel)
.setModelIfNotSet(spark, Some(wrapper), None, None, spModel)
case ONNX.name =>
val onnxWrapper =
OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true)
annotatorModel.setModelIfNotSet(spark, None, Some(onnxWrapper), spModel)
annotatorModel.setModelIfNotSet(spark, None, Some(onnxWrapper), None, spModel)

case Openvino.name =>
val ovWrapper: OpenvinoWrapper =
OpenvinoWrapper.read(
spark,
localModelPath,
zipped = false,
useBundle = true,
detectedEngine = detectedEngine)
annotatorModel
.setModelIfNotSet(spark, None, None, Some(ovWrapper), spModel)

case _ =>
throw new Exception(notSupportedEngineError)
}
Expand Down

0 comments on commit 35d5d80

Please sign in to comment.