Skip to content

Commit

Permalink
OPIK-855 Add type to prompt version
Browse files Browse the repository at this point in the history
  • Loading branch information
Borys Tkachenko authored and Borys Tkachenko committed Jan 30, 2025
1 parent 6fd37cb commit b4f2139
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public record Prompt(
Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template,
@JsonView({Prompt.View.Write.class}) @Nullable JsonNode metadata,
@JsonView({Prompt.View.Write.class}) @Nullable String changeDescription,
@JsonView({Prompt.View.Write.class}) @Nullable PromptType type,
@JsonView({Prompt.View.Public.class,
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({Prompt.View.Public.class,
Expand Down
28 changes: 28 additions & 0 deletions apps/opik-backend/src/main/java/com/comet/opik/api/PromptType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import lombok.Getter;
import lombok.RequiredArgsConstructor;

import java.util.Arrays;

@Getter
@RequiredArgsConstructor
public enum PromptType {

MUSTACHE("mustache"),
JINJA2("jinja2");

@JsonValue
private final String value;

@JsonCreator
public static PromptType fromString(String value) {
if (value == null) return MUSTACHE;
return Arrays.stream(values())
.filter(promptType -> promptType.value.equals(value))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("Unknown prompt type '%s'".formatted(value)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public record PromptVersion(
PromptVersion.View.Detail.class}) @NotBlank String template,
@Json @JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) JsonNode metadata,
@JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) PromptType type,
@JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) String changeDescription,
@JsonView({Prompt.View.Detail.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ SELECT JSON_OBJECT(
'template', pv.template,
'metadata', pv.metadata,
'change_description', pv.change_description,
'type', pv.type,
'created_at', pv.created_at,
'created_by', pv.created_by,
'last_updated_at', pv.last_updated_at,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt,
.template(promptRequest.template())
.metadata(promptRequest.metadata())
.changeDescription(promptRequest.changeDescription())
.type(promptRequest.type())
.createdBy(createdPrompt.createdBy())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
@RegisterConstructorMapper(PromptVersionId.class)
interface PromptVersionDAO {

@SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, metadata, change_description, created_by, workspace_id) "
@SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, metadata, change_description, type, created_by, workspace_id) "
+
"VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.metadata, :bean.changeDescription, :bean.createdBy, :workspace_id)")
"VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.metadata, :bean.changeDescription, :bean.type, :bean.createdBy, :workspace_id)")
void save(@Bind("workspace_id") String workspaceId, @BindMethods("bean") PromptVersion prompt);

@SqlQuery("SELECT * FROM prompt_versions WHERE id IN (<ids>) AND workspace_id = :workspace_id")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.comet.opik.infrastructure.db;

import com.comet.opik.api.PromptType;
import com.comet.opik.api.PromptVersion;
import com.comet.opik.utils.JsonUtils;
import com.comet.opik.utils.MustacheUtils;
Expand Down Expand Up @@ -35,6 +36,7 @@ private PromptVersion mapObject(JsonNode jsonNode) {
.template(jsonNode.get("template").asText())
.metadata(jsonNode.get("metadata"))
.changeDescription(jsonNode.get("change_description").asText())
.type(PromptType.fromString(jsonNode.get("type").asText()))
.variables(MustacheUtils.extractVariables(jsonNode.get("template").asText()))
.createdAt(Instant.from(FORMATTER.parse(jsonNode.get("created_at").asText())))
.createdBy(jsonNode.get("created_by").asText())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--liquibase formatted sql
--changeset BorisTkachenko:add_type_to_prompt_version

ALTER TABLE prompt_versions ADD COLUMN type ENUM('mustache', 'jinja2') NOT NULL DEFAULT 'mustache';

--rollback ALTER TABLE prompt_versions DROP COLUMN type;
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class PromptResourceTest {
private static final TestDropwizardAppExtension app;

private static final WireMockUtils.WireMockRuntime wireMock;
private static final String[] IGNORED_FIELDS = {"latestVersion", "template", "metadata", "changeDescription"};
private static final String[] IGNORED_FIELDS = {"latestVersion", "template", "metadata", "changeDescription",
"type"};

static {
Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join();
Expand Down Expand Up @@ -1393,6 +1394,7 @@ void when__promptHasMultipleVersions__thenReturnPromptWithLatestVersion() {
.template(promptVersion.template())
.metadata(promptVersion.metadata())
.changeDescription(promptVersion.changeDescription())
.type(promptVersion.type())
.versionCount(2L)
.build();

Expand Down Expand Up @@ -1456,6 +1458,7 @@ private void assertLatestVersion(Prompt actualPrompt, Prompt expectedPrompt, Set
assertThat(promptVersion.template()).isEqualTo(expectedPrompt.template());
assertThat(promptVersion.metadata()).isEqualTo(expectedPrompt.metadata());
assertThat(promptVersion.changeDescription()).isEqualTo(expectedPrompt.changeDescription());
assertThat(promptVersion.type()).isEqualTo(expectedPrompt.type());
assertThat(promptVersion.variables()).isEqualTo(expectedVariables);
assertThat(promptVersion.createdBy()).isEqualTo(USER);
assertThat(promptVersion.createdAt()).isBetween(expectedPrompt.createdAt(), Instant.now());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package com.comet.opik.podam.manufacturer;

import com.comet.opik.api.PromptType;
import com.comet.opik.api.PromptVersion;
import com.fasterxml.jackson.databind.JsonNode;
import org.apache.commons.lang3.RandomStringUtils;
import uk.co.jemos.podam.api.AttributeMetadata;
import uk.co.jemos.podam.api.DataProviderStrategy;
import uk.co.jemos.podam.api.PodamUtils;
import uk.co.jemos.podam.common.ManufacturingContext;
import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;

Expand Down Expand Up @@ -43,10 +45,15 @@ public PromptVersion getType(DataProviderStrategy strategy, AttributeMetadata me
.template(template)
.metadata(strategy.getTypeValue(metadata, context, JsonNode.class))
.changeDescription(strategy.getTypeValue(metadata, context, String.class))
.type(randomPromptType())
.variables(Set.of(variable1, variable2, variable3))
.promptId(strategy.getTypeValue(metadata, context, UUID.class))
.createdBy(strategy.getTypeValue(metadata, context, String.class))
.createdAt(Instant.now())
.build();
}

public PromptType randomPromptType() {
return PromptType.values()[PodamUtils.getIntegerInRange(0, PromptType.values().length - 1)];
}
}

0 comments on commit b4f2139

Please sign in to comment.