Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-855] Add type to prompt version #1179

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public record Prompt(
Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template,
@JsonView({Prompt.View.Write.class}) @Nullable JsonNode metadata,
@JsonView({Prompt.View.Write.class}) @Nullable String changeDescription,
@JsonView({Prompt.View.Write.class}) @Nullable PromptType type,
thiagohora marked this conversation as resolved.
Show resolved Hide resolved
@JsonView({Prompt.View.Public.class,
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({Prompt.View.Public.class,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import lombok.Getter;
import lombok.RequiredArgsConstructor;

import java.util.Arrays;

@Getter
@RequiredArgsConstructor
public enum PromptType {

MUSTACHE("mustache"),
JINJA2("jinja2");

@JsonValue
private final String value;

@JsonCreator
public static PromptType fromString(String value) {
if (value == null) return MUSTACHE;
return Arrays.stream(values())
.filter(promptType -> promptType.value.equals(value))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("Unknown prompt type '%s'".formatted(value)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public record PromptVersion(
PromptVersion.View.Detail.class}) @NotBlank String template,
@Json @JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) JsonNode metadata,
@JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) PromptType type,
@JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) String changeDescription,
@JsonView({Prompt.View.Detail.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ SELECT JSON_OBJECT(
'template', pv.template,
'metadata', pv.metadata,
'change_description', pv.change_description,
'type', pv.type,
'created_at', pv.created_at,
'created_by', pv.created_by,
'last_updated_at', pv.last_updated_at,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt,
.template(promptRequest.template())
.metadata(promptRequest.metadata())
.changeDescription(promptRequest.changeDescription())
.type(promptRequest.type())
.createdBy(createdPrompt.createdBy())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
@RegisterConstructorMapper(PromptVersionId.class)
interface PromptVersionDAO {

@SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, metadata, change_description, created_by, workspace_id) "
@SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, metadata, change_description, type, created_by, workspace_id) "
+
"VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.metadata, :bean.changeDescription, :bean.createdBy, :workspace_id)")
"VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.metadata, :bean.changeDescription, :bean.type, :bean.createdBy, :workspace_id)")
void save(@Bind("workspace_id") String workspaceId, @BindMethods("bean") PromptVersion prompt);

@SqlQuery("SELECT * FROM prompt_versions WHERE id IN (<ids>) AND workspace_id = :workspace_id")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.comet.opik.infrastructure.db;

import com.comet.opik.api.PromptType;
import com.comet.opik.api.PromptVersion;
import com.comet.opik.utils.JsonUtils;
import com.comet.opik.utils.MustacheUtils;
Expand Down Expand Up @@ -35,6 +36,7 @@ private PromptVersion mapObject(JsonNode jsonNode) {
.template(jsonNode.get("template").asText())
.metadata(jsonNode.get("metadata"))
.changeDescription(jsonNode.get("change_description").asText())
.type(PromptType.fromString(jsonNode.get("type").asText()))
.variables(MustacheUtils.extractVariables(jsonNode.get("template").asText()))
BorisTkachenko marked this conversation as resolved.
Show resolved Hide resolved
.createdAt(Instant.from(FORMATTER.parse(jsonNode.get("created_at").asText())))
.createdBy(jsonNode.get("created_by").asText())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--liquibase formatted sql
--changeset BorisTkachenko:add_type_to_prompt_version

ALTER TABLE prompt_versions ADD COLUMN type ENUM('mustache', 'jinja2') NOT NULL DEFAULT 'mustache';

--rollback ALTER TABLE prompt_versions DROP COLUMN type;
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class PromptResourceTest {
private static final TestDropwizardAppExtension app;

private static final WireMockUtils.WireMockRuntime wireMock;
private static final String[] IGNORED_FIELDS = {"latestVersion", "template", "metadata", "changeDescription"};
private static final String[] IGNORED_FIELDS = {"latestVersion", "template", "metadata", "changeDescription",
"type"};

static {
Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join();
Expand Down Expand Up @@ -1393,6 +1394,7 @@ void when__promptHasMultipleVersions__thenReturnPromptWithLatestVersion() {
.template(promptVersion.template())
.metadata(promptVersion.metadata())
.changeDescription(promptVersion.changeDescription())
.type(promptVersion.type())
.versionCount(2L)
.build();

Expand Down Expand Up @@ -1456,6 +1458,7 @@ private void assertLatestVersion(Prompt actualPrompt, Prompt expectedPrompt, Set
assertThat(promptVersion.template()).isEqualTo(expectedPrompt.template());
assertThat(promptVersion.metadata()).isEqualTo(expectedPrompt.metadata());
assertThat(promptVersion.changeDescription()).isEqualTo(expectedPrompt.changeDescription());
assertThat(promptVersion.type()).isEqualTo(expectedPrompt.type());
assertThat(promptVersion.variables()).isEqualTo(expectedVariables);
assertThat(promptVersion.createdBy()).isEqualTo(USER);
assertThat(promptVersion.createdAt()).isBetween(expectedPrompt.createdAt(), Instant.now());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package com.comet.opik.podam.manufacturer;

import com.comet.opik.api.PromptType;
import com.comet.opik.api.PromptVersion;
import com.fasterxml.jackson.databind.JsonNode;
import org.apache.commons.lang3.RandomStringUtils;
import uk.co.jemos.podam.api.AttributeMetadata;
import uk.co.jemos.podam.api.DataProviderStrategy;
import uk.co.jemos.podam.api.PodamUtils;
import uk.co.jemos.podam.common.ManufacturingContext;
import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;

Expand Down Expand Up @@ -43,10 +45,15 @@ public PromptVersion getType(DataProviderStrategy strategy, AttributeMetadata me
.template(template)
.metadata(strategy.getTypeValue(metadata, context, JsonNode.class))
.changeDescription(strategy.getTypeValue(metadata, context, String.class))
.type(randomPromptType())
.variables(Set.of(variable1, variable2, variable3))
.promptId(strategy.getTypeValue(metadata, context, UUID.class))
.createdBy(strategy.getTypeValue(metadata, context, String.class))
.createdAt(Instant.now())
.build();
}

public PromptType randomPromptType() {
return PromptType.values()[PodamUtils.getIntegerInRange(0, PromptType.values().length - 1)];
}
}
Loading