Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow returning distance #32

Merged
merged 24 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4cba261
Bump sdk version
mjanuszkiewicz-tt Aug 27, 2024
8c2af72
Implement distance querios for version 7.10
mjanuszkiewicz-tt Aug 28, 2024
ac8bae6
Implement distance querios for version 7.11
mjanuszkiewicz-tt Aug 28, 2024
5baa658
Implement distance querios for version 7.12
mjanuszkiewicz-tt Aug 28, 2024
6938c1a
Implement distance querios for version 7.13
mjanuszkiewicz-tt Aug 28, 2024
191816c
Implement distance querios for version 7.14
mjanuszkiewicz-tt Aug 28, 2024
e29f446
Implement distance querios for version 7.15
mjanuszkiewicz-tt Aug 28, 2024
c2bbce9
Implement distance querios for version 7.16
mjanuszkiewicz-tt Aug 28, 2024
1c7bffb
Implement distance querios for version 7.17
mjanuszkiewicz-tt Aug 28, 2024
c97f6da
Implement distance queries for version 8.0-8.3
mjanuszkiewicz-tt Aug 28, 2024
96d0807
Implement distance queries for version 8.5
mjanuszkiewicz-tt Aug 28, 2024
f64b2d9
Implement distance queries for version 8.5-8.6
mjanuszkiewicz-tt Aug 28, 2024
bf17e66
Implement distance queries for version 8.7-8.9
mjanuszkiewicz-tt Aug 28, 2024
a764626
Fix missaplied patch
mjanuszkiewicz-tt Aug 28, 2024
6486d7d
Fix fetch phase in 7.10
mjanuszkiewicz-tt Aug 29, 2024
abd1244
Fix fetch phase in 7.11-8.9
mjanuszkiewicz-tt Aug 29, 2024
454ec0d
Fix misapplied patch
mjanuszkiewicz-tt Aug 29, 2024
dec7dad
Apply patch to 8.10
mjanuszkiewicz-tt Aug 29, 2024
22cd3f9
Apply patch to 8.11
mjanuszkiewicz-tt Aug 29, 2024
2207d57
Apply patch to 8.12
mjanuszkiewicz-tt Aug 29, 2024
ce5fbe0
Apply patch to 8.13
mjanuszkiewicz-tt Aug 29, 2024
0aff981
Unify query builders
mjanuszkiewicz-tt Aug 29, 2024
b1bf8ff
Change distance output field name
mjanuszkiewicz-tt Aug 29, 2024
2e6dc2f
Fix misapplied patch
mjanuszkiewicz-tt Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class TraveltimePlugin extends Plugin implements SearchPlugin {

private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) {
TraveltimeCache.INSTANCE.cleanUp();
TraveltimeCache.DISTANCE.cleanUp();
threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds));
}

Expand All @@ -57,6 +58,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
Integer cacheSize = CACHE_SIZE.get(environment.settings());

TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry);
TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry);
cleanUpAndReschedule(threadPool, cleanupSeconds);

return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) {
if (traveltimeQuery == null) return null;
TraveltimeQueryParameters params = traveltimeQuery.getParams();
final String output = traveltimeQuery.getOutput();
final String distanceOutput = traveltimeQuery.getDistanceOutput();

FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.mapperService(), fetchContext.searchLookup(), List.of(new FieldAndFormat(params.getField(), null)));

Expand All @@ -59,10 +60,19 @@ public void setNextReader(LeafReaderContext readerContext) {
public void process(HitContext hitContext) throws IOException {
val docValues = hitContext.reader().getSortedNumericDocValues(params.getField());
docValues.advance(hitContext.docId());
Integer tt = TraveltimeCache.INSTANCE.get(params, docValues.nextValue());
val point = docValues.nextValue();
if(!output.isEmpty()) {
Integer tt = TraveltimeCache.INSTANCE.get(params, point);
if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
}
}

if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
if(!distanceOutput.isEmpty()) {
Integer td = TraveltimeCache.DISTANCE.get(params, point);
if (td >= 0) {
hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td)));
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class TraveltimeQueryBuilder extends AbstractQueryBuilder<TraveltimeQuery
private QueryBuilder prefilter;
@NonNull
private String output = "";
@NonNull
private String distanceOutput = "";

public TraveltimeQueryBuilder() {
}
Expand Down Expand Up @@ -63,6 +65,7 @@ public TraveltimeQueryBuilder(StreamInput in) throws IOException {
}
prefilter = in.readOptionalNamedWriteable(QueryBuilder.class);
output = in.readString();
distanceOutput = in.readString();
}

@Override
Expand All @@ -78,6 +81,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if(requestType != null) out.writeEnum(requestType);
out.writeOptionalNamedWriteable(prefilter);
out.writeString(output);
out.writeString(distanceOutput);
}

@Override
Expand All @@ -89,6 +93,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
builder.field("country", country == null ? null : country.getValue());
builder.field("prefilter", prefilter);
builder.field("output", output);
builder.field("distanceOutput", distanceOutput);
}

@Override
Expand Down Expand Up @@ -128,14 +133,19 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build();

TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType);
boolean includeDistance = !distanceOutput.isEmpty();

TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance);
if (params.getMode() == null) {
if (defaultMode.isPresent()) {
params = params.withMode(defaultMode.get());
} else {
throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config");
}
}
if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) {
throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode");
}
if (params.getCountry() == null) {
if (defaultCountry.isPresent()) {
params = params.withCountry(defaultCountry.get());
Expand All @@ -156,7 +166,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null;

return new TraveltimeSearchQuery(params, prefilterQuery, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
private final ParseField requestType = new ParseField("requestType");
private final ParseField prefilter = new ParseField("prefilter");
private final ParseField output = new ParseField("output");
private final ParseField distanceOutput = new ParseField("distanceOutput");

private final ContextParser<Void, QueryBuilder> prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p);

Expand All @@ -36,9 +37,10 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit);
queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode);
queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("country", s, Util::findRequestTypeByName)), requestType);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType);
queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter);
queryParser.declareString(TraveltimeQueryBuilder::setOutput, output);
queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput);

queryParser.declareRequiredFieldSet(field.toString());
queryParser.declareRequiredFieldSet(origin.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class TraveltimeSearchQuery extends Query {
private final TraveltimeQueryParameters params;
private final Query prefilter;
private final String output;
private final String distanceOutput;
private final URI appUri;
private final String appId;
private final String apiKey;
Expand Down Expand Up @@ -45,7 +46,7 @@ public Query rewrite(IndexReader reader) throws IOException {
if (newPrefilter == prefilter) {
return super.rewrite(reader);
} else {
return new TraveltimeSearchQuery(params, newPrefilter, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,49 @@ public Scorer scorer(LeafReaderContext context) throws IOException {

val pointToTime = new Long2IntOpenHashMap(valueArray.size());

val results = protoFetcher.getTimes(
ttQuery.getParams().getOrigin(),
decodedArray,
ttQuery.getParams().getLimit(),
ttQuery.getParams().getMode(),
ttQuery.getParams().getCountry(),
ttQuery.getParams().getRequestType()
);

for (int index = 0; index < results.size(); index++) {
if(results.get(index) >= 0) {
pointToTime.put(valueArray.getLong(index), results.get(index).intValue());
if (ttQuery.getParams().isIncludeDistance()) {
val pointToDistance = new Long2IntOpenHashMap(valueArray.size());

val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode());

val timeDistance = protoFetcher.getTimesAndDistances(
ttQuery.getParams().getOrigin(),
decodedArray,
ttQuery.getParams().getLimit(),
mode,
ttQuery.getParams().getCountry(),
ttQuery.getParams().getRequestType()
);

val times = timeDistance.getLeft();
val distances = timeDistance.getRight();

for (int index = 0; index < times.size(); index++) {
if (times.get(index) >= 0) {
pointToTime.put(valueArray.getLong(index), times.get(index).intValue());
pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue());
}
}

TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance);
} else {
val results = protoFetcher.getTimes(
ttQuery.getParams().getOrigin(),
decodedArray,
ttQuery.getParams().getLimit(),
ttQuery.getParams().getMode(),
ttQuery.getParams().getCountry(),
ttQuery.getParams().getRequestType()
);

for (int index = 0; index < results.size(); index++) {
if (results.get(index) >= 0) {
pointToTime.put(valueArray.getLong(index), results.get(index).intValue());
}
}
}

if(hasOutput) {
if (hasOutput) {
TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class TraveltimePlugin extends Plugin implements SearchPlugin {

private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) {
TraveltimeCache.INSTANCE.cleanUp();
TraveltimeCache.DISTANCE.cleanUp();
threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds));
}

Expand All @@ -57,6 +58,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
Integer cacheSize = CACHE_SIZE.get(environment.settings());

TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry);
TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry);
cleanUpAndReschedule(threadPool, cleanupSeconds);

return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) {
if (traveltimeQuery == null) return null;
TraveltimeQueryParameters params = traveltimeQuery.getParams();
final String output = traveltimeQuery.getOutput();
final String distanceOutput = traveltimeQuery.getDistanceOutput();

FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getQueryShardContext(), fetchContext.searchLookup(), List.of(new FieldAndFormat(params.getField(), null)));

Expand All @@ -59,10 +60,19 @@ public void setNextReader(LeafReaderContext readerContext) {
public void process(HitContext hitContext) throws IOException {
val docValues = hitContext.reader().getSortedNumericDocValues(params.getField());
docValues.advance(hitContext.docId());
Integer tt = TraveltimeCache.INSTANCE.get(params, docValues.nextValue());
val point = docValues.nextValue();
if(!output.isEmpty()) {
Integer tt = TraveltimeCache.INSTANCE.get(params, point);
if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
}
}

if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
if(!distanceOutput.isEmpty()) {
Integer td = TraveltimeCache.DISTANCE.get(params, point);
if (td >= 0) {
hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td)));
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class TraveltimeQueryBuilder extends AbstractQueryBuilder<TraveltimeQuery
private QueryBuilder prefilter;
@NonNull
private String output = "";
@NonNull
private String distanceOutput = "";

public TraveltimeQueryBuilder() {
}
Expand Down Expand Up @@ -63,6 +65,7 @@ public TraveltimeQueryBuilder(StreamInput in) throws IOException {
}
prefilter = in.readOptionalNamedWriteable(QueryBuilder.class);
output = in.readString();
distanceOutput = in.readString();
}

@Override
Expand All @@ -78,6 +81,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if(requestType != null) out.writeEnum(requestType);
out.writeOptionalNamedWriteable(prefilter);
out.writeString(output);
out.writeString(distanceOutput);
}

