Skip to content

Commit

Permalink
Merge pull request #151 from cirnoooo123/main
Browse files Browse the repository at this point in the history
implement SecretUnion
  • Loading branch information
SongY123 authored Apr 21, 2023
2 parents 6310b0d + 6518039 commit 751cbb4
Show file tree
Hide file tree
Showing 8 changed files with 609 additions and 3 deletions.
11 changes: 11 additions & 0 deletions data/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@
<groupId>org.locationtech.jts</groupId>
<artifactId>jts-core</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.hufudb.openhufu</groupId>
<artifactId>openhufu-common</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class ArrayDataSet implements MaterializedDataSet {
final List<ArrayRow> rows;
final int rowCount;

ArrayDataSet(Schema schema, List<ArrayRow> rows) {
public ArrayDataSet(Schema schema, List<ArrayRow> rows) {
this.schema = schema;
this.rows = rows;
this.rowCount = rows.size();
Expand Down Expand Up @@ -50,6 +50,10 @@ public int rowCount() {
return rowCount;
}

public List<ArrayRow> getRows() {
return rows;
}

class Iterator implements DataSetIterator {
int pointer;

Expand Down
171 changes: 171 additions & 0 deletions data/src/main/java/com/hufudb/openhufu/data/storage/RandomDataSet.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package com.hufudb.openhufu.data.storage;

import com.hufudb.openhufu.common.exception.ErrorCode;
import com.hufudb.openhufu.common.exception.OpenHuFuException;
import com.hufudb.openhufu.data.schema.Schema;
import com.hufudb.openhufu.proto.OpenHuFuData;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.math3.distribution.LaplaceDistribution;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.GeometryFactory;
import org.locationtech.jts.geom.Point;

import java.util.*;
import java.util.function.Function;

public class RandomDataSet {
public final static GeometryFactory geoFactory = new GeometryFactory();
private final static double RANDOM_SET_SCALE = 0.5;
private final static double EPS = 1.0;
public final static int RANDOM_SET_OFFSET = 10;
private final static LaplaceDistribution lap = new LaplaceDistribution(0, 1 / EPS);
private final static Random random = new Random();
private final Schema schema;
private final DataSet source;
private final List<ArrayRow> originRows;
private final int originSize;
private final int resultSize;
private final List<ArrayRow> randomRows;
private final Map<Object, List<ArrayRow>> randomRowMap;

public RandomDataSet(DataSet source) {
this.schema = source.getSchema();
this.source = source;
this.originRows = ArrayDataSet.materialize(source).rows;
this.originSize = originRows.size();
this.randomRows = new ArrayList<>();
this.randomRowMap = new HashMap<>();
int size = (int) Math.ceil((double) originSize * RANDOM_SET_SCALE + lap.sample());
this.resultSize = size > 0 ? size : (int) Math.abs(lap.sample());
if (originSize == 0) {
init(this::getRandomValue);
} else {
init(this::getRandomValueFromData);
}
this.mix();
}

private void init(Function<Integer, Object> randomFunc) {
final int headerSize = schema.size();
if (headerSize == 0) {
return;
}
for (int i = 0; i < resultSize; ++i) {
Object[] objects = new Object[headerSize];
Object key = randomFunc.apply(0);
objects[0] = key;
for (int j = 1; j < headerSize; ++j) {
Object value = randomFunc.apply(j);
objects[j] = value;
}
ArrayRow row = new ArrayRow(objects);
randomRows.add(row);
recordRandomRow(key, row);
}
}

private void recordRandomRow(Object key, ArrayRow value) {
if (randomRowMap.containsKey(key)) {
randomRowMap.get(key).add(value);
} else {
randomRowMap.put(key, new LinkedList<>(Arrays.asList(value)));
}
}

private void mix() {
//todo index insert for ArrayList may be slow
for (ArrayRow row : originRows) {
int idx = (int) Math.ceil(random.nextDouble() * randomRows.size());
randomRows.add(idx, row);
}
}

public ArrayDataSet getRandomSet() {
return new ArrayDataSet(schema, randomRows);
}

public ArrayDataSet removeRandom(DataSet dataSet) {
List<ArrayRow> newRows = new ArrayList<>();
for (ArrayRow row : ArrayDataSet.materialize(dataSet).rows) {
Object key = row.get(0);
List<ArrayRow> rows = randomRowMap.get(key);
if (rows == null || rows.isEmpty()) {
newRows.add(row);
continue;
}
int idx = rows.indexOf(row);
if (idx == -1) {
newRows.add(row);
} else {
rows.remove(idx);
}
}
return new ArrayDataSet(schema, newRows);
}

private Object getRandomValueFromData(int columnIndex) {
OpenHuFuData.ColumnType type = schema.getType((columnIndex));
int r = (int) (random.nextDouble() * originSize);
switch (type) {
case BYTE:
return (byte) originRows.get(r).get(columnIndex) + (byte) lap.sample();
case SHORT:
return (short) originRows.get(r).get(columnIndex) + (short) lap.sample();
case INT:
return (int) originRows.get(r).get(columnIndex) + (int) lap.sample();
case LONG:
case DATE:
case TIME:
case TIMESTAMP:
return (long) originRows.get(r).get(columnIndex) + (long) lap.sample();
case FLOAT:
return (float) originRows.get(r).get(columnIndex) + (float) lap.sample();
case DOUBLE:
return (double) originRows.get(r).get(columnIndex) + lap.sample();
case BOOLEAN:
return lap.sample() > 0.0;
case GEOMETRY:
Geometry geometry = (Geometry) originRows.get(r).get(columnIndex);
if (geometry instanceof Point) {
Point p = (Point) geometry;
return geoFactory.createPoint(new Coordinate(p.getX() + lap.sample(), p.getX() + lap.sample()));
} else {
throw new OpenHuFuException(ErrorCode.DATA_TYPE_NOT_SUPPORT, type);
}
case STRING:
return originRows.get(r).get(columnIndex);
default:
throw new OpenHuFuException(ErrorCode.DATA_TYPE_NOT_SUPPORT, type);
}
}

private Object getRandomValue(int columnIndex) {
OpenHuFuData.ColumnType type = schema.getType((columnIndex));
switch (type) {
case BYTE:
return (byte) lap.sample();
case SHORT:
return (short) lap.sample();
case INT:
return (int) lap.sample();
case LONG:
case DATE:
case TIME:
case TIMESTAMP:
return (long) lap.sample();
case FLOAT:
return (float) lap.sample();
case DOUBLE:
return lap.sample();
case BOOLEAN:
return lap.sample() > 0.0;
case GEOMETRY:
return geoFactory.createPoint(new Coordinate(lap.sample(), lap.sample()));
case STRING:
return RandomStringUtils.randomAlphanumeric(RANDOM_SET_OFFSET);
default:
throw new OpenHuFuException(ErrorCode.DATA_TYPE_NOT_SUPPORT, type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.sql.SQLException;
import java.sql.Time;
import java.sql.Timestamp;
import java.time.Year;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -533,4 +534,47 @@ public void testDataSetWithPoint() {
}
assertFalse(it.next());
}

@Test
public void testRandomDataSet() {
Schema schema = Schema.newBuilder().add("A", ColumnType.GEOMETRY, Modifier.PUBLIC).build();
Random random = new Random();
List<Geometry> ps = new ArrayList<>();
ProtoDataSet.Builder dBuilder = ProtoDataSet.newBuilder(schema);
for (int i = 0; i < 10; ++i) {
Geometry p = GeometryUtils.fromString(String.format("POINT(%f %f)", random.nextDouble(), random.nextDouble()));
ArrayRow.Builder rBuilder = ArrayRow.newBuilder(1);
rBuilder.set(0, p);
dBuilder.addRow(rBuilder.build());
ps.add(p);
}
ProtoDataSet dataset = dBuilder.build();
RandomDataSet randomDataSet = new RandomDataSet(dataset);
DataSet newDataSet = randomDataSet.getRandomSet();
DataSetIterator it = newDataSet.getIterator();
for (int i = 0; i < 10; ++i) {
assertTrue(it.next());
}
assertTrue(it.next());

it = randomDataSet.removeRandom(newDataSet).getIterator();
for (int i = 0; i < 10; ++i) {
assertTrue(it.next());
boolean has = false;
for (int j = 0; j < 10; ++j) {
if (ps.get(j).equals(it.get(0))) {
has = true;
break;
}
}
assertTrue(has);
}
assertFalse(it.next());
}
void printAll(DataSet dataSet) {
DataSetIterator it = dataSet.getIterator();
while (it.next()) {
System.out.println(it.get(0));
}
}
}
4 changes: 2 additions & 2 deletions mpc/src/main/java/com/hufudb/openhufu/mpc/ProtocolType.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ public enum ProtocolType {
GMW("GMW", 100, true),
SS("SS", 101, true),
HASH_PSI("PSI", 200, true),
ABY("ABY", 300, true);

ABY("ABY", 300, true),
SECRET_UNION("SECRET_UNION", 400, true);
private static final ImmutableMap<Integer, ProtocolType> MAP;

static {
Expand Down
Loading

0 comments on commit 751cbb4

Please sign in to comment.