Skip to content

Commit

Permalink
[SPARKNLP-1059] Adding aggressiveMatching parameter to DocumentSimila…
Browse files Browse the repository at this point in the history
…rityRanker
  • Loading branch information
danilojsl committed Aug 16, 2024
1 parent 49b37a5 commit 3e53d4d
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 19 deletions.
22 changes: 22 additions & 0 deletions python/sparknlp/annotator/similarity/document_similarity_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ class DocumentSimilarityRankerApproach(AnnotatorApproach, HasEnableCachingProper
"(Default: `empty`)",
typeConverter=TypeConverters.toString)

aggregationMethod = Param(Params._dummy(),
"aggregationMethod",
"Specifies the method used to aggregate multiple sentence embeddings into a single vector representation.",
typeConverter=TypeConverters.toString)


def setSimilarityMethod(self, value):
"""Sets the similarity method used to calculate the neighbours.
(Default: `"brp"`, Bucketed Random Projection for Euclidean Distance)
Expand Down Expand Up @@ -233,6 +239,22 @@ def asRetriever(self, value):
"""
return self._set(asRetrieverQuery=value)

def setAggregationMethod(self, value):
"""Set the method used to aggregate multiple sentence embeddings into a single vector
representation.
Parameters
----------
value : str
Options include
'AVERAGE' (compute the mean of all embeddings),
'FIRST' (use the first embedding only),
'MAX' (compute the element-wise maximum across embeddings)
Default ('AVERAGE')
"""
return self._set(aggregationMethod=value)

@keyword_only
def __init__(self):
super(DocumentSimilarityRankerApproach, self)\
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
package com.johnsnowlabs.nlp.annotators.similarity

import com.johnsnowlabs.nlp.AnnotatorType.{DOC_SIMILARITY_RANKINGS, SENTENCE_EMBEDDINGS}
import com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityUtil._
import com.johnsnowlabs.nlp.{AnnotatorApproach, HasEnableCachingProperties}
import com.johnsnowlabs.util.spark.SparkUtil.retrieveColumnName
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.feature.{BucketedRandomProjectionLSH, MinHashLSH}
import org.apache.spark.ml.functions.array_to_vector
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{BooleanParam, Param}
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.functions.{col, flatten, udf}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, Dataset}

import scala.util.hashing.MurmurHash3

sealed trait NeighborAnnotation {
def neighbors: Array[_]
}
Expand Down Expand Up @@ -152,7 +151,7 @@ class DocumentSimilarityRankerApproach(override val uid: String)

val DISTANCE = "distCol"

val INPUT_EMBEDDINGS = "sentence_embeddings.embeddings"
// val INPUT_EMBEDDINGS = "sentence_embeddings.embeddings"

val TEXT = "text"

Expand Down Expand Up @@ -233,14 +232,47 @@ class DocumentSimilarityRankerApproach(override val uid: String)

def getAsRetrieverQuery: String = $(asRetrieverQuery)

/** Specifies the method used to aggregate multiple sentence embeddings into a single vector
* representation. Options include 'AVERAGE' (compute the mean of all embeddings),
* 'FIRST' (use the first embedding only), 'MAX' (compute the element-wise maximum across embeddings)
*
* Default AVERAGE
*
* @group param
*/
val aggregationMethod = new Param[String](
this,
"aggregationMethod",
"Specifies the method used to aggregate multiple sentence embeddings into a single vector representation.")

/** Set the method used to aggregate multiple sentence embeddings into a single vector
* representation. Options include 'AVERAGE' (compute the mean of all embeddings), 'FIRST' (use
* the first embedding only), 'MAX' (compute the element-wise maximum across embeddings)
*
* Default AVERAGE
*
* @group param
*/
def setAggregationMethod(strategy: String): this.type = {
strategy.toLowerCase() match {
case "average" => set(aggregationMethod, "AVERAGE")
case "first" => set(aggregationMethod, "FIRST")
case "max" => set(aggregationMethod, "MAX")
case _ => throw new MatchError("aggregationMethod must be AVERAGE, FIRST, MAX or CONCAT")
}
}

def getAggregationMethod: String = $(aggregationMethod)

setDefault(
similarityMethod -> "brp",
numberOfNeighbours -> 10,
bucketLength -> 2.0,
numHashTables -> 3,
visibleDistances -> false,
identityRanking -> false,
asRetrieverQuery -> "")
asRetrieverQuery -> "",
aggregationMethod -> "AVERAGE")

