Skip to content

Commit

Permalink
Merge pull request #932 from Nitish1814/single-sparksession-911
Browse files Browse the repository at this point in the history
added singleton SparkSessionProvider
  • Loading branch information
sonalgoyal authored Nov 12, 2024
2 parents d292696 + e470624 commit 04f3b2a
Show file tree
Hide file tree
Showing 31 changed files with 423 additions and 248 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

public abstract class DataDocumenter<S,D,R,C,T> extends DocumenterBase<S,D,R,C,T> {
protected static String name = "zingg.DataDocumenter";
protected static String TEMPLATE_TITLE = "Data Documentation";
public static String TEMPLATE_TITLE = "Data Documentation";
private final String DATA_DOC_TEMPLATE = "dataDocTemplate.ftlh";

public static final Log LOG = LogFactory.getLog(DataDocumenter.class);
Expand Down Expand Up @@ -61,7 +61,7 @@ protected void writeModelDocument(Map<String, Object> root) throws ZinggClientEx
writeDocument(DATA_DOC_TEMPLATE, root, args.getZinggDataDocFile());
}

protected Map<String, Object> populateTemplateData() {
public Map<String, Object> populateTemplateData() {
Map<String, Object> root = new HashMap<String, Object>();
root.put(TemplateFields.TITLE, TEMPLATE_TITLE);
root.put(TemplateFields.MODEL_ID, args.getModelId());
Expand All @@ -84,6 +84,14 @@ protected List<String[]> getFieldDataList() {
return list;
}

public ZFrame<D,R,C> getData() {
return this.data;
}

public void setData(ZFrame<D, R, C> data) {
this.data = data;
}

@Override
public void execute() throws ZinggClientException {
// TODO Auto-generated method stub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ private Configuration createConfigurationObject() {
return cfg;
}

protected void writeDocument(String template, Map<String, Object> root, String fileName)
public void writeDocument(String template, Map<String, Object> root, String fileName)
throws ZinggClientException {
try {
Configuration cfg = getTemplateConfig();
Expand All @@ -67,7 +67,7 @@ protected void writeDocument(String template, Map<String, Object> root, String f
}
}

protected void checkAndCreateDir(String dirName) {
public void checkAndCreateDir(String dirName) {
File directory = new File(dirName);
if (!directory.exists()) {
directory.mkdirs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void process() throws ZinggClientException {
modelColDoc.process(markedRecords);
}

protected void createModelDocument() throws ZinggClientException {
public void createModelDocument() throws ZinggClientException {
try {
LOG.info("Model document generation starts");

Expand All @@ -68,7 +68,7 @@ private void writeModelDocument(Map<String, Object> root) throws ZinggClientExce
writeDocument(MODEL_TEMPLATE, root, args.getZinggModelDocFile());
}

protected Map<String, Object> populateTemplateData() {
public Map<String, Object> populateTemplateData() {
/* Create a data-model */
Map<String, Object> root = new HashMap<String, Object>();
root.put(TemplateFields.MODEL_ID, args.getModelId());
Expand Down Expand Up @@ -166,6 +166,14 @@ private void putSummaryCounts(Map<String, Object> root) {


}

public void setMarkedRecords(ZFrame<D, R, C> markedRecords) {
this.markedRecords = markedRecords;
}

public void setUnmarkedRecords(ZFrame<D, R, C> unmarkedRecords) {
this.unmarkedRecords = unmarkedRecords;
}

@Override
public void execute() throws ZinggClientException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zingg;
package zingg.spark.core;


import static org.junit.jupiter.api.Assertions.assertEquals;
Expand All @@ -7,9 +7,7 @@
import org.apache.spark.sql.types.DataTypes;
import org.junit.jupiter.api.Test;

import zingg.spark.core.executor.ZinggSparkTester;

public class TestDataType extends ZinggSparkTester{
public class TestDataType {

@Test
public void testDataType() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
package zingg;
package zingg.spark.core;
import static org.junit.jupiter.api.Assertions.fail;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.BeforeEach;

import zingg.common.client.Arguments;
import zingg.spark.core.executor.ZinggSparkTester;
import zingg.common.client.ArgumentsUtil;
import zingg.common.client.IArguments;

public class TestDocumenter extends ZinggSparkTester{

public class TestDocumenter {

public static final Log LOG = LogFactory.getLog(TestDocumenter.class);
@BeforeEach
public void setUp(){

try {
args = argsUtil.createArgumentsFromJSON(getClass().getResource("/testDocumenter/config.json").getFile());
ArgumentsUtil argsUtil = new ArgumentsUtil();
IArguments args = argsUtil.createArgumentsFromJSON(getClass().getResource("/testDocumenter/config.json").getFile());
//fail("Exception was expected for missing config file");
} catch (Throwable e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zingg;
package zingg.spark.core;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand All @@ -8,36 +8,46 @@
import org.apache.commons.logging.LogFactory;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import zingg.common.client.Arguments;
import zingg.common.client.ClientOptions;
import org.junit.jupiter.api.extension.ExtendWith;
import zingg.common.client.ArgumentsUtil;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.pipe.FilePipe;
import zingg.common.client.pipe.Pipe;
import zingg.common.client.util.ColName;
import zingg.common.core.executor.TrainMatcher;
import zingg.spark.client.pipe.SparkPipe;
import zingg.spark.core.executor.SparkTrainMatcher;
import zingg.spark.core.executor.ZinggSparkTester;
/**end to end integration test*/
public class TestFebrlDataset extends ZinggSparkTester{

@ExtendWith(TestSparkBase.class)
public class TestFebrlDataset {
public static final Log LOG = LogFactory.getLog(TestFebrlDataset.class);

private final SparkSession sparkSession;

public TestFebrlDataset(SparkSession sparkSession) {
this.sparkSession = sparkSession;
}

SparkPipe outputPipe;

ArgumentsUtil argsUtil = new ArgumentsUtil();
IArguments args;

@BeforeEach
public void setUp() throws Exception, ZinggClientException{
String configFilePath = getClass().getResource("../testFebrl/config.json").getFile();
String configFilePath = getClass().getResource("/testFebrl/config.json").getFile();
System.out.println("configFilePath "+configFilePath);
args = argsUtil.createArgumentsFromJSON(configFilePath);
String modelPath = getClass().getResource("../testFebrl/models").getPath();
String modelPath = getClass().getResource("/testFebrl/models").getPath();
System.out.println("modelPath "+modelPath);
args.setZinggDir(modelPath);
Pipe dataPipe = args.getData()[0];
String csvPath = getClass().getResource("../testFebrl/test.csv").getPath();
String csvPath = getClass().getResource("/testFebrl/test.csv").getPath();
System.out.println("csvPath "+csvPath);
dataPipe.setProp(FilePipe.LOCATION, csvPath);
args.setData(new Pipe[]{dataPipe});
Expand All @@ -51,7 +61,7 @@ public void setUp() throws Exception, ZinggClientException{
public void testModelAccuracy(){
TrainMatcher tm = new SparkTrainMatcher();
try {
tm.init(args,spark,null);
tm.init(args,sparkSession,null);
// tm.setSpark(spark);
// tm.setCtx(ctx);
tm.setArgs(args);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zingg;
package zingg.spark.core;

import static org.apache.spark.sql.functions.callUDF;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand All @@ -10,21 +10,28 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.jupiter.api.Test;

import org.junit.jupiter.api.extension.ExtendWith;
import zingg.common.core.similarity.function.ArrayDoubleSimilarityFunction;
import zingg.spark.core.executor.ZinggSparkTester;
import zingg.spark.core.util.SparkFnRegistrar;

public class TestImageType extends ZinggSparkTester{
@ExtendWith(TestSparkBase.class)
public class TestImageType {


private static final double SMALL_DELTA = 0.0000000001;
private final SparkSession sparkSession;

public TestImageType(SparkSession sparkSession) {
this.sparkSession = sparkSession;
}


@Test
Expand Down Expand Up @@ -91,7 +98,7 @@ public void testUDFArray() {
df.printSchema();
// register ArrayDoubleSimilarityFunction as a UDF
TestUDFDoubleArr testUDFDoubleArr = new TestUDFDoubleArr();
SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleArr", testUDFDoubleArr, DataTypes.DoubleType);
SparkFnRegistrar.registerUDF2(sparkSession, "testUDFDoubleArr", testUDFDoubleArr, DataTypes.DoubleType);
// call the UDF from select clause of DF
df = df.withColumn("cosine",
callUDF("testUDFDoubleArr", df.col("image_embedding"), df.col("image_embedding")));
Expand All @@ -117,7 +124,7 @@ public void testUDFList() {

// register ArrayDoubleSimilarityFunction as a UDF
TestUDFDoubleList testUDFDoubleList = new TestUDFDoubleList();
SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleList", testUDFDoubleList, DataTypes.DoubleType);
SparkFnRegistrar.registerUDF2(sparkSession, "testUDFDoubleList", testUDFDoubleList, DataTypes.DoubleType);

// call the UDF from select clause of DF
df = df.withColumn("cosine", callUDF("testUDFDoubleList",df.col("image_embedding"),df.col("image_embedding")));
Expand All @@ -143,7 +150,7 @@ public void testUDFSeq() {

// register ArrayDoubleSimilarityFunction as a UDF
TestUDFDoubleSeq testUDFDoubleSeq = new TestUDFDoubleSeq();
SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleSeq", testUDFDoubleSeq, DataTypes.DoubleType);
SparkFnRegistrar.registerUDF2(sparkSession, "testUDFDoubleSeq", testUDFDoubleSeq, DataTypes.DoubleType);

// call the UDF from select clause of DF
df = df.withColumn("cosine", callUDF("testUDFDoubleSeq",df.col("image_embedding"),df.col("image_embedding")));
Expand All @@ -169,7 +176,7 @@ public void testUDFWrappedArr() {

// register ArrayDoubleSimilarityFunction as a UDF
TestUDFDoubleWrappedArr testUDFDoubleWrappedArr = new TestUDFDoubleWrappedArr();
SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleWrappedArr", testUDFDoubleWrappedArr, DataTypes.DoubleType);
SparkFnRegistrar.registerUDF2(sparkSession, "testUDFDoubleWrappedArr", testUDFDoubleWrappedArr, DataTypes.DoubleType);

// call the UDF from select clause of DF
df = df.withColumn("cosine", callUDF("testUDFDoubleWrappedArr",df.col("image_embedding"),df.col("image_embedding")));
Expand Down Expand Up @@ -198,7 +205,7 @@ public void testUDFObj() {

// register ArrayDoubleSimilarityFunction as a UDF
TestUDFDoubleObj testUDFDoubleObj = new TestUDFDoubleObj();
SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleObj", testUDFDoubleObj, DataTypes.DoubleType);
SparkFnRegistrar.registerUDF2(sparkSession, "testUDFDoubleObj", testUDFDoubleObj, DataTypes.DoubleType);

// call the UDF from select clause of DF
df = df.withColumn("cosine", callUDF("testUDFDoubleObj",df.col("image_embedding"),df.col("image_embedding")));
Expand Down Expand Up @@ -227,7 +234,7 @@ protected Dataset<Row> createSampleDataset() {
});


Dataset<Row> sample = spark.createDataFrame(Arrays.asList(
Dataset<Row> sample = sparkSession.createDataFrame(Arrays.asList(
RowFactory.create("07317257", "erjc", "henson", "hendersonville", "2873g",new Double[]{0.1123,10.456,110.789}),
RowFactory.create("03102490", "jhon", "kozak", "henders0nville", "28792",new Double[]{0.2123,20.456,220.789}),
RowFactory.create("02890805", "david", "pisczek", "durham", "27717",new Double[]{0.3123,30.456,330.789}),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
package zingg;
package zingg.spark.core;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;
import zingg.spark.core.executor.ZinggSparkTester;
import zingg.common.client.IArguments;
import zingg.spark.core.session.SparkSessionProvider;
import zingg.spark.core.context.ZinggSparkContext;

public class TestSparkBase extends ZinggSparkTester implements BeforeAllCallback, AfterAllCallback, ParameterResolver{

public SparkSession sparkSession;

static boolean isSetUp;
public class TestSparkBase implements BeforeAllCallback, AfterAllCallback, ParameterResolver {
public static IArguments args;
public static JavaSparkContext ctx;
public static SparkSession spark;
public static ZinggSparkContext zsCTX;
static boolean isSetUp = false;

@Override
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
Expand All @@ -25,7 +29,7 @@ public boolean supportsParameter(ParameterContext parameterContext, ExtensionCon
@Override
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {
return sparkSession;
return spark;
}

@Override
Expand All @@ -35,9 +39,12 @@ public void afterAll(ExtensionContext context) {

@Override
public void beforeAll(ExtensionContext context) {
if (!isSetUp || sparkSession == null) {
super.setup();
sparkSession = ZinggSparkTester.spark;
if (!isSetUp || spark == null) {
SparkSessionProvider sparkSessionProvider = SparkSessionProvider.getInstance();
spark = sparkSessionProvider.getSparkSession();
ctx = sparkSessionProvider.getJavaSparkContext();
args = sparkSessionProvider.getArgs();
zsCTX = sparkSessionProvider.getZinggSparkContext();
}
isSetUp = true;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zingg;
package zingg.spark.core;

import org.apache.spark.sql.api.java.UDF2;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zingg;
package zingg.spark.core;

import java.util.List;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zingg;
package zingg.spark.core;

import org.apache.spark.sql.api.java.UDF2;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zingg;
package zingg.spark.core;

import org.apache.spark.sql.api.java.UDF2;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zingg;
package zingg.spark.core;

import org.apache.spark.sql.api.java.UDF2;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package zingg.common.core.block;
package zingg.spark.core.block;

import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.junit.jupiter.api.extension.ExtendWith;
import zingg.TestSparkBase;
import zingg.common.core.block.TestBlockBase;
import zingg.spark.core.TestSparkBase;
import zingg.common.client.ZinggClientException;
import zingg.common.client.util.IWithSession;
import zingg.common.client.util.WithSession;
Expand Down
Loading

0 comments on commit 04f3b2a

Please sign in to comment.