Skip to content

Commit

Permalink
Make LLM token tracking of chat suggestions multi-node compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
bassner committed Oct 23, 2024
1 parent 6a90887 commit 8671e35
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ public void appendRequestsToTrace(List<LLMRequest> requests, LLMTokenUsageTrace
llmTokenUsageRequestRepository.saveAll(requestSet);
}

/**
* Finds an LLMTokenUsageTrace by its ID.
*
* @param id The ID of the LLMTokenUsageTrace to find.
* @return An Optional containing the LLMTokenUsageTrace if found, or an empty Optional otherwise.
*/
public Optional<LLMTokenUsageTrace> findLLMTokenUsageTraceById(Long id) {
return llmTokenUsageTraceRepository.findById(id);
}

/**
* Class LLMTokenUsageBuilder to be used for saveLLMTokenUsage()
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String
* @param job Job related to the status update
* @param statusUpdate the status update containing the new competency recommendations
*/
public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) {
public CompetencyExtractionJob handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) {
Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId());
if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) {
llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withCourse(course.getId()).withUser(job.userId()));
}

var user = userRepository.findById(job.userId()).orElseThrow();
websocketService.send(user.getLogin(), websocketTopic(job.courseId()), statusUpdate);

return job;
}

private static String websocketTopic(long courseId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ public String createTokenForJob(Function<String, PyrisJob> tokenToJobFunction) {

public String addExerciseChatJob(Long courseId, Long exerciseId, Long sessionId) {
var token = generateJobIdToken();
var job = new ExerciseChatJob(token, courseId, exerciseId, sessionId);
var job = new ExerciseChatJob(token, courseId, exerciseId, sessionId, null);
jobMap.put(token, job);
return token;
}

public String addCourseChatJob(Long courseId, Long sessionId) {
var token = generateJobIdToken();
var job = new CourseChatJob(token, courseId, sessionId);
var job = new CourseChatJob(token, courseId, sessionId, null);
jobMap.put(token, job);
return token;
}
Expand All @@ -107,10 +107,19 @@ public String addIngestionWebhookJob() {
/**
* Remove a job from the job map.
*
* @param token the token
* @param job the job to remove
*/
public void removeJob(PyrisJob job) {
jobMap.remove(job.jobId());
}

/**
* Store a job in the job map.
*
* @param job the job to store
*/
public void removeJob(String token) {
jobMap.remove(token);
public void updateJob(PyrisJob job) {
jobMap.put(job.jobId(), job);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import de.tum.cit.aet.artemis.iris.service.pyris.job.CourseChatJob;
import de.tum.cit.aet.artemis.iris.service.pyris.job.ExerciseChatJob;
import de.tum.cit.aet.artemis.iris.service.pyris.job.IngestionWebhookJob;
import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob;
import de.tum.cit.aet.artemis.iris.service.pyris.job.TextExerciseChatJob;
import de.tum.cit.aet.artemis.iris.service.pyris.job.TrackedSessionBasedPyrisJob;
import de.tum.cit.aet.artemis.iris.service.session.IrisCourseChatSessionService;
import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService;
import de.tum.cit.aet.artemis.iris.service.session.IrisTextExerciseChatSessionService;
Expand Down Expand Up @@ -52,15 +54,16 @@ public PyrisStatusUpdateService(PyrisJobService pyrisJobService, IrisExerciseCha
}

/**
* Handles the status update of a exercise chat job and forwards it to {@link IrisExerciseChatSessionService#handleStatusUpdate(ExerciseChatJob, PyrisChatStatusUpdateDTO)}
* Handles the status update of a exercise chat job and forwards it to
* {@link IrisExerciseChatSessionService#handleStatusUpdate(TrackedSessionBasedPyrisJob, PyrisChatStatusUpdateDTO)}
*
* @param job the job that is updated
* @param statusUpdate the status update
*/
public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) {
irisExerciseChatSessionService.handleStatusUpdate(job, statusUpdate);
var updatedJob = irisExerciseChatSessionService.handleStatusUpdate(job, statusUpdate);

removeJobIfTerminated(statusUpdate.stages(), job.jobId());
removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob);
}

/**
Expand All @@ -71,22 +74,22 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta
* @param statusUpdate the status update
*/
public void handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) {
irisTextExerciseChatSessionService.handleStatusUpdate(job, statusUpdate);
var updatedJob = irisTextExerciseChatSessionService.handleStatusUpdate(job, statusUpdate);

removeJobIfTerminated(statusUpdate.stages(), job.jobId());
removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob);
}

/**
* Handles the status update of a course chat job and forwards it to
* {@link de.tum.cit.aet.artemis.iris.service.session.IrisCourseChatSessionService#handleStatusUpdate(CourseChatJob, PyrisChatStatusUpdateDTO)}
* {@link de.tum.cit.aet.artemis.iris.service.session.IrisCourseChatSessionService#handleStatusUpdate(TrackedSessionBasedPyrisJob, PyrisChatStatusUpdateDTO)}
*
* @param job the job that is updated
* @param statusUpdate the status update
*/
public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) {
courseChatSessionService.handleStatusUpdate(job, statusUpdate);
var updatedJob = courseChatSessionService.handleStatusUpdate(job, statusUpdate);

removeJobIfTerminated(statusUpdate.stages(), job.jobId());
removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob);
}

