-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ba690ee
commit eabe2c9
Showing
10 changed files
with
722 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
buildPlugin(this, '8.13', revisions(0)) | ||
|
||
compileJava { | ||
sourceCompatibility = JavaVersion.VERSION_17 | ||
targetCompatibility = JavaVersion.VERSION_17 | ||
} |
82 changes: 82 additions & 0 deletions
82
8.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package com.traveltime.plugin.elasticsearch; | ||
|
||
|
||
import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; | ||
import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; | ||
import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; | ||
import com.traveltime.plugin.elasticsearch.util.Util; | ||
import com.traveltime.sdk.dto.requests.proto.Country; | ||
import com.traveltime.sdk.dto.requests.proto.RequestType; | ||
import com.traveltime.sdk.dto.requests.proto.Transportation; | ||
import org.elasticsearch.client.internal.Client; | ||
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; | ||
import org.elasticsearch.cluster.routing.allocation.AllocationService; | ||
import org.elasticsearch.cluster.service.ClusterService; | ||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; | ||
import org.elasticsearch.common.settings.Setting; | ||
import org.elasticsearch.core.TimeValue; | ||
import org.elasticsearch.env.Environment; | ||
import org.elasticsearch.env.NodeEnvironment; | ||
import org.elasticsearch.indices.IndicesService; | ||
import org.elasticsearch.plugins.Plugin; | ||
import org.elasticsearch.plugins.SearchPlugin; | ||
import org.elasticsearch.repositories.RepositoriesService; | ||
import org.elasticsearch.script.ScriptService; | ||
import org.elasticsearch.search.fetch.FetchSubPhase; | ||
import org.elasticsearch.telemetry.TelemetryProvider; | ||
import org.elasticsearch.threadpool.ThreadPool; | ||
import org.elasticsearch.watcher.ResourceWatcherService; | ||
import org.elasticsearch.xcontent.NamedXContentRegistry; | ||
|
||
import java.net.URI; | ||
import java.time.Duration; | ||
import java.util.Collection; | ||
import java.util.List; | ||
import java.util.Optional; | ||
import java.util.function.Supplier; | ||
|
||
public class TraveltimePlugin extends Plugin implements SearchPlugin { | ||
public static final Setting<String> APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); | ||
public static final Setting<String> API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); | ||
public static final Setting<Optional<Transportation.Modes>> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); | ||
public static final Setting<Optional<Country>> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); | ||
|
||
public static final Setting<Optional<RequestType>> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); | ||
public static final Setting<URI> API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); | ||
|
||
private static final Setting<Integer> CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); | ||
private static final Setting<Integer> CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); | ||
private static final Setting<Integer> CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); | ||
|
||
private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { | ||
TraveltimeCache.INSTANCE.cleanUp(); | ||
threadPool.scheduleUnlessShuttingDown(cleanupSeconds, threadPool.generic(), () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); | ||
} | ||
|
||
@Override | ||
public Collection<?> createComponents(PluginServices pluginServices) { | ||
TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(pluginServices.environment().settings())); | ||
Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(pluginServices.environment().settings())); | ||
Integer cacheSize = CACHE_SIZE.get(pluginServices.environment().settings()); | ||
|
||
TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); | ||
cleanUpAndReschedule(pluginServices.threadPool(), cleanupSeconds); | ||
|
||
return super.createComponents(pluginServices); | ||
} | ||
|
||
@Override | ||
public List<Setting<?>> getSettings() { | ||
return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); | ||
} | ||
|
||
@Override | ||
public List<QuerySpec<?>> getQueries() { | ||
return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); | ||
} | ||
|
||
@Override | ||
public List<FetchSubPhase> getFetchSubPhases(FetchPhaseConstructionContext context) { | ||
return List.of(new TraveltimeFetchPhase()); | ||
} | ||
} |
77 changes: 77 additions & 0 deletions
77
8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package com.traveltime.plugin.elasticsearch.query; | ||
|
||
import com.traveltime.plugin.elasticsearch.TraveltimeCache; | ||
import lombok.val; | ||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.search.Query; | ||
import org.apache.lucene.search.QueryVisitor; | ||
import org.elasticsearch.common.document.DocumentField; | ||
import org.elasticsearch.search.fetch.FetchContext; | ||
import org.elasticsearch.search.fetch.FetchSubPhase; | ||
import org.elasticsearch.search.fetch.FetchSubPhaseProcessor; | ||
import org.elasticsearch.search.fetch.StoredFieldsSpec; | ||
import org.elasticsearch.search.fetch.subphase.FieldAndFormat; | ||
import org.elasticsearch.search.fetch.subphase.FieldFetcher; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Set; | ||
|
||
public class TraveltimeFetchPhase implements FetchSubPhase { | ||
|
||
private static class ParamFinder extends QueryVisitor { | ||
private final List<TraveltimeSearchQuery> paramList = new ArrayList<>(); | ||
|
||
@Override | ||
public void visitLeaf(Query query) { | ||
if (query instanceof TraveltimeSearchQuery) { | ||
if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { | ||
paramList.add(((TraveltimeSearchQuery) query)); | ||
} | ||
} | ||
} | ||
|
||
public TraveltimeSearchQuery getQuery() { | ||
if (paramList.size() == 1) return paramList.get(0); | ||
else return null; | ||
} | ||
} | ||
|
||
@Override | ||
public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { | ||
Query query = fetchContext.query(); | ||
val finder = new ParamFinder(); | ||
query.visit(finder); | ||
TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); | ||
if (traveltimeQuery == null) return null; | ||
TraveltimeQueryParameters params = traveltimeQuery.getParams(); | ||
final String output = traveltimeQuery.getOutput(); | ||
|
||
FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); | ||
|
||
return new FetchSubPhaseProcessor() { | ||
|
||
@Override | ||
public void setNextReader(LeafReaderContext readerContext) { | ||
fieldFetcher.setNextReader(readerContext); | ||
} | ||
|
||
@Override | ||
public void process(HitContext hitContext) throws IOException { | ||
val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); | ||
docValues.advance(hitContext.docId()); | ||
Integer tt = TraveltimeCache.INSTANCE.get(params, docValues.nextValue()); | ||
|
||
if (tt >= 0) { | ||
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); | ||
} | ||
} | ||
|
||
@Override | ||
public StoredFieldsSpec storedFieldsSpec() { | ||
return new StoredFieldsSpec(false, false, Set.of(params.getField())); | ||
} | ||
}; | ||
} | ||
} |
192 changes: 192 additions & 0 deletions
192
8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
package com.traveltime.plugin.elasticsearch.query; | ||
|
||
import com.traveltime.plugin.elasticsearch.TraveltimePlugin; | ||
import com.traveltime.plugin.elasticsearch.util.Util; | ||
import com.traveltime.sdk.dto.common.Coordinates; | ||
import com.traveltime.sdk.dto.requests.proto.Country; | ||
import com.traveltime.sdk.dto.requests.proto.RequestType; | ||
import com.traveltime.sdk.dto.requests.proto.Transportation; | ||
import lombok.NonNull; | ||
import lombok.Setter; | ||
import org.apache.lucene.search.Query; | ||
import org.elasticsearch.TransportVersion; | ||
import org.elasticsearch.TransportVersions; | ||
import org.elasticsearch.common.geo.GeoPoint; | ||
import org.elasticsearch.common.geo.GeoUtils; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.index.mapper.GeoPointFieldMapper; | ||
import org.elasticsearch.index.mapper.MappedFieldType; | ||
import org.elasticsearch.index.query.*; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
|
||
import java.io.IOException; | ||
import java.net.URI; | ||
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
@Setter | ||
public class TraveltimeQueryBuilder extends AbstractQueryBuilder<TraveltimeQueryBuilder> { | ||
@NonNull | ||
private String field; | ||
@NonNull | ||
private GeoPoint origin; | ||
private int limit; | ||
private Transportation.Modes mode; | ||
private Country country; | ||
private RequestType requestType; | ||
private QueryBuilder prefilter; | ||
@NonNull | ||
private String output = ""; | ||
|
||
public TraveltimeQueryBuilder() { | ||
} | ||
|
||
public TraveltimeQueryBuilder(StreamInput in) throws IOException { | ||
super(in); | ||
field = in.readString(); | ||
origin = in.readGeoPoint(); | ||
limit = in.readInt(); | ||
mode = in.readOptionalEnum(Transportation.Modes.class); | ||
String c = in.readOptionalString(); | ||
if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); | ||
requestType = in.readOptionalEnum(RequestType.class); | ||
prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); | ||
output = in.readString(); | ||
} | ||
|
||
@Override | ||
protected void doWriteTo(StreamOutput out) throws IOException { | ||
out.writeString(field); | ||
out.writeGeoPoint(origin); | ||
out.writeInt(limit); | ||
out.writeOptionalEnum(mode); | ||
out.writeOptionalString(country == null ? null : country.getValue()); | ||
out.writeOptionalEnum(requestType); | ||
out.writeOptionalNamedWriteable(prefilter); | ||
out.writeString(output); | ||
} | ||
|
||
@Override | ||
protected void doXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.field("field", field); | ||
builder.field("origin", origin); | ||
builder.field("limit", limit); | ||
builder.field("mode", mode == null ? null : mode.getValue()); | ||
builder.field("country", country == null ? null : country.getValue()); | ||
builder.field("requestType", requestType == null ? null : requestType.name()); | ||
builder.field("prefilter", prefilter); | ||
builder.field("output", output); | ||
} | ||
|
||
@Override | ||
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { | ||
if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); | ||
return super.doRewrite(queryRewriteContext); | ||
} | ||
|
||
@Override | ||
protected Query doToQuery(SearchExecutionContext context) throws IOException { | ||
MappedFieldType originMapping = context.getFieldType(field); | ||
if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { | ||
throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); | ||
} | ||
|
||
GeoUtils.normalizePoint(origin); | ||
if (!GeoUtils.isValidLatitude(origin.getLat())) { | ||
throw new QueryShardException(context, "latitude invalid for origin " + origin); | ||
} | ||
if (!GeoUtils.isValidLongitude(origin.getLon())) { | ||
throw new QueryShardException(context, "longitude invalid for origin " + origin); | ||
} | ||
|
||
URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); | ||
String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); | ||
String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); | ||
if (appId.isEmpty()) { | ||
throw new IllegalStateException("Traveltime app id must be set in the config"); | ||
} | ||
if (apiKey.isEmpty()) { | ||
throw new IllegalStateException("Traveltime api key must be set in the config"); | ||
} | ||
|
||
Optional<Transportation.Modes> defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); | ||
Optional<Country> defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); | ||
Optional<RequestType> defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); | ||
|
||
Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); | ||
|
||
TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType); | ||
if (params.getMode() == null) { | ||
if (defaultMode.isPresent()) { | ||
params = params.withMode(defaultMode.get()); | ||
} else { | ||
throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); | ||
} | ||
} | ||
if (params.getCountry() == null) { | ||
if (defaultCountry.isPresent()) { | ||
params = params.withCountry(defaultCountry.get()); | ||
} else { | ||
throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); | ||
} | ||
} | ||
if(params.getRequestType() == null) { | ||
if(defaultRequestType.isPresent()) { | ||
params = params.withRequestType(defaultRequestType.get()); | ||
} else { | ||
throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); | ||
} | ||
|
||
} | ||
if (params.getLimit() <= 0) { | ||
throw new IllegalStateException("Traveltime limit must be greater than zero"); | ||
} | ||
|
||
Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; | ||
|
||
return new TraveltimeSearchQuery(params, prefilterQuery, output, appUri, appId, apiKey); | ||
} | ||
|
||
@Override | ||
protected boolean doEquals(TraveltimeQueryBuilder other) { | ||
if (!Objects.equals(this.field, other.field)) return false; | ||
if (!Objects.equals(this.origin, other.origin)) return false; | ||
if (!Objects.equals(this.mode, other.mode)) return false; | ||
if (!Objects.equals(this.country, other.country)) return false; | ||
if (!Objects.equals(this.prefilter, other.prefilter)) return false; | ||
if (!Objects.equals(this.output, other.output)) return false; | ||
return this.limit == other.limit; | ||
} | ||
|
||
@Override | ||
protected int doHashCode() { | ||
final int PRIME = 59; | ||
int result = 1; | ||
result = result * PRIME + this.field.hashCode(); | ||
result = result * PRIME + this.origin.hashCode(); | ||
result = result * PRIME + Objects.hashCode(this.mode); | ||
result = result * PRIME + Objects.hashCode(this.country); | ||
result = result * PRIME + Objects.hashCode(this.prefilter); | ||
result = result * PRIME + Objects.hashCode(this.output); | ||
result = result * PRIME + this.limit; | ||
return result; | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return TraveltimeQueryParser.NAME; | ||
} | ||
|
||
@Override | ||
public TransportVersion getMinimalSupportedVersion() { | ||
return TransportVersions.MINIMUM_COMPATIBLE; | ||
} | ||
|
||
public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { | ||
return AbstractQueryBuilder.parseInnerQueryBuilder(parser); | ||
} | ||
|
||
|
||
} |
Oops, something went wrong.