Skip to content

Commit

Permalink
Merge pull request #668 from jmartisk/issue-667
Browse files Browse the repository at this point in the history
Avoid trying to retrieve the default ds if a named one should be used
  • Loading branch information
geoand authored Jun 12, 2024
2 parents 741c90f + 4d8a69f commit 470a775
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 71 deletions.
2 changes: 1 addition & 1 deletion embedding-stores/pgvector/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5</artifactId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
package io.quarkiverse.langchain4j.pgvector.test;

import java.util.Map;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.QuarkusTestProfile;
import io.quarkus.test.junit.TestProfile;
import io.quarkus.test.QuarkusUnitTest;

@QuarkusTest
@TestProfile(ColumnsTest.TestProfile.class)
class ColumnsTest extends LangChain4jPgVectorBaseTest {

public static class TestProfile implements QuarkusTestProfile {
@Override
public Map<String, String> getConfigOverrides() {
return Map.of(
"quarkus.langchain4j.pgvector.metadata.storage-mode", "COLUMN_PER_KEY",
"quarkus.langchain4j.pgvector.metadata.column-definitions", "key text NULL, name text NULL, " +
"age float NULL, city varchar null, country varchar null",
"quarkus.langchain4j.pgvector.metadata.indexes", "key, name, age");
}
}
@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
"quarkus.langchain4j.pgvector.dimension=384\n" +
"quarkus.langchain4j.pgvector.drop-table-first=true\n" +
"quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n" +
"quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n" +
"quarkus.langchain4j.pgvector.metadata.storage-mode=COLUMN_PER_KEY\n" +
"quarkus.langchain4j.pgvector.metadata.column-definitions=key text NULL, name text NULL, " +
"age float NULL, city varchar null, country varchar null\n" +
"quarkus.langchain4j.pgvector.metadata.indexes=key, name, age"),
"application.properties"));

@Test
// do not test parent method to avoid defining all the metadata fields
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
package io.quarkiverse.langchain4j.pgvector.test;

import java.util.Map;
import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.QuarkusTestProfile;
import io.quarkus.test.junit.TestProfile;
import io.quarkus.test.QuarkusUnitTest;

@QuarkusTest
@TestProfile(JSONBMultiIndexTest.TestProfile.class)
public class JSONBMultiIndexTest extends LangChain4jPgVectorBaseTest {

public static class TestProfile implements QuarkusTestProfile {

@Override
public Map<String, String> getConfigOverrides() {
return Map.of(
"quarkus.langchain4j.pgvector.metadata.storage-mode", "COMBINED_JSONB",
"quarkus.langchain4j.pgvector.metadata.column-definitions", "metadata_b JSONB NULL",
"quarkus.langchain4j.pgvector.metadata.indexes",
"(metadata_b->'key'), (metadata_b->'name'), (metadata_b->'age')",
"quarkus.langchain4j.pgvector.metadata.index-type", "GIN");
}
}
@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
"quarkus.langchain4j.pgvector.dimension=384\n" +
"quarkus.langchain4j.pgvector.drop-table-first=true\n" +
"quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n" +
"quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n" +
"quarkus.langchain4j.pgvector.metadata.storage-mode=COMBINED_JSONB\n" +
"quarkus.langchain4j.pgvector.metadata.column-definitions=metadata_b JSONB NULL\n" +
"quarkus.langchain4j.pgvector.metadata.indexes=(metadata_b->'key'), (metadata_b->'name'), (metadata_b->'age')\n"
+
"quarkus.langchain4j.pgvector.metadata.index-type=GIN"),
"application.properties"));

}
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
package io.quarkiverse.langchain4j.pgvector.test;

import java.util.Map;
import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.QuarkusTestProfile;
import io.quarkus.test.junit.TestProfile;
import io.quarkus.test.QuarkusUnitTest;