/**
Expand All @@ -97,26 +100,29 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu
* @param statusUpdate the status update
*/
public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) {
competencyGenerationService.handleStatusUpdate(job, statusUpdate);
var updatedJob = competencyGenerationService.handleStatusUpdate(job, statusUpdate);

removeJobIfTerminated(statusUpdate.stages(), job.jobId());
removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob);
}

/**
* Removes the job from the job service if the status update indicates that the job is terminated.
* This is the case if all stages are in a terminal state.
* Removes the job from the job service if the status update indicates that the job is terminated; updates it to distribute changes otherwise.
* A job is terminated if all stages are in a terminal state.
* <p>
*
* @see PyrisStageState#isTerminal()
*
* @param stages the stages of the status update
* @param job the job to remove
* @param job the job to remove or to update
*/
private void removeJobIfTerminated(List<PyrisStageDTO> stages, String job) {
private void removeJobIfTerminatedElseUpdate(List<PyrisStageDTO> stages, PyrisJob job) {
var isDone = stages.stream().map(PyrisStageDTO::state).allMatch(PyrisStageState::isTerminal);
if (isDone) {
pyrisJobService.removeJob(job);
}
else {
pyrisJobService.updateJob(job);
}
}

/**
Expand All @@ -128,6 +134,6 @@ private void removeJobIfTerminated(List<PyrisStageDTO> stages, String job) {
*/
public void handleStatusUpdate(IngestionWebhookJob job, PyrisLectureIngestionStatusUpdateDTO statusUpdate) {
statusUpdate.stages().forEach(stage -> log.info(stage.name() + ":" + stage.message()));
removeJobIfTerminated(statusUpdate.stages(), job.jobId());
removeJobIfTerminatedElseUpdate(statusUpdate.stages(), job);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
* This job is used to reference the details of a course chat session when Pyris sends a status update.
*/
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record CourseChatJob(String jobId, long courseId, long sessionId) implements SessionBasedPyrisJob {
public record CourseChatJob(String jobId, long courseId, long sessionId, Long traceId) implements TrackedSessionBasedPyrisJob {

@Override
public boolean canAccess(Course course) {
return courseId == course.getId();
}

@Override
public TrackedSessionBasedPyrisJob withTraceId(long traceId) {
return new CourseChatJob(jobId, courseId, sessionId, traceId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* This job is used to reference the details of a exercise chat session when Pyris sends a status update.
*/
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record ExerciseChatJob(String jobId, long courseId, long exerciseId, long sessionId) implements SessionBasedPyrisJob {
public record ExerciseChatJob(String jobId, long courseId, long exerciseId, long sessionId, Long traceId) implements TrackedSessionBasedPyrisJob {

@Override
public boolean canAccess(Course course) {
Expand All @@ -21,4 +21,9 @@ public boolean canAccess(Course course) {
public boolean canAccess(Exercise exercise) {
return exercise.getId().equals(exerciseId);
}

@Override
public TrackedSessionBasedPyrisJob withTraceId(long traceId) {
return new ExerciseChatJob(jobId, courseId, exerciseId, sessionId, traceId);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.tum.cit.aet.artemis.iris.service.pyris.job;

/**
* A Pyris job that has a session id and stored its own LLM usage tracing ID.
* This is used for chat jobs where we need to reference the trace ID later after chat suggestions have been generated.
*/
public interface TrackedSessionBasedPyrisJob extends PyrisJob {

long sessionId();

Long traceId();

TrackedSessionBasedPyrisJob withTraceId(long traceId);
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package de.tum.cit.aet.artemis.iris.service.session;

import java.util.HashMap;
import java.util.List;
import java.util.Optional;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

import de.tum.cit.aet.artemis.core.domain.LLMServiceType;
import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace;
import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService;
import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage;
import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender;
Expand All @@ -16,7 +15,7 @@
import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository;
import de.tum.cit.aet.artemis.iris.service.IrisMessageService;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.job.SessionBasedPyrisJob;
import de.tum.cit.aet.artemis.iris.service.pyris.job.TrackedSessionBasedPyrisJob;
import de.tum.cit.aet.artemis.iris.service.websocket.IrisChatWebsocketService;

public abstract class AbstractIrisChatSessionService<S extends IrisChatSession> implements IrisChatBasedFeatureInterface<S>, IrisRateLimitedFeatureInterface {
Expand All @@ -31,8 +30,6 @@ public abstract class AbstractIrisChatSessionService<S extends IrisChatSession>

private final ObjectMapper objectMapper;

protected final HashMap<String, LLMTokenUsageTrace> traces = new HashMap<>();

public AbstractIrisChatSessionService(IrisSessionRepository irisSessionRepository, ObjectMapper objectMapper, IrisMessageService irisMessageService,
IrisChatWebsocketService irisChatWebsocketService, LLMTokenUsageService llmTokenUsageService) {
this.irisSessionRepository = irisSessionRepository;
Expand Down Expand Up @@ -69,8 +66,9 @@ protected void updateLatestSuggestions(S session, List<String> latestSuggestions
*
* @param job The job that was executed
* @param statusUpdate The status update of the job
* @return the same job record or a new job record with the same job id if changes were made
*/
public void handleStatusUpdate(SessionBasedPyrisJob job, PyrisChatStatusUpdateDTO statusUpdate) {
public TrackedSessionBasedPyrisJob handleStatusUpdate(TrackedSessionBasedPyrisJob job, PyrisChatStatusUpdateDTO statusUpdate) {
var session = (S) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId());
IrisMessage savedMessage;
if (statusUpdate.result() != null) {
Expand All @@ -84,6 +82,7 @@ public void handleStatusUpdate(SessionBasedPyrisJob job, PyrisChatStatusUpdateDT
irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens());
}

TrackedSessionBasedPyrisJob updatedJob = job;
if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) {
if (savedMessage != null) {
// generated message is first sent and generated trace is saved
Expand All @@ -92,27 +91,27 @@ public void handleStatusUpdate(SessionBasedPyrisJob job, PyrisChatStatusUpdateDT
this.setLLMTokenUsageParameters(builder, session);
return builder;
});
traces.put(job.jobId(), llmTokenUsageTrace);

updatedJob = job.withTraceId(llmTokenUsageTrace.getId());
}
else {
// interaction suggestion is sent and appended to the generated trace if it exists, trace is then removed,
// because interaction suggestion is the last message from Iris in the pipeline
if (traces.containsKey(job.jobId())) {
var trace = traces.get(job.jobId());
// interaction suggestion is sent and appended to the generated trace if it exists
Optional.ofNullable(job.traceId()).flatMap(llmTokenUsageService::findLLMTokenUsageTraceById).ifPresentOrElse(trace -> {
llmTokenUsageService.appendRequestsToTrace(statusUpdate.tokens(), trace);
traces.remove(job.jobId());
}
else {
irisSessionRepository.save(session);
}, () -> {
llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> {
builder.withUser(session.getUser().getId());
this.setLLMTokenUsageParameters(builder, session);
return builder;
});
}
});
}
}

updateLatestSuggestions(session, statusUpdate.suggestions());

return updatedJob;
}

protected abstract void setLLMTokenUsageParameters(LLMTokenUsageService.LLMTokenUsageBuilder builder, S session);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ public void requestAndHandleResponse(IrisTextExerciseChatSession irisSession) {
* @param job The job that is updated
* @param statusUpdate The status update
*/
public void handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) {
public TextExerciseChatJob handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) {
// TODO: LLM Token Tracking - or better, make this class a subclass of AbstractIrisChatSessionService
var session = (IrisTextExerciseChatSession) irisSessionRepository.findByIdElseThrow(job.sessionId());
if (statusUpdate.result() != null) {
var message = session.newMessage();
Expand All @@ -127,6 +128,8 @@ public void handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatSta
else {
irisChatWebsocketService.sendMessage(session, null, statusUpdate.stages());
}

return job;
}

@Override
Expand Down

0 comments on commit 8671e35

Please sign in to comment.