diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index 850b2915a7..007e1fdf19 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -10,7 +10,8 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient import com.microsoft.azure.synapse.ml.io.http._ import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails -import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, ServiceParam} +import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, HasGlobalParams, ServiceParam} +import com.microsoft.azure.synapse.ml.services.openai.OpenAIDeploymentNameKey import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda} import org.apache.http.NameValuePair import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase} @@ -128,10 +129,14 @@ trait HasServiceParams extends Params { } } +case object OpenAISubscriptionKey extends GlobalKey[Either[String, String]] + trait HasSubscriptionKey extends HasServiceParams { val subscriptionKey = new ServiceParam[String]( this, "subscriptionKey", "the API key to use") + GlobalParams.registerParam(subscriptionKey, OpenAISubscriptionKey) + def getSubscriptionKey: String = getScalarParam(subscriptionKey) def setSubscriptionKey(v: String): this.type = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 58fe0ece79..e1e35fc020 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -103,6 +103,8 @@ trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion } } +case object OpenAITemperatureKey extends GlobalKey[Either[Double, String]] + trait HasOpenAITextParams extends HasOpenAISharedParams { val maxTokens: ServiceParam[Int] = new ServiceParam[Int]( this, "maxTokens", @@ -126,6 +128,8 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { " We generally recommend using this or `top_p` but not both. Minimum of 0 and maximum of 2 allowed.", isRequired = false) + GlobalParams.registerParam(temperature, OpenAITemperatureKey) + def getTemperature: Double = getScalarParam(temperature) def setTemperature(v: Double): this.type = setScalarParam(temperature, v) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala index 8b91625064..fe32df2267 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -4,9 +4,18 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.param.GlobalParams +import com.microsoft.azure.synapse.ml.services.OpenAISubscriptionKey object OpenAIDefaults { def setDeploymentName(v: String): Unit = { GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, Left(v)) } + + def setSubscriptionKey(v: String): Unit = { + GlobalParams.setGlobalParam(OpenAISubscriptionKey, Left(v)) + } + + def setTemperature(v: Double): Unit = { + GlobalParams.setGlobalParam(OpenAITemperatureKey, Left(v)) + } } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala index 0f156d02f2..139d586592 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala @@ -11,12 +11,14 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { import spark.implicits._ OpenAIDefaults.setDeploymentName(deploymentName) + OpenAIDefaults.setSubscriptionKey(openAIAPIKey) + OpenAIDefaults.setTemperature(0.05) + def promptCompletion: OpenAICompletion = new OpenAICompletion() .setCustomServiceName(openAIServiceName) .setMaxTokens(200) .setOutputCol("out") - .setSubscriptionKey(openAIAPIKey) .setPromptCol("prompt") lazy val promptDF: DataFrame = Seq( @@ -33,10 +35,8 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { } lazy val prompt: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) .setCustomServiceName(openAIServiceName) .setOutputCol("outParsed") - .setTemperature(0) lazy val df: DataFrame = Seq( ("apple", "fruits"), @@ -56,4 +56,10 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { assert(nonNullCount == 3) } + + test("OpenAIPrompt Check Params") { + assert(prompt.getDeploymentName == deploymentName) + assert(prompt.getSubscriptionKey == openAIAPIKey) + assert(prompt.getTemperature == 0.05) + } }