Skip to content

Commit

Permalink
Merge branch 'master' into response-format-03
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 authored Dec 2, 2024
2 parents 4012f98 + 4a6a041 commit efb9b4d
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, HasGlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.services.openai.OpenAIDeploymentNameKey
import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda}
import org.apache.http.NameValuePair
import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase}
Expand Down Expand Up @@ -128,10 +129,14 @@ trait HasServiceParams extends Params {
}
}

case object OpenAISubscriptionKey extends GlobalKey[Either[String, String]]

trait HasSubscriptionKey extends HasServiceParams {
val subscriptionKey = new ServiceParam[String](
this, "subscriptionKey", "the API key to use")

GlobalParams.registerParam(subscriptionKey, OpenAISubscriptionKey)

def getSubscriptionKey: String = getScalarParam(subscriptionKey)

def setSubscriptionKey(v: String): this.type = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion
}
}

case object OpenAITemperatureKey extends GlobalKey[Either[Double, String]]

trait HasOpenAITextParams extends HasOpenAISharedParams {
val maxTokens: ServiceParam[Int] = new ServiceParam[Int](
this, "maxTokens",
Expand All @@ -126,6 +128,8 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
" We generally recommend using this or `top_p` but not both. Minimum of 0 and maximum of 2 allowed.",
isRequired = false)

GlobalParams.registerParam(temperature, OpenAITemperatureKey)

def getTemperature: Double = getScalarParam(temperature)

def setTemperature(v: Double): this.type = setScalarParam(temperature, v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.param.GlobalParams
import com.microsoft.azure.synapse.ml.services.OpenAISubscriptionKey

object OpenAIDefaults {
def setDeploymentName(v: String): Unit = {
GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, Left(v))
}

def setSubscriptionKey(v: String): Unit = {
GlobalParams.setGlobalParam(OpenAISubscriptionKey, Left(v))
}

def setTemperature(v: Double): Unit = {
GlobalParams.setGlobalParam(OpenAITemperatureKey, Left(v))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,32 @@ class OpenAIPrompt(override val uid: String) extends Transformer

def getPostProcessingOptions: Map[String, String] = $(postProcessingOptions)

def setPostProcessingOptions(value: Map[String, String]): this.type = set(postProcessingOptions, value)
def setPostProcessingOptions(value: Map[String, String]): this.type = {
// Helper method to set or validate the postProcessing parameter
def setOrValidatePostProcessing(expected: String): Unit = {
if (isSet(postProcessing)) {
require(getPostProcessing == expected, s"postProcessing must be '$expected'")
} else {
set(postProcessing, expected)
}
}

// Match on the keys in the provided value map to set the appropriate post-processing option
value match {
case v if v.contains("delimiter") =>
setOrValidatePostProcessing("csv")
case v if v.contains("jsonSchema") =>
setOrValidatePostProcessing("json")
case v if v.contains("regex") =>
require(v.contains("regexGroup"), "regexGroup must be specified with regex")
setOrValidatePostProcessing("regex")
case _ =>
throw new IllegalArgumentException("Invalid post processing options")
}

// Set the postProcessingOptions parameter with the provided value map
set(postProcessingOptions, value)
}

def setPostProcessingOptions(v: java.util.HashMap[String, String]): this.type =
set(postProcessingOptions, v.asScala.toMap)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AzMapsSearchAddressSuite extends TransformerFuzzing[AddressGeocoder] with

assert(flattenedResults != null)
assert(flattenedResults.length == 15)
assert(flattenedResults.toSeq.head.get(1) == 47.64188)
assert(flattenedResults.toSeq.head.get(1).toString.startsWith("47.6418"))
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
import spark.implicits._

OpenAIDefaults.setDeploymentName(deploymentName)
OpenAIDefaults.setSubscriptionKey(openAIAPIKey)
OpenAIDefaults.setTemperature(0.05)


def promptCompletion: OpenAICompletion = new OpenAICompletion()
.setCustomServiceName(openAIServiceName)
.setMaxTokens(200)
.setOutputCol("out")
.setSubscriptionKey(openAIAPIKey)
.setPromptCol("prompt")

lazy val promptDF: DataFrame = Seq(
Expand All @@ -33,10 +35,8 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

lazy val df: DataFrame = Seq(
("apple", "fruits"),
Expand All @@ -56,4 +56,10 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {

assert(nonNullCount == 3)
}

test("OpenAIPrompt Check Params") {
assert(prompt.getDeploymentName == deploymentName)
assert(prompt.getSubscriptionKey == openAIAPIKey)
assert(prompt.getTemperature == 0.05)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

test("Basic Usage JSON - Gpt 4 without explicit post-processing") {
promptGpt4.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
.transform(df)
.select("outParsed")
.where(col("outParsed").isNotNull)
.collect()
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

test("Setting and Keeping Messages Col - Gpt 4") {
promptGpt4.setMessagesCol("messages")
.setDropPrompt(false)
Expand Down Expand Up @@ -173,6 +187,46 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.count(r => Option(r.getSeq[String](0)).isDefined)
}

test("setPostProcessingOptions should set postProcessing to 'csv' for delimiter option") {
val prompt = new OpenAIPrompt()
prompt.setPostProcessingOptions(Map("delimiter" -> ","))
assert(prompt.getPostProcessing == "csv")
}

test("setPostProcessingOptions should set postProcessing to 'json' for jsonSchema option") {
val prompt = new OpenAIPrompt()
prompt.setPostProcessingOptions(Map("jsonSchema" -> "schema"))
assert(prompt.getPostProcessing == "json")
}

test("setPostProcessingOptions should set postProcessing to 'regex' for regex option") {
val prompt = new OpenAIPrompt()
prompt.setPostProcessingOptions(Map("regex" -> ".*", "regexGroup" -> "0"))
assert(prompt.getPostProcessing == "regex")
}

test("setPostProcessingOptions should throw IllegalArgumentException for invalid options") {
val prompt = new OpenAIPrompt()
intercept[IllegalArgumentException] {
prompt.setPostProcessingOptions(Map("invalidOption" -> "value"))
}
}

test("setPostProcessingOptions should validate regex options contain regexGroup key") {
val prompt = new OpenAIPrompt()
intercept[IllegalArgumentException] {
prompt.setPostProcessingOptions(Map("regex" -> ".*"))
}
}

test("setPostProcessingOptions should validate existing postProcessing value") {
val prompt = new OpenAIPrompt()
prompt.setPostProcessing("csv")
intercept[IllegalArgumentException] {
prompt.setPostProcessingOptions(Map("jsonSchema" -> "schema"))
}
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq)
}
Expand Down

0 comments on commit efb9b4d

Please sign in to comment.