Skip to content

Commit

Permalink
Added support for includes/excludes source (#1023)
Browse files Browse the repository at this point in the history
* Added support for includes/excluses source

* Fixed comment grammar

* PR fixes

* Fixed tests
  • Loading branch information
kyle-sammons authored Aug 12, 2024
1 parent 0cf0a22 commit f381f04
Show file tree
Hide file tree
Showing 27 changed files with 957 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,8 @@ public SearchResult<T> query(SearchQuery query) {
searchEndTime,
query.howMany,
query.aggBuilder,
query.queryBuilder);
query.queryBuilder,
query.sourceFieldFilter);
} else {
return (SearchResult<T>) SearchResult.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ public SearchResult<T> query(SearchQuery query) {
query.endTimeEpochMs,
query.howMany,
query.aggBuilder,
query.queryBuilder);
query.queryBuilder,
query.sourceFieldFilter);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.slack.astra.logstore.search.SearchResultUtils;
Expand Down Expand Up @@ -64,11 +65,82 @@ public List<AstraSearch.SearchRequest> parseHttpPostBody(String postBody)
.setEndTimeEpochMs(getEndTimeEpochMs(body))
.setAggregations(getAggregations(body))
.setQuery(getQuery(body))
.setSourceFieldFilter(getSourceFieldFilter(body))
.build());
}
return searchRequests;
}

private static AstraSearch.SearchRequest.SourceFieldFilter getSourceFieldFilter(JsonNode body) {
if (body.has("_source") && body.get("_source") != null) {
JsonNode sourceNode = body.get("_source");
if (sourceNode.isBoolean()) {
return AstraSearch.SearchRequest.SourceFieldFilter.newBuilder()
.setIncludeAll(sourceNode.booleanValue())
.build();

} else if (sourceNode.isTextual()) {
return AstraSearch.SearchRequest.SourceFieldFilter.newBuilder()
.addIncludeWildcards(sourceNode.textValue())
.build();
} else if (sourceNode.isArray()) {
ArrayNode includeArrayNode = (ArrayNode) sourceNode;
HashMap<String, Boolean> includes = new HashMap<>();

AstraSearch.SearchRequest.SourceFieldFilter.Builder fieldInclusionBuilder =
AstraSearch.SearchRequest.SourceFieldFilter.newBuilder();

for (JsonNode jsonNode : includeArrayNode) {
String fieldname = jsonNode.asText();
if (fieldname.contains("*")) {
fieldInclusionBuilder.addIncludeWildcards(fieldname);
} else {
includes.put(fieldname, true);
}
}

return fieldInclusionBuilder.putAllIncludeFields(includes).build();

} else if (sourceNode.isObject()) {
AstraSearch.SearchRequest.SourceFieldFilter.Builder sourceFieldFilterBuilder =
AstraSearch.SearchRequest.SourceFieldFilter.newBuilder();

if (sourceNode.has("includes")) {
ArrayNode includeArrayNode = (ArrayNode) sourceNode.get("includes");
HashMap<String, Boolean> includes = new HashMap<>();

for (JsonNode jsonNode : includeArrayNode) {
String fieldname = jsonNode.asText();
if (fieldname.contains("*")) {
sourceFieldFilterBuilder.addIncludeWildcards(fieldname);
} else {
includes.put(fieldname, true);
}
}

sourceFieldFilterBuilder.putAllIncludeFields(includes);
}

if (sourceNode.has("excludes")) {
ArrayNode includeArrayNode = (ArrayNode) sourceNode.get("excludes");
HashMap<String, Boolean> excludes = new HashMap<>();

for (JsonNode jsonNode : includeArrayNode) {
String fieldname = jsonNode.asText();
if (fieldname.contains("*")) {
sourceFieldFilterBuilder.addExcludeWildcards(fieldname);
} else {
excludes.put(fieldname, true);
}
}
sourceFieldFilterBuilder.putAllExcludeFields(excludes);
}
return sourceFieldFilterBuilder.build();
}
}
return AstraSearch.SearchRequest.SourceFieldFilter.newBuilder().build();
}

