Skip to content

Commit

Permalink
use PrefixTree for :in queries (#1107)
Browse files Browse the repository at this point in the history
Updates the query index to leverage the prefix tree for
`:in` queries. The `:in` query will only be returned in
the result set if it is an exact match, so no further
checks are needed.
  • Loading branch information
brharrington authored Jan 4, 2024
1 parent 8061303 commit cf0b717
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,55 @@
* prefix or search key, then the prefix match will only check up to the unsupported
* character and the caller will need to perform further checks on the returned value.
*/
final class PrefixTree<T> {
final class PrefixTree {

private final Lock lock = new ReentrantLock();

private volatile Node root;
private final Set<T> values;
private final Set<Query.KeyQuery> otherQueries;

/** Create a new instance. */
PrefixTree() {
values = newSet();
this.otherQueries = newSet();
}

private Node addQuery(Node node, Query.KeyQuery query) {
if (query instanceof Query.In) {
Query.In q = (Query.In) query;
node.inQueries.add(q);
} else {
node.otherQueries.add(query);
}
return node;
}

private boolean removeQuery(Node node, Query.KeyQuery query) {
if (query instanceof Query.In) {
Query.In q = (Query.In) query;
return node.inQueries.remove(q);
} else {
return node.otherQueries.remove(query);
}
}

/**
* Put a query into the tree.
*
* @param query
* Query to add, the prefix will be extracted from the query clause.
*/
void put(Query.KeyQuery query) {
if (query instanceof Query.In) {
Query.In q = (Query.In) query;
for (String v : q.values()) {
put(v, q);
}
} else if (query instanceof Query.Regex) {
Query.Regex q = (Query.Regex) query;
put(q.pattern().prefix(), q);
} else {
otherQueries.add(query);
}
}

/**
Expand All @@ -54,15 +93,16 @@ final class PrefixTree<T> {
* @param value
* Value to associate with the prefix.
*/
void put(String prefix, T value) {
void put(String prefix, Query.KeyQuery value) {
if (prefix == null) {
values.add(value);
otherQueries.add(value);
} else {
lock.lock();
try {
Node node = root;
if (node == null) {
root = new Node(prefix, EMPTY, asSet(value));
root = new Node(prefix, EMPTY);
addQuery(root, value);
} else {
root = putImpl(node, prefix, 0, value);
}
Expand All @@ -72,17 +112,17 @@ void put(String prefix, T value) {
}
}

private Node putImpl(Node node, String key, int offset, T value) {
private Node putImpl(Node node, String key, int offset, Query.KeyQuery value) {
final int prefixLength = node.prefix.length();
final int keyLength = key.length() - offset;
final int commonLength = commonPrefixLength(node.prefix, key, offset);
if (commonLength == 0 && prefixLength > 0) {
// No common prefix
Node n = new Node(key.substring(offset), EMPTY, asSet(value));
Node n = addQuery(new Node(key.substring(offset), EMPTY), value);
return new Node("", new Node[] {n, node});
} else if (keyLength == prefixLength && commonLength == prefixLength) {
// Fully matches, add the value to this node
node.values.add(value);
addQuery(node, value);
return node;
} else if (keyLength > prefixLength && commonLength == prefixLength) {
// key.startsWith(prefix), put the value into a child
Expand All @@ -92,25 +132,49 @@ private Node putImpl(Node node, String key, int offset, T value) {
Node n = putImpl(node.children[pos], key, childOffset, value);
return node.replaceChild(n, pos);
} else {
Node n = new Node(key.substring(childOffset), EMPTY, asSet(value));
Node n = addQuery(new Node(key.substring(childOffset), EMPTY), value);
return node.addChild(n);
}
} else if (prefixLength > keyLength && commonLength == keyLength) {
// prefix.startsWith(key), make new parent node and add this node as a child
int childOffset = offset + commonLength;
Node n = new Node(node.prefix.substring(commonLength), node.children, node.values);
return new Node(key.substring(offset, childOffset), new Node[] {n}, asSet(value));
Node n = new Node(node.prefix.substring(commonLength), node.children, node.inQueries, node.otherQueries);
return addQuery(new Node(key.substring(offset, childOffset), new Node[] {n}), value);
} else {
// Common prefix is a subset of both
int childOffset = offset + commonLength;
Node[] children = {
new Node(node.prefix.substring(commonLength), node.children, node.values),
new Node(key.substring(childOffset), EMPTY, asSet(value))
new Node(node.prefix.substring(commonLength), node.children, node.inQueries, node.otherQueries),
addQuery(new Node(key.substring(childOffset), EMPTY), value)
};
return new Node(node.prefix.substring(0, commonLength), children);
}
}

/**
* Remove a value from the tree with the associated prefix.
*
* @param query
* Query to remove, the prefix will be extracted from the query clause.
* @return
* Returns true if a value was removed from the tree.
*/
boolean remove(Query.KeyQuery query) {
if (query instanceof Query.In) {
boolean removed = false;
Query.In q = (Query.In) query;
for (String v : q.values()) {
removed |= remove(v, q);
}
return removed;
} else if (query instanceof Query.Regex) {
Query.Regex q = (Query.Regex) query;
return remove(q.pattern().prefix(), q);
} else {
return otherQueries.remove(query);
}
}

/**
* Remove a value from the tree with the associated prefix.
*
Expand All @@ -121,9 +185,9 @@ private Node putImpl(Node node, String key, int offset, T value) {
* @return
* Returns true if a value was removed from the tree.
*/
boolean remove(String prefix, T value) {
boolean remove(String prefix, Query.KeyQuery value) {
if (prefix == null) {
return values.remove(value);
return otherQueries.remove(value);
} else {
lock.lock();
try {
Expand All @@ -143,13 +207,13 @@ boolean remove(String prefix, T value) {
}
}

private boolean removeImpl(Node node, String key, int offset, T value) {
private boolean removeImpl(Node node, String key, int offset, Query.KeyQuery value) {
final int prefixLength = node.prefix.length();
final int keyLength = key.length() - offset;
final int commonLength = commonPrefixLength(node.prefix, key, offset);
if (keyLength == prefixLength && commonLength == prefixLength) {
// Fully matches, remove the value from this node
return node.values.remove(value);
return removeQuery(node, value);
} else if (keyLength > prefixLength && commonLength == prefixLength) {
// Try to remove from children
int childOffset = offset + commonLength;
Expand All @@ -168,8 +232,8 @@ private boolean removeImpl(Node node, String key, int offset, T value) {
* @return
* Values associated with a matching prefix.
*/
List<T> get(String key) {
List<T> result = new ArrayList<>();
List<Query.KeyQuery> get(String key) {
List<Query.KeyQuery> result = new ArrayList<>();
forEach(key, result::add);
return result;
}
Expand All @@ -182,25 +246,23 @@ List<T> get(String key) {
* @param consumer
* Function to call for matching values.
*/
void forEach(String key, Consumer<T> consumer) {
values.forEach(consumer);
void forEach(String key, Consumer<Query.KeyQuery> consumer) {
// In queries cannot have an empty value, so cannot be in the root set
otherQueries.forEach(consumer);
Node node = root;
if (node != null) {
forEachImpl(node, key, 0, consumer);
}
}

@SuppressWarnings("unchecked")
private void forEachImpl(Node node, String key, int offset, Consumer<T> consumer) {
private void forEachImpl(Node node, String key, int offset, Consumer<Query.KeyQuery> consumer) {
final int prefixLength = node.prefix.length();
final int keyLength = key.length() - offset;
final int commonLength = commonPrefixLength(node.prefix, key, offset);

if (commonLength == prefixLength) {
// Prefix matches, consume values for this node
for (Object value : node.values) {
consumer.accept((T) value);
}
// Prefix matches, consume other queries
node.otherQueries.forEach(consumer);

if (commonLength < keyLength) {
// There is more to the key, check if there are also matches for child nodes
Expand All @@ -209,6 +271,9 @@ private void forEachImpl(Node node, String key, int offset, Consumer<T> consumer
if (pos >= 0) {
forEachImpl(node.children[pos], key, childOffset, consumer);
}
} else {
// It is an exact match, consume in queries
node.inQueries.forEach(consumer);
}
}
}
Expand All @@ -217,7 +282,7 @@ private void forEachImpl(Node node, String key, int offset, Consumer<T> consumer
* Returns true if the tree is empty.
*/
boolean isEmpty() {
return values.isEmpty() && (root == null || root.isEmpty());
return otherQueries.isEmpty() && (root == null || root.isEmpty());
}

/**
Expand All @@ -226,7 +291,7 @@ boolean isEmpty() {
*/
int size() {
Node r = root;
return (r == null ? 0 : r.size()) + values.size();
return (r == null ? 0 : r.size()) + otherQueries.size();
}

/**
Expand Down Expand Up @@ -273,8 +338,8 @@ private static <T> Set<T> newSet() {
return new CopyOnWriteArraySet<>();
}

private static Set<Object> asSet(Object value) {
Set<Object> set = newSet();
private static Set<Query.KeyQuery> asSet(Query.KeyQuery value) {
Set<Query.KeyQuery> set = newSet();
set.add(value);
return set;
}
Expand All @@ -285,46 +350,48 @@ private static Set<Object> asSet(Object value) {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PrefixTree<?> that = (PrefixTree<?>) o;
PrefixTree that = (PrefixTree) o;
return Objects.equals(root, that.root)
&& values.equals(that.values);
&& otherQueries.equals(that.otherQueries);
}

@Override
public int hashCode() {
return Objects.hash(root, values);
return Objects.hash(root, otherQueries);
}

private static class Node {

final String prefix;
final Node[] children;
final Set<Object> values;
final Set<Query.In> inQueries;
final Set<Query.KeyQuery> otherQueries;

Node(String prefix, Node[] children, Set<Object> values) {
Node(String prefix, Node[] children, Set<Query.In> inQueries, Set<Query.KeyQuery> otherQueries) {
this.prefix = Preconditions.checkNotNull(prefix, "prefix");
this.children = Preconditions.checkNotNull(children, "children");
this.values = Preconditions.checkNotNull(values, "values");
this.inQueries = Preconditions.checkNotNull(inQueries, "inQueries");
this.otherQueries = Preconditions.checkNotNull(otherQueries, "otherQueries");
Arrays.sort(children, Comparator.comparing(n -> n.prefix));
}

Node(String prefix, Node[] children) {
this(prefix, children, newSet());
this(prefix, children, newSet(), newSet());
}

Node replaceChild(Node n, int i) {
Node[] cs = new Node[children.length];
System.arraycopy(children, 0, cs, 0, i);
cs[i] = n;
System.arraycopy(children, i + 1, cs, i + 1, children.length - i - 1);
return new Node(prefix, cs, values);
return new Node(prefix, cs, inQueries, otherQueries);
}

Node addChild(Node n) {
Node[] cs = new Node[children.length + 1];
System.arraycopy(children, 0, cs, 0, children.length);
cs[children.length] = n;
return new Node(prefix, cs, values);
return new Node(prefix, cs, inQueries, otherQueries);
}

Node compress() {
Expand Down Expand Up @@ -352,17 +419,17 @@ Node compress() {
// Return compressed node. Merge nodes if intermediates have no values.
if (cs == null) {
return this;
} else if (values.isEmpty() && cs.size() == 1) {
} else if (inQueries.isEmpty() && otherQueries.isEmpty() && cs.size() == 1) {
Node c = cs.get(0);
String p = prefix + c.prefix;
return new Node(p, EMPTY, c.values);
return new Node(p, EMPTY, c.inQueries, c.otherQueries);
} else {
return new Node(prefix, cs.toArray(EMPTY), values);
return new Node(prefix, cs.toArray(EMPTY), inQueries, otherQueries);
}
}

boolean isEmpty() {
return values.isEmpty() && areAllChildrenEmpty();
return inQueries.isEmpty() && otherQueries.isEmpty() && areAllChildrenEmpty();
}

private boolean areAllChildrenEmpty() {
Expand All @@ -375,7 +442,7 @@ private boolean areAllChildrenEmpty() {
}

int size() {
int sz = values.size();
int sz = inQueries.size() + otherQueries.size();
for (Node child : children) {
sz += child.size();
}
Expand All @@ -389,12 +456,13 @@ public boolean equals(Object o) {
Node node = (Node) o;
return prefix.equals(node.prefix)
&& Arrays.equals(children, node.children)
&& values.equals(node.values);
&& inQueries.equals(node.inQueries)
&& otherQueries.equals(node.otherQueries);
}

@Override
public int hashCode() {
int result = Objects.hash(prefix, values);
int result = Objects.hash(prefix, inQueries, otherQueries);
result = 31 * result + Arrays.hashCode(children);
return result;
}
Expand Down
Loading

0 comments on commit cf0b717

Please sign in to comment.