From a5dfdbb911df0cbb35e366322ae2cfc208d2f8ac Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Sun, 29 Sep 2024 00:37:04 +0200 Subject: [PATCH 01/30] Define LLM token usage model --- .../aet/artemis/core/domain/LLMService.java | 5 ++ .../artemis/core/domain/LLMTokenUsage.java | 55 +++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMService.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMService.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMService.java new file mode 100644 index 000000000000..a4d2bdf6e094 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMService.java @@ -0,0 +1,5 @@ +package de.tum.cit.aet.artemis.core.domain; + +public enum LLMService { + Iris, Athena +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java new file mode 100644 index 000000000000..8ad99ee03050 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java @@ -0,0 +1,55 @@ +package de.tum.cit.aet.artemis.core.domain; + +import java.time.ZonedDateTime; + +import jakarta.annotation.Nullable; +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Inheritance; +import jakarta.persistence.InheritanceType; +import jakarta.persistence.JoinColumn; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.Table; + +import org.hibernate.annotations.Cache; +import org.hibernate.annotations.CacheConcurrencyStrategy; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; + +@Entity +@Table(name = "llm_token_usage") +@Inheritance(strategy = InheritanceType.SINGLE_TABLE) +@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") +@JsonInclude(JsonInclude.Include.NON_EMPTY) +public class LLMTokenUsage extends DomainObject { + + @Column(name = "service") + private LLMService service; + + @Column(name = "model") + private String model; + + @Column(name = "cost_per_token") + private double cost_per_token; + + @Column(name = "num_input_tokens") + private int num_input_tokens; + + @Column(name = "num_output_tokens") + private int num_output_tokens; + + @Nullable + @Column(name = "timestamp") + private ZonedDateTime timestamp = ZonedDateTime.now(); + + @Nullable + @ManyToOne + @JsonIgnore + @JoinColumn(name = "iris_message_id") + IrisMessage message; +} From d0bdae365439725b98a3c5729953124d06863829 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Fri, 11 Oct 2024 15:18:10 +0200 Subject: [PATCH 02/30] Update table, save data recieved from Pyris Exercise chat pipeline --- .../aet/artemis/core/domain/LLMService.java | 5 - .../artemis/core/domain/LLMServiceType.java | 5 + .../artemis/core/domain/LLMTokenUsage.java | 103 +++++++++++++++++- .../repository/LLMTokenUsageRepository.java | 10 ++ .../core/service/LLMTokenUsageService.java | 51 +++++++++ .../iris/dto/IrisChatWebsocketDTO.java | 8 +- .../dto/chat/PyrisChatStatusUpdateDTO.java | 3 +- .../pyris/dto/data/PyrisLLMCostDTO.java | 6 + .../session/IrisCourseChatSessionService.java | 2 +- .../IrisExerciseChatSessionService.java | 14 ++- .../websocket/IrisChatWebsocketService.java | 9 +- .../changelog/20241011140701_changelog.xml | 35 ++++++ .../resources/config/liquibase/master.xml | 1 + .../app/entities/iris/iris-message.model.ts | 4 + .../iris/IrisChatMessageIntegrationTest.java | 17 +-- .../artemis/iris/IrisChatWebsocketTest.java | 2 +- 16 files changed, 247 insertions(+), 28 deletions(-) delete mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMService.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java create mode 100644 src/main/resources/config/liquibase/changelog/20241011140701_changelog.xml diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMService.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMService.java deleted file mode 100644 index a4d2bdf6e094..000000000000 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMService.java +++ /dev/null @@ -1,5 +0,0 @@ -package de.tum.cit.aet.artemis.core.domain; - -public enum LLMService { - Iris, Athena -} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java new file mode 100644 index 000000000000..fa9c1d030257 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java @@ -0,0 +1,5 @@ +package de.tum.cit.aet.artemis.core.domain; + +public enum LLMServiceType { + ATHENA, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION, IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, NOT_SET +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java index 8ad99ee03050..6646810c2367 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java @@ -5,6 +5,8 @@ import jakarta.annotation.Nullable; import jakarta.persistence.Column; import jakarta.persistence.Entity; +import jakarta.persistence.EnumType; +import jakarta.persistence.Enumerated; import jakarta.persistence.Inheritance; import jakarta.persistence.InheritanceType; import jakarta.persistence.JoinColumn; @@ -18,6 +20,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import de.tum.cit.aet.artemis.exercise.domain.Exercise; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; @Entity @@ -29,7 +32,8 @@ public class LLMTokenUsage extends DomainObject { @Column(name = "service") - private LLMService service; + @Enumerated(EnumType.STRING) + private LLMServiceType serviceType; @Column(name = "model") private String model; @@ -43,6 +47,21 @@ public class LLMTokenUsage extends DomainObject { @Column(name = "num_output_tokens") private int num_output_tokens; + @Nullable + @ManyToOne + @JsonIgnore + @JoinColumn(name = "course_id") + private Course course; + + @Nullable + @ManyToOne + @JsonIgnore + @JoinColumn(name = "exercise_id") + private Exercise exercise; + + @Column(name = "user_id") + private long userId; + @Nullable @Column(name = "timestamp") private ZonedDateTime timestamp = ZonedDateTime.now(); @@ -51,5 +70,85 @@ public class LLMTokenUsage extends DomainObject { @ManyToOne @JsonIgnore @JoinColumn(name = "iris_message_id") - IrisMessage message; + IrisMessage irisMessage; + + public LLMServiceType getServiceType() { + return serviceType; + } + + public void setServiceType(LLMServiceType serviceType) { + this.serviceType = serviceType; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public double getCost_per_token() { + return cost_per_token; + } + + public void setCost_per_token(double cost_per_token) { + this.cost_per_token = cost_per_token; + } + + public int getNum_input_tokens() { + return num_input_tokens; + } + + public void setNum_input_tokens(int num_input_tokens) { + this.num_input_tokens = num_input_tokens; + } + + public int getNum_output_tokens() { + return num_output_tokens; + } + + public void setNum_output_tokens(int num_output_tokens) { + this.num_output_tokens = num_output_tokens; + } + + public Course getCourse() { + return course; + } + + public void setCourse(Course course) { + this.course = course; + } + + public Exercise getExercise() { + return exercise; + } + + public void setExercise(Exercise exercise) { + this.exercise = exercise; + } + + public long getUserId() { + return userId; + } + + public void setUserId(long userId) { + this.userId = userId; + } + + public ZonedDateTime getTimestamp() { + return timestamp; + } + + public void setTimestamp(ZonedDateTime timestamp) { + this.timestamp = timestamp; + } + + public IrisMessage getIrisMessage() { + return irisMessage; + } + + public void setIrisMessage(IrisMessage message) { + this.irisMessage = message; + } } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java new file mode 100644 index 000000000000..755ed4a1b014 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java @@ -0,0 +1,10 @@ +package de.tum.cit.aet.artemis.core.repository; + +import org.springframework.stereotype.Repository; + +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; +import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; + +@Repository +public interface LLMTokenUsageRepository extends ArtemisJpaRepository { +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java new file mode 100644 index 000000000000..3bbaf8a9e5ab --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -0,0 +1,51 @@ +package de.tum.cit.aet.artemis.core.service; + +import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.context.annotation.Profile; +import org.springframework.stereotype.Service; + +import de.tum.cit.aet.artemis.core.domain.Course; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; +import de.tum.cit.aet.artemis.core.domain.User; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRepository; +import de.tum.cit.aet.artemis.exercise.domain.Exercise; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; + +/** + * Service for managing Iris messages. + */ +@Service +@Profile(PROFILE_IRIS) +public class LLMTokenUsageService { + + private final LLMTokenUsageRepository llmTokenUsageRepository; + + public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) { + this.llmTokenUsageRepository = llmTokenUsageRepository; + } + + public List saveTokenUsage(IrisMessage message, Exercise exercise, User user, Course course, List tokens) { + List tokenUsages = new ArrayList<>(); + for (PyrisLLMCostDTO cost : tokens) { + LLMTokenUsage llmTokenUsage = new LLMTokenUsage(); + if (message != null) { + llmTokenUsage.setIrisMessage(message); + llmTokenUsage.setTimestamp(message.getSentAt()); + } + llmTokenUsage.setServiceType(cost.pipeline()); + llmTokenUsage.setExercise(exercise); + llmTokenUsage.setUserId(user.getId()); + llmTokenUsage.setCourse(course); + llmTokenUsage.setNum_input_tokens(cost.num_input_tokens()); + llmTokenUsage.setNum_output_tokens(cost.num_output_tokens()); + llmTokenUsage.setModel(cost.model_info()); + tokenUsages.add(llmTokenUsageRepository.save(llmTokenUsage)); + } + return tokenUsages; + } +} diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java index 75b56488e513..3663e372c844 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java @@ -9,6 +9,7 @@ import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; /** @@ -21,7 +22,7 @@ */ @JsonInclude(JsonInclude.Include.NON_EMPTY) public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List stages, - List suggestions) { + List suggestions, List tokens) { /** * Creates a new IrisWebsocketDTO instance with the given parameters @@ -31,8 +32,9 @@ public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage me * @param rateLimitInfo the rate limit information * @param stages the stages of the Pyris pipeline */ - public IrisChatWebsocketDTO(@Nullable IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List stages, List suggestions) { - this(determineType(message), message, rateLimitInfo, stages, suggestions); + public IrisChatWebsocketDTO(@Nullable IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List stages, List suggestions, + List tokens) { + this(determineType(message), message, rateLimitInfo, stages, suggestions, tokens); } /** diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java index cbfa0b2d98dd..73a9b5603477 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java @@ -4,8 +4,9 @@ import com.fasterxml.jackson.annotation.JsonInclude; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record PyrisChatStatusUpdateDTO(String result, List stages, List suggestions) { +public record PyrisChatStatusUpdateDTO(String result, List stages, List suggestions, List tokens) { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java new file mode 100644 index 000000000000..1fa344b3ebc1 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java @@ -0,0 +1,6 @@ +package de.tum.cit.aet.artemis.iris.service.pyris.dto.data; + +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; + +public record PyrisLLMCostDTO(String model_info, int num_input_tokens, int num_output_tokens, LLMServiceType pipeline) { +} diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index a2c404b13103..bb2e5c9f71dd 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -140,7 +140,7 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions()); + irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } updateLatestSuggestions(session, statusUpdate.suggestions()); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index cec0a9322134..0d557e03d723 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -15,6 +15,7 @@ import de.tum.cit.aet.artemis.core.exception.ConflictException; import de.tum.cit.aet.artemis.core.security.Role; import de.tum.cit.aet.artemis.core.service.AuthorizationCheckService; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.exercise.domain.Submission; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender; @@ -44,6 +45,8 @@ public class IrisExerciseChatSessionService extends AbstractIrisChatSessionServi private final IrisMessageService irisMessageService; + private final LLMTokenUsageService LLMTokenUsageService; + private final IrisSettingsService irisSettingsService; private final IrisChatWebsocketService irisChatWebsocketService; @@ -62,13 +65,14 @@ public class IrisExerciseChatSessionService extends AbstractIrisChatSessionServi private final ProgrammingExerciseRepository programmingExerciseRepository; - public IrisExerciseChatSessionService(IrisMessageService irisMessageService, IrisSettingsService irisSettingsService, IrisChatWebsocketService irisChatWebsocketService, - AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, + public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService LLMTokenUsageService, IrisSettingsService irisSettingsService, + IrisChatWebsocketService irisChatWebsocketService, AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, ProgrammingExerciseStudentParticipationRepository programmingExerciseStudentParticipationRepository, ProgrammingSubmissionRepository programmingSubmissionRepository, IrisRateLimitService rateLimitService, PyrisPipelineService pyrisPipelineService, ProgrammingExerciseRepository programmingExerciseRepository, ObjectMapper objectMapper) { super(irisSessionRepository, objectMapper); this.irisMessageService = irisMessageService; + this.LLMTokenUsageService = LLMTokenUsageService; this.irisSettingsService = irisSettingsService; this.irisChatWebsocketService = irisChatWebsocketService; this.authCheckService = authCheckService; @@ -172,10 +176,14 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); + var tokenUsages = LLMTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(), + session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions()); + var tokenUsages = LLMTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), + statusUpdate.tokens()); + irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java index 320a3103fe99..9f8e97a80be5 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java @@ -11,6 +11,7 @@ import de.tum.cit.aet.artemis.iris.domain.session.IrisChatSession; import de.tum.cit.aet.artemis.iris.dto.IrisChatWebsocketDTO; import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; @Service @@ -41,7 +42,7 @@ public void sendMessage(IrisChatSession session, IrisMessage irisMessage, List

stages) { - this.sendStatusUpdate(session, stages, null); + this.sendStatusUpdate(session, stages, null, null); } /** @@ -62,11 +63,11 @@ public void sendStatusUpdate(IrisChatSession session, List stages * @param stages the stages to send * @param suggestions the suggestions to send */ - public void sendStatusUpdate(IrisChatSession session, List stages, List suggestions) { + public void sendStatusUpdate(IrisChatSession session, List stages, List suggestions, List tokens) { var user = session.getUser(); var rateLimitInfo = rateLimitService.getRateLimitInformation(user); var topic = "" + session.getId(); // Todo: add more specific topic - var payload = new IrisChatWebsocketDTO(null, rateLimitInfo, stages, suggestions); + var payload = new IrisChatWebsocketDTO(null, rateLimitInfo, stages, suggestions, tokens); websocketService.send(user.getLogin(), topic, payload); } } diff --git a/src/main/resources/config/liquibase/changelog/20241011140701_changelog.xml b/src/main/resources/config/liquibase/changelog/20241011140701_changelog.xml new file mode 100644 index 000000000000..4088f2b05a95 --- /dev/null +++ b/src/main/resources/config/liquibase/changelog/20241011140701_changelog.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/main/resources/config/liquibase/master.xml b/src/main/resources/config/liquibase/master.xml index f8be6b6255a0..926b74391e8b 100644 --- a/src/main/resources/config/liquibase/master.xml +++ b/src/main/resources/config/liquibase/master.xml @@ -22,6 +22,7 @@ + diff --git a/src/main/webapp/app/entities/iris/iris-message.model.ts b/src/main/webapp/app/entities/iris/iris-message.model.ts index d4b32977894f..04923788f1e6 100644 --- a/src/main/webapp/app/entities/iris/iris-message.model.ts +++ b/src/main/webapp/app/entities/iris/iris-message.model.ts @@ -13,6 +13,8 @@ export class IrisAssistantMessage implements BaseEntity { sentAt: dayjs.Dayjs; sender: IrisSender.LLM; helpful?: boolean; + num_input_tokens?: number; + num_output_tokens?: number; } export class IrisUserMessage implements BaseEntity { @@ -21,6 +23,8 @@ export class IrisUserMessage implements BaseEntity { sentAt?: dayjs.Dayjs; sender: IrisSender.USER; messageDifferentiator?: number; + num_input_tokens?: number; + num_output_tokens?: number; } export type IrisMessage = IrisAssistantMessage | IrisUserMessage; diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java index 1a8ab50d721c..ad92db87287f 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java @@ -41,6 +41,7 @@ import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository; import de.tum.cit.aet.artemis.iris.service.IrisMessageService; import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; @@ -130,7 +131,7 @@ void sendOneMessage() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null)); pipelineDone.set(true); }); @@ -155,7 +156,7 @@ void sendSuggestions() throws Exception { List suggestions = List.of("suggestion1", "suggestion2", "suggestion3"); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, dto.initialStages(), suggestions)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, dto.initialStages(), suggestions, null)); pipelineDone.set(true); }); @@ -194,7 +195,7 @@ void sendTwoMessages() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 1", dto.initialStages(), null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 1", dto.initialStages(), null, null)); pipelineDone.set(true); }); @@ -202,7 +203,7 @@ void sendTwoMessages() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 2", dto.initialStages(), null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 2", dto.initialStages(), null, null)); pipelineDone.set(true); }); @@ -298,7 +299,7 @@ void resendMessage() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null)); pipelineDone.set(true); }); @@ -321,7 +322,7 @@ void sendMessageRateLimitReached() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null)); pipelineDone.set(true); }); @@ -444,9 +445,9 @@ public String toString() { }; } - private void sendStatus(String jobId, String result, List stages, List suggestions) throws Exception { + private void sendStatus(String jobId, String result, List stages, List suggestions, List tokens) throws Exception { var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); - request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions), + request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, tokens), HttpStatus.OK, headers); } } diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatWebsocketTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatWebsocketTest.java index 03845b59efb7..03afd1453235 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatWebsocketTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatWebsocketTest.java @@ -53,7 +53,7 @@ void sendMessage() { message.setMessageDifferentiator(101010); irisChatWebsocketService.sendMessage(irisSession, message, List.of()); verify(websocketMessagingService, times(1)).sendMessageToUser(eq(TEST_PREFIX + "student1"), eq("/topic/iris/" + irisSession.getId()), - eq(new IrisChatWebsocketDTO(message, new IrisRateLimitService.IrisRateLimitInformation(0, -1, 0), List.of(), List.of()))); + eq(new IrisChatWebsocketDTO(message, new IrisRateLimitService.IrisRateLimitInformation(0, -1, 0), List.of(), List.of(), List.of()))); } private IrisTextMessageContent createMockContent() { From 2a08cb2bcc32d9f811fdc00add177f863a5ccaca Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Fri, 11 Oct 2024 17:49:52 +0200 Subject: [PATCH 03/30] Implement competency generation tracking, update enum --- .../cit/aet/artemis/core/domain/LLMServiceType.java | 3 ++- .../artemis/core/service/LLMTokenUsageService.java | 4 +++- .../service/IrisCompetencyGenerationService.java | 10 ++++++++-- .../competency/PyrisCompetencyStatusUpdateDTO.java | 3 ++- .../session/IrisCourseChatSessionService.java | 13 ++++++++++--- .../session/IrisExerciseChatSessionService.java | 8 ++++---- .../IrisCompetencyGenerationIntegrationTest.java | 2 +- 7 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java index fa9c1d030257..8b41cdf30fa8 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java @@ -1,5 +1,6 @@ package de.tum.cit.aet.artemis.core.domain; public enum LLMServiceType { - ATHENA, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION, IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, NOT_SET + ATHENA, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION, IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, + IRIS_CITATION_PIPELINE, NOT_SET } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 3bbaf8a9e5ab..168581b6fa4d 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -39,7 +39,9 @@ public List saveTokenUsage(IrisMessage message, Exercise exercise } llmTokenUsage.setServiceType(cost.pipeline()); llmTokenUsage.setExercise(exercise); - llmTokenUsage.setUserId(user.getId()); + if (user != null) { + llmTokenUsage.setUserId(user.getId()); + } llmTokenUsage.setCourse(course); llmTokenUsage.setNum_input_tokens(cost.num_input_tokens()); llmTokenUsage.setNum_output_tokens(cost.num_output_tokens()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 7c37831a611c..c38293c8a150 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -8,6 +8,7 @@ import de.tum.cit.aet.artemis.atlas.domain.competency.CompetencyTaxonomy; import de.tum.cit.aet.artemis.core.domain.Course; import de.tum.cit.aet.artemis.core.domain.User; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService; import de.tum.cit.aet.artemis.iris.service.pyris.dto.competency.PyrisCompetencyExtractionPipelineExecutionDTO; @@ -25,14 +26,18 @@ public class IrisCompetencyGenerationService { private final PyrisPipelineService pyrisPipelineService; + private final LLMTokenUsageService llmTokenUsageService; + private final IrisWebsocketService websocketService; private final PyrisJobService pyrisJobService; - public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, IrisWebsocketService websocketService, PyrisJobService pyrisJobService) { + public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, IrisWebsocketService websocketService, + PyrisJobService pyrisJobService) { this.pyrisPipelineService = pyrisPipelineService; this.websocketService = websocketService; this.pyrisJobService = pyrisJobService; + this.llmTokenUsageService = llmTokenUsageService; } /** @@ -50,7 +55,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String "default", pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user.getLogin())), executionDto -> new PyrisCompetencyExtractionPipelineExecutionDTO(executionDto, courseDescription, currentCompetencies, CompetencyTaxonomy.values(), 5), - stages -> websocketService.send(user.getLogin(), websocketTopic(course.getId()), new PyrisCompetencyStatusUpdateDTO(stages, null)) + stages -> websocketService.send(user.getLogin(), websocketTopic(course.getId()), new PyrisCompetencyStatusUpdateDTO(stages, null, null)) ); // @formatter:on } @@ -63,6 +68,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String * @param statusUpdate the status update containing the new competency recommendations */ public void handleStatusUpdate(String userLogin, long courseId, PyrisCompetencyStatusUpdateDTO statusUpdate) { + var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens()); websocketService.send(userLogin, websocketTopic(courseId), statusUpdate); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java index 0956a52f26e8..5a1774eeeb61 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; /** @@ -15,5 +16,5 @@ * @param result List of competencies recommendations that have been generated so far */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record PyrisCompetencyStatusUpdateDTO(List stages, List result) { +public record PyrisCompetencyStatusUpdateDTO(List stages, List result, List tokens) { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index bb2e5c9f71dd..b87e64081c07 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -19,6 +19,7 @@ import de.tum.cit.aet.artemis.core.exception.AccessForbiddenException; import de.tum.cit.aet.artemis.core.security.Role; import de.tum.cit.aet.artemis.core.service.AuthorizationCheckService; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender; import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; @@ -43,6 +44,8 @@ public class IrisCourseChatSessionService extends AbstractIrisChatSessionService private final IrisMessageService irisMessageService; + private final LLMTokenUsageService llmTokenUsageService; + private final IrisSettingsService irisSettingsService; private final IrisChatWebsocketService irisChatWebsocketService; @@ -57,11 +60,13 @@ public class IrisCourseChatSessionService extends AbstractIrisChatSessionService private final PyrisPipelineService pyrisPipelineService; - public IrisCourseChatSessionService(IrisMessageService irisMessageService, IrisSettingsService irisSettingsService, IrisChatWebsocketService irisChatWebsocketService, - AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, IrisRateLimitService rateLimitService, - IrisCourseChatSessionRepository irisCourseChatSessionRepository, PyrisPipelineService pyrisPipelineService, ObjectMapper objectMapper) { + public IrisCourseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService llmTokenUsageService, IrisSettingsService irisSettingsService, + IrisChatWebsocketService irisChatWebsocketService, AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, + IrisRateLimitService rateLimitService, IrisCourseChatSessionRepository irisCourseChatSessionRepository, PyrisPipelineService pyrisPipelineService, + ObjectMapper objectMapper, LLMTokenUsageService lLMTokenUsageService) { super(irisSessionRepository, objectMapper); this.irisMessageService = irisMessageService; + this.llmTokenUsageService = llmTokenUsageService; this.irisSettingsService = irisSettingsService; this.irisChatWebsocketService = irisChatWebsocketService; this.authCheckService = authCheckService; @@ -137,9 +142,11 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); + var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { + var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index 0d557e03d723..bcdc912a99f2 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -45,7 +45,7 @@ public class IrisExerciseChatSessionService extends AbstractIrisChatSessionServi private final IrisMessageService irisMessageService; - private final LLMTokenUsageService LLMTokenUsageService; + private final LLMTokenUsageService llmTokenUsageService; private final IrisSettingsService irisSettingsService; @@ -72,7 +72,7 @@ public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLM ObjectMapper objectMapper) { super(irisSessionRepository, objectMapper); this.irisMessageService = irisMessageService; - this.LLMTokenUsageService = LLMTokenUsageService; + this.llmTokenUsageService = LLMTokenUsageService; this.irisSettingsService = irisSettingsService; this.irisChatWebsocketService = irisChatWebsocketService; this.authCheckService = authCheckService; @@ -176,12 +176,12 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - var tokenUsages = LLMTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(), + var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - var tokenUsages = LLMTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), + var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java index b4fef850f439..adf594a5608f 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java @@ -66,7 +66,7 @@ void generateCompetencies_asEditor_shouldSucceed() throws Exception { List stages = List.of(new PyrisStageDTO("Generating Competencies", 10, PyrisStageState.DONE, null)); // In the real system, this would be triggered by Pyris via a REST call to the Artemis server - irisCompetencyGenerationService.handleStatusUpdate(TEST_PREFIX + "editor1", course.getId(), new PyrisCompetencyStatusUpdateDTO(stages, recommendations)); + irisCompetencyGenerationService.handleStatusUpdate(TEST_PREFIX + "editor1", course.getId(), new PyrisCompetencyStatusUpdateDTO(stages, recommendations, null)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PyrisCompetencyStatusUpdateDTO.class); verify(websocketMessagingService, timeout(200).times(3)).sendMessageToUser(eq(TEST_PREFIX + "editor1"), eq("/topic/iris/competencies/" + course.getId()), From f85cf46f419bdecf4e37563c392465ab654afa80 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Fri, 11 Oct 2024 18:59:36 +0200 Subject: [PATCH 04/30] Add comments to LLMTokenUsageService --- .../core/service/LLMTokenUsageService.java | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 168581b6fa4d..2a5b9204ed48 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -17,7 +17,7 @@ import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; /** - * Service for managing Iris messages. + * Service for managing the LLMTokenUsage by all LLMs in Artemis */ @Service @Profile(PROFILE_IRIS) @@ -29,6 +29,19 @@ public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) { this.llmTokenUsageRepository = llmTokenUsageRepository; } + /** + * saves the tokens used for a specific IrisMessage or Athena call + * in case of an Athena call IrisMessage can be null and the + * LLMServiceType in tokens has to by Athena + * + * @param message IrisMessage related to the TokenUsage + * @param exercise Exercise in which the request was made + * @param user User that made the request + * @param course Course in which the request was made + * @param tokens List with Tokens of the PyrisLLMCostDTO Mdel + * @return List of the created LLMTokenUsage entries + */ + public List saveTokenUsage(IrisMessage message, Exercise exercise, User user, Course course, List tokens) { List tokenUsages = new ArrayList<>(); for (PyrisLLMCostDTO cost : tokens) { From 65fb25974972685a436d4705f04ed3b570dd4040 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Sat, 12 Oct 2024 11:18:03 +0200 Subject: [PATCH 05/30] Fix server test failures by checking if tokens received --- .../repository/LLMTokenUsageRepository.java | 4 ++++ .../IrisCompetencyGenerationService.java | 4 +++- .../session/IrisExerciseChatSessionService.java | 12 ++++++++---- .../iris/IrisChatMessageIntegrationTest.java | 17 ++++++++--------- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java index 755ed4a1b014..2e6d9f1902e1 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java @@ -1,10 +1,14 @@ package de.tum.cit.aet.artemis.core.repository; +import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS; + +import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Repository; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; @Repository +@Profile(PROFILE_IRIS) public interface LLMTokenUsageRepository extends ArtemisJpaRepository { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index c38293c8a150..0ac0872a1767 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -68,7 +68,9 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String * @param statusUpdate the status update containing the new competency recommendations */ public void handleStatusUpdate(String userLogin, long courseId, PyrisCompetencyStatusUpdateDTO statusUpdate) { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens()); + if (statusUpdate.tokens() != null) { + var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens()); + } websocketService.send(userLogin, websocketTopic(courseId), statusUpdate); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index bcdc912a99f2..12fd6fdf6905 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -176,13 +176,17 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(), - session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); + if (statusUpdate.tokens() != null) { + var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(), + session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); + } irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), - statusUpdate.tokens()); + if (statusUpdate.tokens() != null) { + var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), + session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); + } irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java index ad92db87287f..8e36d9063cd9 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java @@ -41,7 +41,6 @@ import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository; import de.tum.cit.aet.artemis.iris.service.IrisMessageService; import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; @@ -131,7 +130,7 @@ void sendOneMessage() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -156,7 +155,7 @@ void sendSuggestions() throws Exception { List suggestions = List.of("suggestion1", "suggestion2", "suggestion3"); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, dto.initialStages(), suggestions, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, dto.initialStages(), suggestions)); pipelineDone.set(true); }); @@ -195,7 +194,7 @@ void sendTwoMessages() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 1", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 1", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -203,7 +202,7 @@ void sendTwoMessages() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 2", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 2", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -299,7 +298,7 @@ void resendMessage() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -322,7 +321,7 @@ void sendMessageRateLimitReached() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -445,9 +444,9 @@ public String toString() { }; } - private void sendStatus(String jobId, String result, List stages, List suggestions, List tokens) throws Exception { + private void sendStatus(String jobId, String result, List stages, List suggestions) throws Exception { var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); - request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, tokens), + request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, null), HttpStatus.OK, headers); } } From 188ff22fc094b9f0aabc2c8126deae6e1d4ed796 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Sat, 12 Oct 2024 12:55:46 +0200 Subject: [PATCH 06/30] Update database for cost tracking and trace_id functionality --- .../artemis/core/domain/LLMServiceType.java | 4 +- .../artemis/core/domain/LLMTokenUsage.java | 43 ++++++++++++++++--- .../core/service/LLMTokenUsageService.java | 10 +++++ .../pyris/dto/data/PyrisLLMCostDTO.java | 2 +- ...gelog.xml => 20241012125003_changelog.xml} | 6 ++- .../resources/config/liquibase/master.xml | 2 +- 6 files changed, 54 insertions(+), 13 deletions(-) rename src/main/resources/config/liquibase/changelog/{20241011140701_changelog.xml => 20241012125003_changelog.xml} (89%) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java index 8b41cdf30fa8..e71589f300ca 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java @@ -1,6 +1,6 @@ package de.tum.cit.aet.artemis.core.domain; public enum LLMServiceType { - ATHENA, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION, IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, - IRIS_CITATION_PIPELINE, NOT_SET + ATHENA_PRELIMINARY_FEEDBACK, ATHENA_FEEDBACK_SUGGESTION, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION, + IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, NOT_SET } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java index 6646810c2367..cd26f899113a 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java @@ -38,15 +38,18 @@ public class LLMTokenUsage extends DomainObject { @Column(name = "model") private String model; - @Column(name = "cost_per_token") - private double cost_per_token; - @Column(name = "num_input_tokens") private int num_input_tokens; + @Column(name = "cost_per_input_token") + private float cost_per_input_token; + @Column(name = "num_output_tokens") private int num_output_tokens; + @Column(name = "cost_per_output_token") + private float cost_per_output_token; + @Nullable @ManyToOne @JsonIgnore @@ -66,6 +69,9 @@ public class LLMTokenUsage extends DomainObject { @Column(name = "timestamp") private ZonedDateTime timestamp = ZonedDateTime.now(); + @Column(name = "trace_id") + private Long traceId; + @Nullable @ManyToOne @JsonIgnore @@ -88,12 +94,20 @@ public void setModel(String model) { this.model = model; } - public double getCost_per_token() { - return cost_per_token; + public float getCost_per_input_token() { + return cost_per_input_token; + } + + public void setCost_per_input_token(float cost_per_input_token) { + this.cost_per_input_token = cost_per_input_token; + } + + public float getCost_per_output_token() { + return cost_per_output_token; } - public void setCost_per_token(double cost_per_token) { - this.cost_per_token = cost_per_token; + public void setCost_per_output_token(float cost_per_output_token) { + this.cost_per_output_token = cost_per_output_token; } public int getNum_input_tokens() { @@ -144,6 +158,14 @@ public void setTimestamp(ZonedDateTime timestamp) { this.timestamp = timestamp; } + public Long getTraceId() { + return traceId; + } + + public void setTraceId(Long traceId) { + this.traceId = traceId; + } + public IrisMessage getIrisMessage() { return irisMessage; } @@ -151,4 +173,11 @@ public IrisMessage getIrisMessage() { public void setIrisMessage(IrisMessage message) { this.irisMessage = message; } + + @Override + public String toString() { + return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + num_input_tokens + ", cost_per_input_token=" + cost_per_input_token + + ", num_output_tokens=" + num_output_tokens + ", cost_per_output_token=" + cost_per_output_token + ", course=" + course + ", exercise=" + exercise + ", userId=" + + userId + ", timestamp=" + timestamp + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}'; + } } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 2a5b9204ed48..3446c61117e3 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -4,6 +4,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.UUID; import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; @@ -44,6 +45,12 @@ public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) { public List saveTokenUsage(IrisMessage message, Exercise exercise, User user, Course course, List tokens) { List tokenUsages = new ArrayList<>(); + + // Combine current time and UUID to create a unique traceId + long timestamp = System.currentTimeMillis(); + long uuidComponent = UUID.randomUUID().getLeastSignificantBits() & Long.MAX_VALUE; + Long traceId = timestamp + uuidComponent; + for (PyrisLLMCostDTO cost : tokens) { LLMTokenUsage llmTokenUsage = new LLMTokenUsage(); if (message != null) { @@ -57,8 +64,11 @@ public List saveTokenUsage(IrisMessage message, Exercise exercise } llmTokenUsage.setCourse(course); llmTokenUsage.setNum_input_tokens(cost.num_input_tokens()); + llmTokenUsage.setCost_per_input_token(cost.cost_per_input_token()); llmTokenUsage.setNum_output_tokens(cost.num_output_tokens()); + llmTokenUsage.setCost_per_output_token(cost.cost_per_output_token()); llmTokenUsage.setModel(cost.model_info()); + llmTokenUsage.setTraceId(traceId); tokenUsages.add(llmTokenUsageRepository.save(llmTokenUsage)); } return tokenUsages; diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java index 1fa344b3ebc1..13fd40d84bf1 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java @@ -2,5 +2,5 @@ import de.tum.cit.aet.artemis.core.domain.LLMServiceType; -public record PyrisLLMCostDTO(String model_info, int num_input_tokens, int num_output_tokens, LLMServiceType pipeline) { +public record PyrisLLMCostDTO(String model_info, int num_input_tokens, float cost_per_input_token, int num_output_tokens, float cost_per_output_token, LLMServiceType pipeline) { } diff --git a/src/main/resources/config/liquibase/changelog/20241011140701_changelog.xml b/src/main/resources/config/liquibase/changelog/20241012125003_changelog.xml similarity index 89% rename from src/main/resources/config/liquibase/changelog/20241011140701_changelog.xml rename to src/main/resources/config/liquibase/changelog/20241012125003_changelog.xml index 4088f2b05a95..fcb5bb25ac22 100644 --- a/src/main/resources/config/liquibase/changelog/20241011140701_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241012125003_changelog.xml @@ -5,20 +5,22 @@ xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.29.xsd" objectQuotingStrategy="QUOTE_ONLY_RESERVED_WORDS"> - + - + + + - + From be85a3b21363a89b83a66faefcba9fb0a6f80223 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Sun, 13 Oct 2024 00:05:21 +0200 Subject: [PATCH 07/30] Update database, add information to competency gen, change traceId calc --- .../artemis/core/domain/LLMServiceType.java | 2 +- .../artemis/core/domain/LLMTokenUsage.java | 73 +++++++++---------- .../core/service/LLMTokenUsageService.java | 42 +++++------ .../IrisCompetencyGenerationService.java | 18 +++-- .../pyris/PyrisStatusUpdateService.java | 4 +- .../pyris/dto/data/PyrisLLMCostDTO.java | 2 +- .../session/IrisCourseChatSessionService.java | 4 +- .../IrisExerciseChatSessionService.java | 8 +- ...gelog.xml => 20241012080932_changelog.xml} | 10 +-- .../resources/config/liquibase/master.xml | 2 +- ...isCompetencyGenerationIntegrationTest.java | 4 +- 11 files changed, 82 insertions(+), 87 deletions(-) rename src/main/resources/config/liquibase/changelog/{20241012125003_changelog.xml => 20241012080932_changelog.xml} (86%) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java index e71589f300ca..4869ec522c3c 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java @@ -2,5 +2,5 @@ public enum LLMServiceType { ATHENA_PRELIMINARY_FEEDBACK, ATHENA_FEEDBACK_SUGGESTION, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION, - IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, NOT_SET + IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, IRIS_LECTURE_RETRIEVAL_PIPELINE, IRIS_LECTURE_INGESTION, NOT_SET } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java index cd26f899113a..3f51390ce397 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java @@ -7,8 +7,6 @@ import jakarta.persistence.Entity; import jakarta.persistence.EnumType; import jakarta.persistence.Enumerated; -import jakarta.persistence.Inheritance; -import jakarta.persistence.InheritanceType; import jakarta.persistence.JoinColumn; import jakarta.persistence.ManyToOne; import jakarta.persistence.Table; @@ -18,16 +16,13 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonTypeInfo; import de.tum.cit.aet.artemis.exercise.domain.Exercise; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; @Entity @Table(name = "llm_token_usage") -@Inheritance(strategy = InheritanceType.SINGLE_TABLE) @Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) -@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") @JsonInclude(JsonInclude.Include.NON_EMPTY) public class LLMTokenUsage extends DomainObject { @@ -39,16 +34,16 @@ public class LLMTokenUsage extends DomainObject { private String model; @Column(name = "num_input_tokens") - private int num_input_tokens; + private int numInputTokens; - @Column(name = "cost_per_input_token") - private float cost_per_input_token; + @Column(name = "cost_per_million_input_tokens") + private float costPerMillionInputTokens; @Column(name = "num_output_tokens") - private int num_output_tokens; + private int numOutputTokens; - @Column(name = "cost_per_output_token") - private float cost_per_output_token; + @Column(name = "cost_per_million_output_tokens") + private float costPerMillionOutputTokens; @Nullable @ManyToOne @@ -66,11 +61,11 @@ public class LLMTokenUsage extends DomainObject { private long userId; @Nullable - @Column(name = "timestamp") - private ZonedDateTime timestamp = ZonedDateTime.now(); + @Column(name = "time") + private ZonedDateTime time = ZonedDateTime.now(); @Column(name = "trace_id") - private Long traceId; + private String traceId; @Nullable @ManyToOne @@ -94,36 +89,36 @@ public void setModel(String model) { this.model = model; } - public float getCost_per_input_token() { - return cost_per_input_token; + public float getCostPerMillionInputTokens() { + return costPerMillionInputTokens; } - public void setCost_per_input_token(float cost_per_input_token) { - this.cost_per_input_token = cost_per_input_token; + public void setCostPerMillionInputTokens(float costPerMillionInputToken) { + this.costPerMillionInputTokens = costPerMillionInputToken; } - public float getCost_per_output_token() { - return cost_per_output_token; + public float getCostPerMillionOutputTokens() { + return costPerMillionOutputTokens; } - public void setCost_per_output_token(float cost_per_output_token) { - this.cost_per_output_token = cost_per_output_token; + public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) { + this.costPerMillionOutputTokens = costPerMillionOutputToken; } - public int getNum_input_tokens() { - return num_input_tokens; + public int getNumInputTokens() { + return numInputTokens; } - public void setNum_input_tokens(int num_input_tokens) { - this.num_input_tokens = num_input_tokens; + public void setNumInputTokens(int numInputTokens) { + this.numInputTokens = numInputTokens; } - public int getNum_output_tokens() { - return num_output_tokens; + public int getNumOutputTokens() { + return numOutputTokens; } - public void setNum_output_tokens(int num_output_tokens) { - this.num_output_tokens = num_output_tokens; + public void setNumOutputTokens(int numOutputTokens) { + this.numOutputTokens = numOutputTokens; } public Course getCourse() { @@ -150,19 +145,19 @@ public void setUserId(long userId) { this.userId = userId; } - public ZonedDateTime getTimestamp() { - return timestamp; + public ZonedDateTime getTime() { + return time; } - public void setTimestamp(ZonedDateTime timestamp) { - this.timestamp = timestamp; + public void setTime(ZonedDateTime time) { + this.time = time; } - public Long getTraceId() { + public String getTraceId() { return traceId; } - public void setTraceId(Long traceId) { + public void setTraceId(String traceId) { this.traceId = traceId; } @@ -176,8 +171,8 @@ public void setIrisMessage(IrisMessage message) { @Override public String toString() { - return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + num_input_tokens + ", cost_per_input_token=" + cost_per_input_token - + ", num_output_tokens=" + num_output_tokens + ", cost_per_output_token=" + cost_per_output_token + ", course=" + course + ", exercise=" + exercise + ", userId=" - + userId + ", timestamp=" + timestamp + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}'; + return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + numInputTokens + ", cost_per_input_token=" + costPerMillionInputTokens + + ", num_output_tokens=" + numOutputTokens + ", cost_per_output_token=" + costPerMillionOutputTokens + ", course=" + course + ", exercise=" + exercise + ", userId=" + + userId + ", timestamp=" + time + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}'; } } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 3446c61117e3..8dab3e056029 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -4,7 +4,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.UUID; import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; @@ -16,6 +15,7 @@ import de.tum.cit.aet.artemis.exercise.domain.Exercise; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob; /** * Service for managing the LLMTokenUsage by all LLMs in Artemis @@ -31,31 +31,25 @@ public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) { } /** - * saves the tokens used for a specific IrisMessage or Athena call - * in case of an Athena call IrisMessage can be null and the - * LLMServiceType in tokens has to by Athena + * method saves the token usage to the database with a link to the IrisMessage + * messages of the same job are grouped together by saving the job id as a trace id * - * @param message IrisMessage related to the TokenUsage - * @param exercise Exercise in which the request was made - * @param user User that made the request - * @param course Course in which the request was made - * @param tokens List with Tokens of the PyrisLLMCostDTO Mdel - * @return List of the created LLMTokenUsage entries + * @param job used to create a unique traceId to group multiple LLM calls + * @param message IrisMessage to map the usage to an IrisMessage + * @param exercise to map the token cost to an exercise + * @param user to map the token cost to a user + * @param course to map the token to a course + * @param tokens token cost lsit of type PyrisLLMCostDTO + * @return list of the saved data */ - - public List saveTokenUsage(IrisMessage message, Exercise exercise, User user, Course course, List tokens) { + public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, Exercise exercise, User user, Course course, List tokens) { List tokenUsages = new ArrayList<>(); - // Combine current time and UUID to create a unique traceId - long timestamp = System.currentTimeMillis(); - long uuidComponent = UUID.randomUUID().getLeastSignificantBits() & Long.MAX_VALUE; - Long traceId = timestamp + uuidComponent; - for (PyrisLLMCostDTO cost : tokens) { LLMTokenUsage llmTokenUsage = new LLMTokenUsage(); if (message != null) { llmTokenUsage.setIrisMessage(message); - llmTokenUsage.setTimestamp(message.getSentAt()); + llmTokenUsage.setTime(message.getSentAt()); } llmTokenUsage.setServiceType(cost.pipeline()); llmTokenUsage.setExercise(exercise); @@ -63,12 +57,12 @@ public List saveTokenUsage(IrisMessage message, Exercise exercise llmTokenUsage.setUserId(user.getId()); } llmTokenUsage.setCourse(course); - llmTokenUsage.setNum_input_tokens(cost.num_input_tokens()); - llmTokenUsage.setCost_per_input_token(cost.cost_per_input_token()); - llmTokenUsage.setNum_output_tokens(cost.num_output_tokens()); - llmTokenUsage.setCost_per_output_token(cost.cost_per_output_token()); - llmTokenUsage.setModel(cost.model_info()); - llmTokenUsage.setTraceId(traceId); + llmTokenUsage.setNumInputTokens(cost.numInputTokens()); + llmTokenUsage.setCostPerMillionInputTokens(cost.costPerInputToken()); + llmTokenUsage.setNumOutputTokens(cost.numOutputTokens()); + llmTokenUsage.setCostPerMillionOutputTokens(cost.costPerOutputToken()); + llmTokenUsage.setModel(cost.modelInfo()); + llmTokenUsage.setTraceId(job.jobId()); tokenUsages.add(llmTokenUsageRepository.save(llmTokenUsage)); } return tokenUsages; diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 0ac0872a1767..93111ad2c234 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -8,6 +8,7 @@ import de.tum.cit.aet.artemis.atlas.domain.competency.CompetencyTaxonomy; import de.tum.cit.aet.artemis.core.domain.Course; import de.tum.cit.aet.artemis.core.domain.User; +import de.tum.cit.aet.artemis.core.repository.CourseRepository; import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService; @@ -28,13 +29,16 @@ public class IrisCompetencyGenerationService { private final LLMTokenUsageService llmTokenUsageService; + private final CourseRepository courseRepository; + private final IrisWebsocketService websocketService; private final PyrisJobService pyrisJobService; - public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, IrisWebsocketService websocketService, - PyrisJobService pyrisJobService) { + public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, CourseRepository courseRepository, + IrisWebsocketService websocketService, PyrisJobService pyrisJobService) { this.pyrisPipelineService = pyrisPipelineService; + this.courseRepository = courseRepository; this.websocketService = websocketService; this.pyrisJobService = pyrisJobService; this.llmTokenUsageService = llmTokenUsageService; @@ -63,15 +67,15 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String /** * Takes a status update from Pyris containing a new competency extraction result and sends it to the client via websocket * - * @param userLogin the login of the user - * @param courseId the id of the course + * @param job Job related to the status update * @param statusUpdate the status update containing the new competency recommendations */ - public void handleStatusUpdate(String userLogin, long courseId, PyrisCompetencyStatusUpdateDTO statusUpdate) { + public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { + Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); if (statusUpdate.tokens() != null) { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(job, null, null, null, course, statusUpdate.tokens()); } - websocketService.send(userLogin, websocketTopic(courseId), statusUpdate); + websocketService.send(job.userLogin(), websocketTopic(job.courseId()), statusUpdate); } private static String websocketTopic(long courseId) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java index aed62b6049c1..732b2b572458 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java @@ -71,13 +71,13 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu /** * Handles the status update of a competency extraction job and forwards it to - * {@link IrisCompetencyGenerationService#handleStatusUpdate(String, long, PyrisCompetencyStatusUpdateDTO)} + * {@link IrisCompetencyGenerationService#handleStatusUpdate(CompetencyExtractionJob, PyrisCompetencyStatusUpdateDTO)} * * @param job the job that is updated * @param statusUpdate the status update */ public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { - competencyGenerationService.handleStatusUpdate(job.userLogin(), job.courseId(), statusUpdate); + competencyGenerationService.handleStatusUpdate(job, statusUpdate); removeJobIfTerminated(statusUpdate.stages(), job.jobId()); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java index 13fd40d84bf1..74f40cce6873 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java @@ -2,5 +2,5 @@ import de.tum.cit.aet.artemis.core.domain.LLMServiceType; -public record PyrisLLMCostDTO(String model_info, int num_input_tokens, float cost_per_input_token, int num_output_tokens, float cost_per_output_token, LLMServiceType pipeline) { +public record PyrisLLMCostDTO(String modelInfo, int numInputTokens, float costPerInputToken, int numOutputTokens, float costPerOutputToken, LLMServiceType pipeline) { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index b87e64081c07..388a0539cf0b 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -142,11 +142,11 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); + var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); + var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, null, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index 12fd6fdf6905..613cbcc4a9eb 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -65,14 +65,14 @@ public class IrisExerciseChatSessionService extends AbstractIrisChatSessionServi private final ProgrammingExerciseRepository programmingExerciseRepository; - public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService LLMTokenUsageService, IrisSettingsService irisSettingsService, + public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService llmTokenUsageService, IrisSettingsService irisSettingsService, IrisChatWebsocketService irisChatWebsocketService, AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, ProgrammingExerciseStudentParticipationRepository programmingExerciseStudentParticipationRepository, ProgrammingSubmissionRepository programmingSubmissionRepository, IrisRateLimitService rateLimitService, PyrisPipelineService pyrisPipelineService, ProgrammingExerciseRepository programmingExerciseRepository, ObjectMapper objectMapper) { super(irisSessionRepository, objectMapper); this.irisMessageService = irisMessageService; - this.llmTokenUsageService = LLMTokenUsageService; + this.llmTokenUsageService = llmTokenUsageService; this.irisSettingsService = irisSettingsService; this.irisChatWebsocketService = irisChatWebsocketService; this.authCheckService = authCheckService; @@ -177,14 +177,14 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); if (statusUpdate.tokens() != null) { - var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(), + var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); } irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { if (statusUpdate.tokens() != null) { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), + var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); } irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); diff --git a/src/main/resources/config/liquibase/changelog/20241012125003_changelog.xml b/src/main/resources/config/liquibase/changelog/20241012080932_changelog.xml similarity index 86% rename from src/main/resources/config/liquibase/changelog/20241012125003_changelog.xml rename to src/main/resources/config/liquibase/changelog/20241012080932_changelog.xml index fcb5bb25ac22..e8f846219bb2 100644 --- a/src/main/resources/config/liquibase/changelog/20241012125003_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241012080932_changelog.xml @@ -5,7 +5,7 @@ xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.29.xsd" objectQuotingStrategy="QUOTE_ONLY_RESERVED_WORDS"> - + @@ -13,14 +13,14 @@ - + - + - - + + - + diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java index adf594a5608f..282e4294eef2 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java @@ -22,6 +22,7 @@ import de.tum.cit.aet.artemis.iris.service.pyris.dto.competency.PyrisCompetencyStatusUpdateDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; +import de.tum.cit.aet.artemis.iris.service.pyris.job.CompetencyExtractionJob; class IrisCompetencyGenerationIntegrationTest extends AbstractIrisIntegrationTest { @@ -66,7 +67,8 @@ void generateCompetencies_asEditor_shouldSucceed() throws Exception { List stages = List.of(new PyrisStageDTO("Generating Competencies", 10, PyrisStageState.DONE, null)); // In the real system, this would be triggered by Pyris via a REST call to the Artemis server - irisCompetencyGenerationService.handleStatusUpdate(TEST_PREFIX + "editor1", course.getId(), new PyrisCompetencyStatusUpdateDTO(stages, recommendations, null)); + CompetencyExtractionJob job = new CompetencyExtractionJob("1", course.getId(), TEST_PREFIX + "editor1"); + irisCompetencyGenerationService.handleStatusUpdate(job, new PyrisCompetencyStatusUpdateDTO(stages, recommendations, null)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PyrisCompetencyStatusUpdateDTO.class); verify(websocketMessagingService, timeout(200).times(3)).sendMessageToUser(eq(TEST_PREFIX + "editor1"), eq("/topic/iris/competencies/" + course.getId()), From e974d5976648c6c9c3ff0ac05b15bb73488cfd58 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Sun, 13 Oct 2024 16:21:12 +0200 Subject: [PATCH 08/30] Implement server Integration tests for token tracking and saving --- .../session/IrisCourseChatSessionService.java | 4 +- .../IrisExerciseChatSessionService.java | 6 +- .../IrisChatTokenTrackingIntegrationTest.java | 272 ++++++++++++++++++ 3 files changed, 277 insertions(+), 5 deletions(-) create mode 100644 src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index 388a0539cf0b..d0bfc7c8fe85 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -142,11 +142,11 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, null, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(job, null, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index 613cbcc4a9eb..442955e8fcea 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -177,15 +177,15 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); if (statusUpdate.tokens() != null) { - var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getExercise(), session.getUser(), + llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); } irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { if (statusUpdate.tokens() != null) { - var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, null, session.getExercise(), session.getUser(), - session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(job, null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), + statusUpdate.tokens()); } irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java new file mode 100644 index 000000000000..7917f820dd02 --- /dev/null +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -0,0 +1,272 @@ +package de.tum.cit.aet.artemis.iris; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.eclipse.jgit.api.errors.GitAPIException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.mock.mockito.SpyBean; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.util.LinkedMultiValueMap; + +import de.tum.cit.aet.artemis.connector.IrisRequestMockProvider; +import de.tum.cit.aet.artemis.core.domain.Course; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; +import de.tum.cit.aet.artemis.core.domain.User; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; +import de.tum.cit.aet.artemis.exercise.domain.Exercise; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageContent; +import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; +import de.tum.cit.aet.artemis.iris.domain.session.IrisSession; +import de.tum.cit.aet.artemis.iris.repository.IrisMessageRepository; +import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; +import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob; +import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; +import de.tum.cit.aet.artemis.participation.ParticipationUtilService; +import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise; +import de.tum.cit.aet.artemis.programming.domain.ProgrammingExerciseStudentParticipation; +import de.tum.cit.aet.artemis.programming.domain.ProjectType; +import de.tum.cit.aet.artemis.programming.domain.SolutionProgrammingExerciseParticipation; +import de.tum.cit.aet.artemis.programming.domain.TemplateProgrammingExerciseParticipation; + +class IrisChatTokenTrackingIntegrationTest extends AbstractIrisIntegrationTest { + + private static final String TEST_PREFIX = "irischattokentrackingintegration"; + + @Autowired + private IrisExerciseChatSessionService irisExerciseChatSessionService; + + @Autowired + private IrisMessageRepository irisMessageRepository; + + @SpyBean + private LLMTokenUsageService llmTokenUsageService; + + @Autowired + private IrisRequestMockProvider irisRequestMockProvider; + + @Autowired + private ParticipationUtilService participationUtilService; + + @Autowired + private PyrisJobService pyrisJobService; + + private ProgrammingExercise exercise; + + private Course course; + + private AtomicBoolean pipelineDone; + + @BeforeEach + void initTestCase() throws GitAPIException, IOException, URISyntaxException { + userUtilService.addUsers(TEST_PREFIX, 2, 0, 0, 0); + + course = programmingExerciseUtilService.addCourseWithOneProgrammingExercise(); + exercise = exerciseUtilService.getFirstExerciseWithType(course, ProgrammingExercise.class); + String projectKey = exercise.getProjectKey(); + exercise.setProjectType(ProjectType.PLAIN_GRADLE); + exercise.setTestRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + projectKey.toLowerCase() + "-tests.git"); + programmingExerciseBuildConfigRepository.save(exercise.getBuildConfig()); + programmingExerciseRepository.save(exercise); + exercise = programmingExerciseRepository.findWithAllParticipationsAndBuildConfigById(exercise.getId()).orElseThrow(); + + // Set the correct repository URIs for the template and the solution participation. + String templateRepositorySlug = projectKey.toLowerCase() + "-exercise"; + TemplateProgrammingExerciseParticipation templateParticipation = exercise.getTemplateParticipation(); + templateParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + templateRepositorySlug + ".git"); + templateProgrammingExerciseParticipationRepository.save(templateParticipation); + String solutionRepositorySlug = projectKey.toLowerCase() + "-solution"; + SolutionProgrammingExerciseParticipation solutionParticipation = exercise.getSolutionParticipation(); + solutionParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + solutionRepositorySlug + ".git"); + solutionProgrammingExerciseParticipationRepository.save(solutionParticipation); + + String assignmentRepositorySlug = projectKey.toLowerCase() + "-" + TEST_PREFIX + "student1"; + + // Add a participation for student1. + ProgrammingExerciseStudentParticipation studentParticipation = participationUtilService.addStudentParticipationForProgrammingExercise(exercise, TEST_PREFIX + "student1"); + studentParticipation.setRepositoryUri(String.format(localVCBaseUrl + "/git/%s/%s.git", projectKey, assignmentRepositorySlug)); + studentParticipation.setBranch(defaultBranch); + programmingExerciseStudentParticipationRepository.save(studentParticipation); + + // Prepare the repositories. + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, templateRepositorySlug); + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, projectKey.toLowerCase() + "-tests"); + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, solutionRepositorySlug); + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, assignmentRepositorySlug); + + // Check that the repository folders were created in the file system for all base repositories. + localVCLocalCITestService.verifyRepositoryFoldersExist(exercise, localVCBasePath); + + activateIrisGlobally(); + activateIrisFor(course); + activateIrisFor(exercise); + pipelineDone = new AtomicBoolean(false); + } + + @Test + @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + void testTokenTrackingHandledExerciseChat() throws Exception { + var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + var messageToSend = createDefaultMockMessage(irisSession); + + var tokens = getMockLLMCosts(); + + List doneStage = new ArrayList<>(); + doneStage.add(new PyrisStageDTO("DoneTest", 10, PyrisStageState.DONE, "Done")); + + irisRequestMockProvider.mockRunResponse(dto -> { + assertThat(dto.settings().authenticationToken()).isNotNull(); + + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", doneStage, null, tokens)); + + pipelineDone.set(true); + }); + + request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); + + await().until(pipelineDone::get); + + // Capture the saved token usages + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(llmTokenUsageService).saveIrisTokenUsage(any(PyrisJob.class), any(IrisMessage.class), any(Exercise.class), any(User.class), any(Course.class), captor.capture()); + + // Verify that the tokens are saved correctly + List savedTokenUsages = captor.getValue(); + assertEquals(5, savedTokenUsages.size()); + for (int i = 0; i < savedTokenUsages.size(); i++) { + PyrisLLMCostDTO usage = savedTokenUsages.get(i); + PyrisLLMCostDTO expectedCost = tokens.get(i); + + assertEquals(expectedCost.numInputTokens(), usage.numInputTokens()); + assertEquals(expectedCost.costPerInputToken(), usage.costPerInputToken()); + assertEquals(expectedCost.numOutputTokens(), usage.numOutputTokens()); + assertEquals(expectedCost.costPerOutputToken(), usage.costPerOutputToken()); + } + } + + @Test + @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + void testTokenTrackingSavedExerciseChat() { + + var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + var irisMessage = createDefaultMockMessage(irisSession); + irisMessageRepository.save(irisMessage); + String jobToken = pyrisJobService.addExerciseChatJob(course.getId(), exercise.getId(), irisSession.getId()); + PyrisJob job = pyrisJobService.getJob(jobToken); + + var tokens = getMockLLMCosts(); + + // Capture the saved token usages + List returnedTokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, irisMessage, exercise, irisSession.getUser(), course, tokens); + + assertEquals(5, returnedTokenUsages.size()); + for (int i = 0; i < returnedTokenUsages.size(); i++) { + LLMTokenUsage usage = returnedTokenUsages.get(i); + PyrisLLMCostDTO expectedCost = tokens.get(i); + + assertEquals(expectedCost.modelInfo(), usage.getModel()); + assertEquals(expectedCost.numInputTokens(), usage.getNumInputTokens()); + assertEquals(expectedCost.numOutputTokens(), usage.getNumOutputTokens()); + assertEquals(expectedCost.costPerInputToken(), usage.getCostPerMillionInputTokens()); + assertEquals(expectedCost.costPerOutputToken(), usage.getCostPerMillionOutputTokens()); + assertEquals(expectedCost.pipeline(), usage.getServiceType()); + } + } + + @Test + @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { + var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + var messageToSend = createDefaultMockMessage(irisSession); + + var tokens = getMockLLMCosts(); + + List failedStages = new ArrayList<>(); + failedStages.add(new PyrisStageDTO("TestTokenFail", 10, PyrisStageState.ERROR, "Failed running pipeline")); + + irisRequestMockProvider.mockRunResponse(dto -> { + assertThat(dto.settings().authenticationToken()).isNotNull(); + + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, failedStages, null, tokens)); + + pipelineDone.set(true); + }); + + request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); + + await().until(pipelineDone::get); + + // Capture the saved token usages + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(llmTokenUsageService).saveIrisTokenUsage(any(PyrisJob.class), isNull(), any(Exercise.class), any(User.class), any(Course.class), captor.capture()); + + // Verify that the tokens are saved correctly + List savedTokenUsages = captor.getValue(); + assertEquals(5, savedTokenUsages.size()); + for (int i = 0; i < savedTokenUsages.size(); i++) { + PyrisLLMCostDTO usage = savedTokenUsages.get(i); + PyrisLLMCostDTO expectedCost = tokens.get(i); + + assertEquals(expectedCost.numInputTokens(), usage.numInputTokens()); + assertEquals(expectedCost.costPerInputToken(), usage.costPerInputToken()); + assertEquals(expectedCost.numOutputTokens(), usage.numOutputTokens()); + assertEquals(expectedCost.costPerOutputToken(), usage.costPerOutputToken()); + } + } + + private List getMockLLMCosts() { + List costs = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + costs.add(new PyrisLLMCostDTO("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, LLMServiceType.IRIS_CHAT_EXERCISE_MESSAGE)); + } + return costs; + } + + private IrisMessage createDefaultMockMessage(IrisSession irisSession) { + var messageToSend = irisSession.newMessage(); + messageToSend.addContent(createMockTextContent(), createMockTextContent(), createMockTextContent()); + return messageToSend; + } + + private IrisMessageContent createMockTextContent() { + String[] adjectives = { "happy", "sad", "angry", "funny", "silly", "crazy", "beautiful", "smart" }; + String[] nouns = { "dog", "cat", "house", "car", "book", "computer", "phone", "shoe" }; + + var rdm = ThreadLocalRandom.current(); + String randomAdjective = adjectives[rdm.nextInt(adjectives.length)]; + String randomNoun = nouns[rdm.nextInt(nouns.length)]; + + var text = "The " + randomAdjective + " " + randomNoun + " jumped over the lazy dog."; + return new IrisTextMessageContent(text); + } + + private void sendStatus(String jobId, String result, List stages, List suggestions, List tokens) throws Exception { + var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); + request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, tokens), + HttpStatus.OK, headers); + } +} From 63371627b31610e48a5ccb7d38199e3899e2539f Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Mon, 14 Oct 2024 12:59:12 +0200 Subject: [PATCH 09/30] Update code based on code-rabbit feedback, fix tests --- .../artemis/core/domain/LLMServiceType.java | 29 +++++++++++- .../artemis/core/domain/LLMTokenUsage.java | 9 ++-- .../core/service/LLMTokenUsageService.java | 18 +++++++- .../IrisCompetencyGenerationService.java | 2 +- .../session/IrisCourseChatSessionService.java | 6 +-- .../IrisExerciseChatSessionService.java | 16 +++---- ...gelog.xml => 20241014125521_changelog.xml} | 2 +- .../resources/config/liquibase/master.xml | 2 +- .../IrisChatTokenTrackingIntegrationTest.java | 45 ++++++++----------- ...isCompetencyGenerationIntegrationTest.java | 3 +- 10 files changed, 80 insertions(+), 52 deletions(-) rename src/main/resources/config/liquibase/changelog/{20241012080932_changelog.xml => 20241014125521_changelog.xml} (97%) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java index 4869ec522c3c..f7e179ccacb6 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java @@ -1,6 +1,31 @@ package de.tum.cit.aet.artemis.core.domain; +/** + * Enum representing different types of LLM (Large Language Model) services used in the system. + */ public enum LLMServiceType { - ATHENA_PRELIMINARY_FEEDBACK, ATHENA_FEEDBACK_SUGGESTION, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION, - IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, IRIS_LECTURE_RETRIEVAL_PIPELINE, IRIS_LECTURE_INGESTION, NOT_SET + /** Athena service for preliminary feedback */ + ATHENA_PRELIMINARY_FEEDBACK, + /** Athena service for feedback suggestions */ + ATHENA_FEEDBACK_SUGGESTION, + /** Iris service for code feedback */ + IRIS_CODE_FEEDBACK, + /** Iris service for course chat messages */ + IRIS_CHAT_COURSE_MESSAGE, + /** Iris service for exercise chat messages */ + IRIS_CHAT_EXERCISE_MESSAGE, + /** Iris service for interaction suggestions */ + IRIS_INTERACTION_SUGGESTION, + /** Iris service for lecture chat messages */ + IRIS_CHAT_LECTURE_MESSAGE, + /** Iris service for competency generation */ + IRIS_COMPETENCY_GENERATION, + /** Iris service for citation pipeline */ + IRIS_CITATION_PIPELINE, + /** Iris service for lecture retrieval pipeline */ + IRIS_LECTURE_RETRIEVAL_PIPELINE, + /** Iris service for lecture ingestion */ + IRIS_LECTURE_INGESTION, + /** Default value when the service type is not set */ + NOT_SET } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java index 3f51390ce397..bba09ad0025d 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java @@ -60,7 +60,6 @@ public class LLMTokenUsage extends DomainObject { @Column(name = "user_id") private long userId; - @Nullable @Column(name = "time") private ZonedDateTime time = ZonedDateTime.now(); @@ -71,7 +70,7 @@ public class LLMTokenUsage extends DomainObject { @ManyToOne @JsonIgnore @JoinColumn(name = "iris_message_id") - IrisMessage irisMessage; + private IrisMessage irisMessage; public LLMServiceType getServiceType() { return serviceType; @@ -171,8 +170,8 @@ public void setIrisMessage(IrisMessage message) { @Override public String toString() { - return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + numInputTokens + ", cost_per_input_token=" + costPerMillionInputTokens - + ", num_output_tokens=" + numOutputTokens + ", cost_per_output_token=" + costPerMillionOutputTokens + ", course=" + course + ", exercise=" + exercise + ", userId=" - + userId + ", timestamp=" + time + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}'; + return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", numInputTokens=" + numInputTokens + ", costPerMillionInputTokens=" + + costPerMillionInputTokens + ", numOutputTokens=" + numOutputTokens + ", costPerMillionOutputTokens=" + costPerMillionOutputTokens + ", course=" + course + + ", exercise=" + exercise + ", userId=" + userId + ", timestamp=" + time + ", traceId=" + traceId + ", irisMessage=" + irisMessage + '}'; } } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 8dab3e056029..409812ab2da0 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -39,7 +39,7 @@ public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) { * @param exercise to map the token cost to an exercise * @param user to map the token cost to a user * @param course to map the token to a course - * @param tokens token cost lsit of type PyrisLLMCostDTO + * @param tokens token cost list of type PyrisLLMCostDTO * @return list of the saved data */ public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, Exercise exercise, User user, Course course, List tokens) { @@ -63,8 +63,22 @@ public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, llmTokenUsage.setCostPerMillionOutputTokens(cost.costPerOutputToken()); llmTokenUsage.setModel(cost.modelInfo()); llmTokenUsage.setTraceId(job.jobId()); - tokenUsages.add(llmTokenUsageRepository.save(llmTokenUsage)); + tokenUsages.add(llmTokenUsage); } + llmTokenUsageRepository.saveAll(tokenUsages); return tokenUsages; } + + // Overloaded methods without optional parameters + public List saveIrisTokenUsage(PyrisJob job, User user, Course course, List tokens) { + return saveIrisTokenUsage(job, null, null, user, course, tokens); + } + + public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, User user, Course course, List tokens) { + return saveIrisTokenUsage(job, message, null, user, course, tokens); + } + + public List saveIrisTokenUsage(PyrisJob job, Course course, List tokens) { + return saveIrisTokenUsage(job, null, null, null, course, tokens); + } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 93111ad2c234..aa15f5a14af8 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -73,7 +73,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); if (statusUpdate.tokens() != null) { - llmTokenUsageService.saveIrisTokenUsage(job, null, null, null, course, statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(job, course, statusUpdate.tokens()); } websocketService.send(job.userLogin(), websocketTopic(job.courseId()), statusUpdate); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index d0bfc7c8fe85..2ebe5c7ae40e 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -63,7 +63,7 @@ public class IrisCourseChatSessionService extends AbstractIrisChatSessionService public IrisCourseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService llmTokenUsageService, IrisSettingsService irisSettingsService, IrisChatWebsocketService irisChatWebsocketService, AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, IrisRateLimitService rateLimitService, IrisCourseChatSessionRepository irisCourseChatSessionRepository, PyrisPipelineService pyrisPipelineService, - ObjectMapper objectMapper, LLMTokenUsageService lLMTokenUsageService) { + ObjectMapper objectMapper) { super(irisSessionRepository, objectMapper); this.irisMessageService = irisMessageService; this.llmTokenUsageService = llmTokenUsageService; @@ -142,11 +142,11 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - llmTokenUsageService.saveIrisTokenUsage(job, null, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(job, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index 442955e8fcea..c25930c50464 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -172,24 +172,22 @@ private Optional getLatestSubmissionIfExists(ProgrammingE */ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { var session = (IrisExerciseChatSession) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); + IrisMessage savedMessage = null; if (statusUpdate.result() != null) { var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); - var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - if (statusUpdate.tokens() != null) { - llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getExercise(), session.getUser(), - session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); - } + savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - if (statusUpdate.tokens() != null) { - llmTokenUsageService.saveIrisTokenUsage(job, null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), - statusUpdate.tokens()); - } irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } + if (statusUpdate.tokens() != null) { + llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), + statusUpdate.tokens()); + } + updateLatestSuggestions(session, statusUpdate.suggestions()); } } diff --git a/src/main/resources/config/liquibase/changelog/20241012080932_changelog.xml b/src/main/resources/config/liquibase/changelog/20241014125521_changelog.xml similarity index 97% rename from src/main/resources/config/liquibase/changelog/20241012080932_changelog.xml rename to src/main/resources/config/liquibase/changelog/20241014125521_changelog.xml index e8f846219bb2..7d14ed33763e 100644 --- a/src/main/resources/config/liquibase/changelog/20241012080932_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241014125521_changelog.xml @@ -5,7 +5,7 @@ xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.29.xsd" objectQuotingStrategy="QUOTE_ONLY_RESERVED_WORDS"> - + diff --git a/src/main/resources/config/liquibase/master.xml b/src/main/resources/config/liquibase/master.xml index 8e0fa4057061..6d4e81ad432c 100644 --- a/src/main/resources/config/liquibase/master.xml +++ b/src/main/resources/config/liquibase/master.xml @@ -22,7 +22,7 @@ - + diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java index 7917f820dd02..676f8d881435 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -3,7 +3,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNoException; import static org.awaitility.Awaitility.await; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.verify; @@ -13,7 +12,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; import org.eclipse.jgit.api.errors.GitAPIException; @@ -156,15 +154,15 @@ void testTokenTrackingHandledExerciseChat() throws Exception { // Verify that the tokens are saved correctly List savedTokenUsages = captor.getValue(); - assertEquals(5, savedTokenUsages.size()); + assertThat(savedTokenUsages).hasSize(5); for (int i = 0; i < savedTokenUsages.size(); i++) { PyrisLLMCostDTO usage = savedTokenUsages.get(i); PyrisLLMCostDTO expectedCost = tokens.get(i); - assertEquals(expectedCost.numInputTokens(), usage.numInputTokens()); - assertEquals(expectedCost.costPerInputToken(), usage.costPerInputToken()); - assertEquals(expectedCost.numOutputTokens(), usage.numOutputTokens()); - assertEquals(expectedCost.costPerOutputToken(), usage.costPerOutputToken()); + assertThat(usage.numInputTokens()).isEqualTo(expectedCost.numInputTokens()); + assertThat(usage.costPerInputToken()).isEqualTo(expectedCost.costPerInputToken()); + assertThat(usage.numOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + assertThat(usage.costPerOutputToken()).isEqualTo(expectedCost.costPerOutputToken()); } } @@ -183,17 +181,17 @@ void testTokenTrackingSavedExerciseChat() { // Capture the saved token usages List returnedTokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, irisMessage, exercise, irisSession.getUser(), course, tokens); - assertEquals(5, returnedTokenUsages.size()); + assertThat(returnedTokenUsages).hasSize(5); for (int i = 0; i < returnedTokenUsages.size(); i++) { LLMTokenUsage usage = returnedTokenUsages.get(i); PyrisLLMCostDTO expectedCost = tokens.get(i); - assertEquals(expectedCost.modelInfo(), usage.getModel()); - assertEquals(expectedCost.numInputTokens(), usage.getNumInputTokens()); - assertEquals(expectedCost.numOutputTokens(), usage.getNumOutputTokens()); - assertEquals(expectedCost.costPerInputToken(), usage.getCostPerMillionInputTokens()); - assertEquals(expectedCost.costPerOutputToken(), usage.getCostPerMillionOutputTokens()); - assertEquals(expectedCost.pipeline(), usage.getServiceType()); + assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); + assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); + assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); + assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); } } @@ -226,15 +224,15 @@ void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { // Verify that the tokens are saved correctly List savedTokenUsages = captor.getValue(); - assertEquals(5, savedTokenUsages.size()); + assertThat(savedTokenUsages).hasSize(5); for (int i = 0; i < savedTokenUsages.size(); i++) { PyrisLLMCostDTO usage = savedTokenUsages.get(i); PyrisLLMCostDTO expectedCost = tokens.get(i); - assertEquals(expectedCost.numInputTokens(), usage.numInputTokens()); - assertEquals(expectedCost.costPerInputToken(), usage.costPerInputToken()); - assertEquals(expectedCost.numOutputTokens(), usage.numOutputTokens()); - assertEquals(expectedCost.costPerOutputToken(), usage.costPerOutputToken()); + assertThat(usage.numInputTokens()).isEqualTo(expectedCost.numInputTokens()); + assertThat(usage.costPerInputToken()).isEqualTo(expectedCost.costPerInputToken()); + assertThat(usage.numOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + assertThat(usage.costPerOutputToken()).isEqualTo(expectedCost.costPerOutputToken()); } } @@ -253,14 +251,7 @@ private IrisMessage createDefaultMockMessage(IrisSession irisSession) { } private IrisMessageContent createMockTextContent() { - String[] adjectives = { "happy", "sad", "angry", "funny", "silly", "crazy", "beautiful", "smart" }; - String[] nouns = { "dog", "cat", "house", "car", "book", "computer", "phone", "shoe" }; - - var rdm = ThreadLocalRandom.current(); - String randomAdjective = adjectives[rdm.nextInt(adjectives.length)]; - String randomNoun = nouns[rdm.nextInt(nouns.length)]; - - var text = "The " + randomAdjective + " " + randomNoun + " jumped over the lazy dog."; + var text = "The happy dog jumped over the lazy dog."; return new IrisTextMessageContent(text); } diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java index 282e4294eef2..f9948aa91ad3 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java @@ -67,7 +67,8 @@ void generateCompetencies_asEditor_shouldSucceed() throws Exception { List stages = List.of(new PyrisStageDTO("Generating Competencies", 10, PyrisStageState.DONE, null)); // In the real system, this would be triggered by Pyris via a REST call to the Artemis server - CompetencyExtractionJob job = new CompetencyExtractionJob("1", course.getId(), TEST_PREFIX + "editor1"); + String jobId = "testJobId"; + CompetencyExtractionJob job = new CompetencyExtractionJob(jobId, course.getId(), TEST_PREFIX + "editor1"); irisCompetencyGenerationService.handleStatusUpdate(job, new PyrisCompetencyStatusUpdateDTO(stages, recommendations, null)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PyrisCompetencyStatusUpdateDTO.class); From 84a60dcb9edfbdf74954ac9c027065143c9167bc Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Mon, 14 Oct 2024 13:46:37 +0200 Subject: [PATCH 10/30] minor comment changes, remove tokens from frontend --- .../core/service/LLMTokenUsageService.java | 28 ++++++++++++++++++- .../PyrisCompetencyStatusUpdateDTO.java | 1 + .../app/entities/iris/iris-message.model.ts | 4 --- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 409812ab2da0..a7ec19f261e6 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -69,15 +69,41 @@ public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, return tokenUsages; } - // Overloaded methods without optional parameters + /** + * Overloaded method to save token usage without message and exercise. + * + * @param job used to create a unique traceId to group multiple LLM calls + * @param user to map the token cost to a user + * @param course to map the token to a course + * @param tokens token cost list of type PyrisLLMCostDTO + * @return list of the saved data + */ public List saveIrisTokenUsage(PyrisJob job, User user, Course course, List tokens) { return saveIrisTokenUsage(job, null, null, user, course, tokens); } + /** + * Overloaded method to save token usage without exercise. + * + * @param job used to create a unique traceId to group multiple LLM calls + * @param message IrisMessage to map the usage to an IrisMessage + * @param user to map the token cost to a user + * @param course to map the token to a course + * @param tokens token cost list of type PyrisLLMCostDTO + * @return list of the saved data + */ public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, User user, Course course, List tokens) { return saveIrisTokenUsage(job, message, null, user, course, tokens); } + /** + * Overloaded method to save token usage without message, exercise and user. + * + * @param job used to create a unique traceId to group multiple LLM calls + * @param course to map the token to a course + * @param tokens token cost list of type PyrisLLMCostDTO + * @return list of the saved data + */ public List saveIrisTokenUsage(PyrisJob job, Course course, List tokens) { return saveIrisTokenUsage(job, null, null, null, course, tokens); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java index 5a1774eeeb61..65d4ecf5d3a6 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java @@ -14,6 +14,7 @@ * * @param stages List of stages of the generation process * @param result List of competencies recommendations that have been generated so far + * @param tokens List of token usages send by Pyris for tracking the token usage and cost */ @JsonInclude(JsonInclude.Include.NON_EMPTY) public record PyrisCompetencyStatusUpdateDTO(List stages, List result, List tokens) { diff --git a/src/main/webapp/app/entities/iris/iris-message.model.ts b/src/main/webapp/app/entities/iris/iris-message.model.ts index 04923788f1e6..d4b32977894f 100644 --- a/src/main/webapp/app/entities/iris/iris-message.model.ts +++ b/src/main/webapp/app/entities/iris/iris-message.model.ts @@ -13,8 +13,6 @@ export class IrisAssistantMessage implements BaseEntity { sentAt: dayjs.Dayjs; sender: IrisSender.LLM; helpful?: boolean; - num_input_tokens?: number; - num_output_tokens?: number; } export class IrisUserMessage implements BaseEntity { @@ -23,8 +21,6 @@ export class IrisUserMessage implements BaseEntity { sentAt?: dayjs.Dayjs; sender: IrisSender.USER; messageDifferentiator?: number; - num_input_tokens?: number; - num_output_tokens?: number; } export type IrisMessage = IrisAssistantMessage | IrisUserMessage; From 62dad8b442dd5183ad6c53beba6893671cd90703 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Mon, 14 Oct 2024 14:07:05 +0200 Subject: [PATCH 11/30] Fix github test fails --- .../iris/service/websocket/IrisChatWebsocketService.java | 1 + .../iris/IrisChatTokenTrackingIntegrationTest.java | 8 ++++---- .../iris/IrisTextExerciseChatMessageIntegrationTest.java | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java index 9f8e97a80be5..43e27543f020 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java @@ -62,6 +62,7 @@ public void sendStatusUpdate(IrisChatSession session, List stages * @param session the session to send the status update to * @param stages the stages to send * @param suggestions the suggestions to send + * @param tokens token usage and cost send by Pyris */ public void sendStatusUpdate(IrisChatSession session, List stages, List suggestions, List tokens) { var user = session.getUser(); diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java index 676f8d881435..d183ff48b06d 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -25,13 +25,14 @@ import org.springframework.security.test.context.support.WithMockUser; import org.springframework.util.LinkedMultiValueMap; -import de.tum.cit.aet.artemis.connector.IrisRequestMockProvider; +import de.tum.cit.aet.artemis.core.connector.IrisRequestMockProvider; import de.tum.cit.aet.artemis.core.domain.Course; import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.exercise.domain.Exercise; +import de.tum.cit.aet.artemis.exercise.participation.util.ParticipationUtilService; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageContent; import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; @@ -44,7 +45,6 @@ import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob; import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; -import de.tum.cit.aet.artemis.participation.ParticipationUtilService; import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise; import de.tum.cit.aet.artemis.programming.domain.ProgrammingExerciseStudentParticipation; import de.tum.cit.aet.artemis.programming.domain.ProjectType; @@ -136,7 +136,7 @@ void testTokenTrackingHandledExerciseChat() throws Exception { List doneStage = new ArrayList<>(); doneStage.add(new PyrisStageDTO("DoneTest", 10, PyrisStageState.DONE, "Done")); - irisRequestMockProvider.mockRunResponse(dto -> { + irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", doneStage, null, tokens)); @@ -206,7 +206,7 @@ void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { List failedStages = new ArrayList<>(); failedStages.add(new PyrisStageDTO("TestTokenFail", 10, PyrisStageState.ERROR, "Failed running pipeline")); - irisRequestMockProvider.mockRunResponse(dto -> { + irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, failedStages, null, tokens)); diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisTextExerciseChatMessageIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisTextExerciseChatMessageIntegrationTest.java index 7be2d0e8abc9..0366317fd557 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisTextExerciseChatMessageIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisTextExerciseChatMessageIntegrationTest.java @@ -398,7 +398,7 @@ public String toString() { private void sendStatus(String jobId, String result, List stages, List suggestions) throws Exception { var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); - request.postWithoutResponseBody("/api/public/pyris/pipelines/text-exercise-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions), + request.postWithoutResponseBody("/api/public/pyris/pipelines/text-exercise-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, null), HttpStatus.OK, headers); } } From 897d643015fed263d92aba2797056d5fdd2717e1 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Mon, 14 Oct 2024 16:36:11 +0200 Subject: [PATCH 12/30] Change servicetype to type String to prevent failures --- .../tum/cit/aet/artemis/core/domain/LLMTokenUsage.java | 9 +++------ .../iris/service/IrisCompetencyGenerationService.java | 2 +- .../iris/service/pyris/dto/data/PyrisLLMCostDTO.java | 4 +--- ...125521_changelog.xml => 20241014035241_changelog.xml} | 2 +- src/main/resources/config/liquibase/master.xml | 2 +- .../iris/IrisChatTokenTrackingIntegrationTest.java | 2 +- 6 files changed, 8 insertions(+), 13 deletions(-) rename src/main/resources/config/liquibase/changelog/{20241014125521_changelog.xml => 20241014035241_changelog.xml} (97%) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java index bba09ad0025d..9fdbe942038c 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java @@ -5,8 +5,6 @@ import jakarta.annotation.Nullable; import jakarta.persistence.Column; import jakarta.persistence.Entity; -import jakarta.persistence.EnumType; -import jakarta.persistence.Enumerated; import jakarta.persistence.JoinColumn; import jakarta.persistence.ManyToOne; import jakarta.persistence.Table; @@ -27,8 +25,7 @@ public class LLMTokenUsage extends DomainObject { @Column(name = "service") - @Enumerated(EnumType.STRING) - private LLMServiceType serviceType; + private String serviceType; @Column(name = "model") private String model; @@ -72,11 +69,11 @@ public class LLMTokenUsage extends DomainObject { @JoinColumn(name = "iris_message_id") private IrisMessage irisMessage; - public LLMServiceType getServiceType() { + public String getServiceType() { return serviceType; } - public void setServiceType(LLMServiceType serviceType) { + public void setServiceType(String serviceType) { this.serviceType = serviceType; } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 0a55bdcb0287..c792deabe1ab 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -38,10 +38,10 @@ public class IrisCompetencyGenerationService { public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, CourseRepository courseRepository, IrisWebsocketService websocketService, PyrisJobService pyrisJobService) { this.pyrisPipelineService = pyrisPipelineService; + this.llmTokenUsageService = llmTokenUsageService; this.courseRepository = courseRepository; this.websocketService = websocketService; this.pyrisJobService = pyrisJobService; - this.llmTokenUsageService = llmTokenUsageService; } /** diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java index 74f40cce6873..43c000a879ae 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java @@ -1,6 +1,4 @@ package de.tum.cit.aet.artemis.iris.service.pyris.dto.data; -import de.tum.cit.aet.artemis.core.domain.LLMServiceType; - -public record PyrisLLMCostDTO(String modelInfo, int numInputTokens, float costPerInputToken, int numOutputTokens, float costPerOutputToken, LLMServiceType pipeline) { +public record PyrisLLMCostDTO(String modelInfo, int numInputTokens, float costPerInputToken, int numOutputTokens, float costPerOutputToken, String pipeline) { } diff --git a/src/main/resources/config/liquibase/changelog/20241014125521_changelog.xml b/src/main/resources/config/liquibase/changelog/20241014035241_changelog.xml similarity index 97% rename from src/main/resources/config/liquibase/changelog/20241014125521_changelog.xml rename to src/main/resources/config/liquibase/changelog/20241014035241_changelog.xml index 7d14ed33763e..5d82a422f6f7 100644 --- a/src/main/resources/config/liquibase/changelog/20241014125521_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241014035241_changelog.xml @@ -5,7 +5,7 @@ xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.29.xsd" objectQuotingStrategy="QUOTE_ONLY_RESERVED_WORDS"> - + diff --git a/src/main/resources/config/liquibase/master.xml b/src/main/resources/config/liquibase/master.xml index b6143b42d8ce..23d8caf93e2c 100644 --- a/src/main/resources/config/liquibase/master.xml +++ b/src/main/resources/config/liquibase/master.xml @@ -22,13 +22,13 @@ - + diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java index d183ff48b06d..fdcff62a165f 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -239,7 +239,7 @@ void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { private List getMockLLMCosts() { List costs = new ArrayList<>(); for (int i = 0; i < 5; i++) { - costs.add(new PyrisLLMCostDTO("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, LLMServiceType.IRIS_CHAT_EXERCISE_MESSAGE)); + costs.add(new PyrisLLMCostDTO("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, LLMServiceType.IRIS_CHAT_EXERCISE_MESSAGE.toString())); } return costs; } From 1d10860b2e8251e304e7c7eb5846ce1726e6f18c Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Mon, 14 Oct 2024 16:39:33 +0200 Subject: [PATCH 13/30] Change servicetype to type String to prevent failures Make saveIrisTokenUsage saver --- .../tum/cit/aet/artemis/core/domain/LLMTokenUsage.java | 9 +++------ .../aet/artemis/core/service/LLMTokenUsageService.java | 8 +++++--- .../iris/service/IrisCompetencyGenerationService.java | 2 +- .../iris/service/pyris/dto/data/PyrisLLMCostDTO.java | 4 +--- ...125521_changelog.xml => 20241014035241_changelog.xml} | 2 +- src/main/resources/config/liquibase/master.xml | 2 +- .../iris/IrisChatTokenTrackingIntegrationTest.java | 2 +- 7 files changed, 13 insertions(+), 16 deletions(-) rename src/main/resources/config/liquibase/changelog/{20241014125521_changelog.xml => 20241014035241_changelog.xml} (97%) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java index bba09ad0025d..9fdbe942038c 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java @@ -5,8 +5,6 @@ import jakarta.annotation.Nullable; import jakarta.persistence.Column; import jakarta.persistence.Entity; -import jakarta.persistence.EnumType; -import jakarta.persistence.Enumerated; import jakarta.persistence.JoinColumn; import jakarta.persistence.ManyToOne; import jakarta.persistence.Table; @@ -27,8 +25,7 @@ public class LLMTokenUsage extends DomainObject { @Column(name = "service") - @Enumerated(EnumType.STRING) - private LLMServiceType serviceType; + private String serviceType; @Column(name = "model") private String model; @@ -72,11 +69,11 @@ public class LLMTokenUsage extends DomainObject { @JoinColumn(name = "iris_message_id") private IrisMessage irisMessage; - public LLMServiceType getServiceType() { + public String getServiceType() { return serviceType; } - public void setServiceType(LLMServiceType serviceType) { + public void setServiceType(String serviceType) { this.serviceType = serviceType; } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index a7ec19f261e6..54463a41cedb 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -51,18 +51,20 @@ public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, llmTokenUsage.setIrisMessage(message); llmTokenUsage.setTime(message.getSentAt()); } - llmTokenUsage.setServiceType(cost.pipeline()); - llmTokenUsage.setExercise(exercise); if (user != null) { llmTokenUsage.setUserId(user.getId()); } + if (job != null) { + llmTokenUsage.setTraceId(job.jobId()); + } + llmTokenUsage.setServiceType(cost.pipeline()); + llmTokenUsage.setExercise(exercise); llmTokenUsage.setCourse(course); llmTokenUsage.setNumInputTokens(cost.numInputTokens()); llmTokenUsage.setCostPerMillionInputTokens(cost.costPerInputToken()); llmTokenUsage.setNumOutputTokens(cost.numOutputTokens()); llmTokenUsage.setCostPerMillionOutputTokens(cost.costPerOutputToken()); llmTokenUsage.setModel(cost.modelInfo()); - llmTokenUsage.setTraceId(job.jobId()); tokenUsages.add(llmTokenUsage); } llmTokenUsageRepository.saveAll(tokenUsages); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 0a55bdcb0287..c792deabe1ab 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -38,10 +38,10 @@ public class IrisCompetencyGenerationService { public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, CourseRepository courseRepository, IrisWebsocketService websocketService, PyrisJobService pyrisJobService) { this.pyrisPipelineService = pyrisPipelineService; + this.llmTokenUsageService = llmTokenUsageService; this.courseRepository = courseRepository; this.websocketService = websocketService; this.pyrisJobService = pyrisJobService; - this.llmTokenUsageService = llmTokenUsageService; } /** diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java index 74f40cce6873..43c000a879ae 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java @@ -1,6 +1,4 @@ package de.tum.cit.aet.artemis.iris.service.pyris.dto.data; -import de.tum.cit.aet.artemis.core.domain.LLMServiceType; - -public record PyrisLLMCostDTO(String modelInfo, int numInputTokens, float costPerInputToken, int numOutputTokens, float costPerOutputToken, LLMServiceType pipeline) { +public record PyrisLLMCostDTO(String modelInfo, int numInputTokens, float costPerInputToken, int numOutputTokens, float costPerOutputToken, String pipeline) { } diff --git a/src/main/resources/config/liquibase/changelog/20241014125521_changelog.xml b/src/main/resources/config/liquibase/changelog/20241014035241_changelog.xml similarity index 97% rename from src/main/resources/config/liquibase/changelog/20241014125521_changelog.xml rename to src/main/resources/config/liquibase/changelog/20241014035241_changelog.xml index 7d14ed33763e..5d82a422f6f7 100644 --- a/src/main/resources/config/liquibase/changelog/20241014125521_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241014035241_changelog.xml @@ -5,7 +5,7 @@ xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.29.xsd" objectQuotingStrategy="QUOTE_ONLY_RESERVED_WORDS"> - + diff --git a/src/main/resources/config/liquibase/master.xml b/src/main/resources/config/liquibase/master.xml index b6143b42d8ce..23d8caf93e2c 100644 --- a/src/main/resources/config/liquibase/master.xml +++ b/src/main/resources/config/liquibase/master.xml @@ -22,13 +22,13 @@ - + diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java index d183ff48b06d..fdcff62a165f 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -239,7 +239,7 @@ void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { private List getMockLLMCosts() { List costs = new ArrayList<>(); for (int i = 0; i < 5; i++) { - costs.add(new PyrisLLMCostDTO("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, LLMServiceType.IRIS_CHAT_EXERCISE_MESSAGE)); + costs.add(new PyrisLLMCostDTO("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, LLMServiceType.IRIS_CHAT_EXERCISE_MESSAGE.toString())); } return costs; } From 86294c10e7857dee90291520c7b96822c77c0010 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Tue, 15 Oct 2024 13:00:34 +0200 Subject: [PATCH 14/30] Fix test failure by removing @SpyBean --- .../IrisChatTokenTrackingIntegrationTest.java | 53 ++++++++----------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java index fdcff62a165f..d0f71e0b7fec 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -3,9 +3,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNoException; import static org.awaitility.Awaitility.await; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isNull; -import static org.mockito.Mockito.verify; import java.io.IOException; import java.net.URISyntaxException; @@ -17,9 +14,7 @@ import org.eclipse.jgit.api.errors.GitAPIException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.test.mock.mockito.SpyBean; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.security.test.context.support.WithMockUser; @@ -29,9 +24,8 @@ import de.tum.cit.aet.artemis.core.domain.Course; import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; -import de.tum.cit.aet.artemis.core.domain.User; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRepository; import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; -import de.tum.cit.aet.artemis.exercise.domain.Exercise; import de.tum.cit.aet.artemis.exercise.participation.util.ParticipationUtilService; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageContent; @@ -61,9 +55,12 @@ class IrisChatTokenTrackingIntegrationTest extends AbstractIrisIntegrationTest { @Autowired private IrisMessageRepository irisMessageRepository; - @SpyBean + @Autowired private LLMTokenUsageService llmTokenUsageService; + @Autowired + private LLMTokenUsageRepository irisLLMTokenUsageRepository; + @Autowired private IrisRequestMockProvider irisRequestMockProvider; @@ -122,6 +119,8 @@ void initTestCase() throws GitAPIException, IOException, URISyntaxException { activateIrisGlobally(); activateIrisFor(course); activateIrisFor(exercise); + // Clean up the database + irisLLMTokenUsageRepository.deleteAll(); pipelineDone = new AtomicBoolean(false); } @@ -148,21 +147,18 @@ void testTokenTrackingHandledExerciseChat() throws Exception { await().until(pipelineDone::get); - // Capture the saved token usages - ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); - verify(llmTokenUsageService).saveIrisTokenUsage(any(PyrisJob.class), any(IrisMessage.class), any(Exercise.class), any(User.class), any(Course.class), captor.capture()); - - // Verify that the tokens are saved correctly - List savedTokenUsages = captor.getValue(); + List savedTokenUsages = irisLLMTokenUsageRepository.findAll(); assertThat(savedTokenUsages).hasSize(5); for (int i = 0; i < savedTokenUsages.size(); i++) { - PyrisLLMCostDTO usage = savedTokenUsages.get(i); + LLMTokenUsage usage = savedTokenUsages.get(i); PyrisLLMCostDTO expectedCost = tokens.get(i); - assertThat(usage.numInputTokens()).isEqualTo(expectedCost.numInputTokens()); - assertThat(usage.costPerInputToken()).isEqualTo(expectedCost.costPerInputToken()); - assertThat(usage.numOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); - assertThat(usage.costPerOutputToken()).isEqualTo(expectedCost.costPerOutputToken()); + assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); + assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); + assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); + assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); } } @@ -218,21 +214,18 @@ void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { await().until(pipelineDone::get); - // Capture the saved token usages - ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); - verify(llmTokenUsageService).saveIrisTokenUsage(any(PyrisJob.class), isNull(), any(Exercise.class), any(User.class), any(Course.class), captor.capture()); - - // Verify that the tokens are saved correctly - List savedTokenUsages = captor.getValue(); + List savedTokenUsages = irisLLMTokenUsageRepository.findAll(); assertThat(savedTokenUsages).hasSize(5); for (int i = 0; i < savedTokenUsages.size(); i++) { - PyrisLLMCostDTO usage = savedTokenUsages.get(i); + LLMTokenUsage usage = savedTokenUsages.get(i); PyrisLLMCostDTO expectedCost = tokens.get(i); - assertThat(usage.numInputTokens()).isEqualTo(expectedCost.numInputTokens()); - assertThat(usage.costPerInputToken()).isEqualTo(expectedCost.costPerInputToken()); - assertThat(usage.numOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); - assertThat(usage.costPerOutputToken()).isEqualTo(expectedCost.costPerOutputToken()); + assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); + assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); + assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); + assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); } } From 56b20e74280a26dd362d94d1e23cccd0d6c90aeb Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Tue, 15 Oct 2024 17:07:23 +0200 Subject: [PATCH 15/30] Update database to safe only IDs, fix competency Integration Test user --- .../artemis/core/domain/LLMTokenUsage.java | 52 +++++++------------ .../core/service/LLMTokenUsageService.java | 22 +++----- .../IrisCompetencyGenerationService.java | 6 +-- .../pyris/job/CompetencyExtractionJob.java | 9 ++-- ...gelog.xml => 20241015043720_changelog.xml} | 11 +--- .../resources/config/liquibase/master.xml | 2 +- ...isCompetencyGenerationIntegrationTest.java | 3 +- 7 files changed, 39 insertions(+), 66 deletions(-) rename src/main/resources/config/liquibase/changelog/{20241014035241_changelog.xml => 20241015043720_changelog.xml} (61%) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java index 9fdbe942038c..ca2f5cf9d240 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java @@ -5,19 +5,13 @@ import jakarta.annotation.Nullable; import jakarta.persistence.Column; import jakarta.persistence.Entity; -import jakarta.persistence.JoinColumn; -import jakarta.persistence.ManyToOne; import jakarta.persistence.Table; import org.hibernate.annotations.Cache; import org.hibernate.annotations.CacheConcurrencyStrategy; -import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; -import de.tum.cit.aet.artemis.exercise.domain.Exercise; -import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; - @Entity @Table(name = "llm_token_usage") @Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) @@ -43,16 +37,12 @@ public class LLMTokenUsage extends DomainObject { private float costPerMillionOutputTokens; @Nullable - @ManyToOne - @JsonIgnore - @JoinColumn(name = "course_id") - private Course course; + @Column(name = "course_id") + private Long courseId; @Nullable - @ManyToOne - @JsonIgnore - @JoinColumn(name = "exercise_id") - private Exercise exercise; + @Column(name = "exercise_id") + private Long exerciseId; @Column(name = "user_id") private long userId; @@ -64,10 +54,8 @@ public class LLMTokenUsage extends DomainObject { private String traceId; @Nullable - @ManyToOne - @JsonIgnore - @JoinColumn(name = "iris_message_id") - private IrisMessage irisMessage; + @Column(name = "iris_message_id") + private Long irisMessageId; public String getServiceType() { return serviceType; @@ -117,20 +105,20 @@ public void setNumOutputTokens(int numOutputTokens) { this.numOutputTokens = numOutputTokens; } - public Course getCourse() { - return course; + public Long getCourseId() { + return courseId; } - public void setCourse(Course course) { - this.course = course; + public void setCourseId(Long courseId) { + this.courseId = courseId; } - public Exercise getExercise() { - return exercise; + public Long getExercisIde() { + return exerciseId; } - public void setExercise(Exercise exercise) { - this.exercise = exercise; + public void setExerciseId(Long exerciseId) { + this.exerciseId = exerciseId; } public long getUserId() { @@ -157,18 +145,18 @@ public void setTraceId(String traceId) { this.traceId = traceId; } - public IrisMessage getIrisMessage() { - return irisMessage; + public Long getIrisMessageId() { + return irisMessageId; } - public void setIrisMessage(IrisMessage message) { - this.irisMessage = message; + public void setIrisMessageId(Long messageId) { + this.irisMessageId = messageId; } @Override public String toString() { return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", numInputTokens=" + numInputTokens + ", costPerMillionInputTokens=" - + costPerMillionInputTokens + ", numOutputTokens=" + numOutputTokens + ", costPerMillionOutputTokens=" + costPerMillionOutputTokens + ", course=" + course - + ", exercise=" + exercise + ", userId=" + userId + ", timestamp=" + time + ", traceId=" + traceId + ", irisMessage=" + irisMessage + '}'; + + costPerMillionInputTokens + ", numOutputTokens=" + numOutputTokens + ", costPerMillionOutputTokens=" + costPerMillionOutputTokens + ", course=" + courseId + + ", exercise=" + exerciseId + ", userId=" + userId + ", timestamp=" + time + ", traceId=" + traceId + ", irisMessage=" + irisMessageId + '}'; } } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 54463a41cedb..bc57d7552d99 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -48,7 +48,7 @@ public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, for (PyrisLLMCostDTO cost : tokens) { LLMTokenUsage llmTokenUsage = new LLMTokenUsage(); if (message != null) { - llmTokenUsage.setIrisMessage(message); + llmTokenUsage.setIrisMessageId(message.getId()); llmTokenUsage.setTime(message.getSentAt()); } if (user != null) { @@ -58,8 +58,12 @@ public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, llmTokenUsage.setTraceId(job.jobId()); } llmTokenUsage.setServiceType(cost.pipeline()); - llmTokenUsage.setExercise(exercise); - llmTokenUsage.setCourse(course); + if (exercise != null) { + llmTokenUsage.setExerciseId(exercise.getId()); + } + if (course != null) { + llmTokenUsage.setCourseId(course.getId()); + } llmTokenUsage.setNumInputTokens(cost.numInputTokens()); llmTokenUsage.setCostPerMillionInputTokens(cost.costPerInputToken()); llmTokenUsage.setNumOutputTokens(cost.numOutputTokens()); @@ -97,16 +101,4 @@ public List saveIrisTokenUsage(PyrisJob job, User user, Course co public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, User user, Course course, List tokens) { return saveIrisTokenUsage(job, message, null, user, course, tokens); } - - /** - * Overloaded method to save token usage without message, exercise and user. - * - * @param job used to create a unique traceId to group multiple LLM calls - * @param course to map the token to a course - * @param tokens token cost list of type PyrisLLMCostDTO - * @return list of the saved data - */ - public List saveIrisTokenUsage(PyrisJob job, Course course, List tokens) { - return saveIrisTokenUsage(job, null, null, null, course, tokens); - } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index c792deabe1ab..cc48a0a306c3 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -57,7 +57,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String pyrisPipelineService.executePipeline( "competency-extraction", "default", - pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user.getLogin())), + pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user)), executionDto -> new PyrisCompetencyExtractionPipelineExecutionDTO(executionDto, courseDescription, currentCompetencies, CompetencyTaxonomy.values(), 5), stages -> websocketService.send(user.getLogin(), websocketTopic(course.getId()), new PyrisCompetencyStatusUpdateDTO(stages, null, null)) ); @@ -73,9 +73,9 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); if (statusUpdate.tokens() != null) { - llmTokenUsageService.saveIrisTokenUsage(job, course, statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(job, job.user(), course, statusUpdate.tokens()); } - websocketService.send(job.userLogin(), websocketTopic(job.courseId()), statusUpdate); + websocketService.send(job.user().getLogin(), websocketTopic(job.courseId()), statusUpdate); } private static String websocketTopic(long courseId) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java index 26ab6427a020..136a1a5ae243 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java @@ -3,16 +3,17 @@ import com.fasterxml.jackson.annotation.JsonInclude; import de.tum.cit.aet.artemis.core.domain.Course; +import de.tum.cit.aet.artemis.core.domain.User; /** * A pyris job that extracts competencies from a course description. * - * @param jobId the job id - * @param courseId the course in which the competencies are being extracted - * @param userLogin the user login of the user who started the job + * @param jobId the job id + * @param courseId the course in which the competencies are being extracted + * @param user the user who started the job */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record CompetencyExtractionJob(String jobId, long courseId, String userLogin) implements PyrisJob { +public record CompetencyExtractionJob(String jobId, long courseId, User user) implements PyrisJob { @Override public boolean canAccess(Course course) { diff --git a/src/main/resources/config/liquibase/changelog/20241014035241_changelog.xml b/src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml similarity index 61% rename from src/main/resources/config/liquibase/changelog/20241014035241_changelog.xml rename to src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml index 5d82a422f6f7..ef16a3dff637 100644 --- a/src/main/resources/config/liquibase/changelog/20241014035241_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml @@ -5,7 +5,7 @@ xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.29.xsd" objectQuotingStrategy="QUOTE_ONLY_RESERVED_WORDS"> - + @@ -23,15 +23,6 @@ - - - \ No newline at end of file diff --git a/src/main/resources/config/liquibase/master.xml b/src/main/resources/config/liquibase/master.xml index 23d8caf93e2c..7dd075b5c855 100644 --- a/src/main/resources/config/liquibase/master.xml +++ b/src/main/resources/config/liquibase/master.xml @@ -28,7 +28,7 @@ - + diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java index f9948aa91ad3..24085a97f70c 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java @@ -68,7 +68,8 @@ void generateCompetencies_asEditor_shouldSucceed() throws Exception { // In the real system, this would be triggered by Pyris via a REST call to the Artemis server String jobId = "testJobId"; - CompetencyExtractionJob job = new CompetencyExtractionJob(jobId, course.getId(), TEST_PREFIX + "editor1"); + String userLogin = TEST_PREFIX + "editor1"; + CompetencyExtractionJob job = new CompetencyExtractionJob(jobId, course.getId(), userUtilService.getUserByLogin(userLogin)); irisCompetencyGenerationService.handleStatusUpdate(job, new PyrisCompetencyStatusUpdateDTO(stages, recommendations, null)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PyrisCompetencyStatusUpdateDTO.class); From 8a29c82468e6cf78dbe5b33b96da13e48bd456a2 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Wed, 16 Oct 2024 11:15:42 +0200 Subject: [PATCH 16/30] Implement builder pattern based on feedback --- .../core/service/LLMTokenUsageService.java | 133 ++++++++++++------ .../IrisCompetencyGenerationService.java | 2 +- .../session/IrisCourseChatSessionService.java | 5 +- .../IrisExerciseChatSessionService.java | 7 +- .../IrisChatTokenTrackingIntegrationTest.java | 4 +- 5 files changed, 100 insertions(+), 51 deletions(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index bc57d7552d99..897ca8278462 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -4,6 +4,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; +import java.util.function.Function; import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; @@ -34,36 +36,35 @@ public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) { * method saves the token usage to the database with a link to the IrisMessage * messages of the same job are grouped together by saving the job id as a trace id * - * @param job used to create a unique traceId to group multiple LLM calls - * @param message IrisMessage to map the usage to an IrisMessage - * @param exercise to map the token cost to an exercise - * @param user to map the token cost to a user - * @param course to map the token to a course - * @param tokens token cost list of type PyrisLLMCostDTO - * @return list of the saved data + * @param builderFunction of type Function using IrisTokenUsageBuilder + * @return saved LLMTokenUsage as a List */ - public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, Exercise exercise, User user, Course course, List tokens) { + public List saveIrisTokenUsage(Function builderFunction) { + IrisTokenUsageBuilder builder = builderFunction.apply(new IrisTokenUsageBuilder()); List tokenUsages = new ArrayList<>(); - + List tokens = builder.getTokens(); for (PyrisLLMCostDTO cost : tokens) { LLMTokenUsage llmTokenUsage = new LLMTokenUsage(); - if (message != null) { + + builder.getMessage().ifPresent(message -> { llmTokenUsage.setIrisMessageId(message.getId()); llmTokenUsage.setTime(message.getSentAt()); - } - if (user != null) { + }); + + builder.getUser().ifPresent(user -> { llmTokenUsage.setUserId(user.getId()); - } - if (job != null) { - llmTokenUsage.setTraceId(job.jobId()); - } - llmTokenUsage.setServiceType(cost.pipeline()); - if (exercise != null) { + }); + + builder.getExercise().ifPresent(exercise -> { llmTokenUsage.setExerciseId(exercise.getId()); - } - if (course != null) { + }); + + builder.getCourse().ifPresent(course -> { llmTokenUsage.setCourseId(course.getId()); - } + }); + + llmTokenUsage.setTraceId(builder.getJob().jobId()); + llmTokenUsage.setServiceType(cost.pipeline()); llmTokenUsage.setNumInputTokens(cost.numInputTokens()); llmTokenUsage.setCostPerMillionInputTokens(cost.costPerInputToken()); llmTokenUsage.setNumOutputTokens(cost.numOutputTokens()); @@ -76,29 +77,75 @@ public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, } /** - * Overloaded method to save token usage without message and exercise. - * - * @param job used to create a unique traceId to group multiple LLM calls - * @param user to map the token cost to a user - * @param course to map the token to a course - * @param tokens token cost list of type PyrisLLMCostDTO - * @return list of the saved data + * Class IrisTokenUsageBuilder to be used for saveIrisTokenUsage() */ - public List saveIrisTokenUsage(PyrisJob job, User user, Course course, List tokens) { - return saveIrisTokenUsage(job, null, null, user, course, tokens); - } + public static class IrisTokenUsageBuilder { - /** - * Overloaded method to save token usage without exercise. - * - * @param job used to create a unique traceId to group multiple LLM calls - * @param message IrisMessage to map the usage to an IrisMessage - * @param user to map the token cost to a user - * @param course to map the token to a course - * @param tokens token cost list of type PyrisLLMCostDTO - * @return list of the saved data - */ - public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, User user, Course course, List tokens) { - return saveIrisTokenUsage(job, message, null, user, course, tokens); + private PyrisJob job; + + private List tokens; + + private Optional course = Optional.empty(); + + private Optional message = Optional.empty(); + + private Optional exercise = Optional.empty(); + + private Optional user = Optional.empty(); + + public IrisTokenUsageBuilder withJob(PyrisJob job) { + this.job = job; + return this; + } + + public IrisTokenUsageBuilder withCourse(Course course) { + this.course = Optional.ofNullable(course); + return this; + } + + public IrisTokenUsageBuilder withTokens(List tokens) { + this.tokens = tokens; + return this; + } + + public IrisTokenUsageBuilder withMessage(IrisMessage message) { + this.message = Optional.ofNullable(message); + return this; + } + + public IrisTokenUsageBuilder withExercise(Exercise exercise) { + this.exercise = Optional.ofNullable(exercise); + return this; + } + + public IrisTokenUsageBuilder withUser(User user) { + this.user = Optional.ofNullable(user); + return this; + } + + // Getters + public PyrisJob getJob() { + return job; + } + + public List getTokens() { + return tokens; + } + + public Optional getCourse() { + return course; + } + + public Optional getMessage() { + return message; + } + + public Optional getExercise() { + return exercise; + } + + public Optional getUser() { + return user; + } } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index cc48a0a306c3..47f4e66d654a 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -73,7 +73,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); if (statusUpdate.tokens() != null) { - llmTokenUsageService.saveIrisTokenUsage(job, job.user(), course, statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(builder -> builder.withJob(job).withCourse(course).withUser(job.user()).withTokens(statusUpdate.tokens())); } websocketService.send(job.user().getLogin(), websocketTopic(job.courseId()), statusUpdate); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index 309825c10f82..324094b0b532 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -143,11 +143,12 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getUser(), session.getCourse(), statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage( + builder -> builder.withJob(job).withMessage(savedMessage).withUser(session.getUser()).withCourse(session.getCourse()).withTokens(statusUpdate.tokens())); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - llmTokenUsageService.saveIrisTokenUsage(job, session.getUser(), session.getCourse(), statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(builder -> builder.withJob(job).withUser(session.getUser()).withCourse(session.getCourse()).withTokens(statusUpdate.tokens())); irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index e8cb9e300fbd..329996cf9916 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -170,7 +170,7 @@ private Optional getLatestSubmissionIfExists(ProgrammingE */ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { var session = (IrisExerciseChatSession) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); - IrisMessage savedMessage = null; + IrisMessage savedMessage; if (statusUpdate.result() != null) { var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); @@ -178,12 +178,13 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { + savedMessage = null; irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } if (statusUpdate.tokens() != null) { - llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), - statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(builder -> builder.withJob(job).withMessage(savedMessage).withExercise(session.getExercise()).withUser(session.getUser()) + .withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember()).withTokens(statusUpdate.tokens())); } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java index d0f71e0b7fec..b5640cacb6c4 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -175,8 +175,8 @@ void testTokenTrackingSavedExerciseChat() { var tokens = getMockLLMCosts(); // Capture the saved token usages - List returnedTokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, irisMessage, exercise, irisSession.getUser(), course, tokens); - + List returnedTokenUsages = llmTokenUsageService.saveIrisTokenUsage( + builder -> builder.withJob(job).withMessage(irisMessage).withExercise(exercise).withUser(irisSession.getUser()).withCourse(course).withTokens(tokens)); assertThat(returnedTokenUsages).hasSize(5); for (int i = 0; i < returnedTokenUsages.size(); i++) { LLMTokenUsage usage = returnedTokenUsages.get(i); From abbd28f313bc9e6dbd3426ff3257b32477b742c4 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Wed, 16 Oct 2024 17:34:28 +0200 Subject: [PATCH 17/30] Update database migration with foreign keys and on delete null --- .../changelog/20241015043720_changelog.xml | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml b/src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml index ef16a3dff637..f31bdf379324 100644 --- a/src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml @@ -23,6 +23,34 @@ - - \ No newline at end of file + + + + + + + + + + From 8d34428cb21498aa2af77174ac5fa4a479b6fff7 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Fri, 18 Oct 2024 18:01:34 +0200 Subject: [PATCH 18/30] Rework database, update saveLLMTokens method --- .../aet/artemis/core/domain/LLMRequest.java | 4 + .../artemis/core/domain/LLMServiceType.java | 25 +- .../artemis/core/domain/LLMTokenUsage.java | 162 ------- .../core/domain/LLMTokenUsageRequest.java | 95 ++++ .../core/domain/LLMTokenUsageTrace.java | 105 +++++ .../repository/LLMTokenUsageRepository.java | 14 - .../LLMTokenUsageTraceRepository.java | 10 + .../core/service/LLMTokenUsageService.java | 132 +++--- .../iris/dto/IrisChatWebsocketDTO.java | 6 +- .../IrisCompetencyGenerationService.java | 3 +- .../dto/chat/PyrisChatStatusUpdateDTO.java | 4 +- .../PyrisCompetencyStatusUpdateDTO.java | 4 +- .../session/IrisCourseChatSessionService.java | 7 +- .../IrisExerciseChatSessionService.java | 11 +- .../websocket/IrisChatWebsocketService.java | 4 +- .../changelog/20241015043720_changelog.xml | 56 --- .../changelog/20241018053210_changelog.xml | 37 ++ .../resources/config/liquibase/master.xml | 2 +- .../IrisChatTokenTrackingIntegrationTest.java | 420 +++++++----------- 19 files changed, 491 insertions(+), 610 deletions(-) create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java delete mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java delete mode 100644 src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java delete mode 100644 src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml create mode 100644 src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java new file mode 100644 index 000000000000..bc3ff7bbe23f --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java @@ -0,0 +1,4 @@ +package de.tum.cit.aet.artemis.core.domain; + +public record LLMRequest(String model, int numInputTokens, float costPerMillionInputToken, int numOutputTokens, float costPerMillionOutputToken, String pipelineId) { +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java index f7e179ccacb6..22465bc57b5f 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java @@ -4,28 +4,5 @@ * Enum representing different types of LLM (Large Language Model) services used in the system. */ public enum LLMServiceType { - /** Athena service for preliminary feedback */ - ATHENA_PRELIMINARY_FEEDBACK, - /** Athena service for feedback suggestions */ - ATHENA_FEEDBACK_SUGGESTION, - /** Iris service for code feedback */ - IRIS_CODE_FEEDBACK, - /** Iris service for course chat messages */ - IRIS_CHAT_COURSE_MESSAGE, - /** Iris service for exercise chat messages */ - IRIS_CHAT_EXERCISE_MESSAGE, - /** Iris service for interaction suggestions */ - IRIS_INTERACTION_SUGGESTION, - /** Iris service for lecture chat messages */ - IRIS_CHAT_LECTURE_MESSAGE, - /** Iris service for competency generation */ - IRIS_COMPETENCY_GENERATION, - /** Iris service for citation pipeline */ - IRIS_CITATION_PIPELINE, - /** Iris service for lecture retrieval pipeline */ - IRIS_LECTURE_RETRIEVAL_PIPELINE, - /** Iris service for lecture ingestion */ - IRIS_LECTURE_INGESTION, - /** Default value when the service type is not set */ - NOT_SET + IRIS, ATHENA } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java deleted file mode 100644 index ca2f5cf9d240..000000000000 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java +++ /dev/null @@ -1,162 +0,0 @@ -package de.tum.cit.aet.artemis.core.domain; - -import java.time.ZonedDateTime; - -import jakarta.annotation.Nullable; -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Table; - -import org.hibernate.annotations.Cache; -import org.hibernate.annotations.CacheConcurrencyStrategy; - -import com.fasterxml.jackson.annotation.JsonInclude; - -@Entity -@Table(name = "llm_token_usage") -@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) -@JsonInclude(JsonInclude.Include.NON_EMPTY) -public class LLMTokenUsage extends DomainObject { - - @Column(name = "service") - private String serviceType; - - @Column(name = "model") - private String model; - - @Column(name = "num_input_tokens") - private int numInputTokens; - - @Column(name = "cost_per_million_input_tokens") - private float costPerMillionInputTokens; - - @Column(name = "num_output_tokens") - private int numOutputTokens; - - @Column(name = "cost_per_million_output_tokens") - private float costPerMillionOutputTokens; - - @Nullable - @Column(name = "course_id") - private Long courseId; - - @Nullable - @Column(name = "exercise_id") - private Long exerciseId; - - @Column(name = "user_id") - private long userId; - - @Column(name = "time") - private ZonedDateTime time = ZonedDateTime.now(); - - @Column(name = "trace_id") - private String traceId; - - @Nullable - @Column(name = "iris_message_id") - private Long irisMessageId; - - public String getServiceType() { - return serviceType; - } - - public void setServiceType(String serviceType) { - this.serviceType = serviceType; - } - - public String getModel() { - return model; - } - - public void setModel(String model) { - this.model = model; - } - - public float getCostPerMillionInputTokens() { - return costPerMillionInputTokens; - } - - public void setCostPerMillionInputTokens(float costPerMillionInputToken) { - this.costPerMillionInputTokens = costPerMillionInputToken; - } - - public float getCostPerMillionOutputTokens() { - return costPerMillionOutputTokens; - } - - public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) { - this.costPerMillionOutputTokens = costPerMillionOutputToken; - } - - public int getNumInputTokens() { - return numInputTokens; - } - - public void setNumInputTokens(int numInputTokens) { - this.numInputTokens = numInputTokens; - } - - public int getNumOutputTokens() { - return numOutputTokens; - } - - public void setNumOutputTokens(int numOutputTokens) { - this.numOutputTokens = numOutputTokens; - } - - public Long getCourseId() { - return courseId; - } - - public void setCourseId(Long courseId) { - this.courseId = courseId; - } - - public Long getExercisIde() { - return exerciseId; - } - - public void setExerciseId(Long exerciseId) { - this.exerciseId = exerciseId; - } - - public long getUserId() { - return userId; - } - - public void setUserId(long userId) { - this.userId = userId; - } - - public ZonedDateTime getTime() { - return time; - } - - public void setTime(ZonedDateTime time) { - this.time = time; - } - - public String getTraceId() { - return traceId; - } - - public void setTraceId(String traceId) { - this.traceId = traceId; - } - - public Long getIrisMessageId() { - return irisMessageId; - } - - public void setIrisMessageId(Long messageId) { - this.irisMessageId = messageId; - } - - @Override - public String toString() { - return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", numInputTokens=" + numInputTokens + ", costPerMillionInputTokens=" - + costPerMillionInputTokens + ", numOutputTokens=" + numOutputTokens + ", costPerMillionOutputTokens=" + costPerMillionOutputTokens + ", course=" + courseId - + ", exercise=" + exerciseId + ", userId=" + userId + ", timestamp=" + time + ", traceId=" + traceId + ", irisMessage=" + irisMessageId + '}'; - } -} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java new file mode 100644 index 000000000000..186b5d049183 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java @@ -0,0 +1,95 @@ +package de.tum.cit.aet.artemis.core.domain; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.Table; + +import org.hibernate.annotations.Cache; +import org.hibernate.annotations.CacheConcurrencyStrategy; + +import com.fasterxml.jackson.annotation.JsonInclude; + +@Entity +@Table(name = "llm_token_usage_request") +@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) +@JsonInclude(JsonInclude.Include.NON_EMPTY) +public class LLMTokenUsageRequest extends DomainObject { + + @Column(name = "model") + private String model; + + @Column(name = "service_pipeline_id") + private String servicePipelineId; + + @Column(name = "num_input_tokens") + private int numInputTokens; + + @Column(name = "cost_per_million_input_tokens") + private float costPerMillionInputTokens; + + @Column(name = "num_output_tokens") + private int numOutputTokens; + + @Column(name = "cost_per_million_output_tokens") + private float costPerMillionOutputTokens; + + @ManyToOne + private LLMTokenUsageTrace trace; + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getServicePipelineId() { + return servicePipelineId; + } + + public void setServicePipelineId(String servicePipelineId) { + this.servicePipelineId = servicePipelineId; + } + + public float getCostPerMillionInputTokens() { + return costPerMillionInputTokens; + } + + public void setCostPerMillionInputTokens(float costPerMillionInputToken) { + this.costPerMillionInputTokens = costPerMillionInputToken; + } + + public float getCostPerMillionOutputTokens() { + return costPerMillionOutputTokens; + } + + public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) { + this.costPerMillionOutputTokens = costPerMillionOutputToken; + } + + public int getNumInputTokens() { + return numInputTokens; + } + + public void setNumInputTokens(int numInputTokens) { + this.numInputTokens = numInputTokens; + } + + public int getNumOutputTokens() { + return numOutputTokens; + } + + public void setNumOutputTokens(int numOutputTokens) { + this.numOutputTokens = numOutputTokens; + } + + public LLMTokenUsageTrace getTraceId() { + return trace; + } + + public void setTrace(LLMTokenUsageTrace trace) { + this.trace = trace; + } +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java new file mode 100644 index 000000000000..0294c322c0e9 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java @@ -0,0 +1,105 @@ +package de.tum.cit.aet.artemis.core.domain; + +import java.time.ZonedDateTime; +import java.util.HashSet; +import java.util.Set; + +import jakarta.annotation.Nullable; +import jakarta.persistence.CascadeType; +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.FetchType; +import jakarta.persistence.OneToMany; +import jakarta.persistence.Table; + +import org.hibernate.annotations.Cache; +import org.hibernate.annotations.CacheConcurrencyStrategy; + +import com.fasterxml.jackson.annotation.JsonInclude; + +@Entity +@Table(name = "llm_token_usage_trace") +@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) +@JsonInclude(JsonInclude.Include.NON_EMPTY) +public class LLMTokenUsageTrace extends DomainObject { + + @Column(name = "service") + private LLMServiceType serviceType; + + @Nullable + @Column(name = "course_id") + private Long courseId; + + @Nullable + @Column(name = "exercise_id") + private Long exerciseId; + + @Column(name = "user_id") + private Long userId; + + @Column(name = "time") + private ZonedDateTime time = ZonedDateTime.now(); + + @Nullable + @Column(name = "iris_message_id") + private Long irisMessageId; + + @OneToMany(mappedBy = "trace", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true) + private Set llmRequests = new HashSet<>(); + + public LLMServiceType getServiceType() { + return serviceType; + } + + public void setServiceType(LLMServiceType serviceType) { + this.serviceType = serviceType; + } + + public Long getCourseId() { + return courseId; + } + + public void setCourseId(Long courseId) { + this.courseId = courseId; + } + + public Long getExerciseId() { + return exerciseId; + } + + public void setExerciseId(Long exerciseId) { + this.exerciseId = exerciseId; + } + + public long getUserId() { + return userId; + } + + public void setUserId(long userId) { + this.userId = userId; + } + + public ZonedDateTime getTime() { + return time; + } + + public void setTime(ZonedDateTime time) { + this.time = time; + } + + public Set getLLMRequests() { + return llmRequests; + } + + public void setLlmRequests(Set llmRequests) { + this.llmRequests = llmRequests; + } + + public Long getIrisMessageId() { + return irisMessageId; + } + + public void setIrisMessageId(Long messageId) { + this.irisMessageId = messageId; + } +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java deleted file mode 100644 index 2e6d9f1902e1..000000000000 --- a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java +++ /dev/null @@ -1,14 +0,0 @@ -package de.tum.cit.aet.artemis.core.repository; - -import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS; - -import org.springframework.context.annotation.Profile; -import org.springframework.stereotype.Repository; - -import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; -import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; - -@Repository -@Profile(PROFILE_IRIS) -public interface LLMTokenUsageRepository extends ArtemisJpaRepository { -} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java new file mode 100644 index 000000000000..602b5686f417 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java @@ -0,0 +1,10 @@ +package de.tum.cit.aet.artemis.core.repository; + +import org.springframework.stereotype.Repository; + +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; +import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; + +@Repository +public interface LLMTokenUsageTraceRepository extends ArtemisJpaRepository { +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 897ca8278462..d11d08ba0afc 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -1,143 +1,111 @@ package de.tum.cit.aet.artemis.core.service; -import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS; - -import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.function.Function; -import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; import de.tum.cit.aet.artemis.core.domain.Course; -import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; import de.tum.cit.aet.artemis.core.domain.User; -import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRepository; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageTraceRepository; import de.tum.cit.aet.artemis.exercise.domain.Exercise; -import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob; /** * Service for managing the LLMTokenUsage by all LLMs in Artemis */ @Service -@Profile(PROFILE_IRIS) public class LLMTokenUsageService { - private final LLMTokenUsageRepository llmTokenUsageRepository; + private final LLMTokenUsageTraceRepository llmTokenUsageTraceRepository; - public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) { - this.llmTokenUsageRepository = llmTokenUsageRepository; + public LLMTokenUsageService(LLMTokenUsageTraceRepository llmTokenUsageTraceRepository) { + this.llmTokenUsageTraceRepository = llmTokenUsageTraceRepository; } /** - * method saves the token usage to the database with a link to the IrisMessage - * messages of the same job are grouped together by saving the job id as a trace id + * method saves the token usage to the database * + * @param llmRequests List of LLM requests + * @param serviceType type of the LLM service * @param builderFunction of type Function using IrisTokenUsageBuilder * @return saved LLMTokenUsage as a List */ - public List saveIrisTokenUsage(Function builderFunction) { - IrisTokenUsageBuilder builder = builderFunction.apply(new IrisTokenUsageBuilder()); - List tokenUsages = new ArrayList<>(); - List tokens = builder.getTokens(); - for (PyrisLLMCostDTO cost : tokens) { - LLMTokenUsage llmTokenUsage = new LLMTokenUsage(); - - builder.getMessage().ifPresent(message -> { - llmTokenUsage.setIrisMessageId(message.getId()); - llmTokenUsage.setTime(message.getSentAt()); - }); - - builder.getUser().ifPresent(user -> { - llmTokenUsage.setUserId(user.getId()); - }); - - builder.getExercise().ifPresent(exercise -> { - llmTokenUsage.setExerciseId(exercise.getId()); - }); - - builder.getCourse().ifPresent(course -> { - llmTokenUsage.setCourseId(course.getId()); - }); - - llmTokenUsage.setTraceId(builder.getJob().jobId()); - llmTokenUsage.setServiceType(cost.pipeline()); - llmTokenUsage.setNumInputTokens(cost.numInputTokens()); - llmTokenUsage.setCostPerMillionInputTokens(cost.costPerInputToken()); - llmTokenUsage.setNumOutputTokens(cost.numOutputTokens()); - llmTokenUsage.setCostPerMillionOutputTokens(cost.costPerOutputToken()); - llmTokenUsage.setModel(cost.modelInfo()); - tokenUsages.add(llmTokenUsage); + public LLMTokenUsageTrace saveLLMTokenUsage(List llmRequests, LLMServiceType serviceType, Function builderFunction) { + LLMTokenUsageTrace llmTokenUsageTrace = new LLMTokenUsageTrace(); + llmTokenUsageTrace.setServiceType(serviceType); + + LLMTokenUsageBuilder builder = builderFunction.apply(new LLMTokenUsageBuilder()); + builder.getIrisMessageID().ifPresent(llmTokenUsageTrace::setIrisMessageId); + builder.getUser().ifPresent(user -> { + llmTokenUsageTrace.setUserId(user.getId()); + }); + builder.getExercise().ifPresent(exercise -> { + llmTokenUsageTrace.setExerciseId(exercise.getId()); + }); + builder.getCourse().ifPresent(course -> { + llmTokenUsageTrace.setCourseId(course.getId()); + }); + + Set llmRequestsSet = llmTokenUsageTrace.getLLMRequests(); + for (LLMRequest llmRequest : llmRequests) { + LLMTokenUsageRequest llmTokenUsageRequest = new LLMTokenUsageRequest(); + llmTokenUsageRequest.setModel(llmRequest.model()); + llmTokenUsageRequest.setNumInputTokens(llmRequest.numInputTokens()); + llmTokenUsageRequest.setNumOutputTokens(llmRequest.numOutputTokens()); + llmTokenUsageRequest.setCostPerMillionInputTokens(llmRequest.costPerMillionInputToken()); + llmTokenUsageRequest.setCostPerMillionOutputTokens(llmRequest.costPerMillionOutputToken()); + llmTokenUsageRequest.setServicePipelineId(llmRequest.pipelineId()); + llmTokenUsageRequest.setTrace(llmTokenUsageTrace); + llmRequestsSet.add(llmTokenUsageRequest); } - llmTokenUsageRepository.saveAll(tokenUsages); - return tokenUsages; + return llmTokenUsageTraceRepository.save(llmTokenUsageTrace); } /** - * Class IrisTokenUsageBuilder to be used for saveIrisTokenUsage() + * Class LLMTokenUsageBuilder to be used for saveLLMTokenUsage() */ - public static class IrisTokenUsageBuilder { - - private PyrisJob job; - - private List tokens; + public static class LLMTokenUsageBuilder { private Optional course = Optional.empty(); - private Optional message = Optional.empty(); + private Optional irisMessageID = Optional.empty(); private Optional exercise = Optional.empty(); private Optional user = Optional.empty(); - public IrisTokenUsageBuilder withJob(PyrisJob job) { - this.job = job; - return this; - } - - public IrisTokenUsageBuilder withCourse(Course course) { + public LLMTokenUsageBuilder withCourse(Course course) { this.course = Optional.ofNullable(course); return this; } - public IrisTokenUsageBuilder withTokens(List tokens) { - this.tokens = tokens; + public LLMTokenUsageBuilder withIrisMessageID(Long irisMessageID) { + this.irisMessageID = Optional.ofNullable(irisMessageID); return this; } - public IrisTokenUsageBuilder withMessage(IrisMessage message) { - this.message = Optional.ofNullable(message); - return this; - } - - public IrisTokenUsageBuilder withExercise(Exercise exercise) { + public LLMTokenUsageBuilder withExercise(Exercise exercise) { this.exercise = Optional.ofNullable(exercise); return this; } - public IrisTokenUsageBuilder withUser(User user) { + public LLMTokenUsageBuilder withUser(User user) { this.user = Optional.ofNullable(user); return this; } - // Getters - public PyrisJob getJob() { - return job; - } - - public List getTokens() { - return tokens; - } - public Optional getCourse() { return course; } - public Optional getMessage() { - return message; + public Optional getIrisMessageID() { + return irisMessageID; } public Optional getExercise() { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java index 3663e372c844..9057b8229fb5 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java @@ -7,9 +7,9 @@ import com.fasterxml.jackson.annotation.JsonInclude; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; /** @@ -22,7 +22,7 @@ */ @JsonInclude(JsonInclude.Include.NON_EMPTY) public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List stages, - List suggestions, List tokens) { + List suggestions, List tokens) { /** * Creates a new IrisWebsocketDTO instance with the given parameters @@ -33,7 +33,7 @@ public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage me * @param stages the stages of the Pyris pipeline */ public IrisChatWebsocketDTO(@Nullable IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List stages, List suggestions, - List tokens) { + List tokens) { this(determineType(message), message, rateLimitInfo, stages, suggestions, tokens); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 47f4e66d654a..88bb6f37d618 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -7,6 +7,7 @@ import de.tum.cit.aet.artemis.atlas.domain.competency.CompetencyTaxonomy; import de.tum.cit.aet.artemis.core.domain.Course; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.repository.CourseRepository; import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; @@ -73,7 +74,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); if (statusUpdate.tokens() != null) { - llmTokenUsageService.saveIrisTokenUsage(builder -> builder.withJob(job).withCourse(course).withUser(job.user()).withTokens(statusUpdate.tokens())); + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withCourse(course).withUser(job.user())); } websocketService.send(job.user().getLogin(), websocketTopic(job.courseId()), statusUpdate); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java index 73a9b5603477..5a1024c6315b 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java @@ -4,9 +4,9 @@ import com.fasterxml.jackson.annotation.JsonInclude; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record PyrisChatStatusUpdateDTO(String result, List stages, List suggestions, List tokens) { +public record PyrisChatStatusUpdateDTO(String result, List stages, List suggestions, List tokens) { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java index 65d4ecf5d3a6..465c8e5edb65 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java @@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; /** @@ -17,5 +17,5 @@ * @param tokens List of token usages send by Pyris for tracking the token usage and cost */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record PyrisCompetencyStatusUpdateDTO(List stages, List result, List tokens) { +public record PyrisCompetencyStatusUpdateDTO(List stages, List result, List tokens) { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index 324094b0b532..b4fc6f4ff8ad 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -15,6 +15,7 @@ import de.tum.cit.aet.artemis.atlas.domain.competency.CompetencyJol; import de.tum.cit.aet.artemis.core.domain.Course; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.exception.AccessForbiddenException; import de.tum.cit.aet.artemis.core.security.Role; @@ -143,12 +144,12 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - llmTokenUsageService.saveIrisTokenUsage( - builder -> builder.withJob(job).withMessage(savedMessage).withUser(session.getUser()).withCourse(session.getCourse()).withTokens(statusUpdate.tokens())); + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, + builder -> builder.withIrisMessageID(savedMessage.getId()).withUser(session.getUser()).withCourse(session.getCourse())); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - llmTokenUsageService.saveIrisTokenUsage(builder -> builder.withJob(job).withUser(session.getUser()).withCourse(session.getCourse()).withTokens(statusUpdate.tokens())); + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withUser(session.getUser()).withCourse(session.getCourse())); irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index 329996cf9916..a6105eb215b1 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -10,6 +10,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.exception.AccessForbiddenException; import de.tum.cit.aet.artemis.core.exception.ConflictException; @@ -183,8 +184,14 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta } if (statusUpdate.tokens() != null) { - llmTokenUsageService.saveIrisTokenUsage(builder -> builder.withJob(job).withMessage(savedMessage).withExercise(session.getExercise()).withUser(session.getUser()) - .withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember()).withTokens(statusUpdate.tokens())); + if (savedMessage != null) { + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withIrisMessageID(savedMessage.getId()) + .withExercise(session.getExercise()).withUser(session.getUser()).withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember())); + } + else { + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withExercise(session.getExercise()) + .withUser(session.getUser()).withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember())); + } } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java index 43e27543f020..d6625dcc6f40 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java @@ -7,11 +7,11 @@ import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.domain.session.IrisChatSession; import de.tum.cit.aet.artemis.iris.dto.IrisChatWebsocketDTO; import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; @Service @@ -64,7 +64,7 @@ public void sendStatusUpdate(IrisChatSession session, List stages * @param suggestions the suggestions to send * @param tokens token usage and cost send by Pyris */ - public void sendStatusUpdate(IrisChatSession session, List stages, List suggestions, List tokens) { + public void sendStatusUpdate(IrisChatSession session, List stages, List suggestions, List tokens) { var user = session.getUser(); var rateLimitInfo = rateLimitService.getRateLimitInformation(user); var topic = "" + session.getId(); // Todo: add more specific topic diff --git a/src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml b/src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml deleted file mode 100644 index f31bdf379324..000000000000 --- a/src/main/resources/config/liquibase/changelog/20241015043720_changelog.xml +++ /dev/null @@ -1,56 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml b/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml new file mode 100644 index 000000000000..9414994c0576 --- /dev/null +++ b/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/main/resources/config/liquibase/master.xml b/src/main/resources/config/liquibase/master.xml index 7dd075b5c855..109eefaa1bbf 100644 --- a/src/main/resources/config/liquibase/master.xml +++ b/src/main/resources/config/liquibase/master.xml @@ -28,7 +28,7 @@ - + diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java index b5640cacb6c4..7e6d767449c4 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -1,256 +1,164 @@ -package de.tum.cit.aet.artemis.iris; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNoException; -import static org.awaitility.Awaitility.await; - -import java.io.IOException; -import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; - -import org.eclipse.jgit.api.errors.GitAPIException; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpStatus; -import org.springframework.security.test.context.support.WithMockUser; -import org.springframework.util.LinkedMultiValueMap; - -import de.tum.cit.aet.artemis.core.connector.IrisRequestMockProvider; -import de.tum.cit.aet.artemis.core.domain.Course; -import de.tum.cit.aet.artemis.core.domain.LLMServiceType; -import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; -import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRepository; -import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; -import de.tum.cit.aet.artemis.exercise.participation.util.ParticipationUtilService; -import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; -import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageContent; -import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; -import de.tum.cit.aet.artemis.iris.domain.session.IrisSession; -import de.tum.cit.aet.artemis.iris.repository.IrisMessageRepository; -import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; -import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob; -import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; -import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise; -import de.tum.cit.aet.artemis.programming.domain.ProgrammingExerciseStudentParticipation; -import de.tum.cit.aet.artemis.programming.domain.ProjectType; -import de.tum.cit.aet.artemis.programming.domain.SolutionProgrammingExerciseParticipation; -import de.tum.cit.aet.artemis.programming.domain.TemplateProgrammingExerciseParticipation; - -class IrisChatTokenTrackingIntegrationTest extends AbstractIrisIntegrationTest { - - private static final String TEST_PREFIX = "irischattokentrackingintegration"; - - @Autowired - private IrisExerciseChatSessionService irisExerciseChatSessionService; - - @Autowired - private IrisMessageRepository irisMessageRepository; - - @Autowired - private LLMTokenUsageService llmTokenUsageService; - - @Autowired - private LLMTokenUsageRepository irisLLMTokenUsageRepository; - - @Autowired - private IrisRequestMockProvider irisRequestMockProvider; - - @Autowired - private ParticipationUtilService participationUtilService; - - @Autowired - private PyrisJobService pyrisJobService; - - private ProgrammingExercise exercise; - - private Course course; - - private AtomicBoolean pipelineDone; - - @BeforeEach - void initTestCase() throws GitAPIException, IOException, URISyntaxException { - userUtilService.addUsers(TEST_PREFIX, 2, 0, 0, 0); - - course = programmingExerciseUtilService.addCourseWithOneProgrammingExercise(); - exercise = exerciseUtilService.getFirstExerciseWithType(course, ProgrammingExercise.class); - String projectKey = exercise.getProjectKey(); - exercise.setProjectType(ProjectType.PLAIN_GRADLE); - exercise.setTestRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + projectKey.toLowerCase() + "-tests.git"); - programmingExerciseBuildConfigRepository.save(exercise.getBuildConfig()); - programmingExerciseRepository.save(exercise); - exercise = programmingExerciseRepository.findWithAllParticipationsAndBuildConfigById(exercise.getId()).orElseThrow(); - - // Set the correct repository URIs for the template and the solution participation. - String templateRepositorySlug = projectKey.toLowerCase() + "-exercise"; - TemplateProgrammingExerciseParticipation templateParticipation = exercise.getTemplateParticipation(); - templateParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + templateRepositorySlug + ".git"); - templateProgrammingExerciseParticipationRepository.save(templateParticipation); - String solutionRepositorySlug = projectKey.toLowerCase() + "-solution"; - SolutionProgrammingExerciseParticipation solutionParticipation = exercise.getSolutionParticipation(); - solutionParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + solutionRepositorySlug + ".git"); - solutionProgrammingExerciseParticipationRepository.save(solutionParticipation); - - String assignmentRepositorySlug = projectKey.toLowerCase() + "-" + TEST_PREFIX + "student1"; - - // Add a participation for student1. - ProgrammingExerciseStudentParticipation studentParticipation = participationUtilService.addStudentParticipationForProgrammingExercise(exercise, TEST_PREFIX + "student1"); - studentParticipation.setRepositoryUri(String.format(localVCBaseUrl + "/git/%s/%s.git", projectKey, assignmentRepositorySlug)); - studentParticipation.setBranch(defaultBranch); - programmingExerciseStudentParticipationRepository.save(studentParticipation); - - // Prepare the repositories. - localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, templateRepositorySlug); - localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, projectKey.toLowerCase() + "-tests"); - localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, solutionRepositorySlug); - localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, assignmentRepositorySlug); - - // Check that the repository folders were created in the file system for all base repositories. - localVCLocalCITestService.verifyRepositoryFoldersExist(exercise, localVCBasePath); - - activateIrisGlobally(); - activateIrisFor(course); - activateIrisFor(exercise); - // Clean up the database - irisLLMTokenUsageRepository.deleteAll(); - pipelineDone = new AtomicBoolean(false); - } - - @Test - @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") - void testTokenTrackingHandledExerciseChat() throws Exception { - var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); - var messageToSend = createDefaultMockMessage(irisSession); - - var tokens = getMockLLMCosts(); - - List doneStage = new ArrayList<>(); - doneStage.add(new PyrisStageDTO("DoneTest", 10, PyrisStageState.DONE, "Done")); - - irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { - assertThat(dto.settings().authenticationToken()).isNotNull(); - - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", doneStage, null, tokens)); - - pipelineDone.set(true); - }); - - request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); - - await().until(pipelineDone::get); - - List savedTokenUsages = irisLLMTokenUsageRepository.findAll(); - assertThat(savedTokenUsages).hasSize(5); - for (int i = 0; i < savedTokenUsages.size(); i++) { - LLMTokenUsage usage = savedTokenUsages.get(i); - PyrisLLMCostDTO expectedCost = tokens.get(i); - - assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); - assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); - assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); - assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); - assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); - assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); - } - } - - @Test - @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") - void testTokenTrackingSavedExerciseChat() { - - var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); - var irisMessage = createDefaultMockMessage(irisSession); - irisMessageRepository.save(irisMessage); - String jobToken = pyrisJobService.addExerciseChatJob(course.getId(), exercise.getId(), irisSession.getId()); - PyrisJob job = pyrisJobService.getJob(jobToken); - - var tokens = getMockLLMCosts(); - - // Capture the saved token usages - List returnedTokenUsages = llmTokenUsageService.saveIrisTokenUsage( - builder -> builder.withJob(job).withMessage(irisMessage).withExercise(exercise).withUser(irisSession.getUser()).withCourse(course).withTokens(tokens)); - assertThat(returnedTokenUsages).hasSize(5); - for (int i = 0; i < returnedTokenUsages.size(); i++) { - LLMTokenUsage usage = returnedTokenUsages.get(i); - PyrisLLMCostDTO expectedCost = tokens.get(i); - - assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); - assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); - assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); - assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); - assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); - assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); - } - } - - @Test - @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") - void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { - var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); - var messageToSend = createDefaultMockMessage(irisSession); - - var tokens = getMockLLMCosts(); - - List failedStages = new ArrayList<>(); - failedStages.add(new PyrisStageDTO("TestTokenFail", 10, PyrisStageState.ERROR, "Failed running pipeline")); - - irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { - assertThat(dto.settings().authenticationToken()).isNotNull(); - - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, failedStages, null, tokens)); - - pipelineDone.set(true); - }); - - request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); - - await().until(pipelineDone::get); - - List savedTokenUsages = irisLLMTokenUsageRepository.findAll(); - assertThat(savedTokenUsages).hasSize(5); - for (int i = 0; i < savedTokenUsages.size(); i++) { - LLMTokenUsage usage = savedTokenUsages.get(i); - PyrisLLMCostDTO expectedCost = tokens.get(i); - - assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); - assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); - assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); - assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); - assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); - assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); - } - } - - private List getMockLLMCosts() { - List costs = new ArrayList<>(); - for (int i = 0; i < 5; i++) { - costs.add(new PyrisLLMCostDTO("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, LLMServiceType.IRIS_CHAT_EXERCISE_MESSAGE.toString())); - } - return costs; - } - - private IrisMessage createDefaultMockMessage(IrisSession irisSession) { - var messageToSend = irisSession.newMessage(); - messageToSend.addContent(createMockTextContent(), createMockTextContent(), createMockTextContent()); - return messageToSend; - } - - private IrisMessageContent createMockTextContent() { - var text = "The happy dog jumped over the lazy dog."; - return new IrisTextMessageContent(text); - } - - private void sendStatus(String jobId, String result, List stages, List suggestions, List tokens) throws Exception { - var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); - request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, tokens), - HttpStatus.OK, headers); - } -} +/* + * package de.tum.cit.aet.artemis.iris; + * class IrisChatTokenTrackingIntegrationTest extends AbstractIrisIntegrationTest { + * private static final String TEST_PREFIX = "irischattokentrackingintegration"; + * @Autowired + * private IrisExerciseChatSessionService irisExerciseChatSessionService; + * @Autowired + * private IrisMessageRepository irisMessageRepository; + * @Autowired + * private LLMTokenUsageService llmTokenUsageService; + * @Autowired + * private LLMTokenUsageRepository irisLLMTokenUsageRepository; + * @Autowired + * private IrisRequestMockProvider irisRequestMockProvider; + * @Autowired + * private ParticipationUtilService participationUtilService; + * @Autowired + * private PyrisJobService pyrisJobService; + * private ProgrammingExercise exercise; + * private Course course; + * private AtomicBoolean pipelineDone; + * @BeforeEach + * void initTestCase() throws GitAPIException, IOException, URISyntaxException { + * userUtilService.addUsers(TEST_PREFIX, 2, 0, 0, 0); + * course = programmingExerciseUtilService.addCourseWithOneProgrammingExercise(); + * exercise = exerciseUtilService.getFirstExerciseWithType(course, ProgrammingExercise.class); + * String projectKey = exercise.getProjectKey(); + * exercise.setProjectType(ProjectType.PLAIN_GRADLE); + * exercise.setTestRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + projectKey.toLowerCase() + "-tests.git"); + * programmingExerciseBuildConfigRepository.save(exercise.getBuildConfig()); + * programmingExerciseRepository.save(exercise); + * exercise = programmingExerciseRepository.findWithAllParticipationsAndBuildConfigById(exercise.getId()).orElseThrow(); + * // Set the correct repository URIs for the template and the solution participation. + * String templateRepositorySlug = projectKey.toLowerCase() + "-exercise"; + * TemplateProgrammingExerciseParticipation templateParticipation = exercise.getTemplateParticipation(); + * templateParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + templateRepositorySlug + ".git"); + * templateProgrammingExerciseParticipationRepository.save(templateParticipation); + * String solutionRepositorySlug = projectKey.toLowerCase() + "-solution"; + * SolutionProgrammingExerciseParticipation solutionParticipation = exercise.getSolutionParticipation(); + * solutionParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + solutionRepositorySlug + ".git"); + * solutionProgrammingExerciseParticipationRepository.save(solutionParticipation); + * String assignmentRepositorySlug = projectKey.toLowerCase() + "-" + TEST_PREFIX + "student1"; + * // Add a participation for student1. + * ProgrammingExerciseStudentParticipation studentParticipation = participationUtilService.addStudentParticipationForProgrammingExercise(exercise, TEST_PREFIX + "student1"); + * studentParticipation.setRepositoryUri(String.format(localVCBaseUrl + "/git/%s/%s.git", projectKey, assignmentRepositorySlug)); + * studentParticipation.setBranch(defaultBranch); + * programmingExerciseStudentParticipationRepository.save(studentParticipation); + * // Prepare the repositories. + * localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, templateRepositorySlug); + * localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, projectKey.toLowerCase() + "-tests"); + * localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, solutionRepositorySlug); + * localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, assignmentRepositorySlug); + * // Check that the repository folders were created in the file system for all base repositories. + * localVCLocalCITestService.verifyRepositoryFoldersExist(exercise, localVCBasePath); + * activateIrisGlobally(); + * activateIrisFor(course); + * activateIrisFor(exercise); + * // Clean up the database + * irisLLMTokenUsageRepository.deleteAll(); + * pipelineDone = new AtomicBoolean(false); + * } + * @Test + * @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + * void testTokenTrackingHandledExerciseChat() throws Exception { + * var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + * var messageToSend = createDefaultMockMessage(irisSession); + * var tokens = getMockLLMCosts(); + * List doneStage = new ArrayList<>(); + * doneStage.add(new PyrisStageDTO("DoneTest", 10, PyrisStageState.DONE, "Done")); + * irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { + * assertThat(dto.settings().authenticationToken()).isNotNull(); + * assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", doneStage, null, tokens)); + * pipelineDone.set(true); + * }); + * request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); + * await().until(pipelineDone::get); + * List savedTokenUsages = irisLLMTokenUsageRepository.findAll(); + * assertThat(savedTokenUsages).hasSize(5); + * for (int i = 0; i < savedTokenUsages.size(); i++) { + * LLMTokenUsage usage = savedTokenUsages.get(i); + * PyrisLLMCostDTO expectedCost = tokens.get(i); + * assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); + * assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + * assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + * assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); + * assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); + * assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); + * } + * } + * @Test + * @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + * void testTokenTrackingSavedExerciseChat() { + * var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + * var irisMessage = createDefaultMockMessage(irisSession); + * irisMessageRepository.save(irisMessage); + * String jobToken = pyrisJobService.addExerciseChatJob(course.getId(), exercise.getId(), irisSession.getId()); + * PyrisJob job = pyrisJobService.getJob(jobToken); + * var tokens = getMockLLMCosts(); + * // Capture the saved token usages + * List returnedTokenUsages = llmTokenUsageService.saveLLMTokenUsage( + * builder -> builder.withMessage(irisMessage).withExercise(exercise).withUser(irisSession.getUser()).withCourse(course).withTokens(tokens)); + * assertThat(returnedTokenUsages).hasSize(5); + * for (int i = 0; i < returnedTokenUsages.size(); i++) { + * LLMTokenUsage usage = returnedTokenUsages.get(i); + * PyrisLLMCostDTO expectedCost = tokens.get(i); + * assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); + * assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + * assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + * assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); + * assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); + * assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); + * } + * } + * @Test + * @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + * void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { + * var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + * var messageToSend = createDefaultMockMessage(irisSession); + * var tokens = getMockLLMCosts(); + * List failedStages = new ArrayList<>(); + * failedStages.add(new PyrisStageDTO("TestTokenFail", 10, PyrisStageState.ERROR, "Failed running pipeline")); + * irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { + * assertThat(dto.settings().authenticationToken()).isNotNull(); + * assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, failedStages, null, tokens)); + * pipelineDone.set(true); + * }); + * request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); + * await().until(pipelineDone::get); + * List savedTokenUsages = irisLLMTokenUsageRepository.findAll(); + * assertThat(savedTokenUsages).hasSize(5); + * for (int i = 0; i < savedTokenUsages.size(); i++) { + * LLMTokenUsage usage = savedTokenUsages.get(i); + * PyrisLLMCostDTO expectedCost = tokens.get(i); + * assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); + * assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + * assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + * assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); + * assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); + * assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); + * } + * } + * private List getMockLLMCosts() { + * List costs = new ArrayList<>(); + * for (int i = 0; i < 5; i++) { + * costs.add(new LLMRequest("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, "IRIS_CHAT_EXERCISE_MESSAGE")); + * } + * return costs; + * } + * private IrisMessage createDefaultMockMessage(IrisSession irisSession) { + * var messageToSend = irisSession.newMessage(); + * messageToSend.addContent(createMockTextContent(), createMockTextContent(), createMockTextContent()); + * return messageToSend; + * } + * private IrisMessageContent createMockTextContent() { + * var text = "The happy dog jumped over the lazy dog."; + * return new IrisTextMessageContent(text); + * } + * private void sendStatus(String jobId, String result, List stages, List suggestions, List tokens) throws Exception { + * var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); + * request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, tokens), + * HttpStatus.OK, headers); + * } + * } + */ From 52bf023d6d5f3e1eaa337aa586169d99b2a12b98 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Sat, 19 Oct 2024 20:29:21 +0200 Subject: [PATCH 19/30] Implement new service in all Pipelines, update database, update test --- .../core/domain/LLMTokenUsageTrace.java | 3 + .../LLMTokenUsageRequestRepository.java | 10 + .../core/service/LLMTokenUsageService.java | 31 +- .../IrisCompetencyGenerationService.java | 2 +- .../AbstractIrisChatSessionService.java | 4 + .../session/IrisCourseChatSessionService.java | 16 +- .../IrisExerciseChatSessionService.java | 22 +- .../changelog/20241018053210_changelog.xml | 2 +- .../IrisChatTokenTrackingIntegrationTest.java | 394 ++++++++++-------- 9 files changed, 298 insertions(+), 186 deletions(-) create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java index 0294c322c0e9..cdf263da00c6 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java @@ -8,6 +8,8 @@ import jakarta.persistence.CascadeType; import jakarta.persistence.Column; import jakarta.persistence.Entity; +import jakarta.persistence.EnumType; +import jakarta.persistence.Enumerated; import jakarta.persistence.FetchType; import jakarta.persistence.OneToMany; import jakarta.persistence.Table; @@ -24,6 +26,7 @@ public class LLMTokenUsageTrace extends DomainObject { @Column(name = "service") + @Enumerated(EnumType.STRING) private LLMServiceType serviceType; @Nullable diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java new file mode 100644 index 000000000000..7c1be9da0120 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java @@ -0,0 +1,10 @@ +package de.tum.cit.aet.artemis.core.repository; + +import org.springframework.stereotype.Repository; + +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest; +import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; + +@Repository +public interface LLMTokenUsageRequestRepository extends ArtemisJpaRepository { +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index d11d08ba0afc..adf79e3b05a7 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -13,8 +13,10 @@ import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; import de.tum.cit.aet.artemis.core.domain.User; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRequestRepository; import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageTraceRepository; import de.tum.cit.aet.artemis.exercise.domain.Exercise; +import edu.stanford.nlp.util.ArraySet; /** * Service for managing the LLMTokenUsage by all LLMs in Artemis @@ -24,8 +26,11 @@ public class LLMTokenUsageService { private final LLMTokenUsageTraceRepository llmTokenUsageTraceRepository; - public LLMTokenUsageService(LLMTokenUsageTraceRepository llmTokenUsageTraceRepository) { + private final LLMTokenUsageRequestRepository llmTokenUsageRequestRepository; + + public LLMTokenUsageService(LLMTokenUsageTraceRepository llmTokenUsageTraceRepository, LLMTokenUsageRequestRepository llmTokenUsageRequestRepository) { this.llmTokenUsageTraceRepository = llmTokenUsageTraceRepository; + this.llmTokenUsageRequestRepository = llmTokenUsageRequestRepository; } /** @@ -42,17 +47,16 @@ public LLMTokenUsageTrace saveLLMTokenUsage(List llmRequests, LLMSer LLMTokenUsageBuilder builder = builderFunction.apply(new LLMTokenUsageBuilder()); builder.getIrisMessageID().ifPresent(llmTokenUsageTrace::setIrisMessageId); - builder.getUser().ifPresent(user -> { - llmTokenUsageTrace.setUserId(user.getId()); - }); - builder.getExercise().ifPresent(exercise -> { - llmTokenUsageTrace.setExerciseId(exercise.getId()); - }); - builder.getCourse().ifPresent(course -> { - llmTokenUsageTrace.setCourseId(course.getId()); - }); + builder.getUser().ifPresent(user -> llmTokenUsageTrace.setUserId(user.getId())); + builder.getExercise().ifPresent(exercise -> llmTokenUsageTrace.setExerciseId(exercise.getId())); + builder.getCourse().ifPresent(course -> llmTokenUsageTrace.setCourseId(course.getId())); Set llmRequestsSet = llmTokenUsageTrace.getLLMRequests(); + setLLMTokenUsageRequests(llmRequests, llmTokenUsageTrace, llmRequestsSet); + return llmTokenUsageTraceRepository.save(llmTokenUsageTrace); + } + + private void setLLMTokenUsageRequests(List llmRequests, LLMTokenUsageTrace llmTokenUsageTrace, Set llmRequestsSet) { for (LLMRequest llmRequest : llmRequests) { LLMTokenUsageRequest llmTokenUsageRequest = new LLMTokenUsageRequest(); llmTokenUsageRequest.setModel(llmRequest.model()); @@ -64,7 +68,12 @@ public LLMTokenUsageTrace saveLLMTokenUsage(List llmRequests, LLMSer llmTokenUsageRequest.setTrace(llmTokenUsageTrace); llmRequestsSet.add(llmTokenUsageRequest); } - return llmTokenUsageTraceRepository.save(llmTokenUsageTrace); + } + + public void appendRequestsToTrace(List requests, LLMTokenUsageTrace trace) { + Set llmRequestsSet = new ArraySet<>(); + setLLMTokenUsageRequests(requests, trace, llmRequestsSet); + llmTokenUsageRequestRepository.saveAll(llmRequestsSet); } /** diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 88bb6f37d618..e307366db077 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -73,7 +73,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String */ public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); - if (statusUpdate.tokens() != null) { + if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withCourse(course).withUser(job.user())); } websocketService.send(job.user().getLogin(), websocketTopic(job.courseId()), statusUpdate); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java index f732529aae72..559e21668775 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java @@ -1,10 +1,12 @@ package de.tum.cit.aet.artemis.iris.service.session; +import java.util.HashMap; import java.util.List; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; import de.tum.cit.aet.artemis.iris.domain.session.IrisChatSession; import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository; @@ -14,6 +16,8 @@ public abstract class AbstractIrisChatSessionService private final ObjectMapper objectMapper; + protected final HashMap traces = new HashMap<>(); + public AbstractIrisChatSessionService(IrisSessionRepository irisSessionRepository, ObjectMapper objectMapper) { this.irisSessionRepository = irisSessionRepository; this.objectMapper = objectMapper; diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index b4fc6f4ff8ad..75c8eb1430da 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -140,18 +140,26 @@ private void requestAndHandleResponse(IrisCourseChatSession session, String vari */ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { var session = (IrisCourseChatSession) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); + IrisMessage savedMessage; if (statusUpdate.result() != null) { var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); - var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, - builder -> builder.withIrisMessageID(savedMessage.getId()).withUser(session.getUser()).withCourse(session.getCourse())); + savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withUser(session.getUser()).withCourse(session.getCourse())); + savedMessage = null; irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } + if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { + if (savedMessage != null) { + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, + builder -> builder.withIrisMessageID(savedMessage.getId()).withUser(session.getUser()).withCourse(session.getCourse())); + } + else { + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withUser(session.getUser()).withCourse(session.getCourse())); + } + } updateLatestSuggestions(session, statusUpdate.suggestions()); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index a6105eb215b1..93db671f2783 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -183,14 +183,26 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } - if (statusUpdate.tokens() != null) { + if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { if (savedMessage != null) { - llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withIrisMessageID(savedMessage.getId()) - .withExercise(session.getExercise()).withUser(session.getUser()).withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember())); + // generated message is first sent and generated trace is saved + var llmTokenUsageTrace = llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, + builder -> builder.withIrisMessageID(savedMessage.getId()).withExercise(session.getExercise()).withUser(session.getUser()) + .withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember())); + traces.put(job.jobId(), llmTokenUsageTrace); } else { - llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withExercise(session.getExercise()) - .withUser(session.getUser()).withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember())); + // interaction suggestion is sent and appended to the generated trace if it exists, trace is then removed, + // because interaction suggestion is the last message from Iris in the pipeline + if (traces.containsKey(job.jobId())) { + var trace = traces.get(job.jobId()); + llmTokenUsageService.appendRequestsToTrace(statusUpdate.tokens(), trace); + traces.remove(job.jobId()); + } + else { + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withExercise(session.getExercise()) + .withUser(session.getUser()).withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember())); + } } } diff --git a/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml b/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml index 9414994c0576..b5c64a0b9eec 100644 --- a/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml @@ -10,7 +10,7 @@ - + diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java index 7e6d767449c4..6ae7ebf00400 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -1,164 +1,230 @@ -/* - * package de.tum.cit.aet.artemis.iris; - * class IrisChatTokenTrackingIntegrationTest extends AbstractIrisIntegrationTest { - * private static final String TEST_PREFIX = "irischattokentrackingintegration"; - * @Autowired - * private IrisExerciseChatSessionService irisExerciseChatSessionService; - * @Autowired - * private IrisMessageRepository irisMessageRepository; - * @Autowired - * private LLMTokenUsageService llmTokenUsageService; - * @Autowired - * private LLMTokenUsageRepository irisLLMTokenUsageRepository; - * @Autowired - * private IrisRequestMockProvider irisRequestMockProvider; - * @Autowired - * private ParticipationUtilService participationUtilService; - * @Autowired - * private PyrisJobService pyrisJobService; - * private ProgrammingExercise exercise; - * private Course course; - * private AtomicBoolean pipelineDone; - * @BeforeEach - * void initTestCase() throws GitAPIException, IOException, URISyntaxException { - * userUtilService.addUsers(TEST_PREFIX, 2, 0, 0, 0); - * course = programmingExerciseUtilService.addCourseWithOneProgrammingExercise(); - * exercise = exerciseUtilService.getFirstExerciseWithType(course, ProgrammingExercise.class); - * String projectKey = exercise.getProjectKey(); - * exercise.setProjectType(ProjectType.PLAIN_GRADLE); - * exercise.setTestRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + projectKey.toLowerCase() + "-tests.git"); - * programmingExerciseBuildConfigRepository.save(exercise.getBuildConfig()); - * programmingExerciseRepository.save(exercise); - * exercise = programmingExerciseRepository.findWithAllParticipationsAndBuildConfigById(exercise.getId()).orElseThrow(); - * // Set the correct repository URIs for the template and the solution participation. - * String templateRepositorySlug = projectKey.toLowerCase() + "-exercise"; - * TemplateProgrammingExerciseParticipation templateParticipation = exercise.getTemplateParticipation(); - * templateParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + templateRepositorySlug + ".git"); - * templateProgrammingExerciseParticipationRepository.save(templateParticipation); - * String solutionRepositorySlug = projectKey.toLowerCase() + "-solution"; - * SolutionProgrammingExerciseParticipation solutionParticipation = exercise.getSolutionParticipation(); - * solutionParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + solutionRepositorySlug + ".git"); - * solutionProgrammingExerciseParticipationRepository.save(solutionParticipation); - * String assignmentRepositorySlug = projectKey.toLowerCase() + "-" + TEST_PREFIX + "student1"; - * // Add a participation for student1. - * ProgrammingExerciseStudentParticipation studentParticipation = participationUtilService.addStudentParticipationForProgrammingExercise(exercise, TEST_PREFIX + "student1"); - * studentParticipation.setRepositoryUri(String.format(localVCBaseUrl + "/git/%s/%s.git", projectKey, assignmentRepositorySlug)); - * studentParticipation.setBranch(defaultBranch); - * programmingExerciseStudentParticipationRepository.save(studentParticipation); - * // Prepare the repositories. - * localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, templateRepositorySlug); - * localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, projectKey.toLowerCase() + "-tests"); - * localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, solutionRepositorySlug); - * localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, assignmentRepositorySlug); - * // Check that the repository folders were created in the file system for all base repositories. - * localVCLocalCITestService.verifyRepositoryFoldersExist(exercise, localVCBasePath); - * activateIrisGlobally(); - * activateIrisFor(course); - * activateIrisFor(exercise); - * // Clean up the database - * irisLLMTokenUsageRepository.deleteAll(); - * pipelineDone = new AtomicBoolean(false); - * } - * @Test - * @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") - * void testTokenTrackingHandledExerciseChat() throws Exception { - * var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); - * var messageToSend = createDefaultMockMessage(irisSession); - * var tokens = getMockLLMCosts(); - * List doneStage = new ArrayList<>(); - * doneStage.add(new PyrisStageDTO("DoneTest", 10, PyrisStageState.DONE, "Done")); - * irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { - * assertThat(dto.settings().authenticationToken()).isNotNull(); - * assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", doneStage, null, tokens)); - * pipelineDone.set(true); - * }); - * request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); - * await().until(pipelineDone::get); - * List savedTokenUsages = irisLLMTokenUsageRepository.findAll(); - * assertThat(savedTokenUsages).hasSize(5); - * for (int i = 0; i < savedTokenUsages.size(); i++) { - * LLMTokenUsage usage = savedTokenUsages.get(i); - * PyrisLLMCostDTO expectedCost = tokens.get(i); - * assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); - * assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); - * assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); - * assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); - * assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); - * assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); - * } - * } - * @Test - * @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") - * void testTokenTrackingSavedExerciseChat() { - * var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); - * var irisMessage = createDefaultMockMessage(irisSession); - * irisMessageRepository.save(irisMessage); - * String jobToken = pyrisJobService.addExerciseChatJob(course.getId(), exercise.getId(), irisSession.getId()); - * PyrisJob job = pyrisJobService.getJob(jobToken); - * var tokens = getMockLLMCosts(); - * // Capture the saved token usages - * List returnedTokenUsages = llmTokenUsageService.saveLLMTokenUsage( - * builder -> builder.withMessage(irisMessage).withExercise(exercise).withUser(irisSession.getUser()).withCourse(course).withTokens(tokens)); - * assertThat(returnedTokenUsages).hasSize(5); - * for (int i = 0; i < returnedTokenUsages.size(); i++) { - * LLMTokenUsage usage = returnedTokenUsages.get(i); - * PyrisLLMCostDTO expectedCost = tokens.get(i); - * assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); - * assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); - * assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); - * assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); - * assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); - * assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); - * } - * } - * @Test - * @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") - * void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { - * var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); - * var messageToSend = createDefaultMockMessage(irisSession); - * var tokens = getMockLLMCosts(); - * List failedStages = new ArrayList<>(); - * failedStages.add(new PyrisStageDTO("TestTokenFail", 10, PyrisStageState.ERROR, "Failed running pipeline")); - * irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { - * assertThat(dto.settings().authenticationToken()).isNotNull(); - * assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, failedStages, null, tokens)); - * pipelineDone.set(true); - * }); - * request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); - * await().until(pipelineDone::get); - * List savedTokenUsages = irisLLMTokenUsageRepository.findAll(); - * assertThat(savedTokenUsages).hasSize(5); - * for (int i = 0; i < savedTokenUsages.size(); i++) { - * LLMTokenUsage usage = savedTokenUsages.get(i); - * PyrisLLMCostDTO expectedCost = tokens.get(i); - * assertThat(usage.getModel()).isEqualTo(expectedCost.modelInfo()); - * assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); - * assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); - * assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerInputToken()); - * assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerOutputToken()); - * assertThat(usage.getServiceType()).isEqualTo(expectedCost.pipeline()); - * } - * } - * private List getMockLLMCosts() { - * List costs = new ArrayList<>(); - * for (int i = 0; i < 5; i++) { - * costs.add(new LLMRequest("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, "IRIS_CHAT_EXERCISE_MESSAGE")); - * } - * return costs; - * } - * private IrisMessage createDefaultMockMessage(IrisSession irisSession) { - * var messageToSend = irisSession.newMessage(); - * messageToSend.addContent(createMockTextContent(), createMockTextContent(), createMockTextContent()); - * return messageToSend; - * } - * private IrisMessageContent createMockTextContent() { - * var text = "The happy dog jumped over the lazy dog."; - * return new IrisTextMessageContent(text); - * } - * private void sendStatus(String jobId, String result, List stages, List suggestions, List tokens) throws Exception { - * var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); - * request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, tokens), - * HttpStatus.OK, headers); - * } - * } - */ +package de.tum.cit.aet.artemis.iris; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.awaitility.Awaitility.await; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.eclipse.jgit.api.errors.GitAPIException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.util.LinkedMultiValueMap; + +import de.tum.cit.aet.artemis.core.connector.IrisRequestMockProvider; +import de.tum.cit.aet.artemis.core.domain.Course; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRequestRepository; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageTraceRepository; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; +import de.tum.cit.aet.artemis.exercise.participation.util.ParticipationUtilService; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageContent; +import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; +import de.tum.cit.aet.artemis.iris.domain.session.IrisSession; +import de.tum.cit.aet.artemis.iris.repository.IrisMessageRepository; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; +import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; +import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise; +import de.tum.cit.aet.artemis.programming.domain.ProgrammingExerciseStudentParticipation; +import de.tum.cit.aet.artemis.programming.domain.ProjectType; +import de.tum.cit.aet.artemis.programming.domain.SolutionProgrammingExerciseParticipation; +import de.tum.cit.aet.artemis.programming.domain.TemplateProgrammingExerciseParticipation; + +class IrisChatTokenTrackingIntegrationTest extends AbstractIrisIntegrationTest { + + private static final String TEST_PREFIX = "irischattokentrackingintegration"; + + @Autowired + private IrisExerciseChatSessionService irisExerciseChatSessionService; + + @Autowired + private IrisMessageRepository irisMessageRepository; + + @Autowired + private LLMTokenUsageService llmTokenUsageService; + + @Autowired + private LLMTokenUsageTraceRepository irisLLMTokenUsageTraceRepository; + + @Autowired + private LLMTokenUsageRequestRepository irisLLMTokenUsageRequestRepository; + + @Autowired + private IrisRequestMockProvider irisRequestMockProvider; + + @Autowired + private ParticipationUtilService participationUtilService; + + private ProgrammingExercise exercise; + + private Course course; + + private AtomicBoolean pipelineDone; + + @BeforeEach + void initTestCase() throws GitAPIException, IOException, URISyntaxException { + userUtilService.addUsers(TEST_PREFIX, 2, 0, 0, 0); + course = programmingExerciseUtilService.addCourseWithOneProgrammingExercise(); + exercise = exerciseUtilService.getFirstExerciseWithType(course, ProgrammingExercise.class); + String projectKey = exercise.getProjectKey(); + exercise.setProjectType(ProjectType.PLAIN_GRADLE); + exercise.setTestRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + projectKey.toLowerCase() + "-tests.git"); + programmingExerciseBuildConfigRepository.save(exercise.getBuildConfig()); + programmingExerciseRepository.save(exercise); + exercise = programmingExerciseRepository.findWithAllParticipationsAndBuildConfigById(exercise.getId()).orElseThrow(); + // Set the correct repository URIs for the template and the solution participation. + String templateRepositorySlug = projectKey.toLowerCase() + "-exercise"; + TemplateProgrammingExerciseParticipation templateParticipation = exercise.getTemplateParticipation(); + templateParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + templateRepositorySlug + ".git"); + templateProgrammingExerciseParticipationRepository.save(templateParticipation); + String solutionRepositorySlug = projectKey.toLowerCase() + "-solution"; + SolutionProgrammingExerciseParticipation solutionParticipation = exercise.getSolutionParticipation(); + solutionParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + solutionRepositorySlug + ".git"); + solutionProgrammingExerciseParticipationRepository.save(solutionParticipation); + String assignmentRepositorySlug = projectKey.toLowerCase() + "-" + TEST_PREFIX + "student1"; + // Add a participation for student1. + ProgrammingExerciseStudentParticipation studentParticipation = participationUtilService.addStudentParticipationForProgrammingExercise(exercise, TEST_PREFIX + "student1"); + studentParticipation.setRepositoryUri(String.format(localVCBaseUrl + "/git/%s/%s.git", projectKey, assignmentRepositorySlug)); + studentParticipation.setBranch(defaultBranch); + programmingExerciseStudentParticipationRepository.save(studentParticipation); + // Prepare the repositories. + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, templateRepositorySlug); + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, projectKey.toLowerCase() + "-tests"); + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, solutionRepositorySlug); + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, assignmentRepositorySlug); + // Check that the repository folders were created in the file system for all base repositories. + localVCLocalCITestService.verifyRepositoryFoldersExist(exercise, localVCBasePath); + activateIrisGlobally(); + activateIrisFor(course); + activateIrisFor(exercise); + // Clean up the database + irisLLMTokenUsageRequestRepository.deleteAll(); + irisLLMTokenUsageTraceRepository.deleteAll(); + pipelineDone = new AtomicBoolean(false); + } + + @Test + @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + void testTokenTrackingHandledExerciseChat() throws Exception { + var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + var messageToSend = createDefaultMockMessage(irisSession); + var tokens = getMockLLMCosts(); + List doneStage = new ArrayList<>(); + doneStage.add(new PyrisStageDTO("DoneTest", 10, PyrisStageState.DONE, "Done")); + irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { + assertThat(dto.settings().authenticationToken()).isNotNull(); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", doneStage, tokens)); + pipelineDone.set(true); + }); + request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); + await().until(pipelineDone::get); + List savedTokenUsageTraces = irisLLMTokenUsageTraceRepository.findAll(); + List savedTokenUsageRequests = irisLLMTokenUsageRequestRepository.findAll(); + assertThat(savedTokenUsageTraces).hasSize(1); + assertThat(savedTokenUsageTraces.getFirst().getServiceType()).isEqualTo(LLMServiceType.IRIS); + assertThat(savedTokenUsageTraces.getFirst().getExerciseId()).isEqualTo(exercise.getId()); + assertThat(savedTokenUsageTraces.getFirst().getCourseId()).isEqualTo(course.getId()); + assertThat(savedTokenUsageRequests).hasSize(5); + for (int i = 0; i < savedTokenUsageRequests.size(); i++) { + LLMTokenUsageRequest usage = savedTokenUsageRequests.get(i); + LLMRequest expectedCost = tokens.get(i); + assertThat(usage.getModel()).isEqualTo(expectedCost.model()); + assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerMillionInputToken()); + assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerMillionOutputToken()); + assertThat(usage.getServicePipelineId()).isEqualTo(expectedCost.pipelineId()); + } + } + + @Test + @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + void testTokenTrackingSavedExerciseChat() { + var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + var irisMessage = createDefaultMockMessage(irisSession); + irisMessageRepository.save(irisMessage); + var tokens = getMockLLMCosts(); + LLMTokenUsageTrace tokenUsageTrace = llmTokenUsageService.saveLLMTokenUsage(tokens, LLMServiceType.IRIS, + builder -> builder.withIrisMessageID(irisMessage.getId()).withExercise(exercise).withUser(irisSession.getUser()).withCourse(course)); + assertThat(tokenUsageTrace.getServiceType()).isEqualTo(LLMServiceType.IRIS); + assertThat(tokenUsageTrace.getIrisMessageId()).isEqualTo(irisMessage.getId()); + assertThat(tokenUsageTrace.getExerciseId()).isEqualTo(exercise.getId()); + assertThat(tokenUsageTrace.getUserId()).isEqualTo(irisSession.getUser().getId()); + assertThat(tokenUsageTrace.getCourseId()).isEqualTo(course.getId()); + } + + @Test + @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { + var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + var messageToSend = createDefaultMockMessage(irisSession); + var tokens = getMockLLMCosts(); + List failedStages = new ArrayList<>(); + failedStages.add(new PyrisStageDTO("TestTokenFail", 10, PyrisStageState.ERROR, "Failed running pipeline")); + irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { + assertThat(dto.settings().authenticationToken()).isNotNull(); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, failedStages, tokens)); + pipelineDone.set(true); + }); + request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); + await().until(pipelineDone::get); + List savedTokenUsageTraces = irisLLMTokenUsageTraceRepository.findAll(); + List savedTokenUsageRequests = irisLLMTokenUsageRequestRepository.findAll(); + assertThat(savedTokenUsageTraces).hasSize(1); + assertThat(savedTokenUsageTraces.getFirst().getServiceType()).isEqualTo(LLMServiceType.IRIS); + assertThat(savedTokenUsageTraces.getFirst().getExerciseId()).isEqualTo(exercise.getId()); + assertThat(savedTokenUsageTraces.getFirst().getIrisMessageId()).isEqualTo(messageToSend.getId()); + assertThat(savedTokenUsageTraces.getFirst().getCourseId()).isEqualTo(course.getId()); + assertThat(savedTokenUsageRequests).hasSize(5); + for (int i = 0; i < savedTokenUsageRequests.size(); i++) { + LLMTokenUsageRequest usage = savedTokenUsageRequests.get(i); + LLMRequest expectedCost = tokens.get(i); + assertThat(usage.getModel()).isEqualTo(expectedCost.model()); + assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerMillionInputToken()); + assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerMillionOutputToken()); + assertThat(usage.getServicePipelineId()).isEqualTo(expectedCost.pipelineId()); + } + } + + private List getMockLLMCosts() { + List costs = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + costs.add(new LLMRequest("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, "IRIS_CHAT_EXERCISE_MESSAGE")); + } + return costs; + } + + private IrisMessage createDefaultMockMessage(IrisSession irisSession) { + var messageToSend = irisSession.newMessage(); + messageToSend.addContent(createMockTextContent(), createMockTextContent(), createMockTextContent()); + return messageToSend; + } + + private IrisMessageContent createMockTextContent() { + var text = "The happy dog jumped over the lazy dog."; + return new IrisTextMessageContent(text); + } + + private void sendStatus(String jobId, String result, List stages, List tokens) throws Exception { + var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); + request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, null, tokens), + HttpStatus.OK, headers); + } +} From b8f5ccae0eb6e774bfd0ebcffd03e63c1fb3800a Mon Sep 17 00:00:00 2001 From: Stephan Krusche Date: Sun, 20 Oct 2024 15:42:35 +0200 Subject: [PATCH 20/30] fix server tests --- .../core/repository/LLMTokenUsageRequestRepository.java | 4 ++++ .../core/repository/LLMTokenUsageTraceRepository.java | 4 ++++ .../cit/aet/artemis/core/service/LLMTokenUsageService.java | 6 ++++++ 3 files changed, 14 insertions(+) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java index 7c1be9da0120..145383bf124a 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java @@ -1,10 +1,14 @@ package de.tum.cit.aet.artemis.core.repository; +import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE; + +import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Repository; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest; import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; +@Profile(PROFILE_CORE) @Repository public interface LLMTokenUsageRequestRepository extends ArtemisJpaRepository { } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java index 602b5686f417..cc1b0e588c4e 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java @@ -1,10 +1,14 @@ package de.tum.cit.aet.artemis.core.repository; +import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE; + +import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Repository; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; +@Profile(PROFILE_CORE) @Repository public interface LLMTokenUsageTraceRepository extends ArtemisJpaRepository { } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index adf79e3b05a7..7326f9792590 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -1,10 +1,13 @@ package de.tum.cit.aet.artemis.core.service; +import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE; + import java.util.List; import java.util.Optional; import java.util.Set; import java.util.function.Function; +import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; import de.tum.cit.aet.artemis.core.domain.Course; @@ -21,6 +24,7 @@ /** * Service for managing the LLMTokenUsage by all LLMs in Artemis */ +@Profile(PROFILE_CORE) @Service public class LLMTokenUsageService { @@ -41,6 +45,7 @@ public LLMTokenUsageService(LLMTokenUsageTraceRepository llmTokenUsageTraceRepos * @param builderFunction of type Function using IrisTokenUsageBuilder * @return saved LLMTokenUsage as a List */ + // TODO: this should ideally be done Async public LLMTokenUsageTrace saveLLMTokenUsage(List llmRequests, LLMServiceType serviceType, Function builderFunction) { LLMTokenUsageTrace llmTokenUsageTrace = new LLMTokenUsageTrace(); llmTokenUsageTrace.setServiceType(serviceType); @@ -70,6 +75,7 @@ private void setLLMTokenUsageRequests(List llmRequests, LLMTokenUsag } } + // TODO: this should ideally be done Async public void appendRequestsToTrace(List requests, LLMTokenUsageTrace trace) { Set llmRequestsSet = new ArraySet<>(); setLLMTokenUsageRequests(requests, trace, llmRequestsSet); From 82fb76d45c1e31448f98d9b6148b98b8df001ff5 Mon Sep 17 00:00:00 2001 From: "Felix T.J. Dietrich" Date: Mon, 21 Oct 2024 10:25:18 +0200 Subject: [PATCH 21/30] fix function naming --- .../tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java index 186b5d049183..1b769f5ea97b 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java @@ -85,7 +85,7 @@ public void setNumOutputTokens(int numOutputTokens) { this.numOutputTokens = numOutputTokens; } - public LLMTokenUsageTrace getTraceId() { + public LLMTokenUsageTrace getTrace() { return trace; } From 6d3037ab18c0080700a63d7d799f139295681f5b Mon Sep 17 00:00:00 2001 From: "Felix T.J. Dietrich" Date: Mon, 21 Oct 2024 10:28:45 +0200 Subject: [PATCH 22/30] replace ArraySet --- .../cit/aet/artemis/core/service/LLMTokenUsageService.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 7326f9792590..2d48bfcf73fd 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -2,6 +2,7 @@ import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; @@ -19,7 +20,6 @@ import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRequestRepository; import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageTraceRepository; import de.tum.cit.aet.artemis.exercise.domain.Exercise; -import edu.stanford.nlp.util.ArraySet; /** * Service for managing the LLMTokenUsage by all LLMs in Artemis @@ -77,7 +77,7 @@ private void setLLMTokenUsageRequests(List llmRequests, LLMTokenUsag // TODO: this should ideally be done Async public void appendRequestsToTrace(List requests, LLMTokenUsageTrace trace) { - Set llmRequestsSet = new ArraySet<>(); + Set llmRequestsSet = new HashSet<>(); setLLMTokenUsageRequests(requests, trace, llmRequestsSet); llmTokenUsageRequestRepository.saveAll(llmRequestsSet); } From 9f4cccd936ad80a42f0fa722c1cd41950f14521f Mon Sep 17 00:00:00 2001 From: Patrick Bassner Date: Mon, 21 Oct 2024 16:58:28 +0200 Subject: [PATCH 23/30] Refactored token usage tracking and improved session-based job handling --- .../core/service/LLMTokenUsageService.java | 89 +++++++++---------- .../IrisCompetencyGenerationService.java | 14 ++- .../pyris/job/CompetencyExtractionJob.java | 5 +- .../iris/service/pyris/job/CourseChatJob.java | 2 +- .../service/pyris/job/ExerciseChatJob.java | 2 +- .../pyris/job/SessionBasedPyrisJob.java | 9 ++ .../AbstractIrisChatSessionService.java | 74 ++++++++++++++- .../session/IrisCourseChatSessionService.java | 45 +--------- .../IrisExerciseChatSessionService.java | 61 ++----------- ...isCompetencyGenerationIntegrationTest.java | 2 +- 10 files changed, 150 insertions(+), 153 deletions(-) create mode 100644 src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/SessionBasedPyrisJob.java diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 2d48bfcf73fd..5ffe5f379ff5 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -2,24 +2,20 @@ import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE; -import java.util.HashSet; import java.util.List; import java.util.Optional; -import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; -import de.tum.cit.aet.artemis.core.domain.Course; import de.tum.cit.aet.artemis.core.domain.LLMRequest; import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; -import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRequestRepository; import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageTraceRepository; -import de.tum.cit.aet.artemis.exercise.domain.Exercise; /** * Service for managing the LLMTokenUsage by all LLMs in Artemis @@ -38,12 +34,17 @@ public LLMTokenUsageService(LLMTokenUsageTraceRepository llmTokenUsageTraceRepos } /** - * method saves the token usage to the database + * Saves the token usage to the database. + * This method records the usage of tokens by various LLM services in the system. * - * @param llmRequests List of LLM requests - * @param serviceType type of the LLM service - * @param builderFunction of type Function using IrisTokenUsageBuilder - * @return saved LLMTokenUsage as a List + * @param llmRequests List of LLM requests containing details about the token usage. + * @param serviceType Type of the LLM service (e.g., IRIS, GPT-3). + * @param builderFunction A function that takes an LLMTokenUsageBuilder and returns a modified LLMTokenUsageBuilder. + * This function is used to set additional properties on the LLMTokenUsageTrace object, such as + * the course ID, user ID, exercise ID, and Iris message ID. + * Example usage: + * builder -> builder.withCourse(courseId).withUser(userId) + * @return The saved LLMTokenUsageTrace object, which includes the details of the token usage. */ // TODO: this should ideally be done Async public LLMTokenUsageTrace saveLLMTokenUsage(List llmRequests, LLMServiceType serviceType, Function builderFunction) { @@ -52,34 +53,32 @@ public LLMTokenUsageTrace saveLLMTokenUsage(List llmRequests, LLMSer LLMTokenUsageBuilder builder = builderFunction.apply(new LLMTokenUsageBuilder()); builder.getIrisMessageID().ifPresent(llmTokenUsageTrace::setIrisMessageId); - builder.getUser().ifPresent(user -> llmTokenUsageTrace.setUserId(user.getId())); - builder.getExercise().ifPresent(exercise -> llmTokenUsageTrace.setExerciseId(exercise.getId())); - builder.getCourse().ifPresent(course -> llmTokenUsageTrace.setCourseId(course.getId())); + builder.getCourseID().ifPresent(llmTokenUsageTrace::setCourseId); + builder.getExerciseID().ifPresent(llmTokenUsageTrace::setExerciseId); + builder.getUserID().ifPresent(llmTokenUsageTrace::setUserId); + + llmTokenUsageTrace.setLlmRequests(llmRequests.stream().map(LLMTokenUsageService::convertLLMRequestToLLMTokenUsageRequest) + .peek(llmTokenUsageRequest -> llmTokenUsageRequest.setTrace(llmTokenUsageTrace)).collect(Collectors.toSet())); - Set llmRequestsSet = llmTokenUsageTrace.getLLMRequests(); - setLLMTokenUsageRequests(llmRequests, llmTokenUsageTrace, llmRequestsSet); return llmTokenUsageTraceRepository.save(llmTokenUsageTrace); } - private void setLLMTokenUsageRequests(List llmRequests, LLMTokenUsageTrace llmTokenUsageTrace, Set llmRequestsSet) { - for (LLMRequest llmRequest : llmRequests) { - LLMTokenUsageRequest llmTokenUsageRequest = new LLMTokenUsageRequest(); - llmTokenUsageRequest.setModel(llmRequest.model()); - llmTokenUsageRequest.setNumInputTokens(llmRequest.numInputTokens()); - llmTokenUsageRequest.setNumOutputTokens(llmRequest.numOutputTokens()); - llmTokenUsageRequest.setCostPerMillionInputTokens(llmRequest.costPerMillionInputToken()); - llmTokenUsageRequest.setCostPerMillionOutputTokens(llmRequest.costPerMillionOutputToken()); - llmTokenUsageRequest.setServicePipelineId(llmRequest.pipelineId()); - llmTokenUsageRequest.setTrace(llmTokenUsageTrace); - llmRequestsSet.add(llmTokenUsageRequest); - } + private static LLMTokenUsageRequest convertLLMRequestToLLMTokenUsageRequest(LLMRequest llmRequest) { + LLMTokenUsageRequest llmTokenUsageRequest = new LLMTokenUsageRequest(); + llmTokenUsageRequest.setModel(llmRequest.model()); + llmTokenUsageRequest.setNumInputTokens(llmRequest.numInputTokens()); + llmTokenUsageRequest.setNumOutputTokens(llmRequest.numOutputTokens()); + llmTokenUsageRequest.setCostPerMillionInputTokens(llmRequest.costPerMillionInputToken()); + llmTokenUsageRequest.setCostPerMillionOutputTokens(llmRequest.costPerMillionOutputToken()); + llmTokenUsageRequest.setServicePipelineId(llmRequest.pipelineId()); + return llmTokenUsageRequest; } // TODO: this should ideally be done Async public void appendRequestsToTrace(List requests, LLMTokenUsageTrace trace) { - Set llmRequestsSet = new HashSet<>(); - setLLMTokenUsageRequests(requests, trace, llmRequestsSet); - llmTokenUsageRequestRepository.saveAll(llmRequestsSet); + var requestSet = requests.stream().map(LLMTokenUsageService::convertLLMRequestToLLMTokenUsageRequest).peek(llmTokenUsageRequest -> llmTokenUsageRequest.setTrace(trace)) + .collect(Collectors.toSet()); + llmTokenUsageRequestRepository.saveAll(requestSet); } /** @@ -87,16 +86,16 @@ public void appendRequestsToTrace(List requests, LLMTokenUsageTrace */ public static class LLMTokenUsageBuilder { - private Optional course = Optional.empty(); + private Optional courseID = Optional.empty(); private Optional irisMessageID = Optional.empty(); - private Optional exercise = Optional.empty(); + private Optional exerciseID = Optional.empty(); - private Optional user = Optional.empty(); + private Optional userID = Optional.empty(); - public LLMTokenUsageBuilder withCourse(Course course) { - this.course = Optional.ofNullable(course); + public LLMTokenUsageBuilder withCourse(Long courseID) { + this.courseID = Optional.ofNullable(courseID); return this; } @@ -105,30 +104,30 @@ public LLMTokenUsageBuilder withIrisMessageID(Long irisMessageID) { return this; } - public LLMTokenUsageBuilder withExercise(Exercise exercise) { - this.exercise = Optional.ofNullable(exercise); + public LLMTokenUsageBuilder withExercise(Long exerciseID) { + this.exerciseID = Optional.ofNullable(exerciseID); return this; } - public LLMTokenUsageBuilder withUser(User user) { - this.user = Optional.ofNullable(user); + public LLMTokenUsageBuilder withUser(Long userID) { + this.userID = Optional.ofNullable(userID); return this; } - public Optional getCourse() { - return course; + public Optional getCourseID() { + return courseID; } public Optional getIrisMessageID() { return irisMessageID; } - public Optional getExercise() { - return exercise; + public Optional getExerciseID() { + return exerciseID; } - public Optional getUser() { - return user; + public Optional getUserID() { + return userID; } } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index e307366db077..f8d2a0201198 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -10,6 +10,7 @@ import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.repository.CourseRepository; +import de.tum.cit.aet.artemis.core.repository.UserRepository; import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService; @@ -36,13 +37,16 @@ public class IrisCompetencyGenerationService { private final PyrisJobService pyrisJobService; + private final UserRepository userRepository; + public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, CourseRepository courseRepository, - IrisWebsocketService websocketService, PyrisJobService pyrisJobService) { + IrisWebsocketService websocketService, PyrisJobService pyrisJobService, UserRepository userRepository) { this.pyrisPipelineService = pyrisPipelineService; this.llmTokenUsageService = llmTokenUsageService; this.courseRepository = courseRepository; this.websocketService = websocketService; this.pyrisJobService = pyrisJobService; + this.userRepository = userRepository; } /** @@ -58,7 +62,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String pyrisPipelineService.executePipeline( "competency-extraction", "default", - pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user)), + pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user.getId())), executionDto -> new PyrisCompetencyExtractionPipelineExecutionDTO(executionDto, courseDescription, currentCompetencies, CompetencyTaxonomy.values(), 5), stages -> websocketService.send(user.getLogin(), websocketTopic(course.getId()), new PyrisCompetencyStatusUpdateDTO(stages, null, null)) ); @@ -74,9 +78,11 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { - llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withCourse(course).withUser(job.user())); + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withCourse(course.getId()).withUser(job.userId())); } - websocketService.send(job.user().getLogin(), websocketTopic(job.courseId()), statusUpdate); + + var user = userRepository.findById(job.userId()).orElseThrow(); + websocketService.send(user.getLogin(), websocketTopic(job.courseId()), statusUpdate); } private static String websocketTopic(long courseId) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java index 136a1a5ae243..b50d8e70b8c9 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java @@ -3,17 +3,16 @@ import com.fasterxml.jackson.annotation.JsonInclude; import de.tum.cit.aet.artemis.core.domain.Course; -import de.tum.cit.aet.artemis.core.domain.User; /** * A pyris job that extracts competencies from a course description. * * @param jobId the job id * @param courseId the course in which the competencies are being extracted - * @param user the user who started the job + * @param userId the user who started the job */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record CompetencyExtractionJob(String jobId, long courseId, User user) implements PyrisJob { +public record CompetencyExtractionJob(String jobId, long courseId, long userId) implements PyrisJob { @Override public boolean canAccess(Course course) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java index fb4b93a28854..c05cbf9b94ea 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java @@ -9,7 +9,7 @@ * This job is used to reference the details of a course chat session when Pyris sends a status update. */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record CourseChatJob(String jobId, long courseId, long sessionId) implements PyrisJob { +public record CourseChatJob(String jobId, long courseId, long sessionId) implements SessionBasedPyrisJob { @Override public boolean canAccess(Course course) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java index 302ae274d8e2..1c2278cb2697 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java @@ -10,7 +10,7 @@ * This job is used to reference the details of a exercise chat session when Pyris sends a status update. */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record ExerciseChatJob(String jobId, long courseId, long exerciseId, long sessionId) implements PyrisJob { +public record ExerciseChatJob(String jobId, long courseId, long exerciseId, long sessionId) implements SessionBasedPyrisJob { @Override public boolean canAccess(Course course) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/SessionBasedPyrisJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/SessionBasedPyrisJob.java new file mode 100644 index 000000000000..03c2e4007838 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/SessionBasedPyrisJob.java @@ -0,0 +1,9 @@ +package de.tum.cit.aet.artemis.iris.service.pyris.job; + +/** + * An interface Pyris job that is associated with a session. + */ +public interface SessionBasedPyrisJob extends PyrisJob { + + long sessionId(); +} diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java index 559e21668775..16df99a68337 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java @@ -6,21 +6,40 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender; +import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; import de.tum.cit.aet.artemis.iris.domain.session.IrisChatSession; import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository; +import de.tum.cit.aet.artemis.iris.service.IrisMessageService; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.job.SessionBasedPyrisJob; +import de.tum.cit.aet.artemis.iris.service.websocket.IrisChatWebsocketService; public abstract class AbstractIrisChatSessionService implements IrisChatBasedFeatureInterface, IrisRateLimitedFeatureInterface { private final IrisSessionRepository irisSessionRepository; + private final IrisMessageService irisMessageService; + + private final IrisChatWebsocketService irisChatWebsocketService; + + private final LLMTokenUsageService llmTokenUsageService; + private final ObjectMapper objectMapper; protected final HashMap traces = new HashMap<>(); - public AbstractIrisChatSessionService(IrisSessionRepository irisSessionRepository, ObjectMapper objectMapper) { + public AbstractIrisChatSessionService(IrisSessionRepository irisSessionRepository, ObjectMapper objectMapper, IrisMessageService irisMessageService, + IrisChatWebsocketService irisChatWebsocketService, LLMTokenUsageService llmTokenUsageService) { this.irisSessionRepository = irisSessionRepository; this.objectMapper = objectMapper; + this.irisMessageService = irisMessageService; + this.irisChatWebsocketService = irisChatWebsocketService; + this.llmTokenUsageService = llmTokenUsageService; } /** @@ -44,4 +63,57 @@ protected void updateLatestSuggestions(S session, List latestSuggestions throw new RuntimeException("Could not update latest suggestions for session " + session.getId(), e); } } + + /** + * Handles the status update of a ExerciseChatJob by sending the result to the student via the Websocket. + * + * @param job The job that was executed + * @param statusUpdate The status update of the job + */ + public void handleStatusUpdate(SessionBasedPyrisJob job, PyrisChatStatusUpdateDTO statusUpdate) { + var session = (S) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); + IrisMessage savedMessage; + if (statusUpdate.result() != null) { + var message = new IrisMessage(); + message.addContent(new IrisTextMessageContent(statusUpdate.result())); + savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); + irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); + } + else { + savedMessage = null; + irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); + } + + if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { + if (savedMessage != null) { + // generated message is first sent and generated trace is saved + var llmTokenUsageTrace = llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> { + builder.withIrisMessageID(savedMessage.getId()).withUser(session.getUser().getId()); + this.setLLMTokenUsageParameters(builder, session); + return builder; + }); + traces.put(job.jobId(), llmTokenUsageTrace); + } + else { + // interaction suggestion is sent and appended to the generated trace if it exists, trace is then removed, + // because interaction suggestion is the last message from Iris in the pipeline + if (traces.containsKey(job.jobId())) { + var trace = traces.get(job.jobId()); + llmTokenUsageService.appendRequestsToTrace(statusUpdate.tokens(), trace); + traces.remove(job.jobId()); + } + else { + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> { + builder.withUser(session.getUser().getId()); + this.setLLMTokenUsageParameters(builder, session); + return builder; + }); + } + } + } + + updateLatestSuggestions(session, statusUpdate.suggestions()); + } + + protected abstract void setLLMTokenUsageParameters(LLMTokenUsageService.LLMTokenUsageBuilder builder, S session); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index 75c8eb1430da..d2743c2e71a5 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -15,15 +15,12 @@ import de.tum.cit.aet.artemis.atlas.domain.competency.CompetencyJol; import de.tum.cit.aet.artemis.core.domain.Course; -import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.exception.AccessForbiddenException; import de.tum.cit.aet.artemis.core.security.Role; import de.tum.cit.aet.artemis.core.service.AuthorizationCheckService; import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; -import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender; -import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; import de.tum.cit.aet.artemis.iris.domain.session.IrisCourseChatSession; import de.tum.cit.aet.artemis.iris.domain.settings.IrisSubSettingsType; import de.tum.cit.aet.artemis.iris.repository.IrisCourseChatSessionRepository; @@ -31,8 +28,6 @@ import de.tum.cit.aet.artemis.iris.service.IrisMessageService; import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.job.CourseChatJob; import de.tum.cit.aet.artemis.iris.service.settings.IrisSettingsService; import de.tum.cit.aet.artemis.iris.service.websocket.IrisChatWebsocketService; @@ -43,10 +38,6 @@ @Profile(PROFILE_IRIS) public class IrisCourseChatSessionService extends AbstractIrisChatSessionService { - private final IrisMessageService irisMessageService; - - private final LLMTokenUsageService llmTokenUsageService; - private final IrisSettingsService irisSettingsService; private final IrisChatWebsocketService irisChatWebsocketService; @@ -65,9 +56,7 @@ public IrisCourseChatSessionService(IrisMessageService irisMessageService, LLMTo IrisChatWebsocketService irisChatWebsocketService, AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, IrisRateLimitService rateLimitService, IrisCourseChatSessionRepository irisCourseChatSessionRepository, PyrisPipelineService pyrisPipelineService, ObjectMapper objectMapper) { - super(irisSessionRepository, objectMapper); - this.irisMessageService = irisMessageService; - this.llmTokenUsageService = llmTokenUsageService; + super(irisSessionRepository, objectMapper, irisMessageService, irisChatWebsocketService, llmTokenUsageService); this.irisSettingsService = irisSettingsService; this.irisChatWebsocketService = irisChatWebsocketService; this.authCheckService = authCheckService; @@ -132,35 +121,9 @@ private void requestAndHandleResponse(IrisCourseChatSession session, String vari pyrisPipelineService.executeCourseChatPipeline(variant, chatSession, competencyJol); } - /** - * Handles the status update of a CourseChatJob by sending the result to the student via the Websocket. - * - * @param job The job that was executed - * @param statusUpdate The status update of the job - */ - public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { - var session = (IrisCourseChatSession) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); - IrisMessage savedMessage; - if (statusUpdate.result() != null) { - var message = new IrisMessage(); - message.addContent(new IrisTextMessageContent(statusUpdate.result())); - savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); - } - else { - savedMessage = null; - irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); - } - if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { - if (savedMessage != null) { - llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, - builder -> builder.withIrisMessageID(savedMessage.getId()).withUser(session.getUser()).withCourse(session.getCourse())); - } - else { - llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withUser(session.getUser()).withCourse(session.getCourse())); - } - } - updateLatestSuggestions(session, statusUpdate.suggestions()); + @Override + protected void setLLMTokenUsageParameters(LLMTokenUsageService.LLMTokenUsageBuilder builder, IrisCourseChatSession session) { + builder.withCourse(session.getCourse().getId()); } /** diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index 93db671f2783..a51f1730e98c 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -10,7 +10,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; -import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.exception.AccessForbiddenException; import de.tum.cit.aet.artemis.core.exception.ConflictException; @@ -19,16 +18,12 @@ import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.exercise.domain.Submission; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; -import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender; -import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; import de.tum.cit.aet.artemis.iris.domain.session.IrisExerciseChatSession; import de.tum.cit.aet.artemis.iris.domain.settings.IrisSubSettingsType; import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository; import de.tum.cit.aet.artemis.iris.service.IrisMessageService; import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.job.ExerciseChatJob; import de.tum.cit.aet.artemis.iris.service.settings.IrisSettingsService; import de.tum.cit.aet.artemis.iris.service.websocket.IrisChatWebsocketService; import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise; @@ -44,10 +39,6 @@ @Profile(PROFILE_IRIS) public class IrisExerciseChatSessionService extends AbstractIrisChatSessionService implements IrisRateLimitedFeatureInterface { - private final IrisMessageService irisMessageService; - - private final LLMTokenUsageService llmTokenUsageService; - private final IrisSettingsService irisSettingsService; private final IrisChatWebsocketService irisChatWebsocketService; @@ -71,9 +62,7 @@ public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLM ProgrammingExerciseStudentParticipationRepository programmingExerciseStudentParticipationRepository, ProgrammingSubmissionRepository programmingSubmissionRepository, IrisRateLimitService rateLimitService, PyrisPipelineService pyrisPipelineService, ProgrammingExerciseRepository programmingExerciseRepository, ObjectMapper objectMapper) { - super(irisSessionRepository, objectMapper); - this.irisMessageService = irisMessageService; - this.llmTokenUsageService = llmTokenUsageService; + super(irisSessionRepository, objectMapper, irisMessageService, irisChatWebsocketService, llmTokenUsageService); this.irisSettingsService = irisSettingsService; this.irisChatWebsocketService = irisChatWebsocketService; this.authCheckService = authCheckService; @@ -163,49 +152,9 @@ private Optional getLatestSubmissionIfExists(ProgrammingE .flatMap(sub -> programmingSubmissionRepository.findWithEagerResultsAndFeedbacksAndBuildLogsById(sub.getId())); } - /** - * Handles the status update of a ExerciseChatJob by sending the result to the student via the Websocket. - * - * @param job The job that was executed - * @param statusUpdate The status update of the job - */ - public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { - var session = (IrisExerciseChatSession) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); - IrisMessage savedMessage; - if (statusUpdate.result() != null) { - var message = new IrisMessage(); - message.addContent(new IrisTextMessageContent(statusUpdate.result())); - savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); - } - else { - savedMessage = null; - irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); - } - - if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { - if (savedMessage != null) { - // generated message is first sent and generated trace is saved - var llmTokenUsageTrace = llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, - builder -> builder.withIrisMessageID(savedMessage.getId()).withExercise(session.getExercise()).withUser(session.getUser()) - .withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember())); - traces.put(job.jobId(), llmTokenUsageTrace); - } - else { - // interaction suggestion is sent and appended to the generated trace if it exists, trace is then removed, - // because interaction suggestion is the last message from Iris in the pipeline - if (traces.containsKey(job.jobId())) { - var trace = traces.get(job.jobId()); - llmTokenUsageService.appendRequestsToTrace(statusUpdate.tokens(), trace); - traces.remove(job.jobId()); - } - else { - llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withExercise(session.getExercise()) - .withUser(session.getUser()).withCourse(session.getExercise().getCourseViaExerciseGroupOrCourseMember())); - } - } - } - - updateLatestSuggestions(session, statusUpdate.suggestions()); + @Override + protected void setLLMTokenUsageParameters(LLMTokenUsageService.LLMTokenUsageBuilder builder, IrisExerciseChatSession session) { + var exercise = session.getExercise(); + builder.withCourse(exercise.getCourseViaExerciseGroupOrCourseMember().getId()).withExercise(exercise.getId()); } } diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java index 24085a97f70c..7b7279a25053 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java @@ -69,7 +69,7 @@ void generateCompetencies_asEditor_shouldSucceed() throws Exception { // In the real system, this would be triggered by Pyris via a REST call to the Artemis server String jobId = "testJobId"; String userLogin = TEST_PREFIX + "editor1"; - CompetencyExtractionJob job = new CompetencyExtractionJob(jobId, course.getId(), userUtilService.getUserByLogin(userLogin)); + CompetencyExtractionJob job = new CompetencyExtractionJob(jobId, course.getId(), userUtilService.getUserByLogin(userLogin).getId()); irisCompetencyGenerationService.handleStatusUpdate(job, new PyrisCompetencyStatusUpdateDTO(stages, recommendations, null)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PyrisCompetencyStatusUpdateDTO.class); From e437c71665599f240afff627137461d48082d14c Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Mon, 21 Oct 2024 17:39:10 +0200 Subject: [PATCH 24/30] Update tests to work with new changes --- .../aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java index 6ae7ebf00400..adb5b009809f 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -160,7 +160,7 @@ void testTokenTrackingSavedExerciseChat() { irisMessageRepository.save(irisMessage); var tokens = getMockLLMCosts(); LLMTokenUsageTrace tokenUsageTrace = llmTokenUsageService.saveLLMTokenUsage(tokens, LLMServiceType.IRIS, - builder -> builder.withIrisMessageID(irisMessage.getId()).withExercise(exercise).withUser(irisSession.getUser()).withCourse(course)); + builder -> builder.withIrisMessageID(irisMessage.getId()).withExercise(exercise.getId()).withUser(irisSession.getUser().getId()).withCourse(course.getId())); assertThat(tokenUsageTrace.getServiceType()).isEqualTo(LLMServiceType.IRIS); assertThat(tokenUsageTrace.getIrisMessageId()).isEqualTo(irisMessage.getId()); assertThat(tokenUsageTrace.getExerciseId()).isEqualTo(exercise.getId()); From cc127af2af5482c6ccfd7e6ba43b4098e67d0749 Mon Sep 17 00:00:00 2001 From: "Felix T.J. Dietrich" Date: Tue, 22 Oct 2024 10:41:24 +0200 Subject: [PATCH 25/30] Athena: Add LLM token usage tracking (#9554) --- .../artemis/athena/dto/ResponseMetaDTO.java | 17 +++++++ .../AthenaFeedbackSuggestionsService.java | 48 +++++++++++++++++-- 2 files changed, 60 insertions(+), 5 deletions(-) create mode 100644 src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java diff --git a/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java b/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java new file mode 100644 index 000000000000..e80830620b37 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java @@ -0,0 +1,17 @@ +package de.tum.cit.aet.artemis.athena.dto; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; + +import de.tum.cit.aet.artemis.core.domain.LLMRequest; + +/** + * DTO representing the meta information in the Athena response. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public record ResponseMetaDTO(TotalUsage totalUsage, List llmRequests) { + + public record TotalUsage(Integer numInputTokens, Integer numOutputTokens, Integer numTotalTokens, Float cost) { + } +} diff --git a/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java b/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java index d9c81849b396..d3632f209ca0 100644 --- a/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java +++ b/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java @@ -17,10 +17,18 @@ import de.tum.cit.aet.artemis.athena.dto.ExerciseBaseDTO; import de.tum.cit.aet.artemis.athena.dto.ModelingFeedbackDTO; import de.tum.cit.aet.artemis.athena.dto.ProgrammingFeedbackDTO; +import de.tum.cit.aet.artemis.athena.dto.ResponseMetaDTO; import de.tum.cit.aet.artemis.athena.dto.SubmissionBaseDTO; import de.tum.cit.aet.artemis.athena.dto.TextFeedbackDTO; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; +import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.exception.ConflictException; import de.tum.cit.aet.artemis.core.exception.NetworkingException; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; +import de.tum.cit.aet.artemis.exercise.domain.Exercise; +import de.tum.cit.aet.artemis.exercise.domain.Submission; +import de.tum.cit.aet.artemis.exercise.domain.participation.StudentParticipation; import de.tum.cit.aet.artemis.modeling.domain.ModelingExercise; import de.tum.cit.aet.artemis.modeling.domain.ModelingSubmission; import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise; @@ -48,20 +56,24 @@ public class AthenaFeedbackSuggestionsService { private final AthenaDTOConverterService athenaDTOConverterService; + private final LLMTokenUsageService llmTokenUsageService; + /** * Create a new AthenaFeedbackSuggestionsService to receive feedback suggestions from the Athena service. * * @param athenaRestTemplate REST template used for the communication with Athena * @param athenaModuleService Athena module serviced used to determine the urls for different modules - * @param athenaDTOConverterService Service to convert exr + * @param athenaDTOConverterService Service to convert exrcises and submissions to DTOs + * @param llmTokenUsageService Service to store the usage of LLM tokens */ public AthenaFeedbackSuggestionsService(@Qualifier("athenaRestTemplate") RestTemplate athenaRestTemplate, AthenaModuleService athenaModuleService, - AthenaDTOConverterService athenaDTOConverterService) { + AthenaDTOConverterService athenaDTOConverterService, LLMTokenUsageService llmTokenUsageService) { textAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOText.class); programmingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOProgramming.class); modelingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOModeling.class); this.athenaDTOConverterService = athenaDTOConverterService; this.athenaModuleService = athenaModuleService; + this.llmTokenUsageService = llmTokenUsageService; } @JsonInclude(JsonInclude.Include.NON_EMPTY) @@ -69,15 +81,15 @@ private record RequestDTO(ExerciseBaseDTO exercise, SubmissionBaseDTO submission } @JsonInclude(JsonInclude.Include.NON_EMPTY) - private record ResponseDTOText(List data) { + private record ResponseDTOText(List data, ResponseMetaDTO meta) { } @JsonInclude(JsonInclude.Include.NON_EMPTY) - private record ResponseDTOProgramming(List data) { + private record ResponseDTOProgramming(List data, ResponseMetaDTO meta) { } @JsonInclude(JsonInclude.Include.NON_EMPTY) - private record ResponseDTOModeling(List data) { + private record ResponseDTOModeling(List data, ResponseMetaDTO meta) { } /** @@ -100,6 +112,7 @@ public List getTextFeedbackSuggestions(TextExercise exercise, T final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded); ResponseDTOText response = textAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0); log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data); + storeTokenUsage(exercise, submission, response.meta, !isGraded); return response.data.stream().toList(); } @@ -117,6 +130,7 @@ public List getProgrammingFeedbackSuggestions(Programmin final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded); ResponseDTOProgramming response = programmingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0); log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data); + storeTokenUsage(exercise, submission, response.meta, !isGraded); return response.data.stream().toList(); } @@ -139,6 +153,30 @@ public List getModelingFeedbackSuggestions(ModelingExercise final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded); ResponseDTOModeling response = modelingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0); log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data); + storeTokenUsage(exercise, submission, response.meta, !isGraded); return response.data; } + + /** + * Store the usage of LLM tokens for a given submission + * + * @param exercise the exercise the submission belongs to + * @param submission the submission for which the tokens were used + * @param meta the meta information of the response from Athena + * @param isPreliminaryFeedback whether the feedback is preliminary or not + */ + private void storeTokenUsage(Exercise exercise, Submission submission, ResponseMetaDTO meta, Boolean isPreliminaryFeedback) { + if (meta == null) { + return; + } + Long courseId = exercise.getCourseViaExerciseGroupOrCourseMember().getId(); + Long userId = ((StudentParticipation) submission.getParticipation()).getStudent().map(User::getId).orElse(null); + List llmRequests = meta.llmRequests(); + if (llmRequests == null) { + return; + } + + llmTokenUsageService.saveLLMTokenUsage(llmRequests, LLMServiceType.ATHENA, + (llmTokenUsageBuilder -> llmTokenUsageBuilder.withCourse(courseId).withExercise(exercise.getId()).withUser(userId))); + } } From aded0ee127e1a66f020b0e16401e7578f930f53b Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Tue, 22 Oct 2024 14:31:49 +0200 Subject: [PATCH 26/30] Implement feedback and fix server tests --- .../cit/aet/artemis/athena/dto/ResponseMetaDTO.java | 2 +- .../service/AthenaFeedbackSuggestionsService.java | 8 +++++++- .../de/tum/cit/aet/artemis/core/domain/LLMRequest.java | 10 ++++++++++ .../aet/artemis/core/domain/LLMTokenUsageRequest.java | 9 +++++++++ .../aet/artemis/core/domain/LLMTokenUsageTrace.java | 3 +++ 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java b/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java index e80830620b37..44d36a033552 100644 --- a/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java @@ -9,7 +9,7 @@ /** * DTO representing the meta information in the Athena response. */ -@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonInclude(JsonInclude.Include.NON_EMPTY) public record ResponseMetaDTO(TotalUsage totalUsage, List llmRequests) { public record TotalUsage(Integer numInputTokens, Integer numOutputTokens, Integer numTotalTokens, Float cost) { diff --git a/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java b/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java index d3632f209ca0..210b3c7ba859 100644 --- a/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java +++ b/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java @@ -170,7 +170,13 @@ private void storeTokenUsage(Exercise exercise, Submission submission, ResponseM return; } Long courseId = exercise.getCourseViaExerciseGroupOrCourseMember().getId(); - Long userId = ((StudentParticipation) submission.getParticipation()).getStudent().map(User::getId).orElse(null); + Long userId; + if (submission.getParticipation() instanceof StudentParticipation studentParticipation) { + userId = studentParticipation.getStudent().map(User::getId).orElse(null); + } + else { + userId = null; + } List llmRequests = meta.llmRequests(); if (llmRequests == null) { return; diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java index bc3ff7bbe23f..040b6ad88893 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java @@ -1,4 +1,14 @@ package de.tum.cit.aet.artemis.core.domain; +/** + * This record is used for the LLMTokenUsageService to provide relevant information about LLM Token usage + * + * @param model LLM model (e.g. gpt-4o) + * @param numInputTokens number of tokens of the LLM call + * @param costPerMillionInputToken cost in Euro per million input tokens + * @param numOutputTokens number of tokens of the LLM answer + * @param costPerMillionOutputToken cost in Euro per million output tokens + * @param pipelineId String with the pipeline name (e.g. IRIS_COURSE_CHAT_PIPELINE) + */ public record LLMRequest(String model, int numInputTokens, float costPerMillionInputToken, int numOutputTokens, float costPerMillionOutputToken, String pipelineId) { } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java index 1b769f5ea97b..1bdaefaf58ba 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java @@ -10,15 +10,24 @@ import com.fasterxml.jackson.annotation.JsonInclude; +/** + * LLMTokenUsageRequest represents a single LLM request usage with all its information about the request + */ @Entity @Table(name = "llm_token_usage_request") @Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) @JsonInclude(JsonInclude.Include.NON_EMPTY) public class LLMTokenUsageRequest extends DomainObject { + /** + * LLM model (e.g. gpt-4o) + */ @Column(name = "model") private String model; + /** + * pipeline that was called (e.g. IRIS_COURSE_CHAT_PIPELINE) + */ @Column(name = "service_pipeline_id") private String servicePipelineId; diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java index cdf263da00c6..7e37fdd97c54 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java @@ -19,6 +19,9 @@ import com.fasterxml.jackson.annotation.JsonInclude; +/** + * This represents a trace that contains one or more requests of type {@link LLMTokenUsageRequest} + */ @Entity @Table(name = "llm_token_usage_trace") @Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) From 8a7bc2e92b726eb526a45c5f8c5cd1447bd76b18 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Tue, 22 Oct 2024 15:01:24 +0200 Subject: [PATCH 27/30] add foreign keys with onDelete=SET NULL to all ids in LLMTokenUsageTrace --- .../liquibase/changelog/20241018053210_changelog.xml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml b/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml index b5c64a0b9eec..e514ec8e5f58 100644 --- a/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml @@ -17,6 +17,18 @@ + + + + From c7e7db6d342e147caa8500daca985be728128aac Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Tue, 22 Oct 2024 15:08:56 +0200 Subject: [PATCH 28/30] Correct wrong Long type, update comment --- .../tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java | 2 +- .../tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java index 1bdaefaf58ba..81d7ca8f21a8 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java @@ -11,7 +11,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; /** - * LLMTokenUsageRequest represents a single LLM request usage with all its information about the request + * Represents the token usage details of a single LLM request, including model, service pipeline, token counts, and costs. */ @Entity @Table(name = "llm_token_usage_request") diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java index 7e37fdd97c54..1773a0c507da 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java @@ -77,11 +77,11 @@ public void setExerciseId(Long exerciseId) { this.exerciseId = exerciseId; } - public long getUserId() { + public Long getUserId() { return userId; } - public void setUserId(long userId) { + public void setUserId(Long userId) { this.userId = userId; } From 79b1b88cd0be5af8284c01ad72d35aca89301ba0 Mon Sep 17 00:00:00 2001 From: Patrick Bassner Date: Wed, 23 Oct 2024 19:42:20 +0200 Subject: [PATCH 29/30] Make LLM token tracking of chat suggestions multi-node compatible --- .../core/service/LLMTokenUsageService.java | 10 +++++ .../IrisCompetencyGenerationService.java | 4 +- .../iris/service/pyris/PyrisJobService.java | 19 ++++++--- .../pyris/PyrisStatusUpdateService.java | 36 +++++++++------- .../iris/service/pyris/job/CourseChatJob.java | 7 +++- .../service/pyris/job/ExerciseChatJob.java | 7 +++- .../pyris/job/SessionBasedPyrisJob.java | 9 ---- .../job/TrackedSessionBasedPyrisJob.java | 14 +++++++ .../AbstractIrisChatSessionService.java | 42 +++++++++---------- .../IrisTextExerciseChatSessionService.java | 5 ++- 10 files changed, 99 insertions(+), 54 deletions(-) delete mode 100644 src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/SessionBasedPyrisJob.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/TrackedSessionBasedPyrisJob.java diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 5ffe5f379ff5..c3dc2af1e519 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -81,6 +81,16 @@ public void appendRequestsToTrace(List requests, LLMTokenUsageTrace llmTokenUsageRequestRepository.saveAll(requestSet); } + /** + * Finds an LLMTokenUsageTrace by its ID. + * + * @param id The ID of the LLMTokenUsageTrace to find. + * @return An Optional containing the LLMTokenUsageTrace if found, or an empty Optional otherwise. + */ + public Optional findLLMTokenUsageTraceById(Long id) { + return llmTokenUsageTraceRepository.findById(id); + } + /** * Class LLMTokenUsageBuilder to be used for saveLLMTokenUsage() */ diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index f8d2a0201198..49e08cae1fd7 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -75,7 +75,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String * @param job Job related to the status update * @param statusUpdate the status update containing the new competency recommendations */ - public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { + public CompetencyExtractionJob handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withCourse(course.getId()).withUser(job.userId())); @@ -83,6 +83,8 @@ public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatu var user = userRepository.findById(job.userId()).orElseThrow(); websocketService.send(user.getLogin(), websocketTopic(job.courseId()), statusUpdate); + + return job; } private static String websocketTopic(long courseId) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisJobService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisJobService.java index 7933e9e20920..16e8969bc463 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisJobService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisJobService.java @@ -78,14 +78,14 @@ public String createTokenForJob(Function tokenToJobFunction) { public String addExerciseChatJob(Long courseId, Long exerciseId, Long sessionId) { var token = generateJobIdToken(); - var job = new ExerciseChatJob(token, courseId, exerciseId, sessionId); + var job = new ExerciseChatJob(token, courseId, exerciseId, sessionId, null); jobMap.put(token, job); return token; } public String addCourseChatJob(Long courseId, Long sessionId) { var token = generateJobIdToken(); - var job = new CourseChatJob(token, courseId, sessionId); + var job = new CourseChatJob(token, courseId, sessionId, null); jobMap.put(token, job); return token; } @@ -107,10 +107,19 @@ public String addIngestionWebhookJob() { /** * Remove a job from the job map. * - * @param token the token + * @param job the job to remove + */ + public void removeJob(PyrisJob job) { + jobMap.remove(job.jobId()); + } + + /** + * Store a job in the job map. + * + * @param job the job to store */ - public void removeJob(String token) { - jobMap.remove(token); + public void updateJob(PyrisJob job) { + jobMap.put(job.jobId(), job); } /** diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java index 1526311fe8c7..cdd398e5c683 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java @@ -20,7 +20,9 @@ import de.tum.cit.aet.artemis.iris.service.pyris.job.CourseChatJob; import de.tum.cit.aet.artemis.iris.service.pyris.job.ExerciseChatJob; import de.tum.cit.aet.artemis.iris.service.pyris.job.IngestionWebhookJob; +import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob; import de.tum.cit.aet.artemis.iris.service.pyris.job.TextExerciseChatJob; +import de.tum.cit.aet.artemis.iris.service.pyris.job.TrackedSessionBasedPyrisJob; import de.tum.cit.aet.artemis.iris.service.session.IrisCourseChatSessionService; import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; import de.tum.cit.aet.artemis.iris.service.session.IrisTextExerciseChatSessionService; @@ -52,15 +54,16 @@ public PyrisStatusUpdateService(PyrisJobService pyrisJobService, IrisExerciseCha } /** - * Handles the status update of a exercise chat job and forwards it to {@link IrisExerciseChatSessionService#handleStatusUpdate(ExerciseChatJob, PyrisChatStatusUpdateDTO)} + * Handles the status update of a exercise chat job and forwards it to + * {@link IrisExerciseChatSessionService#handleStatusUpdate(TrackedSessionBasedPyrisJob, PyrisChatStatusUpdateDTO)} * * @param job the job that is updated * @param statusUpdate the status update */ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { - irisExerciseChatSessionService.handleStatusUpdate(job, statusUpdate); + var updatedJob = irisExerciseChatSessionService.handleStatusUpdate(job, statusUpdate); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob); } /** @@ -71,22 +74,22 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta * @param statusUpdate the status update */ public void handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) { - irisTextExerciseChatSessionService.handleStatusUpdate(job, statusUpdate); + var updatedJob = irisTextExerciseChatSessionService.handleStatusUpdate(job, statusUpdate); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob); } /** * Handles the status update of a course chat job and forwards it to - * {@link de.tum.cit.aet.artemis.iris.service.session.IrisCourseChatSessionService#handleStatusUpdate(CourseChatJob, PyrisChatStatusUpdateDTO)} + * {@link de.tum.cit.aet.artemis.iris.service.session.IrisCourseChatSessionService#handleStatusUpdate(TrackedSessionBasedPyrisJob, PyrisChatStatusUpdateDTO)} * * @param job the job that is updated * @param statusUpdate the status update */ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { - courseChatSessionService.handleStatusUpdate(job, statusUpdate); + var updatedJob = courseChatSessionService.handleStatusUpdate(job, statusUpdate); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob); } /** @@ -97,26 +100,29 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu * @param statusUpdate the status update */ public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { - competencyGenerationService.handleStatusUpdate(job, statusUpdate); + var updatedJob = competencyGenerationService.handleStatusUpdate(job, statusUpdate); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob); } /** - * Removes the job from the job service if the status update indicates that the job is terminated. - * This is the case if all stages are in a terminal state. + * Removes the job from the job service if the status update indicates that the job is terminated; updates it to distribute changes otherwise. + * A job is terminated if all stages are in a terminal state. *

* * @see PyrisStageState#isTerminal() * * @param stages the stages of the status update - * @param job the job to remove + * @param job the job to remove or to update */ - private void removeJobIfTerminated(List stages, String job) { + private void removeJobIfTerminatedElseUpdate(List stages, PyrisJob job) { var isDone = stages.stream().map(PyrisStageDTO::state).allMatch(PyrisStageState::isTerminal); if (isDone) { pyrisJobService.removeJob(job); } + else { + pyrisJobService.updateJob(job); + } } /** @@ -128,6 +134,6 @@ private void removeJobIfTerminated(List stages, String job) { */ public void handleStatusUpdate(IngestionWebhookJob job, PyrisLectureIngestionStatusUpdateDTO statusUpdate) { statusUpdate.stages().forEach(stage -> log.info(stage.name() + ":" + stage.message())); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), job); } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java index c05cbf9b94ea..2f389e22ed96 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java @@ -9,10 +9,15 @@ * This job is used to reference the details of a course chat session when Pyris sends a status update. */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record CourseChatJob(String jobId, long courseId, long sessionId) implements SessionBasedPyrisJob { +public record CourseChatJob(String jobId, long courseId, long sessionId, Long traceId) implements TrackedSessionBasedPyrisJob { @Override public boolean canAccess(Course course) { return courseId == course.getId(); } + + @Override + public TrackedSessionBasedPyrisJob withTraceId(long traceId) { + return new CourseChatJob(jobId, courseId, sessionId, traceId); + } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java index 1c2278cb2697..f74e7360be82 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java @@ -10,7 +10,7 @@ * This job is used to reference the details of a exercise chat session when Pyris sends a status update. */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record ExerciseChatJob(String jobId, long courseId, long exerciseId, long sessionId) implements SessionBasedPyrisJob { +public record ExerciseChatJob(String jobId, long courseId, long exerciseId, long sessionId, Long traceId) implements TrackedSessionBasedPyrisJob { @Override public boolean canAccess(Course course) { @@ -21,4 +21,9 @@ public boolean canAccess(Course course) { public boolean canAccess(Exercise exercise) { return exercise.getId().equals(exerciseId); } + + @Override + public TrackedSessionBasedPyrisJob withTraceId(long traceId) { + return new ExerciseChatJob(jobId, courseId, exerciseId, sessionId, traceId); + } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/SessionBasedPyrisJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/SessionBasedPyrisJob.java deleted file mode 100644 index 03c2e4007838..000000000000 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/SessionBasedPyrisJob.java +++ /dev/null @@ -1,9 +0,0 @@ -package de.tum.cit.aet.artemis.iris.service.pyris.job; - -/** - * An interface Pyris job that is associated with a session. - */ -public interface SessionBasedPyrisJob extends PyrisJob { - - long sessionId(); -} diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/TrackedSessionBasedPyrisJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/TrackedSessionBasedPyrisJob.java new file mode 100644 index 000000000000..bdd180103840 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/TrackedSessionBasedPyrisJob.java @@ -0,0 +1,14 @@ +package de.tum.cit.aet.artemis.iris.service.pyris.job; + +/** + * A Pyris job that has a session id and stored its own LLM usage tracing ID. + * This is used for chat jobs where we need to reference the trace ID later after chat suggestions have been generated. + */ +public interface TrackedSessionBasedPyrisJob extends PyrisJob { + + long sessionId(); + + Long traceId(); + + TrackedSessionBasedPyrisJob withTraceId(long traceId); +} diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java index 16df99a68337..6f0b5a9f411a 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java @@ -1,13 +1,13 @@ package de.tum.cit.aet.artemis.iris.service.session; -import java.util.HashMap; import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import de.tum.cit.aet.artemis.core.domain.LLMServiceType; -import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender; @@ -16,7 +16,7 @@ import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository; import de.tum.cit.aet.artemis.iris.service.IrisMessageService; import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.job.SessionBasedPyrisJob; +import de.tum.cit.aet.artemis.iris.service.pyris.job.TrackedSessionBasedPyrisJob; import de.tum.cit.aet.artemis.iris.service.websocket.IrisChatWebsocketService; public abstract class AbstractIrisChatSessionService implements IrisChatBasedFeatureInterface, IrisRateLimitedFeatureInterface { @@ -31,8 +31,6 @@ public abstract class AbstractIrisChatSessionService private final ObjectMapper objectMapper; - protected final HashMap traces = new HashMap<>(); - public AbstractIrisChatSessionService(IrisSessionRepository irisSessionRepository, ObjectMapper objectMapper, IrisMessageService irisMessageService, IrisChatWebsocketService irisChatWebsocketService, LLMTokenUsageService llmTokenUsageService) { this.irisSessionRepository = irisSessionRepository; @@ -69,8 +67,9 @@ protected void updateLatestSuggestions(S session, List latestSuggestions * * @param job The job that was executed * @param statusUpdate The status update of the job + * @return the same job record or a new job record with the same job id if changes were made */ - public void handleStatusUpdate(SessionBasedPyrisJob job, PyrisChatStatusUpdateDTO statusUpdate) { + public TrackedSessionBasedPyrisJob handleStatusUpdate(TrackedSessionBasedPyrisJob job, PyrisChatStatusUpdateDTO statusUpdate) { var session = (S) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); IrisMessage savedMessage; if (statusUpdate.result() != null) { @@ -84,6 +83,7 @@ public void handleStatusUpdate(SessionBasedPyrisJob job, PyrisChatStatusUpdateDT irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } + AtomicReference updatedJob = new AtomicReference<>(job); if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { if (savedMessage != null) { // generated message is first sent and generated trace is saved @@ -92,27 +92,27 @@ public void handleStatusUpdate(SessionBasedPyrisJob job, PyrisChatStatusUpdateDT this.setLLMTokenUsageParameters(builder, session); return builder; }); - traces.put(job.jobId(), llmTokenUsageTrace); + + updatedJob.set(job.withTraceId(llmTokenUsageTrace.getId())); } else { - // interaction suggestion is sent and appended to the generated trace if it exists, trace is then removed, - // because interaction suggestion is the last message from Iris in the pipeline - if (traces.containsKey(job.jobId())) { - var trace = traces.get(job.jobId()); - llmTokenUsageService.appendRequestsToTrace(statusUpdate.tokens(), trace); - traces.remove(job.jobId()); - } - else { - llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> { - builder.withUser(session.getUser().getId()); - this.setLLMTokenUsageParameters(builder, session); - return builder; - }); - } + // interaction suggestion is sent and appended to the generated trace if it exists + Optional.ofNullable(job.traceId()).flatMap(llmTokenUsageService::findLLMTokenUsageTraceById) + .ifPresentOrElse(trace -> llmTokenUsageService.appendRequestsToTrace(statusUpdate.tokens(), trace), () -> { + var llmTokenUsage = llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> { + builder.withUser(session.getUser().getId()); + this.setLLMTokenUsageParameters(builder, session); + return builder; + }); + + updatedJob.set(job.withTraceId(llmTokenUsage.getId())); + }); } } updateLatestSuggestions(session, statusUpdate.suggestions()); + + return updatedJob.get(); } protected abstract void setLLMTokenUsageParameters(LLMTokenUsageService.LLMTokenUsageBuilder builder, S session); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java index 4520417aad48..5682580ccb89 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java @@ -116,7 +116,8 @@ public void requestAndHandleResponse(IrisTextExerciseChatSession irisSession) { * @param job The job that is updated * @param statusUpdate The status update */ - public void handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) { + public TextExerciseChatJob handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) { + // TODO: LLM Token Tracking - or better, make this class a subclass of AbstractIrisChatSessionService var session = (IrisTextExerciseChatSession) irisSessionRepository.findByIdElseThrow(job.sessionId()); if (statusUpdate.result() != null) { var message = session.newMessage(); @@ -127,6 +128,8 @@ public void handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatSta else { irisChatWebsocketService.sendMessage(session, null, statusUpdate.stages()); } + + return job; } @Override From 5a2e92a05eeab2a59c68254eb6f3d1895183382b Mon Sep 17 00:00:00 2001 From: Patrick Bassner Date: Wed, 23 Oct 2024 19:48:53 +0200 Subject: [PATCH 30/30] Added return statements to handleStatusUpdate methods --- .../artemis/iris/service/IrisCompetencyGenerationService.java | 1 + .../iris/service/session/IrisTextExerciseChatSessionService.java | 1 + 2 files changed, 2 insertions(+) diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 49e08cae1fd7..88906ff80628 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -74,6 +74,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String * * @param job Job related to the status update * @param statusUpdate the status update containing the new competency recommendations + * @return the same job that was passed in */ public CompetencyExtractionJob handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java index 5682580ccb89..8702db7bdf54 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java @@ -115,6 +115,7 @@ public void requestAndHandleResponse(IrisTextExerciseChatSession irisSession) { * * @param job The job that is updated * @param statusUpdate The status update + * @return The same job that was passed in */ public TextExerciseChatJob handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) { // TODO: LLM Token Tracking - or better, make this class a subclass of AbstractIrisChatSessionService