diff --git a/docs/en/annotator_entries/DocumentCharacterTextSplitter.md b/docs/en/annotator_entries/DocumentCharacterTextSplitter.md index 79c7934025e76c..19d15c7a6e973a 100644 --- a/docs/en/annotator_entries/DocumentCharacterTextSplitter.md +++ b/docs/en/annotator_entries/DocumentCharacterTextSplitter.md @@ -43,7 +43,7 @@ from sparknlp.annotator import * from pyspark.ml import Pipeline textDF = spark.read.text( - "/home/ducha/Workspace/scala/spark-nlp/src/test/resources/spell/sherlockholmes.txt", + "sherlockholmes.txt", wholetext=True ).toDF("text") diff --git a/docs/en/annotator_entries/DocumentTokenSplitter.md b/docs/en/annotator_entries/DocumentTokenSplitter.md new file mode 100644 index 00000000000000..97e05faf7d8225 --- /dev/null +++ b/docs/en/annotator_entries/DocumentTokenSplitter.md @@ -0,0 +1,149 @@ +{%- capture title -%} +DocumentTokenSplitter +{%- endcapture -%} + +{%- capture description -%} +Annotator that splits large documents into smaller documents based on the number of tokens in +the text. + +Currently, DocumentTokenSplitter splits the text by whitespaces to create the tokens. The +number of these tokens will then be used as a measure of the text length. In the future, other +tokenization techniques will be supported. + +For example, given 3 tokens and overlap 1: + +```python +"He was, I take it, the most perfect reasoning and observing machine that the world has seen." + +["He was, I", "I take it,", "it, the most", "most perfect reasoning", "reasoning and observing", "observing machine that", "that the world", "world has seen."] +``` + +Additionally, you can set + +- whether to trim whitespaces with setTrimWhitespace +- whether to explode the splits to individual rows with setExplodeSplits + +For extended examples of usage, see the +[DocumentTokenSplitterTest](https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitterTest.scala). +{%- endcapture -%} + +{%- capture input_anno -%} +DOCUMENT +{%- endcapture -%} + +{%- capture output_anno -%} +DOCUMENT +{%- endcapture -%} + +{%- capture python_example -%} +import sparknlp +from sparknlp.base import * +from sparknlp.annotator import * +from pyspark.ml import Pipeline + +textDF = spark.read.text( + "sherlockholmes.txt", + wholetext=True +).toDF("text") + +documentAssembler = DocumentAssembler().setInputCol("text") + +textSplitter = DocumentTokenSplitter() \ + .setInputCols(["document"]) \ + .setOutputCol("splits") \ + .setNumTokens(512) \ + .setTokenOverlap(10) \ + .setExplodeSplits(True) + +pipeline = Pipeline().setStages([documentAssembler, textSplitter]) + +result = pipeline.fit(textDF).transform(textDF) +result.selectExpr( + "splits.result as result", + "splits[0].begin as begin", + "splits[0].end as end", + "splits[0].end - splits[0].begin as length", + "splits[0].metadata.numTokens as tokens") \ + .show(8, truncate = 80) ++--------------------------------------------------------------------------------+-----+-----+------+------+ +| result|begin| end|length|tokens| ++--------------------------------------------------------------------------------+-----+-----+------+------+ +|[ Project Gutenberg's The Adventures of Sherlock Holmes, by Arthur Conan Doyl...| 0| 3018| 3018| 512| +|[study of crime, and occupied his\nimmense faculties and extraordinary powers...| 2950| 5707| 2757| 512| +|[but as I have changed my clothes I can't imagine how you\ndeduce it. As to M...| 5659| 8483| 2824| 512| +|[quarters received. Be in your chamber then at that hour, and do\nnot take it...| 8427|11241| 2814| 512| +|[a pity\nto miss it."\n\n"But your client--"\n\n"Never mind him. I may want y...|11188|13970| 2782| 512| +|[person who employs me wishes his agent to be unknown to\nyou, and I may conf...|13918|16898| 2980| 512| +|[letters back."\n\n"Precisely so. But how--"\n\n"Was there a secret marriage?...|16836|19744| 2908| 512| +|[seven hundred in\nnotes," he said.\n\nHolmes scribbled a receipt upon a shee...|19683|22551| 2868| 512| ++--------------------------------------------------------------------------------+-----+-----+------+------+ +{%- endcapture -%} + +{%- capture scala_example -%} +import com.johnsnowlabs.nlp.annotator._ +import com.johnsnowlabs.nlp.DocumentAssembler +import org.apache.spark.ml.Pipeline + +val textDF = + spark.read + .option("wholetext", "true") + .text("src/test/resources/spell/sherlockholmes.txt") + .toDF("text") + +val documentAssembler = new DocumentAssembler().setInputCol("text") +val textSplitter = new DocumentTokenSplitter() + .setInputCols("document") + .setOutputCol("splits") + .setNumTokens(512) + .setTokenOverlap(10) + .setExplodeSplits(true) + +val pipeline = new Pipeline().setStages(Array(documentAssembler, textSplitter)) +val result = pipeline.fit(textDF).transform(textDF) + +result + .selectExpr( + "splits.result as result", + "splits[0].begin as begin", + "splits[0].end as end", + "splits[0].end - splits[0].begin as length", + "splits[0].metadata.numTokens as tokens") + .show(8, truncate = 80) ++--------------------------------------------------------------------------------+-----+-----+------+------+ +| result|begin| end|length|tokens| ++--------------------------------------------------------------------------------+-----+-----+------+------+ +|[ Project Gutenberg's The Adventures of Sherlock Holmes, by Arthur Conan Doyl...| 0| 3018| 3018| 512| +|[study of crime, and occupied his\nimmense faculties and extraordinary powers...| 2950| 5707| 2757| 512| +|[but as I have changed my clothes I can't imagine how you\ndeduce it. As to M...| 5659| 8483| 2824| 512| +|[quarters received. Be in your chamber then at that hour, and do\nnot take it...| 8427|11241| 2814| 512| +|[a pity\nto miss it."\n\n"But your client--"\n\n"Never mind him. I may want y...|11188|13970| 2782| 512| +|[person who employs me wishes his agent to be unknown to\nyou, and I may conf...|13918|16898| 2980| 512| +|[letters back."\n\n"Precisely so. But how--"\n\n"Was there a secret marriage?...|16836|19744| 2908| 512| +|[seven hundred in\nnotes," he said.\n\nHolmes scribbled a receipt upon a shee...|19683|22551| 2868| 512| ++--------------------------------------------------------------------------------+-----+-----+------+------+ + +{%- endcapture -%} + +{%- capture api_link -%} +[DocumentTokenSplitter](/api/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitter) +{%- endcapture -%} + +{%- capture python_api_link -%} +[DocumentTokenSplitter](/api/python/reference/autosummary/sparknlp/annotator/document_token_splitter/index.html#sparknlp.annotator.document_token_splitter.DocumentTokenSplitter) +{%- endcapture -%} + +{%- capture source_link -%} +[DocumentTokenSplitter](https://github.com/JohnSnowLabs/spark-nlp/tree/master/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitter.scala) +{%- endcapture -%} + +{% include templates/anno_template.md +title=title +description=description +input_anno=input_anno +output_anno=output_anno +python_example=python_example +scala_example=scala_example +api_link=api_link +python_api_link=python_api_link +source_link=source_link +%} \ No newline at end of file diff --git a/docs/en/annotators.md b/docs/en/annotators.md index 9399c0ccdf9ae0..802044bc181df9 100644 --- a/docs/en/annotators.md +++ b/docs/en/annotators.md @@ -61,6 +61,7 @@ There are two types of Annotators: {% include templates/anno_table_entry.md path="" name="DocumentCharacterTextSplitter" summary="Annotator which splits large documents into chunks of roughly given size."%} {% include templates/anno_table_entry.md path="" name="DocumentNormalizer" summary="Annotator which normalizes raw text from tagged text, e.g. scraped web pages or xml documents, from document type columns into Sentence."%} {% include templates/anno_table_entry.md path="" name="DocumentSimilarityRanker" summary="Annotator that uses LSH techniques present in Spark ML lib to execute approximate nearest neighbors search on top of sentence embeddings."%} +{% include templates/anno_table_entry.md path="" name="DocumentTokenSplitter" summary="Annotator that splits large documents into smaller documents based on the number of tokens in the text."%} {% include templates/anno_table_entry.md path="" name="EntityRuler" summary="Fits an Annotator to match exact strings or regex patterns provided in a file against a Document and assigns them an named entity."%} {% include templates/anno_table_entry.md path="" name="EmbeddingsFinisher" summary="Extracts embeddings from Annotations into a more easily usable form."%} {% include templates/anno_table_entry.md path="" name="Finisher" summary="Converts annotation results into a format that easier to use. It is useful to extract the results from Spark NLP Pipelines."%} diff --git a/python/sparknlp/annotator/__init__.py b/python/sparknlp/annotator/__init__.py index 4861acbbf41981..61ace12413b652 100755 --- a/python/sparknlp/annotator/__init__.py +++ b/python/sparknlp/annotator/__init__.py @@ -48,6 +48,7 @@ from sparknlp.annotator.openai import * from sparknlp.annotator.token2_chunk import * from sparknlp.annotator.document_character_text_splitter import * +from sparknlp.annotator.document_token_splitter import * if sys.version_info[0] == 2: raise ImportError( diff --git a/python/sparknlp/annotator/document_character_text_splitter.py b/python/sparknlp/annotator/document_character_text_splitter.py index 8da514ea24df1b..234a11ffaa6eec 100644 --- a/python/sparknlp/annotator/document_character_text_splitter.py +++ b/python/sparknlp/annotator/document_character_text_splitter.py @@ -72,7 +72,7 @@ class DocumentCharacterTextSplitter(AnnotatorModel): >>> from sparknlp.annotator import * >>> from pyspark.ml import Pipeline >>> textDF = spark.read.text( - ... "/home/ducha/Workspace/scala/spark-nlp/src/test/resources/spell/sherlockholmes.txt", + ... "sherlockholmes.txt", ... wholetext=True ... ).toDF("text") >>> documentAssembler = DocumentAssembler().setInputCol("text") diff --git a/python/sparknlp/annotator/document_token_splitter.py b/python/sparknlp/annotator/document_token_splitter.py new file mode 100644 index 00000000000000..802e41ac4cfb7a --- /dev/null +++ b/python/sparknlp/annotator/document_token_splitter.py @@ -0,0 +1,175 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for the DocumentNormalizer""" +from sparknlp.common import * + + +class DocumentTokenSplitter(AnnotatorModel): + """Annotator that splits large documents into smaller documents based on the number of tokens in + the text. + + Currently, DocumentTokenSplitter splits the text by whitespaces to create the tokens. The + number of these tokens will then be used as a measure of the text length. In the future, other + tokenization techniques will be supported. + + For example, given 3 tokens and overlap 1: + + .. code-block:: python + + He was, I take it, the most perfect reasoning and observing machine that the world has seen. + + ["He was, I", "I take it,", "it, the most", "most perfect reasoning", "reasoning and observing", "observing machine that", "that the world", "world has seen."] + + + Additionally, you can set + + - whether to trim whitespaces with setTrimWhitespace + - whether to explode the splits to individual rows with setExplodeSplits + + For extended examples of usage, see the + `DocumentTokenSplitterTest `__. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``DOCUMENT`` + ====================== ====================== + + Parameters + ---------- + + numTokens + Limit of the number of tokens in a text + tokenOverlap + Length of the token overlap between text chunks, by default `0`. + explodeSplits + Whether to explode split chunks to separate rows, by default `False`. + trimWhitespace + Whether to trim whitespaces of extracted chunks, by default `True`. + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> textDF = spark.read.text( + ... "sherlockholmes.txt", + ... wholetext=True + ... ).toDF("text") + >>> documentAssembler = DocumentAssembler().setInputCol("text") + >>> textSplitter = DocumentTokenSplitter() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("splits") \\ + ... .setNumTokens(512) \\ + ... .setTokenOverlap(10) \\ + ... .setExplodeSplits(True) + >>> pipeline = Pipeline().setStages([documentAssembler, textSplitter]) + >>> result = pipeline.fit(textDF).transform(textDF) + >>> result.selectExpr( + ... "splits.result as result", + ... "splits[0].begin as begin", + ... "splits[0].end as end", + ... "splits[0].end - splits[0].begin as length", + ... "splits[0].metadata.numTokens as tokens") \\ + ... .show(8, truncate = 80) + +--------------------------------------------------------------------------------+-----+-----+------+------+ + | result|begin| end|length|tokens| + +--------------------------------------------------------------------------------+-----+-----+------+------+ + |[ Project Gutenberg's The Adventures of Sherlock Holmes, by Arthur Conan Doyl...| 0| 3018| 3018| 512| + |[study of crime, and occupied his\nimmense faculties and extraordinary powers...| 2950| 5707| 2757| 512| + |[but as I have changed my clothes I can't imagine how you\ndeduce it. As to M...| 5659| 8483| 2824| 512| + |[quarters received. Be in your chamber then at that hour, and do\nnot take it...| 8427|11241| 2814| 512| + |[a pity\nto miss it."\n\n"But your client--"\n\n"Never mind him. I may want y...|11188|13970| 2782| 512| + |[person who employs me wishes his agent to be unknown to\nyou, and I may conf...|13918|16898| 2980| 512| + |[letters back."\n\n"Precisely so. But how--"\n\n"Was there a secret marriage?...|16836|19744| 2908| 512| + |[seven hundred in\nnotes," he said.\n\nHolmes scribbled a receipt upon a shee...|19683|22551| 2868| 512| + +--------------------------------------------------------------------------------+-----+-----+------+------+ + + """ + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.DOCUMENT + + numTokens = Param(Params._dummy(), + "numTokens", + "Limit of the number of tokens in a text", + typeConverter=TypeConverters.toInt) + tokenOverlap = Param(Params._dummy(), + "tokenOverlap", + "Length of the token overlap between text chunks", + typeConverter=TypeConverters.toInt) + explodeSplits = Param(Params._dummy(), + "explodeSplits", + "Whether to explode split chunks to separate rows", + typeConverter=TypeConverters.toBoolean) + trimWhitespace = Param(Params._dummy(), + "trimWhitespace", + "Whether to trim whitespaces of extracted chunks", + typeConverter=TypeConverters.toBoolean) + + @keyword_only + def __init__(self): + super(DocumentTokenSplitter, self).__init__( + classname="com.johnsnowlabs.nlp.annotators.DocumentTokenSplitter") + self._setDefault( + tokenOverlap=0, + explodeSplits=False, + trimWhitespace=True + ) + + def setNumTokens(self, value): + """Sets the limit of the number of tokens in a text + + Parameters + ---------- + value : int + Number of tokens in a text + """ + if value < 1: + raise ValueError("Number of tokens should be larger than 0.") + return self._set(numTokens=value) + + def setTokenOverlap(self, value): + """Length of the token overlap between text chunks, by default `0`. + + Parameters + ---------- + value : int + Length of the token overlap between text chunks + """ + if value > self.getOrDefault(self.numTokens): + raise ValueError("Token overlap can't be larger than number of tokens.") + return self._set(tokenOverlap=value) + + def setExplodeSplits(self, value): + """Sets whether to explode split chunks to separate rows, by default `False`. + + Parameters + ---------- + value : bool + Whether to explode split chunks to separate rows + """ + return self._set(explodeSplits=value) + + def setTrimWhitespace(self, value): + """Sets whether to trim whitespaces of extracted chunks, by default `True`. + + Parameters + ---------- + value : bool + Whether to trim whitespaces of extracted chunks + """ + return self._set(trimWhitespace=value) diff --git a/python/sparknlp/annotator/document_token_splitter_test.py b/python/sparknlp/annotator/document_token_splitter_test.py new file mode 100644 index 00000000000000..e3ee341375ee9a --- /dev/null +++ b/python/sparknlp/annotator/document_token_splitter_test.py @@ -0,0 +1,85 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.util import SparkSessionForTest + + +@pytest.mark.fast +class DocumentTokenSplitterTestSpec(unittest.TestCase): + def setUp(self): + self.data = SparkSessionForTest.spark.createDataFrame( + [ + [ + ( + "All emotions, and that\none particularly, were abhorrent to his cold, precise" + " but\nadmirably balanced mind.\n\nHe was, I take it, the most perfect\nreasoning" + " and observing machine that the world has seen." + ) + ] + ] + ).toDF("text") + + def test_run(self): + df = self.data + + document_assembler = ( + DocumentAssembler().setInputCol("text").setOutputCol("document") + ) + + document_token_splitter = ( + DocumentTokenSplitter() + .setInputCols("document") + .setOutputCol("splits") + .setNumTokens(3) + .setTokenOverlap(1) + .setExplodeSplits(True) + .setTrimWhitespace(True) + ) + + pipeline = Pipeline().setStages([document_assembler, document_token_splitter]) + + pipeline_df = pipeline.fit(df).transform(df) + + results = pipeline_df.select("splits").collect() + + splits = [ + row["splits"][0].result.replace("\n\n", " ").replace("\n", " ") + for row in results + ] + + expected = [ + "All emotions, and", + "and that one", + "one particularly, were", + "were abhorrent to", + "to his cold,", + "cold, precise but", + "but admirably balanced", + "balanced mind. He", + "He was, I", + "I take it,", + "it, the most", + "most perfect reasoning", + "reasoning and observing", + "observing machine that", + "that the world", + "world has seen.", + ] + + assert splits == expected diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala index 4f2e65372dca1e..c67934c6466980 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala @@ -739,4 +739,9 @@ package object annotator { com.johnsnowlabs.nlp.annotators.DocumentCharacterTextSplitter object DocumentCharacterTextSplitter extends ParamsAndFeaturesReadable[DocumentCharacterTextSplitter] + + type DocumentTokenSplitter = + com.johnsnowlabs.nlp.annotators.DocumentTokenSplitter + + object DocumentTokenSplitter extends ParamsAndFeaturesReadable[DocumentTokenSplitter] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitter.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitter.scala index 85fcba7a18cf3e..ca7f5c15d7705f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitter.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitter.scala @@ -1,11 +1,11 @@ package com.johnsnowlabs.nlp.annotators +import com.johnsnowlabs.nlp.functions.ExplodeAnnotations import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, AnnotatorType, HasSimpleAnnotate} import org.apache.spark.ml.param.{BooleanParam, IntParam, StringArrayParam} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame -import scala.collection.mutable import scala.util.matching.Regex /** Annotator which splits large documents into chunks of roughly given size. @@ -186,7 +186,7 @@ class DocumentCharacterTextSplitter(override val uid: String) /** @group getParam */ def getKeepSeparators: Boolean = $(keepSeparators) - /** Whether to explode split chunks to separate rows + /** Whether to explode split chunks to separate rows (Default: `false`) * * @group param */ @@ -199,6 +199,9 @@ class DocumentCharacterTextSplitter(override val uid: String) /** @group getParam */ def getExplodeSplits: Boolean = $(explodeSplits) + /** Whether to trim whitespaces of extracted chunks (Default: `true`) + * @group param + */ val trimWhitespace: BooleanParam = new BooleanParam(this, "trimWhitespace", "Whether to trim whitespaces of extracted chunks") @@ -216,147 +219,7 @@ class DocumentCharacterTextSplitter(override val uid: String) splitPatterns -> Array("\n\n", "\n", " ", ""), trimWhitespace -> true) - private def joinDocs(currentDoc: Seq[String], separator: String): String = { - val joined = String.join(separator, currentDoc: _*) - - if (getTrimWhitespace) joined.trim else joined - } - - /** Splits the given text with the separator. - * - * The separator is assumed to be regex (which was optionally escaped). - * - * @param text - * Text to split - * @param separator - * Regex as String - * @return - */ - private def splitTextWithRegex(text: String, separator: String): Seq[String] = { - val splits: Seq[String] = if (separator.nonEmpty) { - val pattern = if (getKeepSeparators) f"(?=$separator)" else separator - text.split(pattern) - } else Seq(text) - - splits.filter(_.nonEmpty) - } - - /** Combines smaller text chunks into one that has about the size of chunk size. - * - * @param splits - * Splits from the previous separator - * @param separator - * The current separator - * @return - */ - private def mergeSplits(splits: Seq[String], separator: String): Seq[String] = { - val separatorLen = separator.length - - var docs: mutable.Seq[String] = mutable.Seq() - var currentDoc: mutable.Seq[String] = mutable.Seq() - var total: Int = 0 - - splits.foreach { d => - val len = d.length - - def separatorLenNonEmpty = if (currentDoc.nonEmpty) separatorLen else 0 - def separatorLenActualText = - if (currentDoc.length > 1) separatorLen - else 0 - - if (total + len + separatorLenNonEmpty > getChunkSize) { - if (currentDoc.nonEmpty) { - val doc = joinDocs(currentDoc, separator) - if (doc.nonEmpty) { - docs = docs :+ doc - } - - def mergeLargerThanChunkSize = - total + len + separatorLenNonEmpty > getChunkSize && total > 0 - - while (total > getChunkOverlap || mergeLargerThanChunkSize) { - total -= currentDoc.head.length + separatorLenActualText - currentDoc = currentDoc.drop(1) - } - } - } - - currentDoc = currentDoc :+ d - total += len + separatorLenActualText - } - - val doc = joinDocs(currentDoc, separator) - if (doc.nonEmpty) { - docs = docs :+ doc - } - - docs - } - - // noinspection RegExpRedundantEscape - private def escapeRegexIfNeeded(text: String) = - if (getPatternsAreRegex) text - else text.replaceAll("([\\\\\\.\\[\\{\\(\\*\\+\\?\\^\\$\\|])", "\\\\$1") - - /** Splits a text into chunks of roughly given chunk size. The separators are given in a list - * and will be used in order. - * - * Inspired by LangChain's RecursiveCharacterTextSplitter. - * - * @param text - * Text to split - * @param separators - * List of separators in decreasing priority - * @return - */ - private def splitText(text: String, separators: Seq[String]): Seq[String] = { - // Get appropriate separator to use - - val (separator: String, nextSeparators: Seq[String]) = separators - .map(escapeRegexIfNeeded) - .zipWithIndex - .collectFirst { - case (sep, _) if sep.length == 4 => - (sep, Seq.empty) - case (sep, i) if sep.r.findFirstIn(text).isDefined => - (sep, separators.drop(i + 1)) - } - .getOrElse(("", Seq.empty)) - - val splits = splitTextWithRegex(text, separator) - - // Now go merging things, recursively splitting longer texts. - var finalChunks: mutable.Seq[String] = mutable.Seq() - var goodSplits: mutable.Seq[String] = mutable.Seq.empty - val separatorStr = if (getKeepSeparators) "" else separator - - splits.foreach { s => - if (s.length < getChunkSize) { - goodSplits = goodSplits :+ s - } else { - if (goodSplits.nonEmpty) { - val mergedText = mergeSplits(goodSplits, separatorStr) - finalChunks = finalChunks ++ mergedText - goodSplits = mutable.Seq.empty - } - if (nextSeparators.isEmpty) { - finalChunks = finalChunks :+ s - } else { - val recursiveChunks = splitText(s, nextSeparators) - finalChunks = finalChunks ++ recursiveChunks - } - } - } - - if (goodSplits.nonEmpty) { - val mergedText = mergeSplits(goodSplits, separatorStr) - finalChunks = finalChunks ++ mergedText - } - - finalChunks - } - - /** takes a document and annotations and produces new annotations of this annotator's annotation + /** Takes a document and annotations and produces new annotations of this annotator's annotation * type * * @param annotations @@ -366,24 +229,33 @@ class DocumentCharacterTextSplitter(override val uid: String) * relationship */ override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = { + val textSplitter = + new TextSplitter( + getChunkSize, + getChunkOverlap, + getKeepSeparators, + getPatternsAreRegex, + getTrimWhitespace) + annotations.zipWithIndex .flatMap { case (annotation, i) => val text = annotation.result - val textChunks = splitText(text, getSplitPatterns) + val textChunks = textSplitter.splitText(text, getSplitPatterns) textChunks.zipWithIndex.map { case (textChunk, index) => - val textChunkIndex = Regex.quote(textChunk).r.findFirstMatchIn(text) match { + val textChunkBegin = Regex.quote(textChunk).r.findFirstMatchIn(text) match { case Some(m) => m.start case None => -1 } + val textChunkEnd = if (textChunkBegin >= 0) textChunkBegin + textChunk.length else -1 ( i, new Annotation( AnnotatorType.DOCUMENT, - textChunkIndex, - textChunkIndex + textChunk.length, + textChunkBegin, + textChunkEnd, textChunk, annotation.metadata ++ Map("document" -> index.toString), annotation.embeddings)) @@ -394,27 +266,7 @@ class DocumentCharacterTextSplitter(override val uid: String) } override protected def afterAnnotate(dataset: DataFrame): DataFrame = { - explodeAnnotations(dataset) + if (getExplodeSplits) dataset.explodeAnnotationsCol(getOutputCol, getOutputCol) else dataset } - /** Explodes the text chunks into separate rows if set - * - * @param dataset - * Processed text chunks - * @return - * Dataset with each chunk on a separate row - */ - private def explodeAnnotations(dataset: DataFrame): DataFrame = { - import org.apache.spark.sql.functions.{array, col, explode} - if (getExplodeSplits) { - dataset - .select(dataset.columns.filterNot(_ == getOutputCol).map(col) :+ explode( - col(getOutputCol)).as("_tmp"): _*) - .withColumn( - getOutputCol, - array(col("_tmp")) - .as(getOutputCol, dataset.schema.fields.find(_.name == getOutputCol).get.metadata)) - .drop("_tmp") - } else dataset - } } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitter.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitter.scala new file mode 100644 index 00000000000000..1acfd42d710bca --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitter.scala @@ -0,0 +1,225 @@ +package com.johnsnowlabs.nlp.annotators + +import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, AnnotatorType, HasSimpleAnnotate} +import org.apache.spark.ml.param.{BooleanParam, IntParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame + +import scala.util.matching.Regex +import com.johnsnowlabs.nlp.functions.ExplodeAnnotations + +/** Annotator that splits large documents into smaller documents based on the number of tokens in + * the text. + * + * Currently, DocumentTokenSplitter splits the text by whitespaces to create the tokens. The + * number of these tokens will then be used as a measure of the text length. In the future, other + * tokenization techniques will be supported. + * + * For example, given 3 tokens and overlap 1: + * {{{ + * He was, I take it, the most perfect reasoning and observing machine that the world has seen. + * + * ["He was, I", "I take it,", "it, the most", "most perfect reasoning", "reasoning and observing", "observing machine that", "that the world", "world has seen."] + * }}} + * + * Additionally, you can set + * + * - whether to trim whitespaces with [[setTrimWhitespace]] + * - whether to explode the splits to individual rows with [[setExplodeSplits]] + * + * For extended examples of usage, see the + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitterTest.scala DocumentTokenSplitterTest]]. + * + * ==Example== + * {{{ + * import com.johnsnowlabs.nlp.annotator._ + * import com.johnsnowlabs.nlp.DocumentAssembler + * import org.apache.spark.ml.Pipeline + * + * val textDF = + * spark.read + * .option("wholetext", "true") + * .text("src/test/resources/spell/sherlockholmes.txt") + * .toDF("text") + * + * val documentAssembler = new DocumentAssembler().setInputCol("text") + * val textSplitter = new DocumentTokenSplitter() + * .setInputCols("document") + * .setOutputCol("splits") + * .setNumTokens(512) + * .setTokenOverlap(10) + * .setExplodeSplits(true) + * + * val pipeline = new Pipeline().setStages(Array(documentAssembler, textSplitter)) + * val result = pipeline.fit(textDF).transform(textDF) + * + * result + * .selectExpr( + * "splits.result as result", + * "splits[0].begin as begin", + * "splits[0].end as end", + * "splits[0].end - splits[0].begin as length", + * "splits[0].metadata.numTokens as tokens") + * .show(8, truncate = 80) + * +--------------------------------------------------------------------------------+-----+-----+------+------+ + * | result|begin| end|length|tokens| + * +--------------------------------------------------------------------------------+-----+-----+------+------+ + * |[ Project Gutenberg's The Adventures of Sherlock Holmes, by Arthur Conan Doyl...| 0| 3018| 3018| 512| + * |[study of crime, and occupied his\nimmense faculties and extraordinary powers...| 2950| 5707| 2757| 512| + * |[but as I have changed my clothes I can't imagine how you\ndeduce it. As to M...| 5659| 8483| 2824| 512| + * |[quarters received. Be in your chamber then at that hour, and do\nnot take it...| 8427|11241| 2814| 512| + * |[a pity\nto miss it."\n\n"But your client--"\n\n"Never mind him. I may want y...|11188|13970| 2782| 512| + * |[person who employs me wishes his agent to be unknown to\nyou, and I may conf...|13918|16898| 2980| 512| + * |[letters back."\n\n"Precisely so. But how--"\n\n"Was there a secret marriage?...|16836|19744| 2908| 512| + * |[seven hundred in\nnotes," he said.\n\nHolmes scribbled a receipt upon a shee...|19683|22551| 2868| 512| + * +--------------------------------------------------------------------------------+-----+-----+------+------+ + * }}} + * + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class DocumentTokenSplitter(override val uid: String) + extends AnnotatorModel[DocumentTokenSplitter] + with HasSimpleAnnotate[DocumentTokenSplitter] { + + def this() = this(Identifiable.randomUID("DocumentTokenSplitter")) + + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(AnnotatorType.DOCUMENT) + override val outputAnnotatorType: AnnotatorType = AnnotatorType.DOCUMENT + + /** Limit of the number of tokens in a text + * + * @group param + */ + val numTokens: IntParam = + new IntParam(this, "numTokens", "Limit of the number of tokens in a text") + + /** @group setParam */ + def setNumTokens(value: Int): this.type = { + require(value > 0, "Number of tokens should be larger than 0.") + set(numTokens, value) + } + + /** @group setParam */ + def getNumTokens: Int = $(numTokens) + + /** Length of the token overlap between text chunks (Default: `0`) + * + * @group param + */ + val tokenOverlap: IntParam = + new IntParam(this, "tokenOverlap", "Length of the overlap between text chunks") + + /** @group setParam */ + def setTokenOverlap(value: Int): this.type = { + require(value <= getNumTokens, "Token overlap can't be larger than number of tokens.") + set(tokenOverlap, value) + } + + /** @group getParam */ + def getTokenOverlap: Int = $(tokenOverlap) + + /** Whether to explode split chunks to separate rows + * + * @group param + */ + val explodeSplits: BooleanParam = + new BooleanParam(this, "explodeSplits", "Whether to explode split chunks to separate rows") + + /** @group setParam */ + def setExplodeSplits(value: Boolean): this.type = set(explodeSplits, value) + + /** @group getParam */ + def getExplodeSplits: Boolean = $(explodeSplits) + + /** Whether to trim whitespaces of extracted chunks (Default: `true`) + * + * @group param + */ + val trimWhitespace: BooleanParam = + new BooleanParam(this, "trimWhitespace", "Whether to trim whitespaces of extracted chunks") + + /** @group setParam */ + def setTrimWhitespace(value: Boolean): this.type = set(trimWhitespace, value) + + /** @group getParam */ + def getTrimWhitespace: Boolean = $(trimWhitespace) + + setDefault(tokenOverlap -> 0, explodeSplits -> false, trimWhitespace -> true) + + // Replaced by the desired tokenizer in the future + private val tokenSplitPattern = "\\s+".r + + def lengthFromTokens(text: String): Int = + tokenSplitPattern.split(text).count(_.nonEmpty) + + /** Takes a Document and produces document splits based on a Tokenizers + * + * @param annotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = { + val textSplitter = + new TextSplitter( + chunkSize = getNumTokens, + chunkOverlap = getTokenOverlap, + keepSeparators = true, + patternsAreRegex = true, + trimWhitespace = getTrimWhitespace, + lengthFunction = lengthFromTokens) + + val documentSplitPatterns = Array("\\s+") + + annotations.zipWithIndex + .flatMap { case (annotation, i) => + val text = annotation.result + + val textChunks = textSplitter.splitText(text, documentSplitPatterns) + + textChunks.zipWithIndex.map { case (textChunk, index) => + val textChunkBegin = Regex.quote(textChunk).r.findFirstMatchIn(text) match { + case Some(m) => m.start + case None => -1 + } + val textChunkEnd = if (textChunkBegin >= 0) textChunkBegin + textChunk.length else -1 + + ( + i, + new Annotation( + AnnotatorType.DOCUMENT, + textChunkBegin, + textChunkEnd, + textChunk, + annotation.metadata ++ Map( + "document" -> index.toString, + "numTokens" -> lengthFromTokens(textChunk).toString), + annotation.embeddings)) + } + } + .sortBy(_._1) + .map(_._2) + } + + override protected def afterAnnotate(dataset: DataFrame): DataFrame = { + if (getExplodeSplits) dataset.explodeAnnotationsCol(getOutputCol, getOutputCol) else dataset + } +} diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/TextSplitter.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/TextSplitter.scala new file mode 100644 index 00000000000000..a389ab0a50538f --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/TextSplitter.scala @@ -0,0 +1,167 @@ +package com.johnsnowlabs.nlp.annotators +import scala.collection.mutable + +/** Splits texts recursively to match given length + * + * @param chunkSize + * Length of the text chunks, measured by `lengthFunction` + * @param chunkOverlap + * Overlap of the text chunks + * @param keepSeparators + * Whether to keep separators in the final chunks + * @param patternsAreRegex + * Whether to interpret split patterns as regex + * @param trimWhitespace + * Whether to trim the whitespace from the final chunks + * @param lengthFunction + * Function to measure chunk length + */ +class TextSplitter( + chunkSize: Int, + chunkOverlap: Int, + keepSeparators: Boolean, + patternsAreRegex: Boolean, + trimWhitespace: Boolean, + lengthFunction: String => Int = _.length) { + + def joinDocs(currentDoc: Seq[String], separator: String): String = { + val joinSeparator = if (patternsAreRegex && !keepSeparators) "" else separator + val joined = String.join(joinSeparator, currentDoc: _*) + + if (trimWhitespace) joined.trim else joined + } + + /** Splits the given text with the separator. + * + * The separator is assumed to be regex (which was optionally escaped). + * + * @param text + * Text to split + * @param separator + * Regex as String + * @return + */ + def splitTextWithRegex(text: String, separator: String): Seq[String] = { + val splits: Seq[String] = if (separator.nonEmpty) { + val pattern = if (keepSeparators) f"(?=$separator)" else separator + text.split(pattern) + } else Seq(text) + + splits.filter(_.nonEmpty) + } + + /** Combines smaller text chunks into one that has about the size of chunk size. + * + * @param splits + * Splits from the previous separator + * @param separator + * The current separator + * @return + */ + def mergeSplits(splits: Seq[String], separator: String): Seq[String] = { + val separatorLen = lengthFunction(separator) + + var docs: mutable.Seq[String] = mutable.Seq() + var currentDoc: mutable.Seq[String] = mutable.Seq() + var total: Int = 0 + + splits.foreach { d => + val len = lengthFunction(d) + + def separatorLenNonEmpty = if (currentDoc.nonEmpty) separatorLen else 0 + + def separatorLenActualText = + if (currentDoc.length > 1) separatorLen + else 0 + + if (total + len + separatorLenNonEmpty > chunkSize) { + if (currentDoc.nonEmpty) { + val doc = joinDocs(currentDoc, separator) + if (doc.nonEmpty) { + docs = docs :+ doc + } + + def mergeLargerThanChunkSize = + total + len + separatorLenNonEmpty > chunkSize && total > 0 + + while (total > chunkOverlap || mergeLargerThanChunkSize) { + total -= lengthFunction(currentDoc.head) + separatorLenActualText + currentDoc = currentDoc.drop(1) + } + } + } + + currentDoc = currentDoc :+ d + total += len + separatorLenActualText + } + + val doc = joinDocs(currentDoc, separator) + if (doc.nonEmpty) { + docs = docs :+ doc + } + + docs + } + + // noinspection RegExpRedundantEscape + def escapeRegexIfNeeded(text: String) = + if (patternsAreRegex) text + else text.replaceAll("([\\\\\\.\\[\\{\\(\\*\\+\\?\\^\\$\\|])", "\\\\$1") + + /** Splits a text into chunks of roughly given chunk size. The separators are given in a list + * and will be used in order. + * + * Inspired by LangChain's RecursiveCharacterTextSplitter. + * + * @param text + * Text to split + * @param separators + * List of separators in decreasing priority + * @return + */ + def splitText(text: String, separators: Seq[String]): Seq[String] = { + // Get appropriate separator to use + val (separator: String, nextSeparators: Seq[String]) = separators + .map(escapeRegexIfNeeded) + .zipWithIndex + .collectFirst { + case (sep, _) if sep.isEmpty => + (sep, Seq.empty) + case (sep, i) if sep.r.findFirstIn(text).isDefined => + (sep, separators.drop(i + 1)) + } + .getOrElse(("", Seq.empty)) + + val splits = splitTextWithRegex(text, separator) + + // Now go merging things, recursively splitting longer texts. + var finalChunks: mutable.Seq[String] = mutable.Seq() + var goodSplits: mutable.Seq[String] = mutable.Seq.empty + val separatorStr = if (keepSeparators) "" else separator + + splits.foreach { s => + if (lengthFunction(s) < chunkSize) { + goodSplits = goodSplits :+ s + } else { + if (goodSplits.nonEmpty) { + val mergedText = mergeSplits(goodSplits, separatorStr) + finalChunks = finalChunks ++ mergedText + goodSplits = mutable.Seq.empty + } + if (nextSeparators.isEmpty) { + finalChunks = finalChunks :+ s + } else { + val recursiveChunks = splitText(s, nextSeparators) + finalChunks = finalChunks ++ recursiveChunks + } + } + } + + if (goodSplits.nonEmpty) { + val mergedText = mergeSplits(goodSplits, separatorStr) + finalChunks = finalChunks ++ mergedText + } + + finalChunks + } +} diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitterTest.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitterTest.scala index b245fbc863da92..e8179829b63e85 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitterTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitterTest.scala @@ -24,6 +24,7 @@ class DocumentCharacterTextSplitterTest extends AnyFlatSpec { val textDocument: DataFrame = documentAssembler.transform(splitTextDF) def assertResult(text: String, result: Array[Annotation], expected: Seq[String]): Unit = { + assert(expected.length == result.length, "Length of results don't match.") result.zip(expected).zipWithIndex foreach { case ((res, exChunk), i) => val chunk = res.result assert(chunk == exChunk, "Chunk was not equal") diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitterTest.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitterTest.scala new file mode 100644 index 00000000000000..036205711dc924 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentTokenSplitterTest.scala @@ -0,0 +1,58 @@ +package com.johnsnowlabs.nlp.annotators + +import com.johnsnowlabs.nlp.Annotation +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.FastTest +import org.apache.spark.sql.DataFrame +import org.scalatest.flatspec.AnyFlatSpec + +class DocumentTokenSplitterTest extends AnyFlatSpec { + + val spark = ResourceHelper.spark + + import spark.implicits._ + + val text = + "All emotions, and that\none particularly, were abhorrent to his cold, precise but\nadmirably balanced mind.\n\n" + + "He was, I take it, the most perfect\nreasoning and observing machine that the world has seen." + + val textDf = Seq(text).toDF("text") + val documentAssembler = new DocumentAssembler().setInputCol("text") + val textDocument: DataFrame = documentAssembler.transform(textDf) + + behavior of "DocumentTokenTextSplitter" + + it should "split by number of tokens" taggedAs FastTest in { + val numTokens = 3 + val tokenTextSplitter = + new DocumentTokenSplitter() + .setInputCols("document") + .setOutputCol("splits") + .setNumTokens(numTokens) + + val splitDF = tokenTextSplitter.transform(textDocument) + val result = Annotation.collect(splitDF, "splits").head + + result.foreach(annotation => assert(annotation.metadata("numTokens").toInt == numTokens)) + } + + it should "split tokens with overlap" taggedAs FastTest in { + val numTokens = 3 + val tokenTextSplitter = + new DocumentTokenSplitter() + .setInputCols("document") + .setOutputCol("splits") + .setNumTokens(numTokens) + .setTokenOverlap(1) + + val splitDF = tokenTextSplitter.transform(textDocument) + val result = Annotation.collect(splitDF, "splits").head + + result.zipWithIndex.foreach { case (annotation, i) => + if (i < result.length - 1) // Last document is shorter + assert(annotation.metadata("numTokens").toInt == numTokens) + } + } + +}