Skip to content

Commit

Permalink
Fix the missing DefaultParamsReadable in DocumentTokenSplitter [skip …
Browse files Browse the repository at this point in the history
…test]
  • Loading branch information
maziyarpanahi committed Dec 27, 2023
1 parent dc18ef4 commit 4723d9c
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*
* Copyright 2017-2023 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.functions.ExplodeAnnotations
Expand Down Expand Up @@ -272,6 +287,6 @@ class DocumentCharacterTextSplitter(override val uid: String)
}

/** This is the companion object of [[DocumentCharacterTextSplitter]]. Please refer to that class
* for the documentation.
*/
* for the documentation.
*/
object DocumentCharacterTextSplitter extends DefaultParamsReadable[DocumentCharacterTextSplitter]
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
/*
* Copyright 2017-2023 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.functions.ExplodeAnnotations
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, AnnotatorType, HasSimpleAnnotate}
import org.apache.spark.ml.param.{BooleanParam, IntParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.DataFrame

import scala.util.matching.Regex
import com.johnsnowlabs.nlp.functions.ExplodeAnnotations

/** Annotator that splits large documents into smaller documents based on the number of tokens in
* the text.
Expand Down Expand Up @@ -223,3 +238,8 @@ class DocumentTokenSplitter(override val uid: String)
if (getExplodeSplits) dataset.explodeAnnotationsCol(getOutputCol, getOutputCol) else dataset
}
}

/** This is the companion object of [[DocumentTokenSplitter]]. Please refer to that class for the
* documentation.
*/
object DocumentTokenSplitter extends DefaultParamsReadable[DocumentTokenSplitter]
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
/*
* Copyright 2017-2023 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.Annotation
import com.johnsnowlabs.nlp.annotator.DocumentCharacterTextSplitter
import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.tags.{FastTest, SlowTest}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
/*
* Copyright 2017-2023 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.Annotation
import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.tags.FastTest
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.DataFrame
import org.scalatest.flatspec.AnyFlatSpec

Expand Down Expand Up @@ -55,4 +71,42 @@ class DocumentTokenSplitterTest extends AnyFlatSpec {
}
}

it should "be serializable" taggedAs FastTest in {
val numTokens = 3
val textSplitter = new DocumentTokenSplitter()
.setInputCols("document")
.setOutputCol("splits")
.setNumTokens(numTokens)
.setTokenOverlap(1)

val pipeline = new Pipeline().setStages(Array(documentAssembler, textSplitter))
val pipelineModel = pipeline.fit(textDf)

pipelineModel.stages.last
.asInstanceOf[DocumentTokenSplitter]
.write
.overwrite()
.save("./tmp_textSplitter")

val loadedTextSplitModel = DocumentTokenSplitter.load("tmp_textSplitter")

loadedTextSplitModel.transform(textDocument).select("splits").show(truncate = false)
}

it should "be exportable to pipeline" taggedAs FastTest in {
val numTokens = 3
val textSplitter = new DocumentTokenSplitter()
.setInputCols("document")
.setOutputCol("splits")
.setNumTokens(numTokens)
.setTokenOverlap(1)

val pipeline = new Pipeline().setStages(Array(documentAssembler, textSplitter))
pipeline.write.overwrite().save("tmp_textsplitter_pipe")

val loadedPipelineModel = Pipeline.load("tmp_textsplitter_pipe")

loadedPipelineModel.fit(textDf).transform(textDf).select("splits").show()
}

}

0 comments on commit 4723d9c

Please sign in to comment.