diff --git a/common/client/src/main/java/zingg/common/client/Client.java b/common/client/src/main/java/zingg/common/client/Client.java index 62e61cbb6..8167c81a4 100644 --- a/common/client/src/main/java/zingg/common/client/Client.java +++ b/common/client/src/main/java/zingg/common/client/Client.java @@ -264,7 +264,7 @@ else if (options.get(ClientOptions.CONF).value.endsWith("env")) { public void init() throws ZinggClientException { zingg.setClientOptions(getOptions()); - zingg.init(getArguments(), getSession()); + zingg.init(getArguments(), getSession(),getOptions()); if (session != null) zingg.setSession(session); initializeListeners(); EventsListener.getInstance().fireEvent(new ZinggStartEvent()); diff --git a/common/client/src/main/java/zingg/common/client/IZingg.java b/common/client/src/main/java/zingg/common/client/IZingg.java index 61bd8133e..265b8c6b6 100644 --- a/common/client/src/main/java/zingg/common/client/IZingg.java +++ b/common/client/src/main/java/zingg/common/client/IZingg.java @@ -2,7 +2,7 @@ public interface IZingg { - public void init(IArguments args, S session) + public void init(IArguments args, S session, ClientOptions options) throws ZinggClientException; public void execute() throws ZinggClientException; @@ -31,7 +31,6 @@ public void init(IArguments args, S session) //public void setSession(S session); // method name will have to be changed in Client too - public void setClientOptions(ClientOptions clientOptions); public ClientOptions getClientOptions(); diff --git a/common/core/src/main/java/zingg/common/core/executor/FindAndLabeller.java b/common/core/src/main/java/zingg/common/core/executor/FindAndLabeller.java index b8eb3eff0..31c2031dc 100644 --- a/common/core/src/main/java/zingg/common/core/executor/FindAndLabeller.java +++ b/common/core/src/main/java/zingg/common/core/executor/FindAndLabeller.java @@ -3,6 +3,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -20,10 +21,10 @@ public FindAndLabeller() { } @Override - public void init(IArguments args, S s) throws ZinggClientException { - finder.init(args,s); - labeller.init(args,s); - super.init(args,s); + public void init(IArguments args, S s, ClientOptions options) throws ZinggClientException { + finder.init(args,s,options); + labeller.init(args,s,options); + super.init(args,s,options); } @Override diff --git a/common/core/src/main/java/zingg/common/core/executor/TrainMatcher.java b/common/core/src/main/java/zingg/common/core/executor/TrainMatcher.java index b4fdfc97e..d810501e4 100644 --- a/common/core/src/main/java/zingg/common/core/executor/TrainMatcher.java +++ b/common/core/src/main/java/zingg/common/core/executor/TrainMatcher.java @@ -2,7 +2,9 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.http.impl.execchain.ClientExecChain; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -21,11 +23,11 @@ public TrainMatcher() { } @Override - public void init(IArguments args, S s) + public void init(IArguments args, S s, ClientOptions options) throws ZinggClientException { - trainer.init(args,s); - matcher.init(args,s); - super.init(args,s); + trainer.init(args,s,options); + matcher.init(args,s,options); + super.init(args,s,options); } @Override diff --git a/common/core/src/main/java/zingg/common/core/executor/ZinggBase.java b/common/core/src/main/java/zingg/common/core/executor/ZinggBase.java index 1f1fd6cc0..8c622ca24 100644 --- a/common/core/src/main/java/zingg/common/core/executor/ZinggBase.java +++ b/common/core/src/main/java/zingg/common/core/executor/ZinggBase.java @@ -63,10 +63,11 @@ public ZinggBase() { @Override - public void init(IArguments args, S session) + public void init(IArguments args, S session, ClientOptions options) throws ZinggClientException { startTime = System.currentTimeMillis(); this.args = args; + this.clientOptions = options; } diff --git a/common/core/src/test/java/zingg/common/core/executor/TestExecutorsGeneric.java b/common/core/src/test/java/zingg/common/core/executor/TestExecutorsGeneric.java index 6de3c9813..a7c85f116 100644 --- a/common/core/src/test/java/zingg/common/core/executor/TestExecutorsGeneric.java +++ b/common/core/src/test/java/zingg/common/core/executor/TestExecutorsGeneric.java @@ -8,6 +8,7 @@ import org.junit.jupiter.api.Test; import zingg.common.client.ArgumentsUtil; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; @@ -16,9 +17,8 @@ public abstract class TestExecutorsGeneric { public static final Log LOG = LogFactory.getLog(TestExecutorsGeneric.class); protected IArguments args; - - protected S session; + protected ClientOptions options; public TestExecutorsGeneric() { @@ -50,33 +50,33 @@ public void testExecutors() throws ZinggClientException { List> executorTesterList = new ArrayList>(); TrainingDataFinder trainingDataFinder = getTrainingDataFinder(); - trainingDataFinder.init(args,session); + trainingDataFinder.init(args,session,options); TrainingDataFinderTester tdft = new TrainingDataFinderTester(trainingDataFinder); executorTesterList.add(tdft); Labeller labeller = getLabeller(); - labeller.init(args,session); + labeller.init(args,session,options); LabellerTester lt = new LabellerTester(labeller); executorTesterList.add(lt); // training and labelling needed twice to get sufficient data TrainingDataFinder trainingDataFinder2 = getTrainingDataFinder(); - trainingDataFinder2.init(args,session); + trainingDataFinder2.init(args,session,options); TrainingDataFinderTester tdft2 = new TrainingDataFinderTester(trainingDataFinder2); executorTesterList.add(tdft2); Labeller labeller2 = getLabeller(); - labeller2.init(args,session); + labeller2.init(args,session,options); LabellerTester lt2 = new LabellerTester(labeller2); executorTesterList.add(lt2); Trainer trainer = getTrainer(); - trainer.init(args,session); + trainer.init(args,session,options); TrainerTester tt = getTrainerTester(trainer); executorTesterList.add(tt); Matcher matcher = getMatcher(); - matcher.init(args,session); + matcher.init(args,session,options); MatcherTester mt = new MatcherTester(matcher); executorTesterList.add(mt); diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkDocumenter.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkDocumenter.java index 98e452c90..0f8652cf7 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkDocumenter.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkDocumenter.java @@ -7,6 +7,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -36,8 +37,8 @@ public SparkDocumenter(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkFindAndLabeller.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkFindAndLabeller.java index 0c0aeb550..462f95e7e 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkFindAndLabeller.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkFindAndLabeller.java @@ -9,6 +9,7 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.SparkSession; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -35,8 +36,8 @@ public SparkFindAndLabeller(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkLabelUpdater.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkLabelUpdater.java index 33dcbd706..d85e4b3c4 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkLabelUpdater.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkLabelUpdater.java @@ -8,6 +8,7 @@ import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.types.DataType; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -39,8 +40,8 @@ public SparkLabelUpdater(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkLabeller.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkLabeller.java index e8aa8f6ec..7b3365162 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkLabeller.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkLabeller.java @@ -8,6 +8,7 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.SparkSession; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -37,8 +38,8 @@ public SparkLabeller(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkLinker.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkLinker.java index 85f442314..200684913 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkLinker.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkLinker.java @@ -8,6 +8,7 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataType; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -34,8 +35,8 @@ public SparkLinker(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkMatcher.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkMatcher.java index 6cb0bc1cd..dd6085662 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkMatcher.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkMatcher.java @@ -8,6 +8,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -40,8 +41,8 @@ public SparkMatcher(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkPeekModel.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkPeekModel.java index 115390b85..2667f052d 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkPeekModel.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkPeekModel.java @@ -34,9 +34,9 @@ public SparkPeekModel() { } @Override - public void init(IArguments args, SparkSession s) + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { - super.init(args,s); + super.init(args,s,options); getContext().setUtils(); //we wil not init here as we wnt py to drive //the spark session etc diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkRecommender.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkRecommender.java index cf608a6e9..39038a14f 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkRecommender.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkRecommender.java @@ -7,6 +7,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -38,8 +39,8 @@ public SparkRecommender(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainMatcher.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainMatcher.java index 699af83bf..77ba7a789 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainMatcher.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainMatcher.java @@ -7,6 +7,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -34,8 +35,8 @@ public SparkTrainMatcher(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainer.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainer.java index e23c5b043..902a16c1c 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainer.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainer.java @@ -9,6 +9,7 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.SparkSession; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -35,8 +36,8 @@ public SparkTrainer(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainingDataFinder.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainingDataFinder.java index 012effdab..2d5f3ad0d 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainingDataFinder.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainingDataFinder.java @@ -7,6 +7,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.options.ZinggOptions; @@ -32,8 +33,8 @@ public SparkTrainingDataFinder(ZinggSparkContext sparkContext) { } @Override - public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args,s); + public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException { + super.init(args,s,options); getContext().init(s); } diff --git a/spark/core/src/test/java/zingg/TestFebrlDataset.java b/spark/core/src/test/java/zingg/TestFebrlDataset.java index a7ef49128..0cc33b05a 100644 --- a/spark/core/src/test/java/zingg/TestFebrlDataset.java +++ b/spark/core/src/test/java/zingg/TestFebrlDataset.java @@ -12,6 +12,7 @@ import org.junit.jupiter.api.Test; import zingg.common.client.Arguments; +import zingg.common.client.ClientOptions; import zingg.common.client.ZinggClientException; import zingg.common.client.pipe.FilePipe; import zingg.common.client.pipe.Pipe; @@ -50,7 +51,7 @@ public void setUp() throws Exception, ZinggClientException{ public void testModelAccuracy(){ TrainMatcher tm = new SparkTrainMatcher(); try { - tm.init(args,spark); + tm.init(args,spark,null); // tm.setSpark(spark); // tm.setCtx(ctx); tm.setArgs(args);