Skip to content

Commit

Permalink
refactoring init in code
Browse files Browse the repository at this point in the history
  • Loading branch information
sania-16 committed Sep 18, 2024
1 parent 5e299ba commit 47e4543
Show file tree
Hide file tree
Showing 18 changed files with 57 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ else if (options.get(ClientOptions.CONF).value.endsWith("env")) {

public void init() throws ZinggClientException {
zingg.setClientOptions(getOptions());
zingg.init(getArguments(), getSession());
zingg.init(getArguments(), getSession(),getOptions());
if (session != null) zingg.setSession(session);
initializeListeners();
EventsListener.getInstance().fireEvent(new ZinggStartEvent());
Expand Down
3 changes: 1 addition & 2 deletions common/client/src/main/java/zingg/common/client/IZingg.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

public interface IZingg<S,D,R,C> {

public void init(IArguments args, S session)
public void init(IArguments args, S session, ClientOptions options)
throws ZinggClientException;

public void execute() throws ZinggClientException;
Expand Down Expand Up @@ -31,7 +31,6 @@ public void init(IArguments args, S session)

//public void setSession(S session); // method name will have to be changed in Client too


public void setClientOptions(ClientOptions clientOptions);

public ClientOptions getClientOptions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand All @@ -20,10 +21,10 @@ public FindAndLabeller() {
}

@Override
public void init(IArguments args, S s) throws ZinggClientException {
finder.init(args,s);
labeller.init(args,s);
super.init(args,s);
public void init(IArguments args, S s, ClientOptions options) throws ZinggClientException {
finder.init(args,s,options);
labeller.init(args,s,options);
super.init(args,s,options);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.http.impl.execchain.ClientExecChain;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand All @@ -21,11 +23,11 @@ public TrainMatcher() {
}

@Override
public void init(IArguments args, S s)
public void init(IArguments args, S s, ClientOptions options)
throws ZinggClientException {
trainer.init(args,s);
matcher.init(args,s);
super.init(args,s);
trainer.init(args,s,options);
matcher.init(args,s,options);
super.init(args,s,options);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ public ZinggBase() {


@Override
public void init(IArguments args, S session)
public void init(IArguments args, S session, ClientOptions options)
throws ZinggClientException {
startTime = System.currentTimeMillis();
this.args = args;
this.clientOptions = options;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.junit.jupiter.api.Test;

import zingg.common.client.ArgumentsUtil;
import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;

Expand All @@ -16,9 +17,8 @@ public abstract class TestExecutorsGeneric<S, D, R, C, T> {
public static final Log LOG = LogFactory.getLog(TestExecutorsGeneric.class);

protected IArguments args;


protected S session;
protected ClientOptions options;

public TestExecutorsGeneric() {

Expand Down Expand Up @@ -50,33 +50,33 @@ public void testExecutors() throws ZinggClientException {
List<ExecutorTester<S, D, R, C, T>> executorTesterList = new ArrayList<ExecutorTester<S, D, R, C, T>>();

TrainingDataFinder<S, D, R, C, T> trainingDataFinder = getTrainingDataFinder();
trainingDataFinder.init(args,session);
trainingDataFinder.init(args,session,options);
TrainingDataFinderTester<S, D, R, C, T> tdft = new TrainingDataFinderTester<S, D, R, C, T>(trainingDataFinder);
executorTesterList.add(tdft);

Labeller<S, D, R, C, T> labeller = getLabeller();
labeller.init(args,session);
labeller.init(args,session,options);
LabellerTester<S, D, R, C, T> lt = new LabellerTester<S, D, R, C, T>(labeller);
executorTesterList.add(lt);

// training and labelling needed twice to get sufficient data
TrainingDataFinder<S, D, R, C, T> trainingDataFinder2 = getTrainingDataFinder();
trainingDataFinder2.init(args,session);
trainingDataFinder2.init(args,session,options);
TrainingDataFinderTester<S, D, R, C, T> tdft2 = new TrainingDataFinderTester<S, D, R, C, T>(trainingDataFinder2);
executorTesterList.add(tdft2);

Labeller<S, D, R, C, T> labeller2 = getLabeller();
labeller2.init(args,session);
labeller2.init(args,session,options);
LabellerTester<S, D, R, C, T> lt2 = new LabellerTester<S, D, R, C, T>(labeller2);
executorTesterList.add(lt2);

Trainer<S, D, R, C, T> trainer = getTrainer();
trainer.init(args,session);
trainer.init(args,session,options);
TrainerTester<S, D, R, C, T> tt = getTrainerTester(trainer);
executorTesterList.add(tt);

Matcher<S, D, R, C, T> matcher = getMatcher();
matcher.init(args,session);
matcher.init(args,session,options);
MatcherTester<S, D, R, C, T> mt = new MatcherTester(matcher);
executorTesterList.add(mt);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand Down Expand Up @@ -36,8 +37,8 @@ public SparkDocumenter(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.SparkSession;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand All @@ -35,8 +36,8 @@ public SparkFindAndLabeller(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.types.DataType;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand Down Expand Up @@ -39,8 +40,8 @@ public SparkLabelUpdater(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.SparkSession;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand Down Expand Up @@ -37,8 +38,8 @@ public SparkLabeller(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand All @@ -34,8 +35,8 @@ public SparkLinker(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand Down Expand Up @@ -40,8 +41,8 @@ public SparkMatcher(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ public SparkPeekModel() {
}

@Override
public void init(IArguments args, SparkSession s)
public void init(IArguments args, SparkSession s, ClientOptions options)
throws ZinggClientException {
super.init(args,s);
super.init(args,s,options);
getContext().setUtils();
//we wil not init here as we wnt py to drive
//the spark session etc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand Down Expand Up @@ -38,8 +39,8 @@ public SparkRecommender(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand Down Expand Up @@ -34,8 +35,8 @@ public SparkTrainMatcher(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.SparkSession;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand All @@ -35,8 +36,8 @@ public SparkTrainer(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
Expand All @@ -32,8 +33,8 @@ public SparkTrainingDataFinder(ZinggSparkContext sparkContext) {
}

@Override
public void init(IArguments args, SparkSession s) throws ZinggClientException {
super.init(args,s);
public void init(IArguments args, SparkSession s, ClientOptions options) throws ZinggClientException {
super.init(args,s,options);
getContext().init(s);
}

Expand Down
3 changes: 2 additions & 1 deletion spark/core/src/test/java/zingg/TestFebrlDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.junit.jupiter.api.Test;

import zingg.common.client.Arguments;
import zingg.common.client.ClientOptions;
import zingg.common.client.ZinggClientException;
import zingg.common.client.pipe.FilePipe;
import zingg.common.client.pipe.Pipe;
Expand Down Expand Up @@ -50,7 +51,7 @@ public void setUp() throws Exception, ZinggClientException{
public void testModelAccuracy(){
TrainMatcher tm = new SparkTrainMatcher();
try {
tm.init(args,spark);
tm.init(args,spark,null);
// tm.setSpark(spark);
// tm.setCtx(ctx);
tm.setArgs(args);
Expand Down

0 comments on commit 47e4543

Please sign in to comment.