diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java index 76bac2e38a7..cdc5fea23c9 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java @@ -21,7 +21,7 @@ import com.datastax.driver.core.utils.UUIDs; import org.junit.jupiter.api.Test; -import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.cassandra.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.utility.DockerImageName; diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java index 035aed0ca5e..7fcee99bc45 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java @@ -18,6 +18,7 @@ import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicLong; @@ -158,6 +159,7 @@ public List get(String sessionId, int lastN) { messages.add(new UserMessage(user)); } } + Collections.reverse(messages); return messages; } diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java index c4cf7a8eed4..8fbf0059cc7 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java @@ -17,21 +17,33 @@ package org.springframework.ai.chat.memory.cassandra; import java.time.Duration; +import java.util.List; +import java.util.UUID; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.cql.ResultSet; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import org.testcontainers.containers.CassandraContainer; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.testcontainers.cassandra.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; +import static org.assertj.core.api.Assertions.assertThat; + /** * Use `mvn failsafe:integration-test -Dit.test=CassandraChatMemoryIT` * @@ -43,7 +55,7 @@ class CassandraChatMemoryIT { @Container - static CassandraContainer cassandraContainer = new CassandraContainer<>(CassandraImage.DEFAULT_IMAGE); + static CassandraContainer cassandraContainer = new CassandraContainer(CassandraImage.DEFAULT_IMAGE); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withUserConfiguration(CassandraChatMemoryIT.TestApplication.class); @@ -57,6 +69,163 @@ void ensureBeanGetsCreated() { }); } + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER" }) + void add_shouldInsertSingleMessage(String content, MessageType messageType) { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var message = switch (messageType) { + case ASSISTANT -> new AssistantMessage(content); + case USER -> new UserMessage(content); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + chatMemory.add(sessionId, message); + + var cqlSession = context.getBean(CqlSession.class); + var query = """ + SELECT session_id, message_timestamp, a, u + FROM test_springframework.ai_chat_memory + WHERE session_id = ? + """; + ResultSet resultSet = cqlSession.execute(query, sessionId); + var rows = resultSet.all(); + + assertThat(rows.size()).isEqualTo(1); + + var firstRow = rows.get(0); + + assertThat(firstRow.getString("session_id")).isEqualTo(sessionId); + assertThat(firstRow.getInstant("message_timestamp")).isNotNull(); + if (messageType == MessageType.ASSISTANT) { + assertThat(firstRow.getString("a")).isEqualTo(content); + assertThat(firstRow.getString("u")).isNull(); + } + else if (messageType == MessageType.USER) { + assertThat(firstRow.getString("a")).isNull(); + assertThat(firstRow.getString("u")).isEqualTo(content); + } + }); + } + + @Test + void add_shouldInsertMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant"), + new UserMessage("Message from user")); + + chatMemory.add(sessionId, messages); + + var cqlSession = context.getBean(CqlSession.class); + var query = """ + SELECT session_id, message_timestamp, a, u + FROM test_springframework.ai_chat_memory + WHERE session_id = ? + ORDER BY message_timestamp ASC + """; + ResultSet resultSet = cqlSession.execute(query, sessionId); + var rows = resultSet.all(); + + assertThat(rows.size()).isEqualTo(messages.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = rows.get(i); + + assertThat(result.getString("session_id")).isNotNull(); + assertThat(result.getString("session_id")).isEqualTo(sessionId); + if (message.getMessageType() == MessageType.ASSISTANT) { + assertThat(result.getString("a")).isEqualTo(message.getText()); + assertThat(result.getString("u")).isNull(); + } + else if (message.getMessageType() == MessageType.USER) { + assertThat(result.getString("a")).isNull(); + assertThat(result.getString("u")).isEqualTo(message.getText()); + } + assertThat(result.getInstant("message_timestamp")).isNotNull(); + } + }); + } + + @Test + void get_shouldReturnMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant 1 - " + sessionId), + new AssistantMessage("Message from assistant 2 - " + sessionId), + new UserMessage("Message from user - " + sessionId)); + + chatMemory.add(sessionId, messages); + + var results = chatMemory.get(sessionId, Integer.MAX_VALUE); + + assertThat(results.size()).isEqualTo(messages.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.getMessageType()).isEqualTo(message.getMessageType()); + assertThat(result.getText()).isEqualTo(message.getText()); + } + }); + } + + @Test + void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from user - " + sessionId); + var assistantMessage = new AssistantMessage("Message from assistant - " + sessionId); + + chatMemory.add(sessionId, userMessage); + chatMemory.add(sessionId, assistantMessage); + + var results = chatMemory.get(sessionId, Integer.MAX_VALUE); + + assertThat(results.size()).isEqualTo(2); + + var messages = List.of(userMessage, assistantMessage); + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.getMessageType()).isEqualTo(message.getMessageType()); + assertThat(result.getText()).isEqualTo(message.getText()); + } + }); + } + + @Test + void clear_shouldDeleteMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + sessionId), + new UserMessage("Message from user - " + sessionId)); + + chatMemory.add(sessionId, messages); + + chatMemory.clear(sessionId); + + var cqlSession = context.getBean(CqlSession.class); + var query = """ + SELECT COUNT(*) + FROM test_springframework.ai_chat_memory + WHERE session_id = ? + """; + ResultSet resultSet = cqlSession.execute(query, sessionId); + var count = resultSet.all().get(0).getLong(0); + + assertThat(count).isZero(); + }); + } + @SpringBootConfiguration @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication {