Skip to content

Commit

Permalink
Initial streaming aggs commit
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Handalian <[email protected]>
  • Loading branch information
mch2 committed Jan 15, 2025
1 parent b076dd6 commit 364edbe
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,82 @@ static int getTopDocsSize(SearchRequest request) {
: source.from());
}

public ReducedQueryPhase reducedAggsFromStream(List<StreamSearchResult> list) {


try (SessionContext context = new SessionContext()) {

List<byte[]> tickets = list.stream().flatMap(r -> r.getFlightTickets().stream())
.map(OSTicket::getBytes)
.collect(Collectors.toList());

// execute the query and get a dataframe
CompletableFuture<DataFrame> frame = DataFusion.query(tickets);

DataFrame dataFrame = null;
ArrowReader arrowReader = null;
try {
dataFrame = frame.get();
arrowReader = dataFrame.collect(new RootAllocator()).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}

int totalRows = 0;
List<ScoreDoc> scoreDocs = new ArrayList<>();
try {
while (arrowReader.loadNextBatch()) {
VectorSchemaRoot root = arrowReader.getVectorSchemaRoot();
int rowCount = root.getRowCount();
totalRows+= rowCount;
System.out.println("Record Batch with " + rowCount + " rows:");

// Iterate through rows
for (int row = 0; row < rowCount; row++) {
FieldVector docID = root.getVector("docID");
Float4Vector score = (Float4Vector) root.getVector("score");
FieldVector shardID = root.getVector("shardID");
FieldVector nodeID = root.getVector("nodeID");

ShardId sid = ShardId.fromString((String) getValue(shardID, row));
int value = (int) getValue(docID, row);
System.out.println("DocID: " + value + " ShardID" + sid + "NodeID: " + getValue(nodeID, row));
scoreDocs.add(new ScoreDoc(value, score.get(row), sid.id()));
}
}

TotalHits totalHits = new TotalHits(totalRows, Relation.EQUAL_TO);
return new ReducedQueryPhase(
totalHits,
totalRows,
1.0f,
false,
false,
null,
null,
null,
new SortedTopDocs(scoreDocs.toArray(ScoreDoc[]::new), false, null, null, null),
null,
1,
totalRows,
0,
totalRows == 0,
list.stream().flatMap(ssr -> ssr.getFlightTickets().stream()).collect(Collectors.toList())
);
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
try {
arrowReader.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}

public ReducedQueryPhase reducedFromStream(List<StreamSearchResult> list) {


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@
package org.opensearch.action.search;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.arrow.StreamIterator;
import org.opensearch.arrow.StreamManager;
import org.opensearch.arrow.StreamProducer;
import org.opensearch.arrow.StreamTicket;
Expand All @@ -44,6 +47,7 @@
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.datafusion.DataFrame;
import org.opensearch.datafusion.DataFrameStreamProducer;
import org.opensearch.datafusion.DataFusion;
Expand All @@ -61,6 +65,7 @@
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.transport.Transport;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -84,10 +89,10 @@ public StreamAsyncAction(Logger logger, SearchTransportService searchTransportSe
this.searchPhaseController = searchPhaseController;
}

@Override
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
return new StreamSearchReducePhase("stream_reduce", context);
}
// @Override
// protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
// return new StreamSearchReducePhase("stream_reduce", context);
// }

class StreamSearchReducePhase extends SearchPhase {
private SearchPhaseContext context;
Expand Down Expand Up @@ -117,8 +122,11 @@ protected void doRun() throws Exception {
// fetch all the tickets (one byte[] per shard) and hand that off to Datafusion.Query
// this creates a single stream that we'll register with the streammanager on this coordinator.
List<SearchPhaseResult> results = StreamAsyncAction.this.results.getAtomicArray().asList();
List<byte[]> tickets = results.stream().flatMap(r -> ((StreamSearchResult) r).getFlightTickets().stream())
.map(OSTicket::getBytes)
// List<byte[]> tickets = results.stream().flatMap(r -> ((StreamSearchResult) r).getFlightTickets().stream())
// .map(OSTicket::getBytes)
// .collect(Collectors.toList());

List<OSTicket> tickets = results.stream().flatMap(r -> ((StreamSearchResult) r).getFlightTickets().stream())
.collect(Collectors.toList());

// This is additional metadata for the fetch phase that will be conducted on the coordinator
Expand All @@ -129,9 +137,25 @@ protected void doRun() throws Exception {
.map(r -> new StreamTargetResponse(r.queryResult(), r.getSearchShardTarget()))
.collect(Collectors.toList());

StreamManager streamManager = searchPhaseController.getStreamManager();
StreamTicket streamTicket = streamManager.registerStream(DataFrameStreamProducer.query(tickets));
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(SearchHits.empty(), null, null, null, false, false, 1, Collections.emptyList(), List.of(new OSTicket(streamTicket.getTicketID(), streamTicket.getNodeID())), targets);
// StreamManager streamManager = searchPhaseController.getStreamManager();
// StreamIterator streamIterator = streamManager.getStreamIterator(StreamTicket.fromBytes(tickets.get(0)));
// List<TransportStreamedJoinAction.Hit> hits = new ArrayList<>();
// while (streamIterator.next()) {
// VectorSchemaRoot root = streamIterator.getRoot();
// int rowCount = root.getRowCount();
// // Iterate through rows
// for (int row = 0; row < rowCount; row++) {
// FieldVector ord = root.getVector("ord");
// FieldVector count = root.getVector("count");;
//
//
// int ordVal = (int) getValue(ord, row);
// int countVal = (int) getValue(count, row);
// logger.info("ORD {} COUNT {}", ordVal, countVal);
// }
// }
// StreamTicket streamTicket = streamManager.registerStream(DataFrameStreamProducer.query(tickets));
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(SearchHits.empty(), null, null, null, false, false, 1, Collections.emptyList(), List.of(tickets.get(0)), targets);
context.sendSearchResponse(internalSearchResponse, StreamAsyncAction.this.results.getAtomicArray());
} catch (Exception e) {
logger.error("broken", e);
Expand All @@ -144,4 +168,15 @@ public void onFailure(Exception e) {
context.onPhaseFailure(phase, "", e);
}
}

private static Object getValue(FieldVector vector, int index) {
if (vector == null || vector.isNull(index)) {
return "null";
}

if (vector instanceof IntVector) {
return ((IntVector) vector).get(index);
}
return "null";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public abstract class Aggregator extends BucketCollector implements Releasable {

private final SetOnce<InternalAggregation> internalAggregation = new SetOnce<>();

public void reset() {}

/**
* Parses the aggregation request and creates the appropriate aggregator factory for it.
*
Expand Down Expand Up @@ -206,6 +208,11 @@ public final InternalAggregation buildTopLevel() throws IOException {
return internalAggregation.get();
}

public final InternalAggregation buildBatchedAgg() throws IOException {
assert parent() == null;
return buildAggregations(new long[] { 0 })[0];
}

/**
* Build an empty aggregation.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ public final void grow(long maxBucketOrd) {
docCounts = bigArrays.grow(docCounts, maxBucketOrd);
}

@Override
public void reset() {
docCounts = bigArrays.newLongArray(1, true);
}

/**
* Utility method to collect the given doc in the given bucket (identified by the bucket ordinal)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,8 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
}
int ord = singleValues.ordValue();
long docCount = docCountProvider.getDocCount(doc);
BytesRef bytesRef = singleValues.lookupOrd(ord);
String s = bytesRef.utf8ToString();
segmentDocCounts.increment(ord + 1, docCount);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
package org.opensearch.search.aggregations.bucket.terms;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.action.search.SearchType;
import org.opensearch.core.ParseField;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.DocValueFormat;
Expand Down Expand Up @@ -441,6 +442,7 @@ Aggregator create(
&& includeExclude == null
&& cardinality == CardinalityUpperBound.ONE
&& ordinalsValuesSource.supportsGlobalOrdinalsMapping()
&& context.searchType() != SearchType.STREAM
&&
// we use the static COLLECT_SEGMENT_ORDS to allow tests to force specific optimizations
(COLLECT_SEGMENT_ORDS != null ? COLLECT_SEGMENT_ORDS.booleanValue() : ratio <= 0.5 && maxOrd <= 2048)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations.support;

import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.FilterCollector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
import org.opensearch.arrow.StreamProducer;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.search.aggregations.Aggregation;
import org.opensearch.search.aggregations.Aggregations;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.BucketCollectorProcessor;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.bucket.terms.InternalMappedTerms;
import org.opensearch.search.aggregations.bucket.terms.InternalTerms;
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

public class StreamingAggregator extends FilterCollector {

private final Aggregator aggregator;
private final SearchContext searchContext;
private final VectorSchemaRoot root;
private final StreamProducer.FlushSignal flushSignal;
private final int batchSize;
private final ShardId shardId;
/**
* Sole constructor.
*
* @param in
*/
public StreamingAggregator(
Aggregator in,
SearchContext searchContext,
VectorSchemaRoot root,
int batchSize,
StreamProducer.FlushSignal flushSignal,
ShardId shardId
) {
super(in);
this.aggregator = in;
this.searchContext = searchContext;
this.root = root;
this.batchSize = batchSize;
this.flushSignal = flushSignal;
this.shardId = shardId;
}

@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {

Map<String, FieldVector> vectors = new HashMap<>();
vectors.put("ord", root.getVector("ord"));
vectors.put("count", root.getVector("count"));
final int[] currentRow = {0};
return new LeafBucketCollector() {


@Override
public void setScorer(Scorable scorer) throws IOException {

}

@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
final LeafBucketCollector leaf = aggregator.getLeafCollector(context);
leaf.collect(doc);
currentRow[0]++;
if (currentRow[0] == batchSize) {
flushBatch();
}

// hit batch size

// flush
}

private void flushBatch() throws IOException {
InternalAggregation agg = aggregator.buildAggregations(new long[]{0})[0];
if (agg instanceof InternalMappedTerms) {
InternalMappedTerms<?,?> terms = (InternalMappedTerms<?,?>) agg;

List<? extends InternalTerms.Bucket> buckets = terms.getBuckets();
for (InternalTerms.Bucket bucket : buckets) {
// Get key/value info
String key = bucket.getKeyAsString();
long docCount = bucket.getDocCount();

Aggregations aggregations = bucket.getAggregations();
for (Aggregation aggregation : aggregations) {
// TODO: subs
}

// Write to vector storage
// e.g., for term and count vectors:
// VarCharVector keyVector = (VarCharVector) vectors.get("key");
// keyVector.setSafe(i, key.getBytes());
FieldVector termVector = vectors.get("ord");
FieldVector countVector = vectors.get("count");
((VarCharVector) termVector).setSafe(0, key.getBytes());
((Float4Vector) countVector).setSafe(0, docCount);

// Add the values...
}

aggregator.reset();

// Also access high-level statistics
// long otherDocCount = terms.getSumOfOtherDocCounts();
// long docCountError = terms.getDocCountError();
}

// Reset for next batch
currentRow[0] = 0;
root.setRowCount(currentRow[0]);
flushSignal.awaitConsumption(1000);
}
};
}
}
Loading

0 comments on commit 364edbe

Please sign in to comment.