Skip to content

Commit

Permalink
Added basic framework for json_merge_aggr
Browse files Browse the repository at this point in the history
  • Loading branch information
shigarg1 committed Oct 14, 2024
1 parent ab0d6eb commit e7af5fb
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public class ExpressionModule implements Module
.add(HyperUniqueExpressions.HllRoundEstimateExprMacro.class)
.add(NestedDataExpressions.JsonObjectExprMacro.class)
.add(NestedDataExpressions.JsonMergeExprMacro.class)
.add(NestedDataExpressions.JsonMergeAggrExprMacro.class)
.add(NestedDataExpressions.JsonKeysExprMacro.class)
.add(NestedDataExpressions.JsonPathsExprMacro.class)
.add(NestedDataExpressions.JsonValueExprMacro.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
package org.apache.druid.query.expression;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectReader;
import org.apache.druid.guice.annotations.Json;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
Expand All @@ -38,10 +42,7 @@

import javax.annotation.Nullable;
import javax.inject.Inject;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.*;
import java.util.stream.Collectors;

public class NestedDataExpressions
Expand Down Expand Up @@ -149,7 +150,7 @@ public ExprEval eval(ObjectBinding bindings)
}

try {
obj = jsonMapper.readValue(getArgAsJson(arg), Object.class);
obj = jsonMapper.readValue(getArgAsJson(JsonMergeExprMacro.this, jsonMapper, arg), Object.class);
}
catch (JsonProcessingException e) {
throw JsonMergeExprMacro.this.processingFailed(e, "bad string input [%s]", arg.asString());
Expand All @@ -161,7 +162,7 @@ public ExprEval eval(ObjectBinding bindings)
ExprEval argSub = args.get(i).eval(bindings);

try {
String str = getArgAsJson(argSub);
String str = getArgAsJson(JsonMergeExprMacro.this, jsonMapper, argSub);
if (str != null) {
obj = updater.readValue(str);
}
Expand All @@ -181,36 +182,217 @@ public ExpressionType getOutputType(InputBindingInspector inspector)
return ExpressionType.NESTED_DATA;
}

private String getArgAsJson(ExprEval arg)

}
return new ParseJsonExpr(args);
}
}

public static class JsonMergeAggrExprMacro implements ExprMacroTable.ExprMacro
{
public static final String NAME = "json_merge_aggr";
static final Logger logger = new Logger(NestedDataExpressions.class);

private final ObjectMapper jsonMapper;

@Inject
public JsonMergeAggrExprMacro(
@Json ObjectMapper jsonMapper
)
{
this.jsonMapper = jsonMapper;
}

public enum Aggregate
{
ADD,
MULTIPLY
}

interface Aggregator
{
ExprEval aggregate(ExprEval sourceValue, ExprEval targetValue);
}

public static class AddAggregator implements Aggregator
{
@Override
public ExprEval aggregate(ExprEval sourceValue, ExprEval targetValue)
{
if (sourceValue.type().isNumeric() && targetValue.type().isNumeric()) {
if (sourceValue.type().equals(ExpressionType.LONG) && targetValue.type().equals(ExpressionType.LONG)) {
return ExprEval.of(sourceValue.asLong() + targetValue.asLong());
}

return ExprEval.of(sourceValue.asDouble() + targetValue.asDouble());
} else if (sourceValue.isArray() && targetValue.isArray()) {
ArrayList<Object> sourceList = new ArrayList<>(Arrays.asList(sourceValue.asArray()));
sourceList.addAll(Arrays.asList(targetValue.asArray()));

return ExprEval.ofComplex(ExpressionType.NESTED_DATA, sourceList);
}
return targetValue;
}
}

public static class MultiplyAggregator implements Aggregator
{
@Override
public ExprEval aggregate(ExprEval sourceValue, ExprEval targetValue)
{
if (sourceValue.type().isNumeric() && targetValue.type().isNumeric()) {
if (sourceValue.type().equals(ExpressionType.LONG) && targetValue.type().equals(ExpressionType.LONG)) {
return ExprEval.of(sourceValue.asLong() * targetValue.asLong());
}

return ExprEval.of(sourceValue.asDouble() * targetValue.asDouble());
} else if (sourceValue.isArray() && targetValue.isArray()) {
ArrayList<Object> sourceList = new ArrayList<>(Arrays.asList(sourceValue.asArray()));
sourceList.addAll(Arrays.asList(targetValue.asArray()));

return ExprEval.ofComplex(ExpressionType.NESTED_DATA, sourceList);
}
return targetValue;
}
}

private Aggregator getAggregator(Aggregate aggregate)
{
switch (aggregate) {
case ADD:
return new AddAggregator();
case MULTIPLY:
return new MultiplyAggregator();
default:
throw new IllegalArgumentException("Unknown aggregator: " + aggregate);
}
}

private boolean isMap(ExprEval exprEval) {
return !exprEval.type().isPrimitive() && !exprEval.isArray();
}

private void merge(Object source, Object target, Aggregator aggregator)
{
ExprEval sourceValue = ExprEval.bestEffortOf(source);
ExprEval targetValue = ExprEval.bestEffortOf(target);

if (isMap(sourceValue) && isMap(targetValue)) {
mergeJson((Map<String, Object>) source, (Map<String, Object>) target, aggregator);
} else if (sourceValue.isArray() && targetValue.isArray()) {
mergeArray((List<Object>) source, (List<Object>) target);
} else {
throw JsonMergeAggrExprMacro.this.validationFailed(
"Unsupported type for merge"
);
}
}

private void mergeArray(List<Object> source, List<Object> target) {
source.addAll(target);
}

private void mergeJson(Map<String, Object> source, Map<String, Object> target, Aggregator aggregator)
{
for (String key : target.keySet()) {
if (source.containsKey(key)) {
if (source.get(key) instanceof Map && target.get(key) instanceof Map) {
mergeJson((Map<String, Object>) source.get(key), (Map<String, Object>) target.get(key), aggregator);
} else {
ExprEval sourceValue = ExprEval.bestEffortOf(source.get(key));
ExprEval targetValue = ExprEval.bestEffortOf(target.get(key));

ExprEval newValue = aggregator.aggregate(sourceValue, targetValue);
source.put(key, unwrap(newValue));
}
} else {
source.put(key, target.get(key));
}
}
}

@Override
public Expr apply(List<Expr> args)
{
if (args.size() < 3) {
throw validationFailed("must have at least three arguments");
}

final class JsonMergeAggExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
{
public JsonMergeAggExpr(List<Expr> args)
{
super(JsonMergeAggrExprMacro.this, args);
}

@Override
public ExprEval eval(ObjectBinding bindings)
{
if (!args.get(0).isLiteral() || args.get(0).getLiteralValue() == null) {
throw validationFailed("aggregator arg must be literal");
}

final Aggregator aggregator = getAggregator(Aggregate.valueOf(
StringUtils.toUpperCase((String) args.get(0).getLiteralValue())
));

ExprEval arg = args.get(1).eval(bindings);

if (arg.value() == null) {
return null;
throw JsonMergeAggrExprMacro.this.validationFailed(
"invalid input expected %s but got %s instead",
ExpressionType.STRING,
arg.type()
);
}

if (arg.type().is(ExprType.STRING)) {
return arg.asString();
}

if (arg.type().is(ExprType.COMPLEX)) {
Object source;
try {
String str = getArgAsJson(JsonMergeAggrExprMacro.this, jsonMapper, arg);
source = jsonMapper.readValue(str, Object.class);
}
catch (JsonProcessingException e) {
throw JsonMergeAggrExprMacro.this.processingFailed(e,
"bad string input [%s]", arg.asString());
}

for (int i = 2; i < args.size(); i++) {
ExprEval argSub = args.get(i).eval(bindings);

try {
return jsonMapper.writeValueAsString(unwrap(arg));
String str = getArgAsJson(JsonMergeAggrExprMacro.this, jsonMapper, argSub);
if (str != null) {
Object target = jsonMapper.readValue(str, Object.class);
merge(source, target, aggregator);
}
}
catch (JsonProcessingException e) {
throw JsonMergeExprMacro.this.processingFailed(e, "bad complex input [%s]", arg.asString());
}
}

throw JsonMergeExprMacro.this.validationFailed(
"invalid input expected %s but got %s instead",
ExpressionType.STRING,
arg.type()
);
throw JsonMergeAggrExprMacro.this.processingFailed(e, "bad string input [%s]", argSub.asString());
}
}
return ExprEval.ofComplex(ExpressionType.NESTED_DATA, source);
}

@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return ExpressionType.NESTED_DATA;
}
}
return new ParseJsonExpr(args);
return new JsonMergeAggExpr(args);
}

@Override
public String name()
{
return NAME;
}


}


public static class ToJsonStringExprMacro implements ExprMacroTable.ExprMacro
{
public static final String NAME = "to_json_string";
Expand Down Expand Up @@ -853,4 +1035,31 @@ static List<NestedPathPart> getJsonPathPartsFromLiteral(NamedFunction fn, Expr a
);
return parts;
}


static String getArgAsJson(NamedFunction fn, ObjectMapper jsonMapper, ExprEval arg)
{
if (arg.value() == null) {
return null;
}

if (arg.type().is(ExprType.STRING)) {
return arg.asString();
}

if (arg.type().is(ExprType.COMPLEX)) {
try {
return jsonMapper.writeValueAsString(unwrap(arg));
}
catch (JsonProcessingException e) {
throw fn.processingFailed(e, "bad complex input [%s]", arg.asString());
}
}

throw fn.validationFailed(
"invalid input expected %s but got %s instead",
ExpressionType.STRING,
arg.type()
);
}
}
Loading

0 comments on commit e7af5fb

Please sign in to comment.