Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] [DO NOT MERGE] Shyamsai/revert1 #2330

Closed
wants to merge 9 commits into from
Prev Previous commit
Next Next commit
Add python tests
sss04 committed Dec 20, 2024

Verified

This commit was signed with the committer’s verified signature.
Julusian Julian Waller
commit c239e9274fbd020f1aeb0aaeac3a1c278c3adbd1
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ def reset_subscription_key(self):
self.defaults.resetSubscriptionKey()

def set_temperature(self, temp):
self.defaults.setTemperature(temp)
self.defaults.setTemperature(float(temp))

def get_temperature(self):
return getOption(self.defaults.getTemperature())
Original file line number Diff line number Diff line change
@@ -3,8 +3,10 @@

from synapse.ml.services.openai.OpenAIDefaults import OpenAIDefaults
from synapse.ml.services.openai.OpenAIPrompt import OpenAIPrompt
import unittest
import unittest,os, json, subprocess
from pyspark.sql import SQLContext
from pyspark.sql.functions import col


from synapse.ml.core.init_spark import *

@@ -58,6 +60,34 @@ def test_two_defaults(self):
defaults.set_deployment_name("Test 1")
self.assertEqual(defaults.get_deployment_name(), "Test 1")

def test_prompt_w_defaults(self):

secretJson = subprocess.check_output(
"az keyvault secret show --vault-name mmlspark-build-keys --name openai-api-key-2",
shell=True,
)
openai_api_key = json.loads(secretJson)["value"]

df = spark.createDataFrame([
("apple", "fruits"),
("mercedes", "cars"),
("cake", "dishes"),
], ["text", "category"])

defaults = OpenAIDefaults()
defaults.set_deployment_name("gpt-35-turbo-0125")
defaults.set_subscription_key(openai_api_key)
defaults.set_temperature(0.05)

prompt = OpenAIPrompt()
prompt = prompt.setOutputCol("outParsed")
prompt = prompt.setCustomServiceName("synapseml-openai-2")
prompt = prompt.setPromptTemplate("Complete this comma separated list of 5 {category}: {text}, ")
results = prompt.transform(df)
results.select("outParsed").show(truncate = False)
nonNullCount = results.filter(col("outParsed").isNotNull()).count()
assert (nonNullCount == 3)


if __name__ == "__main__":
result = unittest.main()