Skip to content

Commit

Permalink
Add support for version 8.13.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mjanuszkiewicz-tt committed Mar 28, 2024
1 parent ba690ee commit eabe2c9
Show file tree
Hide file tree
Showing 10 changed files with 722 additions and 0 deletions.
6 changes: 6 additions & 0 deletions 8.13/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
buildPlugin(this, '8.13', revisions(0))

compileJava {
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.traveltime.plugin.elasticsearch;


import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase;
import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder;
import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser;
import com.traveltime.plugin.elasticsearch.util.Util;
import com.traveltime.sdk.dto.requests.proto.Country;
import com.traveltime.sdk.dto.requests.proto.RequestType;
import com.traveltime.sdk.dto.requests.proto.Transportation;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.routing.allocation.AllocationService;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.NodeEnvironment;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.repositories.RepositoriesService;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.telemetry.TelemetryProvider;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.watcher.ResourceWatcherService;
import org.elasticsearch.xcontent.NamedXContentRegistry;

import java.net.URI;
import java.time.Duration;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;

public class TraveltimePlugin extends Plugin implements SearchPlugin {
public static final Setting<String> APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope);
public static final Setting<String> API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered);
public static final Setting<Optional<Transportation.Modes>> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope);
public static final Setting<Optional<Country>> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope);

public static final Setting<Optional<RequestType>> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope);
public static final Setting<URI> API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope);

private static final Setting<Integer> CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope);
private static final Setting<Integer> CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope);
private static final Setting<Integer> CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope);

private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) {
TraveltimeCache.INSTANCE.cleanUp();
threadPool.scheduleUnlessShuttingDown(cleanupSeconds, threadPool.generic(), () -> cleanUpAndReschedule(threadPool, cleanupSeconds));
}

@Override
public Collection<?> createComponents(PluginServices pluginServices) {
TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(pluginServices.environment().settings()));
Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(pluginServices.environment().settings()));
Integer cacheSize = CACHE_SIZE.get(pluginServices.environment().settings());

TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry);
cleanUpAndReschedule(pluginServices.threadPool(), cleanupSeconds);

return super.createComponents(pluginServices);
}

@Override
public List<Setting<?>> getSettings() {
return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL);
}

@Override
public List<QuerySpec<?>> getQueries() {
return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()));
}

@Override
public List<FetchSubPhase> getFetchSubPhases(FetchPhaseConstructionContext context) {
return List.of(new TraveltimeFetchPhase());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package com.traveltime.plugin.elasticsearch.query;

import com.traveltime.plugin.elasticsearch.TraveltimeCache;
import lombok.val;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.search.fetch.FetchContext;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.fetch.FetchSubPhaseProcessor;
import org.elasticsearch.search.fetch.StoredFieldsSpec;
import org.elasticsearch.search.fetch.subphase.FieldAndFormat;
import org.elasticsearch.search.fetch.subphase.FieldFetcher;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

public class TraveltimeFetchPhase implements FetchSubPhase {

private static class ParamFinder extends QueryVisitor {
private final List<TraveltimeSearchQuery> paramList = new ArrayList<>();

@Override
public void visitLeaf(Query query) {
if (query instanceof TraveltimeSearchQuery) {
if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) {
paramList.add(((TraveltimeSearchQuery) query));
}
}
}

public TraveltimeSearchQuery getQuery() {
if (paramList.size() == 1) return paramList.get(0);
else return null;
}
}

@Override
public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) {
Query query = fetchContext.query();
val finder = new ParamFinder();
query.visit(finder);
TraveltimeSearchQuery traveltimeQuery = finder.getQuery();
if (traveltimeQuery == null) return null;
TraveltimeQueryParameters params = traveltimeQuery.getParams();
final String output = traveltimeQuery.getOutput();

FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null)));

return new FetchSubPhaseProcessor() {

@Override
public void setNextReader(LeafReaderContext readerContext) {
fieldFetcher.setNextReader(readerContext);
}

@Override
public void process(HitContext hitContext) throws IOException {
val docValues = hitContext.reader().getSortedNumericDocValues(params.getField());
docValues.advance(hitContext.docId());
Integer tt = TraveltimeCache.INSTANCE.get(params, docValues.nextValue());

if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
}
}