@Override
Expand All @@ -89,6 +93,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
builder.field("country", country == null ? null : country.getValue());
builder.field("prefilter", prefilter);
builder.field("output", output);
builder.field("distanceOutput", distanceOutput);
}

@Override
Expand Down Expand Up @@ -128,14 +133,19 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build();

TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType);
boolean includeDistance = !distanceOutput.isEmpty();

TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance);
if (params.getMode() == null) {
if (defaultMode.isPresent()) {
params = params.withMode(defaultMode.get());
} else {
throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config");
}
}
if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) {
throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode");
}
if (params.getCountry() == null) {
if (defaultCountry.isPresent()) {
params = params.withCountry(defaultCountry.get());
Expand All @@ -156,7 +166,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null;

return new TraveltimeSearchQuery(params, prefilterQuery, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
private final ParseField requestType = new ParseField("requestType");
private final ParseField prefilter = new ParseField("prefilter");
private final ParseField output = new ParseField("output");
private final ParseField distanceOutput = new ParseField("distanceOutput");

private final ContextParser<Void, QueryBuilder> prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p);

Expand All @@ -36,9 +37,10 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit);
queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode);
queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("country", s, Util::findRequestTypeByName)), requestType);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType);
queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter);
queryParser.declareString(TraveltimeQueryBuilder::setOutput, output);
queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput);

queryParser.declareRequiredFieldSet(field.toString());
queryParser.declareRequiredFieldSet(origin.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class TraveltimeSearchQuery extends Query {
private final TraveltimeQueryParameters params;
private final Query prefilter;
private final String output;
private final String distanceOutput;
private final URI appUri;
private final String appId;
private final String apiKey;
Expand Down Expand Up @@ -45,7 +46,7 @@ public Query rewrite(IndexReader reader) throws IOException {
if (newPrefilter == prefilter) {
return super.rewrite(reader);
} else {
return new TraveltimeSearchQuery(params, newPrefilter, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey);
}
}
}
Loading
Loading