def getNeighborsResultSet(
query: (Int, Vector),
Expand Down Expand Up @@ -300,11 +332,22 @@ class DocumentSimilarityRankerApproach(override val uid: String)
embeddingsDataset: Dataset[_],
recursivePipeline: Option[PipelineModel]): DocumentSimilarityRankerModel = {

val similarityDataset: DataFrame = embeddingsDataset
.withColumn(s"$LSH_INPUT_COL_NAME", array_to_vector(flatten(col(INPUT_EMBEDDINGS))))

val mh3Func = (s: String) => MurmurHash3.stringHash(s, MurmurHash3.stringSeed)
val mh3UDF = udf { mh3Func }
val inputEmbeddingsColumn =
s"${retrieveColumnName(embeddingsDataset, SENTENCE_EMBEDDINGS)}.embeddings"

val similarityDataset: DataFrame = getAggregationMethod match {
case "AVERAGE" =>
embeddingsDataset
.withColumn(s"$LSH_INPUT_COL_NAME", averageAggregation(col(inputEmbeddingsColumn)))
case "FIRST" =>
embeddingsDataset
.withColumn(
s"$LSH_INPUT_COL_NAME",
firstEmbeddingAggregation(col(inputEmbeddingsColumn)))
case "MAX" =>
embeddingsDataset
.withColumn(s"$LSH_INPUT_COL_NAME", maxAggregation(col(inputEmbeddingsColumn)))
}

val similarityDatasetWithHashIndex =
similarityDataset.withColumn(INDEX_COL_NAME, mh3UDF(col(TEXT)))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.johnsnowlabs.nlp.annotators.similarity

import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.expressions.UserDefinedFunction

import scala.util.hashing.MurmurHash3

object DocumentSimilarityUtil {

import org.apache.spark.sql.functions._

val mh3Func: String => Int = (s: String) => MurmurHash3.stringHash(s, MurmurHash3.stringSeed)
val mh3UDF: UserDefinedFunction = udf { mh3Func }

val averageAggregation: UserDefinedFunction = udf((embeddings: Seq[Seq[Double]]) => {
val summed = embeddings.transpose.map(_.sum)
val averaged = summed.map(_ / embeddings.length)
Vectors.dense(averaged.toArray)
})

val firstEmbeddingAggregation: UserDefinedFunction = udf((embeddings: Seq[Seq[Double]]) => {
Vectors.dense(embeddings.head.toArray)
})

val maxAggregation: UserDefinedFunction = udf((embeddings: Seq[Seq[Double]]) => {
Vectors.dense(embeddings.transpose.map(_.max).toArray)
})

}
13 changes: 8 additions & 5 deletions src/main/scala/com/johnsnowlabs/util/spark/SparkUtil.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.johnsnowlabs.util.spark

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.Dataset

object SparkUtil {

// Helper UDF function to flatten arrays for Spark < 2.4.0
def flattenArrays: UserDefinedFunction = udf { (arrayColumn: Seq[Seq[String]]) =>
arrayColumn.flatten.distinct
def retrieveColumnName(dataset: Dataset[_], annotatorType: String): String = {
val structFields = dataset.schema.fields
.filter(field => field.metadata.contains("annotatorType"))
.filter(field => field.metadata.getString("annotatorType") == annotatorType)
val columnNames = structFields.map(structField => structField.name)

columnNames.head
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ trait SparkSessionTest extends BeforeAndAfterAll { this: Suite =>
val emptyDataSet: Dataset[_] = PipelineModels.dummyDataset
val pipeline = new Pipeline()

println(s"Spark version: ${spark.version}")

override def beforeAll(): Unit = {
super.beforeAll()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import com.johnsnowlabs.nlp.annotators.Tokenizer
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
import com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityRankerApproach
import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.embeddings.{AlbertEmbeddings, SentenceEmbeddings}
import com.johnsnowlabs.nlp.embeddings.{
AlbertEmbeddings,
BertSentenceEmbeddings,
SentenceEmbeddings
}
import com.johnsnowlabs.nlp.finisher.DocumentSimilarityRankerFinisher
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.tags.SlowTest
Expand Down Expand Up @@ -302,4 +306,124 @@ class DocumentSimilarityRankerTestSpec extends AnyFlatSpec {
assert(transformed.columns.contains("nearest_neighbor_id"))
assert(transformed.columns.contains("nearest_neighbor_distance"))
}

it should "work when setting aggregation method" taggedAs SlowTest in {
val documentAssembler = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val sentenceDetector = new SentenceDetector()
.setInputCols("document")
.setOutputCol("sentence")

val embeddings = BertSentenceEmbeddings
.pretrained("sent_biobert_clinical_base_cased", "en")
.setInputCols("sentence")
.setOutputCol("sentence_embeddings")

val document_similarity_ranker = new DocumentSimilarityRankerApproach()
.setInputCols("sentence_embeddings")
.setOutputCol("doc_similarity_rankings")
.setSimilarityMethod("brp")
.setNumberOfNeighbours(1)
.setBucketLength(2.0)
.setNumHashTables(3)
.setVisibleDistances(true)
.setIdentityRanking(false)
.setAggregationMethod("MAX")

val document_similarity_ranker_finisher = new DocumentSimilarityRankerFinisher()
.setInputCols("doc_similarity_rankings")
.setOutputCols(
"finished_doc_similarity_rankings_id",
"finished_doc_similarity_rankings_neighbors")
.setExtractNearestNeighbor(true)

val pipeline = new Pipeline()
.setStages(
Array(
documentAssembler,
sentenceDetector,
embeddings,
document_similarity_ranker,
document_similarity_ranker_finisher))

val transformed = pipeline.fit(smallCorpus).transform(smallCorpus)

transformed
.select(
"doc_similarity_rankings",
"finished_doc_similarity_rankings_id",
"finished_doc_similarity_rankings_neighbors")
.show(10, false)
}

"Pipeline" should "should not fail if I use the outputCol and inputCols feature" taggedAs SlowTest in {
val nbOfNeighbors = 3

val documentAssembler = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val sentence = new SentenceDetector()
.setInputCols("document")
.setOutputCol("sentence")

val tokenizer = new Tokenizer()
.setInputCols(Array("document"))
.setOutputCol("token")

val embeddings = AlbertEmbeddings
.pretrained()
.setInputCols("sentence", "token")
.setOutputCol("embeddings")

val embeddingsSentence = new SentenceEmbeddings()
.setInputCols(Array("document", "embeddings"))
.setOutputCol("my_sentence_emb")
.setPoolingStrategy("AVERAGE")

val sentenceFinisher = new EmbeddingsFinisher()
.setInputCols("my_sentence_emb")
.setOutputCols("finished_sentence_embeddings")
.setCleanAnnotations(false)

val query = "Fifth document, Florence in Italy, is among the most beautiful cities in Europe."

val docSimilarityRanker = new DocumentSimilarityRankerApproach()
.setInputCols("my_sentence_emb")
.setOutputCol(DOC_SIMILARITY_RANKINGS)
.setSimilarityMethod("brp")
.setNumberOfNeighbours(nbOfNeighbors)
.setVisibleDistances(true)
.setIdentityRanking(true)
.asRetriever(query)

val documentSimilarityFinisher = new DocumentSimilarityRankerFinisher()
.setInputCols("doc_similarity_rankings")
.setOutputCols(
"finished_doc_similarity_rankings_id",
"finished_doc_similarity_rankings_neighbors")

val pipeline = new Pipeline()
.setStages(
Array(
documentAssembler,
sentence,
tokenizer,
embeddings,
embeddingsSentence,
sentenceFinisher,
docSimilarityRanker,
documentSimilarityFinisher))

val transformed = pipeline.fit(smallCorpus).transform(smallCorpus)

transformed.show(false)

assert(transformed.count() === 3)
assert(transformed.columns.contains("nearest_neighbor_id"))
assert(transformed.columns.contains("nearest_neighbor_distance"))
}

}
36 changes: 36 additions & 0 deletions src/test/scala/com/johnsnowlabs/util/SparkUtilTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.johnsnowlabs.util

import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, TOKEN}
import com.johnsnowlabs.nlp.annotator.SentenceDetector
import com.johnsnowlabs.nlp.annotators.{SparkSessionTest, Tokenizer}
import com.johnsnowlabs.tags.FastTest
import com.johnsnowlabs.util.spark.SparkUtil
import org.apache.spark.ml.Pipeline
import org.scalatest.flatspec.AnyFlatSpec

class SparkUtilTest extends AnyFlatSpec with SparkSessionTest {

"SparkUtil" should "retrieve column name for Token annotator type " taggedAs FastTest in {
val expectedColumn = "token"
val testDataset = tokenizerPipeline.fit(emptyDataSet).transform(emptyDataSet)

val actualColumn = SparkUtil.retrieveColumnName(testDataset, TOKEN)

assert(expectedColumn == actualColumn)
}

it should "retrieve custom column name for Token annotator type " taggedAs FastTest in {
val customColumnName = "my_custom_token_col"
val tokenizer = new Tokenizer()
.setInputCols("document")
.setOutputCol(customColumnName)

val pipeline = new Pipeline().setStages(Array(documentAssembler, tokenizer))
val testDataset = pipeline.fit(emptyDataSet).transform(emptyDataSet)

val actualColumn = SparkUtil.retrieveColumnName(testDataset, TOKEN)

assert(customColumnName == actualColumn)
}

}

0 comments on commit 3e53d4d

Please sign in to comment.