@Override
public StoredFieldsSpec storedFieldsSpec() {
return new StoredFieldsSpec(false, false, Set.of(params.getField()));
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package com.traveltime.plugin.elasticsearch.query;

import com.traveltime.plugin.elasticsearch.TraveltimePlugin;
import com.traveltime.plugin.elasticsearch.util.Util;
import com.traveltime.sdk.dto.common.Coordinates;
import com.traveltime.sdk.dto.requests.proto.Country;
import com.traveltime.sdk.dto.requests.proto.RequestType;
import com.traveltime.sdk.dto.requests.proto.Transportation;
import lombok.NonNull;
import lombok.Setter;
import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.geo.GeoUtils;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.mapper.GeoPointFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.*;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.net.URI;
import java.util.Objects;
import java.util.Optional;

@Setter
public class TraveltimeQueryBuilder extends AbstractQueryBuilder<TraveltimeQueryBuilder> {
@NonNull
private String field;
@NonNull
private GeoPoint origin;
private int limit;
private Transportation.Modes mode;
private Country country;
private RequestType requestType;
private QueryBuilder prefilter;
@NonNull
private String output = "";

public TraveltimeQueryBuilder() {
}

public TraveltimeQueryBuilder(StreamInput in) throws IOException {
super(in);
field = in.readString();
origin = in.readGeoPoint();
limit = in.readInt();
mode = in.readOptionalEnum(Transportation.Modes.class);
String c = in.readOptionalString();
if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c));
requestType = in.readOptionalEnum(RequestType.class);
prefilter = in.readOptionalNamedWriteable(QueryBuilder.class);
output = in.readString();
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(field);
out.writeGeoPoint(origin);
out.writeInt(limit);
out.writeOptionalEnum(mode);
out.writeOptionalString(country == null ? null : country.getValue());
out.writeOptionalEnum(requestType);
out.writeOptionalNamedWriteable(prefilter);
out.writeString(output);
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("field", field);
builder.field("origin", origin);
builder.field("limit", limit);
builder.field("mode", mode == null ? null : mode.getValue());
builder.field("country", country == null ? null : country.getValue());
builder.field("requestType", requestType == null ? null : requestType.name());
builder.field("prefilter", prefilter);
builder.field("output", output);
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext);
return super.doRewrite(queryRewriteContext);
}

@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
MappedFieldType originMapping = context.getFieldType(field);
if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) {
throw new QueryShardException(context, "field [" + field + "] is not a geo_point field");
}

GeoUtils.normalizePoint(origin);
if (!GeoUtils.isValidLatitude(origin.getLat())) {
throw new QueryShardException(context, "latitude invalid for origin " + origin);
}
if (!GeoUtils.isValidLongitude(origin.getLon())) {
throw new QueryShardException(context, "longitude invalid for origin " + origin);
}

URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings());
String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings());
String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings());
if (appId.isEmpty()) {
throw new IllegalStateException("Traveltime app id must be set in the config");
}
if (apiKey.isEmpty()) {
throw new IllegalStateException("Traveltime api key must be set in the config");
}

Optional<Transportation.Modes> defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings());
Optional<Country> defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings());
Optional<RequestType> defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings());

Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build();

TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType);
if (params.getMode() == null) {
if (defaultMode.isPresent()) {
params = params.withMode(defaultMode.get());
} else {
throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config");
}
}
if (params.getCountry() == null) {
if (defaultCountry.isPresent()) {
params = params.withCountry(defaultCountry.get());
} else {
throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config");
}
}
if(params.getRequestType() == null) {
if(defaultRequestType.isPresent()) {
params = params.withRequestType(defaultRequestType.get());
} else {
throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config");
}

}
if (params.getLimit() <= 0) {
throw new IllegalStateException("Traveltime limit must be greater than zero");
}

Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null;

return new TraveltimeSearchQuery(params, prefilterQuery, output, appUri, appId, apiKey);
}

@Override
protected boolean doEquals(TraveltimeQueryBuilder other) {
if (!Objects.equals(this.field, other.field)) return false;
if (!Objects.equals(this.origin, other.origin)) return false;
if (!Objects.equals(this.mode, other.mode)) return false;
if (!Objects.equals(this.country, other.country)) return false;
if (!Objects.equals(this.prefilter, other.prefilter)) return false;
if (!Objects.equals(this.output, other.output)) return false;
return this.limit == other.limit;
}

@Override
protected int doHashCode() {
final int PRIME = 59;
int result = 1;
result = result * PRIME + this.field.hashCode();
result = result * PRIME + this.origin.hashCode();
result = result * PRIME + Objects.hashCode(this.mode);
result = result * PRIME + Objects.hashCode(this.country);
result = result * PRIME + Objects.hashCode(this.prefilter);
result = result * PRIME + Objects.hashCode(this.output);
result = result * PRIME + this.limit;
return result;
}

@Override
public String getWriteableName() {
return TraveltimeQueryParser.NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.MINIMUM_COMPATIBLE;
}

public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException {
return AbstractQueryBuilder.parseInnerQueryBuilder(parser);
}


}
Loading

0 comments on commit eabe2c9

Please sign in to comment.