@QuarkusTest
@TestProfile(JSONBTest.TestProfile.class)
public class JSONBTest extends LangChain4jPgVectorBaseTest {

public static class TestProfile implements QuarkusTestProfile {

@Override
public Map<String, String> getConfigOverrides() {
return Map.of(
"quarkus.langchain4j.pgvector.metadata.storage-mode", "COMBINED_JSONB",
"quarkus.langchain4j.pgvector.metadata.column-definitions", "metadata JSONB NULL",
"quarkus.langchain4j.pgvector.metadata.indexes", "metadata");
}
}
@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
"quarkus.langchain4j.pgvector.dimension=384\n" +
"quarkus.langchain4j.pgvector.drop-table-first=true\n" +
"quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n" +
"quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n" +
"quarkus.langchain4j.pgvector.metadata.storage-mode=COMBINED_JSONB\n" +
"quarkus.langchain4j.pgvector.metadata.column-definitions=metadata JSONB NULL\n" +
"quarkus.langchain4j.pgvector.metadata.indexes=metadata"),
"application.properties"));

}
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
package io.quarkiverse.langchain4j.pgvector.test;

import io.quarkus.test.junit.QuarkusTest;
import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;

@QuarkusTest
public class JSONTest extends LangChain4jPgVectorBaseTest {

@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
"quarkus.langchain4j.pgvector.dimension=384\n" +
"quarkus.langchain4j.pgvector.drop-table-first=true\n" +
"quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n" +
"quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n"),
"application.properties"));

// Default behavior
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import io.quarkus.logging.Log;

abstract class LangChain4jPgVectorBaseTest extends EmbeddingStoreWithFilteringIT {
// FIXME: this should extend EmbeddingStoreWithFilteringIT, but that class
// contains tests parametrized through @MethodSource, which is not supported
// by the quarkus-junit5-internal testing framework
abstract class LangChain4jPgVectorBaseTest extends EmbeddingStoreIT {

private static final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@Inject
protected EmbeddingStore<TextSegment> pgvectorEmbeddingStore;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.quarkiverse.langchain4j.pgvector.test;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;

import javax.sql.DataSource;

import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.quarkus.test.QuarkusUnitTest;

/**
* Verify use of a non-default postgresql datasource
*/
public class PgVectorDataSourceTest {

@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
// DevServicesConfigBuilderCustomizer overrides the image-name only
// for the default DS, so in this case we have to override it manually
"quarkus.datasource.embeddings-ds.devservices.image-name=pgvector/pgvector:pg16\n" +
"quarkus.langchain4j.pgvector.datasource=embeddings-ds\n" +
"quarkus.langchain4j.pgvector.dimension=1536\n" +
"quarkus.datasource.embeddings-ds.db-kind=postgresql\n"),
"application.properties"));

@io.quarkus.agroal.DataSource("embeddings-ds")
DataSource ds;

@Inject
EmbeddingStore<TextSegment> embeddingStore;

@Test
public void verifyThatEmbeddingsTableIsCreated() throws SQLException {
// make sure the store is initialized...
embeddingStore.toString();
try (Connection connection = ds.getConnection()) {
try (Statement statement = connection.createStatement()) {
try (ResultSet rs = statement.executeQuery(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'embeddings')")) {
rs.next();
Assertions.assertTrue(rs.getBoolean(1));
}
}
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.quarkiverse.langchain4j.pgvector.runtime;

import java.util.List;
import java.util.Optional;
import java.util.function.Function;

import jakarta.enterprise.inject.Default;
Expand All @@ -19,10 +18,13 @@ public class PgVectorEmbeddingStoreRecorder {
public Function<SyntheticCreationalContext<PgVectorEmbeddingStore>, PgVectorEmbeddingStore> embeddingStoreFunction(
PgVectorEmbeddingStoreConfig config, String datasourceName) {
return context -> {
AgroalDataSource dataSource = Optional.ofNullable(datasourceName)
.map(DataSourceLiteral::new)
.map(dl -> context.getInjectedReference(AgroalDataSource.class, dl))
.orElse(context.getInjectedReference(AgroalDataSource.class, new Default.Literal()));
AgroalDataSource dataSource = null;
if (datasourceName != null) {
dataSource = context.getInjectedReference(AgroalDataSource.class,
new DataSourceLiteral(datasourceName));
} else {
dataSource = context.getInjectedReference(AgroalDataSource.class, new Default.Literal());
}

dataSource.flush(AgroalDataSource.FlushMode.GRACEFUL);
dataSource.setPoolInterceptors(List.of(new PgVectorAgroalPoolInterceptor()));
Expand Down

0 comments on commit 470a775

Please sign in to comment.