Skip to content

Commit

Permalink
Adding tests and fixing style
Browse files Browse the repository at this point in the history
  • Loading branch information
sss04 committed Dec 17, 2024
1 parent cf93266 commit 744f83f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
import pyspark
from pyspark import SparkContext


def getOption(opt):
if opt.isDefined():
return opt.get()
else:
return None


class OpenAIDefaults:
def __init__(self):
self.defaults = SparkContext.getOrCreate()._jvm.com.microsoft.azure.synapse.ml.services.openai.OpenAIDefaults
self.defaults = (
SparkContext.getOrCreate()._jvm.com.microsoft.azure.synapse.ml.services.openai.OpenAIDefaults
)

def set_deployment_name(self, name):
self.defaults.setDeploymentName(name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,20 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TestOpenAIDefaults(unittest.TestCase):
def test_OpenAIDefaults(self):
def test_setters_and_getters(self):
defaults = OpenAIDefaults()

defaults.set_deployment_name("Bing Bong")
defaults.set_subscription_key("SubKey")
defaults.set_temperature(0.05)

self.assertEqual(defaults.get_deployment_name(), "Bing Bong")
self.assertEqual(defaults.get_subscription_key(), "SubKey")
self.assertEqual(defaults.get_temperature(), 0.05)

def test_resetters(self):
defaults = OpenAIDefaults()

defaults.set_deployment_name("Bing Bong")
Expand All @@ -23,5 +35,29 @@ def test_OpenAIDefaults(self):
self.assertEqual(defaults.get_subscription_key(), "SubKey")
self.assertEqual(defaults.get_temperature(), 0.05)

defaults.reset_deployment_name()
defaults.reset_subscription_key()
defaults.reset_temperature()

self.assertEqual(defaults.get_deployment_name(), None)
self.assertEqual(defaults.get_subscription_key(), None)
self.assertEqual(defaults.get_temperature(), None)

def test_two_defaults(self):
defaults = OpenAIDefaults()

defaults.set_deployment_name("Bing Bong")
self.assertEqual(defaults.get_deployment_name(), "Bing Bong")

defaults2 = OpenAIDefaults()
defaults.set_deployment_name("Bing Bong")
defaults2.set_deployment_name("Vamos")
self.assertEqual(defaults.get_deployment_name(), "Vamos")

defaults2.set_deployment_name("Test 2")
defaults.set_deployment_name("Test 1")
self.assertEqual(defaults.get_deployment_name(), "Test 1")


if __name__ == "__main__":
result = unittest.main()

0 comments on commit 744f83f

Please sign in to comment.