diff --git a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_AlbertForQuestionAnswering.ipynb b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_AlbertForQuestionAnswering.ipynb index 88ec07eec89f47..24fda0354440ba 100644 --- a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_AlbertForQuestionAnswering.ipynb +++ b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_AlbertForQuestionAnswering.ipynb @@ -75,7 +75,7 @@ } ], "source": [ - "!pip install -q --upgrade transformers[onnx]==4.29.1 optimum sentencepiece" + "!pip install -q --upgrade transformers[onnx]==4.29.1 optimum sentencepiece tensorflow" ] }, { diff --git a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_BERT.ipynb b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_BERT.ipynb index daadf8baf6eeb5..2eb92dde43fe45 100644 --- a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_BERT.ipynb +++ b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_BERT.ipynb @@ -70,7 +70,7 @@ } ], "source": [ - "!pip install -q --upgrade transformers[onnx]==4.29.1 optimum" + "!pip install -q --upgrade transformers[onnx]==4.29.1 optimum tensorflow" ] }, { @@ -497,11 +497,21 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" }, "widgets": { "application/vnd.jupyter.widget-state+json": { @@ -2219,5 +2229,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } diff --git a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_BertForQuestionAnswering.ipynb b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_BertForQuestionAnswering.ipynb index f2f777cc060588..7972181e127b7e 100644 --- a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_BertForQuestionAnswering.ipynb +++ b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_BertForQuestionAnswering.ipynb @@ -6,7 +6,7 @@ "source": [ "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/onnx/HuggingFace%20ONNX%20in%20Spark%20NLP%20-%20AlbertForQuestionAnswering.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_BertForQuestionAnswering.ipynb)" ] }, { diff --git a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_DistilBERT.ipynb b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_DistilBERT.ipynb index 678d8840ab629b..95d8c86216561d 100644 --- a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_DistilBERT.ipynb +++ b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_DistilBERT.ipynb @@ -70,7 +70,7 @@ } ], "source": [ - "!pip install -q --upgrade transformers[onnx]==4.29.1 optimum" + "!pip install -q --upgrade transformers[onnx]==4.29.1 optimum tensorflow" ] }, { @@ -503,11 +503,21 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" }, "widgets": { "application/vnd.jupyter.widget-state+json": { @@ -2225,5 +2235,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } diff --git a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_E5.ipynb b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_E5.ipynb index a0f8755c0a48f2..8e2b4ccfabb266 100644 --- a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_E5.ipynb +++ b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_E5.ipynb @@ -66,7 +66,7 @@ } ], "source": [ - "!pip install -q --upgrade transformers[onnx]==4.29.1 optimum" + "!pip install -q --upgrade transformers[onnx]==4.29.1 optimum tensorflow" ] }, { @@ -376,13 +376,23 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } diff --git a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_MPNet.ipynb b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_MPNet.ipynb index 02c6a6b39d4c3a..af5cebcfe3e9ba 100644 --- a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_MPNet.ipynb +++ b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_MPNet.ipynb @@ -66,7 +66,7 @@ } ], "source": [ - "!pip install -q --upgrade transformers[onnx]==4.33.1 optimum" + "!pip install -q --upgrade transformers[onnx]==4.33.1 optimum tensorflow" ] }, { @@ -465,11 +465,21 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" }, "widgets": { "application/vnd.jupyter.widget-state+json": { @@ -2529,5 +2539,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } diff --git a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_Whisper.ipynb b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_Whisper.ipynb index 29079154b073d9..262cad47cf6a91 100644 --- a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_Whisper.ipynb +++ b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_Whisper.ipynb @@ -64,7 +64,7 @@ } ], "source": [ - "!pip install -q --upgrade \"transformers[onnx]==4.31.0\" optimum" + "!pip install -q --upgrade \"transformers[onnx]==4.31.0\" optimum tensorflow" ] }, { @@ -550,4 +550,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala index bd4846945dc4bc..51e1b4b847011b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala @@ -18,7 +18,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} @@ -83,6 +83,7 @@ private[johnsnowlabs] class Albert( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions // keys representing the input and output tensors of the ALBERT model private val SentenceStartTokenId = spp.getSppModel.pieceToId("[CLS]") @@ -107,7 +108,7 @@ private[johnsnowlabs] class Albert( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala index 4a0219b6c7a630..11e29c8c004474 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} @@ -56,6 +56,7 @@ private[johnsnowlabs] class AlbertClassification( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions // keys representing the input and output tensors of the ALBERT model protected val sentencePadTokenId: Int = spp.getSppModel.pieceToId("[pad]") @@ -210,7 +211,7 @@ private[johnsnowlabs] class AlbertClassification( val output = if (sequence) "logits" else "last_hidden_state" // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -348,7 +349,7 @@ private[johnsnowlabs] class AlbertClassification( batch: Seq[Array[Int]], maxSentenceLength: Int): (Array[Float], Array[Float]) = { // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala index 849e0a380bfd48..e63c5d4e0851d5 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala @@ -18,7 +18,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} @@ -66,6 +66,7 @@ private[johnsnowlabs] class Bert( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions private def sessionWarmup(): Unit = { val dummyInput = @@ -90,7 +91,7 @@ private[johnsnowlabs] class Bert( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -193,7 +194,7 @@ private[johnsnowlabs] class Bert( val embeddings = detectedEngine match { case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala index cce6f83988fcdd..7b4dfaf233f879 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} @@ -62,6 +62,7 @@ private[johnsnowlabs] class BertClassification( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions def tokenizeWithAlignment( sentences: Seq[TokenizedSentence], @@ -222,7 +223,7 @@ private[johnsnowlabs] class BertClassification( maxSentenceLength: Int): Array[Float] = { // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -452,7 +453,7 @@ private[johnsnowlabs] class BertClassification( batch: Seq[Array[Int]], maxSentenceLength: Int): (Array[Float], Array[Float]) = { // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/CamemBert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/CamemBert.scala index eb1d421b70ce0b..d8995f67243383 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/CamemBert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/CamemBert.scala @@ -18,7 +18,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} @@ -54,6 +54,7 @@ private[johnsnowlabs] class CamemBert( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions /** HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated * in the actual # sentencepiece vocabulary (this is the case for '''''' and '''''') @@ -82,7 +83,7 @@ private[johnsnowlabs] class CamemBert( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala index bbf4ac83b1862b..94e28f264ef471 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala @@ -18,7 +18,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sentencepiece._ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} @@ -52,6 +52,7 @@ class DeBerta( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions // keys representing the input and output tensors of the DeBERTa model private val SentenceStartTokenId = spp.getSppModel.pieceToId("[CLS]") @@ -68,7 +69,7 @@ class DeBerta( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala index afa6a3b8bb29d5..439bbd3d53f162 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala @@ -18,7 +18,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} @@ -82,6 +82,7 @@ private[johnsnowlabs] class DistilBert( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions private def sessionWarmup(): Unit = { val dummyInput = @@ -103,7 +104,7 @@ private[johnsnowlabs] class DistilBert( val embeddings = detectedEngine match { case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala index 2408ab36ee194c..099622429ecf80 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} @@ -60,6 +60,7 @@ private[johnsnowlabs] class DistilBertClassification( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions protected val sentencePadTokenId = 0 protected val sigmoidThreshold: Float = threshold @@ -212,7 +213,7 @@ private[johnsnowlabs] class DistilBertClassification( private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = { - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -409,7 +410,7 @@ private[johnsnowlabs] class DistilBertClassification( } private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = { - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/E5.scala b/src/main/scala/com/johnsnowlabs/ml/ai/E5.scala index 71b6f37d5d2100..0ffbdd980cfbe3 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/E5.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/E5.scala @@ -16,8 +16,8 @@ package com.johnsnowlabs.ml.ai -import ai.onnxruntime.{OnnxTensor, TensorInfo} -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import ai.onnxruntime.OnnxTensor +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{LinAlg, ONNX, TensorFlow} @@ -55,6 +55,7 @@ private[johnsnowlabs] class E5( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions /** Get sentence embeddings for a batch of sentences * @param batch @@ -167,10 +168,8 @@ private[johnsnowlabs] class E5( val tokenTensors = OnnxTensor.createTensor(env, inputIds) val maskTensors = OnnxTensor.createTensor(env, attentionMask) - val segmentTensors = OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray) - val inputs = Map( "input_ids" -> tokenTensors, diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala b/src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala index 610cb1072a0fd2..3d48e622f908f2 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.{OnnxTensor, TensorInfo} -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{LinAlg, ONNX, TensorFlow} @@ -56,6 +56,7 @@ private[johnsnowlabs] class MPNet( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions /** Get sentence embeddings for a batch of sentences * @param batch @@ -165,14 +166,12 @@ private[johnsnowlabs] class MPNet( } private def getSentenceEmbeddingFromOnnx(batch: Seq[Array[Int]]): Array[Array[Float]] = { - val inputIds = batch.map(x => x.map(x => x.toLong)).toArray val attentionMask = batch.map(sentence => sentence.map(x => if (x < 0L) 0L else 1L)).toArray - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, inputIds) val maskTensors = OnnxTensor.createTensor(env, attentionMask) - val inputs = Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala b/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala index 1e903ff0d4a345..d9e0d1a96e62f0 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala @@ -18,7 +18,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} @@ -57,6 +57,7 @@ private[johnsnowlabs] class RoBerta( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions private def sessionWarmup(): Unit = { val dummyInput = @@ -79,7 +80,7 @@ private[johnsnowlabs] class RoBerta( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala index 3a2b18a9d68868..054d1eff76f2d5 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} @@ -63,6 +63,7 @@ private[johnsnowlabs] class RoBertaClassification( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions protected val sigmoidThreshold: Float = threshold @@ -209,7 +210,7 @@ private[johnsnowlabs] class RoBertaClassification( private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = { // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -407,7 +408,7 @@ private[johnsnowlabs] class RoBertaClassification( private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = { // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala index aa45f640a7a6b7..7da6d8253eb25f 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.ai.util.Generation.Logit.LogitProcess.{ SuppressLogitProcessor } import com.johnsnowlabs.ml.ai.util.Generation.Logit.LogitProcessorList +import com.johnsnowlabs.ml.onnx.OnnxSession import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWrappers import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ import com.johnsnowlabs.ml.tensorflow @@ -102,7 +103,7 @@ private[johnsnowlabs] class Whisper( else throw new IllegalArgumentException("No model engine defined.") private val tfTensorResources = new tensorflow.TensorResources() -// val onnxTensorResources = new onnx.TensorResources(OrtEnvironment.getEnvironment()) + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions private object TfSignatures { object InputOps { @@ -323,9 +324,10 @@ private[johnsnowlabs] class Whisper( tokenIds case ONNX.name => - val (encoderSession, env) = onnxWrappers.get.encoder.getSession() - val decoderSession = onnxWrappers.get.decoder.getSession()._1 - val decoderWithPastSession = onnxWrappers.get.decoderWithPast.getSession()._1 + val (encoderSession, env) = onnxWrappers.get.encoder.getSession(onnxSessionOptions) + val decoderSession = onnxWrappers.get.decoder.getSession(onnxSessionOptions)._1 + val decoderWithPastSession = + onnxWrappers.get.decoderWithPast.getSession(onnxSessionOptions)._1 val encodedBatchTensor: OnnxTensor = encode(featuresBatch, None, Some((encoderSession, env))).asInstanceOf[OnnxTensor] diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoberta.scala b/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoberta.scala index 9f26ede6370c72..0ecd3d70ef8962 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoberta.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoberta.scala @@ -18,7 +18,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} @@ -88,6 +88,7 @@ private[johnsnowlabs] class XlmRoberta( if (tensorflowWrapper.isDefined) TensorFlow.name else if (onnxWrapper.isDefined) ONNX.name else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions private val SentenceStartTokenId = 0 private val SentenceEndTokenId = 2 @@ -115,7 +116,7 @@ private[johnsnowlabs] class XlmRoberta( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala index 5ea77ed02dafca..c9e2f2890ee72f 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -95,11 +95,7 @@ trait ReadOnnxModel { val localPath = new Path(tmpFolder, onnxFile).toString // 3. Read ONNX state - val onnxWrapper = OnnxWrapper.read( - localPath, - zipped = zipped, - useBundle = useBundle, - sessionOptions = sessionOptions) + val onnxWrapper = OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle) // 4. Remove tmp folder FileHelper.delete(tmpFolder) @@ -113,8 +109,7 @@ trait ReadOnnxModel { modelNames: Seq[String], suffix: String, zipped: Boolean = true, - useBundle: Boolean = false, - sessionOptions: Option[SessionOptions] = None): Map[String, OnnxWrapper] = { + useBundle: Boolean = false): Map[String, OnnxWrapper] = { val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) @@ -133,11 +128,7 @@ trait ReadOnnxModel { val localPath = new Path(tmpFolder, localModelFile).toString // 3. Read ONNX state - val onnxWrapper = OnnxWrapper.read( - localPath, - zipped = zipped, - useBundle = useBundle, - sessionOptions = sessionOptions) + val onnxWrapper = OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle) (modelName, onnxWrapper) }).toMap diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSession.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSession.scala new file mode 100644 index 00000000000000..a615856122993d --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSession.scala @@ -0,0 +1,54 @@ +/* + * Copyright 2017-2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.johnsnowlabs.ml.onnx + +import ai.onnxruntime.OrtEnvironment +import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader} +import org.slf4j.{Logger, LoggerFactory} + +import java.io.Serializable + +class OnnxSession extends Serializable { + + // Important for serialization on none-kyro serializers + @transient val logger: Logger = LoggerFactory.getLogger("OnnxSession") + + def getSessionOptions: Map[String, String] = { + val providers = OrtEnvironment.getAvailableProviders + if (providers.toArray.map(x => x.toString).contains("CUDA")) { + getCUDASessionConfig + } else getCPUSessionConfig + } + + private def getCUDASessionConfig: Map[String, String] = { + val gpuDeviceId = ConfigLoader.getConfigIntValue(ConfigHelper.onnxGpuDeviceId) + Map(ConfigHelper.onnxGpuDeviceId -> gpuDeviceId.toString) + } + + private def getCPUSessionConfig: Map[String, String] = { + val intraOpNumThreads = + ConfigLoader.getConfigIntValue(ConfigHelper.onnxIntraOpNumThreads) + val optimizationLevel = + ConfigLoader.getConfigStringValue(ConfigHelper.onnxOptimizationLevel) + val executionMode = + ConfigLoader.getConfigStringValue(ConfigHelper.onnxExecutionMode) + + Map(ConfigHelper.onnxIntraOpNumThreads -> intraOpNumThreads.toString) ++ + Map(ConfigHelper.onnxOptimizationLevel -> optimizationLevel) ++ + Map(ConfigHelper.onnxExecutionMode -> executionMode) + } + +} diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index 4d09c4d3221b09..7f4fb80fcff0e5 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -20,7 +20,7 @@ import ai.onnxruntime.OrtSession.SessionOptions import ai.onnxruntime.OrtSession.SessionOptions.{ExecutionMode, OptLevel} import ai.onnxruntime.providers.OrtCUDAProviderOptions import ai.onnxruntime.{OrtEnvironment, OrtSession} -import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader, FileHelper, ZipArchiveUtil} +import com.johnsnowlabs.util.{ConfigHelper, FileHelper, ZipArchiveUtil} import org.apache.commons.io.FileUtils import org.slf4j.{Logger, LoggerFactory} @@ -37,18 +37,18 @@ class OnnxWrapper(var onnxModel: Array[Byte]) extends Serializable { } // Important for serialization on none-kyro serializers - @transient private var m_session: OrtSession = _ - @transient private var m_env: OrtEnvironment = _ - @transient private val logger = LoggerFactory.getLogger("OnnxWrapper") + @transient private var ortSession: OrtSession = _ + @transient private var ortEnv: OrtEnvironment = _ - def getSession(sessionOptions: Option[SessionOptions] = None): (OrtSession, OrtEnvironment) = + def getSession(onnxSessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) = this.synchronized { - if (m_session == null && m_env == null) { - val (session, env) = OnnxWrapper.withSafeOnnxModelLoader(onnxModel, sessionOptions) - m_env = env - m_session = session + // TODO: After testing it works remove the Map.empty + if (ortSession == null && ortEnv == null) { + val (session, env) = OnnxWrapper.withSafeOnnxModelLoader(onnxModel, onnxSessionOptions) + ortEnv = env + ortSession = session } - (m_session, m_env) + (ortSession, ortEnv) } def saveToFile(file: String, zip: Boolean = true): Unit = { @@ -81,18 +81,16 @@ object OnnxWrapper { // TODO: make sure this.synchronized is needed or it's not a bottleneck private def withSafeOnnxModelLoader( onnxModel: Array[Byte], - sessionOptions: Option[SessionOptions] = None): (OrtSession, OrtEnvironment) = + sessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) = this.synchronized { val env = OrtEnvironment.getEnvironment() - val providers = OrtEnvironment.getAvailableProviders - - val sessionOptionsConfig = if (providers.toArray.map(x => x.toString).contains("CUDA")) { - getCUDASessionConfig + val sessionOptionsObject = if (sessionOptions.isEmpty) { + new SessionOptions() } else { - getCPUSessionConfig + mapToSessionOptionsObject(sessionOptions) } - val session = env.createSession(onnxModel, sessionOptionsConfig) + val session = env.createSession(onnxModel, sessionOptionsObject) (session, env) } @@ -100,8 +98,7 @@ object OnnxWrapper { modelPath: String, zipped: Boolean = true, useBundle: Boolean = false, - modelName: String = "model", - sessionOptions: Option[SessionOptions] = None): OnnxWrapper = { + modelName: String = "model"): OnnxWrapper = { // 1. Create tmp folder val tmpFolder = Files @@ -116,39 +113,39 @@ object OnnxWrapper { else modelPath - // TODO: simplify this logic of useBundle - val (session, env, modelBytes) = - if (useBundle) { - val onnxFile = Paths.get(modelPath, s"$modelName.onnx").toString - val modelFile = new File(onnxFile) - val modelBytes = FileUtils.readFileToByteArray(modelFile) - val (session, env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) - (session, env, modelBytes) - } else { - val modelFile = new File(folder).list().head - val fullPath = Paths.get(folder, modelFile).toFile - val modelBytes = FileUtils.readFileToByteArray(fullPath) - val (session, env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) - (session, env, modelBytes) - } + val sessionOptions = new OnnxSession().getSessionOptions + val onnxFile = + if (useBundle) Paths.get(modelPath, s"$modelName.onnx").toString + else Paths.get(folder, new File(folder).list().head).toString + val modelFile = new File(onnxFile) + val modelBytes = FileUtils.readFileToByteArray(modelFile) + val (session, env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) // 4. Remove tmp folder FileHelper.delete(tmpFolder) val onnxWrapper = new OnnxWrapper(modelBytes) - onnxWrapper.m_session = session - onnxWrapper.m_env = env + onnxWrapper.ortSession = session + onnxWrapper.ortEnv = env onnxWrapper } - private def getCUDASessionConfig: SessionOptions = { + private def mapToSessionOptionsObject(sessionOptions: Map[String, String]): SessionOptions = { + val providers = OrtEnvironment.getAvailableProviders + if (providers.toArray.map(x => x.toString).contains("CUDA")) { + mapToCUDASessionConfig(sessionOptions) + } else mapToCPUSessionConfig(sessionOptions) + } + + private def mapToCUDASessionConfig(sessionOptionsMap: Map[String, String]): SessionOptions = { logger.info("Using CUDA") + println("Using CUDA") // it seems there is no easy way to use multiple GPUs // at least not without using multiple threads // TODO: add support for multiple GPUs - val gpuDeviceId = ConfigLoader.getConfigIntValue(ConfigHelper.onnxGpuDeviceId) + val gpuDeviceId = sessionOptionsMap(ConfigHelper.onnxGpuDeviceId).toInt val sessionOptions = new OrtSession.SessionOptions() logger.info(s"ONNX session option gpuDeviceId=$gpuDeviceId") @@ -158,7 +155,7 @@ object OnnxWrapper { sessionOptions } - private def getCPUSessionConfig: SessionOptions = { + private def mapToCPUSessionConfig(sessionOptionsMap: Map[String, String]): SessionOptions = { val defaultExecutionMode = ExecutionMode.SEQUENTIAL val defaultOptLevel = OptLevel.ALL_OPT @@ -186,17 +183,16 @@ object OnnxWrapper { } logger.info("Using CPUs") + println("Using CPUs") // TODO: the following configs can be tested for performance // However, so far, they seem to be slower than the ones used // opts.setIntraOpNumThreads(Runtime.getRuntime.availableProcessors()) // opts.setMemoryPatternOptimization(true) // opts.setCPUArenaAllocator(false) - val intraOpNumThreads = ConfigLoader.getConfigIntValue(ConfigHelper.onnxIntraOpNumThreads) - val optimizationLevel = getOptLevel( - ConfigLoader.getConfigStringValue(ConfigHelper.onnxOptimizationLevel)) - val executionMode = getExecutionMode( - ConfigLoader.getConfigStringValue(ConfigHelper.onnxExecutionMode)) + val intraOpNumThreads = sessionOptionsMap(ConfigHelper.onnxIntraOpNumThreads).toInt + val optimizationLevel = getOptLevel(sessionOptionsMap(ConfigHelper.onnxOptimizationLevel)) + val executionMode = getExecutionMode(sessionOptionsMap(ConfigHelper.onnxExecutionMode)) val sessionOptions = new OrtSession.SessionOptions() logger.info(s"ONNX session option intraOpNumThreads=$intraOpNumThreads") diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddingsTestSpec.scala index 4cab3dc9791bf7..634a3717909ec6 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddingsTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddingsTestSpec.scala @@ -229,4 +229,5 @@ class BertEmbeddingsTestSpec extends AnyFlatSpec { assert(totalTokens == totalEmbeddings) } + }