Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use PrefixTree for :in queries #1107

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,55 @@
* prefix or search key, then the prefix match will only check up to the unsupported
* character and the caller will need to perform further checks on the returned value.
*/
final class PrefixTree<T> {
final class PrefixTree {

private final Lock lock = new ReentrantLock();

private volatile Node root;
private final Set<T> values;
private final Set<Query.KeyQuery> otherQueries;

/** Create a new instance. */
PrefixTree() {
values = newSet();
this.otherQueries = newSet();
}

private Node addQuery(Node node, Query.KeyQuery query) {
if (query instanceof Query.In) {
Query.In q = (Query.In) query;
node.inQueries.add(q);
} else {
node.otherQueries.add(query);
}
return node;
}

private boolean removeQuery(Node node, Query.KeyQuery query) {
if (query instanceof Query.In) {
Query.In q = (Query.In) query;
return node.inQueries.remove(q);
} else {
return node.otherQueries.remove(query);
}
}

/**
* Put a query into the tree.
*
* @param query
* Query to add, the prefix will be extracted from the query clause.
*/
void put(Query.KeyQuery query) {
if (query instanceof Query.In) {
Query.In q = (Query.In) query;
for (String v : q.values()) {
put(v, q);
}
} else if (query instanceof Query.Regex) {
Query.Regex q = (Query.Regex) query;
put(q.pattern().prefix(), q);
} else {
otherQueries.add(query);
}
}

/**
Expand All @@ -54,15 +93,16 @@ final class PrefixTree<T> {
* @param value
* Value to associate with the prefix.
*/
void put(String prefix, T value) {
void put(String prefix, Query.KeyQuery value) {
if (prefix == null) {
values.add(value);
otherQueries.add(value);
} else {
lock.lock();
try {
Node node = root;
if (node == null) {
root = new Node(prefix, EMPTY, asSet(value));
root = new Node(prefix, EMPTY);
addQuery(root, value);
} else {
root = putImpl(node, prefix, 0, value);
}
Expand All @@ -72,17 +112,17 @@ void put(String prefix, T value) {
}
}

private Node putImpl(Node node, String key, int offset, T value) {
private Node putImpl(Node node, String key, int offset, Query.KeyQuery value) {
final int prefixLength = node.prefix.length();
final int keyLength = key.length() - offset;
final int commonLength = commonPrefixLength(node.prefix, key, offset);
if (commonLength == 0 && prefixLength > 0) {
// No common prefix
Node n = new Node(key.substring(offset), EMPTY, asSet(value));
Node n = addQuery(new Node(key.substring(offset), EMPTY), value);
return new Node("", new Node[] {n, node});
} else if (keyLength == prefixLength && commonLength == prefixLength) {
// Fully matches, add the value to this node
node.values.add(value);
addQuery(node, value);
return node;
} else if (keyLength > prefixLength && commonLength == prefixLength) {
// key.startsWith(prefix), put the value into a child
Expand All @@ -92,25 +132,49 @@ private Node putImpl(Node node, String key, int offset, T value) {
Node n = putImpl(node.children[pos], key, childOffset, value);
return node.replaceChild(n, pos);
} else {
Node n = new Node(key.substring(childOffset), EMPTY, asSet(value));
Node n = addQuery(new Node(key.substring(childOffset), EMPTY), value);
return node.addChild(n);
}
} else if (prefixLength > keyLength && commonLength == keyLength) {
// prefix.startsWith(key), make new parent node and add this node as a child
int childOffset = offset + commonLength;
Node n = new Node(node.prefix.substring(commonLength), node.children, node.values);
return new Node(key.substring(offset, childOffset), new Node[] {n}, asSet(value));
Node n = new Node(node.prefix.substring(commonLength), node.children, node.inQueries, node.otherQueries);
return addQuery(new Node(key.substring(offset, childOffset), new Node[] {n}), value);
} else {
// Common prefix is a subset of both
int childOffset = offset + commonLength;
Node[] children = {
new Node(node.prefix.substring(commonLength), node.children, node.values),
new Node(key.substring(childOffset), EMPTY, asSet(value))
new Node(node.prefix.substring(commonLength), node.children, node.inQueries, node.otherQueries),
addQuery(new Node(key.substring(childOffset), EMPTY), value)
};
return new Node(node.prefix.substring(0, commonLength), children);
}
}

/**
* Remove a value from the tree with the associated prefix.
*
* @param query
* Query to remove, the prefix will be extracted from the query clause.
* @return
* Returns true if a value was removed from the tree.
*/
boolean remove(Query.KeyQuery query) {
if (query instanceof Query.In) {
boolean removed = false;
Query.In q = (Query.In) query;
for (String v : q.values()) {
removed |= remove(v, q);
}
return removed;
} else if (query instanceof Query.Regex) {
Query.Regex q = (Query.Regex) query;
return remove(q.pattern().prefix(), q);
} else {
return otherQueries.remove(query);
}
}

/**
* Remove a value from the tree with the associated prefix.
*
Expand All @@ -121,9 +185,9 @@ private Node putImpl(Node node, String key, int offset, T value) {
* @return
* Returns true if a value was removed from the tree.
*/
boolean remove(String prefix, T value) {
boolean remove(String prefix, Query.KeyQuery value) {
if (prefix == null) {
return values.remove(value);
return otherQueries.remove(value);
} else {
lock.lock();
try {
Expand All @@ -143,13 +207,13 @@ boolean remove(String prefix, T value) {
}
}

private boolean removeImpl(Node node, String key, int offset, T value) {
private boolean removeImpl(Node node, String key, int offset, Query.KeyQuery value) {
final int prefixLength = node.prefix.length();
final int keyLength = key.length() - offset;
final int commonLength = commonPrefixLength(node.prefix, key, offset);
if (keyLength == prefixLength && commonLength == prefixLength) {
// Fully matches, remove the value from this node
return node.values.remove(value);
return removeQuery(node, value);
} else if (keyLength > prefixLength && commonLength == prefixLength) {
// Try to remove from children
int childOffset = offset + commonLength;
Expand All @@ -168,8 +232,8 @@ private boolean removeImpl(Node node, String key, int offset, T value) {
* @return
* Values associated with a matching prefix.
*/
List<T> get(String key) {
List<T> result = new ArrayList<>();
List<Query.KeyQuery> get(String key) {
List<Query.KeyQuery> result = new ArrayList<>();
forEach(key, result::add);
return result;
}
Expand All @@ -182,25 +246,23 @@ List<T> get(String key) {
* @param consumer
* Function to call for matching values.
*/
void forEach(String key, Consumer<T> consumer) {
values.forEach(consumer);
void forEach(String key, Consumer<Query.KeyQuery> consumer) {
// In queries cannot have an empty value, so cannot be in the root set
otherQueries.forEach(consumer);
Node node = root;
if (node != null) {
forEachImpl(node, key, 0, consumer);
}
}

@SuppressWarnings("unchecked")
private void forEachImpl(Node node, String key, int offset, Consumer<T> consumer) {
private void forEachImpl(Node node, String key, int offset, Consumer<Query.KeyQuery> consumer) {
final int prefixLength = node.prefix.length();
final int keyLength = key.length() - offset;
final int commonLength = commonPrefixLength(node.prefix, key, offset);

if (commonLength == prefixLength) {
// Prefix matches, consume values for this node
for (Object value : node.values) {
consumer.accept((T) value);
}
// Prefix matches, consume other queries
node.otherQueries.forEach(consumer);

if (commonLength < keyLength) {
// There is more to the key, check if there are also matches for child nodes
Expand All @@ -209,6 +271,9 @@ private void forEachImpl(Node node, String key, int offset, Consumer<T> consumer
if (pos >= 0) {
forEachImpl(node.children[pos], key, childOffset, consumer);
}
} else {
// It is an exact match, consume in queries
node.inQueries.forEach(consumer);
}
}
}
Expand All @@ -217,7 +282,7 @@ private void forEachImpl(Node node, String key, int offset, Consumer<T> consumer
* Returns true if the tree is empty.
*/
boolean isEmpty() {
return values.isEmpty() && (root == null || root.isEmpty());
return otherQueries.isEmpty() && (root == null || root.isEmpty());
}

/**
Expand All @@ -226,7 +291,7 @@ boolean isEmpty() {
*/
int size() {
Node r = root;
return (r == null ? 0 : r.size()) + values.size();
return (r == null ? 0 : r.size()) + otherQueries.size();
}

/**
Expand Down Expand Up @@ -273,8 +338,8 @@ private static <T> Set<T> newSet() {
return new CopyOnWriteArraySet<>();
}

private static Set<Object> asSet(Object value) {
Set<Object> set = newSet();
private static Set<Query.KeyQuery> asSet(Query.KeyQuery value) {
Set<Query.KeyQuery> set = newSet();
set.add(value);
return set;
}
Expand All @@ -285,46 +350,48 @@ private static Set<Object> asSet(Object value) {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PrefixTree<?> that = (PrefixTree<?>) o;
PrefixTree that = (PrefixTree) o;
return Objects.equals(root, that.root)
&& values.equals(that.values);
&& otherQueries.equals(that.otherQueries);
}

@Override
public int hashCode() {
return Objects.hash(root, values);
return Objects.hash(root, otherQueries);
}

private static class Node {

final String prefix;
final Node[] children;
final Set<Object> values;
final Set<Query.In> inQueries;
final Set<Query.KeyQuery> otherQueries;

Node(String prefix, Node[] children, Set<Object> values) {
Node(String prefix, Node[] children, Set<Query.In> inQueries, Set<Query.KeyQuery> otherQueries) {
this.prefix = Preconditions.checkNotNull(prefix, "prefix");
this.children = Preconditions.checkNotNull(children, "children");
this.values = Preconditions.checkNotNull(values, "values");
this.inQueries = Preconditions.checkNotNull(inQueries, "inQueries");
this.otherQueries = Preconditions.checkNotNull(otherQueries, "otherQueries");
Arrays.sort(children, Comparator.comparing(n -> n.prefix));
}

Node(String prefix, Node[] children) {
this(prefix, children, newSet());
this(prefix, children, newSet(), newSet());
}

Node replaceChild(Node n, int i) {
Node[] cs = new Node[children.length];
System.arraycopy(children, 0, cs, 0, i);
cs[i] = n;
System.arraycopy(children, i + 1, cs, i + 1, children.length - i - 1);
return new Node(prefix, cs, values);
return new Node(prefix, cs, inQueries, otherQueries);
}

Node addChild(Node n) {
Node[] cs = new Node[children.length + 1];
System.arraycopy(children, 0, cs, 0, children.length);
cs[children.length] = n;
return new Node(prefix, cs, values);
return new Node(prefix, cs, inQueries, otherQueries);
}

Node compress() {
Expand Down Expand Up @@ -352,17 +419,17 @@ Node compress() {
// Return compressed node. Merge nodes if intermediates have no values.
if (cs == null) {
return this;
} else if (values.isEmpty() && cs.size() == 1) {
} else if (inQueries.isEmpty() && otherQueries.isEmpty() && cs.size() == 1) {
Node c = cs.get(0);
String p = prefix + c.prefix;
return new Node(p, EMPTY, c.values);
return new Node(p, EMPTY, c.inQueries, c.otherQueries);
} else {
return new Node(prefix, cs.toArray(EMPTY), values);
return new Node(prefix, cs.toArray(EMPTY), inQueries, otherQueries);
}
}

boolean isEmpty() {
return values.isEmpty() && areAllChildrenEmpty();
return inQueries.isEmpty() && otherQueries.isEmpty() && areAllChildrenEmpty();
}

private boolean areAllChildrenEmpty() {
Expand All @@ -375,7 +442,7 @@ private boolean areAllChildrenEmpty() {
}

int size() {
int sz = values.size();
int sz = inQueries.size() + otherQueries.size();
for (Node child : children) {
sz += child.size();
}
Expand All @@ -389,12 +456,13 @@ public boolean equals(Object o) {
Node node = (Node) o;
return prefix.equals(node.prefix)
&& Arrays.equals(children, node.children)
&& values.equals(node.values);
&& inQueries.equals(node.inQueries)
&& otherQueries.equals(node.otherQueries);
}

@Override
public int hashCode() {
int result = Objects.hash(prefix, values);
int result = Objects.hash(prefix, inQueries, otherQueries);
result = 31 * result + Arrays.hashCode(children);
return result;
}
Expand Down
Loading