diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py new file mode 100644 index 0000000000..7aad53f842 --- /dev/null +++ b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py @@ -0,0 +1,60 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import sys + +if sys.version >= "3": + basestring = str + +import pyspark +from pyspark import SparkContext + + +def getOption(opt): + if opt.isDefined(): + return opt.get() + else: + return None + + +class OpenAIDefaults: + def __init__(self): + self.defaults = ( + SparkContext.getOrCreate()._jvm.com.microsoft.azure.synapse.ml.services.openai.OpenAIDefaults + ) + + def set_deployment_name(self, name): + self.defaults.setDeploymentName(name) + + def get_deployment_name(self): + return getOption(self.defaults.getDeploymentName()) + + def reset_deployment_name(self): + self.defaults.resetDeploymentName() + + def set_subscription_key(self, key): + self.defaults.setSubscriptionKey(key) + + def get_subscription_key(self): + return getOption(self.defaults.getSubscriptionKey()) + + def reset_subscription_key(self): + self.defaults.resetSubscriptionKey() + + def set_temperature(self, temp): + self.defaults.setTemperature(float(temp)) + + def get_temperature(self): + return getOption(self.defaults.getTemperature()) + + def reset_temperature(self): + self.defaults.resetTemperature() + + def set_URL(self, URL): + self.defaults.setURL(URL) + + def get_URL(self): + return getOption(self.defaults.getURL()) + + def reset_URL(self): + self.defaults.resetURL() diff --git a/cognitive/src/main/python/synapse/ml/services/openai/__init__.py b/cognitive/src/main/python/synapse/ml/services/openai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala index fe32df2267..f8405fbe1b 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -5,17 +5,61 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.param.GlobalParams import com.microsoft.azure.synapse.ml.services.OpenAISubscriptionKey +import com.microsoft.azure.synapse.ml.io.http.URLKey object OpenAIDefaults { def setDeploymentName(v: String): Unit = { GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, Left(v)) } + def getDeploymentName: Option[String] = { + extractLeft(GlobalParams.getGlobalParam(OpenAIDeploymentNameKey)) + } + + def resetDeploymentName(): Unit = { + GlobalParams.resetGlobalParam(OpenAIDeploymentNameKey) + } + def setSubscriptionKey(v: String): Unit = { GlobalParams.setGlobalParam(OpenAISubscriptionKey, Left(v)) } + def getSubscriptionKey: Option[String] = { + extractLeft(GlobalParams.getGlobalParam(OpenAISubscriptionKey)) + } + + def resetSubscriptionKey(): Unit = { + GlobalParams.resetGlobalParam(OpenAISubscriptionKey) + } + def setTemperature(v: Double): Unit = { GlobalParams.setGlobalParam(OpenAITemperatureKey, Left(v)) } + + def getTemperature: Option[Double] = { + extractLeft(GlobalParams.getGlobalParam(OpenAITemperatureKey)) + } + + def resetTemperature(): Unit = { + GlobalParams.resetGlobalParam(OpenAITemperatureKey) + } + + def setURL(v: String): Unit = { + GlobalParams.setGlobalParam(URLKey, v) + } + + def getURL: Option[String] = { + GlobalParams.getGlobalParam(URLKey) + } + + def resetURL(): Unit = { + GlobalParams.resetGlobalParam(URLKey) + } + + private def extractLeft[T](optEither: Option[Either[T, String]]): Option[T] = { + optEither match { + case Some(Left(v)) => Some(v) + case _ => None + } + } } diff --git a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py new file mode 100644 index 0000000000..beb86c49bb --- /dev/null +++ b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py @@ -0,0 +1,104 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +from synapse.ml.services.openai.OpenAIDefaults import OpenAIDefaults +from synapse.ml.services.openai.OpenAIPrompt import OpenAIPrompt +import unittest, os, json, subprocess +from pyspark.sql import SQLContext +from pyspark.sql.functions import col + + +from synapse.ml.core.init_spark import * + +spark = init_spark() +sc = SQLContext(spark.sparkContext) + + +class TestOpenAIDefaults(unittest.TestCase): + def test_setters_and_getters(self): + defaults = OpenAIDefaults() + + defaults.set_deployment_name("Bing Bong") + defaults.set_subscription_key("SubKey") + defaults.set_temperature(0.05) + defaults.set_URL("Test URL") + + self.assertEqual(defaults.get_deployment_name(), "Bing Bong") + self.assertEqual(defaults.get_subscription_key(), "SubKey") + self.assertEqual(defaults.get_temperature(), 0.05) + self.assertEqual(defaults.get_URL(), "Test URL") + + def test_resetters(self): + defaults = OpenAIDefaults() + + defaults.set_deployment_name("Bing Bong") + defaults.set_subscription_key("SubKey") + defaults.set_temperature(0.05) + defaults.set_URL("Test URL") + + self.assertEqual(defaults.get_deployment_name(), "Bing Bong") + self.assertEqual(defaults.get_subscription_key(), "SubKey") + self.assertEqual(defaults.get_temperature(), 0.05) + self.assertEqual(defaults.get_URL(), "Test URL") + + defaults.reset_deployment_name() + defaults.reset_subscription_key() + defaults.reset_temperature() + defaults.reset_URL() + + self.assertEqual(defaults.get_deployment_name(), None) + self.assertEqual(defaults.get_subscription_key(), None) + self.assertEqual(defaults.get_temperature(), None) + self.assertEqual(defaults.get_URL(), None) + + def test_two_defaults(self): + defaults = OpenAIDefaults() + + defaults.set_deployment_name("Bing Bong") + self.assertEqual(defaults.get_deployment_name(), "Bing Bong") + + defaults2 = OpenAIDefaults() + defaults.set_deployment_name("Bing Bong") + defaults2.set_deployment_name("Vamos") + self.assertEqual(defaults.get_deployment_name(), "Vamos") + + defaults2.set_deployment_name("Test 2") + defaults.set_deployment_name("Test 1") + self.assertEqual(defaults.get_deployment_name(), "Test 1") + + def test_prompt_w_defaults(self): + + secretJson = subprocess.check_output( + "az keyvault secret show --vault-name mmlspark-build-keys --name openai-api-key-2", + shell=True, + ) + openai_api_key = json.loads(secretJson)["value"] + + df = spark.createDataFrame( + [ + ("apple", "fruits"), + ("mercedes", "cars"), + ("cake", "dishes"), + ], + ["text", "category"], + ) + + defaults = OpenAIDefaults() + defaults.set_deployment_name("gpt-35-turbo-0125") + defaults.set_subscription_key(openai_api_key) + defaults.set_temperature(0.05) + defaults.set_URL("https://synapseml-openai-2.openai.azure.com/") + + prompt = OpenAIPrompt() + prompt = prompt.setOutputCol("outParsed") + prompt = prompt.setPromptTemplate( + "Complete this comma separated list of 5 {category}: {text}, " + ) + results = prompt.transform(df) + results.select("outParsed").show(truncate=False) + nonNullCount = results.filter(col("outParsed").isNotNull()).count() + assert nonNullCount == 3 + + +if __name__ == "__main__": + result = unittest.main() diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala index 139d586592..487e0345bc 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala @@ -10,13 +10,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { import spark.implicits._ - OpenAIDefaults.setDeploymentName(deploymentName) - OpenAIDefaults.setSubscriptionKey(openAIAPIKey) - OpenAIDefaults.setTemperature(0.05) - - def promptCompletion: OpenAICompletion = new OpenAICompletion() - .setCustomServiceName(openAIServiceName) .setMaxTokens(200) .setOutputCol("out") .setPromptCol("prompt") @@ -28,6 +22,11 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { ).toDF("prompt") test("Completion w Globals") { + OpenAIDefaults.setDeploymentName(deploymentName) + OpenAIDefaults.setSubscriptionKey(openAIAPIKey) + OpenAIDefaults.setTemperature(0.05) + OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/") + val fromRow = CompletionResponse.makeFromRowConverter promptCompletion.transform(promptDF).collect().foreach(r => fromRow(r.getAs[Row]("out")).choices.foreach(c => @@ -35,7 +34,6 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { } lazy val prompt: OpenAIPrompt = new OpenAIPrompt() - .setCustomServiceName(openAIServiceName) .setOutputCol("outParsed") lazy val df: DataFrame = Seq( @@ -46,6 +44,11 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { ).toDF("text", "category") test("OpenAIPrompt w Globals") { + OpenAIDefaults.setDeploymentName(deploymentName) + OpenAIDefaults.setSubscriptionKey(openAIAPIKey) + OpenAIDefaults.setTemperature(0.05) + OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/") + val nonNullCount = prompt .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -55,11 +58,38 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { .count(r => Option(r.getSeq[String](0)).isDefined) assert(nonNullCount == 3) - } - test("OpenAIPrompt Check Params") { assert(prompt.getDeploymentName == deploymentName) assert(prompt.getSubscriptionKey == openAIAPIKey) assert(prompt.getTemperature == 0.05) } + + test("Test Getters") { + OpenAIDefaults.setDeploymentName(deploymentName) + OpenAIDefaults.setSubscriptionKey(openAIAPIKey) + OpenAIDefaults.setTemperature(0.05) + OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/") + + assert(OpenAIDefaults.getDeploymentName.contains(deploymentName)) + assert(OpenAIDefaults.getSubscriptionKey.contains(openAIAPIKey)) + assert(OpenAIDefaults.getTemperature.contains(0.05)) + assert(OpenAIDefaults.getURL.contains(s"https://$openAIServiceName.openai.azure.com/")) + } + + test("Test Resetters") { + OpenAIDefaults.setDeploymentName(deploymentName) + OpenAIDefaults.setSubscriptionKey(openAIAPIKey) + OpenAIDefaults.setTemperature(0.05) + OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/") + + OpenAIDefaults.resetDeploymentName() + OpenAIDefaults.resetSubscriptionKey() + OpenAIDefaults.resetTemperature() + OpenAIDefaults.resetURL() + + assert(OpenAIDefaults.getDeploymentName.isEmpty) + assert(OpenAIDefaults.getSubscriptionKey.isEmpty) + assert(OpenAIDefaults.getTemperature.isEmpty) + assert(OpenAIDefaults.getURL.isEmpty) + } } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala index 8d942e34b1..43ea24d112 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.codegen.Wrappable import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCol, HasOutputCol} import com.microsoft.azure.synapse.ml.io.http.HandlingUtils.HandlerFunc import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} -import com.microsoft.azure.synapse.ml.param.UDFParam +import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, UDFParam} import org.apache.http.impl.client.CloseableHttpClient import org.apache.spark.injections.UDFUtils import org.apache.spark.ml.param._ @@ -76,10 +76,14 @@ trait ConcurrencyParams extends Wrappable { setDefault(concurrency -> 1, timeout -> 60.0) } +case object URLKey extends GlobalKey[String] + trait HasURL extends Params { val url: Param[String] = new Param[String](this, "url", "Url of the service") + GlobalParams.registerParam(url, URLKey) + /** @group getParam */ def getUrl: String = $(url) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala index 98f7eb33e6..ac6f6a8bcb 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala @@ -18,10 +18,14 @@ object GlobalParams { GlobalParams(key) = value } - private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { + def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) } + def resetGlobalParam[T](key: GlobalKey[T]): Unit = { + GlobalParams -= key + } + def getParam[T](p: Param[T]): Option[T] = { ParamToKeyMap.get(p).flatMap { key => key match {