Skip to content

Commit

Permalink
use terms in set for concrete IPs keep disjunctions over ranges.
Browse files Browse the repository at this point in the history
Signed-off-by: mikhail-khludnev <[email protected]>
  • Loading branch information
mkhludnev authored and mikhail-khludnev committed Oct 17, 2024
1 parent 421a1cc commit da89a31
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 16 deletions.
57 changes: 44 additions & 13 deletions server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
Expand All @@ -58,6 +61,7 @@
import java.io.IOException;
import java.net.InetAddress;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -263,29 +267,56 @@ public Query termQuery(Object value, @Nullable QueryShardContext context) {
return query;
}

// @TODO check strings, byterefs, inetaddresses for concrete and masks
@Override
public Query termsQuery(List<?> values, QueryShardContext context) {
failIfNotIndexedAndNoDocValues();
InetAddress[] addresses = new InetAddress[values.size()];
int i = 0;
for (Object value : values) {
InetAddress address;
List<InetAddress> concreteIPs = new ArrayList<>();
List<Query> ranges = new ArrayList<>();
for (final Object value : values) {
if (value instanceof InetAddress) {
address = (InetAddress) value;
concreteIPs.add((InetAddress) value);
} else {
if (value instanceof BytesRef) {
value = ((BytesRef) value).utf8ToString();
}
if (value.toString().contains("/")) {
final String strVal = (value instanceof BytesRef) ? ((BytesRef) value).utf8ToString() : value.toString();
if (strVal.contains("/")) {
// the `terms` query contains some prefix queries, so we cannot create a set query
// and need to fall back to a disjunction of `term` queries
return super.termsQuery(values, context);
Query query = termQuery(strVal, context);
// would be great to have union on ranges over bare points
ranges.add(query);
} else {
concreteIPs.add(InetAddresses.forString(strVal));
}
}
}
if (!concreteIPs.isEmpty()) {
Supplier<Query> pointsQuery;
pointsQuery = () -> concreteIPs.size() == 1
? InetAddressPoint.newExactQuery(name(), concreteIPs.iterator().next())
: InetAddressPoint.newSetQuery(name(), concreteIPs.toArray(new InetAddress[0]));
if (hasDocValues()) {
List<BytesRef> set = new ArrayList<>(concreteIPs.size());
for (final InetAddress address : concreteIPs) {
set.add(new BytesRef(InetAddressPoint.encode(address)));
}
Query dvQuery = SortedSetDocValuesField.newSlowSetQuery(name(), set);
if (!isSearchable()) {
pointsQuery = () -> dvQuery;
} else {
Supplier<Query> wrap = pointsQuery;
pointsQuery = () -> new IndexOrDocValuesQuery(wrap.get(), dvQuery);
}
address = InetAddresses.forString(value.toString());
}
addresses[i++] = address;
ranges.add(pointsQuery.get());
}
if (ranges.size() == 1) {
return ranges.iterator().next(); // CSQ?
}
BooleanQuery.Builder union = new BooleanQuery.Builder();
for (Query q : ranges) {
union.add(q, BooleanClause.Occur.SHOULD);
}
return InetAddressPoint.newSetQuery(name(), addresses);
return new ConstantScoreQuery(union.build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public void testTermQuery() {
}

public void testTermsQuery() {
MappedFieldType ft = new IpFieldMapper.IpFieldType("field");
MappedFieldType ft = new IpFieldMapper.IpFieldType("field", true, false, false, null, Collections.emptyMap());

assertEquals(
InetAddressPoint.newSetQuery("field", InetAddresses.forString("::2"), InetAddresses.forString("::5")),
Expand All @@ -131,8 +131,8 @@ public void testTermsQuery() {
// if the list includes a prefix query we fallback to a bool query
assertEquals(
new ConstantScoreQuery(
new BooleanQuery.Builder().add(ft.termQuery("::42", null), Occur.SHOULD)
.add(ft.termQuery("::2/16", null), Occur.SHOULD)
new BooleanQuery.Builder().add(ft.termQuery("::2/16", null), Occur.SHOULD)
.add(ft.termQuery("::42", null), Occur.SHOULD)
.build()
),
ft.termsQuery(Arrays.asList("::42", "::2/16"), null)
Expand Down
169 changes: 169 additions & 0 deletions server/src/test/java/org/opensearch/search/SearchIpFieldTermsTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search;

import org.opensearch.action.bulk.BulkRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.network.InetAddresses;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.test.OpenSearchSingleNodeTestCase;
import org.hamcrest.MatcherAssert;

import java.io.IOException;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE;
import static org.hamcrest.Matchers.equalTo;

public class SearchIpFieldTermsTest extends OpenSearchSingleNodeTestCase {

public static final boolean IPv4_ONLY = true;
static String defaultIndexName = "test";

public void testMassive() throws Exception {
XContentBuilder xcb = createMapping();
client().admin().indices().prepareCreate(defaultIndexName).setMapping(xcb).get();
ensureGreen();

BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();

int cidrs = 0;
int ips = 0;
List<String> toQuery = new ArrayList<>();
for (int i = 0; ips <= 1024 && i < 1000000; i++) {
final String ip;
final int prefix;
if (IPv4_ONLY) {
ip = generateRandomIPv4();
prefix = 8 + random().nextInt(24); // CIDR prefix for IPv4
} else {
ip = generateRandomIPv6();
prefix = 32 + random().nextInt(97); // CIDR prefix for IPv6
}

bulkRequestBuilder.add(client().prepareIndex(defaultIndexName).setSource(Map.of("addr", ip)));

final String termToQuery;
if (cidrs < 1024 - 1 && random().nextBoolean()) {
termToQuery = ip + "/" + prefix;
cidrs++;
} else {
termToQuery = ip;
ips++;
}
toQuery.add(termToQuery);
}
int addMatches = 0;
for (int i = 0; i < atLeast(100); i++) {
final String ip;
if (IPv4_ONLY) {
ip = generateRandomIPv4();
} else {
ip = generateRandomIPv6();
}
bulkRequestBuilder.add(client().prepareIndex(defaultIndexName).setSource(Map.of("addr", ip)));
boolean match = false;
for (String termQ : toQuery) {
boolean isCidr = termQ.contains("/");
if ((isCidr && isIPInCIDR(ip, termQ)) || (!isCidr && termQ.equals(ip))) {
match = true;
break;
}
}
if (match) {
addMatches++;
} else {
break; // single mismatch is enough.
}
}

bulkRequestBuilder.setRefreshPolicy(IMMEDIATE).get();
SearchResponse result = client().prepareSearch(defaultIndexName).setQuery(QueryBuilders.termsQuery("addr", toQuery)).get();
MatcherAssert.assertThat(Objects.requireNonNull(result.getHits().getTotalHits()).value, equalTo((long) cidrs + ips + addMatches));
}

// Converts an IP string (either IPv4 or IPv6) to a byte array
private static byte[] ipToBytes(String ip) {
InetAddress inetAddress = InetAddresses.forString(ip);
return inetAddress.getAddress();
}

// Checks if an IP is within a given CIDR (works for both IPv4 and IPv6)
private static boolean isIPInCIDR(String ip, String cidr) {
String[] cidrParts = cidr.split("/");
String cidrIp = cidrParts[0];
int prefixLength = Integer.parseInt(cidrParts[1]);

byte[] ipBytes = ipToBytes(ip);
byte[] cidrIpBytes = ipToBytes(cidrIp);

// Calculate how many full bytes and how many bits are in the mask
int fullBytes = prefixLength / 8;
int extraBits = prefixLength % 8;

// Compare full bytes
for (int i = 0; i < fullBytes; i++) {
if (ipBytes[i] != cidrIpBytes[i]) {
return false;
}
}

// Compare extra bits (if any)
if (extraBits > 0) {
int mask = 0xFF << (8 - extraBits);
return (ipBytes[fullBytes] & mask) == (cidrIpBytes[fullBytes] & mask);
}

return true;
}

// Generate a random IPv4 address
private static String generateRandomIPv4() {
return String.format("%d.%d.%d.%d", random().nextInt(256), random().nextInt(256), random().nextInt(256), random().nextInt(256));
}

// Generate a random IPv6 address
private static String generateRandomIPv6() {
StringBuilder ipv6 = new StringBuilder();
for (int i = 0; i < 8; i++) {
ipv6.append(Integer.toHexString(random().nextInt(0xFFFF + 1)));
if (i < 7) {
ipv6.append(":");
}
}
return ipv6.toString();
}

private XContentBuilder createMapping() throws IOException {
return XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("addr")
.field("type", "ip")
.startObject("fields")
.startObject("idx")
.field("type", "ip")
.field("doc_values", false)
.endObject()
.startObject("dv")
.field("type", "ip")
.field("index", false)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();
}
}

0 comments on commit da89a31

Please sign in to comment.