private static String getQuery(JsonNode body) {
if (!body.get("query").isNull() && !body.get("query").isEmpty()) {
return body.get("query").toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ SearchResult<T> search(
Long maxTime,
int howMany,
AggBuilder aggBuilder,
QueryBuilder queryBuilder);
QueryBuilder queryBuilder,
SourceFieldFilter sourceFieldFilter);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MultiCollectorManager;
Expand Down Expand Up @@ -53,6 +55,8 @@ public class LogIndexSearcherImpl implements LogIndexSearcher<LogMessage> {

private final ReferenceManager.RefreshListener refreshListener;

private final boolean allowIncludeAndExcludeSource;

@VisibleForTesting
public static SearcherManager searcherManagerFromPath(Path path) throws IOException {
MMapDirectory directory = new MMapDirectory(path);
Expand All @@ -79,6 +83,9 @@ public void afterRefresh(boolean didRefresh) {

// initialize the adapter with whatever the default schema is
openSearchAdapter.reloadSchema();
allowIncludeAndExcludeSource =
Boolean.parseBoolean(
System.getProperty("astra.query.allowIncludeAndExcludeSource", "false"));
}

@Override
Expand All @@ -89,7 +96,8 @@ public SearchResult<LogMessage> search(
Long endTimeMsEpoch,
int howMany,
AggBuilder aggBuilder,
QueryBuilder queryBuilder) {
QueryBuilder queryBuilder,
SourceFieldFilter sourceFieldFilter) {

ensureNonEmptyString(dataset, "dataset should be a non-empty string");
ensureNonNullString(queryStr, "query should be a non-empty string");
Expand Down Expand Up @@ -138,7 +146,7 @@ public SearchResult<LogMessage> search(
ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;
results = new ArrayList<>(hits.length);
for (ScoreDoc hit : hits) {
results.add(buildLogMessage(searcher, hit));
results.add(buildLogMessage(searcher, hit, sourceFieldFilter));
}
if (aggBuilder != null) {
internalAggregation = (InternalAggregation) collector[1];
Expand All @@ -164,17 +172,36 @@ public SearchResult<LogMessage> search(
}
}

private LogMessage buildLogMessage(IndexSearcher searcher, ScoreDoc hit) {
private LogMessage buildLogMessage(
IndexSearcher searcher, ScoreDoc hit, SourceFieldFilter sourceFieldFilter) {
String s = "";
try {
s = searcher.doc(hit.doc).get(SystemField.SOURCE.fieldName);
LogWireMessage wireMessage = JsonUtil.read(s, LogWireMessage.class);
Map<String, Object> source = wireMessage.getSource();

if (allowIncludeAndExcludeSource
&& sourceFieldFilter != null
&& sourceFieldFilter.getFilterType() == SourceFieldFilter.FilterType.INCLUDE) {
source =
wireMessage.getSource().keySet().stream()
.filter(sourceFieldFilter::appliesToField)
.collect(Collectors.toMap((key) -> key, (key) -> wireMessage.getSource().get(key)));
} else if (allowIncludeAndExcludeSource
&& sourceFieldFilter != null
&& sourceFieldFilter.getFilterType() == SourceFieldFilter.FilterType.EXCLUDE) {
source =
wireMessage.getSource().keySet().stream()
.filter((key) -> !sourceFieldFilter.appliesToField(key))
.collect(Collectors.toMap((key) -> key, (key) -> wireMessage.getSource().get(key)));
}

return new LogMessage(
wireMessage.getIndex(),
wireMessage.getType(),
wireMessage.getId(),
wireMessage.getTimestamp(),
wireMessage.getSource());
source);
} catch (Exception e) {
throw new IllegalStateException("Error fetching and parsing a result from index: " + s, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class SearchQuery {
public final int howMany;
public final AggBuilder aggBuilder;
public final List<String> chunkIds;
public final SourceFieldFilter sourceFieldFilter;

public SearchQuery(
String dataset,
Expand All @@ -25,7 +26,8 @@ public SearchQuery(
int howMany,
AggBuilder aggBuilder,
List<String> chunkIds,
QueryBuilder queryBuilder) {
QueryBuilder queryBuilder,
SourceFieldFilter sourceFieldFilter) {
this.dataset = dataset;
this.queryStr = queryStr;
this.startTimeEpochMs = startTimeEpochMs;
Expand All @@ -34,6 +36,7 @@ public SearchQuery(
this.aggBuilder = aggBuilder;
this.chunkIds = chunkIds;
this.queryBuilder = queryBuilder;
this.sourceFieldFilter = sourceFieldFilter;
}

@Override
Expand All @@ -57,6 +60,8 @@ public String toString() {
+ aggBuilder
+ ", queryBuilder="
+ queryBuilder
+ ", sourceFieldFilter="
+ sourceFieldFilter
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,8 @@ public static SearchQuery fromSearchRequest(AstraSearch.SearchRequest searchRequ
searchRequest.getHowMany(),
fromSearchAggregations(searchRequest.getAggregations()),
searchRequest.getChunkIdsList(),
queryBuilder);
queryBuilder,
SourceFieldFilter.fromProto(searchRequest.getSourceFieldFilter()));
}

public static SearchResult<LogMessage> fromSearchResultProtoOrEmpty(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package com.slack.astra.logstore.search;

import com.slack.astra.proto.service.AstraSearch;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class SourceFieldFilter {
private final Optional<Boolean> includeAll;
private final Optional<Boolean> excludeAll;
private final Map<String, Boolean> includeFields;
private final Map<String, Boolean> excludeFields;
private final List<String> includeWildcards;
private final List<String> excludeWildcards;
private final FilterType filterType;

public enum FilterType {
INCLUDE,
EXCLUDE
}

public SourceFieldFilter(
Optional<Boolean> includeAll,
Optional<Boolean> excludeAll,
Map<String, Boolean> includeFields,
Map<String, Boolean> excludeFields,
List<String> includeWildcards,
List<String> excludeWildcards) {
this.includeAll = includeAll;
this.excludeAll = excludeAll;
this.includeFields = includeFields;
this.excludeFields = excludeFields;
this.includeWildcards = includeWildcards;
this.excludeWildcards = excludeWildcards;

if (includeAll.isPresent() || !includeFields.isEmpty() || !includeWildcards.isEmpty()) {
this.filterType = FilterType.INCLUDE;
} else {
this.filterType = FilterType.EXCLUDE;
}
}

public static SourceFieldFilter fromProto(
AstraSearch.SearchRequest.SourceFieldFilter sourceFieldFilterProto) {
Optional<Boolean> includeAll = Optional.empty();
Optional<Boolean> excludeAll = Optional.empty();
if (sourceFieldFilterProto.hasIncludeAll()) {
includeAll = Optional.of(sourceFieldFilterProto.getIncludeAll());
}
if (sourceFieldFilterProto.hasExcludeAll()) {
excludeAll = Optional.of(sourceFieldFilterProto.getExcludeAll());
}

return new SourceFieldFilter(
includeAll,
excludeAll,
sourceFieldFilterProto.getIncludeFieldsMap(),
sourceFieldFilterProto.getExcludeFieldsMap(),
sourceFieldFilterProto.getIncludeWildcardsList(),
sourceFieldFilterProto.getExcludeWildcardsList());
}

public boolean appliesToField(String fieldname) {
Optional<Boolean> all =
this.filterType == FilterType.INCLUDE ? this.includeAll : this.excludeAll;
Map<String, Boolean> fields =
this.filterType == FilterType.INCLUDE ? this.includeFields : this.excludeFields;
List<String> wildcards =
this.filterType == FilterType.INCLUDE ? this.includeWildcards : this.excludeWildcards;

if (all.isPresent()) {
return all.get();
}

if (fields.containsKey(fieldname)) {
return fields.get(fieldname);
}

for (String wildcard : wildcards) {
if (fieldname.matches(wildcard)) {
return true;
}
}

return false;
}

public List<String> getExcludeWildcards() {
return excludeWildcards;
}

public List<String> getIncludeWildcards() {
return includeWildcards;
}

public Map<String, Boolean> getExcludeFields() {
return excludeFields;
}

public Map<String, Boolean> getIncludeFields() {
return includeFields;
}

public Optional<Boolean> getExcludeAll() {
return excludeAll;
}

public Optional<Boolean> getIncludeAll() {
return includeAll;
}

public FilterType getFilterType() {
return filterType;
}
}
17 changes: 17 additions & 0 deletions astra/src/main/proto/astra_search.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ message SearchRequest {
// The fully query object, to be used in the leaf nodes and parsed by OpenSearch
string query = 8;

// An explicit list of fields to include/exclude return per document
SourceFieldFilter source_field_filter = 9;

message SourceFieldFilter {
// Whether or not to include/exclude _all_ fields
optional bool includeAll = 1;
optional bool excludeAll = 2;

// What fields to explicitly include/exclude
map<string, bool> includeFields = 3;
map<string, bool> excludeFields = 4;

// Wildcarded field names to include/exclude
repeated string includeWildcards = 5;
repeated string excludeWildcards = 6;
}

message SearchAggregation {
// The type of aggregation (ie, avg, date_histogram, etc)
string type = 1;
Expand Down
Loading

0 comments on commit f381f04

Please sign in to comment.