diff --git a/python/sparknlp/annotator/seq2seq/__init__.py b/python/sparknlp/annotator/seq2seq/__init__.py index f55474504816ee..76e34a8c774969 100644 --- a/python/sparknlp/annotator/seq2seq/__init__.py +++ b/python/sparknlp/annotator/seq2seq/__init__.py @@ -19,4 +19,5 @@ from sparknlp.annotator.seq2seq.bart_transformer import * from sparknlp.annotator.seq2seq.llama2_transformer import * from sparknlp.annotator.seq2seq.m2m100_transformer import * +from sparknlp.annotator.seq2seq.phi2_transformer import * from sparknlp.annotator.seq2seq.mistral_transformer import * diff --git a/python/sparknlp/annotator/seq2seq/m2m100_transformer.py b/python/sparknlp/annotator/seq2seq/m2m100_transformer.py index effed4ad82d6ad..bdef4546f49946 100644 --- a/python/sparknlp/annotator/seq2seq/m2m100_transformer.py +++ b/python/sparknlp/annotator/seq2seq/m2m100_transformer.py @@ -350,7 +350,7 @@ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.M2M100Tran tgtLang="fr") @staticmethod - def loadSavedModel(folder, spark_session): + def loadSavedModel(folder, spark_session, use_openvino=False): """Loads a locally saved model. Parameters @@ -366,7 +366,7 @@ def loadSavedModel(folder, spark_session): The restored model """ from sparknlp.internal import _M2M100Loader - jModel = _M2M100Loader(folder, spark_session._jsparkSession)._java_obj + jModel = _M2M100Loader(folder, spark_session._jsparkSession, use_openvino)._java_obj return M2M100Transformer(java_model=jModel) @staticmethod diff --git a/python/sparknlp/annotator/seq2seq/phi2_transformer.py b/python/sparknlp/annotator/seq2seq/phi2_transformer.py new file mode 100644 index 00000000000000..e7cf7604da03c4 --- /dev/null +++ b/python/sparknlp/annotator/seq2seq/phi2_transformer.py @@ -0,0 +1,326 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for the Phi2Transformer.""" + +from sparknlp.common import * + + +class Phi2Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine): + """Phi-2: Textbooks Are All You Need. + + Phi-2 is a Transformer with 2.7 billion parameters. It was trained using the same data sources as Phi-1.5, + augmented with a new data source that consists of various NLP synthetic texts and filtered websites + (for safety and educational value). When assessed against benchmarks testing common sense, language understanding, + and logical reasoning, Phi-2 showcased a nearly state-of-the-art performance among models with less than 13 billion + parameters. + + Phi-2 hasn't been fine-tuned through reinforcement learning from human feedback. The intention behind crafting + this open-source model is to provide the research community with a non-restricted small model to explore vital + safety challenges, such as reducing toxicity, understanding societal biases, enhancing controllability, and more. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> phi2 = Phi2Transformer.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("generation") + + + The default model is ``"llam2-7b"``, if no name is provided. For available + pretrained models please see the `Models Hub + `__. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``DOCUMENT`` + ====================== ====================== + + Parameters + ---------- + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + minOutputLength + Minimum length of the sequence to be generated, by default 0 + maxOutputLength + Maximum length of output text, by default 20 + doSample + Whether or not to use sampling; use greedy decoding otherwise, by default False + temperature + The value used to module the next token probabilities, by default 1.0 + topK + The number of highest probability vocabulary tokens to keep for + top-k-filtering, by default 50 + topP + Top cumulative probability for vocabulary tokens, by default 1.0 + + If set to float < 1, only the most probable tokens with probabilities + that add up to ``topP`` or higher are kept for generation. + repetitionPenalty + The parameter for repetition penalty, 1.0 means no penalty. , by default + 1.0 + noRepeatNgramSize + If set to int > 0, all ngrams of that size can only occur once, by + default 0 + ignoreTokenIds + A list of token ids which are ignored in the decoder's output, by + default [] + + Notes + ----- + This is a very computationally expensive module especially on larger + sequence. The use of an accelerator such as GPU is recommended. + + References + ---------- + - `Phi-2: Textbooks Are All You Need. + `__ + - https://huggingface.co/microsoft/phi-2 + + **Paper Abstract:** + + *In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned + large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our + fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models + outperform open-source chat models on most benchmarks we tested, and based on our human + evaluations for helpfulness and safety, may be a suitable substitute for closed-source models. + We provide a detailed description of our approach to fine-tuning and safety improvements of + Llama 2-Chat in order to enable the community to build on our work and contribute to the + responsible development of LLMs.* + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("documents") + >>> phi2 = Phi2Transformer.pretrained("phi2-7b") \\ + ... .setInputCols(["documents"]) \\ + ... .setMaxOutputLength(50) \\ + ... .setOutputCol("generation") + >>> pipeline = Pipeline().setStages([documentAssembler, phi2]) + >>> data = spark.createDataFrame([["My name is Leonardo."]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.select("summaries.generation").show(truncate=False) + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + |result | + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + |[My name is Leonardo . I am a student of the University of California, Berkeley. I am interested in the field of Artificial Intelligence and its applications in the real world. I have a strong | + | passion for learning and am always looking for ways to improve my knowledge and skills] | + -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + """ + + name = "Phi2Transformer" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.DOCUMENT + + configProtoBytes = Param(Params._dummy(), "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated", + typeConverter=TypeConverters.toInt) + + maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text", + typeConverter=TypeConverters.toInt) + + doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise", + typeConverter=TypeConverters.toBoolean) + + temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities", + typeConverter=TypeConverters.toFloat) + + topK = Param(Params._dummy(), "topK", + "The number of highest probability vocabulary tokens to keep for top-k-filtering", + typeConverter=TypeConverters.toInt) + + topP = Param(Params._dummy(), "topP", + "If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation", + typeConverter=TypeConverters.toFloat) + + repetitionPenalty = Param(Params._dummy(), "repetitionPenalty", + "The parameter for repetition penalty. 1.0 means no penalty. See `this paper `__ for more details", + typeConverter=TypeConverters.toFloat) + + noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize", + "If set to int > 0, all ngrams of that size can only occur once", + typeConverter=TypeConverters.toInt) + + ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds", + "A list of token ids which are ignored in the decoder's output", + typeConverter=TypeConverters.toListInt) + + def setIgnoreTokenIds(self, value): + """A list of token ids which are ignored in the decoder's output. + + Parameters + ---------- + value : List[int] + The words to be filtered out + """ + return self._set(ignoreTokenIds=value) + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + def setMinOutputLength(self, value): + """Sets minimum length of the sequence to be generated. + + Parameters + ---------- + value : int + Minimum length of the sequence to be generated + """ + return self._set(minOutputLength=value) + + def setMaxOutputLength(self, value): + """Sets maximum length of output text. + + Parameters + ---------- + value : int + Maximum length of output text + """ + return self._set(maxOutputLength=value) + + def setDoSample(self, value): + """Sets whether or not to use sampling, use greedy decoding otherwise. + + Parameters + ---------- + value : bool + Whether or not to use sampling; use greedy decoding otherwise + """ + return self._set(doSample=value) + + def setTemperature(self, value): + """Sets the value used to module the next token probabilities. + + Parameters + ---------- + value : float + The value used to module the next token probabilities + """ + return self._set(temperature=value) + + def setTopK(self, value): + """Sets the number of highest probability vocabulary tokens to keep for + top-k-filtering. + + Parameters + ---------- + value : int + Number of highest probability vocabulary tokens to keep + """ + return self._set(topK=value) + + def setTopP(self, value): + """Sets the top cumulative probability for vocabulary tokens. + + If set to float < 1, only the most probable tokens with probabilities + that add up to ``topP`` or higher are kept for generation. + + Parameters + ---------- + value : float + Cumulative probability for vocabulary tokens + """ + return self._set(topP=value) + + def setRepetitionPenalty(self, value): + """Sets the parameter for repetition penalty. 1.0 means no penalty. + + Parameters + ---------- + value : float + The repetition penalty + + References + ---------- + See `Ctrl: A Conditional Transformer Language Model For Controllable + Generation `__ for more details. + """ + return self._set(repetitionPenalty=value) + + def setNoRepeatNgramSize(self, value): + """Sets size of n-grams that can only occur once. + + If set to int > 0, all ngrams of that size can only occur once. + + Parameters + ---------- + value : int + N-gram size can only occur once + """ + return self._set(noRepeatNgramSize=value) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.Phi2Transformer", java_model=None): + super(Phi2Transformer, self).__init__(classname=classname, java_model=java_model) + self._setDefault(minOutputLength=0, maxOutputLength=20, doSample=False, temperature=0.6, topK=50, topP=0.9, + repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], batchSize=1) + + @staticmethod + def loadSavedModel(folder, spark_session, use_openvino=False): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + Phi2Transformer + The restored model + """ + from sparknlp.internal import _Phi2Loader + jModel = _Phi2Loader(folder, spark_session._jsparkSession, use_openvino)._java_obj + return Phi2Transformer(java_model=jModel) + + @staticmethod + def pretrained(name="phi2-7b", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "phi2-7b" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + Phi2Transformer + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(Phi2Transformer, name, lang, remote_loc) diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index c76d830e682658..deeff9c5189f52 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -268,7 +268,7 @@ def __init__(self, path, jspark): class _M2M100Loader(ExtendedJavaWrapper): - def __init__(self, path, jspark): + def __init__(self, path, jspark, use_openvino=False): super(_M2M100Loader, self).__init__( "com.johnsnowlabs.nlp.annotators.seq2seq.M2M100Transformer.loadSavedModel", path, @@ -279,7 +279,12 @@ def __init__(self, path, jspark): class _MistralLoader(ExtendedJavaWrapper): def __init__(self, path, jspark, use_openvino=False): super(_MistralLoader, self).__init__( - "com.johnsnowlabs.nlp.annotators.seq2seq.MistralTransformer.loadSavedModel", path, jspark, use_openvino) + "com.johnsnowlabs.nlp.annotators.seq2seq.MistralTransformer.loadSavedModel", + path, + jspark, + use_openvino, + ) + class _MarianLoader(ExtendedJavaWrapper): def __init__(self, path, jspark): @@ -299,6 +304,16 @@ def __init__(self, path, jspark): ) +class _Phi2Loader(ExtendedJavaWrapper): + def __init__(self, path, jspark, use_openvino=False): + super(_Phi2Loader, self).__init__( + "com.johnsnowlabs.nlp.annotators.seq2seq.Phi2Transformer.loadSavedModel", + path, + jspark, + use_openvino, + ) + + class _RoBertaLoader(ExtendedJavaWrapper): def __init__(self, path, jspark, use_openvino=False): super(_RoBertaLoader, self).__init__( diff --git a/python/test/annotator/seq2seq/phi2_transformer_test.py b/python/test/annotator/seq2seq/phi2_transformer_test.py new file mode 100644 index 00000000000000..b434424c655b58 --- /dev/null +++ b/python/test/annotator/seq2seq/phi2_transformer_test.py @@ -0,0 +1,47 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.util import SparkContextForTest + + +@pytest.mark.slow +class Phi2TransformerTextGenerationTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + + def runTest(self): + data = self.spark.createDataFrame([ + [1, """Leonardo Da Vinci invented the microscope?""".strip().replace("\n", " ")]]).toDF("id", "text") + + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("documents") + + phi2 = Phi2Transformer \ + .pretrained() \ + .setMaxOutputLength(50) \ + .setDoSample(False) \ + .setInputCols(["documents"]) \ + .setOutputCol("generation") + + pipeline = Pipeline().setStages([document_assembler, phi2]) + results = pipeline.fit(data).transform(data) + + results.select("generation.result").show(truncate=False) + diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala b/src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala index 7cc5f4ff8cc302..0ae56e53768e3b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala @@ -20,17 +20,23 @@ import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig} import com.johnsnowlabs.ml.onnx.OnnxSession import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWithoutPastWrappers +import com.johnsnowlabs.ml.openvino.OpenvinoWrapper.{ + EncoderDecoderWithoutPastWrappers => OpenvinoEncoderDecoderWithoutPastWrappers +} import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ +import com.johnsnowlabs.ml.openvino.OpenvinoWrapper import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper import com.johnsnowlabs.nlp.Annotation +import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow} + +import scala.collection.JavaConverters._ import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT import org.intel.openvino.InferRequest import org.tensorflow.{Session, Tensor} -import scala.collection.JavaConverters._ - private[johnsnowlabs] class M2M100( - val onnxWrappers: EncoderDecoderWithoutPastWrappers, + val onnxWrappers: Option[EncoderDecoderWithoutPastWrappers], + val openvinoWrapper: Option[OpenvinoEncoderDecoderWithoutPastWrappers], val spp: SentencePieceWrapper, generationConfig: GenerationConfig, vocab: Map[String, Int]) @@ -38,6 +44,14 @@ private[johnsnowlabs] class M2M100( with Generate { private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions + private var nextPositionId: Option[Array[Long]] = None + private var decoderEncoderStateTensorsOV: Option[org.intel.openvino.Tensor] = None + private var encoderAttentionMaskTensorsOV: Option[org.intel.openvino.Tensor] = None + + val detectedEngine: String = + if (onnxWrappers.isDefined) ONNX.name + else if (openvinoWrapper.isDefined) Openvino.name + else ONNX.name private val GenerationConfig( bosTokenId: Int, @@ -133,8 +147,7 @@ private[johnsnowlabs] class M2M100( maxInputLength: Int, srcLangToken: Int, tgtLangToken: Int): Array[Array[Int]] = { - val (encoderSession, encoderEnv) = onnxWrappers.encoder.getSession(onnxSessionOptions) - val (decoderSession, decoderEnv) = onnxWrappers.decoder.getSession(onnxSessionOptions) + val ignoreTokenIdsInt = ignoreTokenIds val expandedEncoderInputsVals = batch.flatMap(x => List.fill(beamSize)(x.take(maxInputLength))).toArray @@ -162,16 +175,49 @@ private[johnsnowlabs] class M2M100( effectiveBatch_size = expandedEncoderInputsVals.length effectiveBatch_mult = 1 } + var decoderEncoderStateTensors: Either[Tensor, OnnxTensor] = null + var encoderAttentionMaskTensors: Either[Tensor, OnnxTensor] = null + + var (encoderSession, encoderEnv): (OrtSession, OrtEnvironment) = (null, null) + var (decoderSession, decoderEnv): (OrtSession, OrtEnvironment) = (null, null) + val ovInferRequest: Option[InferRequest] = detectedEngine match { + case ONNX.name => None + case Openvino.name => + Some(openvinoWrapper.get.decoder.getCompiledModel().create_infer_request()) + } + + if (detectedEngine == TensorFlow.name) { + // not implemented yet + return Array() + } else if (detectedEngine == ONNX.name) { + val (_encoderSession, _encoderEnv) = onnxWrappers.get.encoder.getSession(onnxSessionOptions) + val (_decoderSession, _decoderEnv) = onnxWrappers.get.decoder.getSession(onnxSessionOptions) - // run encoder - val decoderEncoderStateTensors = - getEncoderOutput(expandedEncoderInputsVals, Right((encoderEnv, encoderSession))) + encoderSession = _encoderSession + encoderEnv = _encoderEnv + decoderSession = _decoderSession + decoderEnv = _decoderEnv - val encoderAttentionMaskTensors = - Right( + // run encoder + decoderEncoderStateTensors = + getEncoderOutput(expandedEncoderInputsVals, Right((encoderEnv, encoderSession))) + + encoderAttentionMaskTensors = Right( OnnxTensor .createTensor(decoderEnv, expandedEncoderInputsVals.toArray.map(_.map(_ => 1L)))) + } else if (detectedEngine == Openvino.name) { + val encoderInferRequest = + openvinoWrapper.get.encoder.getCompiledModel().create_infer_request() + decoderEncoderStateTensorsOV = Some( + getEncoderOutputOv(expandedEncoderInputsVals, encoderInferRequest)) + + encoderAttentionMaskTensorsOV = Some( + new org.intel.openvino.Tensor( + Array(expandedEncoderInputsVals.length, expandedEncoderInputsVals.head.length), + expandedEncoderInputsVals.flatMap { tokenIds => tokenIds.map(_ => 1L) })) + + } // output with beam search val modelOutputs = generate( batch, @@ -194,7 +240,8 @@ private[johnsnowlabs] class M2M100( randomSeed, ignoreTokenIdsInt, Right((decoderEnv, decoderSession)), - applySoftmax = false) + applySoftmax = false, + ovInferRequest = ovInferRequest) // Run the prompt through the decoder and get the past // val decoderOutputs = @@ -204,21 +251,23 @@ private[johnsnowlabs] class M2M100( // encoderAttentionMaskTensors, // onnxSession = (decoderSession, decoderEnv)) - // close sessions - decoderEncoderStateTensors.fold( - tfTensor => { - // not implemented yet - }, - onnxTensor => onnxTensor.close()) + if (detectedEngine == ONNX.name) { + // close sessions + decoderEncoderStateTensors.fold( + tfTensor => { + // not implemented yet + }, + onnxTensor => onnxTensor.close()) - encoderAttentionMaskTensors.fold( - tfTensor => { - // not implemented yet - }, - onnxTensor => onnxTensor.close()) + encoderAttentionMaskTensors.fold( + tfTensor => { + // not implemented yet + }, + onnxTensor => onnxTensor.close()) - encoderEnv.close() - decoderEnv.close() + encoderEnv.close() + decoderEnv.close() + } // decoderOutputs modelOutputs @@ -372,6 +421,34 @@ private[johnsnowlabs] class M2M100( }) } + private def getEncoderOutputOv( + encoderInputIds: Seq[Array[Int]], + inferRequest: InferRequest): org.intel.openvino.Tensor = { + + val encoderAttentionMask: Array[Long] = + encoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }(collection.breakOut) + val encoderAttentionMaskTensor: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor( + Array(encoderInputIds.length, encoderInputIds.head.length), + encoderAttentionMask) + + val encoderInputIdsLong: Array[Long] = + encoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }(collection.breakOut) + + val encoderInputIdsLongTensor: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor( + Array(encoderInputIds.length, encoderInputIds.head.length), + encoderInputIdsLong) + + inferRequest.set_tensor(OpenVinoSignatures.encoderInputIDs, encoderInputIdsLongTensor) + inferRequest.set_tensor(OpenVinoSignatures.encoderAttentionMask, encoderAttentionMaskTensor) + + inferRequest.infer() + + val result = inferRequest.get_tensor(OpenVinoSignatures.encoderOutput) + result + } + /** Gets the model output * @param encoderInputIds * Input IDs for the Encoder @@ -397,13 +474,27 @@ private[johnsnowlabs] class M2M100( session: Either[Session, (OrtEnvironment, OrtSession)], ovInferRequest: Option[InferRequest]): Array[Array[Float]] = { - session.fold( - tfSession => { +// session.fold( +// tfSession => { +// // not implemented yet +// Array() +// }, +// onnxSession => { +// val (env, decoderSession) = onnxSession +// val decoderOutputs = +// getDecoderOutputs( +// decoderInputIds.toArray, +// decoderEncoderStateTensors, +// encoderAttentionMaskTensors, +// onnxSession = (decoderSession, env)) +// decoderOutputs +// }) + detectedEngine match { + case TensorFlow.name => // not implemented yet Array() - }, - onnxSession => { - val (env, decoderSession) = onnxSession + case ONNX.name => + val (env, decoderSession) = session.right.get val decoderOutputs = getDecoderOutputs( decoderInputIds.toArray, @@ -411,7 +502,15 @@ private[johnsnowlabs] class M2M100( encoderAttentionMaskTensors, onnxSession = (decoderSession, env)) decoderOutputs - }) + case Openvino.name => + val decoderOutputs = + getDecoderOutputsOv( + decoderInputIds.toArray, + decoderEncoderStateTensorsOV.get, + encoderAttentionMaskTensorsOV.get, + ovInferRequest.get) + decoderOutputs + } } @@ -473,6 +572,51 @@ private[johnsnowlabs] class M2M100( decoderOutputs.toArray } + private def getDecoderOutputsOv( + inputIds: Array[Array[Int]], + decoderEncoderStateTensors: org.intel.openvino.Tensor, + encoderAttentionMaskTensors: org.intel.openvino.Tensor, + inferRequest: InferRequest): (Array[Array[Float]]) = { + val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) = + if (nextPositionId.isDefined) { + val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) } + (inpIdsLong, nextPositionId.get) + } else { + val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) } + val posIdsLong = inputIds.flatMap { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + } + } + (inpIdsLong, posIdsLong) + } + + val batchSize: Int = inputIds.length + val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize) + + val inputIdsLongTensor: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(shape, inputIdsLong) + + inferRequest.set_tensor("input_ids", inputIdsLongTensor) + inferRequest.set_tensor("encoder_hidden_states", decoderEncoderStateTensors) + inferRequest.set_tensor("encoder_attention_mask", encoderAttentionMaskTensors) + + inferRequest.infer() + + val result = inferRequest.get_tensor("logits") + val logitsRaw = result.data() + nextPositionId = Some(inputIds.map(tokenIds => tokenIds.length.toLong)) + + val sequenceLength = inputIdsLong.length / batchSize + val decoderOutputs = (0 until batchSize).map(i => { + logitsRaw + .slice( + i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize, + i * sequenceLength * vocabSize + sequenceLength * vocabSize) + }) + decoderOutputs.toArray + } + /** Gets the index with the highest score * * @param scores @@ -539,4 +683,18 @@ private[johnsnowlabs] class M2M100( val decoderOutput: String = "logits" } + private object OpenVinoSignatures { + val encoderInputIDs: String = "input_ids" + val encoderAttentionMask: String = "attention_mask" + + val encoderOutput: String = "last_hidden_state" + + val decoderInputIDs: String = "input_ids" + val decoderEncoderAttentionMask: String = "encoder_attention_mask" + val decoderAttentionMask: String = "attention_mask" + val decoderEncoderState: String = "encoder_hidden_states" + + val decoderOutput: String = "logits" + } + } diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala new file mode 100644 index 00000000000000..400a103abb22cd --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala @@ -0,0 +1,454 @@ +/* + * Copyright 2017 - 2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} +import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig} +import com.johnsnowlabs.ml.onnx.OnnxSession +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ +import com.johnsnowlabs.ml.openvino.OpenvinoWrapper +import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper +import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow} +import com.johnsnowlabs.nlp.Annotation +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import com.johnsnowlabs.nlp.annotators.common.SentenceSplit +import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, Phi2Tokenizer} +import org.intel.openvino.InferRequest +import org.tensorflow.{Session, Tensor} + +import scala.collection.JavaConverters._ + +private[johnsnowlabs] class Phi2( + val onnxWrappers: Option[DecoderWrappers], + val openvinoWrapper: Option[OpenvinoWrapper], + merges: Map[(String, String), Int], + vocabulary: Map[String, Int], + generationConfig: GenerationConfig) + extends Serializable + with Generate { + + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions + val detectedEngine: String = + if (onnxWrappers.isDefined) ONNX.name + else if (openvinoWrapper.isDefined) Openvino.name + else ONNX.name + private var nextPositionId: Option[Array[Long]] = None + val bpeTokenizer: Phi2Tokenizer = BpeTokenizer + .forModel("phi2", merges = merges, vocab = vocabulary, padWithSequenceTokens = false) + .asInstanceOf[Phi2Tokenizer] + + private val GenerationConfig( + bosTokenId: Int, + paddingTokenId: Int, + eosTokenId: Int, + vocabSize: Int, + beginSuppressTokens, + suppressTokenIds, + forcedDecoderIds) = + generationConfig + + /** Decode a sequence of sentences + * @param sentences + * Sequence of sentences + * @return + * Sequence of decoded sentences + */ + def decode(sentences: Array[Array[Int]]): Seq[String] = { + sentences.map(s => bpeTokenizer.decodeTokens(s.map(_.toInt))) + } + + /** Encode a sequence of sentences + * @param sentences + * Sequence of sentences + * @return + * Sequence of encoded sentences + */ + def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = { + SentenceSplit + .unpack(sentences) + .map(s => { + val sentWithTask = s + bpeTokenizer + .tokenize(sentWithTask) + .map(bpeTokenizer.encode) + .flatMap(_.map(_.pieceId)) + }) + } + + def tag( + batch: Seq[Array[Int]], + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + randomSeed: Option[Long], + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int): Array[Array[Int]] = { + val ignoreTokenIdsInt = ignoreTokenIds + val expandedDecoderInputsVals = batch + val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray + val maxSentenceLength = sequencesLength.max // - curLen + + val numReturn_sequences = 1 + // from config + + var effectiveBatch_size = 1 + var effectiveBatch_mult = 1 + + if (doSample) { + effectiveBatch_size = expandedDecoderInputsVals.length * numReturn_sequences + effectiveBatch_mult = numReturn_sequences + } else { + effectiveBatch_size = expandedDecoderInputsVals.length + effectiveBatch_mult = 1 + } + + // Run the prompt through the decoder and get the past +// val decoderOutputs = +// generateGreedyOnnx( +// expandedDecoderInputsVals.toArray, +// (encoderSession, env), +// maxOutputLength) + val (decoderEncoderStateTensors, encoderAttentionMaskTensors, session) = + detectedEngine match { + case ONNX.name => + // dummy tensors for decoder encode state and attention mask + val (encoderSession, env) = onnxWrappers.get.decoder.getSession(onnxSessionOptions) + ( + Right(OnnxTensor.createTensor(env, Array(0))), + Right(OnnxTensor.createTensor(env, Array(1))), + Right((env, encoderSession))) + case Openvino.name => + // not needed + (null, null, null) + } + val ovInferRequest: Option[InferRequest] = detectedEngine match { + case ONNX.name => None + case Openvino.name => Some(openvinoWrapper.get.getCompiledModel().create_infer_request()) + } + // output with beam search + val modelOutputs = generate( + batch, + decoderEncoderStateTensors, + encoderAttentionMaskTensors, + expandedDecoderInputsVals.toArray, + maxOutputLength + maxSentenceLength, + minOutputLength, + doSample, + beamSize, + 1, + temperature, + topK, + topP, + repetitionPenalty, + noRepeatNgramSize, + this.vocabSize, + this.eosTokenId, + this.paddingTokenId, + randomSeed, + ignoreTokenIdsInt, + session, + applySoftmax = false, + ovInferRequest = ovInferRequest) + +// decoderOutputs + modelOutputs + } + + def predict( + sentences: Seq[Annotation], + batchSize: Int, + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + randomSeed: Option[Long] = None, + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int): Seq[Annotation] = { + + val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch => + val batchSP = encode(batch) + val spIds = tag( + batchSP, + minOutputLength, + maxOutputLength, + doSample, + temperature, + topK, + topP, + repetitionPenalty, + noRepeatNgramSize, + randomSeed, + ignoreTokenIds, + beamSize, + maxInputLength) + + decode(spIds) + + } + + var sentBegin, nextSentEnd = 0 + val annotations = batchDecoder.zip(sentences).map { case (content, sent) => + nextSentEnd += content.length - 1 + val annots = new Annotation( + annotatorType = DOCUMENT, + begin = sentBegin, + end = nextSentEnd, + result = content, + metadata = sent.metadata) + sentBegin += nextSentEnd + 1 + annots + } + annotations + } + + private def getDecoderOutputsWithPast( + inputIds: Array[Array[Int]], + decoderPast: Map[String, OnnxTensor], + onnxSession: (OrtSession, OrtEnvironment)) + : (Array[Array[Float]], Map[String, OnnxTensor]) = { + val (session, env) = onnxSession + + val lastTokens: Array[Array[Long]] = + inputIds.map { tokenIds => + Array(tokenIds.last.toLong) + } + + val lastTokensTensor: OnnxTensor = + OnnxTensor.createTensor(env, lastTokens) + val decoderAttentionMask: OnnxTensor = + OnnxTensor.createTensor(env, lastTokens.map(_.map(_ => 1L))) + val decoderWithPastInputs: java.util.Map[String, OnnxTensor] = (Map( + OnnxSignatures.decoderInputIDs -> lastTokensTensor, + OnnxSignatures.decoderAttentionMask -> decoderAttentionMask) ++ decoderPast).asJava + val sessionOutput = session.run(decoderWithPastInputs) + val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) + val decoderPresent = sessionOutput.getOnnxTensors(OnnxSignatures.decoderPresent) + lastTokensTensor.close() + val batchLogits = logits.grouped(vocabSize).toArray + (batchLogits, decoderPresent) + + } + + override def getModelOutput( + encoderInputIds: Seq[Array[Int]], + decoderInputIds: Seq[Array[Int]], + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], + maxLength: Int, + session: Either[Session, (OrtEnvironment, OrtSession)], + ovInferRequest: Option[InferRequest]): Array[Array[Float]] = { + + detectedEngine match { + case TensorFlow.name => + // not implemented yet + Array() + case ONNX.name => + val (env, decoderSession) = session.right.get + val decoderOutputs = + getDecoderOutputs(decoderInputIds.toArray, onnxSession = (decoderSession, env)) + decoderOutputs + case Openvino.name => + val decoderOutputs = + getDecoderOutputsOv(decoderInputIds.toArray, ovInferRequest.get) + decoderOutputs + } + } + + private def getDecoderOutputsOv( + inputIds: Array[Array[Int]], + inferRequest: InferRequest): (Array[Array[Float]]) = { + val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) = + if (nextPositionId.isDefined) { + val inpIdsLong = inputIds.map { tokenIds => tokenIds.last.toLong } + (inpIdsLong, nextPositionId.get) + } else { + val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) } + val posIdsLong = inputIds.flatMap { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + } + } + (inpIdsLong, posIdsLong) + } + val attentionMask: Array[Long] = + inputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) } + + val batchSize: Int = inputIds.length + val beamIdx: Array[Int] = new Array[Int](batchSize) + val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize) + + val inputIdsLongTensor: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(shape, inputIdsLong) + val decoderAttentionMask: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(Array(batchSize, inputIds.head.length), attentionMask) + val decoderPositionIDs: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(shape, inputPositionIDsLong) + val beamIdxTensor: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(Array(batchSize), beamIdx) + + inferRequest.set_tensor(OpenVinoSignatures.decoderInputIDs, inputIdsLongTensor) + inferRequest.set_tensor(OpenVinoSignatures.decoderAttentionMask, decoderAttentionMask) + inferRequest.set_tensor(OpenVinoSignatures.decoderPositionIDs, decoderPositionIDs) + inferRequest.set_tensor(OpenVinoSignatures.decoderBeamIdx, beamIdxTensor) + + inferRequest.infer() + + val result = inferRequest.get_tensor(OpenVinoSignatures.decoderOutput) + val logitsRaw = result.data() + nextPositionId = Some(inputIds.map(tokenIds => tokenIds.length.toLong)) + + val sequenceLength = inputIdsLong.length / batchSize + val decoderOutputs = (0 until batchSize).map(i => { + logitsRaw + .slice( + i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize, + i * sequenceLength * vocabSize + sequenceLength * vocabSize) + }) + decoderOutputs.toArray + } + private def getDecoderOutputs( + inputIds: Array[Array[Int]], + onnxSession: (OrtSession, OrtEnvironment)): (Array[Array[Float]]) = { + val (session, env) = onnxSession + + val inputIdsLong: Array[Array[Long]] = + inputIds.map { tokenIds => tokenIds.map(_.toLong) } + + val inputPositionIDsLong: Array[Array[Long]] = + inputIds.map { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + } + } + + val inputIdsLongTensor: OnnxTensor = + OnnxTensor.createTensor(env, inputIdsLong) + val decoderAttentionMask: OnnxTensor = + OnnxTensor.createTensor(env, inputIdsLong.map(_.map(_ => 1L))) + val decoderPositionIDs: OnnxTensor = + OnnxTensor.createTensor(env, inputPositionIDsLong) + + val decoderInputs: java.util.Map[String, OnnxTensor] = Map( + OnnxSignatures.decoderInputIDs -> inputIdsLongTensor, + OnnxSignatures.decoderAttentionMask -> decoderAttentionMask, + OnnxSignatures.decoderPositionIDs -> decoderPositionIDs).asJava + val sessionOutput = session.run(decoderInputs) + + val sequenceLength = inputIds.head.length + val batchSize = inputIds.length + +// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) +// inputIdsLongTensor.close() +// decoderPositionIDs.close() +// decoderAttentionMask.close() +// val batchLogits = logits.grouped(vocabSize).toArray +// batchLogits + + val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) + val decoderOutputs = (0 until batchSize).map(i => { + logitsRaw + .slice( + i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize, + i * sequenceLength * vocabSize + sequenceLength * vocabSize) + }) + decoderOutputs.toArray + } + + /** Gets the index with the highest score + * + * @param scores + * Array of Scores to max + * @return + * Index of the highest score + */ + private def argmax(scores: Array[Float]): Int = + scores.zipWithIndex.maxBy { case (score, _) => + score + }._2 + private def greedyGenerationFinished( + decoderIds: Seq[Array[Int]], + eosTokenId: Int, + maxOutputLength: Int): Boolean = + decoderIds.map(_.last).forall(_ == eosTokenId) || decoderIds.head.length == maxOutputLength + + private def generateGreedyOnnx( + inputIds: Array[Array[Int]], + onnxSession: (OrtSession, OrtEnvironment), + maxOutputLength: Int): (Array[Array[Int]]) = { + + val sequencesLength = inputIds.map(x => x.length).toArray + val maxSentenceLength = sequencesLength.max // - curLen + var generatedIds: Array[Array[Int]] = inputIds + while (!greedyGenerationFinished( + generatedIds, + eosTokenId, + maxOutputLength + maxSentenceLength)) { + + val (batchLogits: Array[Array[Float]]) = + Array(getDecoderOutputs(generatedIds, onnxSession).last) + + val nextTokenIds: Array[Int] = batchLogits.map(argmax) + generatedIds = + generatedIds.zip(nextTokenIds).map { case (currentIds: Array[Int], nextId: Int) => + currentIds ++ Array(nextId) + } + } + generatedIds + } + + private object OnnxSignatures { + val decoderInputIDs: String = "input_ids" + val decoderAttentionMask: String = "attention_mask" + val decoderPositionIDs: String = "position_ids" + + // create decoder past for 32 layers of key and value eg. past_key_values.0.key and past_key_values.0.value + val decoderPast: Array[String] = (0 until 32) + .flatMap(i => Seq(s"past_key_values.$i.key", s"past_key_values.$i.value")) + .toArray + val decoderOutput: String = "logits" + val decoderPresent: Array[String] = + (0 until 32).flatMap(i => Seq(s"present.$i.key", s"present.$i.value")).toArray + } + + private object OpenVinoSignatures { + val encoderInputIDs: String = "input_ids" + val encoderAttentionMask: String = "attention_mask" + + val encoderOutput: String = "last_hidden_state" + + val decoderInputIDs: String = "input_ids" + val decoderEncoderAttentionMask: String = "encoder_attention_mask" + val decoderAttentionMask: String = "attention_mask" + val decoderPositionIDs: String = "position_ids" + val decoderBeamIdx: String = "beam_idx" + val decoderEncoderState: String = "encoder_hidden_states" + + val decoderOutput: String = "logits" + } +} diff --git a/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala index 642143b5adfa7f..dd8b5f466a2927 100644 --- a/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala @@ -201,4 +201,6 @@ object OpenvinoWrapper { encoder: OpenvinoWrapper, decoder: OpenvinoWrapper, decoderWithPast: OpenvinoWrapper) + case class DecoderWrappers(decoder: OpenvinoWrapper) + case class EncoderDecoderWithoutPastWrappers(encoder: OpenvinoWrapper, decoder: OpenvinoWrapper) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala index d17ec3bdafe696..0b7bc74a3c30ef 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala @@ -18,14 +18,18 @@ package com.johnsnowlabs.nlp.annotators.seq2seq import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig import com.johnsnowlabs.ml.ai.M2M100 import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWithoutPastWrappers +import com.johnsnowlabs.ml.openvino.OpenvinoWrapper.{ + EncoderDecoderWithoutPastWrappers => OpenvinoEncoderDecoderWithoutPastWrappers +} import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel} import com.johnsnowlabs.ml.util.LoadExternalModel.{ loadJsonStringAsset, loadSentencePieceAsset, modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ONNX +import com.johnsnowlabs.ml.util.{ONNX, Openvino} import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT import com.johnsnowlabs.nlp._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ @@ -159,6 +163,7 @@ class M2M100Transformer(override val uid: String) with HasBatchedAnnotate[M2M100Transformer] with ParamsAndFeaturesWritable with WriteOnnxModel + with WriteOpenvinoModel with HasGeneratorProperties with WriteSentencePieceModel with HasEngine { @@ -364,13 +369,15 @@ class M2M100Transformer(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - onnxWrappers: EncoderDecoderWithoutPastWrappers, + onnxWrappers: Option[EncoderDecoderWithoutPastWrappers], + openvinoWrapper: Option[OpenvinoEncoderDecoderWithoutPastWrappers], spp: SentencePieceWrapper): this.type = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new M2M100( onnxWrappers, + openvinoWrapper, spp = spp, generationConfig = getGenerationConfig, vocab = $$(vocabulary)))) @@ -447,13 +454,32 @@ class M2M100Transformer(override val uid: String) writeOnnxModels( path, spark, - Seq((wrappers.encoder, "encoder_model.onnx")), + Seq((wrappers.get.encoder, "encoder_model.onnx")), M2M100Transformer.suffix) writeOnnxModels( path, spark, - Seq((wrappers.decoder, "decoder_model.onnx")), + Seq((wrappers.get.decoder, "decoder_model.onnx")), + M2M100Transformer.suffix) + writeSentencePieceModel( + path, + spark, + obj.spp, + M2M100Transformer.suffix, + M2M100Transformer.sppFile) + case Openvino.name => + val wrappers = getModelIfNotSet.openvinoWrapper + writeOpenvinoModels( + path, + spark, + Seq((wrappers.get.encoder, "openvino_encoder_model.xml")), M2M100Transformer.suffix) + writeOpenvinoModels( + path, + spark, + Seq((wrappers.get.decoder, "openvino_decoder_model.xml")), + M2M100Transformer.suffix) + val obj = getModelIfNotSet writeSentencePieceModel( path, spark, @@ -482,12 +508,16 @@ trait ReadablePretrainedM2M100TransformerModel super.pretrained(name, lang, remoteLoc) } -trait ReadM2M100TransformerDLModel extends ReadOnnxModel with ReadSentencePieceModel { +trait ReadM2M100TransformerDLModel + extends ReadOnnxModel + with ReadOpenvinoModel + with ReadSentencePieceModel { this: ParamsAndFeaturesReadable[M2M100Transformer] => override val onnxFile: String = "m2m100_onnx" val suffix: String = "_m2m100" override val sppFile: String = "m2m100_spp" + override val openvinoFile: String = "m2m100_openvino" def readModel(instance: M2M100Transformer, path: String, spark: SparkSession): Unit = { instance.getEngine match { @@ -501,7 +531,19 @@ trait ReadM2M100TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM decoder = decoderWrappers("decoder_model.onnx"), encoder = encoderWrappers("encoder_model.onnx")) val spp = readSentencePieceModel(path, spark, "_m2m100_spp", sppFile) - instance.setModelIfNotSet(spark, onnxWrappers, spp) + instance.setModelIfNotSet(spark, Some(onnxWrappers), None, spp) + case Openvino.name => + val decoderWrappers = + readOpenvinoModels(path, spark, Seq("openvino_decoder_model.xml"), suffix) + val encoderWrappers = + readOpenvinoModels(path, spark, Seq("openvino_encoder_model.xml"), suffix) + val ovWrapper = { + OpenvinoEncoderDecoderWithoutPastWrappers( + encoder = encoderWrappers("openvino_encoder_model.xml"), + decoder = decoderWrappers("openvino_decoder_model.xml")) + } + val spp = readSentencePieceModel(path, spark, "_m2m100_spp", sppFile) + instance.setModelIfNotSet(spark, None, Some(ovWrapper), spp) case _ => throw new Exception(notSupportedEngineError) } @@ -509,10 +551,13 @@ trait ReadM2M100TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM addReader(readModel) - def loadSavedModel(modelPath: String, spark: SparkSession): M2M100Transformer = { + def loadSavedModel( + modelPath: String, + spark: SparkSession, + useOpenvino: Boolean = false): M2M100Transformer = { implicit val formats: DefaultFormats.type = DefaultFormats // for json4 val (localModelPath, detectedEngine) = - modelSanityCheck(modelPath, isDecoder = true) + modelSanityCheck(modelPath, isEncoderDecoder = true) val modelConfig: JValue = parse(loadJsonStringAsset(localModelPath, "config.json")) @@ -547,10 +592,16 @@ trait ReadM2M100TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM parse(loadJsonStringAsset(localModelPath, "vocab.json")) // convert to map val vocab = vocabulary.extract[Map[String, Int]] + + val modelEngine = + if (useOpenvino) + Openvino.name + else + detectedEngine annotatorModel.setVocabulary(vocab) - annotatorModel.set(annotatorModel.engine, detectedEngine) + annotatorModel.set(annotatorModel.engine, modelEngine) - detectedEngine match { + modelEngine match { case ONNX.name => val onnxWrapperEncoder = OnnxWrapper.read( @@ -575,7 +626,30 @@ trait ReadM2M100TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM decoder = onnxWrapperDecoder) annotatorModel - .setModelIfNotSet(spark, onnxWrappers, spModel) + .setModelIfNotSet(spark, Some(onnxWrappers), None, spModel) + case Openvino.name => + val openvinoEncoderWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine, + modelName = "openvino_encoder_model") + val openvinoDecoderWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine, + modelName = "openvino_decoder_model") + val openvinoWrapper = + OpenvinoEncoderDecoderWithoutPastWrappers( + encoder = openvinoEncoderWrapper, + decoder = openvinoDecoderWrapper) + annotatorModel.setModelIfNotSet(spark, None, Some(openvinoWrapper), spModel) + case _ => throw new Exception(notSupportedEngineError) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2Transformer.scala new file mode 100644 index 00000000000000..9f7657eeeac09c --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2Transformer.scala @@ -0,0 +1,472 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.seq2seq + +import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.ai.Phi2 +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadJsonStringAsset, + loadSentencePieceAsset, + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, Openvino} +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.nlp.serialization.MapFeature +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession +import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature} +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +/** Phi-2: Textbooks Are All You Need. + * + * Phi-2 is a Transformer with 2.7 billion parameters. It was trained using the same data sources + * as Phi-1.5, augmented with a new data source that consists of various NLP synthetic texts and + * filtered websites (for safety and educational value). When assessed against benchmarks testing + * common sense, language understanding, and logical reasoning, Phi-2 showcased a nearly + * state-of-the-art performance among models with less than 13 billion parameters. + * + * Phi-2 hasn't been fine-tuned through reinforcement learning from human feedback. The intention + * behind crafting this open-source model is to provide the research community with a + * non-restricted small model to explore vital safety challenges, such as reducing toxicity, + * understanding societal biases, enhancing controllability, and more. + * + * Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val Phi2 = Phi2Transformer.pretrained() + * .setInputCols("document") + * .setOutputCol("generation") + * }}} + * The default model is `"Phi2-13b"`, if no name is provided. For available pretrained models + * please see the [[https://sparknlp.org/models?q=Phi2 Models Hub]]. + * + * For extended examples of usage, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2TestSpec.scala Phi2TestSpec]]. + * + * '''References:''' + * - [[https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/ Phi-2: Textbooks Are All You Need.]] + * - [[https://huggingface.co/microsoft/phi-2]] + * + * '''Paper Abstract:''' + * + * ''The massive increase in the size of language models to hundreds of billions of parameters + * has unlocked a host of emerging capabilities that have redefined the landscape of natural + * language processing. A question remains whether such emergent abilities can be achieved at a + * smaller scale using strategic choices for training, e.g., data selection.'' + * + * ''Our line of work with the Phi models aims to answer this question by training SLMs that + * achieve performance on par with models of much higher scale (yet still far from the frontier + * models). Our key insights for breaking the conventional language model scaling laws with Phi-2 + * are twofold:'' + * + * ''Firstly, training data quality plays a critical role in model performance. This has been + * known for decades, but we take this insight to its extreme by focusing on “textbook-quality” + * data, following upon our prior work “Textbooks Are All You Need.” Our training data mixture + * contains synthetic datasets specifically created to teach the model common sense reasoning and + * general knowledge, including science, daily activities, and theory of mind, among others. We + * further augment our training corpus with carefully selected web data that is filtered based on + * educational value and content quality. Secondly, we use innovative techniques to scale up, + * starting from our 1.3 billion parameter model, Phi-1.5, and embedding its knowledge within the + * 2.7 billion parameter Phi-2. This scaled knowledge transfer not only accelerates training + * convergence but shows clear boost in Phi-2 benchmark scores.'' + * + * '''Note:''' + * + * This is a very computationally expensive module especially on larger sequence. The use of an + * accelerator such as GPU is recommended. + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base.DocumentAssembler + * import com.johnsnowlabs.nlp.annotators.seq2seq.Phi2Transformer + * import org.apache.spark.ml.Pipeline + * + * val documentAssembler = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("documents") + * + * val Phi2 = Phi2Transformer.pretrained("Phi2-7b") + * .setInputCols(Array("documents")) + * .setMinOutputLength(10) + * .setMaxOutputLength(50) + * .setDoSample(false) + * .setTopK(50) + * .setNoRepeatNgramSize(3) + * .setOutputCol("generation") + * + * val pipeline = new Pipeline().setStages(Array(documentAssembler, Phi2)) + * + * val data = Seq( + * "My name is Leonardo." + * ).toDF("text") + * val result = pipeline.fit(data).transform(data) + * + * results.select("generation.result").show(truncate = false) + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * |result | + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * |[ My name is Leonardo . I am a student of the University of California, Berkeley. I am interested in the field of Artificial Intelligence and its applications in the real world. I have a strong | + * | passion for learning and am always looking for ways to improve my knowledge and skills] | + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * }}} + * + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class Phi2Transformer(override val uid: String) + extends AnnotatorModel[Phi2Transformer] + with HasBatchedAnnotate[Phi2Transformer] + with ParamsAndFeaturesWritable + with WriteOnnxModel + with WriteOpenvinoModel + with HasGeneratorProperties + with HasEngine { + + def this() = this(Identifiable.randomUID("Phi2TRANSFORMER")) + + /** Input annotator type : DOCUMENT + * + * @group param + */ + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(DOCUMENT) + + /** Output annotator type : DOCUMENT + * + * @group param + */ + override val outputAnnotatorType: String = DOCUMENT + + /** @group setParam */ + def setRandomSeed(value: Int): Phi2Transformer.this.type = { + if (randomSeed.isEmpty) { + this.randomSeed = Some(value) + } + this + } + + /** A list of token ids which are ignored in the decoder's output (Default: `Array()`) + * + * @group param + */ + var ignoreTokenIds = new IntArrayParam( + this, + "ignoreTokenIds", + "A list of token ids which are ignored in the decoder's output") + + /** @group setParam */ + def setIgnoreTokenIds(tokenIds: Array[Int]): Phi2Transformer.this.type = { + set(ignoreTokenIds, tokenIds) + } + + /** @group getParam */ + def getIgnoreTokenIds: Array[Int] = $(ignoreTokenIds) + + /** Vocabulary used to encode the words to ids with bpeTokenizer.encode + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + /** Holding merges.txt coming from RoBERTa model + * + * @group param + */ + val merges: MapFeature[(String, String), Int] = new MapFeature(this, "merges").setProtected() + + /** @group setParam */ + def setMerges(value: Map[(String, String), Int]): this.type = set(merges, value) + + private var _model: Option[Broadcast[Phi2]] = None + + val generationConfig: StructFeature[GenerationConfig] = + new StructFeature(this, "generationConfig").setProtected() + + def setGenerationConfig(value: GenerationConfig): this.type = + set(generationConfig, value) + + def getGenerationConfig: GenerationConfig = $$(generationConfig) + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + onnxWrappers: Option[DecoderWrappers], + openvinoWrapper: Option[OpenvinoWrapper]): this.type = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new Phi2( + onnxWrappers, + openvinoWrapper, + $$(merges), + $$(vocabulary), + generationConfig = getGenerationConfig))) + } + this + } + + /** @group getParam */ + def getModelIfNotSet: Phi2 = _model.get.value + + setDefault( + minOutputLength -> 0, + maxOutputLength -> 20, + doSample -> false, + temperature -> 0.6, + topK -> 50, + topP -> 0.9, + repetitionPenalty -> 1.0, + noRepeatNgramSize -> 3, + ignoreTokenIds -> Array(), + batchSize -> 1, + beamSize -> 1, + maxInputLength -> 4096) + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + + val allAnnotations = batchedAnnotations + .filter(_.nonEmpty) + .zipWithIndex + .flatMap { case (annotations, i) => + annotations.filter(_.result.nonEmpty).map(x => (x, i)) + } + val processedAnnotations = if (allAnnotations.nonEmpty) { + this.getModelIfNotSet.predict( + sentences = allAnnotations.map(_._1), + batchSize = $(batchSize), + minOutputLength = $(minOutputLength), + maxOutputLength = $(maxOutputLength), + doSample = $(doSample), + temperature = $(temperature), + topK = $(topK), + topP = $(topP), + repetitionPenalty = $(repetitionPenalty), + noRepeatNgramSize = $(noRepeatNgramSize), + randomSeed = this.randomSeed, + ignoreTokenIds = $(ignoreTokenIds), + beamSize = $(beamSize), + maxInputLength = $(maxInputLength)) + } else { + Seq() + } + Seq(processedAnnotations) + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + getEngine match { + case ONNX.name => + val wrappers = getModelIfNotSet.onnxWrappers + writeOnnxModels( + path, + spark, + Seq((wrappers.get.decoder, "decoder_model.onnx")), + Phi2Transformer.suffix) + case Openvino.name => + val wrappers = getModelIfNotSet.openvinoWrapper + writeOpenvinoModel( + path, + spark, + wrappers.get, + LLAMA2Transformer.suffix, + LLAMA2Transformer.openvinoFile) + } + } +} + +trait ReadablePretrainedPhi2TransformerModel + extends ParamsAndFeaturesReadable[Phi2Transformer] + with HasPretrained[Phi2Transformer] { + override val defaultModelName: Some[String] = Some("Phi2-7b") + + /** Java compliant-overrides */ + override def pretrained(): Phi2Transformer = super.pretrained() + + override def pretrained(name: String): Phi2Transformer = super.pretrained(name) + + override def pretrained(name: String, lang: String): Phi2Transformer = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): Phi2Transformer = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadPhi2TransformerDLModel extends ReadOnnxModel with ReadOpenvinoModel { + this: ParamsAndFeaturesReadable[Phi2Transformer] => + + override val onnxFile: String = "phi2_onnx" + val suffix: String = "_phi2" + override val openvinoFile: String = "llama2_openvino" + + def readModel(instance: Phi2Transformer, path: String, spark: SparkSession): Unit = { + instance.getEngine match { + case ONNX.name => + val wrappers = + readOnnxModels(path, spark, Seq("decoder_model.onnx"), suffix) + val onnxWrappers = + DecoderWrappers(decoder = wrappers("decoder_model.onnx")) + instance.setModelIfNotSet(spark, Some(onnxWrappers), None) + case Openvino.name => + val ovWrapper = + readOpenvinoModel(path, spark, "_llama2_ov") + instance.setModelIfNotSet(spark, None, Some(ovWrapper)) + case _ => + throw new Exception(notSupportedEngineError) + } + } + + addReader(readModel) + + def loadSavedModel( + modelPath: String, + spark: SparkSession, + useOpenvino: Boolean = false): Phi2Transformer = { + implicit val formats: DefaultFormats.type = DefaultFormats // for json4 + val (localModelPath, detectedEngine) = + modelSanityCheck(modelPath, isDecoder = true) + val modelConfig: JValue = + parse(loadJsonStringAsset(localModelPath, "config.json")) + + val beginSuppressTokens: Array[Int] = + (modelConfig \ "begin_suppress_tokens").extract[Array[Int]] + + val suppressTokenIds: Array[Int] = + (modelConfig \ "suppress_tokens").extract[Array[Int]] + + val forcedDecoderIds: Array[(Int, Int)] = + (modelConfig \ "forced_decoder_ids").extract[Array[Array[Int]]].map { + case idxWithTokenId: Array[Int] if idxWithTokenId.length == 2 => + (idxWithTokenId(0), idxWithTokenId(1)) + case _ => + throw new Exception( + "Could not extract forced_decoder_ids. Should be a list of tuples with 2 entries.") + } + + def arrayOrNone[T](array: Array[T]): Option[Array[T]] = + if (array.nonEmpty) Some(array) else None + + val bosTokenId = (modelConfig \ "bos_token_id").extract[Int] + val eosTokenId = (modelConfig \ "eos_token_id").extract[Int] + val padTokenId = (modelConfig \ "eos_token_id").extract[Int] + val vocabSize = (modelConfig \ "vocab_size").extract[Int] + + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + + val bytePairs = loadTextAsset(localModelPath, "merges.txt") + .map(_.split(" ")) + .filter(w => w.length == 2) + .map { case Array(c1, c2) => (c1, c2) } + .zipWithIndex + .toMap + + val annotatorModel = new Phi2Transformer() + .setGenerationConfig( + GenerationConfig( + bosTokenId, + padTokenId, + eosTokenId, + vocabSize, + arrayOrNone(beginSuppressTokens), + arrayOrNone(suppressTokenIds), + arrayOrNone(forcedDecoderIds))) + .setVocabulary(vocabs) + .setMerges(bytePairs) + + val modelEngine = + if (useOpenvino) + Openvino.name + else + detectedEngine + annotatorModel.set(annotatorModel.engine, modelEngine) + + detectedEngine match { + case ONNX.name => + val onnxWrapperDecoder = + OnnxWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + modelName = "decoder_model") + + val onnxWrappers = DecoderWrappers(onnxWrapperDecoder) + + annotatorModel + .setModelIfNotSet(spark, Some(onnxWrappers), None) + case Openvino.name => + val openvinoWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine) + annotatorModel.setModelIfNotSet(spark, None, Some(openvinoWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } + +} + +object Phi2Transformer + extends ReadablePretrainedPhi2TransformerModel + with ReadPhi2TransformerDLModel diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala index 691ca8522dddb5..e7a15439eb47e8 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala @@ -145,6 +145,14 @@ private[johnsnowlabs] object SpecialTokens { unkTokenString = "<|endoftext|>", maskTokenString = "<|endoftext|>", padTokenString = "<|endoftext|>") + case "phi2" => + SpecialTokens( + vocab, + startTokenString = "<|endoftext|>", + endTokenString = "<|endoftext|>", + unkTokenString = "<|endoftext|>", + maskTokenString = "<|endoftext|>", + padTokenString = "<|endoftext|>") } } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala index c948661b0e039b..a75457758dc813 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala @@ -354,6 +354,13 @@ object BpeTokenizer { addPrefixSpaceToSentence = addPrefixSpaceToSentence) case "clip" => new CLIPTokenizer(merges, vocab, modelSpecialTokens()) + case "phi2" => + new Phi2Tokenizer( + merges, + vocab, + modelSpecialTokens(), + padWithSequenceTokens, + addPrefixSpaceToSentence = addPrefixSpaceToSentence) case _ => throw new IllegalArgumentException("Model type \"" + modelType + "\" not supported yet.") } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/Phi2Tokenizer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/Phi2Tokenizer.scala new file mode 100644 index 00000000000000..a46f6c53e6780e --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/Phi2Tokenizer.scala @@ -0,0 +1,31 @@ +/* + * Copyright 2017-2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.tokenizer.bpe + +class Phi2Tokenizer( + merges: Map[(String, String), Int], + vocab: Map[String, Int], + specialTokens: SpecialTokens, + padWithSequenceTokens: Boolean = false, + addPrefixSpaceToSentence: Boolean = false) + extends Gpt2Tokenizer( + merges, + vocab, + specialTokens, + padWithSequenceTokens, + prependString = "Ġ", + addPrefixSpaceToSentence) diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2TestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2TestSpec.scala new file mode 100644 index 00000000000000..55e6bdb7b394e8 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2TestSpec.scala @@ -0,0 +1,52 @@ +/* + * Copyright 2017-2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.seq2seq + +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.{FastTest, SlowTest} +import org.apache.spark.ml.Pipeline +import org.scalatest.flatspec.AnyFlatSpec + +class Phi2TestSpec extends AnyFlatSpec { + + "phi2" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs SlowTest in { + // Even tough the Paper states temperature in interval [0,1), using temperature=0 will result in division by 0 error. + // Also DoSample=True may result in infinities being generated and distFiltered.length==0 which results in exception if we don't return 0 instead internally. + val testData = ResourceHelper.spark + .createDataFrame(Seq((1, "My name is Leonardo."))) + .toDF("id", "text") + .repartition(1) + val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("documents") + + val bart = Phi2Transformer + .pretrained() + .setInputCols(Array("documents")) + .setDoSample(false) + .setMaxOutputLength(50) + .setOutputCol("generation") + .setBeamSize(1) + new Pipeline() + .setStages(Array(documentAssembler, bart)) + .fit(testData) + .transform(testData) + .show(truncate = false) + + } +}