Skip to content

Commit

Permalink
[FSTORE-1135] Preview embedding in vector db in UI (#1698)
Browse files Browse the repository at this point in the history
* preview embedding

* check prefix

* modify error message and remove field constructor

* fix style

(cherry picked from commit 4131e86)
  • Loading branch information
kennethmhc committed Feb 21, 2024
1 parent 71d009a commit e16c2dd
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 3 deletions.
13 changes: 13 additions & 0 deletions hopsworks-IT/src/test/ruby/spec/similarity_search_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@
expect(parsed_json.key?("embeddingIndex")).to be true
expect(parsed_json["embeddingIndex"]["indexName"]).to eq("#{project.id}__embedding_default_project_embedding_0")
expect(parsed_json["embeddingIndex"]["features"].length).to be 1
expect(parsed_json["embeddingIndex"]["colPrefix"]).to eq("#{parsed_json["id"].to_s}_")
end

it "should be able to preview data from vector db" do
project = get_project
featurestore_id = get_featurestore_id(project.id)
json_result, featuregroup_name = create_cached_featuregroup(project.id, featurestore_id, online:true, embedding_index_name: "")
parsed_json = JSON.parse(json_result)
expect_status_details(201)
featuregroup_id = parsed_json["id"]
get "#{ENV['HOPSWORKS_API']}/project/#{project.id.to_s}/featurestores/#{featurestore_id.to_s}/featuregroups/#{featuregroup_id.to_s}/preview?storage=online"
expect_status_details(200)
end

it "should be able to delete a feature group with embedding and project index" do
Expand Down Expand Up @@ -76,6 +88,7 @@
expect(parsed_json.key?("embeddingIndex")).to be true
expect(parsed_json["embeddingIndex"]["indexName"]).to eq("#{project.id}__embedding_test_index")
expect(parsed_json["embeddingIndex"]["features"].length).to be 1
expect(parsed_json["embeddingIndex"]["colPrefix"]).to eq("")
end

it "should be able to delete a feature group with embedding and custom index" do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import com.google.common.collect.Sets;
import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingDTO;
import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController;
import io.hops.hopsworks.common.hdfs.Utils;
import io.hops.hopsworks.common.models.ModelFacade;
import io.hops.hopsworks.common.models.version.ModelVersionFacade;
import io.hops.hopsworks.common.util.Settings;
Expand Down Expand Up @@ -85,6 +84,12 @@ private ModelVersion getModel(Integer projectId, String modelName, Integer model
return modelVersionFacade.findByProjectAndMlId(model.getId(), modelVersion);
}

public String getFieldName(Embedding embedding, String featureName) {
return embedding.getColPrefix() == null
? featureName
: embedding.getColPrefix() + featureName;
}

public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featuregroup featuregroup)
throws FeaturestoreException {
Embedding embedding = new Embedding();
Expand Down Expand Up @@ -247,7 +252,8 @@ private String getVectorDbIndexPrefix(Project project) {
}

private String getVectorDbColPrefix(Featuregroup featuregroup) {
return Utils.getFeaturegroupName(featuregroup) + "_";
// Should use id as prefix instead of name + version since users can recreate fg with the same name and version
return featuregroup.getId() + "_";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ public void addValue(Pair<String, String> value) {
values.add(value);
}

public void addValue(String col, String value) {
values.add(new Pair<>(col, value));
}

public List<Pair<String, String>> getValues() {
return values;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.logicalclocks.shaded.org.apache.commons.lang3.StringUtils;
import io.hops.hopsworks.common.dao.kafka.TopicDTO;
import io.hops.hopsworks.common.featurestore.embedding.EmbeddingController;
import io.hops.hopsworks.common.featurestore.embedding.VectorDatabaseClient;
import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO;
import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController;
import io.hops.hopsworks.common.featurestore.featuregroup.cached.FeaturegroupPreview;
Expand All @@ -46,6 +47,9 @@
import io.hops.hopsworks.persistence.entity.project.Project;
import io.hops.hopsworks.persistence.entity.user.Users;
import io.hops.hopsworks.restutils.RESTCodes;
import io.hops.hopsworks.vectordb.Field;
import io.hops.hopsworks.vectordb.Index;
import io.hops.hopsworks.vectordb.VectorDatabaseException;
import org.apache.calcite.sql.SqlDialect;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlLiteral;
Expand All @@ -63,6 +67,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.logging.Level;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -98,6 +103,8 @@ public class OnlineFeaturegroupController {
private FeaturegroupController featuregroupController;
@EJB
private EmbeddingController embeddingController;
@EJB
private VectorDatabaseClient vectorDatabaseClient;

private final static List<String> SUPPORTED_MYSQL_TYPES = Arrays.asList("INT", "TINYINT",
"SMALLINT", "BIGINT", "FLOAT", "DOUBLE", "DECIMAL", "DATE", "TIMESTAMP");
Expand Down Expand Up @@ -361,6 +368,47 @@ public String getOnlineType(FeatureGroupFeatureDTO featureGroupFeatureDTO) {
*/
public FeaturegroupPreview getFeaturegroupPreview(Featuregroup featuregroup, Project project, Users user, int limit)
throws FeaturestoreException {
if (featuregroup.getEmbedding() == null) {
return getFeaturegroupPreviewRonDb(featuregroup, project, user, limit);
} else {
return getFeaturegroupPreviewVectorDb(featuregroup, project, user, limit);
}
}

private FeaturegroupPreview getFeaturegroupPreviewVectorDb(Featuregroup featuregroup, Project project, Users user,
int limit) throws FeaturestoreException {
try {
List<FeatureGroupFeatureDTO> features = featuregroupController.getFeatures(featuregroup, project, user);
Index index = new Index(featuregroup.getEmbedding().getVectorDbIndexName());
Set<String> primaryKeyFields = featuregroupController.getPrimaryKey(featuregroup)
.stream()
// fetching documents where pk column of the fg is not null
.map(pk -> embeddingController.getFieldName(featuregroup.getEmbedding(), pk.getName()))
.collect(Collectors.toSet());
Set<Field> targetFields = vectorDatabaseClient.getClient().getSchema(index)
.stream()
.filter(field -> primaryKeyFields.contains(field.getName()))
.collect(Collectors.toSet());
FeaturegroupPreview preview = new FeaturegroupPreview();
vectorDatabaseClient.getClient().preview(index, targetFields, limit)
.forEach(result -> {
FeaturegroupPreview.Row row = new FeaturegroupPreview.Row();
features.forEach(feature -> row.addValue(
feature.getName(),
result.getOrDefault(
embeddingController.getFieldName(
featuregroup.getEmbedding(), feature.getName()), "").toString()));
preview.addRow(row);
});
return preview;
} catch (VectorDatabaseException e) {
throw new FeaturestoreException(
RESTCodes.FeaturestoreErrorCode.COULD_NOT_PREVIEW_DATA_IN_VECTOR_DB, Level.SEVERE, "", e.getMessage(), e);
}
}

private FeaturegroupPreview getFeaturegroupPreviewRonDb(Featuregroup featuregroup, Project project,
Users user, int limit) throws FeaturestoreException {
String tbl = featuregroupController.getTblName(featuregroup.getName(), featuregroup.getVersion());

List<FeatureGroupFeatureDTO> features = featuregroupController.getFeatures(featuregroup, project, user);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.hops.hopsworks.persistence.entity.featurestore.featuregroup;

import io.hops.hopsworks.persistence.entity.models.version.ModelVersion;
import com.fasterxml.jackson.annotation.JsonIgnore;

import javax.persistence.Basic;
import javax.persistence.Column;
Expand Down Expand Up @@ -104,4 +105,12 @@ public String getSimilarityFunctionType() {
public ModelVersion getModelVersion() {
return modelVersion;
}

@JsonIgnore
public String getFieldName() {
return embedding.getColPrefix() == null
? name
: embedding.getColPrefix() + name;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -1698,7 +1698,11 @@ public enum FeaturestoreErrorCode implements RESTErrorCode {
Response.Status.INTERNAL_SERVER_ERROR),
FEATURE_MONITORING_ENTITY_NOT_FOUND(233, "Feature Monitoring entity not found.",
Response.Status.NOT_FOUND),
FEATURE_MONITORING_NOT_ENABLED(234, "Feature monitoring is not enabled.", Response.Status.BAD_REQUEST);
FEATURE_MONITORING_NOT_ENABLED(234, "Feature monitoring is not enabled.", Response.Status.BAD_REQUEST),
FEATURE_NOT_FOUND_IN_VECTOR_DB(235, "Feature not found in vector db.",
Response.Status.INTERNAL_SERVER_ERROR),
COULD_NOT_PREVIEW_DATA_IN_VECTOR_DB(236, "Could not preview data in vector database.",
Response.Status.INTERNAL_SERVER_ERROR);

private int code;
private String message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.master.AcknowledgedResponse;
import org.opensearch.client.Request;
import org.opensearch.client.RequestOptions;
Expand All @@ -36,10 +38,14 @@
import org.opensearch.client.indices.GetIndexResponse;
import org.opensearch.client.indices.PutMappingRequest;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryStringQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -193,6 +199,42 @@ public void batchWrite(Index index, Map<String, String> data) throws VectorDatab
bulkRequest(bulkRequest);
}

@Override
public List<Map<String, Object>> preview(Index index, Set<Field> fields, int n) throws VectorDatabaseException {
List<Map<String, Object>> results = Lists.newArrayList();

// Create a bool query
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();

// Add exists queries for each field
for (Field field : fields) {
boolQueryBuilder.must(QueryBuilders.existsQuery(field.getName()));
}

// Create a SearchSourceBuilder to define the search query
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.query(boolQueryBuilder);
sourceBuilder.size(n);

// Create a SearchRequest with the specified index and source builder
SearchRequest searchRequest = new SearchRequest(index.getName());
searchRequest.source(sourceBuilder);

try {
// Execute the search request
SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
// Process the search hits
for (SearchHit hit : searchResponse.getHits().getHits()) {
results.add(hit.getSourceAsMap());
}

} catch (IOException e) {
throw new VectorDatabaseException("Error occurred while querying OpenSearch index");
}

return results;
}

private void bulkRequest(BulkRequest bulkRequest) throws VectorDatabaseException {
try {
BulkResponse response = client.bulk(bulkRequest, RequestOptions.DEFAULT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ public interface VectorDatabase {
void batchWrite(Index index, List<String> data) throws VectorDatabaseException;
void batchWrite(Index index, Map<String, String> data) throws VectorDatabaseException;
void deleteByQuery(Index index, String query) throws VectorDatabaseException;
List<Map<String, Object>> preview(Index index, Set<Field> fields, int n) throws VectorDatabaseException;
void close();
}

0 comments on commit e16c2dd

Please sign in to comment.