Skip to content

Commit 2329805

Browse files
committed
added support for withColumns in ZFrame
1 parent 921c6ae commit 2329805

File tree

4 files changed

+22
-12
lines changed

4 files changed

+22
-12
lines changed

common/client/src/main/java/zingg/common/client/ZFrame.java

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ public interface ZFrame<D, R, C> {
7676
public ZFrame<D, R, C> unionByName(ZFrame<D, R, C> other, boolean flag);
7777

7878
public <A> ZFrame<D, R, C> withColumn(String s, A c);
79+
public ZFrame<D, R, C> withColumns(String[] columns, C[] columnValues);
7980

8081

8182
public ZFrame<D, R, C> repartition(int num);

common/core/src/main/java/zingg/common/core/preprocess/casenormalize/CaseNormalizer.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public abstract class CaseNormalizer<S,D,R,C,T> implements IMultiFieldPreprocess
1515

1616
private static final long serialVersionUID = 1L;
1717
private static final String STRING_TYPE = "string";
18-
protected static String name = "zingg.common.core.preprocess.caseNormalize.CaseNormalizer";
18+
protected static String name = "zingg.common.core.preprocess.casenormalize.CaseNormalizer";
1919
public static final Log LOG = LogFactory.getLog(CaseNormalizer.class);
2020

2121
private IContext<S, D, R, C, T> context;

spark/client/src/main/java/zingg/spark/client/SparkFrame.java

+14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package zingg.spark.client;
22

3+
import java.util.Arrays;
34
import java.util.List;
45

56
import org.apache.spark.sql.Column;
@@ -10,6 +11,7 @@
1011
import org.apache.spark.sql.types.StructField;
1112

1213
import scala.collection.JavaConverters;
14+
import scala.collection.Seq;
1315
import zingg.common.client.FieldData;
1416
import zingg.common.client.ZFrame;
1517
import zingg.common.client.util.ColName;
@@ -203,6 +205,18 @@ public <A> ZFrame<Dataset<Row>, Row, Column> withColumn(String s, A c){
203205
return new SparkFrame(df.withColumn(s, functions.lit(c)));
204206
}
205207

208+
@Override
209+
public ZFrame<Dataset<Row>, Row, Column> withColumns(String[] columns, Column[] columnValues) {
210+
Seq<String> columnsSeq = JavaConverters.asScalaIteratorConverter(Arrays.asList(columns).iterator())
211+
.asScala()
212+
.toSeq();
213+
Seq<Column> columnValuesSeq = JavaConverters.asScalaIteratorConverter(Arrays.asList(columnValues).iterator())
214+
.asScala()
215+
.toSeq();
216+
217+
return new SparkFrame(df.withColumns(columnsSeq, columnValuesSeq));
218+
}
219+
206220
public ZFrame<Dataset<Row>, Row, Column> repartition(int nul){
207221
return new SparkFrame(df.repartition(nul));
208222
}

spark/core/src/main/java/zingg/spark/core/preprocess/caseNormalize/SparkCaseNormalizer.java

+6-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
public class SparkCaseNormalizer extends CaseNormalizer<SparkSession, Dataset<Row>, Row, Column, DataType> {
2222
private static final long serialVersionUID = 1L;
23-
protected static String name = "zingg.spark.core.preprocess.caseNormalize.SparkCaseNormalizer";
23+
protected static String name = "zingg.spark.core.preprocess.casenormalize.SparkCaseNormalizer";
2424

2525
public SparkCaseNormalizer() {
2626
super();
@@ -32,16 +32,11 @@ public SparkCaseNormalizer(IContext<SparkSession, Dataset<Row>, Row, Column, Dat
3232
@Override
3333
protected ZFrame<Dataset<Row>, Row, Column> applyCaseNormalizer(ZFrame<Dataset<Row>, Row, Column> incomingDataFrame, List<String> relevantFields) {
3434
String[] incomingDFColumns = incomingDataFrame.columns();
35-
Seq<String> columnsSeq = JavaConverters.asScalaIteratorConverter(relevantFields.iterator())
36-
.asScala()
37-
.toSeq();
38-
List<Column> caseNormalizedValues = new ArrayList<>();
39-
for (String relevantField : relevantFields) {
40-
caseNormalizedValues.add(lower(incomingDataFrame.col(relevantField)));
35+
Column[] caseNormalizedValues = new Column[relevantFields.size()];
36+
for (int idx = 0; idx < relevantFields.size(); idx++) {
37+
caseNormalizedValues[idx] = lower(incomingDataFrame.col(relevantFields.get(idx)));
4138
}
42-
Seq<Column> caseNormalizedSeq = JavaConverters.asScalaIteratorConverter(caseNormalizedValues.iterator())
43-
.asScala()
44-
.toSeq();
45-
return new SparkFrame(incomingDataFrame.df().withColumns(columnsSeq, caseNormalizedSeq)).select(incomingDFColumns);
39+
40+
return incomingDataFrame.withColumns(incomingDFColumns, caseNormalizedValues).select(incomingDFColumns);
4641
}
4742
}

0 commit comments

Comments
 (0)