Skip to content

Commit

Permalink
[OPIK-855] Add type to prompt version (#1179)
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisTkachenko authored Jan 31, 2025
1 parent 6f65206 commit 4f17d04
Show file tree
Hide file tree
Showing 16 changed files with 168 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.List;
import java.util.UUID;

import static com.comet.opik.api.PromptType.MUSTACHE;
import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK;

@Builder(toBuilder = true)
Expand All @@ -33,6 +34,7 @@ public record Prompt(
Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template,
@JsonView({Prompt.View.Write.class}) @Nullable JsonNode metadata,
@JsonView({Prompt.View.Write.class}) @Nullable String changeDescription,
@JsonView({Prompt.View.Write.class}) @Nullable PromptType type,
@JsonView({Prompt.View.Public.class,
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({Prompt.View.Public.class,
Expand Down Expand Up @@ -74,4 +76,9 @@ public static Prompt.PromptPage empty(int page) {
return new Prompt.PromptPage(page, 0, 0, List.of());
}
}

@Override
public PromptType type() {
return type == null ? MUSTACHE : type;
}
}
27 changes: 27 additions & 0 deletions apps/opik-backend/src/main/java/com/comet/opik/api/PromptType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import lombok.Getter;
import lombok.RequiredArgsConstructor;

import java.util.Arrays;

@Getter
@RequiredArgsConstructor
public enum PromptType {

MUSTACHE("mustache"),
JINJA2("jinja2");

@JsonValue
private final String value;

@JsonCreator
public static PromptType fromString(String value) {
return Arrays.stream(values())
.filter(promptType -> promptType.value.equals(value))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("Unknown prompt type '%s'".formatted(value)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import java.util.Set;
import java.util.UUID;

import static com.comet.opik.api.PromptType.MUSTACHE;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
Expand All @@ -35,6 +37,8 @@ public record PromptVersion(
PromptVersion.View.Detail.class}) @NotBlank String template,
@Json @JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) JsonNode metadata,
@JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) PromptType type,
@JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) String changeDescription,
@JsonView({Prompt.View.Detail.class,
Expand Down Expand Up @@ -68,4 +72,9 @@ public static PromptVersion.PromptVersionPage empty(int page) {
return new PromptVersion.PromptVersionPage(page, 0, 0, List.of());
}
}

@Override
public PromptType type() {
return type == null ? MUSTACHE : type;
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package com.comet.opik.api.resources.v1.events;

import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.PromptType;
import com.comet.opik.api.ScoreSource;
import com.comet.opik.api.Trace;
import com.comet.opik.utils.MustacheUtils;
import com.comet.opik.utils.TemplateParseUtils;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -108,7 +109,8 @@ static List<ChatMessage> renderMessages(List<LlmAsJudgeMessage> templateMessages
return templateMessages.stream()
.map(templateMessage -> {
// will convert all '{{key}}' into 'value'
var renderedMessage = MustacheUtils.render(templateMessage.content(), replacements);
var renderedMessage = TemplateParseUtils.render(templateMessage.content(), replacements,
PromptType.MUSTACHE);

return switch (templateMessage.role()) {
case USER -> UserMessage.from(renderedMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ SELECT JSON_OBJECT(
'template', pv.template,
'metadata', pv.metadata,
'change_description', pv.change_description,
'type', pv.type,
'created_at', pv.created_at,
'created_by', pv.created_by,
'last_updated_at', pv.last_updated_at,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import com.comet.opik.api.CreatePromptVersion;
import com.comet.opik.api.Prompt;
import com.comet.opik.api.PromptType;
import com.comet.opik.api.PromptVersion;
import com.comet.opik.api.PromptVersion.PromptVersionPage;
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.utils.MustacheUtils;
import com.comet.opik.utils.TemplateParseUtils;
import com.google.inject.ImplementedBy;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
Expand Down Expand Up @@ -119,6 +120,7 @@ private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt,
.template(promptRequest.template())
.metadata(promptRequest.metadata())
.changeDescription(promptRequest.changeDescription())
.type(promptRequest.type())
.createdBy(createdPrompt.createdBy())
.build();

Expand Down Expand Up @@ -347,7 +349,7 @@ private PromptVersion getById(String workspaceId, UUID id) {
});

return promptVersion.toBuilder()
.variables(getVariables(promptVersion.template()))
.variables(getVariables(promptVersion.template(), promptVersion.type()))
.build();
}

Expand All @@ -368,7 +370,7 @@ public Prompt getById(@NonNull UUID id) {
.latestVersion(
Optional.ofNullable(prompt.latestVersion())
.map(promptVersion -> promptVersion.toBuilder()
.variables(getVariables(promptVersion.template()))
.variables(getVariables(promptVersion.template(), promptVersion.type()))
.build())
.orElse(null))
.build();
Expand All @@ -388,7 +390,7 @@ public Mono<Map<UUID, PromptVersion>> findVersionByIds(@NonNull Set<UUID> ids) {

return promptVersionDAO.findByIds(ids, workspaceId).stream()
.collect(toMap(PromptVersion::id, promptVersion -> promptVersion.toBuilder()
.variables(getVariables(promptVersion.template()))
.variables(getVariables(promptVersion.template(), promptVersion.type()))
.build()));
});
})
Expand All @@ -413,16 +415,16 @@ public PromptVersion getVersionById(@NonNull String workspaceId, @NonNull UUID i
}

return promptVersion.toBuilder()
.variables(getVariables(promptVersion.template()))
.variables(getVariables(promptVersion.template(), promptVersion.type()))
.build();
}

private Set<String> getVariables(String template) {
private Set<String> getVariables(String template, PromptType type) {
if (template == null) {
return null;
}

return MustacheUtils.extractVariables(template);
return TemplateParseUtils.extractVariables(template, type);
}

private EntityAlreadyExistsException newConflict(String alreadyExists) {
Expand Down Expand Up @@ -491,7 +493,7 @@ public PromptVersion retrievePromptVersion(@NonNull String name, String commit)
}

return promptVersion.toBuilder()
.variables(getVariables(promptVersion.template()))
.variables(getVariables(promptVersion.template(), promptVersion.type()))
.build();
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
@RegisterConstructorMapper(PromptVersionId.class)
interface PromptVersionDAO {

@SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, metadata, change_description, created_by, workspace_id) "
@SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, metadata, change_description, type, created_by, workspace_id) "
+
"VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.metadata, :bean.changeDescription, :bean.createdBy, :workspace_id)")
"VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.metadata, :bean.changeDescription, :bean.type, :bean.createdBy, :workspace_id)")
void save(@Bind("workspace_id") String workspaceId, @BindMethods("bean") PromptVersion prompt);

@SqlQuery("SELECT * FROM prompt_versions WHERE id IN (<ids>) AND workspace_id = :workspace_id")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.comet.opik.domain.template;

import java.util.Map;
import java.util.Set;

public class Jinja2Parser implements TemplateParser {
@Override
public Set<String> extractVariables(String template) {
return Set.of();
}

@Override
public String render(String template, Map<String, ?> context) {
throw new UnsupportedOperationException("Jinja2 template rendering is not supported");
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package com.comet.opik.utils;
package com.comet.opik.domain.template;

import com.github.mustachejava.Code;
import com.github.mustachejava.DefaultMustacheFactory;
import com.github.mustachejava.Mustache;
import com.github.mustachejava.MustacheFactory;
import com.github.mustachejava.codes.ValueCode;
import lombok.experimental.UtilityClass;

import java.io.IOException;
import java.io.StringReader;
Expand All @@ -18,12 +17,12 @@
import java.util.Optional;
import java.util.Set;

@UtilityClass
public class MustacheUtils {
public class MustacheParser implements TemplateParser {

public static final MustacheFactory MF = new DefaultMustacheFactory();

public static Set<String> extractVariables(String template) {
@Override
public Set<String> extractVariables(String template) {
Set<String> variables = new HashSet<>();

// Initialize Mustache Factory
Expand All @@ -36,7 +35,8 @@ public static Set<String> extractVariables(String template) {
return variables;
}

public static String render(String template, Map<String, ?> context) {
@Override
public String render(String template, Map<String, ?> context) {

Mustache mustache = MF.compile(new StringReader(template), "template");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.comet.opik.domain.template;

import java.util.Map;
import java.util.Set;

public interface TemplateParser {

Set<String> extractVariables(String template);

String render(String template, Map<String, ?> context);
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package com.comet.opik.infrastructure.db;

import com.comet.opik.api.PromptType;
import com.comet.opik.api.PromptVersion;
import com.comet.opik.utils.JsonUtils;
import com.comet.opik.utils.MustacheUtils;
import com.comet.opik.utils.TemplateParseUtils;
import com.fasterxml.jackson.databind.JsonNode;
import org.jdbi.v3.core.mapper.ColumnMapper;
import org.jdbi.v3.core.statement.StatementContext;
Expand Down Expand Up @@ -35,7 +36,9 @@ private PromptVersion mapObject(JsonNode jsonNode) {
.template(jsonNode.get("template").asText())
.metadata(jsonNode.get("metadata"))
.changeDescription(jsonNode.get("change_description").asText())
.variables(MustacheUtils.extractVariables(jsonNode.get("template").asText()))
.type(PromptType.fromString(jsonNode.get("type").asText()))
.variables(TemplateParseUtils.extractVariables(jsonNode.get("template").asText(),
PromptType.fromString(jsonNode.get("type").asText())))
.createdAt(Instant.from(FORMATTER.parse(jsonNode.get("created_at").asText())))
.createdBy(jsonNode.get("created_by").asText())
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.comet.opik.utils;

import com.comet.opik.api.PromptType;
import com.comet.opik.domain.template.Jinja2Parser;
import com.comet.opik.domain.template.MustacheParser;
import com.comet.opik.domain.template.TemplateParser;
import lombok.NonNull;
import lombok.experimental.UtilityClass;

import java.util.EnumMap;
import java.util.Map;
import java.util.Set;

@UtilityClass
public class TemplateParseUtils {
private static final Map<PromptType, TemplateParser> parsers = new EnumMap<>(PromptType.class);

static {
parsers.put(PromptType.MUSTACHE, new MustacheParser());
parsers.put(PromptType.JINJA2, new Jinja2Parser());
}

public static Set<String> extractVariables(@NonNull String template, @NonNull PromptType type) {
return parsers.get(type).extractVariables(template);
}

public static String render(@NonNull String template, @NonNull Map<String, ?> context, @NonNull PromptType type) {
return parsers.get(type).render(template, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--liquibase formatted sql
--changeset BorisTkachenko:add_type_to_prompt_version

ALTER TABLE prompt_versions ADD COLUMN type ENUM('mustache', 'jinja2') NOT NULL DEFAULT 'mustache';

--rollback ALTER TABLE prompt_versions DROP COLUMN type;
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,8 @@ void getDatasets__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, b
@DisplayName("Update dataset: when api key is present, then return proper response")
void updateDataset__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean shouldSucceed) {

mockTargetWorkspace(okApikey, TEST_WORKSPACE, WORKSPACE_ID);

var dataset = factory.manufacturePojo(Dataset.class).toBuilder()
.id(null)
.build();
Expand Down
Loading

0 comments on commit 4f17d04

Please sign in to comment.