Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General: Track token usage of LLM service requests #9455

Merged
merged 34 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a5dfdbb
Define LLM token usage model
alexjoham Sep 28, 2024
d0bdae3
Update table, save data recieved from Pyris Exercise chat pipeline
alexjoham Oct 11, 2024
2a08cb2
Implement competency generation tracking, update enum
alexjoham Oct 11, 2024
f85cf46
Add comments to LLMTokenUsageService
alexjoham Oct 11, 2024
65fb259
Fix server test failures by checking if tokens received
alexjoham Oct 12, 2024
188ff22
Update database for cost tracking and trace_id functionality
alexjoham Oct 12, 2024
be85a3b
Update database, add information to competency gen, change traceId calc
alexjoham Oct 12, 2024
e974d59
Implement server Integration tests for token tracking and saving
alexjoham Oct 13, 2024
6337162
Update code based on code-rabbit feedback, fix tests
alexjoham Oct 14, 2024
84a60dc
minor comment changes, remove tokens from frontend
alexjoham Oct 14, 2024
5b0ab48
Merge branch 'develop' into feature/track-usage-of-iris-requests
alexjoham Oct 14, 2024
62dad8b
Fix github test fails
alexjoham Oct 14, 2024
897d643
Change servicetype to type String to prevent failures
alexjoham Oct 14, 2024
1d10860
Change servicetype to type String to prevent failures
alexjoham Oct 14, 2024
8b27861
Merge remote-tracking branch 'origin/feature/track-usage-of-iris-requ…
alexjoham Oct 14, 2024
86294c1
Fix test failure by removing @SpyBean
alexjoham Oct 15, 2024
56b20e7
Update database to safe only IDs, fix competency Integration Test user
alexjoham Oct 15, 2024
8a29c82
Implement builder pattern based on feedback
alexjoham Oct 16, 2024
abbd28f
Update database migration with foreign keys and on delete null
alexjoham Oct 16, 2024
8d34428
Rework database, update saveLLMTokens method
alexjoham Oct 18, 2024
52bf023
Implement new service in all Pipelines, update database, update test
alexjoham Oct 19, 2024
b8f5cca
fix server tests
krusche Oct 20, 2024
82fb76d
fix function naming
FelixTJDietrich Oct 21, 2024
6d3037a
replace ArraySet
FelixTJDietrich Oct 21, 2024
c785417
Merge branch 'develop' into feature/track-usage-of-iris-requests
FelixTJDietrich Oct 21, 2024
9f4cccd
Refactored token usage tracking and improved session-based job handling
bassner Oct 21, 2024
e437c71
Update tests to work with new changes
alexjoham Oct 21, 2024
cc127af
Athena: Add LLM token usage tracking (#9554)
FelixTJDietrich Oct 22, 2024
aded0ee
Implement feedback and fix server tests
alexjoham Oct 22, 2024
8a7bc2e
add foreign keys with onDelete=SET NULL to all ids in LLMTokenUsageTrace
alexjoham Oct 22, 2024
c7e7db6
Correct wrong Long type, update comment
alexjoham Oct 22, 2024
6a90887
Merge branch 'develop' into feature/track-usage-of-iris-requests
FelixTJDietrich Oct 23, 2024
79b1b88
Make LLM token tracking of chat suggestions multi-node compatible
bassner Oct 23, 2024
5a2e92a
Added return statements to handleStatusUpdate methods
bassner Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package de.tum.cit.aet.artemis.athena.dto;

import java.util.List;

import com.fasterxml.jackson.annotation.JsonInclude;

import de.tum.cit.aet.artemis.core.domain.LLMRequest;

/**
* DTO representing the meta information in the Athena response.
*/
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record ResponseMetaDTO(TotalUsage totalUsage, List<LLMRequest> llmRequests) {

public record TotalUsage(Integer numInputTokens, Integer numOutputTokens, Integer numTotalTokens, Float cost) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@
import de.tum.cit.aet.artemis.athena.dto.ExerciseBaseDTO;
import de.tum.cit.aet.artemis.athena.dto.ModelingFeedbackDTO;
import de.tum.cit.aet.artemis.athena.dto.ProgrammingFeedbackDTO;
import de.tum.cit.aet.artemis.athena.dto.ResponseMetaDTO;
import de.tum.cit.aet.artemis.athena.dto.SubmissionBaseDTO;
import de.tum.cit.aet.artemis.athena.dto.TextFeedbackDTO;
import de.tum.cit.aet.artemis.core.domain.LLMRequest;
import de.tum.cit.aet.artemis.core.domain.LLMServiceType;
import de.tum.cit.aet.artemis.core.domain.User;
import de.tum.cit.aet.artemis.core.exception.ConflictException;
import de.tum.cit.aet.artemis.core.exception.NetworkingException;
import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService;
import de.tum.cit.aet.artemis.exercise.domain.Exercise;
import de.tum.cit.aet.artemis.exercise.domain.Submission;
import de.tum.cit.aet.artemis.exercise.domain.participation.StudentParticipation;
import de.tum.cit.aet.artemis.modeling.domain.ModelingExercise;
import de.tum.cit.aet.artemis.modeling.domain.ModelingSubmission;
import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise;
Expand Down Expand Up @@ -48,36 +56,40 @@ public class AthenaFeedbackSuggestionsService {

private final AthenaDTOConverterService athenaDTOConverterService;

private final LLMTokenUsageService llmTokenUsageService;

/**
* Create a new AthenaFeedbackSuggestionsService to receive feedback suggestions from the Athena service.
*
* @param athenaRestTemplate REST template used for the communication with Athena
* @param athenaModuleService Athena module serviced used to determine the urls for different modules
* @param athenaDTOConverterService Service to convert exr
* @param athenaDTOConverterService Service to convert exrcises and submissions to DTOs
* @param llmTokenUsageService Service to store the usage of LLM tokens
*/
public AthenaFeedbackSuggestionsService(@Qualifier("athenaRestTemplate") RestTemplate athenaRestTemplate, AthenaModuleService athenaModuleService,
AthenaDTOConverterService athenaDTOConverterService) {
AthenaDTOConverterService athenaDTOConverterService, LLMTokenUsageService llmTokenUsageService) {
textAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOText.class);
programmingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOProgramming.class);
modelingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOModeling.class);
this.athenaDTOConverterService = athenaDTOConverterService;
this.athenaModuleService = athenaModuleService;
this.llmTokenUsageService = llmTokenUsageService;
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record RequestDTO(ExerciseBaseDTO exercise, SubmissionBaseDTO submission, boolean isGraded) {
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record ResponseDTOText(List<TextFeedbackDTO> data) {
private record ResponseDTOText(List<TextFeedbackDTO> data, ResponseMetaDTO meta) {
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record ResponseDTOProgramming(List<ProgrammingFeedbackDTO> data) {
private record ResponseDTOProgramming(List<ProgrammingFeedbackDTO> data, ResponseMetaDTO meta) {
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record ResponseDTOModeling(List<ModelingFeedbackDTO> data) {
private record ResponseDTOModeling(List<ModelingFeedbackDTO> data, ResponseMetaDTO meta) {
}

/**
Expand All @@ -100,6 +112,7 @@ public List<TextFeedbackDTO> getTextFeedbackSuggestions(TextExercise exercise, T
final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded);
ResponseDTOText response = textAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0);
log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data);
storeTokenUsage(exercise, submission, response.meta, !isGraded);
return response.data.stream().toList();
}

Expand All @@ -117,6 +130,7 @@ public List<ProgrammingFeedbackDTO> getProgrammingFeedbackSuggestions(Programmin
final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded);
ResponseDTOProgramming response = programmingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0);
log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data);
storeTokenUsage(exercise, submission, response.meta, !isGraded);
return response.data.stream().toList();
}

Expand All @@ -139,6 +153,36 @@ public List<ModelingFeedbackDTO> getModelingFeedbackSuggestions(ModelingExercise
final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded);
ResponseDTOModeling response = modelingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0);
log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data);
storeTokenUsage(exercise, submission, response.meta, !isGraded);
return response.data;
}

/**
* Store the usage of LLM tokens for a given submission
*
* @param exercise the exercise the submission belongs to
* @param submission the submission for which the tokens were used
* @param meta the meta information of the response from Athena
* @param isPreliminaryFeedback whether the feedback is preliminary or not
*/
private void storeTokenUsage(Exercise exercise, Submission submission, ResponseMetaDTO meta, Boolean isPreliminaryFeedback) {
if (meta == null) {
return;
}
Long courseId = exercise.getCourseViaExerciseGroupOrCourseMember().getId();
Long userId;
if (submission.getParticipation() instanceof StudentParticipation studentParticipation) {
userId = studentParticipation.getStudent().map(User::getId).orElse(null);
}
else {
userId = null;
}
List<LLMRequest> llmRequests = meta.llmRequests();
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
if (llmRequests == null) {
return;
}

llmTokenUsageService.saveLLMTokenUsage(llmRequests, LLMServiceType.ATHENA,
(llmTokenUsageBuilder -> llmTokenUsageBuilder.withCourse(courseId).withExercise(exercise.getId()).withUser(userId)));
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.tum.cit.aet.artemis.core.domain;

/**
* This record is used for the LLMTokenUsageService to provide relevant information about LLM Token usage
*
* @param model LLM model (e.g. gpt-4o)
* @param numInputTokens number of tokens of the LLM call
* @param costPerMillionInputToken cost in Euro per million input tokens
* @param numOutputTokens number of tokens of the LLM answer
* @param costPerMillionOutputToken cost in Euro per million output tokens
* @param pipelineId String with the pipeline name (e.g. IRIS_COURSE_CHAT_PIPELINE)
*/
public record LLMRequest(String model, int numInputTokens, float costPerMillionInputToken, int numOutputTokens, float costPerMillionOutputToken, String pipelineId) {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package de.tum.cit.aet.artemis.core.domain;

/**
* Enum representing different types of LLM (Large Language Model) services used in the system.
*/
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
public enum LLMServiceType {
IRIS, ATHENA
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package de.tum.cit.aet.artemis.core.domain;

import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;

import org.hibernate.annotations.Cache;
import org.hibernate.annotations.CacheConcurrencyStrategy;

import com.fasterxml.jackson.annotation.JsonInclude;

/**
* Represents the token usage details of a single LLM request, including model, service pipeline, token counts, and costs.
*/
@Entity
@Table(name = "llm_token_usage_request")
@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE)
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public class LLMTokenUsageRequest extends DomainObject {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

/**
* LLM model (e.g. gpt-4o)
*/
@Column(name = "model")
private String model;

/**
* pipeline that was called (e.g. IRIS_COURSE_CHAT_PIPELINE)
*/
@Column(name = "service_pipeline_id")
private String servicePipelineId;

@Column(name = "num_input_tokens")
private int numInputTokens;

@Column(name = "cost_per_million_input_tokens")
private float costPerMillionInputTokens;

@Column(name = "num_output_tokens")
private int numOutputTokens;

@Column(name = "cost_per_million_output_tokens")
private float costPerMillionOutputTokens;

@ManyToOne
private LLMTokenUsageTrace trace;

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public String getServicePipelineId() {
return servicePipelineId;
}

public void setServicePipelineId(String servicePipelineId) {
this.servicePipelineId = servicePipelineId;
}

public float getCostPerMillionInputTokens() {
return costPerMillionInputTokens;
}

public void setCostPerMillionInputTokens(float costPerMillionInputToken) {
this.costPerMillionInputTokens = costPerMillionInputToken;
}

public float getCostPerMillionOutputTokens() {
return costPerMillionOutputTokens;
}

public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) {
this.costPerMillionOutputTokens = costPerMillionOutputToken;
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved

public int getNumInputTokens() {
return numInputTokens;
}

public void setNumInputTokens(int numInputTokens) {
this.numInputTokens = numInputTokens;
}

public int getNumOutputTokens() {
return numOutputTokens;
}

public void setNumOutputTokens(int numOutputTokens) {
this.numOutputTokens = numOutputTokens;
}

public LLMTokenUsageTrace getTrace() {
return trace;
}

public void setTrace(LLMTokenUsageTrace trace) {
this.trace = trace;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package de.tum.cit.aet.artemis.core.domain;

import java.time.ZonedDateTime;
import java.util.HashSet;
import java.util.Set;

import jakarta.annotation.Nullable;
import jakarta.persistence.CascadeType;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EnumType;
import jakarta.persistence.Enumerated;
import jakarta.persistence.FetchType;
import jakarta.persistence.OneToMany;
import jakarta.persistence.Table;

import org.hibernate.annotations.Cache;
import org.hibernate.annotations.CacheConcurrencyStrategy;

import com.fasterxml.jackson.annotation.JsonInclude;

/**
* This represents a trace that contains one or more requests of type {@link LLMTokenUsageRequest}
*/
@Entity
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
@Table(name = "llm_token_usage_trace")
@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE)
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public class LLMTokenUsageTrace extends DomainObject {

@Column(name = "service")
@Enumerated(EnumType.STRING)
private LLMServiceType serviceType;

@Nullable
@Column(name = "course_id")
private Long courseId;

@Nullable
@Column(name = "exercise_id")
private Long exerciseId;

@Column(name = "user_id")
private Long userId;
krusche marked this conversation as resolved.
Show resolved Hide resolved

@Column(name = "time")
private ZonedDateTime time = ZonedDateTime.now();

@Nullable
@Column(name = "iris_message_id")
private Long irisMessageId;
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

@OneToMany(mappedBy = "trace", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true)
private Set<LLMTokenUsageRequest> llmRequests = new HashSet<>();

public LLMServiceType getServiceType() {
return serviceType;
}

public void setServiceType(LLMServiceType serviceType) {
this.serviceType = serviceType;
}

public Long getCourseId() {
return courseId;
}

public void setCourseId(Long courseId) {
this.courseId = courseId;
}

public Long getExerciseId() {
return exerciseId;
}

public void setExerciseId(Long exerciseId) {
this.exerciseId = exerciseId;
}

public Long getUserId() {
return userId;
}

public void setUserId(Long userId) {
this.userId = userId;
}

public ZonedDateTime getTime() {
return time;
}

public void setTime(ZonedDateTime time) {
this.time = time;
}

public Set<LLMTokenUsageRequest> getLLMRequests() {
return llmRequests;
}

public void setLlmRequests(Set<LLMTokenUsageRequest> llmRequests) {
this.llmRequests = llmRequests;
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved

public Long getIrisMessageId() {
return irisMessageId;
}

public void setIrisMessageId(Long messageId) {
this.irisMessageId = messageId;
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.tum.cit.aet.artemis.core.repository;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Repository;

import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest;
import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository;

@Profile(PROFILE_CORE)
@Repository
public interface LLMTokenUsageRequestRepository extends ArtemisJpaRepository<LLMTokenUsageRequest, Long> {
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading