Skip to content

Commit

Permalink
meh
Browse files Browse the repository at this point in the history
  • Loading branch information
evanchooly committed May 15, 2024
1 parent b3d23a6 commit 25fba6b
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 39 deletions.
122 changes: 85 additions & 37 deletions rewrite/src/main/java/dev/morphia/rewrite/recipes/PipelineRewrite.java
Original file line number Diff line number Diff line change
@@ -1,18 +1,60 @@
package dev.morphia.rewrite.recipes;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.StringJoiner;

import org.jetbrains.annotations.NotNull;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.J.MethodInvocation;

public class PipelineRewrite extends Recipe {

private static final String AGGREGATION = "dev.morphia.aggregation.Aggregation";
static final String AGGREGATION = "dev.morphia.aggregation.Aggregation";
static final MethodMatcher pipeline = new MethodMatcher(PipelineRewrite.AGGREGATION + " pipeline(..)");
static final List<MethodMatcher> matchers = List.of(
new MethodMatcher(AGGREGATION + " addFields(dev.morphia.aggregation.stages.AddFields)"),
new MethodMatcher(AGGREGATION + " autoBucket(dev.morphia.aggregation.stages.AutoBucket)"),
new MethodMatcher(AGGREGATION + " bucket(dev.morphia.aggregation.stages.Bucket)"),
new MethodMatcher(AGGREGATION + " changeStream()"),
new MethodMatcher(AGGREGATION + " changeStream(dev.morphia.aggregation.stages.ChangeStream)"),
new MethodMatcher(AGGREGATION + " collStats(dev.morphia.aggregation.stages.CollectionStats)"),
new MethodMatcher(AGGREGATION + " count(dev.morphia.aggregation.stages.Count)"),
new MethodMatcher(AGGREGATION + " currentOp(dev.morphia.aggregation.stages.CountOp)"),
new MethodMatcher(AGGREGATION + " densify(dev.morphia.aggregation.stages.Densify)"),
new MethodMatcher(AGGREGATION + " documents(dev.morphia.aggregation.expressions.impls.DocumentExpression)"),
new MethodMatcher(AGGREGATION + " facet(dev.morphia.aggregation.stages.Facet)"),
new MethodMatcher(AGGREGATION + " fill(dev.morphia.aggregation.stages.Fill)"),
new MethodMatcher(AGGREGATION + " geoNear(dev.morphia.aggregation.stages.GeoNear)"),
new MethodMatcher(AGGREGATION + " graphLookup(dev.morphia.aggregation.stages.GraphLookup)"),
new MethodMatcher(AGGREGATION + " group(dev.morphia.aggregation.stages.Group)"),
new MethodMatcher(AGGREGATION + " indexStats(dev.morphia.aggregation.stages.IndexStats)"),
new MethodMatcher(AGGREGATION + " limit(long)"),
new MethodMatcher(AGGREGATION + " lookup(dev.morphia.aggregation.stages.Lookup)"),
new MethodMatcher(AGGREGATION + " match(dev.morphia.aggregation.stages.Match)"),
new MethodMatcher(AGGREGATION + " planCacheStats()"),
new MethodMatcher(AGGREGATION + " project(dev.morphia.aggregation.stages.Projection)"),
new MethodMatcher(AGGREGATION + " redact(dev.morphia.aggregation.stages.Redact)"),
new MethodMatcher(AGGREGATION + " replaceRoot(dev.morphia.aggregation.stages.ReplaceRoot)"),
new MethodMatcher(AGGREGATION + " replaceWith(dev.morphia.aggregation.stages.ReplaceWith)"),
new MethodMatcher(AGGREGATION + " sample(dev.morphia.aggregation.stages.Sample)"),
new MethodMatcher(AGGREGATION + " set(dev.morphia.aggregation.stages.Set)"),
new MethodMatcher(AGGREGATION + " skip(long)"),
new MethodMatcher(AGGREGATION + " sort(dev.morphia.aggregation.stages.Sort)"),
new MethodMatcher(AGGREGATION + " sortByCount(dev.morphia.aggregation.stages.SortByCount)"),
new MethodMatcher(AGGREGATION + " unionWith(Class,Stage...)"),
new MethodMatcher(AGGREGATION + " unionWith(String,Stage...)"),
new MethodMatcher(AGGREGATION + " unset(dev.morphia.aggregation.stages.Unset)"),
new MethodMatcher(AGGREGATION + " unwind(dev.morphia.aggregation.stages.Unwind)"));

@Override
public String getDisplayName() {
Expand All @@ -26,45 +68,16 @@ public String getDescription() {

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
var matchers = List.of(
new MethodMatcher(AGGREGATION + " addFields(dev.morphia.aggregation.stages.AddFields)"),
new MethodMatcher(AGGREGATION + " autoBucket(dev.morphia.aggregation.stages.AutoBucket)"),
new MethodMatcher(AGGREGATION + " bucket(dev.morphia.aggregation.stages.Bucket)"),
new MethodMatcher(AGGREGATION + " changeStream()"),
new MethodMatcher(AGGREGATION + " changeStream(dev.morphia.aggregation.stages.ChangeStream)"),
new MethodMatcher(AGGREGATION + " collStats(dev.morphia.aggregation.stages.CollectionStats)"),
new MethodMatcher(AGGREGATION + " count(dev.morphia.aggregation.stages.Count)"),
new MethodMatcher(AGGREGATION + " currentOp(dev.morphia.aggregation.stages.CountOp)"),
new MethodMatcher(AGGREGATION + " densify(dev.morphia.aggregation.stages.Densify)"),
new MethodMatcher(AGGREGATION + " documents(dev.morphia.aggregation.expressions.impls.DocumentExpression)"),
new MethodMatcher(AGGREGATION + " facet(dev.morphia.aggregation.stages.Facet)"),
new MethodMatcher(AGGREGATION + " fill(dev.morphia.aggregation.stages.Fill)"),
new MethodMatcher(AGGREGATION + " geoNear(dev.morphia.aggregation.stages.GeoNear)"),
new MethodMatcher(AGGREGATION + " graphLookup(dev.morphia.aggregation.stages.GraphLookup)"),
new MethodMatcher(AGGREGATION + " group(dev.morphia.aggregation.stages.Group)"),
new MethodMatcher(AGGREGATION + " indexStats(dev.morphia.aggregation.stages.IndexStats)"),
new MethodMatcher(AGGREGATION + " limit(long)"),
new MethodMatcher(AGGREGATION + " lookup(dev.morphia.aggregation.stages.Lookup)"),
new MethodMatcher(AGGREGATION + " match(dev.morphia.aggregation.stages.Match)"),
new MethodMatcher(AGGREGATION + " planCacheStats()"),
new MethodMatcher(AGGREGATION + " project(dev.morphia.aggregation.stages.Projection)"),
new MethodMatcher(AGGREGATION + " redact(dev.morphia.aggregation.stages.Redact)"),
new MethodMatcher(AGGREGATION + " replaceRoot(dev.morphia.aggregation.stages.ReplaceRoot)"),
new MethodMatcher(AGGREGATION + " replaceWith(dev.morphia.aggregation.stages.ReplaceWith)"),
new MethodMatcher(AGGREGATION + " sample(dev.morphia.aggregation.stages.Sample)"),
new MethodMatcher(AGGREGATION + " set(dev.morphia.aggregation.stages.Set)"),
new MethodMatcher(AGGREGATION + " skip(long)"),
new MethodMatcher(AGGREGATION + " sort(dev.morphia.aggregation.stages.Sort)"),
new MethodMatcher(AGGREGATION + " sortByCount(dev.morphia.aggregation.stages.SortByCount)"),
new MethodMatcher(AGGREGATION + " unionWith(Class,Stage...)"),
new MethodMatcher(AGGREGATION + " unionWith(String,Stage...)"),
new MethodMatcher(AGGREGATION + " unset(dev.morphia.aggregation.stages.Unset)"),
new MethodMatcher(AGGREGATION + " unwind(dev.morphia.aggregation.stages.Unwind)"));

return new JavaIsoVisitor<>() {

return new JavaIsoVisitor<ExecutionContext>() {

@Override
public MethodInvocation visitMethodInvocation(MethodInvocation methodInvocation, @NotNull ExecutionContext context) {
return working(methodInvocation, context);
// return notWorking(methodInvocation, context);
}

public MethodInvocation working(MethodInvocation methodInvocation, @NotNull ExecutionContext context) {
if (matchers.stream().anyMatch(matcher -> matcher.matches(methodInvocation))) {
return super.visitMethodInvocation(methodInvocation
.withName(methodInvocation.getName().withSimpleName("pipeline")),
Expand All @@ -73,6 +86,41 @@ public MethodInvocation visitMethodInvocation(MethodInvocation methodInvocation,
return super.visitMethodInvocation(methodInvocation, context);
}
}

public MethodInvocation notWorking(MethodInvocation methodInvocation, @NotNull ExecutionContext context) {
if (matchers.stream().anyMatch(matcher -> matcher.matches(methodInvocation))) {
MethodInvocation mi = methodInvocation;
List<Expression> arguments = new ArrayList<>();
Expression expression = mi;
while (expression instanceof MethodInvocation invocation) {
arguments.add(invocation.getArguments().get(0));
mi = invocation;
expression = mi.getSelect();
}
Collections.reverse(arguments);

return applyTemplate(expression, arguments, expression);
} else {
return super.visitMethodInvocation(methodInvocation, context);

}
}

private J.MethodInvocation applyTemplate(Expression expression, List<Expression> arguments, Expression target) {
String code = buildTemplate(expression, arguments);
return JavaTemplate.builder(code)
.contextSensitive()
.javaParser(JavaParser.fromJavaVersion())
.build()
.apply(getCursor(), target.getCoordinates().replace());
}

private String buildTemplate(Expression toReplace, List<Expression> arguments) {
StringJoiner joiner = new StringJoiner(",\n\t", toReplace + ".pipeline(", ")");
arguments.forEach(argument -> joiner.add(argument.toString()));

return joiner.toString();
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package dev.morphia.rewrite.recipes;

import org.jetbrains.annotations.NotNull;
import org.openrewrite.ExecutionContext;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J.MethodInvocation;

class PipelineVisitor extends JavaIsoVisitor<ExecutionContext> {

@Override
public MethodInvocation visitMethodInvocation(MethodInvocation methodInvocation, @NotNull ExecutionContext context) {
if (PipelineRewrite.pipeline.matches(methodInvocation)) {
Expression select = methodInvocation.getSelect();
System.out.println("\n\nselect = " + select);
System.out.println("select.getSideEffects() = " + select.getSideEffects());
if (select instanceof MethodInvocation invocation) {
System.out.println("invocation.getArguments() = " + invocation.getArguments());
} else {
System.out.println("select.getType() = " + select.getType());
}
return super.visitMethodInvocation(methodInvocation, context);
}
return super.visitMethodInvocation(methodInvocation, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ public class UnwrapTest {
public void update(Aggregation<?> aggregation) {
aggregation
.group(group(id("author")).field("count", sum(value(1))))
.sort(sort().ascending("_id"))
.sort(sort().ascending("1"))
.sort(sort().ascending("2"))
.sort(sort().ascending("3"))
.sort(sort().ascending("4"))
.execute(Document.class);
}
}
Expand All @@ -62,7 +65,10 @@ public class UnwrapTest {
public void update(Aggregation<?> aggregation) {
aggregation
.pipeline(group(id("author")).field("count", sum(value(1))))
.pipeline(sort().ascending("_id"))
.pipeline(sort().ascending("1"))
.pipeline(sort().ascending("2"))
.pipeline(sort().ascending("3"))
.pipeline(sort().ascending("4"))
.execute(Document.class);
}
}
Expand Down

0 comments on commit 25fba6b

Please sign in to comment.