diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java index 21829957f4..44d929df4a 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.UUID; +import static com.comet.opik.api.PromptType.MUSTACHE; import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK; @Builder(toBuilder = true) @@ -33,6 +34,7 @@ public record Prompt( Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template, @JsonView({Prompt.View.Write.class}) @Nullable JsonNode metadata, @JsonView({Prompt.View.Write.class}) @Nullable String changeDescription, + @JsonView({Prompt.View.Write.class}) @Nullable PromptType type, @JsonView({Prompt.View.Public.class, Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, @JsonView({Prompt.View.Public.class, @@ -74,4 +76,9 @@ public static Prompt.PromptPage empty(int page) { return new Prompt.PromptPage(page, 0, 0, List.of()); } } + + @Override + public PromptType type() { + return type == null ? MUSTACHE : type; + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptType.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptType.java new file mode 100644 index 0000000000..891299af4e --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptType.java @@ -0,0 +1,27 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.Arrays; + +@Getter +@RequiredArgsConstructor +public enum PromptType { + + MUSTACHE("mustache"), + JINJA2("jinja2"); + + @JsonValue + private final String value; + + @JsonCreator + public static PromptType fromString(String value) { + return Arrays.stream(values()) + .filter(promptType -> promptType.value.equals(value)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Unknown prompt type '%s'".formatted(value))); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java index 0a8a47ae34..8a15e18aff 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java @@ -18,6 +18,8 @@ import java.util.Set; import java.util.UUID; +import static com.comet.opik.api.PromptType.MUSTACHE; + @Builder(toBuilder = true) @JsonIgnoreProperties(ignoreUnknown = true) @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) @@ -35,6 +37,8 @@ public record PromptVersion( PromptVersion.View.Detail.class}) @NotBlank String template, @Json @JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class, PromptVersion.View.Detail.class}) JsonNode metadata, + @JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class, + PromptVersion.View.Detail.class}) PromptType type, @JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class, PromptVersion.View.Detail.class}) String changeDescription, @JsonView({Prompt.View.Detail.class, @@ -68,4 +72,9 @@ public static PromptVersion.PromptVersionPage empty(int page) { return new PromptVersion.PromptVersionPage(page, 0, 0, List.of()); } } + + @Override + public PromptType type() { + return type == null ? MUSTACHE : type; + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java index ca9733472d..73121e0685 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java @@ -1,9 +1,10 @@ package com.comet.opik.api.resources.v1.events; import com.comet.opik.api.FeedbackScoreBatchItem; +import com.comet.opik.api.PromptType; import com.comet.opik.api.ScoreSource; import com.comet.opik.api.Trace; -import com.comet.opik.utils.MustacheUtils; +import com.comet.opik.utils.TemplateParseUtils; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; @@ -108,7 +109,8 @@ static List renderMessages(List templateMessages return templateMessages.stream() .map(templateMessage -> { // will convert all '{{key}}' into 'value' - var renderedMessage = MustacheUtils.render(templateMessage.content(), replacements); + var renderedMessage = TemplateParseUtils.render(templateMessage.content(), replacements, + PromptType.MUSTACHE); return switch (templateMessage.role()) { case USER -> UserMessage.from(renderedMessage); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java index 19575541b8..17d63431d8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java @@ -44,6 +44,7 @@ SELECT JSON_OBJECT( 'template', pv.template, 'metadata', pv.metadata, 'change_description', pv.change_description, + 'type', pv.type, 'created_at', pv.created_at, 'created_by', pv.created_by, 'last_updated_at', pv.last_updated_at, diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java index afdffdf177..a533b64705 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -2,11 +2,12 @@ import com.comet.opik.api.CreatePromptVersion; import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptType; import com.comet.opik.api.PromptVersion; import com.comet.opik.api.PromptVersion.PromptVersionPage; import com.comet.opik.api.error.EntityAlreadyExistsException; import com.comet.opik.infrastructure.auth.RequestContext; -import com.comet.opik.utils.MustacheUtils; +import com.comet.opik.utils.TemplateParseUtils; import com.google.inject.ImplementedBy; import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.inject.Inject; @@ -119,6 +120,7 @@ private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt, .template(promptRequest.template()) .metadata(promptRequest.metadata()) .changeDescription(promptRequest.changeDescription()) + .type(promptRequest.type()) .createdBy(createdPrompt.createdBy()) .build(); @@ -347,7 +349,7 @@ private PromptVersion getById(String workspaceId, UUID id) { }); return promptVersion.toBuilder() - .variables(getVariables(promptVersion.template())) + .variables(getVariables(promptVersion.template(), promptVersion.type())) .build(); } @@ -368,7 +370,7 @@ public Prompt getById(@NonNull UUID id) { .latestVersion( Optional.ofNullable(prompt.latestVersion()) .map(promptVersion -> promptVersion.toBuilder() - .variables(getVariables(promptVersion.template())) + .variables(getVariables(promptVersion.template(), promptVersion.type())) .build()) .orElse(null)) .build(); @@ -388,7 +390,7 @@ public Mono> findVersionByIds(@NonNull Set ids) { return promptVersionDAO.findByIds(ids, workspaceId).stream() .collect(toMap(PromptVersion::id, promptVersion -> promptVersion.toBuilder() - .variables(getVariables(promptVersion.template())) + .variables(getVariables(promptVersion.template(), promptVersion.type())) .build())); }); }) @@ -413,16 +415,16 @@ public PromptVersion getVersionById(@NonNull String workspaceId, @NonNull UUID i } return promptVersion.toBuilder() - .variables(getVariables(promptVersion.template())) + .variables(getVariables(promptVersion.template(), promptVersion.type())) .build(); } - private Set getVariables(String template) { + private Set getVariables(String template, PromptType type) { if (template == null) { return null; } - return MustacheUtils.extractVariables(template); + return TemplateParseUtils.extractVariables(template, type); } private EntityAlreadyExistsException newConflict(String alreadyExists) { @@ -491,7 +493,7 @@ public PromptVersion retrievePromptVersion(@NonNull String name, String commit) } return promptVersion.toBuilder() - .variables(getVariables(promptVersion.template())) + .variables(getVariables(promptVersion.template(), promptVersion.type())) .build(); }); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java index f4dd29068d..68acdd7e73 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java @@ -20,9 +20,9 @@ @RegisterConstructorMapper(PromptVersionId.class) interface PromptVersionDAO { - @SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, metadata, change_description, created_by, workspace_id) " + @SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, metadata, change_description, type, created_by, workspace_id) " + - "VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.metadata, :bean.changeDescription, :bean.createdBy, :workspace_id)") + "VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.metadata, :bean.changeDescription, :bean.type, :bean.createdBy, :workspace_id)") void save(@Bind("workspace_id") String workspaceId, @BindMethods("bean") PromptVersion prompt); @SqlQuery("SELECT * FROM prompt_versions WHERE id IN () AND workspace_id = :workspace_id") diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/template/Jinja2Parser.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/template/Jinja2Parser.java new file mode 100644 index 0000000000..9c1af68072 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/template/Jinja2Parser.java @@ -0,0 +1,16 @@ +package com.comet.opik.domain.template; + +import java.util.Map; +import java.util.Set; + +public class Jinja2Parser implements TemplateParser { + @Override + public Set extractVariables(String template) { + return Set.of(); + } + + @Override + public String render(String template, Map context) { + throw new UnsupportedOperationException("Jinja2 template rendering is not supported"); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/utils/MustacheUtils.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/template/MustacheParser.java similarity index 87% rename from apps/opik-backend/src/main/java/com/comet/opik/utils/MustacheUtils.java rename to apps/opik-backend/src/main/java/com/comet/opik/domain/template/MustacheParser.java index 7544c9bfbb..d559d0146f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/utils/MustacheUtils.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/template/MustacheParser.java @@ -1,11 +1,10 @@ -package com.comet.opik.utils; +package com.comet.opik.domain.template; import com.github.mustachejava.Code; import com.github.mustachejava.DefaultMustacheFactory; import com.github.mustachejava.Mustache; import com.github.mustachejava.MustacheFactory; import com.github.mustachejava.codes.ValueCode; -import lombok.experimental.UtilityClass; import java.io.IOException; import java.io.StringReader; @@ -18,12 +17,12 @@ import java.util.Optional; import java.util.Set; -@UtilityClass -public class MustacheUtils { +public class MustacheParser implements TemplateParser { public static final MustacheFactory MF = new DefaultMustacheFactory(); - public static Set extractVariables(String template) { + @Override + public Set extractVariables(String template) { Set variables = new HashSet<>(); // Initialize Mustache Factory @@ -36,7 +35,8 @@ public static Set extractVariables(String template) { return variables; } - public static String render(String template, Map context) { + @Override + public String render(String template, Map context) { Mustache mustache = MF.compile(new StringReader(template), "template"); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/template/TemplateParser.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/template/TemplateParser.java new file mode 100644 index 0000000000..66d9c2fa29 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/template/TemplateParser.java @@ -0,0 +1,11 @@ +package com.comet.opik.domain.template; + +import java.util.Map; +import java.util.Set; + +public interface TemplateParser { + + Set extractVariables(String template); + + String render(String template, Map context); +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/PromptVersionColumnMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/PromptVersionColumnMapper.java index e27462769f..dc1a9e1b24 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/PromptVersionColumnMapper.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/PromptVersionColumnMapper.java @@ -1,8 +1,9 @@ package com.comet.opik.infrastructure.db; +import com.comet.opik.api.PromptType; import com.comet.opik.api.PromptVersion; import com.comet.opik.utils.JsonUtils; -import com.comet.opik.utils.MustacheUtils; +import com.comet.opik.utils.TemplateParseUtils; import com.fasterxml.jackson.databind.JsonNode; import org.jdbi.v3.core.mapper.ColumnMapper; import org.jdbi.v3.core.statement.StatementContext; @@ -35,7 +36,9 @@ private PromptVersion mapObject(JsonNode jsonNode) { .template(jsonNode.get("template").asText()) .metadata(jsonNode.get("metadata")) .changeDescription(jsonNode.get("change_description").asText()) - .variables(MustacheUtils.extractVariables(jsonNode.get("template").asText())) + .type(PromptType.fromString(jsonNode.get("type").asText())) + .variables(TemplateParseUtils.extractVariables(jsonNode.get("template").asText(), + PromptType.fromString(jsonNode.get("type").asText()))) .createdAt(Instant.from(FORMATTER.parse(jsonNode.get("created_at").asText()))) .createdBy(jsonNode.get("created_by").asText()) .build(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/utils/TemplateParseUtils.java b/apps/opik-backend/src/main/java/com/comet/opik/utils/TemplateParseUtils.java new file mode 100644 index 0000000000..9ef4147ae0 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/utils/TemplateParseUtils.java @@ -0,0 +1,30 @@ +package com.comet.opik.utils; + +import com.comet.opik.api.PromptType; +import com.comet.opik.domain.template.Jinja2Parser; +import com.comet.opik.domain.template.MustacheParser; +import com.comet.opik.domain.template.TemplateParser; +import lombok.NonNull; +import lombok.experimental.UtilityClass; + +import java.util.EnumMap; +import java.util.Map; +import java.util.Set; + +@UtilityClass +public class TemplateParseUtils { + private static final Map parsers = new EnumMap<>(PromptType.class); + + static { + parsers.put(PromptType.MUSTACHE, new MustacheParser()); + parsers.put(PromptType.JINJA2, new Jinja2Parser()); + } + + public static Set extractVariables(@NonNull String template, @NonNull PromptType type) { + return parsers.get(type).extractVariables(template); + } + + public static String render(@NonNull String template, @NonNull Map context, @NonNull PromptType type) { + return parsers.get(type).render(template, context); + } +} diff --git a/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000010_add_type_to_prompt_version.sql b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000010_add_type_to_prompt_version.sql new file mode 100644 index 0000000000..a4c14ee31e --- /dev/null +++ b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000010_add_type_to_prompt_version.sql @@ -0,0 +1,6 @@ +--liquibase formatted sql +--changeset BorisTkachenko:add_type_to_prompt_version + +ALTER TABLE prompt_versions ADD COLUMN type ENUM('mustache', 'jinja2') NOT NULL DEFAULT 'mustache'; + +--rollback ALTER TABLE prompt_versions DROP COLUMN type; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index e31eb0fac3..3efc16b47a 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -466,6 +466,8 @@ void getDatasets__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, b @DisplayName("Update dataset: when api key is present, then return proper response") void updateDataset__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean shouldSucceed) { + mockTargetWorkspace(okApikey, TEST_WORKSPACE, WORKSPACE_ID); + var dataset = factory.manufacturePojo(Dataset.class).toBuilder() .id(null) .build(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java index ed72672cdb..7c998c1085 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -3,6 +3,7 @@ import com.comet.opik.api.BatchDelete; import com.comet.opik.api.CreatePromptVersion; import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptType; import com.comet.opik.api.PromptVersion; import com.comet.opik.api.PromptVersionRetrieve; import com.comet.opik.api.error.ErrorMessage; @@ -18,6 +19,7 @@ import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.podam.PodamFactoryUtils; +import com.comet.opik.utils.TemplateParseUtils; import com.github.tomakehurst.wiremock.client.WireMock; import com.redis.testcontainers.RedisContainer; import jakarta.ws.rs.client.Entity; @@ -41,7 +43,9 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.NullSource; import org.testcontainers.clickhouse.ClickHouseContainer; import org.testcontainers.containers.MySQLContainer; import org.testcontainers.lifecycle.Startables; @@ -93,7 +97,8 @@ class PromptResourceTest { private static final TestDropwizardAppExtension app; private static final WireMockUtils.WireMockRuntime wireMock; - private static final String[] IGNORED_FIELDS = {"latestVersion", "template", "metadata", "changeDescription"}; + private static final String[] IGNORED_FIELDS = {"latestVersion", "template", "metadata", "changeDescription", + "type"}; static { Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join(); @@ -833,14 +838,16 @@ private UUID createPrompt(Prompt prompt, String apiKey, String workspaceName) { @TestInstance(TestInstance.Lifecycle.PER_CLASS) class CreatePrompt { - @Test + @ParameterizedTest + @NullSource + @EnumSource(PromptType.class) @DisplayName("Success: should create prompt") - void shouldCreatePrompt() { + void shouldCreatePrompt(PromptType type) { var prompt = factory.manufacturePojo(Prompt.class).toBuilder() .lastUpdatedBy(USER) .createdBy(USER) - .template(null) + .type(type) .build(); var promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); @@ -1354,14 +1361,17 @@ void when__fetchPromptsUsingPagination__thenReturnPromptsPaginated() { @TestInstance(TestInstance.Lifecycle.PER_CLASS) class GetPromptById { - @Test + @ParameterizedTest + @NullSource + @EnumSource(PromptType.class) @DisplayName("Success: should get prompt by id") - void shouldGetPromptById() { + void shouldGetPromptById(PromptType type) { var prompt = factory.manufacturePojo(Prompt.class).toBuilder() .lastUpdatedBy(USER) .createdBy(USER) .versionCount(1L) + .type(type) .build(); UUID promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); @@ -1393,6 +1403,7 @@ void when__promptHasMultipleVersions__thenReturnPromptWithLatestVersion() { .template(promptVersion.template()) .metadata(promptVersion.metadata()) .changeDescription(promptVersion.changeDescription()) + .type(promptVersion.type()) .versionCount(2L) .build(); @@ -1456,6 +1467,7 @@ private void assertLatestVersion(Prompt actualPrompt, Prompt expectedPrompt, Set assertThat(promptVersion.template()).isEqualTo(expectedPrompt.template()); assertThat(promptVersion.metadata()).isEqualTo(expectedPrompt.metadata()); assertThat(promptVersion.changeDescription()).isEqualTo(expectedPrompt.changeDescription()); + assertThat(promptVersion.type()).isEqualTo(expectedPrompt.type()); assertThat(promptVersion.variables()).isEqualTo(expectedVariables); assertThat(promptVersion.createdBy()).isEqualTo(USER); assertThat(promptVersion.createdAt()).isBetween(expectedPrompt.createdAt(), Instant.now()); @@ -1824,9 +1836,11 @@ void when__promptHasNotVersions__thenReturnEmptyPage() { @TestInstance(TestInstance.Lifecycle.PER_CLASS) class GetPromptVersionById { - @Test + @ParameterizedTest + @NullSource + @EnumSource(PromptType.class) @DisplayName("Success: should get prompt version by id") - void shouldGetPromptVersionById() { + void shouldGetPromptVersionById(PromptType type) { var prompt = factory.manufacturePojo(Prompt.class).toBuilder() .lastUpdatedBy(USER) @@ -1839,6 +1853,7 @@ void shouldGetPromptVersionById() { var promptVersion = factory.manufacturePojo(PromptVersion.class).toBuilder() .createdBy(USER) .promptId(promptId) + .type(type) .build(); var request = new CreatePromptVersion(prompt.name(), promptVersion); @@ -2172,7 +2187,8 @@ private void assertPromptVersion(PromptVersion createdPromptVersion, PromptVersi assertThat(createdPromptVersion.promptId()).isEqualTo(promptId); assertThat(createdPromptVersion.template()).isEqualTo(promptVersion.template()); - assertThat(createdPromptVersion.variables()).isEqualTo(promptVersion.variables()); + assertThat(createdPromptVersion.variables()) + .isEqualTo(TemplateParseUtils.extractVariables(promptVersion.template(), promptVersion.type())); assertThat(createdPromptVersion.createdAt()).isBetween(promptVersion.createdAt(), Instant.now()); assertThat(createdPromptVersion.createdBy()).isEqualTo(USER); } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/PromptVersionManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/PromptVersionManufacturer.java index ec60a4f1a9..8164748eb8 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/PromptVersionManufacturer.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/PromptVersionManufacturer.java @@ -1,15 +1,16 @@ package com.comet.opik.podam.manufacturer; +import com.comet.opik.api.PromptType; import com.comet.opik.api.PromptVersion; import com.fasterxml.jackson.databind.JsonNode; import org.apache.commons.lang3.RandomStringUtils; import uk.co.jemos.podam.api.AttributeMetadata; import uk.co.jemos.podam.api.DataProviderStrategy; +import uk.co.jemos.podam.api.PodamUtils; import uk.co.jemos.podam.common.ManufacturingContext; import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer; import java.time.Instant; -import java.util.Set; import java.util.UUID; public class PromptVersionManufacturer extends AbstractTypeManufacturer { @@ -43,10 +44,14 @@ public PromptVersion getType(DataProviderStrategy strategy, AttributeMetadata me .template(template) .metadata(strategy.getTypeValue(metadata, context, JsonNode.class)) .changeDescription(strategy.getTypeValue(metadata, context, String.class)) - .variables(Set.of(variable1, variable2, variable3)) + .type(randomPromptType()) .promptId(strategy.getTypeValue(metadata, context, UUID.class)) .createdBy(strategy.getTypeValue(metadata, context, String.class)) .createdAt(Instant.now()) .build(); } + + public PromptType randomPromptType() { + return PromptType.values()[PodamUtils.getIntegerInRange(0, PromptType.values().length - 1)]; + } }