Skip to content

Commit

Permalink
Merge pull request #32 from traveltime-dev/allow-returning-distance
Browse files Browse the repository at this point in the history
Allow returning distance
  • Loading branch information
mjanuszkiewicz-tt authored Aug 29, 2024
2 parents d716701 + 2e6dc2f commit 326b8ad
Show file tree
Hide file tree
Showing 137 changed files with 1,568 additions and 414 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class TraveltimePlugin extends Plugin implements SearchPlugin {

private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) {
TraveltimeCache.INSTANCE.cleanUp();
TraveltimeCache.DISTANCE.cleanUp();
threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds));
}

Expand All @@ -57,6 +58,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
Integer cacheSize = CACHE_SIZE.get(environment.settings());

TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry);
TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry);
cleanUpAndReschedule(threadPool, cleanupSeconds);

return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) {
if (traveltimeQuery == null) return null;
TraveltimeQueryParameters params = traveltimeQuery.getParams();
final String output = traveltimeQuery.getOutput();
final String distanceOutput = traveltimeQuery.getDistanceOutput();

FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.mapperService(), fetchContext.searchLookup(), List.of(new FieldAndFormat(params.getField(), null)));

Expand All @@ -59,10 +60,19 @@ public void setNextReader(LeafReaderContext readerContext) {
public void process(HitContext hitContext) throws IOException {
val docValues = hitContext.reader().getSortedNumericDocValues(params.getField());
docValues.advance(hitContext.docId());
Integer tt = TraveltimeCache.INSTANCE.get(params, docValues.nextValue());
val point = docValues.nextValue();
if(!output.isEmpty()) {
Integer tt = TraveltimeCache.INSTANCE.get(params, point);
if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
}
}

if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
if(!distanceOutput.isEmpty()) {
Integer td = TraveltimeCache.DISTANCE.get(params, point);
if (td >= 0) {
hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td)));
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class TraveltimeQueryBuilder extends AbstractQueryBuilder<TraveltimeQuery
private QueryBuilder prefilter;
@NonNull
private String output = "";
@NonNull
private String distanceOutput = "";

public TraveltimeQueryBuilder() {
}
Expand Down Expand Up @@ -63,6 +65,7 @@ public TraveltimeQueryBuilder(StreamInput in) throws IOException {
}
prefilter = in.readOptionalNamedWriteable(QueryBuilder.class);
output = in.readString();
distanceOutput = in.readString();
}

@Override
Expand All @@ -78,6 +81,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if(requestType != null) out.writeEnum(requestType);
out.writeOptionalNamedWriteable(prefilter);
out.writeString(output);
out.writeString(distanceOutput);
}

@Override
Expand All @@ -89,6 +93,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
builder.field("country", country == null ? null : country.getValue());
builder.field("prefilter", prefilter);
builder.field("output", output);
builder.field("distanceOutput", distanceOutput);
}

@Override
Expand Down Expand Up @@ -128,14 +133,19 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build();

TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType);
boolean includeDistance = !distanceOutput.isEmpty();

TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance);
if (params.getMode() == null) {
if (defaultMode.isPresent()) {
params = params.withMode(defaultMode.get());
} else {
throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config");
}
}
if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) {
throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode");
}
if (params.getCountry() == null) {
if (defaultCountry.isPresent()) {
params = params.withCountry(defaultCountry.get());
Expand All @@ -156,7 +166,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null;

return new TraveltimeSearchQuery(params, prefilterQuery, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
private final ParseField requestType = new ParseField("requestType");
private final ParseField prefilter = new ParseField("prefilter");
private final ParseField output = new ParseField("output");
private final ParseField distanceOutput = new ParseField("distanceOutput");

private final ContextParser<Void, QueryBuilder> prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p);

Expand All @@ -36,9 +37,10 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit);
queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode);
queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("country", s, Util::findRequestTypeByName)), requestType);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType);
queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter);
queryParser.declareString(TraveltimeQueryBuilder::setOutput, output);
queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput);

