Skip to content

Commit

Permalink
When we set post-processing options, we can infer the pos-processing …
Browse files Browse the repository at this point in the history
…type. For example, if we set post-processing option as "jsonSchema", then we can infer that the post-processing type to be "json". In this changeset, I am adding this feature and also adding unit tests to validate this.
  • Loading branch information
FMasudMsft committed Nov 25, 2024
1 parent 79d5b58 commit 67f6498
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,32 @@ class OpenAIPrompt(override val uid: String) extends Transformer

def getPostProcessingOptions: Map[String, String] = $(postProcessingOptions)

def setPostProcessingOptions(value: Map[String, String]): this.type = set(postProcessingOptions, value)
def setPostProcessingOptions(value: Map[String, String]): this.type = {
// Helper method to set or validate the postProcessing parameter
def setOrValidatePostProcessing(expected: String): Unit = {
if (isSet(postProcessing)) {
require(getPostProcessing == expected, s"postProcessing must be '$expected'")
} else {
set(postProcessing, expected)
}
}

// Match on the keys in the provided value map to set the appropriate post-processing option
value match {
case v if v.contains("delimiter") =>
setOrValidatePostProcessing("csv")
case v if v.contains("jsonSchema") =>
setOrValidatePostProcessing("json")
case v if v.contains("regex") =>
require(v.contains("regexGroup"), "regexGroup must be specified with regex")
setOrValidatePostProcessing("regex")
case _ =>
throw new IllegalArgumentException("Invalid post processing options")
}

// Set the postProcessingOptions parameter with the provided value map
set(postProcessingOptions, value)
}

def setPostProcessingOptions(v: java.util.HashMap[String, String]): this.type =
set(postProcessingOptions, v.asScala.toMap)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

test("Basic Usage JSON - Gpt 4 without explicit post-processing") {
promptGpt4.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
.transform(df)
.select("outParsed")
.where(col("outParsed").isNotNull)
.collect()
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

test("Setting and Keeping Messages Col - Gpt 4") {
promptGpt4.setMessagesCol("messages")
.setDropPrompt(false)
Expand Down Expand Up @@ -149,6 +163,46 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.count(r => Option(r.getSeq[String](0)).isDefined)
}

test("setPostProcessingOptions should set postProcessing to 'csv' for delimiter option") {
val prompt = new OpenAIPrompt()
prompt.setPostProcessingOptions(Map("delimiter" -> ","))
assert(prompt.getPostProcessing == "csv")
}

test("setPostProcessingOptions should set postProcessing to 'json' for jsonSchema option") {
val prompt = new OpenAIPrompt()
prompt.setPostProcessingOptions(Map("jsonSchema" -> "schema"))
assert(prompt.getPostProcessing == "json")
}

test("setPostProcessingOptions should set postProcessing to 'regex' for regex option") {
val prompt = new OpenAIPrompt()
prompt.setPostProcessingOptions(Map("regex" -> ".*", "regexGroup" -> "0"))
assert(prompt.getPostProcessing == "regex")
}

test("setPostProcessingOptions should throw IllegalArgumentException for invalid options") {
val prompt = new OpenAIPrompt()
intercept[IllegalArgumentException] {
prompt.setPostProcessingOptions(Map("invalidOption" -> "value"))
}
}

test("setPostProcessingOptions should validate regex options contain regexGroup key") {
val prompt = new OpenAIPrompt()
intercept[IllegalArgumentException] {
prompt.setPostProcessingOptions(Map("regex" -> ".*"))
}
}

test("setPostProcessingOptions should validate existing postProcessing value") {
val prompt = new OpenAIPrompt()
prompt.setPostProcessing("csv")
intercept[IllegalArgumentException] {
prompt.setPostProcessingOptions(Map("jsonSchema" -> "schema"))
}
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq)
}
Expand Down

0 comments on commit 67f6498

Please sign in to comment.