Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Jul 19, 2024
1 parent ced6547 commit e8ee16c
Show file tree
Hide file tree
Showing 12 changed files with 27 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class SimpleHTTPTransformerSmokeTest(unittest.TestCase):
def test_simple(self):
df = spark.createDataFrame([("foo",) for x in range(20)], ["data"]).withColumn(
Expand Down
28 changes: 16 additions & 12 deletions core/src/main/python/synapse/ml/core/init_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@

from synapse.ml.core import __spark_package_version__


def init_spark():
from pyspark.sql import SparkSession, SQLContext

return (SparkSession.builder
.master("local[*]")
.appName("PysparkTests")
.config("spark.jars.packages",
"com.microsoft.azure:synapseml_2.12:" +
__spark_package_version__ +
",org.apache.spark:spark-avro_2.12:3.4.1")
.config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven")
.config("spark.executor.heartbeatInterval", "60s")
.config("spark.sql.shuffle.partitions", 10)
.config("spark.sql.crossJoin.enabled", "true")
.getOrCreate())
return (
SparkSession.builder.master("local[*]")
.appName("PysparkTests")
.config(
"spark.jars.packages",
"com.microsoft.azure:synapseml_2.12:"
+ __spark_package_version__
+ ",org.apache.spark:spark-avro_2.12:3.4.1",
)
.config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven")
.config("spark.executor.heartbeatInterval", "60s")
.config("spark.sql.shuffle.partitions", 10)
.config("spark.sql.crossJoin.enabled", "true")
.getOrCreate()
)
1 change: 1 addition & 0 deletions core/src/test/python/synapsemltest/core/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class SampleTransformer(SynapseMLLogger):
def __init__(self):
super().__init__(log_level=logging.DEBUG)
Expand Down
1 change: 1 addition & 0 deletions core/src/test/python/synapsemltest/core/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TemplateSpec(unittest.TestCase):
def create_sample_dataframe(self):
schema = t.StructType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TestComplementAccessTransformer(unittest.TestCase):
def create_dataframe(self) -> DataFrame:
schema = t.StructType(
Expand Down
1 change: 1 addition & 0 deletions core/src/test/python/synapsemltest/cyber/explain_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class ExplainTester:
def check_explain(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TestIndexers(unittest.TestCase):
def create_sample_dataframe(self):
schema = t.StructType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TestScalers(unittest.TestCase):
def create_sample_dataframe(self):
schema = t.StructType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TestDataFrameUtils(unittest.TestCase):
def create_sample_dataframe(self):
dataframe = sc.createDataFrame(
Expand Down
1 change: 1 addition & 0 deletions core/src/test/python/synapsemltest/nn/test_ball_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class NNSpec(unittest.TestCase):
def test_bindings(self):
cbt = ConditionalBallTree([[1.0, 2.0], [2.0, 3.0]], [1, 2], ["foo", "bar"], 50)
Expand Down
1 change: 1 addition & 0 deletions vw/src/test/python/synapsemltest/vw/test_vw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


class VowpalWabbitSpec(unittest.TestCase):
def get_data(self):
# create sample data
Expand Down
1 change: 1 addition & 0 deletions vw/src/test/python/synapsemltest/vw/test_vw_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
spark = init_spark()
sc = SQLContext(spark.sparkContext)


def has_column(df, col):
try:
df[col]
Expand Down

0 comments on commit e8ee16c

Please sign in to comment.