queryParser.declareRequiredFieldSet(field.toString());
queryParser.declareRequiredFieldSet(origin.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class TraveltimeSearchQuery extends Query {
private final TraveltimeQueryParameters params;
private final Query prefilter;
private final String output;
private final String distanceOutput;
private final URI appUri;
private final String appId;
private final String apiKey;
Expand Down Expand Up @@ -45,7 +46,7 @@ public Query rewrite(IndexReader reader) throws IOException {
if (newPrefilter == prefilter) {
return super.rewrite(reader);
} else {
return new TraveltimeSearchQuery(params, newPrefilter, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,49 @@ public Scorer scorer(LeafReaderContext context) throws IOException {

val pointToTime = new Long2IntOpenHashMap(valueArray.size());

val results = protoFetcher.getTimes(
ttQuery.getParams().getOrigin(),
decodedArray,
ttQuery.getParams().getLimit(),
ttQuery.getParams().getMode(),
ttQuery.getParams().getCountry(),
ttQuery.getParams().getRequestType()
);

for (int index = 0; index < results.size(); index++) {
if(results.get(index) >= 0) {
pointToTime.put(valueArray.getLong(index), results.get(index).intValue());
if (ttQuery.getParams().isIncludeDistance()) {
val pointToDistance = new Long2IntOpenHashMap(valueArray.size());

val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode());

val timeDistance = protoFetcher.getTimesAndDistances(
ttQuery.getParams().getOrigin(),
decodedArray,
ttQuery.getParams().getLimit(),
mode,
ttQuery.getParams().getCountry(),
ttQuery.getParams().getRequestType()
);

val times = timeDistance.getLeft();
val distances = timeDistance.getRight();

for (int index = 0; index < times.size(); index++) {
if (times.get(index) >= 0) {
pointToTime.put(valueArray.getLong(index), times.get(index).intValue());
pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue());
}
}

TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance);
} else {
val results = protoFetcher.getTimes(
ttQuery.getParams().getOrigin(),
decodedArray,
ttQuery.getParams().getLimit(),
ttQuery.getParams().getMode(),
ttQuery.getParams().getCountry(),
ttQuery.getParams().getRequestType()
);

for (int index = 0; index < results.size(); index++) {
if (results.get(index) >= 0) {
pointToTime.put(valueArray.getLong(index), results.get(index).intValue());
}
}
}

if(hasOutput) {
if (hasOutput) {
TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class TraveltimePlugin extends Plugin implements SearchPlugin {

private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) {
TraveltimeCache.INSTANCE.cleanUp();
TraveltimeCache.DISTANCE.cleanUp();
threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds));
}

Expand All @@ -57,6 +58,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
Integer cacheSize = CACHE_SIZE.get(environment.settings());

TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry);
TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry);
cleanUpAndReschedule(threadPool, cleanupSeconds);

return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) {
if (traveltimeQuery == null) return null;
TraveltimeQueryParameters params = traveltimeQuery.getParams();
final String output = traveltimeQuery.getOutput();
final String distanceOutput = traveltimeQuery.getDistanceOutput();

FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getQueryShardContext(), fetchContext.searchLookup(), List.of(new FieldAndFormat(params.getField(), null)));

Expand All @@ -59,10 +60,19 @@ public void setNextReader(LeafReaderContext readerContext) {
public void process(HitContext hitContext) throws IOException {
val docValues = hitContext.reader().getSortedNumericDocValues(params.getField());
docValues.advance(hitContext.docId());
Integer tt = TraveltimeCache.INSTANCE.get(params, docValues.nextValue());
val point = docValues.nextValue();
if(!output.isEmpty()) {
Integer tt = TraveltimeCache.INSTANCE.get(params, point);
if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
}
}

if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
if(!distanceOutput.isEmpty()) {
Integer td = TraveltimeCache.DISTANCE.get(params, point);
if (td >= 0) {
hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td)));
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class TraveltimeQueryBuilder extends AbstractQueryBuilder<TraveltimeQuery
private QueryBuilder prefilter;
@NonNull
private String output = "";
@NonNull
private String distanceOutput = "";

public TraveltimeQueryBuilder() {
}
Expand Down Expand Up @@ -63,6 +65,7 @@ public TraveltimeQueryBuilder(StreamInput in) throws IOException {
}
prefilter = in.readOptionalNamedWriteable(QueryBuilder.class);
output = in.readString();
distanceOutput = in.readString();
}

@Override
Expand All @@ -78,6 +81,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if(requestType != null) out.writeEnum(requestType);
out.writeOptionalNamedWriteable(prefilter);
out.writeString(output);
out.writeString(distanceOutput);
}

@Override
Expand All @@ -89,6 +93,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
builder.field("country", country == null ? null : country.getValue());
builder.field("prefilter", prefilter);
builder.field("output", output);
builder.field("distanceOutput", distanceOutput);
}

@Override
Expand Down Expand Up @@ -128,14 +133,19 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build();

TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType);
boolean includeDistance = !distanceOutput.isEmpty();

TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance);
if (params.getMode() == null) {
if (defaultMode.isPresent()) {
params = params.withMode(defaultMode.get());
} else {
throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config");
}
}
if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) {
throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode");
}
if (params.getCountry() == null) {
if (defaultCountry.isPresent()) {
params = params.withCountry(defaultCountry.get());
Expand All @@ -156,7 +166,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null;

return new TraveltimeSearchQuery(params, prefilterQuery, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
private final ParseField requestType = new ParseField("requestType");
private final ParseField prefilter = new ParseField("prefilter");
private final ParseField output = new ParseField("output");
private final ParseField distanceOutput = new ParseField("distanceOutput");

private final ContextParser<Void, QueryBuilder> prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p);

Expand All @@ -36,9 +37,10 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit);
queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode);
queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("country", s, Util::findRequestTypeByName)), requestType);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType);
queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter);
queryParser.declareString(TraveltimeQueryBuilder::setOutput, output);
queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput);

queryParser.declareRequiredFieldSet(field.toString());
queryParser.declareRequiredFieldSet(origin.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class TraveltimeSearchQuery extends Query {
private final TraveltimeQueryParameters params;
private final Query prefilter;
private final String output;
private final String distanceOutput;
private final URI appUri;
private final String appId;
private final String apiKey;
Expand Down Expand Up @@ -45,7 +46,7 @@ public Query rewrite(IndexReader reader) throws IOException {
if (newPrefilter == prefilter) {
return super.rewrite(reader);
} else {
return new TraveltimeSearchQuery(params, newPrefilter, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey);
}
}
}
Loading

0 comments on commit 326b8ad

Please sign in to comment.